├── .github └── FUNDING.yml ├── .gitignore ├── .idea └── .gitignore ├── LICENSE ├── README.md ├── assets ├── arch.png └── outs.png ├── config.py ├── inference.py ├── main.py ├── requirements.txt ├── train.py ├── train_transunet.py └── utils ├── dataset.py ├── transforms.py ├── transunet.py ├── utils.py └── vit.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: mkara44 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /transunet_val0665.pt 131 | /results/ 132 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 mkara44 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Panoramic Dental X-Ray Image Semantic Segmentation with TransUnet 2 | The unofficial implementation of [TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation](https://arxiv.org/abs/2102.04306) on Pytorch 3 | 4 | ![Output](./assets/outs.png "Output") 5 | *Output of my implementation. (A) Original X-Ray Image; (B) Merged Image of the Predicted Segmentation Map and Original X-Ray; (C) Ground Truth; (D) Predicted Segmentation Map* 6 | 7 | ## TransUNet 8 | - On various medical image segmentation tasks, the ushaped architecture, also known as U-Net, has become the de-facto standard and achieved tremendous success. However, due to the intrinsic 9 | locality of convolution operations, U-Net generally demonstrates limitations in explicitly modeling long-range dependency. [1] 10 | - TransUNet employs a hybrid CNN-Transformer architecture to leverage both detailed high-resolution spatial information from CNN features and the global context encoded by Transformers. [1] 11 | 12 | ## Model Architecture 13 | ![Model Architecture](./assets/arch.png "Model Architecure") 14 | 15 | *TransUNet Architecture Figure from Official Paper* 16 | 17 | ## Dependencies 18 | - Python 3.6+ 19 | - `pip install -r requirements.txt` 20 | 21 | ## Dataset 22 | - UFBA_UESC_DENTAL_IMAGES[2] dataset was used for training. 23 | - Dataset can be accessed by request[3]. 24 | 25 | ## Training 26 | - Training process can be started with following command. 27 | - `python main.py --mode train --model_path ./path/to/model --train_path ./path/to/trainset --test_path ./path/to/testset ` 28 | 29 | ## Inference 30 | - After model is trained, inference can be run with following command. 31 | - `python main.py --mode inference --model_path ./path/to/model --image_path ./path/to/image` 32 | 33 | ## Other Implementations 34 | - [Self Attention CV / The AI Summer](https://github.com/The-AI-Summer/self-attention-cv) 35 | - [SOTA Vision / 04RR](https://github.com/04RR/SOTA-Vision) 36 | 37 | ## References 38 | - [1] [TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation](https://arxiv.org/abs/2102.04306) 39 | - [2] [Automatic segmenting teeth in X-ray images: Trends, a novel data set, benchmarking and future perspectives](https://www.sciencedirect.com/science/article/abs/pii/S0957417418302252) 40 | - [3] [GitHub Repository of Dataset](https://github.com/IvisionLab/dental-image) 41 | -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkara44/transunet_pytorch/dcec54efd9d8539927ef495ece0dffba373f5ba3/assets/arch.png -------------------------------------------------------------------------------- /assets/outs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkara44/transunet_pytorch/dcec54efd9d8539927ef495ece0dffba373f5ba3/assets/outs.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | 3 | cfg = EasyDict() 4 | cfg.batch_size = 16 5 | cfg.epoch = 200 6 | cfg.learning_rate = 1e-2 7 | cfg.momentum = 0.9 8 | cfg.weight_decay = 1e-4 9 | cfg.patience = 25 10 | cfg.inference_threshold = 0.75 11 | 12 | cfg.transunet = EasyDict() 13 | cfg.transunet.img_dim = 512 14 | cfg.transunet.in_channels = 3 15 | cfg.transunet.out_channels = 128 16 | cfg.transunet.head_num = 4 17 | cfg.transunet.mlp_dim = 512 18 | cfg.transunet.block_num = 8 19 | cfg.transunet.patch_dim = 16 20 | cfg.transunet.class_num = 1 21 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import datetime 6 | 7 | # Additional Scripts 8 | from train_transunet import TransUNetSeg 9 | 10 | from utils.utils import thresh_func 11 | from config import cfg 12 | 13 | 14 | class SegInference: 15 | def __init__(self, model_path, device): 16 | self.device = device 17 | self.transunet = TransUNetSeg(device) 18 | self.transunet.load_model(model_path) 19 | 20 | if not os.path.exists('./results'): 21 | os.mkdir('./results') 22 | 23 | def read_and_preprocess(self, p): 24 | img = cv2.imread(p) 25 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 26 | 27 | img_torch = cv2.resize(img, (cfg.transunet.img_dim, cfg.transunet.img_dim)) 28 | img_torch = img_torch / 255. 29 | img_torch = img_torch.transpose((2, 0, 1)) 30 | img_torch = np.expand_dims(img_torch, axis=0) 31 | img_torch = torch.from_numpy(img_torch.astype('float32')).to(self.device) 32 | 33 | return img, img_torch 34 | 35 | def save_preds(self, preds): 36 | folder_path = './results/' + str(datetime.datetime.utcnow()).replace(' ', '_') 37 | 38 | os.mkdir(folder_path) 39 | for name, pred_mask in preds.items(): 40 | cv2.imwrite(f'{folder_path}/{name}', pred_mask) 41 | 42 | def infer(self, path, merged=True, save=True): 43 | path = [path] if isinstance(path, str) else path 44 | 45 | preds = {} 46 | for p in path: 47 | file_name = p.split('/')[-1] 48 | img, img_torch = self.read_and_preprocess(p) 49 | with torch.no_grad(): 50 | pred_mask = self.transunet.model(img_torch) 51 | pred_mask = torch.sigmoid(pred_mask) 52 | pred_mask = pred_mask.detach().cpu().numpy().transpose((0, 2, 3, 1)) 53 | 54 | orig_h, orig_w = img.shape[:2] 55 | pred_mask = cv2.resize(pred_mask[0, ...], (orig_w, orig_h)) 56 | pred_mask = thresh_func(pred_mask, thresh=cfg.inference_threshold) 57 | pred_mask *= 255 58 | 59 | if merged: 60 | pred_mask = cv2.bitwise_and(img, img, mask=pred_mask.astype('uint8')) 61 | 62 | preds[file_name] = pred_mask 63 | 64 | if save: 65 | self.save_preds(preds) 66 | 67 | return preds 68 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | 5 | # Additional Scripts 6 | from train import TrainTestPipe 7 | from inference import SegInference 8 | 9 | 10 | def main_pipeline(parser): 11 | device = 'cpu:0' 12 | if torch.cuda.is_available(): 13 | device = 'cuda:0' 14 | 15 | if parser.mode == 'train': 16 | ttp = TrainTestPipe(train_path=parser.train_path, 17 | test_path=parser.test_path, 18 | model_path=parser.model_path, 19 | device=device) 20 | 21 | ttp.train() 22 | 23 | elif parser.mode == 'inference': 24 | inf = SegInference(model_path=parser.model_path, 25 | device=device) 26 | 27 | _ = inf.infer(parser.image_path) 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--mode', type=str, required=True, choices=['train', 'inference']) 33 | parser.add_argument('--model_path', required=True, type=str, default=None) 34 | 35 | parser.add_argument('--train_path', required='train' in sys.argv, type=str, default=None) 36 | parser.add_argument('--test_path', required='train' in sys.argv, type=str, default=None) 37 | 38 | parser.add_argument('--image_path', required='infer' in sys.argv, type=str, default=None) 39 | parser = parser.parse_args() 40 | 41 | main_pipeline(parser) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | torchvision==0.8.1 3 | tqdm==4.49.0 4 | torch==1.7.0 5 | einops==0.3.0 6 | easydict==1.9 7 | opencv_python==4.4.0.42 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | 6 | # Additional Scripts 7 | from utils import transforms as T 8 | from utils.dataset import DentalDataset 9 | from utils.utils import EpochCallback 10 | 11 | from config import cfg 12 | 13 | from train_transunet import TransUNetSeg 14 | 15 | 16 | class TrainTestPipe: 17 | def __init__(self, train_path, test_path, model_path, device): 18 | self.device = device 19 | self.model_path = model_path 20 | 21 | self.train_loader = self.__load_dataset(train_path, train=True) 22 | self.test_loader = self.__load_dataset(test_path) 23 | 24 | self.transunet = TransUNetSeg(self.device) 25 | 26 | def __load_dataset(self, path, train=False): 27 | shuffle = False 28 | transform = False 29 | 30 | if train: 31 | shuffle = True 32 | transform = transforms.Compose([T.RandomAugmentation(2)]) 33 | 34 | set = DentalDataset(path, transform) 35 | loader = DataLoader(set, batch_size=cfg.batch_size, shuffle=shuffle) 36 | 37 | return loader 38 | 39 | def __loop(self, loader, step_func, t): 40 | total_loss = 0 41 | 42 | for step, data in enumerate(loader): 43 | img, mask = data['img'], data['mask'] 44 | img = img.to(self.device) 45 | mask = mask.to(self.device) 46 | 47 | loss, cls_pred = step_func(img=img, mask=mask) 48 | 49 | total_loss += loss 50 | 51 | t.update() 52 | 53 | return total_loss 54 | 55 | def train(self): 56 | callback = EpochCallback(self.model_path, cfg.epoch, 57 | self.transunet.model, self.transunet.optimizer, 'test_loss', cfg.patience) 58 | 59 | for epoch in range(cfg.epoch): 60 | with tqdm(total=len(self.train_loader) + len(self.test_loader)) as t: 61 | train_loss = self.__loop(self.train_loader, self.transunet.train_step, t) 62 | 63 | test_loss = self.__loop(self.test_loader, self.transunet.test_step, t) 64 | 65 | callback.epoch_end(epoch + 1, 66 | {'loss': train_loss / len(self.train_loader), 67 | 'test_loss': test_loss / len(self.test_loader)}) 68 | 69 | if callback.end_training: 70 | break 71 | -------------------------------------------------------------------------------- /train_transunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import SGD 3 | 4 | # Additional Scripts 5 | from utils.transunet import TransUNet 6 | from utils.utils import dice_loss 7 | from config import cfg 8 | 9 | 10 | class TransUNetSeg: 11 | def __init__(self, device): 12 | self.device = device 13 | self.model = TransUNet(img_dim=cfg.transunet.img_dim, 14 | in_channels=cfg.transunet.in_channels, 15 | out_channels=cfg.transunet.out_channels, 16 | head_num=cfg.transunet.head_num, 17 | mlp_dim=cfg.transunet.mlp_dim, 18 | block_num=cfg.transunet.block_num, 19 | patch_dim=cfg.transunet.patch_dim, 20 | class_num=cfg.transunet.class_num).to(self.device) 21 | 22 | self.criterion = dice_loss 23 | self.optimizer = SGD(self.model.parameters(), lr=cfg.learning_rate, 24 | momentum=cfg.momentum, weight_decay=cfg.weight_decay) 25 | 26 | def load_model(self, path): 27 | ckpt = torch.load(path) 28 | self.model.load_state_dict(ckpt['model_state_dict']) 29 | self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) 30 | 31 | self.model.eval() 32 | 33 | def train_step(self, **params): 34 | self.model.train() 35 | 36 | self.optimizer.zero_grad() 37 | pred_mask = self.model(params['img']) 38 | loss = self.criterion(pred_mask, params['mask']) 39 | loss.backward() 40 | self.optimizer.step() 41 | 42 | return loss.item(), pred_mask 43 | 44 | def test_step(self, **params): 45 | self.model.eval() 46 | 47 | pred_mask = self.model(params['img']) 48 | loss = self.criterion(pred_mask, params['mask']) 49 | 50 | return loss.item(), pred_mask 51 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | # Additional Scripts 8 | from config import cfg 9 | 10 | 11 | class DentalDataset(Dataset): 12 | output_size = cfg.transunet.img_dim 13 | 14 | def __init__(self, path, transform): 15 | super().__init__() 16 | 17 | self.transform = transform 18 | 19 | img_folder = os.path.join(path, 'img') 20 | mask_folder = os.path.join(path, 'mask') 21 | 22 | self.img_paths = [] 23 | self.mask_paths = [] 24 | for p in os.listdir(img_folder): 25 | name = p.split('.')[0] 26 | 27 | self.img_paths.append(os.path.join(img_folder, name + '.jpg')) 28 | self.mask_paths.append(os.path.join(mask_folder, name + '.bmp')) 29 | 30 | def __getitem__(self, idx): 31 | if torch.is_tensor(idx): 32 | idx = idx.tolist() 33 | 34 | img = self.img_paths[idx] 35 | mask = self.mask_paths[idx] 36 | 37 | img = cv2.imread(img) 38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 39 | img = cv2.resize(img, (self.output_size, self.output_size)) 40 | 41 | mask = cv2.imread(mask, 0) 42 | mask = cv2.resize(mask, (self.output_size, self.output_size), interpolation=cv2.INTER_NEAREST) 43 | mask = np.expand_dims(mask, axis=-1) 44 | 45 | sample = {'img': img, 'mask': mask} 46 | 47 | if self.transform: 48 | sample = self.transform(sample) 49 | 50 | img, mask = sample['img'], sample['mask'] 51 | 52 | img = img / 255. 53 | img = img.transpose((2, 0, 1)) 54 | img = torch.from_numpy(img.astype('float32')) 55 | 56 | mask = mask / 255. 57 | mask = mask.transpose((2, 0, 1)) 58 | mask = torch.from_numpy(mask.astype('float32')) 59 | 60 | return {'img': img, 'mask': mask} 61 | 62 | def __len__(self): 63 | return len(self.img_paths) 64 | 65 | 66 | if __name__ == '__main__': 67 | import torchvision.transforms as transforms 68 | from utils import transforms as T 69 | 70 | transform = transforms.Compose([T.BGR2RGB(), 71 | T.Rescale(cfg.input_size), 72 | T.RandomAugmentation(2), 73 | T.Normalize(), 74 | T.ToTensor()]) 75 | 76 | md = DentalDataset('/home/kara/Downloads/UFBA_UESC_DENTAL_IMAGES_DEEP/dataset_and_code/test/set/train', 77 | transform) 78 | 79 | for sample in md: 80 | print(sample['img'].shape) 81 | print(sample['mask'].shape) 82 | '''cv2.imshow('img', sample['img']) 83 | cv2.imshow('mask', sample['mask']) 84 | cv2.waitKey()''' 85 | 86 | break 87 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | 7 | def flip_horizontal(img, mask): 8 | img = np.flip(img, axis=1) 9 | mask = np.flip(mask, axis=1) 10 | return img, mask 11 | 12 | 13 | def rotate(img, mask, angle_abs=5): 14 | h, w, _ = img.shape 15 | angle = random.choice([angle_abs, -angle_abs]) 16 | 17 | M = cv2.getRotationMatrix2D((h, w), angle, 1.0) 18 | img = cv2.warpAffine(img, M, (h, w), flags=cv2.INTER_CUBIC) 19 | mask = cv2.warpAffine(mask, M, (h, w), flags=cv2.INTER_CUBIC) 20 | mask = np.expand_dims(mask, axis=-1) 21 | return img, mask 22 | 23 | 24 | class RandomAugmentation: 25 | augmentations = [flip_horizontal, rotate] 26 | 27 | def __init__(self, max_augment_count): 28 | if max_augment_count <= len(self.augmentations): 29 | self.max_augment_count = max_augment_count 30 | else: 31 | self.max_augment_count = len(self.augmentations) 32 | 33 | def __call__(self, sample): 34 | img, mask = sample['img'], sample['mask'] 35 | 36 | augmentation_count = random.randint(0, self.max_augment_count) 37 | selected_augmentations = random.sample(self.augmentations, k=augmentation_count) 38 | for augmentation in selected_augmentations: 39 | img, mask = augmentation(img, mask) 40 | 41 | return {'img': img, 'mask': mask} 42 | -------------------------------------------------------------------------------- /utils/transunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | from utils.vit import ViT 6 | 7 | 8 | class EncoderBottleneck(nn.Module): 9 | def __init__(self, in_channels, out_channels, stride=1, base_width=64): 10 | super().__init__() 11 | 12 | self.downsample = nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), 14 | nn.BatchNorm2d(out_channels) 15 | ) 16 | 17 | width = int(out_channels * (base_width / 64)) 18 | 19 | self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False) 20 | self.norm1 = nn.BatchNorm2d(width) 21 | 22 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False) 23 | self.norm2 = nn.BatchNorm2d(width) 24 | 25 | self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False) 26 | self.norm3 = nn.BatchNorm2d(out_channels) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | def forward(self, x): 31 | x_down = self.downsample(x) 32 | 33 | x = self.conv1(x) 34 | x = self.norm1(x) 35 | x = self.relu(x) 36 | 37 | x = self.conv2(x) 38 | x = self.norm2(x) 39 | x = self.relu(x) 40 | 41 | x = self.conv3(x) 42 | x = self.norm3(x) 43 | x = x + x_down 44 | x = self.relu(x) 45 | 46 | return x 47 | 48 | 49 | class DecoderBottleneck(nn.Module): 50 | def __init__(self, in_channels, out_channels, scale_factor=2): 51 | super().__init__() 52 | 53 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) 54 | self.layer = nn.Sequential( 55 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 56 | nn.BatchNorm2d(out_channels), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 59 | nn.BatchNorm2d(out_channels), 60 | nn.ReLU(inplace=True) 61 | ) 62 | 63 | def forward(self, x, x_concat=None): 64 | x = self.upsample(x) 65 | 66 | if x_concat is not None: 67 | x = torch.cat([x_concat, x], dim=1) 68 | 69 | x = self.layer(x) 70 | return x 71 | 72 | 73 | class Encoder(nn.Module): 74 | def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim): 75 | super().__init__() 76 | 77 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False) 78 | self.norm1 = nn.BatchNorm2d(out_channels) 79 | self.relu = nn.ReLU(inplace=True) 80 | 81 | self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2) 82 | self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2) 83 | self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2) 84 | 85 | self.vit_img_dim = img_dim // patch_dim 86 | self.vit = ViT(self.vit_img_dim, out_channels * 8, out_channels * 8, 87 | head_num, mlp_dim, block_num, patch_dim=1, classification=False) 88 | 89 | self.conv2 = nn.Conv2d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1) 90 | self.norm2 = nn.BatchNorm2d(512) 91 | 92 | def forward(self, x): 93 | x = self.conv1(x) 94 | x = self.norm1(x) 95 | x1 = self.relu(x) 96 | 97 | x2 = self.encoder1(x1) 98 | x3 = self.encoder2(x2) 99 | x = self.encoder3(x3) 100 | 101 | x = self.vit(x) 102 | x = rearrange(x, "b (x y) c -> b c x y", x=self.vit_img_dim, y=self.vit_img_dim) 103 | 104 | x = self.conv2(x) 105 | x = self.norm2(x) 106 | x = self.relu(x) 107 | 108 | return x, x1, x2, x3 109 | 110 | 111 | class Decoder(nn.Module): 112 | def __init__(self, out_channels, class_num): 113 | super().__init__() 114 | 115 | self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 2) 116 | self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels) 117 | self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2)) 118 | self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8)) 119 | 120 | self.conv1 = nn.Conv2d(int(out_channels * 1 / 8), class_num, kernel_size=1) 121 | 122 | def forward(self, x, x1, x2, x3): 123 | x = self.decoder1(x, x3) 124 | x = self.decoder2(x, x2) 125 | x = self.decoder3(x, x1) 126 | x = self.decoder4(x) 127 | x = self.conv1(x) 128 | 129 | return x 130 | 131 | 132 | class TransUNet(nn.Module): 133 | def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num): 134 | super().__init__() 135 | 136 | self.encoder = Encoder(img_dim, in_channels, out_channels, 137 | head_num, mlp_dim, block_num, patch_dim) 138 | 139 | self.decoder = Decoder(out_channels, class_num) 140 | 141 | def forward(self, x): 142 | x, x1, x2, x3 = self.encoder(x) 143 | x = self.decoder(x, x1, x2, x3) 144 | 145 | return x 146 | 147 | 148 | if __name__ == '__main__': 149 | import torch 150 | 151 | transunet = TransUNet(img_dim=128, 152 | in_channels=3, 153 | out_channels=128, 154 | head_num=4, 155 | mlp_dim=512, 156 | block_num=8, 157 | patch_dim=16, 158 | class_num=1) 159 | 160 | print(sum(p.numel() for p in transunet.parameters())) 161 | print(transunet(torch.randn(1, 3, 128, 128)).shape) 162 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def thresh_func(mask, thresh=0.5): 6 | mask[mask >= thresh] = 1 7 | mask[mask < thresh] = 0 8 | 9 | return mask 10 | 11 | 12 | def dice_loss(pred, target): 13 | pred = torch.sigmoid(pred) 14 | 15 | pred = pred.contiguous().view(-1) 16 | target = target.contiguous().view(-1) 17 | 18 | intersection = torch.sum(pred * target) 19 | pred_sum = torch.sum(pred * pred) 20 | target_sum = torch.sum(target * target) 21 | 22 | return 1 - ((2. * intersection + 1e-5) / (pred_sum + target_sum + 1e-5)) 23 | 24 | 25 | class EpochCallback: 26 | end_training = False 27 | not_improved_epoch = 0 28 | monitor_value = np.inf 29 | 30 | def __init__(self, model_name, total_epoch_num, model, optimizer, monitor=None, patience=None): 31 | if isinstance(model_name, str): 32 | model_name = [model_name] 33 | model = [model] 34 | optimizer = [optimizer] 35 | 36 | self.model_name = model_name 37 | self.total_epoch_num = total_epoch_num 38 | self.monitor = monitor 39 | self.patience = patience 40 | self.model = model 41 | self.optimizer = optimizer 42 | 43 | def __save_model(self): 44 | for m_name, m, opt in zip(self.model_name, self.model, self.optimizer): 45 | torch.save({'model_state_dict': m.state_dict(), 46 | 'optimizer_state_dict': opt.state_dict()}, 47 | m_name) 48 | 49 | print(f'Model saved to {m_name}') 50 | 51 | def epoch_end(self, epoch_num, hash): 52 | epoch_end_str = f'Epoch {epoch_num}/{self.total_epoch_num} - ' 53 | for name, value in hash.items(): 54 | epoch_end_str += f'{name}: {round(value, 4)} ' 55 | 56 | print(epoch_end_str) 57 | 58 | if self.monitor is None: 59 | self.__save_model() 60 | 61 | elif hash[self.monitor] < self.monitor_value: 62 | print(f'{self.monitor} decreased from {round(self.monitor_value, 4)} to {round(hash[self.monitor], 4)}') 63 | 64 | self.not_improved_epoch = 0 65 | self.monitor_value = hash[self.monitor] 66 | self.__save_model() 67 | else: 68 | print(f'{self.monitor} did not decrease from {round(self.monitor_value, 4)}, model did not save!') 69 | 70 | self.not_improved_epoch += 1 71 | if self.patience is not None and self.not_improved_epoch >= self.patience: 72 | print("Training was stopped by callback!") 73 | self.end_training = True 74 | -------------------------------------------------------------------------------- /utils/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from einops import rearrange, repeat 5 | 6 | 7 | class MultiHeadAttention(nn.Module): 8 | def __init__(self, embedding_dim, head_num): 9 | super().__init__() 10 | 11 | self.head_num = head_num 12 | self.dk = (embedding_dim // head_num) ** (1 / 2) 13 | 14 | self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) 15 | self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False) 16 | 17 | def forward(self, x, mask=None): 18 | qkv = self.qkv_layer(x) 19 | 20 | query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num)) 21 | energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk 22 | 23 | if mask is not None: 24 | energy = energy.masked_fill(mask, -np.inf) 25 | 26 | attention = torch.softmax(energy, dim=-1) 27 | 28 | x = torch.einsum("... i j , ... j d -> ... i d", attention, value) 29 | 30 | x = rearrange(x, "b h t d -> b t (h d)") 31 | x = self.out_attention(x) 32 | 33 | return x 34 | 35 | 36 | class MLP(nn.Module): 37 | def __init__(self, embedding_dim, mlp_dim): 38 | super().__init__() 39 | 40 | self.mlp_layers = nn.Sequential( 41 | nn.Linear(embedding_dim, mlp_dim), 42 | nn.GELU(), 43 | nn.Dropout(0.1), 44 | nn.Linear(mlp_dim, embedding_dim), 45 | nn.Dropout(0.1) 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.mlp_layers(x) 50 | 51 | return x 52 | 53 | 54 | class TransformerEncoderBlock(nn.Module): 55 | def __init__(self, embedding_dim, head_num, mlp_dim): 56 | super().__init__() 57 | 58 | self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num) 59 | self.mlp = MLP(embedding_dim, mlp_dim) 60 | 61 | self.layer_norm1 = nn.LayerNorm(embedding_dim) 62 | self.layer_norm2 = nn.LayerNorm(embedding_dim) 63 | 64 | self.dropout = nn.Dropout(0.1) 65 | 66 | def forward(self, x): 67 | _x = self.multi_head_attention(x) 68 | _x = self.dropout(_x) 69 | x = x + _x 70 | x = self.layer_norm1(x) 71 | 72 | _x = self.mlp(x) 73 | x = x + _x 74 | x = self.layer_norm2(x) 75 | 76 | return x 77 | 78 | 79 | class TransformerEncoder(nn.Module): 80 | def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12): 81 | super().__init__() 82 | 83 | self.layer_blocks = nn.ModuleList( 84 | [TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)]) 85 | 86 | def forward(self, x): 87 | for layer_block in self.layer_blocks: 88 | x = layer_block(x) 89 | 90 | return x 91 | 92 | 93 | class ViT(nn.Module): 94 | def __init__(self, img_dim, in_channels, embedding_dim, head_num, mlp_dim, 95 | block_num, patch_dim, classification=True, num_classes=1): 96 | super().__init__() 97 | 98 | self.patch_dim = patch_dim 99 | self.classification = classification 100 | self.num_tokens = (img_dim // patch_dim) ** 2 101 | self.token_dim = in_channels * (patch_dim ** 2) 102 | 103 | self.projection = nn.Linear(self.token_dim, embedding_dim) 104 | self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim)) 105 | 106 | self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) 107 | 108 | self.dropout = nn.Dropout(0.1) 109 | 110 | self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num) 111 | 112 | if self.classification: 113 | self.mlp_head = nn.Linear(embedding_dim, num_classes) 114 | 115 | def forward(self, x): 116 | img_patches = rearrange(x, 117 | 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)', 118 | patch_x=self.patch_dim, patch_y=self.patch_dim) 119 | 120 | batch_size, tokens, _ = img_patches.shape 121 | 122 | project = self.projection(img_patches) 123 | token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', 124 | batch_size=batch_size) 125 | 126 | patches = torch.cat([token, project], dim=1) 127 | patches += self.embedding[:tokens + 1, :] 128 | 129 | x = self.dropout(patches) 130 | x = self.transformer(x) 131 | x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :] 132 | 133 | return x 134 | 135 | 136 | if __name__ == '__main__': 137 | vit = ViT(img_dim=128, 138 | in_channels=3, 139 | patch_dim=16, 140 | embedding_dim=512, 141 | block_num=6, 142 | head_num=4, 143 | mlp_dim=1024) 144 | print(sum(p.numel() for p in vit.parameters())) 145 | print(vit(torch.rand(1, 3, 128, 128)).shape) 146 | --------------------------------------------------------------------------------