├── .gitignore
├── README.md
├── __init__.py
├── architectures
├── __init__.py
├── build_architecture.py
└── mobileunetr.py
├── augmentations
├── __init__.py
├── segmentation_augmentations.py
└── segmentation_augmentations_tv.py
├── dataloaders
├── __init__.py
├── build_dataset.py
└── isic_dataset.py
├── example_configs
├── config_mobileunetr_s.yaml
├── config_mobileunetr_xs.yaml
└── config_mobileunetr_xxs.yaml
├── experiments_medical
├── isic_2016
│ └── exp_2_dice_b8_a2
│ │ ├── config.yaml
│ │ ├── generate_images.ipynb
│ │ ├── inference_model_weights
│ │ └── .gitignore
│ │ ├── metrics.ipynb
│ │ ├── run_experiment.py
│ │ └── visualize.ipynb
├── isic_2017
│ └── exp_2_dice_b8_a2
│ │ ├── config.yaml
│ │ ├── inference_model_weights
│ │ └── .gitignore
│ │ ├── metrics.ipynb
│ │ └── run_experiment.py
├── isic_2018
│ └── exp_2_dice_b8_a2
│ │ ├── config.yaml
│ │ ├── inference_model_weights
│ │ └── .gitignore
│ │ ├── metrics.ipynb
│ │ ├── run_experiment.py
│ │ └── visualize.ipynb
└── ph2
│ └── exp_2_dice_b8_a2
│ ├── config.yaml
│ ├── inference_model_weights
│ └── .gitignore
│ ├── metrics.ipynb
│ ├── run_experiment.py
│ └── visualize.ipynb
├── losses
├── __init__.py
└── losses.py
├── misc
├── flop_counter.py
├── mobileunetr.py
└── profiler.ipynb
├── optimizers
├── __init__.py
├── optimizers.py
└── schedulers.py
├── requirements.txt
├── resources
├── .gitignore
├── adv_arch.png
├── citiscapes1.png
├── cityscapes2.png
├── cityscapes3.png
├── cityscapes4.png
├── cityscapes_table.png
├── isic_2016.png
├── isic_2017.png
├── isic_2018.png
├── muvit_architecture.png
├── params.png
├── ph2.png
├── postdam1_gt.png
├── postdam1_pred.png
├── postdam2_gt.png
├── postdam2_pred.png
├── potsdam.png
└── vaihigen.png
└── train_scripts
├── __init__.py
├── ema.py
├── segmentation_trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | **/wandb
3 | **/__pycache__
4 | **.npy
5 | **.pth
6 | **.svg
7 | **.jpg
8 | **.csv
9 | **.bin
10 | **/data
11 | **.tex
12 | **/model_checkpoints
13 | **.pkl
14 | #**.png
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MobileUNETR
2 | ## A Lightweight End-To-End Hybrid Vision Transformer For Efficient Medical Image Segmentation: [ECCV 2024 -- BioImage Computing] (ORAL) https://arxiv.org/abs/2409.03062
3 |
4 | ## Architecture
5 |
6 |
7 |

8 |
9 |
10 |
11 | ## Parameter Distribution and Computational Complexity
12 |
13 |
14 |

15 |
16 |
17 |
18 | ## :rocket: News
19 | * Repository Construction Complete 06/09/2024
20 | * We will continue to update the GitHub repository with new experiments with a wide range of datasets, so be sure to check back regularly.
21 | * In the meantime -> Checkout our other projects: https://github.com/OSUPCVLab/SegFormer3D
22 |
23 | ## Overview:
24 | * Segmentation approaches broadly fall into 2 categories.
25 | 1. End to End CNN Based Segmentation Methods
26 | 2. Transformer Based Encoder with a CNN Based Decoder.
27 | * Many Transformer based segmentation approaches rely primarily on CNN based decoders overlooking the benefits of the Transformer architecture within the decoder.
28 | * We address the need for an efficient/lightweight segmentation architecture by introducing MobileUNETR, which aims to overcome the performance constraints associated with both CNNs and Transformers while minimizing model size, presenting a promising stride towards efficient image segmentation.
29 | * MobileUNETR has 3 main features.
30 | 1. MobileUNETR comprises of a lightweight hybrid CNN-Transformer encoder to help balance local and global contextual feature extraction in an efficient manner.
31 | 2. A novel hybrid decoder that simultaneously utilizes low-level and global features at different resolutions within the decoding stage for accurate mask generation.
32 | 3. surpassing large and complex architectures, MobileUNETR achieves superior performance with 3 million parameters and a computational complexity of 1.3 GFLOPs.
33 |
34 | ## Stand Alone Model [Please Read]
35 | To help improve ease of use of the MobileUNETR architecture, the model is constructed as a single stand alone file. If you want to use the model outside of the provided code base simply grab the mobileunetr.py file from architectures folder and insert it into your own project.
36 |
37 | * Example:
38 | ```
39 | # import from mobileunetr.py file
40 | from mobileunetr import build_mobileunetr_s, build_mobileunetr_xs, build_mobileunetr_xxs
41 | import torch
42 |
43 | # create model
44 | mobileunetr_s = build_mobileunetr_s(num_classes= 1, image_size=512)
45 |
46 | mobileunetr_xs = build_mobileunetr_xs(num_classes=1, image_size=512)
47 |
48 | mobileunetr_xxs = build_mobileunetr_xxs(num_classes=1, image_size= 512)
49 |
50 | # forward pass
51 | data = torch.randn((4, 3, 512, 512))
52 | out = mobileunetr_xxs.forward(data)
53 | print(f"input tensor: {data.shape}")
54 | print(f"output tensor: {out.shape}")
55 |
56 | ```
57 |
58 | ## Data and Data Processing
59 | * ISIC Data -- https://challenge.isic-archive.com/data/
60 | * PH2 Data -- https://www.fc.up.pt/addi/ph2%20database.html
61 |
62 | * Data Preprocessing
63 | * For each dataset ISIC 2016, ISIC 2017 ... etc, Simply create a csv file with N rows 2 columns. Where N is the number of items in the dataset and 2 columns ["image", "mask"] are paths to input image and the path to target mask.
64 | * Once you have a train.csv and a test.csv (lets assume for ISIC 2016), inside experiments/isic_2016/exp_2_dice_b8_a2/config.yaml update the data path for the train and test csv files. And Follow the steps below to run the experiment.
65 |
66 | ## Run Your Experiment
67 | In order to run an experiment, we provide a template folder placed under `MobileUNETR_HOME_PATH/experiments/isic_2016/experiment_folder` that you can use to setup your experiment. While inside the "experiment_folder" run your experiment on a single GPU with:
68 | ```shell
69 | cd MobileUNETR
70 | cd experiments/isic_2016/exp_2_dice_b8_a2/
71 | # the default gpu device is set to cuda:0 (you can change it)
72 | CUDA_VISIBLE_DEVICES="0" accelerate launch run_experiment.py
73 | ```
74 | You might want to change the hyperparameters (batch size, learning rate, weight decay etc.) of your experiment. For that you need to edit the `config.yaml` file inside your experiment folder.
75 |
76 | As the experiment is running, the logs (train loss, vlaidation loss and dice score) will be written to the terminal. You can log your experiment on [wandb](https://wandb.ai/site)
77 | (you need to setup an account there) if you set `mode: "online"` in the `wandb_parameters` section of the `config.yaml`. The default value is `mode: "offline"`. If you want to log the result to your wandb account, put your wandb info into the `wandb_parameters` section of the `config.yaml` and your entire experiment will be logged under your wandb entity (e.g. `pcvlab`) page.
78 |
79 | ## ISIC 2016 Performance
80 |
81 |
82 |

83 |
84 |
85 |
86 | ## ISIC 2017 Performance
87 |
88 |
89 |

90 |
91 |
92 |
93 | ## ISIC 2018 Performance
94 |
95 |
96 |

97 |
98 |
99 |
100 | ## ISIC PH2 Performance
101 |
102 |
103 |

104 |
105 |
106 |
107 | ## Advanced Architectures and Training Methods
108 |
109 |
110 |

111 |
112 |
113 |
114 | ## Experiments: Extending to Complex Real World Scenes (Cityscapes, Potsdamn and Vaihigen)
115 |
116 | ### Cityscapes Results
117 |
118 |
119 |
120 |

121 |
122 |
123 |
124 |
125 |
126 |

127 |
128 |
129 |
130 |
131 |
132 |

133 |
134 |
135 |
136 |
137 |
138 |

139 |
140 |
141 |
142 |
143 |
144 |

145 |
146 |
147 |
148 | ### Potsdam and Vaihigen Results (GT [Left], Prediction Overlay [Right])
149 |
150 |
151 |
152 |

153 |

154 |
155 |
156 |
157 |
158 |
159 |

160 |

161 |
162 |
163 |
164 | ### Potsdam Left Table and Vaihigen Right Table
165 |
166 |
167 |

168 |

169 |
170 |
171 |
172 | ## Citation
173 | If you liked our paper, please consider citing it [will update TBD sections soon]
174 | ```bibtex
175 | @inproceedings{perera2024mobileunetr,
176 | title={MobileUNETR: A Lightweight End-To-End Hybrid Vision Transformer For Efficient Medical Image Segmentation},
177 | author={Perera, Shehan, Erzurumlu, Yunus, Gulati, Deepak and Yilmaz, Alper},
178 | booktitle={Proceedings of the IEEE/CVF European Conference on Computer Vision (ECCV)},
179 | pages={TBD},
180 | year={2024}
181 | }
182 | ```
183 | ```bibtex
184 | @article{perera2024mobileunetr,
185 | title={MobileUNETR: A Lightweight End-To-End Hybrid Vision Transformer For Efficient Medical Image Segmentation},
186 | author={Perera, Shehan, Erzurumlu, Yunus, Gulati, Deepak and Yilmaz, Alper},
187 | journal={https://arxiv.org/abs/2409.03062},
188 | year={2024}
189 | }
190 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/__init__.py
--------------------------------------------------------------------------------
/architectures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/architectures/__init__.py
--------------------------------------------------------------------------------
/architectures/build_architecture.py:
--------------------------------------------------------------------------------
1 | """
2 | To select the architecture based on a config file we need to ensure
3 | we import each of the architectures into this file. Once we have that
4 | we can use a keyword from the config file to build the model.
5 | """
6 |
7 |
8 | ######################################################################
9 | def build_architecture(config):
10 | if config["model_name"] == "mobileunetr_xs":
11 | from .mobileunetr import build_mobileunetr_xs
12 |
13 | model = build_mobileunetr_xs(config=config)
14 | return model
15 |
16 | elif config["model_name"] == "mobileunetr_xxs":
17 | from .mobileunetr import build_mobileunetr_xxs
18 |
19 | model = build_mobileunetr_xxs(config=config)
20 | return model
21 |
22 | elif config["model_name"] == "mobileunetr_s":
23 | from .mobileunetr import build_mobileunetr_s
24 |
25 | model = build_mobileunetr_s(config=config)
26 | return model
27 |
28 | else:
29 | return ValueError(
30 | "specified model not supported, edit build_architecture.py file"
31 | )
32 |
--------------------------------------------------------------------------------
/augmentations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/augmentations/__init__.py
--------------------------------------------------------------------------------
/augmentations/segmentation_augmentations.py:
--------------------------------------------------------------------------------
1 | import albumentations as A
2 | from albumentations.pytorch import ToTensorV2
3 | import cv2
4 | from typing import Dict
5 |
6 |
7 | #######################################################################################
8 | def build_augmentations(train: bool = True, augmentation_args: Dict = None):
9 |
10 | mean = augmentation_args["mean"]
11 | std = augmentation_args["std"]
12 | image_size = augmentation_args["image_size"]
13 |
14 | if train:
15 | train_transform = A.Compose(
16 | [
17 | A.Resize(image_size[0], image_size[1], interpolation=cv2.INTER_CUBIC),
18 | A.HorizontalFlip(p=0.5),
19 | A.VerticalFlip(p=0.5),
20 | A.RandomBrightnessContrast(
21 | brightness_limit=0.5,
22 | contrast_limit=0.5,
23 | p=0.6,
24 | ),
25 | A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
26 | A.RandomRotate90(p=0.5),
27 | A.RandomResizedCrop(
28 | image_size[0],
29 | image_size[1],
30 | interpolation=cv2.INTER_CUBIC,
31 | p=0.5,
32 | ),
33 | A.Normalize(mean, std),
34 | ToTensorV2(),
35 | ]
36 | )
37 |
38 | return train_transform
39 | else:
40 | test_transform = A.Compose(
41 | [
42 | A.Resize(
43 | image_size[0],
44 | image_size[1],
45 | interpolation=cv2.INTER_CUBIC,
46 | ),
47 | A.Normalize(mean, std),
48 | ToTensorV2(),
49 | ]
50 | )
51 | return test_transform
52 |
53 |
54 | #######################################################################################
55 | def build_augmentations_v2(train: bool = True, augmentation_args: Dict = None):
56 |
57 | mean = augmentation_args["mean"]
58 | std = augmentation_args["std"]
59 | image_size = augmentation_args["image_size"]
60 |
61 | if train:
62 | train_transform = A.Compose(
63 | [
64 | A.Resize(image_size[0], image_size[1], interpolation=cv2.INTER_CUBIC),
65 | A.HorizontalFlip(p=0.5),
66 | A.VerticalFlip(p=0.5),
67 | A.RandomBrightnessContrast(
68 | brightness_limit=0.5,
69 | contrast_limit=0.5,
70 | p=0.6,
71 | ),
72 | A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
73 | A.RandomRotate90(p=0.5),
74 | A.OneOf(
75 | [
76 | A.RandomResizedCrop(
77 | image_size[0],
78 | image_size[1],
79 | interpolation=cv2.INTER_CUBIC,
80 | p=1,
81 | ),
82 | A.Compose(
83 | [
84 | A.CropNonEmptyMaskIfExists(128, 128, p=1),
85 | A.Resize(
86 | image_size[0],
87 | image_size[1],
88 | interpolation=cv2.INTER_CUBIC,
89 | ),
90 | ]
91 | ),
92 | A.Compose(
93 | [
94 | A.CropNonEmptyMaskIfExists(64, 64, p=1),
95 | A.Resize(
96 | image_size[0],
97 | image_size[1],
98 | interpolation=cv2.INTER_CUBIC,
99 | ),
100 | ]
101 | ),
102 | A.Compose(
103 | [
104 | A.CropNonEmptyMaskIfExists(256, 256, p=1),
105 | A.Resize(
106 | image_size[0],
107 | image_size[1],
108 | interpolation=cv2.INTER_CUBIC,
109 | ),
110 | ]
111 | ),
112 | A.Compose(
113 | [
114 | A.CropNonEmptyMaskIfExists(160, 160, p=1),
115 | A.Resize(
116 | image_size[0],
117 | image_size[1],
118 | interpolation=cv2.INTER_CUBIC,
119 | ),
120 | ]
121 | ),
122 | ],
123 | p=0.5,
124 | ),
125 | A.Normalize(mean, std),
126 | ToTensorV2(),
127 | ]
128 | )
129 |
130 | return train_transform
131 | else:
132 | test_transform = A.Compose(
133 | [
134 | A.Resize(image_size[0], image_size[1], interpolation=cv2.INTER_CUBIC),
135 | A.HorizontalFlip(p=0.3),
136 | A.VerticalFlip(p=0.3),
137 | A.Normalize(mean, std),
138 | ToTensorV2(),
139 | ]
140 | )
141 | return test_transform
142 |
143 |
144 | #######################################################################################
145 |
--------------------------------------------------------------------------------
/augmentations/segmentation_augmentations_tv.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | from typing import Dict
3 |
4 |
5 | # [0.7128, 0.6000, 0.5532], [0.1577, 0.1662, 0.1829]
6 | #######################################################################################
7 | def build_augmentations(train: bool = True, augmentation_args: Dict = None):
8 | mean = augmentation_args["mean"]
9 | std = augmentation_args["std"]
10 | image_size = augmentation_args["image_size"]
11 |
12 | if train:
13 | train_transform = transforms.Compose(
14 | [
15 | transforms.Resize(
16 | (image_size[0], image_size[1]),
17 | interpolation=transforms.InterpolationMode.BICUBIC,
18 | ),
19 | transforms.autoaugment.TrivialAugmentWide(
20 | interpolation=transforms.InterpolationMode.BILINEAR
21 | ),
22 | transforms.ToTensor(),
23 | transforms.Normalize(mean, std),
24 | ]
25 | )
26 |
27 | return train_transform
28 | else:
29 | test_transform = transforms.Compose(
30 | [
31 | # transforms.RandomRotation(180),
32 | transforms.Resize(
33 | (image_size[0], image_size[1]),
34 | interpolation=transforms.InterpolationMode.BICUBIC,
35 | ),
36 | transforms.ToTensor(),
37 | transforms.Normalize(mean, std),
38 | ]
39 | )
40 |
41 | return test_transform
42 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/dataloaders/__init__.py
--------------------------------------------------------------------------------
/dataloaders/build_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append("../")
4 | from typing import Dict
5 | from torch.utils.data import DataLoader
6 | from monai.data import DataLoader
7 |
8 |
9 | #############################################################################################
10 | def build_dataset(dataset_type: str, dataset_args: Dict, augmentation_args: Dict):
11 | if dataset_type == "isic_torchvision":
12 | from .isic_dataset import ISICDataset
13 | from augmentations.segmentation_augmentations_tv import build_augmentations
14 |
15 | dataset = ISICDataset(
16 | data=dataset_args["data_path"],
17 | image_size=dataset_args["image_size"],
18 | input_transforms=build_augmentations(
19 | dataset_args["train"],
20 | augmentation_args,
21 | ),
22 | target_transforms=True,
23 | )
24 | return dataset
25 | elif dataset_type == "isic_albumentation_v2":
26 | from .isic_dataset import ISICDatasetA
27 | from augmentations.segmentation_augmentations import build_augmentations_v2
28 |
29 | dataset = ISICDatasetA(
30 | data=dataset_args["data_path"],
31 | transforms=build_augmentations_v2(dataset_args["train"], augmentation_args),
32 | )
33 | return dataset
34 | else:
35 | raise NotImplementedError("datasets are supported")
36 |
37 |
38 | #############################################################################################
39 | def build_dataloader(
40 | dataset,
41 | dataloader_args: Dict,
42 | config: Dict = None,
43 | train: bool = True,
44 | ) -> DataLoader:
45 | """builds the dataloader for given dataset
46 |
47 | Args:
48 | dataset (_type_): _description_
49 | dataloader_args (Dict): _description_
50 | config (Dict, optional): _description_. Defaults to None.
51 | train (bool, optional): _description_. Defaults to True.
52 |
53 | Returns:
54 | DataLoader: _description_
55 | """
56 | dataloader = DataLoader(
57 | dataset=dataset,
58 | batch_size=dataloader_args["batch_size"],
59 | shuffle=dataloader_args["shuffle"],
60 | num_workers=dataloader_args["num_workers"],
61 | drop_last=dataloader_args["drop_last"],
62 | pin_memory=dataloader_args["pin_memory"],
63 | )
64 | return dataloader
65 |
66 |
67 | #############################################################################################
68 |
--------------------------------------------------------------------------------
/dataloaders/isic_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision
4 | import pandas as pd
5 | from PIL import Image
6 | import torchvision.transforms.functional as tvF
7 |
8 |
9 | ##########################################################################################################
10 | class ISICDataset(torch.utils.data.Dataset):
11 | def __init__(self, data, image_size, input_transforms, target_transforms):
12 | """
13 | Initialize Class
14 | Args:
15 | config (_type_): _description_
16 | data (_type_): _description_
17 | input_transforms (_type_): _description_
18 | target_transforms (_type_): _description_
19 | """
20 | self.data = pd.read_csv(data)
21 | self.image_size = image_size # tuple
22 | self.input_transforms = input_transforms
23 | self.target_transforms = target_transforms
24 |
25 | def __len__(self):
26 | return len(self.data)
27 |
28 | def __getitem__(self, index):
29 | # get the associated_image
30 | image_id = self.data.iloc[index]
31 |
32 | # get the image
33 | image_name = image_id["image"]
34 | image = Image.open(image_name).convert("RGB")
35 |
36 | # get the mask path
37 | mask_id = image_id["mask"]
38 | mask = Image.open(mask_id)
39 |
40 | # transform input image
41 | if self.input_transforms:
42 | image = self.input_transforms(image)
43 |
44 | # transform target image
45 | if self.target_transforms:
46 | mask = tvF.resize(
47 | mask,
48 | self.image_size,
49 | torchvision.transforms.InterpolationMode.NEAREST,
50 | )
51 | mask = np.array(mask).astype(int)
52 | mask[mask == 255] = 1
53 | mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0)
54 |
55 | out = {
56 | "image": image, # [3, H, W]
57 | "mask": mask, # [1, H, W] 1 b/c binary mask
58 | }
59 |
60 | return out
61 |
62 |
63 | ##########################################################################################################
64 | class ISICDatasetA(torch.utils.data.Dataset):
65 | """
66 | Dataset used for albumentation augmentations
67 | """
68 |
69 | def __init__(self, data, transforms):
70 | """
71 | Initialize Class
72 | Args:
73 | config (_type_): _description_
74 | data (_type_): _description_
75 | input_transforms (_type_): _description_
76 | target_transforms (_type_): _description_
77 | """
78 | self.data = pd.read_csv(data)
79 | # self.data = pd.concat([self.data] * 2).sample(frac=1).reset_index(drop=True)
80 | self.transforms = transforms
81 |
82 | def __len__(self):
83 | return len(self.data)
84 |
85 | def __getitem__(self, index):
86 | # get the associated_image
87 | image_id = self.data.iloc[index]
88 |
89 | # get the image
90 | image_name = image_id["image"]
91 | image = Image.open(image_name).convert("RGB")
92 | image = np.array(image, dtype=np.uint8)
93 |
94 | # get the mask path
95 | mask_id = image_id["mask"]
96 | mask = Image.open(mask_id)
97 | mask = np.array(mask, dtype=np.uint8)
98 | mask[mask == 255] = 1
99 |
100 | # transform
101 | if self.transforms:
102 | out = self.transforms(image=image, mask=mask)
103 | image = out["image"]
104 | mask = out["mask"]
105 |
106 | out = {
107 | "image": image, # [3, H, W]
108 | "mask": mask.unsqueeze(0), # [1, H, W] 1 b/c binary mask
109 | }
110 |
111 | return out
112 |
--------------------------------------------------------------------------------
/example_configs/config_mobileunetr_s.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: isic_2016
16 | name: exp_2_dice_b8_a2
17 | mode: "online"
18 | resume: False
19 | tags: ["tr16ts16", "dice", "s", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xxs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [240]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [160, 196, 196]
34 | decoder:
35 | dims: [144, 192, 240]
36 | channels: [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 196, 196, 640]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_2016.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_2016.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_2016.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/example_configs/config_mobileunetr_xs.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: isic_2016
16 | name: exp_2_dice_b8_a2
17 | mode: "online"
18 | resume: False
19 | tags: ["tr16ts16", "dice", "xs", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [144]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [96, 128, 128]
34 | decoder:
35 | dims: [96, 120, 144]
36 | channels: [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 128, 128, 384]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_2016.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_2016.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_2016.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/example_configs/config_mobileunetr_xxs.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: isic_2016
16 | name: exp_2_dice_b8_a2
17 | mode: "online"
18 | resume: False
19 | tags: ["tr16ts16", "dice", "xxs", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xxs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [96]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [80, 96, 96]
34 | decoder:
35 | dims: [64, 80, 96]
36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_2016.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_2016.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_2016.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/experiments_medical/isic_2016/exp_2_dice_b8_a2/config.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: isic_2016
16 | name: exp_2_dice_b8_a2
17 | mode: "online"
18 | resume: False
19 | tags: ["tr16ts16", "dice", "xxs", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xxs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [96]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [80, 96, 96]
34 | decoder:
35 | dims: [64, 80, 96]
36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_2016.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_2016.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_2016.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/experiments_medical/isic_2016/exp_2_dice_b8_a2/generate_images.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import numpy as np\n",
11 | "from PIL import Image\n",
12 | "import matplotlib.image as mpimg\n",
13 | "import matplotlib.pyplot as plt"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "img_1 = [\n",
23 | " \"predictions/rgb/image_8.png\",\n",
24 | " \"predictions/gt/gt_8.png\",\n",
25 | " \"predictions/pred/pred_8.png\",\n",
26 | " None,\n",
27 | "]\n",
28 | "img_2 = [\n",
29 | " \"predictions/rgb/image_32.png\",\n",
30 | " \"predictions/gt/gt_32.png\",\n",
31 | " \"predictions/pred/pred_32.png\",\n",
32 | " None,\n",
33 | "]\n",
34 | "img_3 = [\n",
35 | " \"predictions/rgb/image_14.png\",\n",
36 | " \"predictions/gt/gt_14.png\",\n",
37 | " \"predictions/pred/pred_14.png\",\n",
38 | " None,\n",
39 | "]\n",
40 | "img_4 = [\n",
41 | " \"predictions/rgb/image_28.png\",\n",
42 | " \"predictions/gt/gt_28.png\",\n",
43 | " \"predictions/pred/pred_28.png\",\n",
44 | " None,\n",
45 | "]\n",
46 | "img_5 = [\n",
47 | " \"predictions/rgb/image_13.png\",\n",
48 | " \"predictions/gt/gt_13.png\",\n",
49 | " \"predictions/pred/pred_13.png\",\n",
50 | " None,\n",
51 | "]\n",
52 | "\n",
53 | "read_rgb = lambda x: np.array(Image.open(x).convert(\"RGB\"))\n",
54 | "read_gray = lambda x: np.array(Image.open(x).convert(\"L\"))\n",
55 | "\n",
56 | "data = [img_1, img_2, img_3, img_4, img_5]"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "import numpy as np\n",
66 | "import cv2\n",
67 | "from PIL import Image, ImageDraw\n",
68 | "\n",
69 | "\n",
70 | "# def create_overlay_image(original_img, ground_truth_mask, predicted_mask):\n",
71 | "# # Convert binary masks to uint8 format\n",
72 | "# ground_truth_mask = ground_truth_mask.astype(np.uint8) * 255\n",
73 | "# predicted_mask = predicted_mask.astype(np.uint8) * 255\n",
74 | "\n",
75 | "# # Create an RGB version of the original image\n",
76 | "# original_rgb = original_img\n",
77 | "\n",
78 | "# # Create an alpha channel for overlaying masks\n",
79 | "# alpha_channel = np.zeros_like(original_img)\n",
80 | "\n",
81 | "# # Set red for ground truth and blue for predicted mask in alpha channel\n",
82 | "# alpha_channel[ground_truth_mask > 0] = [255, 0, 0]\n",
83 | "# alpha_channel[predicted_mask > 0] = [0, 0, 255]\n",
84 | "\n",
85 | "# # Combine original image and masks using alpha blending\n",
86 | "# overlayed_img = cv2.addWeighted(original_rgb, 0.7, alpha_channel, 0.3, 0)\n",
87 | "\n",
88 | "# return overlayed_img\n",
89 | "\n",
90 | "import numpy as np\n",
91 | "import matplotlib.pyplot as plt\n",
92 | "from skimage import measure, color\n",
93 | "\n",
94 | "def create_overlay_image(image, ground_truth_mask, predicted_mask):\n",
95 | " \"\"\"\n",
96 | " Overlays the boundaries of the ground truth mask and the predicted mask on the image.\n",
97 | " \n",
98 | " Parameters:\n",
99 | " - image: numpy array, the original image\n",
100 | " - ground_truth_mask: numpy array, binary ground truth mask\n",
101 | " - predicted_mask: numpy array, binary predicted mask\n",
102 | " \n",
103 | " Returns:\n",
104 | " - overlay_image: numpy array, the original image with overlaid boundaries\n",
105 | " \"\"\"\n",
106 | " def generate_boundary(mask):\n",
107 | " # Ensure the mask is binary\n",
108 | " mask = mask.astype(bool)\n",
109 | " \n",
110 | " # Find contours at a constant value of 0.5\n",
111 | " contours = measure.find_contours(mask, 0.5)\n",
112 | " \n",
113 | " # Create an empty image to draw the boundaries\n",
114 | " boundary = np.zeros_like(mask, dtype=np.uint8)\n",
115 | " \n",
116 | " # Draw the contours on the boundary image\n",
117 | " for contour in contours:\n",
118 | " for y, x in contour:\n",
119 | " boundary[int(y), int(x)] = 1\n",
120 | " \n",
121 | " return boundary\n",
122 | " \n",
123 | " # Generate boundaries\n",
124 | " ground_truth_boundary = generate_boundary(ground_truth_mask)\n",
125 | " predicted_boundary = generate_boundary(predicted_mask)\n",
126 | " \n",
127 | " # Convert the original image to RGB if it's grayscale\n",
128 | " if len(image.shape) == 2 or image.shape[2] == 1:\n",
129 | " image = color.gray2rgb(image)\n",
130 | " \n",
131 | " # Create a copy of the image to overlay the boundaries\n",
132 | " overlay_image = image.copy()\n",
133 | " \n",
134 | " # Define colors for boundaries\n",
135 | " ground_truth_color = [255, 0, 0] # Red\n",
136 | " predicted_color = [0, 0, 255] # Blue\n",
137 | " \n",
138 | " # Overlay the ground truth boundary\n",
139 | " overlay_image[ground_truth_boundary == 1] = ground_truth_color\n",
140 | " \n",
141 | " # Overlay the predicted boundary\n",
142 | " overlay_image[predicted_boundary == 1] = predicted_color\n",
143 | " \n",
144 | " return overlay_image\n",
145 | "\n",
146 | "\n",
147 | "import matplotlib.pyplot as plt\n",
148 | "import matplotlib.gridspec as gridspec\n",
149 | "from matplotlib.ticker import MultipleLocator\n",
150 | "from PIL import Image\n",
151 | "\n",
152 | "# Assuming 'data' is your list of lists\n",
153 | "# Each inner list should contain [image, ground truth, prediction]\n",
154 | "\n",
155 | "# Create a grid layout\n",
156 | "rows = 4\n",
157 | "cols = len(data)\n",
158 | "fig = plt.figure(figsize=(cols * 2, rows * 2))\n",
159 | "gs = gridspec.GridSpec(rows, cols, wspace=0.01, hspace=0.01)\n",
160 | "\n",
161 | "# Plot each image, ground truth, and prediction\n",
162 | "for col, sample in enumerate(data):\n",
163 | " for row, img_type in enumerate(sample):\n",
164 | "\n",
165 | " if row != 3:\n",
166 | " ax = plt.subplot(gs[row, col])\n",
167 | " img = Image.open(img_type)\n",
168 | "\n",
169 | " if row == 3:\n",
170 | " img = create_overlay_image(\n",
171 | " np.array(Image.open(sample[0])),\n",
172 | " np.array(Image.open(sample[1]).convert(\"L\")),\n",
173 | " np.array(Image.open(sample[2]).convert(\"L\")),\n",
174 | " )\n",
175 | "\n",
176 | " # Display image\n",
177 | " ax.imshow(img, cmap=\"gray\") # Assuming images are in grayscale\n",
178 | " ax.axis(\"off\")\n",
179 | "\n",
180 | " # Set titles on the leftmost side\n",
181 | " if col == 0 and row != 3:\n",
182 | " ax.text(\n",
183 | " -0.01,\n",
184 | " 0.5,\n",
185 | " f\"({chr(97+row)})\",\n",
186 | " transform=ax.transAxes,\n",
187 | " fontsize=12,\n",
188 | " va=\"center\",\n",
189 | " ha=\"right\",\n",
190 | " )\n",
191 | "\n",
192 | " # # Set titles for each row\n",
193 | " # if row == 0:\n",
194 | " # ax.set_title('Image')\n",
195 | " # elif row == 1:\n",
196 | " # ax.set_title('Ground Truth')\n",
197 | " # elif row == 2:\n",
198 | " # ax.set_title('Prediction')\n",
199 | "\n",
200 | "# Adjust layout\n",
201 | "plt.tight_layout()\n",
202 | "plt.savefig(\"combined_2016_new.png\", bbox_inches=\"tight\")\n",
203 | "plt.show()"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "# def plot_images(input_list):\n",
213 | "# num_images = len(input_list)\n",
214 | "\n",
215 | "# # Set up the subplots\n",
216 | "# fig, axes = plt.subplots(3, num_images, squeeze=True)\n",
217 | "\n",
218 | "# # Set titles on the left side\n",
219 | "# titles = [\"Image\", \"Mask\", \"Prediction\"]\n",
220 | "# for i, title in enumerate(titles):\n",
221 | "# fig.text(\n",
222 | "# 0.03,\n",
223 | "# 0.5,\n",
224 | "# f\"{title}\",\n",
225 | "# va=\"center\",\n",
226 | "# ha=\"center\",\n",
227 | "# rotation=\"vertical\",\n",
228 | "# fontsize=9,\n",
229 | "# )\n",
230 | "\n",
231 | "# for i, input_paths in enumerate(input_list):\n",
232 | "# img_path, mask_path, pred_path = input_paths\n",
233 | "\n",
234 | "# # Load images\n",
235 | "# img = mpimg.imread(img_path)\n",
236 | "# mask = mpimg.imread(mask_path)\n",
237 | "# pred = mpimg.imread(pred_path)\n",
238 | "\n",
239 | "# axes[2, i].set_aspect(\"equal\")\n",
240 | "\n",
241 | "# # Plot the images\n",
242 | "# axes[0, i].imshow(img)\n",
243 | "# axes[1, i].imshow(mask, cmap=\"gray\")\n",
244 | "# axes[2, i].imshow(pred, cmap=\"gray\")\n",
245 | "\n",
246 | "# # Turn off axis labels\n",
247 | "# for ax in axes[:, i]:\n",
248 | "# ax.axis(\"off\")\n",
249 | "\n",
250 | "# plt.tight_layout(h_pad=0, w_pad=0, pad=0.0001)\n",
251 | "# plt.subplots_adjust(wspace=0.01, hspace=0.01)\n",
252 | "# plt.savefig(\"png_2016.png\")\n",
253 | "# plt.show()"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "# plot_images(data)"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {},
269 | "outputs": [],
270 | "source": [
271 | "# from PIL import Image, ImageDraw\n",
272 | "# import math\n",
273 | "\n",
274 | "\n",
275 | "# def combine_images_with_masks(image_data, output_path):\n",
276 | "# num_images = len(image_data)\n",
277 | "# cols = num_images # Number of columns in the output image grid\n",
278 | "# rows = 3 # Number of rows (images, masks, overlays)\n",
279 | "\n",
280 | "# # Create a blank canvas for the combined image\n",
281 | "# combined_width = cols * 200 # Adjust the width of each image as needed\n",
282 | "# combined_height = rows * 200 # Adjust the height of each image as needed\n",
283 | "# combined_image = Image.new(\"RGB\", (combined_width, combined_height), color=\"white\")\n",
284 | "\n",
285 | "# # Paste each image, mask, and overlay onto the canvas in separate rows\n",
286 | "# for i in range(num_images):\n",
287 | "# image_path, mask_path, overlay_path = image_data[i]\n",
288 | "\n",
289 | "# # Load image, mask, and overlay\n",
290 | "# img = Image.open(image_path)\n",
291 | "# mask = Image.open(mask_path).convert(\"L\") # Convert to grayscale if needed\n",
292 | "# overlay = Image.open(overlay_path).convert(\"L\")\n",
293 | "\n",
294 | "# # Resize each image, mask, and overlay as needed\n",
295 | "# img = img.resize((200, 200))\n",
296 | "# mask = mask.resize((200, 200))\n",
297 | "# overlay = overlay.resize((200, 200))\n",
298 | "\n",
299 | "# # Paste image onto the canvas (row 1)\n",
300 | "# x = i * 200\n",
301 | "# y = 0\n",
302 | "# combined_image.paste(img, (x, y))\n",
303 | "\n",
304 | "# # Paste mask onto the canvas (row 2)\n",
305 | "# y = 200\n",
306 | "# combined_image.paste(mask, (x, y))\n",
307 | "\n",
308 | "# # Paste overlay onto the canvas (row 3)\n",
309 | "# y = 400\n",
310 | "# combined_image.paste(overlay, (x, y))\n",
311 | "\n",
312 | "# # Save or display the combined image\n",
313 | "# combined_image.save(output_path)\n",
314 | "# combined_image.show()\n",
315 | "\n",
316 | "\n",
317 | "# output_image_path = \"combined_2016.jpg\"\n",
318 | "# combine_images_with_masks(data, output_image_path)"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": null,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": []
327 | }
328 | ],
329 | "metadata": {
330 | "kernelspec": {
331 | "display_name": "corev2",
332 | "language": "python",
333 | "name": "python3"
334 | },
335 | "language_info": {
336 | "codemirror_mode": {
337 | "name": "ipython",
338 | "version": 3
339 | },
340 | "file_extension": ".py",
341 | "mimetype": "text/x-python",
342 | "name": "python",
343 | "nbconvert_exporter": "python",
344 | "pygments_lexer": "ipython3",
345 | "version": "3.11.7"
346 | }
347 | },
348 | "nbformat": 4,
349 | "nbformat_minor": 2
350 | }
351 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2016/exp_2_dice_b8_a2/inference_model_weights/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/isic_2016/exp_2_dice_b8_a2/inference_model_weights/.gitignore
--------------------------------------------------------------------------------
/experiments_medical/isic_2016/exp_2_dice_b8_a2/metrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "/opt/anaconda3/envs/core/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13 | " from .autonotebook import tqdm as notebook_tqdm\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "import os\n",
19 | "import sys\n",
20 | "\n",
21 | "sys.path.append(\"../../../\")\n",
22 | "\n",
23 | "import yaml\n",
24 | "import torch\n",
25 | "import numpy as np\n",
26 | "from typing import Dict\n",
27 | "from architectures.build_architecture import build_architecture\n",
28 | "from dataloaders.build_dataset import build_dataset\n",
29 | "from typing import Tuple, Dict\n",
30 | "from fvcore.nn import FlopCountAnalysis\n",
31 | "from tqdm.notebook import tqdm\n",
32 | "from sklearn.metrics import (\n",
33 | " jaccard_score,\n",
34 | " accuracy_score,\n",
35 | " confusion_matrix,\n",
36 | ")\n",
37 | "import monai"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "Load Config File"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "def load_config(config_path: str) -> Dict:\n",
54 | " \"\"\"loads the yaml config file\n",
55 | "\n",
56 | " Args:\n",
57 | " config_path (str): _description_\n",
58 | "\n",
59 | " Returns:\n",
60 | " Dict: _description_\n",
61 | " \"\"\"\n",
62 | " with open(config_path, \"r\") as file:\n",
63 | " config = yaml.safe_load(file)\n",
64 | " return config\n",
65 | "\n",
66 | "\n",
67 | "config = load_config(\"config.yaml\")"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "Build Dataset and DataLoaders"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "# build validation dataset & validataion data loader\n",
84 | "testset = build_dataset(\n",
85 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n",
86 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n",
87 | " augmentation_args=config[\"test_augmentation_args\"],\n",
88 | ")\n",
89 | "\n",
90 | "testloader = torch.utils.data.DataLoader(\n",
91 | " testset, batch_size=1, shuffle=False, num_workers=1\n",
92 | ")"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "Build Model and Load Trained Weights"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 12,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "model = build_architecture(config=config)\n",
109 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n",
110 | "model.load_state_dict(checkpoint)\n",
111 | "model = model.to(\"cpu\")\n",
112 | "model = model.eval()"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "Model Complexity"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 14,
125 | "metadata": {},
126 | "outputs": [
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "Computational complexity: 1.3 GMac\n",
132 | "Number of parameters: 3.01 M \n"
133 | ]
134 | }
135 | ],
136 | "source": [
137 | "import torchvision.models as models\n",
138 | "import torch\n",
139 | "from ptflops import get_model_complexity_info\n",
140 | "\n",
141 | "with torch.cuda.device(0):\n",
142 | " net = model\n",
143 | " macs, params = get_model_complexity_info(\n",
144 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n",
145 | " )\n",
146 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n",
147 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 15,
153 | "metadata": {},
154 | "outputs": [
155 | {
156 | "name": "stderr",
157 | "output_type": "stream",
158 | "text": [
159 | "Unsupported operator aten::silu encountered 84 time(s)\n",
160 | "Unsupported operator aten::add encountered 45 time(s)\n",
161 | "Unsupported operator aten::div encountered 15 time(s)\n",
162 | "Unsupported operator aten::ceil encountered 6 time(s)\n",
163 | "Unsupported operator aten::mul encountered 118 time(s)\n",
164 | "Unsupported operator aten::softmax encountered 21 time(s)\n",
165 | "Unsupported operator aten::clone encountered 4 time(s)\n",
166 | "Unsupported operator aten::mul_ encountered 24 time(s)\n",
167 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n",
168 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
169 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n"
170 | ]
171 | },
172 | {
173 | "name": "stdout",
174 | "output_type": "stream",
175 | "text": [
176 | "Computational complexity: 3.01 \n",
177 | "Number of parameters: 1.24 \n"
178 | ]
179 | }
180 | ],
181 | "source": [
182 | "def flop_count_analysis(\n",
183 | " model: torch.nn.Module,\n",
184 | " input_dim: Tuple,\n",
185 | ") -> Dict:\n",
186 | " \"\"\"_summary_\n",
187 | "\n",
188 | " Args:\n",
189 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n",
190 | " model (torch.nn.Module): _description_\n",
191 | " \"\"\"\n",
192 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
193 | " input_tensor = torch.ones(()).new_empty(\n",
194 | " (1, *input_dim),\n",
195 | " dtype=next(model.parameters()).dtype,\n",
196 | " device=next(model.parameters()).device,\n",
197 | " )\n",
198 | " flops = FlopCountAnalysis(model, input_tensor)\n",
199 | " model_flops = flops.total()\n",
200 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n",
201 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n",
202 | "\n",
203 | " out = {\n",
204 | " \"params\": round(trainable_params * 1e-6, 2),\n",
205 | " \"flops\": round(model_flops * 1e-9, 2),\n",
206 | " }\n",
207 | "\n",
208 | " return out\n",
209 | "\n",
210 | "\n",
211 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n",
212 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n",
213 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "Calculate IoU Metric"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "iou = []\n",
230 | "with torch.no_grad():\n",
231 | " for idx, data in tqdm(enumerate(testloader)):\n",
232 | " image = data[\"image\"].cuda()\n",
233 | " mask = data[\"mask\"].cuda()\n",
234 | " out = model.forward(image)\n",
235 | " out = torch.sigmoid(out)\n",
236 | " out[out < 0.5] = 0\n",
237 | " out[out >= 0.5] = 1\n",
238 | " mean_iou = jaccard_score(\n",
239 | " mask.detach().cpu().numpy().ravel(),\n",
240 | " out.detach().cpu().numpy().ravel(),\n",
241 | " average=\"binary\",\n",
242 | " pos_label=1,\n",
243 | " )\n",
244 | " iou.append(mean_iou.item())\n",
245 | "\n",
246 | "print(f\"test iou: {np.mean(iou)}\")"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | "Accuracy"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "accuracy = []\n",
263 | "with torch.no_grad():\n",
264 | " for idx, data in tqdm(enumerate(testloader)):\n",
265 | " image = data[\"image\"].cuda()\n",
266 | " mask = data[\"mask\"].cuda()\n",
267 | " out = model.forward(image)\n",
268 | " out = torch.sigmoid(out)\n",
269 | " out[out < 0.5] = 0\n",
270 | " out[out >= 0.5] = 1\n",
271 | " acc = accuracy_score(\n",
272 | " mask.detach().cpu().numpy().ravel(),\n",
273 | " out.detach().cpu().numpy().ravel(),\n",
274 | " )\n",
275 | " accuracy.append(acc.item())\n",
276 | "\n",
277 | "print(f\"test accuracy: {np.mean(accuracy)}\")"
278 | ]
279 | },
280 | {
281 | "attachments": {},
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "Calculate Dice"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": null,
291 | "metadata": {},
292 | "outputs": [],
293 | "source": [
294 | "dice = []\n",
295 | "with torch.no_grad():\n",
296 | " for idx, data in tqdm(enumerate(testloader)):\n",
297 | " image = data[\"image\"].cuda()\n",
298 | " mask = data[\"mask\"].cuda()\n",
299 | " out = model.forward(image)\n",
300 | " out = torch.sigmoid(out)\n",
301 | " out[out < 0.5] = 0\n",
302 | " out[out >= 0.5] = 1\n",
303 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n",
304 | " dice.append(mean_dice.item())\n",
305 | "\n",
306 | "print(f\"test dice: {np.mean(dice)}\")"
307 | ]
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "Calculate Specificity"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "metadata": {},
320 | "outputs": [],
321 | "source": [
322 | "specificity = []\n",
323 | "with torch.no_grad():\n",
324 | " for idx, data in tqdm(enumerate(testloader)):\n",
325 | " image = data[\"image\"].cuda()\n",
326 | " mask = data[\"mask\"].cuda()\n",
327 | " out = model.forward(image)\n",
328 | " out = torch.sigmoid(out)\n",
329 | " out[out < 0.5] = 0\n",
330 | " out[out >= 0.5] = 1\n",
331 | " confusion = confusion_matrix(\n",
332 | " mask.detach().cpu().numpy().ravel(),\n",
333 | " out.detach().cpu().numpy().ravel(),\n",
334 | " )\n",
335 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n",
336 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n",
337 | "\n",
338 | " specificity.append(sp)\n",
339 | "\n",
340 | "print(f\"test specificity: {np.mean(specificity)}\")"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {},
346 | "source": [
347 | "Calculate Sensitivity"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "sensitivity = []\n",
357 | "with torch.no_grad():\n",
358 | " for idx, data in tqdm(enumerate(testloader)):\n",
359 | " image = data[\"image\"].cuda()\n",
360 | " mask = data[\"mask\"].cuda()\n",
361 | " out = model.forward(image)\n",
362 | " out = torch.sigmoid(out)\n",
363 | " out[out < 0.5] = 0\n",
364 | " out[out >= 0.5] = 1\n",
365 | " confusion = confusion_matrix(\n",
366 | " mask.detach().cpu().numpy().ravel(),\n",
367 | " out.detach().cpu().numpy().ravel(),\n",
368 | " )\n",
369 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n",
370 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n",
371 | "\n",
372 | " sensitivity.append(se)\n",
373 | "\n",
374 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": null,
380 | "metadata": {},
381 | "outputs": [],
382 | "source": [
383 | "# DONE"
384 | ]
385 | }
386 | ],
387 | "metadata": {
388 | "kernelspec": {
389 | "display_name": "core",
390 | "language": "python",
391 | "name": "python3"
392 | },
393 | "language_info": {
394 | "codemirror_mode": {
395 | "name": "ipython",
396 | "version": 3
397 | },
398 | "file_extension": ".py",
399 | "mimetype": "text/x-python",
400 | "name": "python",
401 | "nbconvert_exporter": "python",
402 | "pygments_lexer": "ipython3",
403 | "version": "3.11.9"
404 | },
405 | "orig_nbformat": 4,
406 | "vscode": {
407 | "interpreter": {
408 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea"
409 | }
410 | }
411 | },
412 | "nbformat": 4,
413 | "nbformat_minor": 2
414 | }
415 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2016/exp_2_dice_b8_a2/run_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 |
5 | sys.path.append("../../../")
6 |
7 | import yaml
8 | import torch
9 | import argparse
10 | import numpy as np
11 | from typing import Dict
12 | from termcolor import colored
13 | from accelerate import Accelerator
14 | from losses.losses import build_loss_fn
15 | from optimizers.optimizers import build_optimizer
16 | from optimizers.schedulers import build_scheduler
17 | from train_scripts.segmentation_trainer import Segmentation_Trainer
18 | from architectures.build_architecture import build_architecture
19 | from dataloaders.build_dataset import build_dataset, build_dataloader
20 |
21 |
22 | ##################################################################################################
23 | def launch_experiment(config_path) -> Dict:
24 | """
25 | Builds Experiment
26 | Args:
27 | config (Dict): configuration file
28 |
29 | Returns:
30 | Dict: _description_
31 | """
32 | # load config
33 | config = load_config(config_path)
34 |
35 | # set seed
36 | seed_everything(config)
37 |
38 | # build directories
39 | build_directories(config)
40 |
41 | # build training dataset & training data loader
42 | trainset = build_dataset(
43 | dataset_type=config["dataset_parameters"]["dataset_type"],
44 | dataset_args=config["dataset_parameters"]["train_dataset_args"],
45 | augmentation_args=config["train_augmentation_args"],
46 | )
47 | trainloader = build_dataloader(
48 | dataset=trainset,
49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"],
50 | config=config,
51 | train=True,
52 | )
53 |
54 | # build validation dataset & validataion data loader
55 | valset = build_dataset(
56 | dataset_type=config["dataset_parameters"]["dataset_type"],
57 | dataset_args=config["dataset_parameters"]["val_dataset_args"],
58 | augmentation_args=config["test_augmentation_args"],
59 | )
60 | valloader = build_dataloader(
61 | dataset=valset,
62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"],
63 | config=config,
64 | train=False,
65 | )
66 |
67 | # build the Model
68 | model = build_architecture(config)
69 |
70 | # set up the loss function
71 | criterion = build_loss_fn(
72 | loss_type=config["loss_fn"]["loss_type"],
73 | loss_args=config["loss_fn"]["loss_args"],
74 | )
75 |
76 | # set up the optimizer
77 | optimizer = build_optimizer(
78 | model=model,
79 | optimizer_type=config["optimizer"]["optimizer_type"],
80 | optimizer_args=config["optimizer"]["optimizer_args"],
81 | )
82 |
83 | # set up schedulers
84 | warmup_scheduler = build_scheduler(
85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config
86 | )
87 | training_scheduler = build_scheduler(
88 | optimizer=optimizer,
89 | scheduler_type="training_scheduler",
90 | config=config,
91 | )
92 |
93 | # use accelarate
94 | accelerator = Accelerator(
95 | log_with="wandb",
96 | gradient_accumulation_steps=config["training_parameters"][
97 | "grad_accumulate_steps"
98 | ],
99 | )
100 | accelerator.init_trackers(
101 | project_name=config["project"],
102 | config=config,
103 | init_kwargs={"wandb": config["wandb_parameters"]},
104 | )
105 |
106 | # display experiment info
107 | display_info(config, accelerator, trainset, valset, model)
108 |
109 | # convert all components to accelerate
110 | model = accelerator.prepare_model(model=model)
111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer)
112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader)
113 | valloader = accelerator.prepare_data_loader(data_loader=valloader)
114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler)
115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler)
116 |
117 | # create a single dict to hold all parameters
118 | storage = {
119 | "model": model,
120 | "trainloader": trainloader,
121 | "valloader": valloader,
122 | "criterion": criterion,
123 | "optimizer": optimizer,
124 | "warmup_scheduler": warmup_scheduler,
125 | "training_scheduler": training_scheduler,
126 | }
127 |
128 | # set up trainer
129 | trainer = Segmentation_Trainer(
130 | config=config,
131 | model=storage["model"],
132 | optimizer=storage["optimizer"],
133 | criterion=storage["criterion"],
134 | train_dataloader=storage["trainloader"],
135 | val_dataloader=storage["valloader"],
136 | warmup_scheduler=storage["warmup_scheduler"],
137 | training_scheduler=storage["training_scheduler"],
138 | accelerator=accelerator,
139 | )
140 |
141 | # run train
142 | trainer.train()
143 |
144 |
145 | ##################################################################################################
146 | def seed_everything(config) -> None:
147 | seed = config["training_parameters"]["seed"]
148 | os.environ["PYTHONHASHSEED"] = str(seed)
149 | random.seed(seed)
150 | np.random.seed(seed)
151 | torch.manual_seed(seed)
152 | torch.cuda.manual_seed(seed)
153 | torch.cuda.manual_seed_all(seed)
154 | torch.backends.cudnn.deterministic = True
155 | torch.backends.cudnn.benchmark = False
156 |
157 |
158 | ##################################################################################################
159 | def load_config(config_path: str) -> Dict:
160 | """loads the yaml config file
161 |
162 | Args:
163 | config_path (str): _description_
164 |
165 | Returns:
166 | Dict: _description_
167 | """
168 | with open(config_path, "r") as file:
169 | config = yaml.safe_load(file)
170 | return config
171 |
172 |
173 | ##################################################################################################
174 | def build_directories(config: Dict) -> None:
175 | # create necessary directories
176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]):
177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"])
178 |
179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]):
180 | raise ValueError("checkpoint exits -- preventing file override -- rename file")
181 |
182 |
183 | ##################################################################################################
184 | def display_info(config, accelerator, trainset, valset, model):
185 | # print experiment info
186 | accelerator.print(f"-------------------------------------------------------")
187 | accelerator.print(f"[info]: Experiment Info")
188 | accelerator.print(
189 | f"[info] ----- Project: {colored(config['project'], color='red')}"
190 | )
191 | accelerator.print(
192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}"
193 | )
194 | accelerator.print(
195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}"
196 | )
197 | accelerator.print(
198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}"
199 | )
200 | accelerator.print(
201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}"
202 | )
203 | accelerator.print(
204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}"
205 | )
206 | accelerator.print(
207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}"
208 | )
209 | accelerator.print(
210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}"
211 | )
212 | accelerator.print(
213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}"
214 | )
215 |
216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
217 | accelerator.print(
218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}"
219 | )
220 | accelerator.print(
221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}"
222 | )
223 | accelerator.print(
224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}"
225 | )
226 | accelerator.print(
227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}"
228 | )
229 | accelerator.print(
230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}"
231 | )
232 | accelerator.print(f"-------------------------------------------------------")
233 |
234 |
235 | ##################################################################################################
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser(description="Simple example of training script.")
238 | parser.add_argument(
239 | "--config", type=str, default="config.yaml", help="path to yaml config file"
240 | )
241 | args = parser.parse_args()
242 | launch_experiment(args.config)
243 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2016/exp_2_dice_b8_a2/visualize.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "\n",
12 | "sys.path.append(\"../../../\")\n",
13 | "\n",
14 | "import yaml\n",
15 | "import torch\n",
16 | "from typing import Dict\n",
17 | "from architectures.build_architecture import build_architecture\n",
18 | "from dataloaders.build_dataset import build_dataset\n",
19 | "from torchvision.utils import save_image"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 4,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "def load_config(config_path: str) -> Dict:\n",
29 | " \"\"\"loads the yaml config file\n",
30 | "\n",
31 | " Args:\n",
32 | " config_path (str): _description_\n",
33 | "\n",
34 | " Returns:\n",
35 | " Dict: _description_\n",
36 | " \"\"\"\n",
37 | " with open(config_path, \"r\") as file:\n",
38 | " config = yaml.safe_load(file)\n",
39 | " return config"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 5,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "config = load_config(\"config.yaml\")"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "Set Up"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 7,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "model = build_architecture(config=config)\n",
65 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n",
66 | "model.load_state_dict(checkpoint)\n",
67 | "model = model.to(\"cpu\")\n",
68 | "model = model.eval()"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 9,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "# build validation dataset & validataion data loader\n",
78 | "valset = build_dataset(\n",
79 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n",
80 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n",
81 | " augmentation_args=config[\"test_augmentation_args\"],\n",
82 | ")\n",
83 | "testloader = torch.utils.data.DataLoader(\n",
84 | " valset, batch_size=1, shuffle=False, num_workers=1\n",
85 | ")"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "Inference"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 10,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "if not os.path.isdir(\"predictions\"):\n",
102 | " os.makedirs(\"predictions/rgb\")\n",
103 | " os.makedirs(\"predictions/gt\")\n",
104 | " os.makedirs(\"predictions/pred\")"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 12,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "torch.Size([1, 3, 512, 512])\n",
117 | "torch.Size([1, 3, 512, 512])\n",
118 | "torch.Size([1, 3, 512, 512])\n",
119 | "torch.Size([1, 3, 512, 512])\n",
120 | "torch.Size([1, 3, 512, 512])\n",
121 | "torch.Size([1, 3, 512, 512])\n",
122 | "torch.Size([1, 3, 512, 512])\n",
123 | "torch.Size([1, 3, 512, 512])\n",
124 | "torch.Size([1, 3, 512, 512])\n",
125 | "torch.Size([1, 3, 512, 512])\n",
126 | "torch.Size([1, 3, 512, 512])\n",
127 | "torch.Size([1, 3, 512, 512])\n",
128 | "torch.Size([1, 3, 512, 512])\n",
129 | "torch.Size([1, 3, 512, 512])\n",
130 | "torch.Size([1, 3, 512, 512])\n",
131 | "torch.Size([1, 3, 512, 512])\n",
132 | "torch.Size([1, 3, 512, 512])\n",
133 | "torch.Size([1, 3, 512, 512])\n",
134 | "torch.Size([1, 3, 512, 512])\n",
135 | "torch.Size([1, 3, 512, 512])\n",
136 | "torch.Size([1, 3, 512, 512])\n",
137 | "torch.Size([1, 3, 512, 512])\n",
138 | "torch.Size([1, 3, 512, 512])\n",
139 | "torch.Size([1, 3, 512, 512])\n",
140 | "torch.Size([1, 3, 512, 512])\n",
141 | "torch.Size([1, 3, 512, 512])\n",
142 | "torch.Size([1, 3, 512, 512])\n",
143 | "torch.Size([1, 3, 512, 512])\n",
144 | "torch.Size([1, 3, 512, 512])\n",
145 | "torch.Size([1, 3, 512, 512])\n",
146 | "torch.Size([1, 3, 512, 512])\n",
147 | "torch.Size([1, 3, 512, 512])\n",
148 | "torch.Size([1, 3, 512, 512])\n",
149 | "torch.Size([1, 3, 512, 512])\n",
150 | "torch.Size([1, 3, 512, 512])\n",
151 | "torch.Size([1, 3, 512, 512])\n",
152 | "torch.Size([1, 3, 512, 512])\n",
153 | "torch.Size([1, 3, 512, 512])\n",
154 | "torch.Size([1, 3, 512, 512])\n",
155 | "torch.Size([1, 3, 512, 512])\n",
156 | "torch.Size([1, 3, 512, 512])\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "model.eval()\n",
162 | "counter = 40\n",
163 | "for idx, data in enumerate(testloader):\n",
164 | " image = data[\"image\"].cuda()\n",
165 | " mask = data[\"mask\"].cuda()\n",
166 | " out = model.forward(image)\n",
167 | " out = torch.sigmoid(out)\n",
168 | " out[out < 0.5] = 0\n",
169 | " out[out >= 0.5] = 1\n",
170 | "\n",
171 | " # rgb input\n",
172 | " img = data[\"image\"]\n",
173 | " img = img.detach()\n",
174 | " print(img.shape)\n",
175 | " img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n",
176 | " img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n",
177 | " img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n",
178 | " save_image(img, f\"predictions/rgb/image_{idx}.png\")\n",
179 | "\n",
180 | " # prediction\n",
181 | " pred = out.detach()\n",
182 | " pred = pred * 255.0\n",
183 | " save_image(pred, f\"predictions/pred/pred_{idx}.png\")\n",
184 | "\n",
185 | " # ground truth\n",
186 | " gt = data[\"mask\"]\n",
187 | " gt = gt.detach()\n",
188 | " gt = gt * 255.0\n",
189 | " save_image(gt, f\"predictions/gt/gt_{idx}.png\")\n",
190 | "\n",
191 | " if idx == counter:\n",
192 | " break"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "Save RGB, GT and Pred"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 14,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "# # RGB INPUT\n",
209 | "\n",
210 | "# img = data[\"image\"]\n",
211 | "# img = img.detach()\n",
212 | "# print(img.shape)\n",
213 | "# img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n",
214 | "# img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n",
215 | "# img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n",
216 | "# save_image(img, f\"predictions/rgb/image_{idx}.png\")"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 15,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": [
225 | "# # PREDICTION\n",
226 | "\n",
227 | "# pred = out.detach()\n",
228 | "# pred = pred * 255.0\n",
229 | "# save_image(pred, f\"predictions/pred/pred_{idx}.png\")"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": 16,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "# # GROUND TRUTH\n",
239 | "\n",
240 | "# gt = data[1]\n",
241 | "# gt = gt.detach()\n",
242 | "# gt = gt * 255.0\n",
243 | "# save_image(gt, f\"predictions/gt/gt_{idx}.png\")"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": []
252 | }
253 | ],
254 | "metadata": {
255 | "kernelspec": {
256 | "display_name": "corev2",
257 | "language": "python",
258 | "name": "python3"
259 | },
260 | "language_info": {
261 | "codemirror_mode": {
262 | "name": "ipython",
263 | "version": 3
264 | },
265 | "file_extension": ".py",
266 | "mimetype": "text/x-python",
267 | "name": "python",
268 | "nbconvert_exporter": "python",
269 | "pygments_lexer": "ipython3",
270 | "version": "3.11.7"
271 | }
272 | },
273 | "nbformat": 4,
274 | "nbformat_minor": 2
275 | }
276 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2017/exp_2_dice_b8_a2/config.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: isic_2017
16 | name: exp_2_dice_b8_a2
17 | mode: "online"
18 | resume: False
19 | tags: ["tr17ts17", "dice", "xxs", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xxs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [96]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [80, 96, 96]
34 | decoder:
35 | dims: [64, 80, 96]
36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_2017.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_2017.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_2017.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/experiments_medical/isic_2017/exp_2_dice_b8_a2/inference_model_weights/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/isic_2017/exp_2_dice_b8_a2/inference_model_weights/.gitignore
--------------------------------------------------------------------------------
/experiments_medical/isic_2017/exp_2_dice_b8_a2/metrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "\n",
12 | "sys.path.append(\"../../../\")\n",
13 | "\n",
14 | "import yaml\n",
15 | "import torch\n",
16 | "import numpy as np\n",
17 | "from typing import Dict\n",
18 | "from architectures.build_architecture import build_architecture\n",
19 | "from dataloaders.build_dataset import build_dataset\n",
20 | "from typing import Tuple, Dict\n",
21 | "from fvcore.nn import FlopCountAnalysis\n",
22 | "from tqdm.notebook import tqdm\n",
23 | "from sklearn.metrics import (\n",
24 | " jaccard_score,\n",
25 | " accuracy_score,\n",
26 | " confusion_matrix,\n",
27 | ")\n",
28 | "import monai"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "Load Config File"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "def load_config(config_path: str) -> Dict:\n",
45 | " \"\"\"loads the yaml config file\n",
46 | "\n",
47 | " Args:\n",
48 | " config_path (str): _description_\n",
49 | "\n",
50 | " Returns:\n",
51 | " Dict: _description_\n",
52 | " \"\"\"\n",
53 | " with open(config_path, \"r\") as file:\n",
54 | " config = yaml.safe_load(file)\n",
55 | " return config\n",
56 | "\n",
57 | "\n",
58 | "config = load_config(\"config.yaml\")"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "Build Dataset and DataLoaders"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "600\n"
78 | ]
79 | }
80 | ],
81 | "source": [
82 | "# build validation dataset & validataion data loader\n",
83 | "testset = build_dataset(\n",
84 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n",
85 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n",
86 | " augmentation_args=config[\"test_augmentation_args\"],\n",
87 | ")\n",
88 | "\n",
89 | "testloader = torch.utils.data.DataLoader(\n",
90 | " testset, batch_size=1, shuffle=False, num_workers=1\n",
91 | ")\n",
92 | "\n",
93 | "print(len(testset))"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 4,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "model = build_architecture(config=config)\n",
103 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n",
104 | "model.load_state_dict(checkpoint)\n",
105 | "model = model.to(\"cpu\")\n",
106 | "model = model.eval()"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "Model Complexity"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 5,
119 | "metadata": {},
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "Computational complexity: 1.3 GMac\n",
126 | "Number of parameters: 3.01 M \n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "import torchvision.models as models\n",
132 | "import torch\n",
133 | "from ptflops import get_model_complexity_info\n",
134 | "\n",
135 | "with torch.cuda.device(0):\n",
136 | " net = model\n",
137 | " macs, params = get_model_complexity_info(\n",
138 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n",
139 | " )\n",
140 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n",
141 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 6,
147 | "metadata": {},
148 | "outputs": [
149 | {
150 | "name": "stderr",
151 | "output_type": "stream",
152 | "text": [
153 | "Unsupported operator aten::silu encountered 84 time(s)\n",
154 | "Unsupported operator aten::add encountered 45 time(s)\n",
155 | "Unsupported operator aten::div encountered 15 time(s)\n",
156 | "Unsupported operator aten::ceil encountered 6 time(s)\n",
157 | "Unsupported operator aten::mul encountered 118 time(s)\n",
158 | "Unsupported operator aten::softmax encountered 21 time(s)\n",
159 | "Unsupported operator aten::clone encountered 4 time(s)\n",
160 | "Unsupported operator aten::mul_ encountered 24 time(s)\n",
161 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n",
162 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
163 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n"
164 | ]
165 | },
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "Computational complexity: 3.01 \n",
171 | "Number of parameters: 1.24 \n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "def flop_count_analysis(\n",
177 | " model: torch.nn.Module,\n",
178 | " input_dim: Tuple,\n",
179 | ") -> Dict:\n",
180 | " \"\"\"_summary_\n",
181 | "\n",
182 | " Args:\n",
183 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n",
184 | " model (torch.nn.Module): _description_\n",
185 | " \"\"\"\n",
186 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
187 | " input_tensor = torch.ones(()).new_empty(\n",
188 | " (1, *input_dim),\n",
189 | " dtype=next(model.parameters()).dtype,\n",
190 | " device=next(model.parameters()).device,\n",
191 | " )\n",
192 | " flops = FlopCountAnalysis(model, input_tensor)\n",
193 | " model_flops = flops.total()\n",
194 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n",
195 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n",
196 | "\n",
197 | " out = {\n",
198 | " \"params\": round(trainable_params * 1e-6, 2),\n",
199 | " \"flops\": round(model_flops * 1e-9, 2),\n",
200 | " }\n",
201 | "\n",
202 | " return out\n",
203 | "\n",
204 | "\n",
205 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n",
206 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n",
207 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "Calculate IoU Metric"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 7,
220 | "metadata": {},
221 | "outputs": [
222 | {
223 | "data": {
224 | "application/vnd.jupyter.widget-view+json": {
225 | "model_id": "a074d62a4c3d4a14ac6e4f99ca7abe73",
226 | "version_major": 2,
227 | "version_minor": 0
228 | },
229 | "text/plain": [
230 | "0it [00:00, ?it/s]"
231 | ]
232 | },
233 | "metadata": {},
234 | "output_type": "display_data"
235 | },
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | "test iou: 0.7900467007312545\n"
241 | ]
242 | }
243 | ],
244 | "source": [
245 | "iou = []\n",
246 | "with torch.no_grad():\n",
247 | " for idx, data in tqdm(enumerate(testloader)):\n",
248 | " image = data[\"image\"].cuda()\n",
249 | " mask = data[\"mask\"].cuda()\n",
250 | " out = model.forward(image)\n",
251 | " out = torch.sigmoid(out)\n",
252 | " out[out < 0.5] = 0\n",
253 | " out[out >= 0.5] = 1\n",
254 | " mean_iou = jaccard_score(\n",
255 | " mask.detach().cpu().numpy().ravel(),\n",
256 | " out.detach().cpu().numpy().ravel(),\n",
257 | " average=\"binary\",\n",
258 | " pos_label=1,\n",
259 | " )\n",
260 | " iou.append(mean_iou.item())\n",
261 | "\n",
262 | "print(f\"test iou: {np.mean(iou)}\")"
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {},
268 | "source": [
269 | "Accuracy"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 8,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "data": {
279 | "application/vnd.jupyter.widget-view+json": {
280 | "model_id": "63a0af5770784eeabded405e89276b4f",
281 | "version_major": 2,
282 | "version_minor": 0
283 | },
284 | "text/plain": [
285 | "0it [00:00, ?it/s]"
286 | ]
287 | },
288 | "metadata": {},
289 | "output_type": "display_data"
290 | },
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "test accuracy: 0.9445873578389485\n"
296 | ]
297 | }
298 | ],
299 | "source": [
300 | "accuracy = []\n",
301 | "with torch.no_grad():\n",
302 | " for idx, data in tqdm(enumerate(testloader)):\n",
303 | " image = data[\"image\"].cuda()\n",
304 | " mask = data[\"mask\"].cuda()\n",
305 | " out = model.forward(image)\n",
306 | " out = torch.sigmoid(out)\n",
307 | " out[out < 0.5] = 0\n",
308 | " out[out >= 0.5] = 1\n",
309 | " acc = accuracy_score(\n",
310 | " mask.detach().cpu().numpy().ravel(),\n",
311 | " out.detach().cpu().numpy().ravel(),\n",
312 | " )\n",
313 | " accuracy.append(acc.item())\n",
314 | "\n",
315 | "print(f\"test accuracy: {np.mean(accuracy)}\")"
316 | ]
317 | },
318 | {
319 | "attachments": {},
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "Calculate Dice"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 9,
329 | "metadata": {},
330 | "outputs": [
331 | {
332 | "data": {
333 | "application/vnd.jupyter.widget-view+json": {
334 | "model_id": "cf1cb2544a624ec6a99862601b08d61b",
335 | "version_major": 2,
336 | "version_minor": 0
337 | },
338 | "text/plain": [
339 | "0it [00:00, ?it/s]"
340 | ]
341 | },
342 | "metadata": {},
343 | "output_type": "display_data"
344 | },
345 | {
346 | "name": "stdout",
347 | "output_type": "stream",
348 | "text": [
349 | "test dice: 0.8683788800487916\n"
350 | ]
351 | }
352 | ],
353 | "source": [
354 | "dice = []\n",
355 | "with torch.no_grad():\n",
356 | " for idx, data in tqdm(enumerate(testloader)):\n",
357 | " image = data[\"image\"].cuda()\n",
358 | " mask = data[\"mask\"].cuda()\n",
359 | " out = model.forward(image)\n",
360 | " out = torch.sigmoid(out)\n",
361 | " out[out < 0.5] = 0\n",
362 | " out[out >= 0.5] = 1\n",
363 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n",
364 | " dice.append(mean_dice.item())\n",
365 | "\n",
366 | "print(f\"test dice: {np.mean(dice)}\")"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "Calculate Specificity"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 10,
379 | "metadata": {},
380 | "outputs": [
381 | {
382 | "data": {
383 | "application/vnd.jupyter.widget-view+json": {
384 | "model_id": "f5a40452c5c648a196704e679826d4d5",
385 | "version_major": 2,
386 | "version_minor": 0
387 | },
388 | "text/plain": [
389 | "0it [00:00, ?it/s]"
390 | ]
391 | },
392 | "metadata": {},
393 | "output_type": "display_data"
394 | },
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "test specificity: 0.969295418576087\n"
400 | ]
401 | }
402 | ],
403 | "source": [
404 | "specificity = []\n",
405 | "with torch.no_grad():\n",
406 | " for idx, data in tqdm(enumerate(testloader)):\n",
407 | " image = data[\"image\"].cuda()\n",
408 | " mask = data[\"mask\"].cuda()\n",
409 | " out = model.forward(image)\n",
410 | " out = torch.sigmoid(out)\n",
411 | " out[out < 0.5] = 0\n",
412 | " out[out >= 0.5] = 1\n",
413 | " confusion = confusion_matrix(\n",
414 | " mask.detach().cpu().numpy().ravel(),\n",
415 | " out.detach().cpu().numpy().ravel(),\n",
416 | " )\n",
417 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n",
418 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n",
419 | "\n",
420 | " specificity.append(sp)\n",
421 | "\n",
422 | "print(f\"test specificity: {np.mean(specificity)}\")"
423 | ]
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {},
428 | "source": [
429 | "Calculate Sensitivity"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 11,
435 | "metadata": {},
436 | "outputs": [
437 | {
438 | "data": {
439 | "application/vnd.jupyter.widget-view+json": {
440 | "model_id": "f880c143f40c429e9125160dfba20fc7",
441 | "version_major": 2,
442 | "version_minor": 0
443 | },
444 | "text/plain": [
445 | "0it [00:00, ?it/s]"
446 | ]
447 | },
448 | "metadata": {},
449 | "output_type": "display_data"
450 | },
451 | {
452 | "name": "stdout",
453 | "output_type": "stream",
454 | "text": [
455 | "test sensitivity: 0.8518193023236439\n"
456 | ]
457 | }
458 | ],
459 | "source": [
460 | "sensitivity = []\n",
461 | "with torch.no_grad():\n",
462 | " for idx, data in tqdm(enumerate(testloader)):\n",
463 | " image = data[\"image\"].cuda()\n",
464 | " mask = data[\"mask\"].cuda()\n",
465 | " out = model.forward(image)\n",
466 | " out = torch.sigmoid(out)\n",
467 | " out[out < 0.5] = 0\n",
468 | " out[out >= 0.5] = 1\n",
469 | " confusion = confusion_matrix(\n",
470 | " mask.detach().cpu().numpy().ravel(),\n",
471 | " out.detach().cpu().numpy().ravel(),\n",
472 | " )\n",
473 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n",
474 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n",
475 | "\n",
476 | " sensitivity.append(se)\n",
477 | "\n",
478 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 12,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": [
487 | "# DONE"
488 | ]
489 | }
490 | ],
491 | "metadata": {
492 | "kernelspec": {
493 | "display_name": "core",
494 | "language": "python",
495 | "name": "python3"
496 | },
497 | "language_info": {
498 | "codemirror_mode": {
499 | "name": "ipython",
500 | "version": 3
501 | },
502 | "file_extension": ".py",
503 | "mimetype": "text/x-python",
504 | "name": "python",
505 | "nbconvert_exporter": "python",
506 | "pygments_lexer": "ipython3",
507 | "version": "3.11.7"
508 | },
509 | "orig_nbformat": 4,
510 | "vscode": {
511 | "interpreter": {
512 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea"
513 | }
514 | }
515 | },
516 | "nbformat": 4,
517 | "nbformat_minor": 2
518 | }
519 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2017/exp_2_dice_b8_a2/run_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 |
5 | sys.path.append("../../../")
6 |
7 | import yaml
8 | import torch
9 | import argparse
10 | import numpy as np
11 | from typing import Dict
12 | from termcolor import colored
13 | from accelerate import Accelerator
14 | from losses.losses import build_loss_fn
15 | from optimizers.optimizers import build_optimizer
16 | from optimizers.schedulers import build_scheduler
17 | from train_scripts.segmentation_trainer import Segmentation_Trainer
18 | from architectures.build_architecture import build_architecture
19 | from dataloaders.build_dataset import build_dataset, build_dataloader
20 |
21 |
22 | ##################################################################################################
23 | def launch_experiment(config_path) -> Dict:
24 | """
25 | Builds Experiment
26 | Args:
27 | config (Dict): configuration file
28 |
29 | Returns:
30 | Dict: _description_
31 | """
32 | # load config
33 | config = load_config(config_path)
34 |
35 | # set seed
36 | seed_everything(config)
37 |
38 | # build directories
39 | build_directories(config)
40 |
41 | # build training dataset & training data loader
42 | trainset = build_dataset(
43 | dataset_type=config["dataset_parameters"]["dataset_type"],
44 | dataset_args=config["dataset_parameters"]["train_dataset_args"],
45 | augmentation_args=config["train_augmentation_args"],
46 | )
47 | trainloader = build_dataloader(
48 | dataset=trainset,
49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"],
50 | config=config,
51 | train=True,
52 | )
53 |
54 | # build validation dataset & validataion data loader
55 | valset = build_dataset(
56 | dataset_type=config["dataset_parameters"]["dataset_type"],
57 | dataset_args=config["dataset_parameters"]["val_dataset_args"],
58 | augmentation_args=config["test_augmentation_args"],
59 | )
60 | valloader = build_dataloader(
61 | dataset=valset,
62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"],
63 | config=config,
64 | train=False,
65 | )
66 |
67 | # build the Model
68 | model = build_architecture(config)
69 |
70 | # set up the loss function
71 | criterion = build_loss_fn(
72 | loss_type=config["loss_fn"]["loss_type"],
73 | loss_args=config["loss_fn"]["loss_args"],
74 | )
75 |
76 | # set up the optimizer
77 | optimizer = build_optimizer(
78 | model=model,
79 | optimizer_type=config["optimizer"]["optimizer_type"],
80 | optimizer_args=config["optimizer"]["optimizer_args"],
81 | )
82 |
83 | # set up schedulers
84 | warmup_scheduler = build_scheduler(
85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config
86 | )
87 | training_scheduler = build_scheduler(
88 | optimizer=optimizer,
89 | scheduler_type="training_scheduler",
90 | config=config,
91 | )
92 |
93 | # use accelarate
94 | accelerator = Accelerator(
95 | log_with="wandb",
96 | gradient_accumulation_steps=config["training_parameters"][
97 | "grad_accumulate_steps"
98 | ],
99 | )
100 | accelerator.init_trackers(
101 | project_name=config["project"],
102 | config=config,
103 | init_kwargs={"wandb": config["wandb_parameters"]},
104 | )
105 |
106 | # display experiment info
107 | display_info(config, accelerator, trainset, valset, model)
108 |
109 | # convert all components to accelerate
110 | model = accelerator.prepare_model(model=model)
111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer)
112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader)
113 | valloader = accelerator.prepare_data_loader(data_loader=valloader)
114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler)
115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler)
116 |
117 | # create a single dict to hold all parameters
118 | storage = {
119 | "model": model,
120 | "trainloader": trainloader,
121 | "valloader": valloader,
122 | "criterion": criterion,
123 | "optimizer": optimizer,
124 | "warmup_scheduler": warmup_scheduler,
125 | "training_scheduler": training_scheduler,
126 | }
127 |
128 | # set up trainer
129 | trainer = Segmentation_Trainer(
130 | config=config,
131 | model=storage["model"],
132 | optimizer=storage["optimizer"],
133 | criterion=storage["criterion"],
134 | train_dataloader=storage["trainloader"],
135 | val_dataloader=storage["valloader"],
136 | warmup_scheduler=storage["warmup_scheduler"],
137 | training_scheduler=storage["training_scheduler"],
138 | accelerator=accelerator,
139 | )
140 |
141 | # run train
142 | trainer.train()
143 |
144 |
145 | ##################################################################################################
146 | def seed_everything(config) -> None:
147 | seed = config["training_parameters"]["seed"]
148 | os.environ["PYTHONHASHSEED"] = str(seed)
149 | random.seed(seed)
150 | np.random.seed(seed)
151 | torch.manual_seed(seed)
152 | torch.cuda.manual_seed(seed)
153 | torch.cuda.manual_seed_all(seed)
154 | torch.backends.cudnn.deterministic = True
155 | torch.backends.cudnn.benchmark = False
156 |
157 |
158 | ##################################################################################################
159 | def load_config(config_path: str) -> Dict:
160 | """loads the yaml config file
161 |
162 | Args:
163 | config_path (str): _description_
164 |
165 | Returns:
166 | Dict: _description_
167 | """
168 | with open(config_path, "r") as file:
169 | config = yaml.safe_load(file)
170 | return config
171 |
172 |
173 | ##################################################################################################
174 | def build_directories(config: Dict) -> None:
175 | # create necessary directories
176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]):
177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"])
178 |
179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]):
180 | raise ValueError("checkpoint exits -- preventing file override -- rename file")
181 |
182 |
183 | ##################################################################################################
184 | def display_info(config, accelerator, trainset, valset, model):
185 | # print experiment info
186 | accelerator.print(f"-------------------------------------------------------")
187 | accelerator.print(f"[info]: Experiment Info")
188 | accelerator.print(
189 | f"[info] ----- Project: {colored(config['project'], color='red')}"
190 | )
191 | accelerator.print(
192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}"
193 | )
194 | accelerator.print(
195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}"
196 | )
197 | accelerator.print(
198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}"
199 | )
200 | accelerator.print(
201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}"
202 | )
203 | accelerator.print(
204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}"
205 | )
206 | accelerator.print(
207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}"
208 | )
209 | accelerator.print(
210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}"
211 | )
212 | accelerator.print(
213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}"
214 | )
215 |
216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
217 | accelerator.print(
218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}"
219 | )
220 | accelerator.print(
221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}"
222 | )
223 | accelerator.print(
224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}"
225 | )
226 | accelerator.print(
227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}"
228 | )
229 | accelerator.print(
230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}"
231 | )
232 | accelerator.print(f"-------------------------------------------------------")
233 |
234 |
235 | ##################################################################################################
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser(description="Simple example of training script.")
238 | parser.add_argument(
239 | "--config", type=str, default="config.yaml", help="path to yaml config file"
240 | )
241 | args = parser.parse_args()
242 | launch_experiment(args.config)
243 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2018/exp_2_dice_b8_a2/config.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: isic_2018
16 | name: exp_2_dice_b8_a
17 | mode: "online"
18 | resume: False
19 | tags: ["tr18ts18", "dice", "xxs", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xxs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [96]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [80, 96, 96]
34 | decoder:
35 | dims: [64, 80, 96]
36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_2018.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_2018.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_2018.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/experiments_medical/isic_2018/exp_2_dice_b8_a2/inference_model_weights/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/isic_2018/exp_2_dice_b8_a2/inference_model_weights/.gitignore
--------------------------------------------------------------------------------
/experiments_medical/isic_2018/exp_2_dice_b8_a2/metrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "\n",
12 | "sys.path.append(\"../../../\")\n",
13 | "\n",
14 | "import yaml\n",
15 | "import torch\n",
16 | "import numpy as np\n",
17 | "from typing import Dict\n",
18 | "from architectures.build_architecture import build_architecture\n",
19 | "from dataloaders.build_dataset import build_dataset\n",
20 | "from typing import Tuple, Dict\n",
21 | "from fvcore.nn import FlopCountAnalysis\n",
22 | "from tqdm.notebook import tqdm\n",
23 | "from sklearn.metrics import (\n",
24 | " jaccard_score,\n",
25 | " accuracy_score,\n",
26 | " confusion_matrix,\n",
27 | ")\n",
28 | "import monai"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "Load Config File"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "def load_config(config_path: str) -> Dict:\n",
45 | " \"\"\"loads the yaml config file\n",
46 | "\n",
47 | " Args:\n",
48 | " config_path (str): _description_\n",
49 | "\n",
50 | " Returns:\n",
51 | " Dict: _description_\n",
52 | " \"\"\"\n",
53 | " with open(config_path, \"r\") as file:\n",
54 | " config = yaml.safe_load(file)\n",
55 | " return config\n",
56 | "\n",
57 | "\n",
58 | "config = load_config(\"config.yaml\")"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "Build Dataset and DataLoaders"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "1000\n"
78 | ]
79 | }
80 | ],
81 | "source": [
82 | "# build validation dataset & validataion data loader\n",
83 | "testset = build_dataset(\n",
84 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n",
85 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n",
86 | " augmentation_args=config[\"test_augmentation_args\"],\n",
87 | ")\n",
88 | "\n",
89 | "testloader = torch.utils.data.DataLoader(\n",
90 | " testset, batch_size=1, shuffle=False, num_workers=1\n",
91 | ")\n",
92 | "\n",
93 | "print(len(testset))"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 4,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "model = build_architecture(config=config)\n",
103 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n",
104 | "model.load_state_dict(checkpoint)\n",
105 | "model = model.to(\"cpu\")\n",
106 | "model = model.eval()"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "Model Complexity"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 5,
119 | "metadata": {},
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "Computational complexity: 1.3 GMac\n",
126 | "Number of parameters: 3.01 M \n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "import torchvision.models as models\n",
132 | "import torch\n",
133 | "from ptflops import get_model_complexity_info\n",
134 | "\n",
135 | "with torch.cuda.device(0):\n",
136 | " net = model\n",
137 | " macs, params = get_model_complexity_info(\n",
138 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n",
139 | " )\n",
140 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n",
141 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 6,
147 | "metadata": {},
148 | "outputs": [
149 | {
150 | "name": "stderr",
151 | "output_type": "stream",
152 | "text": [
153 | "Unsupported operator aten::silu encountered 84 time(s)\n",
154 | "Unsupported operator aten::add encountered 45 time(s)\n",
155 | "Unsupported operator aten::div encountered 15 time(s)\n",
156 | "Unsupported operator aten::ceil encountered 6 time(s)\n",
157 | "Unsupported operator aten::mul encountered 118 time(s)\n",
158 | "Unsupported operator aten::softmax encountered 21 time(s)\n",
159 | "Unsupported operator aten::clone encountered 4 time(s)\n",
160 | "Unsupported operator aten::mul_ encountered 24 time(s)\n",
161 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n",
162 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
163 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n"
164 | ]
165 | },
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "Computational complexity: 3.01 \n",
171 | "Number of parameters: 1.24 \n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "def flop_count_analysis(\n",
177 | " model: torch.nn.Module,\n",
178 | " input_dim: Tuple,\n",
179 | ") -> Dict:\n",
180 | " \"\"\"_summary_\n",
181 | "\n",
182 | " Args:\n",
183 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n",
184 | " model (torch.nn.Module): _description_\n",
185 | " \"\"\"\n",
186 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
187 | " input_tensor = torch.ones(()).new_empty(\n",
188 | " (1, *input_dim),\n",
189 | " dtype=next(model.parameters()).dtype,\n",
190 | " device=next(model.parameters()).device,\n",
191 | " )\n",
192 | " flops = FlopCountAnalysis(model, input_tensor)\n",
193 | " model_flops = flops.total()\n",
194 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n",
195 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n",
196 | "\n",
197 | " out = {\n",
198 | " \"params\": round(trainable_params * 1e-6, 2),\n",
199 | " \"flops\": round(model_flops * 1e-9, 2),\n",
200 | " }\n",
201 | "\n",
202 | " return out\n",
203 | "\n",
204 | "\n",
205 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n",
206 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n",
207 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "Calculate IoU Metric"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 7,
220 | "metadata": {},
221 | "outputs": [
222 | {
223 | "data": {
224 | "application/vnd.jupyter.widget-view+json": {
225 | "model_id": "a50f810cae5446b687648fc4c979ab68",
226 | "version_major": 2,
227 | "version_minor": 0
228 | },
229 | "text/plain": [
230 | "0it [00:00, ?it/s]"
231 | ]
232 | },
233 | "metadata": {},
234 | "output_type": "display_data"
235 | },
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | "test iou: 0.8456581013355624\n"
241 | ]
242 | }
243 | ],
244 | "source": [
245 | "iou = []\n",
246 | "with torch.no_grad():\n",
247 | " for idx, data in tqdm(enumerate(testloader)):\n",
248 | " image = data[\"image\"].cuda()\n",
249 | " mask = data[\"mask\"].cuda()\n",
250 | " out = model.forward(image)\n",
251 | " out = torch.sigmoid(out)\n",
252 | " out[out < 0.5] = 0\n",
253 | " out[out >= 0.5] = 1\n",
254 | " mean_iou = jaccard_score(\n",
255 | " mask.detach().cpu().numpy().ravel(),\n",
256 | " out.detach().cpu().numpy().ravel(),\n",
257 | " average=\"binary\",\n",
258 | " pos_label=1,\n",
259 | " )\n",
260 | " iou.append(mean_iou.item())\n",
261 | "\n",
262 | "print(f\"test iou: {np.mean(iou)}\")"
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {},
268 | "source": [
269 | "Accuracy"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 8,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "data": {
279 | "application/vnd.jupyter.widget-view+json": {
280 | "model_id": "f727b9ca0a604a139f3c41a4ab971605",
281 | "version_major": 2,
282 | "version_minor": 0
283 | },
284 | "text/plain": [
285 | "0it [00:00, ?it/s]"
286 | ]
287 | },
288 | "metadata": {},
289 | "output_type": "display_data"
290 | },
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "test accuracy: 0.9440349769592286\n"
296 | ]
297 | }
298 | ],
299 | "source": [
300 | "accuracy = []\n",
301 | "with torch.no_grad():\n",
302 | " for idx, data in tqdm(enumerate(testloader)):\n",
303 | " image = data[\"image\"].cuda()\n",
304 | " mask = data[\"mask\"].cuda()\n",
305 | " out = model.forward(image)\n",
306 | " out = torch.sigmoid(out)\n",
307 | " out[out < 0.5] = 0\n",
308 | " out[out >= 0.5] = 1\n",
309 | " acc = accuracy_score(\n",
310 | " mask.detach().cpu().numpy().ravel(),\n",
311 | " out.detach().cpu().numpy().ravel(),\n",
312 | " )\n",
313 | " accuracy.append(acc.item())\n",
314 | "\n",
315 | "print(f\"test accuracy: {np.mean(accuracy)}\")"
316 | ]
317 | },
318 | {
319 | "attachments": {},
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "Calculate Dice"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 9,
329 | "metadata": {},
330 | "outputs": [
331 | {
332 | "data": {
333 | "application/vnd.jupyter.widget-view+json": {
334 | "model_id": "b957f714a8e544eb85f1132ee9d24439",
335 | "version_major": 2,
336 | "version_minor": 0
337 | },
338 | "text/plain": [
339 | "0it [00:00, ?it/s]"
340 | ]
341 | },
342 | "metadata": {},
343 | "output_type": "display_data"
344 | },
345 | {
346 | "name": "stdout",
347 | "output_type": "stream",
348 | "text": [
349 | "test dice: 0.9073548163622618\n"
350 | ]
351 | }
352 | ],
353 | "source": [
354 | "dice = []\n",
355 | "with torch.no_grad():\n",
356 | " for idx, data in tqdm(enumerate(testloader)):\n",
357 | " image = data[\"image\"].cuda()\n",
358 | " mask = data[\"mask\"].cuda()\n",
359 | " out = model.forward(image)\n",
360 | " out = torch.sigmoid(out)\n",
361 | " out[out < 0.5] = 0\n",
362 | " out[out >= 0.5] = 1\n",
363 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n",
364 | " dice.append(mean_dice.item())\n",
365 | "\n",
366 | "print(f\"test dice: {np.mean(dice)}\")"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "Calculate Specificity"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 10,
379 | "metadata": {},
380 | "outputs": [
381 | {
382 | "data": {
383 | "application/vnd.jupyter.widget-view+json": {
384 | "model_id": "d9064584bfb248829f465cbd888c5528",
385 | "version_major": 2,
386 | "version_minor": 0
387 | },
388 | "text/plain": [
389 | "0it [00:00, ?it/s]"
390 | ]
391 | },
392 | "metadata": {},
393 | "output_type": "display_data"
394 | },
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "test specificity: 0.9503119752056314\n"
400 | ]
401 | }
402 | ],
403 | "source": [
404 | "specificity = []\n",
405 | "with torch.no_grad():\n",
406 | " for idx, data in tqdm(enumerate(testloader)):\n",
407 | " image = data[\"image\"].cuda()\n",
408 | " mask = data[\"mask\"].cuda()\n",
409 | " out = model.forward(image)\n",
410 | " out = torch.sigmoid(out)\n",
411 | " out[out < 0.5] = 0\n",
412 | " out[out >= 0.5] = 1\n",
413 | " confusion = confusion_matrix(\n",
414 | " mask.detach().cpu().numpy().ravel(),\n",
415 | " out.detach().cpu().numpy().ravel(),\n",
416 | " )\n",
417 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n",
418 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n",
419 | "\n",
420 | " specificity.append(sp)\n",
421 | "\n",
422 | "print(f\"test specificity: {np.mean(specificity)}\")"
423 | ]
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {},
428 | "source": [
429 | "Calculate Sensitivity"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 11,
435 | "metadata": {},
436 | "outputs": [
437 | {
438 | "data": {
439 | "application/vnd.jupyter.widget-view+json": {
440 | "model_id": "cff60dfd365b4cdebb2c5841e6750f03",
441 | "version_major": 2,
442 | "version_minor": 0
443 | },
444 | "text/plain": [
445 | "0it [00:00, ?it/s]"
446 | ]
447 | },
448 | "metadata": {},
449 | "output_type": "display_data"
450 | },
451 | {
452 | "name": "stdout",
453 | "output_type": "stream",
454 | "text": [
455 | "test sensitivity: 0.9255471263739157\n"
456 | ]
457 | }
458 | ],
459 | "source": [
460 | "sensitivity = []\n",
461 | "with torch.no_grad():\n",
462 | " for idx, data in tqdm(enumerate(testloader)):\n",
463 | " image = data[\"image\"].cuda()\n",
464 | " mask = data[\"mask\"].cuda()\n",
465 | " out = model.forward(image)\n",
466 | " out = torch.sigmoid(out)\n",
467 | " out[out < 0.5] = 0\n",
468 | " out[out >= 0.5] = 1\n",
469 | " confusion = confusion_matrix(\n",
470 | " mask.detach().cpu().numpy().ravel(),\n",
471 | " out.detach().cpu().numpy().ravel(),\n",
472 | " )\n",
473 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n",
474 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n",
475 | "\n",
476 | " sensitivity.append(se)\n",
477 | "\n",
478 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 12,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": [
487 | "# DONE"
488 | ]
489 | }
490 | ],
491 | "metadata": {
492 | "kernelspec": {
493 | "display_name": "core",
494 | "language": "python",
495 | "name": "python3"
496 | },
497 | "language_info": {
498 | "codemirror_mode": {
499 | "name": "ipython",
500 | "version": 3
501 | },
502 | "file_extension": ".py",
503 | "mimetype": "text/x-python",
504 | "name": "python",
505 | "nbconvert_exporter": "python",
506 | "pygments_lexer": "ipython3",
507 | "version": "3.11.7"
508 | },
509 | "orig_nbformat": 4,
510 | "vscode": {
511 | "interpreter": {
512 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea"
513 | }
514 | }
515 | },
516 | "nbformat": 4,
517 | "nbformat_minor": 2
518 | }
519 |
--------------------------------------------------------------------------------
/experiments_medical/isic_2018/exp_2_dice_b8_a2/run_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 |
5 | sys.path.append("../../../")
6 |
7 | import yaml
8 | import torch
9 | import argparse
10 | import numpy as np
11 | from typing import Dict
12 | from termcolor import colored
13 | from accelerate import Accelerator
14 | from losses.losses import build_loss_fn
15 | from optimizers.optimizers import build_optimizer
16 | from optimizers.schedulers import build_scheduler
17 | from train_scripts.segmentation_trainer import Segmentation_Trainer
18 | from architectures.build_architecture import build_architecture
19 | from dataloaders.build_dataset import build_dataset, build_dataloader
20 |
21 |
22 | ##################################################################################################
23 | def launch_experiment(config_path) -> Dict:
24 | """
25 | Builds Experiment
26 | Args:
27 | config (Dict): configuration file
28 |
29 | Returns:
30 | Dict: _description_
31 | """
32 | # load config
33 | config = load_config(config_path)
34 |
35 | # set seed
36 | seed_everything(config)
37 |
38 | # build directories
39 | build_directories(config)
40 |
41 | # build training dataset & training data loader
42 | trainset = build_dataset(
43 | dataset_type=config["dataset_parameters"]["dataset_type"],
44 | dataset_args=config["dataset_parameters"]["train_dataset_args"],
45 | augmentation_args=config["train_augmentation_args"],
46 | )
47 | trainloader = build_dataloader(
48 | dataset=trainset,
49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"],
50 | config=config,
51 | train=True,
52 | )
53 |
54 | # build validation dataset & validataion data loader
55 | valset = build_dataset(
56 | dataset_type=config["dataset_parameters"]["dataset_type"],
57 | dataset_args=config["dataset_parameters"]["val_dataset_args"],
58 | augmentation_args=config["test_augmentation_args"],
59 | )
60 | valloader = build_dataloader(
61 | dataset=valset,
62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"],
63 | config=config,
64 | train=False,
65 | )
66 |
67 | # build the Model
68 | model = build_architecture(config)
69 |
70 | # set up the loss function
71 | criterion = build_loss_fn(
72 | loss_type=config["loss_fn"]["loss_type"],
73 | loss_args=config["loss_fn"]["loss_args"],
74 | )
75 |
76 | # set up the optimizer
77 | optimizer = build_optimizer(
78 | model=model,
79 | optimizer_type=config["optimizer"]["optimizer_type"],
80 | optimizer_args=config["optimizer"]["optimizer_args"],
81 | )
82 |
83 | # set up schedulers
84 | warmup_scheduler = build_scheduler(
85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config
86 | )
87 | training_scheduler = build_scheduler(
88 | optimizer=optimizer,
89 | scheduler_type="training_scheduler",
90 | config=config,
91 | )
92 |
93 | # use accelarate
94 | accelerator = Accelerator(
95 | log_with="wandb",
96 | gradient_accumulation_steps=config["training_parameters"][
97 | "grad_accumulate_steps"
98 | ],
99 | )
100 | accelerator.init_trackers(
101 | project_name=config["project"],
102 | config=config,
103 | init_kwargs={"wandb": config["wandb_parameters"]},
104 | )
105 |
106 | # display experiment info
107 | display_info(config, accelerator, trainset, valset, model)
108 |
109 | # convert all components to accelerate
110 | model = accelerator.prepare_model(model=model)
111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer)
112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader)
113 | valloader = accelerator.prepare_data_loader(data_loader=valloader)
114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler)
115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler)
116 |
117 | # create a single dict to hold all parameters
118 | storage = {
119 | "model": model,
120 | "trainloader": trainloader,
121 | "valloader": valloader,
122 | "criterion": criterion,
123 | "optimizer": optimizer,
124 | "warmup_scheduler": warmup_scheduler,
125 | "training_scheduler": training_scheduler,
126 | }
127 |
128 | # set up trainer
129 | trainer = Segmentation_Trainer(
130 | config=config,
131 | model=storage["model"],
132 | optimizer=storage["optimizer"],
133 | criterion=storage["criterion"],
134 | train_dataloader=storage["trainloader"],
135 | val_dataloader=storage["valloader"],
136 | warmup_scheduler=storage["warmup_scheduler"],
137 | training_scheduler=storage["training_scheduler"],
138 | accelerator=accelerator,
139 | )
140 |
141 | # run train
142 | trainer.train()
143 |
144 |
145 | ##################################################################################################
146 | def seed_everything(config) -> None:
147 | seed = config["training_parameters"]["seed"]
148 | os.environ["PYTHONHASHSEED"] = str(seed)
149 | random.seed(seed)
150 | np.random.seed(seed)
151 | torch.manual_seed(seed)
152 | torch.cuda.manual_seed(seed)
153 | torch.cuda.manual_seed_all(seed)
154 | torch.backends.cudnn.deterministic = True
155 | torch.backends.cudnn.benchmark = False
156 |
157 |
158 | ##################################################################################################
159 | def load_config(config_path: str) -> Dict:
160 | """loads the yaml config file
161 |
162 | Args:
163 | config_path (str): _description_
164 |
165 | Returns:
166 | Dict: _description_
167 | """
168 | with open(config_path, "r") as file:
169 | config = yaml.safe_load(file)
170 | return config
171 |
172 |
173 | ##################################################################################################
174 | def build_directories(config: Dict) -> None:
175 | # create necessary directories
176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]):
177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"])
178 |
179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]):
180 | raise ValueError("checkpoint exits -- preventing file override -- rename file")
181 |
182 |
183 | ##################################################################################################
184 | def display_info(config, accelerator, trainset, valset, model):
185 | # print experiment info
186 | accelerator.print(f"-------------------------------------------------------")
187 | accelerator.print(f"[info]: Experiment Info")
188 | accelerator.print(
189 | f"[info] ----- Project: {colored(config['project'], color='red')}"
190 | )
191 | accelerator.print(
192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}"
193 | )
194 | accelerator.print(
195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}"
196 | )
197 | accelerator.print(
198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}"
199 | )
200 | accelerator.print(
201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}"
202 | )
203 | accelerator.print(
204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}"
205 | )
206 | accelerator.print(
207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}"
208 | )
209 | accelerator.print(
210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}"
211 | )
212 | accelerator.print(
213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}"
214 | )
215 |
216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
217 | accelerator.print(
218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}"
219 | )
220 | accelerator.print(
221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}"
222 | )
223 | accelerator.print(
224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}"
225 | )
226 | accelerator.print(
227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}"
228 | )
229 | accelerator.print(
230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}"
231 | )
232 | accelerator.print(f"-------------------------------------------------------")
233 |
234 |
235 | ##################################################################################################
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser(description="Simple example of training script.")
238 | parser.add_argument(
239 | "--config", type=str, default="config.yaml", help="path to yaml config file"
240 | )
241 | args = parser.parse_args()
242 | launch_experiment(args.config)
243 |
--------------------------------------------------------------------------------
/experiments_medical/ph2/exp_2_dice_b8_a2/config.yaml:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | # Commonly Used Variables
3 | ##############################################################
4 | variables:
5 | batch_size: &batch_size 8
6 | num_channels: &num_channels 3
7 | num_classes: &num_classes 1
8 | num_epochs: &num_epochs 400
9 |
10 | ##############################################################
11 | # Wandb Model Tracking
12 | ##############################################################
13 | project: mobileunetr
14 | wandb_parameters:
15 | group: ph2
16 | name: exp_2_dice_b8_a
17 | mode: "online"
18 | resume: False
19 | tags: ["trph2tsph2", "dice", "xxs", "adam"]
20 |
21 | ##############################################################
22 | # Model Hyper-Parameters
23 | ##############################################################
24 | model_name: mobileunetr_xxs
25 | model_parameters:
26 | encoder: None
27 | bottle_neck:
28 | dims: [96]
29 | depths: [3]
30 | expansion: 4
31 | kernel_size: 3
32 | patch_size: [2,2]
33 | channels: [80, 96, 96]
34 | decoder:
35 | dims: [64, 80, 96]
36 | channels: [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 96, 96, 320]
37 | num_classes: 1
38 | image_size: 512
39 |
40 | ##############################################################
41 | # Loss Function
42 | ##############################################################
43 | loss_fn:
44 | loss_type: "dice"
45 | loss_args: None
46 |
47 | ##############################################################
48 | # Metrics
49 | ##############################################################
50 | metrics:
51 | type: "binary"
52 | mean_iou:
53 | enabled: True
54 | mean_iou_args:
55 | include_background: True
56 | reduction: "mean"
57 | get_not_nans: False
58 | ignore_empty: True
59 | dice:
60 | enabled: False
61 | dice_args:
62 | include_background: True
63 | reduction: "mean"
64 | get_not_nans: False
65 | ignore_empty: True
66 | num_classes: *num_classes
67 |
68 | ##############################################################
69 | # Optimizers
70 | ##############################################################
71 | optimizer:
72 | optimizer_type: "adamw"
73 | optimizer_args:
74 | lr: 0.00006
75 | weight_decay: 0.01
76 |
77 | ##############################################################
78 | # Learning Rate Schedulers
79 | ##############################################################
80 | warmup_scheduler:
81 | enabled: True # should be always true
82 | warmup_epochs: 30
83 |
84 | # train scheduler
85 | train_scheduler:
86 | scheduler_type: 'cosine_annealing_wr'
87 | scheduler_args:
88 | t_0_epochs: 600
89 | t_mult: 1
90 | min_lr: 0.000006
91 |
92 | ##############################################################
93 | # EMA (Exponential Moving Average)
94 | ##############################################################
95 | ema:
96 | enabled: False
97 | ema_decay: 0.999
98 | print_ema_every: 20
99 |
100 | ##############################################################
101 | # Gradient Clipping
102 | ##############################################################
103 | clip_gradients:
104 | enabled: False
105 | clip_gradients_value: 0.1
106 |
107 | ##############################################################
108 | # Training Hyperparameters
109 | ##############################################################
110 | training_parameters:
111 | seed: 42
112 | num_epochs: 1000
113 | cutoff_epoch: 600
114 | load_optimizer: False
115 | print_every: 1500
116 | calculate_metrics: True
117 | grad_accumulate_steps: 1 # default: 1
118 | checkpoint_save_dir: "model_checkpoints/best_iou_checkpoint2"
119 | load_checkpoint: # not implemented yet
120 | load_full_checkpoint: False
121 | load_model_only: False
122 | load_checkpoint_path: None
123 |
124 | ##############################################################
125 | # Dataset and Dataloader Args
126 | ##############################################################
127 | dataset_parameters:
128 | dataset_type: "isic_albumentation_v2"
129 | train_dataset_args:
130 | data_path: "../../../data/train_ph2.csv"
131 | train: True
132 | image_size: [512, 512]
133 | val_dataset_args:
134 | data_path: "../../../data/test_ph2.csv"
135 | train: False
136 | image_size: [512, 512]
137 | test_dataset_args:
138 | data_path: "../../../data/test_ph2.csv"
139 | train: False
140 | image_size: [512, 512]
141 | train_dataloader_args:
142 | batch_size: *batch_size
143 | shuffle: True
144 | num_workers: 4
145 | drop_last: True
146 | pin_memory: True
147 | val_dataloader_args:
148 | batch_size: *batch_size
149 | shuffle: False
150 | num_workers: 4
151 | drop_last: False
152 | pin_memory: True
153 | test_dataloader_args:
154 | batch_size: 1
155 | shuffle: False
156 | num_workers: 4
157 | drop_last: False
158 | pin_memory: True
159 |
160 | ##############################################################
161 | # Data Augmentation Args
162 | ##############################################################
163 | train_augmentation_args:
164 | mean: [0.485, 0.456, 0.406]
165 | std: [0.229, 0.224, 0.225]
166 | image_size: [512, 512] # [H, W]
167 |
168 | test_augmentation_args:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 | image_size: [512, 512] # [H, W]
--------------------------------------------------------------------------------
/experiments_medical/ph2/exp_2_dice_b8_a2/inference_model_weights/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/experiments_medical/ph2/exp_2_dice_b8_a2/inference_model_weights/.gitignore
--------------------------------------------------------------------------------
/experiments_medical/ph2/exp_2_dice_b8_a2/metrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "\n",
12 | "sys.path.append(\"../../../\")\n",
13 | "\n",
14 | "import yaml\n",
15 | "import torch\n",
16 | "import numpy as np\n",
17 | "from typing import Dict\n",
18 | "from architectures.build_architecture import build_architecture\n",
19 | "from dataloaders.build_dataset import build_dataset\n",
20 | "from typing import Tuple, Dict\n",
21 | "from fvcore.nn import FlopCountAnalysis\n",
22 | "from tqdm.notebook import tqdm\n",
23 | "from sklearn.metrics import (\n",
24 | " jaccard_score,\n",
25 | " accuracy_score,\n",
26 | " confusion_matrix,\n",
27 | ")\n",
28 | "import monai"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "Load Config File"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "def load_config(config_path: str) -> Dict:\n",
45 | " \"\"\"loads the yaml config file\n",
46 | "\n",
47 | " Args:\n",
48 | " config_path (str): _description_\n",
49 | "\n",
50 | " Returns:\n",
51 | " Dict: _description_\n",
52 | " \"\"\"\n",
53 | " with open(config_path, \"r\") as file:\n",
54 | " config = yaml.safe_load(file)\n",
55 | " return config\n",
56 | "\n",
57 | "\n",
58 | "config = load_config(\"config.yaml\")"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "Build Dataset and DataLoaders"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "40\n"
78 | ]
79 | }
80 | ],
81 | "source": [
82 | "# build validation dataset & validataion data loader\n",
83 | "testset = build_dataset(\n",
84 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n",
85 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n",
86 | " augmentation_args=config[\"test_augmentation_args\"],\n",
87 | ")\n",
88 | "\n",
89 | "testloader = torch.utils.data.DataLoader(\n",
90 | " testset, batch_size=1, shuffle=False, num_workers=1\n",
91 | ")\n",
92 | "\n",
93 | "print(len(testset))"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 4,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "model = build_architecture(config=config)\n",
103 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n",
104 | "model.load_state_dict(checkpoint)\n",
105 | "model = model.to(\"cpu\")\n",
106 | "model = model.eval()"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "Model Complexity"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 5,
119 | "metadata": {},
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "Computational complexity: 1.3 GMac\n",
126 | "Number of parameters: 3.01 M \n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "import torchvision.models as models\n",
132 | "import torch\n",
133 | "from ptflops import get_model_complexity_info\n",
134 | "\n",
135 | "with torch.cuda.device(0):\n",
136 | " net = model\n",
137 | " macs, params = get_model_complexity_info(\n",
138 | " net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False\n",
139 | " )\n",
140 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n",
141 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 6,
147 | "metadata": {},
148 | "outputs": [
149 | {
150 | "name": "stderr",
151 | "output_type": "stream",
152 | "text": [
153 | "Unsupported operator aten::silu encountered 84 time(s)\n",
154 | "Unsupported operator aten::add encountered 45 time(s)\n",
155 | "Unsupported operator aten::div encountered 15 time(s)\n",
156 | "Unsupported operator aten::ceil encountered 6 time(s)\n",
157 | "Unsupported operator aten::mul encountered 118 time(s)\n",
158 | "Unsupported operator aten::softmax encountered 21 time(s)\n",
159 | "Unsupported operator aten::clone encountered 4 time(s)\n",
160 | "Unsupported operator aten::mul_ encountered 24 time(s)\n",
161 | "Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)\n",
162 | "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
163 | "encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization\n"
164 | ]
165 | },
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "Computational complexity: 3.01 \n",
171 | "Number of parameters: 1.24 \n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "def flop_count_analysis(\n",
177 | " model: torch.nn.Module,\n",
178 | " input_dim: Tuple,\n",
179 | ") -> Dict:\n",
180 | " \"\"\"_summary_\n",
181 | "\n",
182 | " Args:\n",
183 | " input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))\n",
184 | " model (torch.nn.Module): _description_\n",
185 | " \"\"\"\n",
186 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
187 | " input_tensor = torch.ones(()).new_empty(\n",
188 | " (1, *input_dim),\n",
189 | " dtype=next(model.parameters()).dtype,\n",
190 | " device=next(model.parameters()).device,\n",
191 | " )\n",
192 | " flops = FlopCountAnalysis(model, input_tensor)\n",
193 | " model_flops = flops.total()\n",
194 | " # print(f\"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M\")\n",
195 | " # print(f\"MAdds: {round(model_flops * 1e-9, 2)} G\")\n",
196 | "\n",
197 | " out = {\n",
198 | " \"params\": round(trainable_params * 1e-6, 2),\n",
199 | " \"flops\": round(model_flops * 1e-9, 2),\n",
200 | " }\n",
201 | "\n",
202 | " return out\n",
203 | "\n",
204 | "\n",
205 | "inference_result = flop_count_analysis(model, (3, 256, 256))\n",
206 | "print(\"{:<30} {:<8}\".format(\"Computational complexity: \", inference_result[\"params\"]))\n",
207 | "print(\"{:<30} {:<8}\".format(\"Number of parameters: \", inference_result[\"flops\"]))"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "Calculate IoU Metric"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 7,
220 | "metadata": {},
221 | "outputs": [
222 | {
223 | "data": {
224 | "application/vnd.jupyter.widget-view+json": {
225 | "model_id": "8697bc384ea947628f121854eb4d091c",
226 | "version_major": 2,
227 | "version_minor": 0
228 | },
229 | "text/plain": [
230 | "0it [00:00, ?it/s]"
231 | ]
232 | },
233 | "metadata": {},
234 | "output_type": "display_data"
235 | },
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | "test iou: 0.9229712867991552\n"
241 | ]
242 | }
243 | ],
244 | "source": [
245 | "iou = []\n",
246 | "with torch.no_grad():\n",
247 | " for idx, data in tqdm(enumerate(testloader)):\n",
248 | " image = data[\"image\"].cuda()\n",
249 | " mask = data[\"mask\"].cuda()\n",
250 | " out = model.forward(image)\n",
251 | " out = torch.sigmoid(out)\n",
252 | " out[out < 0.5] = 0\n",
253 | " out[out >= 0.5] = 1\n",
254 | " mean_iou = jaccard_score(\n",
255 | " mask.detach().cpu().numpy().ravel(),\n",
256 | " out.detach().cpu().numpy().ravel(),\n",
257 | " average=\"binary\",\n",
258 | " pos_label=1,\n",
259 | " )\n",
260 | " iou.append(mean_iou.item())\n",
261 | "\n",
262 | "print(f\"test iou: {np.mean(iou)}\")"
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {},
268 | "source": [
269 | "Accuracy"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 8,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "data": {
279 | "application/vnd.jupyter.widget-view+json": {
280 | "model_id": "a451270dd9f04d849a2c32923785143c",
281 | "version_major": 2,
282 | "version_minor": 0
283 | },
284 | "text/plain": [
285 | "0it [00:00, ?it/s]"
286 | ]
287 | },
288 | "metadata": {},
289 | "output_type": "display_data"
290 | },
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "test accuracy: 0.9771324157714844\n"
296 | ]
297 | }
298 | ],
299 | "source": [
300 | "accuracy = []\n",
301 | "with torch.no_grad():\n",
302 | " for idx, data in tqdm(enumerate(testloader)):\n",
303 | " image = data[\"image\"].cuda()\n",
304 | " mask = data[\"mask\"].cuda()\n",
305 | " out = model.forward(image)\n",
306 | " out = torch.sigmoid(out)\n",
307 | " out[out < 0.5] = 0\n",
308 | " out[out >= 0.5] = 1\n",
309 | " acc = accuracy_score(\n",
310 | " mask.detach().cpu().numpy().ravel(),\n",
311 | " out.detach().cpu().numpy().ravel(),\n",
312 | " )\n",
313 | " accuracy.append(acc.item())\n",
314 | "\n",
315 | "print(f\"test accuracy: {np.mean(accuracy)}\")"
316 | ]
317 | },
318 | {
319 | "attachments": {},
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "Calculate Dice"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 9,
329 | "metadata": {},
330 | "outputs": [
331 | {
332 | "data": {
333 | "application/vnd.jupyter.widget-view+json": {
334 | "model_id": "0c465b3238b842bb8ee3f6b37a1b21b9",
335 | "version_major": 2,
336 | "version_minor": 0
337 | },
338 | "text/plain": [
339 | "0it [00:00, ?it/s]"
340 | ]
341 | },
342 | "metadata": {},
343 | "output_type": "display_data"
344 | },
345 | {
346 | "name": "stdout",
347 | "output_type": "stream",
348 | "text": [
349 | "test dice: 0.9569526433944702\n"
350 | ]
351 | }
352 | ],
353 | "source": [
354 | "dice = []\n",
355 | "with torch.no_grad():\n",
356 | " for idx, data in tqdm(enumerate(testloader)):\n",
357 | " image = data[\"image\"].cuda()\n",
358 | " mask = data[\"mask\"].cuda()\n",
359 | " out = model.forward(image)\n",
360 | " out = torch.sigmoid(out)\n",
361 | " out[out < 0.5] = 0\n",
362 | " out[out >= 0.5] = 1\n",
363 | " mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))\n",
364 | " dice.append(mean_dice.item())\n",
365 | "\n",
366 | "print(f\"test dice: {np.mean(dice)}\")"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "Calculate Specificity"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 10,
379 | "metadata": {},
380 | "outputs": [
381 | {
382 | "data": {
383 | "application/vnd.jupyter.widget-view+json": {
384 | "model_id": "924fdb4171ca424297fcf300084bb352",
385 | "version_major": 2,
386 | "version_minor": 0
387 | },
388 | "text/plain": [
389 | "0it [00:00, ?it/s]"
390 | ]
391 | },
392 | "metadata": {},
393 | "output_type": "display_data"
394 | },
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "test specificity: 0.9660395899750096\n"
400 | ]
401 | }
402 | ],
403 | "source": [
404 | "specificity = []\n",
405 | "with torch.no_grad():\n",
406 | " for idx, data in tqdm(enumerate(testloader)):\n",
407 | " image = data[\"image\"].cuda()\n",
408 | " mask = data[\"mask\"].cuda()\n",
409 | " out = model.forward(image)\n",
410 | " out = torch.sigmoid(out)\n",
411 | " out[out < 0.5] = 0\n",
412 | " out[out >= 0.5] = 1\n",
413 | " confusion = confusion_matrix(\n",
414 | " mask.detach().cpu().numpy().ravel(),\n",
415 | " out.detach().cpu().numpy().ravel(),\n",
416 | " )\n",
417 | " if float(confusion[0, 0] + confusion[0, 1]) != 0:\n",
418 | " sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])\n",
419 | "\n",
420 | " specificity.append(sp)\n",
421 | "\n",
422 | "print(f\"test specificity: {np.mean(specificity)}\")"
423 | ]
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {},
428 | "source": [
429 | "Calculate Sensitivity"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 11,
435 | "metadata": {},
436 | "outputs": [
437 | {
438 | "data": {
439 | "application/vnd.jupyter.widget-view+json": {
440 | "model_id": "7502c6211be24f269ba862fe067fe054",
441 | "version_major": 2,
442 | "version_minor": 0
443 | },
444 | "text/plain": [
445 | "0it [00:00, ?it/s]"
446 | ]
447 | },
448 | "metadata": {},
449 | "output_type": "display_data"
450 | },
451 | {
452 | "name": "stdout",
453 | "output_type": "stream",
454 | "text": [
455 | "test sensitivity: 0.9604657601219436\n"
456 | ]
457 | }
458 | ],
459 | "source": [
460 | "sensitivity = []\n",
461 | "with torch.no_grad():\n",
462 | " for idx, data in tqdm(enumerate(testloader)):\n",
463 | " image = data[\"image\"].cuda()\n",
464 | " mask = data[\"mask\"].cuda()\n",
465 | " out = model.forward(image)\n",
466 | " out = torch.sigmoid(out)\n",
467 | " out[out < 0.5] = 0\n",
468 | " out[out >= 0.5] = 1\n",
469 | " confusion = confusion_matrix(\n",
470 | " mask.detach().cpu().numpy().ravel(),\n",
471 | " out.detach().cpu().numpy().ravel(),\n",
472 | " )\n",
473 | " if float(confusion[1, 1] + confusion[1, 0]) != 0:\n",
474 | " se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])\n",
475 | "\n",
476 | " sensitivity.append(se)\n",
477 | "\n",
478 | "print(f\"test sensitivity: {np.mean(sensitivity)}\")"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 12,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": [
487 | "# DONE"
488 | ]
489 | }
490 | ],
491 | "metadata": {
492 | "kernelspec": {
493 | "display_name": "core",
494 | "language": "python",
495 | "name": "python3"
496 | },
497 | "language_info": {
498 | "codemirror_mode": {
499 | "name": "ipython",
500 | "version": 3
501 | },
502 | "file_extension": ".py",
503 | "mimetype": "text/x-python",
504 | "name": "python",
505 | "nbconvert_exporter": "python",
506 | "pygments_lexer": "ipython3",
507 | "version": "3.11.7"
508 | },
509 | "orig_nbformat": 4,
510 | "vscode": {
511 | "interpreter": {
512 | "hash": "db5989e82860003de3542e01be4c3e7827261da67de3613f2a961c26d75654ea"
513 | }
514 | }
515 | },
516 | "nbformat": 4,
517 | "nbformat_minor": 2
518 | }
519 |
--------------------------------------------------------------------------------
/experiments_medical/ph2/exp_2_dice_b8_a2/run_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 |
5 | sys.path.append("../../../")
6 |
7 | import yaml
8 | import torch
9 | import argparse
10 | import numpy as np
11 | from typing import Dict
12 | from termcolor import colored
13 | from accelerate import Accelerator
14 | from losses.losses import build_loss_fn
15 | from optimizers.optimizers import build_optimizer
16 | from optimizers.schedulers import build_scheduler
17 | from train_scripts.segmentation_trainer import Segmentation_Trainer
18 | from architectures.build_architecture import build_architecture
19 | from dataloaders.build_dataset import build_dataset, build_dataloader
20 |
21 |
22 | ##################################################################################################
23 | def launch_experiment(config_path) -> Dict:
24 | """
25 | Builds Experiment
26 | Args:
27 | config (Dict): configuration file
28 |
29 | Returns:
30 | Dict: _description_
31 | """
32 | # load config
33 | config = load_config(config_path)
34 |
35 | # set seed
36 | seed_everything(config)
37 |
38 | # build directories
39 | build_directories(config)
40 |
41 | # build training dataset & training data loader
42 | trainset = build_dataset(
43 | dataset_type=config["dataset_parameters"]["dataset_type"],
44 | dataset_args=config["dataset_parameters"]["train_dataset_args"],
45 | augmentation_args=config["train_augmentation_args"],
46 | )
47 | trainloader = build_dataloader(
48 | dataset=trainset,
49 | dataloader_args=config["dataset_parameters"]["train_dataloader_args"],
50 | config=config,
51 | train=True,
52 | )
53 |
54 | # build validation dataset & validataion data loader
55 | valset = build_dataset(
56 | dataset_type=config["dataset_parameters"]["dataset_type"],
57 | dataset_args=config["dataset_parameters"]["val_dataset_args"],
58 | augmentation_args=config["test_augmentation_args"],
59 | )
60 | valloader = build_dataloader(
61 | dataset=valset,
62 | dataloader_args=config["dataset_parameters"]["val_dataloader_args"],
63 | config=config,
64 | train=False,
65 | )
66 |
67 | # build the Model
68 | model = build_architecture(config)
69 |
70 | # set up the loss function
71 | criterion = build_loss_fn(
72 | loss_type=config["loss_fn"]["loss_type"],
73 | loss_args=config["loss_fn"]["loss_args"],
74 | )
75 |
76 | # set up the optimizer
77 | optimizer = build_optimizer(
78 | model=model,
79 | optimizer_type=config["optimizer"]["optimizer_type"],
80 | optimizer_args=config["optimizer"]["optimizer_args"],
81 | )
82 |
83 | # set up schedulers
84 | warmup_scheduler = build_scheduler(
85 | optimizer=optimizer, scheduler_type="warmup_scheduler", config=config
86 | )
87 | training_scheduler = build_scheduler(
88 | optimizer=optimizer,
89 | scheduler_type="training_scheduler",
90 | config=config,
91 | )
92 |
93 | # use accelarate
94 | accelerator = Accelerator(
95 | log_with="wandb",
96 | gradient_accumulation_steps=config["training_parameters"][
97 | "grad_accumulate_steps"
98 | ],
99 | )
100 | accelerator.init_trackers(
101 | project_name=config["project"],
102 | config=config,
103 | init_kwargs={"wandb": config["wandb_parameters"]},
104 | )
105 |
106 | # display experiment info
107 | display_info(config, accelerator, trainset, valset, model)
108 |
109 | # convert all components to accelerate
110 | model = accelerator.prepare_model(model=model)
111 | optimizer = accelerator.prepare_optimizer(optimizer=optimizer)
112 | trainloader = accelerator.prepare_data_loader(data_loader=trainloader)
113 | valloader = accelerator.prepare_data_loader(data_loader=valloader)
114 | warmup_scheduler = accelerator.prepare_scheduler(scheduler=warmup_scheduler)
115 | training_scheduler = accelerator.prepare_scheduler(scheduler=training_scheduler)
116 |
117 | # create a single dict to hold all parameters
118 | storage = {
119 | "model": model,
120 | "trainloader": trainloader,
121 | "valloader": valloader,
122 | "criterion": criterion,
123 | "optimizer": optimizer,
124 | "warmup_scheduler": warmup_scheduler,
125 | "training_scheduler": training_scheduler,
126 | }
127 |
128 | # set up trainer
129 | trainer = Segmentation_Trainer(
130 | config=config,
131 | model=storage["model"],
132 | optimizer=storage["optimizer"],
133 | criterion=storage["criterion"],
134 | train_dataloader=storage["trainloader"],
135 | val_dataloader=storage["valloader"],
136 | warmup_scheduler=storage["warmup_scheduler"],
137 | training_scheduler=storage["training_scheduler"],
138 | accelerator=accelerator,
139 | )
140 |
141 | # run train
142 | trainer.train()
143 |
144 |
145 | ##################################################################################################
146 | def seed_everything(config) -> None:
147 | seed = config["training_parameters"]["seed"]
148 | os.environ["PYTHONHASHSEED"] = str(seed)
149 | random.seed(seed)
150 | np.random.seed(seed)
151 | torch.manual_seed(seed)
152 | torch.cuda.manual_seed(seed)
153 | torch.cuda.manual_seed_all(seed)
154 | torch.backends.cudnn.deterministic = True
155 | torch.backends.cudnn.benchmark = False
156 |
157 |
158 | ##################################################################################################
159 | def load_config(config_path: str) -> Dict:
160 | """loads the yaml config file
161 |
162 | Args:
163 | config_path (str): _description_
164 |
165 | Returns:
166 | Dict: _description_
167 | """
168 | with open(config_path, "r") as file:
169 | config = yaml.safe_load(file)
170 | return config
171 |
172 |
173 | ##################################################################################################
174 | def build_directories(config: Dict) -> None:
175 | # create necessary directories
176 | if not os.path.exists(config["training_parameters"]["checkpoint_save_dir"]):
177 | os.makedirs(config["training_parameters"]["checkpoint_save_dir"])
178 |
179 | if os.listdir(config["training_parameters"]["checkpoint_save_dir"]):
180 | raise ValueError("checkpoint exits -- preventing file override -- rename file")
181 |
182 |
183 | ##################################################################################################
184 | def display_info(config, accelerator, trainset, valset, model):
185 | # print experiment info
186 | accelerator.print(f"-------------------------------------------------------")
187 | accelerator.print(f"[info]: Experiment Info")
188 | accelerator.print(
189 | f"[info] ----- Project: {colored(config['project'], color='red')}"
190 | )
191 | accelerator.print(
192 | f"[info] ----- Group: {colored(config['wandb_parameters']['group'], color='red')}"
193 | )
194 | accelerator.print(
195 | f"[info] ----- Name: {colored(config['wandb_parameters']['name'], color='red')}"
196 | )
197 | accelerator.print(
198 | f"[info] ----- Batch Size: {colored(config['dataset_parameters']['val_dataloader_args']['batch_size'], color='red')}"
199 | )
200 | accelerator.print(
201 | f"[info] ----- Num Epochs: {colored(config['training_parameters']['num_epochs'], color='red')}"
202 | )
203 | accelerator.print(
204 | f"[info] ----- Loss: {colored(config['loss_fn']['loss_type'], color='red')}"
205 | )
206 | accelerator.print(
207 | f"[info] ----- Optimizer: {colored(config['optimizer']['optimizer_type'], color='red')}"
208 | )
209 | accelerator.print(
210 | f"[info] ----- Train Dataset Size: {colored(len(trainset), color='red')}"
211 | )
212 | accelerator.print(
213 | f"[info] ----- Test Dataset Size: {colored(len(valset), color='red')}"
214 | )
215 |
216 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
217 | accelerator.print(
218 | f"[info] ----- Distributed Training: {colored('True' if torch.cuda.device_count() > 1 else 'False', color='red')}"
219 | )
220 | accelerator.print(
221 | f"[info] ----- Num Clases: {colored(config['variables']['num_classes'], color='red')}"
222 | )
223 | accelerator.print(
224 | f"[info] ----- EMA: {colored(config['ema']['enabled'], color='red')}"
225 | )
226 | accelerator.print(
227 | f"[info] ----- Load From Checkpoint: {colored(config['training_parameters']['load_checkpoint']['load_full_checkpoint'], color='red')}"
228 | )
229 | accelerator.print(
230 | f"[info] ----- Params: {colored(pytorch_total_params, color='red')}"
231 | )
232 | accelerator.print(f"-------------------------------------------------------")
233 |
234 |
235 | ##################################################################################################
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser(description="Simple example of training script.")
238 | parser.add_argument(
239 | "--config", type=str, default="config.yaml", help="path to yaml config file"
240 | )
241 | args = parser.parse_args()
242 | launch_experiment(args.config)
243 |
--------------------------------------------------------------------------------
/experiments_medical/ph2/exp_2_dice_b8_a2/visualize.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "\n",
12 | "sys.path.append(\"../../../\")\n",
13 | "\n",
14 | "import yaml\n",
15 | "import torch\n",
16 | "from typing import Dict\n",
17 | "from architectures.build_architecture import build_architecture\n",
18 | "from dataloaders.build_dataset import build_dataset\n",
19 | "from torchvision.utils import save_image"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "def load_config(config_path: str) -> Dict:\n",
29 | " \"\"\"loads the yaml config file\n",
30 | "\n",
31 | " Args:\n",
32 | " config_path (str): _description_\n",
33 | "\n",
34 | " Returns:\n",
35 | " Dict: _description_\n",
36 | " \"\"\"\n",
37 | " with open(config_path, \"r\") as file:\n",
38 | " config = yaml.safe_load(file)\n",
39 | " return config"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "config = load_config(\"config.yaml\")"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "Set Up"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 4,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "model = build_architecture(config=config)\n",
65 | "checkpoint = torch.load(\"pytorch_model.bin\", map_location=\"cpu\")\n",
66 | "model.load_state_dict(checkpoint)\n",
67 | "model = model.to(\"cpu\")\n",
68 | "model = model.eval()"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 5,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "# build validation dataset & validataion data loader\n",
78 | "valset = build_dataset(\n",
79 | " dataset_type=config[\"dataset_parameters\"][\"dataset_type\"],\n",
80 | " dataset_args=config[\"dataset_parameters\"][\"val_dataset_args\"],\n",
81 | " augmentation_args=config[\"test_augmentation_args\"],\n",
82 | ")\n",
83 | "testloader = torch.utils.data.DataLoader(\n",
84 | " valset, batch_size=1, shuffle=False, num_workers=1\n",
85 | ")"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "Inference"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 6,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "if not os.path.isdir(\"predictions\"):\n",
102 | " os.makedirs(\"predictions/rgb\")\n",
103 | " os.makedirs(\"predictions/gt\")\n",
104 | " os.makedirs(\"predictions/pred\")"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 7,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "torch.Size([1, 3, 512, 512])\n",
117 | "torch.Size([1, 3, 512, 512])\n",
118 | "torch.Size([1, 3, 512, 512])\n",
119 | "torch.Size([1, 3, 512, 512])\n",
120 | "torch.Size([1, 3, 512, 512])\n",
121 | "torch.Size([1, 3, 512, 512])\n",
122 | "torch.Size([1, 3, 512, 512])\n",
123 | "torch.Size([1, 3, 512, 512])\n",
124 | "torch.Size([1, 3, 512, 512])\n",
125 | "torch.Size([1, 3, 512, 512])\n",
126 | "torch.Size([1, 3, 512, 512])\n",
127 | "torch.Size([1, 3, 512, 512])\n",
128 | "torch.Size([1, 3, 512, 512])\n",
129 | "torch.Size([1, 3, 512, 512])\n",
130 | "torch.Size([1, 3, 512, 512])\n",
131 | "torch.Size([1, 3, 512, 512])\n",
132 | "torch.Size([1, 3, 512, 512])\n",
133 | "torch.Size([1, 3, 512, 512])\n",
134 | "torch.Size([1, 3, 512, 512])\n",
135 | "torch.Size([1, 3, 512, 512])\n",
136 | "torch.Size([1, 3, 512, 512])\n",
137 | "torch.Size([1, 3, 512, 512])\n",
138 | "torch.Size([1, 3, 512, 512])\n",
139 | "torch.Size([1, 3, 512, 512])\n",
140 | "torch.Size([1, 3, 512, 512])\n",
141 | "torch.Size([1, 3, 512, 512])\n",
142 | "torch.Size([1, 3, 512, 512])\n",
143 | "torch.Size([1, 3, 512, 512])\n",
144 | "torch.Size([1, 3, 512, 512])\n",
145 | "torch.Size([1, 3, 512, 512])\n",
146 | "torch.Size([1, 3, 512, 512])\n",
147 | "torch.Size([1, 3, 512, 512])\n",
148 | "torch.Size([1, 3, 512, 512])\n",
149 | "torch.Size([1, 3, 512, 512])\n",
150 | "torch.Size([1, 3, 512, 512])\n",
151 | "torch.Size([1, 3, 512, 512])\n",
152 | "torch.Size([1, 3, 512, 512])\n",
153 | "torch.Size([1, 3, 512, 512])\n",
154 | "torch.Size([1, 3, 512, 512])\n",
155 | "torch.Size([1, 3, 512, 512])\n",
156 | "torch.Size([1, 3, 512, 512])\n",
157 | "torch.Size([1, 3, 512, 512])\n",
158 | "torch.Size([1, 3, 512, 512])\n",
159 | "torch.Size([1, 3, 512, 512])\n",
160 | "torch.Size([1, 3, 512, 512])\n",
161 | "torch.Size([1, 3, 512, 512])\n",
162 | "torch.Size([1, 3, 512, 512])\n",
163 | "torch.Size([1, 3, 512, 512])\n",
164 | "torch.Size([1, 3, 512, 512])\n",
165 | "torch.Size([1, 3, 512, 512])\n",
166 | "torch.Size([1, 3, 512, 512])\n",
167 | "torch.Size([1, 3, 512, 512])\n",
168 | "torch.Size([1, 3, 512, 512])\n",
169 | "torch.Size([1, 3, 512, 512])\n",
170 | "torch.Size([1, 3, 512, 512])\n",
171 | "torch.Size([1, 3, 512, 512])\n",
172 | "torch.Size([1, 3, 512, 512])\n",
173 | "torch.Size([1, 3, 512, 512])\n",
174 | "torch.Size([1, 3, 512, 512])\n",
175 | "torch.Size([1, 3, 512, 512])\n",
176 | "torch.Size([1, 3, 512, 512])\n",
177 | "torch.Size([1, 3, 512, 512])\n",
178 | "torch.Size([1, 3, 512, 512])\n",
179 | "torch.Size([1, 3, 512, 512])\n",
180 | "torch.Size([1, 3, 512, 512])\n",
181 | "torch.Size([1, 3, 512, 512])\n",
182 | "torch.Size([1, 3, 512, 512])\n",
183 | "torch.Size([1, 3, 512, 512])\n",
184 | "torch.Size([1, 3, 512, 512])\n",
185 | "torch.Size([1, 3, 512, 512])\n",
186 | "torch.Size([1, 3, 512, 512])\n",
187 | "torch.Size([1, 3, 512, 512])\n",
188 | "torch.Size([1, 3, 512, 512])\n",
189 | "torch.Size([1, 3, 512, 512])\n",
190 | "torch.Size([1, 3, 512, 512])\n",
191 | "torch.Size([1, 3, 512, 512])\n",
192 | "torch.Size([1, 3, 512, 512])\n",
193 | "torch.Size([1, 3, 512, 512])\n",
194 | "torch.Size([1, 3, 512, 512])\n",
195 | "torch.Size([1, 3, 512, 512])\n"
196 | ]
197 | }
198 | ],
199 | "source": [
200 | "model.eval()\n",
201 | "counter = 4000\n",
202 | "for idx, data in enumerate(testloader):\n",
203 | " image = data[\"image\"].cuda()\n",
204 | " mask = data[\"mask\"].cuda()\n",
205 | " out = model.forward(image)\n",
206 | " out = torch.sigmoid(out)\n",
207 | " out[out < 0.5] = 0\n",
208 | " out[out >= 0.5] = 1\n",
209 | "\n",
210 | " # rgb input\n",
211 | " img = data[\"image\"]\n",
212 | " img = img.detach()\n",
213 | " print(img.shape)\n",
214 | " img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n",
215 | " img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n",
216 | " img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n",
217 | " save_image(img, f\"predictions/rgb/image_{idx}.png\")\n",
218 | "\n",
219 | " # prediction\n",
220 | " pred = out.detach()\n",
221 | " pred = pred * 255.0\n",
222 | " save_image(pred, f\"predictions/pred/pred_{idx}.png\")\n",
223 | "\n",
224 | " # ground truth\n",
225 | " gt = data[\"mask\"]\n",
226 | " gt = gt.detach()\n",
227 | " gt = gt * 255.0\n",
228 | " save_image(gt, f\"predictions/gt/gt_{idx}.png\")\n",
229 | "\n",
230 | " if idx == counter:\n",
231 | " break"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "Save RGB, GT and Pred"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 14,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "# # RGB INPUT\n",
248 | "\n",
249 | "# img = data[\"image\"]\n",
250 | "# img = img.detach()\n",
251 | "# print(img.shape)\n",
252 | "# img[:, 0, :, :] = (img[:, 0, :, :] * 0.1577) + 0.7128\n",
253 | "# img[:, 1, :, :] = (img[:, 1, :, :] * 0.1662) + 0.6000\n",
254 | "# img[:, 2, :, :] = (img[:, 2, :, :] * 0.1829) + 0.5532\n",
255 | "# save_image(img, f\"predictions/rgb/image_{idx}.png\")"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 15,
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "# # PREDICTION\n",
265 | "\n",
266 | "# pred = out.detach()\n",
267 | "# pred = pred * 255.0\n",
268 | "# save_image(pred, f\"predictions/pred/pred_{idx}.png\")"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 16,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "# # GROUND TRUTH\n",
278 | "\n",
279 | "# gt = data[1]\n",
280 | "# gt = gt.detach()\n",
281 | "# gt = gt * 255.0\n",
282 | "# save_image(gt, f\"predictions/gt/gt_{idx}.png\")"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": null,
288 | "metadata": {},
289 | "outputs": [],
290 | "source": []
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": "corev2",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.11.7"
310 | }
311 | },
312 | "nbformat": 4,
313 | "nbformat_minor": 2
314 | }
315 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/losses/__init__.py
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import monai
3 | import torch.nn as nn
4 | from typing import Dict
5 | import torch.nn.functional as F
6 |
7 |
8 | #####################################################################
9 | class CrossEntropyLoss(nn.Module):
10 | def __init__(self, loss_args=None):
11 | super().__init__()
12 | self._loss = nn.CrossEntropyLoss(reduction="mean")
13 |
14 | def __call__(self, predictions, targets):
15 | loss = self._loss(predictions, targets)
16 | return loss
17 |
18 |
19 | #####################################################################
20 | class BinaryCrossEntropyWithLogits:
21 | def __init__(self, loss_args=None):
22 | super().__init__()
23 | self._loss = nn.BCEWithLogitsLoss(reduction="mean")
24 |
25 | def __call__(self, predictions, targets):
26 | loss = self._loss(predictions, targets)
27 | return loss
28 |
29 |
30 | #####################################################################
31 | class MSELoss:
32 | def __init__(self, loss_args=None):
33 | super().__init__()
34 | self._loss = nn.MSELoss()
35 |
36 | def __call__(self, predicted, target):
37 | loss = self._loss(predicted, target)
38 | return loss
39 |
40 |
41 | #####################################################################
42 | class L1Loss:
43 | def __init__(self, loss_args=None):
44 | super().__init__()
45 | self._loss = nn.L1Loss()
46 |
47 | def __call__(self, predicted, target):
48 | loss = self._loss(predicted, target)
49 | return loss
50 |
51 |
52 | #####################################################################
53 | class DistillationLoss(nn.Module):
54 | # TODO: loss function not 100% verified
55 | def __init__(self, loss_args: Dict):
56 | """_summary_
57 |
58 | Args:
59 | loss_args (Dict): _description_
60 | """
61 | super().__init__()
62 | self.kl_div = nn.KLDivLoss(reduction="batchmean")
63 | self.temperature = loss_args["temperature"]
64 | self.lambda_param = loss_args["lambda"]
65 | self.dice = monai.losses.DiceLoss(
66 | to_onehot_y=False,
67 | sigmoid=True,
68 | )
69 |
70 | def forward(
71 | self,
72 | teacher_predictions: torch.Tensor,
73 | student_predictions: torch.Tensor,
74 | targets: torch.Tensor,
75 | ) -> torch.Tensor:
76 | """_summary_
77 |
78 | Args:
79 | teacher_predictions (torch.Tensor): _description_
80 | student_predictions (torch.Tensor): _description_
81 | targets (torch.Tensor): _description_
82 |
83 | Returns:
84 | torch.Tensor: _description_
85 | """
86 | # compute probabilties
87 | soft_teacher = F.softmax(
88 | teacher_predictions.view(-1) / self.temperature,
89 | dim=-1,
90 | )
91 | soft_student = F.log_softmax(
92 | student_predictions.view(-1) / self.temperature,
93 | dim=-1,
94 | )
95 |
96 | # compute kl div loss
97 | distillation_loss = self.kl_div(soft_student, soft_teacher) * (
98 | self.temperature**2
99 | )
100 |
101 | # compute dice loss on student
102 | dice_loss = self.dice(student_predictions, targets)
103 |
104 | # combine via lambda
105 | loss = (1.0 - self.lambda_param) * (
106 | dice_loss + self.lambda_param * distillation_loss
107 | )
108 |
109 | return loss
110 |
111 |
112 | class DiceBCELoss(nn.Module):
113 | def __init__(self, loss_args=None):
114 | super().__init__()
115 | self.dice_loss = monai.losses.DiceLoss(sigmoid=True, to_onehot_y=False)
116 | self.bce_loss = torch.nn.BCEWithLogitsLoss()
117 |
118 | def __call__(self, predicted, target):
119 | dice_loss = self.dice_loss(predicted, target) * 1
120 | bce_loss = self.bce_loss(predicted, target.float()) * 1
121 | final = dice_loss + bce_loss
122 | return final
123 |
124 |
125 | #####################################################################
126 | def build_loss_fn(loss_type: str, loss_args: Dict = None):
127 | if loss_type == "crossentropy":
128 | return CrossEntropyLoss()
129 |
130 | elif loss_type == "binarycrossentropy":
131 | return BinaryCrossEntropyWithLogits()
132 |
133 | elif loss_type == "MSE":
134 | return MSELoss()
135 |
136 | elif loss_type == "L1":
137 | return L1Loss()
138 |
139 | elif loss_type == "dice":
140 | return monai.losses.DiceLoss(to_onehot_y=False, sigmoid=True)
141 |
142 | elif loss_type == "dicebce":
143 | return DiceBCELoss()
144 |
145 | elif loss_type == "dicece":
146 | return monai.losses.DiceCELoss(to_onehot_y=True, softmax=True)
147 | else:
148 | raise ValueError("must be cross entropy or soft dice loss for now!")
149 |
--------------------------------------------------------------------------------
/misc/flop_counter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple, Dict
3 | from fvcore.nn import FlopCountAnalysis
4 |
5 |
6 | ##########################################################################################
7 | def flop_count_analysis(
8 | input_dim: Tuple,
9 | model: torch.nn.Module,
10 | ) -> Dict:
11 | """_summary_
12 |
13 | Args:
14 | input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))
15 | model (torch.nn.Module): _description_
16 | """
17 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
18 | input_tensor = torch.ones(()).new_empty(
19 | (1, *input_dim),
20 | dtype=next(model.parameters()).dtype,
21 | device=next(model.parameters()).device,
22 | )
23 | flops = FlopCountAnalysis(model, input_tensor)
24 | model_flops = flops.total()
25 | print(f"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M")
26 | print(f"MAdds: {round(model_flops * 1e-9, 2)} G")
27 |
28 | out = {
29 | "params": round(trainable_params * 1e-6, 2),
30 | "flops": round(model_flops * 1e-9, 2),
31 | }
32 |
33 | return out
34 |
--------------------------------------------------------------------------------
/misc/profiler.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "\n",
12 | "sys.path.append(\"../../../\")\n",
13 | "\n",
14 | "from config import *\n",
15 | "import wandb\n",
16 | "import torch\n",
17 | "from typing import Dict\n",
18 | "from losses.losses import build_loss_fn\n",
19 | "from optimizers.optimizers import build_optimizer\n",
20 | "from train_scripts.train_functions import train_val\n",
21 | "from architectures.leavisnet_v3 import build_model\n",
22 | "\n",
23 | "# from datasets.build_dataset import build_architecture, build_dataloader\n",
24 | "\n",
25 | "from collections import namedtuple\n",
26 | "from torch.utils.data import DataLoader\n",
27 | "import monai"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "model = build_test_model(None)\n",
37 | "with torch.cuda.device(0):\n",
38 | " net = model\n",
39 | " macs, params = get_model_complexity_info(\n",
40 | " net, (3, 224, 224), as_strings=True, print_per_layer_stat=True, verbose=True\n",
41 | " )\n",
42 | " print(\"{:<30} {:<8}\".format(\"Computational complexity: \", macs))\n",
43 | " print(\"{:<30} {:<8}\".format(\"Number of parameters: \", params))"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "import torch\n",
53 | "from torch.profiler import profile, record_function, ProfilerActivity"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "model = build_test_model(None).cpu()\n",
63 | "inputs = torch.randn(1, 3, 224, 224)\n",
64 | "\n",
65 | "pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
66 | "print(\"\\nmodel parameter count = \", pytorch_total_params)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "def estimate_memory_inference(\n",
76 | " model, sample_input, batch_size=1, use_amp=False, device=0\n",
77 | "):\n",
78 | " \"\"\"Predict the maximum memory usage of the model.\n",
79 | " Args:\n",
80 | " optimizer_type (Type): the class name of the optimizer to instantiate\n",
81 | " model (nn.Module): the neural network model\n",
82 | " sample_input (torch.Tensor): A sample input to the network. It should be\n",
83 | " a single item, not a batch, and it will be replicated batch_size times.\n",
84 | " batch_size (int): the batch size\n",
85 | " use_amp (bool): whether to estimate based on using mixed precision\n",
86 | " device (torch.device): the device to use\n",
87 | " \"\"\"\n",
88 | " # Reset model and optimizer\n",
89 | " model.cpu()\n",
90 | " a = torch.cuda.memory_allocated(device)\n",
91 | " model.to(device)\n",
92 | " b = torch.cuda.memory_allocated(device)\n",
93 | " model_memory = b - a\n",
94 | " model_input = sample_input # .unsqueeze(0).repeat(batch_size, 1)\n",
95 | " output = model(model_input.to(device)).sum()\n",
96 | " total_memory = model_memory\n",
97 | "\n",
98 | " return total_memory\n",
99 | "\n",
100 | "\n",
101 | "estimate_memory_inference(model, inputs)"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "model.cpu()\n",
111 | "with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:\n",
112 | " with record_function(\"model_inference\"):\n",
113 | " model(inputs)\n",
114 | "\n",
115 | "print(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "model = model.cuda()\n",
125 | "inputs = torch.randn(1, 3, 224, 224).cuda()\n",
126 | "\n",
127 | "with profile(\n",
128 | " activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True\n",
129 | ") as prof:\n",
130 | " with record_function(\"model_inference\"):\n",
131 | " model(inputs)\n",
132 | "\n",
133 | "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": []
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": []
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "import torch\n",
157 | "import matplotlib.pyplot as plt\n",
158 | "\n",
159 | "model = torch.nn.Linear(2, 1)\n",
160 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n",
161 | "lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n",
162 | " optimizer, T_0=20, T_mult=2, eta_min=0.001, last_epoch=-1\n",
163 | ")\n",
164 | "\n",
165 | "\n",
166 | "lrs = []\n",
167 | "\n",
168 | "for i in range(1000):\n",
169 | " lr_sched.step()\n",
170 | " lrs.append(optimizer.param_groups[0][\"lr\"])\n",
171 | "\n",
172 | "plt.plot(lrs)"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": []
179 | }
180 | ],
181 | "metadata": {
182 | "kernelspec": {
183 | "display_name": "corev2",
184 | "language": "python",
185 | "name": "python3"
186 | },
187 | "language_info": {
188 | "codemirror_mode": {
189 | "name": "ipython",
190 | "version": 3
191 | },
192 | "file_extension": ".py",
193 | "mimetype": "text/x-python",
194 | "name": "python",
195 | "nbconvert_exporter": "python",
196 | "pygments_lexer": "ipython3",
197 | "version": "3.10.10"
198 | },
199 | "orig_nbformat": 4
200 | },
201 | "nbformat": 4,
202 | "nbformat_minor": 2
203 | }
204 |
--------------------------------------------------------------------------------
/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/optimizers/__init__.py
--------------------------------------------------------------------------------
/optimizers/optimizers.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch.optim as optim
3 | from monai.optimizers import Novograd
4 |
5 | ######################################################################
6 | def optim_adam(model, optimizer_args):
7 | adam = optim.Adam(
8 | model.parameters(),
9 | lr=optimizer_args["lr"],
10 | weight_decay=optimizer_args.get("weight_decay"),
11 | )
12 | return adam
13 |
14 |
15 | ######################################################################
16 | def optim_sgd(model, optimizer_args):
17 | adam = optim.SGD(
18 | model.parameters(),
19 | lr=optimizer_args["lr"],
20 | weight_decay=optimizer_args.get("weight_decay"),
21 | momentum=optimizer_args.get("momentum"),
22 | )
23 | return adam
24 |
25 |
26 | ######################################################################
27 | def optim_adamw(model, optimizer_args):
28 | adam = optim.AdamW(
29 | model.parameters(),
30 | lr=optimizer_args["lr"],
31 | weight_decay=optimizer_args["weight_decay"],
32 | # amsgrad=True,
33 | )
34 | return adam
35 |
36 | ######################################################################
37 | def optim_novograd(model, optimizer_args):
38 | novograd = Novograd(
39 | model.parameters(),
40 | lr=optimizer_args["lr"],
41 | weight_decay=optimizer_args["weight_decay"],
42 | # amsgrad=True,
43 | )
44 | return novograd
45 |
46 |
47 | ######################################################################
48 | def build_optimizer(model, optimizer_type: str, optimizer_args: Dict):
49 | if optimizer_type == "adam":
50 | return optim_adam(model, optimizer_args)
51 | elif optimizer_type == "adamw":
52 | return optim_adamw(model, optimizer_args)
53 | elif optimizer_type == "sgd":
54 | return optim_sgd(model, optimizer_args)
55 | elif optimizer_type == "novograd":
56 | return optim_novograd(model, optimizer_args)
57 | else:
58 | raise ValueError("must be adam or adamw for now")
59 |
--------------------------------------------------------------------------------
/optimizers/schedulers.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch.optim as optim
3 | from torch.optim.lr_scheduler import LRScheduler
4 |
5 |
6 | ##################################################################################################
7 | def warmup_lr_scheduler(config, optimizer):
8 | """
9 | Linearly ramps up the learning rate within X
10 | number of epochs to the working epoch.
11 | Args:
12 | optimizer (_type_): _description_
13 | warmup_epochs (_type_): _description_
14 | warmup_lr (_type_): warmup lr should be the starting lr we want.
15 | """
16 | lambda1 = lambda epoch: (
17 | (epoch + 1) * 1.0 / config["warmup_scheduler"]["warmup_epochs"]
18 | )
19 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1, verbose=False)
20 | return scheduler
21 |
22 |
23 | ##################################################################################################
24 | def training_lr_scheduler(config, optimizer):
25 | """
26 | Wraps a normal scheuler
27 | """
28 | scheduler_type = config["train_scheduler"]["scheduler_type"]
29 | if scheduler_type == "reducelronplateau":
30 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
31 | optimizer,
32 | factor=0.1,
33 | mode=config["train_scheduler"]["mode"],
34 | patience=config["train_scheduler"]["patience"],
35 | verbose=False,
36 | min_lr=config["train_scheduler"]["scheduler_args"]["min_lr"],
37 | )
38 | return scheduler
39 | elif scheduler_type == "cosine_annealing_wr":
40 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
41 | optimizer,
42 | T_0=config["train_scheduler"]["scheduler_args"]["t_0_epochs"],
43 | T_mult=config["train_scheduler"]["scheduler_args"]["t_mult"],
44 | eta_min=config["train_scheduler"]["scheduler_args"]["min_lr"],
45 | last_epoch=-1,
46 | verbose=False,
47 | )
48 | return scheduler
49 | else:
50 | raise NotImplementedError("Specified Scheduler Is Not Implemented")
51 |
52 |
53 | ##################################################################################################
54 | def build_scheduler(
55 | optimizer: optim.Optimizer, scheduler_type: str, config
56 | ) -> LRScheduler:
57 | """generates the learning rate scheduler
58 |
59 | Args:
60 | optimizer (optim.Optimizer): pytorch optimizer
61 | scheduler_type (str): type of scheduler
62 |
63 | Returns:
64 | LRScheduler: _description_
65 | """
66 | if scheduler_type == "warmup_scheduler":
67 | scheduler = warmup_lr_scheduler(config=config, optimizer=optimizer)
68 | return scheduler
69 | elif scheduler_type == "training_scheduler":
70 | scheduler = training_lr_scheduler(config=config, optimizer=optimizer)
71 | return scheduler
72 | else:
73 | raise ValueError("Invalid Input -- Check scheduler_type")
74 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.23.0
2 | datasets==2.14.5
3 | evaluate==0.4.1
4 | einops==0.7.0
5 | h5py==3.9.0
6 | imageio==2.33.0
7 | kornia==0.7.2
8 | kornia-rs==0.1.3
9 | matplotlib==3.7.1
10 | matplotlib-inline==0.1.6
11 | monai==1.2.0
12 | nibabel==5.1.0
13 | opencv-contrib-python==4.7.0.72
14 | opencv-python==4.7.0.72
15 | opencv-python-headless==4.7.0.72
16 | opt-einsum==3.3.0
17 | protobuf==4.22.3
18 | requests==2.28.1
19 | safetensors==0.3.1
20 | scikit-image==0.23.1
21 | scikit-learn==1.2.2
22 | scipy==1.13.0
23 | termcolor==2.3.0
24 | timm==0.6.13
25 | torch-geometric==2.3.0
26 | torch-summary==1.4.5
27 | torchmetrics==0.11.4
28 | torchsummary==1.5.1
29 | tqdm==4.65.0
30 | typing_extensions==4.8.0
31 | wandb==0.15.12
32 |
--------------------------------------------------------------------------------
/resources/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/.gitignore
--------------------------------------------------------------------------------
/resources/adv_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/adv_arch.png
--------------------------------------------------------------------------------
/resources/citiscapes1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/citiscapes1.png
--------------------------------------------------------------------------------
/resources/cityscapes2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes2.png
--------------------------------------------------------------------------------
/resources/cityscapes3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes3.png
--------------------------------------------------------------------------------
/resources/cityscapes4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes4.png
--------------------------------------------------------------------------------
/resources/cityscapes_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/cityscapes_table.png
--------------------------------------------------------------------------------
/resources/isic_2016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/isic_2016.png
--------------------------------------------------------------------------------
/resources/isic_2017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/isic_2017.png
--------------------------------------------------------------------------------
/resources/isic_2018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/isic_2018.png
--------------------------------------------------------------------------------
/resources/muvit_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/muvit_architecture.png
--------------------------------------------------------------------------------
/resources/params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/params.png
--------------------------------------------------------------------------------
/resources/ph2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/ph2.png
--------------------------------------------------------------------------------
/resources/postdam1_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam1_gt.png
--------------------------------------------------------------------------------
/resources/postdam1_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam1_pred.png
--------------------------------------------------------------------------------
/resources/postdam2_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam2_gt.png
--------------------------------------------------------------------------------
/resources/postdam2_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/postdam2_pred.png
--------------------------------------------------------------------------------
/resources/potsdam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/potsdam.png
--------------------------------------------------------------------------------
/resources/vaihigen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/resources/vaihigen.png
--------------------------------------------------------------------------------
/train_scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OSUPCVLab/MobileUNETR/bac8eab0a8831cd07c9fabe90a9bceea5783d623/train_scripts/__init__.py
--------------------------------------------------------------------------------
/train_scripts/ema.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch.nn as nn
3 |
4 |
5 | ###############################################################################
6 | class EMA:
7 | "https://github.com/scott-yjyang/DiffMIC/blob/main/ema.py"
8 |
9 | def __init__(
10 | self,
11 | model: nn.Module,
12 | mu: float = 0.999,
13 | ) -> None:
14 | self.mu = mu
15 | self.ema_model = {}
16 | self.register(model)
17 | self.model_copy = copy.deepcopy(model)
18 |
19 | def register(self, module) -> None:
20 | for name, param in module.named_parameters():
21 | if param.requires_grad:
22 | self.ema_model[name] = param.data.clone()
23 |
24 | def update(self, module: nn.Module) -> None:
25 | for name, param in module.named_parameters():
26 | if param.requires_grad:
27 | self.ema_model[name].data = (
28 | 1.0 - self.mu
29 | ) * param.data + self.mu * self.ema_model[name].data
30 |
31 | def ema(self, module: nn.Module) -> None:
32 | for name, param in module.named_parameters():
33 | if param.requires_grad:
34 | param.data.copy_(self.ema_model[name].data)
35 |
36 | def ema_copy(self, module: nn.Module):
37 | """
38 | Returns the model with the ema weights inserted.
39 | Args:
40 | module (nn.Module): _description_
41 |
42 | Returns:
43 | _type_: _description_
44 | """
45 | # module_copy = type(module)(module.config).to(module.config.device)
46 | module_copy = copy.deepcopy(module)
47 | # module_copy.load_state_dict(module.state_dict())
48 | self.ema(module_copy)
49 |
50 | return module_copy
51 |
52 | def state_dict(self):
53 | """
54 | Returns ema model
55 | Returns:
56 | _type_: _description_
57 | """
58 | return self.ema_model
59 |
60 | def load_state_dict(self, state_dict) -> None:
61 | self.ema_model = state_dict
62 |
63 |
64 | ###############################################################################
65 |
--------------------------------------------------------------------------------
/train_scripts/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import wandb
4 | import torch
5 | import random
6 | import numpy as np
7 |
8 | """
9 | Utils File Used for Training/Validation/Testing
10 | """
11 |
12 |
13 | ##################################################################################################
14 | def log_metrics(**kwargs) -> None:
15 | # data to be logged
16 | log_data = {}
17 | log_data.update(kwargs)
18 |
19 | # log the data
20 | wandb.log(log_data)
21 |
22 |
23 | ##################################################################################################
24 | def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar") -> None:
25 | # print("=> Saving checkpoint")
26 | checkpoint = {
27 | "state_dict": model.state_dict(),
28 | "optimizer": optimizer.state_dict(),
29 | }
30 | torch.save(checkpoint, filename)
31 |
32 |
33 | ##################################################################################################
34 | def load_checkpoint(config, model, optimizer, load_optimizer=True):
35 | print("=> Loading checkpoint")
36 | checkpoint = torch.load(config.checkpoint_file_name, map_location=config.device)
37 | model.load_state_dict(checkpoint["state_dict"])
38 |
39 | if load_optimizer:
40 | optimizer.load_state_dict(checkpoint["optimizer"])
41 |
42 | # If we don't do this then it will just have learning rate of old checkpoint
43 | # and it will lead to many hours of debugging \:
44 | for param_group in optimizer.param_groups:
45 | param_group["lr"] = config.learning_rate
46 |
47 | return model, optimizer
48 |
49 |
50 | ##################################################################################################
51 | def seed_everything(seed: int = 42) -> None:
52 | os.environ["PYTHONHASHSEED"] = str(seed)
53 | random.seed(seed)
54 | np.random.seed(seed)
55 | torch.manual_seed(seed)
56 | torch.cuda.manual_seed(seed)
57 | torch.cuda.manual_seed_all(seed)
58 | torch.backends.cudnn.deterministic = True
59 | torch.backends.cudnn.benchmark = False
60 |
61 |
62 | ##################################################################################################
63 | def initialize_weights(m):
64 | if isinstance(m, torch.nn.Conv2d):
65 | torch.nn.init.xavier_uniform_(m.weight)
66 | if m.bias is not None:
67 | torch.nn.init.constant_(m.bias.data, 0)
68 | elif isinstance(m, torch.nn.BatchNorm2d):
69 | torch.nn.init.constant_(m.weight.data, 1)
70 | if m.bias is not None:
71 | torch.nn.init.constant_(m.bias.data, 0)
72 | elif isinstance(m, torch.nn.Linear):
73 | torch.nn.init.xavier_uniform_(m.weight)
74 | if m.bias is not None:
75 | torch.nn.init.constant_(m.bias.data, 0)
76 |
77 |
78 | ##################################################################################################
79 |
--------------------------------------------------------------------------------