├── LEARN.md ├── .gitattributes ├── losses ├── __init__.py ├── focal_loss.py └── dice_loss.py ├── attentions ├── __init__.py ├── custom_attention.py └── cbam.py ├── CITATION.cff ├── models ├── __init__.py ├── unetplusplus.py ├── modules.py ├── dense_unet.py └── model.py ├── LICENSE ├── README.md ├── CONTRIBUTING.md ├── lookahead.py ├── dataset.py ├── .gitignore ├── train.py └── utils.py /LEARN.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .dice_loss import DiceBCELossLogitsLoss 2 | from .focal_loss import FocalLoss 3 | 4 | __all__ = ['DiceBCELossLogitsLoss', 'FocalLoss'] -------------------------------------------------------------------------------- /attentions/__init__.py: -------------------------------------------------------------------------------- 1 | from cbam import CBAM, ChannelGate, SpatialGate 2 | from custom_attention import SpatialAttention 3 | 4 | __all__ = ['CBAM', 'ChannelGate', 'SpatialGate', 'SpatialAttention'] -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | authors: 2 | - family-names: Ahmad 3 | given-names: Saeed 4 | message: "If you use this software, please cite it using these metadata." 5 | title: "Teeth Segmentation using PyTorch" 6 | -------------------------------------------------------------------------------- /losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['FocalLoss'] 6 | 7 | class FocalLoss(nn.Module): 8 | def __init__(self, alpha=0.25, gamma=2): 9 | super(FocalLoss, self).__init__() 10 | self.alpha = alpha 11 | self.gamma = gamma 12 | 13 | def forward(self, inputs, targets): 14 | bce_loss = F.binary_cross_entropy(inputs, targets.float()) 15 | loss = self.alpha * (1 - torch.exp(-bce_loss)) ** self.gamma * bce_loss 16 | return loss -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dense_unet import DenseUNet 2 | from .model import UNET, UNET_GN, Attention_UNET, CustomAttention_UNET, Inception_UNET, Inception_Attention_UNET, ResUNET, ResUNETPlus, ResUNET_with_CBAM, ResUNET_with_GN 3 | from .unetplusplus import NestedUNet as UNET_Plus 4 | from .modules import DoubleConv, DoubleConv_GN, Attention_block, InceptionBlock, ResNetBlock 5 | 6 | 7 | __all__ = ['UNET', 'UNET_GN', 'Attention_UNET', 'CustomAttention_UNET', 'Inception_UNET', 8 | 'Inception_Attention_UNET', 'ResUNET', 'ResUNETPlus', 'ResUNET_with_CBAM', 'ResUNET_with_GN', 9 | 'UNET_Plus', 'DenseUNet'] -------------------------------------------------------------------------------- /losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | __all__ = ['DiceBCELossLogitsLoss'] 7 | 8 | class DiceBCELossLogitsLoss(nn.Module): 9 | def __init__(self, weight=None, size_average=True): 10 | super(DiceBCELossLogitsLoss, self).__init__() 11 | 12 | def forward(self, inputs, targets, smooth=1): 13 | 14 | #comment out if your model contains a sigmoid or equivalent activation layer 15 | inputs = torch.sigmoid(inputs) 16 | 17 | #flatten label and prediction tensors 18 | inputs = inputs.view(-1) 19 | targets = targets.view(-1) 20 | 21 | intersection = (inputs * targets).sum() 22 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 23 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 24 | Dice_BCE = BCE + dice_loss 25 | 26 | return Dice_BCE -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Saeed Ahmad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Teeth Segmentation Using UNet and Its Variants 2 | 3 | This repository contains the code for training and evaluating various UNet models for teeth segmentation. The models implemented include the original UNet, as well as some of its variants such as UNet++, ResUNet, and Attention UNet. 4 | 5 | ## Data 6 | The data used for this project is not publicly available, but you can request it by contacting me through the email address provided on the profile page. Once you have the data, make sure to update the paths accordingly. 7 | 8 | ## Usage 9 | Before running the code, make sure to modify the 'train.py' file and the other relevent files so that you can get the relevent results. 10 | 11 | To train the model, simply run the train.py script: 12 | 13 | ```bash 14 | python train.py 15 | ``` 16 | 17 | ## Results for the test data 18 | 19 | | UNet Variants | Test Accurary | Test Dice Score | 20 | |----------|----------|----------| 21 | | Base UNet | 96.10 | 90.47 | 22 | | UNet with GN | 96.71 | 91.53 | 23 | | Attention UNet | 96.40 | 91.01 | 24 | | Spatial Attention UNet | 96.45 | 91.09 | 25 | | Inception UNet | 96.29 | 90.69 | 26 | | Residual UNet | 96.16 | 90.06 | 27 | | UNet++ | 96.11 | 90.33 | 28 | | Dense UNet with GN | 96.77 | 91.88 | 29 | | **Spatial Attention UNet2 ${\color{red}\^*}$** | **97.32** | **93.12** | 30 | 31 | ${\color{red}\*}$ increase the resolution from 256\*256 to 768\*512, reduce the batch size from 16 to 2, used Group Normalization and Custom spatial attention module with base UNet 32 | 33 | 34 | ${\color{red}Note}$ 35 | This project is solely for learning purposes; no standard practices are applied. Therefore, I am not claiming any state-of-the-art results. 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for considering contributing to our Teeth Segmentation project! We value your interest and efforts in enhancing the capabilities and quality of our software. 4 | 5 | To contribute, here's a step-by-step guide to get you started: 6 | 7 | - **Fork the Repository:** Begin by forking the repository on GitHub. This creates a copy of the repository in your own GitHub account. 8 | 9 | - **Clone the Forked Repository:** Clone the forked repository to your local machine using the `git clone` command. This step allows you to work on the files locally. 10 | 11 | ```bash 12 | git clone [URL of your forked repository] 13 | ``` 14 | 15 | - **Create a New Branch:** Before making any changes, switch to a new branch using the `git checkout` command. You can name the branch anything you want, but it's recommended to use a descriptive name that reflects the changes you're going to make. 16 | 17 | ```bash 18 | git checkout -b [name_of_your_new_branch] 19 | ``` 20 | 21 | - **Make Changes:** Make the necessary changes or additions to the code in your local repository. Focus on making changes that are clear and address specific issues or enhancements. 22 | 23 | - **Test Your Changes:** After implementing your changes, test them thoroughly to ensure they work as intended and don't introduce any new issues. 24 | 25 | - **Commit Your Changes:** Commit your changes to the branch with a clear and detailed commit message. This helps us understand the changes you've made and why you made them. 26 | 27 | ```bash 28 | git add . 29 | git commit -m "your commit message" 30 | ``` 31 | 32 | - **Push Your Changes:** Push your changes from your local repository to the remote repository on GitHub. 33 | 34 | ```bash 35 | git push origin [name_of_your_new_branch] 36 | ``` 37 | 38 | - **Submit a Pull Request:** Once your changes are pushed, submit a pull request (PR) to the main repository. In your PR description, provide details about the changes and their purpose. Make sure to reference any related issues. 39 | 40 | - **Await Review:** Your PR will be reviewed by the maintainers of the project. Be responsive to any feedback or requests for changes. 41 | -------------------------------------------------------------------------------- /lookahead.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import chain 3 | from torch.optim import Optimizer 4 | import torch 5 | import warnings 6 | 7 | class Lookahead(Optimizer): 8 | def __init__(self, optimizer, k=5, alpha=0.5): 9 | self.optimizer = optimizer 10 | self.k = k 11 | self.alpha = alpha 12 | self.param_groups = self.optimizer.param_groups 13 | self.state = defaultdict(dict) 14 | self.fast_state = self.optimizer.state 15 | for group in self.param_groups: 16 | group["counter"] = 0 17 | 18 | def update(self, group): 19 | for fast in group["params"]: 20 | param_state = self.state[fast] 21 | if "slow_param" not in param_state: 22 | param_state["slow_param"] = torch.zeros_like(fast.data) 23 | param_state["slow_param"].copy_(fast.data) 24 | slow = param_state["slow_param"] 25 | slow += (fast.data - slow) * self.alpha 26 | fast.data.copy_(slow) 27 | 28 | def update_lookahead(self): 29 | for group in self.param_groups: 30 | self.update(group) 31 | 32 | def step(self, closure=None): 33 | loss = self.optimizer.step(closure) 34 | for group in self.param_groups: 35 | if group["counter"] == 0: 36 | self.update(group) 37 | group["counter"] += 1 38 | if group["counter"] >= self.k: 39 | group["counter"] = 0 40 | return loss 41 | 42 | def state_dict(self): 43 | fast_state_dict = self.optimizer.state_dict() 44 | slow_state = { 45 | (id(k) if isinstance(k, torch.Tensor) else k): v 46 | for k, v in self.state.items() 47 | } 48 | fast_state = fast_state_dict["state"] 49 | param_groups = fast_state_dict["param_groups"] 50 | return { 51 | "fast_state": fast_state, 52 | "slow_state": slow_state, 53 | "param_groups": param_groups, 54 | } 55 | 56 | def load_state_dict(self, state_dict): 57 | slow_state_dict = { 58 | "state": state_dict["slow_state"], 59 | "param_groups": state_dict["param_groups"], 60 | } 61 | fast_state_dict = { 62 | "state": state_dict["fast_state"], 63 | "param_groups": state_dict["param_groups"], 64 | } 65 | super(Lookahead, self).load_state_dict(slow_state_dict) 66 | self.optimizer.load_state_dict(fast_state_dict) 67 | self.fast_state = self.optimizer.state 68 | 69 | def add_param_group(self, param_group): 70 | param_group["counter"] = 0 71 | self.optimizer.add_param_group(param_group) -------------------------------------------------------------------------------- /attentions/custom_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..models import Attention_block 5 | 6 | __all__ = ['SpatialAttention'] 7 | 8 | 9 | class ChannelPool(nn.Module): 10 | def forward(self, x): 11 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 12 | 13 | 14 | class spatial_block(nn.Module): 15 | def __init__(self, in_channels, out_channels): 16 | super(spatial_block, self).__init__() 17 | 18 | self.conv1 = nn.Sequential( 19 | nn.GroupNorm(num_groups=in_channels//8,num_channels=in_channels), 20 | nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, dilation=1), 21 | #nn.BatchNorm2d(out_channels), 22 | #nn.ReLU(inplace=True) 23 | ) 24 | self.conv2 = nn.Sequential( 25 | nn.GroupNorm(num_groups=in_channels//8,num_channels=in_channels), 26 | nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, dilation=1), 27 | #nn.BatchNorm2d(out_channels), 28 | #nn.ReLU(inplace=True) 29 | ) 30 | 31 | self.conv3 = nn.Sequential( 32 | nn.GroupNorm(num_groups=in_channels//8,num_channels=in_channels), 33 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1), 34 | #nn.BatchNorm2d(out_channels), 35 | #nn.ReLU(inplace=True) 36 | ) 37 | 38 | self.conv4 = nn.Sequential( 39 | nn.GroupNorm(num_groups=in_channels//8,num_channels=in_channels), 40 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1), 41 | #nn.BatchNorm2d(out_channels), 42 | #nn.ReLU(inplace=True) 43 | ) 44 | 45 | def forward(self, x): 46 | concat = [self.conv1(x), self.conv2(x), self.conv3(x), self.conv4(x)] 47 | return torch.cat(concat, dim=1) 48 | 49 | class SpatialAttention(nn.Module): 50 | def __init__(self, in_channel, kernel_size = 7): 51 | super(SpatialAttention, self).__init__() 52 | 53 | self.x_pool = ChannelPool() 54 | self.g_pool = ChannelPool() 55 | 56 | self.x_block = spatial_block(in_channels=in_channel, out_channels=1) 57 | self.g_block = spatial_block(in_channels=in_channel, out_channels=1) 58 | 59 | self.scale_x = nn.Conv2d(6, 1,kernel_size=1, stride=1, padding=0, dilation=4) 60 | self.scale_g = nn.Conv2d(6, 1,kernel_size=1, stride=1, padding=0, dilation=4) 61 | 62 | 63 | def forward(self, x, g): 64 | x1 = self.x_pool(x) 65 | x2 = self.x_block(x) 66 | x_out = torch.cat((x1, x2), dim=1) 67 | 68 | g1 = self.g_pool(g) 69 | g2 = self.g_block(g) 70 | g_out = torch.cat((g1, g2), dim=1) 71 | 72 | scale_x = self.scale_x(x_out) 73 | scale_g = self.scale_g(g_out) 74 | 75 | output = torch.sigmoid(scale_x+scale_g) 76 | return x* output 77 | 78 | def test(): 79 | x = torch.randn((3, 64, 256, 256)) 80 | #attention = CBAM(gate_channels=64) 81 | #output = attention(x) 82 | output = SpatialAttention(in_channel=64) 83 | x2 = output(x, x) 84 | print(x2.shape) 85 | 86 | 87 | 88 | 89 | if __name__ == "__main__": 90 | test() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import re 6 | from sklearn.model_selection import train_test_split 7 | 8 | class Teeth_Dataset(Dataset): 9 | def __init__(self, images_dir, masks_dir,data_dict, data_type, transform=None, target_transform=None): 10 | self.images_dir = images_dir 11 | self.masks_dir = masks_dir 12 | self.transfrom = transform 13 | self.target_transform = target_transform 14 | 15 | 16 | self.images = data_dict[f'{data_type}_images'] #os.listdir(images_dir) 17 | self.masks = data_dict[f'{data_type}_masks'] #os.listdir(masks_dir) 18 | 19 | 20 | def __len__(self): 21 | return len(self.images) 22 | 23 | 24 | def __getitem__(self, index): 25 | img_path = os.path.join(self.images_dir, str(self.images[index])) 26 | mask_path = os.path.join(self.masks_dir, str(self.masks[index])) 27 | 28 | image = Image.open(img_path) 29 | mask = Image.open(mask_path) 30 | 31 | if self.transfrom: 32 | image = self.transfrom(image) 33 | 34 | if self.target_transform: 35 | mask = self.target_transform(mask) 36 | 37 | return image, mask 38 | 39 | def split_category(images_path, masks_path): 40 | ''' 41 | This function read all the images and split the images category wise 42 | ''' 43 | a = os.listdir(images_path) 44 | b = os.listdir(masks_path) 45 | 46 | all_images = [image for image in a if image.endswith('.jpg')] 47 | all_masks = [mask for mask in b if mask.endswith('.bmp')] 48 | 49 | catogry_wise_images = [[] for _ in range(10)] 50 | catogry_wise_masks = [[] for _ in range(10)] 51 | 52 | for image, mask in zip(all_images, all_masks): 53 | image_cat = image.split('-')[0] 54 | image_cat = int((re.search(r"[0-9]+", image_cat)).group()) 55 | catogry_wise_images[image_cat-1].append(image) 56 | catogry_wise_masks[image_cat-1].append(mask) 57 | return catogry_wise_images, catogry_wise_masks 58 | 59 | 60 | def split_data(category_images, category_masks, test_train_ratio=0.7, train_valid_ratio=0.9): 61 | train_images, train_labels = [], [] 62 | validation_images, validation_labels = [], [] 63 | test_images, test_labels = [], [] 64 | 65 | for cat_images, cat_masks in zip(category_images, category_masks): 66 | train_valid_images, test_images_, train_valid_labels, test_labels_ = train_test_split(cat_images, cat_masks, test_size=test_train_ratio, shuffle=True, random_state=15) 67 | train_images_, valid_images_, train_labels_, valid_labels_ = train_test_split(train_valid_images, train_valid_labels, train_size=train_valid_ratio, shuffle=True,random_state=30) 68 | 69 | train_images.extend(train_images_) 70 | train_labels.extend(train_labels_) 71 | validation_images.extend(valid_images_) 72 | validation_labels.extend(valid_labels_) 73 | test_images.extend(test_images_) 74 | test_labels.extend(test_labels_) 75 | 76 | data_dict = { 77 | 'train_images': train_images, 78 | 'train_masks': train_labels, 79 | 'validation_images': validation_images, 80 | 'validation_masks': validation_labels, 81 | 'test_images': test_images, 82 | 'test_masks': test_labels, 83 | } 84 | 85 | return data_dict 86 | 87 | 88 | 89 | 90 | 91 | def test(): 92 | TEST_IMG_DIR = "./test/test2018/" 93 | TEST_MASK_DIR = "./test/mask/" 94 | cat_wise_images, cat_wise_masks = split_category(TEST_IMG_DIR, TEST_MASK_DIR) 95 | data_dict = split_data(cat_wise_images, cat_wise_images) 96 | print("hello") 97 | #ds = Teeth_Dataset("./train-val/train2018/", "./train-val/masks/") 98 | 99 | 100 | if __name__ == "__main__": 101 | test() 102 | 103 | 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | 163 | ## direcotoris 164 | test/ 165 | train-val/ 166 | saved_models/ 167 | runs/ 168 | 169 | 170 | ## all png and jpeg files 171 | *.jpeg 172 | *.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as t 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import numpy as np 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | 9 | from utils import get_loaders, Fit, check_accuracy, plot_history, visualize_random_image 10 | from models import UNET, Attention_UNET, Inception_UNET, Inception_Attention_UNET, ResUNET, ResUNETPlus, ResUNET_with_GN, ResUNET_with_CBAM, UNET_GN, CustomAttention_UNET 11 | from models.unetplusplus import NestedUNet as UNET_Plus 12 | from dataset import split_data, split_category 13 | #from focal_loss import FocalLoss 14 | from lookahead import Lookahead 15 | from models.dense_unet import DenseUNet 16 | from losses import DiceBCELossLogitsLoss 17 | 18 | 19 | 20 | import os 21 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 22 | 23 | LEARNING_RATE = 1e-4 24 | DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" 25 | BATCH_SIZE = 2 26 | NUM_EPOCHS = 30 27 | IMAGE_HEIGHT = 512 #256 # 1127 originally 28 | IMAGE_WIDTH = 768 #256 # 1991 originally 29 | TRAIN_IMG_DIR = "./train-val/train2018/" 30 | TRAIN_MASK_DIR = "./train-val/masks/" 31 | TEST_IMG_DIR = "./test/test2018/" 32 | TEST_MASK_DIR = "./test/mask/" 33 | MODEL_PATH = "./saved_models/customSpatialAttentionUnet2.pth" 34 | 35 | 36 | 37 | def main(): 38 | 39 | 40 | ## transforms for train images 41 | train_images_transform = t.Compose( 42 | [ 43 | t.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), 44 | t.ToTensor(), 45 | t.Normalize( 46 | mean = [0.477, 0.451, 0.411], 47 | std = [0.284, 0.280, 0.292], 48 | ), 49 | 50 | ] 51 | ) 52 | 53 | ## transforms for train masks 54 | train_masks_transform = t.Compose( 55 | [ 56 | t.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), 57 | t.ToTensor(), 58 | ] 59 | ) 60 | 61 | ## transforms for test images and masks 62 | test_images_transform, test_masks_transform = train_images_transform, train_masks_transform 63 | 64 | ##spliting the data into train, validation and test subset 65 | cat_wise_images, cat_wise_masks = split_category(TEST_IMG_DIR, TEST_MASK_DIR) 66 | data_dict = split_data(cat_wise_images, cat_wise_masks, test_train_ratio=0.7, train_valid_ratio=0.9) 67 | 68 | train_dl, validation_dl, test_dl = get_loaders( 69 | #train_dir= TRAIN_IMG_DIR, 70 | #train_maskdir= TRAIN_MASK_DIR, 71 | images_dir= TEST_IMG_DIR, 72 | masks_dir= TEST_MASK_DIR, 73 | batch_size= BATCH_SIZE, 74 | train_images_transform= train_images_transform, 75 | train_masks_transform= train_masks_transform, 76 | test_images_transform= test_images_transform, 77 | test_masks_transform= test_masks_transform, 78 | data_dict = data_dict, 79 | ) 80 | 81 | 82 | #loss_fn = nn.BCEWithLogitsLoss() 83 | loss_fn = DiceBCELossLogitsLoss() 84 | 85 | 86 | print("CustomAttentionwithGN_DL") 87 | # writer = SummaryWriter("runs/CustomAttentionwithGN_DL3") 88 | model = CustomAttention_UNET(in_channels=3, out_channels=1) 89 | model.to(device=DEVICE) 90 | optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE,) 91 | #lookahead = Lookahead(optimizer, k=3, alpha=0.6) 92 | history = Fit(model=model,train_dl=train_dl, validation_dl=validation_dl, loss_fn=loss_fn, optimizer=optimizer, epochs=NUM_EPOCHS, device=DEVICE, writer=writer) 93 | 94 | torch.save(model.state_dict(),MODEL_PATH ) 95 | #model.load_state_dict(torch.load(MODEL_PATH)) 96 | 97 | visualize_random_image(model=model, loader=test_dl, device=DEVICE, threshold=0.85, width=IMAGE_WIDTH, height=IMAGE_HEIGHT) 98 | 99 | dict = check_accuracy(test_dl, model, device=DEVICE, threshold=0.85, test=True) 100 | 101 | print(f"\n\ntest_accuracy: {dict['accuracy']:.2f}") 102 | print(f"test dice score: {dict['dice_score']:.2f}") 103 | print(f"test precision: {dict['precision']:.2f}") 104 | print(f"test recall: {dict['recall']:.2f}") 105 | print(f"test specificity: {dict['specificity']:.2f}") 106 | print(f"test f1_score: {dict['f1_score']:.2f}") 107 | 108 | 109 | ### ploting graphs 110 | # plot_history(history) 111 | 112 | print("Completed") 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | 119 | 120 | -------------------------------------------------------------------------------- /attentions/cbam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['ChannelGate', 'SpatialGate', 'CBAM'] 6 | 7 | class BasicConv(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 9 | super(BasicConv, self).__init__() 10 | self.out_channels = out_planes 11 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 12 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 13 | self.relu = nn.ReLU() if relu else None 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | if self.bn is not None: 18 | x = self.bn(x) 19 | if self.relu is not None: 20 | x = self.relu(x) 21 | return x 22 | 23 | class Flatten(nn.Module): 24 | def forward(self, x): 25 | return x.view(x.size(0), -1) 26 | 27 | class ChannelGate(nn.Module): 28 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 29 | super(ChannelGate, self).__init__() 30 | self.gate_channels = gate_channels 31 | self.mlp = nn.Sequential( 32 | Flatten(), 33 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 34 | nn.ReLU(), 35 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 36 | ) 37 | self.pool_types = pool_types 38 | def forward(self, x): 39 | channel_att_sum = None 40 | for pool_type in self.pool_types: 41 | if pool_type=='avg': 42 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 43 | channel_att_raw = self.mlp( avg_pool ) 44 | elif pool_type=='max': 45 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 46 | channel_att_raw = self.mlp( max_pool ) 47 | elif pool_type=='lp': 48 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 49 | channel_att_raw = self.mlp( lp_pool ) 50 | elif pool_type=='lse': 51 | # LSE pool only 52 | lse_pool = logsumexp_2d(x) 53 | channel_att_raw = self.mlp( lse_pool ) 54 | 55 | if channel_att_sum is None: 56 | channel_att_sum = channel_att_raw 57 | else: 58 | channel_att_sum = channel_att_sum + channel_att_raw 59 | 60 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 61 | return x * scale 62 | 63 | def logsumexp_2d(tensor): 64 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 65 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 66 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 67 | return outputs 68 | 69 | class ChannelPool(nn.Module): 70 | def forward(self, x): 71 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 72 | 73 | class SpatialGate(nn.Module): 74 | def __init__(self): 75 | super(SpatialGate, self).__init__() 76 | kernel_size = 7 77 | self.compress = ChannelPool() 78 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 79 | def forward(self, x): 80 | x_compress = self.compress(x) 81 | x_out = self.spatial(x_compress) 82 | scale = torch.sigmoid(x_out) # broadcasting 83 | return x * scale 84 | 85 | class CBAM(nn.Module): 86 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 87 | super(CBAM, self).__init__() 88 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 89 | self.no_spatial=no_spatial 90 | if not no_spatial: 91 | self.SpatialGate = SpatialGate() 92 | def forward(self, x): 93 | x_out = self.ChannelGate(x) 94 | if not self.no_spatial: 95 | x_out = self.SpatialGate(x_out) 96 | return x_out 97 | 98 | def test(): 99 | x = torch.randn((3, 64, 161, 161)) 100 | attention = CBAM(gate_channels=64) 101 | output = attention(x) 102 | print(output.shape) 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | test() -------------------------------------------------------------------------------- /models/unetplusplus.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | import torch 4 | from torchvision import models 5 | import torchvision 6 | 7 | __all__ = ['NestedUNet'] 8 | 9 | 10 | class DoubleConv(nn.Module): 11 | def __init__(self, in_ch, out_ch): 12 | super(DoubleConv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 18 | nn.BatchNorm2d(out_ch), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, input): 23 | return self.conv(input) 24 | 25 | # class VGGBlock(nn.Module): 26 | # def __init__(self, in_channels, middle_channels, out_channels, act_func=nn.ReLU(inplace=True)): 27 | # super(VGGBlock, self).__init__() 28 | # self.act_func = act_func 29 | # self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 30 | # self.bn1 = nn.BatchNorm2d(middle_channels) 31 | # self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 32 | # self.bn2 = nn.BatchNorm2d(out_channels) 33 | # 34 | # def forward(self, x): 35 | # out = self.conv1(x) 36 | # out = self.bn1(out) 37 | # out = self.act_func(out) 38 | # 39 | # out = self.conv2(out) 40 | # out = self.bn2(out) 41 | # out = self.act_func(out) 42 | # return out 43 | 44 | class NestedUNet(nn.Module): 45 | def __init__(self, in_channel,out_channel, deepsupervision=True): 46 | super().__init__() 47 | 48 | self.deepsupervision = deepsupervision 49 | 50 | nb_filter = [32, 64, 128, 256, 512] 51 | 52 | self.pool = nn.MaxPool2d(2, 2) 53 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 54 | 55 | self.conv0_0 = DoubleConv(in_channel, nb_filter[0]) 56 | self.conv1_0 = DoubleConv(nb_filter[0], nb_filter[1]) 57 | self.conv2_0 = DoubleConv(nb_filter[1], nb_filter[2]) 58 | self.conv3_0 = DoubleConv(nb_filter[2], nb_filter[3]) 59 | self.conv4_0 = DoubleConv(nb_filter[3], nb_filter[4]) 60 | 61 | self.conv0_1 = DoubleConv(nb_filter[0]+nb_filter[1], nb_filter[0]) 62 | self.conv1_1 = DoubleConv(nb_filter[1]+nb_filter[2], nb_filter[1]) 63 | self.conv2_1 = DoubleConv(nb_filter[2]+nb_filter[3], nb_filter[2]) 64 | self.conv3_1 = DoubleConv(nb_filter[3]+nb_filter[4], nb_filter[3]) 65 | 66 | self.conv0_2 = DoubleConv(nb_filter[0]*2+nb_filter[1], nb_filter[0]) 67 | self.conv1_2 = DoubleConv(nb_filter[1]*2+nb_filter[2], nb_filter[1]) 68 | self.conv2_2 = DoubleConv(nb_filter[2]*2+nb_filter[3], nb_filter[2]) 69 | 70 | self.conv0_3 = DoubleConv(nb_filter[0]*3+nb_filter[1], nb_filter[0]) 71 | self.conv1_3 = DoubleConv(nb_filter[1]*3+nb_filter[2], nb_filter[1]) 72 | 73 | self.conv0_4 = DoubleConv(nb_filter[0]*4+nb_filter[1], nb_filter[0]) 74 | #self.sigmoid = nn.Sigmoid() 75 | 76 | if self.deepsupervision: 77 | self.final1 = nn.Conv2d(nb_filter[0], out_channel, kernel_size=1) 78 | self.final2 = nn.Conv2d(nb_filter[0], out_channel, kernel_size=1) 79 | self.final3 = nn.Conv2d(nb_filter[0], out_channel, kernel_size=1) 80 | self.final4 = nn.Conv2d(nb_filter[0], out_channel, kernel_size=1) 81 | else: 82 | self.final = nn.Conv2d(nb_filter[0], out_channel, kernel_size=1) 83 | 84 | 85 | def forward(self, input): 86 | x0_0 = self.conv0_0(input) 87 | x1_0 = self.conv1_0(self.pool(x0_0)) 88 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 89 | 90 | x2_0 = self.conv2_0(self.pool(x1_0)) 91 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 92 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 93 | 94 | x3_0 = self.conv3_0(self.pool(x2_0)) 95 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) 96 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) 97 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 98 | 99 | x4_0 = self.conv4_0(self.pool(x3_0)) 100 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 101 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) 102 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) 103 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 104 | 105 | if self.deepsupervision: 106 | output1 = self.final1(x0_1) 107 | #output1 = self.sigmoid(output1) 108 | output2 = self.final2(x0_2) 109 | #output2 = self.sigmoid(output2) 110 | output3 = self.final3(x0_3) 111 | #output3 = self.sigmoid(output3) 112 | output4 = self.final4(x0_4) 113 | #output4 = self.sigmoid(output4) 114 | return [output1, output2, output3, output4] 115 | 116 | else: 117 | output = self.final(x0_4) 118 | #output = self.sigmoid(output) 119 | return output 120 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchviz import make_dot 4 | 5 | __all__ = ['DoubleConv', 'DoubleConv_GN', 'Attention_block', 'InceptionBlock', 'ResNetBlock'] 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | 10 | def __init__(self, in_channels, out_channels): 11 | super(DoubleConv, self).__init__() 12 | self.conv = nn.Sequential( 13 | #first convolution 14 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 15 | nn.BatchNorm2d(out_channels), 16 | #nn.Dropout2d(0.05, inplace=True), 17 | nn.ReLU(inplace=True), 18 | 19 | #2nd convolution 20 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | #nn.Dropout2d(0.10, inplace=True), 23 | nn.ReLU(inplace=True), 24 | ) 25 | 26 | 27 | def forward(self, x): 28 | return self.conv(x) 29 | 30 | class DoubleConv_GN(nn.Module): 31 | 32 | def __init__(self, in_channels, out_channels): 33 | super(DoubleConv_GN, self).__init__() 34 | self.conv = nn.Sequential( 35 | #first convolution 36 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 37 | nn.GroupNorm(out_channels//8, out_channels), 38 | nn.ReLU(inplace=True), 39 | 40 | #2nd convolution 41 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 42 | nn.GroupNorm(out_channels//8, out_channels), 43 | nn.ReLU(inplace=True), 44 | ) 45 | 46 | 47 | def forward(self, x): 48 | return self.conv(x) 49 | 50 | 51 | class Attention_block(nn.Module): 52 | def __init__(self,F_g,F_l,F_int): 53 | super(Attention_block,self).__init__() 54 | self.W_g = nn.Sequential( 55 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 56 | nn.BatchNorm2d(F_int) 57 | ) 58 | 59 | self.W_x = nn.Sequential( 60 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 61 | nn.BatchNorm2d(F_int) 62 | ) 63 | 64 | self.psi = nn.Sequential( 65 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 66 | nn.BatchNorm2d(1), 67 | nn.Sigmoid() 68 | ) 69 | 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | def forward(self,g,x): 73 | g1 = self.W_g(g) 74 | x1 = self.W_x(x) 75 | psi = self.relu(g1+x1) 76 | psi = self.psi(psi) 77 | 78 | return x*psi 79 | 80 | 81 | 82 | class InceptionBlock(nn.Module): 83 | 84 | def __init__(self, in_channels, out_channels, mid_channels=None): 85 | super(InceptionBlock, self).__init__() 86 | 87 | if not mid_channels: 88 | mid_channels = out_channels 89 | 90 | self.double_conv1 = nn.Sequential( 91 | nn.MaxPool2d(2), 92 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 93 | nn.BatchNorm2d(mid_channels), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 96 | nn.BatchNorm2d(out_channels), 97 | nn.ReLU(inplace=True) 98 | ) 99 | 100 | self.double_conv2 = nn.Sequential( 101 | nn.MaxPool2d(2), 102 | nn.Conv2d(in_channels, mid_channels, kernel_size=5, padding=2), 103 | nn.BatchNorm2d(mid_channels), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(mid_channels, out_channels, kernel_size=5, padding=2), 106 | nn.BatchNorm2d(out_channels), 107 | nn.ReLU(inplace=True) 108 | ) 109 | 110 | self.double_conv3 = nn.Sequential( 111 | nn.MaxPool2d(2), 112 | nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0), 113 | nn.BatchNorm2d(mid_channels), 114 | nn.ReLU(inplace=True), 115 | ) 116 | 117 | self.double_conv4 = nn.Sequential( 118 | nn.MaxPool2d(2), 119 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 120 | nn.BatchNorm2d(mid_channels), 121 | nn.ReLU(inplace=True), 122 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, padding=0), 123 | nn.BatchNorm2d(out_channels), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | #self.upConv = nn.ConvTranspose2d(out_channels*4, out_channels*4, kernel_size=2, stride=2) 128 | 129 | def forward(self, x): 130 | 131 | concat = [self.double_conv1(x), self.double_conv2(x), self.double_conv3(x), self.double_conv4(x)] 132 | output = torch.cat(concat, dim=1) 133 | 134 | return output #self.upConv(output) 135 | 136 | 137 | class ResNetBlock(nn.Module): 138 | def __init__(self, in_channels, out_channels, stride=2, padding=1): 139 | super(ResNetBlock, self).__init__() 140 | 141 | self.conv_block = nn.Sequential( 142 | nn.BatchNorm2d(in_channels), 143 | #nn.GroupNorm(num_groups=in_channels//8,num_channels=in_channels), 144 | #nn.Dropout2d(0.05, inplace=True), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding), 147 | nn.BatchNorm2d(out_channels), 148 | #nn.GroupNorm(num_groups=out_channels//8,num_channels=out_channels), 149 | #nn.Dropout2d(0.10, inplace=True), 150 | nn.ReLU(inplace=True), 151 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 152 | ) 153 | 154 | self.conv_skip = nn.Sequential( 155 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding), 156 | nn.BatchNorm2d(out_channels), 157 | ) 158 | 159 | def forward(self, x): 160 | return self.conv_block(x) + self.conv_skip(x) 161 | 162 | -------------------------------------------------------------------------------- /models/dense_unet.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import DenseNet 2 | from torchvision.models.densenet import _Transition, _load_state_dict 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from collections import OrderedDict 7 | 8 | 9 | 10 | __all__ = ['DenseUNet'] 11 | 12 | 13 | class _DenseUNetEncoder(DenseNet): 14 | def __init__(self, skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, downsample): 15 | super(_DenseUNetEncoder, self).__init__(growth_rate, block_config, num_init_features, bn_size, drop_rate) 16 | 17 | self.skip_connections = skip_connections 18 | 19 | # remove last norm, classifier 20 | features = OrderedDict(list(self.features.named_children())[:-1]) 21 | delattr(self, 'classifier') 22 | if not downsample: 23 | features['conv0'].stride = 1 24 | del features['pool0'] 25 | self.features = nn.Sequential(features) 26 | 27 | for module in self.features.modules(): 28 | if isinstance(module, nn.AvgPool2d): 29 | module.register_forward_hook(lambda _, input, output : self.skip_connections.append(input[0])) 30 | 31 | def forward(self, x): 32 | return self.features(x) 33 | 34 | class _DenseUNetDecoder(DenseNet): 35 | def __init__(self, skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, upsample): 36 | super(_DenseUNetDecoder, self).__init__(growth_rate, block_config, num_init_features, bn_size, drop_rate) 37 | 38 | self.skip_connections = skip_connections 39 | self.upsample = upsample 40 | 41 | # remove conv0, norm0, relu0, pool0, last denseblock, last norm, classifier 42 | features = list(self.features.named_children())[4:-2] 43 | delattr(self, 'classifier') 44 | 45 | num_features = num_init_features 46 | num_features_list = [] 47 | for i, num_layers in enumerate(block_config): 48 | num_input_features = num_features + num_layers * growth_rate 49 | num_output_features = num_features // 2 50 | num_features_list.append((num_input_features, num_output_features)) 51 | num_features = num_input_features // 2 52 | 53 | for i in range(len(features)): 54 | name, module = features[i] 55 | if isinstance(module, _Transition): 56 | num_input_features, num_output_features = num_features_list.pop(1) 57 | features[i] = (name, _TransitionUp(num_input_features, num_output_features, skip_connections)) 58 | 59 | features.reverse() 60 | 61 | self.features = nn.Sequential(OrderedDict(features)) 62 | 63 | num_input_features, _ = num_features_list.pop(0) 64 | 65 | if upsample: 66 | self.features.add_module('upsample0', nn.Upsample(scale_factor=4, mode='bilinear')) 67 | # self.features.add_module('norm0', nn.BatchNorm2d(num_input_features)) 68 | self.add_module("group_norm0", nn.GroupNorm(num_groups= (num_input_features)//4,num_channels=num_input_features)) 69 | self.features.add_module('relu0', nn.ReLU(inplace=True)) 70 | self.features.add_module('conv0', nn.Conv2d(num_input_features, num_init_features, kernel_size=1, stride=1, bias=False)) 71 | #self.features.add_module('norm1', nn.BatchNorm2d(num_init_features)) 72 | self.add_module("group_norm2", nn.GroupNorm(num_groups= (num_init_features)//8,num_channels=num_init_features)) 73 | 74 | def forward(self, x): 75 | return self.features(x) 76 | 77 | 78 | class _Concatenate(nn.Module): 79 | def __init__(self, skip_connections): 80 | super(_Concatenate, self).__init__() 81 | self.skip_connections = skip_connections 82 | 83 | def forward(self, x): 84 | return torch.cat([x, self.skip_connections.pop()], 1) 85 | 86 | 87 | class _TransitionUp(nn.Sequential): 88 | def __init__(self, num_input_features, num_output_features, skip_connections): 89 | super(_TransitionUp, self).__init__() 90 | 91 | #self.add_module('norm1', nn.BatchNorm2d(num_input_features)) 92 | self.add_module("group_norm1", nn.GroupNorm(num_groups= (num_input_features)//4,num_channels=num_input_features)) 93 | self.add_module('relu1', nn.ReLU(inplace=True)) 94 | self.add_module('conv1', nn.Conv2d(num_input_features, num_output_features * 2, 95 | kernel_size=1, stride=1, bias=False)) 96 | 97 | self.add_module('upsample', nn.Upsample(scale_factor=2, mode='bilinear')) 98 | self.add_module('cat', _Concatenate(skip_connections)) 99 | #self.add_module('norm2', nn.BatchNorm2d(num_output_features * 4)) 100 | self.add_module("group_norm2", nn.GroupNorm(num_groups= (num_output_features * 4)//8,num_channels=num_output_features * 4)) 101 | self.add_module('relu2', nn.ReLU(inplace=True)) 102 | self.add_module('conv2', nn.Conv2d(num_output_features * 4, num_output_features, 103 | kernel_size=1, stride=1, bias=False)) 104 | 105 | class DenseUNet(nn.Module): 106 | def __init__(self, n_classes=1, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, downsample=False, pretrained_encoder_uri=None, progress=None): 107 | super(DenseUNet, self).__init__() 108 | self.skip_connections = [] 109 | self.encoder = _DenseUNetEncoder(self.skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, downsample) 110 | self.decoder = _DenseUNetDecoder(self.skip_connections, growth_rate, block_config, num_init_features, bn_size, drop_rate, downsample) 111 | self.output_layer = nn.Conv2d(num_init_features, n_classes, kernel_size=1, stride=1, bias=True) 112 | #self.softmax = nn.Softmax(dim=1) 113 | 114 | self.encoder._load_state_dict = self.encoder.load_state_dict 115 | self.encoder.load_state_dict = lambda state_dict : self.encoder._load_state_dict(state_dict, strict=False) 116 | if pretrained_encoder_uri: 117 | _load_state_dict(self.encoder, str(pretrained_encoder_uri), progress) 118 | self.encoder.load_state_dict = lambda state_dict : self.encoder._load_state_dict(state_dict, strict=True) 119 | 120 | def forward(self, x): 121 | x = self.encoder(x) 122 | x = self.decoder(x) 123 | #y = self.classifier(x) 124 | return self.output_layer(x) #self.softmax(y) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import Teeth_Dataset 3 | from torch.utils.data import DataLoader 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from torchvision import transforms as t 7 | 8 | def get_loaders( 9 | images_dir, 10 | masks_dir, 11 | batch_size, 12 | train_images_transform, 13 | train_masks_transform, 14 | test_images_transform, 15 | test_masks_transform, 16 | data_dict, 17 | ): 18 | 19 | train_ds = Teeth_Dataset( 20 | images_dir = images_dir, 21 | masks_dir = masks_dir, 22 | data_dict=data_dict, 23 | data_type='train', 24 | transform = train_images_transform, 25 | target_transform = train_masks_transform) 26 | 27 | validation_ds = Teeth_Dataset( 28 | images_dir = images_dir, 29 | masks_dir = masks_dir, 30 | data_dict=data_dict, 31 | data_type='validation', 32 | transform = test_images_transform, 33 | target_transform = test_masks_transform) 34 | 35 | test_ds = Teeth_Dataset( 36 | images_dir = images_dir, 37 | masks_dir = masks_dir, 38 | data_dict=data_dict, 39 | data_type='test', 40 | transform = test_images_transform, 41 | target_transform = test_masks_transform) 42 | 43 | train_dl = DataLoader( 44 | dataset = train_ds, 45 | batch_size = batch_size, 46 | shuffle = True, 47 | ) 48 | 49 | test_dl = DataLoader( 50 | dataset = test_ds, 51 | batch_size = batch_size, 52 | shuffle = False, 53 | ) 54 | validation_dl = DataLoader( 55 | dataset = validation_ds, 56 | batch_size = batch_size, 57 | shuffle = False, 58 | ) 59 | 60 | return train_dl, validation_dl, test_dl 61 | 62 | 63 | def evaluate(preds, targets): 64 | """ 65 | Returns specificty, precision, recall and f1_score 66 | 67 | """ 68 | 69 | confusion_vector = preds / targets 70 | # Element-wise division of the 2 tensors returns a new tensor which holds a 71 | # unique value for each case: 72 | # 1 where prediction and truth are 1 (True Positive) 73 | # inf where prediction is 1 and truth is 0 (False Positive) 74 | # nan where prediction and truth are 0 (True Negative) 75 | # 0 where prediction is 0 and truth is 1 (False Negative) 76 | 77 | true_positives = torch.sum(confusion_vector == 1).item() 78 | false_positives = torch.sum(confusion_vector == float('inf')).item() 79 | true_negatives = torch.sum(torch.isnan(confusion_vector)).item() 80 | false_negatives = torch.sum(confusion_vector == 0).item() 81 | 82 | ### precision, recall, f1_score and specificity 83 | specificity = true_negatives / (true_negatives + false_positives) 84 | precision = true_positives / (true_positives + false_positives) 85 | recall = true_positives / (true_positives + false_negatives) 86 | f1_score = (2.0 * (recall*precision)) / (recall + precision) 87 | 88 | dict = { 89 | 'specificity': specificity, 90 | 'precision': precision, 91 | 'recall': recall, 92 | 'f1_score': f1_score 93 | } 94 | 95 | return dict 96 | 97 | def dice_coeff(pred, target): 98 | smooth = 1. 99 | num = pred.size(0) 100 | m1 = pred.view(num, -1) # Flatten 101 | m2 = target.view(num, -1) # Flatten 102 | intersection = (m1 * m2).sum() 103 | 104 | return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth) 105 | 106 | def check_accuracy(loader, model, device="cuda", threshold=0.5, test=False): 107 | num_correct = 0 108 | num_pixels = 0 109 | dice_score = 0 110 | model.eval() 111 | if test: 112 | f1_score, precision, recall, specificity = 0.0, 0.0 , 0.0 , 0.0 113 | 114 | with torch.no_grad(): 115 | for _, (x, y) in enumerate(loader): 116 | 117 | x = x.to(device) 118 | y = y.to(device) #.unsqueeze(1) 119 | 120 | ## for unet plus plus 121 | preds = torch.sigmoid((model(x))) 122 | 123 | # for unet 124 | 125 | preds = (preds > threshold).float() 126 | num_correct += (preds == y).sum() 127 | num_pixels += torch.numel(preds) 128 | dice_score += dice_coeff(preds, y) 129 | 130 | if test: 131 | temp_dict = evaluate(preds, y) 132 | f1_score += temp_dict['f1_score'] 133 | precision += temp_dict['precision'] 134 | recall += temp_dict['recall'] 135 | specificity += temp_dict['specificity'] 136 | 137 | 138 | 139 | 140 | 141 | accuracy = num_correct/num_pixels*100 142 | dice_score = (dice_score/(len(loader)))*100 143 | 144 | accuracy, dice_score = accuracy.detach().cpu().item() , dice_score.detach().cpu().item() 145 | 146 | if test: 147 | f1_score = (f1_score/(len(loader)))*100 148 | precision = (precision/(len(loader)))*100 149 | recall = (recall/(len(loader)))*100 150 | specificity = (specificity/(len(loader)))*100 151 | 152 | dict = { 153 | 'specificity': specificity, 154 | 'precision': precision, 155 | 'recall': recall, 156 | 'f1_score': f1_score, 157 | 'accuracy':accuracy, 158 | 'dice_score': dice_score, 159 | } 160 | return dict 161 | 162 | else: 163 | print(f"Got {num_correct}/{num_pixels} with acc {accuracy:.2f}" ) 164 | print(f"Dice score: {dice_score :.2f}") 165 | 166 | return accuracy, dice_score 167 | 168 | 169 | def validation_loss(model, validation_dl, loss_fn, device): 170 | total_loss = 0.0 171 | 172 | for x, y in validation_dl: 173 | x, y = x.to(device), y.to(device) 174 | 175 | 176 | preds = model(x) 177 | loss = loss_fn(preds, y) 178 | 179 | total_loss += loss.detach().cpu().item() 180 | 181 | return total_loss/len(validation_dl) 182 | 183 | 184 | def train_fn(train_dl, model, optimizer, loss_fn, device): 185 | mean_loss = 0 186 | 187 | for _, (data, targets) in enumerate(train_dl): 188 | 189 | data = data.to(device=device) 190 | targets = targets.to(device=device) 191 | 192 | optimizer.zero_grad() 193 | # forward 194 | predictions = model(data) 195 | 196 | loss = loss_fn(predictions, targets) 197 | # backward 198 | loss.backward() 199 | optimizer.step() 200 | mean_loss += loss.detach().cpu().item() 201 | 202 | return mean_loss/(len(train_dl)) 203 | 204 | 205 | def Fit(model, train_dl, validation_dl, loss_fn, optimizer, epochs, device, writer): 206 | train_accuracies = [] 207 | validation_accuracies = [] 208 | train_dice_scores = [] 209 | validation_dice_scores = [] 210 | train_losses = [] 211 | validation_losses = [] 212 | 213 | 214 | print("Training started ::: **************** ") 215 | for epoch in range(epochs): 216 | print("\nEpoch: ", epoch) 217 | train_loss = train_fn( 218 | train_dl=train_dl, 219 | model=model, 220 | optimizer=optimizer, 221 | loss_fn=loss_fn, 222 | device=device, 223 | ) 224 | 225 | ## Training accuracy 226 | print("\nResults for Training data: ") 227 | train_accuracy, train_ds = check_accuracy( 228 | loader=train_dl, 229 | model=model, 230 | device=device, 231 | threshold=0.5, 232 | ) 233 | 234 | ## Validation accuracy 235 | print("\nResults for Validation data: ") 236 | validation_accuracy, validation_ds = check_accuracy( 237 | loader=validation_dl, 238 | model=model, 239 | device=device, 240 | threshold=0.5, 241 | ) 242 | 243 | validation_loss_ = validation_loss(model, validation_dl, loss_fn, device) 244 | 245 | writer.add_scalar('Training Loss', train_loss, epoch) 246 | writer.add_scalar('Validation Loss', validation_loss_, epoch) 247 | writer.add_scalar('Training Accuracy', train_accuracy, epoch) 248 | writer.add_scalar('Validation Accuracy', validation_accuracy, epoch) 249 | writer.add_scalar('Training Dice Score', train_ds, epoch) 250 | writer.add_scalar('Validation Dice Score', validation_ds, epoch) 251 | 252 | 253 | train_accuracies.append(train_accuracy) 254 | validation_accuracies.append(validation_accuracy) 255 | 256 | train_dice_scores.append(train_ds) 257 | validation_dice_scores.append(validation_ds) 258 | 259 | train_losses.append(train_loss) 260 | validation_losses.append(validation_loss_) 261 | 262 | history = { 263 | 'model': model, 264 | 'epochs': epochs, 265 | 'train_losses':train_losses, 266 | 'validation_losses': validation_losses, 267 | 'train_accuracies': train_accuracies, 268 | 'train_dice_scores':train_dice_scores, 269 | 'validation_accuracies': validation_accuracies, 270 | 'validation_dice_scores': validation_dice_scores 271 | } 272 | 273 | print("Done") 274 | 275 | return history 276 | 277 | def plot_graph(x, y1, y2, x_label, y_label, title): 278 | 279 | plt.title(title) 280 | plt.plot(x, y1, '-b', label='train') 281 | plt.plot(x, y2, '-r', label='validation') 282 | plt.xlabel(x_label) 283 | plt.legend() 284 | #plt.ylabel(y_label) 285 | plt.savefig(f'{title}.png') 286 | plt.show() 287 | 288 | 289 | 290 | def plot_history(history): 291 | epochs_list = np.arange(0, history['epochs'], 1).tolist() 292 | 293 | plot_graph( 294 | x = epochs_list, 295 | y1 = history['train_losses'], 296 | y2 = history['validation_losses'], 297 | x_label= "n iterations", 298 | y_label= "losses", 299 | title= "Iteration vs losses", 300 | ) 301 | 302 | plot_graph( 303 | x = epochs_list, 304 | y1 = history['train_accuracies'], 305 | y2 = history['validation_accuracies'], 306 | x_label= "n iterations", 307 | y_label= "accuracies", 308 | title= "Iteration vs accuracies", 309 | ) 310 | 311 | plot_graph( 312 | x = epochs_list, 313 | y1 = history['train_dice_scores'], 314 | y2 = history['validation_dice_scores'], 315 | x_label= "n iterations", 316 | y_label= "dice scores", 317 | title= "Iteration vs dice scores", 318 | ) 319 | 320 | 321 | def visualize_random_image(model, loader, device, threshold, width, height): 322 | 323 | rand_batch = torch.randint(0, len(loader), (1,)).item() 324 | 325 | for batch, (x, y) in enumerate(loader): 326 | 327 | if batch == rand_batch: 328 | x = x.to(device) 329 | y = y.to(device) #.unsqueeze(1) 330 | 331 | preds = torch.sigmoid((model(x))) 332 | 333 | preds = (preds > threshold).float() * 255.0 334 | y = y * 255.0 335 | 336 | 337 | preds = preds[0].view(height, width) 338 | y = y[0].view(height, width) 339 | 340 | y, preds = y.detach().cpu(), preds.detach().cpu() 341 | 342 | 343 | figure = plt.figure(figsize=(4,4)) 344 | plt.title(f'test image plot batch size {rand_batch}, first sample. (orignal, predictions)') 345 | figure.add_subplot(1,2, 1) 346 | plt.imshow(y) 347 | figure.add_subplot(1,2, 2) 348 | plt.imshow(preds) 349 | 350 | plt.savefig(f'batch_{rand_batch}_sample_0 (orignal, predictions).png') 351 | plt.show() 352 | 353 | 354 | 355 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms.functional as TF 4 | from torchviz import make_dot 5 | 6 | from .modules import DoubleConv, Attention_block, InceptionBlock, ResNetBlock, DoubleConv_GN 7 | from ..attentions import CBAM 8 | from ..attentions import SpatialAttention 9 | 10 | __all__ = ['UNET', 'UNET_GN', 'Attention_UNET', 'CustomAttention_UNET', 'Inception_UNET', 'Inception_Attention_UNET', 'ResUNET', 'ResUNETPlus', 11 | 'ResUNET_with_CBAM', 'ResUNET_with_GN'] 12 | 13 | class UNET(nn.Module): 14 | def __init__( 15 | self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], 16 | ): 17 | super(UNET, self).__init__() 18 | #module list for encoder layers 19 | self.downs = nn.ModuleList() 20 | #max pooling 21 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 22 | #module list for the decoder layers 23 | self.ups = nn.ModuleList() 24 | 25 | # Down part of UNET 26 | for feature in features: 27 | self.downs.append(DoubleConv(in_channels, feature)) 28 | in_channels = feature 29 | 30 | # Up part of UNET 31 | for feature in reversed(features): 32 | ### up convolution 33 | self.ups.append( 34 | nn.ConvTranspose2d( 35 | feature*2, feature, kernel_size=2, stride=2, 36 | ) 37 | ) 38 | 39 | ### double convolution 40 | self.ups.append(DoubleConv(feature*2, feature)) 41 | 42 | self.bottleneck = DoubleConv(features[-1], features[-1]*2) 43 | self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 44 | 45 | def forward(self, x): 46 | skip_connections = [] 47 | 48 | for down in self.downs: 49 | x = down(x) 50 | skip_connections.append(x) 51 | x = self.pool(x) 52 | 53 | x = self.bottleneck(x) 54 | #reversing the list of the skip connections 55 | skip_connections = skip_connections[::-1] 56 | 57 | for idx in range(0, len(self.ups), 2): 58 | x = self.ups[idx](x) 59 | skip_connection = skip_connections[idx//2] 60 | 61 | #if the resolution of the input image is not completely devisible, then it will skip the reminder 62 | # and the resolution will not be equal in this case, so we are resizing it incase in they are not equal 63 | if x.shape != skip_connection.shape: 64 | x = TF.resize(x, size=skip_connection.shape[2:]) 65 | 66 | 67 | #concatenating the skip connections with x 68 | concat_skip = torch.cat((skip_connection, x), dim=1) 69 | 70 | #passing the concatenated ouptut, to the double convolutional layers 71 | x = self.ups[idx+1](concat_skip) 72 | 73 | return self.final_conv(x) 74 | 75 | class UNET_GN(nn.Module): 76 | def __init__( 77 | self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], 78 | ): 79 | super(UNET_GN, self).__init__() 80 | #module list for encoder layers 81 | self.downs = nn.ModuleList() 82 | #max pooling 83 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 84 | #module list for the decoder layers 85 | self.ups = nn.ModuleList() 86 | 87 | # Down part of UNET 88 | for feature in features: 89 | self.downs.append(DoubleConv_GN(in_channels, feature)) 90 | in_channels = feature 91 | 92 | # Up part of UNET 93 | for feature in reversed(features): 94 | ### up convolution 95 | self.ups.append( 96 | nn.ConvTranspose2d( 97 | feature*2, feature, kernel_size=2, stride=2, 98 | ) 99 | ) 100 | 101 | ### double convolution 102 | self.ups.append(DoubleConv_GN(feature*2, feature)) 103 | 104 | self.bottleneck = DoubleConv_GN(features[-1], features[-1]*2) 105 | self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 106 | 107 | def forward(self, x): 108 | skip_connections = [] 109 | 110 | for down in self.downs: 111 | x = down(x) 112 | skip_connections.append(x) 113 | x = self.pool(x) 114 | 115 | x = self.bottleneck(x) 116 | #reversing the list of the skip connections 117 | skip_connections = skip_connections[::-1] 118 | 119 | for idx in range(0, len(self.ups), 2): 120 | x = self.ups[idx](x) 121 | skip_connection = skip_connections[idx//2] 122 | 123 | #if the resolution of the input image is not completely devisible, then it will skip the reminder 124 | # and the resolution will not be equal in this case, so we are resizing it incase in they are not equal 125 | if x.shape != skip_connection.shape: 126 | x = TF.resize(x, size=skip_connection.shape[2:]) 127 | 128 | 129 | #concatenating the skip connections with x 130 | concat_skip = torch.cat((skip_connection, x), dim=1) 131 | 132 | #passing the concatenated ouptut, to the double convolutional layers 133 | x = self.ups[idx+1](concat_skip) 134 | 135 | return self.final_conv(x) 136 | 137 | 138 | class Attention_UNET(nn.Module): 139 | def __init__( 140 | self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], 141 | ): 142 | super(Attention_UNET, self).__init__() 143 | #module list for encoder layers 144 | self.downs = nn.ModuleList() 145 | #max pooling 146 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 147 | #module list for the decoder layers 148 | self.ups = nn.ModuleList() 149 | 150 | # Down part of UNET 151 | for feature in features: 152 | self.downs.append(DoubleConv(in_channels, feature)) 153 | in_channels = feature 154 | 155 | # Up part of UNET 156 | for feature in reversed(features): 157 | ### up convolution 158 | self.ups.append( 159 | nn.ConvTranspose2d( 160 | feature*2, feature, kernel_size=2, stride=2, 161 | ) 162 | ) 163 | 164 | ### attention module 165 | self.ups.append( 166 | Attention_block(F_g=feature, F_l=feature, F_int=feature//2) 167 | ) 168 | 169 | ### double convolution 170 | self.ups.append(DoubleConv(feature*2, feature)) 171 | 172 | self.bottleneck = DoubleConv(features[-1], features[-1]*2) 173 | self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 174 | 175 | def forward(self, x): 176 | skip_connections = [] 177 | 178 | for down in self.downs: 179 | x = down(x) 180 | skip_connections.append(x) 181 | x = self.pool(x) 182 | 183 | x = self.bottleneck(x) 184 | #reversing the list of the skip connections 185 | skip_connections = skip_connections[::-1] 186 | 187 | for idx in range(0, len(self.ups), 3): 188 | x = self.ups[idx](x) 189 | skip_connection = skip_connections[idx//3] 190 | 191 | #if the resolution of the input image is not completely devisible, then it will skip the reminder 192 | # and the resolution will not be equal in this case, so we are resizing it incase in they are not equal 193 | if x.shape != skip_connection.shape: 194 | x = TF.resize(x, size=skip_connection.shape[2:]) 195 | 196 | 197 | #Attention module 198 | Attention_output = self.ups[idx+1](g= x , x=skip_connection) 199 | 200 | #concatenating the skip connections with x 201 | concat_skip = torch.cat((Attention_output, x), dim=1) 202 | 203 | #passing the concatenated ouptut, to the double convolutional layers 204 | x = self.ups[idx+2](concat_skip) 205 | 206 | return self.final_conv(x) 207 | 208 | 209 | class CustomAttention_UNET(nn.Module): 210 | def __init__( 211 | self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], 212 | ): 213 | super(CustomAttention_UNET, self).__init__() 214 | #module list for encoder layers 215 | self.downs = nn.ModuleList() 216 | #max pooling 217 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 218 | #module list for the decoder layers 219 | self.ups = nn.ModuleList() 220 | 221 | # Down part of UNET 222 | for feature in features: 223 | self.downs.append(DoubleConv_GN(in_channels, feature)) 224 | in_channels = feature 225 | 226 | # Up part of UNET 227 | for feature in reversed(features): 228 | ### up convolution 229 | self.ups.append( 230 | nn.ConvTranspose2d( 231 | feature*2, feature, kernel_size=2, stride=2, 232 | ) 233 | ) 234 | 235 | ### attention module 236 | self.ups.append( 237 | SpatialAttention(in_channel=feature) 238 | ) 239 | 240 | ### attention module 241 | self.ups.append( 242 | Attention_block(F_g=feature, F_l=feature, F_int=feature//2) 243 | ) 244 | 245 | ### double convolution 246 | self.ups.append(DoubleConv_GN(feature*2, feature)) 247 | 248 | self.bottleneck = DoubleConv_GN(features[-1], features[-1]*2) 249 | self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 250 | 251 | def forward(self, x): 252 | skip_connections = [] 253 | 254 | for down in self.downs: 255 | x = down(x) 256 | skip_connections.append(x) 257 | x = self.pool(x) 258 | 259 | x = self.bottleneck(x) 260 | #reversing the list of the skip connections 261 | skip_connections = skip_connections[::-1] 262 | 263 | for idx in range(0, len(self.ups), 4): 264 | x = self.ups[idx](x) 265 | skip_connection = skip_connections[idx//4] 266 | 267 | #if the resolution of the input image is not completely devisible, then it will skip the reminder 268 | # and the resolution will not be equal in this case, so we are resizing it incase in they are not equal 269 | if x.shape != skip_connection.shape: 270 | x = TF.resize(x, size=skip_connection.shape[2:]) 271 | 272 | 273 | #Attention module 274 | spatial_output = self.ups[idx+1](g= x , x=skip_connection) 275 | 276 | #attention unet 277 | # Attention_output = self.ups[idx+2](g= x , x=skip_connection)#x=spatial_output) 278 | 279 | ### 280 | #combined_attention = spatial_output * Attention_output 281 | 282 | #concatenating the skip connections with x 283 | concat_skip = torch.cat((spatial_output, x), dim=1) 284 | 285 | 286 | 287 | #passing the concatenated ouptut, to the double convolutional layers 288 | x = self.ups[idx+3](concat_skip) 289 | 290 | return self.final_conv(x) 291 | 292 | 293 | class Inception_UNET(nn.Module): 294 | def __init__( 295 | self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], 296 | ): 297 | super(Inception_UNET, self).__init__() 298 | #module list for encoder layers 299 | self.downs = nn.ModuleList() 300 | #max pooling 301 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 302 | #module list for the decoder layers 303 | self.ups = nn.ModuleList() 304 | 305 | #module list for inpception blocks 306 | self.inception_blocks = nn.ModuleList() 307 | 308 | #up convolution for inception block 309 | self.ups_inception_block = nn.ModuleList() 310 | 311 | 312 | #inception blocks 313 | for index in range(0, len(features)): 314 | if features[index] == features[-1]: 315 | out_ch = features[index]//4 316 | else: 317 | out_ch = features[index+1]//4 318 | self.inception_blocks.append(InceptionBlock(features[index], out_ch)) 319 | self.ups_inception_block.append(nn.ConvTranspose2d(out_ch*4, out_ch*4, kernel_size=2, stride=2)) 320 | 321 | 322 | # Down part of UNET 323 | for feature in features: 324 | self.downs.append(DoubleConv(in_channels, feature)) 325 | in_channels = feature 326 | 327 | # Up part of UNET 328 | for feature, feature2 in zip(reversed(features), [1536,1024, 512, 256]): 329 | ### up convolution 330 | self.ups.append( 331 | nn.ConvTranspose2d( 332 | feature*2, feature, kernel_size=2, stride=2, 333 | ) 334 | ) 335 | 336 | ### double convolution 337 | self.ups.append(DoubleConv(feature2, feature)) 338 | 339 | 340 | 341 | self.bottleneck = DoubleConv(features[-1], features[-1]*2) 342 | self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 343 | 344 | 345 | 346 | def forward(self, x): 347 | skip_connections = [] 348 | 349 | for down in self.downs: 350 | x = down(x) 351 | skip_connections.append(x) 352 | x = self.pool(x) 353 | 354 | x = self.bottleneck(x) 355 | 356 | 357 | #inception blocks 358 | output_blocks = [] 359 | for index in range(0, len(self.inception_blocks)): 360 | if index == 0: 361 | block = skip_connections[0] 362 | 363 | block = self.inception_blocks[index](block) 364 | output_blocks.append(block) 365 | 366 | #up convolution for the inception block 367 | up_convolved_block = [] 368 | for index, upconv in enumerate(self.ups_inception_block): 369 | up_convolved_block.append(upconv(output_blocks[index])) 370 | 371 | 372 | #reversing the list of the skip connections 373 | skip_connections = skip_connections[::-1] 374 | 375 | for idx in range(0, len(self.ups), 2): 376 | x = self.ups[idx](x) 377 | skip_connection = skip_connections[idx//2] 378 | 379 | #if the resolution of the input image is not completely devisible, then it will skip the reminder 380 | # and the resolution will not be equal in this case, so we are resizing it incase in they are not equal 381 | if x.shape != skip_connection.shape: 382 | x = TF.resize(x, size=skip_connection.shape[2:]) 383 | 384 | 385 | concat_skip = torch.cat((skip_connection,up_convolved_block[(-(idx//2 + 1))], x), dim=1) 386 | 387 | #passing the concatenated ouptut, to the double convolutional layers 388 | x = self.ups[idx+1](concat_skip) 389 | 390 | return self.final_conv(x) 391 | 392 | 393 | 394 | class Inception_Attention_UNET(nn.Module): 395 | def __init__( 396 | self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], 397 | ): 398 | super(Inception_Attention_UNET, self).__init__() 399 | #module list for encoder layers 400 | self.downs = nn.ModuleList() 401 | #max pooling 402 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 403 | #module list for the decoder layers 404 | self.ups = nn.ModuleList() 405 | 406 | #module list for inpception blocks 407 | self.inception_blocks = nn.ModuleList() 408 | 409 | #up convolution for inception block 410 | self.ups_inception_block = nn.ModuleList() 411 | 412 | 413 | #inception blocks 414 | for index in range(0, len(features)): 415 | if features[index] == features[-1]: 416 | out_ch = features[index]//4 417 | else: 418 | out_ch = features[index+1]//4 419 | self.inception_blocks.append(InceptionBlock(features[index], out_ch)) 420 | self.ups_inception_block.append(nn.ConvTranspose2d(out_ch*4, out_ch*4, kernel_size=2, stride=2)) 421 | 422 | 423 | # Down part of UNET 424 | for feature in features: 425 | self.downs.append(DoubleConv(in_channels, feature)) 426 | in_channels = feature 427 | 428 | # Up part of UNET 429 | for feature, feature2 in zip(reversed(features), [1536,1024, 512, 256]): 430 | ### up convolution 431 | self.ups.append( 432 | nn.ConvTranspose2d( 433 | feature*2, feature, kernel_size=2, stride=2, 434 | ) 435 | ) 436 | 437 | ### attention module 438 | self.ups.append( 439 | Attention_block(F_g=feature, F_l=feature, F_int=feature//2) 440 | ) 441 | 442 | ### double convolution 443 | self.ups.append(DoubleConv(feature2, feature)) 444 | 445 | 446 | 447 | self.bottleneck = DoubleConv(features[-1], features[-1]*2) 448 | self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 449 | 450 | 451 | 452 | def forward(self, x): 453 | skip_connections = [] 454 | 455 | for down in self.downs: 456 | x = down(x) 457 | skip_connections.append(x) 458 | x = self.pool(x) 459 | 460 | x = self.bottleneck(x) 461 | 462 | 463 | #inception blocks 464 | output_blocks = [] 465 | for index in range(0, len(self.inception_blocks)): 466 | if index == 0: 467 | block = skip_connections[0] 468 | 469 | block = self.inception_blocks[index](block) 470 | output_blocks.append(block) 471 | 472 | #up convolution for the inception block 473 | up_convolved_block = [] 474 | for index, upconv in enumerate(self.ups_inception_block): 475 | up_convolved_block.append(upconv(output_blocks[index])) 476 | 477 | 478 | #reversing the list of the skip connections 479 | skip_connections = skip_connections[::-1] 480 | 481 | for idx in range(0, len(self.ups), 3): 482 | x = self.ups[idx](x) 483 | skip_connection = skip_connections[idx//3] 484 | 485 | #if the resolution of the input image is not completely devisible, then it will skip the reminder 486 | # and the resolution will not be equal in this case, so we are resizing it incase in they are not equal 487 | if x.shape != skip_connection.shape: 488 | x = TF.resize(x, size=skip_connection.shape[2:]) 489 | 490 | #Attention module 491 | Attention_output = self.ups[idx+1](g=x, x=skip_connection) 492 | 493 | concat_skip = torch.cat((Attention_output,up_convolved_block[(-(idx//3 + 1))], x), dim=1) 494 | 495 | #passing the concatenated ouptut, to the double convolutional layers 496 | x = self.ups[idx+2](concat_skip) 497 | 498 | return self.final_conv(x) 499 | 500 | 501 | 502 | class ResUNET(nn.Module): 503 | def __init__(self, in_channels=3, out_channels=1, filters=[64, 128, 256, 512]): 504 | super(ResUNET, self).__init__() 505 | 506 | ## input and encoder blocks 507 | self.input_layer = nn.Sequential( 508 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1), 509 | nn.BatchNorm2d(filters[0]), 510 | nn.ReLU(), 511 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 512 | ) 513 | self.input_skip = nn.Sequential( 514 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) 515 | ) 516 | 517 | self.residual_conv_1 = ResNetBlock(filters[0], filters[1], stride=2, padding=1) 518 | self.residual_conv_2 = ResNetBlock(filters[1], filters[2], stride=2, padding=1) 519 | 520 | 521 | ## bridge 522 | self.bridge = ResNetBlock(filters[2], filters[3], stride=2, padding=1) 523 | 524 | ## decoder blocks 525 | self.upsample_1 = nn.ConvTranspose2d(filters[3],filters[3], kernel_size=2, stride=2) 526 | self.up_residual_conv_1 = ResNetBlock(filters[3]+filters[2],filters[2], stride=1, padding=1) 527 | 528 | self.upsample_2 = nn.ConvTranspose2d(filters[2],filters[2], kernel_size=2, stride=2) 529 | self.up_residual_conv_2 = ResNetBlock(filters[2]+filters[1],filters[1], stride=1, padding=1) 530 | 531 | self.upsample_3 = nn.ConvTranspose2d(filters[1],filters[1], kernel_size=2, stride=2) 532 | self.up_residual_conv_3 = ResNetBlock(filters[1]+filters[0],filters[0], stride=1, padding=1) 533 | 534 | ## output layer 535 | self.output_layer = nn.Conv2d(filters[0], out_channels, kernel_size=1, stride=1,) 536 | 537 | def forward(self, x): 538 | 539 | ## Encoder 540 | x1 = self.input_layer(x) + self.input_skip(x) 541 | x2 = self.residual_conv_1(x1) 542 | x3 = self.residual_conv_2(x2) 543 | 544 | ## Bridge 545 | x4 = self.bridge(x3) 546 | 547 | ## Decoder 548 | x4 = self.upsample_1(x4) 549 | x5 = torch.cat([x4, x3], dim=1) 550 | x6 = self.up_residual_conv_1(x5) 551 | 552 | x6 = self.upsample_2(x6) 553 | x7 = torch.cat([x6, x2], dim=1) 554 | x8 = self.up_residual_conv_2(x7) 555 | 556 | x8 = self.upsample_3(x8) 557 | x9 = torch.cat([x8, x1], dim=1) 558 | x10 = self.up_residual_conv_3(x9) 559 | 560 | output = self.output_layer(x10) 561 | 562 | return output 563 | 564 | 565 | class ResUNETPlus(nn.Module): 566 | def __init__(self, in_channels=3, out_channels=1, filters=[64, 128, 256, 512, 1024]): 567 | super(ResUNETPlus, self).__init__() 568 | 569 | ## input and encoder blocks 570 | self.input_layer = nn.Sequential( 571 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1), 572 | nn.BatchNorm2d(filters[0]), 573 | nn.ReLU(), 574 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 575 | ) 576 | self.input_skip = nn.Sequential( 577 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) 578 | ) 579 | 580 | self.residual_conv_1 = ResNetBlock(filters[0], filters[1], stride=2, padding=1) 581 | self.residual_conv_2 = ResNetBlock(filters[1], filters[2], stride=2, padding=1) 582 | self.residual_conv_3 = ResNetBlock(filters[2], filters[3], stride=2, padding=1) 583 | 584 | 585 | ## bridge 586 | self.bridge = ResNetBlock(filters[3], filters[4], stride=2, padding=1) 587 | 588 | ## decoder blocks 589 | self.upsample_1 = nn.ConvTranspose2d(filters[4],filters[4], kernel_size=2, stride=2) 590 | self.up_residual_conv_1 = ResNetBlock(filters[4]+filters[3],filters[3], stride=1, padding=1) 591 | 592 | self.upsample_2 = nn.ConvTranspose2d(filters[3],filters[3], kernel_size=2, stride=2) 593 | self.up_residual_conv_2 = ResNetBlock(filters[3]+filters[2],filters[2], stride=1, padding=1) 594 | 595 | self.upsample_3 = nn.ConvTranspose2d(filters[2],filters[2], kernel_size=2, stride=2) 596 | self.up_residual_conv_3 = ResNetBlock(filters[2]+filters[1],filters[1], stride=1, padding=1) 597 | 598 | self.upsample_4 = nn.ConvTranspose2d(filters[1],filters[1], kernel_size=2, stride=2) 599 | self.up_residual_conv_4 = ResNetBlock(filters[1]+filters[0],filters[0], stride=1, padding=1) 600 | 601 | ## output layer 602 | self.output_layer = nn.Sequential( 603 | nn.Conv2d(filters[0], out_channels, kernel_size=1, stride=1,), 604 | #nn.Sigmoid(), 605 | ) 606 | 607 | def forward(self, x): 608 | 609 | ## Encoder 610 | x1 = self.input_layer(x) + self.input_skip(x) 611 | x2 = self.residual_conv_1(x1) 612 | x3 = self.residual_conv_2(x2) 613 | x4 = self.residual_conv_3(x3) 614 | 615 | ## Bridge 616 | x5 = self.bridge(x4) 617 | 618 | ## Decoder 619 | x5 = self.upsample_1(x5) 620 | x6 = torch.cat([x5, x4], dim=1) 621 | x7 = self.up_residual_conv_1(x6) 622 | 623 | x7 = self.upsample_2(x7) 624 | x8 = torch.cat([x7, x3], dim=1) 625 | x9 = self.up_residual_conv_2(x8) 626 | 627 | x9 = self.upsample_3(x9) 628 | x10 = torch.cat([x9, x2], dim=1) 629 | x11 = self.up_residual_conv_3(x10) 630 | 631 | x11 = self.upsample_4(x11) 632 | x12 = torch.cat([x11, x1], dim=1) 633 | x13 = self.up_residual_conv_4(x12) 634 | 635 | output = self.output_layer(x13) 636 | 637 | return output 638 | 639 | 640 | 641 | class ResUNET_with_CBAM(nn.Module): 642 | def __init__(self, in_channels=3, out_channels=1, filters=[64, 128, 256, 512]): 643 | super(ResUNET_with_CBAM, self).__init__() 644 | 645 | ## input and encoder blocks 646 | self.input_layer = nn.Sequential( 647 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1), 648 | nn.BatchNorm2d(filters[0]), 649 | nn.ReLU(), 650 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 651 | ) 652 | self.input_skip = nn.Sequential( 653 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) 654 | ) 655 | 656 | 657 | self.residual_conv_1 = ResNetBlock(filters[0], filters[1], stride=2, padding=1) 658 | self.residual_conv_2 = ResNetBlock(filters[1], filters[2], stride=2, padding=1) 659 | 660 | 661 | ## bridge 662 | self.bridge = ResNetBlock(filters[2], filters[3], stride=2, padding=1) 663 | 664 | ## decoder blocks 665 | self.upsample_1 = nn.ConvTranspose2d(filters[3],filters[3], kernel_size=2, stride=2) 666 | self.up_residual_conv_1 = ResNetBlock(filters[3]+filters[2],filters[2], stride=1, padding=1) 667 | self.cbam_1 = CBAM(gate_channels=filters[3]+filters[2]) 668 | 669 | 670 | 671 | self.upsample_2 = nn.ConvTranspose2d(filters[2],filters[2], kernel_size=2, stride=2) 672 | self.up_residual_conv_2 = ResNetBlock(filters[2]+filters[1],filters[1], stride=1, padding=1) 673 | self.cbam_2 = CBAM(gate_channels=filters[2]+filters[1]) 674 | 675 | self.upsample_3 = nn.ConvTranspose2d(filters[1],filters[1], kernel_size=2, stride=2) 676 | self.up_residual_conv_3 = ResNetBlock(filters[1]+filters[0],filters[0], stride=1, padding=1) 677 | self.cbam_3 = CBAM(gate_channels=filters[1]+filters[0]) 678 | 679 | ## output layer 680 | self.output_layer = nn.Conv2d(filters[0], out_channels, kernel_size=1, stride=1,) 681 | 682 | def forward(self, x): 683 | 684 | ## Encoder 685 | x1 = self.input_layer(x) + self.input_skip(x) 686 | x2 = self.residual_conv_1(x1) 687 | x3 = self.residual_conv_2(x2) 688 | 689 | ## Bridge 690 | x4 = self.bridge(x3) 691 | 692 | ## Decoder 693 | x4 = self.upsample_1(x4) 694 | x5 = torch.cat([x4, x3], dim=1) 695 | x5 = self.cbam_1(x5) 696 | x6 = self.up_residual_conv_1(x5) 697 | 698 | 699 | x6 = self.upsample_2(x6) 700 | x7 = torch.cat([x6, x2], dim=1) 701 | x7 = self.cbam_2(x7) 702 | x8 = self.up_residual_conv_2(x7) 703 | 704 | 705 | x8 = self.upsample_3(x8) 706 | x9 = torch.cat([x8, x1], dim=1) 707 | x9 = self.cbam_3(x9) 708 | x10 = self.up_residual_conv_3(x9) 709 | 710 | 711 | output = self.output_layer(x10) 712 | 713 | return output 714 | 715 | 716 | class ResUNET_with_GN(nn.Module): 717 | def __init__(self, in_channels=3, out_channels=1, filters=[64, 128, 256, 512]): 718 | super(ResUNET_with_GN, self).__init__() 719 | 720 | ## input and encoder blocks 721 | self.input_layer = nn.Sequential( 722 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1), 723 | nn.GroupNorm(num_groups=filters[0]//8,num_channels=filters[0]), 724 | nn.ReLU(), 725 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 726 | ) 727 | self.input_skip = nn.Sequential( 728 | nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) 729 | ) 730 | 731 | self.residual_conv_1 = ResNetBlock(filters[0], filters[1], stride=2, padding=1) 732 | self.residual_conv_2 = ResNetBlock(filters[1], filters[2], stride=2, padding=1) 733 | 734 | 735 | ## bridge 736 | self.bridge = ResNetBlock(filters[2], filters[3], stride=2, padding=1) 737 | 738 | ## decoder blocks 739 | self.upsample_1 = nn.ConvTranspose2d(filters[3],filters[3], kernel_size=2, stride=2) 740 | self.up_residual_conv_1 = ResNetBlock(filters[3]+filters[2],filters[2], stride=1, padding=1) 741 | 742 | self.upsample_2 = nn.ConvTranspose2d(filters[2],filters[2], kernel_size=2, stride=2) 743 | self.up_residual_conv_2 = ResNetBlock(filters[2]+filters[1],filters[1], stride=1, padding=1) 744 | 745 | self.upsample_3 = nn.ConvTranspose2d(filters[1],filters[1], kernel_size=2, stride=2) 746 | self.up_residual_conv_3 = ResNetBlock(filters[1]+filters[0],filters[0], stride=1, padding=1) 747 | 748 | ## output layer 749 | self.output_layer = nn.Conv2d(filters[0], out_channels, kernel_size=1, stride=1,) 750 | 751 | def forward(self, x): 752 | 753 | ## Encoder 754 | x1 = self.input_layer(x) + self.input_skip(x) 755 | x2 = self.residual_conv_1(x1) 756 | x3 = self.residual_conv_2(x2) 757 | 758 | ## Bridge 759 | x4 = self.bridge(x3) 760 | 761 | ## Decoder 762 | x4 = self.upsample_1(x4) 763 | x5 = torch.cat([x4, x3], dim=1) 764 | x6 = self.up_residual_conv_1(x5) 765 | 766 | x6 = self.upsample_2(x6) 767 | x7 = torch.cat([x6, x2], dim=1) 768 | x8 = self.up_residual_conv_2(x7) 769 | 770 | x8 = self.upsample_3(x8) 771 | x9 = torch.cat([x8, x1], dim=1) 772 | x10 = self.up_residual_conv_3(x9) 773 | 774 | output = self.output_layer(x10) 775 | 776 | return output 777 | 778 | 779 | 780 | def test(): 781 | x = torch.randn((3, 1, 160, 160)) 782 | 783 | model = ResUNETPlus(in_channels=1, out_channels=1) 784 | print(model) 785 | preds = model(x) 786 | print(preds.shape, x.shape) 787 | #make_dot(preds, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png") 788 | #assert preds.shape == x.shape 789 | 790 | 791 | 792 | if __name__ == "__main__": 793 | test() --------------------------------------------------------------------------------