├── .gitignore ├── README.md ├── __init__.py ├── architectures ├── __init__.py ├── build_architecture.py └── mobileunetr.py ├── augmentations ├── __init__.py ├── segmentation_augmentations.py └── segmentation_augmentations_tv.py ├── dataloaders ├── __init__.py ├── build_dataset.py └── isic_dataset.py ├── example_configs ├── config_mobileunetr_s.yaml ├── config_mobileunetr_xs.yaml └── config_mobileunetr_xxs.yaml ├── experiments_medical ├── isic_2016 │ └── exp_2_dice_b8_a2 │ │ ├── config.yaml │ │ ├── generate_images.ipynb │ │ ├── inference_model_weights │ │ └── .gitignore │ │ ├── metrics.ipynb │ │ ├── run_experiment.py │ │ └── visualize.ipynb ├── isic_2017 │ └── exp_2_dice_b8_a2 │ │ ├── config.yaml │ │ ├── inference_model_weights │ │ └── .gitignore │ │ ├── metrics.ipynb │ │ └── run_experiment.py ├── isic_2018 │ └── exp_2_dice_b8_a2 │ │ ├── config.yaml │ │ ├── inference_model_weights │ │ └── .gitignore │ │ ├── metrics.ipynb │ │ ├── run_experiment.py │ │ └── visualize.ipynb └── ph2 │ └── exp_2_dice_b8_a2 │ ├── config.yaml │ ├── inference_model_weights │ └── .gitignore │ ├── metrics.ipynb │ ├── run_experiment.py │ └── visualize.ipynb ├── losses ├── __init__.py └── losses.py ├── misc ├── flop_counter.py ├── mobileunetr.py └── profiler.ipynb ├── optimizers ├── __init__.py ├── optimizers.py └── schedulers.py ├── requirements.txt ├── resources ├── .gitignore ├── adv_arch.png ├── citiscapes1.png ├── cityscapes2.png ├── cityscapes3.png ├── cityscapes4.png ├── cityscapes_table.png ├── isic_2016.png ├── isic_2017.png ├── isic_2018.png ├── muvit_architecture.png ├── params.png ├── ph2.png ├── postdam1_gt.png ├── postdam1_pred.png ├── postdam2_gt.png ├── postdam2_pred.png ├── potsdam.png └── vaihigen.png └── train_scripts ├── __init__.py ├── ema.py ├── segmentation_trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | **/wandb 3 | **/__pycache__ 4 | **.npy 5 | **.pth 6 | **.svg 7 | **.jpg 8 | **.csv 9 | **.bin 10 | **/data 11 | **.tex 12 | **/model_checkpoints 13 | **.pkl 14 | #**.png 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileUNETR 2 | ## A Lightweight End-To-End Hybrid Vision Transformer For Efficient Medical Image Segmentation: [ECCV 2024 -- BioImage Computing] (ORAL) https://arxiv.org/abs/2409.03062 3 | 4 | ## Architecture 5 |

6 |

7 | Wide Image 8 |
9 |

10 | 11 | ## Parameter Distribution and Computational Complexity 12 |

13 |

14 | Wide Image 15 |
16 |

17 | 18 | ## :rocket: News 19 | * Repository Construction Complete 06/09/2024 20 | * We will continue to update the GitHub repository with new experiments with a wide range of datasets, so be sure to check back regularly. 21 | * In the meantime -> Checkout our other projects: https://github.com/OSUPCVLab/SegFormer3D 22 | 23 | ## Overview: 24 | * Segmentation approaches broadly fall into 2 categories. 25 | 1. End to End CNN Based Segmentation Methods 26 | 2. Transformer Based Encoder with a CNN Based Decoder. 27 | * Many Transformer based segmentation approaches rely primarily on CNN based decoders overlooking the benefits of the Transformer architecture within the decoder. 28 | * We address the need for an efficient/lightweight segmentation architecture by introducing MobileUNETR, which aims to overcome the performance constraints associated with both CNNs and Transformers while minimizing model size, presenting a promising stride towards efficient image segmentation. 29 | * MobileUNETR has 3 main features. 30 | 1. MobileUNETR comprises of a lightweight hybrid CNN-Transformer encoder to help balance local and global contextual feature extraction in an efficient manner. 31 | 2. A novel hybrid decoder that simultaneously utilizes low-level and global features at different resolutions within the decoding stage for accurate mask generation. 32 | 3. surpassing large and complex architectures, MobileUNETR achieves superior performance with 3 million parameters and a computational complexity of 1.3 GFLOPs. 33 | 34 | ## Stand Alone Model [Please Read] 35 | To help improve ease of use of the MobileUNETR architecture, the model is constructed as a single stand alone file. If you want to use the model outside of the provided code base simply grab the mobileunetr.py file from architectures folder and insert it into your own project. 36 | 37 | * Example: 38 | ``` 39 | # import from mobileunetr.py file 40 | from mobileunetr import build_mobileunetr_s, build_mobileunetr_xs, build_mobileunetr_xxs 41 | import torch 42 | 43 | # create model 44 | mobileunetr_s = build_mobileunetr_s(num_classes= 1, image_size=512) 45 | 46 | mobileunetr_xs = build_mobileunetr_xs(num_classes=1, image_size=512) 47 | 48 | mobileunetr_xxs = build_mobileunetr_xxs(num_classes=1, image_size= 512) 49 | 50 | # forward pass 51 | data = torch.randn((4, 3, 512, 512)) 52 | out = mobileunetr_xxs.forward(data) 53 | print(f"input tensor: {data.shape}") 54 | print(f"output tensor: {out.shape}") 55 | 56 | ``` 57 | 58 | ## Data and Data Processing 59 | * ISIC Data -- https://challenge.isic-archive.com/data/ 60 | * PH2 Data -- https://www.fc.up.pt/addi/ph2%20database.html 61 | 62 | * Data Preprocessing 63 | * For each dataset ISIC 2016, ISIC 2017 ... etc, Simply create a csv file with N rows 2 columns. Where N is the number of items in the dataset and 2 columns ["image", "mask"] are paths to input image and the path to target mask. 64 | * Once you have a train.csv and a test.csv (lets assume for ISIC 2016), inside experiments/isic_2016/exp_2_dice_b8_a2/config.yaml update the data path for the train and test csv files. And Follow the steps below to run the experiment. 65 | 66 | ## Run Your Experiment 67 | In order to run an experiment, we provide a template folder placed under `MobileUNETR_HOME_PATH/experiments/isic_2016/experiment_folder` that you can use to setup your experiment. While inside the "experiment_folder" run your experiment on a single GPU with: 68 | ```shell 69 | cd MobileUNETR 70 | cd experiments/isic_2016/exp_2_dice_b8_a2/ 71 | # the default gpu device is set to cuda:0 (you can change it) 72 | CUDA_VISIBLE_DEVICES="0" accelerate launch run_experiment.py 73 | ``` 74 | You might want to change the hyperparameters (batch size, learning rate, weight decay etc.) of your experiment. For that you need to edit the `config.yaml` file inside your experiment folder. 75 | 76 | As the experiment is running, the logs (train loss, vlaidation loss and dice score) will be written to the terminal. You can log your experiment on [wandb](https://wandb.ai/site) 77 | (you need to setup an account there) if you set `mode: "online"` in the `wandb_parameters` section of the `config.yaml`. The default value is `mode: "offline"`. If you want to log the result to your wandb account, put your wandb info into the `wandb_parameters` section of the `config.yaml` and your entire experiment will be logged under your wandb entity (e.g. `pcvlab`) page. 78 | 79 | ## ISIC 2016 Performance 80 |

81 |

82 | Wide Image 83 |
84 |

85 | 86 | ## ISIC 2017 Performance 87 |

88 |

89 | Wide Image 90 |
91 |

92 | 93 | ## ISIC 2018 Performance 94 |

95 |

96 | Wide Image 97 |
98 |

99 | 100 | ## ISIC PH2 Performance 101 |

102 |

103 | Wide Image 104 |
105 |

106 | 107 | ## Advanced Architectures and Training Methods 108 |

109 |

110 | Wide Image 111 |
112 |

113 | 114 | ## Experiments: Extending to Complex Real World Scenes (Cityscapes, Potsdamn and Vaihigen) 115 | 116 | ### Cityscapes Results 117 | 118 |

119 |

120 | Wide Image 121 |
122 |

123 | 124 |

125 |

126 | Wide Image 127 |
128 |

129 | 130 |

131 |

132 | Wide Image 133 |
134 |

135 | 136 |

137 |

138 | Wide Image 139 |
140 |

141 | 142 |

143 |

144 | Wide Image 145 |
146 |

147 | 148 | ### Potsdam and Vaihigen Results (GT [Left], Prediction Overlay [Right]) 149 | 150 |

151 |

152 | Wide Image 153 | Another Image 154 |
155 |

156 | 157 |

158 |

159 | Wide Image 160 | Another Image 161 |
162 |

163 | 164 | ### Potsdam Left Table and Vaihigen Right Table 165 |

166 |

167 | Wide Image 168 | Another Image 169 |
170 |

171 | 172 | ## Citation 173 | If you liked our paper, please consider citing it [will update TBD sections soon] 174 | ```bibtex 175 | @inproceedings{perera2024mobileunetr, 176 | title={MobileUNETR: A Lightweight End-To-End Hybrid Vision Transformer For Efficient Medical Image Segmentation}, 177 | author={Perera, Shehan, Erzurumlu, Yunus, Gulati, Deepak and Yilmaz, Alper}, 178 | booktitle={Proceedings of the IEEE/CVF European Conference on Computer Vision (ECCV)}, 179 | pages={TBD}, 180 | year={2024} 181 | } 182 | ``` 183 | ```bibtex 184 | @article{perera2024mobileunetr, 185 | title={MobileUNETR: A Lightweight End-To-End Hybrid Vision Transformer For Efficient Medical Image Segmentation}, 186 | author={Perera, Shehan, Erzurumlu, Yunus, Gulati, Deepak and Yilmaz, Alper}, 187 | journal={https://arxiv.org/abs/2409.03062}, 188 | year={2024} 189 | } 190 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/__init__.py -------------------------------------------------------------------------------- /architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/architectures/__init__.py -------------------------------------------------------------------------------- /architectures/build_architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | To select the architecture based on a config file we need to ensure 3 | we import each of the architectures into this file. Once we have that 4 | we can use a keyword from the config file to build the model. 5 | """ 6 | 7 | 8 | ###################################################################### 9 | def build_architecture(config): 10 | if config["model_name"] == "mobileunetr_xs": 11 | from .mobileunetr import build_mobileunetr_xs 12 | 13 | model = build_mobileunetr_xs(config=config) 14 | return model 15 | 16 | elif config["model_name"] == "mobileunetr_xxs": 17 | from .mobileunetr import build_mobileunetr_xxs 18 | 19 | model = build_mobileunetr_xxs(config=config) 20 | return model 21 | 22 | elif config["model_name"] == "mobileunetr_s": 23 | from .mobileunetr import build_mobileunetr_s 24 | 25 | model = build_mobileunetr_s(config=config) 26 | return model 27 | 28 | else: 29 | return ValueError( 30 | "specified model not supported, edit build_architecture.py file" 31 | ) 32 | -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/augmentations/__init__.py -------------------------------------------------------------------------------- /augmentations/segmentation_augmentations.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | from albumentations.pytorch import ToTensorV2 3 | import cv2 4 | from typing import Dict 5 | 6 | 7 | ####################################################################################### 8 | def build_augmentations(train: bool = True, augmentation_args: Dict = None): 9 | 10 | mean = augmentation_args["mean"] 11 | std = augmentation_args["std"] 12 | image_size = augmentation_args["image_size"] 13 | 14 | if train: 15 | train_transform = A.Compose( 16 | [ 17 | A.Resize(image_size[0], image_size[1], interpolation=cv2.INTER_CUBIC), 18 | A.HorizontalFlip(p=0.5), 19 | A.VerticalFlip(p=0.5), 20 | A.RandomBrightnessContrast( 21 | brightness_limit=0.5, 22 | contrast_limit=0.5, 23 | p=0.6, 24 | ), 25 | A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), 26 | A.RandomRotate90(p=0.5), 27 | A.RandomResizedCrop( 28 | image_size[0], 29 | image_size[1], 30 | interpolation=cv2.INTER_CUBIC, 31 | p=0.5, 32 | ), 33 | A.Normalize(mean, std), 34 | ToTensorV2(), 35 | ] 36 | ) 37 | 38 | return train_transform 39 | else: 40 | test_transform = A.Compose( 41 | [ 42 | A.Resize( 43 | image_size[0], 44 | image_size[1], 45 | interpolation=cv2.INTER_CUBIC, 46 | ), 47 | A.Normalize(mean, std), 48 | ToTensorV2(), 49 | ] 50 | ) 51 | return test_transform 52 | 53 | 54 | ####################################################################################### 55 | def build_augmentations_v2(train: bool = True, augmentation_args: Dict = None): 56 | 57 | mean = augmentation_args["mean"] 58 | std = augmentation_args["std"] 59 | image_size = augmentation_args["image_size"] 60 | 61 | if train: 62 | train_transform = A.Compose( 63 | [ 64 | A.Resize(image_size[0], image_size[1], interpolation=cv2.INTER_CUBIC), 65 | A.HorizontalFlip(p=0.5), 66 | A.VerticalFlip(p=0.5), 67 | A.RandomBrightnessContrast( 68 | brightness_limit=0.5, 69 | contrast_limit=0.5, 70 | p=0.6, 71 | ), 72 | A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), 73 | A.RandomRotate90(p=0.5), 74 | A.OneOf( 75 | [ 76 | A.RandomResizedCrop( 77 | image_size[0], 78 | image_size[1], 79 | interpolation=cv2.INTER_CUBIC, 80 | p=1, 81 | ), 82 | A.Compose( 83 | [ 84 | A.CropNonEmptyMaskIfExists(128, 128, p=1), 85 | A.Resize( 86 | image_size[0], 87 | image_size[1], 88 | interpolation=cv2.INTER_CUBIC, 89 | ), 90 | ] 91 | ), 92 | A.Compose( 93 | [ 94 | A.CropNonEmptyMaskIfExists(64, 64, p=1), 95 | A.Resize( 96 | image_size[0], 97 | image_size[1], 98 | interpolation=cv2.INTER_CUBIC, 99 | ), 100 | ] 101 | ), 102 | A.Compose( 103 | [ 104 | A.CropNonEmptyMaskIfExists(256, 256, p=1), 105 | A.Resize( 106 | image_size[0], 107 | image_size[1], 108 | interpolation=cv2.INTER_CUBIC, 109 | ), 110 | ] 111 | ), 112 | A.Compose( 113 | [ 114 | A.CropNonEmptyMaskIfExists(160, 160, p=1), 115 | A.Resize( 116 | image_size[0], 117 | image_size[1], 118 | interpolation=cv2.INTER_CUBIC, 119 | ), 120 | ] 121 | ), 122 | ], 123 | p=0.5, 124 | ), 125 | A.Normalize(mean, std), 126 | ToTensorV2(), 127 | ] 128 | ) 129 | 130 | return train_transform 131 | else: 132 | test_transform = A.Compose( 133 | [ 134 | A.Resize(image_size[0], image_size[1], interpolation=cv2.INTER_CUBIC), 135 | A.HorizontalFlip(p=0.3), 136 | A.VerticalFlip(p=0.3), 137 | A.Normalize(mean, std), 138 | ToTensorV2(), 139 | ] 140 | ) 141 | return test_transform 142 | 143 | 144 | ####################################################################################### 145 | -------------------------------------------------------------------------------- /augmentations/segmentation_augmentations_tv.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from typing import Dict 3 | 4 | 5 | # [0.7128, 0.6000, 0.5532], [0.1577, 0.1662, 0.1829] 6 | ####################################################################################### 7 | def build_augmentations(train: bool = True, augmentation_args: Dict = None): 8 | mean = augmentation_args["mean"] 9 | std = augmentation_args["std"] 10 | image_size = augmentation_args["image_size"] 11 | 12 | if train: 13 | train_transform = transforms.Compose( 14 | [ 15 | transforms.Resize( 16 | (image_size[0], image_size[1]), 17 | interpolation=transforms.InterpolationMode.BICUBIC, 18 | ), 19 | transforms.autoaugment.TrivialAugmentWide( 20 | interpolation=transforms.InterpolationMode.BILINEAR 21 | ), 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean, std), 24 | ] 25 | ) 26 | 27 | return train_transform 28 | else: 29 | test_transform = transforms.Compose( 30 | [ 31 | # transforms.RandomRotation(180), 32 | transforms.Resize( 33 | (image_size[0], image_size[1]), 34 | interpolation=transforms.InterpolationMode.BICUBIC, 35 | ), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean, std), 38 | ] 39 | ) 40 | 41 | return test_transform 42 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/build_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("../") 4 | from typing import Dict 5 | from torch.utils.data import DataLoader 6 | from monai.data import DataLoader 7 | 8 | 9 | ############################################################################################# 10 | def build_dataset(dataset_type: str, dataset_args: Dict, augmentation_args: Dict): 11 | if dataset_type == "isic_torchvision": 12 | from .isic_dataset import ISICDataset 13 | from augmentations.segmentation_augmentations_tv import build_augmentations 14 | 15 | dataset = ISICDataset( 16 | data=dataset_args["data_path"], 17 | image_size=dataset_args["image_size"], 18 | input_transforms=build_augmentations( 19 | dataset_args["train"], 20 | augmentation_args, 21 | ), 22 | target_transforms=True, 23 | ) 24 | return dataset 25 | elif dataset_type == "isic_albumentation_v2": 26 | from .isic_dataset import ISICDatasetA 27 | from augmentations.segmentation_augmentations import build_augmentations_v2 28 | 29 | dataset = ISICDatasetA( 30 | data=dataset_args["data_path"], 31 | transforms=build_augmentations_v2(dataset_args["train"], augmentation_args), 32 | ) 33 | return dataset 34 | else: 35 | raise NotImplementedError("datasets are supported") 36 | 37 | 38 | ############################################################################################# 39 | def build_dataloader( 40 | dataset, 41 | dataloader_args: Dict, 42 | config: Dict = None, 43 | train: bool = True, 44 | ) -> DataLoader: 45 | """builds the dataloader for given dataset 46 | 47 | Args: 48 | dataset (_type_): _description_ 49 | dataloader_args (Dict): _description_ 50 | config (Dict, optional): _description_. Defaults to None. 51 | train (bool, optional): _description_. Defaults to True. 52 | 53 | Returns: 54 | DataLoader: _description_ 55 | """ 56 | dataloader = DataLoader( 57 | dataset=dataset, 58 | batch_size=dataloader_args["batch_size"], 59 | shuffle=dataloader_args["shuffle"], 60 | num_workers=dataloader_args["num_workers"], 61 | drop_last=dataloader_args["drop_last"], 62 | pin_memory=dataloader_args["pin_memory"], 63 | ) 64 | return dataloader 65 | 66 | 67 | ############################################################################################# 68 | -------------------------------------------------------------------------------- /dataloaders/isic_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision 4 | import pandas as pd 5 | from PIL import Image 6 | import torchvision.transforms.functional as tvF 7 | 8 | 9 | ########################################################################################################## 10 | class ISICDataset(torch.utils.data.Dataset): 11 | def __init__(self, data, image_size, input_transforms, target_transforms): 12 | """ 13 | Initialize Class 14 | Args: 15 | config (_type_): _description_ 16 | data (_type_): _description_ 17 | input_transforms (_type_): _description_ 18 | target_transforms (_type_): _description_ 19 | """ 20 | self.data = pd.read_csv(data) 21 | self.image_size = image_size # tuple 22 | self.input_transforms = input_transforms 23 | self.target_transforms = target_transforms 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, index): 29 | # get the associated_image 30 | image_id = self.data.iloc[index] 31 | 32 | # get the image 33 | image_name = image_id["image"] 34 | image = Image.open(image_name).convert("RGB") 35 | 36 | # get the mask path 37 | mask_id = image_id["mask"] 38 | mask = Image.open(mask_id) 39 | 40 | # transform input image 41 | if self.input_transforms: 42 | image = self.input_transforms(image) 43 | 44 | # transform target image 45 | if self.target_transforms: 46 | mask = tvF.resize( 47 | mask, 48 | self.image_size, 49 | torchvision.transforms.InterpolationMode.NEAREST, 50 | ) 51 | mask = np.array(mask).astype(int) 52 | mask[mask == 255] = 1 53 | mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0) 54 | 55 | out = { 56 | "image": image, # [3, H, W] 57 | "mask": mask, # [1, H, W] 1 b/c binary mask 58 | } 59 | 60 | return out 61 | 62 | 63 | ########################################################################################################## 64 | class ISICDatasetA(torch.utils.data.Dataset): 65 | """ 66 | Dataset used for albumentation augmentations 67 | """ 68 | 69 | def __init__(self, data, transforms): 70 | """ 71 | Initialize Class 72 | Args: 73 | config (_type_): _description_ 74 | data (_type_): _description_ 75 | input_transforms (_type_): _description_ 76 | target_transforms (_type_): _description_ 77 | """ 78 | self.data = pd.read_csv(data) 79 | # self.data = pd.concat([self.data] * 2).sample(frac=1).reset_index(drop=True) 80 | self.transforms = transforms 81 | 82 | def __len__(self): 83 | return len(self.data) 84 | 85 | def __getitem__(self, index): 86 | # get the associated_image 87 | image_id = self.data.iloc[index] 88 | 89 | # get the image 90 | image_name = image_id["image"] 91 | image = Image.open(image_name).convert("RGB") 92 | image = np.array(image, dtype=np.uint8) 93 | 94 | # get the mask path 95 | mask_id = image_id["mask"] 96 | mask = Image.open(mask_id) 97 | mask = np.array(mask, dtype=np.uint8) 98 | mask[mask == 255] = 1 99 | 100 | # transform 101 | if self.transforms: 102 | out = self.transforms(image=image, mask=mask) 103 | image = out["image"] 104 | mask = out["mask"] 105 | 106 | out = { 107 | "image": image, # [3, H, W] 108 | "mask": mask.unsqueeze(0), # [1, H, W] 1 b/c binary mask 109 | } 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /example_configs/config_mobileunetr_s.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: isic_2016 16 | name: exp_2_dice_b8_a2 17 | mode: "online" 18 | resume: False 19 | tags: ["tr16ts16", "dice", "s", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xxs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [240] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [160, 196, 196] 34 | decoder: 35 | dims: [144, 192, 240] 36 | channels: [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 196, 196, 640] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_2016.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_2016.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_2016.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /example_configs/config_mobileunetr_xs.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: isic_2016 16 | name: exp_2_dice_b8_a2 17 | mode: "online" 18 | resume: False 19 | tags: ["tr16ts16", "dice", "xs", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [144] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [96, 128, 128] 34 | decoder: 35 | dims: [96, 120, 144] 36 | channels: [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 128, 128, 384] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_2016.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_2016.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_2016.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /example_configs/config_mobileunetr_xxs.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: isic_2016 16 | name: exp_2_dice_b8_a2 17 | mode: "online" 18 | resume: False 19 | tags: ["tr16ts16", "dice", "xxs", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xxs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [96] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [80, 96, 96] 34 | decoder: 35 | dims: [64, 80, 96] 36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_2016.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_2016.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_2016.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /experiments_medical/isic_2016/exp_2_dice_b8_a2/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: isic_2016 16 | name: exp_2_dice_b8_a2 17 | mode: "online" 18 | resume: False 19 | tags: ["tr16ts16", "dice", "xxs", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xxs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [96] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [80, 96, 96] 34 | decoder: 35 | dims: [64, 80, 96] 36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_2016.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_2016.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_2016.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /experiments_medical/isic_2016/exp_2_dice_b8_a2/generate_images.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import numpy as np\n", 11 | "from PIL import Image\n", 12 | "import matplotlib.image as mpimg\n", 13 | "import matplotlib.pyplot as plt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "img_1 = [\n", 23 | " \"predictions/rgb/image_8.png\",\n", 24 | " \"predictions/gt/gt_8.png\",\n", 25 | " \"predictions/pred/pred_8.png\",\n", 26 | " None,\n", 27 | "]\n", 28 | "img_2 = [\n", 29 | " \"predictions/rgb/image_32.png\",\n", 30 | " \"predictions/gt/gt_32.png\",\n", 31 | " \"predictions/pred/pred_32.png\",\n", 32 | " None,\n", 33 | "]\n", 34 | "img_3 = [\n", 35 | " \"predictions/rgb/image_14.png\",\n", 36 | " \"predictions/gt/gt_14.png\",\n", 37 | " \"predictions/pred/pred_14.png\",\n", 38 | " None,\n", 39 | "]\n", 40 | "img_4 = [\n", 41 | " \"predictions/rgb/image_28.png\",\n", 42 | " \"predictions/gt/gt_28.png\",\n", 43 | " \"predictions/pred/pred_28.png\",\n", 44 | " None,\n", 45 | "]\n", 46 | "img_5 = [\n", 47 | " \"predictions/rgb/image_13.png\",\n", 48 | " \"predictions/gt/gt_13.png\",\n", 49 | " \"predictions/pred/pred_13.png\",\n", 50 | " None,\n", 51 | "]\n", 52 | "\n", 53 | "read_rgb = lambda x: np.array(Image.open(x).convert(\"RGB\"))\n", 54 | "read_gray = lambda x: np.array(Image.open(x).convert(\"L\"))\n", 55 | "\n", 56 | "data = [img_1, img_2, img_3, img_4, img_5]" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import numpy as np\n", 66 | "import cv2\n", 67 | "from PIL import Image, ImageDraw\n", 68 | "\n", 69 | "\n", 70 | "# def create_overlay_image(original_img, ground_truth_mask, predicted_mask):\n", 71 | "# # Convert binary masks to uint8 format\n", 72 | "# ground_truth_mask = ground_truth_mask.astype(np.uint8) * 255\n", 73 | "# predicted_mask = predicted_mask.astype(np.uint8) * 255\n", 74 | "\n", 75 | "# # Create an RGB version of the original image\n", 76 | "# original_rgb = original_img\n", 77 | "\n", 78 | "# # Create an alpha channel for overlaying masks\n", 79 | "# alpha_channel = np.zeros_like(original_img)\n", 80 | "\n", 81 | "# # Set red for ground truth and blue for predicted mask in alpha channel\n", 82 | "# alpha_channel[ground_truth_mask > 0] = [255, 0, 0]\n", 83 | "# alpha_channel[predicted_mask > 0] = [0, 0, 255]\n", 84 | "\n", 85 | "# # Combine original image and masks using alpha blending\n", 86 | "# overlayed_img = cv2.addWeighted(original_rgb, 0.7, alpha_channel, 0.3, 0)\n", 87 | "\n", 88 | "# return overlayed_img\n", 89 | "\n", 90 | "import numpy as np\n", 91 | "import matplotlib.pyplot as plt\n", 92 | "from skimage import measure, color\n", 93 | "\n", 94 | "def create_overlay_image(image, ground_truth_mask, predicted_mask):\n", 95 | " \"\"\"\n", 96 | " Overlays the boundaries of the ground truth mask and the predicted mask on the image.\n", 97 | " \n", 98 | " Parameters:\n", 99 | " - image: numpy array, the original image\n", 100 | " - ground_truth_mask: numpy array, binary ground truth mask\n", 101 | " - predicted_mask: numpy array, binary predicted mask\n", 102 | " \n", 103 | " Returns:\n", 104 | " - overlay_image: numpy array, the original image with overlaid boundaries\n", 105 | " \"\"\"\n", 106 | " def generate_boundary(mask):\n", 107 | " # Ensure the mask is binary\n", 108 | " mask = mask.astype(bool)\n", 109 | " \n", 110 | " # Find contours at a constant value of 0.5\n", 111 | " contours = measure.find_contours(mask, 0.5)\n", 112 | " \n", 113 | " # Create an empty image to draw the boundaries\n", 114 | " boundary = np.zeros_like(mask, dtype=np.uint8)\n", 115 | " \n", 116 | " # Draw the contours on the boundary image\n", 117 | " for contour in contours:\n", 118 | " for y, x in contour:\n", 119 | " boundary[int(y), int(x)] = 1\n", 120 | " \n", 121 | " return boundary\n", 122 | " \n", 123 | " # Generate boundaries\n", 124 | " ground_truth_boundary = generate_boundary(ground_truth_mask)\n", 125 | " predicted_boundary = generate_boundary(predicted_mask)\n", 126 | " \n", 127 | " # Convert the original image to RGB if it's grayscale\n", 128 | " if len(image.shape) == 2 or image.shape[2] == 1:\n", 129 | " image = color.gray2rgb(image)\n", 130 | " \n", 131 | " # Create a copy of the image to overlay the boundaries\n", 132 | " overlay_image = image.copy()\n", 133 | " \n", 134 | " # Define colors for boundaries\n", 135 | " ground_truth_color = [255, 0, 0] # Red\n", 136 | " predicted_color = [0, 0, 255] # Blue\n", 137 | " \n", 138 | " # Overlay the ground truth boundary\n", 139 | " overlay_image[ground_truth_boundary == 1] = ground_truth_color\n", 140 | " \n", 141 | " # Overlay the predicted boundary\n", 142 | " overlay_image[predicted_boundary == 1] = predicted_color\n", 143 | " \n", 144 | " return overlay_image\n", 145 | "\n", 146 | "\n", 147 | "import matplotlib.pyplot as plt\n", 148 | "import matplotlib.gridspec as gridspec\n", 149 | "from matplotlib.ticker import MultipleLocator\n", 150 | "from PIL import Image\n", 151 | "\n", 152 | "# Assuming 'data' is your list of lists\n", 153 | "# Each inner list should contain [image, ground truth, prediction]\n", 154 | "\n", 155 | "# Create a grid layout\n", 156 | "rows = 4\n", 157 | "cols = len(data)\n", 158 | "fig = plt.figure(figsize=(cols * 2, rows * 2))\n", 159 | "gs = gridspec.GridSpec(rows, cols, wspace=0.01, hspace=0.01)\n", 160 | "\n", 161 | "# Plot each image, ground truth, and prediction\n", 162 | "for col, sample in enumerate(data):\n", 163 | " for row, img_type in enumerate(sample):\n", 164 | "\n", 165 | " if row != 3:\n", 166 | " ax = plt.subplot(gs[row, col])\n", 167 | " img = Image.open(img_type)\n", 168 | "\n", 169 | " if row == 3:\n", 170 | " img = create_overlay_image(\n", 171 | " np.array(Image.open(sample[0])),\n", 172 | " np.array(Image.open(sample[1]).convert(\"L\")),\n", 173 | " np.array(Image.open(sample[2]).convert(\"L\")),\n", 174 | " )\n", 175 | "\n", 176 | " # Display image\n", 177 | " ax.imshow(img, cmap=\"gray\") # Assuming images are in grayscale\n", 178 | " ax.axis(\"off\")\n", 179 | "\n", 180 | " # Set titles on the leftmost side\n", 181 | " if col == 0 and row != 3:\n", 182 | " ax.text(\n", 183 | " -0.01,\n", 184 | " 0.5,\n", 185 | " f\"({chr(97+row)})\",\n", 186 | " transform=ax.transAxes,\n", 187 | " fontsize=12,\n", 188 | " va=\"center\",\n", 189 | " ha=\"right\",\n", 190 | " )\n", 191 | "\n", 192 | " # # Set titles for each row\n", 193 | " # if row == 0:\n", 194 | " # ax.set_title('Image')\n", 195 | " # elif row == 1:\n", 196 | " # ax.set_title('Ground Truth')\n", 197 | " # elif row == 2:\n", 198 | " # ax.set_title('Prediction')\n", 199 | "\n", 200 | "# Adjust layout\n", 201 | "plt.tight_layout()\n", 202 | "plt.savefig(\"combined_2016_new.png\", bbox_inches=\"tight\")\n", 203 | "plt.show()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# def plot_images(input_list):\n", 213 | "# num_images = len(input_list)\n", 214 | "\n", 215 | "# # Set up the subplots\n", 216 | "# fig, axes = plt.subplots(3, num_images, squeeze=True)\n", 217 | "\n", 218 | "# # Set titles on the left side\n", 219 | "# titles = [\"Image\", \"Mask\", \"Prediction\"]\n", 220 | "# for i, title in enumerate(titles):\n", 221 | "# fig.text(\n", 222 | "# 0.03,\n", 223 | "# 0.5,\n", 224 | "# f\"{title}\",\n", 225 | "# va=\"center\",\n", 226 | "# ha=\"center\",\n", 227 | "# rotation=\"vertical\",\n", 228 | "# fontsize=9,\n", 229 | "# )\n", 230 | "\n", 231 | "# for i, input_paths in enumerate(input_list):\n", 232 | "# img_path, mask_path, pred_path = input_paths\n", 233 | "\n", 234 | "# # Load images\n", 235 | "# img = mpimg.imread(img_path)\n", 236 | "# mask = mpimg.imread(mask_path)\n", 237 | "# pred = mpimg.imread(pred_path)\n", 238 | "\n", 239 | "# axes[2, i].set_aspect(\"equal\")\n", 240 | "\n", 241 | "# # Plot the images\n", 242 | "# axes[0, i].imshow(img)\n", 243 | "# axes[1, i].imshow(mask, cmap=\"gray\")\n", 244 | "# axes[2, i].imshow(pred, cmap=\"gray\")\n", 245 | "\n", 246 | "# # Turn off axis labels\n", 247 | "# for ax in axes[:, i]:\n", 248 | "# ax.axis(\"off\")\n", 249 | "\n", 250 | "# plt.tight_layout(h_pad=0, w_pad=0, pad=0.0001)\n", 251 | "# plt.subplots_adjust(wspace=0.01, hspace=0.01)\n", 252 | "# plt.savefig(\"png_2016.png\")\n", 253 | "# plt.show()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "# plot_images(data)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "# from PIL import Image, ImageDraw\n", 272 | "# import math\n", 273 | "\n", 274 | "\n", 275 | "# def combine_images_with_masks(image_data, output_path):\n", 276 | "# num_images = len(image_data)\n", 277 | "# cols = num_images # Number of columns in the output image grid\n", 278 | "# rows = 3 # Number of rows (images, masks, overlays)\n", 279 | "\n", 280 | "# # Create a blank canvas for the combined image\n", 281 | "# combined_width = cols * 200 # Adjust the width of each image as needed\n", 282 | "# combined_height = rows * 200 # Adjust the height of each image as needed\n", 283 | "# combined_image = Image.new(\"RGB\", (combined_width, combined_height), color=\"white\")\n", 284 | "\n", 285 | "# # Paste each image, mask, and overlay onto the canvas in separate rows\n", 286 | "# for i in range(num_images):\n", 287 | "# image_path, mask_path, overlay_path = image_data[i]\n", 288 | "\n", 289 | "# # Load image, mask, and overlay\n", 290 | "# img = Image.open(image_path)\n", 291 | "# mask = Image.open(mask_path).convert(\"L\") # Convert to grayscale if needed\n", 292 | "# overlay = Image.open(overlay_path).convert(\"L\")\n", 293 | "\n", 294 | "# # Resize each image, mask, and overlay as needed\n", 295 | "# img = img.resize((200, 200))\n", 296 | "# mask = mask.resize((200, 200))\n", 297 | "# overlay = overlay.resize((200, 200))\n", 298 | "\n", 299 | "# # Paste image onto the canvas (row 1)\n", 300 | "# x = i * 200\n", 301 | "# y = 0\n", 302 | "# combined_image.paste(img, (x, y))\n", 303 | "\n", 304 | "# # Paste mask onto the canvas (row 2)\n", 305 | "# y = 200\n", 306 | "# combined_image.paste(mask, (x, y))\n", 307 | "\n", 308 | "# # Paste overlay onto the canvas (row 3)\n", 309 | "# y = 400\n", 310 | "# combined_image.paste(overlay, (x, y))\n", 311 | "\n", 312 | "# # Save or display the combined image\n", 313 | "# combined_image.save(output_path)\n", 314 | "# combined_image.show()\n", 315 | "\n", 316 | "\n", 317 | "# output_image_path = \"combined_2016.jpg\"\n", 318 | "# combine_images_with_masks(data, output_image_path)" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [] 327 | } 328 | ], 329 | "metadata": { 330 | "kernelspec": { 331 | "display_name": "corev2", 332 | "language": "python", 333 | "name": "python3" 334 | }, 335 | "language_info": { 336 | "codemirror_mode": { 337 | "name": "ipython", 338 | "version": 3 339 | }, 340 | "file_extension": ".py", 341 | "mimetype": "text/x-python", 342 | "name": "python", 343 | "nbconvert_exporter": "python", 344 | "pygments_lexer": "ipython3", 345 | "version": "3.11.7" 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 2 350 | } 351 | -------------------------------------------------------------------------------- /experiments_medical/isic_2016/exp_2_dice_b8_a2/inference_model_weights/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/isic_2016/exp_2_dice_b8_a2/inference_model_weights/.gitignore -------------------------------------------------------------------------------- /experiments_medical/isic_2016/exp_2_dice_b8_a2/metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/opt/anaconda3/envs/core/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import os\n", 19 | "import sys\n", 20 | "\n", 21 | "sys.path.append(\"../../../\")\n", 22 | "\n", 23 | "import yaml\n", 24 | "import torch\n", 25 | "import numpy as np\n", 26 | "from typing import Dict\n", 27 | "from architectures.build_architecture import build_architecture\n", 28 | "from dataloaders.build_dataset import build_dataset\n", 29 | "from typing import Tuple, Dict\n", 30 | "from fvcore.nn import FlopCountAnalysis\n", 31 | "from tqdm.notebook import tqdm\n", 32 | "from sklearn.metrics import (\n", 33 | " jaccard_score,\n", 34 | " accuracy_score,\n", 35 | " confusion_matrix,\n", 36 | ")\n", 37 | "import monai" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "Load Config File" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "def load_config(config_path: str) -> Dict:\n", 54 | " \"\"\"loads the yaml config file\n", 55 | "\n", 56 | " Args:\n", 57 | " config_path (str): _description_\n", 58 | "\n", 59 | " Returns:\n", 60 | " Dict: _description_\n", 61 | " \"\"\"\n", 62 | " with open(config_path, \"r\") as file:\n", 63 | " config = yaml.safe_load(file)\n", 64 | " return config\n", 65 | "\n", 66 | "\n", 67 | "config = load_config(\"config.yaml\")" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "Build Dataset and DataLoaders" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# build validation dataset & validataion data loader\n", 84 | "testset = build_dataset(\n", 85 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n", 86 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n", 87 | " augmentation_args=config[\"test_augmentation_args\"],\n", 88 | ")\n", 89 | "\n", 90 | "testloader = torch.utils.data.DataLoader(\n", 91 | " testset, batch_size=1, shuffle=False, num_workers=1\n", 92 | ")" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "Build Model and Load Trained Weights" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 12, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "model = build_architecture(config=config)\n", 109 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n", 110 | "model.load_state_dict(checkpoint)\n", 111 | "model = model.to(\"cpu\")\n", 112 | "model = model.eval()" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "Model Complexity" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 14, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Computational complexity: 1.3 GMac\n", 132 | "Number of parameters: 3.01 M \n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "import torchvision.models as models\n", 138 | "import torch\n", 139 | "from ptflops import get_model_complexity_info\n", 140 | "\n", 141 | "with torch.cuda.device(0):\n", 142 | " net = model\n", 143 | " macs, params = get_model_complexity_info(\n", 144 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n", 145 | " )\n", 146 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n", 147 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 15, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "name": "stderr", 157 | "output_type": "stream", 158 | "text": [ 159 | "Unsupported operator aten::silu encountered 84 time(s)\n", 160 | "Unsupported operator aten::add encountered 45 time(s)\n", 161 | "Unsupported operator aten::div encountered 15 time(s)\n", 162 | "Unsupported operator aten::ceil encountered 6 time(s)\n", 163 | "Unsupported operator aten::mul encountered 118 time(s)\n", 164 | "Unsupported operator aten::softmax encountered 21 time(s)\n", 165 | "Unsupported operator aten::clone encountered 4 time(s)\n", 166 | "Unsupported operator aten::mul_ encountered 24 time(s)\n", 167 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n", 168 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n", 169 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n" 170 | ] 171 | }, 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "Computational complexity: 3.01 \n", 177 | "Number of parameters: 1.24 \n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "def flop_count_analysis(\n", 183 | " model: torch.nn.Module,\n", 184 | " input_dim: Tuple,\n", 185 | ") -> Dict:\n", 186 | " \"\"\"_summary_\n", 187 | "\n", 188 | " Args:\n", 189 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n", 190 | " model (torch.nn.Module): _description_\n", 191 | " \"\"\"\n", 192 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 193 | " input_tensor = torch.ones(()).new_empty(\n", 194 | " (1, *input_dim),\n", 195 | " dtype=next(model.parameters()).dtype,\n", 196 | " device=next(model.parameters()).device,\n", 197 | " )\n", 198 | " flops = FlopCountAnalysis(model, input_tensor)\n", 199 | " model_flops = flops.total()\n", 200 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n", 201 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n", 202 | "\n", 203 | " out = {\n", 204 | " \"params\": round(trainable_params * 1e-6, 2),\n", 205 | " \"flops\": round(model_flops * 1e-9, 2),\n", 206 | " }\n", 207 | "\n", 208 | " return out\n", 209 | "\n", 210 | "\n", 211 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n", 212 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n", 213 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "Calculate IoU Metric" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "iou = []\n", 230 | "with torch.no_grad():\n", 231 | " for idx, data in tqdm(enumerate(testloader)):\n", 232 | " image = data[\"image\"].cuda()\n", 233 | " mask = data[\"mask\"].cuda()\n", 234 | " out = model.forward(image)\n", 235 | " out = torch.sigmoid(out)\n", 236 | " out[out < 0.5] = 0\n", 237 | " out[out >= 0.5] = 1\n", 238 | " mean_iou = jaccard_score(\n", 239 | " mask.detach().cpu().numpy().ravel(),\n", 240 | " out.detach().cpu().numpy().ravel(),\n", 241 | " average=\"binary\",\n", 242 | " pos_label=1,\n", 243 | " )\n", 244 | " iou.append(mean_iou.item())\n", 245 | "\n", 246 | "print(f\"test iou: {np.mean(iou)}\")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Accuracy" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "accuracy = []\n", 263 | "with torch.no_grad():\n", 264 | " for idx, data in tqdm(enumerate(testloader)):\n", 265 | " image = data[\"image\"].cuda()\n", 266 | " mask = data[\"mask\"].cuda()\n", 267 | " out = model.forward(image)\n", 268 | " out = torch.sigmoid(out)\n", 269 | " out[out < 0.5] = 0\n", 270 | " out[out >= 0.5] = 1\n", 271 | " acc = accuracy_score(\n", 272 | " mask.detach().cpu().numpy().ravel(),\n", 273 | " out.detach().cpu().numpy().ravel(),\n", 274 | " )\n", 275 | " accuracy.append(acc.item())\n", 276 | "\n", 277 | "print(f\"test accuracy: {np.mean(accuracy)}\")" 278 | ] 279 | }, 280 | { 281 | "attachments": {}, 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "Calculate Dice" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "dice = []\n", 295 | "with torch.no_grad():\n", 296 | " for idx, data in tqdm(enumerate(testloader)):\n", 297 | " image = data[\"image\"].cuda()\n", 298 | " mask = data[\"mask\"].cuda()\n", 299 | " out = model.forward(image)\n", 300 | " out = torch.sigmoid(out)\n", 301 | " out[out < 0.5] = 0\n", 302 | " out[out >= 0.5] = 1\n", 303 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n", 304 | " dice.append(mean_dice.item())\n", 305 | "\n", 306 | "print(f\"test dice: {np.mean(dice)}\")" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "Calculate Specificity" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "specificity = []\n", 323 | "with torch.no_grad():\n", 324 | " for idx, data in tqdm(enumerate(testloader)):\n", 325 | " image = data[\"image\"].cuda()\n", 326 | " mask = data[\"mask\"].cuda()\n", 327 | " out = model.forward(image)\n", 328 | " out = torch.sigmoid(out)\n", 329 | " out[out < 0.5] = 0\n", 330 | " out[out >= 0.5] = 1\n", 331 | " confusion = confusion_matrix(\n", 332 | " mask.detach().cpu().numpy().ravel(),\n", 333 | " out.detach().cpu().numpy().ravel(),\n", 334 | " )\n", 335 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n", 336 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n", 337 | "\n", 338 | " specificity.append(sp)\n", 339 | "\n", 340 | "print(f\"test specificity: {np.mean(specificity)}\")" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": {}, 346 | "source": [ 347 | "Calculate Sensitivity" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "sensitivity = []\n", 357 | "with torch.no_grad():\n", 358 | " for idx, data in tqdm(enumerate(testloader)):\n", 359 | " image = data[\"image\"].cuda()\n", 360 | " mask = data[\"mask\"].cuda()\n", 361 | " out = model.forward(image)\n", 362 | " out = torch.sigmoid(out)\n", 363 | " out[out < 0.5] = 0\n", 364 | " out[out >= 0.5] = 1\n", 365 | " confusion = confusion_matrix(\n", 366 | " mask.detach().cpu().numpy().ravel(),\n", 367 | " out.detach().cpu().numpy().ravel(),\n", 368 | " )\n", 369 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n", 370 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n", 371 | "\n", 372 | " sensitivity.append(se)\n", 373 | "\n", 374 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "# DONE" 384 | ] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "core", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.11.9" 404 | }, 405 | "orig_nbformat": 4, 406 | "vscode": { 407 | "interpreter": { 408 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea" 409 | } 410 | } 411 | }, 412 | "nbformat": 4, 413 | "nbformat_minor": 2 414 | } 415 | -------------------------------------------------------------------------------- /experiments_medical/isic_2016/exp_2_dice_b8_a2/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | sys.path.append("../../../") 6 | 7 | import yaml 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from typing import Dict 12 | from termcolor import colored 13 | from accelerate import Accelerator 14 | from losses.losses import build_loss_fn 15 | from optimizers.optimizers import build_optimizer 16 | from optimizers.schedulers import build_scheduler 17 | from train_scripts.segmentation_trainer import Segmentation_Trainer 18 | from architectures.build_architecture import build_architecture 19 | from dataloaders.build_dataset import build_dataset, build_dataloader 20 | 21 | 22 | ################################################################################################## 23 | def launch_experiment(config_path) -> Dict: 24 | """ 25 | Builds Experiment 26 | Args: 27 | config (Dict): configuration file 28 | 29 | Returns: 30 | Dict: _description_ 31 | """ 32 | # load config 33 | config = load_config(config_path) 34 | 35 | # set seed 36 | seed_everything(config) 37 | 38 | # build directories 39 | build_directories(config) 40 | 41 | # build training dataset & training data loader 42 | trainset = build_dataset( 43 | dataset_type=config["dataset_parameters"]["dataset_type"], 44 | dataset_args=config["dataset_parameters"]["train_dataset_args"], 45 | augmentation_args=config["train_augmentation_args"], 46 | ) 47 | trainloader = build_dataloader( 48 | dataset=trainset, 49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"], 50 | config=config, 51 | train=True, 52 | ) 53 | 54 | # build validation dataset & validataion data loader 55 | valset = build_dataset( 56 | dataset_type=config["dataset_parameters"]["dataset_type"], 57 | dataset_args=config["dataset_parameters"]["val_dataset_args"], 58 | augmentation_args=config["test_augmentation_args"], 59 | ) 60 | valloader = build_dataloader( 61 | dataset=valset, 62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"], 63 | config=config, 64 | train=False, 65 | ) 66 | 67 | # build the Model 68 | model = build_architecture(config) 69 | 70 | # set up the loss function 71 | criterion = build_loss_fn( 72 | loss_type=config["loss_fn"]["loss_type"], 73 | loss_args=config["loss_fn"]["loss_args"], 74 | ) 75 | 76 | # set up the optimizer 77 | optimizer = build_optimizer( 78 | model=model, 79 | optimizer_type=config["optimizer"]["optimizer_type"], 80 | optimizer_args=config["optimizer"]["optimizer_args"], 81 | ) 82 | 83 | # set up schedulers 84 | warmup_scheduler = build_scheduler( 85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config 86 | ) 87 | training_scheduler = build_scheduler( 88 | optimizer=optimizer, 89 | scheduler_type="training_scheduler", 90 | config=config, 91 | ) 92 | 93 | # use accelarate 94 | accelerator = Accelerator( 95 | log_with="wandb", 96 | gradient_accumulation_steps=config["training_parameters"][ 97 | "grad_accumulate_steps" 98 | ], 99 | ) 100 | accelerator.init_trackers( 101 | project_name=config["project"], 102 | config=config, 103 | init_kwargs={"wandb": config["wandb_parameters"]}, 104 | ) 105 | 106 | # display experiment info 107 | display_info(config, accelerator, trainset, valset, model) 108 | 109 | # convert all components to accelerate 110 | model = accelerator.prepare_model(model=model) 111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer) 112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader) 113 | valloader = accelerator.prepare_data_loader(data_loader=valloader) 114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler) 115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler) 116 | 117 | # create a single dict to hold all parameters 118 | storage = { 119 | "model": model, 120 | "trainloader": trainloader, 121 | "valloader": valloader, 122 | "criterion": criterion, 123 | "optimizer": optimizer, 124 | "warmup_scheduler": warmup_scheduler, 125 | "training_scheduler": training_scheduler, 126 | } 127 | 128 | # set up trainer 129 | trainer = Segmentation_Trainer( 130 | config=config, 131 | model=storage["model"], 132 | optimizer=storage["optimizer"], 133 | criterion=storage["criterion"], 134 | train_dataloader=storage["trainloader"], 135 | val_dataloader=storage["valloader"], 136 | warmup_scheduler=storage["warmup_scheduler"], 137 | training_scheduler=storage["training_scheduler"], 138 | accelerator=accelerator, 139 | ) 140 | 141 | # run train 142 | trainer.train() 143 | 144 | 145 | ################################################################################################## 146 | def seed_everything(config) -> None: 147 | seed = config["training_parameters"]["seed"] 148 | os.environ["PYTHONHASHSEED"] = str(seed) 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.benchmark = False 156 | 157 | 158 | ################################################################################################## 159 | def load_config(config_path: str) -> Dict: 160 | """loads the yaml config file 161 | 162 | Args: 163 | config_path (str): _description_ 164 | 165 | Returns: 166 | Dict: _description_ 167 | """ 168 | with open(config_path, "r") as file: 169 | config = yaml.safe_load(file) 170 | return config 171 | 172 | 173 | ################################################################################################## 174 | def build_directories(config: Dict) -> None: 175 | # create necessary directories 176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]): 177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"]) 178 | 179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]): 180 | raise ValueError("checkpoint exits -- preventing file override -- rename file") 181 | 182 | 183 | ################################################################################################## 184 | def display_info(config, accelerator, trainset, valset, model): 185 | # print experiment info 186 | accelerator.print(f"-------------------------------------------------------") 187 | accelerator.print(f"[info]: Experiment Info") 188 | accelerator.print( 189 | f"[info] ----- Project: {colored(config['project'], color='red')}" 190 | ) 191 | accelerator.print( 192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}" 193 | ) 194 | accelerator.print( 195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}" 196 | ) 197 | accelerator.print( 198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}" 199 | ) 200 | accelerator.print( 201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}" 202 | ) 203 | accelerator.print( 204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}" 205 | ) 206 | accelerator.print( 207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}" 208 | ) 209 | accelerator.print( 210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}" 211 | ) 212 | accelerator.print( 213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}" 214 | ) 215 | 216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 217 | accelerator.print( 218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}" 219 | ) 220 | accelerator.print( 221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}" 222 | ) 223 | accelerator.print( 224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}" 225 | ) 226 | accelerator.print( 227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}" 228 | ) 229 | accelerator.print( 230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}" 231 | ) 232 | accelerator.print(f"-------------------------------------------------------") 233 | 234 | 235 | ################################################################################################## 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser(description="Simple example of training script.") 238 | parser.add_argument( 239 | "--config", type=str, default="config.yaml", help="path to yaml config file" 240 | ) 241 | args = parser.parse_args() 242 | launch_experiment(args.config) 243 | -------------------------------------------------------------------------------- /experiments_medical/isic_2016/exp_2_dice_b8_a2/visualize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../../\")\n", 13 | "\n", 14 | "import yaml\n", 15 | "import torch\n", 16 | "from typing import Dict\n", 17 | "from architectures.build_architecture import build_architecture\n", 18 | "from dataloaders.build_dataset import build_dataset\n", 19 | "from torchvision.utils import save_image" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 4, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "def load_config(config_path: str) -> Dict:\n", 29 | " \"\"\"loads the yaml config file\n", 30 | "\n", 31 | " Args:\n", 32 | " config_path (str): _description_\n", 33 | "\n", 34 | " Returns:\n", 35 | " Dict: _description_\n", 36 | " \"\"\"\n", 37 | " with open(config_path, \"r\") as file:\n", 38 | " config = yaml.safe_load(file)\n", 39 | " return config" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 5, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "config = load_config(\"config.yaml\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Set Up" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 7, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "model = build_architecture(config=config)\n", 65 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n", 66 | "model.load_state_dict(checkpoint)\n", 67 | "model = model.to(\"cpu\")\n", 68 | "model = model.eval()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 9, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# build validation dataset & validataion data loader\n", 78 | "valset = build_dataset(\n", 79 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n", 80 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n", 81 | " augmentation_args=config[\"test_augmentation_args\"],\n", 82 | ")\n", 83 | "testloader = torch.utils.data.DataLoader(\n", 84 | " valset, batch_size=1, shuffle=False, num_workers=1\n", 85 | ")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "Inference" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 10, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "if not os.path.isdir(\"predictions\"):\n", 102 | " os.makedirs(\"predictions/rgb\")\n", 103 | " os.makedirs(\"predictions/gt\")\n", 104 | " os.makedirs(\"predictions/pred\")" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 12, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "torch.Size([1, 3, 512, 512])\n", 117 | "torch.Size([1, 3, 512, 512])\n", 118 | "torch.Size([1, 3, 512, 512])\n", 119 | "torch.Size([1, 3, 512, 512])\n", 120 | "torch.Size([1, 3, 512, 512])\n", 121 | "torch.Size([1, 3, 512, 512])\n", 122 | "torch.Size([1, 3, 512, 512])\n", 123 | "torch.Size([1, 3, 512, 512])\n", 124 | "torch.Size([1, 3, 512, 512])\n", 125 | "torch.Size([1, 3, 512, 512])\n", 126 | "torch.Size([1, 3, 512, 512])\n", 127 | "torch.Size([1, 3, 512, 512])\n", 128 | "torch.Size([1, 3, 512, 512])\n", 129 | "torch.Size([1, 3, 512, 512])\n", 130 | "torch.Size([1, 3, 512, 512])\n", 131 | "torch.Size([1, 3, 512, 512])\n", 132 | "torch.Size([1, 3, 512, 512])\n", 133 | "torch.Size([1, 3, 512, 512])\n", 134 | "torch.Size([1, 3, 512, 512])\n", 135 | "torch.Size([1, 3, 512, 512])\n", 136 | "torch.Size([1, 3, 512, 512])\n", 137 | "torch.Size([1, 3, 512, 512])\n", 138 | "torch.Size([1, 3, 512, 512])\n", 139 | "torch.Size([1, 3, 512, 512])\n", 140 | "torch.Size([1, 3, 512, 512])\n", 141 | "torch.Size([1, 3, 512, 512])\n", 142 | "torch.Size([1, 3, 512, 512])\n", 143 | "torch.Size([1, 3, 512, 512])\n", 144 | "torch.Size([1, 3, 512, 512])\n", 145 | "torch.Size([1, 3, 512, 512])\n", 146 | "torch.Size([1, 3, 512, 512])\n", 147 | "torch.Size([1, 3, 512, 512])\n", 148 | "torch.Size([1, 3, 512, 512])\n", 149 | "torch.Size([1, 3, 512, 512])\n", 150 | "torch.Size([1, 3, 512, 512])\n", 151 | "torch.Size([1, 3, 512, 512])\n", 152 | "torch.Size([1, 3, 512, 512])\n", 153 | "torch.Size([1, 3, 512, 512])\n", 154 | "torch.Size([1, 3, 512, 512])\n", 155 | "torch.Size([1, 3, 512, 512])\n", 156 | "torch.Size([1, 3, 512, 512])\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "model.eval()\n", 162 | "counter = 40\n", 163 | "for idx, data in enumerate(testloader):\n", 164 | " image = data[\"image\"].cuda()\n", 165 | " mask = data[\"mask\"].cuda()\n", 166 | " out = model.forward(image)\n", 167 | " out = torch.sigmoid(out)\n", 168 | " out[out < 0.5] = 0\n", 169 | " out[out >= 0.5] = 1\n", 170 | "\n", 171 | " # rgb input\n", 172 | " img = data[\"image\"]\n", 173 | " img = img.detach()\n", 174 | " print(img.shape)\n", 175 | " img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n", 176 | " img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n", 177 | " img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n", 178 | " save_image(img, f\"predictions/rgb/image_{idx}.png\")\n", 179 | "\n", 180 | " # prediction\n", 181 | " pred = out.detach()\n", 182 | " pred = pred * 255.0\n", 183 | " save_image(pred, f\"predictions/pred/pred_{idx}.png\")\n", 184 | "\n", 185 | " # ground truth\n", 186 | " gt = data[\"mask\"]\n", 187 | " gt = gt.detach()\n", 188 | " gt = gt * 255.0\n", 189 | " save_image(gt, f\"predictions/gt/gt_{idx}.png\")\n", 190 | "\n", 191 | " if idx == counter:\n", 192 | " break" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "Save RGB, GT and Pred" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 14, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# # RGB INPUT\n", 209 | "\n", 210 | "# img = data[\"image\"]\n", 211 | "# img = img.detach()\n", 212 | "# print(img.shape)\n", 213 | "# img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n", 214 | "# img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n", 215 | "# img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n", 216 | "# save_image(img, f\"predictions/rgb/image_{idx}.png\")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 15, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# # PREDICTION\n", 226 | "\n", 227 | "# pred = out.detach()\n", 228 | "# pred = pred * 255.0\n", 229 | "# save_image(pred, f\"predictions/pred/pred_{idx}.png\")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 16, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# # GROUND TRUTH\n", 239 | "\n", 240 | "# gt = data[1]\n", 241 | "# gt = gt.detach()\n", 242 | "# gt = gt * 255.0\n", 243 | "# save_image(gt, f\"predictions/gt/gt_{idx}.png\")" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "corev2", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.11.7" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 2 275 | } 276 | -------------------------------------------------------------------------------- /experiments_medical/isic_2017/exp_2_dice_b8_a2/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: isic_2017 16 | name: exp_2_dice_b8_a2 17 | mode: "online" 18 | resume: False 19 | tags: ["tr17ts17", "dice", "xxs", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xxs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [96] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [80, 96, 96] 34 | decoder: 35 | dims: [64, 80, 96] 36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_2017.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_2017.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_2017.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /experiments_medical/isic_2017/exp_2_dice_b8_a2/inference_model_weights/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/isic_2017/exp_2_dice_b8_a2/inference_model_weights/.gitignore -------------------------------------------------------------------------------- /experiments_medical/isic_2017/exp_2_dice_b8_a2/metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../../\")\n", 13 | "\n", 14 | "import yaml\n", 15 | "import torch\n", 16 | "import numpy as np\n", 17 | "from typing import Dict\n", 18 | "from architectures.build_architecture import build_architecture\n", 19 | "from dataloaders.build_dataset import build_dataset\n", 20 | "from typing import Tuple, Dict\n", 21 | "from fvcore.nn import FlopCountAnalysis\n", 22 | "from tqdm.notebook import tqdm\n", 23 | "from sklearn.metrics import (\n", 24 | " jaccard_score,\n", 25 | " accuracy_score,\n", 26 | " confusion_matrix,\n", 27 | ")\n", 28 | "import monai" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Load Config File" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def load_config(config_path: str) -> Dict:\n", 45 | " \"\"\"loads the yaml config file\n", 46 | "\n", 47 | " Args:\n", 48 | " config_path (str): _description_\n", 49 | "\n", 50 | " Returns:\n", 51 | " Dict: _description_\n", 52 | " \"\"\"\n", 53 | " with open(config_path, \"r\") as file:\n", 54 | " config = yaml.safe_load(file)\n", 55 | " return config\n", 56 | "\n", 57 | "\n", 58 | "config = load_config(\"config.yaml\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Build Dataset and DataLoaders" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "600\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "# build validation dataset & validataion data loader\n", 83 | "testset = build_dataset(\n", 84 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n", 85 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n", 86 | " augmentation_args=config[\"test_augmentation_args\"],\n", 87 | ")\n", 88 | "\n", 89 | "testloader = torch.utils.data.DataLoader(\n", 90 | " testset, batch_size=1, shuffle=False, num_workers=1\n", 91 | ")\n", 92 | "\n", 93 | "print(len(testset))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "model = build_architecture(config=config)\n", 103 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n", 104 | "model.load_state_dict(checkpoint)\n", 105 | "model = model.to(\"cpu\")\n", 106 | "model = model.eval()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "Model Complexity" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Computational complexity: 1.3 GMac\n", 126 | "Number of parameters: 3.01 M \n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "import torchvision.models as models\n", 132 | "import torch\n", 133 | "from ptflops import get_model_complexity_info\n", 134 | "\n", 135 | "with torch.cuda.device(0):\n", 136 | " net = model\n", 137 | " macs, params = get_model_complexity_info(\n", 138 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n", 139 | " )\n", 140 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n", 141 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stderr", 151 | "output_type": "stream", 152 | "text": [ 153 | "Unsupported operator aten::silu encountered 84 time(s)\n", 154 | "Unsupported operator aten::add encountered 45 time(s)\n", 155 | "Unsupported operator aten::div encountered 15 time(s)\n", 156 | "Unsupported operator aten::ceil encountered 6 time(s)\n", 157 | "Unsupported operator aten::mul encountered 118 time(s)\n", 158 | "Unsupported operator aten::softmax encountered 21 time(s)\n", 159 | "Unsupported operator aten::clone encountered 4 time(s)\n", 160 | "Unsupported operator aten::mul_ encountered 24 time(s)\n", 161 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n", 162 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n", 163 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n" 164 | ] 165 | }, 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "Computational complexity: 3.01 \n", 171 | "Number of parameters: 1.24 \n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "def flop_count_analysis(\n", 177 | " model: torch.nn.Module,\n", 178 | " input_dim: Tuple,\n", 179 | ") -> Dict:\n", 180 | " \"\"\"_summary_\n", 181 | "\n", 182 | " Args:\n", 183 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n", 184 | " model (torch.nn.Module): _description_\n", 185 | " \"\"\"\n", 186 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 187 | " input_tensor = torch.ones(()).new_empty(\n", 188 | " (1, *input_dim),\n", 189 | " dtype=next(model.parameters()).dtype,\n", 190 | " device=next(model.parameters()).device,\n", 191 | " )\n", 192 | " flops = FlopCountAnalysis(model, input_tensor)\n", 193 | " model_flops = flops.total()\n", 194 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n", 195 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n", 196 | "\n", 197 | " out = {\n", 198 | " \"params\": round(trainable_params * 1e-6, 2),\n", 199 | " \"flops\": round(model_flops * 1e-9, 2),\n", 200 | " }\n", 201 | "\n", 202 | " return out\n", 203 | "\n", 204 | "\n", 205 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n", 206 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n", 207 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "Calculate IoU Metric" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "application/vnd.jupyter.widget-view+json": { 225 | "model_id": "a074d62a4c3d4a14ac6e4f99ca7abe73", 226 | "version_major": 2, 227 | "version_minor": 0 228 | }, 229 | "text/plain": [ 230 | "0it [00:00, ?it/s]" 231 | ] 232 | }, 233 | "metadata": {}, 234 | "output_type": "display_data" 235 | }, 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "test iou: 0.7900467007312545\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "iou = []\n", 246 | "with torch.no_grad():\n", 247 | " for idx, data in tqdm(enumerate(testloader)):\n", 248 | " image = data[\"image\"].cuda()\n", 249 | " mask = data[\"mask\"].cuda()\n", 250 | " out = model.forward(image)\n", 251 | " out = torch.sigmoid(out)\n", 252 | " out[out < 0.5] = 0\n", 253 | " out[out >= 0.5] = 1\n", 254 | " mean_iou = jaccard_score(\n", 255 | " mask.detach().cpu().numpy().ravel(),\n", 256 | " out.detach().cpu().numpy().ravel(),\n", 257 | " average=\"binary\",\n", 258 | " pos_label=1,\n", 259 | " )\n", 260 | " iou.append(mean_iou.item())\n", 261 | "\n", 262 | "print(f\"test iou: {np.mean(iou)}\")" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Accuracy" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 8, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "application/vnd.jupyter.widget-view+json": { 280 | "model_id": "63a0af5770784eeabded405e89276b4f", 281 | "version_major": 2, 282 | "version_minor": 0 283 | }, 284 | "text/plain": [ 285 | "0it [00:00, ?it/s]" 286 | ] 287 | }, 288 | "metadata": {}, 289 | "output_type": "display_data" 290 | }, 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "test accuracy: 0.9445873578389485\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "accuracy = []\n", 301 | "with torch.no_grad():\n", 302 | " for idx, data in tqdm(enumerate(testloader)):\n", 303 | " image = data[\"image\"].cuda()\n", 304 | " mask = data[\"mask\"].cuda()\n", 305 | " out = model.forward(image)\n", 306 | " out = torch.sigmoid(out)\n", 307 | " out[out < 0.5] = 0\n", 308 | " out[out >= 0.5] = 1\n", 309 | " acc = accuracy_score(\n", 310 | " mask.detach().cpu().numpy().ravel(),\n", 311 | " out.detach().cpu().numpy().ravel(),\n", 312 | " )\n", 313 | " accuracy.append(acc.item())\n", 314 | "\n", 315 | "print(f\"test accuracy: {np.mean(accuracy)}\")" 316 | ] 317 | }, 318 | { 319 | "attachments": {}, 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "Calculate Dice" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 9, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "data": { 333 | "application/vnd.jupyter.widget-view+json": { 334 | "model_id": "cf1cb2544a624ec6a99862601b08d61b", 335 | "version_major": 2, 336 | "version_minor": 0 337 | }, 338 | "text/plain": [ 339 | "0it [00:00, ?it/s]" 340 | ] 341 | }, 342 | "metadata": {}, 343 | "output_type": "display_data" 344 | }, 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "test dice: 0.8683788800487916\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "dice = []\n", 355 | "with torch.no_grad():\n", 356 | " for idx, data in tqdm(enumerate(testloader)):\n", 357 | " image = data[\"image\"].cuda()\n", 358 | " mask = data[\"mask\"].cuda()\n", 359 | " out = model.forward(image)\n", 360 | " out = torch.sigmoid(out)\n", 361 | " out[out < 0.5] = 0\n", 362 | " out[out >= 0.5] = 1\n", 363 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n", 364 | " dice.append(mean_dice.item())\n", 365 | "\n", 366 | "print(f\"test dice: {np.mean(dice)}\")" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "Calculate Specificity" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 10, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "application/vnd.jupyter.widget-view+json": { 384 | "model_id": "f5a40452c5c648a196704e679826d4d5", 385 | "version_major": 2, 386 | "version_minor": 0 387 | }, 388 | "text/plain": [ 389 | "0it [00:00, ?it/s]" 390 | ] 391 | }, 392 | "metadata": {}, 393 | "output_type": "display_data" 394 | }, 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "test specificity: 0.969295418576087\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "specificity = []\n", 405 | "with torch.no_grad():\n", 406 | " for idx, data in tqdm(enumerate(testloader)):\n", 407 | " image = data[\"image\"].cuda()\n", 408 | " mask = data[\"mask\"].cuda()\n", 409 | " out = model.forward(image)\n", 410 | " out = torch.sigmoid(out)\n", 411 | " out[out < 0.5] = 0\n", 412 | " out[out >= 0.5] = 1\n", 413 | " confusion = confusion_matrix(\n", 414 | " mask.detach().cpu().numpy().ravel(),\n", 415 | " out.detach().cpu().numpy().ravel(),\n", 416 | " )\n", 417 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n", 418 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n", 419 | "\n", 420 | " specificity.append(sp)\n", 421 | "\n", 422 | "print(f\"test specificity: {np.mean(specificity)}\")" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "Calculate Sensitivity" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 11, 435 | "metadata": {}, 436 | "outputs": [ 437 | { 438 | "data": { 439 | "application/vnd.jupyter.widget-view+json": { 440 | "model_id": "f880c143f40c429e9125160dfba20fc7", 441 | "version_major": 2, 442 | "version_minor": 0 443 | }, 444 | "text/plain": [ 445 | "0it [00:00, ?it/s]" 446 | ] 447 | }, 448 | "metadata": {}, 449 | "output_type": "display_data" 450 | }, 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "test sensitivity: 0.8518193023236439\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "sensitivity = []\n", 461 | "with torch.no_grad():\n", 462 | " for idx, data in tqdm(enumerate(testloader)):\n", 463 | " image = data[\"image\"].cuda()\n", 464 | " mask = data[\"mask\"].cuda()\n", 465 | " out = model.forward(image)\n", 466 | " out = torch.sigmoid(out)\n", 467 | " out[out < 0.5] = 0\n", 468 | " out[out >= 0.5] = 1\n", 469 | " confusion = confusion_matrix(\n", 470 | " mask.detach().cpu().numpy().ravel(),\n", 471 | " out.detach().cpu().numpy().ravel(),\n", 472 | " )\n", 473 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n", 474 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n", 475 | "\n", 476 | " sensitivity.append(se)\n", 477 | "\n", 478 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 12, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "# DONE" 488 | ] 489 | } 490 | ], 491 | "metadata": { 492 | "kernelspec": { 493 | "display_name": "core", 494 | "language": "python", 495 | "name": "python3" 496 | }, 497 | "language_info": { 498 | "codemirror_mode": { 499 | "name": "ipython", 500 | "version": 3 501 | }, 502 | "file_extension": ".py", 503 | "mimetype": "text/x-python", 504 | "name": "python", 505 | "nbconvert_exporter": "python", 506 | "pygments_lexer": "ipython3", 507 | "version": "3.11.7" 508 | }, 509 | "orig_nbformat": 4, 510 | "vscode": { 511 | "interpreter": { 512 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea" 513 | } 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 2 518 | } 519 | -------------------------------------------------------------------------------- /experiments_medical/isic_2017/exp_2_dice_b8_a2/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | sys.path.append("../../../") 6 | 7 | import yaml 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from typing import Dict 12 | from termcolor import colored 13 | from accelerate import Accelerator 14 | from losses.losses import build_loss_fn 15 | from optimizers.optimizers import build_optimizer 16 | from optimizers.schedulers import build_scheduler 17 | from train_scripts.segmentation_trainer import Segmentation_Trainer 18 | from architectures.build_architecture import build_architecture 19 | from dataloaders.build_dataset import build_dataset, build_dataloader 20 | 21 | 22 | ################################################################################################## 23 | def launch_experiment(config_path) -> Dict: 24 | """ 25 | Builds Experiment 26 | Args: 27 | config (Dict): configuration file 28 | 29 | Returns: 30 | Dict: _description_ 31 | """ 32 | # load config 33 | config = load_config(config_path) 34 | 35 | # set seed 36 | seed_everything(config) 37 | 38 | # build directories 39 | build_directories(config) 40 | 41 | # build training dataset & training data loader 42 | trainset = build_dataset( 43 | dataset_type=config["dataset_parameters"]["dataset_type"], 44 | dataset_args=config["dataset_parameters"]["train_dataset_args"], 45 | augmentation_args=config["train_augmentation_args"], 46 | ) 47 | trainloader = build_dataloader( 48 | dataset=trainset, 49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"], 50 | config=config, 51 | train=True, 52 | ) 53 | 54 | # build validation dataset & validataion data loader 55 | valset = build_dataset( 56 | dataset_type=config["dataset_parameters"]["dataset_type"], 57 | dataset_args=config["dataset_parameters"]["val_dataset_args"], 58 | augmentation_args=config["test_augmentation_args"], 59 | ) 60 | valloader = build_dataloader( 61 | dataset=valset, 62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"], 63 | config=config, 64 | train=False, 65 | ) 66 | 67 | # build the Model 68 | model = build_architecture(config) 69 | 70 | # set up the loss function 71 | criterion = build_loss_fn( 72 | loss_type=config["loss_fn"]["loss_type"], 73 | loss_args=config["loss_fn"]["loss_args"], 74 | ) 75 | 76 | # set up the optimizer 77 | optimizer = build_optimizer( 78 | model=model, 79 | optimizer_type=config["optimizer"]["optimizer_type"], 80 | optimizer_args=config["optimizer"]["optimizer_args"], 81 | ) 82 | 83 | # set up schedulers 84 | warmup_scheduler = build_scheduler( 85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config 86 | ) 87 | training_scheduler = build_scheduler( 88 | optimizer=optimizer, 89 | scheduler_type="training_scheduler", 90 | config=config, 91 | ) 92 | 93 | # use accelarate 94 | accelerator = Accelerator( 95 | log_with="wandb", 96 | gradient_accumulation_steps=config["training_parameters"][ 97 | "grad_accumulate_steps" 98 | ], 99 | ) 100 | accelerator.init_trackers( 101 | project_name=config["project"], 102 | config=config, 103 | init_kwargs={"wandb": config["wandb_parameters"]}, 104 | ) 105 | 106 | # display experiment info 107 | display_info(config, accelerator, trainset, valset, model) 108 | 109 | # convert all components to accelerate 110 | model = accelerator.prepare_model(model=model) 111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer) 112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader) 113 | valloader = accelerator.prepare_data_loader(data_loader=valloader) 114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler) 115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler) 116 | 117 | # create a single dict to hold all parameters 118 | storage = { 119 | "model": model, 120 | "trainloader": trainloader, 121 | "valloader": valloader, 122 | "criterion": criterion, 123 | "optimizer": optimizer, 124 | "warmup_scheduler": warmup_scheduler, 125 | "training_scheduler": training_scheduler, 126 | } 127 | 128 | # set up trainer 129 | trainer = Segmentation_Trainer( 130 | config=config, 131 | model=storage["model"], 132 | optimizer=storage["optimizer"], 133 | criterion=storage["criterion"], 134 | train_dataloader=storage["trainloader"], 135 | val_dataloader=storage["valloader"], 136 | warmup_scheduler=storage["warmup_scheduler"], 137 | training_scheduler=storage["training_scheduler"], 138 | accelerator=accelerator, 139 | ) 140 | 141 | # run train 142 | trainer.train() 143 | 144 | 145 | ################################################################################################## 146 | def seed_everything(config) -> None: 147 | seed = config["training_parameters"]["seed"] 148 | os.environ["PYTHONHASHSEED"] = str(seed) 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.benchmark = False 156 | 157 | 158 | ################################################################################################## 159 | def load_config(config_path: str) -> Dict: 160 | """loads the yaml config file 161 | 162 | Args: 163 | config_path (str): _description_ 164 | 165 | Returns: 166 | Dict: _description_ 167 | """ 168 | with open(config_path, "r") as file: 169 | config = yaml.safe_load(file) 170 | return config 171 | 172 | 173 | ################################################################################################## 174 | def build_directories(config: Dict) -> None: 175 | # create necessary directories 176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]): 177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"]) 178 | 179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]): 180 | raise ValueError("checkpoint exits -- preventing file override -- rename file") 181 | 182 | 183 | ################################################################################################## 184 | def display_info(config, accelerator, trainset, valset, model): 185 | # print experiment info 186 | accelerator.print(f"-------------------------------------------------------") 187 | accelerator.print(f"[info]: Experiment Info") 188 | accelerator.print( 189 | f"[info] ----- Project: {colored(config['project'], color='red')}" 190 | ) 191 | accelerator.print( 192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}" 193 | ) 194 | accelerator.print( 195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}" 196 | ) 197 | accelerator.print( 198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}" 199 | ) 200 | accelerator.print( 201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}" 202 | ) 203 | accelerator.print( 204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}" 205 | ) 206 | accelerator.print( 207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}" 208 | ) 209 | accelerator.print( 210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}" 211 | ) 212 | accelerator.print( 213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}" 214 | ) 215 | 216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 217 | accelerator.print( 218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}" 219 | ) 220 | accelerator.print( 221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}" 222 | ) 223 | accelerator.print( 224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}" 225 | ) 226 | accelerator.print( 227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}" 228 | ) 229 | accelerator.print( 230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}" 231 | ) 232 | accelerator.print(f"-------------------------------------------------------") 233 | 234 | 235 | ################################################################################################## 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser(description="Simple example of training script.") 238 | parser.add_argument( 239 | "--config", type=str, default="config.yaml", help="path to yaml config file" 240 | ) 241 | args = parser.parse_args() 242 | launch_experiment(args.config) 243 | -------------------------------------------------------------------------------- /experiments_medical/isic_2018/exp_2_dice_b8_a2/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: isic_2018 16 | name: exp_2_dice_b8_a 17 | mode: "online" 18 | resume: False 19 | tags: ["tr18ts18", "dice", "xxs", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xxs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [96] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [80, 96, 96] 34 | decoder: 35 | dims: [64, 80, 96] 36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_2018.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_2018.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_2018.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /experiments_medical/isic_2018/exp_2_dice_b8_a2/inference_model_weights/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/isic_2018/exp_2_dice_b8_a2/inference_model_weights/.gitignore -------------------------------------------------------------------------------- /experiments_medical/isic_2018/exp_2_dice_b8_a2/metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../../\")\n", 13 | "\n", 14 | "import yaml\n", 15 | "import torch\n", 16 | "import numpy as np\n", 17 | "from typing import Dict\n", 18 | "from architectures.build_architecture import build_architecture\n", 19 | "from dataloaders.build_dataset import build_dataset\n", 20 | "from typing import Tuple, Dict\n", 21 | "from fvcore.nn import FlopCountAnalysis\n", 22 | "from tqdm.notebook import tqdm\n", 23 | "from sklearn.metrics import (\n", 24 | " jaccard_score,\n", 25 | " accuracy_score,\n", 26 | " confusion_matrix,\n", 27 | ")\n", 28 | "import monai" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Load Config File" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def load_config(config_path: str) -> Dict:\n", 45 | " \"\"\"loads the yaml config file\n", 46 | "\n", 47 | " Args:\n", 48 | " config_path (str): _description_\n", 49 | "\n", 50 | " Returns:\n", 51 | " Dict: _description_\n", 52 | " \"\"\"\n", 53 | " with open(config_path, \"r\") as file:\n", 54 | " config = yaml.safe_load(file)\n", 55 | " return config\n", 56 | "\n", 57 | "\n", 58 | "config = load_config(\"config.yaml\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Build Dataset and DataLoaders" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "1000\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "# build validation dataset & validataion data loader\n", 83 | "testset = build_dataset(\n", 84 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n", 85 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n", 86 | " augmentation_args=config[\"test_augmentation_args\"],\n", 87 | ")\n", 88 | "\n", 89 | "testloader = torch.utils.data.DataLoader(\n", 90 | " testset, batch_size=1, shuffle=False, num_workers=1\n", 91 | ")\n", 92 | "\n", 93 | "print(len(testset))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "model = build_architecture(config=config)\n", 103 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n", 104 | "model.load_state_dict(checkpoint)\n", 105 | "model = model.to(\"cpu\")\n", 106 | "model = model.eval()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "Model Complexity" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Computational complexity: 1.3 GMac\n", 126 | "Number of parameters: 3.01 M \n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "import torchvision.models as models\n", 132 | "import torch\n", 133 | "from ptflops import get_model_complexity_info\n", 134 | "\n", 135 | "with torch.cuda.device(0):\n", 136 | " net = model\n", 137 | " macs, params = get_model_complexity_info(\n", 138 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n", 139 | " )\n", 140 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n", 141 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stderr", 151 | "output_type": "stream", 152 | "text": [ 153 | "Unsupported operator aten::silu encountered 84 time(s)\n", 154 | "Unsupported operator aten::add encountered 45 time(s)\n", 155 | "Unsupported operator aten::div encountered 15 time(s)\n", 156 | "Unsupported operator aten::ceil encountered 6 time(s)\n", 157 | "Unsupported operator aten::mul encountered 118 time(s)\n", 158 | "Unsupported operator aten::softmax encountered 21 time(s)\n", 159 | "Unsupported operator aten::clone encountered 4 time(s)\n", 160 | "Unsupported operator aten::mul_ encountered 24 time(s)\n", 161 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n", 162 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n", 163 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n" 164 | ] 165 | }, 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "Computational complexity: 3.01 \n", 171 | "Number of parameters: 1.24 \n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "def flop_count_analysis(\n", 177 | " model: torch.nn.Module,\n", 178 | " input_dim: Tuple,\n", 179 | ") -> Dict:\n", 180 | " \"\"\"_summary_\n", 181 | "\n", 182 | " Args:\n", 183 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n", 184 | " model (torch.nn.Module): _description_\n", 185 | " \"\"\"\n", 186 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 187 | " input_tensor = torch.ones(()).new_empty(\n", 188 | " (1, *input_dim),\n", 189 | " dtype=next(model.parameters()).dtype,\n", 190 | " device=next(model.parameters()).device,\n", 191 | " )\n", 192 | " flops = FlopCountAnalysis(model, input_tensor)\n", 193 | " model_flops = flops.total()\n", 194 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n", 195 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n", 196 | "\n", 197 | " out = {\n", 198 | " \"params\": round(trainable_params * 1e-6, 2),\n", 199 | " \"flops\": round(model_flops * 1e-9, 2),\n", 200 | " }\n", 201 | "\n", 202 | " return out\n", 203 | "\n", 204 | "\n", 205 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n", 206 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n", 207 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "Calculate IoU Metric" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "application/vnd.jupyter.widget-view+json": { 225 | "model_id": "a50f810cae5446b687648fc4c979ab68", 226 | "version_major": 2, 227 | "version_minor": 0 228 | }, 229 | "text/plain": [ 230 | "0it [00:00, ?it/s]" 231 | ] 232 | }, 233 | "metadata": {}, 234 | "output_type": "display_data" 235 | }, 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "test iou: 0.8456581013355624\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "iou = []\n", 246 | "with torch.no_grad():\n", 247 | " for idx, data in tqdm(enumerate(testloader)):\n", 248 | " image = data[\"image\"].cuda()\n", 249 | " mask = data[\"mask\"].cuda()\n", 250 | " out = model.forward(image)\n", 251 | " out = torch.sigmoid(out)\n", 252 | " out[out < 0.5] = 0\n", 253 | " out[out >= 0.5] = 1\n", 254 | " mean_iou = jaccard_score(\n", 255 | " mask.detach().cpu().numpy().ravel(),\n", 256 | " out.detach().cpu().numpy().ravel(),\n", 257 | " average=\"binary\",\n", 258 | " pos_label=1,\n", 259 | " )\n", 260 | " iou.append(mean_iou.item())\n", 261 | "\n", 262 | "print(f\"test iou: {np.mean(iou)}\")" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Accuracy" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 8, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "application/vnd.jupyter.widget-view+json": { 280 | "model_id": "f727b9ca0a604a139f3c41a4ab971605", 281 | "version_major": 2, 282 | "version_minor": 0 283 | }, 284 | "text/plain": [ 285 | "0it [00:00, ?it/s]" 286 | ] 287 | }, 288 | "metadata": {}, 289 | "output_type": "display_data" 290 | }, 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "test accuracy: 0.9440349769592286\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "accuracy = []\n", 301 | "with torch.no_grad():\n", 302 | " for idx, data in tqdm(enumerate(testloader)):\n", 303 | " image = data[\"image\"].cuda()\n", 304 | " mask = data[\"mask\"].cuda()\n", 305 | " out = model.forward(image)\n", 306 | " out = torch.sigmoid(out)\n", 307 | " out[out < 0.5] = 0\n", 308 | " out[out >= 0.5] = 1\n", 309 | " acc = accuracy_score(\n", 310 | " mask.detach().cpu().numpy().ravel(),\n", 311 | " out.detach().cpu().numpy().ravel(),\n", 312 | " )\n", 313 | " accuracy.append(acc.item())\n", 314 | "\n", 315 | "print(f\"test accuracy: {np.mean(accuracy)}\")" 316 | ] 317 | }, 318 | { 319 | "attachments": {}, 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "Calculate Dice" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 9, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "data": { 333 | "application/vnd.jupyter.widget-view+json": { 334 | "model_id": "b957f714a8e544eb85f1132ee9d24439", 335 | "version_major": 2, 336 | "version_minor": 0 337 | }, 338 | "text/plain": [ 339 | "0it [00:00, ?it/s]" 340 | ] 341 | }, 342 | "metadata": {}, 343 | "output_type": "display_data" 344 | }, 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "test dice: 0.9073548163622618\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "dice = []\n", 355 | "with torch.no_grad():\n", 356 | " for idx, data in tqdm(enumerate(testloader)):\n", 357 | " image = data[\"image\"].cuda()\n", 358 | " mask = data[\"mask\"].cuda()\n", 359 | " out = model.forward(image)\n", 360 | " out = torch.sigmoid(out)\n", 361 | " out[out < 0.5] = 0\n", 362 | " out[out >= 0.5] = 1\n", 363 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n", 364 | " dice.append(mean_dice.item())\n", 365 | "\n", 366 | "print(f\"test dice: {np.mean(dice)}\")" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "Calculate Specificity" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 10, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "application/vnd.jupyter.widget-view+json": { 384 | "model_id": "d9064584bfb248829f465cbd888c5528", 385 | "version_major": 2, 386 | "version_minor": 0 387 | }, 388 | "text/plain": [ 389 | "0it [00:00, ?it/s]" 390 | ] 391 | }, 392 | "metadata": {}, 393 | "output_type": "display_data" 394 | }, 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "test specificity: 0.9503119752056314\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "specificity = []\n", 405 | "with torch.no_grad():\n", 406 | " for idx, data in tqdm(enumerate(testloader)):\n", 407 | " image = data[\"image\"].cuda()\n", 408 | " mask = data[\"mask\"].cuda()\n", 409 | " out = model.forward(image)\n", 410 | " out = torch.sigmoid(out)\n", 411 | " out[out < 0.5] = 0\n", 412 | " out[out >= 0.5] = 1\n", 413 | " confusion = confusion_matrix(\n", 414 | " mask.detach().cpu().numpy().ravel(),\n", 415 | " out.detach().cpu().numpy().ravel(),\n", 416 | " )\n", 417 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n", 418 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n", 419 | "\n", 420 | " specificity.append(sp)\n", 421 | "\n", 422 | "print(f\"test specificity: {np.mean(specificity)}\")" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "Calculate Sensitivity" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 11, 435 | "metadata": {}, 436 | "outputs": [ 437 | { 438 | "data": { 439 | "application/vnd.jupyter.widget-view+json": { 440 | "model_id": "cff60dfd365b4cdebb2c5841e6750f03", 441 | "version_major": 2, 442 | "version_minor": 0 443 | }, 444 | "text/plain": [ 445 | "0it [00:00, ?it/s]" 446 | ] 447 | }, 448 | "metadata": {}, 449 | "output_type": "display_data" 450 | }, 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "test sensitivity: 0.9255471263739157\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "sensitivity = []\n", 461 | "with torch.no_grad():\n", 462 | " for idx, data in tqdm(enumerate(testloader)):\n", 463 | " image = data[\"image\"].cuda()\n", 464 | " mask = data[\"mask\"].cuda()\n", 465 | " out = model.forward(image)\n", 466 | " out = torch.sigmoid(out)\n", 467 | " out[out < 0.5] = 0\n", 468 | " out[out >= 0.5] = 1\n", 469 | " confusion = confusion_matrix(\n", 470 | " mask.detach().cpu().numpy().ravel(),\n", 471 | " out.detach().cpu().numpy().ravel(),\n", 472 | " )\n", 473 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n", 474 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n", 475 | "\n", 476 | " sensitivity.append(se)\n", 477 | "\n", 478 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 12, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "# DONE" 488 | ] 489 | } 490 | ], 491 | "metadata": { 492 | "kernelspec": { 493 | "display_name": "core", 494 | "language": "python", 495 | "name": "python3" 496 | }, 497 | "language_info": { 498 | "codemirror_mode": { 499 | "name": "ipython", 500 | "version": 3 501 | }, 502 | "file_extension": ".py", 503 | "mimetype": "text/x-python", 504 | "name": "python", 505 | "nbconvert_exporter": "python", 506 | "pygments_lexer": "ipython3", 507 | "version": "3.11.7" 508 | }, 509 | "orig_nbformat": 4, 510 | "vscode": { 511 | "interpreter": { 512 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea" 513 | } 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 2 518 | } 519 | -------------------------------------------------------------------------------- /experiments_medical/isic_2018/exp_2_dice_b8_a2/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | sys.path.append("../../../") 6 | 7 | import yaml 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from typing import Dict 12 | from termcolor import colored 13 | from accelerate import Accelerator 14 | from losses.losses import build_loss_fn 15 | from optimizers.optimizers import build_optimizer 16 | from optimizers.schedulers import build_scheduler 17 | from train_scripts.segmentation_trainer import Segmentation_Trainer 18 | from architectures.build_architecture import build_architecture 19 | from dataloaders.build_dataset import build_dataset, build_dataloader 20 | 21 | 22 | ################################################################################################## 23 | def launch_experiment(config_path) -> Dict: 24 | """ 25 | Builds Experiment 26 | Args: 27 | config (Dict): configuration file 28 | 29 | Returns: 30 | Dict: _description_ 31 | """ 32 | # load config 33 | config = load_config(config_path) 34 | 35 | # set seed 36 | seed_everything(config) 37 | 38 | # build directories 39 | build_directories(config) 40 | 41 | # build training dataset & training data loader 42 | trainset = build_dataset( 43 | dataset_type=config["dataset_parameters"]["dataset_type"], 44 | dataset_args=config["dataset_parameters"]["train_dataset_args"], 45 | augmentation_args=config["train_augmentation_args"], 46 | ) 47 | trainloader = build_dataloader( 48 | dataset=trainset, 49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"], 50 | config=config, 51 | train=True, 52 | ) 53 | 54 | # build validation dataset & validataion data loader 55 | valset = build_dataset( 56 | dataset_type=config["dataset_parameters"]["dataset_type"], 57 | dataset_args=config["dataset_parameters"]["val_dataset_args"], 58 | augmentation_args=config["test_augmentation_args"], 59 | ) 60 | valloader = build_dataloader( 61 | dataset=valset, 62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"], 63 | config=config, 64 | train=False, 65 | ) 66 | 67 | # build the Model 68 | model = build_architecture(config) 69 | 70 | # set up the loss function 71 | criterion = build_loss_fn( 72 | loss_type=config["loss_fn"]["loss_type"], 73 | loss_args=config["loss_fn"]["loss_args"], 74 | ) 75 | 76 | # set up the optimizer 77 | optimizer = build_optimizer( 78 | model=model, 79 | optimizer_type=config["optimizer"]["optimizer_type"], 80 | optimizer_args=config["optimizer"]["optimizer_args"], 81 | ) 82 | 83 | # set up schedulers 84 | warmup_scheduler = build_scheduler( 85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config 86 | ) 87 | training_scheduler = build_scheduler( 88 | optimizer=optimizer, 89 | scheduler_type="training_scheduler", 90 | config=config, 91 | ) 92 | 93 | # use accelarate 94 | accelerator = Accelerator( 95 | log_with="wandb", 96 | gradient_accumulation_steps=config["training_parameters"][ 97 | "grad_accumulate_steps" 98 | ], 99 | ) 100 | accelerator.init_trackers( 101 | project_name=config["project"], 102 | config=config, 103 | init_kwargs={"wandb": config["wandb_parameters"]}, 104 | ) 105 | 106 | # display experiment info 107 | display_info(config, accelerator, trainset, valset, model) 108 | 109 | # convert all components to accelerate 110 | model = accelerator.prepare_model(model=model) 111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer) 112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader) 113 | valloader = accelerator.prepare_data_loader(data_loader=valloader) 114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler) 115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler) 116 | 117 | # create a single dict to hold all parameters 118 | storage = { 119 | "model": model, 120 | "trainloader": trainloader, 121 | "valloader": valloader, 122 | "criterion": criterion, 123 | "optimizer": optimizer, 124 | "warmup_scheduler": warmup_scheduler, 125 | "training_scheduler": training_scheduler, 126 | } 127 | 128 | # set up trainer 129 | trainer = Segmentation_Trainer( 130 | config=config, 131 | model=storage["model"], 132 | optimizer=storage["optimizer"], 133 | criterion=storage["criterion"], 134 | train_dataloader=storage["trainloader"], 135 | val_dataloader=storage["valloader"], 136 | warmup_scheduler=storage["warmup_scheduler"], 137 | training_scheduler=storage["training_scheduler"], 138 | accelerator=accelerator, 139 | ) 140 | 141 | # run train 142 | trainer.train() 143 | 144 | 145 | ################################################################################################## 146 | def seed_everything(config) -> None: 147 | seed = config["training_parameters"]["seed"] 148 | os.environ["PYTHONHASHSEED"] = str(seed) 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.benchmark = False 156 | 157 | 158 | ################################################################################################## 159 | def load_config(config_path: str) -> Dict: 160 | """loads the yaml config file 161 | 162 | Args: 163 | config_path (str): _description_ 164 | 165 | Returns: 166 | Dict: _description_ 167 | """ 168 | with open(config_path, "r") as file: 169 | config = yaml.safe_load(file) 170 | return config 171 | 172 | 173 | ################################################################################################## 174 | def build_directories(config: Dict) -> None: 175 | # create necessary directories 176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]): 177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"]) 178 | 179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]): 180 | raise ValueError("checkpoint exits -- preventing file override -- rename file") 181 | 182 | 183 | ################################################################################################## 184 | def display_info(config, accelerator, trainset, valset, model): 185 | # print experiment info 186 | accelerator.print(f"-------------------------------------------------------") 187 | accelerator.print(f"[info]: Experiment Info") 188 | accelerator.print( 189 | f"[info] ----- Project: {colored(config['project'], color='red')}" 190 | ) 191 | accelerator.print( 192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}" 193 | ) 194 | accelerator.print( 195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}" 196 | ) 197 | accelerator.print( 198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}" 199 | ) 200 | accelerator.print( 201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}" 202 | ) 203 | accelerator.print( 204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}" 205 | ) 206 | accelerator.print( 207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}" 208 | ) 209 | accelerator.print( 210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}" 211 | ) 212 | accelerator.print( 213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}" 214 | ) 215 | 216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 217 | accelerator.print( 218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}" 219 | ) 220 | accelerator.print( 221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}" 222 | ) 223 | accelerator.print( 224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}" 225 | ) 226 | accelerator.print( 227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}" 228 | ) 229 | accelerator.print( 230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}" 231 | ) 232 | accelerator.print(f"-------------------------------------------------------") 233 | 234 | 235 | ################################################################################################## 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser(description="Simple example of training script.") 238 | parser.add_argument( 239 | "--config", type=str, default="config.yaml", help="path to yaml config file" 240 | ) 241 | args = parser.parse_args() 242 | launch_experiment(args.config) 243 | -------------------------------------------------------------------------------- /experiments_medical/ph2/exp_2_dice_b8_a2/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | # Commonly Used Variables 3 | ############################################################## 4 | variables: 5 | batch_size: &batch_size 8 6 | num_channels: &num_channels 3 7 | num_classes: &num_classes 1 8 | num_epochs: &num_epochs 400 9 | 10 | ############################################################## 11 | # Wandb Model Tracking 12 | ############################################################## 13 | project: mobileunetr 14 | wandb_parameters: 15 | group: ph2 16 | name: exp_2_dice_b8_a 17 | mode: "online" 18 | resume: False 19 | tags: ["trph2tsph2", "dice", "xxs", "adam"] 20 | 21 | ############################################################## 22 | # Model Hyper-Parameters 23 | ############################################################## 24 | model_name: mobileunetr_xxs 25 | model_parameters: 26 | encoder: None 27 | bottle_neck: 28 | dims: [96] 29 | depths: [3] 30 | expansion: 4 31 | kernel_size: 3 32 | patch_size: [2,2] 33 | channels: [80, 96, 96] 34 | decoder: 35 | dims: [64, 80, 96] 36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320] 37 | num_classes: 1 38 | image_size: 512 39 | 40 | ############################################################## 41 | # Loss Function 42 | ############################################################## 43 | loss_fn: 44 | loss_type: "dice" 45 | loss_args: None 46 | 47 | ############################################################## 48 | # Metrics 49 | ############################################################## 50 | metrics: 51 | type: "binary" 52 | mean_iou: 53 | enabled: True 54 | mean_iou_args: 55 | include_background: True 56 | reduction: "mean" 57 | get_not_nans: False 58 | ignore_empty: True 59 | dice: 60 | enabled: False 61 | dice_args: 62 | include_background: True 63 | reduction: "mean" 64 | get_not_nans: False 65 | ignore_empty: True 66 | num_classes: *num_classes 67 | 68 | ############################################################## 69 | # Optimizers 70 | ############################################################## 71 | optimizer: 72 | optimizer_type: "adamw" 73 | optimizer_args: 74 | lr: 0.00006 75 | weight_decay: 0.01 76 | 77 | ############################################################## 78 | # Learning Rate Schedulers 79 | ############################################################## 80 | warmup_scheduler: 81 | enabled: True # should be always true 82 | warmup_epochs: 30 83 | 84 | # train scheduler 85 | train_scheduler: 86 | scheduler_type: 'cosine_annealing_wr' 87 | scheduler_args: 88 | t_0_epochs: 600 89 | t_mult: 1 90 | min_lr: 0.000006 91 | 92 | ############################################################## 93 | # EMA (Exponential Moving Average) 94 | ############################################################## 95 | ema: 96 | enabled: False 97 | ema_decay: 0.999 98 | print_ema_every: 20 99 | 100 | ############################################################## 101 | # Gradient Clipping 102 | ############################################################## 103 | clip_gradients: 104 | enabled: False 105 | clip_gradients_value: 0.1 106 | 107 | ############################################################## 108 | # Training Hyperparameters 109 | ############################################################## 110 | training_parameters: 111 | seed: 42 112 | num_epochs: 1000 113 | cutoff_epoch: 600 114 | load_optimizer: False 115 | print_every: 1500 116 | calculate_metrics: True 117 | grad_accumulate_steps: 1 # default: 1 118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2" 119 | load_checkpoint: # not implemented yet 120 | load_full_checkpoint: False 121 | load_model_only: False 122 | load_checkpoint_path: None 123 | 124 | ############################################################## 125 | # Dataset and Dataloader Args 126 | ############################################################## 127 | dataset_parameters: 128 | dataset_type: "isic_albumentation_v2" 129 | train_dataset_args: 130 | data_path: "../../../data/train_ph2.csv" 131 | train: True 132 | image_size: [512, 512] 133 | val_dataset_args: 134 | data_path: "../../../data/test_ph2.csv" 135 | train: False 136 | image_size: [512, 512] 137 | test_dataset_args: 138 | data_path: "../../../data/test_ph2.csv" 139 | train: False 140 | image_size: [512, 512] 141 | train_dataloader_args: 142 | batch_size: *batch_size 143 | shuffle: True 144 | num_workers: 4 145 | drop_last: True 146 | pin_memory: True 147 | val_dataloader_args: 148 | batch_size: *batch_size 149 | shuffle: False 150 | num_workers: 4 151 | drop_last: False 152 | pin_memory: True 153 | test_dataloader_args: 154 | batch_size: 1 155 | shuffle: False 156 | num_workers: 4 157 | drop_last: False 158 | pin_memory: True 159 | 160 | ############################################################## 161 | # Data Augmentation Args 162 | ############################################################## 163 | train_augmentation_args: 164 | mean: [0.485, 0.456, 0.406] 165 | std: [0.229, 0.224, 0.225] 166 | image_size: [512, 512] # [H, W] 167 | 168 | test_augmentation_args: 169 | mean: [0.485, 0.456, 0.406] 170 | std: [0.229, 0.224, 0.225] 171 | image_size: [512, 512] # [H, W] -------------------------------------------------------------------------------- /experiments_medical/ph2/exp_2_dice_b8_a2/inference_model_weights/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/ph2/exp_2_dice_b8_a2/inference_model_weights/.gitignore -------------------------------------------------------------------------------- /experiments_medical/ph2/exp_2_dice_b8_a2/metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../../\")\n", 13 | "\n", 14 | "import yaml\n", 15 | "import torch\n", 16 | "import numpy as np\n", 17 | "from typing import Dict\n", 18 | "from architectures.build_architecture import build_architecture\n", 19 | "from dataloaders.build_dataset import build_dataset\n", 20 | "from typing import Tuple, Dict\n", 21 | "from fvcore.nn import FlopCountAnalysis\n", 22 | "from tqdm.notebook import tqdm\n", 23 | "from sklearn.metrics import (\n", 24 | " jaccard_score,\n", 25 | " accuracy_score,\n", 26 | " confusion_matrix,\n", 27 | ")\n", 28 | "import monai" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Load Config File" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def load_config(config_path: str) -> Dict:\n", 45 | " \"\"\"loads the yaml config file\n", 46 | "\n", 47 | " Args:\n", 48 | " config_path (str): _description_\n", 49 | "\n", 50 | " Returns:\n", 51 | " Dict: _description_\n", 52 | " \"\"\"\n", 53 | " with open(config_path, \"r\") as file:\n", 54 | " config = yaml.safe_load(file)\n", 55 | " return config\n", 56 | "\n", 57 | "\n", 58 | "config = load_config(\"config.yaml\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Build Dataset and DataLoaders" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "40\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "# build validation dataset & validataion data loader\n", 83 | "testset = build_dataset(\n", 84 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n", 85 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n", 86 | " augmentation_args=config[\"test_augmentation_args\"],\n", 87 | ")\n", 88 | "\n", 89 | "testloader = torch.utils.data.DataLoader(\n", 90 | " testset, batch_size=1, shuffle=False, num_workers=1\n", 91 | ")\n", 92 | "\n", 93 | "print(len(testset))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "model = build_architecture(config=config)\n", 103 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n", 104 | "model.load_state_dict(checkpoint)\n", 105 | "model = model.to(\"cpu\")\n", 106 | "model = model.eval()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "Model Complexity" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Computational complexity: 1.3 GMac\n", 126 | "Number of parameters: 3.01 M \n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "import torchvision.models as models\n", 132 | "import torch\n", 133 | "from ptflops import get_model_complexity_info\n", 134 | "\n", 135 | "with torch.cuda.device(0):\n", 136 | " net = model\n", 137 | " macs, params = get_model_complexity_info(\n", 138 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n", 139 | " )\n", 140 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n", 141 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stderr", 151 | "output_type": "stream", 152 | "text": [ 153 | "Unsupported operator aten::silu encountered 84 time(s)\n", 154 | "Unsupported operator aten::add encountered 45 time(s)\n", 155 | "Unsupported operator aten::div encountered 15 time(s)\n", 156 | "Unsupported operator aten::ceil encountered 6 time(s)\n", 157 | "Unsupported operator aten::mul encountered 118 time(s)\n", 158 | "Unsupported operator aten::softmax encountered 21 time(s)\n", 159 | "Unsupported operator aten::clone encountered 4 time(s)\n", 160 | "Unsupported operator aten::mul_ encountered 24 time(s)\n", 161 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n", 162 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n", 163 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n" 164 | ] 165 | }, 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "Computational complexity: 3.01 \n", 171 | "Number of parameters: 1.24 \n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "def flop_count_analysis(\n", 177 | " model: torch.nn.Module,\n", 178 | " input_dim: Tuple,\n", 179 | ") -> Dict:\n", 180 | " \"\"\"_summary_\n", 181 | "\n", 182 | " Args:\n", 183 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n", 184 | " model (torch.nn.Module): _description_\n", 185 | " \"\"\"\n", 186 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 187 | " input_tensor = torch.ones(()).new_empty(\n", 188 | " (1, *input_dim),\n", 189 | " dtype=next(model.parameters()).dtype,\n", 190 | " device=next(model.parameters()).device,\n", 191 | " )\n", 192 | " flops = FlopCountAnalysis(model, input_tensor)\n", 193 | " model_flops = flops.total()\n", 194 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n", 195 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n", 196 | "\n", 197 | " out = {\n", 198 | " \"params\": round(trainable_params * 1e-6, 2),\n", 199 | " \"flops\": round(model_flops * 1e-9, 2),\n", 200 | " }\n", 201 | "\n", 202 | " return out\n", 203 | "\n", 204 | "\n", 205 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n", 206 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n", 207 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "Calculate IoU Metric" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "application/vnd.jupyter.widget-view+json": { 225 | "model_id": "8697bc384ea947628f121854eb4d091c", 226 | "version_major": 2, 227 | "version_minor": 0 228 | }, 229 | "text/plain": [ 230 | "0it [00:00, ?it/s]" 231 | ] 232 | }, 233 | "metadata": {}, 234 | "output_type": "display_data" 235 | }, 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "test iou: 0.9229712867991552\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "iou = []\n", 246 | "with torch.no_grad():\n", 247 | " for idx, data in tqdm(enumerate(testloader)):\n", 248 | " image = data[\"image\"].cuda()\n", 249 | " mask = data[\"mask\"].cuda()\n", 250 | " out = model.forward(image)\n", 251 | " out = torch.sigmoid(out)\n", 252 | " out[out < 0.5] = 0\n", 253 | " out[out >= 0.5] = 1\n", 254 | " mean_iou = jaccard_score(\n", 255 | " mask.detach().cpu().numpy().ravel(),\n", 256 | " out.detach().cpu().numpy().ravel(),\n", 257 | " average=\"binary\",\n", 258 | " pos_label=1,\n", 259 | " )\n", 260 | " iou.append(mean_iou.item())\n", 261 | "\n", 262 | "print(f\"test iou: {np.mean(iou)}\")" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Accuracy" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 8, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "application/vnd.jupyter.widget-view+json": { 280 | "model_id": "a451270dd9f04d849a2c32923785143c", 281 | "version_major": 2, 282 | "version_minor": 0 283 | }, 284 | "text/plain": [ 285 | "0it [00:00, ?it/s]" 286 | ] 287 | }, 288 | "metadata": {}, 289 | "output_type": "display_data" 290 | }, 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "test accuracy: 0.9771324157714844\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "accuracy = []\n", 301 | "with torch.no_grad():\n", 302 | " for idx, data in tqdm(enumerate(testloader)):\n", 303 | " image = data[\"image\"].cuda()\n", 304 | " mask = data[\"mask\"].cuda()\n", 305 | " out = model.forward(image)\n", 306 | " out = torch.sigmoid(out)\n", 307 | " out[out < 0.5] = 0\n", 308 | " out[out >= 0.5] = 1\n", 309 | " acc = accuracy_score(\n", 310 | " mask.detach().cpu().numpy().ravel(),\n", 311 | " out.detach().cpu().numpy().ravel(),\n", 312 | " )\n", 313 | " accuracy.append(acc.item())\n", 314 | "\n", 315 | "print(f\"test accuracy: {np.mean(accuracy)}\")" 316 | ] 317 | }, 318 | { 319 | "attachments": {}, 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "Calculate Dice" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 9, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "data": { 333 | "application/vnd.jupyter.widget-view+json": { 334 | "model_id": "0c465b3238b842bb8ee3f6b37a1b21b9", 335 | "version_major": 2, 336 | "version_minor": 0 337 | }, 338 | "text/plain": [ 339 | "0it [00:00, ?it/s]" 340 | ] 341 | }, 342 | "metadata": {}, 343 | "output_type": "display_data" 344 | }, 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "test dice: 0.9569526433944702\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "dice = []\n", 355 | "with torch.no_grad():\n", 356 | " for idx, data in tqdm(enumerate(testloader)):\n", 357 | " image = data[\"image\"].cuda()\n", 358 | " mask = data[\"mask\"].cuda()\n", 359 | " out = model.forward(image)\n", 360 | " out = torch.sigmoid(out)\n", 361 | " out[out < 0.5] = 0\n", 362 | " out[out >= 0.5] = 1\n", 363 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n", 364 | " dice.append(mean_dice.item())\n", 365 | "\n", 366 | "print(f\"test dice: {np.mean(dice)}\")" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "Calculate Specificity" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 10, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "application/vnd.jupyter.widget-view+json": { 384 | "model_id": "924fdb4171ca424297fcf300084bb352", 385 | "version_major": 2, 386 | "version_minor": 0 387 | }, 388 | "text/plain": [ 389 | "0it [00:00, ?it/s]" 390 | ] 391 | }, 392 | "metadata": {}, 393 | "output_type": "display_data" 394 | }, 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "test specificity: 0.9660395899750096\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "specificity = []\n", 405 | "with torch.no_grad():\n", 406 | " for idx, data in tqdm(enumerate(testloader)):\n", 407 | " image = data[\"image\"].cuda()\n", 408 | " mask = data[\"mask\"].cuda()\n", 409 | " out = model.forward(image)\n", 410 | " out = torch.sigmoid(out)\n", 411 | " out[out < 0.5] = 0\n", 412 | " out[out >= 0.5] = 1\n", 413 | " confusion = confusion_matrix(\n", 414 | " mask.detach().cpu().numpy().ravel(),\n", 415 | " out.detach().cpu().numpy().ravel(),\n", 416 | " )\n", 417 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n", 418 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n", 419 | "\n", 420 | " specificity.append(sp)\n", 421 | "\n", 422 | "print(f\"test specificity: {np.mean(specificity)}\")" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "Calculate Sensitivity" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 11, 435 | "metadata": {}, 436 | "outputs": [ 437 | { 438 | "data": { 439 | "application/vnd.jupyter.widget-view+json": { 440 | "model_id": "7502c6211be24f269ba862fe067fe054", 441 | "version_major": 2, 442 | "version_minor": 0 443 | }, 444 | "text/plain": [ 445 | "0it [00:00, ?it/s]" 446 | ] 447 | }, 448 | "metadata": {}, 449 | "output_type": "display_data" 450 | }, 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "test sensitivity: 0.9604657601219436\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "sensitivity = []\n", 461 | "with torch.no_grad():\n", 462 | " for idx, data in tqdm(enumerate(testloader)):\n", 463 | " image = data[\"image\"].cuda()\n", 464 | " mask = data[\"mask\"].cuda()\n", 465 | " out = model.forward(image)\n", 466 | " out = torch.sigmoid(out)\n", 467 | " out[out < 0.5] = 0\n", 468 | " out[out >= 0.5] = 1\n", 469 | " confusion = confusion_matrix(\n", 470 | " mask.detach().cpu().numpy().ravel(),\n", 471 | " out.detach().cpu().numpy().ravel(),\n", 472 | " )\n", 473 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n", 474 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n", 475 | "\n", 476 | " sensitivity.append(se)\n", 477 | "\n", 478 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 12, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "# DONE" 488 | ] 489 | } 490 | ], 491 | "metadata": { 492 | "kernelspec": { 493 | "display_name": "core", 494 | "language": "python", 495 | "name": "python3" 496 | }, 497 | "language_info": { 498 | "codemirror_mode": { 499 | "name": "ipython", 500 | "version": 3 501 | }, 502 | "file_extension": ".py", 503 | "mimetype": "text/x-python", 504 | "name": "python", 505 | "nbconvert_exporter": "python", 506 | "pygments_lexer": "ipython3", 507 | "version": "3.11.7" 508 | }, 509 | "orig_nbformat": 4, 510 | "vscode": { 511 | "interpreter": { 512 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea" 513 | } 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 2 518 | } 519 | -------------------------------------------------------------------------------- /experiments_medical/ph2/exp_2_dice_b8_a2/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | sys.path.append("../../../") 6 | 7 | import yaml 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from typing import Dict 12 | from termcolor import colored 13 | from accelerate import Accelerator 14 | from losses.losses import build_loss_fn 15 | from optimizers.optimizers import build_optimizer 16 | from optimizers.schedulers import build_scheduler 17 | from train_scripts.segmentation_trainer import Segmentation_Trainer 18 | from architectures.build_architecture import build_architecture 19 | from dataloaders.build_dataset import build_dataset, build_dataloader 20 | 21 | 22 | ################################################################################################## 23 | def launch_experiment(config_path) -> Dict: 24 | """ 25 | Builds Experiment 26 | Args: 27 | config (Dict): configuration file 28 | 29 | Returns: 30 | Dict: _description_ 31 | """ 32 | # load config 33 | config = load_config(config_path) 34 | 35 | # set seed 36 | seed_everything(config) 37 | 38 | # build directories 39 | build_directories(config) 40 | 41 | # build training dataset & training data loader 42 | trainset = build_dataset( 43 | dataset_type=config["dataset_parameters"]["dataset_type"], 44 | dataset_args=config["dataset_parameters"]["train_dataset_args"], 45 | augmentation_args=config["train_augmentation_args"], 46 | ) 47 | trainloader = build_dataloader( 48 | dataset=trainset, 49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"], 50 | config=config, 51 | train=True, 52 | ) 53 | 54 | # build validation dataset & validataion data loader 55 | valset = build_dataset( 56 | dataset_type=config["dataset_parameters"]["dataset_type"], 57 | dataset_args=config["dataset_parameters"]["val_dataset_args"], 58 | augmentation_args=config["test_augmentation_args"], 59 | ) 60 | valloader = build_dataloader( 61 | dataset=valset, 62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"], 63 | config=config, 64 | train=False, 65 | ) 66 | 67 | # build the Model 68 | model = build_architecture(config) 69 | 70 | # set up the loss function 71 | criterion = build_loss_fn( 72 | loss_type=config["loss_fn"]["loss_type"], 73 | loss_args=config["loss_fn"]["loss_args"], 74 | ) 75 | 76 | # set up the optimizer 77 | optimizer = build_optimizer( 78 | model=model, 79 | optimizer_type=config["optimizer"]["optimizer_type"], 80 | optimizer_args=config["optimizer"]["optimizer_args"], 81 | ) 82 | 83 | # set up schedulers 84 | warmup_scheduler = build_scheduler( 85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config 86 | ) 87 | training_scheduler = build_scheduler( 88 | optimizer=optimizer, 89 | scheduler_type="training_scheduler", 90 | config=config, 91 | ) 92 | 93 | # use accelarate 94 | accelerator = Accelerator( 95 | log_with="wandb", 96 | gradient_accumulation_steps=config["training_parameters"][ 97 | "grad_accumulate_steps" 98 | ], 99 | ) 100 | accelerator.init_trackers( 101 | project_name=config["project"], 102 | config=config, 103 | init_kwargs={"wandb": config["wandb_parameters"]}, 104 | ) 105 | 106 | # display experiment info 107 | display_info(config, accelerator, trainset, valset, model) 108 | 109 | # convert all components to accelerate 110 | model = accelerator.prepare_model(model=model) 111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer) 112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader) 113 | valloader = accelerator.prepare_data_loader(data_loader=valloader) 114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler) 115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler) 116 | 117 | # create a single dict to hold all parameters 118 | storage = { 119 | "model": model, 120 | "trainloader": trainloader, 121 | "valloader": valloader, 122 | "criterion": criterion, 123 | "optimizer": optimizer, 124 | "warmup_scheduler": warmup_scheduler, 125 | "training_scheduler": training_scheduler, 126 | } 127 | 128 | # set up trainer 129 | trainer = Segmentation_Trainer( 130 | config=config, 131 | model=storage["model"], 132 | optimizer=storage["optimizer"], 133 | criterion=storage["criterion"], 134 | train_dataloader=storage["trainloader"], 135 | val_dataloader=storage["valloader"], 136 | warmup_scheduler=storage["warmup_scheduler"], 137 | training_scheduler=storage["training_scheduler"], 138 | accelerator=accelerator, 139 | ) 140 | 141 | # run train 142 | trainer.train() 143 | 144 | 145 | ################################################################################################## 146 | def seed_everything(config) -> None: 147 | seed = config["training_parameters"]["seed"] 148 | os.environ["PYTHONHASHSEED"] = str(seed) 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.benchmark = False 156 | 157 | 158 | ################################################################################################## 159 | def load_config(config_path: str) -> Dict: 160 | """loads the yaml config file 161 | 162 | Args: 163 | config_path (str): _description_ 164 | 165 | Returns: 166 | Dict: _description_ 167 | """ 168 | with open(config_path, "r") as file: 169 | config = yaml.safe_load(file) 170 | return config 171 | 172 | 173 | ################################################################################################## 174 | def build_directories(config: Dict) -> None: 175 | # create necessary directories 176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]): 177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"]) 178 | 179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]): 180 | raise ValueError("checkpoint exits -- preventing file override -- rename file") 181 | 182 | 183 | ################################################################################################## 184 | def display_info(config, accelerator, trainset, valset, model): 185 | # print experiment info 186 | accelerator.print(f"-------------------------------------------------------") 187 | accelerator.print(f"[info]: Experiment Info") 188 | accelerator.print( 189 | f"[info] ----- Project: {colored(config['project'], color='red')}" 190 | ) 191 | accelerator.print( 192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}" 193 | ) 194 | accelerator.print( 195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}" 196 | ) 197 | accelerator.print( 198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}" 199 | ) 200 | accelerator.print( 201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}" 202 | ) 203 | accelerator.print( 204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}" 205 | ) 206 | accelerator.print( 207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}" 208 | ) 209 | accelerator.print( 210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}" 211 | ) 212 | accelerator.print( 213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}" 214 | ) 215 | 216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 217 | accelerator.print( 218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}" 219 | ) 220 | accelerator.print( 221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}" 222 | ) 223 | accelerator.print( 224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}" 225 | ) 226 | accelerator.print( 227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}" 228 | ) 229 | accelerator.print( 230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}" 231 | ) 232 | accelerator.print(f"-------------------------------------------------------") 233 | 234 | 235 | ################################################################################################## 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser(description="Simple example of training script.") 238 | parser.add_argument( 239 | "--config", type=str, default="config.yaml", help="path to yaml config file" 240 | ) 241 | args = parser.parse_args() 242 | launch_experiment(args.config) 243 | -------------------------------------------------------------------------------- /experiments_medical/ph2/exp_2_dice_b8_a2/visualize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../../\")\n", 13 | "\n", 14 | "import yaml\n", 15 | "import torch\n", 16 | "from typing import Dict\n", 17 | "from architectures.build_architecture import build_architecture\n", 18 | "from dataloaders.build_dataset import build_dataset\n", 19 | "from torchvision.utils import save_image" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "def load_config(config_path: str) -> Dict:\n", 29 | " \"\"\"loads the yaml config file\n", 30 | "\n", 31 | " Args:\n", 32 | " config_path (str): _description_\n", 33 | "\n", 34 | " Returns:\n", 35 | " Dict: _description_\n", 36 | " \"\"\"\n", 37 | " with open(config_path, \"r\") as file:\n", 38 | " config = yaml.safe_load(file)\n", 39 | " return config" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "config = load_config(\"config.yaml\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Set Up" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "model = build_architecture(config=config)\n", 65 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n", 66 | "model.load_state_dict(checkpoint)\n", 67 | "model = model.to(\"cpu\")\n", 68 | "model = model.eval()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# build validation dataset & validataion data loader\n", 78 | "valset = build_dataset(\n", 79 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n", 80 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n", 81 | " augmentation_args=config[\"test_augmentation_args\"],\n", 82 | ")\n", 83 | "testloader = torch.utils.data.DataLoader(\n", 84 | " valset, batch_size=1, shuffle=False, num_workers=1\n", 85 | ")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "Inference" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "if not os.path.isdir(\"predictions\"):\n", 102 | " os.makedirs(\"predictions/rgb\")\n", 103 | " os.makedirs(\"predictions/gt\")\n", 104 | " os.makedirs(\"predictions/pred\")" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 7, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "torch.Size([1, 3, 512, 512])\n", 117 | "torch.Size([1, 3, 512, 512])\n", 118 | "torch.Size([1, 3, 512, 512])\n", 119 | "torch.Size([1, 3, 512, 512])\n", 120 | "torch.Size([1, 3, 512, 512])\n", 121 | "torch.Size([1, 3, 512, 512])\n", 122 | "torch.Size([1, 3, 512, 512])\n", 123 | "torch.Size([1, 3, 512, 512])\n", 124 | "torch.Size([1, 3, 512, 512])\n", 125 | "torch.Size([1, 3, 512, 512])\n", 126 | "torch.Size([1, 3, 512, 512])\n", 127 | "torch.Size([1, 3, 512, 512])\n", 128 | "torch.Size([1, 3, 512, 512])\n", 129 | "torch.Size([1, 3, 512, 512])\n", 130 | "torch.Size([1, 3, 512, 512])\n", 131 | "torch.Size([1, 3, 512, 512])\n", 132 | "torch.Size([1, 3, 512, 512])\n", 133 | "torch.Size([1, 3, 512, 512])\n", 134 | "torch.Size([1, 3, 512, 512])\n", 135 | "torch.Size([1, 3, 512, 512])\n", 136 | "torch.Size([1, 3, 512, 512])\n", 137 | "torch.Size([1, 3, 512, 512])\n", 138 | "torch.Size([1, 3, 512, 512])\n", 139 | "torch.Size([1, 3, 512, 512])\n", 140 | "torch.Size([1, 3, 512, 512])\n", 141 | "torch.Size([1, 3, 512, 512])\n", 142 | "torch.Size([1, 3, 512, 512])\n", 143 | "torch.Size([1, 3, 512, 512])\n", 144 | "torch.Size([1, 3, 512, 512])\n", 145 | "torch.Size([1, 3, 512, 512])\n", 146 | "torch.Size([1, 3, 512, 512])\n", 147 | "torch.Size([1, 3, 512, 512])\n", 148 | "torch.Size([1, 3, 512, 512])\n", 149 | "torch.Size([1, 3, 512, 512])\n", 150 | "torch.Size([1, 3, 512, 512])\n", 151 | "torch.Size([1, 3, 512, 512])\n", 152 | "torch.Size([1, 3, 512, 512])\n", 153 | "torch.Size([1, 3, 512, 512])\n", 154 | "torch.Size([1, 3, 512, 512])\n", 155 | "torch.Size([1, 3, 512, 512])\n", 156 | "torch.Size([1, 3, 512, 512])\n", 157 | "torch.Size([1, 3, 512, 512])\n", 158 | "torch.Size([1, 3, 512, 512])\n", 159 | "torch.Size([1, 3, 512, 512])\n", 160 | "torch.Size([1, 3, 512, 512])\n", 161 | "torch.Size([1, 3, 512, 512])\n", 162 | "torch.Size([1, 3, 512, 512])\n", 163 | "torch.Size([1, 3, 512, 512])\n", 164 | "torch.Size([1, 3, 512, 512])\n", 165 | "torch.Size([1, 3, 512, 512])\n", 166 | "torch.Size([1, 3, 512, 512])\n", 167 | "torch.Size([1, 3, 512, 512])\n", 168 | "torch.Size([1, 3, 512, 512])\n", 169 | "torch.Size([1, 3, 512, 512])\n", 170 | "torch.Size([1, 3, 512, 512])\n", 171 | "torch.Size([1, 3, 512, 512])\n", 172 | "torch.Size([1, 3, 512, 512])\n", 173 | "torch.Size([1, 3, 512, 512])\n", 174 | "torch.Size([1, 3, 512, 512])\n", 175 | "torch.Size([1, 3, 512, 512])\n", 176 | "torch.Size([1, 3, 512, 512])\n", 177 | "torch.Size([1, 3, 512, 512])\n", 178 | "torch.Size([1, 3, 512, 512])\n", 179 | "torch.Size([1, 3, 512, 512])\n", 180 | "torch.Size([1, 3, 512, 512])\n", 181 | "torch.Size([1, 3, 512, 512])\n", 182 | "torch.Size([1, 3, 512, 512])\n", 183 | "torch.Size([1, 3, 512, 512])\n", 184 | "torch.Size([1, 3, 512, 512])\n", 185 | "torch.Size([1, 3, 512, 512])\n", 186 | "torch.Size([1, 3, 512, 512])\n", 187 | "torch.Size([1, 3, 512, 512])\n", 188 | "torch.Size([1, 3, 512, 512])\n", 189 | "torch.Size([1, 3, 512, 512])\n", 190 | "torch.Size([1, 3, 512, 512])\n", 191 | "torch.Size([1, 3, 512, 512])\n", 192 | "torch.Size([1, 3, 512, 512])\n", 193 | "torch.Size([1, 3, 512, 512])\n", 194 | "torch.Size([1, 3, 512, 512])\n", 195 | "torch.Size([1, 3, 512, 512])\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "model.eval()\n", 201 | "counter = 4000\n", 202 | "for idx, data in enumerate(testloader):\n", 203 | " image = data[\"image\"].cuda()\n", 204 | " mask = data[\"mask\"].cuda()\n", 205 | " out = model.forward(image)\n", 206 | " out = torch.sigmoid(out)\n", 207 | " out[out < 0.5] = 0\n", 208 | " out[out >= 0.5] = 1\n", 209 | "\n", 210 | " # rgb input\n", 211 | " img = data[\"image\"]\n", 212 | " img = img.detach()\n", 213 | " print(img.shape)\n", 214 | " img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n", 215 | " img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n", 216 | " img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n", 217 | " save_image(img, f\"predictions/rgb/image_{idx}.png\")\n", 218 | "\n", 219 | " # prediction\n", 220 | " pred = out.detach()\n", 221 | " pred = pred * 255.0\n", 222 | " save_image(pred, f\"predictions/pred/pred_{idx}.png\")\n", 223 | "\n", 224 | " # ground truth\n", 225 | " gt = data[\"mask\"]\n", 226 | " gt = gt.detach()\n", 227 | " gt = gt * 255.0\n", 228 | " save_image(gt, f\"predictions/gt/gt_{idx}.png\")\n", 229 | "\n", 230 | " if idx == counter:\n", 231 | " break" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "Save RGB, GT and Pred" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 14, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "# # RGB INPUT\n", 248 | "\n", 249 | "# img = data[\"image\"]\n", 250 | "# img = img.detach()\n", 251 | "# print(img.shape)\n", 252 | "# img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n", 253 | "# img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n", 254 | "# img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n", 255 | "# save_image(img, f\"predictions/rgb/image_{idx}.png\")" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 15, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "# # PREDICTION\n", 265 | "\n", 266 | "# pred = out.detach()\n", 267 | "# pred = pred * 255.0\n", 268 | "# save_image(pred, f\"predictions/pred/pred_{idx}.png\")" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 16, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "# # GROUND TRUTH\n", 278 | "\n", 279 | "# gt = data[1]\n", 280 | "# gt = gt.detach()\n", 281 | "# gt = gt * 255.0\n", 282 | "# save_image(gt, f\"predictions/gt/gt_{idx}.png\")" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "corev2", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.11.7" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/losses/__init__.py -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import monai 3 | import torch.nn as nn 4 | from typing import Dict 5 | import torch.nn.functional as F 6 | 7 | 8 | ##################################################################### 9 | class CrossEntropyLoss(nn.Module): 10 | def __init__(self, loss_args=None): 11 | super().__init__() 12 | self._loss = nn.CrossEntropyLoss(reduction="mean") 13 | 14 | def __call__(self, predictions, targets): 15 | loss = self._loss(predictions, targets) 16 | return loss 17 | 18 | 19 | ##################################################################### 20 | class BinaryCrossEntropyWithLogits: 21 | def __init__(self, loss_args=None): 22 | super().__init__() 23 | self._loss = nn.BCEWithLogitsLoss(reduction="mean") 24 | 25 | def __call__(self, predictions, targets): 26 | loss = self._loss(predictions, targets) 27 | return loss 28 | 29 | 30 | ##################################################################### 31 | class MSELoss: 32 | def __init__(self, loss_args=None): 33 | super().__init__() 34 | self._loss = nn.MSELoss() 35 | 36 | def __call__(self, predicted, target): 37 | loss = self._loss(predicted, target) 38 | return loss 39 | 40 | 41 | ##################################################################### 42 | class L1Loss: 43 | def __init__(self, loss_args=None): 44 | super().__init__() 45 | self._loss = nn.L1Loss() 46 | 47 | def __call__(self, predicted, target): 48 | loss = self._loss(predicted, target) 49 | return loss 50 | 51 | 52 | ##################################################################### 53 | class DistillationLoss(nn.Module): 54 | # TODO: loss function not 100% verified 55 | def __init__(self, loss_args: Dict): 56 | """_summary_ 57 | 58 | Args: 59 | loss_args (Dict): _description_ 60 | """ 61 | super().__init__() 62 | self.kl_div = nn.KLDivLoss(reduction="batchmean") 63 | self.temperature = loss_args["temperature"] 64 | self.lambda_param = loss_args["lambda"] 65 | self.dice = monai.losses.DiceLoss( 66 | to_onehot_y=False, 67 | sigmoid=True, 68 | ) 69 | 70 | def forward( 71 | self, 72 | teacher_predictions: torch.Tensor, 73 | student_predictions: torch.Tensor, 74 | targets: torch.Tensor, 75 | ) -> torch.Tensor: 76 | """_summary_ 77 | 78 | Args: 79 | teacher_predictions (torch.Tensor): _description_ 80 | student_predictions (torch.Tensor): _description_ 81 | targets (torch.Tensor): _description_ 82 | 83 | Returns: 84 | torch.Tensor: _description_ 85 | """ 86 | # compute probabilties 87 | soft_teacher = F.softmax( 88 | teacher_predictions.view(-1) / self.temperature, 89 | dim=-1, 90 | ) 91 | soft_student = F.log_softmax( 92 | student_predictions.view(-1) / self.temperature, 93 | dim=-1, 94 | ) 95 | 96 | # compute kl div loss 97 | distillation_loss = self.kl_div(soft_student, soft_teacher) * ( 98 | self.temperature**2 99 | ) 100 | 101 | # compute dice loss on student 102 | dice_loss = self.dice(student_predictions, targets) 103 | 104 | # combine via lambda 105 | loss = (1.0 - self.lambda_param) * ( 106 | dice_loss + self.lambda_param * distillation_loss 107 | ) 108 | 109 | return loss 110 | 111 | 112 | class DiceBCELoss(nn.Module): 113 | def __init__(self, loss_args=None): 114 | super().__init__() 115 | self.dice_loss = monai.losses.DiceLoss(sigmoid=True, to_onehot_y=False) 116 | self.bce_loss = torch.nn.BCEWithLogitsLoss() 117 | 118 | def __call__(self, predicted, target): 119 | dice_loss = self.dice_loss(predicted, target) * 1 120 | bce_loss = self.bce_loss(predicted, target.float()) * 1 121 | final = dice_loss + bce_loss 122 | return final 123 | 124 | 125 | ##################################################################### 126 | def build_loss_fn(loss_type: str, loss_args: Dict = None): 127 | if loss_type == "crossentropy": 128 | return CrossEntropyLoss() 129 | 130 | elif loss_type == "binarycrossentropy": 131 | return BinaryCrossEntropyWithLogits() 132 | 133 | elif loss_type == "MSE": 134 | return MSELoss() 135 | 136 | elif loss_type == "L1": 137 | return L1Loss() 138 | 139 | elif loss_type == "dice": 140 | return monai.losses.DiceLoss(to_onehot_y=False, sigmoid=True) 141 | 142 | elif loss_type == "dicebce": 143 | return DiceBCELoss() 144 | 145 | elif loss_type == "dicece": 146 | return monai.losses.DiceCELoss(to_onehot_y=True, softmax=True) 147 | else: 148 | raise ValueError("must be cross entropy or soft dice loss for now!") 149 | -------------------------------------------------------------------------------- /misc/flop_counter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, Dict 3 | from fvcore.nn import FlopCountAnalysis 4 | 5 | 6 | ########################################################################################## 7 | def flop_count_analysis( 8 | input_dim: Tuple, 9 | model: torch.nn.Module, 10 | ) -> Dict: 11 | """_summary_ 12 | 13 | Args: 14 | input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional)) 15 | model (torch.nn.Module): _description_ 16 | """ 17 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 18 | input_tensor = torch.ones(()).new_empty( 19 | (1, *input_dim), 20 | dtype=next(model.parameters()).dtype, 21 | device=next(model.parameters()).device, 22 | ) 23 | flops = FlopCountAnalysis(model, input_tensor) 24 | model_flops = flops.total() 25 | print(f"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M") 26 | print(f"MAdds: {round(model_flops * 1e-9, 2)} G") 27 | 28 | out = { 29 | "params": round(trainable_params * 1e-6, 2), 30 | "flops": round(model_flops * 1e-9, 2), 31 | } 32 | 33 | return out 34 | -------------------------------------------------------------------------------- /misc/profiler.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../../\")\n", 13 | "\n", 14 | "from config import *\n", 15 | "import wandb\n", 16 | "import torch\n", 17 | "from typing import Dict\n", 18 | "from losses.losses import build_loss_fn\n", 19 | "from optimizers.optimizers import build_optimizer\n", 20 | "from train_scripts.train_functions import train_val\n", 21 | "from architectures.leavisnet_v3 import build_model\n", 22 | "\n", 23 | "# from datasets.build_dataset import build_architecture, build_dataloader\n", 24 | "\n", 25 | "from collections import namedtuple\n", 26 | "from torch.utils.data import DataLoader\n", 27 | "import monai" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "model = build_test_model(None)\n", 37 | "with torch.cuda.device(0):\n", 38 | " net = model\n", 39 | " macs, params = get_model_complexity_info(\n", 40 | " net, (3, 224, 224), as_strings=True, print_per_layer_stat=True, verbose=True\n", 41 | " )\n", 42 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n", 43 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "import torch\n", 53 | "from torch.profiler import profile, record_function, ProfilerActivity" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "model = build_test_model(None).cpu()\n", 63 | "inputs = torch.randn(1, 3, 224, 224)\n", 64 | "\n", 65 | "pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 66 | "print(\"\\nmodel parameter count = \", pytorch_total_params)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "def estimate_memory_inference(\n", 76 | " model, sample_input, batch_size=1, use_amp=False, device=0\n", 77 | "):\n", 78 | " \"\"\"Predict the maximum memory usage of the model.\n", 79 | " Args:\n", 80 | " optimizer_type (Type): the class name of the optimizer to instantiate\n", 81 | " model (nn.Module): the neural network model\n", 82 | " sample_input (torch.Tensor): A sample input to the network. It should be\n", 83 | " a single item, not a batch, and it will be replicated batch_size times.\n", 84 | " batch_size (int): the batch size\n", 85 | " use_amp (bool): whether to estimate based on using mixed precision\n", 86 | " device (torch.device): the device to use\n", 87 | " \"\"\"\n", 88 | " # Reset model and optimizer\n", 89 | " model.cpu()\n", 90 | " a = torch.cuda.memory_allocated(device)\n", 91 | " model.to(device)\n", 92 | " b = torch.cuda.memory_allocated(device)\n", 93 | " model_memory = b - a\n", 94 | " model_input = sample_input # .unsqueeze(0).repeat(batch_size, 1)\n", 95 | " output = model(model_input.to(device)).sum()\n", 96 | " total_memory = model_memory\n", 97 | "\n", 98 | " return total_memory\n", 99 | "\n", 100 | "\n", 101 | "estimate_memory_inference(model, inputs)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "model.cpu()\n", 111 | "with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:\n", 112 | " with record_function(\"model_inference\"):\n", 113 | " model(inputs)\n", 114 | "\n", 115 | "print(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "model = model.cuda()\n", 125 | "inputs = torch.randn(1, 3, 224, 224).cuda()\n", 126 | "\n", 127 | "with profile(\n", 128 | " activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True\n", 129 | ") as prof:\n", 130 | " with record_function(\"model_inference\"):\n", 131 | " model(inputs)\n", 132 | "\n", 133 | "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "import torch\n", 157 | "import matplotlib.pyplot as plt\n", 158 | "\n", 159 | "model = torch.nn.Linear(2, 1)\n", 160 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", 161 | "lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n", 162 | " optimizer, T_0=20, T_mult=2, eta_min=0.001, last_epoch=-1\n", 163 | ")\n", 164 | "\n", 165 | "\n", 166 | "lrs = []\n", 167 | "\n", 168 | "for i in range(1000):\n", 169 | " lr_sched.step()\n", 170 | " lrs.append(optimizer.param_groups[0][\"lr\"])\n", 171 | "\n", 172 | "plt.plot(lrs)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "corev2", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.10.10" 198 | }, 199 | "orig_nbformat": 4 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 2 203 | } 204 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/optimizers/__init__.py -------------------------------------------------------------------------------- /optimizers/optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch.optim as optim 3 | from monai.optimizers import Novograd 4 | 5 | ###################################################################### 6 | def optim_adam(model, optimizer_args): 7 | adam = optim.Adam( 8 | model.parameters(), 9 | lr=optimizer_args["lr"], 10 | weight_decay=optimizer_args.get("weight_decay"), 11 | ) 12 | return adam 13 | 14 | 15 | ###################################################################### 16 | def optim_sgd(model, optimizer_args): 17 | adam = optim.SGD( 18 | model.parameters(), 19 | lr=optimizer_args["lr"], 20 | weight_decay=optimizer_args.get("weight_decay"), 21 | momentum=optimizer_args.get("momentum"), 22 | ) 23 | return adam 24 | 25 | 26 | ###################################################################### 27 | def optim_adamw(model, optimizer_args): 28 | adam = optim.AdamW( 29 | model.parameters(), 30 | lr=optimizer_args["lr"], 31 | weight_decay=optimizer_args["weight_decay"], 32 | # amsgrad=True, 33 | ) 34 | return adam 35 | 36 | ###################################################################### 37 | def optim_novograd(model, optimizer_args): 38 | novograd = Novograd( 39 | model.parameters(), 40 | lr=optimizer_args["lr"], 41 | weight_decay=optimizer_args["weight_decay"], 42 | # amsgrad=True, 43 | ) 44 | return novograd 45 | 46 | 47 | ###################################################################### 48 | def build_optimizer(model, optimizer_type: str, optimizer_args: Dict): 49 | if optimizer_type == "adam": 50 | return optim_adam(model, optimizer_args) 51 | elif optimizer_type == "adamw": 52 | return optim_adamw(model, optimizer_args) 53 | elif optimizer_type == "sgd": 54 | return optim_sgd(model, optimizer_args) 55 | elif optimizer_type == "novograd": 56 | return optim_novograd(model, optimizer_args) 57 | else: 58 | raise ValueError("must be adam or adamw for now") 59 | -------------------------------------------------------------------------------- /optimizers/schedulers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch.optim as optim 3 | from torch.optim.lr_scheduler import LRScheduler 4 | 5 | 6 | ################################################################################################## 7 | def warmup_lr_scheduler(config, optimizer): 8 | """ 9 | Linearly ramps up the learning rate within X 10 | number of epochs to the working epoch. 11 | Args: 12 | optimizer (_type_): _description_ 13 | warmup_epochs (_type_): _description_ 14 | warmup_lr (_type_): warmup lr should be the starting lr we want. 15 | """ 16 | lambda1 = lambda epoch: ( 17 | (epoch + 1) * 1.0 / config["warmup_scheduler"]["warmup_epochs"] 18 | ) 19 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1, verbose=False) 20 | return scheduler 21 | 22 | 23 | ################################################################################################## 24 | def training_lr_scheduler(config, optimizer): 25 | """ 26 | Wraps a normal scheuler 27 | """ 28 | scheduler_type = config["train_scheduler"]["scheduler_type"] 29 | if scheduler_type == "reducelronplateau": 30 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 31 | optimizer, 32 | factor=0.1, 33 | mode=config["train_scheduler"]["mode"], 34 | patience=config["train_scheduler"]["patience"], 35 | verbose=False, 36 | min_lr=config["train_scheduler"]["scheduler_args"]["min_lr"], 37 | ) 38 | return scheduler 39 | elif scheduler_type == "cosine_annealing_wr": 40 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( 41 | optimizer, 42 | T_0=config["train_scheduler"]["scheduler_args"]["t_0_epochs"], 43 | T_mult=config["train_scheduler"]["scheduler_args"]["t_mult"], 44 | eta_min=config["train_scheduler"]["scheduler_args"]["min_lr"], 45 | last_epoch=-1, 46 | verbose=False, 47 | ) 48 | return scheduler 49 | else: 50 | raise NotImplementedError("Specified Scheduler Is Not Implemented") 51 | 52 | 53 | ################################################################################################## 54 | def build_scheduler( 55 | optimizer: optim.Optimizer, scheduler_type: str, config 56 | ) -> LRScheduler: 57 | """generates the learning rate scheduler 58 | 59 | Args: 60 | optimizer (optim.Optimizer): pytorch optimizer 61 | scheduler_type (str): type of scheduler 62 | 63 | Returns: 64 | LRScheduler: _description_ 65 | """ 66 | if scheduler_type == "warmup_scheduler": 67 | scheduler = warmup_lr_scheduler(config=config, optimizer=optimizer) 68 | return scheduler 69 | elif scheduler_type == "training_scheduler": 70 | scheduler = training_lr_scheduler(config=config, optimizer=optimizer) 71 | return scheduler 72 | else: 73 | raise ValueError("Invalid Input -- Check scheduler_type") 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | datasets==2.14.5 3 | evaluate==0.4.1 4 | einops==0.7.0 5 | h5py==3.9.0 6 | imageio==2.33.0 7 | kornia==0.7.2 8 | kornia-rs==0.1.3 9 | matplotlib==3.7.1 10 | matplotlib-inline==0.1.6 11 | monai==1.2.0 12 | nibabel==5.1.0 13 | opencv-contrib-python==4.7.0.72 14 | opencv-python==4.7.0.72 15 | opencv-python-headless==4.7.0.72 16 | opt-einsum==3.3.0 17 | protobuf==4.22.3 18 | requests==2.28.1 19 | safetensors==0.3.1 20 | scikit-image==0.23.1 21 | scikit-learn==1.2.2 22 | scipy==1.13.0 23 | termcolor==2.3.0 24 | timm==0.6.13 25 | torch-geometric==2.3.0 26 | torch-summary==1.4.5 27 | torchmetrics==0.11.4 28 | torchsummary==1.5.1 29 | tqdm==4.65.0 30 | typing_extensions==4.8.0 31 | wandb==0.15.12 32 | -------------------------------------------------------------------------------- /resources/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/.gitignore -------------------------------------------------------------------------------- /resources/adv_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/adv_arch.png -------------------------------------------------------------------------------- /resources/citiscapes1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/citiscapes1.png -------------------------------------------------------------------------------- /resources/cityscapes2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes2.png -------------------------------------------------------------------------------- /resources/cityscapes3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes3.png -------------------------------------------------------------------------------- /resources/cityscapes4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes4.png -------------------------------------------------------------------------------- /resources/cityscapes_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes_table.png -------------------------------------------------------------------------------- /resources/isic_2016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/isic_2016.png -------------------------------------------------------------------------------- /resources/isic_2017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/isic_2017.png -------------------------------------------------------------------------------- /resources/isic_2018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/isic_2018.png -------------------------------------------------------------------------------- /resources/muvit_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/muvit_architecture.png -------------------------------------------------------------------------------- /resources/params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/params.png -------------------------------------------------------------------------------- /resources/ph2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/ph2.png -------------------------------------------------------------------------------- /resources/postdam1_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam1_gt.png -------------------------------------------------------------------------------- /resources/postdam1_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam1_pred.png -------------------------------------------------------------------------------- /resources/postdam2_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam2_gt.png -------------------------------------------------------------------------------- /resources/postdam2_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam2_pred.png -------------------------------------------------------------------------------- /resources/potsdam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/potsdam.png -------------------------------------------------------------------------------- /resources/vaihigen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/vaihigen.png -------------------------------------------------------------------------------- /train_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/train_scripts/__init__.py -------------------------------------------------------------------------------- /train_scripts/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | 5 | ############################################################################### 6 | class EMA: 7 | "https://github.com/scott-yjyang/DiffMIC/blob/main/ema.py" 8 | 9 | def __init__( 10 | self, 11 | model: nn.Module, 12 | mu: float = 0.999, 13 | ) -> None: 14 | self.mu = mu 15 | self.ema_model = {} 16 | self.register(model) 17 | self.model_copy = copy.deepcopy(model) 18 | 19 | def register(self, module) -> None: 20 | for name, param in module.named_parameters(): 21 | if param.requires_grad: 22 | self.ema_model[name] = param.data.clone() 23 | 24 | def update(self, module: nn.Module) -> None: 25 | for name, param in module.named_parameters(): 26 | if param.requires_grad: 27 | self.ema_model[name].data = ( 28 | 1.0 - self.mu 29 | ) * param.data + self.mu * self.ema_model[name].data 30 | 31 | def ema(self, module: nn.Module) -> None: 32 | for name, param in module.named_parameters(): 33 | if param.requires_grad: 34 | param.data.copy_(self.ema_model[name].data) 35 | 36 | def ema_copy(self, module: nn.Module): 37 | """ 38 | Returns the model with the ema weights inserted. 39 | Args: 40 | module (nn.Module): _description_ 41 | 42 | Returns: 43 | _type_: _description_ 44 | """ 45 | # module_copy = type(module)(module.config).to(module.config.device) 46 | module_copy = copy.deepcopy(module) 47 | # module_copy.load_state_dict(module.state_dict()) 48 | self.ema(module_copy) 49 | 50 | return module_copy 51 | 52 | def state_dict(self): 53 | """ 54 | Returns ema model 55 | Returns: 56 | _type_: _description_ 57 | """ 58 | return self.ema_model 59 | 60 | def load_state_dict(self, state_dict) -> None: 61 | self.ema_model = state_dict 62 | 63 | 64 | ############################################################################### 65 | -------------------------------------------------------------------------------- /train_scripts/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import wandb 4 | import torch 5 | import random 6 | import numpy as np 7 | 8 | """ 9 | Utils File Used for Training/Validation/Testing 10 | """ 11 | 12 | 13 | ################################################################################################## 14 | def log_metrics(**kwargs) -> None: 15 | # data to be logged 16 | log_data = {} 17 | log_data.update(kwargs) 18 | 19 | # log the data 20 | wandb.log(log_data) 21 | 22 | 23 | ################################################################################################## 24 | def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar") -> None: 25 | # print("=> Saving checkpoint") 26 | checkpoint = { 27 | "state_dict": model.state_dict(), 28 | "optimizer": optimizer.state_dict(), 29 | } 30 | torch.save(checkpoint, filename) 31 | 32 | 33 | ################################################################################################## 34 | def load_checkpoint(config, model, optimizer, load_optimizer=True): 35 | print("=> Loading checkpoint") 36 | checkpoint = torch.load(config.checkpoint_file_name, map_location=config.device) 37 | model.load_state_dict(checkpoint["state_dict"]) 38 | 39 | if load_optimizer: 40 | optimizer.load_state_dict(checkpoint["optimizer"]) 41 | 42 | # If we don't do this then it will just have learning rate of old checkpoint 43 | # and it will lead to many hours of debugging \: 44 | for param_group in optimizer.param_groups: 45 | param_group["lr"] = config.learning_rate 46 | 47 | return model, optimizer 48 | 49 | 50 | ################################################################################################## 51 | def seed_everything(seed: int = 42) -> None: 52 | os.environ["PYTHONHASHSEED"] = str(seed) 53 | random.seed(seed) 54 | np.random.seed(seed) 55 | torch.manual_seed(seed) 56 | torch.cuda.manual_seed(seed) 57 | torch.cuda.manual_seed_all(seed) 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | 61 | 62 | ################################################################################################## 63 | def initialize_weights(m): 64 | if isinstance(m, torch.nn.Conv2d): 65 | torch.nn.init.xavier_uniform_(m.weight) 66 | if m.bias is not None: 67 | torch.nn.init.constant_(m.bias.data, 0) 68 | elif isinstance(m, torch.nn.BatchNorm2d): 69 | torch.nn.init.constant_(m.weight.data, 1) 70 | if m.bias is not None: 71 | torch.nn.init.constant_(m.bias.data, 0) 72 | elif isinstance(m, torch.nn.Linear): 73 | torch.nn.init.xavier_uniform_(m.weight) 74 | if m.bias is not None: 75 | torch.nn.init.constant_(m.bias.data, 0) 76 | 77 | 78 | ################################################################################################## 79 | --------------------------------------------------------------------------------