├── LICENSE ├── README.md ├── cmd.txt ├── environment.yml ├── extractors.py ├── img ├── arch.png ├── medt.png ├── medt1.png └── poster.pdf ├── lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── build_dataloader.cpython-36.pyc │ ├── build_dataloader.cpython-37.pyc │ ├── build_model.cpython-36.pyc │ ├── build_model.cpython-37.pyc │ ├── build_optimizer.cpython-36.pyc │ ├── build_optimizer.cpython-37.pyc │ ├── metrics.cpython-36.pyc │ └── metrics.cpython-37.pyc ├── build_dataloader.py ├── build_model.py ├── build_optimizer.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── imagenet1k.cpython-36.pyc │ │ └── imagenet1k.cpython-37.pyc │ └── imagenet1k.py ├── metrics.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── axialnet.cpython-36.pyc │ │ ├── axialnet.cpython-37.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-37.pyc │ ├── axialnet.py │ ├── model_codes.py │ ├── resnet.py │ └── utils.py └── utils.py ├── metrics.py ├── performancemetrics_ax.m ├── performancemetrics_glas.m ├── performancemetrics_monuseg.m ├── requirements.txt ├── test.py ├── train.py ├── utils.py └── utils_gray.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jeya Maria Jose 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Medical-Transformer 2 | 3 | 4 | 5 | Pytorch code for the paper 6 | ["Medical Transformer: Gated Axial-Attention for 7 | Medical Image Segmentation"](https://arxiv.org/pdf/2102.10662.pdf), MICCAI 2021 8 | 9 | [Paper](https://arxiv.org/pdf/2102.10662.pdf) | [Poster](https://drive.google.com/file/d/1gMjc5guT_dYQFT6TEEwdHAFKwG5XkEc9/view?usp=sharing) 10 | 11 | ## News: 12 | 13 | :rocket: : Checkout our latest work [UNeXt](https://arxiv.org/abs/2203.04967), a faster and more efficient segmentation architecture which is also easy to train and implement! Code is available [here](https://github.com/jeya-maria-jose/UNeXt-pytorch). 14 | 15 | ### About this repo: 16 | 17 | This repo hosts the code for the following networks: 18 | 19 | 1) Gated Axial Attention U-Net 20 | 2) MedT 21 | 22 | ## Introduction 23 | 24 | Majority of existing Transformer-based network architectures proposed for vision applications require large-scale 25 | datasets to train properly. However, compared to the datasets for vision 26 | applications, for medical imaging the number of data samples is relatively 27 | low, making it difficult to efficiently train transformers for medical appli- 28 | cations. To this end, we propose a Gated Axial-Attention model which 29 | extends the existing architectures by introducing an additional control 30 | mechanism in the self-attention module. Furthermore, to train the model 31 | effectively on medical images, we propose a Local-Global training strat- 32 | egy (LoGo) which further improves the performance. Specifically, we op- 33 | erate on the whole image and patches to learn global and local features, 34 | respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net. 35 | 36 |

37 | 38 |

39 | 40 | ### Using the code: 41 | 42 | - Clone this repository: 43 | ```bash 44 | git clone https://github.com/jeya-maria-jose/Medical-Transformer 45 | cd Medical-Transformer 46 | ``` 47 | 48 | The code is stable using Python 3.6.10, Pytorch 1.4.0 49 | 50 | To install all the dependencies using conda: 51 | 52 | ```bash 53 | conda env create -f environment.yml 54 | conda activate medt 55 | ``` 56 | 57 | To install all the dependencies using pip: 58 | 59 | ```bash 60 | pip install -r requirements.txt 61 | ``` 62 | 63 | ### Links for downloading the public Datasets: 64 | 65 | 1) MoNuSeG Dataset - Link (Original) 66 | 2) GLAS Dataset - Link (Original) 67 | 3) Brain Anatomy US dataset from the paper will be made public soon ! 68 | 69 | ## Using the Code for your dataset 70 | 71 | ### Dataset Preparation 72 | 73 | Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format. 74 | 75 | 76 | 77 | ```bash 78 | Train Folder----- 79 | img---- 80 | 0001.png 81 | 0002.png 82 | ....... 83 | labelcol--- 84 | 0001.png 85 | 0002.png 86 | ....... 87 | Validation Folder----- 88 | img---- 89 | 0001.png 90 | 0002.png 91 | ....... 92 | labelcol--- 93 | 0001.png 94 | 0002.png 95 | ....... 96 | Test Folder----- 97 | img---- 98 | 0001.png 99 | 0002.png 100 | ....... 101 | labelcol--- 102 | 0001.png 103 | 0002.png 104 | ....... 105 | 106 | ``` 107 | 108 | - The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255. 109 | 110 | ### Training Command: 111 | 112 | ```bash 113 | python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no" 114 | ``` 115 | 116 | ```bash 117 | Change modelname to MedT or logo to train them 118 | ``` 119 | 120 | ### Testing Command: 121 | 122 | ```bash 123 | python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no" 124 | ``` 125 | 126 | The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU. 127 | 128 | ### Notes: 129 | 130 | 1)Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory. 131 | 2)Google Colab Code is an unofficial implementation for quick train/test. Please follow original code for proper training. 132 | 133 | ### Acknowledgement: 134 | 135 | The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab. 136 | 137 | # Citation: 138 | 139 | ```bash 140 | @InProceedings{jose2021medical, 141 | author="Valanarasu, Jeya Maria Jose 142 | and Oza, Poojan 143 | and Hacihaliloglu, Ilker 144 | and Patel, Vishal M.", 145 | title="Medical Transformer: Gated Axial-Attention for Medical Image Segmentation", 146 | booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2021", 147 | year="2021", 148 | publisher="Springer International Publishing", 149 | address="Cham", 150 | pages="36--46", 151 | isbn="978-3-030-87193-2" 152 | } 153 | 154 | ``` 155 | 156 | Open an issue or mail me directly in case of any queries or suggestions. 157 | -------------------------------------------------------------------------------- /cmd.txt: -------------------------------------------------------------------------------- 1 | python train.py --train_dataset "/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/train/" --val_dataset "/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/test/" --direc "./results/axial128_en/" --batch_size 4 --modelname "logo" --epoch 401 --save_freq 50 --learning_rate 0.0001 --imgsize 128 2 | 3 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: medt 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - argon2-cffi=20.1.0=py36h8c4c3a4_1 8 | - attrs=20.1.0=pyh9f0ad1d_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=py_2 11 | - backports.functools_lru_cache=1.6.1=py_0 12 | - blas=1.0=mkl 13 | - bleach=3.1.5=pyh9f0ad1d_0 14 | - brotlipy=0.7.0=py36h8c4c3a4_1000 15 | - ca-certificates=2020.6.20=hecda079_0 16 | - certifi=2020.6.20=py36h9f0ad1d_0 17 | - cffi=1.11.5=py36_0 18 | - chardet=3.0.4=py36h9f0ad1d_1006 19 | - cryptography=3.1=py36h45558ae_0 20 | - decorator=4.4.2=py_0 21 | - defusedxml=0.6.0=py_0 22 | - entrypoints=0.3=py36h9f0ad1d_1001 23 | - idna=2.10=pyh9f0ad1d_0 24 | - importlib-metadata=1.7.0=py36h9f0ad1d_0 25 | - importlib_metadata=1.7.0=0 26 | - intel-openmp=2020.1=217 27 | - ipykernel=5.3.4=py36h95af2a2_0 28 | - ipython=7.16.1=py36h95af2a2_0 29 | - ipython_genutils=0.2.0=py_1 30 | - ipywidgets=7.5.1=py_0 31 | - jedi=0.17.2=py36h9f0ad1d_0 32 | - jinja2=2.11.2=pyh9f0ad1d_0 33 | - json5=0.9.4=pyh9f0ad1d_0 34 | - jsonschema=3.2.0=py36h9f0ad1d_1 35 | - jupyter_client=6.1.7=py_0 36 | - jupyter_core=4.6.3=py36h9f0ad1d_1 37 | - jupyterlab=2.2.6=py_0 38 | - jupyterlab_server=1.2.0=py_0 39 | - ld_impl_linux-64=2.33.1=h53a641e_7 40 | - libedit=3.1.20191231=h7b6447c_0 41 | - libffi=3.3=he6710b0_1 42 | - libgcc-ng=9.1.0=hdf63c60_0 43 | - libgfortran-ng=7.3.0=hdf63c60_0 44 | - libsodium=1.0.18=h516909a_0 45 | - libstdcxx-ng=9.1.0=hdf63c60_0 46 | - markupsafe=1.1.1=py36h8c4c3a4_1 47 | - mistune=0.8.4=py36h8c4c3a4_1001 48 | - mkl=2020.1=217 49 | - mkl-service=2.3.0=py36he904b0f_0 50 | - mkl_fft=1.1.0=py36h23d657b_0 51 | - mkl_random=1.1.1=py36h0573a6f_0 52 | - nbconvert=5.6.1=py36h9f0ad1d_1 53 | - nbformat=5.0.7=py_0 54 | - ncurses=6.2=he6710b0_1 55 | - notebook=6.1.3=py36h9f0ad1d_0 56 | - numpy=1.18.5=py36ha1c710e_0 57 | - numpy-base=1.18.5=py36hde5b4d6_0 58 | - openssl=1.1.1g=h516909a_1 59 | - packaging=20.4=pyh9f0ad1d_0 60 | - pandoc=2.10.1=h516909a_0 61 | - pandocfilters=1.4.2=py_1 62 | - parso=0.7.1=pyh9f0ad1d_0 63 | - pexpect=4.8.0=py36h9f0ad1d_1 64 | - pickleshare=0.7.5=py36h9f0ad1d_1001 65 | - pip=20.1.1=py36_1 66 | - prometheus_client=0.8.0=pyh9f0ad1d_0 67 | - prompt-toolkit=3.0.7=py_0 68 | - ptyprocess=0.6.0=py_1001 69 | - pycparser=2.20=pyh9f0ad1d_2 70 | - pygments=2.6.1=py_0 71 | - pyopenssl=19.1.0=py_1 72 | - pyparsing=2.4.7=pyh9f0ad1d_0 73 | - pyrsistent=0.16.0=py36h8c4c3a4_0 74 | - pysocks=1.7.1=py36h9f0ad1d_1 75 | - python=3.6.10=h7579374_2 76 | - python-dateutil=2.8.1=py_0 77 | - python_abi=3.6=1_cp36m 78 | - pyzmq=19.0.2=py36h9947dbf_0 79 | - readline=8.0=h7b6447c_0 80 | - requests=2.24.0=pyh9f0ad1d_0 81 | - send2trash=1.5.0=py_0 82 | - setuptools=47.3.1=py36_0 83 | - six=1.15.0=py_0 84 | - sqlite=3.32.3=h62c20be_0 85 | - terminado=0.8.3=py36h9f0ad1d_1 86 | - testpath=0.4.4=py_0 87 | - tk=8.6.10=hbc83047_0 88 | - tornado=6.0.4=py36h8c4c3a4_1 89 | - traitlets=4.3.3=py36h9f0ad1d_1 90 | - urllib3=1.25.10=py_0 91 | - wcwidth=0.2.5=pyh9f0ad1d_1 92 | - webencodings=0.5.1=py_1 93 | - wheel=0.34.2=py36_0 94 | - widgetsnbextension=3.5.1=py36h9f0ad1d_1 95 | - xz=5.2.5=h7b6447c_0 96 | - yaml=0.2.5=h7b6447c_0 97 | - zeromq=4.3.2=he1b5a44_3 98 | - zipp=3.1.0=py_0 99 | - zlib=1.2.11=h7b6447c_3 100 | - pip: 101 | - ci-info==0.2.0 102 | - click==7.1.2 103 | - cython==0.29.20 104 | - et-xmlfile==1.0.1 105 | - etelemetry==0.2.1 106 | - filelock==3.0.12 107 | - isodate==0.6.0 108 | - jdcal==1.4.1 109 | - joblib==0.17.0 110 | - lxml==4.5.1 111 | - matplotlib==3.3.2 112 | - medpy==0.4.0 113 | - natsort==7.0.1 114 | - nibabel==3.1.0 115 | - nipype==1.5.0 116 | - openpyxl==3.0.4 117 | - prov==1.5.3 118 | - pydicom==2.0.0 119 | - pydot==1.4.1 120 | - pydotplus==2.0.2 121 | - pynrrd==0.4.2 122 | - rdflib==5.0.0 123 | - scikit-learn==0.23.2 124 | - scipy==1.5.3 125 | - setproctitle==1.1.10 126 | - simplejson==3.17.0 127 | - threadpoolctl==2.1.0 128 | - torch==1.4.0 129 | - torch-dwconv==0.1.0 130 | - torchvision==0.4.0 131 | - traits==6.1.0 132 | prefix: /home/jeyamariajose/anaconda3/envs/medt 133 | 134 | -------------------------------------------------------------------------------- /extractors.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils import model_zoo 8 | from torchvision.models.densenet import densenet121, densenet161 9 | from torchvision.models.squeezenet import squeezenet1_1 10 | 11 | 12 | def load_weights_sequential(target, source_state): 13 | new_dict = OrderedDict() 14 | for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()): 15 | new_dict[k1] = v2 16 | target.load_state_dict(new_dict) 17 | 18 | ''' 19 | Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x 20 | ''' 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, dilation=dilation, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 71 | super(Bottleneck, self).__init__() 72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(planes) 74 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 75 | padding=dilation, bias=False) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(planes * 4) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | def __init__(self, block, layers=(3, 4, 23, 3)): 108 | self.inplanes = 64 109 | super(ResNet, self).__init__() 110 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 111 | bias=False) 112 | self.bn1 = nn.BatchNorm2d(64) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 115 | self.layer1 = self._make_layer(block, 64, layers[0]) 116 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d(self.inplanes, planes * block.expansion, 133 | kernel_size=1, stride=stride, bias=False), 134 | nn.BatchNorm2d(planes * block.expansion), 135 | ) 136 | 137 | layers = [block(self.inplanes, planes, stride, downsample)] 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append(block(self.inplanes, planes, dilation=dilation)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x_3 = self.layer3(x) 153 | x = self.layer4(x_3) 154 | 155 | return x, x_3 156 | 157 | 158 | ''' 159 | Implementation of DenseNet with deep supervision. Downsampling is changed to 8x 160 | ''' 161 | 162 | 163 | class _DenseLayer(nn.Sequential): 164 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, index): 165 | super(_DenseLayer, self).__init__() 166 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 167 | self.add_module('relu1', nn.ReLU(inplace=True)), 168 | if index == 3: 169 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 170 | growth_rate, kernel_size=1, stride=1, bias=False)), 171 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 172 | self.add_module('relu2', nn.ReLU(inplace=True)), 173 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 174 | kernel_size=3, stride=1, dilation=2, padding=2, bias=False)), 175 | else: 176 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 177 | growth_rate, kernel_size=1, stride=1, bias=False)), 178 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 179 | self.add_module('relu2', nn.ReLU(inplace=True)), 180 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 181 | kernel_size=3, stride=1, padding=1, bias=False)), 182 | self.drop_rate = drop_rate 183 | 184 | def forward(self, x): 185 | new_features = super(_DenseLayer, self).forward(x) 186 | if self.drop_rate > 0: 187 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 188 | return torch.cat([x, new_features], 1) 189 | 190 | 191 | class _DenseBlock(nn.Sequential): 192 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, index): 193 | super(_DenseBlock, self).__init__() 194 | for i in range(num_layers): 195 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, index) 196 | self.add_module('denselayer%d' % (i + 1), layer) 197 | 198 | 199 | class _Transition(nn.Sequential): 200 | def __init__(self, num_input_features, num_output_features, downsample=True): 201 | super(_Transition, self).__init__() 202 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 203 | self.add_module('relu', nn.ReLU(inplace=True)) 204 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 205 | kernel_size=1, stride=1, bias=False)) 206 | if downsample: 207 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 208 | else: 209 | self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1)) # compatibility hack 210 | 211 | 212 | class DenseNet(nn.Module): 213 | def __init__(self, growth_rate=8, block_config=(6, 12, 24, 16), 214 | num_init_features=16, bn_size=4, drop_rate=0, pretrained=False): 215 | 216 | super(DenseNet, self).__init__() 217 | 218 | # First convolution 219 | self.start_features = nn.Sequential(OrderedDict([ 220 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 221 | ('norm0', nn.BatchNorm2d(num_init_features)), 222 | ('relu0', nn.ReLU(inplace=True)), 223 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 224 | ])) 225 | 226 | # Each denseblock 227 | num_features = num_init_features 228 | 229 | init_weights = list(densenet121(pretrained=True).features.children()) 230 | start = 0 231 | for i, c in enumerate(self.start_features.children()): 232 | #if pretrained: 233 | #c.load_state_dict(init_weights[i].state_dict()) 234 | start += 1 235 | self.blocks = nn.ModuleList() 236 | for i, num_layers in enumerate(block_config): 237 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 238 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, index = i) 239 | if pretrained: 240 | block.load_state_dict(init_weights[start].state_dict()) 241 | start += 1 242 | self.blocks.append(block) 243 | setattr(self, 'denseblock%d' % (i + 1), block) 244 | 245 | num_features = num_features + num_layers * growth_rate 246 | if i != len(block_config) - 1: 247 | downsample = i < 1 248 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, 249 | downsample=downsample) 250 | if pretrained: 251 | trans.load_state_dict(init_weights[start].state_dict()) 252 | start += 1 253 | self.blocks.append(trans) 254 | setattr(self, 'transition%d' % (i + 1), trans) 255 | num_features = num_features // 2 256 | 257 | def forward(self, x): 258 | out = self.start_features(x) 259 | deep_features = None 260 | for i, block in enumerate(self.blocks): 261 | out = block(out) 262 | if i == 5: 263 | deep_features = out 264 | 265 | return out, deep_features 266 | 267 | 268 | class Fire(nn.Module): 269 | 270 | def __init__(self, inplanes, squeeze_planes, 271 | expand1x1_planes, expand3x3_planes, dilation=1): 272 | super(Fire, self).__init__() 273 | self.inplanes = inplanes 274 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 275 | self.squeeze_activation = nn.ReLU(inplace=True) 276 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 277 | kernel_size=1) 278 | self.expand1x1_activation = nn.ReLU(inplace=True) 279 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 280 | kernel_size=3, padding=dilation, dilation=dilation) 281 | self.expand3x3_activation = nn.ReLU(inplace=True) 282 | 283 | def forward(self, x): 284 | x = self.squeeze_activation(self.squeeze(x)) 285 | return torch.cat([ 286 | self.expand1x1_activation(self.expand1x1(x)), 287 | self.expand3x3_activation(self.expand3x3(x)) 288 | ], 1) 289 | 290 | 291 | class SqueezeNet(nn.Module): 292 | 293 | def __init__(self, pretrained=False): 294 | super(SqueezeNet, self).__init__() 295 | 296 | self.feat_1 = nn.Sequential( 297 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 298 | nn.ReLU(inplace=True) 299 | ) 300 | self.feat_2 = nn.Sequential( 301 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 302 | Fire(64, 16, 64, 64), 303 | Fire(128, 16, 64, 64) 304 | ) 305 | self.feat_3 = nn.Sequential( 306 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 307 | Fire(128, 32, 128, 128, 2), 308 | Fire(256, 32, 128, 128, 2) 309 | ) 310 | self.feat_4 = nn.Sequential( 311 | Fire(256, 48, 192, 192, 4), 312 | Fire(384, 48, 192, 192, 4), 313 | Fire(384, 64, 256, 256, 4), 314 | Fire(512, 64, 256, 256, 4) 315 | ) 316 | if pretrained: 317 | weights = squeezenet1_1(pretrained=True).features.state_dict() 318 | load_weights_sequential(self, weights) 319 | 320 | def forward(self, x): 321 | f1 = self.feat_1(x) 322 | f2 = self.feat_2(f1) 323 | f3 = self.feat_3(f2) 324 | f4 = self.feat_4(f3) 325 | return f4, f3 326 | 327 | 328 | ''' 329 | Handy methods for construction 330 | ''' 331 | 332 | 333 | def squeezenet(pretrained=True): 334 | return SqueezeNet(pretrained) 335 | 336 | 337 | def densenet(pretrained=True): 338 | return DenseNet(pretrained=pretrained) 339 | 340 | 341 | def resnet18(pretrained=True): 342 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 343 | if pretrained: 344 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18'])) 345 | return model 346 | 347 | 348 | def resnet34(pretrained=True): 349 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 350 | if pretrained: 351 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34'])) 352 | return model 353 | 354 | 355 | def resnet50(pretrained=True): 356 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 357 | if pretrained: 358 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50'])) 359 | return model 360 | 361 | 362 | def resnet101(pretrained=True): 363 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 364 | if pretrained: 365 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101'])) 366 | return model 367 | 368 | 369 | def resnet152(pretrained=True): 370 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 371 | if pretrained: 372 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152'])) 373 | return model 374 | -------------------------------------------------------------------------------- /img/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/arch.png -------------------------------------------------------------------------------- /img/medt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/medt.png -------------------------------------------------------------------------------- /img/medt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/medt1.png -------------------------------------------------------------------------------- /img/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/poster.pdf -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_dataloader import build_dataloader 2 | from .build_model import build_model 3 | from .build_optimizer import build_optimizer 4 | from .metrics import Metric 5 | 6 | 7 | __all__ = ['build_dataloader', 'build_model', 'build_optimizer', 'Metric'] -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/build_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/build_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/build_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_model.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/build_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_model.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/build_optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/build_optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /lib/build_dataloader.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | 3 | 4 | def build_dataloader(args, distributed=False): 5 | return datasets.__dict__[args.dataset](args, distributed) 6 | -------------------------------------------------------------------------------- /lib/build_model.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | 3 | 4 | def build_model(args): 5 | model = models.__dict__[args.model](num_classes=args.num_classes) 6 | return model 7 | -------------------------------------------------------------------------------- /lib/build_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def build_optimizer(args, model): 5 | if args.optim == 'sgd': 6 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 7 | momentum=args.momentum, weight_decay=args.weight_decay, 8 | nesterov=args.nesterov) 9 | else: 10 | raise AssertionError 11 | return optimizer 12 | 13 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagenet1k import imagenet1k 2 | 3 | 4 | __all__ = ['imagenet1k'] 5 | -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/imagenet1k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/imagenet1k.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/imagenet1k.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/imagenet1k.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/imagenet1k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import datasets, transforms 4 | 5 | 6 | def imagenet1k(args, distributed=False): 7 | train_dirs = args.train_dirs 8 | val_dirs = args.val_dirs 9 | batch_size = args.batch_size 10 | val_batch_size = args.val_batch_size 11 | num_workers = args.num_workers 12 | color_jitter = args.color_jitter 13 | 14 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 15 | process = [ 16 | transforms.RandomResizedCrop(224), 17 | transforms.RandomHorizontalFlip(), 18 | ] 19 | if color_jitter: 20 | process += [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)] 21 | process += [ 22 | transforms.ToTensor(), 23 | normalize 24 | ] 25 | 26 | transform_train = transforms.Compose(process) 27 | 28 | train_set = datasets.ImageFolder(train_dirs, 29 | transform=transform_train) 30 | 31 | if distributed: 32 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 33 | else: 34 | train_sampler = None 35 | 36 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=(train_sampler is None), 37 | sampler=train_sampler, num_workers=num_workers, pin_memory=True) 38 | 39 | transform_val = transforms.Compose( 40 | [transforms.Resize(256), 41 | transforms.CenterCrop(224), 42 | transforms.ToTensor(), 43 | normalize]) 44 | 45 | val_set = datasets.ImageFolder(root=val_dirs, 46 | transform=transform_val) 47 | 48 | if distributed: 49 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_set) 50 | else: 51 | val_sampler = None 52 | 53 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False, 54 | sampler=val_sampler, num_workers=num_workers, pin_memory=True) 55 | 56 | return train_loader, train_sampler, val_loader, val_sampler 57 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Metric(object): 5 | def __init__(self, name): 6 | self.name = name 7 | self.sum = torch.tensor(0.) 8 | self.n = torch.tensor(0.) 9 | 10 | def update(self, val): 11 | self.sum += val.detach().cpu() 12 | self.n += 1 13 | 14 | @property 15 | def avg(self): 16 | return self.sum / self.n -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .axialnet import * 3 | -------------------------------------------------------------------------------- /lib/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/axialnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/axialnet.cpython-36.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/axialnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/axialnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/axialnet.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .utils import * 7 | import pdb 8 | import matplotlib.pyplot as plt 9 | 10 | import random 11 | 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class AxialAttention(nn.Module): 20 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56, 21 | stride=1, bias=False, width=False): 22 | assert (in_planes % groups == 0) and (out_planes % groups == 0) 23 | super(AxialAttention, self).__init__() 24 | self.in_planes = in_planes 25 | self.out_planes = out_planes 26 | self.groups = groups 27 | self.group_planes = out_planes // groups 28 | self.kernel_size = kernel_size 29 | self.stride = stride 30 | self.bias = bias 31 | self.width = width 32 | 33 | # Multi-head self attention 34 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1, 35 | padding=0, bias=False) 36 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2) 37 | self.bn_similarity = nn.BatchNorm2d(groups * 3) 38 | 39 | self.bn_output = nn.BatchNorm1d(out_planes * 2) 40 | 41 | # Position embedding 42 | self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True) 43 | query_index = torch.arange(kernel_size).unsqueeze(0) 44 | key_index = torch.arange(kernel_size).unsqueeze(1) 45 | relative_index = key_index - query_index + kernel_size - 1 46 | self.register_buffer('flatten_index', relative_index.view(-1)) 47 | if stride > 1: 48 | self.pooling = nn.AvgPool2d(stride, stride=stride) 49 | 50 | self.reset_parameters() 51 | 52 | def forward(self, x): 53 | # pdb.set_trace() 54 | if self.width: 55 | x = x.permute(0, 2, 1, 3) 56 | else: 57 | x = x.permute(0, 3, 1, 2) # N, W, C, H 58 | N, W, C, H = x.shape 59 | x = x.contiguous().view(N * W, C, H) 60 | 61 | # Transformations 62 | qkv = self.bn_qkv(self.qkv_transform(x)) 63 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2) 64 | 65 | # Calculate position embedding 66 | all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size) 67 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0) 68 | 69 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding) 70 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) 71 | 72 | qk = torch.einsum('bgci, bgcj->bgij', q, k) 73 | 74 | stacked_similarity = torch.cat([qk, qr, kr], dim=1) 75 | stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1) 76 | #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk) 77 | # (N, groups, H, H, W) 78 | similarity = F.softmax(stacked_similarity, dim=3) 79 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v) 80 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) 81 | stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H) 82 | output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2) 83 | 84 | if self.width: 85 | output = output.permute(0, 2, 1, 3) 86 | else: 87 | output = output.permute(0, 2, 3, 1) 88 | 89 | if self.stride > 1: 90 | output = self.pooling(output) 91 | 92 | return output 93 | 94 | def reset_parameters(self): 95 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes)) 96 | #nn.init.uniform_(self.relative, -0.1, 0.1) 97 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes)) 98 | 99 | class AxialAttention_dynamic(nn.Module): 100 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56, 101 | stride=1, bias=False, width=False): 102 | assert (in_planes % groups == 0) and (out_planes % groups == 0) 103 | super(AxialAttention_dynamic, self).__init__() 104 | self.in_planes = in_planes 105 | self.out_planes = out_planes 106 | self.groups = groups 107 | self.group_planes = out_planes // groups 108 | self.kernel_size = kernel_size 109 | self.stride = stride 110 | self.bias = bias 111 | self.width = width 112 | 113 | # Multi-head self attention 114 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1, 115 | padding=0, bias=False) 116 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2) 117 | self.bn_similarity = nn.BatchNorm2d(groups * 3) 118 | self.bn_output = nn.BatchNorm1d(out_planes * 2) 119 | 120 | # Priority on encoding 121 | 122 | ## Initial values 123 | 124 | self.f_qr = nn.Parameter(torch.tensor(0.1), requires_grad=False) 125 | self.f_kr = nn.Parameter(torch.tensor(0.1), requires_grad=False) 126 | self.f_sve = nn.Parameter(torch.tensor(0.1), requires_grad=False) 127 | self.f_sv = nn.Parameter(torch.tensor(1.0), requires_grad=False) 128 | 129 | 130 | # Position embedding 131 | self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True) 132 | query_index = torch.arange(kernel_size).unsqueeze(0) 133 | key_index = torch.arange(kernel_size).unsqueeze(1) 134 | relative_index = key_index - query_index + kernel_size - 1 135 | self.register_buffer('flatten_index', relative_index.view(-1)) 136 | if stride > 1: 137 | self.pooling = nn.AvgPool2d(stride, stride=stride) 138 | 139 | self.reset_parameters() 140 | # self.print_para() 141 | 142 | def forward(self, x): 143 | if self.width: 144 | x = x.permute(0, 2, 1, 3) 145 | else: 146 | x = x.permute(0, 3, 1, 2) # N, W, C, H 147 | N, W, C, H = x.shape 148 | x = x.contiguous().view(N * W, C, H) 149 | 150 | # Transformations 151 | qkv = self.bn_qkv(self.qkv_transform(x)) 152 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2) 153 | 154 | # Calculate position embedding 155 | all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size) 156 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0) 157 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding) 158 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) 159 | qk = torch.einsum('bgci, bgcj->bgij', q, k) 160 | 161 | 162 | # multiply by factors 163 | qr = torch.mul(qr, self.f_qr) 164 | kr = torch.mul(kr, self.f_kr) 165 | 166 | stacked_similarity = torch.cat([qk, qr, kr], dim=1) 167 | stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1) 168 | #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk) 169 | # (N, groups, H, H, W) 170 | similarity = F.softmax(stacked_similarity, dim=3) 171 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v) 172 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) 173 | 174 | # multiply by factors 175 | sv = torch.mul(sv, self.f_sv) 176 | sve = torch.mul(sve, self.f_sve) 177 | 178 | stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H) 179 | output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2) 180 | 181 | if self.width: 182 | output = output.permute(0, 2, 1, 3) 183 | else: 184 | output = output.permute(0, 2, 3, 1) 185 | 186 | if self.stride > 1: 187 | output = self.pooling(output) 188 | 189 | return output 190 | def reset_parameters(self): 191 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes)) 192 | #nn.init.uniform_(self.relative, -0.1, 0.1) 193 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes)) 194 | 195 | class AxialAttention_wopos(nn.Module): 196 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56, 197 | stride=1, bias=False, width=False): 198 | assert (in_planes % groups == 0) and (out_planes % groups == 0) 199 | super(AxialAttention_wopos, self).__init__() 200 | self.in_planes = in_planes 201 | self.out_planes = out_planes 202 | self.groups = groups 203 | self.group_planes = out_planes // groups 204 | self.kernel_size = kernel_size 205 | self.stride = stride 206 | self.bias = bias 207 | self.width = width 208 | 209 | # Multi-head self attention 210 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1, 211 | padding=0, bias=False) 212 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2) 213 | self.bn_similarity = nn.BatchNorm2d(groups ) 214 | 215 | self.bn_output = nn.BatchNorm1d(out_planes * 1) 216 | 217 | if stride > 1: 218 | self.pooling = nn.AvgPool2d(stride, stride=stride) 219 | 220 | self.reset_parameters() 221 | 222 | def forward(self, x): 223 | if self.width: 224 | x = x.permute(0, 2, 1, 3) 225 | else: 226 | x = x.permute(0, 3, 1, 2) # N, W, C, H 227 | N, W, C, H = x.shape 228 | x = x.contiguous().view(N * W, C, H) 229 | 230 | # Transformations 231 | qkv = self.bn_qkv(self.qkv_transform(x)) 232 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2) 233 | 234 | qk = torch.einsum('bgci, bgcj->bgij', q, k) 235 | 236 | stacked_similarity = self.bn_similarity(qk).reshape(N * W, 1, self.groups, H, H).sum(dim=1).contiguous() 237 | 238 | similarity = F.softmax(stacked_similarity, dim=3) 239 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v) 240 | 241 | sv = sv.reshape(N*W,self.out_planes * 1, H).contiguous() 242 | output = self.bn_output(sv).reshape(N, W, self.out_planes, 1, H).sum(dim=-2).contiguous() 243 | 244 | 245 | if self.width: 246 | output = output.permute(0, 2, 1, 3) 247 | else: 248 | output = output.permute(0, 2, 3, 1) 249 | 250 | if self.stride > 1: 251 | output = self.pooling(output) 252 | 253 | return output 254 | 255 | def reset_parameters(self): 256 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes)) 257 | #nn.init.uniform_(self.relative, -0.1, 0.1) 258 | # nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes)) 259 | 260 | #end of attn definition 261 | 262 | class AxialBlock(nn.Module): 263 | expansion = 2 264 | 265 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 266 | base_width=64, dilation=1, norm_layer=None, kernel_size=56): 267 | super(AxialBlock, self).__init__() 268 | if norm_layer is None: 269 | norm_layer = nn.BatchNorm2d 270 | width = int(planes * (base_width / 64.)) 271 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 272 | self.conv_down = conv1x1(inplanes, width) 273 | self.bn1 = norm_layer(width) 274 | self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size) 275 | self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True) 276 | self.conv_up = conv1x1(width, planes * self.expansion) 277 | self.bn2 = norm_layer(planes * self.expansion) 278 | self.relu = nn.ReLU(inplace=True) 279 | self.downsample = downsample 280 | self.stride = stride 281 | 282 | def forward(self, x): 283 | identity = x 284 | 285 | out = self.conv_down(x) 286 | out = self.bn1(out) 287 | out = self.relu(out) 288 | # print(out.shape) 289 | out = self.hight_block(out) 290 | out = self.width_block(out) 291 | out = self.relu(out) 292 | 293 | out = self.conv_up(out) 294 | out = self.bn2(out) 295 | 296 | if self.downsample is not None: 297 | identity = self.downsample(x) 298 | 299 | out += identity 300 | out = self.relu(out) 301 | 302 | return out 303 | 304 | class AxialBlock_dynamic(nn.Module): 305 | expansion = 2 306 | 307 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 308 | base_width=64, dilation=1, norm_layer=None, kernel_size=56): 309 | super(AxialBlock_dynamic, self).__init__() 310 | if norm_layer is None: 311 | norm_layer = nn.BatchNorm2d 312 | width = int(planes * (base_width / 64.)) 313 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 314 | self.conv_down = conv1x1(inplanes, width) 315 | self.bn1 = norm_layer(width) 316 | self.hight_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size) 317 | self.width_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True) 318 | self.conv_up = conv1x1(width, planes * self.expansion) 319 | self.bn2 = norm_layer(planes * self.expansion) 320 | self.relu = nn.ReLU(inplace=True) 321 | self.downsample = downsample 322 | self.stride = stride 323 | 324 | def forward(self, x): 325 | identity = x 326 | 327 | out = self.conv_down(x) 328 | out = self.bn1(out) 329 | out = self.relu(out) 330 | 331 | out = self.hight_block(out) 332 | out = self.width_block(out) 333 | out = self.relu(out) 334 | 335 | out = self.conv_up(out) 336 | out = self.bn2(out) 337 | 338 | if self.downsample is not None: 339 | identity = self.downsample(x) 340 | 341 | out += identity 342 | out = self.relu(out) 343 | 344 | return out 345 | 346 | class AxialBlock_wopos(nn.Module): 347 | expansion = 2 348 | 349 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 350 | base_width=64, dilation=1, norm_layer=None, kernel_size=56): 351 | super(AxialBlock_wopos, self).__init__() 352 | if norm_layer is None: 353 | norm_layer = nn.BatchNorm2d 354 | # print(kernel_size) 355 | width = int(planes * (base_width / 64.)) 356 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 357 | self.conv_down = conv1x1(inplanes, width) 358 | self.conv1 = nn.Conv2d(width, width, kernel_size = 1) 359 | self.bn1 = norm_layer(width) 360 | self.hight_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size) 361 | self.width_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True) 362 | self.conv_up = conv1x1(width, planes * self.expansion) 363 | self.bn2 = norm_layer(planes * self.expansion) 364 | self.relu = nn.ReLU(inplace=True) 365 | self.downsample = downsample 366 | self.stride = stride 367 | 368 | def forward(self, x): 369 | identity = x 370 | 371 | # pdb.set_trace() 372 | 373 | out = self.conv_down(x) 374 | out = self.bn1(out) 375 | out = self.relu(out) 376 | # print(out.shape) 377 | out = self.hight_block(out) 378 | out = self.width_block(out) 379 | 380 | out = self.relu(out) 381 | 382 | out = self.conv_up(out) 383 | out = self.bn2(out) 384 | 385 | if self.downsample is not None: 386 | identity = self.downsample(x) 387 | 388 | out += identity 389 | out = self.relu(out) 390 | 391 | return out 392 | 393 | 394 | #end of block definition 395 | 396 | 397 | class ResAxialAttentionUNet(nn.Module): 398 | 399 | def __init__(self, block, layers, num_classes=2, zero_init_residual=True, 400 | groups=8, width_per_group=64, replace_stride_with_dilation=None, 401 | norm_layer=None, s=0.125, img_size = 128,imgchan = 3): 402 | super(ResAxialAttentionUNet, self).__init__() 403 | if norm_layer is None: 404 | norm_layer = nn.BatchNorm2d 405 | self._norm_layer = norm_layer 406 | 407 | self.inplanes = int(64 * s) 408 | self.dilation = 1 409 | if replace_stride_with_dilation is None: 410 | replace_stride_with_dilation = [False, False, False] 411 | if len(replace_stride_with_dilation) != 3: 412 | raise ValueError("replace_stride_with_dilation should be None " 413 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 414 | self.groups = groups 415 | self.base_width = width_per_group 416 | self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3, 417 | bias=False) 418 | self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False) 419 | self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 420 | self.bn1 = norm_layer(self.inplanes) 421 | self.bn2 = norm_layer(128) 422 | self.bn3 = norm_layer(self.inplanes) 423 | self.relu = nn.ReLU(inplace=True) 424 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 425 | self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2)) 426 | self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2), 427 | dilate=replace_stride_with_dilation[0]) 428 | self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4), 429 | dilate=replace_stride_with_dilation[1]) 430 | self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8), 431 | dilate=replace_stride_with_dilation[2]) 432 | 433 | # Decoder 434 | self.decoder1 = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1) 435 | self.decoder2 = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1) 436 | self.decoder3 = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1) 437 | self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1) 438 | self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1) 439 | self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0) 440 | self.soft = nn.Softmax(dim=1) 441 | 442 | 443 | def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False): 444 | norm_layer = self._norm_layer 445 | downsample = None 446 | previous_dilation = self.dilation 447 | if dilate: 448 | self.dilation *= stride 449 | stride = 1 450 | if stride != 1 or self.inplanes != planes * block.expansion: 451 | downsample = nn.Sequential( 452 | conv1x1(self.inplanes, planes * block.expansion, stride), 453 | norm_layer(planes * block.expansion), 454 | ) 455 | 456 | layers = [] 457 | layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups, 458 | base_width=self.base_width, dilation=previous_dilation, 459 | norm_layer=norm_layer, kernel_size=kernel_size)) 460 | self.inplanes = planes * block.expansion 461 | if stride != 1: 462 | kernel_size = kernel_size // 2 463 | 464 | for _ in range(1, blocks): 465 | layers.append(block(self.inplanes, planes, groups=self.groups, 466 | base_width=self.base_width, dilation=self.dilation, 467 | norm_layer=norm_layer, kernel_size=kernel_size)) 468 | 469 | return nn.Sequential(*layers) 470 | 471 | def _forward_impl(self, x): 472 | 473 | # AxialAttention Encoder 474 | # pdb.set_trace() 475 | x = self.conv1(x) 476 | x = self.bn1(x) 477 | x = self.relu(x) 478 | x = self.conv2(x) 479 | x = self.bn2(x) 480 | x = self.relu(x) 481 | x = self.conv3(x) 482 | x = self.bn3(x) 483 | x = self.relu(x) 484 | 485 | x1 = self.layer1(x) 486 | 487 | x2 = self.layer2(x1) 488 | # print(x2.shape) 489 | x3 = self.layer3(x2) 490 | # print(x3.shape) 491 | x4 = self.layer4(x3) 492 | 493 | x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear')) 494 | x = torch.add(x, x4) 495 | x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear')) 496 | x = torch.add(x, x3) 497 | x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear')) 498 | x = torch.add(x, x2) 499 | x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear')) 500 | x = torch.add(x, x1) 501 | x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear')) 502 | x = self.adjust(F.relu(x)) 503 | # pdb.set_trace() 504 | return x 505 | 506 | def forward(self, x): 507 | return self._forward_impl(x) 508 | 509 | class medt_net(nn.Module): 510 | 511 | def __init__(self, block, block_2, layers, num_classes=2, zero_init_residual=True, 512 | groups=8, width_per_group=64, replace_stride_with_dilation=None, 513 | norm_layer=None, s=0.125, img_size = 128,imgchan = 3): 514 | super(medt_net, self).__init__() 515 | if norm_layer is None: 516 | norm_layer = nn.BatchNorm2d 517 | self._norm_layer = norm_layer 518 | 519 | self.inplanes = int(64 * s) 520 | self.dilation = 1 521 | if replace_stride_with_dilation is None: 522 | replace_stride_with_dilation = [False, False, False] 523 | if len(replace_stride_with_dilation) != 3: 524 | raise ValueError("replace_stride_with_dilation should be None " 525 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 526 | self.groups = groups 527 | self.base_width = width_per_group 528 | self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3, 529 | bias=False) 530 | self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False) 531 | self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 532 | self.bn1 = norm_layer(self.inplanes) 533 | self.bn2 = norm_layer(128) 534 | self.bn3 = norm_layer(self.inplanes) 535 | # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 536 | self.bn1 = norm_layer(self.inplanes) 537 | self.relu = nn.ReLU(inplace=True) 538 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 539 | self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2)) 540 | self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2), 541 | dilate=replace_stride_with_dilation[0]) 542 | # self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4), 543 | # dilate=replace_stride_with_dilation[1]) 544 | # self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8), 545 | # dilate=replace_stride_with_dilation[2]) 546 | 547 | # Decoder 548 | # self.decoder1 = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1) 549 | # self.decoder2 = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1) 550 | # self.decoder3 = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1) 551 | self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1) 552 | self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1) 553 | self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0) 554 | self.soft = nn.Softmax(dim=1) 555 | 556 | 557 | self.conv1_p = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3, 558 | bias=False) 559 | self.conv2_p = nn.Conv2d(self.inplanes,128, kernel_size=3, stride=1, padding=1, 560 | bias=False) 561 | self.conv3_p = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, 562 | bias=False) 563 | # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 564 | self.bn1_p = norm_layer(self.inplanes) 565 | self.bn2_p = norm_layer(128) 566 | self.bn3_p = norm_layer(self.inplanes) 567 | 568 | self.relu_p = nn.ReLU(inplace=True) 569 | 570 | img_size_p = img_size // 4 571 | 572 | self.layer1_p = self._make_layer(block_2, int(128 * s), layers[0], kernel_size= (img_size_p//2)) 573 | self.layer2_p = self._make_layer(block_2, int(256 * s), layers[1], stride=2, kernel_size=(img_size_p//2), 574 | dilate=replace_stride_with_dilation[0]) 575 | self.layer3_p = self._make_layer(block_2, int(512 * s), layers[2], stride=2, kernel_size=(img_size_p//4), 576 | dilate=replace_stride_with_dilation[1]) 577 | self.layer4_p = self._make_layer(block_2, int(1024 * s), layers[3], stride=2, kernel_size=(img_size_p//8), 578 | dilate=replace_stride_with_dilation[2]) 579 | 580 | # Decoder 581 | self.decoder1_p = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1) 582 | self.decoder2_p = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1) 583 | self.decoder3_p = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1) 584 | self.decoder4_p = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1) 585 | self.decoder5_p = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1) 586 | 587 | self.decoderf = nn.Conv2d(int(128*s) , int(128*s) , kernel_size=3, stride=1, padding=1) 588 | self.adjust_p = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0) 589 | self.soft_p = nn.Softmax(dim=1) 590 | 591 | 592 | def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False): 593 | norm_layer = self._norm_layer 594 | downsample = None 595 | previous_dilation = self.dilation 596 | if dilate: 597 | self.dilation *= stride 598 | stride = 1 599 | if stride != 1 or self.inplanes != planes * block.expansion: 600 | downsample = nn.Sequential( 601 | conv1x1(self.inplanes, planes * block.expansion, stride), 602 | norm_layer(planes * block.expansion), 603 | ) 604 | 605 | layers = [] 606 | layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups, 607 | base_width=self.base_width, dilation=previous_dilation, 608 | norm_layer=norm_layer, kernel_size=kernel_size)) 609 | self.inplanes = planes * block.expansion 610 | if stride != 1: 611 | kernel_size = kernel_size // 2 612 | 613 | for _ in range(1, blocks): 614 | layers.append(block(self.inplanes, planes, groups=self.groups, 615 | base_width=self.base_width, dilation=self.dilation, 616 | norm_layer=norm_layer, kernel_size=kernel_size)) 617 | 618 | return nn.Sequential(*layers) 619 | 620 | def _forward_impl(self, x): 621 | 622 | xin = x.clone() 623 | x = self.conv1(x) 624 | x = self.bn1(x) 625 | x = self.relu(x) 626 | x = self.conv2(x) 627 | x = self.bn2(x) 628 | x = self.relu(x) 629 | x = self.conv3(x) 630 | x = self.bn3(x) 631 | # x = F.max_pool2d(x,2,2) 632 | x = self.relu(x) 633 | 634 | # x = self.maxpool(x) 635 | # pdb.set_trace() 636 | x1 = self.layer1(x) 637 | # print(x1.shape) 638 | x2 = self.layer2(x1) 639 | # print(x2.shape) 640 | # x3 = self.layer3(x2) 641 | # # print(x3.shape) 642 | # x4 = self.layer4(x3) 643 | # # print(x4.shape) 644 | # x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear')) 645 | # x = torch.add(x, x4) 646 | # x = F.relu(F.interpolate(self.decoder2(x4) , scale_factor=(2,2), mode ='bilinear')) 647 | # x = torch.add(x, x3) 648 | # x = F.relu(F.interpolate(self.decoder3(x3) , scale_factor=(2,2), mode ='bilinear')) 649 | # x = torch.add(x, x2) 650 | x = F.relu(F.interpolate(self.decoder4(x2) , scale_factor=(2,2), mode ='bilinear')) 651 | x = torch.add(x, x1) 652 | x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear')) 653 | # print(x.shape) 654 | 655 | # end of full image training 656 | 657 | # y_out = torch.ones((1,2,128,128)) 658 | x_loc = x.clone() 659 | # x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear')) 660 | #start 661 | for i in range(0,4): 662 | for j in range(0,4): 663 | 664 | x_p = xin[:,:,32*i:32*(i+1),32*j:32*(j+1)] 665 | # begin patch wise 666 | x_p = self.conv1_p(x_p) 667 | x_p = self.bn1_p(x_p) 668 | # x = F.max_pool2d(x,2,2) 669 | x_p = self.relu(x_p) 670 | 671 | x_p = self.conv2_p(x_p) 672 | x_p = self.bn2_p(x_p) 673 | # x = F.max_pool2d(x,2,2) 674 | x_p = self.relu(x_p) 675 | x_p = self.conv3_p(x_p) 676 | x_p = self.bn3_p(x_p) 677 | # x = F.max_pool2d(x,2,2) 678 | x_p = self.relu(x_p) 679 | 680 | # x = self.maxpool(x) 681 | # pdb.set_trace() 682 | x1_p = self.layer1_p(x_p) 683 | # print(x1.shape) 684 | x2_p = self.layer2_p(x1_p) 685 | # print(x2.shape) 686 | x3_p = self.layer3_p(x2_p) 687 | # # print(x3.shape) 688 | x4_p = self.layer4_p(x3_p) 689 | 690 | x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear')) 691 | x_p = torch.add(x_p, x4_p) 692 | x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear')) 693 | x_p = torch.add(x_p, x3_p) 694 | x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear')) 695 | x_p = torch.add(x_p, x2_p) 696 | x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear')) 697 | x_p = torch.add(x_p, x1_p) 698 | x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear')) 699 | 700 | x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p 701 | 702 | x = torch.add(x,x_loc) 703 | x = F.relu(self.decoderf(x)) 704 | 705 | x = self.adjust(F.relu(x)) 706 | 707 | # pdb.set_trace() 708 | return x 709 | 710 | def forward(self, x): 711 | return self._forward_impl(x) 712 | 713 | 714 | def axialunet(pretrained=False, **kwargs): 715 | model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs) 716 | return model 717 | 718 | def gated(pretrained=False, **kwargs): 719 | model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs) 720 | return model 721 | 722 | def MedT(pretrained=False, **kwargs): 723 | model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125, **kwargs) 724 | return model 725 | 726 | def logo(pretrained=False, **kwargs): 727 | model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs) 728 | return model 729 | 730 | # EOF -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['ResNet', 'resnet26', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152',] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 37 | base_width=64, dilation=1, norm_layer=None): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 77 | # This variant is also known as ResNet V1.5 and improves accuracy according to 78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 79 | 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 83 | base_width=64, dilation=1, norm_layer=None): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | identity = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class ResNet(nn.Module): 123 | 124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 125 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 126 | norm_layer=None): 127 | super(ResNet, self).__init__() 128 | if norm_layer is None: 129 | norm_layer = nn.BatchNorm2d 130 | self._norm_layer = norm_layer 131 | 132 | self.inplanes = 64 133 | self.dilation = 1 134 | if replace_stride_with_dilation is None: 135 | # each element in the tuple indicates if we should replace 136 | # the 2x2 stride with a dilated convolution instead 137 | replace_stride_with_dilation = [False, False, False] 138 | if len(replace_stride_with_dilation) != 3: 139 | raise ValueError("replace_stride_with_dilation should be None " 140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 141 | self.groups = groups 142 | self.base_width = width_per_group 143 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 150 | dilate=replace_stride_with_dilation[0]) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 152 | dilate=replace_stride_with_dilation[1]) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 154 | dilate=replace_stride_with_dilation[2]) 155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | # Zero-initialize the last BN in each residual branch, 166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 168 | if zero_init_residual: 169 | for m in self.modules(): 170 | if isinstance(m, Bottleneck): 171 | nn.init.constant_(m.bn3.weight, 0) 172 | elif isinstance(m, BasicBlock): 173 | nn.init.constant_(m.bn2.weight, 0) 174 | 175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 176 | norm_layer = self._norm_layer 177 | downsample = None 178 | previous_dilation = self.dilation 179 | if dilate: 180 | self.dilation *= stride 181 | stride = 1 182 | if stride != 1 or self.inplanes != planes * block.expansion: 183 | downsample = nn.Sequential( 184 | conv1x1(self.inplanes, planes * block.expansion, stride), 185 | norm_layer(planes * block.expansion), 186 | ) 187 | 188 | layers = [] 189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 190 | self.base_width, previous_dilation, norm_layer)) 191 | self.inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(self.inplanes, planes, groups=self.groups, 194 | base_width=self.base_width, dilation=self.dilation, 195 | norm_layer=norm_layer)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def _forward_impl(self, x): 200 | # See note [TorchScript super()] 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = torch.flatten(x, 1) 213 | x = self.fc(x) 214 | 215 | return x 216 | 217 | def forward(self, x): 218 | return self._forward_impl(x) 219 | 220 | 221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 222 | model = ResNet(block, layers, **kwargs) 223 | if pretrained: 224 | state_dict = load_state_dict_from_url(model_urls[arch], 225 | progress=progress) 226 | model.load_state_dict(state_dict) 227 | return model 228 | 229 | 230 | def resnet18(pretrained=False, progress=True, **kwargs): 231 | r"""ResNet-18 model from 232 | `"Deep Residual Learning for Image Recognition" `_ 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | progress (bool): If True, displays a progress bar of the download to stderr 236 | """ 237 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 238 | **kwargs) 239 | 240 | 241 | def resnet34(pretrained=False, progress=True, **kwargs): 242 | r"""ResNet-34 model from 243 | `"Deep Residual Learning for Image Recognition" `_ 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | progress (bool): If True, displays a progress bar of the download to stderr 247 | """ 248 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 249 | **kwargs) 250 | 251 | 252 | def resnet26(pretrained=False, progress=True, **kwargs): 253 | return _resnet('resnet26', Bottleneck, [1, 2, 4, 1], pretrained, progress, 254 | **kwargs) 255 | 256 | 257 | def resnet50(pretrained=False, progress=True, **kwargs): 258 | r"""ResNet-50 model from 259 | `"Deep Residual Learning for Image Recognition" `_ 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet101(pretrained=False, progress=True, **kwargs): 269 | r"""ResNet-101 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 276 | **kwargs) 277 | 278 | 279 | def resnet152(pretrained=False, progress=True, **kwargs): 280 | r"""ResNet-152 model from 281 | `"Deep Residual Learning for Image Recognition" `_ 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | progress (bool): If True, displays a progress bar of the download to stderr 285 | """ 286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 287 | **kwargs) 288 | -------------------------------------------------------------------------------- /lib/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class qkv_transform(nn.Conv1d): 5 | """Conv1d for qkv_transform""" 6 | 7 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def adjust_learning_rate(args, optimizer, epoch, batch_idx, data_nums, type="cosine"): 8 | if epoch < args.warmup_epochs: 9 | epoch += float(batch_idx + 1) / data_nums 10 | lr_adj = 1. * (epoch / args.warmup_epochs) 11 | elif type == "linear": 12 | if epoch < 30 + args.warmup_epochs: 13 | lr_adj = 1. 14 | elif epoch < 60 + args.warmup_epochs: 15 | lr_adj = 1e-1 16 | elif epoch < 90 + args.warmup_epochs: 17 | lr_adj = 1e-2 18 | else: 19 | lr_adj = 1e-3 20 | elif type == "cosine": 21 | run_epochs = epoch - args.warmup_epochs 22 | total_epochs = args.epochs - args.warmup_epochs 23 | T_cur = float(run_epochs * data_nums) + batch_idx 24 | T_total = float(total_epochs * data_nums) 25 | 26 | lr_adj = 0.5 * (1 + math.cos(math.pi * T_cur / T_total)) 27 | 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = args.lr * lr_adj 30 | return args.lr * lr_adj 31 | 32 | 33 | def label_smoothing(pred, target, eta=0.1): 34 | ''' 35 | Refer from https://arxiv.org/pdf/1512.00567.pdf 36 | :param target: N, 37 | :param n_classes: int 38 | :param eta: float 39 | :return: 40 | N x C onehot smoothed vector 41 | ''' 42 | n_classes = pred.size(1) 43 | target = torch.unsqueeze(target, 1) 44 | onehot_target = torch.zeros_like(pred) 45 | onehot_target.scatter_(1, target, 1) 46 | return onehot_target * (1 - eta) + eta / n_classes * 1 47 | 48 | 49 | def cross_entropy_for_onehot(pred, target): 50 | return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1)) 51 | 52 | 53 | def cross_entropy_with_label_smoothing(pred, target, eta=0.1): 54 | onehot_target = label_smoothing(pred, target, eta=eta) 55 | return cross_entropy_for_onehot(pred, onehot_target) 56 | 57 | 58 | def accuracy(output, target): 59 | # get the index of the max log-probability 60 | pred = output.max(1, keepdim=True)[1] 61 | return pred.eq(target.view_as(pred)).cpu().float().mean() 62 | 63 | 64 | def save_model(model, optimizer, epoch, args): 65 | os.system('mkdir -p {}'.format(args.work_dirs)) 66 | if optimizer is not None: 67 | torch.save({ 68 | 'net': model.state_dict(), 69 | 'optim': optimizer.state_dict(), 70 | 'epoch': epoch 71 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch))) 72 | else: 73 | torch.save({ 74 | 'net': model.state_dict(), 75 | 'epoch': epoch 76 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch))) 77 | 78 | 79 | def dist_save_model(model, optimizer, epoch, ngpus_per_node, args): 80 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 81 | and args.rank % ngpus_per_node == 0): 82 | os.system('mkdir -p {}'.format(args.work_dirs)) 83 | if optimizer is not None: 84 | torch.save({ 85 | 'net': model.state_dict(), 86 | 'optim': optimizer.state_dict(), 87 | 'epoch': epoch 88 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch))) 89 | else: 90 | torch.save({ 91 | 'net': model.state_dict(), 92 | 'epoch': epoch 93 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch))) 94 | 95 | 96 | def load_model(network, args): 97 | if not os.path.exists(args.work_dirs): 98 | print("No such working directory!") 99 | raise AssertionError 100 | 101 | pths = [pth.split('.')[0] for pth in os.listdir(args.work_dirs) if 'pth' in pth] 102 | if len(pths) == 0: 103 | print("No model to load!") 104 | raise AssertionError 105 | 106 | pths = [int(pth) for pth in pths] 107 | if args.test_model == -1: 108 | pth = -1 109 | if pth in pths: 110 | pass 111 | else: 112 | pth = max(pths) 113 | else: 114 | pth = args.test_model 115 | try: 116 | if args.distributed: 117 | loc = 'cuda:{}'.format(args.gpu) 118 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth)), map_location=loc) 119 | except: 120 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth))) 121 | try: 122 | network.load_state_dict(model['net'], strict=True) 123 | except: 124 | network.load_state_dict(convert_model(model['net']), strict=True) 125 | return True 126 | 127 | 128 | def resume_model(network, optimizer, args): 129 | print("Loading the model...") 130 | if not os.path.exists(args.work_dirs): 131 | print("No such working directory!") 132 | return 0 133 | pths = [pth.split('.')[0] for pth in os.listdir(args.work_dirs) if 'pth' in pth] 134 | if len(pths) == 0: 135 | print("No model to load!") 136 | return 0 137 | pths = [int(pth) for pth in pths] 138 | if args.test_model == -1: 139 | pth = max(pths) 140 | else: 141 | pth = args.test_model 142 | try: 143 | if args.distributed: 144 | loc = 'cuda:{}'.format(args.gpu) 145 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth)), map_location=loc) 146 | except: 147 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth))) 148 | try: 149 | network.load_state_dict(model['net'], strict=True) 150 | except: 151 | network.load_state_dict(convert_model(model['net']), strict=True) 152 | optimizer.load_state_dict(model['optim']) 153 | for state in optimizer.state.values(): 154 | for k, v in state.items(): 155 | if torch.is_tensor(v): 156 | try: 157 | state[k] = v.cuda(args.gpu) 158 | except: 159 | state[k] = v.cuda() 160 | return model['epoch'] 161 | 162 | 163 | def convert_model(model): 164 | new_model = {} 165 | for k in model.keys(): 166 | new_model[k[7:]] = model[k] 167 | return new_model 168 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import cross_entropy 3 | from torch.nn.modules.loss import _WeightedLoss 4 | 5 | 6 | EPSILON = 1e-32 7 | 8 | 9 | class LogNLLLoss(_WeightedLoss): 10 | __constants__ = ['weight', 'reduction', 'ignore_index'] 11 | 12 | def __init__(self, weight=None, size_average=None, reduce=None, reduction=None, 13 | ignore_index=-100): 14 | super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction) 15 | self.ignore_index = ignore_index 16 | 17 | def forward(self, y_input, y_target): 18 | # y_input = torch.log(y_input + EPSILON) 19 | return cross_entropy(y_input, y_target, weight=self.weight, 20 | ignore_index=self.ignore_index) 21 | 22 | 23 | def classwise_iou(output, gt): 24 | """ 25 | Args: 26 | output: torch.Tensor of shape (n_batch, n_classes, image.shape) 27 | gt: torch.LongTensor of shape (n_batch, image.shape) 28 | """ 29 | dims = (0, *range(2, len(output.shape))) 30 | gt = torch.zeros_like(output).scatter_(1, gt[:, None, :], 1) 31 | intersection = output*gt 32 | union = output + gt - intersection 33 | classwise_iou = (intersection.sum(dim=dims).float() + EPSILON) / (union.sum(dim=dims) + EPSILON) 34 | 35 | return classwise_iou 36 | 37 | 38 | def classwise_f1(output, gt): 39 | """ 40 | Args: 41 | output: torch.Tensor of shape (n_batch, n_classes, image.shape) 42 | gt: torch.LongTensor of shape (n_batch, image.shape) 43 | """ 44 | 45 | epsilon = 1e-20 46 | n_classes = output.shape[1] 47 | 48 | output = torch.argmax(output, dim=1) 49 | true_positives = torch.tensor([((output == i) * (gt == i)).sum() for i in range(n_classes)]).float() 50 | selected = torch.tensor([(output == i).sum() for i in range(n_classes)]).float() 51 | relevant = torch.tensor([(gt == i).sum() for i in range(n_classes)]).float() 52 | 53 | precision = (true_positives + epsilon) / (selected + epsilon) 54 | recall = (true_positives + epsilon) / (relevant + epsilon) 55 | classwise_f1 = 2 * (precision * recall) / (precision + recall) 56 | 57 | return classwise_f1 58 | 59 | 60 | def make_weighted_metric(classwise_metric): 61 | """ 62 | Args: 63 | classwise_metric: classwise metric like classwise_IOU or classwise_F1 64 | """ 65 | 66 | def weighted_metric(output, gt, weights=None): 67 | 68 | # dimensions to sum over 69 | dims = (0, *range(2, len(output.shape))) 70 | 71 | # default weights 72 | if weights == None: 73 | weights = torch.ones(output.shape[1]) / output.shape[1] 74 | else: 75 | # creating tensor if needed 76 | if len(weights) != output.shape[1]: 77 | raise ValueError("The number of weights must match with the number of classes") 78 | if not isinstance(weights, torch.Tensor): 79 | weights = torch.tensor(weights) 80 | # normalizing weights 81 | weights /= torch.sum(weights) 82 | 83 | classwise_scores = classwise_metric(output, gt).cpu() 84 | 85 | return classwise_scores 86 | 87 | return weighted_metric 88 | 89 | 90 | jaccard_index = make_weighted_metric(classwise_iou) 91 | f1_score = make_weighted_metric(classwise_f1) 92 | 93 | 94 | if __name__ == '__main__': 95 | output, gt = torch.zeros(3, 2, 5, 5), torch.zeros(3, 5, 5).long() 96 | print(classwise_iou(output, gt)) 97 | -------------------------------------------------------------------------------- /performancemetrics_ax.m: -------------------------------------------------------------------------------- 1 | 2 | % close all; 3 | % clear all; 4 | % clc; 5 | N = 328 6 | st = 0; 7 | Fsc=[]; 8 | MIU=[]; 9 | PA=[]; 10 | bestfsc=0; 11 | bestmiu=0; 12 | bestpa=0; 13 | bestep = 0; 14 | 15 | for k = 0:8 16 | k 17 | Fsc=[]; 18 | MIU=[]; 19 | PA=[]; 20 | for i = st:st+N 21 | i; 22 | %gname = strcat('./Brain_test/',num2str(i,'%04d'),'.png'); 23 | 24 | tname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Projects/axialseg/KiU-Net-pytorch/results/brainus/mix_3_gated_wopos/'; 25 | imgname = strcat(tname,num2str(50*k),'/',num2str(i,'%04d'),'.png'); 26 | lname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/test/labelcol/'; 27 | labelname = strcat(lname, num2str(i,'%04d'),'.png'); 28 | 29 | I = double(imread(imgname));tmp2=zeros(128,128); 30 | tmp2(I>131) = 255; 31 | tmp2(I<130) = 0; 32 | tmp = double(imread(labelname)); 33 | tmp = tmp(:,:,1); 34 | tmp(tmp<130)=0;tmp(tmp>131)=255; 35 | 36 | tp=0;fp=0;fn=0;tn=0;uni=0;ttp=0;lab=0; 37 | 38 | for p =1:128 39 | for q =1:128 40 | if tmp(p,q)==0 41 | if tmp2(p,q) == tmp(p,q) 42 | tn = tn+1; 43 | else 44 | fp = fp+1; 45 | uni = uni+1; 46 | ttp = ttp+1; 47 | end 48 | elseif tmp(p,q)==255 49 | lab = lab +1; 50 | if tmp2(p,q) == tmp(p,q) 51 | tp = tp+1; 52 | ttp = ttp+1; 53 | else 54 | fn = fn+1; 55 | end 56 | uni = uni+1; 57 | end 58 | 59 | end 60 | end 61 | 62 | if (tp~=0) 63 | F = (2*tp)/(2*tp+fp+fn); 64 | MIU=[MIU,(tp*1.0/uni)]; 65 | PA=[PA,(tp*1.0/ttp)]; 66 | Fsc=[Fsc;[i,F]]; 67 | else 68 | MIU=[MIU,1]; 69 | PA=[PA,1]; 70 | Fsc=[Fsc;[i,1]]; 71 | 72 | end 73 | 74 | 75 | 76 | end 77 | if bestfsc <= mean(Fsc) & (mean(Fsc) ~= 1) 78 | bestfsc = mean(Fsc); 79 | bestmiu = mean(MIU,2); 80 | bestpa = mean(PA,2); 81 | bestep = 50*k; 82 | 83 | end 84 | mean(Fsc) 85 | end 86 | 87 | bestfsc 88 | bestmiu 89 | bestpa 90 | bestep 91 | 92 | % plot(Fsc(:,1),Fsc(:,2),'-*') 93 | % hold on 94 | % plot(Fsc(:,1),Fsc1(:,2),'-s') 95 | % hold off 96 | % figure();plot(Fsc(:,1),PA,'-*');hold on 97 | % plot(Fsc(:,1),PA1,'-s');hold off 98 | % Fsc1=Fsc; 99 | % MIU1=MIU; 100 | % PA1=PA; 101 | -------------------------------------------------------------------------------- /performancemetrics_glas.m: -------------------------------------------------------------------------------- 1 | 2 | % close all; 3 | % clear all; 4 | % clc; 5 | N = 79 6 | st = 1; 7 | Fsc=[]; 8 | MIU=[]; 9 | PA=[]; 10 | bestfsc=0; 11 | bestmiu=0; 12 | bestpa=0; 13 | bestep = 0; 14 | 15 | for k = 1:24 16 | k 17 | Fsc=[]; 18 | MIU=[]; 19 | PA=[]; 20 | for i = st:st+N 21 | i; 22 | %gname = strcat('./Brain_test/',num2str(i,'%04d'),'.png'); 23 | 24 | tname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Projects/axialseg/KiU-Net-pytorch/results/glas/medT/'; 25 | imgname = strcat(tname,num2str(50*k),'/',num2str(i,'%02d'),'.png'); 26 | lname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/glas/resized/test/labelcol/'; 27 | 28 | labelname = strcat(lname, num2str(i,'%02d'),'.png'); 29 | 30 | I = double(imread(imgname));tmp2=zeros(128,128); 31 | tmp2(I>130) = 255; 32 | tmp2(I<131) = 0; 33 | tmp = double(imread(labelname)); 34 | tmp = tmp(:,:,1); 35 | tmp(tmp<130)=0;tmp(tmp>131)=255; 36 | 37 | 38 | 39 | tp=0;fp=0;fn=0;tn=0;uni=0;ttp=0;lab=0; 40 | 41 | for p =1:128 42 | for q =1:128 43 | if tmp(p,q)==0 44 | if tmp2(p,q) == tmp(p,q) 45 | tn = tn+1; 46 | else 47 | fp = fp+1; 48 | uni = uni+1; 49 | ttp = ttp+1; 50 | end 51 | elseif tmp(p,q)==255 52 | lab = lab +1; 53 | if tmp2(p,q) == tmp(p,q) 54 | tp = tp+1; 55 | ttp = ttp+1; 56 | else 57 | fn = fn+1; 58 | end 59 | uni = uni+1; 60 | end 61 | 62 | end 63 | end 64 | 65 | 66 | if (tp~=0) 67 | F = (2*tp)/(2*tp+fp+fn); 68 | MIU=[MIU,(tp*1.0/uni)]; 69 | PA=[PA,(tp*1.0/ttp)]; 70 | Fsc=[Fsc;[i,F]]; 71 | 72 | else 73 | MIU=[MIU,1]; 74 | PA=[PA,1]; 75 | Fsc=[Fsc;[i,1]]; 76 | 77 | end 78 | 79 | 80 | 81 | end 82 | if bestfsc <= mean(Fsc) & (mean(Fsc) ~= 1) 83 | bestfsc = mean(Fsc); 84 | bestmiu = mean(MIU,2); 85 | bestpa = mean(PA,2); 86 | bestep = 50*k; 87 | 88 | end 89 | mean(Fsc) 90 | end 91 | 92 | bestfsc 93 | bestmiu 94 | bestpa 95 | bestep 96 | 97 | -------------------------------------------------------------------------------- /performancemetrics_monuseg.m: -------------------------------------------------------------------------------- 1 | 2 | % close all; 3 | % clear all; 4 | % clc; 5 | N = 328 6 | st = 0; 7 | Fsc=[]; 8 | MIU=[]; 9 | PA=[]; 10 | bestfsc=0; 11 | bestmiu=0; 12 | bestpa=0; 13 | bestep = 0; 14 | 15 | folder = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/monuseg/resized/test/labelcol/'; 16 | listinfo = dir(strcat(folder,'*.png')); 17 | lm = length(listinfo); 18 | 19 | 20 | for k = 1:10 21 | k 22 | Fsc=[]; 23 | MIU=[]; 24 | PA=[]; 25 | for i = 1:lm 26 | %I = double(imread(strcat(folder,listinfo(i).name))); 27 | imgfile = strcat(folder,listinfo(i).name); 28 | imgname = listinfo(i).name(1:27) ; 29 | i; 30 | %gname = strcat('./Brain_test/',num2str(i,'%04d'),'.png'); 31 | 32 | lname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Projects/axialseg/KiU-Net-pytorch/results/monuseg/medTr/'; 33 | labelname = strcat(lname, num2str(k*10),'/', imgname); 34 | %imgname 35 | I = double(imread(imgfile));tmp2=zeros(512,512); 36 | %I = rgb2gray(I); 37 | tmp2(I>127) = 255; 38 | tmp2(I<126) = 0; 39 | tmp = double(imread(labelname)); 40 | 41 | tmp(tmp<127)=0;tmp(tmp>126)=255; 42 | %tmp2 = I; 43 | tp=0;fp=0;fn=0;tn=0;uni=0;ttp=0;lab=0; 44 | 45 | for p =1:512 46 | for q =1:512 47 | if tmp(p,q)==0 48 | if tmp2(p,q) == tmp(p,q) 49 | tn = tn+1; 50 | else 51 | fp = fp+1; 52 | uni = uni+1; 53 | ttp = ttp+1; 54 | end 55 | elseif tmp(p,q)==255 56 | lab = lab +1; 57 | if tmp2(p,q) == tmp(p,q) 58 | tp = tp+1; 59 | ttp = ttp+1; 60 | else 61 | fn = fn+1; 62 | end 63 | uni = uni+1; 64 | end 65 | 66 | end 67 | end 68 | 69 | if (tp~=0) 70 | F = (2*tp)/(2*tp+fp+fn); 71 | MIU=[MIU,(tp*1.0/uni)]; 72 | PA=[PA,(tp*1.0/ttp)]; 73 | Fsc=[Fsc;[i,F]]; 74 | else 75 | MIU=[MIU,1]; 76 | PA=[PA,1]; 77 | Fsc=[Fsc;[i,1]]; 78 | 79 | end 80 | 81 | 82 | 83 | end 84 | 85 | if bestfsc <= mean(Fsc) & (mean(Fsc) ~= 1) 86 | bestfsc = mean(Fsc); 87 | bestmiu = mean(MIU,2); 88 | bestpa = mean(PA,2); 89 | bestep = 10*k; 90 | 91 | end 92 | mean(Fsc) 93 | end 94 | 95 | bestfsc 96 | bestmiu 97 | %bestpa 98 | bestep 99 | 100 | % plot(Fsc(:,1),Fsc(:,2),'-*') 101 | % hold on 102 | % plot(Fsc(:,1),Fsc1(:,2),'-s') 103 | % hold off 104 | % figure();plot(Fsc(:,1),PA,'-*');hold on 105 | % plot(Fsc(:,1),PA1,'-s');hold off 106 | % Fsc1=Fsc; 107 | % MIU1=MIU; 108 | % PA1=PA; 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.5.0 3 | scikit-learn==0.23.2 4 | scipy==1.5.3 5 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import lib 3 | import torch 4 | import torchvision 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | from torchvision.utils import save_image 10 | from torchvision.datasets import MNIST 11 | import torch.nn.functional as F 12 | import os 13 | import matplotlib.pyplot as plt 14 | import torch.utils.data as data 15 | from PIL import Image 16 | import numpy as np 17 | from torchvision.utils import save_image 18 | import torch 19 | import torch.nn.init as init 20 | from utils import JointTransform2D, ImageToImage2D, Image2D 21 | from metrics import jaccard_index, f1_score, LogNLLLoss,classwise_f1 22 | from utils import chk_mkdir, Logger, MetricList 23 | import cv2 24 | from functools import partial 25 | from random import randint 26 | 27 | 28 | parser = argparse.ArgumentParser(description='MedT') 29 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 30 | help='number of data loading workers (default: 8)') 31 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 32 | help='number of total epochs to run(default: 1)') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch_size', default=1, type=int, 36 | metavar='N', help='batch size (default: 8)') 37 | parser.add_argument('--learning_rate', default=1e-3, type=float, 38 | metavar='LR', help='initial learning rate (default: 0.01)') 39 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 40 | help='momentum') 41 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 42 | metavar='W', help='weight decay (default: 1e-4)') 43 | parser.add_argument('--train_dataset', type=str) 44 | parser.add_argument('--val_dataset', type=str) 45 | parser.add_argument('--save_freq', type=int,default = 5) 46 | parser.add_argument('--modelname', default='off', type=str, 47 | help='name of the model to load') 48 | parser.add_argument('--cuda', default="on", type=str, 49 | help='switch on/off cuda option (default: off)') 50 | 51 | parser.add_argument('--direc', default='./results', type=str, 52 | help='directory to save') 53 | parser.add_argument('--crop', type=int, default=None) 54 | parser.add_argument('--device', default='cuda', type=str) 55 | parser.add_argument('--loaddirec', default='load', type=str) 56 | parser.add_argument('--imgsize', type=int, default=None) 57 | parser.add_argument('--gray', default='no', type=str) 58 | args = parser.parse_args() 59 | 60 | direc = args.direc 61 | gray_ = args.gray 62 | aug = args.aug 63 | direc = args.direc 64 | modelname = args.modelname 65 | imgsize = args.imgsize 66 | loaddirec = args.loaddirec 67 | 68 | if gray_ == "yes": 69 | from utils_gray import JointTransform2D, ImageToImage2D, Image2D 70 | imgchant = 1 71 | else: 72 | from utils import JointTransform2D, ImageToImage2D, Image2D 73 | imgchant = 3 74 | 75 | if args.crop is not None: 76 | crop = (args.crop, args.crop) 77 | else: 78 | crop = None 79 | 80 | tf_train = JointTransform2D(crop=crop, p_flip=0.5, color_jitter_params=None, long_mask=True) 81 | tf_val = JointTransform2D(crop=crop, p_flip=0, color_jitter_params=None, long_mask=True) 82 | train_dataset = ImageToImage2D(args.train_dataset, tf_val) 83 | val_dataset = ImageToImage2D(args.val_dataset, tf_val) 84 | predict_dataset = Image2D(args.val_dataset) 85 | dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 86 | valloader = DataLoader(val_dataset, 1, shuffle=True) 87 | 88 | device = torch.device("cuda") 89 | 90 | if modelname == "axialunet": 91 | model = lib.models.axialunet(img_size = imgsize, imgchan = imgchant) 92 | elif modelname == "MedT": 93 | model = lib.models.axialnet.MedT(img_size = imgsize, imgchan = imgchant) 94 | elif modelname == "gatedaxialunet": 95 | model = lib.models.axialnet.gated(img_size = imgsize, imgchan = imgchant) 96 | elif modelname == "logo": 97 | model = lib.models.axialnet.logo(img_size = imgsize, imgchan = imgchant) 98 | 99 | if torch.cuda.device_count() > 1: 100 | print("Let's use", torch.cuda.device_count(), "GPUs!") 101 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 102 | model = nn.DataParallel(model,device_ids=[0,1]).cuda() 103 | model.to(device) 104 | 105 | model.load_state_dict(torch.load(loaddirec)) 106 | model.eval() 107 | 108 | 109 | for batch_idx, (X_batch, y_batch, *rest) in enumerate(valloader): 110 | # print(batch_idx) 111 | if isinstance(rest[0][0], str): 112 | image_filename = rest[0][0] 113 | else: 114 | image_filename = '%s.png' % str(batch_idx + 1).zfill(3) 115 | 116 | X_batch = Variable(X_batch.to(device='cuda')) 117 | y_batch = Variable(y_batch.to(device='cuda')) 118 | 119 | y_out = model(X_batch) 120 | 121 | tmp2 = y_batch.detach().cpu().numpy() 122 | tmp = y_out.detach().cpu().numpy() 123 | tmp[tmp>=0.5] = 1 124 | tmp[tmp<0.5] = 0 125 | tmp2[tmp2>0] = 1 126 | tmp2[tmp2<=0] = 0 127 | tmp2 = tmp2.astype(int) 128 | tmp = tmp.astype(int) 129 | 130 | # print(np.unique(tmp2)) 131 | yHaT = tmp 132 | yval = tmp2 133 | 134 | epsilon = 1e-20 135 | 136 | del X_batch, y_batch,tmp,tmp2, y_out 137 | 138 | yHaT[yHaT==1] =255 139 | yval[yval==1] =255 140 | fulldir = direc+"/" 141 | 142 | if not os.path.isdir(fulldir): 143 | 144 | os.makedirs(fulldir) 145 | 146 | cv2.imwrite(fulldir+image_filename, yHaT[0,1,:,:]) 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Code for MedT 2 | 3 | import torch 4 | import lib 5 | import argparse 6 | import torch 7 | import torchvision 8 | from torch import nn 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | from torchvision.utils import save_image 13 | import torch.nn.functional as F 14 | import os 15 | import matplotlib.pyplot as plt 16 | import torch.utils.data as data 17 | from PIL import Image 18 | import numpy as np 19 | from torchvision.utils import save_image 20 | import torch 21 | import torch.nn.init as init 22 | from utils import JointTransform2D, ImageToImage2D, Image2D 23 | from metrics import jaccard_index, f1_score, LogNLLLoss,classwise_f1 24 | from utils import chk_mkdir, Logger, MetricList 25 | import cv2 26 | from functools import partial 27 | from random import randint 28 | import timeit 29 | 30 | parser = argparse.ArgumentParser(description='MedT') 31 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 32 | help='number of data loading workers (default: 8)') 33 | parser.add_argument('--epochs', default=400, type=int, metavar='N', 34 | help='number of total epochs to run(default: 400)') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('-b', '--batch_size', default=1, type=int, 38 | metavar='N', help='batch size (default: 1)') 39 | parser.add_argument('--learning_rate', default=1e-3, type=float, 40 | metavar='LR', help='initial learning rate (default: 0.001)') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 42 | help='momentum') 43 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 44 | metavar='W', help='weight decay (default: 1e-5)') 45 | parser.add_argument('--train_dataset', required=True, type=str) 46 | parser.add_argument('--val_dataset', type=str) 47 | parser.add_argument('--save_freq', type=int,default = 10) 48 | 49 | parser.add_argument('--modelname', default='MedT', type=str, 50 | help='type of model') 51 | parser.add_argument('--cuda', default="on", type=str, 52 | help='switch on/off cuda option (default: off)') 53 | parser.add_argument('--aug', default='off', type=str, 54 | help='turn on img augmentation (default: False)') 55 | parser.add_argument('--load', default='default', type=str, 56 | help='load a pretrained model') 57 | parser.add_argument('--save', default='default', type=str, 58 | help='save the model') 59 | parser.add_argument('--direc', default='./medt', type=str, 60 | help='directory to save') 61 | parser.add_argument('--crop', type=int, default=None) 62 | parser.add_argument('--imgsize', type=int, default=None) 63 | parser.add_argument('--device', default='cuda', type=str) 64 | parser.add_argument('--gray', default='no', type=str) 65 | 66 | args = parser.parse_args() 67 | gray_ = args.gray 68 | aug = args.aug 69 | direc = args.direc 70 | modelname = args.modelname 71 | imgsize = args.imgsize 72 | 73 | if gray_ == "yes": 74 | from utils_gray import JointTransform2D, ImageToImage2D, Image2D 75 | imgchant = 1 76 | else: 77 | from utils import JointTransform2D, ImageToImage2D, Image2D 78 | imgchant = 3 79 | 80 | if args.crop is not None: 81 | crop = (args.crop, args.crop) 82 | else: 83 | crop = None 84 | 85 | tf_train = JointTransform2D(crop=crop, p_flip=0.5, color_jitter_params=None, long_mask=True) 86 | tf_val = JointTransform2D(crop=crop, p_flip=0, color_jitter_params=None, long_mask=True) 87 | train_dataset = ImageToImage2D(args.train_dataset, tf_train) 88 | val_dataset = ImageToImage2D(args.val_dataset, tf_val) 89 | predict_dataset = Image2D(args.val_dataset) 90 | dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 91 | valloader = DataLoader(val_dataset, 1, shuffle=True) 92 | 93 | device = torch.device("cuda") 94 | 95 | if modelname == "axialunet": 96 | model = lib.models.axialunet(img_size = imgsize, imgchan = imgchant) 97 | elif modelname == "MedT": 98 | model = lib.models.axialnet.MedT(img_size = imgsize, imgchan = imgchant) 99 | elif modelname == "gatedaxialunet": 100 | model = lib.models.axialnet.gated(img_size = imgsize, imgchan = imgchant) 101 | elif modelname == "logo": 102 | model = lib.models.axialnet.logo(img_size = imgsize, imgchan = imgchant) 103 | 104 | if torch.cuda.device_count() > 1: 105 | print("Let's use", torch.cuda.device_count(), "GPUs!") 106 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 107 | model = nn.DataParallel(model,device_ids=[0,1]).cuda() 108 | model.to(device) 109 | 110 | criterion = LogNLLLoss() 111 | optimizer = torch.optim.Adam(list(model.parameters()), lr=args.learning_rate, 112 | weight_decay=1e-5) 113 | 114 | 115 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 116 | print("Total_params: {}".format(pytorch_total_params)) 117 | 118 | seed = 3000 119 | np.random.seed(seed) 120 | torch.manual_seed(seed) 121 | torch.cuda.manual_seed(seed) 122 | # torch.set_deterministic(True) 123 | # random.seed(seed) 124 | 125 | 126 | for epoch in range(args.epochs): 127 | 128 | epoch_running_loss = 0 129 | 130 | for batch_idx, (X_batch, y_batch, *rest) in enumerate(dataloader): 131 | 132 | 133 | 134 | X_batch = Variable(X_batch.to(device ='cuda')) 135 | y_batch = Variable(y_batch.to(device='cuda')) 136 | 137 | # ===================forward===================== 138 | 139 | 140 | output = model(X_batch) 141 | 142 | tmp2 = y_batch.detach().cpu().numpy() 143 | tmp = output.detach().cpu().numpy() 144 | tmp[tmp>=0.5] = 1 145 | tmp[tmp<0.5] = 0 146 | tmp2[tmp2>0] = 1 147 | tmp2[tmp2<=0] = 0 148 | tmp2 = tmp2.astype(int) 149 | tmp = tmp.astype(int) 150 | 151 | yHaT = tmp 152 | yval = tmp2 153 | 154 | 155 | 156 | loss = criterion(output, y_batch) 157 | 158 | # ===================backward==================== 159 | optimizer.zero_grad() 160 | loss.backward() 161 | optimizer.step() 162 | epoch_running_loss += loss.item() 163 | 164 | # ===================log======================== 165 | print('epoch [{}/{}], loss:{:.4f}' 166 | .format(epoch, args.epochs, epoch_running_loss/(batch_idx+1))) 167 | 168 | 169 | if epoch == 10: 170 | for param in model.parameters(): 171 | param.requires_grad =True 172 | if (epoch % args.save_freq) ==0: 173 | 174 | for batch_idx, (X_batch, y_batch, *rest) in enumerate(valloader): 175 | # print(batch_idx) 176 | if isinstance(rest[0][0], str): 177 | image_filename = rest[0][0] 178 | else: 179 | image_filename = '%s.png' % str(batch_idx + 1).zfill(3) 180 | 181 | X_batch = Variable(X_batch.to(device='cuda')) 182 | y_batch = Variable(y_batch.to(device='cuda')) 183 | # start = timeit.default_timer() 184 | y_out = model(X_batch) 185 | # stop = timeit.default_timer() 186 | # print('Time: ', stop - start) 187 | tmp2 = y_batch.detach().cpu().numpy() 188 | tmp = y_out.detach().cpu().numpy() 189 | tmp[tmp>=0.5] = 1 190 | tmp[tmp<0.5] = 0 191 | tmp2[tmp2>0] = 1 192 | tmp2[tmp2<=0] = 0 193 | tmp2 = tmp2.astype(int) 194 | tmp = tmp.astype(int) 195 | 196 | # print(np.unique(tmp2)) 197 | yHaT = tmp 198 | yval = tmp2 199 | 200 | epsilon = 1e-20 201 | 202 | del X_batch, y_batch,tmp,tmp2, y_out 203 | 204 | 205 | yHaT[yHaT==1] =255 206 | yval[yval==1] =255 207 | fulldir = direc+"/{}/".format(epoch) 208 | # print(fulldir+image_filename) 209 | if not os.path.isdir(fulldir): 210 | 211 | os.makedirs(fulldir) 212 | 213 | cv2.imwrite(fulldir+image_filename, yHaT[0,1,:,:]) 214 | # cv2.imwrite(fulldir+'/gt_{}.png'.format(count), yval[0,:,:]) 215 | fulldir = direc+"/{}/".format(epoch) 216 | torch.save(model.state_dict(), fulldir+args.modelname+".pth") 217 | torch.save(model.state_dict(), direc+"final_model.pth") 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from skimage import io,color 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms as T 9 | from torchvision.transforms import functional as F 10 | 11 | from typing import Callable 12 | import os 13 | import cv2 14 | import pandas as pd 15 | 16 | from numbers import Number 17 | from typing import Container 18 | from collections import defaultdict 19 | 20 | 21 | def to_long_tensor(pic): 22 | # handle numpy array 23 | img = torch.from_numpy(np.array(pic, np.uint8)) 24 | # backward compatibility 25 | return img.long() 26 | 27 | 28 | def correct_dims(*images): 29 | corr_images = [] 30 | # print(images) 31 | for img in images: 32 | if len(img.shape) == 2: 33 | corr_images.append(np.expand_dims(img, axis=2)) 34 | else: 35 | corr_images.append(img) 36 | 37 | if len(corr_images) == 1: 38 | return corr_images[0] 39 | else: 40 | return corr_images 41 | 42 | 43 | class JointTransform2D: 44 | """ 45 | Performs augmentation on image and mask when called. Due to the randomness of augmentation transforms, 46 | it is not enough to simply apply the same Transform from torchvision on the image and mask separetely. 47 | Doing this will result in messing up the ground truth mask. To circumvent this problem, this class can 48 | be used, which will take care of the problems above. 49 | 50 | Args: 51 | crop: tuple describing the size of the random crop. If bool(crop) evaluates to False, no crop will 52 | be taken. 53 | p_flip: float, the probability of performing a random horizontal flip. 54 | color_jitter_params: tuple describing the parameters of torchvision.transforms.ColorJitter. 55 | If bool(color_jitter_params) evaluates to false, no color jitter transformation will be used. 56 | p_random_affine: float, the probability of performing a random affine transform using 57 | torchvision.transforms.RandomAffine. 58 | long_mask: bool, if True, returns the mask as LongTensor in label-encoded format. 59 | """ 60 | def __init__(self, crop=(32, 32), p_flip=0.5, color_jitter_params=(0.1, 0.1, 0.1, 0.1), 61 | p_random_affine=0, long_mask=False): 62 | self.crop = crop 63 | self.p_flip = p_flip 64 | self.color_jitter_params = color_jitter_params 65 | if color_jitter_params: 66 | self.color_tf = T.ColorJitter(*color_jitter_params) 67 | self.p_random_affine = p_random_affine 68 | self.long_mask = long_mask 69 | 70 | def __call__(self, image, mask): 71 | # transforming to PIL image 72 | image, mask = F.to_pil_image(image), F.to_pil_image(mask) 73 | 74 | # random crop 75 | if self.crop: 76 | i, j, h, w = T.RandomCrop.get_params(image, self.crop) 77 | image, mask = F.crop(image, i, j, h, w), F.crop(mask, i, j, h, w) 78 | 79 | if np.random.rand() < self.p_flip: 80 | image, mask = F.hflip(image), F.hflip(mask) 81 | 82 | # color transforms || ONLY ON IMAGE 83 | if self.color_jitter_params: 84 | image = self.color_tf(image) 85 | 86 | # random affine transform 87 | if np.random.rand() < self.p_random_affine: 88 | affine_params = T.RandomAffine(180).get_params((-90, 90), (1, 1), (2, 2), (-45, 45), self.crop) 89 | image, mask = F.affine(image, *affine_params), F.affine(mask, *affine_params) 90 | 91 | # transforming to tensor 92 | image = F.to_tensor(image) 93 | if not self.long_mask: 94 | mask = F.to_tensor(mask) 95 | else: 96 | mask = to_long_tensor(mask) 97 | 98 | return image, mask 99 | 100 | 101 | class ImageToImage2D(Dataset): 102 | """ 103 | Reads the images and applies the augmentation transform on them. 104 | Usage: 105 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to 106 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image, mask and image 107 | filename. 108 | 2. With unet.model.Model wrapper, an instance of this object should be passed as train or validation 109 | datasets. 110 | 111 | Args: 112 | dataset_path: path to the dataset. Structure of the dataset should be: 113 | dataset_path 114 | |-- images 115 | |-- img001.png 116 | |-- img002.png 117 | |-- ... 118 | |-- masks 119 | |-- img001.png 120 | |-- img002.png 121 | |-- ... 122 | 123 | joint_transform: augmentation transform, an instance of JointTransform2D. If bool(joint_transform) 124 | evaluates to False, torchvision.transforms.ToTensor will be used on both image and mask. 125 | one_hot_mask: bool, if True, returns the mask in one-hot encoded form. 126 | """ 127 | 128 | def __init__(self, dataset_path: str, joint_transform: Callable = None, one_hot_mask: int = False) -> None: 129 | self.dataset_path = dataset_path 130 | self.input_path = os.path.join(dataset_path, 'img') 131 | self.output_path = os.path.join(dataset_path, 'labelcol') 132 | self.images_list = os.listdir(self.input_path) 133 | self.one_hot_mask = one_hot_mask 134 | 135 | if joint_transform: 136 | self.joint_transform = joint_transform 137 | else: 138 | to_tensor = T.ToTensor() 139 | self.joint_transform = lambda x, y: (to_tensor(x), to_tensor(y)) 140 | 141 | def __len__(self): 142 | return len(os.listdir(self.input_path)) 143 | 144 | def __getitem__(self, idx): 145 | image_filename = self.images_list[idx] 146 | #print(image_filename[: -3]) 147 | # read image 148 | # print(os.path.join(self.input_path, image_filename)) 149 | # print(os.path.join(self.output_path, image_filename[: -3] + "png")) 150 | # print(os.path.join(self.input_path, image_filename)) 151 | image = cv2.imread(os.path.join(self.input_path, image_filename)) 152 | # print(image.shape) 153 | # read mask image 154 | mask = cv2.imread(os.path.join(self.output_path, image_filename[: -3] + "png"),0) 155 | 156 | mask[mask<=127] = 0 157 | mask[mask>127] = 1 158 | # correct dimensions if needed 159 | image, mask = correct_dims(image, mask) 160 | # print(image.shape) 161 | 162 | if self.joint_transform: 163 | image, mask = self.joint_transform(image, mask) 164 | 165 | if self.one_hot_mask: 166 | assert self.one_hot_mask > 0, 'one_hot_mask must be nonnegative' 167 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1) 168 | # mask = np.swapaxes(mask,2,0) 169 | # print(image.shape) 170 | # print(mask.shape) 171 | # mask = np.transpose(mask,(2,0,1)) 172 | # image = np.transpose(image,(2,0,1)) 173 | # print(image.shape) 174 | # print(mask.shape) 175 | 176 | return image, mask, image_filename 177 | 178 | 179 | class Image2D(Dataset): 180 | """ 181 | Reads the images and applies the augmentation transform on them. As opposed to ImageToImage2D, this 182 | reads a single image and requires a simple augmentation transform. 183 | Usage: 184 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to 185 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image and image 186 | filename. 187 | 2. With unet.model.Model wrapper, an instance of this object should be passed as a prediction 188 | dataset. 189 | 190 | Args: 191 | 192 | dataset_path: path to the dataset. Structure of the dataset should be: 193 | dataset_path 194 | |-- images 195 | |-- img001.png 196 | |-- img002.png 197 | |-- ... 198 | 199 | transform: augmentation transform. If bool(joint_transform) evaluates to False, 200 | torchvision.transforms.ToTensor will be used. 201 | """ 202 | 203 | def __init__(self, dataset_path: str, transform: Callable = None): 204 | 205 | self.dataset_path = dataset_path 206 | self.input_path = os.path.join(dataset_path, 'img') 207 | self.images_list = os.listdir(self.input_path) 208 | 209 | if transform: 210 | self.transform = transform 211 | else: 212 | self.transform = T.ToTensor() 213 | 214 | def __len__(self): 215 | return len(os.listdir(self.input_path)) 216 | 217 | def __getitem__(self, idx): 218 | 219 | image_filename = self.images_list[idx] 220 | 221 | image = cv2.imread(os.path.join(self.input_path, image_filename)) 222 | 223 | # image = np.transpose(image,(2,0,1)) 224 | 225 | image = correct_dims(image) 226 | 227 | image = self.transform(image) 228 | 229 | # image = np.swapaxes(image,2,0) 230 | 231 | return image, image_filename 232 | 233 | def chk_mkdir(*paths: Container) -> None: 234 | """ 235 | Creates folders if they do not exist. 236 | 237 | Args: 238 | paths: Container of paths to be created. 239 | """ 240 | for path in paths: 241 | if not os.path.exists(path): 242 | os.makedirs(path) 243 | 244 | 245 | class Logger: 246 | def __init__(self, verbose=False): 247 | self.logs = defaultdict(list) 248 | self.verbose = verbose 249 | 250 | def log(self, logs): 251 | for key, value in logs.items(): 252 | self.logs[key].append(value) 253 | 254 | if self.verbose: 255 | print(logs) 256 | 257 | def get_logs(self): 258 | return self.logs 259 | 260 | def to_csv(self, path): 261 | pd.DataFrame(self.logs).to_csv(path, index=None) 262 | 263 | 264 | class MetricList: 265 | def __init__(self, metrics): 266 | assert isinstance(metrics, dict), '\'metrics\' must be a dictionary of callables' 267 | self.metrics = metrics 268 | self.results = {key: 0.0 for key in self.metrics.keys()} 269 | 270 | def __call__(self, y_out, y_batch): 271 | for key, value in self.metrics.items(): 272 | self.results[key] += value(y_out, y_batch) 273 | 274 | def reset(self): 275 | self.results = {key: 0.0 for key in self.metrics.keys()} 276 | 277 | def get_results(self, normalize=False): 278 | assert isinstance(normalize, bool) or isinstance(normalize, Number), '\'normalize\' must be boolean or a number' 279 | if not normalize: 280 | return self.results 281 | else: 282 | return {key: value/normalize for key, value in self.results.items()} 283 | -------------------------------------------------------------------------------- /utils_gray.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from skimage import io,color 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms as T 9 | from torchvision.transforms import functional as F 10 | 11 | from typing import Callable 12 | import os 13 | import cv2 14 | import pandas as pd 15 | 16 | from numbers import Number 17 | from typing import Container 18 | from collections import defaultdict 19 | 20 | 21 | def to_long_tensor(pic): 22 | # handle numpy array 23 | img = torch.from_numpy(np.array(pic, np.uint8)) 24 | # backward compatibility 25 | return img.long() 26 | 27 | 28 | def correct_dims(*images): 29 | corr_images = [] 30 | # print(images) 31 | for img in images: 32 | if len(img.shape) == 2: 33 | corr_images.append(np.expand_dims(img, axis=2)) 34 | else: 35 | corr_images.append(img) 36 | 37 | if len(corr_images) == 1: 38 | return corr_images[0] 39 | else: 40 | return corr_images 41 | 42 | 43 | class JointTransform2D: 44 | """ 45 | Performs augmentation on image and mask when called. Due to the randomness of augmentation transforms, 46 | it is not enough to simply apply the same Transform from torchvision on the image and mask separetely. 47 | Doing this will result in messing up the ground truth mask. To circumvent this problem, this class can 48 | be used, which will take care of the problems above. 49 | 50 | Args: 51 | crop: tuple describing the size of the random crop. If bool(crop) evaluates to False, no crop will 52 | be taken. 53 | p_flip: float, the probability of performing a random horizontal flip. 54 | color_jitter_params: tuple describing the parameters of torchvision.transforms.ColorJitter. 55 | If bool(color_jitter_params) evaluates to false, no color jitter transformation will be used. 56 | p_random_affine: float, the probability of performing a random affine transform using 57 | torchvision.transforms.RandomAffine. 58 | long_mask: bool, if True, returns the mask as LongTensor in label-encoded format. 59 | """ 60 | def __init__(self, crop=(32, 32), p_flip=0.5, color_jitter_params=(0.1, 0.1, 0.1, 0.1), 61 | p_random_affine=0, long_mask=False): 62 | self.crop = crop 63 | self.p_flip = p_flip 64 | self.color_jitter_params = color_jitter_params 65 | if color_jitter_params: 66 | self.color_tf = T.ColorJitter(*color_jitter_params) 67 | self.p_random_affine = p_random_affine 68 | self.long_mask = long_mask 69 | 70 | def __call__(self, image, mask): 71 | # transforming to PIL image 72 | image, mask = F.to_pil_image(image), F.to_pil_image(mask) 73 | 74 | # random crop 75 | if self.crop: 76 | i, j, h, w = T.RandomCrop.get_params(image, self.crop) 77 | image, mask = F.crop(image, i, j, h, w), F.crop(mask, i, j, h, w) 78 | 79 | if np.random.rand() < self.p_flip: 80 | image, mask = F.hflip(image), F.hflip(mask) 81 | 82 | # color transforms || ONLY ON IMAGE 83 | if self.color_jitter_params: 84 | image = self.color_tf(image) 85 | 86 | # random affine transform 87 | if np.random.rand() < self.p_random_affine: 88 | affine_params = T.RandomAffine(180).get_params((-90, 90), (1, 1), (2, 2), (-45, 45), self.crop) 89 | image, mask = F.affine(image, *affine_params), F.affine(mask, *affine_params) 90 | 91 | # transforming to tensor 92 | image = F.to_tensor(image) 93 | if not self.long_mask: 94 | mask = F.to_tensor(mask) 95 | else: 96 | mask = to_long_tensor(mask) 97 | 98 | return image, mask 99 | 100 | 101 | class ImageToImage2D(Dataset): 102 | """ 103 | Reads the images and applies the augmentation transform on them. 104 | Usage: 105 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to 106 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image, mask and image 107 | filename. 108 | 2. With unet.model.Model wrapper, an instance of this object should be passed as train or validation 109 | datasets. 110 | 111 | Args: 112 | dataset_path: path to the dataset. Structure of the dataset should be: 113 | dataset_path 114 | |-- images 115 | |-- img001.png 116 | |-- img002.png 117 | |-- ... 118 | |-- masks 119 | |-- img001.png 120 | |-- img002.png 121 | |-- ... 122 | 123 | joint_transform: augmentation transform, an instance of JointTransform2D. If bool(joint_transform) 124 | evaluates to False, torchvision.transforms.ToTensor will be used on both image and mask. 125 | one_hot_mask: bool, if True, returns the mask in one-hot encoded form. 126 | """ 127 | 128 | def __init__(self, dataset_path: str, joint_transform: Callable = None, one_hot_mask: int = False) -> None: 129 | self.dataset_path = dataset_path 130 | self.input_path = os.path.join(dataset_path, 'img') 131 | self.output_path = os.path.join(dataset_path, 'labelcol') 132 | self.images_list = os.listdir(self.input_path) 133 | self.one_hot_mask = one_hot_mask 134 | 135 | if joint_transform: 136 | self.joint_transform = joint_transform 137 | else: 138 | to_tensor = T.ToTensor() 139 | self.joint_transform = lambda x, y: (to_tensor(x), to_tensor(y)) 140 | 141 | def __len__(self): 142 | return len(os.listdir(self.input_path)) 143 | 144 | def __getitem__(self, idx): 145 | image_filename = self.images_list[idx] 146 | #print(image_filename[: -3]) 147 | # read image 148 | # print(os.path.join(self.input_path, image_filename)) 149 | # print(os.path.join(self.output_path, image_filename[: -3] + "png")) 150 | # print(os.path.join(self.input_path, image_filename)) 151 | image = cv2.imread(os.path.join(self.input_path, image_filename),0) 152 | # print(image.shape) 153 | # read mask image 154 | mask = cv2.imread(os.path.join(self.output_path, image_filename[: -3] + "png"),0) 155 | 156 | # correct dimensions if needed 157 | image, mask = correct_dims(image, mask) 158 | # print(image.shape) 159 | mask[mask<127] = 0 160 | mask[mask>=127] = 1 161 | 162 | 163 | if self.joint_transform: 164 | image, mask = self.joint_transform(image, mask) 165 | 166 | if self.one_hot_mask: 167 | assert self.one_hot_mask > 0, 'one_hot_mask must be nonnegative' 168 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1) 169 | # mask = np.swapaxes(mask,2,0) 170 | # print(image.shape) 171 | # print(mask.shape) 172 | # mask = np.transpose(mask,(2,0,1)) 173 | # image = np.transpose(image,(2,0,1)) 174 | # print(image.shape) 175 | # print(mask.shape) 176 | 177 | return image, mask, image_filename 178 | 179 | 180 | class Image2D(Dataset): 181 | """ 182 | Reads the images and applies the augmentation transform on them. As opposed to ImageToImage2D, this 183 | reads a single image and requires a simple augmentation transform. 184 | Usage: 185 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to 186 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image and image 187 | filename. 188 | 2. With unet.model.Model wrapper, an instance of this object should be passed as a prediction 189 | dataset. 190 | 191 | Args: 192 | 193 | dataset_path: path to the dataset. Structure of the dataset should be: 194 | dataset_path 195 | |-- images 196 | |-- img001.png 197 | |-- img002.png 198 | |-- ... 199 | 200 | transform: augmentation transform. If bool(joint_transform) evaluates to False, 201 | torchvision.transforms.ToTensor will be used. 202 | """ 203 | 204 | def __init__(self, dataset_path: str, transform: Callable = None): 205 | 206 | self.dataset_path = dataset_path 207 | self.input_path = os.path.join(dataset_path, 'img') 208 | self.images_list = os.listdir(self.input_path) 209 | 210 | if transform: 211 | self.transform = transform 212 | else: 213 | self.transform = T.ToTensor() 214 | 215 | def __len__(self): 216 | return len(os.listdir(self.input_path)) 217 | 218 | def __getitem__(self, idx): 219 | 220 | image_filename = self.images_list[idx] 221 | 222 | image = cv2.imread(os.path.join(self.input_path, image_filename),0) 223 | 224 | # image = np.transpose(image,(2,0,1)) 225 | 226 | image = correct_dims(image) 227 | 228 | image = self.transform(image) 229 | 230 | # image = np.swapaxes(image,2,0) 231 | 232 | return image, image_filename 233 | 234 | def chk_mkdir(*paths: Container) -> None: 235 | """ 236 | Creates folders if they do not exist. 237 | 238 | Args: 239 | paths: Container of paths to be created. 240 | """ 241 | for path in paths: 242 | if not os.path.exists(path): 243 | os.makedirs(path) 244 | 245 | 246 | class Logger: 247 | def __init__(self, verbose=False): 248 | self.logs = defaultdict(list) 249 | self.verbose = verbose 250 | 251 | def log(self, logs): 252 | for key, value in logs.items(): 253 | self.logs[key].append(value) 254 | 255 | if self.verbose: 256 | print(logs) 257 | 258 | def get_logs(self): 259 | return self.logs 260 | 261 | def to_csv(self, path): 262 | pd.DataFrame(self.logs).to_csv(path, index=None) 263 | 264 | 265 | class MetricList: 266 | def __init__(self, metrics): 267 | assert isinstance(metrics, dict), '\'metrics\' must be a dictionary of callables' 268 | self.metrics = metrics 269 | self.results = {key: 0.0 for key in self.metrics.keys()} 270 | 271 | def __call__(self, y_out, y_batch): 272 | for key, value in self.metrics.items(): 273 | self.results[key] += value(y_out, y_batch) 274 | 275 | def reset(self): 276 | self.results = {key: 0.0 for key in self.metrics.keys()} 277 | 278 | def get_results(self, normalize=False): 279 | assert isinstance(normalize, bool) or isinstance(normalize, Number), '\'normalize\' must be boolean or a number' 280 | if not normalize: 281 | return self.results 282 | else: 283 | return {key: value/normalize for key, value in self.results.items()} 284 | --------------------------------------------------------------------------------