├── .gitignore
├── CRL_Fetal_Brain_Atlas_2017v3
└── README
├── LICENSE
├── README.md
├── code
├── config.py
├── data.py
├── main.py
├── models.py
├── trajectory.py
└── utils.py
├── img
└── stress.gif
└── trajectory
├── 1.mat
├── 10.mat
├── 2.mat
├── 3.mat
├── 4.mat
├── 5.mat
├── 6.mat
├── 7.mat
├── 8.mat
└── 9.mat
/.gitignore:
--------------------------------------------------------------------------------
1 | /CRL_Fetal_Brain_Atlas_2017v3/*
2 | !/CRL_Fetal_Brain_Atlas_2017v3/README
3 | /results/*
4 | /_results/*
5 | .vscode/
6 | .mypy_cache/
7 | *.nii
8 | *.nii.gz
9 | *.npy
10 | *.pt
11 |
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
--------------------------------------------------------------------------------
/CRL_Fetal_Brain_Atlas_2017v3/README:
--------------------------------------------------------------------------------
1 | The CRL fetal brain atlas can be downloaded from https://form.jotform.com/91364382958166
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Junshen Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fetalSR
2 |
3 | STRESS: Super-Resolution for Dynamic Fetal MRI using Self-Supervised Learning ([Springer](https://link.springer.com/chapter/10.1007/978-3-030-87234-2_19)|[arXiv](https://arxiv.org/abs/2106.12407))
4 |
5 |
6 |
7 |
8 |
9 | ## Usage
10 |
11 | Run ```python main.py``` to train a model and test on the simulated dataset.
12 |
13 | You may create your own dataset following `EPIDataset` in `data.py`.
14 |
15 | ## Cite our work
16 |
17 | ```
18 | @inproceedings{xu2021stress,
19 | title={STRESS: Super-Resolution for Dynamic Fetal MRI Using Self-supervised Learning},
20 | author={Xu, Junshen and Abaci Turk, Esra and Grant, P Ellen and Golland, Polina and Adalsteinsson, Elfar},
21 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
22 | pages={197--206},
23 | year={2021},
24 | organization={Springer}
25 | }
26 | ```
27 |
--------------------------------------------------------------------------------
/code/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
4 | # hyperparameters
5 | use_sim = True
6 | is_denoise = False
7 | denoiser = None
8 |
9 | model_name = "sim_s4_k3_n3"
10 | batch_size = 16 if is_denoise else 64
11 | num_iter = 100000
12 | lr = 1e-4
13 |
14 | num_split = 4
15 | use_k = num_split // 2
16 |
17 | sigma = 0.03
18 |
--------------------------------------------------------------------------------
/code/data.py:
--------------------------------------------------------------------------------
1 | from scipy.ndimage import rotate
2 | import torch.multiprocessing as mp
3 | import numpy as np
4 | import nibabel as nib
5 | import torch
6 | import os
7 | from trajectory import get_trajectory
8 | import traceback
9 | from scipy.spatial.transform import Rotation
10 | from scipy.ndimage import map_coordinates, affine_transform
11 | from scipy.stats import special_ortho_group
12 | from itertools import product
13 | from time import time
14 | from config import sigma as _sigma
15 |
16 | def read_nifti(nii_filename):
17 | data = nib.load(nii_filename)
18 | return np.squeeze(data.get_data().astype(np.float32))
19 |
20 | def down_up(*imgs, start=0):
21 | res = []
22 | num_split = (len(imgs) + 1) // 2
23 | for s, img in enumerate(imgs, start):
24 | ss = s % num_split
25 | X, Y, Z = np.meshgrid((np.arange(img.shape[0]) - ss) / num_split, np.arange(img.shape[1]), np.arange(img.shape[2]), indexing='ij')
26 | res.append(map_coordinates(img[ss::num_split], [X, Y, Z], order=3, mode='nearest'))
27 | return res
28 |
29 | class EPIDataset(torch.utils.data.Dataset):
30 | def __init__(self, num_split, stage, is_denoise=False, denoiser=None):
31 | assert stage in ['train', 'val', 'test']
32 | self.data_dir = '/home/junshen/new'
33 | self.is_test = stage == 'test'
34 | self.folders = sorted(os.listdir(self.data_dir))
35 | if stage == 'test':
36 | self.folders = [folder for i, folder in enumerate(self.folders) if i % 6 == 0]
37 | else:
38 | self.folders = [folder for i, folder in enumerate(self.folders) if i % 6 != 0]
39 | self.proc = []
40 | self.res = []
41 | num_p = 10
42 | self.queue = mp.Queue(1024)
43 | n_new = 6
44 | self.queue2 = mp.Queue(n_new)
45 | for _ in range(n_new):
46 | self.queue2.put(None)
47 |
48 | if denoiser is not None:
49 | denoise_queue = [{'in':mp.Queue(1),'out':mp.Queue(1)} for _ in range(num_p)]
50 | self.denoiser = mp.Process(target=denoise_fn, args=(denoiser, denoise_queue))
51 | self.denoiser.daemon = True
52 | self.denoiser.start()
53 | self.denoise_queue = denoise_queue
54 | else:
55 | denoise_queue = [None] * num_p
56 | self.denoise_queue = denoise_queue
57 |
58 | # use multiple processes to fetch data
59 | for i in range(num_p):
60 | proc = mp.Process(target=prefetch_volumes_test if self.is_test else prefetch_volumes,
61 | args=(self.data_dir, self.folders[i::num_p], self.queue, self.queue2, num_split, is_denoise, self.denoise_queue[i]))
62 | proc.daemon = True
63 | proc.start()
64 | self.proc.append(proc)
65 |
66 | def load_data(self):
67 | if len(self.res) == 0:
68 | N = 0
69 | while True:
70 | res = self.queue.get()
71 | if res is None:
72 | N += 1
73 | if N == len(self.proc):
74 | break
75 | else:
76 | self.res.append(res)
77 | self.res = sorted(self.res, key=lambda x:x[-1])
78 |
79 | if self.denoise_queue[0] is not None:
80 | self.denoise_queue[0]['in'].put(None)
81 |
82 | print("test set len: %d" % len(self.res))
83 |
84 | def __len__(self):
85 | if self.is_test:
86 | self.load_data()
87 | return len(self.res)
88 | else:
89 | return int(1e8)
90 |
91 | def __getitem__(self, idx):
92 | if self.is_test:
93 | self.load_data()
94 | return self.res[idx][:3]
95 | else:
96 | return self.queue.get()
97 |
98 |
99 | def prefetch_volumes(data_dir, folders, queue, q2, num_split, is_denoise, denoiser):
100 | a = 32
101 | volumes = [None] * len(folders)
102 | files = [[]] * len(folders)
103 | starts = [None] * len(folders)
104 | start0s = [None] * len(folders)
105 | for i in range(len(folders)):
106 | files[i] = sorted(os.listdir(os.path.join(data_dir, folders[i])))
107 | img = read_nifti(os.path.join(data_dir, folders[i], files[i][0]))
108 | err0 = np.mean(((img[:, :, 20] + img[:, :, 22]) / 2 - img[:, :, 21])**2)
109 | err1 = np.mean(((img[:, :, 21] + img[:, :, 23]) / 2 - img[:, :, 22])**2)
110 | start0s[i] = 0 if err0 < err1 else 1
111 |
112 | try:
113 | while(True):
114 | for i in range(len(volumes)):
115 |
116 | new_vol = False
117 | if volumes[i] is not None:
118 | try:
119 | _ = q2.get_nowait()
120 | new_vol = True
121 | except:
122 | pass
123 |
124 | if volumes[i] is None or new_vol:
125 | fid = np.random.choice(np.arange(num_split, len(files[i])-num_split))
126 |
127 | angle = np.random.uniform(360)
128 | hrs = []
129 |
130 | for dt in range(-num_split+1, num_split):
131 | t = fid + dt
132 |
133 | img = read_nifti(os.path.join(data_dir, folders[i], files[i][t]))
134 | img = (img - 70.0) / 100.0
135 | img = rotate(img, angle, axes=(0, 1), reshape=False)
136 | ss = (start0s[i] + t) % num_split
137 |
138 | if denoiser is not None:
139 | d = 128-img.shape[0]
140 | d1 = d//2
141 | d2 = d - d1
142 | if d1 >= 0:
143 | frames = np.pad(img[..., ss::num_split], [(d1, d2),(d1, d2), (0, 0)], mode='constant')
144 | else:
145 | frames = img[-d1:d2, -d1:d2, ss::num_split]
146 | frames = torch.tensor(frames[None]).permute(3, 0, 1, 2)
147 | for n_slice in range(0, frames.shape[0], 16):
148 | denoiser['in'].put(frames[n_slice:n_slice+16])
149 | frames[n_slice:n_slice+16] = denoiser['out'].get()
150 | if d1 >= 0:
151 | frames = frames.squeeze().permute(1,2,0)[d1:-d2,d1:-d2].numpy()
152 | else:
153 | frames = np.pad(frames.squeeze().permute(1,2,0).numpy(), [(-d1, -d2),(-d1, -d2), (0, 0)], mode='constant')
154 | img[..., ss::num_split] = frames
155 |
156 | if is_denoise:
157 | img = img[..., ss::num_split]
158 | else:
159 | X, Y, Z = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), (np.arange(img.shape[2]) - ss) / num_split, indexing='ij')
160 | img = map_coordinates(img[..., ss::num_split], [X, Y, Z], order=3, mode='nearest')
161 |
162 | hrs.append(img)
163 |
164 | if is_denoise:
165 | hrs = np.concatenate(hrs, -1)[None]
166 | volumes[i] = (hrs, hrs)
167 | else:
168 | volumes[i] = (np.stack(down_up(*hrs), 0), np.stack(hrs, 0))
169 | starts[i] = (start0s[i] + fid) % num_split
170 |
171 | if new_vol:
172 | q2.put(None)
173 |
174 | lr, hr = volumes[i]
175 | if is_denoise:
176 | z = np.random.randint(lr.shape[3])
177 | hr = hr[:, :, :, z]
178 | d = 128-hr.shape[1]
179 | d1 = d//2
180 | d2 = d - d1
181 | if d1 >= 0:
182 | hr = np.pad(hr, [(0,0),(d1, d2),(d1, d2)], mode='constant')
183 | else:
184 | hr = hr[:, -d1:d2, -d1:d2]
185 | lr = hr
186 | else:
187 | y = np.random.randint(lr.shape[1] - a)
188 | x = np.random.randint(lr.shape[2] - a)
189 | z = np.random.randint((lr.shape[3] - starts[i]) // num_split)
190 | lr = lr[:, y:y+a, x:x+a, starts[i] + z * num_split]
191 | hr = hr[:, y:y+a, x:x+a, starts[i] + z * num_split]
192 | axis = np.random.choice([None, 1, 2])
193 | if axis is not None:
194 | lr = np.flip(lr, axis).copy()
195 | hr = np.flip(hr, axis).copy()
196 | lr = torch.tensor(lr, dtype=torch.float32)
197 | hr = torch.tensor(hr, dtype=torch.float32)
198 | queue.put((lr, hr))
199 | except:
200 | traceback.print_exc()
201 | print("error: %s" % mp.current_process().name)
202 |
203 | def prefetch_volumes_test(data_dir, folders, queue, q2, num_split, is_denoise, denoiser):
204 | try:
205 | for folder in folders:
206 | files = sorted(os.listdir(os.path.join(data_dir, folder)))
207 | img = read_nifti(os.path.join(data_dir, folder, files[0]))
208 | err0 = np.mean(((img[:, :, 20] + img[:, :, 22]) / 2 - img[:, :, 21])**2)
209 | err1 = np.mean(((img[:, :, 21] + img[:, :, 23]) / 2 - img[:, :, 22])**2)
210 | start = 0 if err0 < err1 else 1
211 |
212 | for fid in list(range(len(files)))[::(len(files)//7)][1:-1]:
213 | imgs = []
214 | combined = np.zeros_like(img)
215 | for dt in range(-num_split+1, num_split):
216 | t = fid + dt
217 | img = read_nifti(os.path.join(data_dir, folder, files[t]))
218 | img = (img - 70.0) / 100.0
219 | ss = (start + t) % num_split
220 |
221 | combined[..., ss::num_split] += img[..., ss::num_split] * (num_split - np.abs(dt)) / num_split
222 |
223 | if dt == 0:
224 | if num_split == 4:
225 | start_gt = (start+fid+2) % num_split
226 | gt = 0 * img - 1000
227 | gt[..., start_gt::num_split] = img[..., start_gt::num_split]
228 | else:
229 | gt = 0
230 |
231 | if denoiser is not None:
232 | d = 128-img.shape[0]
233 | d1 = d//2
234 | d2 = d - d1
235 | if d1 >= 0:
236 | frames = np.pad(img[..., ss::num_split], [(d1, d2),(d1, d2), (0, 0)], mode='constant')
237 | else:
238 | frames = img[-d1:d2, -d1:d2, ss::num_split]
239 | frames = torch.tensor(frames[None]).permute(3, 0, 1, 2)
240 | for n_slice in range(0, frames.shape[0], 16):
241 | denoiser['in'].put(frames[n_slice:n_slice+16])
242 | frames[n_slice:n_slice+16] = denoiser['out'].get()
243 | if d1 >= 0:
244 | frames = frames.squeeze().permute(1,2,0)[d1:-d2,d1:-d2].numpy()
245 | else:
246 | frames = np.pad(frames.squeeze().permute(1,2,0).numpy(), [(-d1, -d2),(-d1, -d2), (0, 0)], mode='constant')
247 | img[..., ss::num_split] = frames
248 |
249 | if is_denoise:
250 | imgs.append(img[..., ss::num_split])
251 | else:
252 | X, Y, Z = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), (np.arange(img.shape[2]) - ss) / num_split, indexing='ij')
253 | imgs.append(map_coordinates(img[..., ss::num_split], [X, Y, Z], order=3, mode='nearest'))
254 |
255 | if is_denoise:
256 | imgs = np.concatenate(imgs, -1)
257 | d = 128-img.shape[0]
258 | d1 = d//2
259 | d2 = d - d1
260 | if d1 >= 0:
261 | imgs = np.pad(imgs, [(d1, d2),(d1, d2), (0, 0)], mode='constant')
262 | else:
263 | imgs = imgs[-d1:d2, -d1:d2]
264 | else:
265 | imgs = np.stack(imgs, 0)
266 | queue.put((imgs, gt, combined, (start+fid) % num_split, os.path.join(folder, files[fid])))
267 | except:
268 | traceback.print_exc()
269 | print("test error: %s" % mp.current_process().name)
270 | queue.put(None)
271 | return
272 |
273 |
274 | class SimDataset(torch.utils.data.Dataset):
275 | def __init__(self, num_split, stage, is_denoise=False, denoiser=None):
276 | assert stage in ['test', 'train']
277 | self.is_test = stage == 'test'
278 | test_ga = ['25', '28', '30', '33', '35']
279 | data_dir = '/home/junshen/fetalSR/CRL_Fetal_Brain_Atlas_2017v3/'
280 | files = [f for f in os.listdir(data_dir) if ('STA' in f) and ('_' not in f)]
281 | if self.is_test:
282 | files = [f for f in files if any(ga in f for ga in test_ga)]
283 | else:
284 | files = [f for f in files if all(ga not in f for ga in test_ga)]
285 | files = [os.path.join(data_dir, f) for f in files]
286 | trajs = get_trajectory()
287 | self.proc = []
288 |
289 | num_p = 10
290 | self.queue = mp.Queue(1024)
291 |
292 | # use multiple processes to fetch data
293 | if self.is_test:
294 | imgs = [nib.load(f).get_fdata().astype(np.float32) / 1000.0 for f in files]
295 | # trajs = [trajs[t][0] for t in [0, 5, 10, 15, 20, 25, 30]]
296 | trajs = [trajs[t][0] for t in range(len(trajs))]
297 | imgs_trajs = list(product(imgs, trajs))
298 | self.length = len(imgs_trajs)
299 | self.res = []
300 | else:
301 | n_new = 6
302 | self.queue2 = mp.Queue(n_new)
303 | for _ in range(n_new):
304 | self.queue2.put(None)
305 | self.length = int(1e8)
306 |
307 | if denoiser is not None:
308 | denoise_queue = [{'in':mp.Queue(1),'out':mp.Queue(1)} for _ in range(num_p)]
309 | self.denoiser = mp.Process(target=denoise_fn, args=(denoiser, denoise_queue))
310 | self.denoiser.daemon = True
311 | self.denoiser.start()
312 | self.denoise_queue = denoise_queue
313 | else:
314 | denoise_queue = [None] * num_p
315 | self.denoise_queue = denoise_queue
316 |
317 | for i in range(num_p):
318 | if self.is_test:
319 | proc = mp.Process(target=prefetch_sim_volumes_test, args=(imgs_trajs[i::num_p], self.queue, num_split, is_denoise, denoise_queue[i]))
320 | else:
321 | proc = mp.Process(target=prefetch_sim_volumes, args=(files[i::num_p], self.queue, trajs, self.queue2, num_split, is_denoise, denoise_queue[i]))
322 | proc.daemon = True
323 | proc.start()
324 | self.proc.append(proc)
325 |
326 | def __len__(self):
327 | return self.length
328 |
329 | def __getitem__(self, idx):
330 | if self.is_test:
331 | if len(self.res) == 0:
332 | for i in range(self.length):
333 | self.res.append(self.queue.get())
334 | #print(i, self.length)
335 | if self.denoise_queue[0] is not None:
336 | self.denoise_queue[0]['in'].put(None)
337 | return self.res[idx]
338 | else:
339 | return self.queue.get()
340 |
341 | def prefetch_sim_volumes(files, queue, trajs, q2, num_split, is_denoise, denoiser):
342 | #denoiser = denoiser.cuda()
343 | a = 64 # 64 (0.031)
344 | starts = [None] * len(files)
345 | volumes = [None] * len(files)
346 | imgs = [nib.load(f).get_fdata().astype(np.float32) / 1000.0 for f in files]
347 | try:
348 | while True:
349 | for j in range(len(imgs)):
350 | new_vol = False
351 | if volumes[j] is not None:
352 | try:
353 | _ = q2.get_nowait()
354 | new_vol = True
355 | except:
356 | pass
357 | if volumes[j] is None or new_vol:
358 |
359 | start = np.random.choice(num_split)
360 | traj, T = trajs[np.random.choice(len(trajs))]
361 | t0 = np.random.uniform(0, T)
362 | hr, gt, combined, starts[j] = sim_scan(imgs[j], num_split, traj, t0, 1.0 / imgs[j].shape[-1], start, np.eye(3,3), is_denoise, denoiser)
363 |
364 | lr = down_up(*hr)
365 | volumes[j] = (np.stack(lr, 0), np.stack(hr, 0))
366 | if new_vol:
367 | q2.put(None)
368 |
369 | lr, hr = volumes[j]
370 | while True:
371 | y = np.random.randint(lr.shape[1] - a)
372 | x = np.random.randint(lr.shape[2] - a)
373 | z = np.random.randint((lr.shape[3] - starts[j]) // num_split)
374 | if is_denoise:
375 | lr_ = lr[lr.shape[0]//2:lr.shape[0]//2+1, :, :, starts[j] + z * num_split]
376 | hr_ = hr[lr.shape[0]//2:lr.shape[0]//2+1, :, :, starts[j] + z * num_split]
377 | hr_ = np.pad(hr_, [(0,0),(28, 29),(1, 2)], mode='constant')
378 | else:
379 | lr_ = lr[:, y:y+a, x:x+a, starts[j] + z * num_split]
380 | hr_ = hr[:, y:y+a, x:x+a, starts[j] + z * num_split]
381 | if np.max(hr_) > 1:
382 | lr, hr = lr_, hr_
383 | break
384 | axis = np.random.choice([None, 1, 2])
385 | if axis is not None:
386 | lr = np.flip(lr, axis).copy()
387 | hr = np.flip(hr, axis).copy()
388 | lr = torch.tensor(lr, dtype=torch.float32)
389 | hr = torch.tensor(hr, dtype=torch.float32)
390 | queue.put((lr, hr))
391 | except:
392 | traceback.print_exc()
393 | print("error: %s" % mp.current_process().name)
394 |
395 | def prefetch_sim_volumes_test(imgs_trajs, queue, num_split, is_denoise, denoiser):
396 | #denoiser = denoiser.cuda()
397 | try:
398 | for img, traj in imgs_trajs:
399 | t0 = 9
400 | inputs, gt, combined, start = sim_scan(img, num_split, traj, t0, 1.0 / img.shape[-1], 0, np.eye(3,3), is_denoise, denoiser)
401 | if is_denoise:
402 | gt = np.stack(gt, -1)
403 | inputs = inputs[len(inputs)//2][..., start::num_split]
404 | gt = np.pad(gt, [(28, 29),(1, 2),(0,0)], mode='constant')
405 | inputs = np.pad(inputs, [(28, 29),(1, 2),(0,0)], mode='constant')
406 | else:
407 | inputs = np.stack(inputs, 0)
408 | queue.put((inputs, gt, combined))
409 | except:
410 | traceback.print_exc()
411 | print("test error: %s" % mp.current_process().name)
412 | return
413 |
414 | def sim_scan(img, num_split, traj, t0, dt, start, rot0, is_denoise=False, model=None):
415 | #model = model.cuda
416 | t0 = t0 - dt * (img.shape[2] - img.shape[2] / 2 / num_split)
417 | idx = start
418 | i = 0
419 | gt = []
420 | all_frames = []
421 | combined = [0] * img.shape[2]
422 | frames = []
423 | sigma = img.max() * _sigma
424 | while True:
425 | if idx >= img.shape[2]:
426 | if model is not None:
427 | frames = np.pad(np.stack(frames, 0), [(0,0), (28, 29),(1, 2)], mode='constant')
428 | frames = torch.tensor(np.stack(frames, 0)[:, None])
429 | for n_slice in range(0, frames.shape[0], 16):
430 | model['in'].put(frames[n_slice:n_slice+16])
431 | frames[n_slice:n_slice+16] = model['out'].get()
432 | frames = frames.squeeze().permute(1,2,0)[28:-29,1:-2].numpy()
433 | else:
434 | frames = np.stack(frames, -1)
435 | X, Y, Z = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), (np.arange(img.shape[2]) - start) / num_split, indexing='ij')
436 | all_frames.append(map_coordinates(frames, [X, Y, Z], order=3, mode='nearest'))
437 | if len(all_frames) == 2 * num_split - 1:
438 | return all_frames, gt, np.stack(combined, -1), (start + 1) % num_split
439 | start = (start + 1) % num_split
440 | idx = start
441 | frames = []
442 |
443 | Rt = traj(t0 + i * dt)
444 | R = Rotation.from_euler('xyz', Rt[:3]).as_matrix() @ rot0
445 | t = Rt[3:] - [img.shape[0]/2, img.shape[1]/2, img.shape[2]/2] @ R.T + [img.shape[0]/2, img.shape[1]/2, img.shape[2]/2]
446 | frame = affine_transform(img, R, t, order=1)
447 | if (len(all_frames) == num_split - 1):
448 | if is_denoise:
449 | gt.append(frame[..., idx])
450 | elif ((idx - start) // num_split == img.shape[2] // num_split // 2):
451 | gt = frame
452 | frames.append(frame[..., idx])
453 | if sigma > 0:
454 | noise1 = np.random.normal(scale=sigma, size=frames[-1].shape).astype(np.float32)
455 | noise2 = np.random.normal(scale=sigma, size=frames[-1].shape).astype(np.float32)
456 | frames[-1] = np.sqrt((frames[-1] + noise1)**2 + noise2**2)
457 | #if num_split - 1 - num_split//2 <= len(all_frames) < 2*num_split - 1 - num_split//2:
458 | # combined[idx] = frames[-1]
459 | combined[idx] += frames[-1] * (num_split - np.abs(len(all_frames) - num_split + 1)) / num_split
460 | idx += num_split
461 | i += 1
462 |
463 | def denoise_fn(model, queues):
464 | model = model.cuda()
465 | while True:
466 | for q in queues:
467 | try:
468 | inputs = q['in'].get_nowait()
469 | if inputs is None:
470 | return
471 | with torch.no_grad():
472 | q['out'].put(model(inputs.cuda()).cpu())
473 | except:
474 | pass
475 |
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | from config import *
2 | from data import EPIDataset, SimDataset
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from models import EDSR, NoiseNetwork
7 | import nibabel as nib
8 | import numpy as np
9 | from time import time
10 | from utils import MovingAverage, rician_correct, mkdir, psnr, ssim
11 | import torch.multiprocessing as mp
12 |
13 | if __name__ == "__main__":
14 | # mp.set_start_method('spawn', force=True)
15 |
16 | assert 0 <= use_k < num_split and num_split % 2 == 0
17 | mkdir("../results/" + model_name + "/outputs")
18 |
19 | # model
20 | if is_denoise:
21 | model = NoiseNetwork(in_channels=1, out_channels=1, blindspot=True).cuda()
22 | else:
23 | model = EDSR(cin=2 * use_k + 1, n_resblocks=16, n_feats=64, res_scale=1).cuda()
24 |
25 | #model.load_state_dict(
26 | # torch.load("../results/" + model_name + "/" + model_name + ".pt")
27 | #)
28 |
29 | if denoiser is not None:
30 | denoiser_name = denoiser
31 | denoiser = NoiseNetwork(in_channels=1, out_channels=1, blindspot=True)
32 | denoiser.load_state_dict(torch.load(denoiser_name))
33 |
34 | # dataset
35 | Dataset = SimDataset if use_sim else EPIDataset
36 | train_dataset = Dataset(num_split, "train", is_denoise, denoiser)
37 | train_dataloader = DataLoader(
38 | train_dataset, batch_size, shuffle=False, pin_memory=True
39 | )
40 | dataiter = iter(train_dataloader)
41 | test_dataset = Dataset(num_split, "test", is_denoise, denoiser)
42 | test_dataloader = DataLoader(test_dataset, 1, shuffle=False, pin_memory=True)
43 |
44 | # optimizer
45 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
46 |
47 | average = MovingAverage(0.999)
48 |
49 | t_start = time()
50 | for i in range(1, num_iter + 1):
51 |
52 | lr, hr = next(dataiter)
53 | lr = lr.cuda()
54 | hr = hr.cuda()
55 |
56 | if not is_denoise:
57 | out = model(lr[:, num_split - 1 - use_k : num_split + use_k])
58 | loss = torch.mean(torch.abs(hr[:, num_split - 1 : num_split] - out))
59 | loss_b = torch.mean(
60 | torch.abs(
61 | hr[:, num_split - 1 : num_split] - lr[:, num_split - 1 : num_split]
62 | )
63 | )
64 | average("loss", loss.item())
65 | average("loss_b", loss_b.item())
66 | else:
67 | out = model(hr)
68 | loss = torch.mean((hr - out) ** 2)
69 | average("loss", loss.item())
70 |
71 | optimizer.zero_grad()
72 | loss.backward()
73 | optimizer.step()
74 |
75 | if i % 100 == 0:
76 | print("i = %d, %s, time = %d" % (i, average, time() - t_start))
77 |
78 | if i % 1000 == 0 or i == num_iter:
79 | torch.save(
80 | model.state_dict(),
81 | "../results/" + model_name + "/" + model_name + ".pt",
82 | )
83 |
84 | average_test = MovingAverage(0)
85 |
86 | for j, data in enumerate(test_dataloader):
87 |
88 | if is_denoise:
89 | img = data[0].cuda().permute(3, 0, 1, 2)
90 | if use_sim:
91 | gt = data[1].cuda().permute(3, 0, 1, 2)
92 | else:
93 | img = data[0][0].cuda()
94 | img_mid = img[num_split - 1]
95 | img = img.permute(2, 0, 3, 1)
96 | combined = data[-1][0].cuda()
97 | gt = data[1][0].cuda()
98 |
99 | with torch.no_grad():
100 | if is_denoise:
101 | out = model(img)
102 | else:
103 | out = (
104 | model(img[:, num_split - 1 - use_k : num_split + use_k])
105 | .squeeze()
106 | .permute(2, 0, 1)
107 | )
108 |
109 | if is_denoise:
110 | if use_sim:
111 | average_test(
112 | "mse", ((out - gt) ** 2)[gt > 0.01].mean().item()
113 | )
114 | np.save(
115 | "../results/" + model_name + "/outputs/gt_%d" % j,
116 | gt.cpu().numpy(),
117 | )
118 | else:
119 | np.save(
120 | "../results/" + model_name + "/outputs/in_%d" % j,
121 | img.cpu().numpy(),
122 | )
123 | np.save(
124 | "../results/" + model_name + "/outputs/out_%d" % j,
125 | out.cpu().numpy(),
126 | )
127 | else:
128 | if use_sim:
129 | out = rician_correct(
130 | out,
131 | None if sigma and (denoiser is not None) else 0,
132 | gt < 0.01,
133 | )
134 |
135 | average_test(
136 | "mse_cubic",
137 | torch.sqrt(
138 | ((gt - img_mid) ** 2).mean() / (gt**2).mean()
139 | ).item(),
140 | )
141 | average_test(
142 | "mse_out",
143 | torch.sqrt(
144 | ((gt - out) ** 2).mean() / (gt**2).mean()
145 | ).item(),
146 | )
147 | average_test(
148 | "mse_combined",
149 | torch.sqrt(
150 | ((gt - combined) ** 2).mean() / (gt**2).mean()
151 | ).item(),
152 | )
153 | average_test(
154 | "mm_cubic",
155 | ((gt - img_mid) ** 2)[gt > 0.01].mean().item(),
156 | )
157 | average_test(
158 | "mm_out", ((gt - out) ** 2)[gt > 0.01].mean().item()
159 | )
160 | average_test(
161 | "mm_combined",
162 | ((gt - combined) ** 2)[gt > 0.01].mean().item(),
163 | )
164 |
165 | nib.save(
166 | nib.Nifti1Image(out.cpu().numpy() * 1000, np.eye(4)),
167 | "../results/"
168 | + model_name
169 | + "/outputs/out_%d.nii.gz" % j,
170 | )
171 | nib.save(
172 | nib.Nifti1Image(gt.cpu().numpy() * 1000, np.eye(4)),
173 | "../results/"
174 | + model_name
175 | + "/outputs/gt_%d.nii.gz" % j,
176 | )
177 | nib.save(
178 | nib.Nifti1Image(
179 | img_mid.cpu().numpy() * 1000, np.eye(4)
180 | ),
181 | "../results/"
182 | + model_name
183 | + "/outputs/in_%d.nii.gz" % j,
184 | )
185 | nib.save(
186 | nib.Nifti1Image(
187 | combined.cpu().numpy() * 1000, np.eye(4)
188 | ),
189 | "../results/"
190 | + model_name
191 | + "/outputs/combined_%d.nii.gz" % j,
192 | )
193 |
194 | else:
195 | out = out * 100 + 70
196 | out[out < 0] = 0
197 | img_mid = img_mid * 100 + 70
198 | combined = combined * 100 + 70
199 | gt = gt * 100 + 70
200 |
201 | out = out.cpu().numpy()
202 | img_mid = img_mid.cpu().numpy()
203 | combined = combined.cpu().numpy()
204 | gt = gt.cpu().numpy()
205 | sti = (img_mid + combined) / 2
206 |
207 | if num_split == 4:
208 | mask = gt > 0
209 | average_test("psnr_si", psnr(img_mid, gt, mask))
210 | average_test("psnr_ti", psnr(combined, gt, mask))
211 | average_test("psnr_sti", psnr(sti, gt, mask))
212 | average_test("psnr_out", psnr(out, gt, mask))
213 |
214 | nib.save(
215 | nib.Nifti1Image(gt, np.eye(4)),
216 | "../results/"
217 | + model_name
218 | + "/outputs/gt_%d.nii.gz" % j,
219 | )
220 |
221 | nib.save(
222 | nib.Nifti1Image(out, np.eye(4)),
223 | "../results/"
224 | + model_name
225 | + "/outputs/out_%d.nii.gz" % j,
226 | )
227 | nib.save(
228 | nib.Nifti1Image(combined, np.eye(4)),
229 | "../results/"
230 | + model_name
231 | + "/outputs/combined_%d.nii.gz" % j,
232 | )
233 | nib.save(
234 | nib.Nifti1Image(img_mid, np.eye(4)),
235 | "../results/"
236 | + model_name
237 | + "/outputs/in_%d.nii.gz" % j,
238 | )
239 |
240 | print("%d, %s" % (i // 1000, average_test))
241 |
--------------------------------------------------------------------------------
/code/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 | from typing import Tuple
7 |
8 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
9 | return nn.Conv2d(
10 | in_channels, out_channels, kernel_size,
11 | padding=(kernel_size//2), bias=bias)
12 |
13 |
14 | class ResBlock(nn.Module):
15 | def __init__(
16 | self, conv, n_feats, kernel_size,
17 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
18 |
19 | super(ResBlock, self).__init__()
20 | m = []
21 | for i in range(2):
22 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
23 | if bn:
24 | m.append(nn.BatchNorm2d(n_feats))
25 | if i == 0:
26 | m.append(act)
27 |
28 | self.body = nn.Sequential(*m)
29 | self.res_scale = res_scale
30 |
31 | def forward(self, x):
32 |
33 | res = x #F.dropout(x, 0.3)
34 | res = self.body(res).mul(self.res_scale)
35 | res += x
36 |
37 | return res
38 |
39 |
40 | class EDSR(nn.Module):
41 | def __init__(self, cin=1, n_resblocks=16, n_feats=64, res_scale=1):
42 | super(EDSR, self).__init__()
43 |
44 | conv = default_conv
45 | kernel_size = 3
46 | act = nn.LeakyReLU(0.1, True) #nn.ReLU(True)
47 |
48 | # define head module
49 | m_head = [conv(cin, n_feats, kernel_size)]
50 |
51 | # define body module
52 | m_body = [
53 | ResBlock(
54 | conv, n_feats, kernel_size, act=act, res_scale=res_scale
55 | ) for _ in range(n_resblocks)
56 | ]
57 | m_body.append(conv(n_feats, n_feats, kernel_size))
58 |
59 | # define tail module
60 | m_tail = [
61 | conv(n_feats, 1, kernel_size)
62 | ]
63 |
64 | self.head = nn.Sequential(*m_head)
65 | self.body = nn.Sequential(*m_body)
66 | self.tail = nn.Sequential(*m_tail)
67 |
68 | def forward(self, x):
69 | x = self.head(x)
70 |
71 | res = self.body(x)
72 | res += x
73 |
74 | x = self.tail(res)
75 |
76 | return x
77 |
78 |
79 | class Crop2d(nn.Module):
80 | """Crop input using slicing. Assumes BCHW data.
81 | Args:
82 | crop (Tuple[int, int, int, int]): Amounts to crop from each side of the image.
83 | Tuple is treated as [left, right, top, bottom]/
84 | """
85 |
86 | def __init__(self, crop: Tuple[int, int, int, int]):
87 | super().__init__()
88 | self.crop = crop
89 | assert len(crop) == 4
90 |
91 | def forward(self, x: Tensor) -> Tensor:
92 | (left, right, top, bottom) = self.crop
93 | x0, x1 = left, x.shape[-1] - right
94 | y0, y1 = top, x.shape[-2] - bottom
95 | return x[:, :, y0:y1, x0:x1]
96 |
97 |
98 | class Shift2d(nn.Module):
99 | """Shift an image in either or both of the vertical and horizontal axis by first
100 | zero padding on the opposite side that the image is shifting towards before
101 | cropping the side being shifted towards.
102 | Args:
103 | shift (Tuple[int, int]): Tuple of vertical and horizontal shift. Positive values
104 | shift towards right and bottom, negative values shift towards left and top.
105 | """
106 |
107 | def __init__(self, shift: Tuple[int, int]):
108 | super().__init__()
109 | self.shift = shift
110 | vert, horz = self.shift
111 | y_a, y_b = abs(vert), 0
112 | x_a, x_b = abs(horz), 0
113 | if vert < 0:
114 | y_a, y_b = y_b, y_a
115 | if horz < 0:
116 | x_a, x_b = x_b, x_a
117 | # Order : Left, Right, Top Bottom
118 | self.pad = nn.ZeroPad2d((x_a, x_b, y_a, y_b))
119 | self.crop = Crop2d((x_b, x_a, y_b, y_a))
120 | self.shift_block = nn.Sequential(self.pad, self.crop)
121 |
122 | def forward(self, x: Tensor) -> Tensor:
123 | return self.shift_block(x)
124 |
125 |
126 | def rotate(x: torch.Tensor, angle: int) -> torch.Tensor:
127 | """Rotate images by 90 degrees clockwise. Can handle any 2D data format.
128 | Args:
129 | x (Tensor): Image or batch of images.
130 | angle (int): Clockwise rotation angle in multiples of 90.
131 | data_format (str, optional): Format of input image data, e.g. BCHW,
132 | HWC. Defaults to BCHW.
133 | Returns:
134 | Tensor: Copy of tensor with rotation applied.
135 | """
136 | h_dim = 2
137 | w_dim = 3
138 |
139 | if angle == 0:
140 | return x
141 | elif angle == 90:
142 | return x.flip(w_dim).transpose(h_dim, w_dim)
143 | elif angle == 180:
144 | return x.flip(w_dim).flip(h_dim)
145 | elif angle == 270:
146 | return x.flip(h_dim).transpose(h_dim, w_dim)
147 | else:
148 | raise NotImplementedError("Must be rotation divisible by 90 degrees")
149 |
150 |
151 | class NoiseNetwork(nn.Module):
152 | """Custom U-Net architecture for Self Supervised Denoising (SSDN) and Noise2Noise (N2N).
153 | Base N2N implementation was made with reference to @joeylitalien's N2N implementation.
154 | Changes made are removal of weight sharing when blocks are reused. Usage of LeakyReLu
155 | over standard ReLu and incorporation of blindspot functionality.
156 | Unlike other typical U-Net implementations dropout is not used when the model is trained.
157 | When in blindspot mode the following behaviour changes occur:
158 | * Input batches are duplicated for rotations: 0, 90, 180, 270. This increases the
159 | batch size by 4x. After the encode-decode stage the rotations are undone and
160 | concatenated on the channel axis with the associated original image. This 4x
161 | increase in channel count is collapsed to the standard channel count in the
162 | first 1x1 kernel convolution.
163 | * To restrict the receptive field into the upward direction a shift is used for
164 | convolutions (see ShiftConv2d) and downsampling. Downsampling uses a single
165 | pixel shift prior to max pooling as dictated by Laine et al. This is equivalent
166 | to applying a shift on the upsample.
167 | Args:
168 | in_channels (int, optional): Number of input channels, this will typically be either
169 | 1 (Mono) or 3 (RGB) but can be more. Defaults to 3.
170 | out_channels (int, optional): Number of channels the final convolution should output.
171 | Defaults to 3.
172 | blindspot (bool, optional): Whether to enable the network blindspot. This will
173 | add in rotation stages and shift stages while max pooling and during convolutions.
174 | A futher shift will occur after upsample. Defaults to False.
175 | zero_output_weights (bool, optional): Whether to initialise the weights of
176 | `nin_c` to zero. This is not mentioned in literature but is done as part
177 | of the tensorflow implementation for the parameter estimation network.
178 | Defaults to False.
179 | """
180 |
181 | def __init__(
182 | self,
183 | in_channels: int = 3,
184 | out_channels: int = 3,
185 | blindspot: bool = False,
186 | zero_output_weights: bool = False,
187 | ):
188 | super(NoiseNetwork, self).__init__()
189 | self._blindspot = blindspot
190 | self._zero_output_weights = zero_output_weights
191 | self.Conv2d = ShiftConv2d if self.blindspot else nn.Conv2d
192 |
193 | ####################################
194 | # Encode Blocks
195 | ####################################
196 |
197 | def _max_pool_block(max_pool: nn.Module) -> nn.Module:
198 | if blindspot:
199 | return nn.Sequential(Shift2d((1, 0)), max_pool)
200 | return max_pool
201 |
202 | # Layers: enc_conv0, enc_conv1, pool1
203 | self.encode_block_1 = nn.Sequential(
204 | self.Conv2d(in_channels, 48, 3, stride=1, padding=1),
205 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
206 | self.Conv2d(48, 48, 3, padding=1),
207 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
208 | _max_pool_block(nn.MaxPool2d(2)),
209 | )
210 |
211 | # Layers: enc_conv(i), pool(i); i=2..5
212 | def _encode_block_2_3_4_5() -> nn.Module:
213 | return nn.Sequential(
214 | #nn.Dropout(p_drop), ####
215 | self.Conv2d(48, 48, 3, stride=1, padding=1),
216 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
217 | _max_pool_block(nn.MaxPool2d(2)),
218 | )
219 |
220 | # Separate instances of same encode module definition created
221 | self.encode_block_2 = _encode_block_2_3_4_5()
222 | self.encode_block_3 = _encode_block_2_3_4_5()
223 | self.encode_block_4 = _encode_block_2_3_4_5()
224 | self.encode_block_5 = _encode_block_2_3_4_5()
225 |
226 | # Layers: enc_conv6
227 | self.encode_block_6 = nn.Sequential(
228 | #nn.Dropout(p_drop), ####
229 | self.Conv2d(48, 48, 3, stride=1, padding=1),
230 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
231 | )
232 |
233 | ####################################
234 | # Decode Blocks
235 | ####################################
236 | # Layers: upsample5
237 | self.decode_block_6 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"))
238 |
239 | # Layers: dec_conv5a, dec_conv5b, upsample4
240 | self.decode_block_5 = nn.Sequential(
241 | #nn.Dropout(p_drop), ####
242 | self.Conv2d(96, 96, 3, stride=1, padding=1),
243 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
244 | #nn.Dropout(p_drop), ####
245 | self.Conv2d(96, 96, 3, stride=1, padding=1),
246 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
247 | nn.Upsample(scale_factor=2, mode="nearest"),
248 | )
249 |
250 | # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2
251 | def _decode_block_4_3_2() -> nn.Module:
252 | return nn.Sequential(
253 | #nn.Dropout(p_drop), ####
254 | self.Conv2d(144, 96, 3, stride=1, padding=1),
255 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
256 | #nn.Dropout(p_drop), ####
257 | self.Conv2d(96, 96, 3, stride=1, padding=1),
258 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
259 | nn.Upsample(scale_factor=2, mode="nearest"),
260 | )
261 |
262 | # Separate instances of same decode module definition created
263 | self.decode_block_4 = _decode_block_4_3_2()
264 | self.decode_block_3 = _decode_block_4_3_2()
265 | self.decode_block_2 = _decode_block_4_3_2()
266 |
267 | # Layers: dec_conv1a, dec_conv1b, dec_conv1c,
268 | self.decode_block_1 = nn.Sequential(
269 | #nn.Dropout(p_drop), ####
270 | self.Conv2d(96 + in_channels, 96, 3, stride=1, padding=1),
271 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
272 | #nn.Dropout(p_drop), ####
273 | self.Conv2d(96, 96, 3, stride=1, padding=1),
274 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
275 | )
276 |
277 | ####################################
278 | # Output Block
279 | ####################################
280 |
281 | if self.blindspot:
282 | # Shift 1 pixel down
283 | self.shift = Shift2d((1, 0))
284 | # 4 x Channels due to batch rotations
285 | nin_a_io = 384
286 | else:
287 | nin_a_io = 96
288 |
289 | # nin_a,b,c, linear_act
290 | self.output_conv = self.Conv2d(96, out_channels, 1)
291 | self.output_block = nn.Sequential(
292 | #nn.Dropout(p_drop), ####
293 | self.Conv2d(nin_a_io, nin_a_io, 1),
294 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
295 | #nn.Dropout(p_drop), ####
296 | self.Conv2d(nin_a_io, 96, 1),
297 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
298 | self.output_conv,
299 | )
300 |
301 | # Initialize weights
302 | #self.init_weights()
303 |
304 | @property
305 | def blindspot(self) -> bool:
306 | return self._blindspot
307 |
308 | def init_weights(self):
309 | """Initializes weights using Kaiming He et al. (2015).
310 | Only convolution layers have learnable weights. All convolutions use a leaky
311 | relu activation function (negative_slope = 0.1) except the last which is just
312 | a linear output.
313 | """
314 | with torch.no_grad():
315 | self._init_weights()
316 |
317 | def _init_weights(self):
318 | for m in self.modules():
319 | if isinstance(m, nn.Conv2d):
320 | nn.init.kaiming_normal_(m.weight.data, a=0.1)
321 | m.bias.data.zero_()
322 | # Initialise last output layer
323 | if self._zero_output_weights:
324 | self.output_conv.weight.zero_()
325 | else:
326 | nn.init.kaiming_normal_(self.output_conv.weight.data, nonlinearity="linear")
327 |
328 | def forward(self, x: Tensor) -> Tensor:
329 | if self.blindspot:
330 | rotated = [rotate(x, rot) for rot in (0, 90, 180, 270)]
331 | x = torch.cat((rotated), dim=0)
332 |
333 | # Encoder
334 | pool1 = self.encode_block_1(x)
335 | pool2 = self.encode_block_2(pool1)
336 | pool3 = self.encode_block_3(pool2)
337 | pool4 = self.encode_block_4(pool3)
338 | pool5 = self.encode_block_5(pool4)
339 | encoded = self.encode_block_6(pool5)
340 |
341 | # Decoder
342 | upsample5 = self.decode_block_6(encoded)
343 | concat5 = torch.cat((upsample5, pool4), dim=1)
344 | upsample4 = self.decode_block_5(concat5)
345 | concat4 = torch.cat((upsample4, pool3), dim=1)
346 | upsample3 = self.decode_block_4(concat4)
347 | concat3 = torch.cat((upsample3, pool2), dim=1)
348 | upsample2 = self.decode_block_3(concat3)
349 | concat2 = torch.cat((upsample2, pool1), dim=1)
350 | upsample1 = self.decode_block_2(concat2)
351 | concat1 = torch.cat((upsample1, x), dim=1)
352 | x = self.decode_block_1(concat1)
353 |
354 | # Output
355 | if self.blindspot:
356 | # Apply shift
357 | shifted = self.shift(x)
358 | # Unstack, rotate and combine
359 | rotated_batch = torch.chunk(shifted, 4, dim=0)
360 | aligned = [
361 | rotate(rotated, rot)
362 | for rotated, rot in zip(rotated_batch, (0, 270, 180, 90))
363 | ]
364 | x = torch.cat(aligned, dim=1)
365 |
366 | x = self.output_block(x)
367 |
368 | return x
369 |
370 | @staticmethod
371 | def input_wh_mul() -> int:
372 | """Multiple that both the width and height dimensions of an input must be to be
373 | processed by the network. This is devised from the number of pooling layers that
374 | reduce the input size.
375 | Returns:
376 | int: Dimension multiplier
377 | """
378 | max_pool_layers = 5
379 | return 2 ** max_pool_layers
380 |
381 |
382 | class ShiftConv2d(nn.Conv2d):
383 | def __init__(self, *args, **kwargs):
384 | super().__init__(*args, **kwargs)
385 | self.shift_size = (self.kernel_size[0] // 2, 0)
386 | # Use individual layers of shift for wrapping conv with shift
387 | shift = Shift2d(self.shift_size)
388 | self.pad = shift.pad
389 | self.crop = shift.crop
390 |
391 | def forward(self, x: Tensor) -> Tensor:
392 | x = self.pad(x)
393 | x = super().forward(x)
394 | x = self.crop(x)
395 | return x
--------------------------------------------------------------------------------
/code/trajectory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.io as sio
3 | import os
4 | from scipy.spatial.transform import Rotation
5 | from scipy.ndimage import gaussian_filter1d
6 | from scipy.interpolate import interp1d
7 |
8 |
9 | def get_trajectory(folder='../trajectory'):
10 |
11 | traj = []
12 |
13 | for f in os.listdir(folder):
14 | joint_coord = sio.loadmat(os.path.join(folder, f))['joint_coord'].astype(np.float32)
15 |
16 | joint_coord = joint_coord[np.all(joint_coord > 0, (1, 2))]
17 |
18 | eye_l = joint_coord[..., 7]
19 | eye_r = joint_coord[..., 8]
20 | neck = (joint_coord[..., 11] + joint_coord[..., 12]) / 2
21 |
22 | origin = (eye_l + eye_r + neck) / 3
23 |
24 | x_vec = eye_l - eye_r
25 | x_vec = x_vec / np.linalg.norm(x_vec, ord=2, axis=-1, keepdims=True)
26 |
27 | neck_eye_l = neck - eye_l
28 | y_vec = np.cross(x_vec, neck_eye_l)
29 | y_vec = y_vec / np.linalg.norm(y_vec, ord=2, axis=-1, keepdims=True)
30 |
31 | z_vec = np.cross(x_vec, y_vec)
32 | z_vec = z_vec / np.linalg.norm(z_vec, ord=2, axis=-1, keepdims=True)
33 |
34 | R = np.stack([x_vec, y_vec, z_vec], -1)
35 | R = R @ R[0].T[None]
36 | R = Rotation.from_matrix(R).as_euler('xyz')
37 | t = origin - origin[[0]]
38 | Rt = np.concatenate([R, t], -1)
39 | Rt = Rt[::2]
40 | Rt = gaussian_filter1d(Rt, 0.5, 0)
41 |
42 | interp_func = interp1d(np.arange(Rt.shape[0]), Rt, kind='cubic', axis=0, fill_value="extrapolate", assume_sorted=True)
43 |
44 | traj.append((interp_func, Rt.shape[0]-1))
45 |
46 | return traj
47 |
48 | if __name__ == '__main__':
49 | pass
50 |
--------------------------------------------------------------------------------
/code/utils.py:
--------------------------------------------------------------------------------
1 | from scipy.special import i0, i1
2 | from scipy.stats import trim_mean
3 | import torch
4 | import numpy as np
5 | import os
6 | from skimage.metrics import structural_similarity
7 |
8 | class MovingAverage:
9 | def __init__(self, alpha):
10 | assert 0 <= alpha < 1
11 | self.alpha = alpha
12 | self.value = dict()
13 |
14 | def __call__(self, key, value):
15 | if key not in self.value:
16 | self.value[key] = (0, 0)
17 | num, v = self.value[key]
18 | num += 1
19 | if self.alpha:
20 | v = v * self.alpha + value * (1 - self.alpha)
21 | else:
22 | v += value
23 | self.value[key] = (num, v)
24 |
25 | def __str__(self):
26 | s = ''
27 | for key in self.value:
28 | num, v = self.value[key]
29 | if self.alpha:
30 | s += "%s = %f\t" % (key, v / (1 - self.alpha**num))
31 | else:
32 | s += "%s = %f\t" % (key, v / num)
33 | return s
34 |
35 | def rician_correct(out, sigma, background):
36 | if sigma == 0:
37 | out[out < 0] = 0
38 | return out
39 | elif sigma is None:
40 | sigma_pi = trim_mean(out[background].cpu().numpy(), 0.1, None)
41 | sigma = sigma_pi * np.sqrt(2/np.pi)
42 | else:
43 | sigma_pi = sigma * np.sqrt(np.pi/2)
44 |
45 | old_out = out
46 | out = out / sigma_pi
47 | out[out < 1] = 1
48 | curVal=0
49 | for coeff in [-0.02459419, 0.28790799, 0.27697441, 2.68069732]:
50 | curVal = (curVal+coeff)*out
51 | out = (curVal - 3.22092921) * (sigma**2)
52 | snr_mask = old_out/sigma > 3.5
53 | out[snr_mask] = old_out[snr_mask]**2 - sigma**2
54 | out = torch.sqrt(out)
55 | return out
56 |
57 | def fba(imgs, p):
58 | freqs = [np.fft.rfftn(img) for img in imgs]
59 | weights = [np.abs(freq) ** p for freq in freqs]
60 | return np.fft.irfftn(sum(freq * weight for freq, weight in zip(freqs, weights)) / sum(weights)).astype(np.float32)
61 |
62 | def mkdir(path):
63 | if not os.path.exists(path):
64 | os.makedirs(path)
65 |
66 | def psnr(x, y, mask=None):
67 | if mask is None:
68 | mse = np.mean((x - y) ** 2)
69 | else:
70 | mse = np.sum(((x - y) ** 2) * mask) / mask.sum()
71 | return 10 * np.log10(y.max()**2 / mse)
72 |
73 | def ssim(x, y, mask=None):
74 | mssim, S = structural_similarity(x, y, full=True)
75 | if mask is not None:
76 | return (S * mask).sum() / mask.sum()
77 | else:
78 | return mssim
79 |
80 | def ssim_slice(x, y, mask):
81 | mask = mask.sum((0,1)) > 0
82 | #print(np.nonzero(mask))
83 | x = x[..., mask]
84 | y = y[..., mask]
85 |
86 | return structural_similarity(x, y)
87 | #ssims = []
88 | #for i in range(x.shape[-1]):
89 | # ssims.append(structural_similarity(x[..., i], y[..., i]))
90 | #return np.mean(ssims)
91 |
92 |
--------------------------------------------------------------------------------
/img/stress.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/img/stress.gif
--------------------------------------------------------------------------------
/trajectory/1.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/1.mat
--------------------------------------------------------------------------------
/trajectory/10.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/10.mat
--------------------------------------------------------------------------------
/trajectory/2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/2.mat
--------------------------------------------------------------------------------
/trajectory/3.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/3.mat
--------------------------------------------------------------------------------
/trajectory/4.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/4.mat
--------------------------------------------------------------------------------
/trajectory/5.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/5.mat
--------------------------------------------------------------------------------
/trajectory/6.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/6.mat
--------------------------------------------------------------------------------
/trajectory/7.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/7.mat
--------------------------------------------------------------------------------
/trajectory/8.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/8.mat
--------------------------------------------------------------------------------
/trajectory/9.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddmc/STRESS/5074271b2a8f9e252257c9ee5cd565a224bb2c2c/trajectory/9.mat
--------------------------------------------------------------------------------