├── LICENSE ├── README.md ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc └── cnnSegmentShortAxis.cpython-36.pyc ├── cnnSegmentShortAxis.py ├── common ├── __pycache__ │ ├── cardiac_utils.cpython-36.pyc │ ├── deploy_network.cpython-36.pyc │ ├── deploy_network.cpython-37.pyc │ ├── image_utils.cpython-36.pyc │ ├── image_utils.cpython-37.pyc │ ├── network.cpython-36.pyc │ └── network.cpython-37.pyc ├── cardiac_utils.py ├── deploy_network.py ├── image_utils.py ├── network.py ├── squeezeNiis.py ├── train_network_UW.py └── train_network_UW_fine_tune.py ├── computeDiceAll.py ├── matlab_scripts ├── cnnSeg2Segment.m ├── computeLVFlowCompartments.m ├── computeRVFlowCompartments.m ├── computeVentricularKE.m ├── pcvipr2nii.m ├── resizeNii.m └── segment2nii.m ├── model └── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001 │ ├── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.data-00000-of-00001 │ ├── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.index │ ├── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.meta │ └── checkpoint ├── modelFT └── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001 │ ├── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.data-00000-of-00001 │ ├── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.index │ ├── FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.meta │ └── checkpoint ├── registration └── register_SA_mask_to_Flow_images ├── segmentAll.py ├── third_party ├── src │ ├── CMakeLists.txt │ └── average_3d_ffd.cc ├── ubuntu_16.04_bin │ └── average_3d_ffd └── ubuntu_18.04_bin │ └── average_3d_ffd └── ukbb_trained_model ├── FCN_sa.data-00000-of-00001 ├── FCN_sa.index └── FCN_sa.meta /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017, Wenjia Bai 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This code is based on a toolbox developed by Wenjia Bai (https://github.com/baiwenjia/ukbb_cardiac) for processing cardiovascular magnetic resonance (CMR) images. 4 | 5 | I have modified the code for "fine-tuning", or loading in the weights from training on a large data set (UK BioBank) and training the weights for a small number of iterations on data acquired at my local institution. There are also additional scripts for registering bSSFP images to 4D flow images and for computing kinetic energy and flow components from masked 4D flow MRI images. The registration scripts use the ANTs toolbox (http://stnava.github.io/ANTs/) and the analysis scripts run in Matlab. 6 | 7 | 8 | ## Segmentation Toolbox Installation 9 | 10 | The segmentation toolbox is developed using Python 3. 11 | 12 | The toolbox depends on some external libraries which need to be installed, including: 13 | 14 | * tensorflow for deep learning; 15 | * numpy and scipy for numerical computation; 16 | * pandas and python-dateutil for handling spreadsheet; 17 | * pydicom, SimpleITK for handling dicom images 18 | * nibabel for reading and writing nifti images; 19 | * opencv-python for transforming images in data augmentation. 20 | 21 | The most convenient way to install these libraries is to use pip3 (or pip for Python 2) by running this command in the terminal: 22 | ``` 23 | pip3 install tensorflow-gpu numpy scipy pandas python-dateutil pydicom SimpleITK nibabel opencv-python 24 | ``` 25 | 26 | To use, please add the github repository directory to your $PYTHONPATH environment, so that the ukbb_cardiac module can be imported and cross-referenced in its code. If you are using Linux, you can run this command: 27 | ``` 28 | export PYTHONPATH=YOUR_GIT_REPOSITORY_PATH:"${PYTHONPATH}" 29 | ``` 30 | 31 | There is one parameter in the script, *CUDA_VISIBLE_DEVICES*, which controls which GPU device to use on your machine. Currently, I set it to 0, which means the first GPU on your machine. 32 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/__init__.py -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/cnnSegmentShortAxis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/__pycache__/cnnSegmentShortAxis.cpython-36.pyc -------------------------------------------------------------------------------- /cnnSegmentShortAxis.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import random 5 | import numpy as np 6 | import nibabel as nib 7 | import tensorflow as tf 8 | import math 9 | from scipy.ndimage import zoom 10 | sys.path.insert(1, '/export/home/pcorrado/CODE/') 11 | 12 | print(sys.path) 13 | from ukbb_cardiac.common.network import build_FCN 14 | from ukbb_cardiac.common.image_utils import tf_categorical_accuracy, tf_categorical_dice 15 | from ukbb_cardiac.common.image_utils import crop_image, rescale_intensity, data_augmenter 16 | from ukbb_cardiac.common import deploy_network 17 | 18 | if os.path.exists("sa.nii.gz"): 19 | nim = nib.load("sa.nii.gz") 20 | elif os.path.exists("sa.nii"): 21 | nim = nib.load("sa.nii") 22 | orig_image = nim.get_data() 23 | (X1,Y1,Z1,T1) = orig_image.shape 24 | 25 | if X1 != Y1: 26 | print("Image is not square, exiting.") 27 | else: 28 | print('Pre-zoom image shape') 29 | print(orig_image.shape) 30 | image = zoom(orig_image,(192.0/X1,192.0/Y1,1,1),order=1) 31 | print('Post-zoom image shape') 32 | print(image.shape) 33 | os.system('CUDA_VISIBLE_DEVICES=1') 34 | model_path = os.path.join(os.path.dirname(__file__), "modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt") 35 | with tf.Session() as sess: 36 | sess.run(tf.global_variables_initializer()) 37 | 38 | # Import the computation graph and restore the variable values 39 | saver = tf.train.import_meta_graph('{0}.meta'.format(model_path)) 40 | saver.restore(sess, '{0}'.format(model_path)) 41 | 42 | print('Start deployment on the data set ...') 43 | start_time = time.time() 44 | 45 | X, Y, Z, T = image.shape 46 | 47 | print(' Segmenting full sequence ...') 48 | start_seg_time = time.time() 49 | 50 | # Intensity rescaling 51 | image = rescale_intensity(image, (1, 99)) 52 | 53 | # Prediction (segmentation) 54 | pred = np.zeros(image.shape) 55 | 56 | # Pad the image size to be a factor of 16 so that the 57 | # downsample and upsample procedures in the network will 58 | # result in the same image size at each resolution level. 59 | X2, Y2 = int(math.ceil(X / 16.0)) * 16, int(math.ceil(Y / 16.0)) * 16 60 | x_pre, y_pre = int((X2 - X) / 2), int((Y2 - Y) / 2) 61 | x_post, y_post = (X2 - X) - x_pre, (Y2 - Y) - y_pre 62 | image = np.pad(image, ((x_pre, x_post), (y_pre, y_post), (0, 0), (0, 0)), 'constant') 63 | 64 | # Process each time frame 65 | for t in range(T): 66 | # Transpose the shape to NXYC 67 | image_fr = image[:, :, :, t] 68 | image_fr = np.transpose(image_fr, axes=(2, 0, 1)).astype(np.float32) 69 | image_fr = np.expand_dims(image_fr, axis=-1) 70 | 71 | # Evaluate the network 72 | prob_fr, pred_fr = sess.run(['prob:0', 'pred:0'], 73 | feed_dict={'image:0': image_fr, 'training:0': False}) 74 | 75 | # Transpose and crop segmentation to recover the original size 76 | pred_fr = np.transpose(pred_fr, axes=(1, 2, 0)) 77 | pred_fr = pred_fr[x_pre:x_pre + X, y_pre:y_pre + Y] 78 | pred[:, :, :, t] = pred_fr 79 | 80 | seg_time = time.time() - start_seg_time 81 | print(' Segmentation time = {:3f}s'.format(seg_time)) 82 | 83 | pred = zoom(pred,(X1/X,Y1/Y,1,1),order=0) 84 | rvMask = pred==3 85 | pred[rvMask]=-1 86 | 87 | # Save the segmentation 88 | print(' Saving segmentation ...') 89 | nim2 = nib.Nifti1Image(pred, nim.affine) 90 | nim2.header['pixdim'] = nim.header['pixdim'] 91 | seg_name = './seg_sa.nii.gz' 92 | nib.save(nim2, seg_name) 93 | -------------------------------------------------------------------------------- /common/__pycache__/cardiac_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/cardiac_utils.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/deploy_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/deploy_network.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/deploy_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/deploy_network.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/image_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/image_utils.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/image_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/image_utils.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/common/__pycache__/network.cpython-37.pyc -------------------------------------------------------------------------------- /common/cardiac_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import os 16 | import math 17 | import numpy as np 18 | import nibabel as nib 19 | import cv2 20 | import vtk 21 | import pandas as pd 22 | import matplotlib.pyplot as plt 23 | from vtk.util import numpy_support 24 | from scipy import interpolate 25 | import skimage 26 | from ukbb_cardiac.common.image_utils import * 27 | 28 | 29 | def approximate_contour(contour, factor=4, smooth=0.05, periodic=False): 30 | """ Approximate a contour. 31 | 32 | contour: input contour 33 | factor: upsampling factor for the contour 34 | smooth: smoothing factor for controling the number of spline knots. 35 | Number of knots will be increased until the smoothing 36 | condition is satisfied: 37 | sum((w[i] * (y[i]-spl(x[i])))**2, axis=0) <= s 38 | which means the larger s is, the fewer knots will be used, 39 | thus the contour will be smoother but also deviating more 40 | from the input contour. 41 | periodic: set to True if this is a closed contour, otherwise False. 42 | 43 | return the upsampled and smoothed contour 44 | """ 45 | # The input contour 46 | N = len(contour) 47 | dt = 1.0 / N 48 | t = np.arange(N) * dt 49 | x = contour[:, 0] 50 | y = contour[:, 1] 51 | 52 | # Pad the contour before approximation to avoid underestimating 53 | # the values at the end points 54 | r = int(0.5 * N) 55 | t_pad = np.concatenate((np.arange(-r, 0) * dt, t, 1 + np.arange(0, r) * dt)) 56 | if periodic: 57 | x_pad = np.concatenate((x[-r:], x, x[:r])) 58 | y_pad = np.concatenate((y[-r:], y, y[:r])) 59 | else: 60 | x_pad = np.concatenate((np.repeat(x[0], repeats=r), x, np.repeat(x[-1], repeats=r))) 61 | y_pad = np.concatenate((np.repeat(y[0], repeats=r), y, np.repeat(y[-1], repeats=r))) 62 | 63 | # Fit the contour with splines with a smoothness constraint 64 | fx = interpolate.UnivariateSpline(t_pad, x_pad, s=smooth * len(t_pad)) 65 | fy = interpolate.UnivariateSpline(t_pad, y_pad, s=smooth * len(t_pad)) 66 | 67 | # Evaluate the new contour 68 | N2 = N * factor 69 | dt2 = 1.0 / N2 70 | t2 = np.arange(N2) * dt2 71 | x2, y2 = fx(t2), fy(t2) 72 | contour2 = np.stack((x2, y2), axis=1) 73 | return contour2 74 | 75 | 76 | def sa_pass_quality_control(seg_sa_name): 77 | """ Quality control for short-axis image segmentation """ 78 | nim = nib.load(seg_sa_name) 79 | seg_sa = nim.get_data() 80 | X, Y, Z = seg_sa.shape[:3] 81 | 82 | # Label class in the segmentation 83 | label = {'LV': 1, 'Myo': 2, 'RV': 3} 84 | 85 | # Criterion 1: every class exists and the area is above a threshold 86 | # Count number of pixels in 3D 87 | for l_name, l in label.items(): 88 | pixel_thres = 10 89 | if np.sum(seg_sa == l) < pixel_thres: 90 | print('{0}: The segmentation for class {1} is smaller than {2} pixels. ' 91 | 'It does not pass the quality control.'.format(seg_sa_name, l_name, pixel_thres)) 92 | return False 93 | 94 | # Criterion 2: number of slices with LV segmentations is above a threshold 95 | # and there is no missing segmentation in between the slices 96 | z_pos = [] 97 | for z in range(Z): 98 | seg_z = seg_sa[:, :, z] 99 | endo = (seg_z == label['LV']).astype(np.uint8) 100 | myo = (seg_z == label['Myo']).astype(np.uint8) 101 | pixel_thres = 10 102 | if (np.sum(endo) < pixel_thres) or (np.sum(myo) < pixel_thres): 103 | continue 104 | z_pos += [z] 105 | n_slice = len(z_pos) 106 | slice_thres = 6 107 | if n_slice < slice_thres: 108 | print('{0}: The segmentation has less than {1} slices. ' 109 | 'It does not pass the quality control.'.format(seg_sa_name, slice_thres)) 110 | return False 111 | 112 | if n_slice != (np.max(z_pos) - np.min(z_pos) + 1): 113 | print('{0}: There is missing segmentation between the slices. ' 114 | 'It does not pass the quality control.'.format(seg_sa_name)) 115 | return False 116 | 117 | # Criterion 3: LV and RV exists on the mid-cavity slice 118 | _, _, cz = [np.mean(x) for x in np.nonzero(seg_sa == label['LV'])] 119 | z = int(round(cz)) 120 | seg_z = seg_sa[:, :, z] 121 | 122 | endo = (seg_z == label['LV']).astype(np.uint8) 123 | endo = get_largest_cc(endo).astype(np.uint8) 124 | myo = (seg_z == label['Myo']).astype(np.uint8) 125 | myo = remove_small_cc(myo).astype(np.uint8) 126 | epi = (endo | myo).astype(np.uint8) 127 | epi = get_largest_cc(epi).astype(np.uint8) 128 | rv = (seg_z == label['RV']).astype(np.uint8) 129 | rv = get_largest_cc(rv).astype(np.uint8) 130 | pixel_thres = 10 131 | if np.sum(epi) < pixel_thres or np.sum(rv) < pixel_thres: 132 | print('{0}: Can not find LV epi or RV to determine the AHA ' 133 | 'coordinate system.'.format(seg_sa_name)) 134 | return False 135 | return True 136 | 137 | 138 | def la_pass_quality_control(seg_la_name): 139 | """ Quality control for long-axis image segmentation """ 140 | nim = nib.load(seg_la_name) 141 | seg = nim.get_data() 142 | X, Y, Z = seg.shape[:3] 143 | seg_z = seg[:, :, 0] 144 | 145 | # Label class in the segmentation 146 | label = {'LV': 1, 'Myo': 2, 'RV': 3, 'LA': 4, 'RA': 5} 147 | 148 | for l_name, l in label.items(): 149 | # Criterion 1: every class exists and the area is above a threshold 150 | pixel_thres = 10 151 | if np.sum(seg_z == l) < pixel_thres: 152 | print('{0}: The segmentation for class {1} is smaller than {2} pixels. ' 153 | 'It does not pass the quality control.'.format(seg_la_name, l_name, pixel_thres)) 154 | return False 155 | 156 | # Criterion 2: the area is above a threshold after connected component analysis 157 | endo = (seg_z == label['LV']).astype(np.uint8) 158 | endo = get_largest_cc(endo).astype(np.uint8) 159 | myo = (seg_z == label['Myo']).astype(np.uint8) 160 | myo = remove_small_cc(myo).astype(np.uint8) 161 | epi = (endo | myo).astype(np.uint8) 162 | epi = get_largest_cc(epi).astype(np.uint8) 163 | pixel_thres = 10 164 | if np.sum(endo) < pixel_thres or np.sum(myo) < pixel_thres or np.sum(epi) < pixel_thres: 165 | print('{0}: Can not find LV endo, myo or epi to extract the long-axis ' 166 | 'myocardial contour.'.format(seg_la_name)) 167 | return False 168 | return True 169 | 170 | 171 | def determine_aha_coordinate_system(seg_sa, affine_sa): 172 | """ Determine the AHA coordinate system using the mid-cavity slice 173 | of the short-axis image segmentation. 174 | """ 175 | # Label class in the segmentation 176 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3} 177 | 178 | # Find the mid-cavity slice 179 | _, _, cz = [np.mean(x) for x in np.nonzero(seg_sa == label['LV'])] 180 | z = int(round(cz)) 181 | seg_z = seg_sa[:, :, z] 182 | 183 | endo = (seg_z == label['LV']).astype(np.uint8) 184 | endo = get_largest_cc(endo).astype(np.uint8) 185 | myo = (seg_z == label['Myo']).astype(np.uint8) 186 | myo = remove_small_cc(myo).astype(np.uint8) 187 | epi = (endo | myo).astype(np.uint8) 188 | epi = get_largest_cc(epi).astype(np.uint8) 189 | rv = (seg_z == label['RV']).astype(np.uint8) 190 | rv = get_largest_cc(rv).astype(np.uint8) 191 | 192 | # Extract epicardial contour 193 | _, contours, _ = cv2.findContours(cv2.inRange(epi, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 194 | epi_contour = contours[0][:, 0, :] 195 | 196 | # Find the septum, which is the intersection between LV and RV 197 | septum = [] 198 | dilate_iter = 1 199 | while len(septum) == 0: 200 | # Dilate the RV till it intersects with LV epicardium. 201 | # Normally, this is fulfilled after just one iteration. 202 | rv_dilate = cv2.dilate(rv, np.ones((3, 3), dtype=np.uint8), iterations=dilate_iter) 203 | dilate_iter += 1 204 | for y, x in epi_contour: 205 | if rv_dilate[x, y] == 1: 206 | septum += [[x, y]] 207 | 208 | # The middle of the septum 209 | mx, my = septum[int(round(0.5 * len(septum)))] 210 | point_septum = np.dot(affine_sa, np.array([mx, my, z, 1]))[:3] 211 | 212 | # Find the centre of the LV cavity 213 | cx, cy = [np.mean(x) for x in np.nonzero(endo)] 214 | point_cavity = np.dot(affine_sa, np.array([cx, cy, z, 1]))[:3] 215 | 216 | # Determine the AHA coordinate system 217 | axis = {} 218 | axis['lv_to_sep'] = point_septum - point_cavity 219 | axis['lv_to_sep'] /= np.linalg.norm(axis['lv_to_sep']) 220 | axis['apex_to_base'] = np.copy(affine_sa[:3, 2]) 221 | axis['apex_to_base'] /= np.linalg.norm(axis['apex_to_base']) 222 | if axis['apex_to_base'][2] < 0: 223 | axis['apex_to_base'] *= -1 224 | axis['inf_to_ant'] = np.cross(axis['apex_to_base'], axis['lv_to_sep']) 225 | return axis 226 | 227 | 228 | def determine_aha_part(seg_sa, affine_sa, three_slices=False): 229 | """ Determine the AHA part for each slice. """ 230 | # Label class in the segmentation 231 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3} 232 | 233 | # Sort the z-axis positions of the slices with both endo and epicardium 234 | # segmentations 235 | X, Y, Z = seg_sa.shape[:3] 236 | z_pos = [] 237 | for z in range(Z): 238 | seg_z = seg_sa[:, :, z] 239 | endo = (seg_z == label['LV']).astype(np.uint8) 240 | myo = (seg_z == label['Myo']).astype(np.uint8) 241 | pixel_thres = 10 242 | if (np.sum(endo) < pixel_thres) or (np.sum(myo) < pixel_thres): 243 | continue 244 | z_pos += [(z, np.dot(affine_sa, np.array([X / 2.0, Y / 2.0, z, 1]))[2])] 245 | z_pos = sorted(z_pos, key=lambda x: -x[1]) 246 | 247 | # Divide the slices into three parts: basal, mid-cavity and apical 248 | n_slice = len(z_pos) 249 | part_z = {} 250 | if three_slices: 251 | # Select three slices (basal, mid and apical) for strain analysis, inspired by: 252 | # 253 | # [1] Robin J. Taylor, et al. Myocardial strain measurement with 254 | # feature-tracking cardiovascular magnetic resonance: normal values. 255 | # European Heart Journal - Cardiovascular Imaging, (2015) 16, 871-881. 256 | # 257 | # [2] A. Schuster, et al. Cardiovascular magnetic resonance feature- 258 | # tracking assessment of myocardial mechanics: Intervendor agreement 259 | # and considerations regarding reproducibility. Clinical Radiology 260 | # 70 (2015), 989-998. 261 | 262 | # Use the slice at 25% location from base to apex. 263 | # Avoid using the first one or two basal slices, as the myocardium 264 | # will move out of plane at ES due to longitudinal motion, which will 265 | # be a problem for 2D in-plane motion tracking. 266 | z = int(round((n_slice - 1) * 0.25)) 267 | part_z[z_pos[z][0]] = 'basal' 268 | 269 | # Use the central slice. 270 | z = int(round((n_slice - 1) * 0.5)) 271 | part_z[z_pos[z][0]] = 'mid' 272 | 273 | # Use the slice at 75% location from base to apex. 274 | # In the most apical slices, the myocardium looks blurry and 275 | # may not be suitable for motion tracking. 276 | z = int(round((n_slice - 1) * 0.75)) 277 | part_z[z_pos[z][0]] = 'apical' 278 | else: 279 | # Use all the slices 280 | i1 = int(math.ceil(n_slice / 3.0)) 281 | i2 = int(math.ceil(2 * n_slice / 3.0)) 282 | i3 = n_slice 283 | 284 | for i in range(0, i1): 285 | part_z[z_pos[i][0]] = 'basal' 286 | 287 | for i in range(i1, i2): 288 | part_z[z_pos[i][0]] = 'mid' 289 | 290 | for i in range(i2, i3): 291 | part_z[z_pos[i][0]] = 'apical' 292 | return part_z 293 | 294 | 295 | def determine_aha_segment_id(point, lv_centre, aha_axis, part): 296 | """ Determine the AHA segment ID given a point, 297 | the LV cavity center and the coordinate system. 298 | """ 299 | d = point - lv_centre 300 | x = np.dot(d, aha_axis['inf_to_ant']) 301 | y = np.dot(d, aha_axis['lv_to_sep']) 302 | deg = math.degrees(math.atan2(y, x)) 303 | seg_id = 0 304 | 305 | if part == 'basal': 306 | if (deg >= -30) and (deg < 30): 307 | seg_id = 1 308 | elif (deg >= 30) and (deg < 90): 309 | seg_id = 2 310 | elif (deg >= 90) and (deg < 150): 311 | seg_id = 3 312 | elif (deg >= 150) or (deg < -150): 313 | seg_id = 4 314 | elif (deg >= -150) and (deg < -90): 315 | seg_id = 5 316 | elif (deg >= -90) and (deg < -30): 317 | seg_id = 6 318 | else: 319 | print('Error: wrong degree {0}!'.format(deg)) 320 | exit(0) 321 | elif part == 'mid': 322 | if (deg >= -30) and (deg < 30): 323 | seg_id = 7 324 | elif (deg >= 30) and (deg < 90): 325 | seg_id = 8 326 | elif (deg >= 90) and (deg < 150): 327 | seg_id = 9 328 | elif (deg >= 150) or (deg < -150): 329 | seg_id = 10 330 | elif (deg >= -150) and (deg < -90): 331 | seg_id = 11 332 | elif (deg >= -90) and (deg < -30): 333 | seg_id = 12 334 | else: 335 | print('Error: wrong degree {0}!'.format(deg)) 336 | exit(0) 337 | elif part == 'apical': 338 | if (deg >= -45) and (deg < 45): 339 | seg_id = 13 340 | elif (deg >= 45) and (deg < 135): 341 | seg_id = 14 342 | elif (deg >= 135) or (deg < -135): 343 | seg_id = 15 344 | elif (deg >= -135) and (deg < -45): 345 | seg_id = 16 346 | else: 347 | print('Error: wrong degree {0}!'.format(deg)) 348 | exit(0) 349 | elif part == 'apex': 350 | seg_id = 17 351 | else: 352 | print('Error: unknown part {0}!'.format(part)) 353 | exit(0) 354 | return seg_id 355 | 356 | 357 | def evaluate_wall_thickness(seg_name, output_name_stem, part=None): 358 | """ Evaluate myocardial wall thickness. """ 359 | # Read the segmentation image 360 | nim = nib.load(seg_name) 361 | Z = nim.header['dim'][3] 362 | affine = nim.affine 363 | seg = nim.get_data() 364 | 365 | # Label class in the segmentation 366 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3} 367 | 368 | # Determine the AHA coordinate system using the mid-cavity slice 369 | aha_axis = determine_aha_coordinate_system(seg, affine) 370 | 371 | # Determine the AHA part of each slice 372 | part_z = {} 373 | if not part: 374 | part_z = determine_aha_part(seg, affine) 375 | else: 376 | part_z = {z: part for z in range(Z)} 377 | 378 | # Construct the points set to represent the endocardial contours 379 | endo_points = vtk.vtkPoints() 380 | thickness = vtk.vtkDoubleArray() 381 | thickness.SetName('Thickness') 382 | points_aha = vtk.vtkIntArray() 383 | points_aha.SetName('Segment ID') 384 | point_id = 0 385 | lines = vtk.vtkCellArray() 386 | 387 | # Save epicardial contour for debug and demonstration purposes 388 | save_epi_contour = False 389 | if save_epi_contour: 390 | epi_points = vtk.vtkPoints() 391 | points_epi_aha = vtk.vtkIntArray() 392 | points_epi_aha.SetName('Segment ID') 393 | point_epi_id = 0 394 | lines_epi = vtk.vtkCellArray() 395 | 396 | # For each slice 397 | for z in range(Z): 398 | # Check whether there is endocardial segmentation and it is not too small, 399 | # e.g. a single pixel, which either means the structure is missing or 400 | # causes problem in contour interpolation. 401 | seg_z = seg[:, :, z] 402 | endo = (seg_z == label['LV']).astype(np.uint8) 403 | endo = get_largest_cc(endo).astype(np.uint8) 404 | myo = (seg_z == label['Myo']).astype(np.uint8) 405 | myo = remove_small_cc(myo).astype(np.uint8) 406 | epi = (endo | myo).astype(np.uint8) 407 | epi = get_largest_cc(epi).astype(np.uint8) 408 | pixel_thres = 10 409 | if (np.sum(endo) < pixel_thres) or (np.sum(myo) < pixel_thres): 410 | continue 411 | 412 | # Calculate the centre of the LV cavity 413 | # Get the largest component in case we have a bad segmentation 414 | cx, cy = [np.mean(x) for x in np.nonzero(endo)] 415 | lv_centre = np.dot(affine, np.array([cx, cy, z, 1]))[:3] 416 | 417 | # Extract endocardial contour 418 | # Note: cv2 considers an input image as a Y x X array, which is different 419 | # from nibabel which assumes a X x Y array. 420 | _, contours, _ = cv2.findContours(cv2.inRange(endo, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 421 | endo_contour = contours[0][:, 0, :] 422 | 423 | # Extract epicardial contour 424 | _, contours, _ = cv2.findContours(cv2.inRange(epi, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 425 | epi_contour = contours[0][:, 0, :] 426 | 427 | # Smooth the contours 428 | endo_contour = approximate_contour(endo_contour, periodic=True) 429 | epi_contour = approximate_contour(epi_contour, periodic=True) 430 | 431 | # A polydata representation of the epicardial contour 432 | epi_points_z = vtk.vtkPoints() 433 | for y, x in epi_contour: 434 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 435 | epi_points_z.InsertNextPoint(p) 436 | epi_poly_z = vtk.vtkPolyData() 437 | epi_poly_z.SetPoints(epi_points_z) 438 | 439 | # Point locator for the epicardial contour 440 | locator = vtk.vtkPointLocator() 441 | locator.SetDataSet(epi_poly_z) 442 | locator.BuildLocator() 443 | 444 | # For each point on endocardium, find the closest point on epicardium 445 | N = endo_contour.shape[0] 446 | for i in range(N): 447 | y, x = endo_contour[i] 448 | 449 | # The world coordinate of this point 450 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 451 | endo_points.InsertNextPoint(p) 452 | 453 | # The closest epicardial point 454 | q = np.array(epi_points_z.GetPoint(locator.FindClosestPoint(p))) 455 | 456 | # The distance from endo to epi 457 | dist_pq = np.linalg.norm(q - p) 458 | 459 | # Add the point data 460 | thickness.InsertNextTuple1(dist_pq) 461 | seg_id = determine_aha_segment_id(p, lv_centre, aha_axis, part_z[z]) 462 | points_aha.InsertNextTuple1(seg_id) 463 | 464 | # Record the first point of the current contour 465 | if i == 0: 466 | contour_start_id = point_id 467 | 468 | # Add the line 469 | if i == (N - 1): 470 | lines.InsertNextCell(2, [point_id, contour_start_id]) 471 | else: 472 | lines.InsertNextCell(2, [point_id, point_id + 1]) 473 | 474 | # Increment the point index 475 | point_id += 1 476 | 477 | if save_epi_contour: 478 | # For each point on epicardium 479 | N = epi_contour.shape[0] 480 | for i in range(N): 481 | y, x = epi_contour[i] 482 | 483 | # The world coordinate of this point 484 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 485 | epi_points.InsertNextPoint(p) 486 | seg_id = determine_aha_segment_id(p, lv_centre, aha_axis, part_z[z]) 487 | points_epi_aha.InsertNextTuple1(seg_id) 488 | 489 | # Record the first point of the current contour 490 | if i == 0: 491 | contour_start_id = point_epi_id 492 | 493 | # Add the line 494 | if i == (N - 1): 495 | lines_epi.InsertNextCell(2, [point_epi_id, contour_start_id]) 496 | else: 497 | lines_epi.InsertNextCell(2, [point_epi_id, point_epi_id + 1]) 498 | 499 | # Increment the point index 500 | point_epi_id += 1 501 | 502 | # Save to a vtk file 503 | endo_poly = vtk.vtkPolyData() 504 | endo_poly.SetPoints(endo_points) 505 | endo_poly.GetPointData().AddArray(thickness) 506 | endo_poly.GetPointData().AddArray(points_aha) 507 | endo_poly.SetLines(lines) 508 | 509 | writer = vtk.vtkPolyDataWriter() 510 | output_name = '{0}.vtk'.format(output_name_stem) 511 | writer.SetFileName(output_name) 512 | writer.SetInputData(endo_poly) 513 | writer.Write() 514 | 515 | if save_epi_contour: 516 | epi_poly = vtk.vtkPolyData() 517 | epi_poly.SetPoints(epi_points) 518 | epi_poly.GetPointData().AddArray(points_epi_aha) 519 | epi_poly.SetLines(lines_epi) 520 | 521 | writer = vtk.vtkPolyDataWriter() 522 | output_name = '{0}_epi.vtk'.format(output_name_stem) 523 | writer.SetFileName(output_name) 524 | writer.SetInputData(epi_poly) 525 | writer.Write() 526 | 527 | # Evaluate the wall thickness per AHA segment and save to a csv file 528 | table_thickness = np.zeros(17) 529 | np_thickness = numpy_support.vtk_to_numpy(thickness).astype(np.float32) 530 | np_points_aha = numpy_support.vtk_to_numpy(points_aha).astype(np.int8) 531 | 532 | for i in range(16): 533 | table_thickness[i] = np.mean(np_thickness[np_points_aha == (i + 1)]) 534 | table_thickness[-1] = np.mean(np_thickness) 535 | 536 | index = [str(x) for x in np.arange(1, 17)] + ['Global'] 537 | df = pd.DataFrame(table_thickness, index=index, columns=['Thickness']) 538 | df.to_csv('{0}.csv'.format(output_name_stem)) 539 | 540 | 541 | def extract_myocardial_contour(seg_name, contour_name_stem, part=None, three_slices=False): 542 | """ Extract the myocardial contours, including both endo and epicardial contours. 543 | Determine the AHA segment ID for all the contour points. 544 | 545 | By default, part is None. This function will automatically determine the part 546 | for each slice (basal, mid or apical). 547 | If part is given, this function will use the given part for the image slice. 548 | """ 549 | # Read the segmentation image 550 | nim = nib.load(seg_name) 551 | X, Y, Z = nim.header['dim'][1:4] 552 | affine = nim.affine 553 | seg = nim.get_data() 554 | 555 | # Label class in the segmentation 556 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3} 557 | 558 | # Determine the AHA coordinate system using the mid-cavity slice 559 | aha_axis = determine_aha_coordinate_system(seg, affine) 560 | 561 | # Determine the AHA part of each slice 562 | part_z = {} 563 | if not part: 564 | part_z = determine_aha_part(seg, affine, three_slices=three_slices) 565 | else: 566 | part_z = {z: part for z in range(Z)} 567 | 568 | # For each slice 569 | for z in range(Z): 570 | # Check whether there is the endocardial segmentation 571 | seg_z = seg[:, :, z] 572 | endo = (seg_z == label['LV']).astype(np.uint8) 573 | endo = get_largest_cc(endo).astype(np.uint8) 574 | myo = (seg_z == label['Myo']).astype(np.uint8) 575 | myo = remove_small_cc(myo).astype(np.uint8) 576 | epi = (endo | myo).astype(np.uint8) 577 | epi = get_largest_cc(epi).astype(np.uint8) 578 | pixel_thres = 10 579 | if (np.sum(endo) < pixel_thres) or (np.sum(myo) < pixel_thres): 580 | continue 581 | 582 | # Check whether this slice is going to be analysed 583 | if z not in part_z.keys(): 584 | continue 585 | 586 | # Construct the points set and data arrays to represent both endo and epicardial contours 587 | points = vtk.vtkPoints() 588 | points_radial = vtk.vtkFloatArray() 589 | points_radial.SetName('Direction_Radial') 590 | points_radial.SetNumberOfComponents(3) 591 | points_label = vtk.vtkIntArray() 592 | points_label.SetName('Label') 593 | points_aha = vtk.vtkIntArray() 594 | points_aha.SetName('Segment ID') 595 | point_id = 0 596 | 597 | lines = vtk.vtkCellArray() 598 | lines_aha = vtk.vtkIntArray() 599 | lines_aha.SetName('Segment ID') 600 | lines_dir = vtk.vtkIntArray() 601 | lines_dir.SetName('Direction ID') 602 | 603 | # Calculate the centre of the LV cavity 604 | # Get the largest component in case we have a bad segmentation 605 | cx, cy = [np.mean(x) for x in np.nonzero(endo)] 606 | lv_centre = np.dot(affine, np.array([cx, cy, z, 1]))[:3] 607 | 608 | # Extract epicardial contour 609 | _, contours, _ = cv2.findContours(cv2.inRange(epi, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 610 | epi_contour = contours[0][:, 0, :] 611 | epi_contour = approximate_contour(epi_contour, periodic=True) 612 | 613 | N = epi_contour.shape[0] 614 | for i in range(N): 615 | y, x = epi_contour[i] 616 | 617 | # The world coordinate of this point 618 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 619 | points.InsertNextPoint(p[0], p[1], p[2]) 620 | 621 | # The radial direction from the cavity centre to this point 622 | d_rad = p - lv_centre 623 | d_rad = d_rad / np.linalg.norm(d_rad) 624 | points_radial.InsertNextTuple3(d_rad[0], d_rad[1], d_rad[2]) 625 | 626 | # Record the type of the point (1 = endo, 2 = epi) 627 | points_label.InsertNextTuple1(2) 628 | 629 | # Record the AHA segment ID 630 | seg_id = determine_aha_segment_id(p, lv_centre, aha_axis, part_z[z]) 631 | points_aha.InsertNextTuple1(seg_id) 632 | 633 | # Record the first point of the current contour 634 | if i == 0: 635 | contour_start_id = point_id 636 | 637 | # Add the circumferential line 638 | if i == (N - 1): 639 | lines.InsertNextCell(2, [point_id, contour_start_id]) 640 | else: 641 | lines.InsertNextCell(2, [point_id, point_id + 1]) 642 | 643 | lines_aha.InsertNextTuple1(seg_id) 644 | 645 | # Line direction (1 = radial, 2 = circumferential, 3 = longitudinal) 646 | lines_dir.InsertNextTuple1(2) 647 | 648 | # Increment the point index 649 | point_id += 1 650 | 651 | # Point locator 652 | epi_points = vtk.vtkPoints() 653 | epi_points.DeepCopy(points) 654 | epi_poly = vtk.vtkPolyData() 655 | epi_poly.SetPoints(epi_points) 656 | locator = vtk.vtkPointLocator() 657 | locator.SetDataSet(epi_poly) 658 | locator.BuildLocator() 659 | 660 | # Extract endocardial contour 661 | # Note: cv2 considers an input image as a Y x X array, which is different 662 | # from nibabel which assumes a X x Y array. 663 | _, contours, _ = cv2.findContours(cv2.inRange(endo, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 664 | endo_contour = contours[0][:, 0, :] 665 | endo_contour = approximate_contour(endo_contour, periodic=True) 666 | 667 | N = endo_contour.shape[0] 668 | for i in range(N): 669 | y, x = endo_contour[i] 670 | 671 | # The world coordinate of this point 672 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 673 | points.InsertNextPoint(p[0], p[1], p[2]) 674 | 675 | # The radial direction from the cavity centre to this point 676 | d_rad = p - lv_centre 677 | d_rad = d_rad / np.linalg.norm(d_rad) 678 | points_radial.InsertNextTuple3(d_rad[0], d_rad[1], d_rad[2]) 679 | 680 | # Record the type of the point (1 = endo, 2 = epi) 681 | points_label.InsertNextTuple1(1) 682 | 683 | # Record the AHA segment ID 684 | seg_id = determine_aha_segment_id(p, lv_centre, aha_axis, part_z[z]) 685 | points_aha.InsertNextTuple1(seg_id) 686 | 687 | # Record the first point of the current contour 688 | if i == 0: 689 | contour_start_id = point_id 690 | 691 | # Add the circumferential line 692 | if i == (N - 1): 693 | lines.InsertNextCell(2, [point_id, contour_start_id]) 694 | else: 695 | lines.InsertNextCell(2, [point_id, point_id + 1]) 696 | 697 | lines_aha.InsertNextTuple1(seg_id) 698 | 699 | # Line direction (1 = radial, 2 = circumferential, 3 = longitudinal) 700 | lines_dir.InsertNextTuple1(2) 701 | 702 | # Add the radial line for every few points 703 | n_radial = 36 704 | M = int(round(N / float(n_radial))) 705 | if i % M == 0: 706 | # The closest epicardial points 707 | ids = vtk.vtkIdList() 708 | n_ids = 10 709 | locator.FindClosestNPoints(n_ids, p, ids) 710 | 711 | # The point that aligns with the radial direction 712 | val = [] 713 | for j in range(n_ids): 714 | q = epi_points.GetPoint(ids.GetId(j)) 715 | d = (q - lv_centre) / np.linalg.norm(q - lv_centre) 716 | val += [np.dot(d, d_rad)] 717 | val = np.array(val) 718 | epi_point_id = ids.GetId(np.argmax(val)) 719 | 720 | # Add the radial line 721 | lines.InsertNextCell(2, [point_id, epi_point_id]) 722 | lines_aha.InsertNextTuple1(seg_id) 723 | 724 | # Line direction (1 = radial, 2 = circumferential, 3 = longitudinal) 725 | lines_dir.InsertNextTuple1(1) 726 | 727 | # Increment the point index 728 | point_id += 1 729 | 730 | # Save the contour for each slice 731 | poly = vtk.vtkPolyData() 732 | poly.SetPoints(points) 733 | poly.GetPointData().AddArray(points_label) 734 | poly.GetPointData().AddArray(points_aha) 735 | poly.GetPointData().AddArray(points_radial) 736 | poly.SetLines(lines) 737 | poly.GetCellData().AddArray(lines_aha) 738 | poly.GetCellData().AddArray(lines_dir) 739 | 740 | writer = vtk.vtkPolyDataWriter() 741 | contour_name = '{0}{1:02d}.vtk'.format(contour_name_stem, z) 742 | writer.SetFileName(contour_name) 743 | writer.SetInputData(poly) 744 | writer.Write() 745 | os.system('sed -i "1s/4.1/4.0/" {0}'.format(contour_name)) 746 | 747 | 748 | def evaluate_strain_by_length(contour_name_stem, T, dt, output_name_stem): 749 | """ Calculate the strain based on the line length """ 750 | # Read the polydata at the first time frame (ED frame) 751 | fr = 0 752 | reader = vtk.vtkPolyDataReader() 753 | reader.SetFileName('{0}{1:02d}.vtk'.format(contour_name_stem, fr)) 754 | reader.Update() 755 | poly = reader.GetOutput() 756 | points = poly.GetPoints() 757 | 758 | # Calculate the length of each line 759 | lines = poly.GetLines() 760 | lines_aha = poly.GetCellData().GetArray('Segment ID') 761 | lines_dir = poly.GetCellData().GetArray('Direction ID') 762 | n_lines = lines.GetNumberOfCells() 763 | length_ED = np.zeros(n_lines) 764 | seg_id = np.zeros(n_lines) 765 | dir_id = np.zeros(n_lines) 766 | 767 | lines.InitTraversal() 768 | for i in range(n_lines): 769 | ids = vtk.vtkIdList() 770 | lines.GetNextCell(ids) 771 | p1 = np.array(points.GetPoint(ids.GetId(0))) 772 | p2 = np.array(points.GetPoint(ids.GetId(1))) 773 | d = np.linalg.norm(p1 - p2) 774 | seg_id[i] = lines_aha.GetValue(i) 775 | dir_id[i] = lines_dir.GetValue(i) 776 | length_ED[i] = d 777 | 778 | # For each time frame, calculate the strain, i.e. change of length 779 | table_strain = {} 780 | table_strain['radial'] = np.zeros((17, T)) 781 | table_strain['circum'] = np.zeros((17, T)) 782 | 783 | for fr in range(0, T): 784 | # Read the polydata 785 | reader = vtk.vtkPolyDataReader() 786 | filename = '{0}{1:02d}.vtk'.format(contour_name_stem, fr) 787 | reader.SetFileName(filename) 788 | reader.Update() 789 | poly = reader.GetOutput() 790 | points = poly.GetPoints() 791 | 792 | # Calculate the strain for each line 793 | lines = poly.GetLines() 794 | n_lines = lines.GetNumberOfCells() 795 | strain = np.zeros(n_lines) 796 | vtk_strain = vtk.vtkFloatArray() 797 | vtk_strain.SetName('Strain') 798 | lines.InitTraversal() 799 | for i in range(n_lines): 800 | ids = vtk.vtkIdList() 801 | lines.GetNextCell(ids) 802 | p1 = np.array(points.GetPoint(ids.GetId(0))) 803 | p2 = np.array(points.GetPoint(ids.GetId(1))) 804 | d = np.linalg.norm(p1 - p2) 805 | 806 | # Strain of this line (unit: %) 807 | strain[i] = (d - length_ED[i]) / length_ED[i] * 100 808 | vtk_strain.InsertNextTuple1(strain[i]) 809 | 810 | # Save the strain array to the vtk file 811 | poly.GetCellData().AddArray(vtk_strain) 812 | writer = vtk.vtkPolyDataWriter() 813 | writer.SetInputData(poly) 814 | writer.SetFileName(filename) 815 | writer.Write() 816 | os.system('sed -i "1s/4.1/4.0/" {0}'.format(filename)) 817 | 818 | # Calculate the segmental and global strains 819 | for i in range(0, 16): 820 | table_strain['radial'][i, fr] = np.mean(strain[(seg_id == (i + 1)) & (dir_id == 1)]) 821 | table_strain['circum'][i, fr] = np.mean(strain[(seg_id == (i + 1)) & (dir_id == 2)]) 822 | table_strain['radial'][-1, fr] = np.mean(strain[dir_id == 1]) 823 | table_strain['circum'][-1, fr] = np.mean(strain[dir_id == 2]) 824 | 825 | for c in ['radial', 'circum']: 826 | # Save into csv files 827 | index = [str(x) for x in np.arange(1, 17)] + ['Global'] 828 | column = np.arange(0, T) * dt * 1e3 829 | df = pd.DataFrame(table_strain[c], index=index, columns=column) 830 | df.to_csv('{0}_{1}.csv'.format(output_name_stem, c)) 831 | 832 | 833 | def cine_2d_sa_motion_and_strain_analysis(data_dir, par_dir, output_dir, output_name_stem): 834 | """ Perform motion tracking and strain analysis for cine MR images. """ 835 | # Crop the image to save computation for image registration 836 | # Focus on the left ventricle so that motion tracking is less affected by 837 | # the movement of RV and LV outflow tract 838 | padding('{0}/seg_sa_ED.nii.gz'.format(data_dir), 839 | '{0}/seg_sa_ED.nii.gz'.format(data_dir), 840 | '{0}/seg_sa_lv_ED.nii.gz'.format(output_dir), 3, 0) 841 | auto_crop_image('{0}/seg_sa_lv_ED.nii.gz'.format(output_dir), 842 | '{0}/seg_sa_lv_crop_ED.nii.gz'.format(output_dir), 20) 843 | os.system('mirtk transform-image {0}/sa.nii.gz {1}/sa_crop.nii.gz ' 844 | '-target {1}/seg_sa_lv_crop_ED.nii.gz'.format(data_dir, output_dir)) 845 | os.system('mirtk transform-image {0}/seg_sa.nii.gz {1}/seg_sa_crop.nii.gz ' 846 | '-target {1}/seg_sa_lv_crop_ED.nii.gz'.format(data_dir, output_dir)) 847 | 848 | # Extract the myocardial contours for three slices, respectively basal, mid-cavity and apical 849 | extract_myocardial_contour('{0}/seg_sa_ED.nii.gz'.format(data_dir), 850 | '{0}/myo_contour_ED_z'.format(output_dir), 851 | three_slices=True) 852 | 853 | # Split the volume into slices 854 | split_volume('{0}/sa_crop.nii.gz'.format(output_dir), '{0}/sa_crop_z'.format(output_dir)) 855 | split_volume('{0}/seg_sa_crop.nii.gz'.format(output_dir), '{0}/seg_sa_crop_z'.format(output_dir)) 856 | 857 | # Label class in the segmentation 858 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3} 859 | 860 | # Inter-frame motion estimation 861 | nim = nib.load('{0}/sa_crop.nii.gz'.format(output_dir)) 862 | Z = nim.header['dim'][3] 863 | T = nim.header['dim'][4] 864 | dt = nim.header['pixdim'][4] 865 | dice_lv_myo = [] 866 | for z in range(Z): 867 | if not os.path.exists('{0}/myo_contour_ED_z{1:02d}.vtk'.format(output_dir, z)): 868 | continue 869 | 870 | # Split the cine sequence for this slice 871 | split_sequence('{0}/sa_crop_z{1:02d}.nii.gz'.format(output_dir, z), 872 | '{0}/sa_crop_z{1:02d}_fr'.format(output_dir, z)) 873 | 874 | # Forward image registration 875 | for fr in range(1, T): 876 | target_fr = fr - 1 877 | source_fr = fr 878 | target = '{0}/sa_crop_z{1:02d}_fr{2:02d}.nii.gz'.format(output_dir, z, target_fr) 879 | source = '{0}/sa_crop_z{1:02d}_fr{2:02d}.nii.gz'.format(output_dir, z, source_fr) 880 | par = '{0}/ffd_cine_2d_motion.cfg'.format(par_dir) 881 | dof = '{0}/ffd_z{1:02d}_pair_{2:02d}_to_{3:02d}.dof.gz'.format(output_dir, z, target_fr, source_fr) 882 | os.system('mirtk register {0} {1} -parin {2} -dofout {3}'.format(target, source, par, dof)) 883 | 884 | # Compose forward inter-frame transformation fields 885 | os.system('cp {0}/ffd_z{1:02d}_pair_00_to_01.dof.gz ' 886 | '{0}/ffd_z{1:02d}_forward_00_to_01.dof.gz'.format(output_dir, z)) 887 | for fr in range(2, T): 888 | dofs = '' 889 | for k in range(1, fr + 1): 890 | dof = '{0}/ffd_z{1:02d}_pair_{2:02d}_to_{3:02d}.dof.gz'.format(output_dir, z, k - 1, k) 891 | dofs += dof + ' ' 892 | dof_out = '{0}/ffd_z{1:02d}_forward_00_to_{2:02d}.dof.gz'.format(output_dir, z, fr) 893 | os.system('mirtk compose-dofs {0} {1} -approximate'.format(dofs, dof_out)) 894 | 895 | # Backward image registration 896 | for fr in range(T - 1, 0, -1): 897 | target_fr = (fr + 1) % T 898 | source_fr = fr 899 | target = '{0}/sa_crop_z{1:02d}_fr{2:02d}.nii.gz'.format(output_dir, z, target_fr) 900 | source = '{0}/sa_crop_z{1:02d}_fr{2:02d}.nii.gz'.format(output_dir, z, source_fr) 901 | par = '{0}/ffd_cine_2d_motion.cfg'.format(par_dir) 902 | dof = '{0}/ffd_z{1:02d}_pair_{2:02d}_to_{3:02d}.dof.gz'.format(output_dir, z, target_fr, source_fr) 903 | os.system('mirtk register {0} {1} -parin {2} -dofout {3}'.format(target, source, par, dof)) 904 | 905 | # Compose backward inter-frame transformation fields 906 | os.system('cp {0}/ffd_z{1:02d}_pair_00_to_49.dof.gz ' 907 | '{0}/ffd_z{1:02d}_backward_00_to_49.dof.gz'.format(output_dir, z)) 908 | for fr in range(T - 2, 0, -1): 909 | dofs = '' 910 | for k in range(T - 1, fr - 1, -1): 911 | dof = '{0}/ffd_z{1:02d}_pair_{2:02d}_to_{3:02d}.dof.gz'.format(output_dir, z, 912 | (k + 1) % T, k) 913 | dofs += dof + ' ' 914 | dof_out = '{0}/ffd_z{1:02d}_backward_00_to_{2:02d}.dof.gz'.format(output_dir, z, fr) 915 | os.system('mirtk compose-dofs {0} {1} -approximate'.format(dofs, dof_out)) 916 | 917 | # Average the forward and backward transformations 918 | os.system('mirtk init-dof {0}/ffd_z{1:02d}_forward_00_to_00.dof.gz'.format(output_dir, z)) 919 | os.system('mirtk init-dof {0}/ffd_z{1:02d}_backward_00_to_00.dof.gz'.format(output_dir, z)) 920 | os.system('mirtk init-dof {0}/ffd_z{1:02d}_00_to_00.dof.gz'.format(output_dir, z)) 921 | for fr in range(1, T): 922 | dof_forward = '{0}/ffd_z{1:02d}_forward_00_to_{2:02d}.dof.gz'.format(output_dir, z, fr) 923 | weight_forward = float(T - fr) / T 924 | dof_backward = '{0}/ffd_z{1:02d}_backward_00_to_{2:02d}.dof.gz'.format(output_dir, z, fr) 925 | weight_backward = float(fr) / T 926 | dof_combine = '{0}/ffd_z{1:02d}_00_to_{2:02d}.dof.gz'.format(output_dir, z, fr) 927 | os.system('average_3d_ffd 2 {0} {1} {2} {3} {4}'.format(dof_forward, weight_forward, 928 | dof_backward, weight_backward, 929 | dof_combine)) 930 | 931 | # Transform the contours 932 | for fr in range(0, T): 933 | os.system('mirtk transform-points {0}/myo_contour_ED_z{1:02d}.vtk ' 934 | '{0}/myo_contour_z{1:02d}_fr{2:02d}.vtk ' 935 | '-dofin {0}/ffd_z{1:02d}_00_to_{2:02d}.dof.gz'.format(output_dir, z, fr)) 936 | 937 | # Transform the segmentation and evaluate the Dice metric 938 | eval_dice = False 939 | if eval_dice: 940 | split_sequence('{0}/seg_sa_crop_z{1:02d}.nii.gz'.format(output_dir, z), 941 | '{0}/seg_sa_crop_z{1:02d}_fr'.format(output_dir, z)) 942 | 943 | image_names = [] 944 | for fr in range(0, T): 945 | os.system('mirtk transform-image {0}/seg_sa_crop_z{1:02d}_fr{2:02d}.nii.gz ' 946 | '{0}/seg_sa_crop_warp_ffd_z{1:02d}_fr{2:02d}.nii.gz ' 947 | '-dofin {0}/ffd_z{1:02d}_00_to_{2:02d}.dof.gz ' 948 | '-target {0}/seg_sa_crop_z{1:02d}_fr00.nii.gz'.format(output_dir, z, fr)) 949 | image_A = nib.load('{0}/seg_sa_crop_z{1:02d}_fr00.nii.gz'.format(output_dir, z)).get_data() 950 | image_B = nib.load('{0}/seg_sa_crop_warp_ffd_z{1:02d}_fr{2:02d}.nii.gz'.format(output_dir, z, fr)).get_data() 951 | dice_lv_myo += [[np_categorical_dice(image_A, image_B, 1), 952 | np_categorical_dice(image_A, image_B, 2)]] 953 | image_names += ['{0}/seg_sa_crop_warp_ffd_z{1:02d}_fr{2:02d}.nii.gz'.format(output_dir, z, fr)] 954 | combine_name = '{0}/seg_sa_crop_warp_ffd_z{1:02d}.nii.gz'.format(output_dir, z) 955 | make_sequence(image_names, dt, combine_name) 956 | 957 | if eval_dice: 958 | print(np.mean(dice_lv_myo, axis=0)) 959 | df_dice = pd.DataFrame(dice_lv_myo) 960 | df_dice.to_csv('{0}/dice_cine_warp_ffd.csv'.format(output_dir), index=None, header=None) 961 | 962 | # Merge the 2D tracked contours from all the slice 963 | for fr in range(0, T): 964 | append = vtk.vtkAppendPolyData() 965 | reader = {} 966 | for z in range(Z): 967 | if not os.path.exists('{0}/myo_contour_z{1:02d}_fr{2:02d}.vtk'.format(output_dir, z, fr)): 968 | continue 969 | reader[z] = vtk.vtkPolyDataReader() 970 | reader[z].SetFileName('{0}/myo_contour_z{1:02d}_fr{2:02d}.vtk'.format(output_dir, z, fr)) 971 | reader[z].Update() 972 | append.AddInputData(reader[z].GetOutput()) 973 | append.Update() 974 | writer = vtk.vtkPolyDataWriter() 975 | writer.SetFileName('{0}/myo_contour_fr{1:02d}.vtk'.format(output_dir, fr)) 976 | writer.SetInputData(append.GetOutput()) 977 | writer.Write() 978 | 979 | # Calculate the strain based on the line length 980 | evaluate_strain_by_length('{0}/myo_contour_fr'.format(output_dir), T, dt, output_name_stem) 981 | 982 | 983 | def remove_mitral_valve_points(endo_contour, epi_contour, mitral_plane): 984 | """ Remove the mitral valve points from the contours and 985 | start the contours from the point next to the mitral valve plane. 986 | So connecting the lines will be easier in the next step. 987 | """ 988 | N = endo_contour.shape[0] 989 | start_i = 0 990 | for i in range(N): 991 | y, x = endo_contour[i] 992 | prev_y, prev_x = endo_contour[(i - 1) % N] 993 | if not mitral_plane[x, y] and mitral_plane[prev_x, prev_y]: 994 | start_i = i 995 | break 996 | endo_contour = np.concatenate((endo_contour[start_i:], endo_contour[:start_i])) 997 | 998 | N = endo_contour.shape[0] 999 | end_i = N 1000 | for i in range(N): 1001 | y, x = endo_contour[i] 1002 | if mitral_plane[x, y]: 1003 | end_i = i 1004 | break 1005 | endo_contour = endo_contour[:end_i] 1006 | 1007 | N = epi_contour.shape[0] 1008 | start_i = 0 1009 | for i in range(N): 1010 | y, x = epi_contour[i] 1011 | y2, x2 = epi_contour[(i - 1) % N] 1012 | if not mitral_plane[x, y] and mitral_plane[x2, y2]: 1013 | start_i = i 1014 | break 1015 | epi_contour = np.concatenate((epi_contour[start_i:], epi_contour[:start_i])) 1016 | 1017 | N = epi_contour.shape[0] 1018 | end_i = N 1019 | for i in range(N): 1020 | y, x = epi_contour[i] 1021 | if mitral_plane[x, y]: 1022 | end_i = i 1023 | break 1024 | epi_contour = epi_contour[:end_i] 1025 | return endo_contour, epi_contour 1026 | 1027 | 1028 | def determine_la_aha_part(seg_la, affine_la, affine_sa): 1029 | """ Extract the mid-line of the left ventricle, record its index 1030 | along the long-axis and determine the part for each index. 1031 | """ 1032 | # Label class in the segmentation 1033 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3, 'LA': 4, 'RA': 5} 1034 | 1035 | # Sort the left ventricle and myocardium points according to their long-axis locations 1036 | lv_myo_points = [] 1037 | X, Y = seg_la.shape[:2] 1038 | z = 0 1039 | for y in range(Y): 1040 | for x in range(X): 1041 | if seg_la[x, y] == label['LV'] or seg_la[x, y] == label['Myo']: 1042 | z_sa = np.dot(np.linalg.inv(affine_sa), np.dot(affine_la, np.array([x, y, z, 1])))[2] 1043 | la_idx = int(round(z_sa * 2)) 1044 | lv_myo_points += [[x, y, la_idx]] 1045 | lv_myo_points = np.array(lv_myo_points) 1046 | lv_myo_idx_min = np.min(lv_myo_points[:, 2]) 1047 | lv_myo_idx_max = np.max(lv_myo_points[:, 2]) 1048 | 1049 | # Determine the AHA part according to the slice location along the long-axis 1050 | if affine_sa[2, 2] > 0: 1051 | la_idx = np.arange(lv_myo_idx_max, lv_myo_idx_min, -1) 1052 | else: 1053 | la_idx = np.arange(lv_myo_idx_min, lv_myo_idx_max + 1, 1) 1054 | 1055 | n_la_idx = len(la_idx) 1056 | i1 = int(math.ceil(n_la_idx / 3.0)) 1057 | i2 = int(math.ceil(2 * n_la_idx / 3.0)) 1058 | i3 = n_la_idx 1059 | 1060 | part_z = {} 1061 | for i in range(0, i1): 1062 | part_z[la_idx[i]] = 'basal' 1063 | 1064 | for i in range(i1, i2): 1065 | part_z[la_idx[i]] = 'mid' 1066 | 1067 | for i in range(i2, i3): 1068 | part_z[la_idx[i]] = 'apical' 1069 | 1070 | # Extract the mid-line of left ventricle endocardium. 1071 | # Only use the endocardium points so that it would not be affected by 1072 | # the myocardium points at the most basal slices. 1073 | lv_points = [] 1074 | X, Y = seg_la.shape[:2] 1075 | z = 0 1076 | for y in range(Y): 1077 | for x in range(X): 1078 | if seg_la[x, y] == label['LV']: 1079 | z_sa = np.dot(np.linalg.inv(affine_sa), np.dot(affine_la, np.array([x, y, z, 1])))[2] 1080 | la_idx = int(round(z_sa * 2)) 1081 | lv_points += [[x, y, la_idx]] 1082 | lv_points = np.array(lv_points) 1083 | lv_idx_min = np.min(lv_points[:, 2]) 1084 | lv_idx_max = np.max(lv_points[:, 2]) 1085 | 1086 | mid_line = {} 1087 | for la_idx in range(lv_idx_min, lv_idx_max + 1): 1088 | mx, my = np.mean(lv_points[lv_points[:, 2] == la_idx, :2], axis=0) 1089 | mid_line[la_idx] = np.dot(affine_la, np.array([mx, my, z, 1]))[:3] 1090 | 1091 | for la_idx in range(lv_myo_idx_min, lv_idx_min): 1092 | mid_line[la_idx] = mid_line[lv_idx_min] 1093 | 1094 | for la_idx in range(lv_idx_max, lv_myo_idx_max + 1): 1095 | mid_line[la_idx] = mid_line[lv_idx_max] 1096 | return part_z, mid_line 1097 | 1098 | 1099 | def determine_la_aha_segment_id(point, la_idx, axis, mid_line, part_z): 1100 | """ Determine the AHA segment ID given a point on long-axis images. 1101 | """ 1102 | # The mid-point at this position 1103 | mid_point = mid_line[la_idx] 1104 | 1105 | # The line from the mid-point to the contour point 1106 | vec = point - mid_point 1107 | if np.dot(vec, axis['lv_to_sep']) > 0: 1108 | # This is spetum 1109 | if part_z[la_idx] == 'basal': 1110 | # basal septal 1111 | seg_id = 1 1112 | elif part_z[la_idx] == 'mid': 1113 | # mid septal 1114 | seg_id = 3 1115 | elif part_z[la_idx] == 'apical': 1116 | # apical septal 1117 | seg_id = 5 1118 | else: 1119 | # This is lateral 1120 | if part_z[la_idx] == 'basal': 1121 | # basal lateral 1122 | seg_id = 2 1123 | elif part_z[la_idx] == 'mid': 1124 | # mid lateral 1125 | seg_id = 4 1126 | elif part_z[la_idx] == 'apical': 1127 | # apical lateral 1128 | seg_id = 6 1129 | return seg_id 1130 | 1131 | 1132 | def extract_la_myocardial_contour(seg_la_name, seg_sa_name, contour_name): 1133 | """ Extract the myocardial contours on long-axis images. 1134 | Also, determine the AHA segment ID for all the contour points. 1135 | """ 1136 | # Read the segmentation image 1137 | nim = nib.load(seg_la_name) 1138 | X, Y, Z = nim.header['dim'][1:4] 1139 | affine = nim.affine 1140 | seg = nim.get_data() 1141 | 1142 | # Label class in the segmentation 1143 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3, 'LA': 4, 'RA': 5} 1144 | 1145 | # Determine the AHA coordinate system using the mid-cavity slice of short-axis images 1146 | nim_sa = nib.load(seg_sa_name) 1147 | affine_sa = nim_sa.affine 1148 | seg_sa = nim_sa.get_data() 1149 | aha_axis = determine_aha_coordinate_system(seg_sa, affine_sa) 1150 | 1151 | # Construct the points set and data arrays to represent both endo and epicardial contours 1152 | points = vtk.vtkPoints() 1153 | points_radial = vtk.vtkFloatArray() 1154 | points_radial.SetName('Direction_Radial') 1155 | points_radial.SetNumberOfComponents(3) 1156 | points_label = vtk.vtkIntArray() 1157 | points_label.SetName('Label') 1158 | points_aha = vtk.vtkIntArray() 1159 | points_aha.SetName('Segment ID') 1160 | point_id = 0 1161 | lines = vtk.vtkCellArray() 1162 | lines_aha = vtk.vtkIntArray() 1163 | lines_aha.SetName('Segment ID') 1164 | lines_dir = vtk.vtkIntArray() 1165 | lines_dir.SetName('Direction ID') 1166 | 1167 | # Check whether there is the endocardial segmentation 1168 | # Only keep the largest connected component 1169 | z = 0 1170 | seg_z = seg[:, :, z] 1171 | endo = (seg_z == label['LV']).astype(np.uint8) 1172 | endo = get_largest_cc(endo).astype(np.uint8) 1173 | # The myocardium may be split to two parts due to the very thin apex. 1174 | # So we do not apply get_largest_cc() to it. However, we remove small pieces, which 1175 | # may cause problems in determining the contours. 1176 | myo = (seg_z == label['Myo']).astype(np.uint8) 1177 | myo = remove_small_cc(myo).astype(np.uint8) 1178 | epi = (endo | myo).astype(np.uint8) 1179 | epi = get_largest_cc(epi).astype(np.uint8) 1180 | 1181 | # Extract endocardial contour 1182 | # Note: cv2 considers an input image as a Y x X array, which is different 1183 | # from nibabel which assumes a X x Y array. 1184 | _, contours, _ = cv2.findContours(cv2.inRange(endo, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 1185 | endo_contour = contours[0][:, 0, :] 1186 | 1187 | # Extract epicardial contour 1188 | _, contours, _ = cv2.findContours(cv2.inRange(epi, 1, 1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 1189 | epi_contour = contours[0][:, 0, :] 1190 | 1191 | # Record the points located on the mitral valve plane. 1192 | mitral_plane = np.zeros(seg_z.shape) 1193 | N = epi_contour.shape[0] 1194 | for i in range(N): 1195 | y, x = epi_contour[i] 1196 | if endo[x, y]: 1197 | mitral_plane[x, y] = 1 1198 | 1199 | # Remove the mitral valve points from the contours and 1200 | # start the contours from the point next to the mitral valve plane. 1201 | # So connecting the lines will be easier in the next step. 1202 | if np.sum(mitral_plane) >= 1: 1203 | endo_contour, epi_contour = remove_mitral_valve_points(endo_contour, epi_contour, mitral_plane) 1204 | 1205 | # Note that remove_mitral_valve_points may fail if the endo or epi has more 1206 | # than one connected components. As a result, the endo_contour or epi_contour 1207 | # may only have zero or one points left, which cause problems for approximate_contour. 1208 | 1209 | # Smooth the contours 1210 | if len(endo_contour) >= 2: 1211 | endo_contour = approximate_contour(endo_contour) 1212 | if len(epi_contour) >= 2: 1213 | epi_contour = approximate_contour(epi_contour) 1214 | 1215 | # Determine the aha part and extract the mid-line of the left ventricle 1216 | part_z, mid_line = determine_la_aha_part(seg_z, affine, affine_sa) 1217 | la_idx_min = np.array([x for x in part_z.keys()]).min() 1218 | la_idx_max = np.array([x for x in part_z.keys()]).max() 1219 | 1220 | # Go through the endo contour points 1221 | N = endo_contour.shape[0] 1222 | for i in range(N): 1223 | y, x = endo_contour[i] 1224 | 1225 | # The world coordinate of this point 1226 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 1227 | points.InsertNextPoint(p[0], p[1], p[2]) 1228 | 1229 | # The index along the long axis 1230 | z_sa = np.dot(np.linalg.inv(affine_sa), np.hstack([p, 1]))[2] 1231 | la_idx = int(round(z_sa * 2)) 1232 | la_idx = max(la_idx, la_idx_min) 1233 | la_idx = min(la_idx, la_idx_max) 1234 | 1235 | # The radial direction 1236 | mid_point = mid_line[la_idx] 1237 | d = p - mid_point 1238 | d = d / np.linalg.norm(d) 1239 | points_radial.InsertNextTuple3(d[0], d[1], d[2]) 1240 | 1241 | # Record the type of the point (1 = endo, 2 = epi) 1242 | points_label.InsertNextTuple1(1) 1243 | 1244 | # Record the segment ID 1245 | seg_id = determine_la_aha_segment_id(p, la_idx, aha_axis, mid_line, part_z) 1246 | points_aha.InsertNextTuple1(seg_id) 1247 | 1248 | # Add the line 1249 | if i < (N - 1): 1250 | lines.InsertNextCell(2, [point_id, point_id + 1]) 1251 | lines_aha.InsertNextTuple1(seg_id) 1252 | 1253 | # Line direction (1 = radial, 2 = circumferential, 3 = longitudinal) 1254 | lines_dir.InsertNextTuple1(3) 1255 | 1256 | # Increment the point index 1257 | point_id += 1 1258 | 1259 | # Go through the epi contour points 1260 | N = epi_contour.shape[0] 1261 | for i in range(N): 1262 | y, x = epi_contour[i] 1263 | 1264 | # The world coordinate of this point 1265 | p = np.dot(affine, np.array([x, y, z, 1]))[:3] 1266 | points.InsertNextPoint(p[0], p[1], p[2]) 1267 | 1268 | # The index along the long axis 1269 | z_sa = np.dot(np.linalg.inv(affine_sa), np.hstack([p, 1]))[2] 1270 | la_idx = int(round(z_sa * 2)) 1271 | la_idx = max(la_idx, la_idx_min) 1272 | la_idx = min(la_idx, la_idx_max) 1273 | 1274 | # The radial direction 1275 | mid_point = mid_line[la_idx] 1276 | d = p - mid_point 1277 | d = d / np.linalg.norm(d) 1278 | points_radial.InsertNextTuple3(d[0], d[1], d[2]) 1279 | 1280 | # Record the type of the point (1 = endo, 2 = epi) 1281 | points_label.InsertNextTuple1(2) 1282 | 1283 | # Record the segment ID 1284 | seg_id = determine_la_aha_segment_id(p, la_idx, aha_axis, mid_line, part_z) 1285 | points_aha.InsertNextTuple1(seg_id) 1286 | 1287 | # Add the line 1288 | if i < (N - 1): 1289 | lines.InsertNextCell(2, [point_id, point_id + 1]) 1290 | lines_aha.InsertNextTuple1(seg_id) 1291 | 1292 | # Line direction (1 = radial, 2 = circumferential, 3 = longitudinal) 1293 | lines_dir.InsertNextTuple1(3) 1294 | 1295 | # Increment the point index 1296 | point_id += 1 1297 | 1298 | # Save to a vtk file 1299 | poly = vtk.vtkPolyData() 1300 | poly.SetPoints(points) 1301 | poly.GetPointData().AddArray(points_label) 1302 | poly.GetPointData().AddArray(points_aha) 1303 | poly.GetPointData().AddArray(points_radial) 1304 | poly.SetLines(lines) 1305 | poly.GetCellData().AddArray(lines_aha) 1306 | poly.GetCellData().AddArray(lines_dir) 1307 | 1308 | writer = vtk.vtkPolyDataWriter() 1309 | writer.SetFileName(contour_name) 1310 | writer.SetInputData(poly) 1311 | writer.Write() 1312 | 1313 | # Change vtk file version to 4.0 to avoid the warning by MIRTK, which is 1314 | # developed using VTK 6.3, which does not know file version 4.1. 1315 | os.system('sed -i "1s/4.1/4.0/" {0}'.format(contour_name)) 1316 | 1317 | 1318 | def evaluate_la_strain_by_length(contour_name_stem, T, dt, output_name_stem): 1319 | """ Calculate the strain based on the line length """ 1320 | # Read the polydata at the first time frame (ED frame) 1321 | fr = 0 1322 | reader = vtk.vtkPolyDataReader() 1323 | reader.SetFileName('{0}{1:02d}.vtk'.format(contour_name_stem, fr)) 1324 | reader.Update() 1325 | poly = reader.GetOutput() 1326 | points = poly.GetPoints() 1327 | 1328 | # Calculate the length of each line 1329 | lines = poly.GetLines() 1330 | lines_aha = poly.GetCellData().GetArray('Segment ID') 1331 | lines_dir = poly.GetCellData().GetArray('Direction ID') 1332 | n_lines = lines.GetNumberOfCells() 1333 | length_ED = np.zeros(n_lines) 1334 | seg_id = np.zeros(n_lines) 1335 | dir_id = np.zeros(n_lines) 1336 | 1337 | lines.InitTraversal() 1338 | for i in range(n_lines): 1339 | ids = vtk.vtkIdList() 1340 | lines.GetNextCell(ids) 1341 | p1 = np.array(points.GetPoint(ids.GetId(0))) 1342 | p2 = np.array(points.GetPoint(ids.GetId(1))) 1343 | d = np.linalg.norm(p1 - p2) 1344 | seg_id[i] = lines_aha.GetValue(i) 1345 | dir_id[i] = lines_dir.GetValue(i) 1346 | length_ED[i] = d 1347 | 1348 | # For each time frame, calculate the strain, i.e. change of length 1349 | table_strain = {} 1350 | table_strain['longit'] = np.zeros((7, T)) 1351 | 1352 | for fr in range(0, T): 1353 | # Read the polydata 1354 | reader = vtk.vtkPolyDataReader() 1355 | filename = '{0}{1:02d}.vtk'.format(contour_name_stem, fr) 1356 | reader.SetFileName(filename) 1357 | reader.Update() 1358 | poly = reader.GetOutput() 1359 | points = poly.GetPoints() 1360 | 1361 | # Calculate the strain for each line 1362 | lines = poly.GetLines() 1363 | n_lines = lines.GetNumberOfCells() 1364 | strain = np.zeros(n_lines) 1365 | vtk_strain = vtk.vtkFloatArray() 1366 | vtk_strain.SetName('Strain') 1367 | lines.InitTraversal() 1368 | for i in range(n_lines): 1369 | ids = vtk.vtkIdList() 1370 | lines.GetNextCell(ids) 1371 | p1 = np.array(points.GetPoint(ids.GetId(0))) 1372 | p2 = np.array(points.GetPoint(ids.GetId(1))) 1373 | d = np.linalg.norm(p1 - p2) 1374 | 1375 | # Strain of this line (unit: %) 1376 | strain[i] = (d - length_ED[i]) / length_ED[i] * 100 1377 | vtk_strain.InsertNextTuple1(strain[i]) 1378 | 1379 | # Save the strain array to the vtk file 1380 | poly.GetCellData().AddArray(vtk_strain) 1381 | writer = vtk.vtkPolyDataWriter() 1382 | writer.SetInputData(poly) 1383 | writer.SetFileName(filename) 1384 | writer.Write() 1385 | os.system('sed -i "1s/4.1/4.0/" {0}'.format(filename)) 1386 | 1387 | # Calculate the segmental and global strains 1388 | for i in range(6): 1389 | table_strain['longit'][i, fr] = np.mean(strain[(seg_id == (i + 1)) & (dir_id == 3)]) 1390 | table_strain['longit'][-1, fr] = np.mean(strain[dir_id == 3]) 1391 | 1392 | for c in ['longit']: 1393 | # Save into csv files 1394 | index = [str(x) for x in np.arange(1, 7)] + ['Global'] 1395 | column = np.arange(0, T) * dt * 1e3 1396 | df = pd.DataFrame(table_strain[c], index=index, columns=column) 1397 | df.to_csv('{0}_{1}.csv'.format(output_name_stem, c)) 1398 | 1399 | 1400 | def cine_2d_la_motion_and_strain_analysis(data_dir, par_dir, output_dir, output_name_stem): 1401 | """ Perform motion tracking and strain analysis for cine MR images. """ 1402 | # Crop the image to save computation for image registration 1403 | # Focus on the left ventricle so that motion tracking is less affected by 1404 | # the movement of RV and LV outflow tract 1405 | padding('{0}/seg4_la_4ch_ED.nii.gz'.format(data_dir), 1406 | '{0}/seg4_la_4ch_ED.nii.gz'.format(data_dir), 1407 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 2, 1) 1408 | padding('{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1409 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1410 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 3, 0) 1411 | padding('{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1412 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1413 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 4, 0) 1414 | padding('{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1415 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1416 | '{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 5, 0) 1417 | auto_crop_image('{0}/seg4_la_4ch_lv_ED.nii.gz'.format(output_dir), 1418 | '{0}/seg4_la_4ch_lv_crop_ED.nii.gz'.format(output_dir), 20) 1419 | os.system('mirtk transform-image {0}/la_4ch.nii.gz {1}/la_4ch_crop.nii.gz ' 1420 | '-target {1}/seg4_la_4ch_lv_crop_ED.nii.gz'.format(data_dir, output_dir)) 1421 | os.system('mirtk transform-image {0}/seg4_la_4ch.nii.gz {1}/seg4_la_4ch_crop.nii.gz ' 1422 | '-target {1}/seg4_la_4ch_lv_crop_ED.nii.gz'.format(data_dir, output_dir)) 1423 | 1424 | # Extract the myocardial contour 1425 | extract_la_myocardial_contour('{0}/seg4_la_4ch_ED.nii.gz'.format(data_dir), 1426 | '{0}/seg_sa_ED.nii.gz'.format(data_dir), 1427 | '{0}/la_4ch_myo_contour_ED.vtk'.format(output_dir)) 1428 | 1429 | # Inter-frame motion estimation 1430 | nim = nib.load('{0}/la_4ch_crop.nii.gz'.format(output_dir)) 1431 | T = nim.header['dim'][4] 1432 | dt = nim.header['pixdim'][4] 1433 | 1434 | # Label class in the segmentation 1435 | label = {'BG': 0, 'LV': 1, 'Myo': 2, 'RV': 3, 'LA': 4, 'RA': 5} 1436 | 1437 | # Split the cine sequence 1438 | split_sequence('{0}/la_4ch_crop.nii.gz'.format(output_dir), 1439 | '{0}/la_4ch_crop_fr'.format(output_dir)) 1440 | 1441 | # Forward image registration 1442 | for fr in range(1, T): 1443 | target_fr = fr - 1 1444 | source_fr = fr 1445 | target = '{0}/la_4ch_crop_fr{1:02d}.nii.gz'.format(output_dir, target_fr) 1446 | source = '{0}/la_4ch_crop_fr{1:02d}.nii.gz'.format(output_dir, source_fr) 1447 | par = '{0}/ffd_cine_la_2d_motion.cfg'.format(par_dir) 1448 | dof = '{0}/ffd_la_4ch_pair_{1:02d}_to_{2:02d}.dof.gz'.format(output_dir, target_fr, source_fr) 1449 | os.system('mirtk register {0} {1} -parin {2} -dofout {3}'.format(target, source, par, dof)) 1450 | 1451 | # Compose forward inter-frame transformation fields 1452 | os.system('cp {0}/ffd_la_4ch_pair_00_to_01.dof.gz ' 1453 | '{0}/ffd_la_4ch_forward_00_to_01.dof.gz'.format(output_dir)) 1454 | for fr in range(2, T): 1455 | dofs = '' 1456 | for k in range(1, fr + 1): 1457 | dof = '{0}/ffd_la_4ch_pair_{1:02d}_to_{2:02d}.dof.gz'.format(output_dir, k - 1, k) 1458 | dofs += dof + ' ' 1459 | dof_out = '{0}/ffd_la_4ch_forward_00_to_{1:02d}.dof.gz'.format(output_dir, fr) 1460 | os.system('mirtk compose-dofs {0} {1} -approximate'.format(dofs, dof_out)) 1461 | 1462 | # Backward image registration 1463 | for fr in range(T - 1, 0, -1): 1464 | target_fr = (fr + 1) % T 1465 | source_fr = fr 1466 | target = '{0}/la_4ch_crop_fr{1:02d}.nii.gz'.format(output_dir, target_fr) 1467 | source = '{0}/la_4ch_crop_fr{1:02d}.nii.gz'.format(output_dir, source_fr) 1468 | par = '{0}/ffd_cine_la_2d_motion.cfg'.format(par_dir) 1469 | dof = '{0}/ffd_la_4ch_pair_{1:02d}_to_{2:02d}.dof.gz'.format(output_dir, target_fr, source_fr) 1470 | os.system('mirtk register {0} {1} -parin {2} -dofout {3}'.format(target, source, par, dof)) 1471 | 1472 | # Compose backward inter-frame transformation fields 1473 | os.system('cp {0}/ffd_la_4ch_pair_00_to_49.dof.gz ' 1474 | '{0}/ffd_la_4ch_backward_00_to_49.dof.gz'.format(output_dir)) 1475 | for fr in range(T - 2, 0, -1): 1476 | dofs = '' 1477 | for k in range(T - 1, fr - 1, -1): 1478 | dof = '{0}/ffd_la_4ch_pair_{1:02d}_to_{2:02d}.dof.gz'.format(output_dir, (k + 1) % T, k) 1479 | dofs += dof + ' ' 1480 | dof_out = '{0}/ffd_la_4ch_backward_00_to_{1:02d}.dof.gz'.format(output_dir, fr) 1481 | os.system('mirtk compose-dofs {0} {1} -approximate'.format(dofs, dof_out)) 1482 | 1483 | # Average the forward and backward transformations 1484 | os.system('mirtk init-dof {0}/ffd_la_4ch_forward_00_to_00.dof.gz'.format(output_dir)) 1485 | os.system('mirtk init-dof {0}/ffd_la_4ch_backward_00_to_00.dof.gz'.format(output_dir)) 1486 | os.system('mirtk init-dof {0}/ffd_la_4ch_00_to_00.dof.gz'.format(output_dir)) 1487 | for fr in range(1, T): 1488 | dof_forward = '{0}/ffd_la_4ch_forward_00_to_{1:02d}.dof.gz'.format(output_dir, fr) 1489 | weight_forward = float(T - fr) / T 1490 | dof_backward = '{0}/ffd_la_4ch_backward_00_to_{1:02d}.dof.gz'.format(output_dir, fr) 1491 | weight_backward = float(fr) / T 1492 | dof_combine = '{0}/ffd_la_4ch_00_to_{1:02d}.dof.gz'.format(output_dir, fr) 1493 | os.system('average_3d_ffd 2 {0} {1} {2} {3} {4}'.format(dof_forward, weight_forward, 1494 | dof_backward, weight_backward, 1495 | dof_combine)) 1496 | 1497 | # Transform the contours and calculate the strain 1498 | for fr in range(0, T): 1499 | os.system('mirtk transform-points {0}/la_4ch_myo_contour_ED.vtk ' 1500 | '{0}/la_4ch_myo_contour_fr{1:02d}.vtk ' 1501 | '-dofin {0}/ffd_la_4ch_00_to_{1:02d}.dof.gz'.format(output_dir, fr)) 1502 | 1503 | # Calculate the strain based on the line length 1504 | evaluate_la_strain_by_length('{0}/la_4ch_myo_contour_fr'.format(output_dir), 1505 | T, dt, output_name_stem) 1506 | 1507 | # Transform the segmentation and evaluate the Dice metric 1508 | eval_dice = False 1509 | if eval_dice: 1510 | split_sequence('{0}/seg4_la_4ch_crop.nii.gz'.format(output_dir), 1511 | '{0}/seg4_la_4ch_crop_fr'.format(output_dir)) 1512 | dice_lv_myo = [] 1513 | 1514 | image_names = [] 1515 | for fr in range(0, T): 1516 | os.system('mirtk transform-image {0}/seg4_la_4ch_crop_fr{1:02d}.nii.gz ' 1517 | '{0}/seg4_la_4ch_crop_warp_ffd_fr{1:02d}.nii.gz ' 1518 | '-dofin {0}/ffd_la_4ch_00_to_{1:02d}.dof.gz ' 1519 | '-target {0}/seg4_la_4ch_crop_fr00.nii.gz'.format(output_dir, fr)) 1520 | image_A = nib.load('{0}/seg4_la_4ch_crop_fr00.nii.gz'.format(output_dir)).get_data() 1521 | image_B = nib.load('{0}/seg4_la_4ch_crop_warp_ffd_fr{1:02d}.nii.gz'.format(output_dir, fr)).get_data() 1522 | dice_lv_myo += [[np_categorical_dice(image_A, image_B, 1), 1523 | np_categorical_dice(image_A, image_B, 2)]] 1524 | image_names += ['{0}/seg4_la_4ch_crop_warp_ffd_fr{1:02d}.nii.gz'.format(output_dir, fr)] 1525 | combine_name = '{0}/seg4_la_4ch_crop_warp_ffd.nii.gz'.format(output_dir) 1526 | make_sequence(image_names, dt, combine_name) 1527 | 1528 | print(np.mean(dice_lv_myo, axis=0)) 1529 | df_dice = pd.DataFrame(dice_lv_myo) 1530 | df_dice.to_csv('{0}/dice_cine_la_4ch_warp_ffd.csv'.format(output_dir), index=None, header=None) 1531 | 1532 | 1533 | def plot_bulls_eye(data, vmin, vmax, cmap='Reds', color_line='black'): 1534 | """ Plot the bull's eye plot. 1535 | data: values for 16 segments 1536 | """ 1537 | if len(data) != 16: 1538 | print('Error: len(data) != 16!') 1539 | exit(0) 1540 | 1541 | # The cartesian coordinate and the polar coordinate 1542 | x = np.linspace(-1, 1, 201) 1543 | y = np.linspace(-1, 1, 201) 1544 | xx, yy = np.meshgrid(x, y) 1545 | r = np.sqrt(xx * xx + yy * yy) 1546 | theta = np.degrees(np.arctan2(yy, xx)) 1547 | 1548 | # The radius and degree for each segment 1549 | R1, R2, R3, R4 = 1, 0.65, 0.3, 0.0 1550 | rad_deg = { 1551 | 1: (R1, R2, 60, 120), 1552 | 2: (R1, R2, 120, 180), 1553 | 3: (R1, R2, -180, -120), 1554 | 4: (R1, R2, -120, -60), 1555 | 5: (R1, R2, -60, 0), 1556 | 6: (R1, R2, 0, 60), 1557 | 7: (R2, R3, 60, 120), 1558 | 8: (R2, R3, 120, 180), 1559 | 9: (R2, R3, -180, -120), 1560 | 10: (R2, R3, -120, -60), 1561 | 11: (R2, R3, -60, 0), 1562 | 12: (R2, R3, 0, 60), 1563 | 13: (R3, R4, 45, 135), 1564 | 14: (R3, R4, 135, -135), 1565 | 15: (R3, R4, -135, -45), 1566 | 16: (R3, R4, -45, 45) 1567 | } 1568 | 1569 | # Plot the segments 1570 | canvas = np.zeros(xx.shape) 1571 | cx, cy = (np.array(xx.shape) - 1) / 2 1572 | sz = cx 1573 | 1574 | for i in range(1, 17): 1575 | val = data[i - 1] 1576 | r1, r2, theta1, theta2 = rad_deg[i] 1577 | if theta2 > theta1: 1578 | mask = ((r < r1) & (r >= r2)) & ((theta >= theta1) & (theta < theta2)) 1579 | else: 1580 | mask = ((r < r1) & (r >= r2)) & ((theta >= theta1) | (theta < theta2)) 1581 | canvas[mask] = val 1582 | plt.imshow(canvas, cmap=cmap, vmin=vmin, vmax=vmax) 1583 | plt.colorbar() 1584 | plt.axis('off') 1585 | plt.gca().invert_yaxis() 1586 | 1587 | # Plot the circles 1588 | for r in [R1, R2, R3]: 1589 | deg = np.linspace(0, 2 * np.pi, 201) 1590 | circle_x = cx + sz * r * np.cos(deg) 1591 | circle_y = cy + sz * r * np.sin(deg) 1592 | plt.plot(circle_x, circle_y, color=color_line) 1593 | 1594 | # Plot the lines between segments 1595 | for i in range(1, 17): 1596 | r1, r2, theta1, theta2 = rad_deg[i] 1597 | line_x = cx + sz * np.array([r1, r2]) * np.cos(np.radians(theta1)) 1598 | line_y = cy + sz * np.array([r1, r2]) * np.sin(np.radians(theta1)) 1599 | plt.plot(line_x, line_y, color=color_line) 1600 | 1601 | # Plot the indicator for RV insertion points 1602 | for i in [2, 4]: 1603 | r1, r2, theta1, theta2 = rad_deg[i] 1604 | x = cx + sz * r1 * np.cos(np.radians(theta1)) 1605 | y = cy + sz * r1 * np.sin(np.radians(theta1)) 1606 | plt.plot([x, x - sz * 0.2], [y, y], color=color_line) 1607 | 1608 | 1609 | def atrium_pass_quality_control(label, label_dict): 1610 | """ Quality control for atrial volume estimation """ 1611 | for l_name, l in label_dict.items(): 1612 | # Criterion: the atrium does not disappear at any time point so that we can 1613 | # measure the area and length. 1614 | T = label.shape[3] 1615 | for t in range(T): 1616 | label_t = label[:, :, 0, t] 1617 | area = np.sum(label_t == l) 1618 | if area == 0: 1619 | print('The area of {0} is 0 at time frame {1}.'.format(l_name, t)) 1620 | return False 1621 | return True 1622 | 1623 | 1624 | def evaluate_atrial_area_length(label, nim, long_axis): 1625 | """ Evaluate the atrial area and length from 2 chamber or 4 chamber view images. """ 1626 | # Area per pixel 1627 | pixdim = nim.header['pixdim'][1:4] 1628 | area_per_pix = pixdim[0] * pixdim[1] * 1e-2 # Unit: cm^2 1629 | 1630 | # Go through the label class 1631 | L = [] 1632 | A = [] 1633 | landmarks = [] 1634 | labs = np.sort(list(set(np.unique(label)) - set([0]))) 1635 | for i in labs: 1636 | # The binary label map 1637 | label_i = (label == i) 1638 | 1639 | # Get the largest component in case we have a bad segmentation 1640 | label_i = get_largest_cc(label_i) 1641 | 1642 | # Go through all the points in the atrium, sort them by the distance along the long-axis. 1643 | points_label = np.nonzero(label_i) 1644 | points = [] 1645 | for j in range(len(points_label[0])): 1646 | x = points_label[0][j] 1647 | y = points_label[1][j] 1648 | points += [[x, y, 1649 | np.dot(np.dot(nim.affine, np.array([x, y, 0, 1]))[:3], long_axis)]] 1650 | points = np.array(points) 1651 | points = points[points[:, 2].argsort()] 1652 | 1653 | # The centre at the top part of the atrium (top third) 1654 | n_points = len(points) 1655 | top_points = points[int(2 * n_points / 3):] 1656 | cx, cy, _ = np.mean(top_points, axis=0) 1657 | 1658 | # The centre at the bottom part of the atrium (bottom third) 1659 | bottom_points = points[:int(n_points / 3)] 1660 | bx, by, _ = np.mean(bottom_points, axis=0) 1661 | 1662 | # Determine the major axis by connecting the geometric centre and the bottom centre 1663 | major_axis = np.array([cx - bx, cy - by]) 1664 | major_axis = major_axis / np.linalg.norm(major_axis) 1665 | 1666 | # Get the intersection between the major axis and the atrium contour 1667 | px = cx + major_axis[0] * 100 1668 | py = cy + major_axis[1] * 100 1669 | qx = cx - major_axis[0] * 100 1670 | qy = cy - major_axis[1] * 100 1671 | 1672 | if np.isnan(px) or np.isnan(py) or np.isnan(qx) or np.isnan(qy): 1673 | return -1, -1, -1 1674 | 1675 | # Note the difference between nifti image index and cv2 image index 1676 | # nifti image index: XY 1677 | # cv2 image index: YX (height, width) 1678 | image_line = np.zeros(label_i.shape) 1679 | cv2.line(image_line, (int(qy), int(qx)), (int(py), int(px)), (1, 0, 0)) 1680 | image_line = label_i & (image_line > 0) 1681 | 1682 | # Sort the intersection points by the distance along long-axis 1683 | # and calculate the length of the intersection 1684 | points_line = np.nonzero(image_line) 1685 | points = [] 1686 | for j in range(len(points_line[0])): 1687 | x = points_line[0][j] 1688 | y = points_line[1][j] 1689 | # World coordinate 1690 | point = np.dot(nim.affine, np.array([x, y, 0, 1]))[:3] 1691 | # Distance along the long-axis 1692 | points += [np.append(point, np.dot(point, long_axis))] 1693 | points = np.array(points) 1694 | if len(points) == 0: 1695 | return -1, -1, -1 1696 | points = points[points[:, 3].argsort(), :3] 1697 | L += [np.linalg.norm(points[-1] - points[0]) * 1e-1] # Unit: cm 1698 | 1699 | # Calculate the area 1700 | A += [np.sum(label_i) * area_per_pix] 1701 | 1702 | # Landmarks of the intersection points 1703 | landmarks += [points[0]] 1704 | landmarks += [points[-1]] 1705 | return A, L, landmarks 1706 | 1707 | 1708 | def aorta_pass_quality_control(image, seg): 1709 | """ Quality control for aortic segmentation """ 1710 | for l_name, l in [('AAo', 1), ('DAo', 2)]: 1711 | # Criterion 1: the aorta does not disappear at some point. 1712 | T = seg.shape[3] 1713 | for t in range(T): 1714 | seg_t = seg[:, :, :, t] 1715 | area = np.sum(seg_t == l) 1716 | if area == 0: 1717 | print('The area of {0} is 0 at time frame {1}.'.format(l_name, t)) 1718 | return False 1719 | 1720 | # Criterion 2: no strong image noise, which affects the segmentation accuracy. 1721 | image_ED = image[:, :, :, 0] 1722 | seg_ED = seg[:, :, :, 0] 1723 | mean_intensity_ED = image_ED[seg_ED == l].mean() 1724 | ratio_thres = 3 1725 | for t in range(T): 1726 | image_t = image[:, :, :, t] 1727 | seg_t = seg[:, :, :, t] 1728 | max_intensity_t = np.max(image_t[seg_t == l]) 1729 | ratio = max_intensity_t / mean_intensity_ED 1730 | if ratio >= ratio_thres: 1731 | print('The image becomes very noisy at time frame {0}.'.format(t)) 1732 | return False 1733 | 1734 | # Criterion 3: no fragmented segmentation 1735 | pixel_thres = 10 1736 | for t in range(T): 1737 | seg_t = seg[:, :, :, t] 1738 | cc, n_cc = skimage.measure.label(seg_t == l, neighbors=8, return_num=True) 1739 | count_cc = 0 1740 | for i in range(1, n_cc + 1): 1741 | binary_cc = (cc == i) 1742 | if np.sum(binary_cc) > pixel_thres: 1743 | # If this connected component has more than certain pixels, count it. 1744 | count_cc += 1 1745 | if count_cc >= 2: 1746 | print('The segmentation has at least two connected components with more than {0} pixels ' 1747 | 'at time frame {1}.'.format(pixel_thres, t)) 1748 | return False 1749 | 1750 | # Criterion 4: no abrupt change of area 1751 | A = np.sum(seg == l, axis=(0, 1, 2)) 1752 | for t in range(T): 1753 | ratio = A[t] / float(A[t-1]) 1754 | if ratio >= 2 or ratio <= 0.5: 1755 | print('There is abrupt change of area at time frame {0}.'.format(t)) 1756 | return False 1757 | return True 1758 | -------------------------------------------------------------------------------- /common/deploy_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | import os, sys, time, math 16 | import numpy as np 17 | import nibabel as nib 18 | import tensorflow.compat.v1 as tf 19 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 20 | from image_utils import rescale_intensity 21 | 22 | 23 | """ Deployment parameters """ 24 | FLAGS = tf.app.flags.FLAGS 25 | tf.app.flags.DEFINE_string('data_dir', '/vol/bitbucket/wbai/own_work/ukbb_cardiac_demo', 26 | 'Path to the data set directory, under which images ' 27 | 'are organised in subdirectories for each subject.') 28 | tf.app.flags.DEFINE_string('model_path', 29 | '', 30 | 'Path to the saved trained model.') 31 | tf.app.flags.DEFINE_boolean('process_seq', True, 32 | 'Process a time sequence of images.') 33 | tf.app.flags.DEFINE_boolean('save_seg', True, 34 | 'Save segmentation.') 35 | tf.app.flags.DEFINE_boolean('seg4', False, 36 | 'Segment all the 4 chambers in long-axis 4 chamber view. ' 37 | 'This seg4 network is trained using 200 subjects from Application 18545.' 38 | 'By default, for all the other tasks (ventricular segmentation' 39 | 'on short-axis images and atrial segmentation on long-axis images,' 40 | 'the networks are trained using 3,975 subjects from Application 2964.') 41 | 42 | 43 | if __name__ == '__main__': 44 | with tf.Session() as sess: 45 | sess.run(tf.global_variables_initializer()) 46 | 47 | # Import the computation graph and restore the variable values 48 | saver = tf.train.import_meta_graph('{0}.meta'.format(FLAGS.model_path)) 49 | saver.restore(sess, '{0}'.format(FLAGS.model_path)) 50 | 51 | print('Start deployment on the data set ...') 52 | start_time = time.time() 53 | 54 | # Process each subject subdirectory 55 | data_list = sorted(os.listdir(FLAGS.data_dir)) 56 | processed_list = [] 57 | table_time = [] 58 | for data in data_list: 59 | print(data) 60 | data_dir = os.path.join(FLAGS.data_dir, data) 61 | 62 | if FLAGS.process_seq: 63 | # Process the temporal sequence 64 | image_name = '{0}/sa.nii'.format(data_dir) 65 | 66 | if not os.path.exists(image_name): 67 | print(' Directory {0} does not contain an image with file ' 68 | 'name {1}. Skip.'.format(data_dir, os.path.basename(image_name))) 69 | continue 70 | 71 | # Read the image 72 | print(' Reading {} ...'.format(image_name)) 73 | nim = nib.load(image_name) 74 | image = nim.get_data() 75 | X, Y, Z, T = image.shape 76 | orig_image = image 77 | 78 | print(' Segmenting full sequence ...') 79 | start_seg_time = time.time() 80 | 81 | # Intensity rescaling 82 | image = rescale_intensity(image, (1, 99)) 83 | 84 | # Prediction (segmentation) 85 | pred = np.zeros(image.shape) 86 | 87 | # Pad the image size to be a factor of 16 so that the 88 | # downsample and upsample procedures in the network will 89 | # result in the same image size at each resolution level. 90 | X2, Y2 = int(math.ceil(X / 16.0)) * 16, int(math.ceil(Y / 16.0)) * 16 91 | x_pre, y_pre = int((X2 - X) / 2), int((Y2 - Y) / 2) 92 | x_post, y_post = (X2 - X) - x_pre, (Y2 - Y) - y_pre 93 | image = np.pad(image, ((x_pre, x_post), (y_pre, y_post), (0, 0), (0, 0)), 'constant') 94 | 95 | # Process each time frame 96 | for t in range(T): 97 | # Transpose the shape to NXYC 98 | image_fr = image[:, :, :, t] 99 | image_fr = np.transpose(image_fr, axes=(2, 0, 1)).astype(np.float32) 100 | image_fr = np.expand_dims(image_fr, axis=-1) 101 | 102 | # Evaluate the network 103 | prob_fr, pred_fr = sess.run(['prob:0', 'pred:0'], 104 | feed_dict={'image:0': image_fr, 'training:0': False}) 105 | 106 | # Transpose and crop segmentation to recover the original size 107 | pred_fr = np.transpose(pred_fr, axes=(1, 2, 0)) 108 | pred_fr = pred_fr[x_pre:x_pre + X, y_pre:y_pre + Y] 109 | pred[:, :, :, t] = pred_fr 110 | 111 | seg_time = time.time() - start_seg_time 112 | print(' Segmentation time = {:3f}s'.format(seg_time)) 113 | table_time += [seg_time] 114 | processed_list += [data] 115 | 116 | # ED frame defaults to be the first time frame. 117 | # Determine ES frame according to the minimum LV volume. 118 | k = {} 119 | k['ED'] = 0 120 | k['ES'] = np.argmin(np.sum(pred == 1, axis=(0, 1, 2))) 121 | 122 | print(' ED frame = {:d}, ES frame = {:d}'.format(k['ED'], k['ES'])) 123 | 124 | # Save the segmentation 125 | if FLAGS.save_seg: 126 | print(' Saving segmentation ...') 127 | nim2 = nib.Nifti1Image(pred, nim.affine) 128 | nim2.header['pixdim'] = nim.header['pixdim'] 129 | seg_name = '{0}/seg_sa.nii.gz'.format(data_dir) 130 | nib.save(nim2, seg_name) 131 | 132 | for fr in ['ED', 'ES']: 133 | nib.save(nib.Nifti1Image(orig_image[:, :, :, k[fr]], nim.affine), 134 | '{0}/sa_{1}.nii.gz'.format(data_dir, fr)) 135 | seg_name = '{0}/seg_sa_{1}.nii.gz'.format(data_dir, fr) 136 | nib.save(nib.Nifti1Image(pred[:, :, :, k[fr]], nim.affine), seg_name) 137 | 138 | print('Average segmentation time = {:.3f}s per sequence'.format(np.mean(table_time))) 139 | process_time = time.time() - start_time 140 | print('Including image I/O, CUDA resource allocation, ' 141 | 'it took {:.3f}s for processing {:d} subjects ({:.3f}s per subjects).'.format( 142 | process_time, len(processed_list), process_time / len(processed_list))) 143 | -------------------------------------------------------------------------------- /common/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import cv2 16 | import numpy as np 17 | import nibabel as nib 18 | import tensorflow as tf 19 | from scipy import ndimage 20 | import scipy.ndimage.measurements as measure 21 | 22 | 23 | def tf_categorical_accuracy(pred, truth): 24 | """ Accuracy metric """ 25 | return tf.reduce_mean(tf.cast(tf.equal(pred, truth), dtype=tf.float32)) 26 | 27 | 28 | def tf_categorical_dice(pred, truth, k): 29 | """ Dice overlap metric for label k """ 30 | A = tf.cast(tf.equal(pred, k), dtype=tf.float32) 31 | B = tf.cast(tf.equal(truth, k), dtype=tf.float32) 32 | return 2 * tf.reduce_sum(tf.multiply(A, B)) / (tf.reduce_sum(A) + tf.reduce_sum(B)) 33 | 34 | 35 | def crop_image(image, cx, cy, size): 36 | """ Crop a 3D image using a bounding box centred at (cx, cy) with specified size """ 37 | X, Y = image.shape[:2] 38 | r = int(size / 2) 39 | x1, x2 = cx - r, cx + r 40 | y1, y2 = cy - r, cy + r 41 | x1_, x2_ = max(x1, 0), min(x2, X) 42 | y1_, y2_ = max(y1, 0), min(y2, Y) 43 | # Crop the image 44 | crop = image[x1_: x2_, y1_: y2_] 45 | # Pad the image if the specified size is larger than the input image size 46 | if crop.ndim == 3: 47 | crop = np.pad(crop, 48 | ((x1_ - x1, x2 - x2_), (y1_ - y1, y2 - y2_), (0, 0)), 49 | 'constant') 50 | elif crop.ndim == 4: 51 | crop = np.pad(crop, 52 | ((x1_ - x1, x2 - x2_), (y1_ - y1, y2 - y2_), (0, 0), (0, 0)), 53 | 'constant') 54 | else: 55 | print('Error: unsupported dimension, crop.ndim = {0}.'.format(crop.ndim)) 56 | exit(0) 57 | return crop 58 | 59 | 60 | def normalise_intensity(image, thres_roi=10.0): 61 | """ Normalise the image intensity by the mean and standard deviation """ 62 | val_l = np.percentile(image, thres_roi) 63 | roi = (image >= val_l) 64 | mu, sigma = np.mean(image[roi]), np.std(image[roi]) 65 | eps = 1e-6 66 | image2 = (image - mu) / (sigma + eps) 67 | return image2 68 | 69 | 70 | def rescale_intensity(image, thres=(1.0, 99.0)): 71 | """ Rescale the image intensity to the range of [0, 1] """ 72 | val_l, val_h = np.percentile(image, thres) 73 | image2 = image 74 | image2[image < val_l] = val_l 75 | image2[image > val_h] = val_h 76 | image2 = (image2.astype(np.float32) - val_l) / (val_h - val_l) 77 | return image2 78 | 79 | 80 | def data_augmenter(image, label, shift, rotate, scale, intensity, flip): 81 | """ 82 | Online data augmentation 83 | Perform affine transformation on image and label, 84 | which are 4D tensor of shape (N, H, W, C) and 3D tensor of shape (N, H, W). 85 | """ 86 | image2 = np.zeros(image.shape, dtype=np.float32) 87 | label2 = np.zeros(label.shape, dtype=np.int32) 88 | for i in range(image.shape[0]): 89 | # For each image slice, generate random affine transformation parameters 90 | # using the Gaussian distribution 91 | shift_val = [np.clip(np.random.normal(), -3, 3) * shift, 92 | np.clip(np.random.normal(), -3, 3) * shift] 93 | rotate_val = np.clip(np.random.normal(), -3, 3) * rotate 94 | scale_val = 1 + np.clip(np.random.normal(), -3, 3) * scale 95 | intensity_val = 1 + np.clip(np.random.normal(), -3, 3) * intensity 96 | 97 | # Apply the affine transformation (rotation + scale + shift) to the image 98 | row, col = image.shape[1:3] 99 | M = cv2.getRotationMatrix2D((row / 2, col / 2), rotate_val, 1.0 / scale_val) 100 | M[:, 2] += shift_val 101 | for c in range(image.shape[3]): 102 | image2[i, :, :, c] = ndimage.interpolation.affine_transform(image[i, :, :, c], 103 | M[:, :2], M[:, 2], order=1) 104 | 105 | # Apply the affine transformation (rotation + scale + shift) to the label map 106 | label2[i, :, :] = ndimage.interpolation.affine_transform(label[i, :, :], 107 | M[:, :2], M[:, 2], order=0) 108 | 109 | # Apply intensity variation 110 | image2[i] *= intensity_val 111 | 112 | # Apply random horizontal or vertical flipping 113 | if flip: 114 | if np.random.uniform() >= 0.5: 115 | image2[i] = image2[i, ::-1, :, :] 116 | label2[i] = label2[i, ::-1, :] 117 | else: 118 | image2[i] = image2[i, :, ::-1, :] 119 | label2[i] = label2[i, :, ::-1] 120 | return image2, label2 121 | 122 | 123 | def aortic_data_augmenter(image, label, shift, rotate, scale, intensity, flip): 124 | """ 125 | Online data augmentation 126 | Perform affine transformation on image and label, 127 | 128 | image: NXYC 129 | label: NXY 130 | """ 131 | image2 = np.zeros(image.shape, dtype=np.float32) 132 | label2 = np.zeros(label.shape, dtype=np.int32) 133 | 134 | # For N image. which come come from the same subject in the LSTM model, 135 | # generate the same random affine transformation parameters. 136 | shift_val = [np.clip(np.random.normal(), -3, 3) * shift, 137 | np.clip(np.random.normal(), -3, 3) * shift] 138 | rotate_val = np.clip(np.random.normal(), -3, 3) * rotate 139 | scale_val = 1 + np.clip(np.random.normal(), -3, 3) * scale 140 | intensity_val = 1 + np.clip(np.random.normal(), -3, 3) * intensity 141 | 142 | # The affine transformation (rotation + scale + shift) 143 | row, col = image.shape[1:3] 144 | M = cv2.getRotationMatrix2D( 145 | (row / 2, col / 2), rotate_val, 1.0 / scale_val) 146 | M[:, 2] += shift_val 147 | 148 | # Apply the transformation to the image 149 | for i in range(image.shape[0]): 150 | for c in range(image.shape[3]): 151 | image2[i, :, :, c] = ndimage.interpolation.affine_transform( 152 | image[i, :, :, c], M[:, :2], M[:, 2], order=1) 153 | 154 | label2[i, :, :] = ndimage.interpolation.affine_transform( 155 | label[i, :, :], M[:, :2], M[:, 2], order=0) 156 | 157 | # Apply intensity variation 158 | image2[i] *= intensity_val 159 | 160 | # Apply random horizontal or vertical flipping 161 | if flip: 162 | if np.random.uniform() >= 0.5: 163 | image2[i] = image2[i, ::-1, :, :] 164 | label2[i] = label2[i, ::-1, :] 165 | else: 166 | image2[i] = image2[i, :, ::-1, :] 167 | label2[i] = label2[i, :, ::-1] 168 | return image2, label2 169 | 170 | 171 | def np_categorical_dice(pred, truth, k): 172 | """ Dice overlap metric for label k """ 173 | A = (pred == k).astype(np.float32) 174 | B = (truth == k).astype(np.float32) 175 | return 2 * np.sum(A * B) / (np.sum(A) + np.sum(B)) 176 | 177 | 178 | def distance_metric(seg_A, seg_B, dx): 179 | """ 180 | Measure the distance errors between the contours of two segmentations. 181 | The manual contours are drawn on 2D slices. 182 | We calculate contour to contour distance for each slice. 183 | """ 184 | table_md = [] 185 | table_hd = [] 186 | X, Y, Z = seg_A.shape 187 | for z in range(Z): 188 | # Binary mask at this slice 189 | slice_A = seg_A[:, :, z].astype(np.uint8) 190 | slice_B = seg_B[:, :, z].astype(np.uint8) 191 | 192 | # The distance is defined only when both contours exist on this slice 193 | if np.sum(slice_A) > 0 and np.sum(slice_B) > 0: 194 | # Find contours and retrieve all the points 195 | _, contours, _ = cv2.findContours(cv2.inRange(slice_A, 1, 1), 196 | cv2.RETR_EXTERNAL, 197 | cv2.CHAIN_APPROX_NONE) 198 | pts_A = contours[0] 199 | for i in range(1, len(contours)): 200 | pts_A = np.vstack((pts_A, contours[i])) 201 | 202 | _, contours, _ = cv2.findContours(cv2.inRange(slice_B, 1, 1), 203 | cv2.RETR_EXTERNAL, 204 | cv2.CHAIN_APPROX_NONE) 205 | pts_B = contours[0] 206 | for i in range(1, len(contours)): 207 | pts_B = np.vstack((pts_B, contours[i])) 208 | 209 | # Distance matrix between point sets 210 | M = np.zeros((len(pts_A), len(pts_B))) 211 | for i in range(len(pts_A)): 212 | for j in range(len(pts_B)): 213 | M[i, j] = np.linalg.norm(pts_A[i, 0] - pts_B[j, 0]) 214 | 215 | # Mean distance and hausdorff distance 216 | md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx 217 | hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx 218 | table_md += [md] 219 | table_hd += [hd] 220 | 221 | # Return the mean distance and Hausdorff distance across 2D slices 222 | mean_md = np.mean(table_md) if table_md else None 223 | mean_hd = np.mean(table_hd) if table_hd else None 224 | return mean_md, mean_hd 225 | 226 | 227 | def get_largest_cc(binary): 228 | """ Get the largest connected component in the foreground. """ 229 | cc, n_cc = measure.label(binary) 230 | max_n = -1 231 | max_area = 0 232 | for n in range(1, n_cc + 1): 233 | area = np.sum(cc == n) 234 | if area > max_area: 235 | max_area = area 236 | max_n = n 237 | largest_cc = (cc == max_n) 238 | return largest_cc 239 | 240 | 241 | def remove_small_cc(binary, thres=10): 242 | """ Remove small connected component in the foreground. """ 243 | cc, n_cc = measure.label(binary) 244 | binary2 = np.copy(binary) 245 | for n in range(1, n_cc + 1): 246 | area = np.sum(cc == n) 247 | if area < thres: 248 | binary2[cc == n] = 0 249 | return binary2 250 | 251 | 252 | def split_sequence(image_name, output_name): 253 | """ Split an image sequence into a number of time frames. """ 254 | nim = nib.load(image_name) 255 | T = nim.header['dim'][4] 256 | affine = nim.affine 257 | image = nim.get_data() 258 | 259 | for t in range(T): 260 | image_fr = image[:, :, :, t] 261 | nim2 = nib.Nifti1Image(image_fr, affine) 262 | nib.save(nim2, '{0}{1:02d}.nii.gz'.format(output_name, t)) 263 | 264 | 265 | def make_sequence(image_names, dt, output_name): 266 | """ Combine a number of time frames into one image sequence. """ 267 | nim = nib.load(image_names[0]) 268 | affine = nim.affine 269 | X, Y, Z = nim.header['dim'][1:4] 270 | T = len(image_names) 271 | image = np.zeros((X, Y, Z, T)) 272 | 273 | for t in range(T): 274 | image[:, :, :, t] = nib.load(image_names[t]).get_data() 275 | 276 | nim2 = nib.Nifti1Image(image, affine) 277 | nim2.header['pixdim'][4] = dt 278 | nib.save(nim2, output_name) 279 | 280 | 281 | def split_volume(image_name, output_name): 282 | """ Split an image volume into a number of slices. """ 283 | nim = nib.load(image_name) 284 | Z = nim.header['dim'][3] 285 | affine = nim.affine 286 | image = nim.get_data() 287 | 288 | for z in range(Z): 289 | image_slice = image[:, :, z] 290 | image_slice = np.expand_dims(image_slice, axis=2) 291 | affine2 = np.copy(affine) 292 | affine2[:3, 3] += z * affine2[:3, 2] 293 | nim2 = nib.Nifti1Image(image_slice, affine2) 294 | nib.save(nim2, '{0}{1:02d}.nii.gz'.format(output_name, z)) 295 | 296 | 297 | def image_apply_mask(input_name, output_name, mask_image, pad_value=-1): 298 | # Assign the background voxels (mask == 0) with pad_value 299 | nim = nib.load(input_name) 300 | image = nim.get_data() 301 | image[mask_image == 0] = pad_value 302 | nim2 = nib.Nifti1Image(image, nim.affine) 303 | nib.save(nim2, output_name) 304 | 305 | 306 | def padding(input_A_name, input_B_name, output_name, value_in_B, value_output): 307 | nim = nib.load(input_A_name) 308 | image_A = nim.get_data() 309 | image_B = nib.load(input_B_name).get_data() 310 | image_A[image_B == value_in_B] = value_output 311 | nim2 = nib.Nifti1Image(image_A, nim.affine) 312 | nib.save(nim2, output_name) 313 | 314 | 315 | def auto_crop_image(input_name, output_name, reserve): 316 | nim = nib.load(input_name) 317 | image = nim.get_data() 318 | X, Y, Z = image.shape[:3] 319 | 320 | # Detect the bounding box of the foreground 321 | idx = np.nonzero(image > 0) 322 | x1, x2 = idx[0].min() - reserve, idx[0].max() + reserve + 1 323 | y1, y2 = idx[1].min() - reserve, idx[1].max() + reserve + 1 324 | z1, z2 = idx[2].min() - reserve, idx[2].max() + reserve + 1 325 | x1, x2 = max(x1, 0), min(x2, X) 326 | y1, y2 = max(y1, 0), min(y2, Y) 327 | z1, z2 = max(z1, 0), min(z2, Z) 328 | print('Bounding box') 329 | print(' bottom-left corner = ({},{},{})'.format(x1, y1, z1)) 330 | print(' top-right corner = ({},{},{})'.format(x2, y2, z2)) 331 | 332 | # Crop the image 333 | image = image[x1:x2, y1:y2, z1:z2] 334 | 335 | # Update the affine matrix 336 | affine = nim.affine 337 | affine[:3, 3] = np.dot(affine, np.array([x1, y1, z1, 1]))[:3] 338 | nim2 = nib.Nifti1Image(image, affine) 339 | nib.save(nim2, output_name) 340 | -------------------------------------------------------------------------------- /common/network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow.compat.v1 as tf 16 | tf.disable_v2_behavior() 17 | import numpy as np 18 | 19 | 20 | def conv2d_bn_relu(x, filters, training, kernel_size=3, strides=1, trainable=True): 21 | """ Basic Conv + BN + ReLU unit """ 22 | x_conv = tf.layers.conv2d(x, filters=filters, kernel_size=kernel_size, 23 | strides=strides, padding='same', use_bias=False, trainable=trainable) 24 | x_bn = tf.layers.batch_normalization(x_conv, training=training) 25 | x_relu = tf.nn.relu(x_bn) 26 | return x_relu 27 | 28 | 29 | def linear_1d(sz): 30 | """ 1D linear interpolation kernel """ 31 | if sz % 2 == 0: 32 | raise NotImplementedError('`Linear kernel` requires odd filter size.') 33 | c = int((sz + 1) / 2) 34 | h = np.array(list(range(1, c + 1)) + list(range(c - 1, 0, -1)), dtype=np.float32) 35 | h /= float(c) 36 | return h 37 | 38 | 39 | def linear_2d(sz): 40 | """ 2D linear interpolation kernel """ 41 | W = np.ones((sz, sz), dtype=np.float32) 42 | h = linear_1d(sz) 43 | for i in range(sz): 44 | W[i, :] *= h 45 | for j in range(sz): 46 | W[:, j] *= h 47 | return W 48 | 49 | 50 | def transpose_upsample2d(x, factor, constant=True): 51 | """ 2D upsampling operator using transposed convolution """ 52 | x_shape = tf.shape(x) 53 | output_shape = tf.stack([x_shape[0], x_shape[1] * factor, x_shape[2] * factor, x.shape[3].value]) 54 | 55 | # The bilinear interpolation weight for the upsampling filter 56 | sz = factor * 2 - 1 57 | W = linear_2d(sz) 58 | n = x.shape[3].value 59 | filt_val = np.zeros((sz, sz, n, n), dtype=np.float32) 60 | for i in range(n): 61 | filt_val[:, :, i, i] = W 62 | 63 | # Currently, we simply use the fixed bilinear interpolation weights. 64 | # However, it is possible to set the filt to a trainable variable. 65 | if constant: 66 | filt = tf.constant(filt_val, dtype=tf.float32) 67 | else: 68 | filt = tf.Variable(filt_val, dtype=tf.float32) 69 | 70 | # Currently, if output_shape is an unknown shape, conv2d_transpose() 71 | # will output an unknown shape during graph construction. This will be 72 | # a problem for the next step tf.concat(), which requires a known shape. 73 | # A workaround is to reshape this tensor to the expected shape size. 74 | # Refer to https://github.com/tensorflow/tensorflow/issues/833#issuecomment-278016198 75 | x_up = tf.nn.conv2d_transpose(x, filter=filt, output_shape=output_shape, 76 | strides=[1, factor, factor, 1], padding='SAME') 77 | x_out = tf.reshape(x_up, 78 | (x_shape[0], x_shape[1] * factor, x_shape[2] * factor, x.shape[3].value)) 79 | return x_out 80 | 81 | 82 | def build_FCN(image, n_class, n_level, n_filter, n_block, training, same_dim=32, fc=64, frozenLayers=0): 83 | """ 84 | Build a fully convolutional network for segmenting an input image 85 | into n_class classes and return the logits map. 86 | """ 87 | net = {} 88 | x = image 89 | 90 | layer = 1 91 | # Learn fine-to-coarse features at each resolution level 92 | for l in range(0, n_level): 93 | with tf.name_scope('conv{0}'.format(l)): 94 | # If this is the first level (l = 0), keep the resolution. 95 | # Otherwise, convolve with a stride of 2, i.e. downsample 96 | # by a factor of 2。 97 | strides = 1 if l == 0 else 2 98 | # For each resolution level, perform n_block[l] times convolutions 99 | x = conv2d_bn_relu(x, filters=n_filter[l], training=training, kernel_size=3, strides=strides, trainable=(layer>frozenLayers)) 100 | layer +=1 101 | for i in range(1, n_block[l]): 102 | x = conv2d_bn_relu(x, filters=n_filter[l], training=training, kernel_size=3, trainable=(layer>frozenLayers)) 103 | layer +=1 104 | net['conv{0}'.format(l)] = x 105 | 106 | # Before upsampling back to the original resolution level, map all the 107 | # feature maps to have same_dim dimensions. Otherwise, the upsampled 108 | # feature maps will have both a large size (e.g. 192 x 192) and a high 109 | # dimension (e.g. 256 features), which may exhaust the GPU memory (e.g. 110 | # 12 GB for Nvidia Titan K80). 111 | # Exemplar calculation: 112 | # batch size 20 x image size 192 x 192 x feature dimension 256 x floating data type 4 113 | # = 755 MB for a feature map 114 | # Apart from this, there is also associated memory of the same size 115 | # used for gradient calculation. 116 | layerSameDim = 1 117 | with tf.name_scope('same_dim'): 118 | for l in range(0, n_level): 119 | net['conv{0}_same_dim'.format(l)] = conv2d_bn_relu(net['conv{0}'.format(l)], filters=same_dim, 120 | training=training, kernel_size=1, trainable=(layerSameDim>frozenLayers)) 121 | layerSameDim += n_block[l] 122 | 123 | # Upsample the feature maps at each resolution level to the original resolution 124 | with tf.name_scope('up'): 125 | net['conv0_up'] = net['conv0_same_dim'] 126 | for l in range(1, n_level): 127 | net['conv{0}_up'.format(l)] = transpose_upsample2d(net['conv{0}_same_dim'.format(l)], factor=int(pow(2, l))) 128 | 129 | 130 | # Concatenate the multi-level feature maps 131 | with tf.name_scope('concat'): 132 | list_up = [] 133 | for l in range(0, n_level): 134 | list_up += [net['conv{0}_up'.format(l)]] 135 | net['concat'] = tf.concat(list_up, axis=-1) 136 | 137 | # Perform prediction using the multi-level feature maps 138 | with tf.name_scope('out'): 139 | # We only calculate logits, instead of softmax here because the loss 140 | # function tf.nn.softmax_cross_entropy() accepts the unscaled logits 141 | # and performs softmax internally for efficiency and numerical stability 142 | # reasons. Refer to https://github.com/tensorflow/tensorflow/issues/2462 143 | x = net['concat'] 144 | x = conv2d_bn_relu(x, filters=fc, training=training, kernel_size=1, trainable=(layer>=frozenLayers)) 145 | layer += 1 146 | x = conv2d_bn_relu(x, filters=fc, training=training, kernel_size=1, trainable=(layer>=frozenLayers)) 147 | layer += 1 148 | logits = tf.layers.conv2d(x, filters=n_class, kernel_size=1, padding='same') 149 | return logits 150 | 151 | -------------------------------------------------------------------------------- /common/squeezeNiis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | import os 4 | 5 | def squeezeNii(root, file): 6 | img = nib.load(root + '/' + file) 7 | newFile = file.replace('.nii','_temp.nii') 8 | nib.save(nib.Nifti1Image(np.squeeze(img.dataobj),img.affine), root+'/'+newFile) 9 | 10 | 11 | # This is to get the directory that the program 12 | # is currently running in. 13 | dir_path = '/data/data_mrcv/45_DATA_HUMANS/CHEST/STUDIES/2020_CARDIAC_DL_SEGMENTATION_CORRADO/test' 14 | 15 | for root, dirs, files in os.walk(dir_path): 16 | print(root) 17 | for file in files: 18 | if file.endswith('.nii'): 19 | squeezeNii(root, str(file)) 20 | os.remove(root + '/' + str(file)) 21 | os.rename(root + '/' + str(file).replace('.nii','_temp.nii'), root + '/' + str(file)) 22 | -------------------------------------------------------------------------------- /common/train_network_UW.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # Modified in 2020 by Philip Corrado. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the 'License'); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an 'AS IS' BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import os, sys, time, random 17 | import numpy as np 18 | import nibabel as nib 19 | import tensorflow.compat.v1 as tf 20 | tf.disable_v2_behavior() 21 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 22 | from network import build_FCN 23 | from image_utils import tf_categorical_accuracy, rescale_intensity, data_augmenter, crop_image 24 | 25 | 26 | imgSize = 192 # Image size after interpolating 27 | train_batch_size = 20 # Number of images for each training batch 28 | validation_batch_size = 20 # Number of images for each validation batch 29 | train_iteration = 10000 # Number of training iterations 30 | num_filter = 16 #Number of filters for the first convolution layer 31 | num_level = 5 # Number of network levels 32 | learning_rate = 1e-3 # Learning rate 33 | dataset_dir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/' # Path to the dataset directory 34 | frozenLayers = 4 # number of layers to freeze for transfer learning 35 | log_dir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/log_UW_layers_frozen' # Directory for saving the log file 36 | checkpoint_dir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/model_UW_layers_frozen' # Directory for saving the trained model 37 | #Path to the saved trained model 38 | 39 | 40 | 41 | def get_random_batch(filename_list, batch_size, image_size=192, data_augmentation=False, 42 | shift=0.0, rotate=0.0, scale=0.0, intensity=0.0, flip=False): 43 | # Randomly select batch_size images from filename_list 44 | n_file = len(filename_list) 45 | n_selected = 0 46 | images = [] 47 | labels = [] 48 | 49 | rand_index = random.randrange(n_file) 50 | image_name, label_name = filename_list[rand_index] 51 | if os.path.exists(image_name) and os.path.exists(label_name): 52 | print(' Select {0} {1}'.format(image_name, label_name)) 53 | 54 | # Read image and label 55 | image = nib.load(image_name).get_data() 56 | label = nib.load(label_name).get_data() 57 | 58 | # Handle exceptions 59 | if image.shape != label.shape: 60 | print('Error: mismatched size, image.shape = {0}, ' 61 | 'label.shape = {1}'.format(image.shape, label.shape)) 62 | print('Skip {0}, {1}'.format(image_name, label_name)) 63 | return 64 | 65 | if image.max() < 1e-6: 66 | print('Error: blank image, image.max = {0}'.format(image.max())) 67 | print('Skip {0} {1}'.format(image_name, label_name)) 68 | return 69 | 70 | # Normalise the image size 71 | X, Y, Z, T = image.shape 72 | 73 | if X != image_size or Y != image_size: 74 | cx, cy = int(X / 2), int(Y / 2) 75 | image = crop_image(image, cx, cy, image_size) 76 | label = crop_image(label, cx, cy, image_size) 77 | 78 | # Intensity rescaling 79 | image = rescale_intensity(image, (1.0, 99.0)) 80 | 81 | while n_selected < batch_size: 82 | randZ = random.randrange(Z) 83 | randT = random.randrange(T) 84 | images += [image[:, :, randZ, randT]] 85 | labels += [label[:, :, randZ, randT]] 86 | # Increase the counter 87 | n_selected += 1 88 | 89 | 90 | # Convert to a numpy array 91 | images = np.array(images, dtype=np.float32) 92 | labels = np.array(labels, dtype=np.int32) 93 | 94 | # Add the channel dimension 95 | # tensorflow by default assumes NHWC format 96 | images = np.expand_dims(images, axis=3) 97 | 98 | # Perform data augmentation 99 | if data_augmentation: 100 | images, labels = data_augmenter(images, labels, 101 | shift=shift, rotate=rotate, 102 | scale=scale, 103 | intensity=intensity, flip=flip) 104 | return images, labels 105 | 106 | 107 | if __name__ == '__main__': 108 | # Go through each subset (training, validation, test) under the data directory 109 | # and list the file names of the subjects 110 | data_list = {} 111 | for k in ['train', 'test']: 112 | subset_dir = os.path.join(dataset_dir, k) 113 | data_list[k] = [] 114 | 115 | for data in sorted(os.listdir(subset_dir)): 116 | data_dir = os.path.join(subset_dir, data) 117 | image_name = '{0}/sa.nii'.format(data_dir) 118 | label_name = '{0}/label_sa.nii'.format(data_dir) 119 | if os.path.exists(image_name) and os.path.exists(label_name): 120 | data_list[k] += [[image_name, label_name]] 121 | 122 | 123 | # Prepare tensors for the image and label map pairs 124 | # Use int32 for label_pl as tf.one_hot uses int32 125 | image_pl = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='image') 126 | label_pl = tf.placeholder(tf.int32, shape=[None, None, None], name='label') 127 | 128 | # Print out the placeholders' names, which will be useful when deploying the network 129 | print('Placeholder image_pl.name = ' + image_pl.name) 130 | print('Placeholder label_pl.name = ' + label_pl.name) 131 | 132 | # Placeholder for the training phase 133 | # This flag is important for the batch_normalization layer to function properly. 134 | training_pl = tf.placeholder(tf.bool, shape=[], name='training') 135 | print('Placeholder training_pl.name = ' + training_pl.name) 136 | 137 | n_class = 4 138 | 139 | # The number of resolution levels 140 | n_level = num_level 141 | 142 | # The number of filters at each resolution level 143 | # Follow the VGG philosophy, increasing the dimension 144 | # by a factor of 2 for each level 145 | n_filter = [] 146 | for i in range(n_level): 147 | n_filter += [num_filter * pow(2, i)] 148 | print('Number of filters at each level =', n_filter) 149 | print('Note: The connection between neurons is proportional to ' 150 | 'n_filter * n_filter. Increasing n_filter by a factor of 2 ' 151 | 'will increase the number of parameters by a factor of 4. ' 152 | 'So it is better to start experiments with a small n_filter ' 153 | 'and increase it later.') 154 | 155 | # Build the neural network, which outputs the logits, 156 | # i.e. the unscaled values just before the softmax layer, 157 | # which will then normalise the logits into the probabilities. 158 | n_block = [2, 2, 3, 3, 3] 159 | logits = build_FCN(image_pl, n_class, n_level=n_level, 160 | n_filter=n_filter, n_block=n_block, 161 | training=training_pl, same_dim=32, fc=64) 162 | 163 | # The softmax probability and the predicted segmentation 164 | prob = tf.nn.softmax(logits, name='prob') 165 | pred = tf.cast(tf.argmax(prob, axis=-1), dtype=tf.int32, name='pred') 166 | print('prob.name = ' + prob.name) 167 | print('pred.name = ' + pred.name) 168 | 169 | # Loss 170 | label_1hot = tf.one_hot(indices=label_pl, depth=n_class) 171 | label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=label_1hot, logits=logits) 172 | loss = tf.reduce_mean(label_loss) 173 | 174 | # Evaluation metrics 175 | accuracy = tf_categorical_accuracy(pred, label_pl) 176 | 177 | # Optimiser 178 | lr = learning_rate 179 | 180 | # We need to add the operators associated with batch_normalization 181 | # to the optimiser, according to 182 | # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization 183 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 184 | with tf.control_dependencies(update_ops): 185 | print('Using Adam optimizer.') 186 | train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) 187 | 188 | # Model name and directory 189 | model_name = 'FCN_sa_level{1}_filter{2}_{3}_batch{4}_iter{5}_lr{6}'.format( 190 | n_level, n_filter[0], ''.join([str(x) for x in n_block]), 191 | train_batch_size, train_iteration, learning_rate) 192 | model_dir = os.path.join(checkpoint_dir, model_name) 193 | if not os.path.exists(model_dir): 194 | os.makedirs(model_dir) 195 | 196 | # Start the tensorflow session 197 | with tf.Session() as sess: 198 | print('Start training...') 199 | start_time = time.time() 200 | 201 | # Create a saver 202 | saver = tf.train.Saver(max_to_keep=20) 203 | 204 | # Summary writer 205 | summary_dir = os.path.join(log_dir, model_name) 206 | if os.path.exists(summary_dir): 207 | os.system('rm -rf {0}'.format(summary_dir)) 208 | train_writer = tf.summary.FileWriter(os.path.join(summary_dir, 'train'), graph=sess.graph) 209 | 210 | # Initialise variables 211 | sess.run(tf.global_variables_initializer()) 212 | 213 | # Iterate 214 | for iteration in range(1, 1 + train_iteration): 215 | # For each iteration, we randomly choose a batch of subjects 216 | print('Iteration {0}: training...'.format(iteration)) 217 | start_time_iter = time.time() 218 | 219 | images, labels = get_random_batch(data_list['train'], 220 | train_batch_size, 221 | image_size=imgSize, 222 | data_augmentation=True, 223 | shift=0, rotate=90, scale=0.2, 224 | intensity=0, flip=False) 225 | 226 | # Stochastic optimisation using this batch 227 | _, train_loss, train_acc = sess.run([train_op, loss, accuracy], 228 | {image_pl: images, label_pl: labels, training_pl: True}) 229 | summary = tf.Summary() 230 | summary.value.add(tag='loss', simple_value=train_loss) 231 | summary.value.add(tag='accuracy', simple_value=train_acc) 232 | train_writer.add_summary(summary, iteration) 233 | 234 | # Print the results for this iteration 235 | print('Iteration {} of {} took {:.3f}s'.format(iteration, train_iteration, 236 | time.time() - start_time_iter)) 237 | print(' training loss:\t\t{:.6f}'.format(train_loss)) 238 | print(' training accuracy:\t\t{:.2f}%'.format(train_acc * 100)) 239 | 240 | # Save model 241 | saver.save(sess, save_path=os.path.join(model_dir, '{0}.ckpt'.format(model_name)), 242 | global_step=iteration) 243 | 244 | # Close the summary writers 245 | train_writer.close() 246 | print('Training took {:.3f}s in total.\n'.format(time.time() - start_time)) 247 | -------------------------------------------------------------------------------- /common/train_network_UW_fine_tune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # Modified in 2020 by Philip Corrado. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the 'License'); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an 'AS IS' BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | import os, sys, time, random 17 | import numpy as np 18 | import nibabel as nib 19 | import tensorflow.compat.v1 as tf 20 | tf.disable_v2_behavior() 21 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 22 | from network import build_FCN 23 | from image_utils import tf_categorical_accuracy, tf_categorical_dice, rescale_intensity, data_augmenter, crop_image 24 | 25 | 26 | imgSize = 192 # Image size after interpolating 27 | train_batch_size = 20 # Number of images for each training batch 28 | validation_batch_size = 20 # Number of images for each validation batch 29 | train_iteration = 10000 # Number of training iterations 30 | num_filter = 16 #Number of filters for the first convolution layer 31 | num_level = 5 # Number of network levels 32 | learning_rate = 1e-3 # Learning rate 33 | old_model_path = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/Cardiac-Segmentation-4D-Flow/ukbb_trained_model/FCN_sa' 34 | dataset_dir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/' # Path to the dataset directory 35 | frozenLayers = 15 # number of layers to freeze for transfer learning 36 | log_dir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/log_{}_layers_frozen'.format(frozenLayers) # Directory for saving the log file 37 | checkpoint_dir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/model_{}_layers_frozen'.format(frozenLayers) # Directory for saving the trained model 38 | #Path to the saved trained model 39 | 40 | 41 | def get_random_batch(filename_list, batch_size, image_size=192, data_augmentation=False, 42 | shift=0.0, rotate=0.0, scale=0.0, intensity=0.0, flip=False): 43 | # Randomly select batch_size images from filename_list 44 | n_file = len(filename_list) 45 | n_selected = 0 46 | images = [] 47 | labels = [] 48 | 49 | rand_index = random.randrange(n_file) 50 | image_name, label_name = filename_list[rand_index] 51 | if os.path.exists(image_name) and os.path.exists(label_name): 52 | print(' Select {0} {1}'.format(image_name, label_name)) 53 | 54 | # Read image and label 55 | image = nib.load(image_name).get_data() 56 | label = nib.load(label_name).get_data() 57 | 58 | # Handle exceptions 59 | if image.shape != label.shape: 60 | print('Error: mismatched size, image.shape = {0}, ' 61 | 'label.shape = {1}'.format(image.shape, label.shape)) 62 | print('Skip {0}, {1}'.format(image_name, label_name)) 63 | return 64 | 65 | if image.max() < 1e-6: 66 | print('Error: blank image, image.max = {0}'.format(image.max())) 67 | print('Skip {0} {1}'.format(image_name, label_name)) 68 | return 69 | 70 | # Normalise the image size 71 | X, Y, Z, T = image.shape 72 | if X != image_size or Y != image_size: 73 | cx, cy = int(X / 2), int(Y / 2) 74 | image = crop_image(image, cx, cy, image_size) 75 | label = crop_image(label, cx, cy, image_size) 76 | 77 | 78 | # Intensity rescaling 79 | image = rescale_intensity(image, (1.0, 99.0)) 80 | 81 | while n_selected < batch_size: 82 | randZ = random.randrange(Z) 83 | randT = random.randrange(T) 84 | images += [image[:, :, randZ, randT]] 85 | labels += [label[:, :, randZ, randT]] 86 | # Increase the counter 87 | n_selected += 1 88 | 89 | 90 | # Convert to a numpy array 91 | images = np.array(images, dtype=np.float32) 92 | labels = np.array(labels, dtype=np.int32) 93 | 94 | # Add the channel dimension 95 | # tensorflow by default assumes NHWC format 96 | images = np.expand_dims(images, axis=3) 97 | 98 | # Perform data augmentation 99 | if data_augmentation: 100 | images, labels = data_augmenter(images, labels, 101 | shift=shift, rotate=rotate, 102 | scale=scale, 103 | intensity=intensity, flip=flip) 104 | return images, labels 105 | 106 | 107 | if __name__ == '__main__': 108 | # Go through each subset (training, validation, test) under the data directory 109 | # and list the file names of the subjects 110 | data_list = {} 111 | for k in ['train', 'test']: 112 | subset_dir = os.path.join(dataset_dir, k) 113 | data_list[k] = [] 114 | 115 | for data in sorted(os.listdir(subset_dir)): 116 | data_dir = os.path.join(subset_dir, data) 117 | image_name = '{0}/sa.nii'.format(data_dir) 118 | label_name = '{0}/label_sa.nii'.format(data_dir) 119 | if os.path.exists(image_name) and os.path.exists(label_name): 120 | data_list[k] += [[image_name, label_name]] 121 | 122 | 123 | # Prepare tensors for the image and label map pairs 124 | # Use int32 for label_pl as tf.one_hot uses int32 125 | image_pl = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='image') 126 | label_pl = tf.placeholder(tf.int32, shape=[None, None, None], name='label') 127 | 128 | # Print out the placeholders' names, which will be useful when deploying the network 129 | print('Placeholder image_pl.name = ' + image_pl.name) 130 | print('Placeholder label_pl.name = ' + label_pl.name) 131 | 132 | # Placeholder for the training phase 133 | # This flag is important for the batch_normalization layer to function properly. 134 | training_pl = tf.placeholder(tf.bool, shape=[], name='training') 135 | print('Placeholder training_pl.name = ' + training_pl.name) 136 | 137 | # Determine the number of label classes according to the manual annotation procedure 138 | # for each image sequence. 139 | n_class = 4 140 | 141 | # The number of resolution levels 142 | n_level = num_level 143 | 144 | # The number of filters at each resolution level 145 | # Follow the VGG philosophy, increasing the dimension 146 | # by a factor of 2 for each level 147 | n_filter = [] 148 | for i in range(n_level): 149 | n_filter += [num_filter * pow(2, i)] 150 | print('Number of filters at each level =', n_filter) 151 | print('Note: The connection between neurons is proportional to ' 152 | 'n_filter * n_filter. Increasing n_filter by a factor of 2 ' 153 | 'will increase the number of parameters by a factor of 4. ' 154 | 'So it is better to start experiments with a small n_filter ' 155 | 'and increase it later.') 156 | 157 | # Build the neural network, which outputs the logits, 158 | # i.e. the unscaled values just before the softmax layer, 159 | # which will then normalise the logits into the probabilities. 160 | n_block = [2, 2, 3, 3, 3] 161 | logits = build_FCN(image_pl, n_class, n_level=n_level, 162 | n_filter=n_filter, n_block=n_block, 163 | training=training_pl, same_dim=32, fc=64, frozenLayers=frozenLayers) 164 | 165 | # The softmax probability and the predicted segmentation 166 | prob = tf.nn.softmax(logits, name='prob') 167 | pred = tf.cast(tf.argmax(prob, axis=-1), dtype=tf.int32, name='pred') 168 | print('prob.name = ' + prob.name) 169 | print('pred.name = ' + pred.name) 170 | 171 | # Loss 172 | label_1hot = tf.one_hot(indices=label_pl, depth=n_class) 173 | label_loss = tf.nn.softmax_cross_entropy_with_logits(labels=label_1hot, logits=logits) 174 | loss = tf.reduce_mean(label_loss) 175 | 176 | # Evaluation metrics 177 | accuracy = tf_categorical_accuracy(pred, label_pl) 178 | dice_lv = tf_categorical_dice(pred, label_pl, 1) 179 | dice_myo = tf_categorical_dice(pred, label_pl, 2) 180 | dice_rv = tf_categorical_dice(pred, label_pl, 3) 181 | 182 | # Optimiser 183 | lr = learning_rate 184 | 185 | # We need to add the operators associated with batch_normalization 186 | # to the optimiser, according to 187 | # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization 188 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 189 | with tf.control_dependencies(update_ops): 190 | print('Using Adam optimizer.') 191 | train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) 192 | 193 | # Model name and directory 194 | model_name = 'FCN_{0}_level{1}_filter{2}_{3}_batch{4}_iter{5}_lr{6}'.format( 195 | 'sa', n_level, n_filter[0], ''.join([str(x) for x in n_block]), 196 | train_batch_size, train_iteration, learning_rate) 197 | model_dir = os.path.join(checkpoint_dir, model_name) 198 | if not os.path.exists(model_dir): 199 | os.makedirs(model_dir) 200 | 201 | 202 | # Start the tensorflow session 203 | with tf.Session() as sess: 204 | print('Start training...') 205 | start_time = time.time() 206 | 207 | # Create a saver 208 | saver = tf.train.Saver(max_to_keep=20) 209 | 210 | # Summary writer 211 | summary_dir = os.path.join(log_dir, model_name) 212 | if os.path.exists(summary_dir): 213 | os.system('rm -rf {0}'.format(summary_dir)) 214 | train_writer = tf.summary.FileWriter(os.path.join(summary_dir, 'train'), graph=sess.graph) 215 | validation_writer = tf.summary.FileWriter(os.path.join(summary_dir, 'validation'), graph=sess.graph) 216 | 217 | # Initialise variables 218 | sess.run(tf.global_variables_initializer()) 219 | 220 | # Import the computation graph and restore the variable values 221 | print('Loading pretrained model...') 222 | saverOld = tf.train.import_meta_graph('{0}.meta'.format(old_model_path)) 223 | saverOld.restore(sess, '{0}'.format(old_model_path)) 224 | 225 | # Iterate 226 | for iteration in range(1, 1 + train_iteration): 227 | 228 | # For each iteration, we randomly choose a batch of subjects 229 | print('Iteration {0}: training...'.format(iteration)) 230 | start_time_iter = time.time() 231 | 232 | images, labels = get_random_batch(data_list['train'], 233 | train_batch_size, 234 | image_size=imgSize, 235 | data_augmentation=True, 236 | shift=0, rotate=90, scale=0.2, 237 | intensity=0, flip=False) 238 | 239 | 240 | # Stochastic optimisation using this batch 241 | _, train_loss, train_acc = sess.run([train_op, loss, accuracy], 242 | {image_pl: images, label_pl: labels, training_pl: True}) 243 | # 244 | summary = tf.Summary() 245 | summary.value.add(tag='loss', simple_value=train_loss) 246 | summary.value.add(tag='accuracy', simple_value=train_acc) 247 | train_writer.add_summary(summary, iteration) 248 | 249 | # Print the results for this iteration 250 | print('Iteration {} of {} took {:.3f}s'.format(iteration, train_iteration, 251 | time.time() - start_time_iter)) 252 | print(' training loss:\t\t{:.6f}'.format(train_loss)) 253 | print(' training accuracy:\t\t{:.2f}%'.format(train_acc * 100)) 254 | 255 | # Save model 256 | saver.save(sess, save_path=os.path.join(model_dir, '{0}.ckpt'.format(model_name)), 257 | global_step=iteration) 258 | 259 | # Close the summary writers 260 | train_writer.close() 261 | validation_writer.close() 262 | print('Training took {:.3f}s in total.\n'.format(time.time() - start_time)) -------------------------------------------------------------------------------- /computeDiceAll.py: -------------------------------------------------------------------------------- 1 | import os, sys, re 2 | import nibabel as nib 3 | import numpy as np 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 6 | from common.image_utils import np_categorical_dice 7 | 8 | 9 | testDir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/test' 10 | numLayers = [0,4,8,12,14,15] 11 | 12 | def compareImages(img1Path, img2Path): 13 | img1 = np.round(nib.load(img1Path).get_data()) 14 | img2 = np.round(nib.load(img2Path).get_data()) 15 | return (np_categorical_dice(img1, img2, 1), np_categorical_dice(img1, img2, 3)) 16 | 17 | if __name__ == '__main__': 18 | 19 | diceDict = {} 20 | for data in sorted(os.listdir(testDir)): 21 | data_dir = os.path.join(testDir, data) 22 | print(data_dir) 23 | for truthImage in sorted(os.listdir(data_dir)): 24 | if re.match('O\d_label_sa.nii', truthImage): 25 | print(os.path.join(data_dir,truthImage)) 26 | for l in numLayers: 27 | if not str(l) in diceDict: 28 | diceDict[str(l)] = {} 29 | if not data in diceDict[str(l)]: 30 | diceDict[str(l)][data] = [] 31 | cnnImage = 'sa_label_{}.nii'.format(l) 32 | print('Comparing {0} with {1}'.format(truthImage, cnnImage)) 33 | diceLV, diceRV = compareImages(os.path.join(data_dir, truthImage), os.path.join(data_dir, cnnImage)) 34 | diceDict[str(l)][data].append([diceLV,diceRV]) 35 | 36 | meanDict = {} 37 | stdDict = {} 38 | for l in numLayers: 39 | print(l) 40 | arr = [] 41 | for subj in diceDict[str(l)]: 42 | arr.append(np.mean(np.array(diceDict[str(l)][subj]), axis=0)) 43 | meanDict[str(l)] = np.mean(np.array(arr), axis=0) 44 | stdDict[str(l)] = np.std(np.array(arr), axis=0) 45 | print(meanDict[str(l)]) 46 | print(stdDict[str(l)]) -------------------------------------------------------------------------------- /matlab_scripts/cnnSeg2Segment.m: -------------------------------------------------------------------------------- 1 | function cnnSeg2Segment(dirName) 2 | %cnnSeg2Segment converts nifti format short axis bSSFP images and nifti 3 | %format LV & RV segmentation to a format compatible with Medviso Segment software. 4 | % Input directory should contain 'sa.nii.gz' and 'seg_sa.nii.gz' files. 5 | if nargin<1 || (~ischar(dirName) && ~isstring(dirName)) 6 | error("Directory not found.\n Usage',' cnnSeg2Segment(dirName);"); 7 | end 8 | 9 | setstruct.IM = niftiread(fullfile(dirName,'sa.nii.gz')); 10 | setstruct.IM = single(permute(setstruct.IM,[2,1,4,3]))./single(max(setstruct.IM(:))); 11 | 12 | seg = niftiread(fullfile(dirName,'seg_sa.nii.gz')); 13 | 14 | for z=1:size(seg,3) 15 | for t=1:size(seg,4) 16 | BLV = bwboundaries(seg(:,:,z,t)==1); 17 | [~,indLV] = max(cellfun(@numel,BLV)); 18 | if ~isempty(indLV) && indLV>0 && size(BLV{indLV},1)>2 19 | pt = interparc(linspace(0,1,80),BLV{indLV}(:,1),BLV{indLV}(:,2),'csape'); 20 | setstruct.EndoY(:,t,z) = pt(:,1); 21 | setstruct.EndoX(:,t,z) = pt(:,2); 22 | else 23 | setstruct.EndoX(:,t,z) = nan(80,1); 24 | setstruct.EndoY(:,t,z) = nan(80,1); 25 | end 26 | BRV = bwboundaries(seg(:,:,z,t)==-1); 27 | [~,indRV] = max(cellfun(@numel,BRV)); 28 | if ~isempty(indRV) && indRV>0 && size(BRV{indRV},1)>2 29 | pt = interparc(linspace(0,1,80),BRV{indRV}(:,1),BRV{indRV}(:,2),'csape'); 30 | setstruct.RVEndoY(:,t,z) = pt(:,1); 31 | setstruct.RVEndoX(:,t,z) = pt(:,2); 32 | else 33 | setstruct.RVEndoX(:,t,z) = nan(80,1); 34 | setstruct.RVEndoY(:,t,z) = nan(80,1); 35 | end 36 | end 37 | end 38 | 39 | info = niftiinfo(fullfile(dirName,'sa.nii.gz')); 40 | 41 | setstruct.ResolutionX = info.PixelDimensions(1); 42 | setstruct.ResolutionY = info.PixelDimensions(2); 43 | setstruct.SliceThickness = info.PixelDimensions(3); 44 | setstruct.SliceGap = 0; 45 | setstruct.TIncr = info.PixelDimensions(4)./1000; 46 | setstruct.XSize = info.ImageSize(1); 47 | setstruct.YSize = info.ImageSize(2); 48 | setstruct.ZSize = info.ImageSize(3); 49 | setstruct.TSize = info.ImageSize(4); 50 | 51 | setstruct.ImagePosition = info.Transform.T(4,1:3).*[-1,-1,1]; 52 | R = info.Transform.T(1:3,1:3)./repmat([setstruct.ResolutionX;setstruct.ResolutionY;setstruct.SliceThickness],[1,3]); 53 | R = R'.*repmat([-1;-1;1],[1,3]); 54 | setstruct.ImageOrientation = R(1:6); 55 | setstruct.StartSlice=1; 56 | setstruct.EndSlice=1; 57 | setstruct.CurrentTimeFrame=1; 58 | setstruct.CurrentSlice=1; 59 | setstruct.OrgXSize=setstruct.XSize; 60 | setstruct.OrgYSize=setstruct.YSize; 61 | setstruct.OrgZSize=setstruct.ZSize; 62 | setstruct.OrgTSize=setstruct.TSize; 63 | setstruct.TDelay=0; 64 | setstruct.TimeVector=0:setstruct.TIncr:setstruct.TIncr*(setstruct.TSize-1); 65 | setstruct.EchoTime=1; 66 | setstruct.T2preptime=NaN; 67 | setstruct.RepetitionTime=1; 68 | setstruct.InversionTime=0; 69 | setstruct.TriggerTime=zeros(setstruct.ZSize,setstruct.TSize); 70 | setstruct.FlipAngle=1; 71 | setstruct.AccessionNumber='X'; 72 | setstruct.StudyUID=''; 73 | setstruct.StudyID=''; 74 | setstruct.NumberOfAverages=1; 75 | setstruct.VENC=0; 76 | setstruct.GEVENCSCALE=0; 77 | setstruct.Scanner='GE'; 78 | setstruct.Modality='MR'; 79 | setstruct.PathName=dirName; 80 | setstruct.FileName=fullfile(dirName,'cnnSeg.mat'); 81 | setstruct.OrigFileName=''; 82 | setstruct.PatientInfo=struct('Name','cnn2Seg',... 83 | 'ID','cnn2Seg',... 84 | 'BirthDate','19010101',... 85 | 'Sex','M',... 86 | 'Age',99,... 87 | 'AcquisitionDate','19010101',... 88 | 'Length',100,... 89 | 'Weight',100,... 90 | 'BSA',2,... 91 | 'Institution',''); 92 | setstruct.XMin=1; 93 | setstruct.YMin=1; 94 | setstruct.Cyclic=1; 95 | setstruct.Bitstored=info.BitsPerPixel; 96 | setstruct.Rotated=0; 97 | setstruct.SequenceName=''; 98 | setstruct.SeriesDescription=''; 99 | setstruct.SeriesNumber=7; 100 | setstruct.AcquisitionTime=00001; 101 | setstruct.DICOMImageType=''; 102 | setstruct.HeartRate=60./(setstruct.TIncr)./(setstruct.TSize); 103 | setstruct.BeatTime=(setstruct.TIncr).*(setstruct.TSize); 104 | setstruct.Scar=[]; 105 | setstruct.Flow=[]; 106 | setstruct.Report=[]; 107 | setstruct.Perfusion=[]; 108 | setstruct.PerfusionScoring=[]; 109 | setstruct.Strain=[]; 110 | setstruct.StrainTagging=[]; 111 | setstruct.Stress=[]; 112 | setstruct.CenterX=100; 113 | setstruct.CenterY=100; 114 | setstruct.RoiCurrent=[]; 115 | setstruct.RoiN=0; 116 | setstruct.Roi=struct('X',[],'Y',[],'T',[],'Z',[],'Sign',[],'Name','','LineSpec','','Area',[],'Mean',[],'StD',[],'Flow',[]); 117 | setstruct.EndoPinX=[]; 118 | setstruct.EndoPinY=[]; 119 | setstruct.EndoInterpX=[]; 120 | setstruct.EndoInterpY=[]; 121 | setstruct.EndoXView=zeros(1215,setstruct.TSize); 122 | setstruct.EndoYView=zeros(1215,setstruct.TSize); 123 | setstruct.EndoPinXView=[]; 124 | setstruct.EndoPinYView=[]; 125 | setstruct.EpiX=zeros(80,setstruct.TSize,setstruct.ZSize); 126 | setstruct.EpiY=zeros(80,setstruct.TSize,setstruct.ZSize); 127 | setstruct.EpiXView=zeros(1215,setstruct.TSize); 128 | setstruct.EpiYView=zeros(1215,setstruct.TSize); 129 | setstruct.EpiPinX=[]; 130 | setstruct.EpiPinY=[]; 131 | setstruct.EpiInterpX=[]; 132 | setstruct.EpiInterpY=[]; 133 | setstruct.EpiPinXView=[]; 134 | setstruct.EpiPinYView=[]; 135 | setstruct.EndoDraged=false(setstruct.TSize,setstruct.ZSize); 136 | setstruct.EpiDraged=false(setstruct.TSize,setstruct.ZSize); 137 | setstruct.RVEndoInterpX=cell(setstruct.TSize,setstruct.ZSize); 138 | setstruct.RVEndoInterpY=cell(setstruct.TSize,setstruct.ZSize); 139 | setstruct.RVEndoXView=zeros(1,setstruct.TSize); 140 | setstruct.RVEndoYView=zeros(1,setstruct.TSize); 141 | setstruct.RVEpiX=[]; 142 | setstruct.RVEpiY=[]; 143 | setstruct.RVEpiInterpX=[]; 144 | setstruct.RVEpiInterpY=[]; 145 | setstruct.RVEpiXView=NaN; 146 | setstruct.RVEpiYView=NaN; 147 | setstruct.RVEndoPinX=[]; 148 | setstruct.RVEndoPinY=[]; 149 | setstruct.RVEpiPinX=[]; 150 | setstruct.RVEpiPinY=[]; 151 | setstruct.RVEndoPinXView=[]; 152 | setstruct.RVEndoPinYView=[]; 153 | setstruct.RVEpiPinXView=[]; 154 | setstruct.RVEpiPinYView=[]; 155 | setstruct.SectorRotation=0; 156 | setstruct.Mmode=zeros(1,setstruct.TSize); 157 | setstruct.LVV=zeros(1,setstruct.TSize); 158 | setstruct.EPV=zeros(1,setstruct.TSize); 159 | setstruct.PV=zeros(1,setstruct.TSize); 160 | setstruct.LVM = zeros(1,setstruct.TSize); 161 | setstruct.ESV=0; 162 | setstruct.EDV=0; 163 | setstruct.EDT=1; 164 | setstruct.EST=1; 165 | setstruct.SV=0; 166 | setstruct.EF=0; 167 | setstruct.PFR=544.8117; 168 | setstruct.PER=458.6786; 169 | setstruct.PFRT=11; 170 | setstruct.PERT=2; 171 | setstruct.RVV=zeros(1,setstruct.TSize); 172 | setstruct.RVEPV=nan(1,setstruct.TSize); 173 | setstruct.RVM=nan(1,setstruct.TSize); 174 | setstruct.RVESV=1; 175 | setstruct.RVEDV=1; 176 | setstruct.RVSV=0; 177 | setstruct.RVEF=0; 178 | setstruct.SpectSpecialTag=[]; 179 | setstruct.RVPFR=405.1299; 180 | setstruct.RVPFRT=12; 181 | setstruct.RVPER=569.7068; 182 | setstruct.RVPERT=2; 183 | setstruct.ImageType='Cine'; 184 | setstruct.ImageViewPlane='Short-axis'; 185 | setstruct.ImagingTechnique='MRSSFP'; 186 | setstruct.IntensityScaling=2644; 187 | setstruct.IntensityOffset=0; 188 | setstruct.StartAnalysis=1; 189 | setstruct.EndAnalysis=setstruct.TSize; 190 | setstruct.EndoCenter=1; 191 | setstruct.NormalZoomState=[0.5;0.5+setstruct.XSize;0.5;setstruct.YSize+0.5]; 192 | setstruct.MontageZoomState=[]; 193 | setstruct.MontageRowZoomState=[]; 194 | setstruct.MontageFitZoomState=[]; 195 | setstruct.Measure=[]; 196 | setstruct.RV=struct(); 197 | setstruct.LevelSet=[]; 198 | setstruct.OrgRes=info.PixelDimensions./[1,1,1,1000]; 199 | setstruct.AutoLongaxis=0; 200 | setstruct.Longaxis=1; 201 | setstruct.Point=struct('X',[],'Y',[],'T',[],'Z',[]); 202 | setstruct.Point.Label={}; 203 | setstruct.IntensityMapping=struct('Brightness',0.5779,'Contrast',2.2695,'Compression',[]); 204 | setstruct.Colormap=[]; 205 | setstruct.View=struct('ViewPanels',1,... 206 | 'ViewMatrix',[1 1],'ThisFrameOnly',1,... 207 | 'CurrentPanel',1,'CurrentTheme','rv','CurrentTool','select'); 208 | setstruct.View.ViewPanelsType={'one'}; 209 | setstruct.View.ViewPanelsMatrix={[4 4]}; 210 | setstruct.RotationCenter=[]; 211 | setstruct.Fusion=[]; 212 | setstruct.ProgramVersion='2.2R6435'; 213 | setstruct.PapillaryIM=[]; 214 | setstruct.MaR=[]; 215 | setstruct.T2=[]; 216 | setstruct.Children=[]; 217 | setstruct.Parent=[]; 218 | setstruct.Linked=1; 219 | setstruct.Overlay=[]; 220 | setstruct.Intersection=[]; 221 | setstruct.Comment=[]; 222 | setstruct.AtrialScar=[]; 223 | setstruct.SAX3=[]; 224 | setstruct.HLA=[]; 225 | setstruct.VLA=[]; 226 | setstruct.GLA=[]; 227 | setstruct.CT=[]; 228 | setstruct.PapillaryThreshold=0; 229 | setstruct.ECV=[]; 230 | setstruct.Developer=[]; 231 | setstruct.EndoInterpXView=[]; 232 | setstruct.EndoInterpYView=[]; 233 | setstruct.EpiInterpXView=[]; 234 | setstruct.EpiInterpYView=[]; 235 | setstruct.RVEndoInterpXView=cell(setstruct.TSize,1); 236 | setstruct.RVEndoInterpYView=cell(setstruct.TSize,1); 237 | setstruct.RVEpiInterpXView=[]; 238 | setstruct.RVEpiInterpYView=[]; 239 | 240 | 241 | 242 | im = []; %#ok 243 | preview = repmat(uint8(setstruct.IM(:,:,1,1).*255),[1,1,3]); %#ok 244 | info = struct('Name','cnn2Seg',... 245 | 'ID','cnn2Seg',... 246 | 'BirthDate','19010101',... 247 | 'Sex','M',... 248 | 'Age',99,... 249 | 'AcquisitionDate','19010101',... 250 | 'Length',100,... 251 | 'Weight',100,... 252 | 'BSA',2,... 253 | 'Institution','',... 254 | 'NFrames',0,... 255 | 'NumSlices',0,... 256 | 'ResolutionX',0,... 257 | 'ResolutionY',0,... 258 | 'SliceThickness',0,... 259 | 'SliceGap',0,... 260 | 'TIncr',0,... 261 | 'EchoTime',0,... 262 | 'FlipAngle',0,... 263 | 'AccessionNumber','',... 264 | 'StudyUID','',... 265 | 'StudyID','',... 266 | 'NumberOfAverages',0,... 267 | 'RepetitionTime',0,... 268 | 'InversionTime',0,... 269 | 'TDelay',0,... 270 | 'VENC',0,... 271 | 'Scanner','',... 272 | 'ImagingTechnique','',... 273 | 'ImageType','',... 274 | 'ImageViewPlane','',... 275 | 'IntensityScaling',1,... 276 | 'IntensityOffset',0,... 277 | 'MultiDataSet',true,... 278 | 'Modality','MR'); %#ok 279 | save(fullfile(dirName,'cnn2Seg.mat'),'preview','info','im','setstruct'); 280 | end 281 | -------------------------------------------------------------------------------- /matlab_scripts/computeLVFlowCompartments.m: -------------------------------------------------------------------------------- 1 | function data = computeLVFlowCompartments(inputDir,tShift, maskFile) 2 | %computeLVFlowCompartments compute the amount of intra-LV flow in each of 3 | % the 4 main flow compartments (see Eriksson J, et al. 2010. 4 | % https://doi.org/10.1186/1532-429X-12-9) 5 | % data = computeLVFlowCompartments(inputDir) calculate flow compartments 6 | % based on 4D flow data and mask located in directory 'srcDir' 7 | % The directory should contain 4D flow data in nifti format (MAG.nii, 8 | % CD.nii, VELX.nii, VELY.nii, VELZ.nii, registeredMask.nii) 9 | 10 | % Parse Input 1 11 | curDir = pwd; 12 | if nargin<1 13 | fprintf('No input directory, using %s.\n', pwd); 14 | inputDir = pwd; 15 | else 16 | fprintf('Computing pathlines for %s.\n', inputDir); 17 | end 18 | 19 | cd(inputDir); 20 | 21 | % Parse Input 2 22 | if nargin<2 || isempty(tShift) || ~isnumeric(tShift) 23 | tShift = 0; 24 | end 25 | 26 | fprintf('Reading images...\n'); 27 | if nargin>2 && exist(maskFile,'file') 28 | mask = circshift(round(squeeze(niftiread(maskFile))),[0,0,0,tShift]); 29 | elseif exist('registeredMask.nii','file') 30 | mask = circshift(squeeze(round(niftiread('registeredMask.nii'))),[0,0,0,tShift]); 31 | else 32 | mask = circshift(squeeze(round(niftiread('registeredMask.nii.gz'))),[0,0,0,tShift]); 33 | end 34 | 35 | if exist('VELX.nii','file') 36 | ext = '.nii'; 37 | else 38 | ext = '.nii.gz'; 39 | end 40 | 41 | info = niftiinfo(['VELX',ext]); 42 | vol = prod(info.PixelDimensions(1:3)/1000); % pixel volume in m^3 43 | 44 | lvVolume = squeeze(sum(sum(sum(mask>0,1),2),3).*vol.*1e6); % in mL 45 | 46 | staticMask = false; 47 | if all(lvVolume==lvVolume(1)); staticMask=true; end 48 | 49 | [~,edvTime] = max(lvVolume); 50 | if staticMask 51 | edvTime = 7-tShift; 52 | end 53 | 54 | mask = circshift(mask,[0,0,0,-(edvTime-1)]); 55 | 56 | vx = double(circshift(-niftiread(['VELX',ext]),[0,0,0,-(edvTime-1)])); 57 | vy = double(circshift(-niftiread(['VELY',ext]),[0,0,0,-(edvTime-1)])); 58 | vz = double(circshift(-niftiread(['VELZ',ext]),[0,0,0,-(edvTime-1)])); 59 | 60 | lvVolume = squeeze(sum(sum(sum(mask>0,1),2),3).*vol.*1e6); % in mL 61 | [~,esvTime] = min(lvVolume); 62 | if staticMask % Special case to handle static mask (i.e. phantom project) 63 | esvTime = 13; 64 | end 65 | % Forward through time mask and velocity 66 | maskF = mask(:,:,:,1:esvTime); 67 | vxF = vx(:,:,:,1:esvTime); 68 | vyF = vy(:,:,:,1:esvTime); 69 | vzF = vz(:,:,:,1:esvTime); 70 | % Reverse through time mask and velocity 71 | maskR = mask(:,:,:,[1,size(mask,4):-1:esvTime]); 72 | vxR = -vx(:,:,:,[1,size(mask,4):-1:esvTime]); 73 | vyR = -vy(:,:,:,[1,size(mask,4):-1:esvTime]); 74 | vzR = -vz(:,:,:,[1,size(mask,4):-1:esvTime]); 75 | 76 | fprintf('Done reading images.\n'); 77 | 78 | voxelSize = info.PixelDimensions; 79 | voxelSize(4) = voxelSize(4)./1000; % from milliseconds to seconds 80 | 81 | fprintf('Computing forward pathlines...\n'); 82 | fpaths = computePaths(maskF, vxF, vyF, vzF, voxelSize); 83 | 84 | fprintf('Computing backward pathlines...\n'); 85 | rpaths = computePaths(maskR,vxR,vyR,vzR, voxelSize); 86 | rpaths(:,4,:) = mod(-rpaths(:,4,:),voxelSize(4).*size(mask,4)); 87 | 88 | paths = [flip(rpaths(2:end,:,:),1);fpaths]; % Combine forward and backward pathlines 89 | data = classifyPaths(mask, paths, fpaths, rpaths, voxelSize); 90 | data.paths = paths; 91 | 92 | direct = numel(data.direct)/(size(data.paths,3)-numel(data.errant)); 93 | delayed = numel(data.delayed)/(size(data.paths,3)-numel(data.errant)); 94 | retained = numel(data.retained)/(size(data.paths,3)-numel(data.errant)); 95 | residual = numel(data.residual)/(size(data.paths,3)-numel(data.errant)); 96 | passing = 1-numel(data.errant)/size(data.paths,3); 97 | 98 | fprintf('\nPercentage passing QC check: %%%3.1f\n',passing*100); 99 | fprintf('\nPercentage direct flow: %%%3.1f\n',direct*100); 100 | fprintf('Percentage delayed ejection: %%%3.1f\n',delayed*100); 101 | fprintf('Percentage retained inflow: %%%3.1f\n',retained*100); 102 | fprintf('Percentage residual volume: %%%3.1f\n',residual*100); 103 | 104 | cd(curDir); 105 | end 106 | 107 | function data = classifyPaths(mask, paths, fPaths, rPaths, voxSize) 108 | fprintf('Classifying pathlines...\n'); 109 | 110 | mask = cat(4,mask,mask(:,:,:,1)); 111 | 112 | [x,y,z,t] = ndgrid( (1:size(mask,1)).*voxSize(1),... 113 | (1:size(mask,2)).*voxSize(2),... 114 | (1:size(mask,3)).*voxSize(3),... 115 | (0:(size(mask,4)-1)).*voxSize(4) ); 116 | 117 | F = griddedInterpolant(x,y,z,t,double(mask),'nearest', 'nearest'); 118 | 119 | px = reshape(paths(:,1,:),[],1); 120 | py = reshape(paths(:,2,:),[],1); 121 | pz = reshape(paths(:,3,:),[],1); 122 | pt = reshape(paths(:,4,:),[],1); 123 | 124 | nT = size(paths,1); 125 | nPath = size(paths,3); 126 | locations = reshape(F([px,py,pz,pt]), nT, nPath); 127 | 128 | for ii = 1:nPath 129 | firstIn(ii) = find(locations(:,ii)>0,1,'first'); %#ok 130 | lastIn(ii) = find(locations(:,ii)>0,1,'last'); %#ok 131 | % Check that pathline only enters or leaves LV through the base, 132 | % not the mid-ventricle or the apex. 133 | qcCheck(ii) = (firstIn(ii)==1 || locations(firstIn(ii),ii)==3) && ... 134 | ( lastIn(ii)==nT || locations(lastIn(ii),ii)==3); %#ok 135 | end 136 | startsIn = firstIn==1; 137 | enters = firstIn~=1; 138 | endsIn = lastIn==nT; 139 | leaves = lastIn~=nT; 140 | 141 | data.residual = find(startsIn .* endsIn .* qcCheck); 142 | data.direct = find(enters .* leaves .* qcCheck); 143 | data.retained = find(enters .* endsIn .* qcCheck); 144 | data.delayed = find(startsIn .* leaves .* qcCheck); 145 | data.errant = find(~qcCheck); 146 | 147 | data.directStartF = squeeze(fPaths(1,1:3,data.direct)./repmat(voxSize(1:3),[1,1,numel(data.direct)])); 148 | data.retainedStartF = squeeze(fPaths(1,1:3,data.retained)./repmat(voxSize(1:3),[1,1,numel(data.retained)])); 149 | data.delayedStartF = squeeze(fPaths(1,1:3,data.delayed)./repmat(voxSize(1:3),[1,1,numel(data.delayed)])); 150 | data.residualStartF = squeeze(fPaths(1,1:3,data.residual)./repmat(voxSize(1:3),[1,1,numel(data.residual)])); 151 | 152 | data.directStartR = squeeze(rPaths(1,1:3,data.direct)./repmat(voxSize(1:3),[1,1,numel(data.direct)])); 153 | data.retainedStartR = squeeze(rPaths(1,1:3,data.retained)./repmat(voxSize(1:3),[1,1,numel(data.retained)])); 154 | data.delayedStartR = squeeze(rPaths(1,1:3,data.delayed)./repmat(voxSize(1:3),[1,1,numel(data.delayed)])); 155 | data.residualStartR = squeeze(rPaths(1,1:3,data.residual)./repmat(voxSize(1:3),[1,1,numel(data.residual)])); 156 | end 157 | 158 | function paths = computePaths(mask, vx, vy, vz, voxSize) 159 | fprintf('Computing pathlines...\n'); 160 | 161 | stepSizeFraction = 100; 162 | 163 | [x,y,z,t] = ndgrid( (1:size(vx,1) ).*voxSize(1),... 164 | (1:size(vx,2) ).*voxSize(2),... 165 | (1:size(vx,3) ).*voxSize(3),... 166 | (0:(size(vx,4)-1)).*voxSize(4) ); 167 | tt = (0:(size(vx,4)-1)).*voxSize(4); 168 | nT = numel(tt); 169 | 170 | ind = find(mask(:,:,:,1)>0); 171 | nVox = numel(ind); 172 | h = (voxSize(4))/stepSizeFraction; % step size 173 | 174 | fprintf('Setting up interpolation grids.\n'); 175 | if (nT>1) 176 | fx = griddedInterpolant(x,y,z,t,vx,'linear', 'none'); 177 | fy = griddedInterpolant(x,y,z,t,vy,'linear', 'none'); 178 | fz = griddedInterpolant(x,y,z,t,vz,'linear', 'none'); 179 | 180 | F = @(tq,rq) [fx([rq,tq]), fy([rq,tq]), fz([rq,tq])]; 181 | else 182 | fx = griddedInterpolant(x,y,z,vx,'linear', 'none'); 183 | fy = griddedInterpolant(x,y,z,vy,'linear', 'none'); 184 | fz = griddedInterpolant(x,y,z,vz,'linear', 'none'); 185 | 186 | F = @(tq,rq) [fx(rq), fy(rq), fz(rq)]; 187 | end 188 | 189 | 190 | ttt = repmat(reshape(tt,1,1,[]),[nVox,1,1]); 191 | fprintf('Found %i mask voxels. Making %i pathlines.\n',nVox,nVox); 192 | paths = zeros(nVox,4,nT*stepSizeFraction); 193 | 194 | paths(:,:,1) = [x(ind),y(ind),z(ind),t(ind)]; 195 | for ii=1:(nT*stepSizeFraction-1) % calculation loop 196 | k1 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 0.0*h, paths(:,1:3,ii) + 0.0*h ); 197 | k2 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 0.5*h, paths(:,1:3,ii) + 0.5*h*k1); 198 | k3 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 0.5*h, paths(:,1:3,ii) + 0.5*h*k2); 199 | if ii<(nT*stepSizeFraction-1) 200 | k4 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 1.0*h, paths(:,1:3,ii) + 1.0*h*k3); 201 | else 202 | k4 = F( ttt(:,1,nT) , paths(:,1:3,ii) + 1.0*h*k3); 203 | end 204 | paths(:,:,ii+1) = paths(:,:,ii) + [(1/6)*(k1+2*k2+2*k3+k4)*h, repmat(h,[nVox,1])]; % main equation 205 | end 206 | paths = paths(:,:,1:stepSizeFraction:end); 207 | fprintf('Done computing pathlines.\n'); 208 | 209 | paths = permute(paths,[3,2,1]); 210 | end -------------------------------------------------------------------------------- /matlab_scripts/computeRVFlowCompartments.m: -------------------------------------------------------------------------------- 1 | function data = computeRVFlowCompartments(inputDir,tShift, maskFile) 2 | %computeRVFlowCompartments compute the amount of intra-RV flow in each of 3 | % the 4 main flow compartments (see Eriksson J, et al. 2010. 4 | % https://doi.org/10.1186/1532-429X-12-9) 5 | % data = computeLVFlowCompartments(inputDir) calculate flow compartments 6 | % based on 4D flow data and mask located in directory 'srcDir' 7 | % The directory should contain 4D flow data in nifti format (MAG.nii, 8 | % CD.nii, VELX.nii, VELY.nii, VELZ.nii, registeredMask.nii) 9 | 10 | curDir = pwd; 11 | if nargin<1 12 | fprintf('No input directory, using %s.\n', pwd); 13 | inputDir = pwd; 14 | else 15 | fprintf('Computing pathlines for %s.\n', inputDir); 16 | end 17 | 18 | cd(inputDir); 19 | 20 | if nargin<2 || isempty(tShift) || ~isnumeric(tShift) 21 | tShift = 0; 22 | end 23 | 24 | fprintf('Reading images...\n'); 25 | if nargin>2 && exist(maskFile,'file') 26 | mask = circshift(squeeze(round(niftiread(maskFile))),[0,0,0,tShift]); 27 | elseif exist('registeredMask.nii','file') 28 | mask = circshift(squeeze(round(niftiread('registeredMask.nii'))),[0,0,0,tShift]); 29 | else 30 | mask = circshift(squeeze(round(niftiread('registeredMask.nii.gz'))),[0,0,0,tShift]); 31 | end 32 | 33 | if exist('VELX.nii','file') 34 | ext = '.nii'; 35 | else 36 | ext = '.nii.gz'; 37 | end 38 | 39 | info = niftiinfo(['VELX',ext]); 40 | vol = prod(info.PixelDimensions(1:3)/1000); % pixel volume in m^3 41 | 42 | rvVolume = squeeze(sum(sum(sum(mask<0,1),2),3).*vol.*1e6); % in mL 43 | 44 | staticMask = false; 45 | if all(rvVolume==rvVolume(1)); staticMask=true; end 46 | 47 | [~,edvTime] = max(rvVolume); 48 | if staticMask 49 | edvTime = 7-tShift; 50 | end 51 | 52 | mask = circshift(mask,[0,0,0,-(edvTime-1)]); 53 | 54 | vx = double(circshift(-niftiread(['VELX',ext]),[0,0,0,-(edvTime-1)])); 55 | vy = double(circshift(-niftiread(['VELY',ext]),[0,0,0,-(edvTime-1)])); 56 | vz = double(circshift(-niftiread(['VELZ',ext]),[0,0,0,-(edvTime-1)])); 57 | 58 | rvVolume = squeeze(sum(sum(sum(mask>0,1),2),3).*vol.*1e6); % in mL 59 | [~,esvTime] = min(rvVolume); 60 | if staticMask % Special case to handle static mask (i.e. phantom project) 61 | esvTime = 13; 62 | end 63 | maskF = mask(:,:,:,1:esvTime); 64 | vxF = vx(:,:,:,1:esvTime); 65 | vyF = vy(:,:,:,1:esvTime); 66 | vzF = vz(:,:,:,1:esvTime); 67 | maskR = mask(:,:,:,[1,size(mask,4):-1:esvTime]); 68 | vxR = -vx(:,:,:,[1,size(mask,4):-1:esvTime]); 69 | vyR = -vy(:,:,:,[1,size(mask,4):-1:esvTime]); 70 | vzR = -vz(:,:,:,[1,size(mask,4):-1:esvTime]); 71 | 72 | fprintf('Done reading images.\n'); 73 | 74 | voxelSize = info.PixelDimensions; 75 | voxelSize(4) = voxelSize(4)./1000; % from milliseconds to seconds 76 | 77 | fprintf('Computing forward pathlines...\n'); 78 | fpaths = computePaths(maskF, vxF, vyF, vzF, voxelSize); 79 | 80 | fprintf('Computing backward pathlines...\n'); 81 | rpaths = computePaths(maskR,vxR,vyR,vzR, voxelSize); 82 | rpaths(:,4,:) = mod(-rpaths(:,4,:),voxelSize(4).*size(mask,4)); 83 | 84 | paths = [flip(rpaths(2:end,:,:),1);fpaths]; 85 | data = classifyPaths(mask, paths, fpaths, rpaths, voxelSize); 86 | data.paths = paths; 87 | 88 | direct = numel(data.direct)/(size(data.paths,3)-numel(data.errant)); 89 | delayed = numel(data.delayed)/(size(data.paths,3)-numel(data.errant)); 90 | retained = numel(data.retained)/(size(data.paths,3)-numel(data.errant)); 91 | residual = numel(data.residual)/(size(data.paths,3)-numel(data.errant)); 92 | passing = 1-numel(data.errant)/size(data.paths,3); 93 | 94 | fprintf('\nPercentage passing QC check: %%%3.1f\n',passing*100); 95 | fprintf('\nPercentage direct flow: %%%3.1f\n',direct*100); 96 | fprintf('Percentage delayed ejection: %%%3.1f\n',delayed*100); 97 | fprintf('Percentage retained inflow: %%%3.1f\n',retained*100); 98 | fprintf('Percentage residual volume: %%%3.1f\n',residual*100); 99 | 100 | cd(curDir); 101 | end 102 | 103 | function data = classifyPaths(mask, paths, fPaths, rPaths, voxSize) 104 | fprintf('Classifying pathlines...\n'); 105 | 106 | mask = cat(4,mask,mask(:,:,:,1)); 107 | 108 | [x,y,z,t] = ndgrid( (1:size(mask,1)).*voxSize(1),... 109 | (1:size(mask,2)).*voxSize(2),... 110 | (1:size(mask,3)).*voxSize(3),... 111 | (0:(size(mask,4)-1)).*voxSize(4) ); 112 | 113 | F = griddedInterpolant(x,y,z,t,double(mask),'nearest', 'nearest'); 114 | 115 | px = reshape(paths(:,1,:),[],1); 116 | py = reshape(paths(:,2,:),[],1); 117 | pz = reshape(paths(:,3,:),[],1); 118 | pt = reshape(paths(:,4,:),[],1); 119 | 120 | nT = size(paths,1); 121 | nPath = size(paths,3); 122 | locations = reshape(F([px,py,pz,pt]), nT, nPath); 123 | 124 | for ii = 1:nPath 125 | firstIn(ii) = find(locations(:,ii)<0,1,'first'); %#ok 126 | lastIn(ii) = find(locations(:,ii)<0,1,'last'); %#ok 127 | qcCheck(ii) = (firstIn(ii)==1 || locations(firstIn(ii),ii)==-3) && ... 128 | ( lastIn(ii)==nT || locations(lastIn(ii),ii)==-3); %#ok 129 | end 130 | startsIn = firstIn==1; 131 | enters = firstIn~=1; 132 | endsIn = lastIn==nT; 133 | leaves = lastIn~=nT; 134 | 135 | data.residual = find(startsIn .* endsIn .* qcCheck); 136 | data.direct = find(enters .* leaves .* qcCheck); 137 | data.retained = find(enters .* endsIn .* qcCheck); 138 | data.delayed = find(startsIn .* leaves .* qcCheck); 139 | data.errant = find(~qcCheck); 140 | 141 | data.directStartF = squeeze(fPaths(1,1:3,data.direct)./repmat(voxSize(1:3),[1,1,numel(data.direct)])); 142 | data.retainedStartF = squeeze(fPaths(1,1:3,data.retained)./repmat(voxSize(1:3),[1,1,numel(data.retained)])); 143 | data.delayedStartF = squeeze(fPaths(1,1:3,data.delayed)./repmat(voxSize(1:3),[1,1,numel(data.delayed)])); 144 | data.residualStartF = squeeze(fPaths(1,1:3,data.residual)./repmat(voxSize(1:3),[1,1,numel(data.residual)])); 145 | 146 | data.directStartR = squeeze(rPaths(1,1:3,data.direct)./repmat(voxSize(1:3),[1,1,numel(data.direct)])); 147 | data.retainedStartR = squeeze(rPaths(1,1:3,data.retained)./repmat(voxSize(1:3),[1,1,numel(data.retained)])); 148 | data.delayedStartR = squeeze(rPaths(1,1:3,data.delayed)./repmat(voxSize(1:3),[1,1,numel(data.delayed)])); 149 | data.residualStartR = squeeze(rPaths(1,1:3,data.residual)./repmat(voxSize(1:3),[1,1,numel(data.residual)])); 150 | end 151 | 152 | function paths = computePaths(mask, vx, vy, vz, voxSize) 153 | fprintf('Computing pathlines...\n'); 154 | 155 | stepSizeFraction = 100; 156 | 157 | [x,y,z,t] = ndgrid( (1:size(vx,1) ).*voxSize(1),... 158 | (1:size(vx,2) ).*voxSize(2),... 159 | (1:size(vx,3) ).*voxSize(3),... 160 | (0:(size(vx,4)-1)).*voxSize(4) ); 161 | tt = (0:(size(vx,4)-1)).*voxSize(4); 162 | nT = numel(tt); 163 | 164 | ind = find(mask(:,:,:,1)<0); 165 | nVox = numel(ind); 166 | h = (voxSize(4))/stepSizeFraction; % step size 167 | 168 | fprintf('Setting up interpolation grids.\n'); 169 | if (nT>1) 170 | fx = griddedInterpolant(x,y,z,t,vx,'linear', 'none'); 171 | fy = griddedInterpolant(x,y,z,t,vy,'linear', 'none'); 172 | fz = griddedInterpolant(x,y,z,t,vz,'linear', 'none'); 173 | 174 | F = @(tq,rq) [fx([rq,tq]), fy([rq,tq]), fz([rq,tq])]; 175 | else 176 | fx = griddedInterpolant(x,y,z,vx,'linear', 'none'); 177 | fy = griddedInterpolant(x,y,z,vy,'linear', 'none'); 178 | fz = griddedInterpolant(x,y,z,vz,'linear', 'none'); 179 | 180 | F = @(tq,rq) [fx(rq), fy(rq), fz(rq)]; 181 | end 182 | 183 | 184 | ttt = repmat(reshape(tt,1,1,[]),[nVox,1,1]); 185 | fprintf('Found %i mask voxels. Making %i pathlines.\n',nVox,nVox); 186 | paths = zeros(nVox,4,nT*stepSizeFraction); 187 | 188 | paths(:,:,1) = [x(ind),y(ind),z(ind),t(ind)]; 189 | for ii=1:(nT*stepSizeFraction-1) % calculation loop 190 | k1 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 0.0*h, paths(:,1:3,ii) + 0.0*h ); 191 | k2 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 0.5*h, paths(:,1:3,ii) + 0.5*h*k1); 192 | k3 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 0.5*h, paths(:,1:3,ii) + 0.5*h*k2); 193 | if ii<(nT*stepSizeFraction-1) 194 | k4 = F( ttt(:,1,ceil(ii/stepSizeFraction)) + 1.0*h, paths(:,1:3,ii) + 1.0*h*k3); 195 | else 196 | k4 = F( ttt(:,1,nT) , paths(:,1:3,ii) + 1.0*h*k3); 197 | end 198 | paths(:,:,ii+1) = paths(:,:,ii) + [(1/6)*(k1+2*k2+2*k3+k4)*h, repmat(h,[nVox,1])]; % main equation 199 | end 200 | paths = paths(:,:,1:stepSizeFraction:end); 201 | fprintf('Done computing pathlines.\n'); 202 | 203 | paths = permute(paths,[3,2,1]); 204 | end -------------------------------------------------------------------------------- /matlab_scripts/computeVentricularKE.m: -------------------------------------------------------------------------------- 1 | function data = computeVentricularKE(srcDir,tshift,maskFile) 2 | %computeVentricularKE compute ventricular kinetic energy 3 | % data = computeVentricularKE(srcDir) calculate KE-tiem curve based on 4D 4 | % flow data and mask located in directory 'srcDir' 5 | % The directory should contain 4D flow data in nifti format (MAG.nii, 6 | % CD.nii, VELX.nii, VELY.nii, VELZ.nii, registeredMask.nii) 7 | 8 | 9 | % Parse Input 1 10 | curDir = pwd; 11 | if nargin<1 || (~ischar(srcDir) && ~isstring(srcDir)) || ~exist(srcDir, 'dir') 12 | srcDir = curDir; 13 | fprintf('Setting source directory to %s\n', srcDir); 14 | end 15 | srcDir = char(srcDir); 16 | 17 | % Parse Input 2 18 | if nargin<2 19 | PULSEGATED = false; 20 | if PULSEGATED; tshift=[4,0]; else tshift=[0,0]; end %#ok 21 | end 22 | 23 | rho = 1060; % blood density in kg/m^3 24 | 25 | % Check whether niftis are compressed or not 26 | if exist(fullfile(srcDir,'VELX.nii'),'file') 27 | ext = '.nii'; 28 | else 29 | ext = '.nii.gz'; 30 | end 31 | 32 | % Read 4D flow data 33 | info = niftiinfo(fullfile(srcDir,['VELX',ext])); 34 | velX = circshift(niftiread(fullfile(srcDir,['VELX',ext])),[0,0,0,tshift(1)]); 35 | velY = circshift(niftiread(fullfile(srcDir,['VELY',ext])),[0,0,0,tshift(1)]); 36 | velZ = circshift(niftiread(fullfile(srcDir,['VELZ',ext])),[0,0,0,tshift(1)]); 37 | vol = prod(info.PixelDimensions(1:3)/1000); % pixel volume in m^3 38 | tt = linspace(0,(info.ImageSize(4)-1)*info.PixelDimensions(4),info.ImageSize(4)); 39 | 40 | % Read LV RV mask 41 | if nargin>2 && exist(maskFile,'file') 42 | mask = circshift(squeeze(niftiread(maskFile)),[0,0,0,tshift(2)]); 43 | elseif exist(fullfile(srcDir,'registeredMask.nii'),'file') 44 | mask = circshift(squeeze(round(niftiread(fullfile(srcDir,'registeredMask.nii')))),[0,0,0,tshift(2)]); 45 | else 46 | mask = circshift(squeeze(round(niftiread(fullfile(srcDir,'registeredMask.nii.gz')))),[0,0,0,tshift(2)]); 47 | end 48 | 49 | % Loop through time and compute KE at every time point 50 | for t=1:size(velZ,4) 51 | if size(mask,4)>1 52 | m = mask(:,:,:,t); 53 | else 54 | m = mask; 55 | end 56 | 57 | vX = double(velX(:,:,:,t))./1000; % in m/s 58 | vY = double(velY(:,:,:,t))./1000; % in m/s 59 | vZ = double(velZ(:,:,:,t))./1000; % in m/s 60 | vMag = sqrt(vX.^2 + vY.^2 + vZ.^2); % in m/s 61 | 62 | rvInd = m<0; 63 | lvInd = m>0; 64 | 65 | % Compute KE 66 | data.rv.KE(t) = (1/2)*rho*vol*sum(vMag(rvInd).^2).*1e6; % in uJ 67 | data.lv.KE(t) = (1/2)*rho*vol*sum(vMag(lvInd).^2).*1e6; % in uJ 68 | 69 | % Compute ventricular volume 70 | data.lv.vol(t) = sum(lvInd(:)).*vol.*1e6; % in mL 71 | data.rv.vol(t) = sum(rvInd(:)).*vol.*1e6; % in mL 72 | end 73 | 74 | % Estimate peaks 75 | data.lv = analyzeVentricleKE(data.lv, tt, 'LV'); 76 | data.rv = analyzeVentricleKE(data.rv, tt, 'RV'); 77 | 78 | end 79 | 80 | % Estimate KE peaks from KE vs. time curve 81 | function data = analyzeVentricleKE(data, t, ventricle) 82 | KE = data.KE; 83 | vol = data.vol; 84 | 85 | dVdt = (circshift(vol,1)-circshift(vol,-1)); % Rate of change of ventricular volume 86 | [ESV,EST] = min(vol); 87 | [EDV,~] = max(vol); 88 | SV = EDV-ESV; 89 | 90 | sysInd=1:EST; 91 | diasInd = setdiff(1:numel(vol),sysInd); 92 | 93 | eWaveInd = diasInd(1:round(numel(diasInd)*7/10)); 94 | [~,iEWave] = min(dVdt(eWaveInd)); 95 | iEWave = eWaveInd(iEWave); % index of e-wave peak 96 | 97 | aWaveInd = diasInd((round(numel(diasInd)*7/10+1):end)); 98 | [~,iAWave] = min(dVdt(aWaveInd)); 99 | iAWave = aWaveInd(iAWave); % index of a-wave peak 100 | 101 | [~,iDiastasis] = min(abs(dVdt(iEWave:iAWave))); 102 | iDiastasis = iDiastasis+iEWave-1; 103 | 104 | eWaveInd = diasInd(1:find(diasInd==iDiastasis)); 105 | aWaveInd = diasInd((find(diasInd==iDiastasis)+1):end); 106 | 107 | data.KE_EDV = KE./EDV; % in uJ/mL 108 | data.KE_SV = KE./SV; % in uJ/mL 109 | data.KE = KE./1000; % in mJ 110 | 111 | data.sysKE = mean(data.KE(sysInd)); % mean systolic KE 112 | data.sysKE_EDV = mean(data.KE_EDV(sysInd)); 113 | data.sysKE_SV = mean(data.KE_SV(sysInd)); 114 | 115 | data.diasKE = mean(data.KE(diasInd)); % mean diastolic KE 116 | data.diasKE_EDV = mean(data.KE_EDV(diasInd)); 117 | data.diasKE_SV = mean(data.KE_SV(diasInd)); 118 | 119 | data.aveKE = mean(data.KE); % average KE 120 | data.aveKE_EDV = mean(data.KE_EDV); 121 | data.aveKE_SV = mean(data.KE_SV); 122 | 123 | [data.minKE,imin] = min(data.KE); % minimum KE 124 | data.minKE_EDV = min(data.KE_EDV); 125 | data.minKE_SV = min(data.KE_SV); 126 | 127 | [data.maxSysKE,imaxS] = max(data.KE(sysInd));% peak systolic KE 128 | imaxS = sysInd(imaxS); 129 | data.maxSysKE_EDV = max(data.KE_EDV(sysInd)); 130 | data.maxSysKE_SV = max(data.KE_SV(sysInd)); 131 | 132 | [data.maxEWKE,imaxE] = max(data.KE(eWaveInd));% E-wave KE 133 | imaxE = eWaveInd(imaxE); 134 | data.maxEWKE_EDV = max(data.KE_EDV(eWaveInd)); 135 | data.maxEWKE_SV = max(data.KE_SV(eWaveInd)); 136 | 137 | [data.maxAWKE,imaxA] = max(data.KE(aWaveInd)); % A-Wave KE 138 | imaxA = aWaveInd(imaxA); 139 | data.maxAWKE_EDV = max(data.KE_EDV(aWaveInd)); 140 | data.maxAWKE_SV = max(data.KE_SV(aWaveInd)); 141 | 142 | % Plot ventricular volume vs. time curve 143 | figure(); 144 | plot(t,vol); 145 | ax=gca; 146 | yl = ax.YLim; 147 | hold on; 148 | line([t(iDiastasis),t(iDiastasis)],[0,yl(2)],'LineStyle','--'); 149 | hold off; 150 | ylim([0,yl(2)]); 151 | title(sprintf('%s Volume vs. Time Curve',ventricle)); 152 | xlabel('Time (ms)'); 153 | ylabel('Volume (mL)'); 154 | legend('Volume','Diastasis'); 155 | 156 | % Plot KE-time curve with peaks 157 | figure(); 158 | plot(t,data.KE_EDV,'-r*'); 159 | ax=gca; 160 | yl = ax.YLim; 161 | hold on; 162 | line([t(imaxS),t(imaxS)],[0,yl(2)],'Color','blue','LineStyle','--'); 163 | line([t(imin),t(imin)],[0,yl(2)],'Color','red','LineStyle','--'); 164 | line([t(imaxE),t(imaxE)],[0,yl(2)],'Color','green','LineStyle','--'); 165 | line([t(imaxA),t(imaxA)],[0,yl(2)],'Color','magenta','LineStyle','--'); 166 | hold off; 167 | ylim([0,yl(2)]); 168 | title(sprintf('%s Kinetic Energy Time Curve',ventricle)); 169 | xlabel('Time (ms)'); 170 | ylabel('Kinetic Energy/EDV (uJ/mL)'); 171 | legend(sprintf('%s KE-Time Curve',ventricle),'max systolic KE','minimum KE','max E-Wave Ke','max A-wave KE'); 172 | end -------------------------------------------------------------------------------- /matlab_scripts/pcvipr2nii.m: -------------------------------------------------------------------------------- 1 | function pcvipr2nii(srcDir) 2 | %pcvipr2nii converts PC VIPR 4D flow data (*.dat format) to nifti format 3 | 4 | curDir = pwd; 5 | if nargin<1 || (~ischar(srcDir) && ~isstring(srcDir)) || ~exist(srcDir, 'dir') 6 | srcDir = curDir; 7 | fprintf('Setting source directory to %s\n', srcDir); 8 | end 9 | srcDir = char(srcDir); 10 | 11 | pcVIPRHeaderFilePath = fullfile(srcDir,'pcvipr_header.txt'); 12 | if exist(pcVIPRHeaderFilePath, 'file')==2 13 | fid = fopen(pcVIPRHeaderFilePath); 14 | if fid<0 15 | error('Could not open pcvipr_header.txt file.'); 16 | else 17 | C = textscan(fid,'%s %s'); 18 | field = C{1}; 19 | value = C{2}; 20 | fclose(fid); 21 | 22 | fov = lookup(field,value,'fovx',3); 23 | 24 | xSize = lookup(field,value,'matrixx',1); 25 | ySize = lookup(field,value,'matrixy',1); 26 | zSize = lookup(field,value,'matrixz',1); 27 | nT = lookup(field,value,'frames',1); 28 | dT = lookup(field,value,'timeres',1); 29 | 30 | p = lookup(field,value,'sx',3)'; 31 | R = reshape(lookup(field,value,'ix',9),[3,3])'; 32 | spacing = sqrt(diag(R*R')); 33 | 34 | cd = zeros(xSize,ySize,zSize,nT, 'int16'); 35 | mag = zeros(xSize,ySize,zSize,nT, 'int16'); 36 | velX = zeros(xSize,ySize,zSize,nT, 'int16'); 37 | velY = zeros(xSize,ySize,zSize,nT, 'int16'); 38 | velZ = zeros(xSize,ySize,zSize,nT, 'int16'); 39 | avgCd = zeros(xSize,ySize,zSize, 'int16'); %#ok 40 | avgMag = zeros(xSize,ySize,zSize, 'int16'); %#ok 41 | avgVelX = zeros(xSize,ySize,zSize, 'int16'); %#ok 42 | avgVelY = zeros(xSize,ySize,zSize, 'int16'); %#ok 43 | avgVelZ = zeros(xSize,ySize,zSize, 'int16'); %#ok 44 | 45 | name = fullfile(srcDir,'CD.dat'); 46 | if ~exist(name, 'file') 47 | error('Could not find CD.dat file.'); 48 | end 49 | m = memmapfile(name,'Format','int16'); 50 | avgCd = reshape(m.Data,[xSize,ySize,zSize]); 51 | 52 | name = fullfile(srcDir,'MAG.dat'); 53 | if ~exist(name, 'file') 54 | error('Could not find MAG.dat file.'); 55 | end 56 | m = memmapfile(name,'Format','int16'); 57 | avgMag = reshape(m.Data,[xSize,ySize,zSize]); 58 | 59 | name = fullfile(srcDir,'comp_vd_1.dat'); 60 | if ~exist(name, 'file') 61 | error('Could not find comp_vd_1.dat file.'); 62 | end 63 | m = memmapfile(name,'Format','int16'); 64 | avgVelX = reshape(m.Data,[xSize,ySize,zSize]); 65 | 66 | name = fullfile(srcDir,'comp_vd_2.dat'); 67 | if ~exist(name, 'file') 68 | error('Could not find comp_vd_2.dat file.'); 69 | end 70 | m = memmapfile(name,'Format','int16'); 71 | avgVelY = reshape(m.Data,[xSize,ySize,zSize]); 72 | 73 | name = fullfile(srcDir,'comp_vd_3.dat'); 74 | if ~exist(name, 'file') 75 | error('Could not find comp_vd_3.dat file.'); 76 | end 77 | m = memmapfile(name,'Format','int16'); 78 | avgVelZ = reshape(m.Data,[xSize,ySize,zSize]); 79 | 80 | 81 | if nT==1 82 | mag = avgMag; cd = avgCd; velX = avgVelX; velY = avgVelY; velZ = avgVelZ; 83 | else 84 | for ii=1:nT 85 | name = fullfile(srcDir,sprintf('ph_%03d_cd.dat',ii-1)); 86 | if ~exist(name, 'file') 87 | error('Could not find ph_%03d_cd.dat file.',ii-1); 88 | end 89 | m = memmapfile(name,'Format','int16'); 90 | cd(:,:,:,ii) = reshape(m.Data,[xSize,ySize,zSize]); 91 | 92 | name = fullfile(srcDir,sprintf('ph_%03d_mag.dat',ii-1)); 93 | if ~exist(name, 'file') 94 | error('Could not find ph_%03d_mag.dat file.',ii-1); 95 | end 96 | m = memmapfile(name,'Format','int16'); 97 | mag(:,:,:,ii) = reshape(m.Data,[xSize,ySize,zSize]); 98 | 99 | name = fullfile(srcDir,sprintf('ph_%03d_vd_1.dat',ii-1)); 100 | if ~exist(name, 'file') 101 | error('Could not find file "ph_%03d_vd_1.dat".',ii-1); 102 | end 103 | m = memmapfile(name,'Format','int16'); 104 | velX(:,:,:,ii) = reshape(m.Data,[xSize,ySize,zSize]); 105 | 106 | name = fullfile(srcDir,sprintf('ph_%03d_vd_2.dat',ii-1)); 107 | 108 | if ~exist(name, 'file') 109 | error('Could not find file "ph_%03d_vd_2.dat".',ii-1); 110 | end 111 | m = memmapfile(name,'Format','int16'); 112 | velY(:,:,:,ii) = reshape(m.data,[xSize,ySize,zSize]); 113 | 114 | name = fullfile(srcDir,sprintf('ph_%03d_vd_3.dat',ii-1)); 115 | if ~exist(name, 'file') 116 | error('Could not find file "ph_%03d_vd_3.dat".',ii-1); 117 | end 118 | m = memmapfile(name,'Format','int16'); 119 | velZ(:,:,:,ii) = reshape(m.Data,[xSize,ySize,zSize]); 120 | end 121 | end 122 | end 123 | else 124 | error('Could not find pcvipr_header.txt file.'); 125 | end 126 | 127 | info.Filename = ''; 128 | info.Filemoddate = ''; 129 | info.Filesize = 0; 130 | info.Description = ''; 131 | info.Datatype = 'int16'; 132 | info.BitsPerPixel = 16; 133 | info.SpaceUnits = 'Millimeter'; 134 | info.AdditiveOffset = 0; 135 | info.MultiplicativeScaling = 0; 136 | info.TimeOffset = 0; 137 | info.SliceCode = 'Unknown'; 138 | info.FrequencyDimension = 0; 139 | info.PhaseDimension = 0; 140 | info.SpatialDimension = 0; 141 | info.DisplayIntensityRange = [0 0]; 142 | info.TransformName = 'Sform'; 143 | info.Qfactor = 1; 144 | R = R.*repmat([-1,-1,1],3,1); 145 | p = p.*[-1,-1,1]; 146 | % if det(R)>0 147 | info.Qfactor = 1; 148 | % else 149 | % info.Qfactor = -1; 150 | % R(:,3) = -R(:,3); 151 | % end 152 | info.Transform = affine3d([[R;p],[0;0;0;1]]); 153 | info.AuxiliaryFile = 'none'; 154 | 155 | if nT==1 156 | info.ImageSize = [xSize, ySize, zSize]; 157 | info.PixelDimensions = spacing'; 158 | info.TimeUnits = 'None'; 159 | niftiwrite(mag, fullfile(srcDir,"MAG.nii"), info); 160 | niftiwrite(cd, fullfile(srcDir,"CD.nii"), info); 161 | niftiwrite(velX, fullfile(srcDir,"VELX.nii"), info); 162 | niftiwrite(velY, fullfile(srcDir,"VELY.nii"), info); 163 | niftiwrite(velZ, fullfile(srcDir,"VELZ.nii"), info); 164 | 165 | elseif nT>1 166 | info.ImageSize = [xSize, ySize, zSize, nT]; 167 | info.PixelDimensions = [spacing', dT]; 168 | info.TimeUnits = 'Millisecond'; 169 | niftiwrite(mag, fullfile(srcDir,"MAG.nii"), info); 170 | niftiwrite(cd, fullfile(srcDir,"CD.nii)"), info); 171 | niftiwrite(velX, fullfile(srcDir,"VELX.nii"), info); 172 | niftiwrite(velY, fullfile(srcDir,"VELY.nii"), info); 173 | niftiwrite(velZ, fullfile(srcDir,"VELZ.nii"), info); 174 | 175 | info.ImageSize = info.ImageSize(1:3); 176 | info.PixelDimensions = info.PixelDimensions(1:3); 177 | info.raw.dim(1) = 3; 178 | info.raw.dim(5) = 1; 179 | info.TimeUnits = 'None'; 180 | niftiwrite(avgMag, fullfile(srcDir,"AVG_MAG.nii"), info); 181 | niftiwrite(avgCd, fullfile(srcDir,"AVG_CD.nii"), info); 182 | niftiwrite(avgVelX, fullfile(srcDir,"AVG_VELX.nii"), info); 183 | niftiwrite(avgVelY, fullfile(srcDir,"AVG_VELY.nii"), info); 184 | niftiwrite(avgVelZ, fullfile(srcDir,"AVG_VELZ.nii"), info); 185 | end 186 | end 187 | 188 | function value = lookup(fields,values,field, length) 189 | index = find(cellfun(@(s) strcmp(field, s), fields)); 190 | value = cellfun(@str2num,values(index:(index+length-1))); 191 | end -------------------------------------------------------------------------------- /matlab_scripts/resizeNii.m: -------------------------------------------------------------------------------- 1 | function resizeNii(fileName, factor, method) 2 | % resizeNii resizes a nifti image in x- and y-directions by a set scale 3 | % factor to achieve a desired resolution. 4 | if nargin<3; method='bilinear'; end 5 | img = niftiread(fileName); 6 | if numel(factor)==1; factor = [factor]; 7 | info = niftiinfo(fileName); 8 | for z=1:size(img,3) 9 | for t = 1:size(img,4) 10 | img2(:,:,z,t) = imresize(img(:,:,z,t), factor, method); %#ok 11 | end 12 | end 13 | info.ImageSize(1:2) = round(info.ImageSize(1:2).*factor); 14 | info.PixelDimensions(1:2) = info.PixelDimensions(1:2)./factor; 15 | info.Transform.T(1:2,:) = info.Transform.T(1:2,:)./factor; 16 | niftiwrite(img2,fileName,info); 17 | end 18 | 19 | -------------------------------------------------------------------------------- /matlab_scripts/segment2nii.m: -------------------------------------------------------------------------------- 1 | function segment2nii(matFile) 2 | % segment2nii converts a saved session file from Medviso Segment software 3 | % into nifti files for the short-axis image stack and for the LV/RV 4 | % segmentation. 5 | 6 | if nargin<1 || (~ischar(matFile) && ~isstring(matFile)) || ~exist(matFile, 'file') 7 | error("Segment .mat file not found.\n Usage: segment2nii('segment_file.mat');"); 8 | end 9 | [path,~,~] = fileparts(matFile); 10 | 11 | segmentData = load(matFile); 12 | for ii=1:numel(segmentData.setstruct) 13 | if strcmp(segmentData.setstruct(ii).ImageViewPlane, 'Short-axis') && segmentData.setstruct(ii).ZSize>1 14 | dat = segmentData.setstruct(ii); 15 | end 16 | end 17 | if ~exist('dat','var'); error('Could not find short-axis dataset.'); end 18 | 19 | endoX = dat.EndoX; 20 | endoY = dat.EndoY; 21 | rvEndoX = dat.RVEndoX; 22 | rvEndoY = dat.RVEndoY; 23 | 24 | p = dat.ImagePosition; 25 | R = dat.ImageOrientation; 26 | R(7:9) = -cross(R(1:3),R(4:6)); 27 | R = reshape(R, [3,3]); 28 | spacing = [dat.ResolutionX,dat.ResolutionY,dat.SliceThickness+dat.SliceGap]; 29 | img = permute(dat.IM,[2,1,4,3]); 30 | img = int16(img.*(2^15)./max(img(:))); 31 | xSize = dat.XSize; 32 | ySize = dat.YSize; 33 | zSize = dat.ZSize; 34 | nT = dat.TSize; 35 | dT = dat.TIncr*1000; 36 | 37 | info.Filename = ''; 38 | info.Filemoddate = ''; 39 | info.Filesize = 0; 40 | info.Description = ''; 41 | info.Datatype = 'int16'; 42 | info.BitsPerPixel = 16; 43 | info.SpaceUnits = 'Millimeter'; 44 | info.AdditiveOffset = 0; 45 | info.MultiplicativeScaling = 0; 46 | info.TimeOffset = 0; 47 | info.SliceCode = 'Unknown'; 48 | info.FrequencyDimension = 0; 49 | info.PhaseDimension = 0; 50 | info.SpatialDimension = 0; 51 | info.DisplayIntensityRange = [0 0]; 52 | info.TransformName = 'Sform'; 53 | info.Qfactor = 1; 54 | R = R'*diag([-1,-1,1]); 55 | p = p.*[-1,-1,1]; 56 | 57 | info.Transform = affine3d([[diag(spacing)*R;p],[0;0;0;1]]); 58 | info.AuxiliaryFile = 'none'; 59 | 60 | info.ImageSize = [ySize, xSize, zSize, nT]; 61 | info.PixelDimensions = [spacing, dT]; 62 | info.TimeUnits = 'Millisecond'; 63 | 64 | [y,x,z,t] = ndgrid(1:ySize,1:xSize,1:zSize,1:nT); 65 | index = findIndexInBoundary(x(:),y(:),z(:),t(:),endoX,endoY); 66 | indexRV = findIndexInBoundary(x(:),y(:),z(:),t(:),rvEndoX,rvEndoY); 67 | 68 | [zBase,zMid,zApex] = divideSlices(t(index),z(index), info.Transform.T(3,2)); 69 | segImg = int16(zeros(ySize,xSize,zSize,nT)); 70 | 71 | 72 | for tt = 1:numel(zBase) 73 | ind = index(t(index)==tt); 74 | segImg(ind(ismember(z(ind),zBase{tt})))=int16(1); 75 | segImg(ind(ismember(z(ind),zMid{tt})))=int16(2); 76 | segImg(ind(ismember(z(ind),zApex{tt})))=int16(3); 77 | indRV = indexRV(t(indexRV)==tt); 78 | segImg(indRV(ismember(z(indRV),zBase{tt})))=int16(-1); 79 | segImg(indRV(ismember(z(indRV),zMid{tt})))=int16(-2); 80 | segImg(indRV(ismember(z(indRV),zApex{tt})))=int16(-3); 81 | end 82 | 83 | niftiwrite(img, fullfile(path,'saSegment.nii'), info); 84 | niftiwrite(segImg, fullfile(path,'saSegmentMask.nii'), info); 85 | 86 | info.ImageSize = info.ImageSize(1:3); 87 | info.PixelDimensions = info.PixelDimensions(1:3); 88 | info.TimeUnits = 'None'; 89 | niftiwrite(int16(mean(img,4)), fullfile(path,'saSegmentAvg.nii'), info); 90 | end 91 | 92 | function index = findIndexInBoundary(x,y,z,t,bX,bY) 93 | index = []; 94 | for tInd = 1:size(bX,2) 95 | for zInd = 1:size(bX,3) 96 | if all(isnan(bX(:,tInd,zInd))) 97 | continue; 98 | end 99 | ind = find((round(t)==tInd).*round(z)==zInd); 100 | subInd = inpolygon(x(ind),y(ind),bX(:,tInd,zInd),bY(:,tInd,zInd)); 101 | if ~isempty(ind) 102 | index = [index;ind(subInd)]; %#ok 103 | end 104 | end 105 | end 106 | end 107 | function [zBase,zMid,zApex] = divideSlices(t,z,ky) 108 | a = unique([t,z],'rows'); 109 | for tt=min(a(:,1)):max(a(:,1)) 110 | ind = find(a(:,1)==tt); 111 | zBase{tt} = min(a(ind,2)):round(min(a(ind,2)) + (max(a(ind,2))-min(a(ind,2))-2)/3);%#ok 112 | zApex{tt} = round(max(a(ind,2)) - (max(a(ind,2))-min(a(ind,2))-2)/3):max(a(ind,2)); %#ok 113 | zMid{tt} = max(zBase{tt}+1):min(zApex{tt}-1); %#ok 114 | end 115 | if ky<0 116 | temp = zBase; 117 | zBase = zApex; 118 | zApex = temp; 119 | end 120 | end -------------------------------------------------------------------------------- /model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.index -------------------------------------------------------------------------------- /model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.meta -------------------------------------------------------------------------------- /model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/export/home/pcorrado/CODE/ukbb_cardiac/model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt" 2 | all_model_checkpoint_paths: "/export/home/pcorrado/CODE/ukbb_cardiac/model/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt" 3 | -------------------------------------------------------------------------------- /modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.index -------------------------------------------------------------------------------- /modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt.meta -------------------------------------------------------------------------------- /modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/export/home/pcorrado/CODE/ukbb_cardiac/modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt" 2 | all_model_checkpoint_paths: "/export/home/pcorrado/CODE/ukbb_cardiac/modelFT/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001/FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001.ckpt" 3 | -------------------------------------------------------------------------------- /registration/register_SA_mask_to_Flow_images: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUMPARAMS=$# 4 | 5 | if [ $NUMPARAMS -lt 3 ] 6 | then 7 | echo " USAGE :: " 8 | echo " antsaffine_LVFlow AVG_MAG.nii.gz sa_avg.nii.gz seg_sa.nii.gz" 9 | exit 10 | fi 11 | 12 | AVG_MAG=${1} 13 | SA_AVG=${2} 14 | SEG_SA=${3} 15 | 16 | echo "Arg 1: $1" 17 | echo "Arg 2: $2" 18 | echo "Arg 3: $3" 19 | 20 | echo "AVG_MAG: $AVG_MAG" 21 | echo "SA_AVG: $SA_AVG" 22 | echo "SEG_SA: $SEG_SA" 23 | 24 | 25 | antsRegistration --dimensionality 3 --output SA2Flow --transform Rigid[0.01] --interpolation Linear --verbose 1 --convergence [30, 1e-3, 10] --shrink-factors 2 --smoothing-sigmas 2 --metric CC[$AVG_MAG, $SA_AVG] --winsorize-image-intensities [0.01, 0.6] 26 | 27 | 28 | antsApplyTransforms -d 3 -e 0 -i $SA_AVG -t SA2Flow0GenericAffine.mat -n Linear -r $AVG_MAG -o registeredSA.nii.gz -v 1 29 | 30 | 31 | antsApplyTransforms -d 3 -e 3 -i $SEG_SA -t SA2Flow0GenericAffine.mat -n NearestNeighbor -r $AVG_MAG -o registeredMask.nii.gz -v 1 32 | -------------------------------------------------------------------------------- /segmentAll.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Wenjia Bai. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the 'License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an 'AS IS' BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | 18 | testDir = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/test' 19 | modelBasePath = '/home/pcorrado/Cardiac-DL-Segmentation-Paper/Cardiac-Segmentation-4D-Flow/TrainedModels' 20 | modelPaths = ['model_{}_layers_frozen'.format(l) for l in [4,8,12,14,15]] 21 | modelPaths.append('modelUnfrozen') 22 | modelName = 'FCN_sa_level5_filter16_22333_batch20_iter10000_lr0.001' 23 | numLayers = [4,8,12,14,15,0] 24 | 25 | if __name__ == '__main__': 26 | 27 | for ii in range(len(modelPaths)): 28 | os.system('python3 common/deploy_network.py --data_dir {0} ' 29 | '--model_path {1}/{2}/{3}/{3}.ckpt-10000'.format(testDir, modelBasePath, modelPaths[ii], modelName)) 30 | for data in sorted(os.listdir(testDir)): 31 | data_dir = os.path.join(testDir, data) 32 | os.system('mv {0}/seg_sa.nii.gz {0}/sa_label_{1}.nii.gz'.format(data_dir, numLayers[ii])) 33 | 34 | 35 | -------------------------------------------------------------------------------- /third_party/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | 3 | # The directory for compiled executable files 4 | SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin) 5 | 6 | # MIRTK 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") 8 | set(MIRTK_SOURCE_DIR /vol/biomedic2/wbai/git/MIRTK) 9 | set(MIRTK_BINARY_DIR /vol/biomedic2/wbai/git/MIRTK_bin) 10 | 11 | include_directories(${MIRTK_SOURCE_DIR}/Modules/Common/include) 12 | include_directories(${MIRTK_SOURCE_DIR}/Modules/Image/include) 13 | include_directories(${MIRTK_SOURCE_DIR}/Modules/ImageIO/include) 14 | include_directories(${MIRTK_SOURCE_DIR}/Modules/Numerics/include) 15 | include_directories(${MIRTK_SOURCE_DIR}/Modules/PointSet/include) 16 | include_directories(${MIRTK_SOURCE_DIR}/Modules/Registration/include) 17 | include_directories(${MIRTK_SOURCE_DIR}/Modules/Transformation/include) 18 | include_directories(${MIRTK_BINARY_DIR}/include) 19 | 20 | link_directories(${MIRTK_BINARY_DIR}/lib) 21 | link_libraries(MIRTKCommon MIRTKNumerics MIRTKImage MIRTKIO MIRTKPointSet MIRTKRegistration MIRTKTransformation) 22 | 23 | # VTK 24 | find_package(VTK REQUIRED) 25 | include(${VTK_USE_FILE}) 26 | link_libraries(${VTK_LIBRARIES}) 27 | 28 | # Executables 29 | ADD_EXECUTABLE(average_3d_ffd average_3d_ffd.cc) 30 | -------------------------------------------------------------------------------- /third_party/src/average_3d_ffd.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | using namespace mirtk; 8 | 9 | // ============================================================================= 10 | // Help 11 | // ============================================================================= 12 | 13 | // ----------------------------------------------------------------------------- 14 | // Print help screen 15 | void PrintHelp(const char *name) 16 | { 17 | cout << endl; 18 | cout << "Usage: " << name << " n ... " << endl; 19 | cout << endl; 20 | 21 | cout << "Description:" << endl; 22 | cout << " Computes the average T of the given input free-form deformations such that" << endl; 23 | cout << " T_out(x) = T1(x) * w1 + T2(x) * w2 ... + Tn(x) * wn " << endl; 24 | cout << endl; 25 | 26 | cout << " can be one or more of the following:" << endl; 27 | cerr << "<-verbose> Display information." << endl; 28 | exit(1); 29 | } 30 | 31 | // ============================================================================= 32 | // Main 33 | // ============================================================================= 34 | 35 | // ----------------------------------------------------------------------------- 36 | int main(int argc, char **argv) 37 | { 38 | // Parse arguments 39 | char *command = argv[0]; 40 | argc--; argv++; 41 | if (argc < 3) { 42 | PrintHelp(command); 43 | } 44 | 45 | int n_trans = atoi(argv[0]); 46 | argc--; argv++; 47 | char **input_name = new char *[n_trans]; 48 | double *w = new double[n_trans]; 49 | for (int i = 0; i < n_trans; i++) { 50 | input_name[i] = argv[0]; 51 | argc--; argv++; 52 | w[i] = atof(argv[0]); 53 | argc--; argv++; 54 | } 55 | 56 | char *output_name = argv[0]; 57 | argc--; argv++; 58 | 59 | bool verbose = false; 60 | while(argc > 0){ 61 | bool ok = false; 62 | if((ok == false) && (strcmp(argv[0], "-verbose") == 0)){ 63 | argc--; argv++; 64 | verbose = true; 65 | ok = true; 66 | } 67 | if(ok == false){ 68 | cerr << "Can not parse argument " << argv[0] << endl; 69 | PrintHelp(command); 70 | } 71 | } 72 | 73 | // Read the input transformations 74 | MultiLevelTransformation **T = new MultiLevelTransformation *[n_trans]; 75 | for (int i = 0; i < n_trans; i++) { 76 | if (verbose) { 77 | cout << "Reading transformation " << i << " from " << input_name[i] << " (weight = " << w[i] << ") ..." << endl; 78 | } 79 | Transformation *ptr = Transformation::New(input_name[i]); 80 | T[i] = dynamic_cast(ptr); 81 | if (T[i] == NULL) { 82 | cout << "Error: error in reading the transformation file." << endl; 83 | exit(0); 84 | } 85 | if (T[i]->NumberOfLevels() > 1) { 86 | cout << "Error: the transformation has more than one local FFDs." << endl; 87 | exit(0); 88 | } 89 | } 90 | 91 | // Get information from the first ffd 92 | BSplineFreeFormTransformation3D *T0_ffd = dynamic_cast(T[0]->GetLocalTransformation(0)); 93 | ImageAttributes attr = T0_ffd->Attributes(); 94 | int X = attr._x; 95 | int Y = attr._y; 96 | int Z = attr._z; 97 | 98 | // Allocate the output ffd 99 | BSplineFreeFormTransformation3D *out_ffd = new BSplineFreeFormTransformation3D(attr); 100 | 101 | // For each control point 102 | for (int k = 0; k < Z; k++) { 103 | for (int j = 0; j < Y; j++) { 104 | for (int i = 0; i < X; i++) { 105 | double out_dx = 0; 106 | double out_dy = 0; 107 | double out_dz = 0; 108 | 109 | for (int n = 0; n < n_trans; n++) { 110 | double dx, dy, dz; 111 | T[n]->GetLocalTransformation(0)->Get(i, j, k, dx, dy, dz); 112 | out_dx += w[n] * dx; 113 | out_dy += w[n] * dy; 114 | out_dz += w[n] * dz; 115 | } 116 | 117 | out_ffd->Put(i, j, k, out_dx, out_dy, out_dz); 118 | } 119 | } 120 | } 121 | 122 | // The output transformation 123 | MultiLevelFreeFormTransformation *T_out = new MultiLevelFreeFormTransformation; 124 | T_out->PushLocalTransformation(out_ffd); 125 | T_out->Write(output_name); 126 | if (verbose) { 127 | cout << "Writing the average transformation to " << output_name << endl; 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /third_party/ubuntu_16.04_bin/average_3d_ffd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/third_party/ubuntu_16.04_bin/average_3d_ffd -------------------------------------------------------------------------------- /third_party/ubuntu_18.04_bin/average_3d_ffd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/third_party/ubuntu_18.04_bin/average_3d_ffd -------------------------------------------------------------------------------- /ukbb_trained_model/FCN_sa.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/ukbb_trained_model/FCN_sa.data-00000-of-00001 -------------------------------------------------------------------------------- /ukbb_trained_model/FCN_sa.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/ukbb_trained_model/FCN_sa.index -------------------------------------------------------------------------------- /ukbb_trained_model/FCN_sa.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcorrado/Cardiac-Segmentation-4D-Flow/23cfa4dcd17dd8195490018879fec104e362c5f5/ukbb_trained_model/FCN_sa.meta --------------------------------------------------------------------------------