├── models ├── __pycache__ │ ├── pfm_seg_models.cpython-312.pyc.140269122611248 │ ├── dora.cpython-312.pyc │ ├── lora.cpython-312.pyc │ ├── unet.cpython-312.pyc │ ├── utils.cpython-312.pyc │ ├── losses.cpython-312.pyc │ ├── __init__.cpython-312.pyc │ ├── cnn_adapter.cpython-312.pyc │ ├── pfm_seg_models.cpython-312.pyc │ ├── build_conch_v1_5.cpython-312.pyc │ ├── conch_v1_5_config.cpython-312.pyc │ └── transformer_adapter.cpython-312.pyc ├── conch_v1_5_config.py ├── __init__.py ├── unet.py ├── lora.py ├── dora.py ├── utils.py ├── losses.py └── transformer_adapter.py ├── data ├── __pycache__ │ ├── utils.cpython-312.pyc │ ├── __init__.cpython-312.pyc │ ├── seg_dataset.cpython-312.pyc │ └── transforms.cpython-312.pyc ├── example.json ├── __init__.py ├── seg_dataset.py ├── transforms.py └── utils.py ├── utils ├── __pycache__ │ ├── logs.cpython-312.pyc │ ├── __init__.cpython-312.pyc │ ├── metrics.cpython-312.pyc │ ├── trainer.cpython-312.pyc │ ├── evaluator.cpython-312.pyc │ ├── scheduler.cpython-312.pyc │ ├── yaml_utils.cpython-312.pyc │ └── visualization.cpython-312.pyc ├── logs.py ├── yaml_utils.py ├── __init__.py ├── metrics.py ├── scheduler.py └── evaluator.py ├── configs ├── unet.yaml ├── test.yaml └── config.yaml ├── scripts └── train.py └── README.md /models/__pycache__/pfm_seg_models.cpython-312.pyc.140269122611248: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/data/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/dora.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/dora.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/lora.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/lora.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/unet.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logs.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/logs.cpython-312.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/data/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/losses.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/losses.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/metrics.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/trainer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/trainer.cpython-312.pyc -------------------------------------------------------------------------------- /data/__pycache__/seg_dataset.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/data/__pycache__/seg_dataset.cpython-312.pyc -------------------------------------------------------------------------------- /data/__pycache__/transforms.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/data/__pycache__/transforms.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluator.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/evaluator.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scheduler.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/scheduler.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/yaml_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/yaml_utils.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/cnn_adapter.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/cnn_adapter.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/utils/__pycache__/visualization.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/pfm_seg_models.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/pfm_seg_models.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/build_conch_v1_5.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/build_conch_v1_5.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/conch_v1_5_config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/conch_v1_5_config.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer_adapter.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingxitong/PFM_Segmentation/HEAD/models/__pycache__/transformer_adapter.cpython-312.pyc -------------------------------------------------------------------------------- /utils/logs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import sys 4 | 5 | def setup_logging(log_dir: str): 6 | """Setup logging configuration.""" 7 | os.makedirs(log_dir, exist_ok=True) 8 | 9 | level = logging.INFO 10 | 11 | logging.basicConfig( 12 | level=level, 13 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 14 | handlers=[ 15 | logging.FileHandler(os.path.join(log_dir, 'training.log')), 16 | logging.StreamHandler(sys.stdout) 17 | ] 18 | ) -------------------------------------------------------------------------------- /utils/yaml_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def load_config(config_path): 4 | with open(config_path, 'r') as f: 5 | return yaml.safe_load(f) 6 | 7 | def update_config(config, key, value): 8 | keys = key.split('.') 9 | d = config 10 | for k in keys[:-1]: 11 | d = d.setdefault(k, {}) 12 | # 尝试转换值类型 13 | try: 14 | d[keys[-1]] = yaml.safe_load(value) 15 | except yaml.YAMLError: 16 | d[keys[-1]] = value 17 | 18 | def load_config_with_options(config_path, options=None): 19 | config = load_config(config_path) 20 | if options: 21 | for opt in options: 22 | if '=' not in opt: 23 | raise ValueError(f"Invalid option: {opt}") 24 | key, value = opt.split('=', 1) 25 | update_config(config, key, value) 26 | return config 27 | -------------------------------------------------------------------------------- /models/conch_v1_5_config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any 3 | 4 | from transformers import PretrainedConfig 5 | 6 | class ConchConfig(PretrainedConfig): 7 | model_type = "conch" 8 | 9 | def __init__( 10 | self, 11 | patch_size: int = 16, 12 | context_dim: int = 1024, 13 | embed_dim: int = 768, 14 | depth: int = 24, 15 | num_heads: int = 16, 16 | mlp_ratio: float = 4.0, 17 | qkv_bias: bool = True, 18 | init_values: float = 1e-6, 19 | pooler_n_queries_contrast: int = 1, 20 | **kwargs: Any, 21 | ): 22 | self.patch_size = patch_size 23 | self.context_dim = context_dim 24 | self.embed_dim = embed_dim 25 | self.depth = depth 26 | self.num_heads = num_heads 27 | self.mlp_ratio = mlp_ratio 28 | self.qkv_bias = qkv_bias 29 | self.init_values = init_values 30 | self.pooler_n_queries_contrast = pooler_n_queries_contrast 31 | 32 | super().__init__(**kwargs) -------------------------------------------------------------------------------- /data/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_classes": 3, 3 | "data": { 4 | "train": [ 5 | { 6 | "image_path": "/images/train/patient001_slide01.jpg", 7 | "mask_path": "/masks/train/patient001_slide01.png" 8 | }, 9 | { 10 | "image_path": "/images/train/patient001_slide02.jpg", 11 | "mask_path": "masks/train/patient001_slide02.png" 12 | }, 13 | { 14 | "image_path": "/images/train/patient002_slide01.jpg", 15 | "mask_path": "/masks/train/patient002_slide01.png" 16 | } 17 | ], 18 | "val": [ 19 | { 20 | "image_path": "/images/val/patient003_slide01.jpg", 21 | "mask_path": "/masks/val/patient003_slide01.png" 22 | }, 23 | { 24 | "image_path": "/images/val/patient003_slide02.jpg", 25 | "mask_path": "/masks/val/patient003_slide02.png" 26 | } 27 | ], 28 | "test": [ 29 | { 30 | "image_path": "/images/val/patient004_slide01.jpg", 31 | "mask_path": "/masks/val/patient004_slide01.png" 32 | }, 33 | { 34 | "image_path": "/images/val/patient004_slide02.jpg", 35 | "mask_path": "/masks/val/patient004_slide02.png" 36 | } 37 | ] 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models package for semantic segmentation. 3 | 4 | This package contains model definitions, loss functions, and utilities 5 | for semantic segmentation tasks. 6 | """ 7 | 8 | from .pfm_seg_models import create_pfm_segmentation_model, create_segmentation_model 9 | from .unet import create_unet_model, UNet 10 | from .lora import equip_model_with_lora 11 | from .cnn_adapter import equip_model_with_cnn_adapter 12 | from .transformer_adapter import equip_model_with_transformer_adapter 13 | from .losses import ( 14 | CrossEntropyLoss, DiceLoss, IoULoss, OHEMLoss,get_loss_function 15 | ) 16 | from .utils import ( 17 | count_parameters, initialize_weights, 18 | save_checkpoint, load_checkpoint, 19 | get_model_complexity, convert_to_onnx, print_model_summary 20 | ) 21 | 22 | __all__ = [ 23 | # Models 24 | 'create_pfm_segmentation_model', 25 | 'create_segmentation_model', 26 | 'create_unet_model', 27 | 'UNet', 28 | 'equip_model_with_lora', 29 | 'equip_model_with_cnn_adapter', 30 | 'equip_model_with_transformer_adapter', 31 | 32 | # Loss functions 33 | 'CrossEntropyLoss', 34 | 'DiceLoss', 35 | 'IoULoss', 36 | 'OHEMLoss', 37 | 'get_loss_function', 38 | 39 | # Utilities 40 | 'count_parameters', 41 | 'initialize_weights', 42 | 'save_checkpoint', 43 | 'load_checkpoint', 44 | 'get_model_complexity', 45 | 'convert_to_onnx', 46 | 'print_model_summary' 47 | ] 48 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data package for semantic segmentation. 3 | 4 | This package contains dataset classes, data transforms, and utilities 5 | for loading and preprocessing segmentation data. 6 | """ 7 | 8 | from data.seg_dataset import ( 9 | JSONSegmentationDataset, 10 | ) 11 | 12 | from .transforms import ( 13 | SegmentationTransforms, parse_transform_config, get_transforms, 14 | MixUp, CutMix, Mosaic, AdvancedAugmentationPipeline 15 | ) 16 | from .utils import ( 17 | create_dataloader, segmentation_collate_fn, compute_class_distribution, 18 | visualize_class_distribution, visualize_sample, create_color_map, 19 | analyze_dataset_quality, save_dataset_info, create_data_split 20 | ) 21 | 22 | __all__ = [ 23 | # Datasets 24 | 'BaseSegmentationDataset', 25 | 'CityscapesDataset', 26 | 'ADE20KDataset', 27 | 'PascalVOCDataset', 28 | 'CustomDataset', 29 | 'get_dataset', 30 | 'DatasetStatistics', 31 | 32 | # Transforms 33 | 'SegmentationTransforms', 34 | 'parse_transform_config', 35 | 'get_transforms', 36 | 'MixUp', 37 | 'CutMix', 38 | 'Mosaic', 39 | 'AdvancedAugmentationPipeline', 40 | 41 | # Utils 42 | 'create_dataloader', 43 | 'segmentation_collate_fn', 44 | 'compute_class_distribution', 45 | 'visualize_class_distribution', 46 | 'visualize_sample', 47 | 'create_color_map', 48 | 'analyze_dataset_quality', 49 | 'save_dataset_info', 50 | 'create_data_split' 51 | ] 52 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils package for semantic segmentation. 3 | 4 | This package contains training utilities, evaluation metrics, visualization tools, 5 | and other helper functions for semantic segmentation. 6 | """ 7 | 8 | from .trainer import SegmentationTrainer 9 | from .logs import setup_logging 10 | from .evaluator import SegmentationEvaluator 11 | from .metrics import SegmentationMetrics, StreamingMetrics 12 | from .scheduler import ( 13 | CosineAnnealingWithWarmup, PolynomialLR, WarmupMultiStepLR, 14 | OneCycleLR, CyclicLR, get_scheduler, WarmupScheduler 15 | ) 16 | from .visualization import ( 17 | create_color_palette, tensor_to_image, 18 | apply_color_map, visualize_prediction, save_predictions, 19 | plot_training_history, plot_confusion_matrix, plot_class_metrics, 20 | create_interactive_training_dashboard, visualize_feature_maps 21 | ) 22 | 23 | __all__ = [ 24 | # Training 25 | 'SegmentationTrainer', 26 | 27 | # Evaluation 28 | 'SegmentationEvaluator', 29 | 30 | # Metrics 31 | 'SegmentationMetrics', 32 | 'StreamingMetrics', 33 | 34 | # Schedulers 35 | 'CosineAnnealingWithWarmup', 36 | 'PolynomialLR', 37 | 'WarmupMultiStepLR', 38 | 'OneCycleLR', 39 | 'CyclicLR', 40 | 'get_scheduler', 41 | 'WarmupScheduler', 42 | 43 | # Visualization 44 | 'setup_matplotlib_for_plotting', 45 | 'create_color_palette', 46 | 'tensor_to_image', 47 | 'apply_color_map', 48 | 'visualize_prediction', 49 | 'save_predictions', 50 | 'plot_training_history', 51 | 'plot_confusion_matrix', 52 | 'plot_class_metrics', 53 | 'create_interactive_training_dashboard', 54 | 'visualize_feature_maps' 55 | 56 | # Utility functions 57 | 'setup_logging' 58 | ] 59 | -------------------------------------------------------------------------------- /configs/unet.yaml: -------------------------------------------------------------------------------- 1 | # UNet Model Configuration Example 2 | # This is an example configuration file for training UNet model 3 | 4 | # Dataset configuration 5 | dataset: 6 | json_file: "/mnt/sdb/chenwm/PFM_Segmentation/dataset_json/BCSS.json" # Path to the JSON configuration file 7 | num_classes: 22 # Number of classes; must match num_classes in the JSON file 8 | ignore_index: 255 # Index value to ignore during training 9 | 10 | system: 11 | num_workers: 2 # Number of worker threads for data loading 12 | pin_memory: true # Whether to use pin_memory to accelerate data loading 13 | seed: 42 14 | device: "cuda:7" # Device to use, 'cuda' or 'cpu' 15 | 16 | # Model configuration for UNet 17 | model: 18 | pfm_name: "unet" # Use "unet" to specify UNet model 19 | # Alternatively, you can use: 20 | # model_type: "unet" # This also works 21 | 22 | # UNet specific parameters (optional) 23 | n_channels: 3 # Number of input channels (default: 3 for RGB) 24 | bilinear: true # Whether to use bilinear upsampling (default: true) 25 | 26 | # Required for UNet 27 | num_classes: 22 # Must match num_classes in the JSON file 28 | 29 | # Note: For UNet, the following PFM-specific parameters are NOT required: 30 | # - pfm_weights_path (not needed for UNet) 31 | # - emb_dim (not needed for UNet) 32 | # - finetune_mode (not needed for UNet) 33 | 34 | # Training configuration 35 | training: 36 | batch_size: 4 37 | epochs: 150 38 | learning_rate: 0.001 39 | weight_decay: 0.0001 40 | use_amp: true # Whether to use automatic mixed precision (AMP) for training 41 | accumulate_grad_batches: 1 # Number of batches to accumulate gradients over before performing an optimizer step 42 | clip_grad_norm: 5.0 # Gradient clipping value to prevent exploding gradients 43 | 44 | # Data augmentation 45 | augmentation: 46 | RandomResizedCropSize: 944 # Input image size for training (can be any size for UNet) 47 | 48 | # Optimizer settings 49 | optimizer: 50 | type: "Adam" # Options: SGD, Adam, AdamW 51 | 52 | # Learning rate scheduler 53 | scheduler: 54 | type: "cosine" # Options: cosine, step 55 | warmup_epochs: 2 56 | 57 | # Loss function 58 | loss: 59 | type: "cross_entropy" # Options: cross_entropy, dice, ohem, iou, bce_with_logits 60 | 61 | # Validation configuration 62 | validation: 63 | eval_interval: 1 # Validate every N epochs 64 | batch_size: 4 65 | augmentation: 66 | ResizedSize: 944 # Input image size for validation (can be any size for UNet) 67 | 68 | logging: 69 | log_dir: "/mnt/sdb/chenwm/PFM_Segmentation/logs_unet" 70 | experiment_name: "unet_test_BCSS_2" 71 | 72 | visualization: 73 | save_interval: 10 # Save visualization results every N epochs 74 | num_vis_samples: 8 # Number of samples to visualize 75 | 76 | -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | # Dataset configuration 2 | dataset: 3 | json_file: "/mnt/sdb/chenwm/PFM_Segmentation/dataset_json/GlaS.json" # Path to the JSON configuration file 4 | num_classes: 2 # Number of classes; must match num_classes in the JSON file 5 | ignore_index: 255 # Index value to ignore during training 6 | 7 | system: 8 | num_workers: 2 # Number of worker threads for data loading 9 | pin_memory: true # Whether to use pin_memory to accelerate data loading 10 | seed: 42 11 | device: "cuda:5" # Device to use, 'cuda' or 'cpu' 12 | 13 | # Model configuration 14 | model: 15 | # Options: uni_v1, uni_v2, virchow_v1, virchow_v2, conch_v1_5, conch_v1, midnight12k, lunit_vits8, musk,PathOrchestra 16 | # Options: gigapath, phikon , patho3dmatrix-vision, phikon_v2, hoptimus_0, hoptimus_1, kaiko-vitl14, hibou_l 17 | pfm_name: "uni_v1" 18 | 19 | # midnight12k/hoptimus_0/hoptimus_1/uni_v2/gigapath: 1536 20 | # virchow_v1/virchow_v2: 1280 21 | # uni_v1/hibou_l/musk/phikon_v2/kaiko-vitl14/conch_v1_5/patho3dmatrix-vision/PathOrchestra: 1024 22 | # conch_v1/phikon: 768 23 | # lunit_vits8: 384 24 | emb_dim: 1024 25 | finetune_mode: 26 | type: dora # Options: lora, full, frozen, dora, cnn_adapter, transformer_adapter 27 | rank: 16 # only used when finetune_mode.type is lora or dora 28 | alpha: 1.0 # only used when finetune_mode.type is lora or dora 29 | transformer_adapter: 30 | depth: 2 # Vision Blocks 的数量 (默认: 2) 31 | mlp_ratio: 4.0 # MLP 隐藏层维度与嵌入维度的比率 (默认: 4.0) 32 | drop_rate: 0.0 # Dropout 率 (默认: 0.0) 33 | attn_drop_rate: 0.0 # Attention Dropout 率 (默认: 0.0) 34 | drop_path_rate: 0.1 # 随机深度率 (默认: 0.1) 35 | init_values: 1.0e-5 # LayerScale 初始值 (默认: 1e-5) 36 | qk_norm: false # 是否使用 QK 归一化 (默认: false) 37 | # CNN Adapter settings (only used when finetune_mode.type is cnn_adapter) 38 | # Reference: TransUNet ResNetV2 architecture 39 | cnn_adapter: 40 | width_factor: 1 # Width multiplier for CNN adapter (1 = 64 base channels) 41 | block_units: [3, 4, 9] # Number of bottleneck units in each block 42 | #Musk and hibou_l can only be loaded from huggingface.co,so we don't need to specify the path 43 | pfm_weights_path: '/mnt/sdb/chenwm/PFM_Segmentation/weight/UNI/pytorch_model.bin' # Path to the PFM model weights 44 | num_classes: 2 # Must match num_classes in the JSON file 45 | 46 | # Training configuration 47 | training: 48 | batch_size: 8 49 | epochs: 1 50 | learning_rate: 0.001 51 | weight_decay: 0.0001 52 | use_amp: true # Whether to use automatic mixed precision (AMP) for training 53 | accumulate_grad_batches: 1 # Number of batches to accumulate gradients over before performing an optimizer step 54 | clip_grad_norm: 5.0 # Gradient clipping value to prevent exploding gradients 55 | 56 | # Data augmentation 57 | augmentation: 58 | # virchow_v1,virchow_v2,uni_v2,midnight12k,kaiko-vitl14,hibou_l,hoptimus_0,hoptimus_1,: must be a multiple of 14 (token_size) 59 | # uni_v1,conch_v1_5,gigapath,conch_v1 ,phikon_v2,patho3dmatrix-vision,PathOrchestra,phikon,phikon_v2: must be a multiple of 16 (token_size) 60 | # special: musk: 384 61 | RandomResizedCropSize: 416 62 | 63 | # Optimizer settings 64 | optimizer: 65 | type: "Adam" # Options: SGD, Adam, AdamW 66 | 67 | # Learning rate scheduler 68 | scheduler: 69 | type: "cosine" # Options: cosine, step 70 | warmup_epochs: 2 71 | 72 | # Loss function 73 | loss: 74 | type: "dice" # Options: cross_entropy, dice, ohem, iou 75 | 76 | # Validation configuration 77 | validation: 78 | eval_interval: 1 # Validate every N epochs 79 | batch_size: 8 80 | augmentation: 81 | # virchow_v1,virchow_v2,uni_v2,midnight12k,kaiko-vitl14,hibou_l,hoptimus_0,hoptimus_1,: must be a multiple of 14 (token_size) 82 | # uni_v1,conch_v1_5,gigapath,conch_v1 ,phikon_v2,patho3dmatrix-vision,PathOrchestra,phikon,phikon_v2: must be a multiple of 16 (token_size) 83 | # special: musk: 384 84 | ResizedSize: 416 85 | 86 | logging: 87 | log_dir: "/mnt/sdb/chenwm/PFM_Segmentation/logs" 88 | experiment_name: "test_12-07" 89 | 90 | visualization: 91 | save_interval: 10 # Save visualization results every N epochs 92 | num_vis_samples: 8 # Number of samples to visualize -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # Dataset configuration 2 | dataset: 3 | json_file: "/mnt/sdb/chenwm/PFM_Segmentation/dataset_json/TNBC.json" # Path to the JSON configuration file 4 | num_classes: 2 # Number of classes; must match num_classes in the JSON file 5 | ignore_index: 255 # Index value to ignore during training 6 | 7 | system: 8 | num_workers: 2 # Number of worker threads for data loading 9 | pin_memory: true # Whether to use pin_memory to accelerate data loading 10 | seed: 42 11 | device: "cuda:0" # Device to use, 'cuda' or 'cpu' 12 | 13 | # Model configuration 14 | model: 15 | # Options: uni_v1, uni_v2, virchow_v1, virchow_v2, conch_v1_5, conch_v1, midnight12k, lunit_vits8, musk,PathOrchestra 16 | # Options: gigapath, phikon , patho3dmatrix-vision, phikon_v2, hoptimus_0, hoptimus_1, kaiko-vitl14, hibou_l 17 | pfm_name: "conch_v1" 18 | 19 | # midnight12k/hoptimus_0/hoptimus_1/uni_v2/gigapath: 1536 20 | # virchow_v1/virchow_v2: 1280 21 | # uni_v1/hibou_l/musk/phikon_v2/kaiko-vitl14/patho3dmatrix-vision/PathOrchestra/conch_v1_5: 1024 22 | # conch_v1/phikon: 768 23 | # lunit_vits8: 384 24 | emb_dim: 768 25 | finetune_mode: 26 | type: frozen # Options: lora, full, frozen, dora, cnn_adapter, transformer_adapter 27 | rank: 16 # only used when finetune_mode.type is lora or dora 28 | alpha: 1.0 # only used when finetune_mode.type is lora or dora 29 | transformer_adapter: 30 | depth: 2 # Vision Blocks 的数量 (默认: 2) 31 | mlp_ratio: 4.0 # MLP 隐藏层维度与嵌入维度的比率 (默认: 4.0) 32 | drop_rate: 0.0 # Dropout 率 (默认: 0.0) 33 | attn_drop_rate: 0.0 # Attention Dropout 率 (默认: 0.0) 34 | drop_path_rate: 0.1 # 随机深度率 (默认: 0.1) 35 | init_values: 1.0e-5 # LayerScale 初始值 (默认: 1e-5) 36 | qk_norm: false # 是否使用 QK 归一化 (默认: false) 37 | # CNN Adapter settings (only used when finetune_mode.type is cnn_adapter) 38 | # Reference: TransUNet ResNetV2 architecture 39 | cnn_adapter: 40 | width_factor: 1 # Width multiplier for CNN adapter (1 = 64 base channels) 41 | block_units: [3, 4, 9] # Number of bottleneck units in each block 42 | #Musk and hibou_l can only be loaded from huggingface.co,so we don't need to specify the path 43 | pfm_weights_path: '/mnt/sdb/chenwm/PFM_Segmentation/weight/conch_v1/pytorch_model.bin' # Path to the PFM model weights 44 | num_classes: 2 # Must match num_classes in the JSON file 45 | 46 | # Training configuration 47 | training: 48 | batch_size: 8 49 | epochs: 150 50 | learning_rate: 0.001 51 | weight_decay: 0.0001 52 | use_amp: true # Whether to use automatic mixed precision (AMP) for training 53 | accumulate_grad_batches: 1 # Number of batches to accumulate gradients over before performing an optimizer step 54 | clip_grad_norm: 5.0 # Gradient clipping value to prevent exploding gradients 55 | 56 | # Data augmentation 57 | augmentation: 58 | # virchow_v1,virchow_v2,uni_v2,midnight12k,kaiko-vitl14,hibou_l,hoptimus_0,hoptimus_1,: must be a multiple of 14 (token_size) 59 | # uni_v1,conch_v1_5,gigapath,conch_v1 ,phikon_v2,patho3dmatrix-vision,PathOrchestra,phikon,phikon_v2: must be a multiple of 16 (token_size) 60 | # special: musk: 384 61 | RandomResizedCropSize: 224 62 | 63 | # Optimizer settings 64 | optimizer: 65 | type: "Adam" # Options: SGD, Adam, AdamW 66 | 67 | # Learning rate scheduler 68 | scheduler: 69 | type: "cosine" # Options: cosine, step 70 | warmup_epochs: 2 71 | 72 | # Loss function 73 | loss: 74 | type: "dice" # Options: cross_entropy, dice, ohem, iou 75 | 76 | # Validation configuration 77 | validation: 78 | eval_interval: 1 # Validate every N epochs 79 | batch_size: 8 80 | augmentation: 81 | # virchow_v1,virchow_v2,uni_v2,midnight12k,kaiko-vitl14,hibou_l: must be a multiple of 14 (token_size) 82 | # uni_v1,conch_v1_5,gigapath,conch_v1,phikon,phikon_v2,patho3dmatrix-vision,PathOrchestra: must be a multiple of 16 (token_size) 83 | # lunit_vits8: must be a multiple of 8 (token_size) 84 | # special: H-optimus-1/H-optimus-0: 224 85 | # specia: musk: 384 86 | ResizedSize: 224 87 | 88 | logging: 89 | log_dir: "/mnt/sdb/chenwm/PFM_Segmentation_Output/logs_frozen" 90 | experiment_name: "test" 91 | 92 | visualization: 93 | save_interval: 10 # Save visualization results every N epochs 94 | num_vis_samples: 8 # Number of samples to visualize -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet Model for Semantic Segmentation 3 | 4 | This module implements a standard UNet architecture for semantic segmentation tasks. 5 | UNet is a popular encoder-decoder architecture with skip connections. 6 | 7 | Author: @Toby 8 | Function: UNet segmentation model 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from typing import Dict, Any, Optional 15 | 16 | 17 | class DoubleConv(nn.Module): 18 | """Double convolution block with batch normalization and ReLU activation.""" 19 | 20 | def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): 21 | super(DoubleConv, self).__init__() 22 | if not mid_channels: 23 | mid_channels = out_channels 24 | self.double_conv = nn.Sequential( 25 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 26 | nn.BatchNorm2d(mid_channels), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | return self.double_conv(x) 35 | 36 | 37 | class Down(nn.Module): 38 | """Downsampling block: max pooling + double convolution.""" 39 | 40 | def __init__(self, in_channels: int, out_channels: int): 41 | super(Down, self).__init__() 42 | self.maxpool_conv = nn.Sequential( 43 | nn.MaxPool2d(2), 44 | DoubleConv(in_channels, out_channels) 45 | ) 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | return self.maxpool_conv(x) 49 | 50 | 51 | class Up(nn.Module): 52 | """Upsampling block: upsampling + concatenation + double convolution.""" 53 | 54 | def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True): 55 | super(Up, self).__init__() 56 | 57 | # if bilinear, use the normal convolutions to reduce the number of channels 58 | if bilinear: 59 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 60 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 61 | else: 62 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 63 | self.conv = DoubleConv(in_channels, out_channels) 64 | 65 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 66 | x1 = self.up(x1) 67 | # input is CHW 68 | diffY = x2.size()[2] - x1.size()[2] 69 | diffX = x2.size()[3] - x1.size()[3] 70 | 71 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 72 | diffY // 2, diffY - diffY // 2]) 73 | # if you have padding issues, see 74 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3 75 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 76 | x = torch.cat([x2, x1], dim=1) 77 | return self.conv(x) 78 | 79 | 80 | class OutConv(nn.Module): 81 | """Output convolution layer.""" 82 | 83 | def __init__(self, in_channels: int, out_channels: int): 84 | super(OutConv, self).__init__() 85 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | return self.conv(x) 89 | 90 | 91 | class UNet(nn.Module): 92 | """ 93 | UNet model for semantic segmentation. 94 | 95 | Args: 96 | n_channels (int): Number of input channels (default: 3 for RGB) 97 | n_classes (int): Number of output classes 98 | bilinear (bool): Whether to use bilinear upsampling (default: True) 99 | """ 100 | 101 | def __init__(self, n_channels: int = 3, n_classes: int = 2, bilinear: bool = True): 102 | super(UNet, self).__init__() 103 | self.n_channels = n_channels 104 | self.n_classes = n_classes 105 | self.bilinear = bilinear 106 | 107 | self.inc = DoubleConv(n_channels, 64) 108 | self.down1 = Down(64, 128) 109 | self.down2 = Down(128, 256) 110 | self.down3 = Down(256, 512) 111 | factor = 2 if bilinear else 1 112 | self.down4 = Down(512, 1024 // factor) 113 | self.up1 = Up(1024, 512 // factor, bilinear) 114 | self.up2 = Up(512, 256 // factor, bilinear) 115 | self.up3 = Up(256, 128 // factor, bilinear) 116 | self.up4 = Up(128, 64, bilinear) 117 | self.outc = OutConv(64, n_classes) 118 | 119 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 120 | """ 121 | Forward pass through UNet. 122 | 123 | Args: 124 | x (torch.Tensor): Input tensor of shape (B, C, H, W) 125 | 126 | Returns: 127 | Dict[str, torch.Tensor]: Dictionary containing output predictions with key 'out' 128 | """ 129 | # Handle single channel images by repeating to 3 channels 130 | if x.size(1) == 1: 131 | x = x.repeat(1, 3, 1, 1) 132 | 133 | x1 = self.inc(x) 134 | x2 = self.down1(x1) 135 | x3 = self.down2(x2) 136 | x4 = self.down3(x3) 137 | x5 = self.down4(x4) 138 | x = self.up1(x5, x4) 139 | x = self.up2(x, x3) 140 | x = self.up3(x, x2) 141 | x = self.up4(x, x1) 142 | logits = self.outc(x) 143 | return {'out': logits} 144 | 145 | 146 | def create_unet_model(model_config: Dict[str, Any]) -> UNet: 147 | """ 148 | Factory function to create UNet model. 149 | 150 | Args: 151 | model_config (Dict[str, Any]): Model configuration dictionary 152 | Required keys: 153 | - num_classes: Number of segmentation classes 154 | Optional keys: 155 | - n_channels: Number of input channels (default: 3) 156 | - bilinear: Whether to use bilinear upsampling (default: True) 157 | 158 | Returns: 159 | UNet: Configured UNet model 160 | """ 161 | if 'num_classes' not in model_config: 162 | raise ValueError("Missing required configuration key: num_classes") 163 | 164 | n_channels = model_config.get('n_channels', 3) 165 | n_classes = model_config.get('num_classes', 2) 166 | bilinear = model_config.get('bilinear', True) 167 | 168 | model = UNet(n_channels=n_channels, n_classes=n_classes, bilinear=bilinear) 169 | return model 170 | 171 | -------------------------------------------------------------------------------- /models/lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import timm 3 | from timm.layers import use_fused_attn 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | from einops import repeat 10 | from .build_conch_v1_5 import Conch_V1_5_Attention 11 | from typing import Optional 12 | 13 | 14 | class LoRALinear(nn.Module): 15 | def __init__(self, in_features, out_features, bias=True, r=4, lora_alpha=1.0): 16 | super().__init__() 17 | self.r = r 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.bias = bias 21 | self.r = r 22 | self.lora_alpha = lora_alpha 23 | # original linear layer 24 | self.weight = nn.Parameter(torch.empty((out_features, in_features))) 25 | if bias: 26 | self.bias = nn.Parameter(torch.empty(out_features)) 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | 31 | # LoRA: low-rank adaptor 32 | self.lora_a = nn.Parameter(torch.zeros(in_features, r), requires_grad=True) 33 | self.lora_b = nn.Parameter(torch.zeros(r, out_features), requires_grad=True) 34 | self.scale = lora_alpha 35 | 36 | # initialization 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self) -> None: 40 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 41 | if self.bias is not None: 42 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 43 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 44 | init.uniform_(self.bias, -bound, bound) 45 | 46 | nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5)) 47 | nn.init.zeros_(self.lora_b) 48 | 49 | def forward(self, x): # shape [10000, 197, 1024] 50 | # compute original output 51 | ori_output = F.linear(x, self.weight, self.bias) 52 | lora_output = ((x @ self.lora_a) @ self.lora_b) * self.scale 53 | return ori_output + lora_output 54 | 55 | 56 | class LoRA_Attention(nn.Module): 57 | def __init__( 58 | self, 59 | dim: int, 60 | num_heads: int = 8, 61 | qkv_bias: bool = False, 62 | qk_norm: bool = False, 63 | attn_drop: float = 0., 64 | proj_drop: float = 0., 65 | norm_layer: nn.Module = nn.LayerNorm, 66 | lora_r: int = 16, 67 | lora_alpha: float = 1., 68 | ) -> None: 69 | super().__init__() 70 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 71 | self.num_heads = num_heads 72 | self.head_dim = dim // num_heads 73 | self.scale = self.head_dim ** -0.5 74 | self.fused_attn = use_fused_attn() 75 | 76 | self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=lora_r, lora_alpha=lora_alpha) 77 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 78 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 79 | self.attn_drop = nn.Dropout(attn_drop) 80 | self.proj = LoRALinear(dim, dim, r=lora_r, lora_alpha=lora_alpha) 81 | self.proj_drop = nn.Dropout(proj_drop) 82 | 83 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 84 | B, N, C = x.shape 85 | 86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 87 | q, k, v = qkv.unbind(0) 88 | q, k = self.q_norm(q), self.k_norm(k) 89 | 90 | if self.fused_attn: 91 | x = F.scaled_dot_product_attention( 92 | q, k, v, 93 | attn_mask=attn_mask, 94 | dropout_p=self.attn_drop.p if self.training else 0., 95 | ) 96 | else: 97 | q = q * self.scale 98 | attn = q @ k.transpose(-2, -1) 99 | if attn_mask is not None: 100 | try: 101 | attn = attn + attn_mask 102 | except Exception: 103 | # attn_mask might be boolean mask; use masked_fill for that case 104 | attn = attn.masked_fill(attn_mask, float('-inf')) 105 | 106 | attn = attn.softmax(dim=-1) 107 | attn = self.attn_drop(attn) 108 | x = attn @ v 109 | 110 | x = x.transpose(1, 2).reshape(B, N, C) 111 | x = self.proj(x) 112 | x = self.proj_drop(x) 113 | return x 114 | 115 | 116 | 117 | 118 | def equip_model_with_lora(pfm_name, model, rank, alpha): 119 | """ 120 | Equip a PFM model with LoRA by replacing its attention layers with LoRA_Attention layers. 121 | This version also copies original attention weights into the LoRA-Attention module. 122 | 123 | Args: 124 | pfm_name (str): Name of the PFM model. 125 | model (nn.Module): The PFM model to be equipped with LoRA. 126 | rank (int): Rank of the low-rank adaptation. 127 | alpha (float): Scaling factor for the LoRA output. 128 | 129 | Returns: 130 | nn.Module: The PFM model with LoRA applied to its attention layers. 131 | """ 132 | def copy_weights(src_attn, dst_lora_attn): 133 | with torch.no_grad(): 134 | dst_lora_attn.qkv.weight.copy_(src_attn.qkv.weight) 135 | if src_attn.qkv.bias is not None: 136 | dst_lora_attn.qkv.bias.copy_(src_attn.qkv.bias) 137 | 138 | dst_lora_attn.proj.weight.copy_(src_attn.proj.weight) 139 | if src_attn.proj.bias is not None: 140 | dst_lora_attn.proj.bias.copy_(src_attn.proj.bias) 141 | 142 | if hasattr(src_attn, 'q_norm') and hasattr(dst_lora_attn, 'q_norm') and isinstance(dst_lora_attn.q_norm, nn.LayerNorm): 143 | dst_lora_attn.q_norm.load_state_dict(src_attn.q_norm.state_dict()) 144 | if hasattr(src_attn, 'k_norm') and hasattr(dst_lora_attn, 'k_norm') and isinstance(dst_lora_attn.k_norm, nn.LayerNorm): 145 | dst_lora_attn.k_norm.load_state_dict(src_attn.k_norm.state_dict()) 146 | 147 | if pfm_name in ['uni_v1', 'uni_v2', 'virchow_v2', 'gigapath','virchow_v1','phikon','phikon_v2','hibou_l','musk','lunit_vits8','midnight12k','hoptimus_0','hoptimus_1','patho3dmatrix-vision','kaiko-vitl14','conch_v1']: 148 | for name, module in model.named_modules(): 149 | if isinstance(module, timm.models.vision_transformer.Attention): 150 | lora_attn = LoRA_Attention( 151 | dim=module.qkv.in_features, 152 | num_heads=module.num_heads, 153 | qkv_bias=module.qkv.bias is not None, 154 | qk_norm=isinstance(module.q_norm, nn.LayerNorm), 155 | attn_drop=module.attn_drop.p, 156 | proj_drop=module.proj_drop.p, 157 | lora_r=rank, 158 | lora_alpha=alpha, 159 | ) 160 | 161 | copy_weights(module, lora_attn) 162 | 163 | parent_module = dict(model.named_modules())[name.rsplit(".", 1)[0]] 164 | setattr(parent_module, name.rsplit('.', 1)[-1], lora_attn) 165 | 166 | elif pfm_name == 'conch_v1_5': 167 | for name, module in model.named_modules(): 168 | if isinstance(module, Conch_V1_5_Attention): 169 | lora_attn = LoRA_Attention( 170 | dim=module.qkv.in_features, 171 | num_heads=module.num_heads, 172 | qkv_bias=module.qkv.bias is not None, 173 | qk_norm=isinstance(module.q_norm, nn.LayerNorm), 174 | attn_drop=module.attn_drop.p, 175 | proj_drop=module.proj_drop.p, 176 | lora_r=rank, 177 | lora_alpha=alpha, 178 | ) 179 | 180 | copy_weights(module, lora_attn) 181 | 182 | parent_module = dict(model.named_modules())[name.rsplit(".", 1)[0]] 183 | setattr(parent_module, name.rsplit('.', 1)[-1], lora_attn) 184 | 185 | for param in model.parameters(): 186 | param.requires_grad = False 187 | 188 | for name, param in model.named_parameters(): 189 | if 'lora_a' in name or 'lora_b' in name: 190 | param.requires_grad = True 191 | 192 | return model 193 | -------------------------------------------------------------------------------- /data/seg_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simplified JSON-based dataset class 3 | 4 | Supports basic img_path and mask_path format, designed for semantic segmentation tasks. 5 | """ 6 | 7 | import os 8 | import json 9 | import torch 10 | from torch.utils.data import Dataset 11 | from PIL import Image 12 | import numpy as np 13 | from typing import Dict, List, Optional, Callable 14 | 15 | 16 | class JSONSegmentationDataset(Dataset): 17 | """ 18 | Semantic segmentation dataset based on a JSON file. 19 | 20 | Expected JSON format: 21 | { 22 | "num_classes": 3, 23 | "data": { 24 | "train": [ 25 | {"image_path": "/path/to/image1.jpg", "mask_path": "/path/to/mask1.png"}, 26 | {"image_path": "/path/to/image2.jpg", "mask_path": "/path/to/mask2.png"} 27 | ], 28 | "val": [...], 29 | "test": [...] 30 | } 31 | } 32 | 33 | Args: 34 | json_file (str): Path to the JSON config file. 35 | split (str): Dataset split ('train', 'val', or 'test'). 36 | transform (Optional[Callable]): Data transformation/augmentation function. 37 | """ 38 | 39 | def __init__(self, json_file: str, split: str = 'train', 40 | transform: Optional[Callable] = None): 41 | self.json_file = json_file 42 | self.split = split 43 | self.transform = transform 44 | 45 | # Load JSON configuration 46 | self.config = self._load_json_config() 47 | 48 | # Extract basic info 49 | self.num_classes = self.config.get('num_classes') 50 | self.ignore_index = 255 # fixed ignore label 51 | # Load data entries 52 | self.data_items = self._load_data_items() 53 | self.fixed_size = self._check_fixed_size() 54 | self.has_mask = self._check_has_mask() 55 | if not self.has_mask: 56 | self._reset_mask() 57 | print(f"Dataset loaded: split = {split}, samples = {len(self.data_items)}, classes = {self.num_classes}") 58 | 59 | def _load_json_config(self) -> Dict: 60 | """Load the JSON config file.""" 61 | try: 62 | with open(self.json_file, 'r', encoding='utf-8') as f: 63 | config = json.load(f) 64 | return config 65 | except FileNotFoundError: 66 | raise FileNotFoundError(f"JSON config file not found: {self.json_file}") 67 | except json.JSONDecodeError as e: 68 | raise ValueError(f"Invalid JSON format: {e}") 69 | 70 | def _check_has_mask(self) -> bool: 71 | """Check if the dataset has mask paths.""" 72 | for item in self.data_items: 73 | mask_path = item.get('mask_path') 74 | if mask_path == None: 75 | return False 76 | if not os.path.exists(mask_path): 77 | return False 78 | return True 79 | 80 | def _reset_mask(self) -> None: 81 | """Reset mask paths to None if they are not present.""" 82 | new_items = [] 83 | for item in self.data_items: 84 | item['mask_path'] = None 85 | new_items.append(item) 86 | self.data_items = new_items 87 | 88 | 89 | def _check_fixed_size(self) -> bool: 90 | """Check if the dataset has a fixed image size.""" 91 | _img_size = None 92 | for item in self.data_items: 93 | img_path = item.get('img_path', '') 94 | with Image.open(img_path) as img: 95 | if _img_size is None: 96 | _img_size = img.size 97 | elif _img_size != img.size: 98 | return False 99 | return True 100 | 101 | def _load_data_items(self) -> List[Dict]: 102 | """Load the data entries for the given split.""" 103 | data_config = self.config.get('data') 104 | split_data = data_config.get(self.split) 105 | 106 | if not split_data: 107 | raise ValueError(f"No data found for split '{self.split}'") 108 | 109 | processed_items = [] 110 | for item in split_data: 111 | processed_item = self._process_data_item(item) 112 | if processed_item: 113 | processed_items.append(processed_item) 114 | 115 | if not processed_items: 116 | raise ValueError(f"No valid items found in split '{self.split}'") 117 | 118 | return processed_items 119 | 120 | def _process_data_item(self, item: Dict) -> Optional[Dict]: 121 | """Process a single data entry.""" 122 | img_path = item.get('image_path', '') 123 | mask_path = item.get('mask_path', None) 124 | 125 | if not img_path or not mask_path: 126 | if self.split == 'train' or self.split == 'val': 127 | print(f"Missing image or mask path: {item}") 128 | return None 129 | 130 | return { 131 | 'img_path': img_path, 132 | 'mask_path': mask_path 133 | } 134 | 135 | def __len__(self) -> int: 136 | """Return the dataset size.""" 137 | return len(self.data_items) 138 | 139 | def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: 140 | """ 141 | Retrieve a single data entry. 142 | 143 | Args: 144 | index (int): Index of the data item 145 | 146 | Returns: 147 | Dict[str, torch.Tensor]: Dictionary containing image and label tensors 148 | """ 149 | item = self.data_items[index] 150 | 151 | image = Image.open(item['img_path']).convert('RGB') 152 | ori_size = image.size 153 | 154 | if self.has_mask: 155 | mask = Image.open(item['mask_path']) 156 | if mask.mode != 'L': 157 | mask = mask.convert('L') 158 | # 确保掩码尺寸与图像一致 159 | if mask.size != image.size: 160 | print(f"Warning: Resizing mask to match image size for {item['mask_path']}") 161 | print(f"Image size: {image.size}, Mask size: {mask.size}") 162 | mask = mask.resize(image.size, Image.Resampling.NEAREST) 163 | mask = np.array(mask, dtype=np.int64) 164 | else: 165 | mask = np.ones((ori_size[1],ori_size[0]), dtype=np.int64) * (-1) 166 | 167 | # Validate mask values (should be within [0, num_classes-1] or 255 as ignore index) 168 | unique_values = np.unique(mask) 169 | valid_values = set(range(self.num_classes)) | {self.ignore_index} 170 | invalid_values = set(unique_values) - valid_values 171 | 172 | if invalid_values and self.has_mask: 173 | print(f"Invalid label values {invalid_values} found in {item['mask_path']}") 174 | for invalid_val in invalid_values: 175 | mask[mask == invalid_val] = self.ignore_index 176 | 177 | # Apply transformation 178 | if self.transform: 179 | transformed = self.transform(image=np.array(image), mask=mask) 180 | image = transformed['image'] 181 | mask = transformed['mask'] 182 | else: 183 | image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 184 | mask = torch.from_numpy(mask).long() 185 | 186 | return { 187 | 'image': image, 188 | 'label': mask, 189 | 'ori_size': ori_size, 190 | 'image_path': item['img_path'], 191 | 'label_path': item['mask_path'] 192 | } 193 | 194 | def get_class_weights(self) -> torch.Tensor: 195 | """ 196 | Compute class weights to handle class imbalance. 197 | 198 | Returns: 199 | torch.Tensor: Computed class weights 200 | """ 201 | print("Computing class weights...") 202 | 203 | class_counts = np.zeros(self.num_classes) 204 | total_pixels = 0 205 | 206 | for item in self.data_items: 207 | mask = Image.open(item['mask_path']) 208 | if mask.mode != 'L': 209 | mask = mask.convert('L') 210 | mask_array = np.array(mask) 211 | 212 | for class_id in range(self.num_classes): 213 | class_counts[class_id] += np.sum(mask_array == class_id) 214 | 215 | valid_pixels = mask_array != self.ignore_index 216 | total_pixels += np.sum(valid_pixels) 217 | 218 | class_counts = np.maximum(class_counts, 1) 219 | weights = total_pixels / (self.num_classes * class_counts) 220 | weights = weights / weights.sum() * self.num_classes 221 | 222 | print(f"Class weights: {weights}") 223 | return torch.from_numpy(weights).float() 224 | 225 | 226 | def get_dataset(data_configs, transforms, split): 227 | json_file = data_configs.get('json_file') 228 | return JSONSegmentationDataset( 229 | json_file=json_file, 230 | split=split, 231 | transform=transforms 232 | ) 233 | 234 | -------------------------------------------------------------------------------- /models/dora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import timm 3 | from timm.layers import use_fused_attn 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | from einops import repeat 10 | from .build_conch_v1_5 import Conch_V1_5_Attention 11 | from typing import Optional 12 | 13 | 14 | class DoraLinear(nn.Module): 15 | def __init__(self, in_features, out_features, bias=True, r=4, lora_alpha=1.0): 16 | super().__init__() 17 | self.r = r 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.bias = bias 21 | self.r = r 22 | self.lora_alpha = lora_alpha 23 | # original linear layer (frozen) 24 | self.weight = nn.Parameter(torch.empty((out_features, in_features)), requires_grad=False) 25 | if bias: 26 | self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False) 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | # LoRA: low-rank adaptor 31 | # Note: lora_a: [in_features, r], lora_b: [r, out_features] 32 | # So lora_a @ lora_b: [in_features, r] @ [r, out_features] = [in_features, out_features] 33 | self.lora_a = nn.Parameter(torch.zeros(in_features, r), requires_grad=True) 34 | self.lora_b = nn.Parameter(torch.zeros(r, out_features), requires_grad=True) 35 | self.scale = lora_alpha 36 | 37 | # Dora magnitude scaling parameter 38 | self.m = nn.Parameter(torch.ones(1), requires_grad=True) # Scalar magnitude 39 | 40 | # initialization 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self) -> None: 44 | # Initialize original weight and bias (frozen) 45 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 46 | if self.bias is not None: 47 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 48 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 49 | init.uniform_(self.bias, -bound, bound) 50 | 51 | nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5)) 52 | nn.init.zeros_(self.lora_b) 53 | 54 | # Initialize magnitude to 1 55 | nn.init.ones_(self.m) 56 | 57 | def forward(self, x): # shape [batch, seq_len, in_features] 58 | # compute LoRA update: lora_a @ lora_b 59 | lora_update = self.lora_a @ self.lora_b # Shape: [in_features, out_features] 60 | 61 | # Apply scale 62 | lora_update = lora_update * self.scale # Shape: [in_features, out_features] 63 | 64 | # compute original weight norm and updated weight norm (Frobenius norm) 65 | weight_norm = self.weight.norm(p='fro') 66 | updated_weight = self.weight + lora_update.t() # Shape: [out_features, in_features] 67 | updated_norm = updated_weight.norm(p='fro') 68 | 69 | # compute direction (unit vector in matrix space) 70 | direction = updated_weight / updated_norm # Shape: [out_features, in_features] 71 | 72 | # apply Dora: magnitude * direction * original_norm 73 | dora_weight = self.m * direction * weight_norm 74 | 75 | # compute output using Dora-adjusted weight 76 | output = F.linear(x, dora_weight, self.bias) 77 | return output 78 | 79 | 80 | class Dora_Attention(nn.Module): 81 | def __init__( 82 | self, 83 | dim: int, 84 | num_heads: int = 8, 85 | qkv_bias: bool = False, 86 | qk_norm: bool = False, 87 | attn_drop: float = 0., 88 | proj_drop: float = 0., 89 | norm_layer: nn.Module = nn.LayerNorm, 90 | dora_r: int = 16, 91 | dora_alpha: float = 1., 92 | ) -> None: 93 | super().__init__() 94 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 95 | self.num_heads = num_heads 96 | self.head_dim = dim // num_heads 97 | self.scale = self.head_dim ** -0.5 98 | self.fused_attn = use_fused_attn() 99 | 100 | self.qkv = DoraLinear(dim, dim * 3, bias=qkv_bias, r=dora_r, lora_alpha=dora_alpha) 101 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 102 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 103 | self.attn_drop = nn.Dropout(attn_drop) 104 | self.proj = DoraLinear(dim, dim, r=dora_r, lora_alpha=dora_alpha) 105 | self.proj_drop = nn.Dropout(proj_drop) 106 | 107 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 108 | B, N, C = x.shape 109 | 110 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 111 | q, k, v = qkv.unbind(0) 112 | q, k = self.q_norm(q), self.k_norm(k) 113 | 114 | if self.fused_attn: 115 | x = F.scaled_dot_product_attention( 116 | q, k, v, 117 | attn_mask=attn_mask, 118 | dropout_p=self.attn_drop.p if self.training else 0., 119 | ) 120 | else: 121 | q = q * self.scale 122 | attn = q @ k.transpose(-2, -1) 123 | if attn_mask is not None: 124 | try: 125 | attn = attn + attn_mask 126 | except Exception: 127 | # attn_mask might be boolean mask; use masked_fill for that case 128 | attn = attn.masked_fill(attn_mask, float('-inf')) 129 | 130 | attn = attn.softmax(dim=-1) 131 | attn = self.attn_drop(attn) 132 | x = attn @ v 133 | 134 | x = x.transpose(1, 2).reshape(B, N, C) 135 | x = self.proj(x) 136 | x = self.proj_drop(x) 137 | return x 138 | 139 | 140 | def equip_model_with_dora(pfm_name, model, rank, alpha): 141 | """ 142 | Equip a PFM model with Dora by replacing its attention layers with Dora_Attention layers. 143 | This version also copies original attention weights into the Dora-Attention module. 144 | 145 | Args: 146 | pfm_name (str): Name of the PFM model. 147 | model (nn.Module): The PFM model to be equipped with Dora. 148 | rank (int): Rank of the low-rank adaptation. 149 | alpha (float): Scaling factor for the Dora output. 150 | 151 | Returns: 152 | nn.Module: The PFM model with Dora applied to its attention layers. 153 | """ 154 | def copy_weights(src_attn, dst_dora_attn): 155 | with torch.no_grad(): 156 | # Copy original weights to the frozen part of Dora 157 | dst_dora_attn.qkv.weight.copy_(src_attn.qkv.weight) 158 | if src_attn.qkv.bias is not None: 159 | dst_dora_attn.qkv.bias.copy_(src_attn.qkv.bias) 160 | 161 | dst_dora_attn.proj.weight.copy_(src_attn.proj.weight) 162 | if src_attn.proj.bias is not None: 163 | dst_dora_attn.proj.bias.copy_(src_attn.proj.bias) 164 | 165 | if hasattr(src_attn, 'q_norm') and hasattr(dst_dora_attn, 'q_norm') and isinstance(dst_dora_attn.q_norm, nn.LayerNorm): 166 | dst_dora_attn.q_norm.load_state_dict(src_attn.q_norm.state_dict()) 167 | if hasattr(src_attn, 'k_norm') and hasattr(dst_dora_attn, 'k_norm') and isinstance(dst_dora_attn.k_norm, nn.LayerNorm): 168 | dst_dora_attn.k_norm.load_state_dict(src_attn.k_norm.state_dict()) 169 | 170 | if pfm_name in ['uni_v1', 'uni_v2', 'virchow_v2', 'gigapath','virchow_v1','phikon','phikon_v2','hibou_l','musk','lunit_vits8','midnight12k','hoptimus_0','hoptimus_1','patho3dmatrix-vision','kaiko-vitl14','conch_v1']: 171 | for name, module in model.named_modules(): 172 | if isinstance(module, timm.models.vision_transformer.Attention): 173 | dora_attn = Dora_Attention( 174 | dim=module.qkv.in_features, 175 | num_heads=module.num_heads, 176 | qkv_bias=module.qkv.bias is not None, 177 | qk_norm=isinstance(module.q_norm, nn.LayerNorm), 178 | attn_drop=module.attn_drop.p, 179 | proj_drop=module.proj_drop.p, 180 | dora_r=rank, 181 | dora_alpha=alpha, 182 | ) 183 | 184 | copy_weights(module, dora_attn) 185 | 186 | parent_module = dict(model.named_modules())[name.rsplit(".", 1)[0]] 187 | setattr(parent_module, name.rsplit('.', 1)[-1], dora_attn) 188 | 189 | elif pfm_name == 'conch_v1_5': 190 | for name, module in model.named_modules(): 191 | if isinstance(module, Conch_V1_5_Attention): 192 | dora_attn = Dora_Attention( 193 | dim=module.qkv.in_features, 194 | num_heads=module.num_heads, 195 | qkv_bias=module.qkv.bias is not None, 196 | qk_norm=isinstance(module.q_norm, nn.LayerNorm), 197 | attn_drop=module.attn_drop.p, 198 | proj_drop=module.proj_drop.p, 199 | dora_r=rank, 200 | dora_alpha=alpha, 201 | ) 202 | 203 | copy_weights(module, dora_attn) 204 | 205 | parent_module = dict(model.named_modules())[name.rsplit(".", 1)[0]] 206 | setattr(parent_module, name.rsplit('.', 1)[-1], dora_attn) 207 | 208 | # Freeze all original parameters 209 | for param in model.parameters(): 210 | param.requires_grad = False 211 | 212 | # Only train Dora-specific parameters: lora_a, lora_b, and m 213 | for name, param in model.named_parameters(): 214 | if 'lora_a' in name or 'lora_b' in name or name.endswith('.m'): 215 | param.requires_grad = True 216 | 217 | return model -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Utilities for Semantic Segmentation 3 | 4 | This module contains utility functions for model management, including 5 | model creation, weight initialization, and checkpoint handling. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from typing import Dict, Any, Optional, Union 11 | import os 12 | 13 | 14 | 15 | def count_parameters(model: nn.Module) -> Dict[str, float]: 16 | """ 17 | Count the number of parameters in a model and return in millions (M). 18 | 19 | Args: 20 | model (nn.Module): PyTorch model 21 | 22 | Returns: 23 | Dict[str, float]: Dictionary containing parameter counts in millions (M) 24 | with 2 decimal places precision 25 | """ 26 | total_params = sum(p.numel() for p in model.parameters()) / 1e6 # Convert to millions 27 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 28 | 29 | return { 30 | 'total_parameters(M)': round(total_params, 2), 31 | 'trainable_parameters(M)': round(trainable_params, 2), 32 | 'non_trainable_parameters(M)': round(total_params - trainable_params, 2) 33 | } 34 | 35 | 36 | def initialize_weights(model: nn.Module, init_type: str = 'kaiming') -> None: 37 | """ 38 | Initialize model weights with specified initialization method. 39 | 40 | Args: 41 | model (nn.Module): PyTorch model 42 | init_type (str): Initialization type ('kaiming', 'xavier', 'normal', 'zero') 43 | """ 44 | for m in model.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | if init_type == 'kaiming': 47 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 48 | elif init_type == 'xavier': 49 | nn.init.xavier_normal_(m.weight) 50 | elif init_type == 'normal': 51 | nn.init.normal_(m.weight, 0, 0.01) 52 | elif init_type == 'zero': 53 | nn.init.zeros_(m.weight) 54 | 55 | if m.bias is not None: 56 | nn.init.constant_(m.bias, 0) 57 | 58 | elif isinstance(m, nn.BatchNorm2d): 59 | nn.init.constant_(m.weight, 1) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | elif isinstance(m, nn.Linear): 63 | if init_type == 'kaiming': 64 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 65 | elif init_type == 'xavier': 66 | nn.init.xavier_normal_(m.weight) 67 | elif init_type == 'normal': 68 | nn.init.normal_(m.weight, 0, 0.01) 69 | elif init_type == 'zero': 70 | nn.init.zeros_(m.weight) 71 | 72 | if m.bias is not None: 73 | nn.init.constant_(m.bias, 0) 74 | 75 | 76 | def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, 77 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], 78 | epoch: int, loss: float, metrics: Dict[str, float], 79 | checkpoint_path: str, is_best: bool = False) -> None: 80 | """ 81 | Save model checkpoint. 82 | 83 | Args: 84 | model (nn.Module): PyTorch model 85 | optimizer (torch.optim.Optimizer): Optimizer 86 | scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]): Learning rate scheduler 87 | epoch (int): Current epoch 88 | loss (float): Current loss value 89 | metrics (Dict[str, float]): Evaluation metrics 90 | checkpoint_path (str): Path to save checkpoint 91 | is_best (bool): Whether this is the best checkpoint 92 | """ 93 | checkpoint = { 94 | 'epoch': epoch, 95 | 'model_state_dict': model.state_dict(), 96 | 'optimizer_state_dict': optimizer.state_dict(), 97 | 'scheduler_state_dict': scheduler.state_dict() if scheduler else None, 98 | 'loss': loss, 99 | 'metrics': metrics 100 | } 101 | 102 | torch.save(checkpoint, checkpoint_path) 103 | 104 | if is_best: 105 | best_path = os.path.join(os.path.dirname(checkpoint_path), 'best_model.pth') 106 | torch.save(checkpoint, best_path) 107 | 108 | 109 | def load_checkpoint(model: nn.Module, checkpoint_path: str, 110 | optimizer: Optional[torch.optim.Optimizer] = None, 111 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, 112 | device: str = 'cpu') -> Dict[str, Any]: 113 | """ 114 | Load model checkpoint. 115 | 116 | Args: 117 | model (nn.Module): PyTorch model 118 | checkpoint_path (str): Path to checkpoint file 119 | optimizer (Optional[torch.optim.Optimizer]): Optimizer to load state 120 | scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]): Scheduler to load state 121 | device (str): Device to load checkpoint on 122 | 123 | Returns: 124 | Dict[str, Any]: Checkpoint information 125 | """ 126 | if not os.path.exists(checkpoint_path): 127 | raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") 128 | 129 | checkpoint = torch.load(checkpoint_path, map_location=device) 130 | 131 | # Load model state 132 | model.load_state_dict(checkpoint['model_state_dict']) 133 | 134 | # Load optimizer state if provided 135 | if optimizer and 'optimizer_state_dict' in checkpoint: 136 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 137 | 138 | # Load scheduler state if provided 139 | if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']: 140 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 141 | 142 | return { 143 | 'epoch': checkpoint.get('epoch', 0), 144 | 'loss': checkpoint.get('loss', float('inf')), 145 | 'metrics': checkpoint.get('metrics', {}) 146 | } 147 | 148 | 149 | def get_model_complexity(model: nn.Module, input_size: tuple = (1, 3, 512, 512)) -> Dict[str, Any]: 150 | """ 151 | Analyze model complexity including parameters, FLOPs, and memory usage. 152 | 153 | Args: 154 | model (nn.Module): PyTorch model 155 | input_size (tuple): Input tensor size 156 | 157 | Returns: 158 | Dict[str, Any]: Model complexity metrics 159 | """ 160 | # Count parameters 161 | param_stats = count_parameters(model) 162 | 163 | # Estimate model size in MB 164 | param_size = sum(p.numel() * p.element_size() for p in model.parameters()) 165 | buffer_size = sum(b.numel() * b.element_size() for b in model.buffers()) 166 | model_size_mb = (param_size + buffer_size) / (1024 ** 2) 167 | 168 | # Create dummy input for memory estimation 169 | dummy_input = torch.randn(input_size) 170 | model.eval() 171 | 172 | with torch.no_grad(): 173 | # Estimate memory usage 174 | torch.cuda.empty_cache() if torch.cuda.is_available() else None 175 | 176 | if torch.cuda.is_available(): 177 | model = model.cuda() 178 | dummy_input = dummy_input.cuda() 179 | 180 | # Measure memory before forward pass 181 | torch.cuda.synchronize() 182 | mem_before = torch.cuda.memory_allocated() 183 | 184 | # Forward pass 185 | _ = model(dummy_input) 186 | 187 | # Measure memory after forward pass 188 | torch.cuda.synchronize() 189 | mem_after = torch.cuda.memory_allocated() 190 | 191 | memory_usage_mb = (mem_after - mem_before) / (1024 ** 2) 192 | else: 193 | # CPU memory estimation (approximate) 194 | output = model(dummy_input) 195 | memory_usage_mb = sum( 196 | tensor.numel() * tensor.element_size() 197 | for tensor in [dummy_input, output['out']] 198 | ) / (1024 ** 2) 199 | 200 | return { 201 | 'parameters': param_stats, 202 | 'model_size_mb': model_size_mb, 203 | 'memory_usage_mb': memory_usage_mb, 204 | 'input_size': input_size 205 | } 206 | 207 | 208 | def convert_to_onnx(model: nn.Module, output_path: str, 209 | input_size: tuple = (1, 3, 512, 512), 210 | opset_version: int = 11) -> None: 211 | """ 212 | Convert PyTorch model to ONNX format. 213 | 214 | Args: 215 | model (nn.Module): PyTorch model 216 | output_path (str): Path to save ONNX model 217 | input_size (tuple): Input tensor size 218 | opset_version (int): ONNX opset version 219 | """ 220 | try: 221 | import onnx 222 | import onnxruntime 223 | except ImportError: 224 | raise ImportError("ONNX and ONNXRuntime are required for ONNX conversion") 225 | 226 | model.eval() 227 | dummy_input = torch.randn(input_size) 228 | 229 | # Export to ONNX 230 | torch.onnx.export( 231 | model, 232 | dummy_input, 233 | output_path, 234 | export_params=True, 235 | opset_version=opset_version, 236 | do_constant_folding=True, 237 | input_names=['input'], 238 | output_names=['output'], 239 | dynamic_axes={ 240 | 'input': {0: 'batch_size'}, 241 | 'output': {0: 'batch_size'} 242 | } 243 | ) 244 | 245 | # Verify ONNX model 246 | onnx_model = onnx.load(output_path) 247 | onnx.checker.check_model(onnx_model) 248 | 249 | print(f"Model successfully converted to ONNX: {output_path}") 250 | 251 | 252 | def print_model_summary(model: nn.Module, input_size: tuple = (1, 3, 512, 512)) -> None: 253 | """ 254 | Print a comprehensive model summary. 255 | 256 | Args: 257 | model (nn.Module): PyTorch model 258 | input_size (tuple): Input tensor size 259 | """ 260 | print("=" * 80) 261 | print("MODEL SUMMARY") 262 | print("=" * 80) 263 | 264 | # Model architecture 265 | print(f"Model: {model.__class__.__name__}") 266 | print(f"Input size: {input_size}") 267 | 268 | # Parameter statistics 269 | param_stats = count_parameters(model) 270 | print(f"Total parameters: {param_stats['total_parameters']:,}") 271 | print(f"Trainable parameters: {param_stats['trainable_parameters']:,}") 272 | print(f"Non-trainable parameters: {param_stats['non_trainable_parameters']:,}") 273 | 274 | # Model complexity 275 | complexity = get_model_complexity(model, input_size) 276 | print(f"Model size: {complexity['model_size_mb']:.2f} MB") 277 | print(f"Memory usage: {complexity['memory_usage_mb']:.2f} MB") 278 | 279 | print("=" * 80) 280 | 281 | 282 | 283 | 284 | if __name__ == "__main__": 285 | # Test model utilities 286 | from .PFM import create_pfm_model 287 | 288 | # Create a test model 289 | model = create_pfm_model(num_classes=19, img_size=512) 290 | 291 | # Print model summary 292 | print_model_summary(model) 293 | 294 | # Test other utilities 295 | initialize_weights(model, 'kaiming') 296 | print("Model weights initialized successfully") 297 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Training Script for Semantic Segmentation 4 | 5 | This script provides a complete training pipeline for semantic segmentation models 6 | with support for various datasets, augmentations, loss functions, and optimization techniques. 7 | 8 | Author: @Toby 9 | Function: Train a semantic segmentation model using a configuration file. 10 | """ 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | import argparse 14 | import os 15 | import sys 16 | import yaml 17 | import torch 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | from torch.utils.data import DataLoader 21 | import random 22 | import albumentations as A 23 | import numpy as np 24 | import logging 25 | from typing import Dict, Any 26 | 27 | # Add project root to path 28 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 29 | 30 | from models import create_segmentation_model, count_parameters 31 | from models.losses import get_loss_function 32 | # from data.datasets import get_dataset 33 | # from data.transforms import get_transforms 34 | from data.utils import create_dataloader 35 | from data.transforms import get_transforms,SegmentationTransforms 36 | from data.seg_dataset import get_dataset 37 | from utils.trainer import SegmentationTrainer 38 | from utils.visualization import plot_training_history 39 | from utils.logs import setup_logging 40 | from utils.yaml_utils import load_config_with_options 41 | 42 | 43 | 44 | def parse_args(): 45 | """Parse command line arguments.""" 46 | parser = argparse.ArgumentParser(description='Training script for semantic segmentation') 47 | 48 | parser.add_argument('--config', type=str, default='/mnt/sdb/chenwm/PFM_Segmentation/configs/test.yaml', 49 | help='Path to configuration file') 50 | parser.add_argument('--options',nargs='+', 51 | help='override some settings in the used config, the key-value pair in xxx=yyy format will be merged into the yaml config file') 52 | parser.add_argument('--resume', type=str, default=None, 53 | help='Path to checkpoint to resume from') 54 | parser.add_argument('--device', type=str, default='', 55 | help='Device to use (cuda/cpu/auto)') 56 | 57 | return parser.parse_args() 58 | 59 | 60 | def load_config(config_path: str) -> Dict[str, Any]: 61 | """ 62 | Load configuration from YAML file. 63 | 64 | Args: 65 | config_path (str): Path to configuration file 66 | 67 | Returns: 68 | Dict[str, Any]: Configuration dictionary 69 | """ 70 | with open(config_path, 'r') as f: 71 | config = yaml.safe_load(f) 72 | return config 73 | 74 | 75 | 76 | 77 | def set_random_seed(seed: int): 78 | """Set random seed for reproducibility.""" 79 | random.seed(seed) 80 | np.random.seed(seed) 81 | torch.manual_seed(seed) 82 | torch.cuda.manual_seed(seed) 83 | torch.cuda.manual_seed_all(seed) 84 | torch.backends.cudnn.deterministic = True 85 | torch.backends.cudnn.benchmark = False 86 | 87 | 88 | 89 | def get_device(device_arg: str, config: Dict[str, Any]) -> str: 90 | """ 91 | Get device for training. 92 | 93 | Args: 94 | device_arg (str): Device argument from command line 95 | 96 | Returns: 97 | str: Device string 98 | """ 99 | # If device_arg is empty, get device from config or default to 'cuda' if available 100 | if not device_arg: 101 | device_arg = config['system'].get('device', 'cuda' if torch.cuda.is_available() else 'cpu') 102 | return torch.device(device_arg) 103 | 104 | def save_config(config: Dict[str, Any], save_path: str): 105 | """ 106 | Save configuration to a YAML file. 107 | 108 | Args: 109 | config (Dict[str, Any]): Configuration dictionary 110 | save_path (str): Path to save the configuration file 111 | """ 112 | with open(save_path, 'w') as f: 113 | yaml.dump(config, f, default_flow_style=False) 114 | print(f"Configuration saved to {save_path}") 115 | 116 | 117 | 118 | 119 | def worker_init_fn(worker_id): 120 | """Initialize worker with a random seed based on worker ID.""" 121 | seed = 42 # Base seed 122 | random.seed(seed + worker_id) 123 | np.random.seed(seed+ worker_id) 124 | torch.manual_seed(seed+ worker_id) 125 | torch.cuda.manual_seed(seed+ worker_id) 126 | torch.cuda.manual_seed_all(seed+ worker_id) 127 | 128 | 129 | def create_optimizer(model: nn.Module, config: Dict[str, Any]) -> optim.Optimizer: 130 | """ 131 | Create optimizer based on configuration. 132 | 133 | Args: 134 | model (nn.Module): Model to optimize 135 | config (Dict[str, Any]): Training configuration 136 | 137 | Returns: 138 | optim.Optimizer: Optimizer 139 | """ 140 | optimizer_config = config['training'].get('optimizer') 141 | optimizer_type = optimizer_config.get('type', 'SGD').lower() 142 | 143 | lr = config['training']['learning_rate'] 144 | weight_decay = config['training']['optimizer'].get('weight_decay', 1e-4) 145 | 146 | if optimizer_type == 'sgd': 147 | momentum = config['training'].get('momentum', 0.9) 148 | nesterov = optimizer_config.get('nesterov', True) 149 | 150 | return optim.SGD( 151 | model.parameters(), 152 | lr=lr, 153 | momentum=momentum, 154 | weight_decay=weight_decay, 155 | nesterov=nesterov 156 | ) 157 | 158 | elif optimizer_type == 'adam': 159 | betas = optimizer_config.get('betas', (0.9, 0.999)) 160 | eps = optimizer_config.get('eps', 1e-8) 161 | 162 | return optim.Adam( 163 | model.parameters(), 164 | lr=lr, 165 | betas=betas, 166 | eps=eps, 167 | weight_decay=weight_decay 168 | ) 169 | 170 | elif optimizer_type == 'adamw': 171 | betas = optimizer_config.get('betas', (0.9, 0.999)) 172 | eps = optimizer_config.get('eps', 1e-8) 173 | 174 | return optim.AdamW( 175 | model.parameters(), 176 | lr=lr, 177 | betas=betas, 178 | eps=eps, 179 | weight_decay=weight_decay 180 | ) 181 | 182 | else: 183 | raise ValueError(f"Unsupported optimizer type: {optimizer_type}") 184 | 185 | 186 | def main(): 187 | """Main training function.""" 188 | args = parse_args() 189 | 190 | # Load configuration 191 | config = load_config_with_options(args.config, args.options) 192 | 193 | # Set random seed 194 | seed = config['system'].get('seed', 42) 195 | set_random_seed(seed) 196 | generator = torch.Generator() 197 | generator.manual_seed(seed) 198 | 199 | # Get device 200 | device = get_device(args.device,config) 201 | 202 | # Setup logging 203 | log_dir = config['logging'].get('log_dir') 204 | experiment_name = config['logging'].get('experiment_name') 205 | log_dir = os.path.join(log_dir, experiment_name) 206 | os.makedirs(log_dir, exist_ok=True) 207 | save_config(config, os.path.join(log_dir, 'config.yaml')) 208 | setup_logging(log_dir) 209 | logger = logging.getLogger(__name__) 210 | 211 | logger.info("Starting training...") 212 | logger.info(f"Configuration file: {args.config}") 213 | logger.info(f"Device: {device}") 214 | logger.info(f"Random seed: {seed}") 215 | 216 | # Create model 217 | pfm_name = config['model'].get('pfm_name', 'unet') 218 | model_type = config['model'].get('model_type', '') 219 | if pfm_name.lower() == 'unet' or model_type.lower() == 'unet': 220 | logger.info(f"Creating model: UNet...") 221 | model = create_segmentation_model(config['model']) 222 | else: 223 | logger.info(f"Creating model: {pfm_name}...") 224 | finetune_mode = config['model'].get('finetune_mode', {}) 225 | if finetune_mode: 226 | logger.info(f"Model finetune-mode: {finetune_mode}") 227 | model = create_segmentation_model(config['model']) 228 | model = model.to(device) 229 | 230 | # Log model information 231 | model_params_info_dict = count_parameters(model) 232 | logger.info(f"Model parameters info: {model_params_info_dict}") 233 | # Create datasets and data loaders 234 | logger.info("Creating datasets...") 235 | dataset_config = config['dataset'] 236 | 237 | # Training dataset 238 | # Get normalization values based on model name 239 | from data.transforms import get_model_normalization 240 | pfm_name = config['model'].get('pfm_name', 'unet') 241 | mean, std = get_model_normalization(pfm_name) 242 | train_transforms = SegmentationTransforms.get_training_transforms( 243 | img_size=config['training']['augmentation']['RandomResizedCropSize'], 244 | seed=seed, 245 | mean=mean, 246 | std=std 247 | ) 248 | train_dataset = get_dataset(dataset_config, train_transforms, split='train') 249 | 250 | train_loader = create_dataloader( 251 | train_dataset, 252 | batch_size=config['training']['batch_size'], 253 | shuffle=True, 254 | generator=generator, 255 | num_workers=config['system'].get('num_workers', 4), 256 | pin_memory=config['system'].get('pin_memory', True), 257 | worker_init_fn=worker_init_fn, 258 | drop_last=False, 259 | ) 260 | 261 | # Validation dataset 262 | val_transforms = SegmentationTransforms.get_validation_transforms( 263 | img_size=config['validation']['augmentation']['ResizedSize'], 264 | mean=mean, 265 | std=std 266 | ) 267 | val_dataset = get_dataset(dataset_config, val_transforms, split='val') 268 | 269 | val_loader = create_dataloader( 270 | val_dataset, 271 | batch_size=config['validation']['batch_size'], 272 | shuffle=False, 273 | num_workers=config['system'].get('num_workers', 4), 274 | pin_memory=config['system'].get('pin_memory', True), 275 | drop_last=False 276 | ) 277 | 278 | logger.info(f"Training samples: {len(train_dataset)}") 279 | logger.info(f"Validation samples: {len(val_dataset)}") 280 | logger.info(f"Training batches: {len(train_loader)}") 281 | logger.info(f"Validation batches: {len(val_loader)}") 282 | 283 | # Create loss function 284 | logger.info("Creating loss function...") 285 | criterion = get_loss_function(config['training']['loss']) 286 | criterion = criterion.to(device) 287 | 288 | # Create optimizer 289 | logger.info("Creating optimizer...") 290 | optimizer = create_optimizer(model, config) 291 | 292 | # Create trainer 293 | logger.info("Creating trainer...") 294 | trainer = SegmentationTrainer( 295 | model=model, 296 | train_loader=train_loader, 297 | val_loader=val_loader, 298 | criterion=criterion, 299 | optimizer=optimizer, 300 | config=config, 301 | device=device 302 | ) 303 | 304 | # Resume from checkpoint if specified 305 | if args.resume: 306 | logger.info(f"Resuming from checkpoint: {args.resume}") 307 | trainer.load_checkpoint(args.resume) 308 | 309 | # Start training 310 | 311 | trainer.train() 312 | 313 | # Plot training history 314 | logger.info("Generating training history plots...") 315 | training_stats = trainer.get_training_stats() 316 | 317 | history_plot_path = os.path.join(log_dir, 'training_history.png') 318 | 319 | plot_training_history( 320 | train_losses=training_stats['train_losses'], 321 | val_losses=training_stats['val_losses'], 322 | val_metrics=training_stats['val_mious'], 323 | metric_name='mIoU', 324 | save_path=history_plot_path 325 | ) 326 | 327 | logger.info("Training completed successfully!") 328 | 329 | 330 | if __name__ == "__main__": 331 | main() 332 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss Functions for Semantic Segmentation 3 | 4 | This module contains various loss functions commonly used in semantic segmentation, 5 | including Cross Entropy, Focal Loss, Dice Loss, IoU Loss, and OHEM. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from typing import Optional, List, Tuple 12 | import numpy as np 13 | 14 | 15 | class CrossEntropyLoss(nn.Module): 16 | """ 17 | Standard Cross Entropy Loss for semantic segmentation. 18 | 19 | Args: 20 | ignore_index (int): Index to ignore in loss calculation 21 | weight (Optional[torch.Tensor]): Class weights for handling imbalanced datasets 22 | reduction (str): Reduction method ('mean', 'sum', 'none') 23 | """ 24 | 25 | def __init__(self, ignore_index: int = 255, weight: Optional[torch.Tensor] = None, 26 | reduction: str = 'mean'): 27 | super(CrossEntropyLoss, self).__init__() 28 | self.ignore_index = ignore_index 29 | self.weight = weight 30 | self.reduction = reduction 31 | 32 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 33 | """ 34 | Forward pass of Cross Entropy Loss. 35 | 36 | Args: 37 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 38 | target (torch.Tensor): Ground truth of shape (B, H, W) 39 | 40 | Returns: 41 | torch.Tensor: Loss value 42 | """ 43 | return F.cross_entropy( 44 | pred, target, 45 | weight=self.weight, 46 | ignore_index=self.ignore_index, 47 | reduction=self.reduction 48 | ) 49 | 50 | 51 | class DiceLoss(nn.Module): 52 | """ 53 | Dice Loss for semantic segmentation, particularly effective for small objects. 54 | 55 | Args: 56 | smooth (float): Smoothing factor to avoid division by zero 57 | ignore_index (int): Index to ignore in loss calculation 58 | reduction (str): Reduction method ('mean', 'sum', 'none') 59 | """ 60 | 61 | def __init__(self, smooth: float = 1e-5, ignore_index: int = 255, reduction: str = 'mean'): 62 | super(DiceLoss, self).__init__() 63 | self.smooth = smooth 64 | self.ignore_index = ignore_index 65 | self.reduction = reduction 66 | 67 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 68 | """ 69 | Forward pass of Dice Loss. 70 | 71 | Args: 72 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 73 | target (torch.Tensor): Ground truth of shape (B, H, W) 74 | 75 | Returns: 76 | torch.Tensor: Loss value 77 | """ 78 | # Convert predictions to probabilities 79 | pred = F.softmax(pred, dim=1) 80 | 81 | # One-hot encode target 82 | num_classes = pred.shape[1] 83 | target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float() 84 | 85 | # Create mask for valid pixels 86 | mask = (target != self.ignore_index).unsqueeze(1).float() 87 | pred = pred * mask 88 | target_one_hot = target_one_hot * mask 89 | 90 | # Compute Dice coefficient 91 | intersection = (pred * target_one_hot).sum(dim=(2, 3)) 92 | union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3)) 93 | 94 | dice_coeff = (2.0 * intersection + self.smooth) / (union + self.smooth) 95 | dice_loss = 1.0 - dice_coeff 96 | 97 | if self.reduction == 'mean': 98 | return dice_loss.mean() 99 | elif self.reduction == 'sum': 100 | return dice_loss.sum() 101 | else: 102 | return dice_loss 103 | 104 | 105 | class IoULoss(nn.Module): 106 | """ 107 | IoU (Intersection over Union) Loss for semantic segmentation. 108 | 109 | Args: 110 | smooth (float): Smoothing factor to avoid division by zero 111 | ignore_index (int): Index to ignore in loss calculation 112 | reduction (str): Reduction method ('mean', 'sum', 'none') 113 | """ 114 | 115 | def __init__(self, smooth: float = 1e-5, ignore_index: int = 255, reduction: str = 'mean'): 116 | super(IoULoss, self).__init__() 117 | self.smooth = smooth 118 | self.ignore_index = ignore_index 119 | self.reduction = reduction 120 | 121 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 122 | """ 123 | Forward pass of IoU Loss. 124 | 125 | Args: 126 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 127 | target (torch.Tensor): Ground truth of shape (B, H, W) 128 | 129 | Returns: 130 | torch.Tensor: Loss value 131 | """ 132 | # Convert predictions to probabilities 133 | pred = F.softmax(pred, dim=1) 134 | 135 | # One-hot encode target 136 | num_classes = pred.shape[1] 137 | target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float() 138 | 139 | # Create mask for valid pixels 140 | mask = (target != self.ignore_index).unsqueeze(1).float() 141 | pred = pred * mask 142 | target_one_hot = target_one_hot * mask 143 | 144 | # Compute IoU 145 | intersection = (pred * target_one_hot).sum(dim=(2, 3)) 146 | union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3)) - intersection 147 | 148 | iou = (intersection + self.smooth) / (union + self.smooth) 149 | iou_loss = 1.0 - iou 150 | 151 | if self.reduction == 'mean': 152 | return iou_loss.mean() 153 | elif self.reduction == 'sum': 154 | return iou_loss.sum() 155 | else: 156 | return iou_loss 157 | 158 | 159 | class OHEMLoss(nn.Module): 160 | """ 161 | Online Hard Example Mining (OHEM) Loss for focusing on hard examples. 162 | 163 | Args: 164 | thresh (float): Threshold for hard example selection 165 | min_kept (int): Minimum number of pixels to keep 166 | ignore_index (int): Index to ignore in loss calculation 167 | base_loss (str): Base loss function ('ce', 'focal') 168 | """ 169 | 170 | def __init__(self, thresh: float = 0.7, min_kept: int = 100000, 171 | ignore_index: int = 255, base_loss: str = 'ce'): 172 | super(OHEMLoss, self).__init__() 173 | self.thresh = thresh 174 | self.min_kept = min_kept 175 | self.ignore_index = ignore_index 176 | 177 | if base_loss == 'ce': 178 | self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none') 179 | else: 180 | raise ValueError(f"Unsupported base loss: {base_loss}") 181 | 182 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 183 | """ 184 | Forward pass of OHEM Loss. 185 | 186 | Args: 187 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 188 | target (torch.Tensor): Ground truth of shape (B, H, W) 189 | 190 | Returns: 191 | torch.Tensor: Loss value 192 | """ 193 | # Compute pixel-wise loss 194 | pixel_losses = self.criterion(pred, target) 195 | 196 | # Create mask for valid pixels 197 | mask = (target != self.ignore_index).float() 198 | pixel_losses = pixel_losses * mask 199 | 200 | # Sort losses in descending order 201 | sorted_losses, _ = torch.sort(pixel_losses.view(-1), descending=True) 202 | 203 | # Determine number of pixels to keep 204 | valid_pixels = mask.sum().int().item() 205 | keep_num = max(self.min_kept, int(valid_pixels * self.thresh)) 206 | keep_num = min(keep_num, valid_pixels) 207 | 208 | # Keep only the hardest examples 209 | if keep_num < valid_pixels: 210 | threshold = sorted_losses[keep_num] 211 | hard_mask = (pixel_losses >= threshold).float() 212 | return (pixel_losses * hard_mask).sum() / hard_mask.sum() 213 | else: 214 | return pixel_losses.sum() / mask.sum() 215 | 216 | 217 | class BCEWithLogitsLoss(nn.Module): 218 | """ 219 | Binary Cross Entropy with Logits Loss for semantic segmentation. 220 | For multi-class segmentation, this applies BCE to each class independently. 221 | 222 | Args: 223 | ignore_index (int): Index to ignore in loss calculation 224 | weight (Optional[torch.Tensor]): Class weights for handling imbalanced datasets 225 | reduction (str): Reduction method ('mean', 'sum', 'none') 226 | pos_weight (Optional[torch.Tensor]): Weight of positive examples 227 | """ 228 | 229 | def __init__(self, ignore_index: int = 255, weight: Optional[torch.Tensor] = None, 230 | reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): 231 | super(BCEWithLogitsLoss, self).__init__() 232 | self.ignore_index = ignore_index 233 | self.weight = weight 234 | self.reduction = reduction 235 | self.pos_weight = pos_weight 236 | self.bce_loss = nn.BCEWithLogitsLoss(weight=weight, reduction='none', pos_weight=pos_weight) 237 | 238 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 239 | """ 240 | Forward pass of BCE with Logits Loss. 241 | 242 | Args: 243 | pred (torch.Tensor): Predictions of shape (B, C, H, W) 244 | target (torch.Tensor): Ground truth of shape (B, H, W) 245 | 246 | Returns: 247 | torch.Tensor: Loss value 248 | """ 249 | # Convert target to one-hot encoding 250 | num_classes = pred.shape[1] 251 | target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float() 252 | 253 | # Create mask for valid pixels 254 | mask = (target != self.ignore_index).unsqueeze(1).float() 255 | 256 | # Apply BCE loss to each class 257 | loss = self.bce_loss(pred, target_one_hot) 258 | 259 | # Apply mask to ignore invalid pixels 260 | loss = loss * mask 261 | 262 | if self.reduction == 'mean': 263 | # Average over valid pixels only 264 | return loss.sum() / mask.sum() 265 | elif self.reduction == 'sum': 266 | return loss.sum() 267 | else: 268 | return loss 269 | 270 | 271 | def get_loss_function(loss_config: dict) -> nn.Module: 272 | """ 273 | Factory function to create loss function based on configuration. 274 | 275 | Args: 276 | loss_config (dict): Loss configuration dictionary 277 | 278 | Returns: 279 | nn.Module: Loss function 280 | """ 281 | loss_type = loss_config.get('type', 'cross_entropy').lower() 282 | ignore_index = loss_config.get('ignore_index', 255) 283 | 284 | if loss_type == 'cross_entropy' or loss_type == 'ce': 285 | weight = loss_config.get('class_weights') 286 | if weight is not None: 287 | weight = torch.tensor(weight, dtype=torch.float32) 288 | return CrossEntropyLoss(ignore_index=ignore_index, weight=weight) 289 | 290 | elif loss_type == 'dice': 291 | smooth = loss_config.get('dice_smooth', 1e-5) 292 | return DiceLoss(smooth=smooth, ignore_index=ignore_index) 293 | 294 | elif loss_type == 'iou': 295 | smooth = loss_config.get('iou_smooth', 1e-5) 296 | return IoULoss(smooth=smooth, ignore_index=ignore_index) 297 | 298 | elif loss_type == 'ohem': 299 | thresh = loss_config.get('ohem_thresh', 0.7) 300 | min_kept = loss_config.get('ohem_min_kept', 100000) 301 | base_loss = loss_config.get('ohem_base_loss', 'ce') 302 | return OHEMLoss(thresh=thresh, min_kept=min_kept, 303 | ignore_index=ignore_index, base_loss=base_loss) 304 | 305 | elif loss_type == 'bce_with_logits' or loss_type == 'bce': 306 | weight = loss_config.get('class_weights') 307 | if weight is not None: 308 | weight = torch.tensor(weight, dtype=torch.float32) 309 | pos_weight = loss_config.get('pos_weight') 310 | if pos_weight is not None: 311 | pos_weight = torch.tensor(pos_weight, dtype=torch.float32) 312 | return BCEWithLogitsLoss(ignore_index=ignore_index, weight=weight, pos_weight=pos_weight) 313 | 314 | else: 315 | raise ValueError(f"Unsupported loss type: {loss_type}") 316 | 317 | 318 | if __name__ == "__main__": 319 | # Test loss functions 320 | batch_size, num_classes, height, width = 2, 19, 64, 64 321 | 322 | # Create dummy data 323 | pred = torch.randn(batch_size, num_classes, height, width) 324 | target = torch.randint(0, num_classes, (batch_size, height, width)) 325 | 326 | # Test different loss functions 327 | losses = { 328 | 'CrossEntropy': CrossEntropyLoss(), 329 | 'Dice': DiceLoss(), 330 | 'IoU': IoULoss(), 331 | 'OHEM': OHEMLoss(), 332 | } 333 | 334 | for loss_name, loss_fn in losses.items(): 335 | loss_value = loss_fn(pred, target) 336 | print(f"{loss_name} Loss: {loss_value.item():.4f}") 337 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🩺 Pathology Foundation Models Meet Semantic Segmentation 2 | 3 | A comprehensive semantic segmentation framework based on Pathology Foundation Models (PFMs), designed specifically for pathological image analysis, supporting multiple state-of-the-art pathology foundation models with complete training, inference, and evaluation capabilities. 4 | 5 | ## 🌟 Features 6 | 7 | - 🧬 **Support for SOTA Pathology Foundation Models**: uni_v1, uni_v2, conch_v1_5, gigapath, virchow_v2 8 | - 🔧 **Flexible Fine-tuning Strategies**: LoRA, full parameter fine-tuning, frozen backbone 9 | - 📊 **Complete Training Pipeline**: Mixed precision training, learning rate scheduling, gradient accumulation 10 | - 🎯 **Advanced Data Augmentation**: Integrated 10+ advanced data augmentations including spatial, color, and noise transformations 11 | - 📈 **Comprehensive Evaluation Metrics**: Integrated 10+ evaluation metrics including IoU/Dice and more 12 | - ⚡ **Advanced Inference Pipeline**: Support for arbitrary resolution sliding window inference 13 | 14 | ## 📋 Table of Contents 15 | 16 | - [Dataset Format](#-dataset-format) 17 | - [Configuration File Details](#-configuration-file-details) 18 | - [Training Script Usage](#-training-script-usage) 19 | - [Inference Script Usage](#-inference-script-usage) 20 | - [Pathology Foundation Models Details](#-pathology-foundation-models-details) 21 | 22 | ## 📁 Dataset Format 23 | 24 | ### JSON Configuration File Format 25 | 26 | The dataset uses JSON format for configuration, supporting train, validation, and test set splits: 27 | 28 | ```json 29 | { 30 | "num_classes": 3, 31 | "data": { 32 | "train": [ 33 | { 34 | "image_path": "/path/to/train/image1.jpg", 35 | "mask_path": "/path/to/train/mask1.png" 36 | }, 37 | ], 38 | "val": [ 39 | { 40 | "image_path": "/path/to/val/image1.jpg", 41 | "mask_path": "/path/to/val/mask1.png" 42 | } 43 | ], 44 | "test": [ 45 | { 46 | "image_path": "/path/to/test/image1.jpg", 47 | "mask_path": "/path/to/test/image2.png" 48 | } 49 | ] 50 | } 51 | } 52 | ``` 53 | 54 | During training, only the `train` and `val` fields are used. The `test` field is used when executing inference scripts. The `mask_path` in the test field can be null or missing, in which case the model will not compute metrics. If `mask_path` exists, metrics will be automatically calculated after inference. 55 | 56 | ## ⚙️ Configuration File Details 57 | 58 | The configuration file uses YAML format and includes the following main sections: 59 | 60 | ### Dataset Configuration (dataset) 61 | 62 | ```yaml 63 | dataset: 64 | json_file: "/path/to/dataset.json" # Path to dataset JSON configuration file 65 | num_classes: 3 # Number of classes, must match JSON file 66 | ignore_index: 255 # Pixel value to ignore for uncertain regions 67 | ``` 68 | 69 | ### System Configuration (system) 70 | 71 | ```yaml 72 | system: 73 | num_workers: 4 # Number of processes for data loading 74 | pin_memory: true # Whether to use pin_memory for faster data transfer 75 | seed: 42 # Random seed for reproducible experiments 76 | device: "cuda:0" # Device to use 77 | ``` 78 | 79 | ### Pathology Foundation Model Configuration (model) 🧬 80 | 81 | This is the most important section, controlling the selection and configuration of pathology foundation models: 82 | 83 | ```yaml 84 | model: 85 | # === Base Model Selection === 86 | pfm_name: "uni_v1" # Pathology foundation model name 87 | # Options: 88 | # - "uni_v1" : UNI model version 1 (1024 dim) 89 | # - "uni_v2" : UNI model version 2 (1536 dim) 90 | # - "conch_v1" : Conch model version 1 (1024 dim) 91 | # - "conch_v1_5" : Conch model version 1.5 (768 dim) 92 | # - "virchow_v1" : Virchow model version 2 (1280 dim) 93 | # - "virchow_v2" : Virchow model version 2 (1280 dim) 94 | # - "phikon" : Phikon model (768 dim) 95 | # - "phikon_v2" : Phikon-v2 model (1024 dim) 96 | # - "hoptimus_0" : H-Optimus-0 model (1536 dim) 97 | # - "hoptimus_1" : H-Optimus-1 model (1536 dim) 98 | # - "gigapath" : Gigapath model (1536 dim) 99 | # - "midnight12k" : Midnight-12k model (1536 dim) 100 | # - "kaiko-vitl14" : Kaiko-ViT-L14 model (1024 dim) 101 | # - "lunit_vits8" : Lunit-S8 model (384 dim) 102 | # - 'musk' : MUSK model (1024 dim) 103 | # - "patho3dmatrix-vision": Patho3DMatrix-Vision model (1536 dim) 104 | # - "pathorchestra": PathOrchestra model (1536 dim) 105 | # - "hibou_l" : Hibou-Large model (1024 dim) 106 | 107 | 108 | # === Model Parameter Configuration === 109 | emb_dim: 1024 # Embedding dimension, must match selected PFM model 110 | # Corresponding embedding dimensions for each model: 111 | # midnight12k/hoptimus_0/hoptimus_1/uni_v2/gigapath: 1536 112 | # virchow_v1/virchow_v2: 1280 113 | # uni_v1/hibou_l/musk/phikon_v2/kaiko-vitl14/conch_v1/patho3dmatrix-vision: 1024 114 | # conch_v1_5/phikon: 768 115 | # lunit_vits8: 384 116 | 117 | pfm_weights_path: '/path/to/pytorch_model.bin' # Path to pre-trained weights file 118 | 119 | # === Fine-tuning Strategy Configuration === 120 | finetune_mode: 121 | type: "lora" # Fine-tuning mode 122 | # Options: 123 | # - "lora" : LoRA low-rank adaptation, parameter efficient 124 | # - "full" : Full parameter fine-tuning, best performance but requires more memory 125 | # - "frozen" : Frozen backbone, only train segmentation head 126 | 127 | rank: 16 # LoRA rank, only used when type is "lora" 128 | alpha: 1.0 # LoRA scaling factor, only used when type is "lora" 129 | 130 | num_classes: 3 # Number of segmentation classes, must match dataset.num_classes 131 | ``` 132 | 133 | ### Training Configuration (training) 134 | 135 | ```yaml 136 | training: 137 | # === Basic Training Parameters === 138 | batch_size: 8 # Batch size 139 | epochs: 100 # Number of training epochs 140 | learning_rate: 0.001 # Initial learning rate 141 | weight_decay: 0.0001 # Weight decay 142 | 143 | # === Training Optimization Settings === 144 | use_amp: true # Whether to use mixed precision training 145 | accumulate_grad_batches: 1 # Number of gradient accumulation steps 146 | clip_grad_norm: 5.0 # Gradient clipping threshold 147 | 148 | # === Data Augmentation Configuration === 149 | augmentation: 150 | RandomResizedCropSize: 512 # Random crop size 151 | # Note: Different PFM models have input size requirements 152 | # virchow_v1,virchow_v2,uni_v2,midnight12k,kaiko-vitl14,hibou_l,H-optimus-1,H-optimus-0: must be a multiple of 14 (token_size) 153 | # uni_v1,conch_v1_5,gigapath,conch_v1 ,phikon,phikon_v2,patho3dmatrix-vision: must be a multiple of 16 (token_size) 154 | # lunit_vits8: must be a multiple of 8 (token_size) 155 | # special: musk: 384 156 | 157 | # === Optimizer Configuration === 158 | optimizer: 159 | type: "SGD" # Optimizer type: SGD, Adam, AdamW 160 | momentum: 0.9 # SGD momentum (SGD only) 161 | nesterov: true # Whether to use Nesterov momentum 162 | 163 | # === Learning Rate Scheduler === 164 | scheduler: 165 | type: "cosine" # Scheduler type: cosine, step 166 | warmup_epochs: 2 # Number of warmup epochs 167 | 168 | # === Loss Function === 169 | loss: 170 | type: "dice" # Loss function: cross_entropy, dice, ohem, iou 171 | ``` 172 | 173 | ### Validation Configuration (validation) 174 | 175 | ```yaml 176 | validation: 177 | eval_interval: 1 # Validate every N epochs 178 | batch_size: 16 # Validation batch size 179 | augmentation: 180 | ResizedSize: 512 # Image size during validation 181 | # Note: Different PFM models have input size requirements 182 | # virchow_v1,virchow_v2,uni_v2,midnight12k,kaiko-vitl14,hibou_l,H-optimus-1,H-optimus-0: must be a multiple of 14 (token_size) 183 | # uni_v1,conch_v1_5,gigapath,conch_v1 ,phikon,phikon_v2,patho3dmatrix-vision: must be a multiple of 16 (token_size) 184 | # lunit_vits8: must be a multiple of 8 (token_size) 185 | # special: musk: 384 186 | ``` 187 | 188 | ### Logging and Visualization Configuration 189 | 190 | ```yaml 191 | logging: 192 | log_dir: "/path/to/logs" # Log save directory 193 | experiment_name: "pfm_segmentation" # Experiment name 194 | 195 | visualization: 196 | save_interval: 2 # Save visualization results every N epochs 197 | num_vis_samples: 8 # Number of visualization samples to save 198 | ``` 199 | 200 | ## 🚀 Training Script Usage 201 | 202 | ### Basic Training Command 203 | 204 | ```bash 205 | python scripts/train.py --config configs/config.yaml 206 | ``` 207 | 208 | ### Training Script Parameters Details 209 | 210 | ```bash 211 | python scripts/train.py \ 212 | --config configs/config.yaml \ # Configuration file path 213 | --resume checkpoints/model.pth \ # Resume training from checkpoint (optional) 214 | --device cuda:0 # Specify device (optional, overrides config file) 215 | ``` 216 | 217 | ### Parameter Description 218 | 219 | - `--config`: **Required** Configuration file path containing all training settings 220 | - `--resume`: **Optional** Checkpoint file path for resuming interrupted training 221 | - `--device`: **Optional** Training device, overrides device setting in config file 222 | 223 | ### Training Output 224 | 225 | During training, the following files will be generated: 226 | 227 | ``` 228 | logs/experiment_name/ 229 | ├── config.yaml # Saved copy of configuration file 230 | ├── training.log # Training log 231 | ├── checkpoints/ # Model checkpoints 232 | │ ├── best_model.pth # Best model 233 | ├── visualizations/ # Visualization results 234 | │ ├── epoch_010_sample_00.png 235 | │ └── ... 236 | └── training_history.png # Training curve plot 237 | ``` 238 | 239 | ### Training Monitoring 240 | 241 | During training, the following will be displayed: 242 | - Training loss and validation loss 243 | - Validation metrics (mIoU, Pixel Accuracy, etc.) 244 | - Learning rate changes 245 | - Time consumption per epoch 246 | 247 | ## 🔍 Inference Script Usage 248 | 249 | ### Basic Inference Command 250 | 251 | ```bash 252 | python scripts/infer.py \ 253 | --config logs/experiment_name/config.yaml \ 254 | --checkpoint logs/experiment_name/checkpoints/best_model.pth \ 255 | --input_json dataset/test.json \ 256 | --output_dir results/ 257 | ``` 258 | 259 | ### Inference Script Parameters Details 260 | 261 | ```bash 262 | python scripts/infer.py \ 263 | --config CONFIG_PATH \ # Configuration file used during training 264 | --checkpoint CHECKPOINT_PATH \ # Trained model weights 265 | --input_json INPUT_JSON \ # Input data JSON file 266 | --output_dir OUTPUT_DIR \ # Results save directory 267 | --device cuda:0 \ # Inference device 268 | --input_size 512 \ # Input image size 269 | --resize_or_windowslide windowslide \ # Inference mode 270 | --batch_size 4 # Inference batch size 271 | ``` 272 | 273 | ### Detailed Parameter Description 274 | 275 | | Parameter | Type | Required | Description | 276 | |-----------|------|----------|-------------| 277 | | `--config` | str | ✅ | Configuration file path used during training | 278 | | `--checkpoint` | str | ✅ | Path to model checkpoint file or checkpoint directory. For LoRA/DoRA mode, will automatically load both base model and LoRA/DoRA weights. | 279 | | `--input_json` | str | ✅ | JSON file containing data to be inferred | 280 | | `--output_dir` | str | ✅ | Inference results save directory | 281 | | `--device` | str | ✅ | Inference device, default cuda:0 | 282 | | `--input_size` | int | ✅ | Input image size for model, not original image size | 283 | | `--resize_or_windowslide` | str | ✅ | Inference mode, default windowslide | 284 | | `--batch_size` | int | ✅ | Inference batch size, default 2 | 285 | 286 | ### Inference Mode Selection 287 | 288 | 1. **Resize Mode** (`--resize_or_windowslide resize`) 289 | - Resize input images to fixed size (input_size) for inference 290 | - Resize prediction results back to original image size after inference 291 | 292 | 2. **Window Slide Mode** (`--resize_or_windowslide windowslide`) 293 | - Use sliding window (input_size) strategy to process large images 294 | - Maintains original resolution with higher accuracy 295 | - Merge back to original image size after inference 296 | 297 | ### Inference Output 298 | 299 | After inference completion, the following will be generated: 300 | 301 | ``` 302 | output_dir/ 303 | ├── predictions_masks/ # Prediction masks (grayscale images) 304 | │ ├── image001.png 305 | │ ├── image002.png 306 | │ └── ... 307 | └── predictions_overlays/ # Prediction result visualizations (colored overlay images) 308 | ├── image001.png 309 | ├── image002.png 310 | └── ... 311 | ``` 312 | 313 | ### Inference Result Format 314 | 315 | - **Prediction Masks**: Grayscale PNG images with pixel values corresponding to class indices 316 | - **Visualization Overlays**: Colored overlays of original images with prediction results for intuitive viewing 317 | 318 | ## 🧬 Pathology Foundation Models Details 319 | 320 | ### Supported Models List 321 | 322 | | Model Name | Parameters | Embedding Dim | Token Size | HuggingFace | 323 | |------------|------------|---------------|------------|-------------| 324 | | UNI | 307M | 1024 | 16×16 | [MahmoodLab/UNI](https://huggingface.co/MahmoodLab/UNI) | 325 | | UNI2-h | 1.1B | 1536 | 14×14 | [MahmoodLab/UNI2-h](https://huggingface.co/MahmoodLab/UNI2-h) | 326 | | CONCH | 90M | 768 | 16×16 | [MahmoodLab/CONCH](https://huggingface.co/MahmoodLab/CONCH) | 327 | | CONCHv1.5 | 307M | 1024 | 16×16 | [MahmoodLab/conchv1_5](https://huggingface.co/MahmoodLab/conchv1_5) | 328 | | Virchow | 632M | 2560 | 14×14 | [paige-ai/Virchow](https://huggingface.co/paige-ai/Virchow) | 329 | | Virchow2 | 632M | 2560 | 14×14 | [paige-ai/Virchow2](https://huggingface.co/paige-ai/Virchow2) | 330 | | Phikon | 85.8M | 768 | 16×16 | [owkin/phikon](https://huggingface.co/owkin/phikon) | 331 | | Phikon-v2 | 300M | 1024 | 16×16 | [owkin/phikon-v2](https://huggingface.co/owkin/phikon-v2) | 332 | | Prov-Gigapath | 1.1B | 1536 | 16×16 | [prov-gigapath/prov-gigapath](https://huggingface.co/prov-gigapath/prov-gigapath) | 333 | | H-Optimus-0 | 1.1B | 1536 | 14×14 | [bioptimus/H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | 334 | | H-Optimus-1 | 1.1B | 1536 | 14×14 | [bioptimus/H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) | 335 | | MUSK | - | 1024 | 32×32 | [xiangjx/musk](https://huggingface.co/xiangjx/musk) | 336 | | Midnight-12k | - | 3072 | 14×14 | [kaiko-ai/midnight](https://huggingface.co/kaiko-ai/midnight) | 337 | | Kaiko | Various | 384/768/1024 | Various (8×8 or 16×16 or 14×14) | [1aurent/kaikoai-models-66636c99d8e1e34bc6dcf795](https://huggingface.co/collections/1aurent/kaikoai-models) | 338 | | Lunit | 21.7M | 384 | 8×8 | [1aurent/vit_small_patch8_224.lunit_dino](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) | 339 | | Hibou | - | 1024 | 14×14 | [histai/hibou-L](https://huggingface.co/histai/hibou-L) | 340 | | PathOrchestra | 307M | 1024 | 16×16 | [AI4Pathology/PathOrchestra](https://huggingface.co/AI4Pathology/PathOrchestra) | 341 | | patho3dmatrix-vision | 307M | 1024 | 16×16 | - | 342 | 343 | ## 🤝 Contributing 344 | 345 | Welcome to submit issues and feature requests! Please check the contribution guidelines for more information. 346 | 347 | ## 📞 Contact 348 | 349 | If you have questions or suggestions, please contact us through: 350 | - Submit GitHub Issue 351 | - Send email to: [lingxt23@mails.tsinghua.edu.cn] or [cwm25@mails.tsinghua.edu.cn] 352 | 353 | --- 354 | 355 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation Metrics for Semantic Segmentation 3 | 4 | This module contains comprehensive evaluation metrics including 5 | IoU, Pixel Accuracy, Dice Score, and class-wise statistics. 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | from typing import Dict, List, Optional, Tuple 11 | import torch.nn.functional as F 12 | 13 | 14 | class SegmentationMetrics: 15 | """ 16 | Comprehensive metrics for semantic segmentation evaluation. 17 | 18 | Computes: 19 | - Mean IoU (mIoU) 20 | - Pixel Accuracy 21 | - Mean Accuracy 22 | - Frequency Weighted IoU 23 | - Per-class IoU and Accuracy 24 | - Dice Score 25 | - Precision and Recall 26 | 27 | Args: 28 | num_classes (int): Number of classes 29 | ignore_index (int): Index to ignore in calculations 30 | device (str): Device for computations 31 | """ 32 | 33 | def __init__(self, num_classes: int, ignore_index: int = 255, device: str = 'cpu'): 34 | self.num_classes = num_classes 35 | self.ignore_index = ignore_index 36 | self.device = device 37 | 38 | self.reset() 39 | 40 | def reset(self): 41 | """Reset all metrics.""" 42 | self.confusion_matrix = torch.zeros( 43 | (self.num_classes, self.num_classes), 44 | dtype=torch.int64, 45 | device=self.device 46 | ) 47 | self.total_samples = 0 48 | 49 | def update(self, predictions: torch.Tensor, targets: torch.Tensor): 50 | """ 51 | Update metrics with new predictions and targets. 52 | 53 | Args: 54 | predictions (torch.Tensor): Model predictions of shape (B, C, H, W) 55 | targets (torch.Tensor): Ground truth labels of shape (B, H, W) 56 | """ 57 | # Convert predictions to class indices 58 | if predictions.dim() == 4: # (B, C, H, W) 59 | predictions = torch.argmax(predictions, dim=1) 60 | 61 | # Flatten tensors 62 | predictions = predictions.flatten() 63 | targets = targets.flatten() 64 | 65 | # Create mask for valid pixels 66 | mask = (targets != self.ignore_index) 67 | predictions = predictions[mask] 68 | targets = targets[mask] 69 | 70 | # Update confusion matrix 71 | indices = self.num_classes * targets + predictions 72 | cm_update = torch.bincount(indices, minlength=self.num_classes**2) 73 | cm_update = cm_update.reshape(self.num_classes, self.num_classes) 74 | 75 | self.confusion_matrix += cm_update.to(self.device) 76 | self.total_samples += mask.sum().item() 77 | 78 | def compute_iou(self) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Compute IoU metrics. 81 | 82 | Returns: 83 | Tuple[torch.Tensor, torch.Tensor]: (per_class_iou, mean_iou) 84 | """ 85 | # IoU = TP / (TP + FP + FN) 86 | # TP: diagonal elements 87 | # FP: column sum - diagonal 88 | # FN: row sum - diagonal 89 | 90 | tp = torch.diag(self.confusion_matrix).float() 91 | fp = self.confusion_matrix.sum(dim=0) - tp 92 | fn = self.confusion_matrix.sum(dim=1) - tp 93 | 94 | # Avoid division by zero 95 | denominator = tp + fp + fn 96 | iou = tp / (denominator + 1e-8) 97 | 98 | # Set IoU to 0 for classes that don't appear in ground truth 99 | valid_classes = (denominator > 0) 100 | iou = iou * valid_classes.float() 101 | 102 | mean_iou = iou[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 103 | 104 | return iou, mean_iou 105 | 106 | def compute_pixel_accuracy(self) -> torch.Tensor: 107 | """ 108 | Compute pixel accuracy. 109 | 110 | Returns: 111 | torch.Tensor: Pixel accuracy 112 | """ 113 | correct_pixels = torch.diag(self.confusion_matrix).sum() 114 | total_pixels = self.confusion_matrix.sum() 115 | 116 | return correct_pixels / (total_pixels + 1e-8) 117 | 118 | def compute_mean_accuracy(self) -> torch.Tensor: 119 | """ 120 | Compute mean class accuracy. 121 | 122 | Returns: 123 | torch.Tensor: Mean accuracy 124 | """ 125 | # Class accuracy = TP / (TP + FN) 126 | tp = torch.diag(self.confusion_matrix).float() 127 | total_per_class = self.confusion_matrix.sum(dim=1).float() 128 | 129 | class_accuracy = tp / (total_per_class + 1e-8) 130 | 131 | # Only consider classes that appear in ground truth 132 | valid_classes = (total_per_class > 0) 133 | mean_accuracy = class_accuracy[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 134 | 135 | return mean_accuracy 136 | 137 | def compute_frequency_weighted_iou(self) -> torch.Tensor: 138 | """ 139 | Compute frequency weighted IoU. 140 | 141 | Returns: 142 | torch.Tensor: Frequency weighted IoU 143 | """ 144 | iou, _ = self.compute_iou() 145 | 146 | # Class frequencies 147 | class_frequencies = self.confusion_matrix.sum(dim=1).float() 148 | total_pixels = class_frequencies.sum() 149 | weights = class_frequencies / (total_pixels + 1e-8) 150 | 151 | # Weighted IoU 152 | fwiou = (weights * iou).sum() 153 | 154 | return fwiou 155 | 156 | def compute_dice_score(self) -> Tuple[torch.Tensor, torch.Tensor]: 157 | """ 158 | Compute Dice score metrics. 159 | 160 | Returns: 161 | Tuple[torch.Tensor, torch.Tensor]: (per_class_dice, mean_dice) 162 | """ 163 | # Dice = 2 * TP / (2 * TP + FP + FN) 164 | tp = torch.diag(self.confusion_matrix).float() 165 | fp = self.confusion_matrix.sum(dim=0) - tp 166 | fn = self.confusion_matrix.sum(dim=1) - tp 167 | 168 | dice = (2 * tp) / (2 * tp + fp + fn + 1e-8) 169 | 170 | # Set Dice to 0 for classes that don't appear 171 | valid_classes = ((tp + fp + fn) > 0) 172 | dice = dice * valid_classes.float() 173 | 174 | mean_dice = dice[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 175 | 176 | return dice, mean_dice 177 | 178 | def compute_precision_recall(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 179 | """ 180 | Compute precision and recall metrics. 181 | 182 | Returns: 183 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 184 | (per_class_precision, mean_precision, per_class_recall, mean_recall) 185 | """ 186 | tp = torch.diag(self.confusion_matrix).float() 187 | fp = self.confusion_matrix.sum(dim=0) - tp 188 | fn = self.confusion_matrix.sum(dim=1) - tp 189 | 190 | # Precision = TP / (TP + FP) 191 | precision = tp / (tp + fp + 1e-8) 192 | valid_precision = ((tp + fp) > 0) 193 | precision = precision * valid_precision.float() 194 | mean_precision = precision[valid_precision].mean() if valid_precision.any() else torch.tensor(0.0) 195 | 196 | # Recall = TP / (TP + FN) 197 | recall = tp / (tp + fn + 1e-8) 198 | valid_recall = ((tp + fn) > 0) 199 | recall = recall * valid_recall.float() 200 | mean_recall = recall[valid_recall].mean() if valid_recall.any() else torch.tensor(0.0) 201 | 202 | return precision, mean_precision, recall, mean_recall 203 | 204 | def compute_f1_score(self) -> Tuple[torch.Tensor, torch.Tensor]: 205 | """ 206 | Compute F1 score metrics. 207 | 208 | Returns: 209 | Tuple[torch.Tensor, torch.Tensor]: (per_class_f1, mean_f1) 210 | """ 211 | precision, _, recall, _ = self.compute_precision_recall() 212 | 213 | f1 = 2 * (precision * recall) / (precision + recall + 1e-8) 214 | 215 | # Only consider valid classes 216 | valid_classes = ((precision + recall) > 0) 217 | f1 = f1 * valid_classes.float() 218 | mean_f1 = f1[valid_classes].mean() if valid_classes.any() else torch.tensor(0.0) 219 | 220 | return f1, mean_f1 221 | 222 | def compute(self) -> Dict[str, float]: 223 | """ 224 | Compute all metrics and return as dictionary. 225 | 226 | Returns: 227 | Dict[str, float]: Dictionary containing all metrics 228 | """ 229 | # IoU metrics 230 | per_class_iou, mean_iou = self.compute_iou() 231 | 232 | # Accuracy metrics 233 | pixel_accuracy = self.compute_pixel_accuracy() 234 | mean_accuracy = self.compute_mean_accuracy() 235 | fwiou = self.compute_frequency_weighted_iou() 236 | 237 | # Dice score 238 | per_class_dice, mean_dice = self.compute_dice_score() 239 | 240 | # Precision and Recall 241 | per_class_precision, mean_precision, per_class_recall, mean_recall = self.compute_precision_recall() 242 | 243 | # F1 Score 244 | per_class_f1, mean_f1 = self.compute_f1_score() 245 | 246 | # Convert to float for logging 247 | metrics = { 248 | 'mIoU': mean_iou.item(), 249 | 'Pixel_Accuracy': pixel_accuracy.item(), 250 | 'Mean_Accuracy': mean_accuracy.item(), 251 | 'Frequency_Weighted_IoU': fwiou.item(), 252 | 'Mean_Dice': mean_dice.item(), 253 | 'Mean_Precision': mean_precision.item(), 254 | 'Mean_Recall': mean_recall.item(), 255 | 'Mean_F1': mean_f1.item() 256 | } 257 | 258 | # Add per-class metrics 259 | for i in range(self.num_classes): 260 | metrics[f'IoU_Class_{i}'] = per_class_iou[i].item() 261 | metrics[f'Dice_Class_{i}'] = per_class_dice[i].item() 262 | metrics[f'Precision_Class_{i}'] = per_class_precision[i].item() 263 | metrics[f'Recall_Class_{i}'] = per_class_recall[i].item() 264 | metrics[f'F1_Class_{i}'] = per_class_f1[i].item() 265 | 266 | return metrics 267 | 268 | def get_confusion_matrix(self) -> np.ndarray: 269 | """ 270 | Get confusion matrix as numpy array. 271 | 272 | Returns: 273 | np.ndarray: Confusion matrix 274 | """ 275 | return self.confusion_matrix.cpu().numpy() 276 | 277 | def print_class_metrics(self, class_names: Optional[List[str]] = None): 278 | """ 279 | Print detailed per-class metrics. 280 | 281 | Args: 282 | class_names (Optional[List[str]]): Names of classes 283 | """ 284 | if class_names is None: 285 | class_names = [f"Class {i}" for i in range(self.num_classes)] 286 | 287 | # Compute metrics 288 | per_class_iou, mean_iou = self.compute_iou() 289 | per_class_dice, mean_dice = self.compute_dice_score() 290 | per_class_precision, mean_precision, per_class_recall, mean_recall = self.compute_precision_recall() 291 | per_class_f1, mean_f1 = self.compute_f1_score() 292 | 293 | print("\nPer-Class Metrics:") 294 | print("-" * 80) 295 | print(f"{'Class':<20} {'IoU':<8} {'Dice':<8} {'Precision':<10} {'Recall':<8} {'F1':<8}") 296 | print("-" * 80) 297 | 298 | for i in range(self.num_classes): 299 | class_name = class_names[i] if i < len(class_names) else f"Class {i}" 300 | print( 301 | f"{class_name:<20} " 302 | f"{per_class_iou[i]:.4f} " 303 | f"{per_class_dice[i]:.4f} " 304 | f"{per_class_precision[i]:.4f} " 305 | f"{per_class_recall[i]:.4f} " 306 | f"{per_class_f1[i]:.4f}" 307 | ) 308 | 309 | print("-" * 80) 310 | print( 311 | f"{'Mean':<20} " 312 | f"{mean_iou:.4f} " 313 | f"{mean_dice:.4f} " 314 | f"{mean_precision:.4f} " 315 | f"{mean_recall:.4f} " 316 | f"{mean_f1:.4f}" 317 | ) 318 | print("-" * 80) 319 | 320 | 321 | class StreamingMetrics: 322 | """ 323 | Streaming version of metrics for large datasets that don't fit in memory. 324 | """ 325 | 326 | def __init__(self, num_classes: int, ignore_index: int = 255): 327 | self.num_classes = num_classes 328 | self.ignore_index = ignore_index 329 | self.reset() 330 | 331 | def reset(self): 332 | """Reset metrics.""" 333 | self.tp = np.zeros(self.num_classes, dtype=np.int64) 334 | self.fp = np.zeros(self.num_classes, dtype=np.int64) 335 | self.fn = np.zeros(self.num_classes, dtype=np.int64) 336 | self.total_pixels = 0 337 | self.correct_pixels = 0 338 | 339 | def update(self, predictions: np.ndarray, targets: np.ndarray): 340 | """ 341 | Update metrics with new predictions and targets. 342 | 343 | Args: 344 | predictions (np.ndarray): Model predictions 345 | targets (np.ndarray): Ground truth labels 346 | """ 347 | # Flatten arrays 348 | predictions = predictions.flatten() 349 | targets = targets.flatten() 350 | 351 | # Create mask for valid pixels 352 | mask = (targets != self.ignore_index) 353 | predictions = predictions[mask] 354 | targets = targets[mask] 355 | 356 | # Update pixel counts 357 | self.total_pixels += len(targets) 358 | self.correct_pixels += np.sum(predictions == targets) 359 | 360 | # Update per-class counts 361 | for c in range(self.num_classes): 362 | pred_mask = (predictions == c) 363 | target_mask = (targets == c) 364 | 365 | self.tp[c] += np.sum(pred_mask & target_mask) 366 | self.fp[c] += np.sum(pred_mask & ~target_mask) 367 | self.fn[c] += np.sum(~pred_mask & target_mask) 368 | 369 | def compute_metrics(self) -> Dict[str, float]: 370 | """ 371 | Compute metrics from accumulated counts. 372 | 373 | Returns: 374 | Dict[str, float]: Computed metrics 375 | """ 376 | # IoU 377 | iou = self.tp / (self.tp + self.fp + self.fn + 1e-8) 378 | valid_classes = (self.tp + self.fp + self.fn) > 0 379 | mean_iou = np.mean(iou[valid_classes]) if np.any(valid_classes) else 0.0 380 | 381 | # Pixel accuracy 382 | pixel_accuracy = self.correct_pixels / (self.total_pixels + 1e-8) 383 | 384 | # Mean accuracy 385 | class_accuracy = self.tp / (self.tp + self.fn + 1e-8) 386 | mean_accuracy = np.mean(class_accuracy[valid_classes]) if np.any(valid_classes) else 0.0 387 | 388 | # Precision and Recall 389 | precision = self.tp / (self.tp + self.fp + 1e-8) 390 | recall = self.tp / (self.tp + self.fn + 1e-8) 391 | 392 | valid_precision = (self.tp + self.fp) > 0 393 | valid_recall = (self.tp + self.fn) > 0 394 | 395 | mean_precision = np.mean(precision[valid_precision]) if np.any(valid_precision) else 0.0 396 | mean_recall = np.mean(recall[valid_recall]) if np.any(valid_recall) else 0.0 397 | 398 | # F1 Score 399 | f1 = 2 * (precision * recall) / (precision + recall + 1e-8) 400 | valid_f1 = (precision + recall) > 0 401 | mean_f1 = np.mean(f1[valid_f1]) if np.any(valid_f1) else 0.0 402 | 403 | return { 404 | 'mIoU': float(mean_iou), 405 | 'Pixel_Accuracy': float(pixel_accuracy), 406 | 'Mean_Accuracy': float(mean_accuracy), 407 | 'Mean_Precision': float(mean_precision), 408 | 'Mean_Recall': float(mean_recall), 409 | 'Mean_F1': float(mean_f1) 410 | } 411 | 412 | 413 | if __name__ == "__main__": 414 | # Test metrics 415 | num_classes = 19 416 | batch_size = 2 417 | height, width = 64, 64 418 | 419 | # Create dummy data 420 | predictions = torch.randn(batch_size, num_classes, height, width) 421 | targets = torch.randint(0, num_classes, (batch_size, height, width)) 422 | 423 | # Test SegmentationMetrics 424 | metrics = SegmentationMetrics(num_classes) 425 | metrics.update(predictions, targets) 426 | 427 | computed_metrics = metrics.compute() 428 | print("Computed metrics:") 429 | for key, value in computed_metrics.items(): 430 | if not key.startswith(('IoU_Class', 'Dice_Class', 'Precision_Class', 'Recall_Class', 'F1_Class')): 431 | print(f"{key}: {value:.4f}") 432 | 433 | # Test class-wise metrics 434 | metrics.print_class_metrics() 435 | 436 | print("\nMetrics module test completed successfully!") 437 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Transforms for Semantic Segmentation 3 | 4 | This module contains data augmentation and preprocessing transforms 5 | using albumentations library for robust training. 6 | """ 7 | 8 | import albumentations as A 9 | from albumentations.pytorch import ToTensorV2 10 | import cv2 11 | import numpy as np 12 | from typing import List, Dict, Any, Optional, Callable, Tuple 13 | import torch 14 | 15 | 16 | def get_model_normalization(pfm_name: str) -> Tuple[List[float], List[float]]: 17 | """ 18 | Get normalization mean and std values for a given PFM model. 19 | 20 | Args: 21 | pfm_name (str): Name of the PFM model 22 | 23 | Returns: 24 | Tuple[List[float], List[float]]: (mean, std) values for normalization 25 | 26 | Note: 27 | - conch_v1 uses CLIP normalization values 28 | - Other models use ImageNet normalization values 29 | """ 30 | pfm_name = pfm_name.lower() 31 | 32 | # Conch v1 uses CLIP normalization 33 | if pfm_name == 'conch_v1': 34 | mean = [0.48145466, 0.4578275, 0.40821073] 35 | std = [0.26862954, 0.26130258, 0.27577711] 36 | elif pfm_name == 'conch_v1_5': 37 | mean = (0.485, 0.456, 0.406) 38 | std = (0.229, 0.224, 0.225) 39 | elif pfm_name == 'virchow_v1': 40 | mean = [0.485, 0.456, 0.406] 41 | std = [0.229, 0.224, 0.225] 42 | elif pfm_name == 'virchow_v2': 43 | mean = [0.485, 0.456, 0.406] 44 | std = [0.229, 0.224, 0.225] 45 | elif pfm_name == 'gigapath': 46 | mean = (0.485, 0.456, 0.406) 47 | std = (0.229, 0.224, 0.225) 48 | elif pfm_name == 'patho3dmatrix-vision': 49 | mean = [0.485, 0.456, 0.406] 50 | std = [0.229, 0.224, 0.225] 51 | elif pfm_name == 'uni_v2': 52 | mean = [0.485, 0.456, 0.406] 53 | std = [0.229, 0.224, 0.225] 54 | elif pfm_name == 'uni_v1': 55 | mean = [0.485, 0.456, 0.406] 56 | std = [0.229, 0.224, 0.225] 57 | elif pfm_name == 'phikon' or pfm_name == 'phikon_v2': 58 | # Phikon uses ImageNet normalization 59 | mean = [0.485, 0.456, 0.406] 60 | std = [0.229, 0.224, 0.225] 61 | elif pfm_name == 'hoptimus_0' or pfm_name == 'hoptimus_1': 62 | mean=(0.707223, 0.578729, 0.703617) 63 | std=(0.211883, 0.230117, 0.177517) 64 | elif pfm_name == 'musk': 65 | mean = [0.485, 0.456, 0.406] 66 | std = [0.229, 0.224, 0.225] 67 | elif pfm_name == 'midnight12k': 68 | mean = [0.5, 0.5, 0.5] 69 | std = [0.5, 0.5, 0.5] 70 | elif pfm_name.startswith('kaiko-'): 71 | mean = [0.5, 0.5, 0.5] 72 | std = [0.5, 0.5, 0.5] 73 | elif pfm_name == 'hibou_l': 74 | mean = [0.7068,0.5755,0.722] 75 | std = [0.195,0.2316,0.1816] 76 | else: 77 | # Default ImageNet normalization for other models 78 | # (uni_v1, uni_v2, virchow_v1, virchow_v2, gigapath, 79 | # patho3dmatrix-vision, conch_v1_5, unet, phikon) 80 | mean = [0.485, 0.456, 0.406] 81 | std = [0.229, 0.224, 0.225] 82 | 83 | return mean, std 84 | 85 | class SegmentationTransforms: 86 | """ 87 | Collection of segmentation-specific transforms. 88 | """ 89 | 90 | @staticmethod 91 | def get_training_transforms(img_size: int = 512, 92 | mean: List[float] = [0.485, 0.456, 0.406], 93 | std: List[float] = [0.229, 0.224, 0.225], 94 | seed: int = 42) -> A.Compose: 95 | """ 96 | Get training transforms with strong augmentations. 97 | 98 | Args: 99 | img_size (int): Target image size 100 | mean (List[float]): Normalization mean 101 | std (List[float]): Normalization standard deviation 102 | 103 | Returns: 104 | A.Compose: Composed transforms 105 | """ 106 | return A.Compose([ 107 | # Geometric transforms 108 | A.RandomResizedCrop(size=(img_size,img_size), scale=(0.5, 1.0), ratio=(0.75, 1.33), p=1.0), 109 | A.HorizontalFlip(p=0.5), 110 | A.VerticalFlip(p=0.1), 111 | A.RandomRotate90(p=0.3), 112 | A.Transpose(p=0.3), 113 | 114 | # Spatial transforms 115 | A.OneOf([ 116 | A.ElasticTransform(alpha=1, sigma=50, p=1.0), 117 | A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0), 118 | A.OpticalDistortion(distort_limit=0.2, p=1.0), 119 | ], p=0.3), 120 | 121 | # Color transforms 122 | A.OneOf([ 123 | A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=1.0), 124 | A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0), 125 | A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=1.0), 126 | ], p=0.5), 127 | 128 | # Noise and blur 129 | A.OneOf([ 130 | A.GaussNoise(p=1.0), 131 | A.MultiplicativeNoise(multiplier=(0.9, 1.1), per_channel=True, p=1.0), 132 | ], p=0.3), 133 | 134 | A.OneOf([ 135 | A.GaussianBlur(blur_limit=(3, 7), p=1.0), 136 | A.MotionBlur(blur_limit=7, p=1.0), 137 | A.MedianBlur(blur_limit=7, p=1.0), 138 | ], p=0.2), 139 | 140 | # Weather effects 141 | A.OneOf([ 142 | A.RandomRain(brightness_coefficient=0.7, p=1.0), 143 | A.RandomSnow(brightness_coeff=2.5, p=1.0), 144 | A.RandomFog(alpha_coef=0.08, p=1.0), 145 | ], p=0.2), 146 | 147 | # Lighting 148 | A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.4), 149 | A.RandomGamma(gamma_limit=(70, 130), p=0.3), 150 | 151 | # Cutout and mixing 152 | A.CoarseDropout(p=0.3), 153 | 154 | # Normalization 155 | A.Normalize(mean=mean, std=std), 156 | ToTensorV2(), 157 | ], seed=seed) 158 | 159 | @staticmethod 160 | def get_validation_transforms(img_size: int = 512, 161 | mean: List[float] = [0.485, 0.456, 0.406], 162 | std: List[float] = [0.229, 0.224, 0.225]) -> A.Compose: 163 | """ 164 | Get validation transforms with minimal augmentation. 165 | 166 | Args: 167 | img_size (int): Target image size 168 | mean (List[float]): Normalization mean 169 | std (List[float]): Normalization standard deviation 170 | 171 | Returns: 172 | A.Compose: Composed transforms 173 | """ 174 | if img_size != None: 175 | return A.Compose([ 176 | A.Resize(height=img_size, width=img_size), 177 | A.Normalize(mean=mean, std=std), 178 | ToTensorV2(), 179 | ]) 180 | else: 181 | return A.Compose([ 182 | A.Normalize(mean=mean, std=std), 183 | ToTensorV2(), 184 | ]) 185 | 186 | 187 | def parse_transform_config(config: Dict[str, Any]) -> A.Compose: 188 | """ 189 | Parse transform configuration and create albumentations transforms. 190 | 191 | Args: 192 | config (Dict[str, Any]): Transform configuration 193 | 194 | Returns: 195 | A.Compose: Composed transforms 196 | """ 197 | transforms = [] 198 | 199 | for transform_config in config: 200 | transform_type = transform_config['type'] 201 | transform_params = {k: v for k, v in transform_config.items() if k != 'type'} 202 | 203 | # Get transform class from albumentations 204 | if hasattr(A, transform_type): 205 | transform_class = getattr(A, transform_type) 206 | transforms.append(transform_class(**transform_params)) 207 | elif transform_type == 'ToTensorV2': 208 | transforms.append(ToTensorV2()) 209 | else: 210 | raise ValueError(f"Unknown transform type: {transform_type}") 211 | 212 | return A.Compose(transforms) 213 | 214 | 215 | def get_transforms(transform_config: List[Dict[str, Any]]) -> A.Compose: 216 | """ 217 | Factory function to create transforms from configuration. 218 | 219 | Args: 220 | transform_config (List[Dict[str, Any]]): List of transform configurations 221 | 222 | Returns: 223 | A.Compose: Composed transforms 224 | """ 225 | if isinstance(transform_config, list): 226 | return parse_transform_config(transform_config) 227 | else: 228 | raise ValueError("Transform config must be a list of dictionaries") 229 | 230 | 231 | class MixUp: 232 | """ 233 | MixUp augmentation for semantic segmentation. 234 | """ 235 | 236 | def __init__(self, alpha: float = 1.0, p: float = 0.5): 237 | self.alpha = alpha 238 | self.p = p 239 | 240 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 241 | """ 242 | Apply MixUp to a batch of data. 243 | 244 | Args: 245 | batch (Dict[str, torch.Tensor]): Batch containing images and labels 246 | 247 | Returns: 248 | Dict[str, torch.Tensor]: Mixed batch 249 | """ 250 | if np.random.random() > self.p: 251 | return batch 252 | 253 | images = batch['image'] 254 | labels = batch['label'] 255 | 256 | batch_size = images.size(0) 257 | indices = torch.randperm(batch_size) 258 | 259 | # Sample lambda from Beta distribution 260 | lam = np.random.beta(self.alpha, self.alpha) 261 | 262 | # Mix images 263 | mixed_images = lam * images + (1 - lam) * images[indices] 264 | 265 | # For segmentation, we need to handle labels differently 266 | # We can either use the original labels or create mixed labels 267 | mixed_labels = labels # Keep original labels for simplicity 268 | 269 | return { 270 | 'image': mixed_images, 271 | 'label': mixed_labels, 272 | 'lambda': lam, 273 | 'indices': indices 274 | } 275 | 276 | 277 | class CutMix: 278 | """ 279 | CutMix augmentation for semantic segmentation. 280 | """ 281 | 282 | def __init__(self, alpha: float = 1.0, p: float = 0.5): 283 | self.alpha = alpha 284 | self.p = p 285 | 286 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 287 | """ 288 | Apply CutMix to a batch of data. 289 | 290 | Args: 291 | batch (Dict[str, torch.Tensor]): Batch containing images and labels 292 | 293 | Returns: 294 | Dict[str, torch.Tensor]: Cut-mixed batch 295 | """ 296 | if np.random.random() > self.p: 297 | return batch 298 | 299 | images = batch['image'] 300 | labels = batch['label'] 301 | 302 | batch_size, _, height, width = images.shape 303 | indices = torch.randperm(batch_size) 304 | 305 | # Sample lambda and bounding box 306 | lam = np.random.beta(self.alpha, self.alpha) 307 | 308 | # Generate random bounding box 309 | cut_ratio = np.sqrt(1.0 - lam) 310 | cut_w = int(width * cut_ratio) 311 | cut_h = int(height * cut_ratio) 312 | 313 | cx = np.random.randint(width) 314 | cy = np.random.randint(height) 315 | 316 | bbx1 = np.clip(cx - cut_w // 2, 0, width) 317 | bby1 = np.clip(cy - cut_h // 2, 0, height) 318 | bbx2 = np.clip(cx + cut_w // 2, 0, width) 319 | bby2 = np.clip(cy + cut_h // 2, 0, height) 320 | 321 | # Apply CutMix 322 | mixed_images = images.clone() 323 | mixed_labels = labels.clone() 324 | 325 | mixed_images[:, :, bby1:bby2, bbx1:bbx2] = images[indices, :, bby1:bby2, bbx1:bbx2] 326 | mixed_labels[:, bby1:bby2, bbx1:bbx2] = labels[indices, bby1:bby2, bbx1:bbx2] 327 | 328 | return { 329 | 'image': mixed_images, 330 | 'label': mixed_labels, 331 | 'lambda': lam, 332 | 'indices': indices, 333 | 'bbox': (bbx1, bby1, bbx2, bby2) 334 | } 335 | 336 | 337 | class Mosaic: 338 | """ 339 | Mosaic augmentation for semantic segmentation. 340 | """ 341 | 342 | def __init__(self, p: float = 0.5): 343 | self.p = p 344 | 345 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 346 | """ 347 | Apply Mosaic to a batch of data. 348 | 349 | Args: 350 | batch (Dict[str, torch.Tensor]): Batch containing images and labels 351 | 352 | Returns: 353 | Dict[str, torch.Tensor]: Mosaic batch 354 | """ 355 | if np.random.random() > self.p or batch['image'].size(0) < 4: 356 | return batch 357 | 358 | images = batch['image'] 359 | labels = batch['label'] 360 | 361 | batch_size, channels, height, width = images.shape 362 | 363 | # Create mosaic for first sample 364 | mosaic_image = torch.zeros(channels, height, width) 365 | mosaic_label = torch.zeros(height, width, dtype=labels.dtype) 366 | 367 | # Divide image into 4 quadrants 368 | h_mid = height // 2 369 | w_mid = width // 2 370 | 371 | indices = torch.randperm(batch_size)[:4] 372 | 373 | # Top-left 374 | mosaic_image[:, :h_mid, :w_mid] = images[indices[0], :, :h_mid, :w_mid] 375 | mosaic_label[:h_mid, :w_mid] = labels[indices[0], :h_mid, :w_mid] 376 | 377 | # Top-right 378 | mosaic_image[:, :h_mid, w_mid:] = images[indices[1], :, :h_mid, w_mid:] 379 | mosaic_label[:h_mid, w_mid:] = labels[indices[1], :h_mid, w_mid:] 380 | 381 | # Bottom-left 382 | mosaic_image[:, h_mid:, :w_mid] = images[indices[2], :, h_mid:, :w_mid] 383 | mosaic_label[h_mid:, :w_mid] = labels[indices[2], h_mid:, :w_mid] 384 | 385 | # Bottom-right 386 | mosaic_image[:, h_mid:, w_mid:] = images[indices[3], :, h_mid:, w_mid:] 387 | mosaic_label[h_mid:, w_mid:] = labels[indices[3], h_mid:, w_mid:] 388 | 389 | # Replace first sample with mosaic 390 | new_images = images.clone() 391 | new_labels = labels.clone() 392 | new_images[0] = mosaic_image 393 | new_labels[0] = mosaic_label 394 | 395 | return { 396 | 'image': new_images, 397 | 'label': new_labels 398 | } 399 | 400 | 401 | class AdvancedAugmentationPipeline: 402 | """ 403 | Advanced augmentation pipeline combining multiple techniques. 404 | """ 405 | 406 | def __init__(self, mixup_p: float = 0.3, cutmix_p: float = 0.3, mosaic_p: float = 0.2): 407 | self.mixup = MixUp(p=mixup_p) 408 | self.cutmix = CutMix(p=cutmix_p) 409 | self.mosaic = Mosaic(p=mosaic_p) 410 | 411 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 412 | """ 413 | Apply advanced augmentations to batch. 414 | 415 | Args: 416 | batch (Dict[str, torch.Tensor]): Input batch 417 | 418 | Returns: 419 | Dict[str, torch.Tensor]: Augmented batch 420 | """ 421 | # Apply augmentations in random order 422 | augmentations = [self.mixup, self.cutmix, self.mosaic] 423 | np.random.shuffle(augmentations) 424 | 425 | for aug in augmentations: 426 | batch = aug(batch) 427 | 428 | return batch 429 | 430 | 431 | if __name__ == "__main__": 432 | # Test transforms 433 | from PIL import Image 434 | import numpy as np 435 | 436 | # Create dummy data 437 | image = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) 438 | mask = np.random.randint(0, 19, (512, 512), dtype=np.uint8) 439 | 440 | # Test training transforms 441 | train_transforms = SegmentationTransforms.get_training_transforms() 442 | transformed = train_transforms(image=image, mask=mask) 443 | 444 | print(f"Original image shape: {image.shape}") 445 | print(f"Transformed image shape: {transformed['image'].shape}") 446 | print(f"Transformed mask shape: {transformed['mask'].shape}") 447 | 448 | # Test validation transforms 449 | val_transforms = SegmentationTransforms.get_validation_transforms() 450 | val_transformed = val_transforms(image=image, mask=mask) 451 | 452 | print(f"Validation image shape: {val_transformed['image'].shape}") 453 | print(f"Validation mask shape: {val_transformed['mask'].shape}") 454 | 455 | # Test TTA transforms 456 | tta_transforms = SegmentationTransforms.get_test_time_augmentation_transforms() 457 | print(f"Number of TTA transforms: {len(tta_transforms)}") 458 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Utilities for Semantic Segmentation 3 | 4 | This module contains utility functions for data loading, preprocessing, 5 | and dataset management. 6 | """ 7 | 8 | import torch 9 | from torch.utils.data import DataLoader, DistributedSampler 10 | import numpy as np 11 | from typing import Dict, List, Optional, Tuple, Any, Callable 12 | import cv2 13 | from PIL import Image 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | import seaborn as sns 18 | from collections import Counter 19 | import os 20 | 21 | 22 | def create_dataloader(dataset, batch_size: int = 8, shuffle: bool = True, 23 | num_workers: int = 4, pin_memory: bool = True, 24 | drop_last: bool = True, distributed: bool = False, generator = None, worker_init_fn = None) -> DataLoader: 25 | """ 26 | Create DataLoader with appropriate settings. 27 | 28 | Args: 29 | dataset: PyTorch dataset 30 | batch_size (int): Batch size 31 | shuffle (bool): Whether to shuffle data 32 | num_workers (int): Number of worker processes 33 | pin_memory (bool): Whether to pin memory 34 | drop_last (bool): Whether to drop last incomplete batch 35 | distributed (bool): Whether to use distributed training 36 | generator: Random number generator for reproducibility 37 | worker_init_fn (Callable): Function to initialize workers 38 | 39 | Returns: 40 | DataLoader: Configured data loader 41 | """ 42 | sampler = None 43 | if distributed: 44 | sampler = DistributedSampler(dataset, shuffle=shuffle) 45 | shuffle = False # Disable shuffle when using sampler 46 | 47 | return DataLoader( 48 | dataset=dataset, 49 | batch_size=batch_size, 50 | shuffle=shuffle, 51 | num_workers=num_workers, 52 | pin_memory=pin_memory, 53 | drop_last=drop_last, 54 | sampler=sampler, 55 | generator=generator, 56 | collate_fn=segmentation_collate_fn, 57 | worker_init_fn=worker_init_fn, 58 | ) 59 | 60 | 61 | def segmentation_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 62 | """ 63 | Custom collate function for segmentation data. 64 | 65 | Args: 66 | batch (List[Dict[str, Any]]): List of sample dictionaries 67 | 68 | Returns: 69 | Dict[str, torch.Tensor]: Batched data 70 | """ 71 | images = [] 72 | labels = [] 73 | ori_sizes = [] 74 | image_paths = [] 75 | label_paths = [] 76 | 77 | for sample in batch: 78 | images.append(sample['image']) 79 | labels.append(sample['label']) 80 | image_paths.append(sample['image_path']) 81 | label_paths.append(sample['label_path']) 82 | ori_sizes.append(sample['ori_size']) 83 | 84 | # Stack images and labels 85 | images = torch.stack(images, dim=0) 86 | labels = torch.stack(labels, dim=0) 87 | 88 | return { 89 | 'image': images, 90 | 'label': labels, 91 | 'ori_size': ori_sizes, 92 | 'image_path': image_paths, 93 | 'label_path': label_paths 94 | } 95 | 96 | 97 | def compute_class_distribution(dataset, num_classes: int, ignore_index: int = 255) -> Dict[str, Any]: 98 | """ 99 | Compute class distribution statistics for a dataset. 100 | 101 | Args: 102 | dataset: Segmentation dataset 103 | num_classes (int): Number of classes 104 | ignore_index (int): Index to ignore in calculations 105 | 106 | Returns: 107 | Dict[str, Any]: Class distribution statistics 108 | """ 109 | class_counts = np.zeros(num_classes, dtype=np.int64) 110 | total_pixels = 0 111 | 112 | print("Computing class distribution...") 113 | for i, sample in enumerate(dataset): 114 | if i % 100 == 0: 115 | print(f"Processed {i}/{len(dataset)} samples") 116 | 117 | label = sample['label'].numpy() 118 | mask = (label != ignore_index) 119 | 120 | for c in range(num_classes): 121 | class_counts[c] += np.sum(label == c) 122 | total_pixels += np.sum(mask) 123 | 124 | # Compute statistics 125 | class_frequencies = class_counts / total_pixels 126 | class_weights = 1.0 / (class_frequencies + 1e-8) 127 | class_weights = class_weights / class_weights.sum() * num_classes 128 | 129 | return { 130 | 'class_counts': class_counts, 131 | 'class_frequencies': class_frequencies, 132 | 'class_weights': class_weights, 133 | 'total_pixels': total_pixels 134 | } 135 | 136 | 137 | def visualize_class_distribution(class_stats: Dict[str, Any], class_names: Optional[List[str]] = None, 138 | save_path: Optional[str] = None) -> None: 139 | """ 140 | Visualize class distribution statistics. 141 | 142 | Args: 143 | class_stats (Dict[str, Any]): Class statistics from compute_class_distribution 144 | class_names (Optional[List[str]]): Names of classes 145 | save_path (Optional[str]): Path to save the plot 146 | """ 147 | num_classes = len(class_stats['class_counts']) 148 | 149 | if class_names is None: 150 | class_names = [f"Class {i}" for i in range(num_classes)] 151 | 152 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) 153 | 154 | # Plot class counts 155 | bars1 = ax1.bar(range(num_classes), class_stats['class_counts']) 156 | ax1.set_xlabel('Class') 157 | ax1.set_ylabel('Pixel Count') 158 | ax1.set_title('Class Distribution (Pixel Counts)') 159 | ax1.set_xticks(range(num_classes)) 160 | ax1.set_xticklabels(class_names, rotation=45, ha='right') 161 | 162 | # Add value labels on bars 163 | for i, bar in enumerate(bars1): 164 | height = bar.get_height() 165 | ax1.text(bar.get_x() + bar.get_width()/2., height, 166 | f'{int(height):,}', ha='center', va='bottom', fontsize=8) 167 | 168 | # Plot class frequencies 169 | bars2 = ax2.bar(range(num_classes), class_stats['class_frequencies']) 170 | ax2.set_xlabel('Class') 171 | ax2.set_ylabel('Frequency') 172 | ax2.set_title('Class Distribution (Frequencies)') 173 | ax2.set_xticks(range(num_classes)) 174 | ax2.set_xticklabels(class_names, rotation=45, ha='right') 175 | 176 | # Add value labels on bars 177 | for i, bar in enumerate(bars2): 178 | height = bar.get_height() 179 | ax2.text(bar.get_x() + bar.get_width()/2., height, 180 | f'{height:.3f}', ha='center', va='bottom', fontsize=8) 181 | 182 | plt.tight_layout() 183 | 184 | if save_path: 185 | plt.savefig(save_path, dpi=300, bbox_inches='tight') 186 | print(f"Class distribution plot saved to: {save_path}") 187 | 188 | plt.show() 189 | 190 | 191 | def visualize_sample(sample: Dict[str, torch.Tensor], class_colors: Optional[List[List[int]]] = None, 192 | class_names: Optional[List[str]] = None, save_path: Optional[str] = None) -> None: 193 | """ 194 | Visualize a single sample with image and label overlay. 195 | 196 | Args: 197 | sample (Dict[str, torch.Tensor]): Sample containing image and label 198 | class_colors (Optional[List[List[int]]]): RGB colors for each class 199 | class_names (Optional[List[str]]): Names of classes 200 | save_path (Optional[str]): Path to save the visualization 201 | """ 202 | image = sample['image'] 203 | label = sample['label'] 204 | 205 | # Convert tensors to numpy arrays 206 | if isinstance(image, torch.Tensor): 207 | if image.dim() == 3: # C, H, W 208 | image = image.permute(1, 2, 0) 209 | image = image.numpy() 210 | 211 | if isinstance(label, torch.Tensor): 212 | label = label.numpy() 213 | 214 | # Denormalize image if needed 215 | if image.max() <= 1.0: 216 | image = (image * 255).astype(np.uint8) 217 | 218 | # Create colored label map 219 | if class_colors is None: 220 | # Generate random colors 221 | num_classes = int(label.max()) + 1 222 | class_colors = plt.cm.tab20(np.linspace(0, 1, num_classes))[:, :3] * 255 223 | class_colors = class_colors.astype(np.uint8) 224 | 225 | colored_label = np.zeros((*label.shape, 3), dtype=np.uint8) 226 | for class_id, color in enumerate(class_colors): 227 | colored_label[label == class_id] = color 228 | 229 | # Create overlay 230 | alpha = 0.6 231 | overlay = cv2.addWeighted(image, 1 - alpha, colored_label, alpha, 0) 232 | 233 | # Create visualization 234 | fig, axes = plt.subplots(1, 3, figsize=(15, 5)) 235 | 236 | # Original image 237 | axes[0].imshow(image) 238 | axes[0].set_title('Original Image') 239 | axes[0].axis('off') 240 | 241 | # Label map 242 | axes[1].imshow(colored_label) 243 | axes[1].set_title('Label Map') 244 | axes[1].axis('off') 245 | 246 | # Overlay 247 | axes[2].imshow(overlay) 248 | axes[2].set_title('Overlay') 249 | axes[2].axis('off') 250 | 251 | plt.tight_layout() 252 | 253 | if save_path: 254 | plt.savefig(save_path, dpi=300, bbox_inches='tight') 255 | print(f"Sample visualization saved to: {save_path}") 256 | 257 | plt.show() 258 | 259 | 260 | def create_color_map(num_classes: int) -> np.ndarray: 261 | """ 262 | Create a color map for visualization. 263 | 264 | Args: 265 | num_classes (int): Number of classes 266 | 267 | Returns: 268 | np.ndarray: Color map of shape (num_classes, 3) 269 | """ 270 | colors = [] 271 | for i in range(num_classes): 272 | # Generate distinct colors using HSV space 273 | hue = i / num_classes 274 | saturation = 0.7 + 0.3 * (i % 2) # Alternate between high and higher saturation 275 | value = 0.8 + 0.2 * ((i // 2) % 2) # Alternate brightness 276 | 277 | # Convert HSV to RGB 278 | hsv = np.array([hue, saturation, value]).reshape(1, 1, 3) 279 | rgb = cv2.cvtColor((hsv * 255).astype(np.uint8), cv2.COLOR_HSV2RGB)[0, 0] 280 | colors.append(rgb) 281 | 282 | return np.array(colors) 283 | 284 | 285 | def analyze_dataset_quality(dataset, sample_ratio: float = 0.1) -> Dict[str, Any]: 286 | """ 287 | Analyze dataset quality metrics. 288 | 289 | Args: 290 | dataset: Segmentation dataset 291 | sample_ratio (float): Ratio of samples to analyze 292 | 293 | Returns: 294 | Dict[str, Any]: Quality analysis results 295 | """ 296 | num_samples = int(len(dataset) * sample_ratio) 297 | indices = np.random.choice(len(dataset), num_samples, replace=False) 298 | 299 | image_sizes = [] 300 | label_coverage = [] # Percentage of labeled pixels 301 | class_diversity = [] # Number of unique classes per image 302 | 303 | print(f"Analyzing dataset quality on {num_samples} samples...") 304 | 305 | for i, idx in enumerate(indices): 306 | if i % 50 == 0: 307 | print(f"Processed {i}/{num_samples} samples") 308 | 309 | sample = dataset[idx] 310 | image = sample['image'] 311 | label = sample['label'] 312 | 313 | # Image size 314 | if isinstance(image, torch.Tensor): 315 | h, w = image.shape[-2:] 316 | else: 317 | h, w = image.shape[:2] 318 | image_sizes.append((h, w)) 319 | 320 | # Label coverage 321 | if isinstance(label, torch.Tensor): 322 | label_np = label.numpy() 323 | else: 324 | label_np = label 325 | 326 | valid_pixels = np.sum(label_np != 255) # Assuming 255 is ignore_index 327 | total_pixels = label_np.size 328 | coverage = valid_pixels / total_pixels 329 | label_coverage.append(coverage) 330 | 331 | # Class diversity 332 | unique_classes = len(np.unique(label_np[label_np != 255])) 333 | class_diversity.append(unique_classes) 334 | 335 | # Compute statistics 336 | unique_sizes = list(set(image_sizes)) 337 | size_consistency = len(unique_sizes) == 1 338 | 339 | return { 340 | 'num_samples_analyzed': num_samples, 341 | 'unique_image_sizes': unique_sizes, 342 | 'size_consistency': size_consistency, 343 | 'avg_label_coverage': np.mean(label_coverage), 344 | 'std_label_coverage': np.std(label_coverage), 345 | 'avg_class_diversity': np.mean(class_diversity), 346 | 'std_class_diversity': np.std(class_diversity), 347 | 'label_coverage_distribution': label_coverage, 348 | 'class_diversity_distribution': class_diversity 349 | } 350 | 351 | 352 | def save_dataset_info(dataset, output_dir: str, dataset_name: str = "dataset") -> None: 353 | """ 354 | Save comprehensive dataset information to files. 355 | 356 | Args: 357 | dataset: Segmentation dataset 358 | output_dir (str): Output directory 359 | dataset_name (str): Name of the dataset 360 | """ 361 | os.makedirs(output_dir, exist_ok=True) 362 | 363 | # Basic info 364 | info = { 365 | 'dataset_name': dataset_name, 366 | 'num_samples': len(dataset), 367 | 'num_classes': getattr(dataset, 'num_classes', 'unknown'), 368 | 'ignore_index': getattr(dataset, 'ignore_index', 255) 369 | } 370 | 371 | # Save basic info 372 | import json 373 | with open(os.path.join(output_dir, f'{dataset_name}_info.json'), 'w') as f: 374 | json.dump(info, f, indent=2) 375 | 376 | # Compute and save class distribution 377 | if hasattr(dataset, 'num_classes'): 378 | class_stats = compute_class_distribution(dataset, dataset.num_classes) 379 | 380 | # Save class statistics 381 | np.save(os.path.join(output_dir, f'{dataset_name}_class_counts.npy'), 382 | class_stats['class_counts']) 383 | np.save(os.path.join(output_dir, f'{dataset_name}_class_weights.npy'), 384 | class_stats['class_weights']) 385 | 386 | # Save class distribution plot 387 | visualize_class_distribution( 388 | class_stats, 389 | save_path=os.path.join(output_dir, f'{dataset_name}_class_distribution.png') 390 | ) 391 | 392 | # Dataset quality analysis 393 | quality_stats = analyze_dataset_quality(dataset) 394 | with open(os.path.join(output_dir, f'{dataset_name}_quality.json'), 'w') as f: 395 | # Convert numpy arrays to lists for JSON serialization 396 | quality_stats_serializable = {} 397 | for k, v in quality_stats.items(): 398 | if isinstance(v, np.ndarray): 399 | quality_stats_serializable[k] = v.tolist() 400 | elif isinstance(v, np.float64): 401 | quality_stats_serializable[k] = float(v) 402 | else: 403 | quality_stats_serializable[k] = v 404 | json.dump(quality_stats_serializable, f, indent=2) 405 | 406 | print(f"Dataset information saved to: {output_dir}") 407 | 408 | 409 | def create_data_split(image_dir: str, label_dir: str, 410 | train_ratio: float = 0.7, val_ratio: float = 0.2, test_ratio: float = 0.1, 411 | output_dir: str = "splits", seed: int = 42) -> None: 412 | """ 413 | Create train/val/test splits for a custom dataset. 414 | 415 | Args: 416 | image_dir (str): Directory containing images 417 | label_dir (str): Directory containing labels 418 | train_ratio (float): Ratio for training set 419 | val_ratio (float): Ratio for validation set 420 | test_ratio (float): Ratio for test set 421 | output_dir (str): Output directory for split files 422 | seed (int): Random seed for reproducibility 423 | """ 424 | import random 425 | import shutil 426 | 427 | assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1.0" 428 | 429 | # Get all image files 430 | image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] 431 | 432 | # Filter files that have corresponding labels 433 | valid_files = [] 434 | for img_file in image_files: 435 | label_file = img_file.replace('.jpg', '.png').replace('.jpeg', '.png') 436 | if os.path.exists(os.path.join(label_dir, label_file)): 437 | valid_files.append(img_file) 438 | 439 | print(f"Found {len(valid_files)} valid image-label pairs") 440 | 441 | # Shuffle files 442 | random.seed(seed) 443 | random.shuffle(valid_files) 444 | 445 | # Compute split indices 446 | num_files = len(valid_files) 447 | train_end = int(num_files * train_ratio) 448 | val_end = train_end + int(num_files * val_ratio) 449 | 450 | # Split files 451 | train_files = valid_files[:train_end] 452 | val_files = valid_files[train_end:val_end] 453 | test_files = valid_files[val_end:] 454 | 455 | print(f"Split: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}") 456 | 457 | # Create output directories 458 | for split in ['train', 'val', 'test']: 459 | os.makedirs(os.path.join(output_dir, 'images', split), exist_ok=True) 460 | os.makedirs(os.path.join(output_dir, 'labels', split), exist_ok=True) 461 | 462 | # Copy files to respective directories 463 | for split, files in [('train', train_files), ('val', val_files), ('test', test_files)]: 464 | for img_file in files: 465 | label_file = img_file.replace('.jpg', '.png').replace('.jpeg', '.png') 466 | 467 | # Copy image 468 | shutil.copy2( 469 | os.path.join(image_dir, img_file), 470 | os.path.join(output_dir, 'images', split, img_file) 471 | ) 472 | 473 | # Copy label 474 | shutil.copy2( 475 | os.path.join(label_dir, label_file), 476 | os.path.join(output_dir, 'labels', split, label_file) 477 | ) 478 | 479 | print(f"Data split created in: {output_dir}") 480 | 481 | 482 | if __name__ == "__main__": 483 | # Test utilities 484 | print("Testing data utilities...") 485 | 486 | # Test color map creation 487 | color_map = create_color_map(19) 488 | print(f"Created color map with shape: {color_map.shape}") 489 | 490 | # Test other functions would require actual dataset 491 | print("Data utilities module loaded successfully!") 492 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Learning Rate Schedulers for Semantic Segmentation Training 3 | 4 | This module contains various learning rate scheduling strategies 5 | including cosine annealing, polynomial decay, and warmup schedules. 6 | """ 7 | 8 | import torch 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | import math 12 | from typing import Dict, Any, Optional, List 13 | 14 | 15 | class CosineAnnealingWithWarmup(_LRScheduler): 16 | """ 17 | Cosine Annealing learning rate scheduler with linear warmup. 18 | 19 | Args: 20 | optimizer (torch.optim.Optimizer): Optimizer 21 | T_max (int): Maximum number of iterations/epochs 22 | eta_min (float): Minimum learning rate 23 | warmup_epochs (int): Number of warmup epochs 24 | last_epoch (int): The index of last epoch 25 | """ 26 | 27 | def __init__(self, optimizer: torch.optim.Optimizer, T_max: int, 28 | eta_min: float = 0, warmup_epochs: int = 0, last_epoch: int = -1): 29 | self.T_max = T_max 30 | self.eta_min = eta_min 31 | self.warmup_epochs = warmup_epochs 32 | super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) 33 | 34 | def get_lr(self) -> List[float]: 35 | """Compute learning rate for current epoch.""" 36 | if self.last_epoch < self.warmup_epochs: 37 | # Linear warmup 38 | warmup_factor = self.last_epoch / self.warmup_epochs 39 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 40 | else: 41 | # Cosine annealing 42 | adjusted_epoch = self.last_epoch - self.warmup_epochs 43 | adjusted_T_max = self.T_max - self.warmup_epochs 44 | 45 | return [ 46 | self.eta_min + (base_lr - self.eta_min) * 47 | (1 + math.cos(math.pi * adjusted_epoch / adjusted_T_max)) / 2 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class PolynomialLR(_LRScheduler): 53 | """ 54 | Polynomial learning rate decay scheduler. 55 | 56 | Args: 57 | optimizer (torch.optim.Optimizer): Optimizer 58 | total_epochs (int): Total number of training epochs 59 | power (float): Power for polynomial decay 60 | warmup_epochs (int): Number of warmup epochs 61 | last_epoch (int): The index of last epoch 62 | """ 63 | 64 | def __init__(self, optimizer: torch.optim.Optimizer, total_epochs: int, 65 | power: float = 0.9, warmup_epochs: int = 0, last_epoch: int = -1): 66 | self.total_epochs = total_epochs 67 | self.power = power 68 | self.warmup_epochs = warmup_epochs 69 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 70 | 71 | def get_lr(self) -> List[float]: 72 | """Compute learning rate for current epoch.""" 73 | if self.last_epoch < self.warmup_epochs: 74 | # Linear warmup 75 | warmup_factor = self.last_epoch / self.warmup_epochs 76 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 77 | else: 78 | # Polynomial decay 79 | factor = (1 - (self.last_epoch - self.warmup_epochs) / 80 | (self.total_epochs - self.warmup_epochs)) ** self.power 81 | return [base_lr * factor for base_lr in self.base_lrs] 82 | 83 | 84 | class WarmupMultiStepLR(_LRScheduler): 85 | """ 86 | Multi-step learning rate scheduler with warmup. 87 | 88 | Args: 89 | optimizer (torch.optim.Optimizer): Optimizer 90 | milestones (List[int]): List of epoch indices for learning rate decay 91 | gamma (float): Multiplicative factor of learning rate decay 92 | warmup_epochs (int): Number of warmup epochs 93 | last_epoch (int): The index of last epoch 94 | """ 95 | 96 | def __init__(self, optimizer: torch.optim.Optimizer, milestones: List[int], 97 | gamma: float = 0.1, warmup_epochs: int = 0, last_epoch: int = -1): 98 | self.milestones = sorted(milestones) 99 | self.gamma = gamma 100 | self.warmup_epochs = warmup_epochs 101 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 102 | 103 | def get_lr(self) -> List[float]: 104 | """Compute learning rate for current epoch.""" 105 | if self.last_epoch < self.warmup_epochs: 106 | # Linear warmup 107 | warmup_factor = self.last_epoch / self.warmup_epochs 108 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 109 | else: 110 | # Multi-step decay 111 | adjusted_epoch = self.last_epoch - self.warmup_epochs 112 | adjusted_milestones = [m - self.warmup_epochs for m in self.milestones if m > self.warmup_epochs] 113 | 114 | decay_factor = self.gamma ** len([m for m in adjusted_milestones if m <= adjusted_epoch]) 115 | return [base_lr * decay_factor for base_lr in self.base_lrs] 116 | 117 | 118 | class OneCycleLR(_LRScheduler): 119 | """ 120 | One Cycle learning rate policy as described in "Super-Convergence". 121 | 122 | Args: 123 | optimizer (torch.optim.Optimizer): Optimizer 124 | max_lr (float): Maximum learning rate 125 | total_steps (int): Total number of training steps 126 | pct_start (float): Percentage of cycle spent increasing learning rate 127 | anneal_strategy (str): Annealing strategy ('cos' or 'linear') 128 | div_factor (float): Determines initial learning rate (max_lr / div_factor) 129 | final_div_factor (float): Determines minimum learning rate (max_lr / final_div_factor) 130 | last_epoch (int): The index of last epoch 131 | """ 132 | 133 | def __init__(self, optimizer: torch.optim.Optimizer, max_lr: float, total_steps: int, 134 | pct_start: float = 0.3, anneal_strategy: str = 'cos', 135 | div_factor: float = 25.0, final_div_factor: float = 1e4, last_epoch: int = -1): 136 | self.max_lr = max_lr 137 | self.total_steps = total_steps 138 | self.pct_start = pct_start 139 | self.anneal_strategy = anneal_strategy 140 | self.div_factor = div_factor 141 | self.final_div_factor = final_div_factor 142 | 143 | self.initial_lr = max_lr / div_factor 144 | self.min_lr = max_lr / final_div_factor 145 | 146 | super(OneCycleLR, self).__init__(optimizer, last_epoch) 147 | 148 | def get_lr(self) -> List[float]: 149 | """Compute learning rate for current step.""" 150 | step_num = self.last_epoch 151 | 152 | if step_num <= self.pct_start * self.total_steps: 153 | # Increasing phase 154 | pct = step_num / (self.pct_start * self.total_steps) 155 | lr = self.initial_lr + pct * (self.max_lr - self.initial_lr) 156 | else: 157 | # Decreasing phase 158 | pct = (step_num - self.pct_start * self.total_steps) / ((1 - self.pct_start) * self.total_steps) 159 | 160 | if self.anneal_strategy == 'cos': 161 | lr = self.min_lr + (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * pct)) / 2 162 | else: # linear 163 | lr = self.max_lr - pct * (self.max_lr - self.min_lr) 164 | 165 | return [lr for _ in self.base_lrs] 166 | 167 | 168 | class CyclicLR(_LRScheduler): 169 | """ 170 | Cyclic learning rate scheduler. 171 | 172 | Args: 173 | optimizer (torch.optim.Optimizer): Optimizer 174 | base_lr (float): Lower boundary of learning rate 175 | max_lr (float): Upper boundary of learning rate 176 | step_size_up (int): Number of training iterations in increasing half of cycle 177 | step_size_down (Optional[int]): Number of training iterations in decreasing half of cycle 178 | mode (str): One of 'triangular', 'triangular2', 'exp_range' 179 | gamma (float): Constant in 'exp_range' scaling function 180 | scale_fn (Optional[callable]): Custom scaling function 181 | scale_mode (str): 'cycle' or 'iterations' 182 | cycle_momentum (bool): Whether to cycle momentum inversely to learning rate 183 | base_momentum (float): Lower boundary of momentum 184 | max_momentum (float): Upper boundary of momentum 185 | last_epoch (int): The index of last epoch 186 | """ 187 | 188 | def __init__(self, optimizer: torch.optim.Optimizer, base_lr: float, max_lr: float, 189 | step_size_up: int = 2000, step_size_down: Optional[int] = None, 190 | mode: str = 'triangular', gamma: float = 1.0, scale_fn: Optional[callable] = None, 191 | scale_mode: str = 'cycle', cycle_momentum: bool = True, 192 | base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1): 193 | 194 | self.base_lr = base_lr 195 | self.max_lr = max_lr 196 | self.step_size_up = step_size_up 197 | self.step_size_down = step_size_down or step_size_up 198 | self.total_size = self.step_size_up + self.step_size_down 199 | self.mode = mode 200 | self.gamma = gamma 201 | self.scale_fn = scale_fn 202 | self.scale_mode = scale_mode 203 | self.cycle_momentum = cycle_momentum 204 | self.base_momentum = base_momentum 205 | self.max_momentum = max_momentum 206 | 207 | super(CyclicLR, self).__init__(optimizer, last_epoch) 208 | 209 | def get_lr(self) -> List[float]: 210 | """Compute learning rate for current step.""" 211 | cycle = math.floor(1 + self.last_epoch / self.total_size) 212 | x = 1 + self.last_epoch / self.total_size - cycle 213 | 214 | if x <= self.step_size_up / self.total_size: 215 | scale_factor = x / (self.step_size_up / self.total_size) 216 | else: 217 | scale_factor = (x - 1) / (self.step_size_down / self.total_size) + 1 218 | 219 | # Apply scaling based on mode 220 | if self.scale_fn is None: 221 | if self.mode == 'triangular': 222 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) 223 | for _ in self.base_lrs] 224 | elif self.mode == 'triangular2': 225 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) / (2 ** (cycle - 1)) 226 | for _ in self.base_lrs] 227 | elif self.mode == 'exp_range': 228 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) * (self.gamma ** self.last_epoch) 229 | for _ in self.base_lrs] 230 | else: 231 | lrs = [self.base_lr + (self.max_lr - self.base_lr) * max(0, (1 - abs(scale_factor))) * 232 | self.scale_fn(self.last_epoch if self.scale_mode == 'iterations' else cycle) 233 | for _ in self.base_lrs] 234 | 235 | return lrs 236 | 237 | 238 | def get_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> Optional[_LRScheduler]: 239 | """ 240 | Factory function to create learning rate scheduler based on configuration. 241 | 242 | Args: 243 | optimizer (torch.optim.Optimizer): Optimizer 244 | scheduler_config (Dict[str, Any]): Scheduler configuration 245 | 246 | Returns: 247 | Optional[_LRScheduler]: Learning rate scheduler or None 248 | """ 249 | if not scheduler_config or scheduler_config.get('type') is None: 250 | return None 251 | 252 | scheduler_type = scheduler_config['type'].lower() 253 | 254 | if scheduler_type == 'cosine': 255 | T_max = scheduler_config.get('T_max', 100) 256 | eta_min = scheduler_config.get('min_lr', 0) 257 | warmup_epochs = scheduler_config.get('warmup_epochs', 0) 258 | 259 | return CosineAnnealingWithWarmup( 260 | optimizer, T_max=T_max, eta_min=eta_min, warmup_epochs=warmup_epochs 261 | ) 262 | 263 | elif scheduler_type == 'polynomial': 264 | total_epochs = scheduler_config.get('total_epochs', 100) 265 | power = scheduler_config.get('power', 0.9) 266 | warmup_epochs = scheduler_config.get('warmup_epochs', 0) 267 | 268 | return PolynomialLR( 269 | optimizer, total_epochs=total_epochs, power=power, warmup_epochs=warmup_epochs 270 | ) 271 | 272 | elif scheduler_type == 'step': 273 | step_size = scheduler_config.get('step_size', 30) 274 | gamma = scheduler_config.get('gamma', 0.1) 275 | 276 | return optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 277 | 278 | elif scheduler_type == 'multistep': 279 | milestones = scheduler_config.get('milestones', [60, 80]) 280 | gamma = scheduler_config.get('gamma', 0.1) 281 | warmup_epochs = scheduler_config.get('warmup_epochs', 0) 282 | 283 | if warmup_epochs > 0: 284 | return WarmupMultiStepLR( 285 | optimizer, milestones=milestones, gamma=gamma, warmup_epochs=warmup_epochs 286 | ) 287 | else: 288 | return optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma) 289 | 290 | elif scheduler_type == 'exponential': 291 | gamma = scheduler_config.get('gamma', 0.95) 292 | return optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 293 | 294 | elif scheduler_type == 'reduce_on_plateau': 295 | mode = scheduler_config.get('mode', 'min') 296 | factor = scheduler_config.get('factor', 0.5) 297 | patience = scheduler_config.get('patience', 10) 298 | threshold = scheduler_config.get('threshold', 1e-4) 299 | 300 | return optim.lr_scheduler.ReduceLROnPlateau( 301 | optimizer, mode=mode, factor=factor, patience=patience, threshold=threshold 302 | ) 303 | 304 | elif scheduler_type == 'one_cycle': 305 | max_lr = scheduler_config.get('max_lr', 0.1) 306 | total_steps = scheduler_config.get('total_steps', 1000) 307 | pct_start = scheduler_config.get('pct_start', 0.3) 308 | anneal_strategy = scheduler_config.get('anneal_strategy', 'cos') 309 | div_factor = scheduler_config.get('div_factor', 25.0) 310 | final_div_factor = scheduler_config.get('final_div_factor', 1e4) 311 | 312 | return OneCycleLR( 313 | optimizer, max_lr=max_lr, total_steps=total_steps, pct_start=pct_start, 314 | anneal_strategy=anneal_strategy, div_factor=div_factor, final_div_factor=final_div_factor 315 | ) 316 | 317 | elif scheduler_type == 'cyclic': 318 | base_lr = scheduler_config.get('base_lr', 0.001) 319 | max_lr = scheduler_config.get('max_lr', 0.006) 320 | step_size_up = scheduler_config.get('step_size_up', 2000) 321 | mode = scheduler_config.get('mode', 'triangular') 322 | gamma = scheduler_config.get('gamma', 1.0) 323 | 324 | return CyclicLR( 325 | optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=step_size_up, 326 | mode=mode, gamma=gamma 327 | ) 328 | 329 | else: 330 | raise ValueError(f"Unsupported scheduler type: {scheduler_type}") 331 | 332 | 333 | class WarmupScheduler: 334 | """ 335 | Wrapper for adding warmup to any scheduler. 336 | 337 | Args: 338 | optimizer (torch.optim.Optimizer): Optimizer 339 | scheduler (_LRScheduler): Base scheduler 340 | warmup_epochs (int): Number of warmup epochs 341 | warmup_method (str): Warmup method ('linear' or 'constant') 342 | warmup_factor (float): Warmup factor for 'constant' method 343 | """ 344 | 345 | def __init__(self, optimizer: torch.optim.Optimizer, scheduler: _LRScheduler, 346 | warmup_epochs: int, warmup_method: str = 'linear', warmup_factor: float = 0.1): 347 | self.optimizer = optimizer 348 | self.scheduler = scheduler 349 | self.warmup_epochs = warmup_epochs 350 | self.warmup_method = warmup_method 351 | self.warmup_factor = warmup_factor 352 | self.base_lrs = [group['lr'] for group in optimizer.param_groups] 353 | self.last_epoch = 0 354 | 355 | def step(self, epoch: Optional[int] = None): 356 | """Step the scheduler.""" 357 | if epoch is None: 358 | epoch = self.last_epoch + 1 359 | self.last_epoch = epoch 360 | 361 | if epoch < self.warmup_epochs: 362 | # Warmup phase 363 | if self.warmup_method == 'linear': 364 | warmup_factor = epoch / self.warmup_epochs 365 | else: # constant 366 | warmup_factor = self.warmup_factor 367 | 368 | for i, param_group in enumerate(self.optimizer.param_groups): 369 | param_group['lr'] = self.base_lrs[i] * warmup_factor 370 | else: 371 | # Normal scheduling 372 | self.scheduler.step(epoch - self.warmup_epochs) 373 | 374 | def state_dict(self): 375 | """Return state dict.""" 376 | return { 377 | 'scheduler': self.scheduler.state_dict(), 378 | 'last_epoch': self.last_epoch, 379 | 'warmup_epochs': self.warmup_epochs, 380 | 'warmup_method': self.warmup_method, 381 | 'warmup_factor': self.warmup_factor, 382 | 'base_lrs': self.base_lrs 383 | } 384 | 385 | def load_state_dict(self, state_dict): 386 | """Load state dict.""" 387 | self.scheduler.load_state_dict(state_dict['scheduler']) 388 | self.last_epoch = state_dict['last_epoch'] 389 | self.warmup_epochs = state_dict['warmup_epochs'] 390 | self.warmup_method = state_dict['warmup_method'] 391 | self.warmup_factor = state_dict['warmup_factor'] 392 | self.base_lrs = state_dict['base_lrs'] 393 | 394 | 395 | if __name__ == "__main__": 396 | # Test schedulers 397 | import matplotlib 398 | matplotlib.use('Agg') 399 | import matplotlib.pyplot as plt 400 | 401 | # Create dummy optimizer 402 | model = torch.nn.Linear(10, 1) 403 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 404 | 405 | # Test different schedulers 406 | schedulers = { 407 | 'Cosine with Warmup': CosineAnnealingWithWarmup(optimizer, T_max=100, warmup_epochs=10), 408 | 'Polynomial': PolynomialLR(optimizer, total_epochs=100, power=0.9, warmup_epochs=10), 409 | 'One Cycle': OneCycleLR(optimizer, max_lr=0.1, total_steps=100), 410 | } 411 | 412 | # Plot learning rate schedules 413 | fig, axes = plt.subplots(1, len(schedulers), figsize=(15, 5)) 414 | if len(schedulers) == 1: 415 | axes = [axes] 416 | 417 | for i, (name, scheduler) in enumerate(schedulers.items()): 418 | lrs = [] 419 | # Reset optimizer 420 | for param_group in optimizer.param_groups: 421 | param_group['lr'] = 0.1 422 | scheduler.last_epoch = -1 423 | 424 | for epoch in range(100): 425 | scheduler.step() 426 | lrs.append(optimizer.param_groups[0]['lr']) 427 | 428 | axes[i].plot(lrs) 429 | axes[i].set_title(name) 430 | axes[i].set_xlabel('Epoch') 431 | axes[i].set_ylabel('Learning Rate') 432 | axes[i].grid(True) 433 | 434 | plt.tight_layout() 435 | plt.savefig('/workspace/semantic_segmentation_project/scheduler_comparison.png') 436 | print("Scheduler comparison plot saved!") 437 | 438 | print("Scheduler module test completed successfully!") 439 | -------------------------------------------------------------------------------- /utils/evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Evaluator for Semantic Segmentation 3 | 4 | This module contains comprehensive evaluation utilities including 5 | model testing, inference with TTA, and detailed analysis. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | import time 14 | import os 15 | from typing import Dict, List, Optional, Tuple, Any, Union 16 | from tqdm import tqdm 17 | import json 18 | import cv2 19 | from PIL import Image 20 | 21 | from .metrics import SegmentationMetrics, StreamingMetrics 22 | from .visualization import visualize_prediction, apply_color_map, create_color_palette 23 | 24 | 25 | class SegmentationEvaluator: 26 | """ 27 | Comprehensive evaluator for semantic segmentation models. 28 | 29 | Args: 30 | model (nn.Module): Segmentation model 31 | device (str): Device for evaluation 32 | num_classes (int): Number of classes 33 | ignore_index (int): Index to ignore in evaluation 34 | class_names (Optional[List[str]]): Names of classes 35 | class_colors (Optional[List[List[int]]]): Colors for each class 36 | """ 37 | 38 | def __init__(self, model: nn.Module, device: str = 'cuda', 39 | num_classes: int = 19, ignore_index: int = 255, 40 | class_names: Optional[List[str]] = None, 41 | class_colors: Optional[List[List[int]]] = None): 42 | self.model = model 43 | self.device = device 44 | self.num_classes = num_classes 45 | self.ignore_index = ignore_index 46 | self.class_names = class_names or [f"Class {i}" for i in range(num_classes)] 47 | 48 | # Create color palette 49 | if class_colors is not None: 50 | self.color_palette = np.array(class_colors, dtype=np.uint8) 51 | else: 52 | self.color_palette = create_color_palette(num_classes) 53 | 54 | # Initialize metrics 55 | self.metrics = SegmentationMetrics(num_classes, ignore_index, device) 56 | 57 | def evaluate_dataset(self, data_loader: DataLoader, 58 | use_tta: bool = False, 59 | tta_scales: List[float] = [0.75, 1.0, 1.25], 60 | tta_flip: bool = True, 61 | save_predictions: bool = False, 62 | output_dir: str = "eval_results") -> Dict[str, Any]: 63 | """ 64 | Evaluate model on a dataset. 65 | 66 | Args: 67 | data_loader (DataLoader): Data loader for evaluation 68 | use_tta (bool): Whether to use Test Time Augmentation 69 | tta_scales (List[float]): Scales for TTA 70 | tta_flip (bool): Whether to use horizontal flip in TTA 71 | save_predictions (bool): Whether to save predictions 72 | output_dir (str): Output directory for results 73 | 74 | Returns: 75 | Dict[str, Any]: Evaluation results 76 | """ 77 | self.model.eval() 78 | self.metrics.reset() 79 | 80 | if save_predictions: 81 | os.makedirs(output_dir, exist_ok=True) 82 | pred_dir = os.path.join(output_dir, 'predictions') 83 | vis_dir = os.path.join(output_dir, 'visualizations') 84 | os.makedirs(pred_dir, exist_ok=True) 85 | os.makedirs(vis_dir, exist_ok=True) 86 | 87 | total_time = 0 88 | num_samples = 0 89 | 90 | with torch.no_grad(): 91 | pbar = tqdm(data_loader, desc='Evaluating') 92 | 93 | for batch_idx, batch in enumerate(pbar): 94 | images = batch['image'].to(self.device, non_blocking=True) 95 | labels = batch['label'].to(self.device, non_blocking=True) 96 | 97 | start_time = time.time() 98 | 99 | if use_tta: 100 | predictions = self._predict_with_tta(images, tta_scales, tta_flip) 101 | else: 102 | outputs = self.model(images) 103 | if isinstance(outputs, dict): 104 | predictions = outputs['out'] 105 | else: 106 | predictions = outputs 107 | 108 | inference_time = time.time() - start_time 109 | total_time += inference_time 110 | num_samples += len(images) 111 | 112 | # Update metrics 113 | self.metrics.update(predictions, labels) 114 | 115 | # Save predictions and visualizations 116 | if save_predictions: 117 | self._save_batch_predictions( 118 | batch, predictions, batch_idx, pred_dir, vis_dir 119 | ) 120 | 121 | # Update progress bar 122 | current_metrics = self.metrics.compute() 123 | pbar.set_postfix({ 124 | 'mIoU': f'{current_metrics["mIoU"]:.4f}', 125 | 'Pixel Acc': f'{current_metrics["Pixel_Accuracy"]:.4f}' 126 | }) 127 | 128 | # Compute final metrics 129 | final_metrics = self.metrics.compute() 130 | 131 | # Add timing information 132 | final_metrics['inference_time_per_sample'] = total_time / num_samples 133 | final_metrics['fps'] = num_samples / total_time 134 | 135 | # Save metrics 136 | if save_predictions: 137 | self._save_evaluation_results(final_metrics, output_dir) 138 | 139 | return final_metrics 140 | 141 | def _predict_with_tta(self, images: torch.Tensor, 142 | scales: List[float], use_flip: bool) -> torch.Tensor: 143 | """ 144 | Perform prediction with Test Time Augmentation. 145 | 146 | Args: 147 | images (torch.Tensor): Input images 148 | scales (List[float]): Scale factors 149 | use_flip (bool): Whether to use horizontal flip 150 | 151 | Returns: 152 | torch.Tensor: Averaged predictions 153 | """ 154 | b, c, h, w = images.shape 155 | 156 | # Initialize aggregated predictions 157 | aggregated_preds = torch.zeros(b, self.num_classes, h, w, device=self.device) 158 | num_predictions = 0 159 | 160 | for scale in scales: 161 | # Resize images 162 | scaled_h, scaled_w = int(h * scale), int(w * scale) 163 | scaled_images = F.interpolate( 164 | images, size=(scaled_h, scaled_w), 165 | mode='bilinear', align_corners=False 166 | ) 167 | 168 | # Normal prediction 169 | outputs = self.model(scaled_images) 170 | if isinstance(outputs, dict): 171 | preds = outputs['out'] 172 | else: 173 | preds = outputs 174 | 175 | # Resize back to original size 176 | preds = F.interpolate( 177 | preds, size=(h, w), 178 | mode='bilinear', align_corners=False 179 | ) 180 | aggregated_preds += preds 181 | num_predictions += 1 182 | 183 | # Flipped prediction 184 | if use_flip: 185 | flipped_images = torch.flip(scaled_images, dims=[3]) 186 | outputs = self.model(flipped_images) 187 | if isinstance(outputs, dict): 188 | preds = outputs['out'] 189 | else: 190 | preds = outputs 191 | 192 | # Flip back and resize 193 | preds = torch.flip(preds, dims=[3]) 194 | preds = F.interpolate( 195 | preds, size=(h, w), 196 | mode='bilinear', align_corners=False 197 | ) 198 | aggregated_preds += preds 199 | num_predictions += 1 200 | 201 | # Average predictions 202 | averaged_preds = aggregated_preds / num_predictions 203 | 204 | return averaged_preds 205 | 206 | def _save_batch_predictions(self, batch: Dict[str, torch.Tensor], 207 | predictions: torch.Tensor, batch_idx: int, 208 | pred_dir: str, vis_dir: str) -> None: 209 | """Save batch predictions and visualizations.""" 210 | batch_size = len(batch['image']) 211 | 212 | for i in range(batch_size): 213 | sample_idx = batch_idx * batch_size + i 214 | 215 | # Get data 216 | image = batch['image'][i].cpu() 217 | label = batch['label'][i].cpu() 218 | pred = torch.argmax(predictions[i], dim=0).cpu() 219 | confidence = torch.max(torch.softmax(predictions[i], dim=0), dim=0)[0].cpu() 220 | 221 | # Save prediction mask 222 | pred_path = os.path.join(pred_dir, f'prediction_{sample_idx:06d}.png') 223 | pred_image = Image.fromarray(pred.numpy().astype(np.uint8)) 224 | pred_image.save(pred_path) 225 | 226 | # Save visualization 227 | vis_path = os.path.join(vis_dir, f'visualization_{sample_idx:06d}.png') 228 | visualize_prediction( 229 | image=image, 230 | label=label, 231 | prediction=pred, 232 | confidence=confidence, 233 | color_palette=self.color_palette, 234 | save_path=vis_path 235 | ) 236 | 237 | def _save_evaluation_results(self, metrics: Dict[str, float], output_dir: str) -> None: 238 | """Save evaluation results to files.""" 239 | # Save metrics as JSON 240 | metrics_path = os.path.join(output_dir, 'metrics.json') 241 | with open(metrics_path, 'w') as f: 242 | json.dump(metrics, f, indent=2) 243 | 244 | # Save detailed per-class metrics 245 | detailed_metrics = {} 246 | for i in range(self.num_classes): 247 | class_name = self.class_names[i] if i < len(self.class_names) else f"Class {i}" 248 | detailed_metrics[class_name] = { 249 | 'IoU': metrics.get(f'IoU_Class_{i}', 0.0), 250 | 'Dice': metrics.get(f'Dice_Class_{i}', 0.0), 251 | 'Precision': metrics.get(f'Precision_Class_{i}', 0.0), 252 | 'Recall': metrics.get(f'Recall_Class_{i}', 0.0), 253 | 'F1': metrics.get(f'F1_Class_{i}', 0.0) 254 | } 255 | 256 | detailed_path = os.path.join(output_dir, 'detailed_metrics.json') 257 | with open(detailed_path, 'w') as f: 258 | json.dump(detailed_metrics, f, indent=2) 259 | 260 | # Save confusion matrix 261 | confusion_matrix = self.metrics.get_confusion_matrix() 262 | np.save(os.path.join(output_dir, 'confusion_matrix.npy'), confusion_matrix) 263 | 264 | print(f"Evaluation results saved to: {output_dir}") 265 | 266 | def evaluate_single_image(self, image_path: str, 267 | use_tta: bool = False, 268 | save_result: bool = True, 269 | output_path: Optional[str] = None) -> Dict[str, Any]: 270 | """ 271 | Evaluate model on a single image. 272 | 273 | Args: 274 | image_path (str): Path to input image 275 | use_tta (bool): Whether to use TTA 276 | save_result (bool): Whether to save result 277 | output_path (Optional[str]): Output path for result 278 | 279 | Returns: 280 | Dict[str, Any]: Prediction results 281 | """ 282 | self.model.eval() 283 | 284 | # Load and preprocess image 285 | image = Image.open(image_path).convert('RGB') 286 | original_size = image.size 287 | 288 | # Convert to tensor (assuming normalization is handled in transforms) 289 | image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0 290 | image_tensor = image_tensor.unsqueeze(0).to(self.device) 291 | 292 | with torch.no_grad(): 293 | start_time = time.time() 294 | 295 | if use_tta: 296 | predictions = self._predict_with_tta( 297 | image_tensor, scales=[0.75, 1.0, 1.25], use_flip=True 298 | ) 299 | else: 300 | outputs = self.model(image_tensor) 301 | if isinstance(outputs, dict): 302 | predictions = outputs['out'] 303 | else: 304 | predictions = outputs 305 | 306 | inference_time = time.time() - start_time 307 | 308 | # Process predictions 309 | pred_mask = torch.argmax(predictions[0], dim=0).cpu().numpy() 310 | confidence_map = torch.max(torch.softmax(predictions[0], dim=0), dim=0)[0].cpu().numpy() 311 | 312 | # Create colored prediction 313 | colored_pred = apply_color_map(pred_mask, self.color_palette, self.ignore_index) 314 | 315 | results = { 316 | 'prediction_mask': pred_mask, 317 | 'confidence_map': confidence_map, 318 | 'colored_prediction': colored_pred, 319 | 'inference_time': inference_time, 320 | 'original_size': original_size 321 | } 322 | 323 | # Save results 324 | if save_result: 325 | if output_path is None: 326 | base_name = os.path.splitext(os.path.basename(image_path))[0] 327 | output_path = f"{base_name}_prediction.png" 328 | 329 | # Save colored prediction 330 | colored_pred_image = Image.fromarray(colored_pred) 331 | colored_pred_image.save(output_path) 332 | 333 | # Save raw prediction mask 334 | mask_path = output_path.replace('.png', '_mask.png') 335 | mask_image = Image.fromarray(pred_mask.astype(np.uint8)) 336 | mask_image.save(mask_path) 337 | 338 | print(f"Results saved to: {output_path}") 339 | 340 | return results 341 | 342 | def benchmark_model(self, data_loader: DataLoader, 343 | num_warmup: int = 10, 344 | num_iterations: int = 100) -> Dict[str, float]: 345 | """ 346 | Benchmark model performance. 347 | 348 | Args: 349 | data_loader (DataLoader): Data loader for benchmarking 350 | num_warmup (int): Number of warmup iterations 351 | num_iterations (int): Number of benchmark iterations 352 | 353 | Returns: 354 | Dict[str, float]: Benchmark results 355 | """ 356 | self.model.eval() 357 | 358 | # Warmup 359 | print("Warming up...") 360 | with torch.no_grad(): 361 | for i, batch in enumerate(data_loader): 362 | if i >= num_warmup: 363 | break 364 | 365 | images = batch['image'].to(self.device, non_blocking=True) 366 | outputs = self.model(images) 367 | 368 | if self.device == 'cuda': 369 | torch.cuda.synchronize() 370 | 371 | # Benchmark 372 | print("Benchmarking...") 373 | times = [] 374 | 375 | with torch.no_grad(): 376 | for i, batch in enumerate(data_loader): 377 | if i >= num_iterations: 378 | break 379 | 380 | images = batch['image'].to(self.device, non_blocking=True) 381 | 382 | if self.device == 'cuda': 383 | torch.cuda.synchronize() 384 | 385 | start_time = time.time() 386 | outputs = self.model(images) 387 | 388 | if self.device == 'cuda': 389 | torch.cuda.synchronize() 390 | 391 | end_time = time.time() 392 | times.append(end_time - start_time) 393 | 394 | # Compute statistics 395 | times = np.array(times) 396 | batch_size = len(batch['image']) 397 | 398 | results = { 399 | 'avg_batch_time': float(np.mean(times)), 400 | 'std_batch_time': float(np.std(times)), 401 | 'min_batch_time': float(np.min(times)), 402 | 'max_batch_time': float(np.max(times)), 403 | 'avg_sample_time': float(np.mean(times) / batch_size), 404 | 'fps': float(batch_size / np.mean(times)), 405 | 'throughput_samples_per_sec': float(num_iterations * batch_size / np.sum(times)) 406 | } 407 | 408 | return results 409 | 410 | def analyze_failure_cases(self, data_loader: DataLoader, 411 | iou_threshold: float = 0.3, 412 | max_cases: int = 50, 413 | output_dir: str = "failure_analysis") -> List[Dict[str, Any]]: 414 | """ 415 | Analyze failure cases where model performs poorly. 416 | 417 | Args: 418 | data_loader (DataLoader): Data loader 419 | iou_threshold (float): IoU threshold below which samples are considered failures 420 | max_cases (int): Maximum number of failure cases to analyze 421 | output_dir (str): Output directory for analysis 422 | 423 | Returns: 424 | List[Dict[str, Any]]: List of failure case information 425 | """ 426 | self.model.eval() 427 | failure_cases = [] 428 | 429 | os.makedirs(output_dir, exist_ok=True) 430 | 431 | with torch.no_grad(): 432 | pbar = tqdm(data_loader, desc='Analyzing failure cases') 433 | 434 | for batch_idx, batch in enumerate(pbar): 435 | if len(failure_cases) >= max_cases: 436 | break 437 | 438 | images = batch['image'].to(self.device, non_blocking=True) 439 | labels = batch['label'].to(self.device, non_blocking=True) 440 | 441 | outputs = self.model(images) 442 | if isinstance(outputs, dict): 443 | predictions = outputs['out'] 444 | else: 445 | predictions = outputs 446 | 447 | # Compute per-sample IoU 448 | batch_size = len(images) 449 | for i in range(batch_size): 450 | sample_pred = predictions[i:i+1] 451 | sample_label = labels[i:i+1] 452 | 453 | # Compute sample metrics 454 | sample_metrics = SegmentationMetrics(self.num_classes, self.ignore_index, self.device) 455 | sample_metrics.update(sample_pred, sample_label) 456 | metrics = sample_metrics.compute() 457 | 458 | if metrics['mIoU'] < iou_threshold: 459 | # Save failure case 460 | sample_idx = len(failure_cases) 461 | 462 | case_info = { 463 | 'sample_index': sample_idx, 464 | 'batch_index': batch_idx, 465 | 'sample_in_batch': i, 466 | 'miou': metrics['mIoU'], 467 | 'pixel_accuracy': metrics['Pixel_Accuracy'], 468 | 'image_path': batch.get('image_path', [''])[i], 469 | 'label_path': batch.get('label_path', [''])[i] 470 | } 471 | 472 | # Save visualization 473 | vis_path = os.path.join(output_dir, f'failure_case_{sample_idx:03d}.png') 474 | visualize_prediction( 475 | image=images[i].cpu(), 476 | label=labels[i].cpu(), 477 | prediction=torch.argmax(predictions[i], dim=0).cpu(), 478 | confidence=torch.max(torch.softmax(predictions[i], dim=0), dim=0)[0].cpu(), 479 | color_palette=self.color_palette, 480 | save_path=vis_path 481 | ) 482 | 483 | failure_cases.append(case_info) 484 | 485 | # Save failure case summary 486 | summary_path = os.path.join(output_dir, 'failure_cases_summary.json') 487 | with open(summary_path, 'w') as f: 488 | json.dump(failure_cases, f, indent=2) 489 | 490 | print(f"Found {len(failure_cases)} failure cases. Analysis saved to: {output_dir}") 491 | 492 | return failure_cases 493 | 494 | 495 | if __name__ == "__main__": 496 | # Test evaluator functionality 497 | print("Testing evaluator module...") 498 | 499 | # This would require actual model, data loaders, etc. 500 | # For now, just test imports 501 | print("Evaluator module loaded successfully!") 502 | -------------------------------------------------------------------------------- /models/transformer_adapter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer Adapter Module for PFM Segmentation 3 | 4 | This module implements a Transformer Adapter inspired by DINOv2's architecture. 5 | The Transformer Adapter adds extra Vision Blocks (Transformer layers) after 6 | the frozen ViT encoder to enable efficient fine-tuning for segmentation tasks. 7 | 8 | The strategy: 9 | 1. Freeze the original PFM (ViT encoder) parameters 10 | 2. Add trainable Vision Blocks after the ViT encoder 11 | 3. Train only: Vision Blocks + Decoder + Segmentation Head 12 | 13 | Reference: DINOv2 Vision Transformer architecture 14 | 15 | Author: @chenwm 16 | """ 17 | 18 | import math 19 | from typing import Optional, Dict, Any, Tuple, List 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from timm.layers import use_fused_attn, DropPath, Mlp 25 | 26 | 27 | class VisionBlockAttention(nn.Module): 28 | """ 29 | Multi-head Self-Attention for Vision Block. 30 | 31 | Follows the standard ViT attention mechanism with optional QK normalization. 32 | 33 | Args: 34 | dim (int): Input dimension 35 | num_heads (int): Number of attention heads 36 | qkv_bias (bool): Whether to use bias in QKV projection 37 | qk_norm (bool): Whether to apply layer norm to Q and K 38 | attn_drop (float): Dropout rate for attention weights 39 | proj_drop (float): Dropout rate for output projection 40 | norm_layer (nn.Module): Normalization layer class 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dim: int, 46 | num_heads: int = 8, 47 | qkv_bias: bool = True, 48 | qk_norm: bool = False, 49 | attn_drop: float = 0., 50 | proj_drop: float = 0., 51 | norm_layer: nn.Module = nn.LayerNorm, 52 | ) -> None: 53 | super().__init__() 54 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 55 | self.num_heads = num_heads 56 | self.head_dim = dim // num_heads 57 | self.scale = self.head_dim ** -0.5 58 | self.fused_attn = use_fused_attn() 59 | 60 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 61 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 62 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 63 | self.attn_drop = nn.Dropout(attn_drop) 64 | self.proj = nn.Linear(dim, dim) 65 | self.proj_drop = nn.Dropout(proj_drop) 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | B, N, C = x.shape 69 | 70 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 71 | q, k, v = qkv.unbind(0) 72 | q, k = self.q_norm(q), self.k_norm(k) 73 | 74 | if self.fused_attn: 75 | x = F.scaled_dot_product_attention( 76 | q, k, v, 77 | dropout_p=self.attn_drop.p if self.training else 0., 78 | ) 79 | else: 80 | q = q * self.scale 81 | attn = q @ k.transpose(-2, -1) 82 | attn = attn.softmax(dim=-1) 83 | attn = self.attn_drop(attn) 84 | x = attn @ v 85 | 86 | x = x.transpose(1, 2).reshape(B, N, C) 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | 91 | 92 | class VisionBlock(nn.Module): 93 | """ 94 | Vision Block (Transformer Block) for adapter. 95 | 96 | Standard Transformer block with: 97 | - Pre-normalization 98 | - Multi-head Self-Attention 99 | - MLP (Feed-Forward Network) 100 | - Residual connections 101 | - Optional DropPath for regularization 102 | 103 | This follows the DINOv2/ViT-style architecture. 104 | 105 | Args: 106 | dim (int): Input/output dimension 107 | num_heads (int): Number of attention heads 108 | mlp_ratio (float): Ratio of MLP hidden dim to embedding dim 109 | qkv_bias (bool): Whether to use bias in QKV projection 110 | qk_norm (bool): Whether to apply layer norm to Q and K 111 | drop (float): Dropout rate 112 | attn_drop (float): Attention dropout rate 113 | drop_path (float): Stochastic depth rate 114 | init_values (float): Initial value for layer scale (None to disable) 115 | act_layer (nn.Module): Activation layer class 116 | norm_layer (nn.Module): Normalization layer class 117 | mlp_layer (nn.Module): MLP layer class 118 | """ 119 | 120 | def __init__( 121 | self, 122 | dim: int, 123 | num_heads: int = 8, 124 | mlp_ratio: float = 4., 125 | qkv_bias: bool = True, 126 | qk_norm: bool = False, 127 | drop: float = 0., 128 | attn_drop: float = 0., 129 | drop_path: float = 0., 130 | init_values: Optional[float] = None, 131 | act_layer: nn.Module = nn.GELU, 132 | norm_layer: nn.Module = nn.LayerNorm, 133 | mlp_layer: nn.Module = Mlp, 134 | ) -> None: 135 | super().__init__() 136 | 137 | # Pre-normalization 138 | self.norm1 = norm_layer(dim) 139 | 140 | # Self-attention 141 | self.attn = VisionBlockAttention( 142 | dim=dim, 143 | num_heads=num_heads, 144 | qkv_bias=qkv_bias, 145 | qk_norm=qk_norm, 146 | attn_drop=attn_drop, 147 | proj_drop=drop, 148 | norm_layer=norm_layer, 149 | ) 150 | 151 | # Layer scale for attention 152 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 153 | 154 | # Drop path for stochastic depth 155 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 156 | 157 | # Pre-normalization for MLP 158 | self.norm2 = norm_layer(dim) 159 | 160 | # MLP (Feed-Forward Network) 161 | mlp_hidden_dim = int(dim * mlp_ratio) 162 | self.mlp = mlp_layer( 163 | in_features=dim, 164 | hidden_features=mlp_hidden_dim, 165 | act_layer=act_layer, 166 | drop=drop, 167 | ) 168 | 169 | # Layer scale for MLP 170 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 171 | 172 | # Drop path for stochastic depth 173 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 174 | 175 | def forward(self, x: torch.Tensor) -> torch.Tensor: 176 | # Self-attention with residual 177 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 178 | # MLP with residual 179 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 180 | return x 181 | 182 | 183 | class LayerScale(nn.Module): 184 | """ 185 | Layer Scale module for improved training stability. 186 | 187 | Reference: "Going deeper with Image Transformers" (CaiT paper) 188 | 189 | Args: 190 | dim (int): Dimension of the input 191 | init_values (float): Initial value for the scale parameters 192 | """ 193 | 194 | def __init__(self, dim: int, init_values: float = 1e-5) -> None: 195 | super().__init__() 196 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 197 | 198 | def forward(self, x: torch.Tensor) -> torch.Tensor: 199 | return x * self.gamma 200 | 201 | 202 | class TransformerAdapter(nn.Module): 203 | """ 204 | Transformer Adapter: A stack of Vision Blocks added after the frozen ViT encoder. 205 | 206 | This adapter processes the features from the frozen PFM encoder and learns 207 | task-specific representations through additional transformer layers. 208 | 209 | Args: 210 | dim (int): Input/output embedding dimension (must match PFM output) 211 | depth (int): Number of Vision Blocks 212 | num_heads (int): Number of attention heads 213 | mlp_ratio (float): Ratio of MLP hidden dim to embedding dim 214 | qkv_bias (bool): Whether to use bias in QKV projection 215 | qk_norm (bool): Whether to apply layer norm to Q and K 216 | drop_rate (float): Dropout rate 217 | attn_drop_rate (float): Attention dropout rate 218 | drop_path_rate (float): Stochastic depth rate 219 | init_values (float): Initial value for layer scale 220 | act_layer (nn.Module): Activation layer class 221 | norm_layer (nn.Module): Normalization layer class 222 | """ 223 | 224 | def __init__( 225 | self, 226 | dim: int, 227 | depth: int = 4, 228 | num_heads: int = 8, 229 | mlp_ratio: float = 4., 230 | qkv_bias: bool = True, 231 | qk_norm: bool = False, 232 | drop_rate: float = 0., 233 | attn_drop_rate: float = 0., 234 | drop_path_rate: float = 0., 235 | init_values: Optional[float] = 1e-5, 236 | act_layer: nn.Module = nn.GELU, 237 | norm_layer: nn.Module = nn.LayerNorm, 238 | ) -> None: 239 | super().__init__() 240 | self.dim = dim 241 | self.depth = depth 242 | self.num_heads = num_heads 243 | 244 | # Stochastic depth decay rule 245 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 246 | 247 | # Stack of Vision Blocks 248 | self.blocks = nn.ModuleList([ 249 | VisionBlock( 250 | dim=dim, 251 | num_heads=num_heads, 252 | mlp_ratio=mlp_ratio, 253 | qkv_bias=qkv_bias, 254 | qk_norm=qk_norm, 255 | drop=drop_rate, 256 | attn_drop=attn_drop_rate, 257 | drop_path=dpr[i], 258 | init_values=init_values, 259 | act_layer=act_layer, 260 | norm_layer=norm_layer, 261 | ) 262 | for i in range(depth) 263 | ]) 264 | 265 | # Final normalization 266 | self.norm = norm_layer(dim) 267 | 268 | # Initialize weights 269 | self.apply(self._init_weights) 270 | 271 | def _init_weights(self, m: nn.Module) -> None: 272 | """Initialize weights following ViT conventions.""" 273 | if isinstance(m, nn.Linear): 274 | # Use truncated normal initialization 275 | nn.init.trunc_normal_(m.weight, std=0.02) 276 | if m.bias is not None: 277 | nn.init.zeros_(m.bias) 278 | elif isinstance(m, nn.LayerNorm): 279 | nn.init.zeros_(m.bias) 280 | nn.init.ones_(m.weight) 281 | 282 | def forward(self, x: torch.Tensor) -> torch.Tensor: 283 | """ 284 | Forward pass through the Transformer Adapter. 285 | 286 | Args: 287 | x (torch.Tensor): Input features from PFM encoder, shape (B, N, dim) 288 | 289 | Returns: 290 | torch.Tensor: Adapted features, shape (B, N, dim) 291 | """ 292 | for block in self.blocks: 293 | x = block(x) 294 | x = self.norm(x) 295 | return x 296 | 297 | 298 | # Number of tokens to skip for each PFM model (CLS token + register tokens) 299 | # This is used to extract patch tokens after vision blocks processing 300 | PFM_SKIP_TOKENS = { 301 | 'gigapath': 1, # CLS only 302 | 'uni_v1': 1, # CLS only 303 | 'uni_v2': 9, # CLS + 8 register tokens 304 | 'virchow_v1': 1, # CLS only 305 | 'virchow_v2': 5, # CLS + 4 register tokens 306 | 'conch_v1': 1, # CLS only 307 | 'conch_v1_5': 1, # CLS only 308 | 'phikon': 1, # CLS only 309 | 'phikon_v2': 1, # CLS only 310 | 'patho3dmatrix-vision': 1, # CLS only 311 | 'hoptimus_0': 5, # CLS + 4 register tokens 312 | 'hoptimus_1': 5, # CLS + 4 register tokens 313 | 'kaiko-vits8': 5, # CLS + 4 register tokens 314 | 'kaiko-vits16': 5, # CLS + 4 register tokens 315 | 'kaiko-vitb8': 5, # CLS + 4 register tokens 316 | 'kaiko-vitb16': 5, # CLS + 4 register tokens 317 | 'kaiko-vitl14': 5, # CLS + 4 register tokens 318 | 'midnight12k': 1, # CLS only 319 | 'lunit_vits8': 1, # CLS only 320 | 'hibou_l': 5, # CLS + 4 register tokens 321 | 'musk': 1, # CLS only 322 | 'PathOrchestra': 1, # CLS only 323 | } 324 | 325 | 326 | def get_skip_tokens(pfm_name: str) -> int: 327 | """ 328 | Get the number of tokens to skip (CLS + register tokens) for a specific PFM model. 329 | 330 | Args: 331 | pfm_name (str): Name of the PFM model 332 | 333 | Returns: 334 | int: Number of tokens to skip at the beginning of the sequence 335 | """ 336 | if pfm_name in PFM_SKIP_TOKENS: 337 | return PFM_SKIP_TOKENS[pfm_name] 338 | else: 339 | # Default: skip only CLS token 340 | return 1 341 | 342 | 343 | def get_num_heads_from_dim(emb_dim: int) -> int: 344 | """ 345 | Infer the number of attention heads from embedding dimension. 346 | 347 | Common ViT configurations: 348 | - dim=384: 6 heads (head_dim=64) 349 | - dim=768: 12 heads (head_dim=64) 350 | - dim=1024: 16 heads (head_dim=64) 351 | - dim=1280: 16 heads (head_dim=80) or 20 heads (head_dim=64) 352 | - dim=1536: 24 heads (head_dim=64) 353 | 354 | Args: 355 | emb_dim (int): Embedding dimension 356 | 357 | Returns: 358 | int: Number of attention heads 359 | """ 360 | # Standard head_dim is 64 for most ViT models 361 | head_dim = 64 362 | num_heads = emb_dim // head_dim 363 | 364 | # Ensure num_heads is valid (at least 1 and divides evenly) 365 | if emb_dim % head_dim != 0: 366 | # Try common head dimensions 367 | for hd in [64, 80, 96, 128]: 368 | if emb_dim % hd == 0: 369 | num_heads = emb_dim // hd 370 | break 371 | 372 | return max(1, num_heads) 373 | 374 | 375 | def equip_model_with_transformer_adapter( 376 | model: nn.Module, 377 | adapter_config: Dict[str, Any] 378 | ) -> nn.Module: 379 | """ 380 | Equip a PFMSegmentationModel with Transformer Adapter. 381 | 382 | Data flow: Image -> PFM (frozen, full output with CLS) -> Transformer Adapter (trainable) 383 | -> Extract patch tokens -> Decoder (trainable) -> Segmentation Head (trainable) 384 | 385 | The Transformer Adapter adds extra Vision Blocks after the frozen PFM encoder. 386 | All PFM output tokens (including CLS and register tokens) are passed to Vision Blocks, 387 | then patch tokens are extracted before the decoder. 388 | 389 | Args: 390 | model (nn.Module): PFMSegmentationModel instance 391 | adapter_config (dict): Transformer adapter configuration with keys: 392 | - depth (int): Number of Vision Blocks (default: 4) 393 | - num_heads (int): Number of attention heads (default: auto-inferred from emb_dim) 394 | - mlp_ratio (float): MLP hidden dim ratio (default: 4.0) 395 | - drop_rate (float): Dropout rate (default: 0.0) 396 | - attn_drop_rate (float): Attention dropout rate (default: 0.0) 397 | - drop_path_rate (float): Stochastic depth rate (default: 0.1) 398 | - init_values (float): Layer scale init value (default: 1e-5) 399 | - qk_norm (bool): Whether to use QK normalization (default: False) 400 | 401 | Returns: 402 | nn.Module: Model equipped with Transformer Adapter 403 | """ 404 | import types 405 | import logging 406 | 407 | logger = logging.getLogger(__name__) 408 | 409 | # Get PFM model properties - read directly from existing model 410 | PFM_name = model.PFM_name 411 | emb_dim = model.decoder.conv_more[0].in_channels # Get emb_dim from existing decoder 412 | 413 | # Get number of tokens to skip (CLS + register tokens) 414 | skip_tokens = get_skip_tokens(PFM_name) 415 | 416 | # Get adapter config 417 | depth = adapter_config.get('depth', 4) 418 | mlp_ratio = adapter_config.get('mlp_ratio', 4.0) 419 | drop_rate = adapter_config.get('drop_rate', 0.0) 420 | attn_drop_rate = adapter_config.get('attn_drop_rate', 0.0) 421 | drop_path_rate = adapter_config.get('drop_path_rate', 0.1) 422 | init_values = adapter_config.get('init_values', 1e-5) 423 | qk_norm = adapter_config.get('qk_norm', False) 424 | 425 | # Get num_heads: from config if provided, otherwise infer from emb_dim 426 | num_heads = adapter_config.get('num_heads', None) 427 | if num_heads is None: 428 | num_heads = get_num_heads_from_dim(emb_dim) 429 | 430 | # Create Transformer Adapter 431 | transformer_adapter = TransformerAdapter( 432 | dim=emb_dim, 433 | depth=depth, 434 | num_heads=num_heads, 435 | mlp_ratio=mlp_ratio, 436 | qkv_bias=True, 437 | qk_norm=qk_norm, 438 | drop_rate=drop_rate, 439 | attn_drop_rate=attn_drop_rate, 440 | drop_path_rate=drop_path_rate, 441 | init_values=init_values, 442 | ) 443 | 444 | # Add transformer adapter to model 445 | model.transformer_adapter = transformer_adapter 446 | 447 | # Store skip_tokens for use in forward method 448 | model._transformer_adapter_skip_tokens = skip_tokens 449 | 450 | logger.info(f"Transformer Adapter created: depth={depth}, num_heads={num_heads}, " 451 | f"dim={emb_dim}, mlp_ratio={mlp_ratio}, skip_tokens={skip_tokens}") 452 | 453 | # Define new forward method that uses transformer adapter 454 | def forward_with_transformer_adapter(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 455 | """ 456 | Forward pass with Transformer Adapter. 457 | 458 | Flow: Image -> PFM (frozen, full output) -> Transformer Adapter (all tokens) 459 | -> Extract patch tokens -> Decoder -> Seg Head 460 | 461 | Key difference from original: PFM outputs ALL tokens (including CLS and register tokens) 462 | to Vision Blocks, then we extract patch tokens after Vision Blocks processing. 463 | """ 464 | # Handle single channel images 465 | if x.size(1) == 1: 466 | x = x.repeat(1, 3, 1, 1) 467 | 468 | # Step 1: Extract FULL features from frozen PFM (including CLS and register tokens) 469 | # Different PFM models have different output structures 470 | if self.PFM_name == 'virchow_v1': 471 | # Virchow v1: returns all tokens including CLS 472 | features = self.pfm(x) # (B, N+1, dim) where N is num_patches 473 | elif self.PFM_name == 'virchow_v2': 474 | # Virchow v2: returns all tokens including CLS + register tokens 475 | features = self.pfm(x) # (B, N+5, dim) 476 | elif self.PFM_name == 'conch_v1': 477 | # CONCH v1: access through visual.trunk 478 | features = self.pfm.visual.trunk.forward_features(x) # (B, N+1, dim) 479 | elif self.PFM_name == 'conch_v1_5': 480 | # CONCH v1.5: access through trunk 481 | features = self.pfm.trunk.forward_features(x) # (B, N+1, dim) 482 | elif self.PFM_name == 'phikon' or self.PFM_name == 'phikon_v2': 483 | # Phikon: transformers ViTModel wrapper 484 | features = self.pfm(x) # (B, N+1, dim) 485 | elif self.PFM_name == 'hibou_l': 486 | # Hibou-L: transformers AutoModel wrapper 487 | features = self.pfm(x) # (B, N+5, dim) 488 | elif self.PFM_name == 'musk': 489 | # MUSK: wrapper returns all tokens 490 | features = self.pfm.forward(x) # (B, N+1, dim) 491 | elif self.PFM_name == 'lunit_vits8': 492 | # Lunit: standard timm ViT 493 | features = self.pfm.forward_features(x) # (B, N+1, dim) 494 | elif self.PFM_name == 'midnight12k': 495 | # Midnight-12k: transformers AutoModel wrapper 496 | features = self.pfm.forward_features(x) # (B, N+1, dim) 497 | elif self.PFM_name.startswith('kaiko-'): 498 | # Kaiko models: standard timm ViT with register tokens 499 | features = self.pfm.forward_features(x) # (B, N+5, dim) 500 | elif self.PFM_name == 'hoptimus_0' or self.PFM_name == 'hoptimus_1': 501 | # H-Optimus: standard timm ViT with register tokens 502 | features = self.pfm.forward_features(x) # (B, N+5, dim) 503 | elif self.PFM_name == 'patho3dmatrix-vision': 504 | # Patho3DMatrix: standard timm ViT 505 | features = self.pfm.forward_features(x) # (B, N+1, dim) 506 | elif self.PFM_name == 'uni_v2': 507 | # UNI v2: standard timm ViT with register tokens 508 | features = self.pfm.forward_features(x) # (B, N+9, dim) 509 | elif self.PFM_name == 'PathOrchestra': 510 | # PathOrchestra: standard timm ViT 511 | features = self.pfm.forward_features(x) # (B, N+1, dim) 512 | else: 513 | # Default: assume standard timm ViT structure 514 | features = self.pfm.forward_features(x) # (B, N+1, dim) 515 | 516 | # Step 2: Apply Transformer Adapter to ALL tokens (including CLS and register tokens) 517 | features = self.transformer_adapter(features) 518 | 519 | # Step 3: Extract patch tokens (skip CLS and register tokens) 520 | skip_tokens = self._transformer_adapter_skip_tokens 521 | patch_features = features[:, skip_tokens:, :] # (B, N, dim) 522 | 523 | # Step 4: Decode features (only patch tokens) 524 | decoded_features = self.decoder(patch_features) 525 | 526 | # Step 5: Generate final predictions 527 | logits = self.segmentation_head(decoded_features) 528 | 529 | return {'out': logits} 530 | 531 | # Bind the new forward method to the model 532 | model.forward = types.MethodType(forward_with_transformer_adapter, model) 533 | 534 | return model 535 | 536 | --------------------------------------------------------------------------------