├── LICENSE
├── README.md
├── cmd.txt
├── environment.yml
├── extractors.py
├── img
├── arch.png
├── medt.png
├── medt1.png
└── poster.pdf
├── lib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── build_dataloader.cpython-36.pyc
│ ├── build_dataloader.cpython-37.pyc
│ ├── build_model.cpython-36.pyc
│ ├── build_model.cpython-37.pyc
│ ├── build_optimizer.cpython-36.pyc
│ ├── build_optimizer.cpython-37.pyc
│ ├── metrics.cpython-36.pyc
│ └── metrics.cpython-37.pyc
├── build_dataloader.py
├── build_model.py
├── build_optimizer.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── imagenet1k.cpython-36.pyc
│ │ └── imagenet1k.cpython-37.pyc
│ └── imagenet1k.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── axialnet.cpython-36.pyc
│ │ ├── axialnet.cpython-37.pyc
│ │ ├── resnet.cpython-36.pyc
│ │ ├── resnet.cpython-37.pyc
│ │ ├── utils.cpython-36.pyc
│ │ └── utils.cpython-37.pyc
│ ├── axialnet.py
│ ├── model_codes.py
│ ├── resnet.py
│ └── utils.py
└── utils.py
├── metrics.py
├── performancemetrics_ax.m
├── performancemetrics_glas.m
├── performancemetrics_monuseg.m
├── requirements.txt
├── test.py
├── train.py
├── utils.py
└── utils_gray.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jeya Maria Jose
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Medical-Transformer
2 |
3 |
4 |
5 | Pytorch code for the paper
6 | ["Medical Transformer: Gated Axial-Attention for
7 | Medical Image Segmentation"](https://arxiv.org/pdf/2102.10662.pdf), MICCAI 2021
8 |
9 | [Paper](https://arxiv.org/pdf/2102.10662.pdf) | [Poster](https://drive.google.com/file/d/1gMjc5guT_dYQFT6TEEwdHAFKwG5XkEc9/view?usp=sharing)
10 |
11 | ## News:
12 |
13 | :rocket: : Checkout our latest work [UNeXt](https://arxiv.org/abs/2203.04967), a faster and more efficient segmentation architecture which is also easy to train and implement! Code is available [here](https://github.com/jeya-maria-jose/UNeXt-pytorch).
14 |
15 | ### About this repo:
16 |
17 | This repo hosts the code for the following networks:
18 |
19 | 1) Gated Axial Attention U-Net
20 | 2) MedT
21 |
22 | ## Introduction
23 |
24 | Majority of existing Transformer-based network architectures proposed for vision applications require large-scale
25 | datasets to train properly. However, compared to the datasets for vision
26 | applications, for medical imaging the number of data samples is relatively
27 | low, making it difficult to efficiently train transformers for medical appli-
28 | cations. To this end, we propose a Gated Axial-Attention model which
29 | extends the existing architectures by introducing an additional control
30 | mechanism in the self-attention module. Furthermore, to train the model
31 | effectively on medical images, we propose a Local-Global training strat-
32 | egy (LoGo) which further improves the performance. Specifically, we op-
33 | erate on the whole image and patches to learn global and local features,
34 | respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net.
35 |
36 |
37 |
38 |
39 |
40 | ### Using the code:
41 |
42 | - Clone this repository:
43 | ```bash
44 | git clone https://github.com/jeya-maria-jose/Medical-Transformer
45 | cd Medical-Transformer
46 | ```
47 |
48 | The code is stable using Python 3.6.10, Pytorch 1.4.0
49 |
50 | To install all the dependencies using conda:
51 |
52 | ```bash
53 | conda env create -f environment.yml
54 | conda activate medt
55 | ```
56 |
57 | To install all the dependencies using pip:
58 |
59 | ```bash
60 | pip install -r requirements.txt
61 | ```
62 |
63 | ### Links for downloading the public Datasets:
64 |
65 | 1) MoNuSeG Dataset - Link (Original)
66 | 2) GLAS Dataset - Link (Original)
67 | 3) Brain Anatomy US dataset from the paper will be made public soon !
68 |
69 | ## Using the Code for your dataset
70 |
71 | ### Dataset Preparation
72 |
73 | Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format.
74 |
75 |
76 |
77 | ```bash
78 | Train Folder-----
79 | img----
80 | 0001.png
81 | 0002.png
82 | .......
83 | labelcol---
84 | 0001.png
85 | 0002.png
86 | .......
87 | Validation Folder-----
88 | img----
89 | 0001.png
90 | 0002.png
91 | .......
92 | labelcol---
93 | 0001.png
94 | 0002.png
95 | .......
96 | Test Folder-----
97 | img----
98 | 0001.png
99 | 0002.png
100 | .......
101 | labelcol---
102 | 0001.png
103 | 0002.png
104 | .......
105 |
106 | ```
107 |
108 | - The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255.
109 |
110 | ### Training Command:
111 |
112 | ```bash
113 | python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no"
114 | ```
115 |
116 | ```bash
117 | Change modelname to MedT or logo to train them
118 | ```
119 |
120 | ### Testing Command:
121 |
122 | ```bash
123 | python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no"
124 | ```
125 |
126 | The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU.
127 |
128 | ### Notes:
129 |
130 | 1)Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory.
131 | 2)Google Colab Code is an unofficial implementation for quick train/test. Please follow original code for proper training.
132 |
133 | ### Acknowledgement:
134 |
135 | The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab.
136 |
137 | # Citation:
138 |
139 | ```bash
140 | @InProceedings{jose2021medical,
141 | author="Valanarasu, Jeya Maria Jose
142 | and Oza, Poojan
143 | and Hacihaliloglu, Ilker
144 | and Patel, Vishal M.",
145 | title="Medical Transformer: Gated Axial-Attention for Medical Image Segmentation",
146 | booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2021",
147 | year="2021",
148 | publisher="Springer International Publishing",
149 | address="Cham",
150 | pages="36--46",
151 | isbn="978-3-030-87193-2"
152 | }
153 |
154 | ```
155 |
156 | Open an issue or mail me directly in case of any queries or suggestions.
157 |
--------------------------------------------------------------------------------
/cmd.txt:
--------------------------------------------------------------------------------
1 | python train.py --train_dataset "/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/train/" --val_dataset "/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/test/" --direc "./results/axial128_en/" --batch_size 4 --modelname "logo" --epoch 401 --save_freq 50 --learning_rate 0.0001 --imgsize 128
2 |
3 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: medt
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - argon2-cffi=20.1.0=py36h8c4c3a4_1
8 | - attrs=20.1.0=pyh9f0ad1d_0
9 | - backcall=0.2.0=pyh9f0ad1d_0
10 | - backports=1.0=py_2
11 | - backports.functools_lru_cache=1.6.1=py_0
12 | - blas=1.0=mkl
13 | - bleach=3.1.5=pyh9f0ad1d_0
14 | - brotlipy=0.7.0=py36h8c4c3a4_1000
15 | - ca-certificates=2020.6.20=hecda079_0
16 | - certifi=2020.6.20=py36h9f0ad1d_0
17 | - cffi=1.11.5=py36_0
18 | - chardet=3.0.4=py36h9f0ad1d_1006
19 | - cryptography=3.1=py36h45558ae_0
20 | - decorator=4.4.2=py_0
21 | - defusedxml=0.6.0=py_0
22 | - entrypoints=0.3=py36h9f0ad1d_1001
23 | - idna=2.10=pyh9f0ad1d_0
24 | - importlib-metadata=1.7.0=py36h9f0ad1d_0
25 | - importlib_metadata=1.7.0=0
26 | - intel-openmp=2020.1=217
27 | - ipykernel=5.3.4=py36h95af2a2_0
28 | - ipython=7.16.1=py36h95af2a2_0
29 | - ipython_genutils=0.2.0=py_1
30 | - ipywidgets=7.5.1=py_0
31 | - jedi=0.17.2=py36h9f0ad1d_0
32 | - jinja2=2.11.2=pyh9f0ad1d_0
33 | - json5=0.9.4=pyh9f0ad1d_0
34 | - jsonschema=3.2.0=py36h9f0ad1d_1
35 | - jupyter_client=6.1.7=py_0
36 | - jupyter_core=4.6.3=py36h9f0ad1d_1
37 | - jupyterlab=2.2.6=py_0
38 | - jupyterlab_server=1.2.0=py_0
39 | - ld_impl_linux-64=2.33.1=h53a641e_7
40 | - libedit=3.1.20191231=h7b6447c_0
41 | - libffi=3.3=he6710b0_1
42 | - libgcc-ng=9.1.0=hdf63c60_0
43 | - libgfortran-ng=7.3.0=hdf63c60_0
44 | - libsodium=1.0.18=h516909a_0
45 | - libstdcxx-ng=9.1.0=hdf63c60_0
46 | - markupsafe=1.1.1=py36h8c4c3a4_1
47 | - mistune=0.8.4=py36h8c4c3a4_1001
48 | - mkl=2020.1=217
49 | - mkl-service=2.3.0=py36he904b0f_0
50 | - mkl_fft=1.1.0=py36h23d657b_0
51 | - mkl_random=1.1.1=py36h0573a6f_0
52 | - nbconvert=5.6.1=py36h9f0ad1d_1
53 | - nbformat=5.0.7=py_0
54 | - ncurses=6.2=he6710b0_1
55 | - notebook=6.1.3=py36h9f0ad1d_0
56 | - numpy=1.18.5=py36ha1c710e_0
57 | - numpy-base=1.18.5=py36hde5b4d6_0
58 | - openssl=1.1.1g=h516909a_1
59 | - packaging=20.4=pyh9f0ad1d_0
60 | - pandoc=2.10.1=h516909a_0
61 | - pandocfilters=1.4.2=py_1
62 | - parso=0.7.1=pyh9f0ad1d_0
63 | - pexpect=4.8.0=py36h9f0ad1d_1
64 | - pickleshare=0.7.5=py36h9f0ad1d_1001
65 | - pip=20.1.1=py36_1
66 | - prometheus_client=0.8.0=pyh9f0ad1d_0
67 | - prompt-toolkit=3.0.7=py_0
68 | - ptyprocess=0.6.0=py_1001
69 | - pycparser=2.20=pyh9f0ad1d_2
70 | - pygments=2.6.1=py_0
71 | - pyopenssl=19.1.0=py_1
72 | - pyparsing=2.4.7=pyh9f0ad1d_0
73 | - pyrsistent=0.16.0=py36h8c4c3a4_0
74 | - pysocks=1.7.1=py36h9f0ad1d_1
75 | - python=3.6.10=h7579374_2
76 | - python-dateutil=2.8.1=py_0
77 | - python_abi=3.6=1_cp36m
78 | - pyzmq=19.0.2=py36h9947dbf_0
79 | - readline=8.0=h7b6447c_0
80 | - requests=2.24.0=pyh9f0ad1d_0
81 | - send2trash=1.5.0=py_0
82 | - setuptools=47.3.1=py36_0
83 | - six=1.15.0=py_0
84 | - sqlite=3.32.3=h62c20be_0
85 | - terminado=0.8.3=py36h9f0ad1d_1
86 | - testpath=0.4.4=py_0
87 | - tk=8.6.10=hbc83047_0
88 | - tornado=6.0.4=py36h8c4c3a4_1
89 | - traitlets=4.3.3=py36h9f0ad1d_1
90 | - urllib3=1.25.10=py_0
91 | - wcwidth=0.2.5=pyh9f0ad1d_1
92 | - webencodings=0.5.1=py_1
93 | - wheel=0.34.2=py36_0
94 | - widgetsnbextension=3.5.1=py36h9f0ad1d_1
95 | - xz=5.2.5=h7b6447c_0
96 | - yaml=0.2.5=h7b6447c_0
97 | - zeromq=4.3.2=he1b5a44_3
98 | - zipp=3.1.0=py_0
99 | - zlib=1.2.11=h7b6447c_3
100 | - pip:
101 | - ci-info==0.2.0
102 | - click==7.1.2
103 | - cython==0.29.20
104 | - et-xmlfile==1.0.1
105 | - etelemetry==0.2.1
106 | - filelock==3.0.12
107 | - isodate==0.6.0
108 | - jdcal==1.4.1
109 | - joblib==0.17.0
110 | - lxml==4.5.1
111 | - matplotlib==3.3.2
112 | - medpy==0.4.0
113 | - natsort==7.0.1
114 | - nibabel==3.1.0
115 | - nipype==1.5.0
116 | - openpyxl==3.0.4
117 | - prov==1.5.3
118 | - pydicom==2.0.0
119 | - pydot==1.4.1
120 | - pydotplus==2.0.2
121 | - pynrrd==0.4.2
122 | - rdflib==5.0.0
123 | - scikit-learn==0.23.2
124 | - scipy==1.5.3
125 | - setproctitle==1.1.10
126 | - simplejson==3.17.0
127 | - threadpoolctl==2.1.0
128 | - torch==1.4.0
129 | - torch-dwconv==0.1.0
130 | - torchvision==0.4.0
131 | - traits==6.1.0
132 | prefix: /home/jeyamariajose/anaconda3/envs/medt
133 |
134 |
--------------------------------------------------------------------------------
/extractors.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils import model_zoo
8 | from torchvision.models.densenet import densenet121, densenet161
9 | from torchvision.models.squeezenet import squeezenet1_1
10 |
11 |
12 | def load_weights_sequential(target, source_state):
13 | new_dict = OrderedDict()
14 | for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()):
15 | new_dict[k1] = v2
16 | target.load_state_dict(new_dict)
17 |
18 | '''
19 | Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x
20 | '''
21 | model_urls = {
22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
32 | padding=dilation, dilation=dilation, bias=False)
33 |
34 |
35 | class BasicBlock(nn.Module):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
39 | super(BasicBlock, self).__init__()
40 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
41 | self.bn1 = nn.BatchNorm2d(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
44 | self.bn2 = nn.BatchNorm2d(planes)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out += residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | expansion = 4
69 |
70 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
71 | super(Bottleneck, self).__init__()
72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
73 | self.bn1 = nn.BatchNorm2d(planes)
74 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
75 | padding=dilation, bias=False)
76 | self.bn2 = nn.BatchNorm2d(planes)
77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
78 | self.bn3 = nn.BatchNorm2d(planes * 4)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.relu(out)
102 |
103 | return out
104 |
105 |
106 | class ResNet(nn.Module):
107 | def __init__(self, block, layers=(3, 4, 23, 3)):
108 | self.inplanes = 64
109 | super(ResNet, self).__init__()
110 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
111 | bias=False)
112 | self.bn1 = nn.BatchNorm2d(64)
113 | self.relu = nn.ReLU(inplace=True)
114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115 | self.layer1 = self._make_layer(block, 64, layers[0])
116 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
117 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
118 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2. / n))
124 | elif isinstance(m, nn.BatchNorm2d):
125 | m.weight.data.fill_(1)
126 | m.bias.data.zero_()
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | nn.Conv2d(self.inplanes, planes * block.expansion,
133 | kernel_size=1, stride=stride, bias=False),
134 | nn.BatchNorm2d(planes * block.expansion),
135 | )
136 |
137 | layers = [block(self.inplanes, planes, stride, downsample)]
138 | self.inplanes = planes * block.expansion
139 | for i in range(1, blocks):
140 | layers.append(block(self.inplanes, planes, dilation=dilation))
141 |
142 | return nn.Sequential(*layers)
143 |
144 | def forward(self, x):
145 | x = self.conv1(x)
146 | x = self.bn1(x)
147 | x = self.relu(x)
148 | x = self.maxpool(x)
149 |
150 | x = self.layer1(x)
151 | x = self.layer2(x)
152 | x_3 = self.layer3(x)
153 | x = self.layer4(x_3)
154 |
155 | return x, x_3
156 |
157 |
158 | '''
159 | Implementation of DenseNet with deep supervision. Downsampling is changed to 8x
160 | '''
161 |
162 |
163 | class _DenseLayer(nn.Sequential):
164 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, index):
165 | super(_DenseLayer, self).__init__()
166 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
167 | self.add_module('relu1', nn.ReLU(inplace=True)),
168 | if index == 3:
169 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
170 | growth_rate, kernel_size=1, stride=1, bias=False)),
171 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
172 | self.add_module('relu2', nn.ReLU(inplace=True)),
173 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
174 | kernel_size=3, stride=1, dilation=2, padding=2, bias=False)),
175 | else:
176 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
177 | growth_rate, kernel_size=1, stride=1, bias=False)),
178 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
179 | self.add_module('relu2', nn.ReLU(inplace=True)),
180 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
181 | kernel_size=3, stride=1, padding=1, bias=False)),
182 | self.drop_rate = drop_rate
183 |
184 | def forward(self, x):
185 | new_features = super(_DenseLayer, self).forward(x)
186 | if self.drop_rate > 0:
187 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
188 | return torch.cat([x, new_features], 1)
189 |
190 |
191 | class _DenseBlock(nn.Sequential):
192 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, index):
193 | super(_DenseBlock, self).__init__()
194 | for i in range(num_layers):
195 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, index)
196 | self.add_module('denselayer%d' % (i + 1), layer)
197 |
198 |
199 | class _Transition(nn.Sequential):
200 | def __init__(self, num_input_features, num_output_features, downsample=True):
201 | super(_Transition, self).__init__()
202 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
203 | self.add_module('relu', nn.ReLU(inplace=True))
204 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
205 | kernel_size=1, stride=1, bias=False))
206 | if downsample:
207 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
208 | else:
209 | self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1)) # compatibility hack
210 |
211 |
212 | class DenseNet(nn.Module):
213 | def __init__(self, growth_rate=8, block_config=(6, 12, 24, 16),
214 | num_init_features=16, bn_size=4, drop_rate=0, pretrained=False):
215 |
216 | super(DenseNet, self).__init__()
217 |
218 | # First convolution
219 | self.start_features = nn.Sequential(OrderedDict([
220 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
221 | ('norm0', nn.BatchNorm2d(num_init_features)),
222 | ('relu0', nn.ReLU(inplace=True)),
223 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
224 | ]))
225 |
226 | # Each denseblock
227 | num_features = num_init_features
228 |
229 | init_weights = list(densenet121(pretrained=True).features.children())
230 | start = 0
231 | for i, c in enumerate(self.start_features.children()):
232 | #if pretrained:
233 | #c.load_state_dict(init_weights[i].state_dict())
234 | start += 1
235 | self.blocks = nn.ModuleList()
236 | for i, num_layers in enumerate(block_config):
237 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
238 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, index = i)
239 | if pretrained:
240 | block.load_state_dict(init_weights[start].state_dict())
241 | start += 1
242 | self.blocks.append(block)
243 | setattr(self, 'denseblock%d' % (i + 1), block)
244 |
245 | num_features = num_features + num_layers * growth_rate
246 | if i != len(block_config) - 1:
247 | downsample = i < 1
248 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
249 | downsample=downsample)
250 | if pretrained:
251 | trans.load_state_dict(init_weights[start].state_dict())
252 | start += 1
253 | self.blocks.append(trans)
254 | setattr(self, 'transition%d' % (i + 1), trans)
255 | num_features = num_features // 2
256 |
257 | def forward(self, x):
258 | out = self.start_features(x)
259 | deep_features = None
260 | for i, block in enumerate(self.blocks):
261 | out = block(out)
262 | if i == 5:
263 | deep_features = out
264 |
265 | return out, deep_features
266 |
267 |
268 | class Fire(nn.Module):
269 |
270 | def __init__(self, inplanes, squeeze_planes,
271 | expand1x1_planes, expand3x3_planes, dilation=1):
272 | super(Fire, self).__init__()
273 | self.inplanes = inplanes
274 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
275 | self.squeeze_activation = nn.ReLU(inplace=True)
276 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
277 | kernel_size=1)
278 | self.expand1x1_activation = nn.ReLU(inplace=True)
279 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
280 | kernel_size=3, padding=dilation, dilation=dilation)
281 | self.expand3x3_activation = nn.ReLU(inplace=True)
282 |
283 | def forward(self, x):
284 | x = self.squeeze_activation(self.squeeze(x))
285 | return torch.cat([
286 | self.expand1x1_activation(self.expand1x1(x)),
287 | self.expand3x3_activation(self.expand3x3(x))
288 | ], 1)
289 |
290 |
291 | class SqueezeNet(nn.Module):
292 |
293 | def __init__(self, pretrained=False):
294 | super(SqueezeNet, self).__init__()
295 |
296 | self.feat_1 = nn.Sequential(
297 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
298 | nn.ReLU(inplace=True)
299 | )
300 | self.feat_2 = nn.Sequential(
301 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
302 | Fire(64, 16, 64, 64),
303 | Fire(128, 16, 64, 64)
304 | )
305 | self.feat_3 = nn.Sequential(
306 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
307 | Fire(128, 32, 128, 128, 2),
308 | Fire(256, 32, 128, 128, 2)
309 | )
310 | self.feat_4 = nn.Sequential(
311 | Fire(256, 48, 192, 192, 4),
312 | Fire(384, 48, 192, 192, 4),
313 | Fire(384, 64, 256, 256, 4),
314 | Fire(512, 64, 256, 256, 4)
315 | )
316 | if pretrained:
317 | weights = squeezenet1_1(pretrained=True).features.state_dict()
318 | load_weights_sequential(self, weights)
319 |
320 | def forward(self, x):
321 | f1 = self.feat_1(x)
322 | f2 = self.feat_2(f1)
323 | f3 = self.feat_3(f2)
324 | f4 = self.feat_4(f3)
325 | return f4, f3
326 |
327 |
328 | '''
329 | Handy methods for construction
330 | '''
331 |
332 |
333 | def squeezenet(pretrained=True):
334 | return SqueezeNet(pretrained)
335 |
336 |
337 | def densenet(pretrained=True):
338 | return DenseNet(pretrained=pretrained)
339 |
340 |
341 | def resnet18(pretrained=True):
342 | model = ResNet(BasicBlock, [2, 2, 2, 2])
343 | if pretrained:
344 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']))
345 | return model
346 |
347 |
348 | def resnet34(pretrained=True):
349 | model = ResNet(BasicBlock, [3, 4, 6, 3])
350 | if pretrained:
351 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34']))
352 | return model
353 |
354 |
355 | def resnet50(pretrained=True):
356 | model = ResNet(Bottleneck, [3, 4, 6, 3])
357 | if pretrained:
358 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']))
359 | return model
360 |
361 |
362 | def resnet101(pretrained=True):
363 | model = ResNet(Bottleneck, [3, 4, 23, 3])
364 | if pretrained:
365 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101']))
366 | return model
367 |
368 |
369 | def resnet152(pretrained=True):
370 | model = ResNet(Bottleneck, [3, 8, 36, 3])
371 | if pretrained:
372 | load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152']))
373 | return model
374 |
--------------------------------------------------------------------------------
/img/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/arch.png
--------------------------------------------------------------------------------
/img/medt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/medt.png
--------------------------------------------------------------------------------
/img/medt1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/medt1.png
--------------------------------------------------------------------------------
/img/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/img/poster.pdf
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .build_dataloader import build_dataloader
2 | from .build_model import build_model
3 | from .build_optimizer import build_optimizer
4 | from .metrics import Metric
5 |
6 |
7 | __all__ = ['build_dataloader', 'build_model', 'build_optimizer', 'Metric']
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/build_dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/build_dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/build_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_model.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/build_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_model.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/build_optimizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_optimizer.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/build_optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/build_optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/build_dataloader.py:
--------------------------------------------------------------------------------
1 | from . import datasets
2 |
3 |
4 | def build_dataloader(args, distributed=False):
5 | return datasets.__dict__[args.dataset](args, distributed)
6 |
--------------------------------------------------------------------------------
/lib/build_model.py:
--------------------------------------------------------------------------------
1 | from . import models
2 |
3 |
4 | def build_model(args):
5 | model = models.__dict__[args.model](num_classes=args.num_classes)
6 | return model
7 |
--------------------------------------------------------------------------------
/lib/build_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def build_optimizer(args, model):
5 | if args.optim == 'sgd':
6 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
7 | momentum=args.momentum, weight_decay=args.weight_decay,
8 | nesterov=args.nesterov)
9 | else:
10 | raise AssertionError
11 | return optimizer
12 |
13 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .imagenet1k import imagenet1k
2 |
3 |
4 | __all__ = ['imagenet1k']
5 |
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/imagenet1k.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/imagenet1k.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/imagenet1k.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/datasets/__pycache__/imagenet1k.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/imagenet1k.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torchvision import datasets, transforms
4 |
5 |
6 | def imagenet1k(args, distributed=False):
7 | train_dirs = args.train_dirs
8 | val_dirs = args.val_dirs
9 | batch_size = args.batch_size
10 | val_batch_size = args.val_batch_size
11 | num_workers = args.num_workers
12 | color_jitter = args.color_jitter
13 |
14 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15 | process = [
16 | transforms.RandomResizedCrop(224),
17 | transforms.RandomHorizontalFlip(),
18 | ]
19 | if color_jitter:
20 | process += [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]
21 | process += [
22 | transforms.ToTensor(),
23 | normalize
24 | ]
25 |
26 | transform_train = transforms.Compose(process)
27 |
28 | train_set = datasets.ImageFolder(train_dirs,
29 | transform=transform_train)
30 |
31 | if distributed:
32 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
33 | else:
34 | train_sampler = None
35 |
36 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=(train_sampler is None),
37 | sampler=train_sampler, num_workers=num_workers, pin_memory=True)
38 |
39 | transform_val = transforms.Compose(
40 | [transforms.Resize(256),
41 | transforms.CenterCrop(224),
42 | transforms.ToTensor(),
43 | normalize])
44 |
45 | val_set = datasets.ImageFolder(root=val_dirs,
46 | transform=transform_val)
47 |
48 | if distributed:
49 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_set)
50 | else:
51 | val_sampler = None
52 |
53 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False,
54 | sampler=val_sampler, num_workers=num_workers, pin_memory=True)
55 |
56 | return train_loader, train_sampler, val_loader, val_sampler
57 |
--------------------------------------------------------------------------------
/lib/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Metric(object):
5 | def __init__(self, name):
6 | self.name = name
7 | self.sum = torch.tensor(0.)
8 | self.n = torch.tensor(0.)
9 |
10 | def update(self, val):
11 | self.sum += val.detach().cpu()
12 | self.n += 1
13 |
14 | @property
15 | def avg(self):
16 | return self.sum / self.n
--------------------------------------------------------------------------------
/lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .axialnet import *
3 |
--------------------------------------------------------------------------------
/lib/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/axialnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/axialnet.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/axialnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/axialnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/models/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeya-maria-jose/Medical-Transformer/62f40a530c912d6b1cf297a52af7d22834ad6640/lib/models/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/models/axialnet.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .utils import *
7 | import pdb
8 | import matplotlib.pyplot as plt
9 |
10 | import random
11 |
12 |
13 |
14 | def conv1x1(in_planes, out_planes, stride=1):
15 | """1x1 convolution"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
17 |
18 |
19 | class AxialAttention(nn.Module):
20 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
21 | stride=1, bias=False, width=False):
22 | assert (in_planes % groups == 0) and (out_planes % groups == 0)
23 | super(AxialAttention, self).__init__()
24 | self.in_planes = in_planes
25 | self.out_planes = out_planes
26 | self.groups = groups
27 | self.group_planes = out_planes // groups
28 | self.kernel_size = kernel_size
29 | self.stride = stride
30 | self.bias = bias
31 | self.width = width
32 |
33 | # Multi-head self attention
34 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
35 | padding=0, bias=False)
36 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
37 | self.bn_similarity = nn.BatchNorm2d(groups * 3)
38 |
39 | self.bn_output = nn.BatchNorm1d(out_planes * 2)
40 |
41 | # Position embedding
42 | self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
43 | query_index = torch.arange(kernel_size).unsqueeze(0)
44 | key_index = torch.arange(kernel_size).unsqueeze(1)
45 | relative_index = key_index - query_index + kernel_size - 1
46 | self.register_buffer('flatten_index', relative_index.view(-1))
47 | if stride > 1:
48 | self.pooling = nn.AvgPool2d(stride, stride=stride)
49 |
50 | self.reset_parameters()
51 |
52 | def forward(self, x):
53 | # pdb.set_trace()
54 | if self.width:
55 | x = x.permute(0, 2, 1, 3)
56 | else:
57 | x = x.permute(0, 3, 1, 2) # N, W, C, H
58 | N, W, C, H = x.shape
59 | x = x.contiguous().view(N * W, C, H)
60 |
61 | # Transformations
62 | qkv = self.bn_qkv(self.qkv_transform(x))
63 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
64 |
65 | # Calculate position embedding
66 | all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
67 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
68 |
69 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
70 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
71 |
72 | qk = torch.einsum('bgci, bgcj->bgij', q, k)
73 |
74 | stacked_similarity = torch.cat([qk, qr, kr], dim=1)
75 | stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
76 | #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
77 | # (N, groups, H, H, W)
78 | similarity = F.softmax(stacked_similarity, dim=3)
79 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
80 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
81 | stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
82 | output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
83 |
84 | if self.width:
85 | output = output.permute(0, 2, 1, 3)
86 | else:
87 | output = output.permute(0, 2, 3, 1)
88 |
89 | if self.stride > 1:
90 | output = self.pooling(output)
91 |
92 | return output
93 |
94 | def reset_parameters(self):
95 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
96 | #nn.init.uniform_(self.relative, -0.1, 0.1)
97 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
98 |
99 | class AxialAttention_dynamic(nn.Module):
100 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
101 | stride=1, bias=False, width=False):
102 | assert (in_planes % groups == 0) and (out_planes % groups == 0)
103 | super(AxialAttention_dynamic, self).__init__()
104 | self.in_planes = in_planes
105 | self.out_planes = out_planes
106 | self.groups = groups
107 | self.group_planes = out_planes // groups
108 | self.kernel_size = kernel_size
109 | self.stride = stride
110 | self.bias = bias
111 | self.width = width
112 |
113 | # Multi-head self attention
114 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
115 | padding=0, bias=False)
116 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
117 | self.bn_similarity = nn.BatchNorm2d(groups * 3)
118 | self.bn_output = nn.BatchNorm1d(out_planes * 2)
119 |
120 | # Priority on encoding
121 |
122 | ## Initial values
123 |
124 | self.f_qr = nn.Parameter(torch.tensor(0.1), requires_grad=False)
125 | self.f_kr = nn.Parameter(torch.tensor(0.1), requires_grad=False)
126 | self.f_sve = nn.Parameter(torch.tensor(0.1), requires_grad=False)
127 | self.f_sv = nn.Parameter(torch.tensor(1.0), requires_grad=False)
128 |
129 |
130 | # Position embedding
131 | self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
132 | query_index = torch.arange(kernel_size).unsqueeze(0)
133 | key_index = torch.arange(kernel_size).unsqueeze(1)
134 | relative_index = key_index - query_index + kernel_size - 1
135 | self.register_buffer('flatten_index', relative_index.view(-1))
136 | if stride > 1:
137 | self.pooling = nn.AvgPool2d(stride, stride=stride)
138 |
139 | self.reset_parameters()
140 | # self.print_para()
141 |
142 | def forward(self, x):
143 | if self.width:
144 | x = x.permute(0, 2, 1, 3)
145 | else:
146 | x = x.permute(0, 3, 1, 2) # N, W, C, H
147 | N, W, C, H = x.shape
148 | x = x.contiguous().view(N * W, C, H)
149 |
150 | # Transformations
151 | qkv = self.bn_qkv(self.qkv_transform(x))
152 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
153 |
154 | # Calculate position embedding
155 | all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
156 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
157 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
158 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
159 | qk = torch.einsum('bgci, bgcj->bgij', q, k)
160 |
161 |
162 | # multiply by factors
163 | qr = torch.mul(qr, self.f_qr)
164 | kr = torch.mul(kr, self.f_kr)
165 |
166 | stacked_similarity = torch.cat([qk, qr, kr], dim=1)
167 | stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
168 | #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
169 | # (N, groups, H, H, W)
170 | similarity = F.softmax(stacked_similarity, dim=3)
171 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
172 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
173 |
174 | # multiply by factors
175 | sv = torch.mul(sv, self.f_sv)
176 | sve = torch.mul(sve, self.f_sve)
177 |
178 | stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
179 | output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
180 |
181 | if self.width:
182 | output = output.permute(0, 2, 1, 3)
183 | else:
184 | output = output.permute(0, 2, 3, 1)
185 |
186 | if self.stride > 1:
187 | output = self.pooling(output)
188 |
189 | return output
190 | def reset_parameters(self):
191 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
192 | #nn.init.uniform_(self.relative, -0.1, 0.1)
193 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
194 |
195 | class AxialAttention_wopos(nn.Module):
196 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
197 | stride=1, bias=False, width=False):
198 | assert (in_planes % groups == 0) and (out_planes % groups == 0)
199 | super(AxialAttention_wopos, self).__init__()
200 | self.in_planes = in_planes
201 | self.out_planes = out_planes
202 | self.groups = groups
203 | self.group_planes = out_planes // groups
204 | self.kernel_size = kernel_size
205 | self.stride = stride
206 | self.bias = bias
207 | self.width = width
208 |
209 | # Multi-head self attention
210 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
211 | padding=0, bias=False)
212 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
213 | self.bn_similarity = nn.BatchNorm2d(groups )
214 |
215 | self.bn_output = nn.BatchNorm1d(out_planes * 1)
216 |
217 | if stride > 1:
218 | self.pooling = nn.AvgPool2d(stride, stride=stride)
219 |
220 | self.reset_parameters()
221 |
222 | def forward(self, x):
223 | if self.width:
224 | x = x.permute(0, 2, 1, 3)
225 | else:
226 | x = x.permute(0, 3, 1, 2) # N, W, C, H
227 | N, W, C, H = x.shape
228 | x = x.contiguous().view(N * W, C, H)
229 |
230 | # Transformations
231 | qkv = self.bn_qkv(self.qkv_transform(x))
232 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
233 |
234 | qk = torch.einsum('bgci, bgcj->bgij', q, k)
235 |
236 | stacked_similarity = self.bn_similarity(qk).reshape(N * W, 1, self.groups, H, H).sum(dim=1).contiguous()
237 |
238 | similarity = F.softmax(stacked_similarity, dim=3)
239 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
240 |
241 | sv = sv.reshape(N*W,self.out_planes * 1, H).contiguous()
242 | output = self.bn_output(sv).reshape(N, W, self.out_planes, 1, H).sum(dim=-2).contiguous()
243 |
244 |
245 | if self.width:
246 | output = output.permute(0, 2, 1, 3)
247 | else:
248 | output = output.permute(0, 2, 3, 1)
249 |
250 | if self.stride > 1:
251 | output = self.pooling(output)
252 |
253 | return output
254 |
255 | def reset_parameters(self):
256 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
257 | #nn.init.uniform_(self.relative, -0.1, 0.1)
258 | # nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
259 |
260 | #end of attn definition
261 |
262 | class AxialBlock(nn.Module):
263 | expansion = 2
264 |
265 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
266 | base_width=64, dilation=1, norm_layer=None, kernel_size=56):
267 | super(AxialBlock, self).__init__()
268 | if norm_layer is None:
269 | norm_layer = nn.BatchNorm2d
270 | width = int(planes * (base_width / 64.))
271 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
272 | self.conv_down = conv1x1(inplanes, width)
273 | self.bn1 = norm_layer(width)
274 | self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
275 | self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
276 | self.conv_up = conv1x1(width, planes * self.expansion)
277 | self.bn2 = norm_layer(planes * self.expansion)
278 | self.relu = nn.ReLU(inplace=True)
279 | self.downsample = downsample
280 | self.stride = stride
281 |
282 | def forward(self, x):
283 | identity = x
284 |
285 | out = self.conv_down(x)
286 | out = self.bn1(out)
287 | out = self.relu(out)
288 | # print(out.shape)
289 | out = self.hight_block(out)
290 | out = self.width_block(out)
291 | out = self.relu(out)
292 |
293 | out = self.conv_up(out)
294 | out = self.bn2(out)
295 |
296 | if self.downsample is not None:
297 | identity = self.downsample(x)
298 |
299 | out += identity
300 | out = self.relu(out)
301 |
302 | return out
303 |
304 | class AxialBlock_dynamic(nn.Module):
305 | expansion = 2
306 |
307 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
308 | base_width=64, dilation=1, norm_layer=None, kernel_size=56):
309 | super(AxialBlock_dynamic, self).__init__()
310 | if norm_layer is None:
311 | norm_layer = nn.BatchNorm2d
312 | width = int(planes * (base_width / 64.))
313 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
314 | self.conv_down = conv1x1(inplanes, width)
315 | self.bn1 = norm_layer(width)
316 | self.hight_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size)
317 | self.width_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
318 | self.conv_up = conv1x1(width, planes * self.expansion)
319 | self.bn2 = norm_layer(planes * self.expansion)
320 | self.relu = nn.ReLU(inplace=True)
321 | self.downsample = downsample
322 | self.stride = stride
323 |
324 | def forward(self, x):
325 | identity = x
326 |
327 | out = self.conv_down(x)
328 | out = self.bn1(out)
329 | out = self.relu(out)
330 |
331 | out = self.hight_block(out)
332 | out = self.width_block(out)
333 | out = self.relu(out)
334 |
335 | out = self.conv_up(out)
336 | out = self.bn2(out)
337 |
338 | if self.downsample is not None:
339 | identity = self.downsample(x)
340 |
341 | out += identity
342 | out = self.relu(out)
343 |
344 | return out
345 |
346 | class AxialBlock_wopos(nn.Module):
347 | expansion = 2
348 |
349 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
350 | base_width=64, dilation=1, norm_layer=None, kernel_size=56):
351 | super(AxialBlock_wopos, self).__init__()
352 | if norm_layer is None:
353 | norm_layer = nn.BatchNorm2d
354 | # print(kernel_size)
355 | width = int(planes * (base_width / 64.))
356 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
357 | self.conv_down = conv1x1(inplanes, width)
358 | self.conv1 = nn.Conv2d(width, width, kernel_size = 1)
359 | self.bn1 = norm_layer(width)
360 | self.hight_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size)
361 | self.width_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
362 | self.conv_up = conv1x1(width, planes * self.expansion)
363 | self.bn2 = norm_layer(planes * self.expansion)
364 | self.relu = nn.ReLU(inplace=True)
365 | self.downsample = downsample
366 | self.stride = stride
367 |
368 | def forward(self, x):
369 | identity = x
370 |
371 | # pdb.set_trace()
372 |
373 | out = self.conv_down(x)
374 | out = self.bn1(out)
375 | out = self.relu(out)
376 | # print(out.shape)
377 | out = self.hight_block(out)
378 | out = self.width_block(out)
379 |
380 | out = self.relu(out)
381 |
382 | out = self.conv_up(out)
383 | out = self.bn2(out)
384 |
385 | if self.downsample is not None:
386 | identity = self.downsample(x)
387 |
388 | out += identity
389 | out = self.relu(out)
390 |
391 | return out
392 |
393 |
394 | #end of block definition
395 |
396 |
397 | class ResAxialAttentionUNet(nn.Module):
398 |
399 | def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
400 | groups=8, width_per_group=64, replace_stride_with_dilation=None,
401 | norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
402 | super(ResAxialAttentionUNet, self).__init__()
403 | if norm_layer is None:
404 | norm_layer = nn.BatchNorm2d
405 | self._norm_layer = norm_layer
406 |
407 | self.inplanes = int(64 * s)
408 | self.dilation = 1
409 | if replace_stride_with_dilation is None:
410 | replace_stride_with_dilation = [False, False, False]
411 | if len(replace_stride_with_dilation) != 3:
412 | raise ValueError("replace_stride_with_dilation should be None "
413 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
414 | self.groups = groups
415 | self.base_width = width_per_group
416 | self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
417 | bias=False)
418 | self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
419 | self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
420 | self.bn1 = norm_layer(self.inplanes)
421 | self.bn2 = norm_layer(128)
422 | self.bn3 = norm_layer(self.inplanes)
423 | self.relu = nn.ReLU(inplace=True)
424 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
425 | self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
426 | self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
427 | dilate=replace_stride_with_dilation[0])
428 | self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
429 | dilate=replace_stride_with_dilation[1])
430 | self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
431 | dilate=replace_stride_with_dilation[2])
432 |
433 | # Decoder
434 | self.decoder1 = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
435 | self.decoder2 = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
436 | self.decoder3 = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
437 | self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
438 | self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
439 | self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
440 | self.soft = nn.Softmax(dim=1)
441 |
442 |
443 | def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
444 | norm_layer = self._norm_layer
445 | downsample = None
446 | previous_dilation = self.dilation
447 | if dilate:
448 | self.dilation *= stride
449 | stride = 1
450 | if stride != 1 or self.inplanes != planes * block.expansion:
451 | downsample = nn.Sequential(
452 | conv1x1(self.inplanes, planes * block.expansion, stride),
453 | norm_layer(planes * block.expansion),
454 | )
455 |
456 | layers = []
457 | layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
458 | base_width=self.base_width, dilation=previous_dilation,
459 | norm_layer=norm_layer, kernel_size=kernel_size))
460 | self.inplanes = planes * block.expansion
461 | if stride != 1:
462 | kernel_size = kernel_size // 2
463 |
464 | for _ in range(1, blocks):
465 | layers.append(block(self.inplanes, planes, groups=self.groups,
466 | base_width=self.base_width, dilation=self.dilation,
467 | norm_layer=norm_layer, kernel_size=kernel_size))
468 |
469 | return nn.Sequential(*layers)
470 |
471 | def _forward_impl(self, x):
472 |
473 | # AxialAttention Encoder
474 | # pdb.set_trace()
475 | x = self.conv1(x)
476 | x = self.bn1(x)
477 | x = self.relu(x)
478 | x = self.conv2(x)
479 | x = self.bn2(x)
480 | x = self.relu(x)
481 | x = self.conv3(x)
482 | x = self.bn3(x)
483 | x = self.relu(x)
484 |
485 | x1 = self.layer1(x)
486 |
487 | x2 = self.layer2(x1)
488 | # print(x2.shape)
489 | x3 = self.layer3(x2)
490 | # print(x3.shape)
491 | x4 = self.layer4(x3)
492 |
493 | x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
494 | x = torch.add(x, x4)
495 | x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear'))
496 | x = torch.add(x, x3)
497 | x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear'))
498 | x = torch.add(x, x2)
499 | x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear'))
500 | x = torch.add(x, x1)
501 | x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
502 | x = self.adjust(F.relu(x))
503 | # pdb.set_trace()
504 | return x
505 |
506 | def forward(self, x):
507 | return self._forward_impl(x)
508 |
509 | class medt_net(nn.Module):
510 |
511 | def __init__(self, block, block_2, layers, num_classes=2, zero_init_residual=True,
512 | groups=8, width_per_group=64, replace_stride_with_dilation=None,
513 | norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
514 | super(medt_net, self).__init__()
515 | if norm_layer is None:
516 | norm_layer = nn.BatchNorm2d
517 | self._norm_layer = norm_layer
518 |
519 | self.inplanes = int(64 * s)
520 | self.dilation = 1
521 | if replace_stride_with_dilation is None:
522 | replace_stride_with_dilation = [False, False, False]
523 | if len(replace_stride_with_dilation) != 3:
524 | raise ValueError("replace_stride_with_dilation should be None "
525 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
526 | self.groups = groups
527 | self.base_width = width_per_group
528 | self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
529 | bias=False)
530 | self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
531 | self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
532 | self.bn1 = norm_layer(self.inplanes)
533 | self.bn2 = norm_layer(128)
534 | self.bn3 = norm_layer(self.inplanes)
535 | # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
536 | self.bn1 = norm_layer(self.inplanes)
537 | self.relu = nn.ReLU(inplace=True)
538 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
539 | self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
540 | self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
541 | dilate=replace_stride_with_dilation[0])
542 | # self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
543 | # dilate=replace_stride_with_dilation[1])
544 | # self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
545 | # dilate=replace_stride_with_dilation[2])
546 |
547 | # Decoder
548 | # self.decoder1 = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
549 | # self.decoder2 = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
550 | # self.decoder3 = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
551 | self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
552 | self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
553 | self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
554 | self.soft = nn.Softmax(dim=1)
555 |
556 |
557 | self.conv1_p = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
558 | bias=False)
559 | self.conv2_p = nn.Conv2d(self.inplanes,128, kernel_size=3, stride=1, padding=1,
560 | bias=False)
561 | self.conv3_p = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1,
562 | bias=False)
563 | # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
564 | self.bn1_p = norm_layer(self.inplanes)
565 | self.bn2_p = norm_layer(128)
566 | self.bn3_p = norm_layer(self.inplanes)
567 |
568 | self.relu_p = nn.ReLU(inplace=True)
569 |
570 | img_size_p = img_size // 4
571 |
572 | self.layer1_p = self._make_layer(block_2, int(128 * s), layers[0], kernel_size= (img_size_p//2))
573 | self.layer2_p = self._make_layer(block_2, int(256 * s), layers[1], stride=2, kernel_size=(img_size_p//2),
574 | dilate=replace_stride_with_dilation[0])
575 | self.layer3_p = self._make_layer(block_2, int(512 * s), layers[2], stride=2, kernel_size=(img_size_p//4),
576 | dilate=replace_stride_with_dilation[1])
577 | self.layer4_p = self._make_layer(block_2, int(1024 * s), layers[3], stride=2, kernel_size=(img_size_p//8),
578 | dilate=replace_stride_with_dilation[2])
579 |
580 | # Decoder
581 | self.decoder1_p = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
582 | self.decoder2_p = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
583 | self.decoder3_p = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
584 | self.decoder4_p = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
585 | self.decoder5_p = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
586 |
587 | self.decoderf = nn.Conv2d(int(128*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
588 | self.adjust_p = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
589 | self.soft_p = nn.Softmax(dim=1)
590 |
591 |
592 | def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
593 | norm_layer = self._norm_layer
594 | downsample = None
595 | previous_dilation = self.dilation
596 | if dilate:
597 | self.dilation *= stride
598 | stride = 1
599 | if stride != 1 or self.inplanes != planes * block.expansion:
600 | downsample = nn.Sequential(
601 | conv1x1(self.inplanes, planes * block.expansion, stride),
602 | norm_layer(planes * block.expansion),
603 | )
604 |
605 | layers = []
606 | layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
607 | base_width=self.base_width, dilation=previous_dilation,
608 | norm_layer=norm_layer, kernel_size=kernel_size))
609 | self.inplanes = planes * block.expansion
610 | if stride != 1:
611 | kernel_size = kernel_size // 2
612 |
613 | for _ in range(1, blocks):
614 | layers.append(block(self.inplanes, planes, groups=self.groups,
615 | base_width=self.base_width, dilation=self.dilation,
616 | norm_layer=norm_layer, kernel_size=kernel_size))
617 |
618 | return nn.Sequential(*layers)
619 |
620 | def _forward_impl(self, x):
621 |
622 | xin = x.clone()
623 | x = self.conv1(x)
624 | x = self.bn1(x)
625 | x = self.relu(x)
626 | x = self.conv2(x)
627 | x = self.bn2(x)
628 | x = self.relu(x)
629 | x = self.conv3(x)
630 | x = self.bn3(x)
631 | # x = F.max_pool2d(x,2,2)
632 | x = self.relu(x)
633 |
634 | # x = self.maxpool(x)
635 | # pdb.set_trace()
636 | x1 = self.layer1(x)
637 | # print(x1.shape)
638 | x2 = self.layer2(x1)
639 | # print(x2.shape)
640 | # x3 = self.layer3(x2)
641 | # # print(x3.shape)
642 | # x4 = self.layer4(x3)
643 | # # print(x4.shape)
644 | # x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
645 | # x = torch.add(x, x4)
646 | # x = F.relu(F.interpolate(self.decoder2(x4) , scale_factor=(2,2), mode ='bilinear'))
647 | # x = torch.add(x, x3)
648 | # x = F.relu(F.interpolate(self.decoder3(x3) , scale_factor=(2,2), mode ='bilinear'))
649 | # x = torch.add(x, x2)
650 | x = F.relu(F.interpolate(self.decoder4(x2) , scale_factor=(2,2), mode ='bilinear'))
651 | x = torch.add(x, x1)
652 | x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
653 | # print(x.shape)
654 |
655 | # end of full image training
656 |
657 | # y_out = torch.ones((1,2,128,128))
658 | x_loc = x.clone()
659 | # x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
660 | #start
661 | for i in range(0,4):
662 | for j in range(0,4):
663 |
664 | x_p = xin[:,:,32*i:32*(i+1),32*j:32*(j+1)]
665 | # begin patch wise
666 | x_p = self.conv1_p(x_p)
667 | x_p = self.bn1_p(x_p)
668 | # x = F.max_pool2d(x,2,2)
669 | x_p = self.relu(x_p)
670 |
671 | x_p = self.conv2_p(x_p)
672 | x_p = self.bn2_p(x_p)
673 | # x = F.max_pool2d(x,2,2)
674 | x_p = self.relu(x_p)
675 | x_p = self.conv3_p(x_p)
676 | x_p = self.bn3_p(x_p)
677 | # x = F.max_pool2d(x,2,2)
678 | x_p = self.relu(x_p)
679 |
680 | # x = self.maxpool(x)
681 | # pdb.set_trace()
682 | x1_p = self.layer1_p(x_p)
683 | # print(x1.shape)
684 | x2_p = self.layer2_p(x1_p)
685 | # print(x2.shape)
686 | x3_p = self.layer3_p(x2_p)
687 | # # print(x3.shape)
688 | x4_p = self.layer4_p(x3_p)
689 |
690 | x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear'))
691 | x_p = torch.add(x_p, x4_p)
692 | x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
693 | x_p = torch.add(x_p, x3_p)
694 | x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
695 | x_p = torch.add(x_p, x2_p)
696 | x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
697 | x_p = torch.add(x_p, x1_p)
698 | x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
699 |
700 | x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p
701 |
702 | x = torch.add(x,x_loc)
703 | x = F.relu(self.decoderf(x))
704 |
705 | x = self.adjust(F.relu(x))
706 |
707 | # pdb.set_trace()
708 | return x
709 |
710 | def forward(self, x):
711 | return self._forward_impl(x)
712 |
713 |
714 | def axialunet(pretrained=False, **kwargs):
715 | model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
716 | return model
717 |
718 | def gated(pretrained=False, **kwargs):
719 | model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs)
720 | return model
721 |
722 | def MedT(pretrained=False, **kwargs):
723 | model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125, **kwargs)
724 | return model
725 |
726 | def logo(pretrained=False, **kwargs):
727 | model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
728 | return model
729 |
730 | # EOF
--------------------------------------------------------------------------------
/lib/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | __all__ = ['ResNet', 'resnet26', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152',]
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=dilation, groups=groups, bias=False, dilation=dilation)
26 |
27 |
28 | def conv1x1(in_planes, out_planes, stride=1):
29 | """1x1 convolution"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
37 | base_width=64, dilation=1, norm_layer=None):
38 | super(BasicBlock, self).__init__()
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 | if groups != 1 or base_width != 64:
42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
43 | if dilation > 1:
44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = norm_layer(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.conv2 = conv3x3(planes, planes)
50 | self.bn2 = norm_layer(planes)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def forward(self, x):
55 | identity = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 |
64 | if self.downsample is not None:
65 | identity = self.downsample(x)
66 |
67 | out += identity
68 | out = self.relu(out)
69 |
70 | return out
71 |
72 |
73 | class Bottleneck(nn.Module):
74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
77 | # This variant is also known as ResNet V1.5 and improves accuracy according to
78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
79 |
80 | expansion = 4
81 |
82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
83 | base_width=64, dilation=1, norm_layer=None):
84 | super(Bottleneck, self).__init__()
85 | if norm_layer is None:
86 | norm_layer = nn.BatchNorm2d
87 | width = int(planes * (base_width / 64.)) * groups
88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
89 | self.conv1 = conv1x1(inplanes, width)
90 | self.bn1 = norm_layer(width)
91 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
92 | self.bn2 = norm_layer(width)
93 | self.conv3 = conv1x1(width, planes * self.expansion)
94 | self.bn3 = norm_layer(planes * self.expansion)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.downsample = downsample
97 | self.stride = stride
98 |
99 | def forward(self, x):
100 | identity = x
101 |
102 | out = self.conv1(x)
103 | out = self.bn1(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv2(out)
107 | out = self.bn2(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv3(out)
111 | out = self.bn3(out)
112 |
113 | if self.downsample is not None:
114 | identity = self.downsample(x)
115 |
116 | out += identity
117 | out = self.relu(out)
118 |
119 | return out
120 |
121 |
122 | class ResNet(nn.Module):
123 |
124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
125 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
126 | norm_layer=None):
127 | super(ResNet, self).__init__()
128 | if norm_layer is None:
129 | norm_layer = nn.BatchNorm2d
130 | self._norm_layer = norm_layer
131 |
132 | self.inplanes = 64
133 | self.dilation = 1
134 | if replace_stride_with_dilation is None:
135 | # each element in the tuple indicates if we should replace
136 | # the 2x2 stride with a dilated convolution instead
137 | replace_stride_with_dilation = [False, False, False]
138 | if len(replace_stride_with_dilation) != 3:
139 | raise ValueError("replace_stride_with_dilation should be None "
140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
141 | self.groups = groups
142 | self.base_width = width_per_group
143 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
144 | bias=False)
145 | self.bn1 = norm_layer(self.inplanes)
146 | self.relu = nn.ReLU(inplace=True)
147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
148 | self.layer1 = self._make_layer(block, 64, layers[0])
149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
150 | dilate=replace_stride_with_dilation[0])
151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
152 | dilate=replace_stride_with_dilation[1])
153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
154 | dilate=replace_stride_with_dilation[2])
155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 |
158 | for m in self.modules():
159 | if isinstance(m, nn.Conv2d):
160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
162 | nn.init.constant_(m.weight, 1)
163 | nn.init.constant_(m.bias, 0)
164 |
165 | # Zero-initialize the last BN in each residual branch,
166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
168 | if zero_init_residual:
169 | for m in self.modules():
170 | if isinstance(m, Bottleneck):
171 | nn.init.constant_(m.bn3.weight, 0)
172 | elif isinstance(m, BasicBlock):
173 | nn.init.constant_(m.bn2.weight, 0)
174 |
175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
176 | norm_layer = self._norm_layer
177 | downsample = None
178 | previous_dilation = self.dilation
179 | if dilate:
180 | self.dilation *= stride
181 | stride = 1
182 | if stride != 1 or self.inplanes != planes * block.expansion:
183 | downsample = nn.Sequential(
184 | conv1x1(self.inplanes, planes * block.expansion, stride),
185 | norm_layer(planes * block.expansion),
186 | )
187 |
188 | layers = []
189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
190 | self.base_width, previous_dilation, norm_layer))
191 | self.inplanes = planes * block.expansion
192 | for _ in range(1, blocks):
193 | layers.append(block(self.inplanes, planes, groups=self.groups,
194 | base_width=self.base_width, dilation=self.dilation,
195 | norm_layer=norm_layer))
196 |
197 | return nn.Sequential(*layers)
198 |
199 | def _forward_impl(self, x):
200 | # See note [TorchScript super()]
201 | x = self.conv1(x)
202 | x = self.bn1(x)
203 | x = self.relu(x)
204 | x = self.maxpool(x)
205 |
206 | x = self.layer1(x)
207 | x = self.layer2(x)
208 | x = self.layer3(x)
209 | x = self.layer4(x)
210 |
211 | x = self.avgpool(x)
212 | x = torch.flatten(x, 1)
213 | x = self.fc(x)
214 |
215 | return x
216 |
217 | def forward(self, x):
218 | return self._forward_impl(x)
219 |
220 |
221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
222 | model = ResNet(block, layers, **kwargs)
223 | if pretrained:
224 | state_dict = load_state_dict_from_url(model_urls[arch],
225 | progress=progress)
226 | model.load_state_dict(state_dict)
227 | return model
228 |
229 |
230 | def resnet18(pretrained=False, progress=True, **kwargs):
231 | r"""ResNet-18 model from
232 | `"Deep Residual Learning for Image Recognition" `_
233 | Args:
234 | pretrained (bool): If True, returns a model pre-trained on ImageNet
235 | progress (bool): If True, displays a progress bar of the download to stderr
236 | """
237 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
238 | **kwargs)
239 |
240 |
241 | def resnet34(pretrained=False, progress=True, **kwargs):
242 | r"""ResNet-34 model from
243 | `"Deep Residual Learning for Image Recognition" `_
244 | Args:
245 | pretrained (bool): If True, returns a model pre-trained on ImageNet
246 | progress (bool): If True, displays a progress bar of the download to stderr
247 | """
248 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
249 | **kwargs)
250 |
251 |
252 | def resnet26(pretrained=False, progress=True, **kwargs):
253 | return _resnet('resnet26', Bottleneck, [1, 2, 4, 1], pretrained, progress,
254 | **kwargs)
255 |
256 |
257 | def resnet50(pretrained=False, progress=True, **kwargs):
258 | r"""ResNet-50 model from
259 | `"Deep Residual Learning for Image Recognition" `_
260 | Args:
261 | pretrained (bool): If True, returns a model pre-trained on ImageNet
262 | progress (bool): If True, displays a progress bar of the download to stderr
263 | """
264 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
265 | **kwargs)
266 |
267 |
268 | def resnet101(pretrained=False, progress=True, **kwargs):
269 | r"""ResNet-101 model from
270 | `"Deep Residual Learning for Image Recognition" `_
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | progress (bool): If True, displays a progress bar of the download to stderr
274 | """
275 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
276 | **kwargs)
277 |
278 |
279 | def resnet152(pretrained=False, progress=True, **kwargs):
280 | r"""ResNet-152 model from
281 | `"Deep Residual Learning for Image Recognition" `_
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | progress (bool): If True, displays a progress bar of the download to stderr
285 | """
286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
287 | **kwargs)
288 |
--------------------------------------------------------------------------------
/lib/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class qkv_transform(nn.Conv1d):
5 | """Conv1d for qkv_transform"""
6 |
7 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def adjust_learning_rate(args, optimizer, epoch, batch_idx, data_nums, type="cosine"):
8 | if epoch < args.warmup_epochs:
9 | epoch += float(batch_idx + 1) / data_nums
10 | lr_adj = 1. * (epoch / args.warmup_epochs)
11 | elif type == "linear":
12 | if epoch < 30 + args.warmup_epochs:
13 | lr_adj = 1.
14 | elif epoch < 60 + args.warmup_epochs:
15 | lr_adj = 1e-1
16 | elif epoch < 90 + args.warmup_epochs:
17 | lr_adj = 1e-2
18 | else:
19 | lr_adj = 1e-3
20 | elif type == "cosine":
21 | run_epochs = epoch - args.warmup_epochs
22 | total_epochs = args.epochs - args.warmup_epochs
23 | T_cur = float(run_epochs * data_nums) + batch_idx
24 | T_total = float(total_epochs * data_nums)
25 |
26 | lr_adj = 0.5 * (1 + math.cos(math.pi * T_cur / T_total))
27 |
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = args.lr * lr_adj
30 | return args.lr * lr_adj
31 |
32 |
33 | def label_smoothing(pred, target, eta=0.1):
34 | '''
35 | Refer from https://arxiv.org/pdf/1512.00567.pdf
36 | :param target: N,
37 | :param n_classes: int
38 | :param eta: float
39 | :return:
40 | N x C onehot smoothed vector
41 | '''
42 | n_classes = pred.size(1)
43 | target = torch.unsqueeze(target, 1)
44 | onehot_target = torch.zeros_like(pred)
45 | onehot_target.scatter_(1, target, 1)
46 | return onehot_target * (1 - eta) + eta / n_classes * 1
47 |
48 |
49 | def cross_entropy_for_onehot(pred, target):
50 | return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))
51 |
52 |
53 | def cross_entropy_with_label_smoothing(pred, target, eta=0.1):
54 | onehot_target = label_smoothing(pred, target, eta=eta)
55 | return cross_entropy_for_onehot(pred, onehot_target)
56 |
57 |
58 | def accuracy(output, target):
59 | # get the index of the max log-probability
60 | pred = output.max(1, keepdim=True)[1]
61 | return pred.eq(target.view_as(pred)).cpu().float().mean()
62 |
63 |
64 | def save_model(model, optimizer, epoch, args):
65 | os.system('mkdir -p {}'.format(args.work_dirs))
66 | if optimizer is not None:
67 | torch.save({
68 | 'net': model.state_dict(),
69 | 'optim': optimizer.state_dict(),
70 | 'epoch': epoch
71 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch)))
72 | else:
73 | torch.save({
74 | 'net': model.state_dict(),
75 | 'epoch': epoch
76 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch)))
77 |
78 |
79 | def dist_save_model(model, optimizer, epoch, ngpus_per_node, args):
80 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
81 | and args.rank % ngpus_per_node == 0):
82 | os.system('mkdir -p {}'.format(args.work_dirs))
83 | if optimizer is not None:
84 | torch.save({
85 | 'net': model.state_dict(),
86 | 'optim': optimizer.state_dict(),
87 | 'epoch': epoch
88 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch)))
89 | else:
90 | torch.save({
91 | 'net': model.state_dict(),
92 | 'epoch': epoch
93 | }, os.path.join(args.work_dirs, '{}.pth'.format(epoch)))
94 |
95 |
96 | def load_model(network, args):
97 | if not os.path.exists(args.work_dirs):
98 | print("No such working directory!")
99 | raise AssertionError
100 |
101 | pths = [pth.split('.')[0] for pth in os.listdir(args.work_dirs) if 'pth' in pth]
102 | if len(pths) == 0:
103 | print("No model to load!")
104 | raise AssertionError
105 |
106 | pths = [int(pth) for pth in pths]
107 | if args.test_model == -1:
108 | pth = -1
109 | if pth in pths:
110 | pass
111 | else:
112 | pth = max(pths)
113 | else:
114 | pth = args.test_model
115 | try:
116 | if args.distributed:
117 | loc = 'cuda:{}'.format(args.gpu)
118 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth)), map_location=loc)
119 | except:
120 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth)))
121 | try:
122 | network.load_state_dict(model['net'], strict=True)
123 | except:
124 | network.load_state_dict(convert_model(model['net']), strict=True)
125 | return True
126 |
127 |
128 | def resume_model(network, optimizer, args):
129 | print("Loading the model...")
130 | if not os.path.exists(args.work_dirs):
131 | print("No such working directory!")
132 | return 0
133 | pths = [pth.split('.')[0] for pth in os.listdir(args.work_dirs) if 'pth' in pth]
134 | if len(pths) == 0:
135 | print("No model to load!")
136 | return 0
137 | pths = [int(pth) for pth in pths]
138 | if args.test_model == -1:
139 | pth = max(pths)
140 | else:
141 | pth = args.test_model
142 | try:
143 | if args.distributed:
144 | loc = 'cuda:{}'.format(args.gpu)
145 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth)), map_location=loc)
146 | except:
147 | model = torch.load(os.path.join(args.work_dirs, '{}.pth'.format(pth)))
148 | try:
149 | network.load_state_dict(model['net'], strict=True)
150 | except:
151 | network.load_state_dict(convert_model(model['net']), strict=True)
152 | optimizer.load_state_dict(model['optim'])
153 | for state in optimizer.state.values():
154 | for k, v in state.items():
155 | if torch.is_tensor(v):
156 | try:
157 | state[k] = v.cuda(args.gpu)
158 | except:
159 | state[k] = v.cuda()
160 | return model['epoch']
161 |
162 |
163 | def convert_model(model):
164 | new_model = {}
165 | for k in model.keys():
166 | new_model[k[7:]] = model[k]
167 | return new_model
168 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.functional import cross_entropy
3 | from torch.nn.modules.loss import _WeightedLoss
4 |
5 |
6 | EPSILON = 1e-32
7 |
8 |
9 | class LogNLLLoss(_WeightedLoss):
10 | __constants__ = ['weight', 'reduction', 'ignore_index']
11 |
12 | def __init__(self, weight=None, size_average=None, reduce=None, reduction=None,
13 | ignore_index=-100):
14 | super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction)
15 | self.ignore_index = ignore_index
16 |
17 | def forward(self, y_input, y_target):
18 | # y_input = torch.log(y_input + EPSILON)
19 | return cross_entropy(y_input, y_target, weight=self.weight,
20 | ignore_index=self.ignore_index)
21 |
22 |
23 | def classwise_iou(output, gt):
24 | """
25 | Args:
26 | output: torch.Tensor of shape (n_batch, n_classes, image.shape)
27 | gt: torch.LongTensor of shape (n_batch, image.shape)
28 | """
29 | dims = (0, *range(2, len(output.shape)))
30 | gt = torch.zeros_like(output).scatter_(1, gt[:, None, :], 1)
31 | intersection = output*gt
32 | union = output + gt - intersection
33 | classwise_iou = (intersection.sum(dim=dims).float() + EPSILON) / (union.sum(dim=dims) + EPSILON)
34 |
35 | return classwise_iou
36 |
37 |
38 | def classwise_f1(output, gt):
39 | """
40 | Args:
41 | output: torch.Tensor of shape (n_batch, n_classes, image.shape)
42 | gt: torch.LongTensor of shape (n_batch, image.shape)
43 | """
44 |
45 | epsilon = 1e-20
46 | n_classes = output.shape[1]
47 |
48 | output = torch.argmax(output, dim=1)
49 | true_positives = torch.tensor([((output == i) * (gt == i)).sum() for i in range(n_classes)]).float()
50 | selected = torch.tensor([(output == i).sum() for i in range(n_classes)]).float()
51 | relevant = torch.tensor([(gt == i).sum() for i in range(n_classes)]).float()
52 |
53 | precision = (true_positives + epsilon) / (selected + epsilon)
54 | recall = (true_positives + epsilon) / (relevant + epsilon)
55 | classwise_f1 = 2 * (precision * recall) / (precision + recall)
56 |
57 | return classwise_f1
58 |
59 |
60 | def make_weighted_metric(classwise_metric):
61 | """
62 | Args:
63 | classwise_metric: classwise metric like classwise_IOU or classwise_F1
64 | """
65 |
66 | def weighted_metric(output, gt, weights=None):
67 |
68 | # dimensions to sum over
69 | dims = (0, *range(2, len(output.shape)))
70 |
71 | # default weights
72 | if weights == None:
73 | weights = torch.ones(output.shape[1]) / output.shape[1]
74 | else:
75 | # creating tensor if needed
76 | if len(weights) != output.shape[1]:
77 | raise ValueError("The number of weights must match with the number of classes")
78 | if not isinstance(weights, torch.Tensor):
79 | weights = torch.tensor(weights)
80 | # normalizing weights
81 | weights /= torch.sum(weights)
82 |
83 | classwise_scores = classwise_metric(output, gt).cpu()
84 |
85 | return classwise_scores
86 |
87 | return weighted_metric
88 |
89 |
90 | jaccard_index = make_weighted_metric(classwise_iou)
91 | f1_score = make_weighted_metric(classwise_f1)
92 |
93 |
94 | if __name__ == '__main__':
95 | output, gt = torch.zeros(3, 2, 5, 5), torch.zeros(3, 5, 5).long()
96 | print(classwise_iou(output, gt))
97 |
--------------------------------------------------------------------------------
/performancemetrics_ax.m:
--------------------------------------------------------------------------------
1 |
2 | % close all;
3 | % clear all;
4 | % clc;
5 | N = 328
6 | st = 0;
7 | Fsc=[];
8 | MIU=[];
9 | PA=[];
10 | bestfsc=0;
11 | bestmiu=0;
12 | bestpa=0;
13 | bestep = 0;
14 |
15 | for k = 0:8
16 | k
17 | Fsc=[];
18 | MIU=[];
19 | PA=[];
20 | for i = st:st+N
21 | i;
22 | %gname = strcat('./Brain_test/',num2str(i,'%04d'),'.png');
23 |
24 | tname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Projects/axialseg/KiU-Net-pytorch/results/brainus/mix_3_gated_wopos/';
25 | imgname = strcat(tname,num2str(50*k),'/',num2str(i,'%04d'),'.png');
26 | lname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/test/labelcol/';
27 | labelname = strcat(lname, num2str(i,'%04d'),'.png');
28 |
29 | I = double(imread(imgname));tmp2=zeros(128,128);
30 | tmp2(I>131) = 255;
31 | tmp2(I<130) = 0;
32 | tmp = double(imread(labelname));
33 | tmp = tmp(:,:,1);
34 | tmp(tmp<130)=0;tmp(tmp>131)=255;
35 |
36 | tp=0;fp=0;fn=0;tn=0;uni=0;ttp=0;lab=0;
37 |
38 | for p =1:128
39 | for q =1:128
40 | if tmp(p,q)==0
41 | if tmp2(p,q) == tmp(p,q)
42 | tn = tn+1;
43 | else
44 | fp = fp+1;
45 | uni = uni+1;
46 | ttp = ttp+1;
47 | end
48 | elseif tmp(p,q)==255
49 | lab = lab +1;
50 | if tmp2(p,q) == tmp(p,q)
51 | tp = tp+1;
52 | ttp = ttp+1;
53 | else
54 | fn = fn+1;
55 | end
56 | uni = uni+1;
57 | end
58 |
59 | end
60 | end
61 |
62 | if (tp~=0)
63 | F = (2*tp)/(2*tp+fp+fn);
64 | MIU=[MIU,(tp*1.0/uni)];
65 | PA=[PA,(tp*1.0/ttp)];
66 | Fsc=[Fsc;[i,F]];
67 | else
68 | MIU=[MIU,1];
69 | PA=[PA,1];
70 | Fsc=[Fsc;[i,1]];
71 |
72 | end
73 |
74 |
75 |
76 | end
77 | if bestfsc <= mean(Fsc) & (mean(Fsc) ~= 1)
78 | bestfsc = mean(Fsc);
79 | bestmiu = mean(MIU,2);
80 | bestpa = mean(PA,2);
81 | bestep = 50*k;
82 |
83 | end
84 | mean(Fsc)
85 | end
86 |
87 | bestfsc
88 | bestmiu
89 | bestpa
90 | bestep
91 |
92 | % plot(Fsc(:,1),Fsc(:,2),'-*')
93 | % hold on
94 | % plot(Fsc(:,1),Fsc1(:,2),'-s')
95 | % hold off
96 | % figure();plot(Fsc(:,1),PA,'-*');hold on
97 | % plot(Fsc(:,1),PA1,'-s');hold off
98 | % Fsc1=Fsc;
99 | % MIU1=MIU;
100 | % PA1=PA;
101 |
--------------------------------------------------------------------------------
/performancemetrics_glas.m:
--------------------------------------------------------------------------------
1 |
2 | % close all;
3 | % clear all;
4 | % clc;
5 | N = 79
6 | st = 1;
7 | Fsc=[];
8 | MIU=[];
9 | PA=[];
10 | bestfsc=0;
11 | bestmiu=0;
12 | bestpa=0;
13 | bestep = 0;
14 |
15 | for k = 1:24
16 | k
17 | Fsc=[];
18 | MIU=[];
19 | PA=[];
20 | for i = st:st+N
21 | i;
22 | %gname = strcat('./Brain_test/',num2str(i,'%04d'),'.png');
23 |
24 | tname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Projects/axialseg/KiU-Net-pytorch/results/glas/medT/';
25 | imgname = strcat(tname,num2str(50*k),'/',num2str(i,'%02d'),'.png');
26 | lname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/glas/resized/test/labelcol/';
27 |
28 | labelname = strcat(lname, num2str(i,'%02d'),'.png');
29 |
30 | I = double(imread(imgname));tmp2=zeros(128,128);
31 | tmp2(I>130) = 255;
32 | tmp2(I<131) = 0;
33 | tmp = double(imread(labelname));
34 | tmp = tmp(:,:,1);
35 | tmp(tmp<130)=0;tmp(tmp>131)=255;
36 |
37 |
38 |
39 | tp=0;fp=0;fn=0;tn=0;uni=0;ttp=0;lab=0;
40 |
41 | for p =1:128
42 | for q =1:128
43 | if tmp(p,q)==0
44 | if tmp2(p,q) == tmp(p,q)
45 | tn = tn+1;
46 | else
47 | fp = fp+1;
48 | uni = uni+1;
49 | ttp = ttp+1;
50 | end
51 | elseif tmp(p,q)==255
52 | lab = lab +1;
53 | if tmp2(p,q) == tmp(p,q)
54 | tp = tp+1;
55 | ttp = ttp+1;
56 | else
57 | fn = fn+1;
58 | end
59 | uni = uni+1;
60 | end
61 |
62 | end
63 | end
64 |
65 |
66 | if (tp~=0)
67 | F = (2*tp)/(2*tp+fp+fn);
68 | MIU=[MIU,(tp*1.0/uni)];
69 | PA=[PA,(tp*1.0/ttp)];
70 | Fsc=[Fsc;[i,F]];
71 |
72 | else
73 | MIU=[MIU,1];
74 | PA=[PA,1];
75 | Fsc=[Fsc;[i,1]];
76 |
77 | end
78 |
79 |
80 |
81 | end
82 | if bestfsc <= mean(Fsc) & (mean(Fsc) ~= 1)
83 | bestfsc = mean(Fsc);
84 | bestmiu = mean(MIU,2);
85 | bestpa = mean(PA,2);
86 | bestep = 50*k;
87 |
88 | end
89 | mean(Fsc)
90 | end
91 |
92 | bestfsc
93 | bestmiu
94 | bestpa
95 | bestep
96 |
97 |
--------------------------------------------------------------------------------
/performancemetrics_monuseg.m:
--------------------------------------------------------------------------------
1 |
2 | % close all;
3 | % clear all;
4 | % clc;
5 | N = 328
6 | st = 0;
7 | Fsc=[];
8 | MIU=[];
9 | PA=[];
10 | bestfsc=0;
11 | bestmiu=0;
12 | bestpa=0;
13 | bestep = 0;
14 |
15 | folder = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/monuseg/resized/test/labelcol/';
16 | listinfo = dir(strcat(folder,'*.png'));
17 | lm = length(listinfo);
18 |
19 |
20 | for k = 1:10
21 | k
22 | Fsc=[];
23 | MIU=[];
24 | PA=[];
25 | for i = 1:lm
26 | %I = double(imread(strcat(folder,listinfo(i).name)));
27 | imgfile = strcat(folder,listinfo(i).name);
28 | imgname = listinfo(i).name(1:27) ;
29 | i;
30 | %gname = strcat('./Brain_test/',num2str(i,'%04d'),'.png');
31 |
32 | lname = '/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Projects/axialseg/KiU-Net-pytorch/results/monuseg/medTr/';
33 | labelname = strcat(lname, num2str(k*10),'/', imgname);
34 | %imgname
35 | I = double(imread(imgfile));tmp2=zeros(512,512);
36 | %I = rgb2gray(I);
37 | tmp2(I>127) = 255;
38 | tmp2(I<126) = 0;
39 | tmp = double(imread(labelname));
40 |
41 | tmp(tmp<127)=0;tmp(tmp>126)=255;
42 | %tmp2 = I;
43 | tp=0;fp=0;fn=0;tn=0;uni=0;ttp=0;lab=0;
44 |
45 | for p =1:512
46 | for q =1:512
47 | if tmp(p,q)==0
48 | if tmp2(p,q) == tmp(p,q)
49 | tn = tn+1;
50 | else
51 | fp = fp+1;
52 | uni = uni+1;
53 | ttp = ttp+1;
54 | end
55 | elseif tmp(p,q)==255
56 | lab = lab +1;
57 | if tmp2(p,q) == tmp(p,q)
58 | tp = tp+1;
59 | ttp = ttp+1;
60 | else
61 | fn = fn+1;
62 | end
63 | uni = uni+1;
64 | end
65 |
66 | end
67 | end
68 |
69 | if (tp~=0)
70 | F = (2*tp)/(2*tp+fp+fn);
71 | MIU=[MIU,(tp*1.0/uni)];
72 | PA=[PA,(tp*1.0/ttp)];
73 | Fsc=[Fsc;[i,F]];
74 | else
75 | MIU=[MIU,1];
76 | PA=[PA,1];
77 | Fsc=[Fsc;[i,1]];
78 |
79 | end
80 |
81 |
82 |
83 | end
84 |
85 | if bestfsc <= mean(Fsc) & (mean(Fsc) ~= 1)
86 | bestfsc = mean(Fsc);
87 | bestmiu = mean(MIU,2);
88 | bestpa = mean(PA,2);
89 | bestep = 10*k;
90 |
91 | end
92 | mean(Fsc)
93 | end
94 |
95 | bestfsc
96 | bestmiu
97 | %bestpa
98 | bestep
99 |
100 | % plot(Fsc(:,1),Fsc(:,2),'-*')
101 | % hold on
102 | % plot(Fsc(:,1),Fsc1(:,2),'-s')
103 | % hold off
104 | % figure();plot(Fsc(:,1),PA,'-*');hold on
105 | % plot(Fsc(:,1),PA1,'-s');hold off
106 | % Fsc1=Fsc;
107 | % MIU1=MIU;
108 | % PA1=PA;
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4.0
2 | torchvision>=0.5.0
3 | scikit-learn==0.23.2
4 | scipy==1.5.3
5 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import lib
3 | import torch
4 | import torchvision
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | from torchvision import transforms
9 | from torchvision.utils import save_image
10 | from torchvision.datasets import MNIST
11 | import torch.nn.functional as F
12 | import os
13 | import matplotlib.pyplot as plt
14 | import torch.utils.data as data
15 | from PIL import Image
16 | import numpy as np
17 | from torchvision.utils import save_image
18 | import torch
19 | import torch.nn.init as init
20 | from utils import JointTransform2D, ImageToImage2D, Image2D
21 | from metrics import jaccard_index, f1_score, LogNLLLoss,classwise_f1
22 | from utils import chk_mkdir, Logger, MetricList
23 | import cv2
24 | from functools import partial
25 | from random import randint
26 |
27 |
28 | parser = argparse.ArgumentParser(description='MedT')
29 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
30 | help='number of data loading workers (default: 8)')
31 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
32 | help='number of total epochs to run(default: 1)')
33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
34 | help='manual epoch number (useful on restarts)')
35 | parser.add_argument('-b', '--batch_size', default=1, type=int,
36 | metavar='N', help='batch size (default: 8)')
37 | parser.add_argument('--learning_rate', default=1e-3, type=float,
38 | metavar='LR', help='initial learning rate (default: 0.01)')
39 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
40 | help='momentum')
41 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
42 | metavar='W', help='weight decay (default: 1e-4)')
43 | parser.add_argument('--train_dataset', type=str)
44 | parser.add_argument('--val_dataset', type=str)
45 | parser.add_argument('--save_freq', type=int,default = 5)
46 | parser.add_argument('--modelname', default='off', type=str,
47 | help='name of the model to load')
48 | parser.add_argument('--cuda', default="on", type=str,
49 | help='switch on/off cuda option (default: off)')
50 |
51 | parser.add_argument('--direc', default='./results', type=str,
52 | help='directory to save')
53 | parser.add_argument('--crop', type=int, default=None)
54 | parser.add_argument('--device', default='cuda', type=str)
55 | parser.add_argument('--loaddirec', default='load', type=str)
56 | parser.add_argument('--imgsize', type=int, default=None)
57 | parser.add_argument('--gray', default='no', type=str)
58 | args = parser.parse_args()
59 |
60 | direc = args.direc
61 | gray_ = args.gray
62 | aug = args.aug
63 | direc = args.direc
64 | modelname = args.modelname
65 | imgsize = args.imgsize
66 | loaddirec = args.loaddirec
67 |
68 | if gray_ == "yes":
69 | from utils_gray import JointTransform2D, ImageToImage2D, Image2D
70 | imgchant = 1
71 | else:
72 | from utils import JointTransform2D, ImageToImage2D, Image2D
73 | imgchant = 3
74 |
75 | if args.crop is not None:
76 | crop = (args.crop, args.crop)
77 | else:
78 | crop = None
79 |
80 | tf_train = JointTransform2D(crop=crop, p_flip=0.5, color_jitter_params=None, long_mask=True)
81 | tf_val = JointTransform2D(crop=crop, p_flip=0, color_jitter_params=None, long_mask=True)
82 | train_dataset = ImageToImage2D(args.train_dataset, tf_val)
83 | val_dataset = ImageToImage2D(args.val_dataset, tf_val)
84 | predict_dataset = Image2D(args.val_dataset)
85 | dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
86 | valloader = DataLoader(val_dataset, 1, shuffle=True)
87 |
88 | device = torch.device("cuda")
89 |
90 | if modelname == "axialunet":
91 | model = lib.models.axialunet(img_size = imgsize, imgchan = imgchant)
92 | elif modelname == "MedT":
93 | model = lib.models.axialnet.MedT(img_size = imgsize, imgchan = imgchant)
94 | elif modelname == "gatedaxialunet":
95 | model = lib.models.axialnet.gated(img_size = imgsize, imgchan = imgchant)
96 | elif modelname == "logo":
97 | model = lib.models.axialnet.logo(img_size = imgsize, imgchan = imgchant)
98 |
99 | if torch.cuda.device_count() > 1:
100 | print("Let's use", torch.cuda.device_count(), "GPUs!")
101 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
102 | model = nn.DataParallel(model,device_ids=[0,1]).cuda()
103 | model.to(device)
104 |
105 | model.load_state_dict(torch.load(loaddirec))
106 | model.eval()
107 |
108 |
109 | for batch_idx, (X_batch, y_batch, *rest) in enumerate(valloader):
110 | # print(batch_idx)
111 | if isinstance(rest[0][0], str):
112 | image_filename = rest[0][0]
113 | else:
114 | image_filename = '%s.png' % str(batch_idx + 1).zfill(3)
115 |
116 | X_batch = Variable(X_batch.to(device='cuda'))
117 | y_batch = Variable(y_batch.to(device='cuda'))
118 |
119 | y_out = model(X_batch)
120 |
121 | tmp2 = y_batch.detach().cpu().numpy()
122 | tmp = y_out.detach().cpu().numpy()
123 | tmp[tmp>=0.5] = 1
124 | tmp[tmp<0.5] = 0
125 | tmp2[tmp2>0] = 1
126 | tmp2[tmp2<=0] = 0
127 | tmp2 = tmp2.astype(int)
128 | tmp = tmp.astype(int)
129 |
130 | # print(np.unique(tmp2))
131 | yHaT = tmp
132 | yval = tmp2
133 |
134 | epsilon = 1e-20
135 |
136 | del X_batch, y_batch,tmp,tmp2, y_out
137 |
138 | yHaT[yHaT==1] =255
139 | yval[yval==1] =255
140 | fulldir = direc+"/"
141 |
142 | if not os.path.isdir(fulldir):
143 |
144 | os.makedirs(fulldir)
145 |
146 | cv2.imwrite(fulldir+image_filename, yHaT[0,1,:,:])
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Code for MedT
2 |
3 | import torch
4 | import lib
5 | import argparse
6 | import torch
7 | import torchvision
8 | from torch import nn
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | from torchvision import transforms
12 | from torchvision.utils import save_image
13 | import torch.nn.functional as F
14 | import os
15 | import matplotlib.pyplot as plt
16 | import torch.utils.data as data
17 | from PIL import Image
18 | import numpy as np
19 | from torchvision.utils import save_image
20 | import torch
21 | import torch.nn.init as init
22 | from utils import JointTransform2D, ImageToImage2D, Image2D
23 | from metrics import jaccard_index, f1_score, LogNLLLoss,classwise_f1
24 | from utils import chk_mkdir, Logger, MetricList
25 | import cv2
26 | from functools import partial
27 | from random import randint
28 | import timeit
29 |
30 | parser = argparse.ArgumentParser(description='MedT')
31 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
32 | help='number of data loading workers (default: 8)')
33 | parser.add_argument('--epochs', default=400, type=int, metavar='N',
34 | help='number of total epochs to run(default: 400)')
35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
36 | help='manual epoch number (useful on restarts)')
37 | parser.add_argument('-b', '--batch_size', default=1, type=int,
38 | metavar='N', help='batch size (default: 1)')
39 | parser.add_argument('--learning_rate', default=1e-3, type=float,
40 | metavar='LR', help='initial learning rate (default: 0.001)')
41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
42 | help='momentum')
43 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
44 | metavar='W', help='weight decay (default: 1e-5)')
45 | parser.add_argument('--train_dataset', required=True, type=str)
46 | parser.add_argument('--val_dataset', type=str)
47 | parser.add_argument('--save_freq', type=int,default = 10)
48 |
49 | parser.add_argument('--modelname', default='MedT', type=str,
50 | help='type of model')
51 | parser.add_argument('--cuda', default="on", type=str,
52 | help='switch on/off cuda option (default: off)')
53 | parser.add_argument('--aug', default='off', type=str,
54 | help='turn on img augmentation (default: False)')
55 | parser.add_argument('--load', default='default', type=str,
56 | help='load a pretrained model')
57 | parser.add_argument('--save', default='default', type=str,
58 | help='save the model')
59 | parser.add_argument('--direc', default='./medt', type=str,
60 | help='directory to save')
61 | parser.add_argument('--crop', type=int, default=None)
62 | parser.add_argument('--imgsize', type=int, default=None)
63 | parser.add_argument('--device', default='cuda', type=str)
64 | parser.add_argument('--gray', default='no', type=str)
65 |
66 | args = parser.parse_args()
67 | gray_ = args.gray
68 | aug = args.aug
69 | direc = args.direc
70 | modelname = args.modelname
71 | imgsize = args.imgsize
72 |
73 | if gray_ == "yes":
74 | from utils_gray import JointTransform2D, ImageToImage2D, Image2D
75 | imgchant = 1
76 | else:
77 | from utils import JointTransform2D, ImageToImage2D, Image2D
78 | imgchant = 3
79 |
80 | if args.crop is not None:
81 | crop = (args.crop, args.crop)
82 | else:
83 | crop = None
84 |
85 | tf_train = JointTransform2D(crop=crop, p_flip=0.5, color_jitter_params=None, long_mask=True)
86 | tf_val = JointTransform2D(crop=crop, p_flip=0, color_jitter_params=None, long_mask=True)
87 | train_dataset = ImageToImage2D(args.train_dataset, tf_train)
88 | val_dataset = ImageToImage2D(args.val_dataset, tf_val)
89 | predict_dataset = Image2D(args.val_dataset)
90 | dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
91 | valloader = DataLoader(val_dataset, 1, shuffle=True)
92 |
93 | device = torch.device("cuda")
94 |
95 | if modelname == "axialunet":
96 | model = lib.models.axialunet(img_size = imgsize, imgchan = imgchant)
97 | elif modelname == "MedT":
98 | model = lib.models.axialnet.MedT(img_size = imgsize, imgchan = imgchant)
99 | elif modelname == "gatedaxialunet":
100 | model = lib.models.axialnet.gated(img_size = imgsize, imgchan = imgchant)
101 | elif modelname == "logo":
102 | model = lib.models.axialnet.logo(img_size = imgsize, imgchan = imgchant)
103 |
104 | if torch.cuda.device_count() > 1:
105 | print("Let's use", torch.cuda.device_count(), "GPUs!")
106 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
107 | model = nn.DataParallel(model,device_ids=[0,1]).cuda()
108 | model.to(device)
109 |
110 | criterion = LogNLLLoss()
111 | optimizer = torch.optim.Adam(list(model.parameters()), lr=args.learning_rate,
112 | weight_decay=1e-5)
113 |
114 |
115 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
116 | print("Total_params: {}".format(pytorch_total_params))
117 |
118 | seed = 3000
119 | np.random.seed(seed)
120 | torch.manual_seed(seed)
121 | torch.cuda.manual_seed(seed)
122 | # torch.set_deterministic(True)
123 | # random.seed(seed)
124 |
125 |
126 | for epoch in range(args.epochs):
127 |
128 | epoch_running_loss = 0
129 |
130 | for batch_idx, (X_batch, y_batch, *rest) in enumerate(dataloader):
131 |
132 |
133 |
134 | X_batch = Variable(X_batch.to(device ='cuda'))
135 | y_batch = Variable(y_batch.to(device='cuda'))
136 |
137 | # ===================forward=====================
138 |
139 |
140 | output = model(X_batch)
141 |
142 | tmp2 = y_batch.detach().cpu().numpy()
143 | tmp = output.detach().cpu().numpy()
144 | tmp[tmp>=0.5] = 1
145 | tmp[tmp<0.5] = 0
146 | tmp2[tmp2>0] = 1
147 | tmp2[tmp2<=0] = 0
148 | tmp2 = tmp2.astype(int)
149 | tmp = tmp.astype(int)
150 |
151 | yHaT = tmp
152 | yval = tmp2
153 |
154 |
155 |
156 | loss = criterion(output, y_batch)
157 |
158 | # ===================backward====================
159 | optimizer.zero_grad()
160 | loss.backward()
161 | optimizer.step()
162 | epoch_running_loss += loss.item()
163 |
164 | # ===================log========================
165 | print('epoch [{}/{}], loss:{:.4f}'
166 | .format(epoch, args.epochs, epoch_running_loss/(batch_idx+1)))
167 |
168 |
169 | if epoch == 10:
170 | for param in model.parameters():
171 | param.requires_grad =True
172 | if (epoch % args.save_freq) ==0:
173 |
174 | for batch_idx, (X_batch, y_batch, *rest) in enumerate(valloader):
175 | # print(batch_idx)
176 | if isinstance(rest[0][0], str):
177 | image_filename = rest[0][0]
178 | else:
179 | image_filename = '%s.png' % str(batch_idx + 1).zfill(3)
180 |
181 | X_batch = Variable(X_batch.to(device='cuda'))
182 | y_batch = Variable(y_batch.to(device='cuda'))
183 | # start = timeit.default_timer()
184 | y_out = model(X_batch)
185 | # stop = timeit.default_timer()
186 | # print('Time: ', stop - start)
187 | tmp2 = y_batch.detach().cpu().numpy()
188 | tmp = y_out.detach().cpu().numpy()
189 | tmp[tmp>=0.5] = 1
190 | tmp[tmp<0.5] = 0
191 | tmp2[tmp2>0] = 1
192 | tmp2[tmp2<=0] = 0
193 | tmp2 = tmp2.astype(int)
194 | tmp = tmp.astype(int)
195 |
196 | # print(np.unique(tmp2))
197 | yHaT = tmp
198 | yval = tmp2
199 |
200 | epsilon = 1e-20
201 |
202 | del X_batch, y_batch,tmp,tmp2, y_out
203 |
204 |
205 | yHaT[yHaT==1] =255
206 | yval[yval==1] =255
207 | fulldir = direc+"/{}/".format(epoch)
208 | # print(fulldir+image_filename)
209 | if not os.path.isdir(fulldir):
210 |
211 | os.makedirs(fulldir)
212 |
213 | cv2.imwrite(fulldir+image_filename, yHaT[0,1,:,:])
214 | # cv2.imwrite(fulldir+'/gt_{}.png'.format(count), yval[0,:,:])
215 | fulldir = direc+"/{}/".format(epoch)
216 | torch.save(model.state_dict(), fulldir+args.modelname+".pth")
217 | torch.save(model.state_dict(), direc+"final_model.pth")
218 |
219 |
220 |
221 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 | from skimage import io,color
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms as T
9 | from torchvision.transforms import functional as F
10 |
11 | from typing import Callable
12 | import os
13 | import cv2
14 | import pandas as pd
15 |
16 | from numbers import Number
17 | from typing import Container
18 | from collections import defaultdict
19 |
20 |
21 | def to_long_tensor(pic):
22 | # handle numpy array
23 | img = torch.from_numpy(np.array(pic, np.uint8))
24 | # backward compatibility
25 | return img.long()
26 |
27 |
28 | def correct_dims(*images):
29 | corr_images = []
30 | # print(images)
31 | for img in images:
32 | if len(img.shape) == 2:
33 | corr_images.append(np.expand_dims(img, axis=2))
34 | else:
35 | corr_images.append(img)
36 |
37 | if len(corr_images) == 1:
38 | return corr_images[0]
39 | else:
40 | return corr_images
41 |
42 |
43 | class JointTransform2D:
44 | """
45 | Performs augmentation on image and mask when called. Due to the randomness of augmentation transforms,
46 | it is not enough to simply apply the same Transform from torchvision on the image and mask separetely.
47 | Doing this will result in messing up the ground truth mask. To circumvent this problem, this class can
48 | be used, which will take care of the problems above.
49 |
50 | Args:
51 | crop: tuple describing the size of the random crop. If bool(crop) evaluates to False, no crop will
52 | be taken.
53 | p_flip: float, the probability of performing a random horizontal flip.
54 | color_jitter_params: tuple describing the parameters of torchvision.transforms.ColorJitter.
55 | If bool(color_jitter_params) evaluates to false, no color jitter transformation will be used.
56 | p_random_affine: float, the probability of performing a random affine transform using
57 | torchvision.transforms.RandomAffine.
58 | long_mask: bool, if True, returns the mask as LongTensor in label-encoded format.
59 | """
60 | def __init__(self, crop=(32, 32), p_flip=0.5, color_jitter_params=(0.1, 0.1, 0.1, 0.1),
61 | p_random_affine=0, long_mask=False):
62 | self.crop = crop
63 | self.p_flip = p_flip
64 | self.color_jitter_params = color_jitter_params
65 | if color_jitter_params:
66 | self.color_tf = T.ColorJitter(*color_jitter_params)
67 | self.p_random_affine = p_random_affine
68 | self.long_mask = long_mask
69 |
70 | def __call__(self, image, mask):
71 | # transforming to PIL image
72 | image, mask = F.to_pil_image(image), F.to_pil_image(mask)
73 |
74 | # random crop
75 | if self.crop:
76 | i, j, h, w = T.RandomCrop.get_params(image, self.crop)
77 | image, mask = F.crop(image, i, j, h, w), F.crop(mask, i, j, h, w)
78 |
79 | if np.random.rand() < self.p_flip:
80 | image, mask = F.hflip(image), F.hflip(mask)
81 |
82 | # color transforms || ONLY ON IMAGE
83 | if self.color_jitter_params:
84 | image = self.color_tf(image)
85 |
86 | # random affine transform
87 | if np.random.rand() < self.p_random_affine:
88 | affine_params = T.RandomAffine(180).get_params((-90, 90), (1, 1), (2, 2), (-45, 45), self.crop)
89 | image, mask = F.affine(image, *affine_params), F.affine(mask, *affine_params)
90 |
91 | # transforming to tensor
92 | image = F.to_tensor(image)
93 | if not self.long_mask:
94 | mask = F.to_tensor(mask)
95 | else:
96 | mask = to_long_tensor(mask)
97 |
98 | return image, mask
99 |
100 |
101 | class ImageToImage2D(Dataset):
102 | """
103 | Reads the images and applies the augmentation transform on them.
104 | Usage:
105 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to
106 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image, mask and image
107 | filename.
108 | 2. With unet.model.Model wrapper, an instance of this object should be passed as train or validation
109 | datasets.
110 |
111 | Args:
112 | dataset_path: path to the dataset. Structure of the dataset should be:
113 | dataset_path
114 | |-- images
115 | |-- img001.png
116 | |-- img002.png
117 | |-- ...
118 | |-- masks
119 | |-- img001.png
120 | |-- img002.png
121 | |-- ...
122 |
123 | joint_transform: augmentation transform, an instance of JointTransform2D. If bool(joint_transform)
124 | evaluates to False, torchvision.transforms.ToTensor will be used on both image and mask.
125 | one_hot_mask: bool, if True, returns the mask in one-hot encoded form.
126 | """
127 |
128 | def __init__(self, dataset_path: str, joint_transform: Callable = None, one_hot_mask: int = False) -> None:
129 | self.dataset_path = dataset_path
130 | self.input_path = os.path.join(dataset_path, 'img')
131 | self.output_path = os.path.join(dataset_path, 'labelcol')
132 | self.images_list = os.listdir(self.input_path)
133 | self.one_hot_mask = one_hot_mask
134 |
135 | if joint_transform:
136 | self.joint_transform = joint_transform
137 | else:
138 | to_tensor = T.ToTensor()
139 | self.joint_transform = lambda x, y: (to_tensor(x), to_tensor(y))
140 |
141 | def __len__(self):
142 | return len(os.listdir(self.input_path))
143 |
144 | def __getitem__(self, idx):
145 | image_filename = self.images_list[idx]
146 | #print(image_filename[: -3])
147 | # read image
148 | # print(os.path.join(self.input_path, image_filename))
149 | # print(os.path.join(self.output_path, image_filename[: -3] + "png"))
150 | # print(os.path.join(self.input_path, image_filename))
151 | image = cv2.imread(os.path.join(self.input_path, image_filename))
152 | # print(image.shape)
153 | # read mask image
154 | mask = cv2.imread(os.path.join(self.output_path, image_filename[: -3] + "png"),0)
155 |
156 | mask[mask<=127] = 0
157 | mask[mask>127] = 1
158 | # correct dimensions if needed
159 | image, mask = correct_dims(image, mask)
160 | # print(image.shape)
161 |
162 | if self.joint_transform:
163 | image, mask = self.joint_transform(image, mask)
164 |
165 | if self.one_hot_mask:
166 | assert self.one_hot_mask > 0, 'one_hot_mask must be nonnegative'
167 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1)
168 | # mask = np.swapaxes(mask,2,0)
169 | # print(image.shape)
170 | # print(mask.shape)
171 | # mask = np.transpose(mask,(2,0,1))
172 | # image = np.transpose(image,(2,0,1))
173 | # print(image.shape)
174 | # print(mask.shape)
175 |
176 | return image, mask, image_filename
177 |
178 |
179 | class Image2D(Dataset):
180 | """
181 | Reads the images and applies the augmentation transform on them. As opposed to ImageToImage2D, this
182 | reads a single image and requires a simple augmentation transform.
183 | Usage:
184 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to
185 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image and image
186 | filename.
187 | 2. With unet.model.Model wrapper, an instance of this object should be passed as a prediction
188 | dataset.
189 |
190 | Args:
191 |
192 | dataset_path: path to the dataset. Structure of the dataset should be:
193 | dataset_path
194 | |-- images
195 | |-- img001.png
196 | |-- img002.png
197 | |-- ...
198 |
199 | transform: augmentation transform. If bool(joint_transform) evaluates to False,
200 | torchvision.transforms.ToTensor will be used.
201 | """
202 |
203 | def __init__(self, dataset_path: str, transform: Callable = None):
204 |
205 | self.dataset_path = dataset_path
206 | self.input_path = os.path.join(dataset_path, 'img')
207 | self.images_list = os.listdir(self.input_path)
208 |
209 | if transform:
210 | self.transform = transform
211 | else:
212 | self.transform = T.ToTensor()
213 |
214 | def __len__(self):
215 | return len(os.listdir(self.input_path))
216 |
217 | def __getitem__(self, idx):
218 |
219 | image_filename = self.images_list[idx]
220 |
221 | image = cv2.imread(os.path.join(self.input_path, image_filename))
222 |
223 | # image = np.transpose(image,(2,0,1))
224 |
225 | image = correct_dims(image)
226 |
227 | image = self.transform(image)
228 |
229 | # image = np.swapaxes(image,2,0)
230 |
231 | return image, image_filename
232 |
233 | def chk_mkdir(*paths: Container) -> None:
234 | """
235 | Creates folders if they do not exist.
236 |
237 | Args:
238 | paths: Container of paths to be created.
239 | """
240 | for path in paths:
241 | if not os.path.exists(path):
242 | os.makedirs(path)
243 |
244 |
245 | class Logger:
246 | def __init__(self, verbose=False):
247 | self.logs = defaultdict(list)
248 | self.verbose = verbose
249 |
250 | def log(self, logs):
251 | for key, value in logs.items():
252 | self.logs[key].append(value)
253 |
254 | if self.verbose:
255 | print(logs)
256 |
257 | def get_logs(self):
258 | return self.logs
259 |
260 | def to_csv(self, path):
261 | pd.DataFrame(self.logs).to_csv(path, index=None)
262 |
263 |
264 | class MetricList:
265 | def __init__(self, metrics):
266 | assert isinstance(metrics, dict), '\'metrics\' must be a dictionary of callables'
267 | self.metrics = metrics
268 | self.results = {key: 0.0 for key in self.metrics.keys()}
269 |
270 | def __call__(self, y_out, y_batch):
271 | for key, value in self.metrics.items():
272 | self.results[key] += value(y_out, y_batch)
273 |
274 | def reset(self):
275 | self.results = {key: 0.0 for key in self.metrics.keys()}
276 |
277 | def get_results(self, normalize=False):
278 | assert isinstance(normalize, bool) or isinstance(normalize, Number), '\'normalize\' must be boolean or a number'
279 | if not normalize:
280 | return self.results
281 | else:
282 | return {key: value/normalize for key, value in self.results.items()}
283 |
--------------------------------------------------------------------------------
/utils_gray.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 | from skimage import io,color
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms as T
9 | from torchvision.transforms import functional as F
10 |
11 | from typing import Callable
12 | import os
13 | import cv2
14 | import pandas as pd
15 |
16 | from numbers import Number
17 | from typing import Container
18 | from collections import defaultdict
19 |
20 |
21 | def to_long_tensor(pic):
22 | # handle numpy array
23 | img = torch.from_numpy(np.array(pic, np.uint8))
24 | # backward compatibility
25 | return img.long()
26 |
27 |
28 | def correct_dims(*images):
29 | corr_images = []
30 | # print(images)
31 | for img in images:
32 | if len(img.shape) == 2:
33 | corr_images.append(np.expand_dims(img, axis=2))
34 | else:
35 | corr_images.append(img)
36 |
37 | if len(corr_images) == 1:
38 | return corr_images[0]
39 | else:
40 | return corr_images
41 |
42 |
43 | class JointTransform2D:
44 | """
45 | Performs augmentation on image and mask when called. Due to the randomness of augmentation transforms,
46 | it is not enough to simply apply the same Transform from torchvision on the image and mask separetely.
47 | Doing this will result in messing up the ground truth mask. To circumvent this problem, this class can
48 | be used, which will take care of the problems above.
49 |
50 | Args:
51 | crop: tuple describing the size of the random crop. If bool(crop) evaluates to False, no crop will
52 | be taken.
53 | p_flip: float, the probability of performing a random horizontal flip.
54 | color_jitter_params: tuple describing the parameters of torchvision.transforms.ColorJitter.
55 | If bool(color_jitter_params) evaluates to false, no color jitter transformation will be used.
56 | p_random_affine: float, the probability of performing a random affine transform using
57 | torchvision.transforms.RandomAffine.
58 | long_mask: bool, if True, returns the mask as LongTensor in label-encoded format.
59 | """
60 | def __init__(self, crop=(32, 32), p_flip=0.5, color_jitter_params=(0.1, 0.1, 0.1, 0.1),
61 | p_random_affine=0, long_mask=False):
62 | self.crop = crop
63 | self.p_flip = p_flip
64 | self.color_jitter_params = color_jitter_params
65 | if color_jitter_params:
66 | self.color_tf = T.ColorJitter(*color_jitter_params)
67 | self.p_random_affine = p_random_affine
68 | self.long_mask = long_mask
69 |
70 | def __call__(self, image, mask):
71 | # transforming to PIL image
72 | image, mask = F.to_pil_image(image), F.to_pil_image(mask)
73 |
74 | # random crop
75 | if self.crop:
76 | i, j, h, w = T.RandomCrop.get_params(image, self.crop)
77 | image, mask = F.crop(image, i, j, h, w), F.crop(mask, i, j, h, w)
78 |
79 | if np.random.rand() < self.p_flip:
80 | image, mask = F.hflip(image), F.hflip(mask)
81 |
82 | # color transforms || ONLY ON IMAGE
83 | if self.color_jitter_params:
84 | image = self.color_tf(image)
85 |
86 | # random affine transform
87 | if np.random.rand() < self.p_random_affine:
88 | affine_params = T.RandomAffine(180).get_params((-90, 90), (1, 1), (2, 2), (-45, 45), self.crop)
89 | image, mask = F.affine(image, *affine_params), F.affine(mask, *affine_params)
90 |
91 | # transforming to tensor
92 | image = F.to_tensor(image)
93 | if not self.long_mask:
94 | mask = F.to_tensor(mask)
95 | else:
96 | mask = to_long_tensor(mask)
97 |
98 | return image, mask
99 |
100 |
101 | class ImageToImage2D(Dataset):
102 | """
103 | Reads the images and applies the augmentation transform on them.
104 | Usage:
105 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to
106 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image, mask and image
107 | filename.
108 | 2. With unet.model.Model wrapper, an instance of this object should be passed as train or validation
109 | datasets.
110 |
111 | Args:
112 | dataset_path: path to the dataset. Structure of the dataset should be:
113 | dataset_path
114 | |-- images
115 | |-- img001.png
116 | |-- img002.png
117 | |-- ...
118 | |-- masks
119 | |-- img001.png
120 | |-- img002.png
121 | |-- ...
122 |
123 | joint_transform: augmentation transform, an instance of JointTransform2D. If bool(joint_transform)
124 | evaluates to False, torchvision.transforms.ToTensor will be used on both image and mask.
125 | one_hot_mask: bool, if True, returns the mask in one-hot encoded form.
126 | """
127 |
128 | def __init__(self, dataset_path: str, joint_transform: Callable = None, one_hot_mask: int = False) -> None:
129 | self.dataset_path = dataset_path
130 | self.input_path = os.path.join(dataset_path, 'img')
131 | self.output_path = os.path.join(dataset_path, 'labelcol')
132 | self.images_list = os.listdir(self.input_path)
133 | self.one_hot_mask = one_hot_mask
134 |
135 | if joint_transform:
136 | self.joint_transform = joint_transform
137 | else:
138 | to_tensor = T.ToTensor()
139 | self.joint_transform = lambda x, y: (to_tensor(x), to_tensor(y))
140 |
141 | def __len__(self):
142 | return len(os.listdir(self.input_path))
143 |
144 | def __getitem__(self, idx):
145 | image_filename = self.images_list[idx]
146 | #print(image_filename[: -3])
147 | # read image
148 | # print(os.path.join(self.input_path, image_filename))
149 | # print(os.path.join(self.output_path, image_filename[: -3] + "png"))
150 | # print(os.path.join(self.input_path, image_filename))
151 | image = cv2.imread(os.path.join(self.input_path, image_filename),0)
152 | # print(image.shape)
153 | # read mask image
154 | mask = cv2.imread(os.path.join(self.output_path, image_filename[: -3] + "png"),0)
155 |
156 | # correct dimensions if needed
157 | image, mask = correct_dims(image, mask)
158 | # print(image.shape)
159 | mask[mask<127] = 0
160 | mask[mask>=127] = 1
161 |
162 |
163 | if self.joint_transform:
164 | image, mask = self.joint_transform(image, mask)
165 |
166 | if self.one_hot_mask:
167 | assert self.one_hot_mask > 0, 'one_hot_mask must be nonnegative'
168 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1)
169 | # mask = np.swapaxes(mask,2,0)
170 | # print(image.shape)
171 | # print(mask.shape)
172 | # mask = np.transpose(mask,(2,0,1))
173 | # image = np.transpose(image,(2,0,1))
174 | # print(image.shape)
175 | # print(mask.shape)
176 |
177 | return image, mask, image_filename
178 |
179 |
180 | class Image2D(Dataset):
181 | """
182 | Reads the images and applies the augmentation transform on them. As opposed to ImageToImage2D, this
183 | reads a single image and requires a simple augmentation transform.
184 | Usage:
185 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to
186 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image and image
187 | filename.
188 | 2. With unet.model.Model wrapper, an instance of this object should be passed as a prediction
189 | dataset.
190 |
191 | Args:
192 |
193 | dataset_path: path to the dataset. Structure of the dataset should be:
194 | dataset_path
195 | |-- images
196 | |-- img001.png
197 | |-- img002.png
198 | |-- ...
199 |
200 | transform: augmentation transform. If bool(joint_transform) evaluates to False,
201 | torchvision.transforms.ToTensor will be used.
202 | """
203 |
204 | def __init__(self, dataset_path: str, transform: Callable = None):
205 |
206 | self.dataset_path = dataset_path
207 | self.input_path = os.path.join(dataset_path, 'img')
208 | self.images_list = os.listdir(self.input_path)
209 |
210 | if transform:
211 | self.transform = transform
212 | else:
213 | self.transform = T.ToTensor()
214 |
215 | def __len__(self):
216 | return len(os.listdir(self.input_path))
217 |
218 | def __getitem__(self, idx):
219 |
220 | image_filename = self.images_list[idx]
221 |
222 | image = cv2.imread(os.path.join(self.input_path, image_filename),0)
223 |
224 | # image = np.transpose(image,(2,0,1))
225 |
226 | image = correct_dims(image)
227 |
228 | image = self.transform(image)
229 |
230 | # image = np.swapaxes(image,2,0)
231 |
232 | return image, image_filename
233 |
234 | def chk_mkdir(*paths: Container) -> None:
235 | """
236 | Creates folders if they do not exist.
237 |
238 | Args:
239 | paths: Container of paths to be created.
240 | """
241 | for path in paths:
242 | if not os.path.exists(path):
243 | os.makedirs(path)
244 |
245 |
246 | class Logger:
247 | def __init__(self, verbose=False):
248 | self.logs = defaultdict(list)
249 | self.verbose = verbose
250 |
251 | def log(self, logs):
252 | for key, value in logs.items():
253 | self.logs[key].append(value)
254 |
255 | if self.verbose:
256 | print(logs)
257 |
258 | def get_logs(self):
259 | return self.logs
260 |
261 | def to_csv(self, path):
262 | pd.DataFrame(self.logs).to_csv(path, index=None)
263 |
264 |
265 | class MetricList:
266 | def __init__(self, metrics):
267 | assert isinstance(metrics, dict), '\'metrics\' must be a dictionary of callables'
268 | self.metrics = metrics
269 | self.results = {key: 0.0 for key in self.metrics.keys()}
270 |
271 | def __call__(self, y_out, y_batch):
272 | for key, value in self.metrics.items():
273 | self.results[key] += value(y_out, y_batch)
274 |
275 | def reset(self):
276 | self.results = {key: 0.0 for key in self.metrics.keys()}
277 |
278 | def get_results(self, normalize=False):
279 | assert isinstance(normalize, bool) or isinstance(normalize, Number), '\'normalize\' must be boolean or a number'
280 | if not normalize:
281 | return self.results
282 | else:
283 | return {key: value/normalize for key, value in self.results.items()}
284 |
--------------------------------------------------------------------------------