├── .gitignore ├── CRL_Fetal_Brain_Atlas_2017v3 └── README ├── LICENSE ├── README.md ├── code ├── config.py ├── data.py ├── main.py ├── models.py ├── trajectory.py └── utils.py ├── img └── stress.gif └── trajectory ├── 1.mat ├── 10.mat ├── 2.mat ├── 3.mat ├── 4.mat ├── 5.mat ├── 6.mat ├── 7.mat ├── 8.mat └── 9.mat /.gitignore: -------------------------------------------------------------------------------- 1 | /CRL_Fetal_Brain_Atlas_2017v3/* 2 | !/CRL_Fetal_Brain_Atlas_2017v3/README 3 | /results/* 4 | /_results/* 5 | .vscode/ 6 | .mypy_cache/ 7 | *.nii 8 | *.nii.gz 9 | *.npy 10 | *.pt 11 | 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ -------------------------------------------------------------------------------- /CRL_Fetal_Brain_Atlas_2017v3/README: -------------------------------------------------------------------------------- 1 | The CRL fetal brain atlas can be downloaded from https://form.jotform.com/91364382958166 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Junshen Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fetalSR 2 | 3 | STRESS: Super-Resolution for Dynamic Fetal MRI using Self-Supervised Learning ([Springer](https://link.springer.com/chapter/10.1007/978-3-030-87234-2_19)|[arXiv](https://arxiv.org/abs/2106.12407)) 4 | 5 |

6 | STRESS 7 |

8 | 9 | ## Usage 10 | 11 | Run ```python main.py``` to train a model and test on the simulated dataset. 12 | 13 | You may create your own dataset following `EPIDataset` in `data.py`. 14 | 15 | ## Cite our work 16 | 17 | ``` 18 | @inproceedings{xu2021stress, 19 | title={STRESS: Super-Resolution for Dynamic Fetal MRI Using Self-supervised Learning}, 20 | author={Xu, Junshen and Abaci Turk, Esra and Grant, P Ellen and Golland, Polina and Adalsteinsson, Elfar}, 21 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 22 | pages={197--206}, 23 | year={2021}, 24 | organization={Springer} 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 4 | # hyperparameters 5 | use_sim = True 6 | is_denoise = False 7 | denoiser = None 8 | 9 | model_name = "sim_s4_k3_n3" 10 | batch_size = 16 if is_denoise else 64 11 | num_iter = 100000 12 | lr = 1e-4 13 | 14 | num_split = 4 15 | use_k = num_split // 2 16 | 17 | sigma = 0.03 18 | -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage import rotate 2 | import torch.multiprocessing as mp 3 | import numpy as np 4 | import nibabel as nib 5 | import torch 6 | import os 7 | from trajectory import get_trajectory 8 | import traceback 9 | from scipy.spatial.transform import Rotation 10 | from scipy.ndimage import map_coordinates, affine_transform 11 | from scipy.stats import special_ortho_group 12 | from itertools import product 13 | from time import time 14 | from config import sigma as _sigma 15 | 16 | def read_nifti(nii_filename): 17 | data = nib.load(nii_filename) 18 | return np.squeeze(data.get_data().astype(np.float32)) 19 | 20 | def down_up(*imgs, start=0): 21 | res = [] 22 | num_split = (len(imgs) + 1) // 2 23 | for s, img in enumerate(imgs, start): 24 | ss = s % num_split 25 | X, Y, Z = np.meshgrid((np.arange(img.shape[0]) - ss) / num_split, np.arange(img.shape[1]), np.arange(img.shape[2]), indexing='ij') 26 | res.append(map_coordinates(img[ss::num_split], [X, Y, Z], order=3, mode='nearest')) 27 | return res 28 | 29 | class EPIDataset(torch.utils.data.Dataset): 30 | def __init__(self, num_split, stage, is_denoise=False, denoiser=None): 31 | assert stage in ['train', 'val', 'test'] 32 | self.data_dir = '/home/junshen/new' 33 | self.is_test = stage == 'test' 34 | self.folders = sorted(os.listdir(self.data_dir)) 35 | if stage == 'test': 36 | self.folders = [folder for i, folder in enumerate(self.folders) if i % 6 == 0] 37 | else: 38 | self.folders = [folder for i, folder in enumerate(self.folders) if i % 6 != 0] 39 | self.proc = [] 40 | self.res = [] 41 | num_p = 10 42 | self.queue = mp.Queue(1024) 43 | n_new = 6 44 | self.queue2 = mp.Queue(n_new) 45 | for _ in range(n_new): 46 | self.queue2.put(None) 47 | 48 | if denoiser is not None: 49 | denoise_queue = [{'in':mp.Queue(1),'out':mp.Queue(1)} for _ in range(num_p)] 50 | self.denoiser = mp.Process(target=denoise_fn, args=(denoiser, denoise_queue)) 51 | self.denoiser.daemon = True 52 | self.denoiser.start() 53 | self.denoise_queue = denoise_queue 54 | else: 55 | denoise_queue = [None] * num_p 56 | self.denoise_queue = denoise_queue 57 | 58 | # use multiple processes to fetch data 59 | for i in range(num_p): 60 | proc = mp.Process(target=prefetch_volumes_test if self.is_test else prefetch_volumes, 61 | args=(self.data_dir, self.folders[i::num_p], self.queue, self.queue2, num_split, is_denoise, self.denoise_queue[i])) 62 | proc.daemon = True 63 | proc.start() 64 | self.proc.append(proc) 65 | 66 | def load_data(self): 67 | if len(self.res) == 0: 68 | N = 0 69 | while True: 70 | res = self.queue.get() 71 | if res is None: 72 | N += 1 73 | if N == len(self.proc): 74 | break 75 | else: 76 | self.res.append(res) 77 | self.res = sorted(self.res, key=lambda x:x[-1]) 78 | 79 | if self.denoise_queue[0] is not None: 80 | self.denoise_queue[0]['in'].put(None) 81 | 82 | print("test set len: %d" % len(self.res)) 83 | 84 | def __len__(self): 85 | if self.is_test: 86 | self.load_data() 87 | return len(self.res) 88 | else: 89 | return int(1e8) 90 | 91 | def __getitem__(self, idx): 92 | if self.is_test: 93 | self.load_data() 94 | return self.res[idx][:3] 95 | else: 96 | return self.queue.get() 97 | 98 | 99 | def prefetch_volumes(data_dir, folders, queue, q2, num_split, is_denoise, denoiser): 100 | a = 32 101 | volumes = [None] * len(folders) 102 | files = [[]] * len(folders) 103 | starts = [None] * len(folders) 104 | start0s = [None] * len(folders) 105 | for i in range(len(folders)): 106 | files[i] = sorted(os.listdir(os.path.join(data_dir, folders[i]))) 107 | img = read_nifti(os.path.join(data_dir, folders[i], files[i][0])) 108 | err0 = np.mean(((img[:, :, 20] + img[:, :, 22]) / 2 - img[:, :, 21])**2) 109 | err1 = np.mean(((img[:, :, 21] + img[:, :, 23]) / 2 - img[:, :, 22])**2) 110 | start0s[i] = 0 if err0 < err1 else 1 111 | 112 | try: 113 | while(True): 114 | for i in range(len(volumes)): 115 | 116 | new_vol = False 117 | if volumes[i] is not None: 118 | try: 119 | _ = q2.get_nowait() 120 | new_vol = True 121 | except: 122 | pass 123 | 124 | if volumes[i] is None or new_vol: 125 | fid = np.random.choice(np.arange(num_split, len(files[i])-num_split)) 126 | 127 | angle = np.random.uniform(360) 128 | hrs = [] 129 | 130 | for dt in range(-num_split+1, num_split): 131 | t = fid + dt 132 | 133 | img = read_nifti(os.path.join(data_dir, folders[i], files[i][t])) 134 | img = (img - 70.0) / 100.0 135 | img = rotate(img, angle, axes=(0, 1), reshape=False) 136 | ss = (start0s[i] + t) % num_split 137 | 138 | if denoiser is not None: 139 | d = 128-img.shape[0] 140 | d1 = d//2 141 | d2 = d - d1 142 | if d1 >= 0: 143 | frames = np.pad(img[..., ss::num_split], [(d1, d2),(d1, d2), (0, 0)], mode='constant') 144 | else: 145 | frames = img[-d1:d2, -d1:d2, ss::num_split] 146 | frames = torch.tensor(frames[None]).permute(3, 0, 1, 2) 147 | for n_slice in range(0, frames.shape[0], 16): 148 | denoiser['in'].put(frames[n_slice:n_slice+16]) 149 | frames[n_slice:n_slice+16] = denoiser['out'].get() 150 | if d1 >= 0: 151 | frames = frames.squeeze().permute(1,2,0)[d1:-d2,d1:-d2].numpy() 152 | else: 153 | frames = np.pad(frames.squeeze().permute(1,2,0).numpy(), [(-d1, -d2),(-d1, -d2), (0, 0)], mode='constant') 154 | img[..., ss::num_split] = frames 155 | 156 | if is_denoise: 157 | img = img[..., ss::num_split] 158 | else: 159 | X, Y, Z = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), (np.arange(img.shape[2]) - ss) / num_split, indexing='ij') 160 | img = map_coordinates(img[..., ss::num_split], [X, Y, Z], order=3, mode='nearest') 161 | 162 | hrs.append(img) 163 | 164 | if is_denoise: 165 | hrs = np.concatenate(hrs, -1)[None] 166 | volumes[i] = (hrs, hrs) 167 | else: 168 | volumes[i] = (np.stack(down_up(*hrs), 0), np.stack(hrs, 0)) 169 | starts[i] = (start0s[i] + fid) % num_split 170 | 171 | if new_vol: 172 | q2.put(None) 173 | 174 | lr, hr = volumes[i] 175 | if is_denoise: 176 | z = np.random.randint(lr.shape[3]) 177 | hr = hr[:, :, :, z] 178 | d = 128-hr.shape[1] 179 | d1 = d//2 180 | d2 = d - d1 181 | if d1 >= 0: 182 | hr = np.pad(hr, [(0,0),(d1, d2),(d1, d2)], mode='constant') 183 | else: 184 | hr = hr[:, -d1:d2, -d1:d2] 185 | lr = hr 186 | else: 187 | y = np.random.randint(lr.shape[1] - a) 188 | x = np.random.randint(lr.shape[2] - a) 189 | z = np.random.randint((lr.shape[3] - starts[i]) // num_split) 190 | lr = lr[:, y:y+a, x:x+a, starts[i] + z * num_split] 191 | hr = hr[:, y:y+a, x:x+a, starts[i] + z * num_split] 192 | axis = np.random.choice([None, 1, 2]) 193 | if axis is not None: 194 | lr = np.flip(lr, axis).copy() 195 | hr = np.flip(hr, axis).copy() 196 | lr = torch.tensor(lr, dtype=torch.float32) 197 | hr = torch.tensor(hr, dtype=torch.float32) 198 | queue.put((lr, hr)) 199 | except: 200 | traceback.print_exc() 201 | print("error: %s" % mp.current_process().name) 202 | 203 | def prefetch_volumes_test(data_dir, folders, queue, q2, num_split, is_denoise, denoiser): 204 | try: 205 | for folder in folders: 206 | files = sorted(os.listdir(os.path.join(data_dir, folder))) 207 | img = read_nifti(os.path.join(data_dir, folder, files[0])) 208 | err0 = np.mean(((img[:, :, 20] + img[:, :, 22]) / 2 - img[:, :, 21])**2) 209 | err1 = np.mean(((img[:, :, 21] + img[:, :, 23]) / 2 - img[:, :, 22])**2) 210 | start = 0 if err0 < err1 else 1 211 | 212 | for fid in list(range(len(files)))[::(len(files)//7)][1:-1]: 213 | imgs = [] 214 | combined = np.zeros_like(img) 215 | for dt in range(-num_split+1, num_split): 216 | t = fid + dt 217 | img = read_nifti(os.path.join(data_dir, folder, files[t])) 218 | img = (img - 70.0) / 100.0 219 | ss = (start + t) % num_split 220 | 221 | combined[..., ss::num_split] += img[..., ss::num_split] * (num_split - np.abs(dt)) / num_split 222 | 223 | if dt == 0: 224 | if num_split == 4: 225 | start_gt = (start+fid+2) % num_split 226 | gt = 0 * img - 1000 227 | gt[..., start_gt::num_split] = img[..., start_gt::num_split] 228 | else: 229 | gt = 0 230 | 231 | if denoiser is not None: 232 | d = 128-img.shape[0] 233 | d1 = d//2 234 | d2 = d - d1 235 | if d1 >= 0: 236 | frames = np.pad(img[..., ss::num_split], [(d1, d2),(d1, d2), (0, 0)], mode='constant') 237 | else: 238 | frames = img[-d1:d2, -d1:d2, ss::num_split] 239 | frames = torch.tensor(frames[None]).permute(3, 0, 1, 2) 240 | for n_slice in range(0, frames.shape[0], 16): 241 | denoiser['in'].put(frames[n_slice:n_slice+16]) 242 | frames[n_slice:n_slice+16] = denoiser['out'].get() 243 | if d1 >= 0: 244 | frames = frames.squeeze().permute(1,2,0)[d1:-d2,d1:-d2].numpy() 245 | else: 246 | frames = np.pad(frames.squeeze().permute(1,2,0).numpy(), [(-d1, -d2),(-d1, -d2), (0, 0)], mode='constant') 247 | img[..., ss::num_split] = frames 248 | 249 | if is_denoise: 250 | imgs.append(img[..., ss::num_split]) 251 | else: 252 | X, Y, Z = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), (np.arange(img.shape[2]) - ss) / num_split, indexing='ij') 253 | imgs.append(map_coordinates(img[..., ss::num_split], [X, Y, Z], order=3, mode='nearest')) 254 | 255 | if is_denoise: 256 | imgs = np.concatenate(imgs, -1) 257 | d = 128-img.shape[0] 258 | d1 = d//2 259 | d2 = d - d1 260 | if d1 >= 0: 261 | imgs = np.pad(imgs, [(d1, d2),(d1, d2), (0, 0)], mode='constant') 262 | else: 263 | imgs = imgs[-d1:d2, -d1:d2] 264 | else: 265 | imgs = np.stack(imgs, 0) 266 | queue.put((imgs, gt, combined, (start+fid) % num_split, os.path.join(folder, files[fid]))) 267 | except: 268 | traceback.print_exc() 269 | print("test error: %s" % mp.current_process().name) 270 | queue.put(None) 271 | return 272 | 273 | 274 | class SimDataset(torch.utils.data.Dataset): 275 | def __init__(self, num_split, stage, is_denoise=False, denoiser=None): 276 | assert stage in ['test', 'train'] 277 | self.is_test = stage == 'test' 278 | test_ga = ['25', '28', '30', '33', '35'] 279 | data_dir = '/home/junshen/fetalSR/CRL_Fetal_Brain_Atlas_2017v3/' 280 | files = [f for f in os.listdir(data_dir) if ('STA' in f) and ('_' not in f)] 281 | if self.is_test: 282 | files = [f for f in files if any(ga in f for ga in test_ga)] 283 | else: 284 | files = [f for f in files if all(ga not in f for ga in test_ga)] 285 | files = [os.path.join(data_dir, f) for f in files] 286 | trajs = get_trajectory() 287 | self.proc = [] 288 | 289 | num_p = 10 290 | self.queue = mp.Queue(1024) 291 | 292 | # use multiple processes to fetch data 293 | if self.is_test: 294 | imgs = [nib.load(f).get_fdata().astype(np.float32) / 1000.0 for f in files] 295 | # trajs = [trajs[t][0] for t in [0, 5, 10, 15, 20, 25, 30]] 296 | trajs = [trajs[t][0] for t in range(len(trajs))] 297 | imgs_trajs = list(product(imgs, trajs)) 298 | self.length = len(imgs_trajs) 299 | self.res = [] 300 | else: 301 | n_new = 6 302 | self.queue2 = mp.Queue(n_new) 303 | for _ in range(n_new): 304 | self.queue2.put(None) 305 | self.length = int(1e8) 306 | 307 | if denoiser is not None: 308 | denoise_queue = [{'in':mp.Queue(1),'out':mp.Queue(1)} for _ in range(num_p)] 309 | self.denoiser = mp.Process(target=denoise_fn, args=(denoiser, denoise_queue)) 310 | self.denoiser.daemon = True 311 | self.denoiser.start() 312 | self.denoise_queue = denoise_queue 313 | else: 314 | denoise_queue = [None] * num_p 315 | self.denoise_queue = denoise_queue 316 | 317 | for i in range(num_p): 318 | if self.is_test: 319 | proc = mp.Process(target=prefetch_sim_volumes_test, args=(imgs_trajs[i::num_p], self.queue, num_split, is_denoise, denoise_queue[i])) 320 | else: 321 | proc = mp.Process(target=prefetch_sim_volumes, args=(files[i::num_p], self.queue, trajs, self.queue2, num_split, is_denoise, denoise_queue[i])) 322 | proc.daemon = True 323 | proc.start() 324 | self.proc.append(proc) 325 | 326 | def __len__(self): 327 | return self.length 328 | 329 | def __getitem__(self, idx): 330 | if self.is_test: 331 | if len(self.res) == 0: 332 | for i in range(self.length): 333 | self.res.append(self.queue.get()) 334 | #print(i, self.length) 335 | if self.denoise_queue[0] is not None: 336 | self.denoise_queue[0]['in'].put(None) 337 | return self.res[idx] 338 | else: 339 | return self.queue.get() 340 | 341 | def prefetch_sim_volumes(files, queue, trajs, q2, num_split, is_denoise, denoiser): 342 | #denoiser = denoiser.cuda() 343 | a = 64 # 64 (0.031) 344 | starts = [None] * len(files) 345 | volumes = [None] * len(files) 346 | imgs = [nib.load(f).get_fdata().astype(np.float32) / 1000.0 for f in files] 347 | try: 348 | while True: 349 | for j in range(len(imgs)): 350 | new_vol = False 351 | if volumes[j] is not None: 352 | try: 353 | _ = q2.get_nowait() 354 | new_vol = True 355 | except: 356 | pass 357 | if volumes[j] is None or new_vol: 358 | 359 | start = np.random.choice(num_split) 360 | traj, T = trajs[np.random.choice(len(trajs))] 361 | t0 = np.random.uniform(0, T) 362 | hr, gt, combined, starts[j] = sim_scan(imgs[j], num_split, traj, t0, 1.0 / imgs[j].shape[-1], start, np.eye(3,3), is_denoise, denoiser) 363 | 364 | lr = down_up(*hr) 365 | volumes[j] = (np.stack(lr, 0), np.stack(hr, 0)) 366 | if new_vol: 367 | q2.put(None) 368 | 369 | lr, hr = volumes[j] 370 | while True: 371 | y = np.random.randint(lr.shape[1] - a) 372 | x = np.random.randint(lr.shape[2] - a) 373 | z = np.random.randint((lr.shape[3] - starts[j]) // num_split) 374 | if is_denoise: 375 | lr_ = lr[lr.shape[0]//2:lr.shape[0]//2+1, :, :, starts[j] + z * num_split] 376 | hr_ = hr[lr.shape[0]//2:lr.shape[0]//2+1, :, :, starts[j] + z * num_split] 377 | hr_ = np.pad(hr_, [(0,0),(28, 29),(1, 2)], mode='constant') 378 | else: 379 | lr_ = lr[:, y:y+a, x:x+a, starts[j] + z * num_split] 380 | hr_ = hr[:, y:y+a, x:x+a, starts[j] + z * num_split] 381 | if np.max(hr_) > 1: 382 | lr, hr = lr_, hr_ 383 | break 384 | axis = np.random.choice([None, 1, 2]) 385 | if axis is not None: 386 | lr = np.flip(lr, axis).copy() 387 | hr = np.flip(hr, axis).copy() 388 | lr = torch.tensor(lr, dtype=torch.float32) 389 | hr = torch.tensor(hr, dtype=torch.float32) 390 | queue.put((lr, hr)) 391 | except: 392 | traceback.print_exc() 393 | print("error: %s" % mp.current_process().name) 394 | 395 | def prefetch_sim_volumes_test(imgs_trajs, queue, num_split, is_denoise, denoiser): 396 | #denoiser = denoiser.cuda() 397 | try: 398 | for img, traj in imgs_trajs: 399 | t0 = 9 400 | inputs, gt, combined, start = sim_scan(img, num_split, traj, t0, 1.0 / img.shape[-1], 0, np.eye(3,3), is_denoise, denoiser) 401 | if is_denoise: 402 | gt = np.stack(gt, -1) 403 | inputs = inputs[len(inputs)//2][..., start::num_split] 404 | gt = np.pad(gt, [(28, 29),(1, 2),(0,0)], mode='constant') 405 | inputs = np.pad(inputs, [(28, 29),(1, 2),(0,0)], mode='constant') 406 | else: 407 | inputs = np.stack(inputs, 0) 408 | queue.put((inputs, gt, combined)) 409 | except: 410 | traceback.print_exc() 411 | print("test error: %s" % mp.current_process().name) 412 | return 413 | 414 | def sim_scan(img, num_split, traj, t0, dt, start, rot0, is_denoise=False, model=None): 415 | #model = model.cuda 416 | t0 = t0 - dt * (img.shape[2] - img.shape[2] / 2 / num_split) 417 | idx = start 418 | i = 0 419 | gt = [] 420 | all_frames = [] 421 | combined = [0] * img.shape[2] 422 | frames = [] 423 | sigma = img.max() * _sigma 424 | while True: 425 | if idx >= img.shape[2]: 426 | if model is not None: 427 | frames = np.pad(np.stack(frames, 0), [(0,0), (28, 29),(1, 2)], mode='constant') 428 | frames = torch.tensor(np.stack(frames, 0)[:, None]) 429 | for n_slice in range(0, frames.shape[0], 16): 430 | model['in'].put(frames[n_slice:n_slice+16]) 431 | frames[n_slice:n_slice+16] = model['out'].get() 432 | frames = frames.squeeze().permute(1,2,0)[28:-29,1:-2].numpy() 433 | else: 434 | frames = np.stack(frames, -1) 435 | X, Y, Z = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), (np.arange(img.shape[2]) - start) / num_split, indexing='ij') 436 | all_frames.append(map_coordinates(frames, [X, Y, Z], order=3, mode='nearest')) 437 | if len(all_frames) == 2 * num_split - 1: 438 | return all_frames, gt, np.stack(combined, -1), (start + 1) % num_split 439 | start = (start + 1) % num_split 440 | idx = start 441 | frames = [] 442 | 443 | Rt = traj(t0 + i * dt) 444 | R = Rotation.from_euler('xyz', Rt[:3]).as_matrix() @ rot0 445 | t = Rt[3:] - [img.shape[0]/2, img.shape[1]/2, img.shape[2]/2] @ R.T + [img.shape[0]/2, img.shape[1]/2, img.shape[2]/2] 446 | frame = affine_transform(img, R, t, order=1) 447 | if (len(all_frames) == num_split - 1): 448 | if is_denoise: 449 | gt.append(frame[..., idx]) 450 | elif ((idx - start) // num_split == img.shape[2] // num_split // 2): 451 | gt = frame 452 | frames.append(frame[..., idx]) 453 | if sigma > 0: 454 | noise1 = np.random.normal(scale=sigma, size=frames[-1].shape).astype(np.float32) 455 | noise2 = np.random.normal(scale=sigma, size=frames[-1].shape).astype(np.float32) 456 | frames[-1] = np.sqrt((frames[-1] + noise1)**2 + noise2**2) 457 | #if num_split - 1 - num_split//2 <= len(all_frames) < 2*num_split - 1 - num_split//2: 458 | # combined[idx] = frames[-1] 459 | combined[idx] += frames[-1] * (num_split - np.abs(len(all_frames) - num_split + 1)) / num_split 460 | idx += num_split 461 | i += 1 462 | 463 | def denoise_fn(model, queues): 464 | model = model.cuda() 465 | while True: 466 | for q in queues: 467 | try: 468 | inputs = q['in'].get_nowait() 469 | if inputs is None: 470 | return 471 | with torch.no_grad(): 472 | q['out'].put(model(inputs.cuda()).cpu()) 473 | except: 474 | pass 475 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from data import EPIDataset, SimDataset 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from models import EDSR, NoiseNetwork 7 | import nibabel as nib 8 | import numpy as np 9 | from time import time 10 | from utils import MovingAverage, rician_correct, mkdir, psnr, ssim 11 | import torch.multiprocessing as mp 12 | 13 | if __name__ == "__main__": 14 | # mp.set_start_method('spawn', force=True) 15 | 16 | assert 0 <= use_k < num_split and num_split % 2 == 0 17 | mkdir("../results/" + model_name + "/outputs") 18 | 19 | # model 20 | if is_denoise: 21 | model = NoiseNetwork(in_channels=1, out_channels=1, blindspot=True).cuda() 22 | else: 23 | model = EDSR(cin=2 * use_k + 1, n_resblocks=16, n_feats=64, res_scale=1).cuda() 24 | 25 | #model.load_state_dict( 26 | # torch.load("../results/" + model_name + "/" + model_name + ".pt") 27 | #) 28 | 29 | if denoiser is not None: 30 | denoiser_name = denoiser 31 | denoiser = NoiseNetwork(in_channels=1, out_channels=1, blindspot=True) 32 | denoiser.load_state_dict(torch.load(denoiser_name)) 33 | 34 | # dataset 35 | Dataset = SimDataset if use_sim else EPIDataset 36 | train_dataset = Dataset(num_split, "train", is_denoise, denoiser) 37 | train_dataloader = DataLoader( 38 | train_dataset, batch_size, shuffle=False, pin_memory=True 39 | ) 40 | dataiter = iter(train_dataloader) 41 | test_dataset = Dataset(num_split, "test", is_denoise, denoiser) 42 | test_dataloader = DataLoader(test_dataset, 1, shuffle=False, pin_memory=True) 43 | 44 | # optimizer 45 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 46 | 47 | average = MovingAverage(0.999) 48 | 49 | t_start = time() 50 | for i in range(1, num_iter + 1): 51 | 52 | lr, hr = next(dataiter) 53 | lr = lr.cuda() 54 | hr = hr.cuda() 55 | 56 | if not is_denoise: 57 | out = model(lr[:, num_split - 1 - use_k : num_split + use_k]) 58 | loss = torch.mean(torch.abs(hr[:, num_split - 1 : num_split] - out)) 59 | loss_b = torch.mean( 60 | torch.abs( 61 | hr[:, num_split - 1 : num_split] - lr[:, num_split - 1 : num_split] 62 | ) 63 | ) 64 | average("loss", loss.item()) 65 | average("loss_b", loss_b.item()) 66 | else: 67 | out = model(hr) 68 | loss = torch.mean((hr - out) ** 2) 69 | average("loss", loss.item()) 70 | 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | 75 | if i % 100 == 0: 76 | print("i = %d, %s, time = %d" % (i, average, time() - t_start)) 77 | 78 | if i % 1000 == 0 or i == num_iter: 79 | torch.save( 80 | model.state_dict(), 81 | "../results/" + model_name + "/" + model_name + ".pt", 82 | ) 83 | 84 | average_test = MovingAverage(0) 85 | 86 | for j, data in enumerate(test_dataloader): 87 | 88 | if is_denoise: 89 | img = data[0].cuda().permute(3, 0, 1, 2) 90 | if use_sim: 91 | gt = data[1].cuda().permute(3, 0, 1, 2) 92 | else: 93 | img = data[0][0].cuda() 94 | img_mid = img[num_split - 1] 95 | img = img.permute(2, 0, 3, 1) 96 | combined = data[-1][0].cuda() 97 | gt = data[1][0].cuda() 98 | 99 | with torch.no_grad(): 100 | if is_denoise: 101 | out = model(img) 102 | else: 103 | out = ( 104 | model(img[:, num_split - 1 - use_k : num_split + use_k]) 105 | .squeeze() 106 | .permute(2, 0, 1) 107 | ) 108 | 109 | if is_denoise: 110 | if use_sim: 111 | average_test( 112 | "mse", ((out - gt) ** 2)[gt > 0.01].mean().item() 113 | ) 114 | np.save( 115 | "../results/" + model_name + "/outputs/gt_%d" % j, 116 | gt.cpu().numpy(), 117 | ) 118 | else: 119 | np.save( 120 | "../results/" + model_name + "/outputs/in_%d" % j, 121 | img.cpu().numpy(), 122 | ) 123 | np.save( 124 | "../results/" + model_name + "/outputs/out_%d" % j, 125 | out.cpu().numpy(), 126 | ) 127 | else: 128 | if use_sim: 129 | out = rician_correct( 130 | out, 131 | None if sigma and (denoiser is not None) else 0, 132 | gt < 0.01, 133 | ) 134 | 135 | average_test( 136 | "mse_cubic", 137 | torch.sqrt( 138 | ((gt - img_mid) ** 2).mean() / (gt**2).mean() 139 | ).item(), 140 | ) 141 | average_test( 142 | "mse_out", 143 | torch.sqrt( 144 | ((gt - out) ** 2).mean() / (gt**2).mean() 145 | ).item(), 146 | ) 147 | average_test( 148 | "mse_combined", 149 | torch.sqrt( 150 | ((gt - combined) ** 2).mean() / (gt**2).mean() 151 | ).item(), 152 | ) 153 | average_test( 154 | "mm_cubic", 155 | ((gt - img_mid) ** 2)[gt > 0.01].mean().item(), 156 | ) 157 | average_test( 158 | "mm_out", ((gt - out) ** 2)[gt > 0.01].mean().item() 159 | ) 160 | average_test( 161 | "mm_combined", 162 | ((gt - combined) ** 2)[gt > 0.01].mean().item(), 163 | ) 164 | 165 | nib.save( 166 | nib.Nifti1Image(out.cpu().numpy() * 1000, np.eye(4)), 167 | "../results/" 168 | + model_name 169 | + "/outputs/out_%d.nii.gz" % j, 170 | ) 171 | nib.save( 172 | nib.Nifti1Image(gt.cpu().numpy() * 1000, np.eye(4)), 173 | "../results/" 174 | + model_name 175 | + "/outputs/gt_%d.nii.gz" % j, 176 | ) 177 | nib.save( 178 | nib.Nifti1Image( 179 | img_mid.cpu().numpy() * 1000, np.eye(4) 180 | ), 181 | "../results/" 182 | + model_name 183 | + "/outputs/in_%d.nii.gz" % j, 184 | ) 185 | nib.save( 186 | nib.Nifti1Image( 187 | combined.cpu().numpy() * 1000, np.eye(4) 188 | ), 189 | "../results/" 190 | + model_name 191 | + "/outputs/combined_%d.nii.gz" % j, 192 | ) 193 | 194 | else: 195 | out = out * 100 + 70 196 | out[out < 0] = 0 197 | img_mid = img_mid * 100 + 70 198 | combined = combined * 100 + 70 199 | gt = gt * 100 + 70 200 | 201 | out = out.cpu().numpy() 202 | img_mid = img_mid.cpu().numpy() 203 | combined = combined.cpu().numpy() 204 | gt = gt.cpu().numpy() 205 | sti = (img_mid + combined) / 2 206 | 207 | if num_split == 4: 208 | mask = gt > 0 209 | average_test("psnr_si", psnr(img_mid, gt, mask)) 210 | average_test("psnr_ti", psnr(combined, gt, mask)) 211 | average_test("psnr_sti", psnr(sti, gt, mask)) 212 | average_test("psnr_out", psnr(out, gt, mask)) 213 | 214 | nib.save( 215 | nib.Nifti1Image(gt, np.eye(4)), 216 | "../results/" 217 | + model_name 218 | + "/outputs/gt_%d.nii.gz" % j, 219 | ) 220 | 221 | nib.save( 222 | nib.Nifti1Image(out, np.eye(4)), 223 | "../results/" 224 | + model_name 225 | + "/outputs/out_%d.nii.gz" % j, 226 | ) 227 | nib.save( 228 | nib.Nifti1Image(combined, np.eye(4)), 229 | "../results/" 230 | + model_name 231 | + "/outputs/combined_%d.nii.gz" % j, 232 | ) 233 | nib.save( 234 | nib.Nifti1Image(img_mid, np.eye(4)), 235 | "../results/" 236 | + model_name 237 | + "/outputs/in_%d.nii.gz" % j, 238 | ) 239 | 240 | print("%d, %s" % (i // 1000, average_test)) 241 | -------------------------------------------------------------------------------- /code/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from typing import Tuple 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size//2), bias=bias) 12 | 13 | 14 | class ResBlock(nn.Module): 15 | def __init__( 16 | self, conv, n_feats, kernel_size, 17 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 18 | 19 | super(ResBlock, self).__init__() 20 | m = [] 21 | for i in range(2): 22 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 23 | if bn: 24 | m.append(nn.BatchNorm2d(n_feats)) 25 | if i == 0: 26 | m.append(act) 27 | 28 | self.body = nn.Sequential(*m) 29 | self.res_scale = res_scale 30 | 31 | def forward(self, x): 32 | 33 | res = x #F.dropout(x, 0.3) 34 | res = self.body(res).mul(self.res_scale) 35 | res += x 36 | 37 | return res 38 | 39 | 40 | class EDSR(nn.Module): 41 | def __init__(self, cin=1, n_resblocks=16, n_feats=64, res_scale=1): 42 | super(EDSR, self).__init__() 43 | 44 | conv = default_conv 45 | kernel_size = 3 46 | act = nn.LeakyReLU(0.1, True) #nn.ReLU(True) 47 | 48 | # define head module 49 | m_head = [conv(cin, n_feats, kernel_size)] 50 | 51 | # define body module 52 | m_body = [ 53 | ResBlock( 54 | conv, n_feats, kernel_size, act=act, res_scale=res_scale 55 | ) for _ in range(n_resblocks) 56 | ] 57 | m_body.append(conv(n_feats, n_feats, kernel_size)) 58 | 59 | # define tail module 60 | m_tail = [ 61 | conv(n_feats, 1, kernel_size) 62 | ] 63 | 64 | self.head = nn.Sequential(*m_head) 65 | self.body = nn.Sequential(*m_body) 66 | self.tail = nn.Sequential(*m_tail) 67 | 68 | def forward(self, x): 69 | x = self.head(x) 70 | 71 | res = self.body(x) 72 | res += x 73 | 74 | x = self.tail(res) 75 | 76 | return x 77 | 78 | 79 | class Crop2d(nn.Module): 80 | """Crop input using slicing. Assumes BCHW data. 81 | Args: 82 | crop (Tuple[int, int, int, int]): Amounts to crop from each side of the image. 83 | Tuple is treated as [left, right, top, bottom]/ 84 | """ 85 | 86 | def __init__(self, crop: Tuple[int, int, int, int]): 87 | super().__init__() 88 | self.crop = crop 89 | assert len(crop) == 4 90 | 91 | def forward(self, x: Tensor) -> Tensor: 92 | (left, right, top, bottom) = self.crop 93 | x0, x1 = left, x.shape[-1] - right 94 | y0, y1 = top, x.shape[-2] - bottom 95 | return x[:, :, y0:y1, x0:x1] 96 | 97 | 98 | class Shift2d(nn.Module): 99 | """Shift an image in either or both of the vertical and horizontal axis by first 100 | zero padding on the opposite side that the image is shifting towards before 101 | cropping the side being shifted towards. 102 | Args: 103 | shift (Tuple[int, int]): Tuple of vertical and horizontal shift. Positive values 104 | shift towards right and bottom, negative values shift towards left and top. 105 | """ 106 | 107 | def __init__(self, shift: Tuple[int, int]): 108 | super().__init__() 109 | self.shift = shift 110 | vert, horz = self.shift 111 | y_a, y_b = abs(vert), 0 112 | x_a, x_b = abs(horz), 0 113 | if vert < 0: 114 | y_a, y_b = y_b, y_a 115 | if horz < 0: 116 | x_a, x_b = x_b, x_a 117 | # Order : Left, Right, Top Bottom 118 | self.pad = nn.ZeroPad2d((x_a, x_b, y_a, y_b)) 119 | self.crop = Crop2d((x_b, x_a, y_b, y_a)) 120 | self.shift_block = nn.Sequential(self.pad, self.crop) 121 | 122 | def forward(self, x: Tensor) -> Tensor: 123 | return self.shift_block(x) 124 | 125 | 126 | def rotate(x: torch.Tensor, angle: int) -> torch.Tensor: 127 | """Rotate images by 90 degrees clockwise. Can handle any 2D data format. 128 | Args: 129 | x (Tensor): Image or batch of images. 130 | angle (int): Clockwise rotation angle in multiples of 90. 131 | data_format (str, optional): Format of input image data, e.g. BCHW, 132 | HWC. Defaults to BCHW. 133 | Returns: 134 | Tensor: Copy of tensor with rotation applied. 135 | """ 136 | h_dim = 2 137 | w_dim = 3 138 | 139 | if angle == 0: 140 | return x 141 | elif angle == 90: 142 | return x.flip(w_dim).transpose(h_dim, w_dim) 143 | elif angle == 180: 144 | return x.flip(w_dim).flip(h_dim) 145 | elif angle == 270: 146 | return x.flip(h_dim).transpose(h_dim, w_dim) 147 | else: 148 | raise NotImplementedError("Must be rotation divisible by 90 degrees") 149 | 150 | 151 | class NoiseNetwork(nn.Module): 152 | """Custom U-Net architecture for Self Supervised Denoising (SSDN) and Noise2Noise (N2N). 153 | Base N2N implementation was made with reference to @joeylitalien's N2N implementation. 154 | Changes made are removal of weight sharing when blocks are reused. Usage of LeakyReLu 155 | over standard ReLu and incorporation of blindspot functionality. 156 | Unlike other typical U-Net implementations dropout is not used when the model is trained. 157 | When in blindspot mode the following behaviour changes occur: 158 | * Input batches are duplicated for rotations: 0, 90, 180, 270. This increases the 159 | batch size by 4x. After the encode-decode stage the rotations are undone and 160 | concatenated on the channel axis with the associated original image. This 4x 161 | increase in channel count is collapsed to the standard channel count in the 162 | first 1x1 kernel convolution. 163 | * To restrict the receptive field into the upward direction a shift is used for 164 | convolutions (see ShiftConv2d) and downsampling. Downsampling uses a single 165 | pixel shift prior to max pooling as dictated by Laine et al. This is equivalent 166 | to applying a shift on the upsample. 167 | Args: 168 | in_channels (int, optional): Number of input channels, this will typically be either 169 | 1 (Mono) or 3 (RGB) but can be more. Defaults to 3. 170 | out_channels (int, optional): Number of channels the final convolution should output. 171 | Defaults to 3. 172 | blindspot (bool, optional): Whether to enable the network blindspot. This will 173 | add in rotation stages and shift stages while max pooling and during convolutions. 174 | A futher shift will occur after upsample. Defaults to False. 175 | zero_output_weights (bool, optional): Whether to initialise the weights of 176 | `nin_c` to zero. This is not mentioned in literature but is done as part 177 | of the tensorflow implementation for the parameter estimation network. 178 | Defaults to False. 179 | """ 180 | 181 | def __init__( 182 | self, 183 | in_channels: int = 3, 184 | out_channels: int = 3, 185 | blindspot: bool = False, 186 | zero_output_weights: bool = False, 187 | ): 188 | super(NoiseNetwork, self).__init__() 189 | self._blindspot = blindspot 190 | self._zero_output_weights = zero_output_weights 191 | self.Conv2d = ShiftConv2d if self.blindspot else nn.Conv2d 192 | 193 | #################################### 194 | # Encode Blocks 195 | #################################### 196 | 197 | def _max_pool_block(max_pool: nn.Module) -> nn.Module: 198 | if blindspot: 199 | return nn.Sequential(Shift2d((1, 0)), max_pool) 200 | return max_pool 201 | 202 | # Layers: enc_conv0, enc_conv1, pool1 203 | self.encode_block_1 = nn.Sequential( 204 | self.Conv2d(in_channels, 48, 3, stride=1, padding=1), 205 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 206 | self.Conv2d(48, 48, 3, padding=1), 207 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 208 | _max_pool_block(nn.MaxPool2d(2)), 209 | ) 210 | 211 | # Layers: enc_conv(i), pool(i); i=2..5 212 | def _encode_block_2_3_4_5() -> nn.Module: 213 | return nn.Sequential( 214 | #nn.Dropout(p_drop), #### 215 | self.Conv2d(48, 48, 3, stride=1, padding=1), 216 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 217 | _max_pool_block(nn.MaxPool2d(2)), 218 | ) 219 | 220 | # Separate instances of same encode module definition created 221 | self.encode_block_2 = _encode_block_2_3_4_5() 222 | self.encode_block_3 = _encode_block_2_3_4_5() 223 | self.encode_block_4 = _encode_block_2_3_4_5() 224 | self.encode_block_5 = _encode_block_2_3_4_5() 225 | 226 | # Layers: enc_conv6 227 | self.encode_block_6 = nn.Sequential( 228 | #nn.Dropout(p_drop), #### 229 | self.Conv2d(48, 48, 3, stride=1, padding=1), 230 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 231 | ) 232 | 233 | #################################### 234 | # Decode Blocks 235 | #################################### 236 | # Layers: upsample5 237 | self.decode_block_6 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest")) 238 | 239 | # Layers: dec_conv5a, dec_conv5b, upsample4 240 | self.decode_block_5 = nn.Sequential( 241 | #nn.Dropout(p_drop), #### 242 | self.Conv2d(96, 96, 3, stride=1, padding=1), 243 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 244 | #nn.Dropout(p_drop), #### 245 | self.Conv2d(96, 96, 3, stride=1, padding=1), 246 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 247 | nn.Upsample(scale_factor=2, mode="nearest"), 248 | ) 249 | 250 | # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2 251 | def _decode_block_4_3_2() -> nn.Module: 252 | return nn.Sequential( 253 | #nn.Dropout(p_drop), #### 254 | self.Conv2d(144, 96, 3, stride=1, padding=1), 255 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 256 | #nn.Dropout(p_drop), #### 257 | self.Conv2d(96, 96, 3, stride=1, padding=1), 258 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 259 | nn.Upsample(scale_factor=2, mode="nearest"), 260 | ) 261 | 262 | # Separate instances of same decode module definition created 263 | self.decode_block_4 = _decode_block_4_3_2() 264 | self.decode_block_3 = _decode_block_4_3_2() 265 | self.decode_block_2 = _decode_block_4_3_2() 266 | 267 | # Layers: dec_conv1a, dec_conv1b, dec_conv1c, 268 | self.decode_block_1 = nn.Sequential( 269 | #nn.Dropout(p_drop), #### 270 | self.Conv2d(96 + in_channels, 96, 3, stride=1, padding=1), 271 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 272 | #nn.Dropout(p_drop), #### 273 | self.Conv2d(96, 96, 3, stride=1, padding=1), 274 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 275 | ) 276 | 277 | #################################### 278 | # Output Block 279 | #################################### 280 | 281 | if self.blindspot: 282 | # Shift 1 pixel down 283 | self.shift = Shift2d((1, 0)) 284 | # 4 x Channels due to batch rotations 285 | nin_a_io = 384 286 | else: 287 | nin_a_io = 96 288 | 289 | # nin_a,b,c, linear_act 290 | self.output_conv = self.Conv2d(96, out_channels, 1) 291 | self.output_block = nn.Sequential( 292 | #nn.Dropout(p_drop), #### 293 | self.Conv2d(nin_a_io, nin_a_io, 1), 294 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 295 | #nn.Dropout(p_drop), #### 296 | self.Conv2d(nin_a_io, 96, 1), 297 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 298 | self.output_conv, 299 | ) 300 | 301 | # Initialize weights 302 | #self.init_weights() 303 | 304 | @property 305 | def blindspot(self) -> bool: 306 | return self._blindspot 307 | 308 | def init_weights(self): 309 | """Initializes weights using Kaiming He et al. (2015). 310 | Only convolution layers have learnable weights. All convolutions use a leaky 311 | relu activation function (negative_slope = 0.1) except the last which is just 312 | a linear output. 313 | """ 314 | with torch.no_grad(): 315 | self._init_weights() 316 | 317 | def _init_weights(self): 318 | for m in self.modules(): 319 | if isinstance(m, nn.Conv2d): 320 | nn.init.kaiming_normal_(m.weight.data, a=0.1) 321 | m.bias.data.zero_() 322 | # Initialise last output layer 323 | if self._zero_output_weights: 324 | self.output_conv.weight.zero_() 325 | else: 326 | nn.init.kaiming_normal_(self.output_conv.weight.data, nonlinearity="linear") 327 | 328 | def forward(self, x: Tensor) -> Tensor: 329 | if self.blindspot: 330 | rotated = [rotate(x, rot) for rot in (0, 90, 180, 270)] 331 | x = torch.cat((rotated), dim=0) 332 | 333 | # Encoder 334 | pool1 = self.encode_block_1(x) 335 | pool2 = self.encode_block_2(pool1) 336 | pool3 = self.encode_block_3(pool2) 337 | pool4 = self.encode_block_4(pool3) 338 | pool5 = self.encode_block_5(pool4) 339 | encoded = self.encode_block_6(pool5) 340 | 341 | # Decoder 342 | upsample5 = self.decode_block_6(encoded) 343 | concat5 = torch.cat((upsample5, pool4), dim=1) 344 | upsample4 = self.decode_block_5(concat5) 345 | concat4 = torch.cat((upsample4, pool3), dim=1) 346 | upsample3 = self.decode_block_4(concat4) 347 | concat3 = torch.cat((upsample3, pool2), dim=1) 348 | upsample2 = self.decode_block_3(concat3) 349 | concat2 = torch.cat((upsample2, pool1), dim=1) 350 | upsample1 = self.decode_block_2(concat2) 351 | concat1 = torch.cat((upsample1, x), dim=1) 352 | x = self.decode_block_1(concat1) 353 | 354 | # Output 355 | if self.blindspot: 356 | # Apply shift 357 | shifted = self.shift(x) 358 | # Unstack, rotate and combine 359 | rotated_batch = torch.chunk(shifted, 4, dim=0) 360 | aligned = [ 361 | rotate(rotated, rot) 362 | for rotated, rot in zip(rotated_batch, (0, 270, 180, 90)) 363 | ] 364 | x = torch.cat(aligned, dim=1) 365 | 366 | x = self.output_block(x) 367 | 368 | return x 369 | 370 | @staticmethod 371 | def input_wh_mul() -> int: 372 | """Multiple that both the width and height dimensions of an input must be to be 373 | processed by the network. This is devised from the number of pooling layers that 374 | reduce the input size. 375 | Returns: 376 | int: Dimension multiplier 377 | """ 378 | max_pool_layers = 5 379 | return 2 ** max_pool_layers 380 | 381 | 382 | class ShiftConv2d(nn.Conv2d): 383 | def __init__(self, *args, **kwargs): 384 | super().__init__(*args, **kwargs) 385 | self.shift_size = (self.kernel_size[0] // 2, 0) 386 | # Use individual layers of shift for wrapping conv with shift 387 | shift = Shift2d(self.shift_size) 388 | self.pad = shift.pad 389 | self.crop = shift.crop 390 | 391 | def forward(self, x: Tensor) -> Tensor: 392 | x = self.pad(x) 393 | x = super().forward(x) 394 | x = self.crop(x) 395 | return x -------------------------------------------------------------------------------- /code/trajectory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import os 4 | from scipy.spatial.transform import Rotation 5 | from scipy.ndimage import gaussian_filter1d 6 | from scipy.interpolate import interp1d 7 | 8 | 9 | def get_trajectory(folder='../trajectory'): 10 | 11 | traj = [] 12 | 13 | for f in os.listdir(folder): 14 | joint_coord = sio.loadmat(os.path.join(folder, f))['joint_coord'].astype(np.float32) 15 | 16 | joint_coord = joint_coord[np.all(joint_coord > 0, (1, 2))] 17 | 18 | eye_l = joint_coord[..., 7] 19 | eye_r = joint_coord[..., 8] 20 | neck = (joint_coord[..., 11] + joint_coord[..., 12]) / 2 21 | 22 | origin = (eye_l + eye_r + neck) / 3 23 | 24 | x_vec = eye_l - eye_r 25 | x_vec = x_vec / np.linalg.norm(x_vec, ord=2, axis=-1, keepdims=True) 26 | 27 | neck_eye_l = neck - eye_l 28 | y_vec = np.cross(x_vec, neck_eye_l) 29 | y_vec = y_vec / np.linalg.norm(y_vec, ord=2, axis=-1, keepdims=True) 30 | 31 | z_vec = np.cross(x_vec, y_vec) 32 | z_vec = z_vec / np.linalg.norm(z_vec, ord=2, axis=-1, keepdims=True) 33 | 34 | R = np.stack([x_vec, y_vec, z_vec], -1) 35 | R = R @ R[0].T[None] 36 | R = Rotation.from_matrix(R).as_euler('xyz') 37 | t = origin - origin[[0]] 38 | Rt = np.concatenate([R, t], -1) 39 | Rt = Rt[::2] 40 | Rt = gaussian_filter1d(Rt, 0.5, 0) 41 | 42 | interp_func = interp1d(np.arange(Rt.shape[0]), Rt, kind='cubic', axis=0, fill_value="extrapolate", assume_sorted=True) 43 | 44 | traj.append((interp_func, Rt.shape[0]-1)) 45 | 46 | return traj 47 | 48 | if __name__ == '__main__': 49 | pass 50 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.special import i0, i1 2 | from scipy.stats import trim_mean 3 | import torch 4 | import numpy as np 5 | import os 6 | from skimage.metrics import structural_similarity 7 | 8 | class MovingAverage: 9 | def __init__(self, alpha): 10 | assert 0 <= alpha < 1 11 | self.alpha = alpha 12 | self.value = dict() 13 | 14 | def __call__(self, key, value): 15 | if key not in self.value: 16 | self.value[key] = (0, 0) 17 | num, v = self.value[key] 18 | num += 1 19 | if self.alpha: 20 | v = v * self.alpha + value * (1 - self.alpha) 21 | else: 22 | v += value 23 | self.value[key] = (num, v) 24 | 25 | def __str__(self): 26 | s = '' 27 | for key in self.value: 28 | num, v = self.value[key] 29 | if self.alpha: 30 | s += "%s = %f\t" % (key, v / (1 - self.alpha**num)) 31 | else: 32 | s += "%s = %f\t" % (key, v / num) 33 | return s 34 | 35 | def rician_correct(out, sigma, background): 36 | if sigma == 0: 37 | out[out < 0] = 0 38 | return out 39 | elif sigma is None: 40 | sigma_pi = trim_mean(out[background].cpu().numpy(), 0.1, None) 41 | sigma = sigma_pi * np.sqrt(2/np.pi) 42 | else: 43 | sigma_pi = sigma * np.sqrt(np.pi/2) 44 | 45 | old_out = out 46 | out = out / sigma_pi 47 | out[out < 1] = 1 48 | curVal=0 49 | for coeff in [-0.02459419, 0.28790799, 0.27697441, 2.68069732]: 50 | curVal = (curVal+coeff)*out 51 | out = (curVal - 3.22092921) * (sigma**2) 52 | snr_mask = old_out/sigma > 3.5 53 | out[snr_mask] = old_out[snr_mask]**2 - sigma**2 54 | out = torch.sqrt(out) 55 | return out 56 | 57 | def fba(imgs, p): 58 | freqs = [np.fft.rfftn(img) for img in imgs] 59 | weights = [np.abs(freq) ** p for freq in freqs] 60 | return np.fft.irfftn(sum(freq * weight for freq, weight in zip(freqs, weights)) / sum(weights)).astype(np.float32) 61 | 62 | def mkdir(path): 63 | if not os.path.exists(path): 64 | os.makedirs(path) 65 | 66 | def psnr(x, y, mask=None): 67 | if mask is None: 68 | mse = np.mean((x - y) ** 2) 69 | else: 70 | mse = np.sum(((x - y) ** 2) * mask) / mask.sum() 71 | return 10 * np.log10(y.max()**2 / mse) 72 | 73 | def ssim(x, y, mask=None): 74 | mssim, S = structural_similarity(x, y, full=True) 75 | if mask is not None: 76 | return (S * mask).sum() / mask.sum() 77 | else: 78 | return mssim 79 | 80 | def ssim_slice(x, y, mask): 81 | mask = mask.sum((0,1)) > 0 82 | #print(np.nonzero(mask)) 83 | x = x[..., mask] 84 | y = y[..., mask] 85 | 86 | return structural_similarity(x, y) 87 | #ssims = [] 88 | #for i in range(x.shape[-1]): 89 | # ssims.append(structural_similarity(x[..., i], y[..., i])) 90 | #return np.mean(ssims) 91 | 92 | -------------------------------------------------------------------------------- /img/stress.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/img/stress.gif -------------------------------------------------------------------------------- /trajectory/1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/1.mat -------------------------------------------------------------------------------- /trajectory/10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/10.mat -------------------------------------------------------------------------------- /trajectory/2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/2.mat -------------------------------------------------------------------------------- /trajectory/3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/3.mat -------------------------------------------------------------------------------- /trajectory/4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/4.mat -------------------------------------------------------------------------------- /trajectory/5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/5.mat -------------------------------------------------------------------------------- /trajectory/6.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/6.mat -------------------------------------------------------------------------------- /trajectory/7.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/7.mat -------------------------------------------------------------------------------- /trajectory/8.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/8.mat -------------------------------------------------------------------------------- /trajectory/9.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/9.mat --------------------------------------------------------------------------------