├── haca3
├── modules
│ ├── __init__.py
│ ├── _static_version.py
│ ├── dataset.py
│ ├── utils.py
│ ├── _version.py
│ ├── fusion_model.py
│ ├── network.py
│ └── model.py
├── __init__.py
├── _static_version.py
├── train_fusion.py
├── train.py
├── encode.py
├── _version.py
└── test.py
├── figures
├── GA.png
├── multi_site.png
└── longitudinal.png
├── Dockerfile
├── requirements.txt
├── .gitlab-ci.yml
├── setup.py
├── .gitignore
└── README.md
/haca3/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figures/GA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lianruizuo/haca3/HEAD/figures/GA.png
--------------------------------------------------------------------------------
/figures/multi_site.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lianruizuo/haca3/HEAD/figures/multi_site.png
--------------------------------------------------------------------------------
/figures/longitudinal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lianruizuo/haca3/HEAD/figures/longitudinal.png
--------------------------------------------------------------------------------
/haca3/__init__.py:
--------------------------------------------------------------------------------
1 | from ._version import __version__
2 | from .modules.model import HACA3
3 | from .modules.utils import *
4 | from .test import background_removal2d
5 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
2 |
3 | RUN apt-get update && apt-get install -y --no-install-recommends git
4 |
5 | COPY . /tmp/haca3
6 |
7 | RUN pip install /tmp/haca3 && rm -rf /tmp/haca3
8 |
9 | ENTRYPOINT ["haca3-test"]
10 |
--------------------------------------------------------------------------------
/haca3/_static_version.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # This file is part of 'miniver': https://github.com/jbweston/miniver
3 | #
4 | # This file will be overwritten by setup.py when a source or binary
5 | # distribution is made. The magic value "__use_git__" is interpreted by
6 | # version.py.
7 |
8 | version = "__use_git__"
9 |
10 | # These values are only set if the distribution was created with 'git archive'
11 | refnames = "$Format:%D$"
12 | git_hash = "$Format:%h$"
13 |
--------------------------------------------------------------------------------
/haca3/modules/_static_version.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # This file is part of 'miniver': https://github.com/jbweston/miniver
3 | #
4 | # This file will be overwritten by setup.py when a source or binary
5 | # distribution is made. The magic value "__use_git__" is interpreted by
6 | # version.py.
7 |
8 | version = "__use_git__"
9 |
10 | # These values are only set if the distribution was created with 'git archive'
11 | refnames = "$Format:%D$"
12 | git_hash = "$Format:%h$"
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2022.12.7
2 | charset-normalizer==2.1.1
3 | click==8.1.7
4 | cmake==3.25.0
5 | colorama==0.4.6
6 | Deprecated==1.2.14
7 | filelock==3.9.0
8 | humanize==4.8.0
9 | idna==3.4
10 | importlib-resources==6.0.1
11 | Jinja2==3.1.2
12 | lit==15.0.7
13 | markdown-it-py==3.0.0
14 | MarkupSafe==2.1.2
15 | mdurl==0.1.2
16 | mpmath==1.2.1
17 | networkx==3.0
18 | nibabel==5.1.0
19 | numpy==1.24.4
20 | packaging==23.1
21 | Pillow==9.3.0
22 | Pygments==2.16.1
23 | requests==2.28.1
24 | rich==13.5.2
25 | scipy==1.10.1
26 | shellingham==1.5.3
27 | SimpleITK==2.2.1
28 | sympy==1.11.1
29 | torch==2.0.1+cu118
30 | torchaudio==2.0.2+cu118
31 | torchio==0.19.1
32 | torchvision==0.15.2+cu118
33 | tqdm==4.66.1
34 | triton==2.0.0
35 | typer==0.9.0
36 | typing_extensions==4.4.0
37 | urllib3==1.26.13
38 | wrapt==1.15.0
39 | zipp==3.16.2
40 |
--------------------------------------------------------------------------------
/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 | .build_image:
2 | image: docker:latest
3 | services:
4 | - name: docker:dind
5 | command: [ "--experimental" ]
6 | variables:
7 | DOCKER_BUILDKIT: 1
8 | before_script:
9 | - docker login -u gitlab-ci-token -p $CI_JOB_TOKEN $CI_REGISTRY
10 | - apk add curl bash
11 | - mkdir -vp ~/.docker/cli-plugins/
12 | - curl --silent -L "https://gitlab.com/neurobuilds/setup-dind/-/raw/main/setup_dind.sh" > ~/setup_dind.sh
13 | - bash ~/setup_dind.sh
14 | - docker context create tls-environment
15 | - docker buildx create --use tls-environment
16 | - docker buildx inspect --bootstrap
17 | script:
18 | - docker buildx build --build-arg CI_JOB_TOKEN=$CI_JOB_TOKEN --push -t $TAG .
19 |
20 | build_tag:
21 | extends: .build_image
22 | variables:
23 | TAG: $CI_REGISTRY_IMAGE:$CI_COMMIT_TAG
24 | only:
25 | - tags
26 |
27 | build_branch:
28 | extends: .build_image
29 | variables:
30 | TAG: $CI_REGISTRY_IMAGE:$CI_COMMIT_BRANCH
31 | only:
32 | - main
33 | - support/*
--------------------------------------------------------------------------------
/haca3/train_fusion.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from modules.fusion_model import FusionNet
4 |
5 | def main(args=None):
6 | args = sys.argv[1:] if args is None else args
7 | parser = argparse.ArgumentParser(description='Unsupervised harmonization via disentanglement.')
8 | parser.add_argument('--dataset-dirs', nargs='+', type=str, required=True)
9 | parser.add_argument('--out-dir', type=str, default='../')
10 | parser.add_argument('--pretrained-model', type=str, default=None)
11 | parser.add_argument('--lr', type=float, default=5e-4)
12 | parser.add_argument('--batch-size', type=int, default=8)
13 | parser.add_argument('--epochs', type=int, default=4)
14 | parser.add_argument('--gpu', type=int, default=0)
15 | args = parser.parse_args(args)
16 |
17 | # initialize model
18 | trainer = FusionNet(pretrained_model=args.pretrained_model, gpu=args.gpu)
19 |
20 | trainer.load_dataset(dataset_dirs=args.dataset_dirs, batch_size=args.batch_size)
21 |
22 | trainer.initialize_training(out_dir=args.out_dir, lr=args.lr)
23 |
24 | trainer.train(epochs=args.epochs)
25 |
26 | if __name__ == '__main__':
27 | main()
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | __package_name__ = "haca3"
4 |
5 |
6 | def get_version_and_cmdclass(pkg_path):
7 | """Load version.py module without importing the whole package.
8 |
9 | Template code from miniver
10 | """
11 | import os
12 | from importlib.util import module_from_spec, spec_from_file_location
13 |
14 | spec = spec_from_file_location("version", os.path.join(pkg_path, "_version.py"))
15 | module = module_from_spec(spec)
16 | spec.loader.exec_module(module)
17 | return module.__version__, module.get_cmdclass(pkg_path)
18 |
19 |
20 | __version__, cmdclass = get_version_and_cmdclass(__package_name__)
21 |
22 |
23 | # noinspection PyTypeChecker
24 | setup(
25 | name=__package_name__,
26 | version=__version__,
27 | description="HACA3: A unified approach for multi-site MR image harmonization",
28 | long_description="HACA3: A unified approach for multi-site MR image harmonization",
29 | author="Lianrui Zuo",
30 | author_email="lr_zuo@jhu.edu",
31 | url="https://gitlab.com/lr_zuo/haca3",
32 | license="Apache License, 2.0",
33 | classifiers=[
34 | "Development Status :: 3 - Alpha",
35 | "Environment :: Console",
36 | "Intended Audience :: Science/Research",
37 | "License :: OSI Approved :: Apache Software License",
38 | "Programming Language :: Python :: 3.8",
39 | "Topic :: Scientific/Engineering",
40 | ],
41 | packages=find_packages(),
42 | keywords="mri harmonization",
43 | entry_points={
44 | "console_scripts": [
45 | "haca3-train=haca3.train:main",
46 | "haca3-test=haca3.test:main",
47 | ]
48 | },
49 | install_requires=[
50 | "nibabel",
51 | "numpy",
52 | "scipy",
53 | "torch",
54 | "torchvision",
55 | "tqdm",
56 | "torchio",
57 | "scikit-image",
58 | "tensorboard"
59 | ],
60 | cmdclass=cmdclass,
61 | )
62 |
--------------------------------------------------------------------------------
/haca3/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from .modules.model import HACA3
4 |
5 |
6 | def main(args=None):
7 | args = sys.argv[1:] if args is None else args
8 | parser = argparse.ArgumentParser(description='Harmonization with HACA3.')
9 | parser.add_argument('--dataset-dirs', type=str, nargs='+', required=True)
10 | parser.add_argument('--contrasts', type=str, nargs='+', required=True)
11 | parser.add_argument('--orientations', type=str, nargs='+', default=['axial', 'coronal', 'sagittal'])
12 | parser.add_argument('--out-dir', type=str, default='.')
13 | parser.add_argument('--beta-dim', type=int, default=5)
14 | parser.add_argument('--theta-dim', type=int, default=2)
15 | parser.add_argument('--eta-dim', type=int, default=2)
16 | parser.add_argument('--normalization-method', type=str, default='01')
17 | parser.add_argument('--pretrained-haca3', type=str, default=None)
18 | parser.add_argument('--pretrained-eta-encoder', type=str, default=None)
19 | parser.add_argument('--lr', type=float, default=5e-4)
20 | parser.add_argument('--batch-size', type=int, default=8)
21 | parser.add_argument('--epochs', type=int, default=8)
22 | parser.add_argument('--gpu-id', type=int, default=0)
23 | args = parser.parse_args(args)
24 |
25 | text_div = '=' * 10
26 | print(f'{text_div} BEGIN HACA3 TRAINING {text_div}')
27 |
28 | # ====== 1. INITIALIZE MODEL ======
29 | haca3 = HACA3(beta_dim=args.beta_dim, theta_dim=args.theta_dim, eta_dim=args.eta_dim,
30 | pretrained_haca3=args.pretrained_haca3, pretrained_eta_encoder=args.pretrained_eta_encoder,
31 | gpu_id=args.gpu_id)
32 |
33 | # ====== 2. LOAD DATASETS ======
34 | haca3.load_dataset(dataset_dirs=args.dataset_dirs, contrasts=args.contrasts, orientations=args.orientations,
35 | batch_size=args.batch_size, normalization_method=args.normalization_method)
36 |
37 | # ====== 3. INITIALIZE TRAINING ======
38 | haca3.initialize_training(out_dir=args.out_dir, lr=args.lr)
39 |
40 | # ====== 4. BEGIN TRAINING ======
41 | haca3.train(epochs=args.epochs)
42 |
--------------------------------------------------------------------------------
/haca3/encode.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from modules.model import HACA3
4 | import torch
5 | import nibabel as nib
6 | from PIL import Image
7 | import os
8 | import numpy as np
9 | from torchvision.transforms import ToTensor, CenterCrop, Compose, ToPILImage
10 |
11 |
12 | def obtain_single_image(img_path, normalization_val=1000.0):
13 | img_file = nib.load(img_path)
14 | img_vol = np.array(img_file.get_fdata().astype(np.float32))
15 | img_vol = img_vol / normalization_val * 0.25
16 | n_row, n_col, n_slc = img_vol.shape
17 | # get images with proper zero padding
18 | img_padded = np.zeros((288, 288, 288)).astype(np.float32)
19 | img_padded[144 - n_row // 2:144 + n_row // 2 + n_row % 2,
20 | 144 - n_col // 2:144 + n_col // 2 + n_col % 2,
21 | 144 - n_slc // 2:144 + n_slc // 2 + n_slc % 2] = img_vol
22 |
23 | return ToTensor()(img_padded), img_file.header, img_file.affine
24 |
25 |
26 | def main(args=None):
27 | args = sys.argv[1:] if args is None else args
28 | parser = argparse.ArgumentParser(description='Learn anatomy, artifact, and contrast with HACA3')
29 | parser.add_argument('--image', type=str, required=True)
30 | parser.add_argument('--out-dir', type=str, required=True)
31 | parser.add_argument('--prefix', type=str, default='subject1')
32 | parser.add_argument('--pretrained-harmonization', type=str, default=None)
33 | parser.add_argument('--beta-dim', type=int, default=5)
34 | parser.add_argument('--eta-dim', type=int, default=2)
35 | parser.add_argument('--theta-dim', type=int, default=2)
36 | parser.add_argument('--gpu', type=int, default=0)
37 | parser.add_argument('--norm', default=1000.0, type=float)
38 | args = parser.parse_args(args)
39 |
40 | harmonization_model = HACA3(beta_dim=args.beta_dim,
41 | theta_dim=args.theta_dim,
42 | eta_dim=args.eta_dim,
43 | pretrained_harmonization=args.pretrained_harmonization,
44 | gpu=args.gpu)
45 |
46 | # load image
47 | img_vol, img_header, img_affine = obtain_single_image(args.image, args.norm)
48 |
49 | # Encoding
50 | harmonization_model.encode(img=img_vol.float().permute(2, 1, 0).permute(2, 0, 1),
51 | out_dir=args.out_dir,
52 | prefix=args.prefix,
53 | affine=img_affine,
54 | header=img_header)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/haca3/modules/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import torch
4 | from torch.utils.data.dataset import Dataset
5 | import numpy as np
6 | from torchvision.transforms import Compose, Pad, CenterCrop, ToTensor, ToPILImage
7 | import torchio as tio
8 | import nibabel as nib
9 |
10 | default_transform = Compose([ToPILImage(), Pad(40), CenterCrop([224, 224])])
11 | transform_dict = {
12 | tio.RandomMotion(degrees=(15, 30), translation=(10, 20)): 0.25,
13 | tio.RandomNoise(std=(0.01, 0.1)): 0.25,
14 | tio.RandomGhosting(num_ghosts=(4, 10)): 0.25,
15 | tio.RandomBiasField(): 0.25
16 | }
17 | degradation_transform = tio.OneOf(transform_dict)
18 | contrast_names = ['T1PRE', 'T2', 'PD', 'FLAIR']
19 |
20 |
21 | def get_tensor_from_fpath(fpath, normalization_method):
22 | if os.path.exists(fpath):
23 | image = np.squeeze(nib.load(fpath).get_fdata().astype(np.float32)).transpose([1, 0])
24 |
25 | if normalization_method == 'wm':
26 | image = image / 2.0
27 | else:
28 | p99 = np.percentile(image.flatten(), 95)
29 | image = image / (p99 + 1e-5)
30 | image = np.clip(image, a_min=0.0, a_max=5.0)
31 |
32 | image = np.array(default_transform(image))
33 | image = ToTensor()(image)
34 | else:
35 | image = torch.ones([1, 224, 224])
36 | return image
37 |
38 |
39 | def background_removal(image_dicts):
40 | num_contrasts = len(contrast_names)
41 | mask = torch.ones((1, 224, 224))
42 | for image_dict in image_dicts:
43 | mask = mask * image_dict['image'].ge(1e-8)
44 | for i in range(num_contrasts):
45 | image_dicts[i]['image'] = image_dicts[i]['image'] * mask
46 | image_dicts[i]['image_degrade'] = image_dicts[i]['image_degrade'] * mask
47 | image_dicts[i]['mask'] = mask.bool()
48 | return image_dicts
49 |
50 |
51 | class HACA3Dataset(Dataset):
52 | def __init__(self, dataset_dirs, contrasts, orientations, mode='train', normalization_method='01'):
53 | self.mode = mode
54 | self.dataset_dirs = dataset_dirs
55 | self.contrasts = contrasts
56 | self.orientations = orientations
57 | self.t1_paths, self.site_ids = self._get_file_paths()
58 | self.normalization_method = normalization_method
59 |
60 | def _get_file_paths(self):
61 | fpaths, site_ids = [], []
62 | for site_id, dataset_dir in enumerate(self.dataset_dirs):
63 | for orientation in self.orientations:
64 | t1_paths = sorted(glob(os.path.join(dataset_dir, self.mode, f'*T1PRE*{orientation.upper()}*nii.gz')))
65 | fpaths += t1_paths
66 | site_ids += [site_id] * len(t1_paths)
67 | return fpaths, site_ids
68 |
69 | def __len__(self):
70 | return len(self.t1_paths)
71 |
72 | def __getitem__(self, idx: int):
73 | image_dicts = []
74 | for contrast_id, contrast_name in enumerate(contrast_names):
75 | image_path = self.t1_paths[idx].replace('T1PRE', contrast_name)
76 | image = get_tensor_from_fpath(image_path, self.normalization_method)
77 | image_degrade = degradation_transform(image.unsqueeze(1)).squeeze(1)
78 | site_id = self.site_ids[idx]
79 | image_dict = {'image': image,
80 | 'image_degrade': image_degrade,
81 | 'site_id': site_id,
82 | 'contrast_id': contrast_id,
83 | 'exists': 0 if image[0, 0, 0] > 0.9999 else 1}
84 | image_dicts.append(image_dict)
85 | return background_removal(image_dicts)
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### JetBrains template
2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
4 |
5 | # User-specific stuff
6 | .idea/
7 |
8 | ### macOS template
9 | # General
10 | .DS_Store
11 | .AppleDouble
12 | .LSOverride
13 |
14 | # Icon must end with two \r
15 | Icon
16 |
17 | # Thumbnails
18 | ._*
19 |
20 | # Files that might appear in the root of a volume
21 | .DocumentRevisions-V100
22 | .fseventsd
23 | .Spotlight-V100
24 | .TemporaryItems
25 | .Trashes
26 | .VolumeIcon.icns
27 | .com.apple.timemachine.donotpresent
28 |
29 | # Directories potentially created on remote AFP share
30 | .AppleDB
31 | .AppleDesktop
32 | Network Trash Folder
33 | Temporary Items
34 | .apdisk
35 |
36 | ### Python template
37 | # Byte-compiled / optimized / DLL files
38 | __pycache__/
39 | *.py[cod]
40 | *$py.class
41 |
42 | # C extensions
43 | *.so
44 |
45 | # Distribution / packaging
46 | .Python
47 | build/
48 | develop-eggs/
49 | dist/
50 | downloads/
51 | eggs/
52 | .eggs/
53 | lib/
54 | lib64/
55 | parts/
56 | sdist/
57 | var/
58 | wheels/
59 | share/python-wheels/
60 | *.egg-info/
61 | .installed.cfg
62 | *.egg
63 | MANIFEST
64 |
65 | # PyInstaller
66 | # Usually these files are written by a python script from a template
67 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
68 | *.manifest
69 | *.spec
70 |
71 | # Installer logs
72 | pip-log.txt
73 | pip-delete-this-directory.txt
74 |
75 | # Unit test / coverage reports
76 | htmlcov/
77 | .tox/
78 | .nox/
79 | .coverage
80 | .coverage.*
81 | .cache
82 | nosetests.xml
83 | coverage.xml
84 | *.cover
85 | *.py,cover
86 | .hypothesis/
87 | .pytest_cache/
88 | cover/
89 |
90 | # Translations
91 | *.mo
92 | *.pot
93 |
94 | # Django stuff:
95 | *.log
96 | local_settings.py
97 | db.sqlite3
98 | db.sqlite3-journal
99 |
100 | # Flask stuff:
101 | instance/
102 | .webassets-cache
103 |
104 | # Scrapy stuff:
105 | .scrapy
106 |
107 | # Sphinx documentation
108 | docs/_build/
109 |
110 | # PyBuilder
111 | .pybuilder/
112 | target/
113 |
114 | # Jupyter Notebook
115 | .ipynb_checkpoints
116 |
117 | # IPython
118 | profile_default/
119 | ipython_config.py
120 |
121 | # pyenv
122 | # For a library or package, you might want to ignore these files since the code is
123 | # intended to run in multiple environments; otherwise, check them in:
124 | # .python-version
125 |
126 | # pipenv
127 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
128 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
129 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
130 | # install all needed dependencies.
131 | #Pipfile.lock
132 |
133 | # poetry
134 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
135 | # This is especially recommended for binary packages to ensure reproducibility, and is more
136 | # commonly ignored for libraries.
137 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
138 | #poetry.lock
139 |
140 | # pdm
141 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
142 | #pdm.lock
143 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
144 | # in version control.
145 | # https://pdm.fming.dev/#use-with-ide
146 | .pdm.toml
147 |
148 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
149 | __pypackages__/
150 |
151 | # Celery stuff
152 | celerybeat-schedule
153 | celerybeat.pid
154 |
155 | # SageMath parsed files
156 | *.sage.py
157 |
158 | # Environments
159 | .env
160 | .venv
161 | env/
162 | venv/
163 | ENV/
164 | env.bak/
165 | venv.bak/
166 |
167 | # Spyder project settings
168 | .spyderproject
169 | .spyproject
170 |
171 | # Rope project settings
172 | .ropeproject
173 |
174 | # mkdocs documentation
175 | /site
176 |
177 | # mypy
178 | .mypy_cache/
179 | .dmypy.json
180 | dmypy.json
181 |
182 | # Pyre type checker
183 | .pyre/
184 |
185 | # pytype static type analyzer
186 | .pytype/
187 |
188 | # Cython debug symbols
189 | cython_debug/
190 |
191 | # PyCharm
192 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
193 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
194 | # and can be added to the global gitignore or merged into this file. For a more nuclear
195 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
196 | #.idea/
197 |
198 |
--------------------------------------------------------------------------------
/haca3/modules/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import errno
6 | import nibabel as nib
7 | from torchvision import utils
8 | import torchvision.models as models
9 | import numpy as np
10 |
11 |
12 | def mkdir_p(path):
13 | try:
14 | os.makedirs(path)
15 | except OSError as exc:
16 | if exc.errno == errno.EEXIST and os.path.isdir(path):
17 | pass
18 | else:
19 | raise
20 |
21 |
22 | def reparameterize_logit(logit):
23 | import warnings
24 | warnings.filterwarnings('ignore', message='.*Mixed memory format inputs detected.*')
25 | beta = F.gumbel_softmax(logit, tau=1.0, dim=1, hard=True)
26 | return beta
27 |
28 |
29 | def save_image(images, file_name):
30 | image_save = torch.cat([image[:4, [0], ...].cpu() for image in images], dim=0)
31 | image_save = utils.make_grid(tensor=image_save, nrow=4, normalize=False, range=(0, 1)).detach().numpy()[0, ...]
32 | image_save = nib.Nifti1Image(image_save.transpose(1, 0), np.eye(4))
33 | nib.save(image_save, file_name)
34 |
35 |
36 | def dropout_contrasts(available_contrast_id, contrast_id_to_drop=None):
37 | """
38 | Randomly dropout contrasts for HACA3 training.
39 |
40 | ==INPUTS==j
41 | * available_contrast_id: torch.Tensor (batch_size, num_contrasts)
42 | Indicates the availability of each MR contrast. 1: if available, 0: if unavailable.
43 |
44 | * contrast_id_to_drop: torch.Tensor (batch_size, num_contrasts)
45 | If provided, indicates the contrast indexes forced to drop. Default: None
46 |
47 | ==OUTPUTS==
48 | * contrast_id_after_dropout: torch.Tensor (batch_size, num_contrasts)
49 | Some available contrasts will be randomly dropped out (as if they are unavailable).
50 | However, each sample will have at least one contrast available.
51 | """
52 | batch_size = available_contrast_id.shape[0]
53 | if contrast_id_to_drop is not None:
54 | available_contrast_id = available_contrast_id - contrast_id_to_drop
55 | contrast_id_after_dropout = available_contrast_id.clone()
56 | for i in range(batch_size):
57 | available_contrast_ids_per_subject = (available_contrast_id[i] == 1).nonzero(as_tuple=False).squeeze(1)
58 | num_available_contrasts = available_contrast_ids_per_subject.numel()
59 | if num_available_contrasts > 1:
60 | num_contrast_to_drop = torch.randperm(num_available_contrasts - 1)[0]
61 | contrast_ids_to_drop = torch.randperm(num_available_contrasts)[:num_contrast_to_drop]
62 | contrast_ids_to_drop = available_contrast_ids_per_subject[contrast_ids_to_drop]
63 | contrast_id_after_dropout[i, contrast_ids_to_drop] = 0.0
64 | return contrast_id_after_dropout
65 |
66 |
67 | class PerceptualLoss(nn.Module):
68 | def __init__(self, vgg_model):
69 | super().__init__()
70 | for param in vgg_model.parameters():
71 | param.requires_grad = False
72 | self.vgg = nn.Sequential(*list(vgg_model.children())[:13]).eval()
73 |
74 | def forward(self, x, y):
75 | if x.shape[1] == 1:
76 | x = x.repeat(1, 3, 1, 1)
77 | if y.shape[1] == 1:
78 | y = y.repeat(1, 3, 1, 1)
79 | return F.l1_loss(self.vgg(x), self.vgg(y))
80 |
81 |
82 | class PatchNCELoss(nn.Module):
83 | def __init__(self, temperature=0.1):
84 | super().__init__()
85 | self.ce_loss = nn.CrossEntropyLoss(reduction='none')
86 | self.temperature = temperature
87 |
88 | def forward(self, query_feature, positive_feature, negative_feature):
89 | B, C, N = query_feature.shape
90 |
91 | l_positive = (query_feature * positive_feature).sum(dim=1)[:, :, None]
92 | l_negative = torch.bmm(query_feature.permute(0, 2, 1), negative_feature)
93 |
94 | logits = torch.cat((l_positive, l_negative), dim=2) / self.temperature
95 |
96 | predictions = logits.flatten(0, 1)
97 | targets = torch.zeros(B * N, dtype=torch.long).to(query_feature.device)
98 | return self.ce_loss(predictions, targets).mean()
99 |
100 |
101 | class KLDivergenceLoss(nn.Module):
102 | def __init__(self):
103 | super().__init__()
104 |
105 | def forward(self, mu, logvar):
106 | kld_loss = -0.5 * logvar + 0.5 * (torch.exp(logvar) + torch.pow(mu, 2)) - 0.5
107 | return kld_loss
108 |
109 |
110 | def divide_into_batches(in_tensor, num_batches):
111 | batch_size = in_tensor.shape[0] // num_batches
112 | remainder = in_tensor.shape[0] % num_batches
113 | batches = []
114 |
115 | current_start = 0
116 | for i in range(num_batches):
117 | current_end = current_start + batch_size
118 | if remainder:
119 | current_end += 1
120 | remainder -= 1
121 | batches.append(in_tensor[current_start:current_end, ...])
122 | current_start = current_end
123 | return batches
124 |
125 |
126 | def normalize_intensity(image):
127 | thresh = np.percentile(image.flatten(), 95)
128 | image = image / (thresh + 1e-5)
129 | image = np.clip(image, a_min=0.0, a_max=5.0)
130 | return image, thresh
131 |
132 |
133 | def zero_pad(image, image_dim=256):
134 | [n_row, n_col, n_slc] = image.shape
135 | image_padded = np.zeros((image_dim, image_dim, image_dim))
136 | center_loc = image_dim // 2
137 | image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
138 | center_loc - n_col // 2: center_loc + n_col - n_col // 2,
139 | center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] = image
140 | return image_padded
141 |
142 | def zero_pad2d(image, image_dim=256):
143 | [n_row, n_col] = image.shape
144 | image_padded = np.zeros((image_dim, image_dim))
145 | center_loc = image_dim // 2
146 | image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
147 | center_loc - n_col // 2: center_loc + n_col - n_col // 2] = image
148 | return image_padded
149 |
150 |
151 | def crop(image, n_row, n_col, n_slc):
152 | image_dim = image.shape[0]
153 | center_loc = image_dim // 2
154 | return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
155 | center_loc - n_col // 2: center_loc + n_col - n_col // 2,
156 | center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2]
157 |
158 | def crop2d(image, n_row, n_col):
159 | image_dim = image.shape[0]
160 | center_loc = image_dim // 2
161 | return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2,
162 | center_loc - n_col // 2: center_loc + n_col - n_col // 2]
163 |
--------------------------------------------------------------------------------
/haca3/_version.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # This file is part of 'miniver': https://github.com/jbweston/miniver
3 | #
4 | from collections import namedtuple
5 | import os
6 | import subprocess
7 |
8 | from setuptools.command.build_py import build_py as build_py_orig
9 | from setuptools.command.sdist import sdist as sdist_orig
10 |
11 | Version = namedtuple("Version", ("release", "dev", "labels"))
12 |
13 | # No public API
14 | __all__ = []
15 |
16 | package_root = os.path.dirname(os.path.realpath(__file__))
17 | package_name = os.path.basename(package_root)
18 |
19 | STATIC_VERSION_FILE = "_static_version.py"
20 |
21 |
22 | def get_version(version_file=STATIC_VERSION_FILE):
23 | version_info = get_static_version_info(version_file)
24 | version = version_info["version"]
25 | if version == "__use_git__":
26 | version = get_version_from_git()
27 | if not version:
28 | version = get_version_from_git_archive(version_info)
29 | if not version:
30 | version = Version("unknown", None, None)
31 | return pep440_format(version)
32 | else:
33 | return version
34 |
35 |
36 | def get_static_version_info(version_file=STATIC_VERSION_FILE):
37 | version_info = {}
38 | with open(os.path.join(package_root, version_file), "rb") as f:
39 | exec(f.read(), {}, version_info)
40 | return version_info
41 |
42 |
43 | def version_is_from_git(version_file=STATIC_VERSION_FILE):
44 | return get_static_version_info(version_file)["version"] == "__use_git__"
45 |
46 |
47 | def pep440_format(version_info):
48 | release, dev, labels = version_info
49 |
50 | version_parts = [release]
51 | if dev:
52 | if release.endswith("-dev") or release.endswith(".dev"):
53 | version_parts.append(dev)
54 | else: # prefer PEP440 over strict adhesion to semver
55 | version_parts.append(".dev{}".format(dev))
56 |
57 | if labels:
58 | version_parts.append("+")
59 | version_parts.append(".".join(labels))
60 |
61 | return "".join(version_parts)
62 |
63 |
64 | def get_version_from_git():
65 | # git describe --first-parent does not take into account tags from branches
66 | # that were merged-in. The '--long' flag gets us the 'dev' version and
67 | # git hash, '--always' returns the git hash even if there are no tags.
68 | for opts in [["--first-parent"], []]:
69 | try:
70 | p = subprocess.Popen(
71 | ["git", "describe", "--long", "--always"] + opts,
72 | cwd=package_root,
73 | stdout=subprocess.PIPE,
74 | stderr=subprocess.PIPE,
75 | )
76 | except OSError:
77 | return
78 | if p.wait() == 0:
79 | break
80 | else:
81 | return
82 |
83 | description = (
84 | p.communicate()[0]
85 | .decode()
86 | .strip("v") # Tags can have a leading 'v', but the version should not
87 | .rstrip("\n")
88 | .rsplit("-", 2) # Split the latest tag, commits since tag, and hash
89 | )
90 |
91 | try:
92 | release, dev, git = description
93 | except ValueError: # No tags, only the git hash
94 | # prepend 'g' to match with format returned by 'git describe'
95 | git = "g{}".format(*description)
96 | release = "unknown"
97 | dev = None
98 |
99 | labels = []
100 | if dev == "0":
101 | dev = None
102 | else:
103 | labels.append(git)
104 |
105 | try:
106 | p = subprocess.Popen(["git", "diff", "--quiet"], cwd=package_root)
107 | except OSError:
108 | labels.append("confused") # This should never happen.
109 | else:
110 | if p.wait() == 1:
111 | labels.append("dirty")
112 |
113 | return Version(release, dev, labels)
114 |
115 |
116 | # TODO: change this logic when there is a git pretty-format
117 | # that gives the same output as 'git describe'.
118 | # Currently we can only tell the tag the current commit is
119 | # pointing to, or its hash (with no version info)
120 | # if it is not tagged.
121 | def get_version_from_git_archive(version_info):
122 | try:
123 | refnames = version_info["refnames"]
124 | git_hash = version_info["git_hash"]
125 | except KeyError:
126 | # These fields are not present if we are running from an sdist.
127 | # Execution should never reach here, though
128 | return None
129 |
130 | if git_hash.startswith("$Format") or refnames.startswith("$Format"):
131 | # variables not expanded during 'git archive'
132 | return None
133 |
134 | VTAG = "tag: v"
135 | refs = set(r.strip() for r in refnames.split(","))
136 | version_tags = set(r[len(VTAG) :] for r in refs if r.startswith(VTAG))
137 | if version_tags:
138 | release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1"
139 | return Version(release, dev=None, labels=None)
140 | else:
141 | return Version("unknown", dev=None, labels=["g{}".format(git_hash)])
142 |
143 |
144 | __version__ = get_version()
145 |
146 |
147 | # The following section defines a 'get_cmdclass' function
148 | # that can be used from setup.py. The '__version__' module
149 | # global is used (but not modified).
150 |
151 |
152 | def _write_version(fname):
153 | # This could be a hard link, so try to delete it first. Is there any way
154 | # to do this atomically together with opening?
155 | try:
156 | os.remove(fname)
157 | except OSError:
158 | pass
159 | with open(fname, "w") as f:
160 | f.write(
161 | "# This file has been created by setup.py.\n"
162 | "version = '{}'\n".format(__version__)
163 | )
164 |
165 |
166 | def get_cmdclass(pkg_source_path):
167 | class _build_py(build_py_orig):
168 | def run(self):
169 | super().run()
170 |
171 | src_marker = "".join(["src", os.path.sep])
172 |
173 | if pkg_source_path.startswith(src_marker):
174 | path = pkg_source_path[len(src_marker):]
175 | else:
176 | path = pkg_source_path
177 | _write_version(
178 | os.path.join(
179 | self.build_lib, path, STATIC_VERSION_FILE
180 | )
181 | )
182 |
183 | class _sdist(sdist_orig):
184 | def make_release_tree(self, base_dir, files):
185 | super().make_release_tree(base_dir, files)
186 | _write_version(
187 | os.path.join(base_dir, pkg_source_path, STATIC_VERSION_FILE)
188 | )
189 |
190 | return dict(sdist=_sdist, build_py=_build_py)
191 |
--------------------------------------------------------------------------------
/haca3/modules/_version.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # This file is part of 'miniver': https://github.com/jbweston/miniver
3 | #
4 | from collections import namedtuple
5 | import os
6 | import subprocess
7 |
8 | from setuptools.command.build_py import build_py as build_py_orig
9 | from setuptools.command.sdist import sdist as sdist_orig
10 |
11 | Version = namedtuple("Version", ("release", "dev", "labels"))
12 |
13 | # No public API
14 | __all__ = []
15 |
16 | package_root = os.path.dirname(os.path.realpath(__file__))
17 | package_name = os.path.basename(package_root)
18 |
19 | STATIC_VERSION_FILE = "_static_version.py"
20 |
21 |
22 | def get_version(version_file=STATIC_VERSION_FILE):
23 | version_info = get_static_version_info(version_file)
24 | version = version_info["version"]
25 | if version == "__use_git__":
26 | version = get_version_from_git()
27 | if not version:
28 | version = get_version_from_git_archive(version_info)
29 | if not version:
30 | version = Version("unknown", None, None)
31 | return pep440_format(version)
32 | else:
33 | return version
34 |
35 |
36 | def get_static_version_info(version_file=STATIC_VERSION_FILE):
37 | version_info = {}
38 | with open(os.path.join(package_root, version_file), "rb") as f:
39 | exec(f.read(), {}, version_info)
40 | return version_info
41 |
42 |
43 | def version_is_from_git(version_file=STATIC_VERSION_FILE):
44 | return get_static_version_info(version_file)["version"] == "__use_git__"
45 |
46 |
47 | def pep440_format(version_info):
48 | release, dev, labels = version_info
49 |
50 | version_parts = [release]
51 | if dev:
52 | if release.endswith("-dev") or release.endswith(".dev"):
53 | version_parts.append(dev)
54 | else: # prefer PEP440 over strict adhesion to semver
55 | version_parts.append(".dev{}".format(dev))
56 |
57 | if labels:
58 | version_parts.append("+")
59 | version_parts.append(".".join(labels))
60 |
61 | return "".join(version_parts)
62 |
63 |
64 | def get_version_from_git():
65 | # git describe --first-parent does not take into account tags from branches
66 | # that were merged-in. The '--long' flag gets us the 'dev' version and
67 | # git hash, '--always' returns the git hash even if there are no tags.
68 | for opts in [["--first-parent"], []]:
69 | try:
70 | p = subprocess.Popen(
71 | ["git", "describe", "--long", "--always"] + opts,
72 | cwd=package_root,
73 | stdout=subprocess.PIPE,
74 | stderr=subprocess.PIPE,
75 | )
76 | except OSError:
77 | return
78 | if p.wait() == 0:
79 | break
80 | else:
81 | return
82 |
83 | description = (
84 | p.communicate()[0]
85 | .decode()
86 | .strip("v") # Tags can have a leading 'v', but the version should not
87 | .rstrip("\n")
88 | .rsplit("-", 2) # Split the latest tag, commits since tag, and hash
89 | )
90 |
91 | try:
92 | release, dev, git = description
93 | except ValueError: # No tags, only the git hash
94 | # prepend 'g' to match with format returned by 'git describe'
95 | git = "g{}".format(*description)
96 | release = "unknown"
97 | dev = None
98 |
99 | labels = []
100 | if dev == "0":
101 | dev = None
102 | else:
103 | labels.append(git)
104 |
105 | try:
106 | p = subprocess.Popen(["git", "diff", "--quiet"], cwd=package_root)
107 | except OSError:
108 | labels.append("confused") # This should never happen.
109 | else:
110 | if p.wait() == 1:
111 | labels.append("dirty")
112 |
113 | return Version(release, dev, labels)
114 |
115 |
116 | # TODO: change this logic when there is a git pretty-format
117 | # that gives the same output as 'git describe'.
118 | # Currently we can only tell the tag the current commit is
119 | # pointing to, or its hash (with no version info)
120 | # if it is not tagged.
121 | def get_version_from_git_archive(version_info):
122 | try:
123 | refnames = version_info["refnames"]
124 | git_hash = version_info["git_hash"]
125 | except KeyError:
126 | # These fields are not present if we are running from an sdist.
127 | # Execution should never reach here, though
128 | return None
129 |
130 | if git_hash.startswith("$Format") or refnames.startswith("$Format"):
131 | # variables not expanded during 'git archive'
132 | return None
133 |
134 | VTAG = "tag: v"
135 | refs = set(r.strip() for r in refnames.split(","))
136 | version_tags = set(r[len(VTAG) :] for r in refs if r.startswith(VTAG))
137 | if version_tags:
138 | release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1"
139 | return Version(release, dev=None, labels=None)
140 | else:
141 | return Version("unknown", dev=None, labels=["g{}".format(git_hash)])
142 |
143 |
144 | __version__ = get_version()
145 |
146 |
147 | # The following section defines a 'get_cmdclass' function
148 | # that can be used from setup.py. The '__version__' module
149 | # global is used (but not modified).
150 |
151 |
152 | def _write_version(fname):
153 | # This could be a hard link, so try to delete it first. Is there any way
154 | # to do this atomically together with opening?
155 | try:
156 | os.remove(fname)
157 | except OSError:
158 | pass
159 | with open(fname, "w") as f:
160 | f.write(
161 | "# This file has been created by setup.py.\n"
162 | "version = '{}'\n".format(__version__)
163 | )
164 |
165 |
166 | def get_cmdclass(pkg_source_path):
167 | class _build_py(build_py_orig):
168 | def run(self):
169 | super().run()
170 |
171 | src_marker = "".join(["src", os.path.sep])
172 |
173 | if pkg_source_path.startswith(src_marker):
174 | path = pkg_source_path[len(src_marker):]
175 | else:
176 | path = pkg_source_path
177 | _write_version(
178 | os.path.join(
179 | self.build_lib, path, STATIC_VERSION_FILE
180 | )
181 | )
182 |
183 | class _sdist(sdist_orig):
184 | def make_release_tree(self, base_dir, files):
185 | super().make_release_tree(base_dir, files)
186 | _write_version(
187 | os.path.join(base_dir, pkg_source_path, STATIC_VERSION_FILE)
188 | )
189 |
190 | return dict(sdist=_sdist, build_py=_build_py)
191 |
--------------------------------------------------------------------------------
/haca3/test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from pathlib import Path
4 |
5 | import nibabel as nib
6 | import numpy as np
7 | import torch
8 | from torchvision.transforms import ToTensor
9 |
10 | from skimage.filters import threshold_otsu
11 | from skimage.morphology import isotropic_closing
12 |
13 | from .modules.model import HACA3
14 | from .modules.utils import *
15 |
16 |
17 | def background_removal(image_vol):
18 | [n_row, n_col, n_slc] = image_vol.shape
19 | thresh = threshold_otsu(image_vol)
20 | mask = (image_vol >= thresh) * 1.0
21 | mask = zero_pad(mask, 256)
22 | mask = isotropic_closing(mask, radius=20)
23 | mask = crop(mask, n_row, n_col, n_slc)
24 | image_vol[mask < 1e-4] = 0.0
25 | return image_vol
26 |
27 | def background_removal2d(image_vol):
28 | [n_row, n_col] = image_vol.shape
29 | thresh = threshold_otsu(image_vol)
30 | mask = (image_vol >= thresh) * 1.0
31 | mask = zero_pad2d(mask, 256)
32 | mask = isotropic_closing(mask, radius=20)
33 | mask = crop2d(mask, n_row, n_col)
34 | image_vol[mask < 1e-4] = 0.0
35 | return image_vol
36 |
37 | def obtain_single_image(image_path, bg_removal=True):
38 | image_obj = nib.Nifti1Image.from_filename(image_path)
39 | image_vol = np.array(image_obj.get_fdata().astype(np.float32))
40 | thresh = np.percentile(image_vol.flatten(), 95)
41 | image_vol = image_vol / (thresh + 1e-5)
42 | image_vol = np.clip(image_vol, a_min=0.0, a_max=5.0)
43 | if bg_removal:
44 | image_vol = background_removal(image_vol)
45 |
46 | n_row, n_col, n_slc = image_vol.shape
47 | # zero padding
48 | image_padded = np.zeros((224, 224, 224)).astype(np.float32)
49 | image_padded[112 - n_row // 2:112 + n_row // 2 + n_row % 2,
50 | 112 - n_col // 2:112 + n_col // 2 + n_col % 2,
51 | 112 - n_slc // 2:112 + n_slc // 2 + n_slc % 2] = image_vol
52 | return ToTensor()(image_padded), image_obj.header, thresh
53 |
54 |
55 | def load_source_images(image_paths, bg_removal=True):
56 | source_images = []
57 | image_header = None
58 | for image_path in image_paths:
59 | image_vol, image_header, _ = obtain_single_image(image_path, bg_removal)
60 | source_images.append(image_vol.float().permute(2, 1, 0))
61 | return source_images, image_header
62 |
63 |
64 | def main(args=None):
65 | args = sys.argv[1:] if args is None else args
66 | parser = argparse.ArgumentParser(description='Harmonization with HACA3.')
67 | parser.add_argument('--in-path', type=Path, action='append', required=True)
68 | parser.add_argument('--target-image', type=Path, action='append', default=[])
69 | parser.add_argument('--target-theta', type=float, nargs=2, action='append', default=[])
70 | parser.add_argument('--target-eta', type=float, nargs=2, action='append', default=[])
71 | parser.add_argument('--norm-val', type=float, action='append', default=[])
72 | parser.add_argument('--out-path', type=Path, action='append', required=True)
73 | parser.add_argument('--harmonization-model', type=Path, required=True)
74 | parser.add_argument('--fusion-model', type=Path)
75 | parser.add_argument('--beta-dim', type=int, default=5)
76 | parser.add_argument('--theta-dim', type=int, default=2)
77 | parser.add_argument('--save-intermediate', action='store_true', default=False)
78 | parser.add_argument('--intermediate-out-dir', type=Path, default=Path.cwd())
79 | parser.add_argument('--no-bg-removal', dest='bg_removal', action='store_false', default=True)
80 | parser.add_argument('--gpu-id', type=int, default=0)
81 | parser.add_argument('--num-batches', type=int, default=4)
82 |
83 | args = parser.parse_args(args)
84 |
85 | text_div = '=' * 10
86 | print(f'{text_div} BEGIN HACA3 HARMONIZATION {text_div}')
87 |
88 | # ==== GET ABSOLUTE PATHS ====
89 | for argname in ['in_path', 'target_image', 'out_path', 'harmonization_model',
90 | 'fusion_model', 'intermediate_out_dir']:
91 | if isinstance(getattr(args, argname), list):
92 | setattr(args, argname, [path.resolve() for path in getattr(args, argname)])
93 | else:
94 | setattr(args, argname, getattr(args, argname).resolve())
95 |
96 | # ==== SET DEFAULT FOR NORM/ETA ====
97 | if len(args.target_theta) > 0 and len(args.target_eta) == 0:
98 | args.target_eta = [[0.3, 0.5]]
99 | if len(args.target_theta) > 0 and len(args.norm_val) == 0:
100 | args.norm_val = [1000]
101 |
102 | # ==== CHECK CONDITIONS OF INPUT ARGUMENTS ====
103 | if not ((len(args.target_image) > 0) ^ (len(args.target_theta) > 0)):
104 | parser.error("'--target-image' or '--target-theta' value should be provided.")
105 |
106 | if 0 < len(args.target_image) != len(args.out_path):
107 | parser.error("Number of '--out-path' and '--target-image' options should be the same.")
108 |
109 | if len(args.target_theta) == 1 and len(args.target_eta) > 1:
110 | args.target_theta = args.target_theta * len(args.target_eta)
111 |
112 | if len(args.target_theta) > 1 and len(args.target_eta) == 1:
113 | args.target_eta = args.target_eta * len(args.target_theta)
114 |
115 | if len(args.target_theta) > 1 and len(args.norm_val) == 1:
116 | args.norm_val = args.norm_val * len(args.target_theta)
117 |
118 | if 0 < len(args.target_theta) != len(args.target_eta):
119 | parser.error("Number of '--target-theta' and '--target-eta' options should be the same.")
120 |
121 | if 0 < len(args.target_theta) != len(args.norm_val):
122 | parser.error("Number of '--target-theta' and '--norm-val' options should be the same.")
123 |
124 | if 0 < len(args.target_theta) != len(args.out_path):
125 | parser.error("Number of '--target-theta' and '--out-path' options should be the same.")
126 |
127 | if args.save_intermediate:
128 | mkdir_p(args.intermediate_out_dir)
129 |
130 | # ==== INITIALIZE MODEL ====
131 | haca3 = HACA3(beta_dim=args.beta_dim,
132 | theta_dim=args.theta_dim,
133 | eta_dim=2,
134 | pretrained_haca3=args.harmonization_model,
135 | gpu_id=args.gpu_id)
136 |
137 | # ==== LOAD SOURCE IMAGES ====
138 | source_images, image_header = load_source_images(args.in_path, args.bg_removal)
139 |
140 | # ==== LOAD TARGET IMAGES IF PROVIDED ====
141 | if len(args.target_image) > 0:
142 | target_images, norm_vals = [], []
143 | for target_image_path, out_path in zip(args.target_image, args.out_path):
144 | target_image_tmp, tmp_header, norm_val = obtain_single_image(target_image_path, args.bg_removal)
145 | target_images.append(target_image_tmp.permute(2, 1, 0).permute(0, 2, 1).flip(1)[100:120, ...])
146 | norm_vals.append(norm_val)
147 | if args.save_intermediate:
148 | out_prefix = out_path.name.replace('.nii.gz', '')
149 | save_img = target_image_tmp.permute(1, 2, 0).numpy()[112 - 96:112 + 96, :, 112 - 96:112 + 96]
150 | target_obj = nib.Nifti1Image(save_img * norm_val, None, tmp_header)
151 | target_obj.to_filename(args.intermediate_out_dir / f'{out_prefix}_target.nii.gz')
152 | if args.save_intermediate:
153 | out_prefix = args.out_path[0].name.replace('.nii.gz', '')
154 | with open(args.intermediate_out_dir / f'{out_prefix}_targetnorms.txt', 'w') as fp:
155 | fp.write('image,norm_val\n')
156 | for i, norm_val in enumerate(norm_vals):
157 | fp.write(f'{i},{norm_val:.6f}\n')
158 | np.savetxt(args.intermediate_out_dir / f'{out_prefix}_targetnorms.txt', norm_vals)
159 | target_theta = None
160 | target_eta = None
161 | else:
162 | target_images = None
163 | target_theta = torch.as_tensor(args.target_theta, dtype=torch.float32)
164 | target_eta = torch.as_tensor(args.target_eta, dtype=torch.float32)
165 | norm_vals = args.norm_val
166 |
167 | # ===== BEGIN HARMONIZATION WITH HACA3 =====
168 | haca3.harmonize(
169 | source_images=[image.permute(2, 0, 1) for image in source_images],
170 | target_images=target_images,
171 | target_theta=target_theta,
172 | target_eta=target_eta,
173 | out_paths=args.out_path,
174 | header=image_header,
175 | recon_orientation='axial',
176 | norm_vals=norm_vals,
177 | num_batches=args.num_batches,
178 | save_intermediate=args.save_intermediate,
179 | intermediate_out_dir=args.intermediate_out_dir,
180 | )
181 |
182 | haca3.harmonize(
183 | source_images=[image.permute(0, 2, 1).flip(1) for image in source_images],
184 | target_images=target_images,
185 | target_theta=target_theta,
186 | target_eta=target_eta,
187 | out_paths=args.out_path,
188 | header=image_header,
189 | recon_orientation='coronal',
190 | norm_vals=norm_vals,
191 | num_batches=args.num_batches,
192 | save_intermediate=args.save_intermediate,
193 | intermediate_out_dir=args.intermediate_out_dir,
194 | )
195 |
196 | haca3.harmonize(
197 | source_images=[image.permute(1, 2, 0).flip(1) for image in source_images],
198 | target_images=target_images,
199 | target_theta=target_theta,
200 | target_eta=target_eta,
201 | out_paths=args.out_path,
202 | header=image_header,
203 | recon_orientation='sagittal',
204 | norm_vals=norm_vals,
205 | num_batches=args.num_batches,
206 | save_intermediate=args.save_intermediate,
207 | intermediate_out_dir=args.intermediate_out_dir,
208 | )
209 |
210 | print(f'{text_div} START FUSION {text_div}')
211 | for out_path, norm_val in zip(args.out_path, norm_vals):
212 | prefix = out_path.name.replace('.nii.gz', '')
213 | decode_img_paths = [out_path.parent / f'{prefix}_harmonized_{orient}.nii.gz'
214 | for orient in ['axial', 'coronal', 'sagittal']]
215 | haca3.combine_images(decode_img_paths, out_path, norm_val, args.fusion_model)
216 |
--------------------------------------------------------------------------------
/haca3/modules/fusion_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | from glob import glob
5 |
6 | import torch
7 | from torch import nn
8 | from torch.utils.data import DataLoader
9 | from torchvision.transforms import ToTensor
10 | from torch.utils.data.dataset import Dataset
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torchvision import utils
13 | import torch.nn.functional as F
14 | import nibabel as nib
15 | from .utils import mkdir_p
16 |
17 | class Net(nn.Module):
18 | def __init__(self, in_ch=3, out_ch=1):
19 | super().__init__()
20 | self.conv1 = nn.Sequential(
21 | nn.Conv3d(in_ch, 8, 3, 1, 1),
22 | nn.InstanceNorm3d(8),
23 | nn.LeakyReLU(),
24 | nn.Conv3d(8, 16, 3, 1, 1),
25 | nn.InstanceNorm3d(16),
26 | nn.LeakyReLU())
27 | self.conv2 = nn.Sequential(
28 | nn.Conv3d(in_ch+16, 16, 3, 1, 1),
29 | #nn.InstanceNorm3d(16),
30 | nn.LeakyReLU(),
31 | nn.Conv3d(16, out_ch, 3, 1, 1),
32 | nn.ReLU())
33 |
34 | def forward(self, x):
35 | #return self.conv2(x + self.conv1(x))
36 | return self.conv2(torch.cat([x, self.conv1(x)], dim=1))
37 |
38 | class MultiOrientationDataset(Dataset):
39 | def __init__(self, dataset_dirs):
40 | self.dataset_dirs = dataset_dirs
41 | self.imgs = self._get_files()
42 |
43 | def _get_files(self):
44 | img_paths = []
45 | contrast_names = ['T1', 'T2', 'PD', 'FLAIR']
46 | for dataset_dir in self.dataset_dirs:
47 | for contrast_name in contrast_names:
48 | img_path_tmp = os.path.join(dataset_dir, f'*harmonized_to_{contrast_name}_ori.nii.gz')
49 | img_path_tmp = sorted(glob(img_path_tmp))
50 | img_paths = img_paths + img_path_tmp
51 | return img_paths
52 |
53 | def __len__(self):
54 | return len(self.imgs)
55 |
56 | def get_tensor_from_path(self, img_path, if_norm_val=False):
57 | img = nib.load(img_path).get_fdata().astype(np.float32)
58 | img = ToTensor()(img)
59 | img = img.permute(2,1,0).permute(2,0,1).unsqueeze(0)
60 | img, norm_val = self.normalize_intensity(img)
61 | if if_norm_val:
62 | return img, norm_val
63 | else:
64 | return img
65 |
66 | def normalize_intensity(self, image):
67 | p99 = np.percentile(image.flatten(), 99)
68 | image = np.clip(image, a_min=0.0, a_max=p99)
69 | image = image / p99
70 | return image, p99
71 |
72 | def __getitem__(self, idx:int):
73 | img_path = self.imgs[idx]
74 | str_id = img_path.find('_ori')
75 | axial_img_path = img_path[:str_id] + '_axial.nii.gz'
76 | coronal_img_path = img_path[:str_id] + '_coronal.nii.gz'
77 | sagittal_img_path = img_path[:str_id] + '_sagittal.nii.gz'
78 | ori_image, norm_val = self.get_tensor_from_path(img_path, if_norm_val=True)
79 | img_dict = {'ori_img' : ori_image,
80 | 'axial_img' : self.get_tensor_from_path(axial_img_path),
81 | 'coronal_img' : self.get_tensor_from_path(coronal_img_path),
82 | 'sagittal_img' : self.get_tensor_from_path(sagittal_img_path),
83 | 'norm_val' : norm_val
84 | }
85 |
86 | return img_dict
87 |
88 |
89 | class FusionNet:
90 | def __init__(self, pretrained_model=None, gpu=0):
91 | self.device = torch.device('cuda:0' if gpu==0 else'cuda:1')
92 |
93 | # define networks
94 | self.fusion_net = Net(in_ch=3, out_ch=1)
95 |
96 | # initialize training variables
97 | self.train_loader, self.valid_loader = None, None
98 | self.out_dir = None
99 | self.optim_fusion_net = None
100 |
101 | # load pretrained models
102 | self.checkpoint = None
103 | if pretrained_model is not None:
104 | self.checkpoint = torch.load(pretrained_model, map_location=self.device)
105 | self.fusion_net.load_state_dict(self.checkpoint['fusion_net'])
106 | self.fusion_net.to(self.device)
107 | self.start_epoch = 0
108 |
109 | def load_dataset(self, dataset_dirs, batch_size):
110 | all_dataset = MultiOrientationDataset(dataset_dirs)
111 | num_instances = all_dataset.__len__()
112 | num_train = int(0.8 * num_instances)
113 | num_valid = num_instances - num_train
114 | train_dataset, valid_dataset = torch.utils.data.random_split(all_dataset,
115 | [num_train, num_valid])
116 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
117 | self.valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
118 |
119 | def initialize_training(self, out_dir, lr):
120 | self.out_dir = out_dir
121 | mkdir_p(out_dir)
122 | mkdir_p(os.path.join(out_dir, 'results'))
123 | mkdir_p(os.path.join(out_dir, 'models'))
124 |
125 | # define loss
126 | self.l1_loss = nn.L1Loss(reduction='none')
127 |
128 | self.optim_fusion_net = torch.optim.Adam(self.fusion_net.parameters(), lr=lr)
129 |
130 | if self.checkpoint is not None:
131 | self.start_epoch = self.checkpoint['epoch']
132 | self.optim_fusion_net.load_state_dict(self.checkpoint['optim_fusion_net'])
133 | self.start_epoch += 1
134 |
135 | def train(self, epochs):
136 | for epoch in range(self.start_epoch, epochs+1):
137 | self.train_loader = tqdm(self.train_loader)
138 | self.fusion_net.train()
139 | train_loss_sum = 0.0
140 | num_train_imgs = 0
141 | for batch_id, img_dict in enumerate(self.train_loader):
142 | syn_img = torch.cat([
143 | img_dict['axial_img'],
144 | img_dict['coronal_img'],
145 | img_dict['sagittal_img']
146 | ], dim=1).to(self.device)
147 | ori_img = img_dict['ori_img'].to(self.device)
148 | batch_size = ori_img.shape[0]
149 |
150 | ori_img = ori_img * (syn_img[:,[0],:,:,:] > 1e-8).detach()
151 |
152 | fusion_img = self.fusion_net(syn_img)
153 |
154 | rec_loss = self.l1_loss(fusion_img, ori_img).mean()
155 | self.optim_fusion_net.zero_grad()
156 | rec_loss.backward()
157 | self.optim_fusion_net.step()
158 |
159 | train_loss_sum += rec_loss.item() * batch_size
160 | num_train_imgs += batch_size
161 | self.train_loader.set_description((f'epoch: {epoch}; '
162 | f'rec: {rec_loss.item():.3f}; '
163 | f'avg_trn: {train_loss_sum / num_train_imgs:.3f}; '))
164 | if batch_id % 50 - 1 == 0:
165 | img_affine = [[-1, 0, 0, 96], [0, -1, 0, 96], [0, 0, 1, -78], [0, 0, 0, 1]]
166 | img_save = np.array(fusion_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2))
167 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine)
168 | file_name = os.path.join(self.out_dir, 'results', f'train_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_syn.nii.gz')
169 | nib.save(img_save, file_name)
170 |
171 | img_save = np.array(ori_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2))
172 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine)
173 | file_name = os.path.join(self.out_dir, 'results', f'train_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_ori.nii.gz')
174 | nib.save(img_save, file_name)
175 |
176 | # save models
177 | if batch_id % 100 == 0:
178 | file_name = os.path.join(self.out_dir, 'models',
179 | f'epoch{str(epoch).zfill(3)}_batch{str(batch_id).zfill(4)}.pt')
180 | self.save_model(file_name, epoch)
181 | # VALIDATION
182 | self.valid_loader = tqdm(self.valid_loader)
183 | valid_loss_sum = 0.0
184 | num_valid_imgs = 0
185 | self.fusion_net.eval()
186 | with torch.set_grad_enabled(False):
187 | for batch_id, img_dict in enumerate(self.valid_loader):
188 | syn_img = torch.cat([
189 | img_dict['axial_img'],
190 | img_dict['coronal_img'],
191 | img_dict['sagittal_img']
192 | ], dim=1).to(self.device)
193 | ori_img = img_dict['ori_img'].to(self.device)
194 | batch_size = ori_img.shape[0]
195 |
196 | #mask = syn_img[:,[0],:,:,:] > 1e-8
197 | ori_img = ori_img * (syn_img[:,[0],:,:,:] > 1e-8).detach()
198 |
199 | fusion_img = self.fusion_net(syn_img)
200 |
201 | rec_loss = self.l1_loss(fusion_img, ori_img).mean()
202 |
203 | valid_loss_sum += rec_loss.item() * batch_size
204 | num_valid_imgs += batch_size
205 | self.valid_loader.set_description((f'epoch: {epoch}; '
206 | f'rec: {rec_loss.item():.3f}; '
207 | f'avg_trn: {valid_loss_sum / num_valid_imgs:.3f}; '))
208 | if batch_id % 50 - 1 == 0:
209 | img_affine = [[-1, 0, 0, 96], [0, -1, 0, 96], [0, 0, 1, -78], [0, 0, 0, 1]]
210 | img_save = np.array(fusion_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2))
211 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine)
212 | file_name = os.path.join(self.out_dir, 'results', f'valid_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_syn.nii.gz')
213 | nib.save(img_save, file_name)
214 |
215 | img_save = np.array(ori_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2))
216 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine)
217 | file_name = os.path.join(self.out_dir, 'results', f'valid_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_ori.nii.gz')
218 | nib.save(img_save, file_name)
219 |
220 |
221 | def save_model(self, file_name, epoch):
222 | state = {'epoch': epoch,
223 | 'fusion_net': self.fusion_net.state_dict(),
224 | 'optim_fusion_net': self.optim_fusion_net.state_dict()}
225 | torch.save(obj=state, f=file_name)
226 |
227 |
228 |
229 |
--------------------------------------------------------------------------------
/haca3/modules/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | class FusionNet(nn.Module):
8 | def __init__(self, in_ch=3, out_ch=1):
9 | super().__init__()
10 | self.conv1 = nn.Sequential(
11 | nn.Conv3d(in_ch, 8, 3, 1, 1),
12 | nn.InstanceNorm3d(8),
13 | nn.LeakyReLU(),
14 | nn.Conv3d(8, 16, 3, 1, 1),
15 | nn.InstanceNorm3d(16),
16 | nn.LeakyReLU())
17 | self.conv2 = nn.Sequential(
18 | nn.Conv3d(in_ch + 16, 16, 3, 1, 1),
19 | nn.LeakyReLU(),
20 | nn.Conv3d(16, out_ch, 3, 1, 1),
21 | nn.ReLU())
22 |
23 | def forward(self, x):
24 | # return self.conv2(x + self.conv1(x))
25 | return self.conv2(torch.cat([x, self.conv1(x)], dim=1))
26 |
27 | class UNet(nn.Module):
28 | def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'):
29 | super().__init__()
30 | self.final_act = final_act
31 | self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1)
32 |
33 | self.down_convs = nn.ModuleList()
34 | self.down_samples = nn.ModuleList()
35 | self.up_samples = nn.ModuleList()
36 | self.up_convs = nn.ModuleList()
37 | for lv in range(num_lvs):
38 | ch = base_ch * (2 ** lv)
39 | self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2))
40 | self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2))
41 | self.up_samples.append(Upsample(ch * 4))
42 | self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2))
43 | bottleneck_ch = base_ch * (2 ** num_lvs)
44 | self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2)
45 | self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1),
46 | nn.LeakyReLU(0.1),
47 | nn.Conv2d(base_ch, out_ch, 3, 1, 1))
48 |
49 | def forward(self, in_tensor, condition=None):
50 | encoded_features = []
51 | x = self.in_conv(in_tensor)
52 | for down_conv, down_sample in zip(self.down_convs, self.down_samples):
53 | if condition is not None:
54 | feature_dim = x.shape[-1]
55 | down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1))
56 | else:
57 | down_conv_out = down_conv(x)
58 | x = down_sample(down_conv_out)
59 | encoded_features.append(down_conv_out)
60 | x = self.bottleneck_conv(x)
61 | for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features),
62 | reversed(self.up_convs),
63 | reversed(self.up_samples)):
64 | x = up_sample(x, encoded_feature)
65 | x = up_conv(x)
66 | x = self.out_conv(x)
67 | if self.final_act == 'sigmoid':
68 | x = torch.sigmoid(x)
69 | elif self.final_act == "relu":
70 | x = torch.relu(x)
71 | elif self.final_act == 'tanh':
72 | x = torch.tanh(x)
73 | else:
74 | x = x
75 | return x
76 |
77 |
78 | class ConvBlock2d(nn.Module):
79 | def __init__(self, in_ch, mid_ch, out_ch):
80 | super().__init__()
81 | self.conv = nn.Sequential(
82 | nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
83 | nn.InstanceNorm2d(mid_ch),
84 | nn.LeakyReLU(0.1),
85 | nn.Conv2d(mid_ch, out_ch, 3, 1, 1),
86 | nn.InstanceNorm2d(out_ch),
87 | nn.LeakyReLU(0.1)
88 | )
89 |
90 | def forward(self, in_tensor):
91 | return self.conv(in_tensor)
92 |
93 |
94 | class Upsample(nn.Module):
95 | def __init__(self, in_ch):
96 | super().__init__()
97 | out_ch = in_ch // 2
98 | self.conv = nn.Sequential(
99 | nn.Conv2d(in_ch, out_ch, 3, 1, 1),
100 | nn.InstanceNorm2d(out_ch),
101 | nn.LeakyReLU(0.1)
102 | )
103 |
104 | def forward(self, in_tensor, encoded_feature):
105 | up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False)
106 | up_sampled_tensor = self.conv(up_sampled_tensor)
107 | return torch.cat([encoded_feature, up_sampled_tensor], dim=1)
108 |
109 |
110 | class EtaEncoder(nn.Module):
111 | def __init__(self, in_ch=1, out_ch=2):
112 | super().__init__()
113 | self.in_conv = nn.Sequential(
114 | nn.Conv2d(in_ch, 16, 5, 1, 2), # (*, 16, 224, 224)
115 | nn.InstanceNorm2d(16),
116 | nn.LeakyReLU(0.1),
117 | nn.Conv2d(16, 64, 3, 1, 1), # (*, 64, 224, 224)
118 | nn.InstanceNorm2d(64),
119 | nn.LeakyReLU(0.1)
120 | )
121 | self.seq = nn.Sequential(
122 | nn.Conv2d(64 + in_ch, 32, 32, 32, 0), # (*, 32, 7, 7)
123 | nn.InstanceNorm2d(32),
124 | nn.LeakyReLU(0.1),
125 | nn.Conv2d(32, out_ch, 7, 7, 0))
126 |
127 | def forward(self, x):
128 | return self.seq(torch.cat([self.in_conv(x), x], dim=1))
129 |
130 |
131 | class Patchifier(nn.Module):
132 | def __init__(self, in_ch, out_ch=1):
133 | super().__init__()
134 | self.conv = nn.Sequential(
135 | nn.Conv2d(in_ch, 64, 32, 32, 0), # (*, in_ch, 224, 224) --> (*, 64, 7, 7)
136 | nn.LeakyReLU(0.1),
137 | nn.Conv2d(64, out_ch, 1, 1, 0))
138 |
139 | def forward(self, x):
140 | return self.conv(x)
141 |
142 |
143 | class ThetaEncoder(nn.Module):
144 | def __init__(self, in_ch, out_ch):
145 | super().__init__()
146 | self.conv = nn.Sequential(
147 | nn.Conv2d(in_ch, 32, 17, 9, 4),
148 | nn.InstanceNorm2d(32),
149 | nn.LeakyReLU(0.1), # (*, 32, 28, 28)
150 | nn.Conv2d(32, 64, 4, 2, 1),
151 | nn.InstanceNorm2d(64),
152 | nn.LeakyReLU(0.1), # (*, 64, 14, 14)
153 | nn.Conv2d(64, 64, 4, 2, 1),
154 | nn.InstanceNorm2d(64),
155 | nn.LeakyReLU(0.1)) # (* 64, 7, 7)
156 | self.mean_conv = nn.Sequential(
157 | nn.Conv2d(64, 32, 3, 1, 1),
158 | nn.InstanceNorm2d(32),
159 | nn.LeakyReLU(0.1),
160 | nn.Conv2d(32, out_ch, 6, 6, 0))
161 | self.logvar_conv = nn.Sequential(
162 | nn.Conv2d(64, 32, 3, 1, 1),
163 | nn.InstanceNorm2d(32),
164 | nn.LeakyReLU(0.1),
165 | nn.Conv2d(32, out_ch, 6, 6, 0))
166 |
167 | def forward(self, x):
168 | M = self.conv(x)
169 | mu = self.mean_conv(M)
170 | logvar = self.logvar_conv(M)
171 | return mu, logvar
172 |
173 | # class ThetaEncoder(nn.Module):
174 | # def __init__(self, in_ch, out_ch):
175 | # super().__init__()
176 | # self.conv = nn.Sequential(
177 | # nn.Conv2d(in_ch, 32, 32, 32, 0), # (*, in_ch, 224, 244) --> (*, 32, 7, 7)
178 | # nn.InstanceNorm2d(32),
179 | # nn.LeakyReLU(0.1),
180 | # nn.Conv2d(32, 64, 1, 1, 0),
181 | # # nn.InstanceNorm2d(64),
182 | # nn.LeakyReLU(0.1))
183 | # self.mu_conv = nn.Sequential(
184 | # nn.Conv2d(64, 64, 3, 1, 1),
185 | # nn.InstanceNorm2d(64),
186 | # nn.LeakyReLU(0.1),
187 | # nn.Conv2d(64, out_ch, 7, 7, 0))
188 | # self.logvar_conv = nn.Sequential(
189 | # nn.Conv2d(64, 64, 3, 1, 1),
190 | # nn.InstanceNorm2d(64),
191 | # nn.LeakyReLU(0.1),
192 | # nn.Conv2d(64, out_ch, 7, 7, 0))
193 | #
194 | # def forward(self, x, patch_shuffle=False):
195 | # m = self.conv(x)
196 | # if patch_shuffle:
197 | # batch_size = m.shape[0]
198 | # num_features = m.shape[1]
199 | # num_patches_per_dim = m.shape[-1]
200 | # m = m.view(batch_size, num_features, -1)[:, :, torch.randperm(num_patches_per_dim ** 2)]
201 | # m = m.view(batch_size, num_features, num_patches_per_dim, num_patches_per_dim)
202 | # mu = self.mu_conv(m)
203 | # logvar = self.logvar_conv(m)
204 | # return mu, logvar
205 |
206 |
207 | class AttentionModule(nn.Module):
208 | def __init__(self, dim, v_ch=5):
209 | super().__init__()
210 | self.dim = dim
211 | self.v_ch = v_ch
212 | self.q_fc = nn.Sequential(
213 | nn.Linear(dim, 128),
214 | nn.LeakyReLU(0.1),
215 | nn.Linear(128, 16),
216 | nn.LayerNorm(16))
217 | self.k_fc = nn.Sequential(
218 | nn.Linear(dim, 128),
219 | nn.LeakyReLU(0.1),
220 | nn.Linear(128, 16),
221 | nn.LayerNorm(16))
222 |
223 | self.scale = self.dim ** (-0.5)
224 |
225 | def forward(self, q, k, v, modality_dropout=None, temperature=10.0):
226 | """
227 | Attention module for optimal anatomy fusion.
228 |
229 | ===INPUTS===
230 | * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1)
231 | Query variable. In HACA3, query is the concatenation of target \theta and target \eta.
232 | * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4)
233 | Key variable. In HACA3, keys are \theta and \eta's of source images.
234 | * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4)
235 | Value variable. In HACA3, values are multi-channel logits of source images.
236 | self.v_ch is the number of \beta channels.
237 | * modality_dropout: torch.Tensor (batch_size, num_contrasts=4)
238 | Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists.
239 | """
240 | batch_size, feature_dim_q, num_q_patches = q.shape
241 | _, feature_dim_k, _, num_contrasts = k.shape
242 | num_v_patches = v.shape[2]
243 | assert (
244 | feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim
245 | ), 'Feature dimensions do not match.'
246 |
247 | # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q)
248 | q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1)
249 | # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k)
250 | k = k.permute(0, 2, 3, 1)
251 | # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5)
252 | v = v.permute(0, 2, 3, 1)
253 | q = self.q_fc(q)
254 | # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4)
255 | k = self.k_fc(k).permute(0, 1, 3, 2)
256 |
257 | # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4)
258 | dot_prod = (q @ k) * self.scale
259 | interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches))
260 |
261 | q_spatial_dim = int(math.sqrt(num_q_patches))
262 | dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts)
263 |
264 | image_dim = int(math.sqrt(num_v_patches))
265 | # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts)
266 | dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1)
267 | if modality_dropout is not None:
268 | modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1)
269 | dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5)
270 |
271 | attention = (dot_prod_interp / temperature).softmax(dim=-1)
272 | v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v
273 | v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2)
274 | attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2)
275 | return v, attention
276 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # HACA3: A unified approach for multi-site MR image harmonization | [Paper](https://www.sciencedirect.com/science/article/pii/S0895611123001039)
4 |
5 | HACA3 is an advanced approach for multi-site MRI harmonization. This page provides a gentle introduction to HACA3 inference and training.
6 |
7 | - Publication: [Zuo et al. HACA3: A unified approach for multi-site MR image harmonization. *Computerized Medical Imaging
8 | and Graphics, 2023.*](https://www.sciencedirect.com/science/article/pii/S0895611123001039)
9 |
10 | - Citation:
11 | ```bibtex
12 | @article{ZUO2023102285,
13 | title = {HACA3: A unified approach for multi-site MR image harmonization},
14 | journal = {Computerized Medical Imaging and Graphics},
15 | volume = {109},
16 | pages = {102285},
17 | year = {2023},
18 | issn = {0895-6111},
19 | doi = {https://doi.org/10.1016/j.compmedimag.2023.102285},
20 | author = {Lianrui Zuo and Yihao Liu and Yuan Xue and Blake E. Dewey and
21 | Samuel W. Remedios and Savannah P. Hays and Murat Bilgel and
22 | Ellen M. Mowry and Scott D. Newsome and Peter A. Calabresi and
23 | Susan M. Resnick and Jerry L. Prince and Aaron Carass}
24 | }
25 | ```
26 |
27 | ###
Recent Updates
28 | - August 11, 2024 - **An [interactive demo](https://colab.research.google.com/drive/1PeBuqOAGupLQ2gXWVneX1Kn31ISh4oFB?usp=share_link) for HACA3 now available.** You can explore HACA3 harmonization and imputation real time.
29 |
30 | ## 1. Introduction and motivation
31 | ### 1.1 The double-edged sword of MRI: flexibility and variability
32 | Magnetic resonance imaging (MRI) is a powerful imaging technique, offering flexibility in capturing various tissue
33 | contrasts in a single imaging session. For example, T1-weighted, T2-weighed, and FLAIR images can be acquired in a single
34 | imaging session to provide comprehensive insights into different tissue properties. However, this flexibility comes at
35 | a cost: ***lack of standardization and consistency*** across imaging studies. Several factors contribute to this
36 | variability, including but not limited to
37 | - Pulse sequences, e.g., MPRAGE, SPGR
38 | - Imaging parameters, e.g., flip angle, echo time
39 | - Scanner manufacturers, e.g, Siemens, GE
40 | - Technician and site preferences.
41 |
42 | ### 1.2 Why should we harmonize MR images?
43 | Contrast variations in MR images may sometimes be subtle but are often significant enough to impact the quality and
44 | reliability of ***multi-site*** and ***longitudinal*** studies.
45 |
46 | - ***Example #1***: Multi-site inconsistency. In this example, two images were acquired at different sites using
47 | distinct imaging parameters. This led to ***different image contrast for the same subject***. As a result, an automatic
48 | segmentation algorithm produced inconsistent outcomes due to these contrast differences. Harmonization effectively
49 | alleviates this issue.
50 |
52 |
62 |