├── .gitignore ├── LICENSE ├── README.md ├── core ├── datasets.py ├── framework.py ├── losses.py ├── networks │ ├── affnet.py │ └── sdhnet.py └── utils │ ├── aug_transform.py │ └── warp.py ├── eval.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shenglong Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Distilled Hierarchical Network 2 | 3 | **[Self-Distilled Hierarchical Network for Unsupervised Deformable Image Registration](https://ieeexplore.ieee.org/abstract/document/10042453)** 4 | 5 | IEEE Transactions on Medical Imaging (TMI) 2023 6 | 7 | Shenglong Zhou, Bo Hu, Zhiwei Xiong and Feng Wu 8 | 9 | University of Science and Technology of China (USTC) 10 | 11 | ## Introduction 12 | 13 | ![Framework](https://user-images.githubusercontent.com/26156941/201927630-23340d83-52a0-45b6-a007-19c7fb603ea9.png) 14 | 15 | We present a novel unsupervised learning approach named Self-Distilled Hierarchical Network (SDHNet). 16 | By decomposing the registration procedure into several iterations, SDHNet generates hierarchical deformation fields (HDFs) simultaneously in each iteration and connects different iterations utilizing the learned hidden state. 17 | Hierarchical features are extracted to generate HDFs through several parallel GRUs, and HDFs are then fused adaptively conditioned on themselves as well as contextual features from the input image. 18 | Furthermore, different from common unsupervised methods that only apply similarity loss and regularization loss, SDHNet introduces a novel self-deformation distillation scheme. 19 | This scheme distills the final deformation field as the teacher guidance, which adds constraints for intermediate deformation fields. 20 | 21 | ## Requirements 22 | The packages and their corresponding version we used in this repository are listed below. 23 | - Python 3 24 | - Pytorch 1.1 25 | - Numpy 26 | - SimpleITK 27 | 28 | ## Training 29 | After configuring the environment, please use this command to train the model. 30 | ```python 31 | python -m torch.distributed.launch --nproc_per_node=4 train.py --name=SDHNet --iters=6 --dataset=brain --data_path=/xx/xx/ --base_path=/xx/xx/ 32 | 33 | ``` 34 | 35 | ## Testing 36 | Use this command to obtain the testing results. 37 | ```python 38 | python eval.py --name=SDHNet --model=SDHNet_lpba --dataset=brain --dataset_test=lpba --iters=6 --local_rank=0 --data_path=/xx/xx/ --base_path=/xx/xx/ 39 | ``` 40 | 41 | ## Datasets and Pre-trained Models (Based on Cascade VTN) 42 | We follow Cascade VTN to prepare the training and testing datasets, please refer to [Cascade VTN](https://github.com/microsoft/Recursive-Cascaded-Networks) for details. 43 | 44 | The related [pretrained models](https://drive.google.com/drive/folders/1BpxkIzL_SrPuKdqC_buiINawNZVMqoWc?usp=share_link) are available, please refer to the testing command for evaluating. 45 | 46 | ## Citation 47 | If you find this work or code is helpful in your research, please cite: 48 | ``` 49 | @article{zhou2023self, 50 | title={Self-Distilled Hierarchical Network for Unsupervised Deformable Image Registration}, 51 | author={Zhou, Shenglong and Hu, Bo and Xiong, Zhiwei and Wu, Feng}, 52 | journal={IEEE Transactions on Medical Imaging}, 53 | year={2023}, 54 | publisher={IEEE} 55 | } 56 | ``` 57 | 58 | ## Contact 59 | Due to our further exploration of the self-distillation, the current repo does not involve the related part temporarily. 60 | 61 | Please be free to contact us by e-mail (slzhou96@mail.ustc.edu.cn) or WeChat (ZslBlcony) if you have any questions. 62 | 63 | ## Acknowledgements 64 | We follow the functional implementation in [Cascade VTN](https://github.com/microsoft/Recursive-Cascaded-Networks), and the overall code framework is adapted from [RAFT](https://github.com/princeton-vl/RAFT). 65 | 66 | Thanks a lot for their great contribution! 67 | 68 | 69 | -------------------------------------------------------------------------------- /core/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import random 4 | import numpy as np 5 | import tifffile as tif 6 | 7 | import torch 8 | import torch.utils.data as data 9 | 10 | 11 | def read_datasets(path, datasets): 12 | files = [] 13 | for d in datasets: 14 | files.extend([join(path, d, i) for i in os.listdir(join(path, d))]) 15 | return files 16 | 17 | 18 | def generate_pairs(files): 19 | pairs = [] 20 | for i, d1 in enumerate(files): 21 | for j, d2 in enumerate(files): 22 | if i != j: 23 | pairs.append([join(d1, 'volume.tif'), join(d2, 'volume.tif')]) 24 | return pairs 25 | 26 | 27 | def generate_pairs_val(files): 28 | pairs = [] 29 | labels = [] 30 | for i, d1 in enumerate(files): 31 | for j, d2 in enumerate(files): 32 | if i != j: 33 | pairs.append([join(d1, 'volume.tif'), join(d2, 'volume.tif')]) 34 | labels.append([join(d1, 'segmentation.tif'), join(d2, 'segmentation.tif')]) 35 | return pairs, labels 36 | 37 | 38 | def generate_lspig_val(files): 39 | pairs = [] 40 | labels = [] 41 | files.sort() 42 | for i in range(0, len(files), 2): 43 | d1 = files[i] 44 | d2 = files[i + 1] 45 | pairs.append([join(d1, 'volume.tif'), join(d2, 'volume.tif')]) 46 | labels.append([join(d1, 'segmentation.tif'), join(d2, 'segmentation.tif')]) 47 | pairs.append([join(d2, 'volume.tif'), join(d1, 'volume.tif')]) 48 | labels.append([join(d2, 'segmentation.tif'), join(d1, 'segmentation.tif')]) 49 | 50 | return pairs, labels 51 | 52 | 53 | def generate_atlas(atlas, files): 54 | pairs = [] 55 | for d in files: 56 | pairs.append([join(atlas, 'volume.tif'), join(d, 'volume.tif')]) 57 | 58 | return pairs 59 | 60 | 61 | def generate_atlas_val(atlas, files): 62 | pairs = [] 63 | labels = [] 64 | for d in files: 65 | if 'S01' in d: 66 | continue 67 | pairs.append([join(atlas, 'volume.tif'), join(d, 'volume.tif')]) 68 | labels.append([join(atlas, 'segmentation.tif'), join(d, 'segmentation.tif')]) 69 | return pairs, labels 70 | 71 | 72 | class LiverTrain(data.Dataset): 73 | def __init__(self, args): 74 | self.seed = False 75 | self.size = [128, 128, 128] 76 | self.datasets = ['msd_hep', 'msd_pan', 'msd_liver', 'youyi_liver'] 77 | self.train_path = join(args.data_path, 'Train') 78 | self.files = read_datasets(self.train_path, self.datasets) 79 | self.pairs = generate_pairs(self.files) 80 | 81 | def __getitem__(self, index): 82 | if not self.seed: 83 | random.seed(123) 84 | np.random.seed(123) 85 | torch.manual_seed(123) 86 | torch.cuda.manual_seed_all(123) 87 | self.seed = True 88 | 89 | index = index % len(self.pairs) 90 | data1, data2 = self.pairs[index] 91 | 92 | image1 = torch.from_numpy(tif.imread(data1)[np.newaxis]).float() / 255.0 93 | image2 = torch.from_numpy(tif.imread(data2)[np.newaxis]).float() / 255.0 94 | 95 | return image1, image2 96 | 97 | def __len__(self): 98 | return len(self.pairs) 99 | 100 | 101 | class LiverTest(data.Dataset): 102 | def __init__(self, args, datas): 103 | self.size = [128, 128, 128] 104 | self.datasets = [datas] 105 | self.test_path = join(args.data_path, 'Test') 106 | self.files = read_datasets(self.test_path, self.datasets) 107 | self.pairs, self.labels = generate_pairs_val(self.files) 108 | 109 | def __getitem__(self, index): 110 | data1, data2 = self.pairs[index] 111 | seg1, seg2 = self.labels[index] 112 | 113 | image1 = torch.from_numpy(tif.imread(data1)[np.newaxis]).float() / 255.0 114 | image2 = torch.from_numpy(tif.imread(data2)[np.newaxis]).float() / 255.0 115 | 116 | label1 = torch.from_numpy(tif.imread(seg1)[np.newaxis]).float() 117 | label2 = torch.from_numpy(tif.imread(seg2)[np.newaxis]).float() 118 | 119 | return image1, image2, label1, label2 120 | 121 | def __len__(self): 122 | return len(self.pairs) 123 | 124 | 125 | class LspigTest(data.Dataset): 126 | def __init__(self, args, datas): 127 | self.size = [128, 128, 128] 128 | self.datasets = [datas] 129 | self.test_path = join(args.data_path, 'Test') 130 | self.files = read_datasets(self.test_path, self.datasets) 131 | self.pairs, self.labels = generate_lspig_val(self.files) 132 | 133 | def __getitem__(self, index): 134 | data1, data2 = self.pairs[index] 135 | seg1, seg2 = self.labels[index] 136 | 137 | image1 = torch.from_numpy(tif.imread(data1)[np.newaxis]).float() / 255.0 138 | image2 = torch.from_numpy(tif.imread(data2)[np.newaxis]).float() / 255.0 139 | 140 | label1 = torch.from_numpy(tif.imread(seg1)[np.newaxis]).float() 141 | label2 = torch.from_numpy(tif.imread(seg2)[np.newaxis]).float() 142 | 143 | return image1, image2, label1, label2 144 | 145 | def __len__(self): 146 | return len(self.pairs) 147 | 148 | 149 | class BrainTrain(data.Dataset): 150 | def __init__(self, args): 151 | self.seed = False 152 | self.size = [128, 128, 128] 153 | self.datasets = ['abide', 'abidef', 'adhd', 'adni'] 154 | self.train_path = join(args.data_path, 'Train') 155 | self.atlas = join(args.data_path, 'Test/lpba/S01') 156 | self.files = read_datasets(self.train_path, self.datasets) 157 | self.pairs = generate_atlas(self.atlas, self.files) 158 | 159 | def __getitem__(self, index): 160 | if not self.seed: 161 | random.seed(123) 162 | np.random.seed(123) 163 | torch.manual_seed(123) 164 | torch.cuda.manual_seed_all(123) 165 | self.seed = True 166 | 167 | index = index % len(self.pairs) 168 | data1, data2 = self.pairs[index] 169 | 170 | image1 = torch.from_numpy(tif.imread(data1)[np.newaxis]).float() / 255.0 171 | image2 = torch.from_numpy(tif.imread(data2)[np.newaxis]).float() / 255.0 172 | 173 | return image1, image2 174 | 175 | def __len__(self): 176 | return len(self.pairs) 177 | 178 | 179 | class BrainTest(data.Dataset): 180 | def __init__(self, args, datas): 181 | self.size = [128, 128, 128] 182 | self.datasets = [datas] 183 | self.test_path = join(args.data_path, 'Test') 184 | self.atlas = join(args.data_path, 'Test/lpba/S01') 185 | self.files = read_datasets(self.test_path, self.datasets) 186 | self.pairs, self.labels = generate_atlas_val(self.atlas, self.files) 187 | self.seg_values = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 41, 42, 43, 44, 45, 46, 47, 48, 49, 188 | 50, 61, 62, 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 101, 102, 189 | 121, 122, 161, 162, 163, 164, 165, 166, 181, 182] 190 | 191 | def __getitem__(self, index): 192 | data1, data2 = self.pairs[index] 193 | seg1, seg2 = self.labels[index] 194 | 195 | image1 = torch.from_numpy(tif.imread(data1)[np.newaxis]).float() / 255.0 196 | image2 = torch.from_numpy(tif.imread(data2)[np.newaxis]).float() / 255.0 197 | 198 | label1 = torch.from_numpy(tif.imread(seg1)[np.newaxis]).float() 199 | label2 = torch.from_numpy(tif.imread(seg2)[np.newaxis]).float() 200 | 201 | return image1, image2, label1, label2 202 | 203 | def __len__(self): 204 | return len(self.pairs) 205 | -------------------------------------------------------------------------------- /core/framework.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from .utils import aug_transform, warp 5 | from .networks import affnet, sdhnet 6 | 7 | 8 | class Framework(nn.Module): 9 | def __init__(self, args): 10 | super(Framework, self).__init__() 11 | self.args = args 12 | self.cdim = 32 13 | self.hdim = 16 14 | self.flow_multiplier = 1.0 / args.iters 15 | self.sample_power = aug_transform.sample_power 16 | self.free_form_fields = aug_transform.free_form_fields 17 | self.reconstruction = warp.warp3D() 18 | 19 | self.affnet = affnet.AffineNet() 20 | self.context = sdhnet.ContextNet(outputc=self.cdim + self.hdim) 21 | self.defnet = nn.ModuleList([sdhnet.SDHNet(hdim=self.hdim, flow_multiplier=self.flow_multiplier) for _ in range(args.iters)]) 22 | 23 | def forward(self, Img1, Img2, augment=True): 24 | if augment: 25 | bs = Img1.shape[0] 26 | imgs = Img1.shape[2:5] # D, H, W 27 | 28 | control_fields = (self.sample_power(-0.4, 0.4, 3, [bs, 5, 5, 5, 3]) * 29 | torch.Tensor(np.array(imgs).astype(np.float) // 4)).permute(0, 4, 1, 2, 3) 30 | augFlow = (self.free_form_fields(imgs, control_fields)).cuda() # B, C, D, H, W 31 | 32 | augImg2 = self.reconstruction(Img2, augFlow) # B, C, D, H, W 33 | else: 34 | augImg2 = Img2 35 | 36 | affines = self.affnet(Img1, augImg2) 37 | 38 | contexts = self.context(Img1) # extract the features of the fixed image 39 | cont, hid = torch.split(contexts, [self.cdim, self.hdim], dim=1) 40 | cont = torch.relu(cont) 41 | hid = [torch.tanh(hid), 42 | torch.tanh(torch.max_pool3d(hid, kernel_size=2, stride=2)), 43 | torch.tanh(torch.max_pool3d(hid, kernel_size=4, stride=4))] 44 | 45 | augImg2_affine = self.reconstruction(augImg2, affines['flow']) 46 | deforms_0, hid = self.defnet[0](Img1, augImg2_affine, cont, hid) 47 | 48 | I = torch.cuda.FloatTensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) 49 | agg_flow_0 = torch.einsum('bij,bjxyz->bixyz', affines['W'] + I, deforms_0['flow']) + affines['flow'] 50 | 51 | warpImg = self.reconstruction(augImg2, agg_flow_0) 52 | agg_flow = agg_flow_0 53 | 54 | Deforms = [deforms_0] 55 | agg_flows = [] 56 | for i in range(self.args.iters - 1): 57 | deforms, hid = self.defnet[i + 1](Img1, warpImg, cont, hid) 58 | agg_flow = self.reconstruction(agg_flow, deforms['flow']) + deforms['flow'] 59 | warpImg = self.reconstruction(augImg2, agg_flow) 60 | 61 | Deforms.append(deforms) 62 | agg_flows.append(agg_flow) 63 | 64 | return augImg2, affines, Deforms, agg_flow, agg_flows 65 | -------------------------------------------------------------------------------- /core/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def det3x3(M): 6 | M = [[M[:, i, j] for j in range(3)] for i in range(3)] 7 | 8 | det = (M[0][0] * M[1][1] * M[2][2] + M[0][1] * M[1][2] * M[2][0] + M[0][2] * M[1][0] * M[2][1]) - \ 9 | (M[0][0] * M[1][2] * M[2][1] + M[0][1] * M[1][0] * M[2][2] + M[0][2] * M[1][1] * M[2][0]) 10 | 11 | return det 12 | 13 | 14 | def elem_sym_polys_of_eigen_values(M): 15 | M = [[M[:, i, j] for j in range(3)] for i in range(3)] 16 | 17 | sigma1 = (M[0][0] + M[1][1] + M[2][2]) 18 | 19 | sigma2 = (M[0][0] * M[1][1] + M[1][1] * M[2][2] + M[2][2] * M[0][0]) - \ 20 | (M[0][1] * M[1][0] + M[1][2] * M[2][1] + M[2][0] * M[0][2]) 21 | 22 | sigma3 = (M[0][0] * M[1][1] * M[2][2] + M[0][1] * M[1][2] * M[2][0] + M[0][2] * M[1][0] * M[2][1]) - \ 23 | (M[0][0] * M[1][2] * M[2][1] + M[0][1] * M[1][0] * M[2][2] + M[0][2] * M[1][1] * M[2][0]) 24 | 25 | return sigma1, sigma2, sigma3 26 | 27 | 28 | def similarity_loss(img1, img2_warped): 29 | sizes = np.prod(img1.shape[1:]) 30 | flatten1 = img1.view(-1, sizes) 31 | flatten2 = img2_warped.view(-1, sizes) 32 | 33 | mean1 = torch.mean(flatten1, -1).view(-1, 1) 34 | mean2 = torch.mean(flatten2, -1).view(-1, 1) 35 | var1 = torch.mean((flatten1 - mean1) ** 2, -1) 36 | var2 = torch.mean((flatten2 - mean2) ** 2, -1) 37 | 38 | conv12 = torch.mean((flatten1 - mean1) * (flatten2 - mean2), -1) 39 | pearson_r = conv12 / torch.sqrt((var1 + 1e-6) * (var2 + 1e-6)) 40 | 41 | raw_loss = 1 - pearson_r 42 | raw_loss = torch.sum(raw_loss) 43 | 44 | return raw_loss 45 | 46 | 47 | def regularize_loss(flow): 48 | ret = torch.sum((flow[:, :, 1:, :, :] - flow[:, :, :-1, :, :]) ** 2) / 2 + \ 49 | torch.sum((flow[:, :, :, 1:, :] - flow[:, :, :, :-1, :]) ** 2) / 2 + \ 50 | torch.sum((flow[:, :, :, :, 1:] - flow[:, :, :, :, :-1]) ** 2) / 2 51 | ret = ret / np.prod(flow.shape[1:5]) 52 | 53 | return ret 54 | -------------------------------------------------------------------------------- /core/networks/affnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AffineNet(nn.Module): 6 | def __init__(self, multiplier=1): 7 | super(AffineNet, self).__init__() 8 | self.multiplier = multiplier 9 | 10 | self.conv1 = nn.Conv3d(2, 16, kernel_size=3, stride=2, padding=1) # 64 * 64 * 64 11 | 12 | self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1) # 32 * 32 * 32 13 | 14 | self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1) # 16 * 16 * 16 15 | self.conv3_1 = nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1) 16 | 17 | self.conv4 = nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1) # 8 * 8 * 8 18 | self.conv4_1 = nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1) 19 | 20 | self.conv5 = nn.Conv3d(128, 256, kernel_size=3, stride=2, padding=1) # 4 * 4 * 4 21 | self.conv5_1 = nn.Conv3d(256, 256, kernel_size=3, stride=1, padding=1) 22 | 23 | self.conv6 = nn.Conv3d(256, 512, kernel_size=3, stride=2, padding=1) # 2 * 2 * 2 24 | self.conv6_1 = nn.Conv3d(512, 512, kernel_size=3, stride=1, padding=1) 25 | 26 | self.conv7_W = nn.Conv3d(512, 9, kernel_size=2, stride=1, padding=0, bias=False) 27 | self.conv7_b = nn.Conv3d(512, 3, kernel_size=2, stride=1, padding=0, bias=False) 28 | 29 | self.lrelu = nn.LeakyReLU(0.1, inplace=True) 30 | 31 | # weight initialization ? 32 | 33 | def affine_flow(self, W, b, sd, sh, sw): 34 | b = b.view([-1, 3, 1, 1, 1]) 35 | 36 | xr = torch.arange(-(sw - 1) / 2.0, sw / 2.0, 1.0) 37 | xr = xr.view([1, 1, 1, 1, -1]).cuda() 38 | yr = torch.arange(-(sh - 1) / 2.0, sh / 2.0, 1.0) 39 | yr = yr.view([1, 1, 1, -1, 1]).cuda() 40 | zr = torch.arange(-(sd - 1) / 2.0, sd / 2.0, 1.0) 41 | zr = zr.view([1, 1, -1, 1, 1]).cuda() 42 | 43 | wx = W[:, :, 0] 44 | wx = wx.view([-1, 3, 1, 1, 1]) 45 | wy = W[:, :, 1] 46 | wy = wy.view([-1, 3, 1, 1, 1]) 47 | wz = W[:, :, 2] 48 | wz = wz.view([-1, 3, 1, 1, 1]) 49 | 50 | return xr * wx + yr * wy + zr * wz + b 51 | 52 | def forward(self, image1, image2): 53 | 54 | concatImgs = torch.cat([image1, image2], 1) # B, C, D, H, W 55 | 56 | x = self.lrelu(self.conv1(concatImgs)) 57 | x = self.lrelu(self.conv2(x)) 58 | x = self.lrelu(self.conv3(x)) 59 | x = self.lrelu(self.conv3_1(x)) 60 | x = self.lrelu(self.conv4(x)) 61 | x = self.lrelu(self.conv4_1(x)) 62 | x = self.lrelu(self.conv5(x)) 63 | x = self.lrelu(self.conv5_1(x)) 64 | x = self.lrelu(self.conv6(x)) 65 | x = self.lrelu(self.conv6_1(x)) 66 | 67 | I = torch.cuda.FloatTensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) 68 | W = self.conv7_W(x).view([-1, 3, 3]) * self.multiplier 69 | b = self.conv7_b(x).view([-1, 3]) * self.multiplier 70 | 71 | A = W + I 72 | 73 | sd, sh, sw = image1.shape[2:5] 74 | flow = self.affine_flow(W, b, sd, sh, sw) # B, C, D, H, W (Displacement Field) 75 | 76 | return {'flow': flow, 'A': A, 'W': W, 'b': b} 77 | -------------------------------------------------------------------------------- /core/networks/sdhnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ..utils import warp 6 | 7 | 8 | class ContextNet(nn.Module): 9 | def __init__(self, outputc): 10 | super(ContextNet, self).__init__() 11 | 12 | self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1) 13 | self.conv2 = nn.Conv3d(16, 16, kernel_size=3, stride=1, padding=1) 14 | self.conv3 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1) 15 | self.conv4 = nn.Conv3d(32, outputc, kernel_size=3, stride=1, padding=1) 16 | 17 | self.lrelu = nn.LeakyReLU(0.1, inplace=True) 18 | 19 | def forward(self, x): 20 | x = self.lrelu(self.conv1(x)) 21 | x = self.lrelu(self.conv2(x)) 22 | x = self.lrelu(self.conv3(x)) 23 | x = self.conv4(x) 24 | 25 | return x 26 | 27 | 28 | class ExtractNet(nn.Module): 29 | def __init__(self): 30 | super(ExtractNet, self).__init__() 31 | 32 | self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1) 33 | self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1) 34 | self.conv3 = nn.Conv3d(32, 48, kernel_size=3, stride=2, padding=1) 35 | self.conv4 = nn.Conv3d(48, 64, kernel_size=3, stride=2, padding=1) 36 | 37 | self.lrelu = nn.LeakyReLU(0.1, inplace=True) 38 | 39 | def forward(self, x): 40 | x_2x = self.lrelu(self.conv1(x)) 41 | x_4x = self.lrelu(self.conv2(x_2x)) 42 | x_8x = self.lrelu(self.conv3(x_4x)) 43 | x_16x = self.lrelu(self.conv4(x_8x)) 44 | 45 | return {'1/4': x_4x, '1/8': x_8x, '1/16': x_16x} 46 | 47 | 48 | class ConvGRU(nn.Module): 49 | def __init__(self, inputc, hidden_dim): 50 | super(ConvGRU, self).__init__() 51 | self.convz1 = nn.Conv3d(inputc, hidden_dim, 3, padding=1) 52 | self.convr1 = nn.Conv3d(inputc, hidden_dim, 3, padding=1) 53 | self.convq1 = nn.Conv3d(inputc, hidden_dim, 3, padding=1) 54 | 55 | self.conv1 = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1) 56 | self.conv2 = nn.Conv3d(hidden_dim, 3, kernel_size=3, stride=1, padding=1) 57 | self.lrelu = nn.LeakyReLU(0.1, inplace=True) 58 | 59 | def forward(self, h, x): 60 | # 1st round 61 | hx = torch.cat([h, x], dim=1) 62 | 63 | z = torch.sigmoid(self.convz1(hx)) 64 | r = torch.sigmoid(self.convr1(hx)) 65 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 66 | 67 | h = (1-z) * h + z * q 68 | 69 | # flow estimation 70 | flow = self.conv2(self.lrelu(self.conv1(h))) 71 | 72 | return flow, h 73 | 74 | 75 | class Fusion(nn.Module): 76 | def __init__(self, inputc1, inputc2): 77 | super(Fusion, self).__init__() 78 | 79 | self.conv1 = nn.Conv3d((inputc1 + inputc2), 48, kernel_size=3, stride=1, padding=1) 80 | self.conv2 = nn.Conv3d(48, 16, kernel_size=3, stride=1, padding=1) 81 | self.conv3 = nn.Conv3d(16, inputc2, kernel_size=3, stride=1, padding=1) 82 | self.lrelu = nn.LeakyReLU(0.1, inplace=True) 83 | 84 | def forward(self, cont, flows): 85 | x = self.lrelu(self.conv1(torch.cat([cont, flows], 1))) 86 | x = self.lrelu(self.conv2(x)) 87 | k = self.conv3(x) 88 | 89 | b, _, d, h, w = flows.shape 90 | flow = (flows.view(b, -1, 3, d, h, w) * F.softmax(k.view(b, -1, 3, d, h, w), 1)).sum(1) 91 | return flow 92 | 93 | 94 | class SDHNet(nn.Module): 95 | def __init__(self, hdim, flow_multiplier=1.): 96 | super(SDHNet, self).__init__() 97 | self.hdim = hdim 98 | self.flow_multiplier = flow_multiplier 99 | 100 | self.extraction = ExtractNet() 101 | 102 | self.estimator_4x = ConvGRU(inputc=64 + self.hdim, hidden_dim=self.hdim) 103 | self.estimator_8x = ConvGRU(inputc=96 + self.hdim, hidden_dim=self.hdim) 104 | self.estimator_16x = ConvGRU(inputc=128 + self.hdim, hidden_dim=self.hdim) 105 | 106 | self.fusion = Fusion(inputc1=32, inputc2=3*3) 107 | 108 | self.reconstruction = warp.warp3D() 109 | 110 | def forward(self, image1, image2, c_fea, h_fea): 111 | f_fea = self.extraction(image1) 112 | m_fea = self.extraction(image2) 113 | hid_4x, hid_8x, hid_16x = h_fea 114 | 115 | b, c, d_4x, h_4x, w_4x = f_fea['1/4'].shape 116 | 117 | # estimate the multi-resolution flow 118 | flow_4x, hid_4x = self.estimator_4x(hid_4x, torch.cat([f_fea['1/4'], m_fea['1/4']], 1)) 119 | flow_8x, hid_8x = self.estimator_8x(hid_8x, torch.cat([f_fea['1/8'], m_fea['1/8']], 1)) 120 | flow_16x, hid_16x = self.estimator_16x(hid_16x, torch.cat([f_fea['1/16'], m_fea['1/16']], 1)) 121 | 122 | hid = [hid_4x, hid_8x, hid_16x] 123 | 124 | # flow fusion 125 | flow = self.fusion(c_fea, torch.cat([flow_4x, 126 | F.interpolate(flow_8x, (d_4x, h_4x, w_4x), mode='trilinear') * 2.0, 127 | F.interpolate(flow_16x, (d_4x, h_4x, w_4x), mode='trilinear') * 4.0], 1)) 128 | 129 | b, c, d, h, w = image1.shape 130 | final_flow = F.interpolate(flow, size=(d, h, w), mode='trilinear') * 4.0 131 | 132 | return {'flow': final_flow * self.flow_multiplier}, hid 133 | -------------------------------------------------------------------------------- /core/utils/aug_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_coef(u): 7 | return torch.stack([((1 - u) ** 3) / 6, (3 * (u ** 3) - 6 * (u ** 2) + 4) / 6, 8 | (-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6, (u ** 3) / 6], dim=1) 9 | 10 | 11 | def sample_power(lo, hi, k, size=None): 12 | r = (hi - lo) / 2 13 | center = (hi + lo) / 2 14 | r = r ** (1 / k) 15 | points = (torch.rand(size) - 0.5) * 2 * r 16 | points = (torch.abs(points) ** k) * torch.sign(points) 17 | return points + center 18 | 19 | 20 | def pad_3d(mat, pad): 21 | return F.pad(mat, (pad, pad, pad, pad, pad, pad)) 22 | 23 | 24 | def free_form_fields(shape, control_fields, padding='same'): 25 | interpolate_range = 4 26 | 27 | control_fields = torch.Tensor(control_fields) 28 | _, _, n, m, t = control_fields.shape 29 | if padding == 'same': 30 | control_fields = pad_3d(control_fields, 1) 31 | elif padding == 'valid': 32 | n -= 2 33 | m -= 2 34 | t -= 2 35 | 36 | control_fields = torch.reshape(control_fields.permute(2, 3, 4, 0, 1).contiguous(), [n + 2, m + 2, t + 2, -1]) 37 | 38 | assert shape[0] % (n - 1) == 0 39 | s_x = shape[0] // (n - 1) 40 | u_x = (torch.arange(0, s_x, dtype=torch.float32) + 0.5) / s_x # s_x 41 | coef_x = get_coef(u_x) # (s_x, 4) 42 | 43 | shape_cf = control_fields.shape 44 | flow = torch.cat([torch.matmul(coef_x, 45 | torch.reshape(control_fields[i: i + interpolate_range], [interpolate_range, -1])) 46 | for i in range(0, n - 1)], dim=0) 47 | 48 | assert shape[1] % (m - 1) == 0 49 | s_y = shape[1] // (m - 1) 50 | u_y = (torch.arange(0, s_y, dtype=torch.float32) + 0.5) / s_y # s_y 51 | coef_y = get_coef(u_y) # (s_y, 4) 52 | 53 | flow_dims = np.arange(0, len(flow.shape))[::-1] 54 | flow = torch.reshape(flow.permute(*flow_dims).contiguous(), [shape_cf[1], -1]) 55 | flow = torch.cat([torch.matmul(coef_y, 56 | torch.reshape(flow[i: i + interpolate_range], [interpolate_range, -1])) 57 | for i in range(0, m - 1)], dim=0) 58 | 59 | assert shape[2] % (t - 1) == 0 60 | s_z = shape[2] // (t - 1) 61 | u_z = (torch.arange(0, s_z, dtype=torch.float32) + 0.5) / s_z # s_y 62 | coef_z = get_coef(u_z) # (s_y, 4) 63 | 64 | flow_dims = np.arange(0, len(flow.shape))[::-1] 65 | flow = torch.reshape(flow.permute(*flow_dims).contiguous(), [shape_cf[2], -1]) 66 | flow = torch.cat([torch.matmul(coef_z, 67 | torch.reshape(flow[i: i + interpolate_range], [interpolate_range, -1])) 68 | for i in range(0, t - 1)], dim=0) 69 | 70 | flow = torch.reshape(flow, [shape[2], -1, 3, shape[1], shape[0]]) 71 | flow = flow.permute(1, 2, 4, 3, 0).contiguous() 72 | 73 | return flow 74 | -------------------------------------------------------------------------------- /core/utils/warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class warp3D: 7 | def __init__(self, padding=False): 8 | self.padding = padding 9 | 10 | def __call__(self, I, flow): 11 | return self._transform(I, flow[:, 0, :, :, :], flow[:, 1, :, :, :], flow[:, 2, :, :, :]) 12 | 13 | def _meshgrid(self, depth, height, width): 14 | x_t = torch.matmul(torch.ones(height, 1), 15 | (torch.linspace(0.0, float(width) - 1.0, width)[:, np.newaxis].permute(1, 0).contiguous())) 16 | x_t = x_t[np.newaxis].repeat(depth, 1, 1) 17 | 18 | y_t = torch.matmul(torch.linspace(0.0, float(height) - 1.0, height)[:, np.newaxis], torch.ones(1, width)) 19 | y_t = y_t[np.newaxis].repeat(depth, 1, 1) 20 | 21 | z_t = torch.linspace(0.0, float(depth) - 1.0, depth)[:, np.newaxis, np.newaxis].repeat(1, height, width) 22 | 23 | return x_t, y_t, z_t 24 | 25 | def _transform(self, I, dx, dy, dz): 26 | batch_size = dx.shape[0] 27 | depth = dx.shape[1] 28 | height = dx.shape[2] 29 | width = dx.shape[3] 30 | 31 | # Convert dx and dy to absolute locations 32 | x_mesh, y_mesh, z_mesh = self._meshgrid(depth, height, width) 33 | x_mesh = x_mesh[np.newaxis] 34 | y_mesh = y_mesh[np.newaxis] 35 | z_mesh = z_mesh[np.newaxis] 36 | 37 | x_mesh = x_mesh.repeat(batch_size, 1, 1, 1).cuda() 38 | y_mesh = y_mesh.repeat(batch_size, 1, 1, 1).cuda() 39 | z_mesh = z_mesh.repeat(batch_size, 1, 1, 1).cuda() 40 | x_new = dx + x_mesh 41 | y_new = dy + y_mesh 42 | z_new = dz + z_mesh 43 | 44 | return self._interpolate(I, x_new, y_new, z_new) 45 | 46 | def _repeat(self, x, n_repeats): 47 | rep = torch.ones(size=[n_repeats, ])[:, np.newaxis].permute(1, 0).contiguous().int() 48 | x = torch.matmul(x.view([-1, 1]).int(), rep) 49 | return x.view([-1]) 50 | 51 | def _interpolate(self, im, x, y, z): 52 | if self.padding: 53 | im = F.pad(im, (1, 1, 1, 1, 1, 1)) 54 | 55 | num_batch = im.shape[0] 56 | channels = im.shape[1] 57 | depth = im.shape[2] 58 | height = im.shape[3] 59 | width = im.shape[4] 60 | 61 | out_depth = x.shape[1] 62 | out_height = x.shape[2] 63 | out_width = x.shape[3] 64 | 65 | x = x.view([-1]) 66 | y = y.view([-1]) 67 | z = z.view([-1]) 68 | 69 | padding_constant = 1 if self.padding else 0 70 | x = x.float() + padding_constant 71 | y = y.float() + padding_constant 72 | z = z.float() + padding_constant 73 | 74 | max_x = int(width - 1) 75 | max_y = int(height - 1) 76 | max_z = int(depth - 1) 77 | 78 | x0 = torch.floor(x).int() 79 | x1 = x0 + 1 80 | y0 = torch.floor(y).int() 81 | y1 = y0 + 1 82 | z0 = torch.floor(z).int() 83 | z1 = z0 + 1 84 | 85 | x0 = torch.clamp(x0, 0, max_x) 86 | x1 = torch.clamp(x1, 0, max_x) 87 | y0 = torch.clamp(y0, 0, max_y) 88 | y1 = torch.clamp(y1, 0, max_y) 89 | z0 = torch.clamp(z0, 0, max_z) 90 | z1 = torch.clamp(z1, 0, max_z) 91 | 92 | dim1 = width 93 | dim2 = width * height 94 | dim3 = width * height * depth 95 | 96 | base = self._repeat(torch.arange(num_batch) * dim3, 97 | out_depth * out_height * out_width).cuda() 98 | 99 | idx_a = (base + x0 + y0 * dim1 + z0 * dim2)[:, np.newaxis].repeat(1, channels) 100 | idx_b = (base + x0 + y1 * dim1 + z0 * dim2)[:, np.newaxis].repeat(1, channels) 101 | idx_c = (base + x1 + y0 * dim1 + z0 * dim2)[:, np.newaxis].repeat(1, channels) 102 | idx_d = (base + x1 + y1 * dim1 + z0 * dim2)[:, np.newaxis].repeat(1, channels) 103 | idx_e = (base + x0 + y0 * dim1 + z1 * dim2)[:, np.newaxis].repeat(1, channels) 104 | idx_f = (base + x0 + y1 * dim1 + z1 * dim2)[:, np.newaxis].repeat(1, channels) 105 | idx_g = (base + x1 + y0 * dim1 + z1 * dim2)[:, np.newaxis].repeat(1, channels) 106 | idx_h = (base + x1 + y1 * dim1 + z1 * dim2)[:, np.newaxis].repeat(1, channels) 107 | 108 | # use indices to lookup pixels in the flat image and restore 109 | # channels dim 110 | im_flat = im.permute(0, 2, 3, 4, 1).contiguous().view([-1, channels]).float() 111 | 112 | Ia = torch.gather(im_flat, 0, idx_a.long()) 113 | Ib = torch.gather(im_flat, 0, idx_b.long()) 114 | Ic = torch.gather(im_flat, 0, idx_c.long()) 115 | Id = torch.gather(im_flat, 0, idx_d.long()) 116 | Ie = torch.gather(im_flat, 0, idx_e.long()) 117 | If = torch.gather(im_flat, 0, idx_f.long()) 118 | Ig = torch.gather(im_flat, 0, idx_g.long()) 119 | Ih = torch.gather(im_flat, 0, idx_h.long()) 120 | 121 | # and finally calculate interpolated values 122 | x1_f = x1.float() 123 | y1_f = y1.float() 124 | z1_f = z1.float() 125 | 126 | dx = x1_f - x 127 | dy = y1_f - y 128 | dz = z1_f - z 129 | 130 | wa = (dz * dx * dy)[:, np.newaxis] 131 | wb = (dz * dx * (1 - dy))[:, np.newaxis] 132 | wc = (dz * (1 - dx) * dy)[:, np.newaxis] 133 | wd = (dz * (1 - dx) * (1 - dy))[:, np.newaxis] 134 | we = ((1 - dz) * dx * dy)[:, np.newaxis] 135 | wf = ((1 - dz) * dx * (1 - dy))[:, np.newaxis] 136 | wg = ((1 - dz) * (1 - dx) * dy)[:, np.newaxis] 137 | wh = ((1 - dz) * (1 - dx) * (1 - dy))[:, np.newaxis] 138 | 139 | output = wa * Ia + wb * Ib + wc * Ic + wd * Id + we * Ie + wf * If + wg * Ig + wh * Ih 140 | output = output.view([-1, out_depth, out_height, out_width, channels]) 141 | 142 | return output.permute(0, 4, 1, 2, 3) 143 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from os.path import join 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import core.datasets as datasets 10 | from core.utils.warp import warp3D 11 | from core.framework import Framework 12 | 13 | 14 | def mask_class(label, value): 15 | return (torch.abs(label - value) < 0.5).float() * 255.0 16 | 17 | 18 | def mask_metrics(seg1, seg2): 19 | sizes = np.prod(seg1.shape[1:]) 20 | seg1 = (seg1.view(-1, sizes) > 128).type(torch.float32) 21 | seg2 = (seg2.view(-1, sizes) > 128).type(torch.float32) 22 | dice_score = 2.0 * torch.sum(seg1 * seg2, 1) / (torch.sum(seg1, 1) + torch.sum(seg2, 1)) 23 | 24 | union = torch.sum(torch.max(seg1, seg2), 1) 25 | iden = (torch.ones(*union.shape) * 0.01).cuda() 26 | jacc_score = torch.sum(torch.min(seg1, seg2), 1) / torch.max(iden, union) 27 | 28 | return dice_score, jacc_score 29 | 30 | 31 | def jacobian_det(flow): 32 | bias_d = np.array([0, 0, 1]) 33 | bias_h = np.array([0, 1, 0]) 34 | bias_w = np.array([1, 0, 0]) 35 | 36 | volume_d = np.transpose(flow[:, 1:, :-1, :-1] - flow[:, :-1, :-1, :-1], (1, 2, 3, 0)) + bias_d 37 | volume_h = np.transpose(flow[:, :-1, 1:, :-1] - flow[:, :-1, :-1, :-1], (1, 2, 3, 0)) + bias_h 38 | volume_w = np.transpose(flow[:, :-1, :-1, 1:] - flow[:, :-1, :-1, :-1], (1, 2, 3, 0)) + bias_w 39 | 40 | jacobian_det_volume = np.linalg.det(np.stack([volume_w, volume_h, volume_d], -1)) 41 | jd = np.sum(jacobian_det_volume <= 0) 42 | return jd 43 | 44 | 45 | def evaluate_liver(args, model, steps): 46 | for datas in args.dataset_test: 47 | eval_path = join(args.eval_path, datas) 48 | if (args.local_rank == 0) and (not os.path.isdir(eval_path)): 49 | os.makedirs(eval_path) 50 | file_sum = join(eval_path, datas + '.txt') 51 | file = join(eval_path, datas + '_' + str(steps) + '.txt') 52 | f = open(file, 'a+') 53 | g = open(file_sum, 'a+') 54 | 55 | Dice, Jacc, Jacb = [], [], [] 56 | if 'lspig' in datas: 57 | eval_dataset = datasets.LspigTest(args, datas) 58 | else: 59 | eval_dataset = datasets.LiverTest(args, datas) 60 | if args.local_rank == 0: 61 | print('Dataset in evaluation: %s' % datas, file=f) 62 | print('Image pairs in evaluation: %d' % len(eval_dataset), file=f) 63 | print('Evaluation steps: %s' % steps, file=f) 64 | 65 | for i in range(len(eval_dataset)): 66 | image1, image2 = eval_dataset[i][0][np.newaxis].cuda(), eval_dataset[i][1][np.newaxis].cuda() 67 | label1, label2 = eval_dataset[i][2][np.newaxis].cuda(), eval_dataset[i][3][np.newaxis].cuda() 68 | 69 | with torch.no_grad(): 70 | _, _, _, agg_flow, _ = model.module(image1, image2, augment=False) 71 | label2_warped = warp3D()(label2, agg_flow) 72 | 73 | dice, jacc = mask_metrics(label1, label2_warped) 74 | 75 | dice = dice.cpu().numpy()[0] 76 | jacc = jacc.cpu().numpy()[0] 77 | jacb = jacobian_det(agg_flow.cpu().numpy()[0]) 78 | 79 | if args.local_rank == 0: 80 | print('Pair{:6d} dice:{:10.6f} jacc:{:10.6f} jacb:{:10.2f}'. 81 | format(i, dice, jacc, jacb), 82 | file=f) 83 | 84 | Dice.append(dice) 85 | Jacc.append(jacc) 86 | Jacb.append(jacb) 87 | 88 | dice_mean, dice_std = np.mean(np.array(Dice)), np.std(np.array(Dice)) 89 | jacc_mean, jacc_std = np.mean(np.array(Jacc)), np.std(np.array(Jacc)) 90 | jacb_mean, jacb_std = np.mean(np.array(Jacb)), np.std(np.array(Jacb)) 91 | 92 | if args.local_rank == 0: 93 | print('Summary ---> ' 94 | 'Dice:{:10.6f}({:10.6f}) Jacc:{:10.6f}({:10.6f}) ' 95 | 'Jacb:{:10.2f}({:10.2f})' 96 | .format(dice_mean, dice_std, jacc_mean, jacc_std, jacb_mean, jacb_std), 97 | file=f) 98 | 99 | print('Step{:12d} ---> ' 100 | 'Dice:{:10.6f}({:10.6f}) Jacc:{:10.6f}({:10.6f}) ' 101 | 'Jacb:{:10.2f}({:10.2f})' 102 | .format(steps, dice_mean, dice_std, jacc_mean, jacc_std, jacb_mean, jacb_std), 103 | file=g) 104 | 105 | f.close() 106 | g.close() 107 | 108 | 109 | def evaluate_brain(args, model, steps): 110 | for datas in args.dataset_test: 111 | eval_path = join(args.eval_path, datas) 112 | if (args.local_rank == 0) and (not os.path.isdir(eval_path)): 113 | os.makedirs(eval_path) 114 | file_sum = join(eval_path, datas + '.txt') 115 | file = join(eval_path, datas + '_' + str(steps) + '.txt') 116 | f = open(file, 'a+') 117 | g = open(file_sum, 'a+') 118 | 119 | Dice, Jacc, Jacb = [], [], [] 120 | eval_dataset = datasets.BrainTest(args, datas) 121 | if args.local_rank == 0: 122 | print('Dataset in evaluation: %s' % datas, file=f) 123 | print('Image pairs in evaluation: %d' % len(eval_dataset), file=f) 124 | print('Evaluation steps: %s' % steps, file=f) 125 | 126 | for i in range(len(eval_dataset)): 127 | image1, image2 = eval_dataset[i][0][np.newaxis].cuda(), eval_dataset[i][1][np.newaxis].cuda() 128 | label1, label2 = eval_dataset[i][2][np.newaxis].cuda(), eval_dataset[i][3][np.newaxis].cuda() 129 | 130 | with torch.no_grad(): 131 | _, _, _, agg_flow, _ = model.module(image1, image2, augment=False) 132 | 133 | jaccs = [] 134 | dices = [] 135 | 136 | for v in eval_dataset.seg_values: 137 | label1_fixed = mask_class(label1, v) 138 | label2_warped = warp3D()(mask_class(label2, v), agg_flow) 139 | 140 | class_dice, class_jacc = mask_metrics(label1_fixed, label2_warped) 141 | 142 | dices.append(class_dice) 143 | jaccs.append(class_jacc) 144 | 145 | jacb = jacobian_det(agg_flow.cpu().numpy()[0]) 146 | 147 | dice = torch.mean(torch.cuda.FloatTensor(dices)).cpu().numpy() 148 | jacc = torch.mean(torch.cuda.FloatTensor(jaccs)).cpu().numpy() 149 | 150 | if args.local_rank == 0: 151 | print('Pair{:6d} dice:{:10.6f} jacc:{:10.6f} jacb:{:10.2f}'. 152 | format(i, dice, jacc, jacb), 153 | file=f) 154 | 155 | Dice.append(dice) 156 | Jacc.append(jacc) 157 | Jacb.append(jacb) 158 | 159 | dice_mean, dice_std = np.mean(np.array(Dice)), np.std(np.array(Dice)) 160 | jacc_mean, jacc_std = np.mean(np.array(Jacc)), np.std(np.array(Jacc)) 161 | jacb_mean, jacb_std = np.mean(np.array(Jacb)), np.std(np.array(Jacb)) 162 | 163 | if args.local_rank == 0: 164 | print('Summary ---> ' 165 | 'Dice:{:10.6f}({:10.6f}) Jacc:{:10.6f}({:10.6f}) ' 166 | 'Jacb:{:10.2f}({:10.2f})' 167 | .format(dice_mean, dice_std, jacc_mean, jacc_std, jacb_mean, jacb_std), 168 | file=f) 169 | 170 | print('Step{:12d} ---> ' 171 | 'Dice:{:10.6f}({:10.6f}) Jacc:{:10.6f}({:10.6f}) ' 172 | 'Jacb:{:10.2f}({:10.2f})' 173 | .format(steps, dice_mean, dice_std, jacc_mean, jacc_std, jacb_mean, jacb_std), 174 | file=g) 175 | 176 | f.close() 177 | g.close() 178 | 179 | 180 | def evaluate(args, model, steps=100000): 181 | if args.dataset == 'liver': 182 | evaluate_liver(args, model, steps) 183 | elif args.dataset == 'brain': 184 | evaluate_brain(args, model, steps) 185 | 186 | 187 | if __name__ == '__main__': 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--name', type=str, default='SDHNet', help='evaluation experiment') 190 | parser.add_argument('--model', type=str, default='SDHNet_lpba', help='evaluation experiment') 191 | parser.add_argument('--dataset', type=str, default='brain', help='which dataset to use for evaluation') 192 | parser.add_argument('--dataset_test', nargs='+', default=['lpba'], help='specific dataset to use for evaluation') 193 | parser.add_argument('--data_path', type=str, default='E:/Registration/Code/TMI2022/Github/Data_MRIBrain/') 194 | parser.add_argument('--base_path', type=str, default='E:/Registration/Code/TMI2022/Github/') 195 | parser.add_argument('--iters', type=int, default=6) 196 | parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training') 197 | args = parser.parse_args() 198 | 199 | args.model_path = args.base_path + args.name + '/output/checkpoints_' + args.dataset 200 | args.eval_path = args.base_path + args.name + '/output/eval_' + args.dataset 201 | args.restore_ckpt = join(args.model_path, args.model + '.pth') 202 | 203 | model = Framework(args) 204 | model = nn.DataParallel(model) 205 | model.eval() 206 | model.cuda() 207 | 208 | model.load_state_dict(torch.load(args.restore_ckpt)) 209 | evaluate(args, model) 210 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import random 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | import torch.distributed as dist 11 | 12 | import core.datasets as datasets 13 | import core.losses as losses 14 | from core.utils.warp import warp3D 15 | from core.framework import Framework 16 | 17 | 18 | def fetch_optimizer(args, model): 19 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 20 | milestones = [args.round*3, args.round*4, args.round*5] # args.epoch == 5 21 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5) 22 | 23 | return optimizer, scheduler 24 | 25 | 26 | def fetch_loss(affines, deforms, agg_flow, image1, image2): 27 | # affine loss 28 | det = losses.det3x3(affines['A']) 29 | det_loss = torch.sum((det - 1.0) ** 2) / 2 30 | 31 | I = torch.cuda.FloatTensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) 32 | eps = 1e-5 33 | epsI = torch.cuda.FloatTensor([[[eps * elem for elem in row] for row in Mat] for Mat in I]) 34 | C = torch.matmul(affines['A'].permute(0, 2, 1), affines['A']) + epsI 35 | s1, s2, s3 = losses.elem_sym_polys_of_eigen_values(C) 36 | ortho_loss = torch.sum(s1 + (1 + eps) * (1 + eps) * s2 / s3 - 3 * 2 * (1 + eps)) 37 | aff_loss = 0.1 * det_loss + 0.1 * ortho_loss 38 | 39 | # deform loss 40 | image2_warped = warp3D()(image2, agg_flow) 41 | sim_loss = losses.similarity_loss(image1, image2_warped) 42 | 43 | reg_loss = 0.0 44 | for i in range(len(deforms)): 45 | reg_loss = reg_loss + losses.regularize_loss(deforms[i]['flow']) 46 | 47 | whole_loss = aff_loss + sim_loss + 0.5 * reg_loss 48 | 49 | metrics = { 50 | 'aff_loss': aff_loss.item(), 51 | 'sim_loss': sim_loss.item(), 52 | 'reg_loss': reg_loss.item() 53 | } 54 | 55 | return whole_loss, metrics 56 | 57 | 58 | def fetch_dataloader(args): 59 | if args.dataset == 'liver': 60 | train_dataset = datasets.LiverTrain(args) 61 | elif args.dataset == 'brain': 62 | train_dataset = datasets.BrainTrain(args) 63 | else: 64 | print('Wrong Dataset') 65 | 66 | gpuargs = {'num_workers': 4, 'drop_last': True} 67 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 68 | train_loader = DataLoader(train_dataset, batch_size=args.batch, pin_memory=True, shuffle=False, sampler=train_sampler, **gpuargs) 69 | 70 | if args.local_rank == 0: 71 | print('Image pairs in training: %d' % len(train_dataset), file=args.files, flush=True) 72 | return train_loader 73 | 74 | 75 | class Logger: 76 | def __init__(self, model, scheduler, args): 77 | self.model = model 78 | self.scheduler = scheduler 79 | self.total_steps = 0 80 | self.running_loss = {} 81 | self.sum_freq = args.sum_freq 82 | 83 | def _print_training_status(self): 84 | metrics_data = ["{" + k + ":{:10.5f}".format(self.running_loss[k] / self.sum_freq) + "} " 85 | for k in self.running_loss.keys()] 86 | training_str = "[Steps:{:9d}, Lr:{:10.7f}] ".format(self.total_steps + 1, self.scheduler.get_lr()[0]) 87 | print(training_str + "".join(metrics_data), file=args.files, flush=True) 88 | print(training_str + "".join(metrics_data)) 89 | 90 | for key in self.running_loss: 91 | self.running_loss[key] = 0.0 92 | 93 | def push(self, metrics): 94 | self.total_steps = self.total_steps + 1 95 | 96 | for key in metrics: 97 | if key not in self.running_loss: 98 | self.running_loss[key] = 0.0 99 | 100 | self.running_loss[key] = self.running_loss[key] + metrics[key] 101 | 102 | if self.total_steps % self.sum_freq == self.sum_freq - 1: 103 | if args.local_rank == 0: 104 | self._print_training_status() 105 | self.running_loss = {} 106 | 107 | 108 | def train(args): 109 | model = Framework(args) 110 | model.cuda() 111 | 112 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 113 | 114 | train_loader = fetch_dataloader(args) 115 | optimizer, scheduler = fetch_optimizer(args, model) 116 | 117 | total_steps = 0 118 | logger = Logger(model, scheduler, args) 119 | 120 | should_keep_training = True 121 | while should_keep_training: 122 | for i_batch, data_blob in enumerate(train_loader): 123 | model.train() 124 | image1, image2 = [x.cuda(non_blocking=True) for x in data_blob] 125 | 126 | optimizer.zero_grad() 127 | image2_aug, affines, deforms, agg_flow, agg_flows = model(image1, image2) 128 | 129 | loss, metrics = fetch_loss(affines, deforms, agg_flow, image1, image2_aug) 130 | loss.backward() 131 | 132 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 133 | optimizer.step() 134 | scheduler.step() 135 | total_steps = total_steps + 1 136 | 137 | logger.push(metrics) 138 | 139 | if total_steps % args.val_freq == args.val_freq - 1: 140 | PATH = args.model_path + '/%s_%d.pth' % (args.name, total_steps + 1) 141 | torch.save(model.state_dict(), PATH) 142 | 143 | if total_steps == args.num_steps: 144 | should_keep_training = False 145 | break 146 | 147 | PATH = args.model_path + '/%s.pth' % args.name 148 | torch.save(model.state_dict(), PATH) 149 | 150 | return PATH 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--name', type=str, default='SDHNet', help='name your experiment') 156 | parser.add_argument('--dataset', type=str, default='brain', help='which dataset to use for training') 157 | parser.add_argument('--epoch', type=int, default=5, help='number of epochs') 158 | parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate') 159 | parser.add_argument('--clip', type=float, default=1.0) 160 | parser.add_argument('--batch', type=int, default=1, help='number of image pairs per batch on single gpu') 161 | parser.add_argument('--sum_freq', type=int, default=1000) 162 | parser.add_argument('--val_freq', type=int, default=2000) 163 | parser.add_argument('--round', type=int, default=20000, help='number of batches per epoch') 164 | parser.add_argument('--data_path', type=str, default='E:/Registration/Code/TMI2022/Github/Data_MRIBrain/') 165 | parser.add_argument('--base_path', type=str, default='E:/Registration/Code/TMI2022/Github/') 166 | parser.add_argument('--iters', type=int, default=6) 167 | parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training') 168 | args = parser.parse_args() 169 | 170 | args.model_path = args.base_path + args.name + '/output/checkpoints_' + args.dataset 171 | args.eval_path = args.base_path + args.name + '/output/eval_' + args.dataset 172 | 173 | dist.init_process_group(backend='nccl') 174 | 175 | if args.local_rank == 0: 176 | os.makedirs(args.model_path, exist_ok=True) 177 | os.makedirs(args.eval_path, exist_ok=True) 178 | 179 | random.seed(123) 180 | np.random.seed(123) 181 | torch.manual_seed(123) 182 | torch.cuda.manual_seed_all(123) 183 | 184 | args.nums_gpu = torch.cuda.device_count() 185 | args.batch = args.batch 186 | args.num_steps = args.epoch * args.round 187 | args.files = open(args.base_path + args.name + '/output/train_' + args.dataset + '.txt', 'a+') 188 | 189 | if args.local_rank == 0: 190 | print('Dataset: %s' % args.dataset, file=args.files, flush=True) 191 | print('Batch size: %s' % args.batch, file=args.files, flush=True) 192 | print('Step: %s' % args.num_steps, file=args.files, flush=True) 193 | print('Parallel GPU: %s' % args.nums_gpu, file=args.files, flush=True) 194 | 195 | train(args) 196 | args.files.close() 197 | --------------------------------------------------------------------------------