├── haca3 ├── modules │ ├── __init__.py │ ├── _static_version.py │ ├── dataset.py │ ├── utils.py │ ├── _version.py │ ├── fusion_model.py │ ├── network.py │ └── model.py ├── __init__.py ├── _static_version.py ├── train_fusion.py ├── train.py ├── encode.py ├── _version.py └── test.py ├── figures ├── GA.png ├── multi_site.png └── longitudinal.png ├── Dockerfile ├── requirements.txt ├── .gitlab-ci.yml ├── setup.py ├── .gitignore └── README.md /haca3/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figures/GA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lianruizuo/haca3/HEAD/figures/GA.png -------------------------------------------------------------------------------- /figures/multi_site.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lianruizuo/haca3/HEAD/figures/multi_site.png -------------------------------------------------------------------------------- /figures/longitudinal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lianruizuo/haca3/HEAD/figures/longitudinal.png -------------------------------------------------------------------------------- /haca3/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | from .modules.model import HACA3 3 | from .modules.utils import * 4 | from .test import background_removal2d 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends git 4 | 5 | COPY . /tmp/haca3 6 | 7 | RUN pip install /tmp/haca3 && rm -rf /tmp/haca3 8 | 9 | ENTRYPOINT ["haca3-test"] 10 | -------------------------------------------------------------------------------- /haca3/_static_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # This file is part of 'miniver': https://github.com/jbweston/miniver 3 | # 4 | # This file will be overwritten by setup.py when a source or binary 5 | # distribution is made. The magic value "__use_git__" is interpreted by 6 | # version.py. 7 | 8 | version = "__use_git__" 9 | 10 | # These values are only set if the distribution was created with 'git archive' 11 | refnames = "$Format:%D$" 12 | git_hash = "$Format:%h$" 13 | -------------------------------------------------------------------------------- /haca3/modules/_static_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # This file is part of 'miniver': https://github.com/jbweston/miniver 3 | # 4 | # This file will be overwritten by setup.py when a source or binary 5 | # distribution is made. The magic value "__use_git__" is interpreted by 6 | # version.py. 7 | 8 | version = "__use_git__" 9 | 10 | # These values are only set if the distribution was created with 'git archive' 11 | refnames = "$Format:%D$" 12 | git_hash = "$Format:%h$" 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2022.12.7 2 | charset-normalizer==2.1.1 3 | click==8.1.7 4 | cmake==3.25.0 5 | colorama==0.4.6 6 | Deprecated==1.2.14 7 | filelock==3.9.0 8 | humanize==4.8.0 9 | idna==3.4 10 | importlib-resources==6.0.1 11 | Jinja2==3.1.2 12 | lit==15.0.7 13 | markdown-it-py==3.0.0 14 | MarkupSafe==2.1.2 15 | mdurl==0.1.2 16 | mpmath==1.2.1 17 | networkx==3.0 18 | nibabel==5.1.0 19 | numpy==1.24.4 20 | packaging==23.1 21 | Pillow==9.3.0 22 | Pygments==2.16.1 23 | requests==2.28.1 24 | rich==13.5.2 25 | scipy==1.10.1 26 | shellingham==1.5.3 27 | SimpleITK==2.2.1 28 | sympy==1.11.1 29 | torch==2.0.1+cu118 30 | torchaudio==2.0.2+cu118 31 | torchio==0.19.1 32 | torchvision==0.15.2+cu118 33 | tqdm==4.66.1 34 | triton==2.0.0 35 | typer==0.9.0 36 | typing_extensions==4.4.0 37 | urllib3==1.26.13 38 | wrapt==1.15.0 39 | zipp==3.16.2 40 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | .build_image: 2 | image: docker:latest 3 | services: 4 | - name: docker:dind 5 | command: [ "--experimental" ] 6 | variables: 7 | DOCKER_BUILDKIT: 1 8 | before_script: 9 | - docker login -u gitlab-ci-token -p $CI_JOB_TOKEN $CI_REGISTRY 10 | - apk add curl bash 11 | - mkdir -vp ~/.docker/cli-plugins/ 12 | - curl --silent -L "https://gitlab.com/neurobuilds/setup-dind/-/raw/main/setup_dind.sh" > ~/setup_dind.sh 13 | - bash ~/setup_dind.sh 14 | - docker context create tls-environment 15 | - docker buildx create --use tls-environment 16 | - docker buildx inspect --bootstrap 17 | script: 18 | - docker buildx build --build-arg CI_JOB_TOKEN=$CI_JOB_TOKEN --push -t $TAG . 19 | 20 | build_tag: 21 | extends: .build_image 22 | variables: 23 | TAG: $CI_REGISTRY_IMAGE:$CI_COMMIT_TAG 24 | only: 25 | - tags 26 | 27 | build_branch: 28 | extends: .build_image 29 | variables: 30 | TAG: $CI_REGISTRY_IMAGE:$CI_COMMIT_BRANCH 31 | only: 32 | - main 33 | - support/* -------------------------------------------------------------------------------- /haca3/train_fusion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from modules.fusion_model import FusionNet 4 | 5 | def main(args=None): 6 | args = sys.argv[1:] if args is None else args 7 | parser = argparse.ArgumentParser(description='Unsupervised harmonization via disentanglement.') 8 | parser.add_argument('--dataset-dirs', nargs='+', type=str, required=True) 9 | parser.add_argument('--out-dir', type=str, default='../') 10 | parser.add_argument('--pretrained-model', type=str, default=None) 11 | parser.add_argument('--lr', type=float, default=5e-4) 12 | parser.add_argument('--batch-size', type=int, default=8) 13 | parser.add_argument('--epochs', type=int, default=4) 14 | parser.add_argument('--gpu', type=int, default=0) 15 | args = parser.parse_args(args) 16 | 17 | # initialize model 18 | trainer = FusionNet(pretrained_model=args.pretrained_model, gpu=args.gpu) 19 | 20 | trainer.load_dataset(dataset_dirs=args.dataset_dirs, batch_size=args.batch_size) 21 | 22 | trainer.initialize_training(out_dir=args.out_dir, lr=args.lr) 23 | 24 | trainer.train(epochs=args.epochs) 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __package_name__ = "haca3" 4 | 5 | 6 | def get_version_and_cmdclass(pkg_path): 7 | """Load version.py module without importing the whole package. 8 | 9 | Template code from miniver 10 | """ 11 | import os 12 | from importlib.util import module_from_spec, spec_from_file_location 13 | 14 | spec = spec_from_file_location("version", os.path.join(pkg_path, "_version.py")) 15 | module = module_from_spec(spec) 16 | spec.loader.exec_module(module) 17 | return module.__version__, module.get_cmdclass(pkg_path) 18 | 19 | 20 | __version__, cmdclass = get_version_and_cmdclass(__package_name__) 21 | 22 | 23 | # noinspection PyTypeChecker 24 | setup( 25 | name=__package_name__, 26 | version=__version__, 27 | description="HACA3: A unified approach for multi-site MR image harmonization", 28 | long_description="HACA3: A unified approach for multi-site MR image harmonization", 29 | author="Lianrui Zuo", 30 | author_email="lr_zuo@jhu.edu", 31 | url="https://gitlab.com/lr_zuo/haca3", 32 | license="Apache License, 2.0", 33 | classifiers=[ 34 | "Development Status :: 3 - Alpha", 35 | "Environment :: Console", 36 | "Intended Audience :: Science/Research", 37 | "License :: OSI Approved :: Apache Software License", 38 | "Programming Language :: Python :: 3.8", 39 | "Topic :: Scientific/Engineering", 40 | ], 41 | packages=find_packages(), 42 | keywords="mri harmonization", 43 | entry_points={ 44 | "console_scripts": [ 45 | "haca3-train=haca3.train:main", 46 | "haca3-test=haca3.test:main", 47 | ] 48 | }, 49 | install_requires=[ 50 | "nibabel", 51 | "numpy", 52 | "scipy", 53 | "torch", 54 | "torchvision", 55 | "tqdm", 56 | "torchio", 57 | "scikit-image", 58 | "tensorboard" 59 | ], 60 | cmdclass=cmdclass, 61 | ) 62 | -------------------------------------------------------------------------------- /haca3/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from .modules.model import HACA3 4 | 5 | 6 | def main(args=None): 7 | args = sys.argv[1:] if args is None else args 8 | parser = argparse.ArgumentParser(description='Harmonization with HACA3.') 9 | parser.add_argument('--dataset-dirs', type=str, nargs='+', required=True) 10 | parser.add_argument('--contrasts', type=str, nargs='+', required=True) 11 | parser.add_argument('--orientations', type=str, nargs='+', default=['axial', 'coronal', 'sagittal']) 12 | parser.add_argument('--out-dir', type=str, default='.') 13 | parser.add_argument('--beta-dim', type=int, default=5) 14 | parser.add_argument('--theta-dim', type=int, default=2) 15 | parser.add_argument('--eta-dim', type=int, default=2) 16 | parser.add_argument('--normalization-method', type=str, default='01') 17 | parser.add_argument('--pretrained-haca3', type=str, default=None) 18 | parser.add_argument('--pretrained-eta-encoder', type=str, default=None) 19 | parser.add_argument('--lr', type=float, default=5e-4) 20 | parser.add_argument('--batch-size', type=int, default=8) 21 | parser.add_argument('--epochs', type=int, default=8) 22 | parser.add_argument('--gpu-id', type=int, default=0) 23 | args = parser.parse_args(args) 24 | 25 | text_div = '=' * 10 26 | print(f'{text_div} BEGIN HACA3 TRAINING {text_div}') 27 | 28 | # ====== 1. INITIALIZE MODEL ====== 29 | haca3 = HACA3(beta_dim=args.beta_dim, theta_dim=args.theta_dim, eta_dim=args.eta_dim, 30 | pretrained_haca3=args.pretrained_haca3, pretrained_eta_encoder=args.pretrained_eta_encoder, 31 | gpu_id=args.gpu_id) 32 | 33 | # ====== 2. LOAD DATASETS ====== 34 | haca3.load_dataset(dataset_dirs=args.dataset_dirs, contrasts=args.contrasts, orientations=args.orientations, 35 | batch_size=args.batch_size, normalization_method=args.normalization_method) 36 | 37 | # ====== 3. INITIALIZE TRAINING ====== 38 | haca3.initialize_training(out_dir=args.out_dir, lr=args.lr) 39 | 40 | # ====== 4. BEGIN TRAINING ====== 41 | haca3.train(epochs=args.epochs) 42 | -------------------------------------------------------------------------------- /haca3/encode.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from modules.model import HACA3 4 | import torch 5 | import nibabel as nib 6 | from PIL import Image 7 | import os 8 | import numpy as np 9 | from torchvision.transforms import ToTensor, CenterCrop, Compose, ToPILImage 10 | 11 | 12 | def obtain_single_image(img_path, normalization_val=1000.0): 13 | img_file = nib.load(img_path) 14 | img_vol = np.array(img_file.get_fdata().astype(np.float32)) 15 | img_vol = img_vol / normalization_val * 0.25 16 | n_row, n_col, n_slc = img_vol.shape 17 | # get images with proper zero padding 18 | img_padded = np.zeros((288, 288, 288)).astype(np.float32) 19 | img_padded[144 - n_row // 2:144 + n_row // 2 + n_row % 2, 20 | 144 - n_col // 2:144 + n_col // 2 + n_col % 2, 21 | 144 - n_slc // 2:144 + n_slc // 2 + n_slc % 2] = img_vol 22 | 23 | return ToTensor()(img_padded), img_file.header, img_file.affine 24 | 25 | 26 | def main(args=None): 27 | args = sys.argv[1:] if args is None else args 28 | parser = argparse.ArgumentParser(description='Learn anatomy, artifact, and contrast with HACA3') 29 | parser.add_argument('--image', type=str, required=True) 30 | parser.add_argument('--out-dir', type=str, required=True) 31 | parser.add_argument('--prefix', type=str, default='subject1') 32 | parser.add_argument('--pretrained-harmonization', type=str, default=None) 33 | parser.add_argument('--beta-dim', type=int, default=5) 34 | parser.add_argument('--eta-dim', type=int, default=2) 35 | parser.add_argument('--theta-dim', type=int, default=2) 36 | parser.add_argument('--gpu', type=int, default=0) 37 | parser.add_argument('--norm', default=1000.0, type=float) 38 | args = parser.parse_args(args) 39 | 40 | harmonization_model = HACA3(beta_dim=args.beta_dim, 41 | theta_dim=args.theta_dim, 42 | eta_dim=args.eta_dim, 43 | pretrained_harmonization=args.pretrained_harmonization, 44 | gpu=args.gpu) 45 | 46 | # load image 47 | img_vol, img_header, img_affine = obtain_single_image(args.image, args.norm) 48 | 49 | # Encoding 50 | harmonization_model.encode(img=img_vol.float().permute(2, 1, 0).permute(2, 0, 1), 51 | out_dir=args.out_dir, 52 | prefix=args.prefix, 53 | affine=img_affine, 54 | header=img_header) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /haca3/modules/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import torch 4 | from torch.utils.data.dataset import Dataset 5 | import numpy as np 6 | from torchvision.transforms import Compose, Pad, CenterCrop, ToTensor, ToPILImage 7 | import torchio as tio 8 | import nibabel as nib 9 | 10 | default_transform = Compose([ToPILImage(), Pad(40), CenterCrop([224, 224])]) 11 | transform_dict = { 12 | tio.RandomMotion(degrees=(15, 30), translation=(10, 20)): 0.25, 13 | tio.RandomNoise(std=(0.01, 0.1)): 0.25, 14 | tio.RandomGhosting(num_ghosts=(4, 10)): 0.25, 15 | tio.RandomBiasField(): 0.25 16 | } 17 | degradation_transform = tio.OneOf(transform_dict) 18 | contrast_names = ['T1PRE', 'T2', 'PD', 'FLAIR'] 19 | 20 | 21 | def get_tensor_from_fpath(fpath, normalization_method): 22 | if os.path.exists(fpath): 23 | image = np.squeeze(nib.load(fpath).get_fdata().astype(np.float32)).transpose([1, 0]) 24 | 25 | if normalization_method == 'wm': 26 | image = image / 2.0 27 | else: 28 | p99 = np.percentile(image.flatten(), 95) 29 | image = image / (p99 + 1e-5) 30 | image = np.clip(image, a_min=0.0, a_max=5.0) 31 | 32 | image = np.array(default_transform(image)) 33 | image = ToTensor()(image) 34 | else: 35 | image = torch.ones([1, 224, 224]) 36 | return image 37 | 38 | 39 | def background_removal(image_dicts): 40 | num_contrasts = len(contrast_names) 41 | mask = torch.ones((1, 224, 224)) 42 | for image_dict in image_dicts: 43 | mask = mask * image_dict['image'].ge(1e-8) 44 | for i in range(num_contrasts): 45 | image_dicts[i]['image'] = image_dicts[i]['image'] * mask 46 | image_dicts[i]['image_degrade'] = image_dicts[i]['image_degrade'] * mask 47 | image_dicts[i]['mask'] = mask.bool() 48 | return image_dicts 49 | 50 | 51 | class HACA3Dataset(Dataset): 52 | def __init__(self, dataset_dirs, contrasts, orientations, mode='train', normalization_method='01'): 53 | self.mode = mode 54 | self.dataset_dirs = dataset_dirs 55 | self.contrasts = contrasts 56 | self.orientations = orientations 57 | self.t1_paths, self.site_ids = self._get_file_paths() 58 | self.normalization_method = normalization_method 59 | 60 | def _get_file_paths(self): 61 | fpaths, site_ids = [], [] 62 | for site_id, dataset_dir in enumerate(self.dataset_dirs): 63 | for orientation in self.orientations: 64 | t1_paths = sorted(glob(os.path.join(dataset_dir, self.mode, f'*T1PRE*{orientation.upper()}*nii.gz'))) 65 | fpaths += t1_paths 66 | site_ids += [site_id] * len(t1_paths) 67 | return fpaths, site_ids 68 | 69 | def __len__(self): 70 | return len(self.t1_paths) 71 | 72 | def __getitem__(self, idx: int): 73 | image_dicts = [] 74 | for contrast_id, contrast_name in enumerate(contrast_names): 75 | image_path = self.t1_paths[idx].replace('T1PRE', contrast_name) 76 | image = get_tensor_from_fpath(image_path, self.normalization_method) 77 | image_degrade = degradation_transform(image.unsqueeze(1)).squeeze(1) 78 | site_id = self.site_ids[idx] 79 | image_dict = {'image': image, 80 | 'image_degrade': image_degrade, 81 | 'site_id': site_id, 82 | 'contrast_id': contrast_id, 83 | 'exists': 0 if image[0, 0, 0] > 0.9999 else 1} 84 | image_dicts.append(image_dict) 85 | return background_removal(image_dicts) 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### JetBrains template 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | # User-specific stuff 6 | .idea/ 7 | 8 | ### macOS template 9 | # General 10 | .DS_Store 11 | .AppleDouble 12 | .LSOverride 13 | 14 | # Icon must end with two \r 15 | Icon 16 | 17 | # Thumbnails 18 | ._* 19 | 20 | # Files that might appear in the root of a volume 21 | .DocumentRevisions-V100 22 | .fseventsd 23 | .Spotlight-V100 24 | .TemporaryItems 25 | .Trashes 26 | .VolumeIcon.icns 27 | .com.apple.timemachine.donotpresent 28 | 29 | # Directories potentially created on remote AFP share 30 | .AppleDB 31 | .AppleDesktop 32 | Network Trash Folder 33 | Temporary Items 34 | .apdisk 35 | 36 | ### Python template 37 | # Byte-compiled / optimized / DLL files 38 | __pycache__/ 39 | *.py[cod] 40 | *$py.class 41 | 42 | # C extensions 43 | *.so 44 | 45 | # Distribution / packaging 46 | .Python 47 | build/ 48 | develop-eggs/ 49 | dist/ 50 | downloads/ 51 | eggs/ 52 | .eggs/ 53 | lib/ 54 | lib64/ 55 | parts/ 56 | sdist/ 57 | var/ 58 | wheels/ 59 | share/python-wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .nox/ 79 | .coverage 80 | .coverage.* 81 | .cache 82 | nosetests.xml 83 | coverage.xml 84 | *.cover 85 | *.py,cover 86 | .hypothesis/ 87 | .pytest_cache/ 88 | cover/ 89 | 90 | # Translations 91 | *.mo 92 | *.pot 93 | 94 | # Django stuff: 95 | *.log 96 | local_settings.py 97 | db.sqlite3 98 | db.sqlite3-journal 99 | 100 | # Flask stuff: 101 | instance/ 102 | .webassets-cache 103 | 104 | # Scrapy stuff: 105 | .scrapy 106 | 107 | # Sphinx documentation 108 | docs/_build/ 109 | 110 | # PyBuilder 111 | .pybuilder/ 112 | target/ 113 | 114 | # Jupyter Notebook 115 | .ipynb_checkpoints 116 | 117 | # IPython 118 | profile_default/ 119 | ipython_config.py 120 | 121 | # pyenv 122 | # For a library or package, you might want to ignore these files since the code is 123 | # intended to run in multiple environments; otherwise, check them in: 124 | # .python-version 125 | 126 | # pipenv 127 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 128 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 129 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 130 | # install all needed dependencies. 131 | #Pipfile.lock 132 | 133 | # poetry 134 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 135 | # This is especially recommended for binary packages to ensure reproducibility, and is more 136 | # commonly ignored for libraries. 137 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 138 | #poetry.lock 139 | 140 | # pdm 141 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 142 | #pdm.lock 143 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 144 | # in version control. 145 | # https://pdm.fming.dev/#use-with-ide 146 | .pdm.toml 147 | 148 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 149 | __pypackages__/ 150 | 151 | # Celery stuff 152 | celerybeat-schedule 153 | celerybeat.pid 154 | 155 | # SageMath parsed files 156 | *.sage.py 157 | 158 | # Environments 159 | .env 160 | .venv 161 | env/ 162 | venv/ 163 | ENV/ 164 | env.bak/ 165 | venv.bak/ 166 | 167 | # Spyder project settings 168 | .spyderproject 169 | .spyproject 170 | 171 | # Rope project settings 172 | .ropeproject 173 | 174 | # mkdocs documentation 175 | /site 176 | 177 | # mypy 178 | .mypy_cache/ 179 | .dmypy.json 180 | dmypy.json 181 | 182 | # Pyre type checker 183 | .pyre/ 184 | 185 | # pytype static type analyzer 186 | .pytype/ 187 | 188 | # Cython debug symbols 189 | cython_debug/ 190 | 191 | # PyCharm 192 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 193 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 194 | # and can be added to the global gitignore or merged into this file. For a more nuclear 195 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 196 | #.idea/ 197 | 198 | -------------------------------------------------------------------------------- /haca3/modules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import errno 6 | import nibabel as nib 7 | from torchvision import utils 8 | import torchvision.models as models 9 | import numpy as np 10 | 11 | 12 | def mkdir_p(path): 13 | try: 14 | os.makedirs(path) 15 | except OSError as exc: 16 | if exc.errno == errno.EEXIST and os.path.isdir(path): 17 | pass 18 | else: 19 | raise 20 | 21 | 22 | def reparameterize_logit(logit): 23 | import warnings 24 | warnings.filterwarnings('ignore', message='.*Mixed memory format inputs detected.*') 25 | beta = F.gumbel_softmax(logit, tau=1.0, dim=1, hard=True) 26 | return beta 27 | 28 | 29 | def save_image(images, file_name): 30 | image_save = torch.cat([image[:4, [0], ...].cpu() for image in images], dim=0) 31 | image_save = utils.make_grid(tensor=image_save, nrow=4, normalize=False, range=(0, 1)).detach().numpy()[0, ...] 32 | image_save = nib.Nifti1Image(image_save.transpose(1, 0), np.eye(4)) 33 | nib.save(image_save, file_name) 34 | 35 | 36 | def dropout_contrasts(available_contrast_id, contrast_id_to_drop=None): 37 | """ 38 | Randomly dropout contrasts for HACA3 training. 39 | 40 | ==INPUTS==j 41 | * available_contrast_id: torch.Tensor (batch_size, num_contrasts) 42 | Indicates the availability of each MR contrast. 1: if available, 0: if unavailable. 43 | 44 | * contrast_id_to_drop: torch.Tensor (batch_size, num_contrasts) 45 | If provided, indicates the contrast indexes forced to drop. Default: None 46 | 47 | ==OUTPUTS== 48 | * contrast_id_after_dropout: torch.Tensor (batch_size, num_contrasts) 49 | Some available contrasts will be randomly dropped out (as if they are unavailable). 50 | However, each sample will have at least one contrast available. 51 | """ 52 | batch_size = available_contrast_id.shape[0] 53 | if contrast_id_to_drop is not None: 54 | available_contrast_id = available_contrast_id - contrast_id_to_drop 55 | contrast_id_after_dropout = available_contrast_id.clone() 56 | for i in range(batch_size): 57 | available_contrast_ids_per_subject = (available_contrast_id[i] == 1).nonzero(as_tuple=False).squeeze(1) 58 | num_available_contrasts = available_contrast_ids_per_subject.numel() 59 | if num_available_contrasts > 1: 60 | num_contrast_to_drop = torch.randperm(num_available_contrasts - 1)[0] 61 | contrast_ids_to_drop = torch.randperm(num_available_contrasts)[:num_contrast_to_drop] 62 | contrast_ids_to_drop = available_contrast_ids_per_subject[contrast_ids_to_drop] 63 | contrast_id_after_dropout[i, contrast_ids_to_drop] = 0.0 64 | return contrast_id_after_dropout 65 | 66 | 67 | class PerceptualLoss(nn.Module): 68 | def __init__(self, vgg_model): 69 | super().__init__() 70 | for param in vgg_model.parameters(): 71 | param.requires_grad = False 72 | self.vgg = nn.Sequential(*list(vgg_model.children())[:13]).eval() 73 | 74 | def forward(self, x, y): 75 | if x.shape[1] == 1: 76 | x = x.repeat(1, 3, 1, 1) 77 | if y.shape[1] == 1: 78 | y = y.repeat(1, 3, 1, 1) 79 | return F.l1_loss(self.vgg(x), self.vgg(y)) 80 | 81 | 82 | class PatchNCELoss(nn.Module): 83 | def __init__(self, temperature=0.1): 84 | super().__init__() 85 | self.ce_loss = nn.CrossEntropyLoss(reduction='none') 86 | self.temperature = temperature 87 | 88 | def forward(self, query_feature, positive_feature, negative_feature): 89 | B, C, N = query_feature.shape 90 | 91 | l_positive = (query_feature * positive_feature).sum(dim=1)[:, :, None] 92 | l_negative = torch.bmm(query_feature.permute(0, 2, 1), negative_feature) 93 | 94 | logits = torch.cat((l_positive, l_negative), dim=2) / self.temperature 95 | 96 | predictions = logits.flatten(0, 1) 97 | targets = torch.zeros(B * N, dtype=torch.long).to(query_feature.device) 98 | return self.ce_loss(predictions, targets).mean() 99 | 100 | 101 | class KLDivergenceLoss(nn.Module): 102 | def __init__(self): 103 | super().__init__() 104 | 105 | def forward(self, mu, logvar): 106 | kld_loss = -0.5 * logvar + 0.5 * (torch.exp(logvar) + torch.pow(mu, 2)) - 0.5 107 | return kld_loss 108 | 109 | 110 | def divide_into_batches(in_tensor, num_batches): 111 | batch_size = in_tensor.shape[0] // num_batches 112 | remainder = in_tensor.shape[0] % num_batches 113 | batches = [] 114 | 115 | current_start = 0 116 | for i in range(num_batches): 117 | current_end = current_start + batch_size 118 | if remainder: 119 | current_end += 1 120 | remainder -= 1 121 | batches.append(in_tensor[current_start:current_end, ...]) 122 | current_start = current_end 123 | return batches 124 | 125 | 126 | def normalize_intensity(image): 127 | thresh = np.percentile(image.flatten(), 95) 128 | image = image / (thresh + 1e-5) 129 | image = np.clip(image, a_min=0.0, a_max=5.0) 130 | return image, thresh 131 | 132 | 133 | def zero_pad(image, image_dim=256): 134 | [n_row, n_col, n_slc] = image.shape 135 | image_padded = np.zeros((image_dim, image_dim, image_dim)) 136 | center_loc = image_dim // 2 137 | image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 138 | center_loc - n_col // 2: center_loc + n_col - n_col // 2, 139 | center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] = image 140 | return image_padded 141 | 142 | def zero_pad2d(image, image_dim=256): 143 | [n_row, n_col] = image.shape 144 | image_padded = np.zeros((image_dim, image_dim)) 145 | center_loc = image_dim // 2 146 | image_padded[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 147 | center_loc - n_col // 2: center_loc + n_col - n_col // 2] = image 148 | return image_padded 149 | 150 | 151 | def crop(image, n_row, n_col, n_slc): 152 | image_dim = image.shape[0] 153 | center_loc = image_dim // 2 154 | return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 155 | center_loc - n_col // 2: center_loc + n_col - n_col // 2, 156 | center_loc - n_slc // 2: center_loc + n_slc - n_slc // 2] 157 | 158 | def crop2d(image, n_row, n_col): 159 | image_dim = image.shape[0] 160 | center_loc = image_dim // 2 161 | return image[center_loc - n_row // 2: center_loc + n_row - n_row // 2, 162 | center_loc - n_col // 2: center_loc + n_col - n_col // 2] 163 | -------------------------------------------------------------------------------- /haca3/_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # This file is part of 'miniver': https://github.com/jbweston/miniver 3 | # 4 | from collections import namedtuple 5 | import os 6 | import subprocess 7 | 8 | from setuptools.command.build_py import build_py as build_py_orig 9 | from setuptools.command.sdist import sdist as sdist_orig 10 | 11 | Version = namedtuple("Version", ("release", "dev", "labels")) 12 | 13 | # No public API 14 | __all__ = [] 15 | 16 | package_root = os.path.dirname(os.path.realpath(__file__)) 17 | package_name = os.path.basename(package_root) 18 | 19 | STATIC_VERSION_FILE = "_static_version.py" 20 | 21 | 22 | def get_version(version_file=STATIC_VERSION_FILE): 23 | version_info = get_static_version_info(version_file) 24 | version = version_info["version"] 25 | if version == "__use_git__": 26 | version = get_version_from_git() 27 | if not version: 28 | version = get_version_from_git_archive(version_info) 29 | if not version: 30 | version = Version("unknown", None, None) 31 | return pep440_format(version) 32 | else: 33 | return version 34 | 35 | 36 | def get_static_version_info(version_file=STATIC_VERSION_FILE): 37 | version_info = {} 38 | with open(os.path.join(package_root, version_file), "rb") as f: 39 | exec(f.read(), {}, version_info) 40 | return version_info 41 | 42 | 43 | def version_is_from_git(version_file=STATIC_VERSION_FILE): 44 | return get_static_version_info(version_file)["version"] == "__use_git__" 45 | 46 | 47 | def pep440_format(version_info): 48 | release, dev, labels = version_info 49 | 50 | version_parts = [release] 51 | if dev: 52 | if release.endswith("-dev") or release.endswith(".dev"): 53 | version_parts.append(dev) 54 | else: # prefer PEP440 over strict adhesion to semver 55 | version_parts.append(".dev{}".format(dev)) 56 | 57 | if labels: 58 | version_parts.append("+") 59 | version_parts.append(".".join(labels)) 60 | 61 | return "".join(version_parts) 62 | 63 | 64 | def get_version_from_git(): 65 | # git describe --first-parent does not take into account tags from branches 66 | # that were merged-in. The '--long' flag gets us the 'dev' version and 67 | # git hash, '--always' returns the git hash even if there are no tags. 68 | for opts in [["--first-parent"], []]: 69 | try: 70 | p = subprocess.Popen( 71 | ["git", "describe", "--long", "--always"] + opts, 72 | cwd=package_root, 73 | stdout=subprocess.PIPE, 74 | stderr=subprocess.PIPE, 75 | ) 76 | except OSError: 77 | return 78 | if p.wait() == 0: 79 | break 80 | else: 81 | return 82 | 83 | description = ( 84 | p.communicate()[0] 85 | .decode() 86 | .strip("v") # Tags can have a leading 'v', but the version should not 87 | .rstrip("\n") 88 | .rsplit("-", 2) # Split the latest tag, commits since tag, and hash 89 | ) 90 | 91 | try: 92 | release, dev, git = description 93 | except ValueError: # No tags, only the git hash 94 | # prepend 'g' to match with format returned by 'git describe' 95 | git = "g{}".format(*description) 96 | release = "unknown" 97 | dev = None 98 | 99 | labels = [] 100 | if dev == "0": 101 | dev = None 102 | else: 103 | labels.append(git) 104 | 105 | try: 106 | p = subprocess.Popen(["git", "diff", "--quiet"], cwd=package_root) 107 | except OSError: 108 | labels.append("confused") # This should never happen. 109 | else: 110 | if p.wait() == 1: 111 | labels.append("dirty") 112 | 113 | return Version(release, dev, labels) 114 | 115 | 116 | # TODO: change this logic when there is a git pretty-format 117 | # that gives the same output as 'git describe'. 118 | # Currently we can only tell the tag the current commit is 119 | # pointing to, or its hash (with no version info) 120 | # if it is not tagged. 121 | def get_version_from_git_archive(version_info): 122 | try: 123 | refnames = version_info["refnames"] 124 | git_hash = version_info["git_hash"] 125 | except KeyError: 126 | # These fields are not present if we are running from an sdist. 127 | # Execution should never reach here, though 128 | return None 129 | 130 | if git_hash.startswith("$Format") or refnames.startswith("$Format"): 131 | # variables not expanded during 'git archive' 132 | return None 133 | 134 | VTAG = "tag: v" 135 | refs = set(r.strip() for r in refnames.split(",")) 136 | version_tags = set(r[len(VTAG) :] for r in refs if r.startswith(VTAG)) 137 | if version_tags: 138 | release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1" 139 | return Version(release, dev=None, labels=None) 140 | else: 141 | return Version("unknown", dev=None, labels=["g{}".format(git_hash)]) 142 | 143 | 144 | __version__ = get_version() 145 | 146 | 147 | # The following section defines a 'get_cmdclass' function 148 | # that can be used from setup.py. The '__version__' module 149 | # global is used (but not modified). 150 | 151 | 152 | def _write_version(fname): 153 | # This could be a hard link, so try to delete it first. Is there any way 154 | # to do this atomically together with opening? 155 | try: 156 | os.remove(fname) 157 | except OSError: 158 | pass 159 | with open(fname, "w") as f: 160 | f.write( 161 | "# This file has been created by setup.py.\n" 162 | "version = '{}'\n".format(__version__) 163 | ) 164 | 165 | 166 | def get_cmdclass(pkg_source_path): 167 | class _build_py(build_py_orig): 168 | def run(self): 169 | super().run() 170 | 171 | src_marker = "".join(["src", os.path.sep]) 172 | 173 | if pkg_source_path.startswith(src_marker): 174 | path = pkg_source_path[len(src_marker):] 175 | else: 176 | path = pkg_source_path 177 | _write_version( 178 | os.path.join( 179 | self.build_lib, path, STATIC_VERSION_FILE 180 | ) 181 | ) 182 | 183 | class _sdist(sdist_orig): 184 | def make_release_tree(self, base_dir, files): 185 | super().make_release_tree(base_dir, files) 186 | _write_version( 187 | os.path.join(base_dir, pkg_source_path, STATIC_VERSION_FILE) 188 | ) 189 | 190 | return dict(sdist=_sdist, build_py=_build_py) 191 | -------------------------------------------------------------------------------- /haca3/modules/_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # This file is part of 'miniver': https://github.com/jbweston/miniver 3 | # 4 | from collections import namedtuple 5 | import os 6 | import subprocess 7 | 8 | from setuptools.command.build_py import build_py as build_py_orig 9 | from setuptools.command.sdist import sdist as sdist_orig 10 | 11 | Version = namedtuple("Version", ("release", "dev", "labels")) 12 | 13 | # No public API 14 | __all__ = [] 15 | 16 | package_root = os.path.dirname(os.path.realpath(__file__)) 17 | package_name = os.path.basename(package_root) 18 | 19 | STATIC_VERSION_FILE = "_static_version.py" 20 | 21 | 22 | def get_version(version_file=STATIC_VERSION_FILE): 23 | version_info = get_static_version_info(version_file) 24 | version = version_info["version"] 25 | if version == "__use_git__": 26 | version = get_version_from_git() 27 | if not version: 28 | version = get_version_from_git_archive(version_info) 29 | if not version: 30 | version = Version("unknown", None, None) 31 | return pep440_format(version) 32 | else: 33 | return version 34 | 35 | 36 | def get_static_version_info(version_file=STATIC_VERSION_FILE): 37 | version_info = {} 38 | with open(os.path.join(package_root, version_file), "rb") as f: 39 | exec(f.read(), {}, version_info) 40 | return version_info 41 | 42 | 43 | def version_is_from_git(version_file=STATIC_VERSION_FILE): 44 | return get_static_version_info(version_file)["version"] == "__use_git__" 45 | 46 | 47 | def pep440_format(version_info): 48 | release, dev, labels = version_info 49 | 50 | version_parts = [release] 51 | if dev: 52 | if release.endswith("-dev") or release.endswith(".dev"): 53 | version_parts.append(dev) 54 | else: # prefer PEP440 over strict adhesion to semver 55 | version_parts.append(".dev{}".format(dev)) 56 | 57 | if labels: 58 | version_parts.append("+") 59 | version_parts.append(".".join(labels)) 60 | 61 | return "".join(version_parts) 62 | 63 | 64 | def get_version_from_git(): 65 | # git describe --first-parent does not take into account tags from branches 66 | # that were merged-in. The '--long' flag gets us the 'dev' version and 67 | # git hash, '--always' returns the git hash even if there are no tags. 68 | for opts in [["--first-parent"], []]: 69 | try: 70 | p = subprocess.Popen( 71 | ["git", "describe", "--long", "--always"] + opts, 72 | cwd=package_root, 73 | stdout=subprocess.PIPE, 74 | stderr=subprocess.PIPE, 75 | ) 76 | except OSError: 77 | return 78 | if p.wait() == 0: 79 | break 80 | else: 81 | return 82 | 83 | description = ( 84 | p.communicate()[0] 85 | .decode() 86 | .strip("v") # Tags can have a leading 'v', but the version should not 87 | .rstrip("\n") 88 | .rsplit("-", 2) # Split the latest tag, commits since tag, and hash 89 | ) 90 | 91 | try: 92 | release, dev, git = description 93 | except ValueError: # No tags, only the git hash 94 | # prepend 'g' to match with format returned by 'git describe' 95 | git = "g{}".format(*description) 96 | release = "unknown" 97 | dev = None 98 | 99 | labels = [] 100 | if dev == "0": 101 | dev = None 102 | else: 103 | labels.append(git) 104 | 105 | try: 106 | p = subprocess.Popen(["git", "diff", "--quiet"], cwd=package_root) 107 | except OSError: 108 | labels.append("confused") # This should never happen. 109 | else: 110 | if p.wait() == 1: 111 | labels.append("dirty") 112 | 113 | return Version(release, dev, labels) 114 | 115 | 116 | # TODO: change this logic when there is a git pretty-format 117 | # that gives the same output as 'git describe'. 118 | # Currently we can only tell the tag the current commit is 119 | # pointing to, or its hash (with no version info) 120 | # if it is not tagged. 121 | def get_version_from_git_archive(version_info): 122 | try: 123 | refnames = version_info["refnames"] 124 | git_hash = version_info["git_hash"] 125 | except KeyError: 126 | # These fields are not present if we are running from an sdist. 127 | # Execution should never reach here, though 128 | return None 129 | 130 | if git_hash.startswith("$Format") or refnames.startswith("$Format"): 131 | # variables not expanded during 'git archive' 132 | return None 133 | 134 | VTAG = "tag: v" 135 | refs = set(r.strip() for r in refnames.split(",")) 136 | version_tags = set(r[len(VTAG) :] for r in refs if r.startswith(VTAG)) 137 | if version_tags: 138 | release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1" 139 | return Version(release, dev=None, labels=None) 140 | else: 141 | return Version("unknown", dev=None, labels=["g{}".format(git_hash)]) 142 | 143 | 144 | __version__ = get_version() 145 | 146 | 147 | # The following section defines a 'get_cmdclass' function 148 | # that can be used from setup.py. The '__version__' module 149 | # global is used (but not modified). 150 | 151 | 152 | def _write_version(fname): 153 | # This could be a hard link, so try to delete it first. Is there any way 154 | # to do this atomically together with opening? 155 | try: 156 | os.remove(fname) 157 | except OSError: 158 | pass 159 | with open(fname, "w") as f: 160 | f.write( 161 | "# This file has been created by setup.py.\n" 162 | "version = '{}'\n".format(__version__) 163 | ) 164 | 165 | 166 | def get_cmdclass(pkg_source_path): 167 | class _build_py(build_py_orig): 168 | def run(self): 169 | super().run() 170 | 171 | src_marker = "".join(["src", os.path.sep]) 172 | 173 | if pkg_source_path.startswith(src_marker): 174 | path = pkg_source_path[len(src_marker):] 175 | else: 176 | path = pkg_source_path 177 | _write_version( 178 | os.path.join( 179 | self.build_lib, path, STATIC_VERSION_FILE 180 | ) 181 | ) 182 | 183 | class _sdist(sdist_orig): 184 | def make_release_tree(self, base_dir, files): 185 | super().make_release_tree(base_dir, files) 186 | _write_version( 187 | os.path.join(base_dir, pkg_source_path, STATIC_VERSION_FILE) 188 | ) 189 | 190 | return dict(sdist=_sdist, build_py=_build_py) 191 | -------------------------------------------------------------------------------- /haca3/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from pathlib import Path 4 | 5 | import nibabel as nib 6 | import numpy as np 7 | import torch 8 | from torchvision.transforms import ToTensor 9 | 10 | from skimage.filters import threshold_otsu 11 | from skimage.morphology import isotropic_closing 12 | 13 | from .modules.model import HACA3 14 | from .modules.utils import * 15 | 16 | 17 | def background_removal(image_vol): 18 | [n_row, n_col, n_slc] = image_vol.shape 19 | thresh = threshold_otsu(image_vol) 20 | mask = (image_vol >= thresh) * 1.0 21 | mask = zero_pad(mask, 256) 22 | mask = isotropic_closing(mask, radius=20) 23 | mask = crop(mask, n_row, n_col, n_slc) 24 | image_vol[mask < 1e-4] = 0.0 25 | return image_vol 26 | 27 | def background_removal2d(image_vol): 28 | [n_row, n_col] = image_vol.shape 29 | thresh = threshold_otsu(image_vol) 30 | mask = (image_vol >= thresh) * 1.0 31 | mask = zero_pad2d(mask, 256) 32 | mask = isotropic_closing(mask, radius=20) 33 | mask = crop2d(mask, n_row, n_col) 34 | image_vol[mask < 1e-4] = 0.0 35 | return image_vol 36 | 37 | def obtain_single_image(image_path, bg_removal=True): 38 | image_obj = nib.Nifti1Image.from_filename(image_path) 39 | image_vol = np.array(image_obj.get_fdata().astype(np.float32)) 40 | thresh = np.percentile(image_vol.flatten(), 95) 41 | image_vol = image_vol / (thresh + 1e-5) 42 | image_vol = np.clip(image_vol, a_min=0.0, a_max=5.0) 43 | if bg_removal: 44 | image_vol = background_removal(image_vol) 45 | 46 | n_row, n_col, n_slc = image_vol.shape 47 | # zero padding 48 | image_padded = np.zeros((224, 224, 224)).astype(np.float32) 49 | image_padded[112 - n_row // 2:112 + n_row // 2 + n_row % 2, 50 | 112 - n_col // 2:112 + n_col // 2 + n_col % 2, 51 | 112 - n_slc // 2:112 + n_slc // 2 + n_slc % 2] = image_vol 52 | return ToTensor()(image_padded), image_obj.header, thresh 53 | 54 | 55 | def load_source_images(image_paths, bg_removal=True): 56 | source_images = [] 57 | image_header = None 58 | for image_path in image_paths: 59 | image_vol, image_header, _ = obtain_single_image(image_path, bg_removal) 60 | source_images.append(image_vol.float().permute(2, 1, 0)) 61 | return source_images, image_header 62 | 63 | 64 | def main(args=None): 65 | args = sys.argv[1:] if args is None else args 66 | parser = argparse.ArgumentParser(description='Harmonization with HACA3.') 67 | parser.add_argument('--in-path', type=Path, action='append', required=True) 68 | parser.add_argument('--target-image', type=Path, action='append', default=[]) 69 | parser.add_argument('--target-theta', type=float, nargs=2, action='append', default=[]) 70 | parser.add_argument('--target-eta', type=float, nargs=2, action='append', default=[]) 71 | parser.add_argument('--norm-val', type=float, action='append', default=[]) 72 | parser.add_argument('--out-path', type=Path, action='append', required=True) 73 | parser.add_argument('--harmonization-model', type=Path, required=True) 74 | parser.add_argument('--fusion-model', type=Path) 75 | parser.add_argument('--beta-dim', type=int, default=5) 76 | parser.add_argument('--theta-dim', type=int, default=2) 77 | parser.add_argument('--save-intermediate', action='store_true', default=False) 78 | parser.add_argument('--intermediate-out-dir', type=Path, default=Path.cwd()) 79 | parser.add_argument('--no-bg-removal', dest='bg_removal', action='store_false', default=True) 80 | parser.add_argument('--gpu-id', type=int, default=0) 81 | parser.add_argument('--num-batches', type=int, default=4) 82 | 83 | args = parser.parse_args(args) 84 | 85 | text_div = '=' * 10 86 | print(f'{text_div} BEGIN HACA3 HARMONIZATION {text_div}') 87 | 88 | # ==== GET ABSOLUTE PATHS ==== 89 | for argname in ['in_path', 'target_image', 'out_path', 'harmonization_model', 90 | 'fusion_model', 'intermediate_out_dir']: 91 | if isinstance(getattr(args, argname), list): 92 | setattr(args, argname, [path.resolve() for path in getattr(args, argname)]) 93 | else: 94 | setattr(args, argname, getattr(args, argname).resolve()) 95 | 96 | # ==== SET DEFAULT FOR NORM/ETA ==== 97 | if len(args.target_theta) > 0 and len(args.target_eta) == 0: 98 | args.target_eta = [[0.3, 0.5]] 99 | if len(args.target_theta) > 0 and len(args.norm_val) == 0: 100 | args.norm_val = [1000] 101 | 102 | # ==== CHECK CONDITIONS OF INPUT ARGUMENTS ==== 103 | if not ((len(args.target_image) > 0) ^ (len(args.target_theta) > 0)): 104 | parser.error("'--target-image' or '--target-theta' value should be provided.") 105 | 106 | if 0 < len(args.target_image) != len(args.out_path): 107 | parser.error("Number of '--out-path' and '--target-image' options should be the same.") 108 | 109 | if len(args.target_theta) == 1 and len(args.target_eta) > 1: 110 | args.target_theta = args.target_theta * len(args.target_eta) 111 | 112 | if len(args.target_theta) > 1 and len(args.target_eta) == 1: 113 | args.target_eta = args.target_eta * len(args.target_theta) 114 | 115 | if len(args.target_theta) > 1 and len(args.norm_val) == 1: 116 | args.norm_val = args.norm_val * len(args.target_theta) 117 | 118 | if 0 < len(args.target_theta) != len(args.target_eta): 119 | parser.error("Number of '--target-theta' and '--target-eta' options should be the same.") 120 | 121 | if 0 < len(args.target_theta) != len(args.norm_val): 122 | parser.error("Number of '--target-theta' and '--norm-val' options should be the same.") 123 | 124 | if 0 < len(args.target_theta) != len(args.out_path): 125 | parser.error("Number of '--target-theta' and '--out-path' options should be the same.") 126 | 127 | if args.save_intermediate: 128 | mkdir_p(args.intermediate_out_dir) 129 | 130 | # ==== INITIALIZE MODEL ==== 131 | haca3 = HACA3(beta_dim=args.beta_dim, 132 | theta_dim=args.theta_dim, 133 | eta_dim=2, 134 | pretrained_haca3=args.harmonization_model, 135 | gpu_id=args.gpu_id) 136 | 137 | # ==== LOAD SOURCE IMAGES ==== 138 | source_images, image_header = load_source_images(args.in_path, args.bg_removal) 139 | 140 | # ==== LOAD TARGET IMAGES IF PROVIDED ==== 141 | if len(args.target_image) > 0: 142 | target_images, norm_vals = [], [] 143 | for target_image_path, out_path in zip(args.target_image, args.out_path): 144 | target_image_tmp, tmp_header, norm_val = obtain_single_image(target_image_path, args.bg_removal) 145 | target_images.append(target_image_tmp.permute(2, 1, 0).permute(0, 2, 1).flip(1)[100:120, ...]) 146 | norm_vals.append(norm_val) 147 | if args.save_intermediate: 148 | out_prefix = out_path.name.replace('.nii.gz', '') 149 | save_img = target_image_tmp.permute(1, 2, 0).numpy()[112 - 96:112 + 96, :, 112 - 96:112 + 96] 150 | target_obj = nib.Nifti1Image(save_img * norm_val, None, tmp_header) 151 | target_obj.to_filename(args.intermediate_out_dir / f'{out_prefix}_target.nii.gz') 152 | if args.save_intermediate: 153 | out_prefix = args.out_path[0].name.replace('.nii.gz', '') 154 | with open(args.intermediate_out_dir / f'{out_prefix}_targetnorms.txt', 'w') as fp: 155 | fp.write('image,norm_val\n') 156 | for i, norm_val in enumerate(norm_vals): 157 | fp.write(f'{i},{norm_val:.6f}\n') 158 | np.savetxt(args.intermediate_out_dir / f'{out_prefix}_targetnorms.txt', norm_vals) 159 | target_theta = None 160 | target_eta = None 161 | else: 162 | target_images = None 163 | target_theta = torch.as_tensor(args.target_theta, dtype=torch.float32) 164 | target_eta = torch.as_tensor(args.target_eta, dtype=torch.float32) 165 | norm_vals = args.norm_val 166 | 167 | # ===== BEGIN HARMONIZATION WITH HACA3 ===== 168 | haca3.harmonize( 169 | source_images=[image.permute(2, 0, 1) for image in source_images], 170 | target_images=target_images, 171 | target_theta=target_theta, 172 | target_eta=target_eta, 173 | out_paths=args.out_path, 174 | header=image_header, 175 | recon_orientation='axial', 176 | norm_vals=norm_vals, 177 | num_batches=args.num_batches, 178 | save_intermediate=args.save_intermediate, 179 | intermediate_out_dir=args.intermediate_out_dir, 180 | ) 181 | 182 | haca3.harmonize( 183 | source_images=[image.permute(0, 2, 1).flip(1) for image in source_images], 184 | target_images=target_images, 185 | target_theta=target_theta, 186 | target_eta=target_eta, 187 | out_paths=args.out_path, 188 | header=image_header, 189 | recon_orientation='coronal', 190 | norm_vals=norm_vals, 191 | num_batches=args.num_batches, 192 | save_intermediate=args.save_intermediate, 193 | intermediate_out_dir=args.intermediate_out_dir, 194 | ) 195 | 196 | haca3.harmonize( 197 | source_images=[image.permute(1, 2, 0).flip(1) for image in source_images], 198 | target_images=target_images, 199 | target_theta=target_theta, 200 | target_eta=target_eta, 201 | out_paths=args.out_path, 202 | header=image_header, 203 | recon_orientation='sagittal', 204 | norm_vals=norm_vals, 205 | num_batches=args.num_batches, 206 | save_intermediate=args.save_intermediate, 207 | intermediate_out_dir=args.intermediate_out_dir, 208 | ) 209 | 210 | print(f'{text_div} START FUSION {text_div}') 211 | for out_path, norm_val in zip(args.out_path, norm_vals): 212 | prefix = out_path.name.replace('.nii.gz', '') 213 | decode_img_paths = [out_path.parent / f'{prefix}_harmonized_{orient}.nii.gz' 214 | for orient in ['axial', 'coronal', 'sagittal']] 215 | haca3.combine_images(decode_img_paths, out_path, norm_val, args.fusion_model) 216 | -------------------------------------------------------------------------------- /haca3/modules/fusion_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from glob import glob 5 | 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from torchvision.transforms import ToTensor 10 | from torch.utils.data.dataset import Dataset 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torchvision import utils 13 | import torch.nn.functional as F 14 | import nibabel as nib 15 | from .utils import mkdir_p 16 | 17 | class Net(nn.Module): 18 | def __init__(self, in_ch=3, out_ch=1): 19 | super().__init__() 20 | self.conv1 = nn.Sequential( 21 | nn.Conv3d(in_ch, 8, 3, 1, 1), 22 | nn.InstanceNorm3d(8), 23 | nn.LeakyReLU(), 24 | nn.Conv3d(8, 16, 3, 1, 1), 25 | nn.InstanceNorm3d(16), 26 | nn.LeakyReLU()) 27 | self.conv2 = nn.Sequential( 28 | nn.Conv3d(in_ch+16, 16, 3, 1, 1), 29 | #nn.InstanceNorm3d(16), 30 | nn.LeakyReLU(), 31 | nn.Conv3d(16, out_ch, 3, 1, 1), 32 | nn.ReLU()) 33 | 34 | def forward(self, x): 35 | #return self.conv2(x + self.conv1(x)) 36 | return self.conv2(torch.cat([x, self.conv1(x)], dim=1)) 37 | 38 | class MultiOrientationDataset(Dataset): 39 | def __init__(self, dataset_dirs): 40 | self.dataset_dirs = dataset_dirs 41 | self.imgs = self._get_files() 42 | 43 | def _get_files(self): 44 | img_paths = [] 45 | contrast_names = ['T1', 'T2', 'PD', 'FLAIR'] 46 | for dataset_dir in self.dataset_dirs: 47 | for contrast_name in contrast_names: 48 | img_path_tmp = os.path.join(dataset_dir, f'*harmonized_to_{contrast_name}_ori.nii.gz') 49 | img_path_tmp = sorted(glob(img_path_tmp)) 50 | img_paths = img_paths + img_path_tmp 51 | return img_paths 52 | 53 | def __len__(self): 54 | return len(self.imgs) 55 | 56 | def get_tensor_from_path(self, img_path, if_norm_val=False): 57 | img = nib.load(img_path).get_fdata().astype(np.float32) 58 | img = ToTensor()(img) 59 | img = img.permute(2,1,0).permute(2,0,1).unsqueeze(0) 60 | img, norm_val = self.normalize_intensity(img) 61 | if if_norm_val: 62 | return img, norm_val 63 | else: 64 | return img 65 | 66 | def normalize_intensity(self, image): 67 | p99 = np.percentile(image.flatten(), 99) 68 | image = np.clip(image, a_min=0.0, a_max=p99) 69 | image = image / p99 70 | return image, p99 71 | 72 | def __getitem__(self, idx:int): 73 | img_path = self.imgs[idx] 74 | str_id = img_path.find('_ori') 75 | axial_img_path = img_path[:str_id] + '_axial.nii.gz' 76 | coronal_img_path = img_path[:str_id] + '_coronal.nii.gz' 77 | sagittal_img_path = img_path[:str_id] + '_sagittal.nii.gz' 78 | ori_image, norm_val = self.get_tensor_from_path(img_path, if_norm_val=True) 79 | img_dict = {'ori_img' : ori_image, 80 | 'axial_img' : self.get_tensor_from_path(axial_img_path), 81 | 'coronal_img' : self.get_tensor_from_path(coronal_img_path), 82 | 'sagittal_img' : self.get_tensor_from_path(sagittal_img_path), 83 | 'norm_val' : norm_val 84 | } 85 | 86 | return img_dict 87 | 88 | 89 | class FusionNet: 90 | def __init__(self, pretrained_model=None, gpu=0): 91 | self.device = torch.device('cuda:0' if gpu==0 else'cuda:1') 92 | 93 | # define networks 94 | self.fusion_net = Net(in_ch=3, out_ch=1) 95 | 96 | # initialize training variables 97 | self.train_loader, self.valid_loader = None, None 98 | self.out_dir = None 99 | self.optim_fusion_net = None 100 | 101 | # load pretrained models 102 | self.checkpoint = None 103 | if pretrained_model is not None: 104 | self.checkpoint = torch.load(pretrained_model, map_location=self.device) 105 | self.fusion_net.load_state_dict(self.checkpoint['fusion_net']) 106 | self.fusion_net.to(self.device) 107 | self.start_epoch = 0 108 | 109 | def load_dataset(self, dataset_dirs, batch_size): 110 | all_dataset = MultiOrientationDataset(dataset_dirs) 111 | num_instances = all_dataset.__len__() 112 | num_train = int(0.8 * num_instances) 113 | num_valid = num_instances - num_train 114 | train_dataset, valid_dataset = torch.utils.data.random_split(all_dataset, 115 | [num_train, num_valid]) 116 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 117 | self.valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 118 | 119 | def initialize_training(self, out_dir, lr): 120 | self.out_dir = out_dir 121 | mkdir_p(out_dir) 122 | mkdir_p(os.path.join(out_dir, 'results')) 123 | mkdir_p(os.path.join(out_dir, 'models')) 124 | 125 | # define loss 126 | self.l1_loss = nn.L1Loss(reduction='none') 127 | 128 | self.optim_fusion_net = torch.optim.Adam(self.fusion_net.parameters(), lr=lr) 129 | 130 | if self.checkpoint is not None: 131 | self.start_epoch = self.checkpoint['epoch'] 132 | self.optim_fusion_net.load_state_dict(self.checkpoint['optim_fusion_net']) 133 | self.start_epoch += 1 134 | 135 | def train(self, epochs): 136 | for epoch in range(self.start_epoch, epochs+1): 137 | self.train_loader = tqdm(self.train_loader) 138 | self.fusion_net.train() 139 | train_loss_sum = 0.0 140 | num_train_imgs = 0 141 | for batch_id, img_dict in enumerate(self.train_loader): 142 | syn_img = torch.cat([ 143 | img_dict['axial_img'], 144 | img_dict['coronal_img'], 145 | img_dict['sagittal_img'] 146 | ], dim=1).to(self.device) 147 | ori_img = img_dict['ori_img'].to(self.device) 148 | batch_size = ori_img.shape[0] 149 | 150 | ori_img = ori_img * (syn_img[:,[0],:,:,:] > 1e-8).detach() 151 | 152 | fusion_img = self.fusion_net(syn_img) 153 | 154 | rec_loss = self.l1_loss(fusion_img, ori_img).mean() 155 | self.optim_fusion_net.zero_grad() 156 | rec_loss.backward() 157 | self.optim_fusion_net.step() 158 | 159 | train_loss_sum += rec_loss.item() * batch_size 160 | num_train_imgs += batch_size 161 | self.train_loader.set_description((f'epoch: {epoch}; ' 162 | f'rec: {rec_loss.item():.3f}; ' 163 | f'avg_trn: {train_loss_sum / num_train_imgs:.3f}; ')) 164 | if batch_id % 50 - 1 == 0: 165 | img_affine = [[-1, 0, 0, 96], [0, -1, 0, 96], [0, 0, 1, -78], [0, 0, 0, 1]] 166 | img_save = np.array(fusion_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2)) 167 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine) 168 | file_name = os.path.join(self.out_dir, 'results', f'train_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_syn.nii.gz') 169 | nib.save(img_save, file_name) 170 | 171 | img_save = np.array(ori_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2)) 172 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine) 173 | file_name = os.path.join(self.out_dir, 'results', f'train_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_ori.nii.gz') 174 | nib.save(img_save, file_name) 175 | 176 | # save models 177 | if batch_id % 100 == 0: 178 | file_name = os.path.join(self.out_dir, 'models', 179 | f'epoch{str(epoch).zfill(3)}_batch{str(batch_id).zfill(4)}.pt') 180 | self.save_model(file_name, epoch) 181 | # VALIDATION 182 | self.valid_loader = tqdm(self.valid_loader) 183 | valid_loss_sum = 0.0 184 | num_valid_imgs = 0 185 | self.fusion_net.eval() 186 | with torch.set_grad_enabled(False): 187 | for batch_id, img_dict in enumerate(self.valid_loader): 188 | syn_img = torch.cat([ 189 | img_dict['axial_img'], 190 | img_dict['coronal_img'], 191 | img_dict['sagittal_img'] 192 | ], dim=1).to(self.device) 193 | ori_img = img_dict['ori_img'].to(self.device) 194 | batch_size = ori_img.shape[0] 195 | 196 | #mask = syn_img[:,[0],:,:,:] > 1e-8 197 | ori_img = ori_img * (syn_img[:,[0],:,:,:] > 1e-8).detach() 198 | 199 | fusion_img = self.fusion_net(syn_img) 200 | 201 | rec_loss = self.l1_loss(fusion_img, ori_img).mean() 202 | 203 | valid_loss_sum += rec_loss.item() * batch_size 204 | num_valid_imgs += batch_size 205 | self.valid_loader.set_description((f'epoch: {epoch}; ' 206 | f'rec: {rec_loss.item():.3f}; ' 207 | f'avg_trn: {valid_loss_sum / num_valid_imgs:.3f}; ')) 208 | if batch_id % 50 - 1 == 0: 209 | img_affine = [[-1, 0, 0, 96], [0, -1, 0, 96], [0, 0, 1, -78], [0, 0, 0, 1]] 210 | img_save = np.array(fusion_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2)) 211 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine) 212 | file_name = os.path.join(self.out_dir, 'results', f'valid_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_syn.nii.gz') 213 | nib.save(img_save, file_name) 214 | 215 | img_save = np.array(ori_img.detach().cpu().squeeze().permute(1,2,0).permute(1,0,2)) 216 | img_save = nib.Nifti1Image(img_save * np.array(img_dict['norm_val']), img_affine) 217 | file_name = os.path.join(self.out_dir, 'results', f'valid_epoch{str(epoch).zfill(2)}_batch{str(batch_id).zfill(3)}_ori.nii.gz') 218 | nib.save(img_save, file_name) 219 | 220 | 221 | def save_model(self, file_name, epoch): 222 | state = {'epoch': epoch, 223 | 'fusion_net': self.fusion_net.state_dict(), 224 | 'optim_fusion_net': self.optim_fusion_net.state_dict()} 225 | torch.save(obj=state, f=file_name) 226 | 227 | 228 | 229 | -------------------------------------------------------------------------------- /haca3/modules/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class FusionNet(nn.Module): 8 | def __init__(self, in_ch=3, out_ch=1): 9 | super().__init__() 10 | self.conv1 = nn.Sequential( 11 | nn.Conv3d(in_ch, 8, 3, 1, 1), 12 | nn.InstanceNorm3d(8), 13 | nn.LeakyReLU(), 14 | nn.Conv3d(8, 16, 3, 1, 1), 15 | nn.InstanceNorm3d(16), 16 | nn.LeakyReLU()) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv3d(in_ch + 16, 16, 3, 1, 1), 19 | nn.LeakyReLU(), 20 | nn.Conv3d(16, out_ch, 3, 1, 1), 21 | nn.ReLU()) 22 | 23 | def forward(self, x): 24 | # return self.conv2(x + self.conv1(x)) 25 | return self.conv2(torch.cat([x, self.conv1(x)], dim=1)) 26 | 27 | class UNet(nn.Module): 28 | def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'): 29 | super().__init__() 30 | self.final_act = final_act 31 | self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1) 32 | 33 | self.down_convs = nn.ModuleList() 34 | self.down_samples = nn.ModuleList() 35 | self.up_samples = nn.ModuleList() 36 | self.up_convs = nn.ModuleList() 37 | for lv in range(num_lvs): 38 | ch = base_ch * (2 ** lv) 39 | self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2)) 40 | self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2)) 41 | self.up_samples.append(Upsample(ch * 4)) 42 | self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2)) 43 | bottleneck_ch = base_ch * (2 ** num_lvs) 44 | self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2) 45 | self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1), 46 | nn.LeakyReLU(0.1), 47 | nn.Conv2d(base_ch, out_ch, 3, 1, 1)) 48 | 49 | def forward(self, in_tensor, condition=None): 50 | encoded_features = [] 51 | x = self.in_conv(in_tensor) 52 | for down_conv, down_sample in zip(self.down_convs, self.down_samples): 53 | if condition is not None: 54 | feature_dim = x.shape[-1] 55 | down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1)) 56 | else: 57 | down_conv_out = down_conv(x) 58 | x = down_sample(down_conv_out) 59 | encoded_features.append(down_conv_out) 60 | x = self.bottleneck_conv(x) 61 | for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features), 62 | reversed(self.up_convs), 63 | reversed(self.up_samples)): 64 | x = up_sample(x, encoded_feature) 65 | x = up_conv(x) 66 | x = self.out_conv(x) 67 | if self.final_act == 'sigmoid': 68 | x = torch.sigmoid(x) 69 | elif self.final_act == "relu": 70 | x = torch.relu(x) 71 | elif self.final_act == 'tanh': 72 | x = torch.tanh(x) 73 | else: 74 | x = x 75 | return x 76 | 77 | 78 | class ConvBlock2d(nn.Module): 79 | def __init__(self, in_ch, mid_ch, out_ch): 80 | super().__init__() 81 | self.conv = nn.Sequential( 82 | nn.Conv2d(in_ch, mid_ch, 3, 1, 1), 83 | nn.InstanceNorm2d(mid_ch), 84 | nn.LeakyReLU(0.1), 85 | nn.Conv2d(mid_ch, out_ch, 3, 1, 1), 86 | nn.InstanceNorm2d(out_ch), 87 | nn.LeakyReLU(0.1) 88 | ) 89 | 90 | def forward(self, in_tensor): 91 | return self.conv(in_tensor) 92 | 93 | 94 | class Upsample(nn.Module): 95 | def __init__(self, in_ch): 96 | super().__init__() 97 | out_ch = in_ch // 2 98 | self.conv = nn.Sequential( 99 | nn.Conv2d(in_ch, out_ch, 3, 1, 1), 100 | nn.InstanceNorm2d(out_ch), 101 | nn.LeakyReLU(0.1) 102 | ) 103 | 104 | def forward(self, in_tensor, encoded_feature): 105 | up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False) 106 | up_sampled_tensor = self.conv(up_sampled_tensor) 107 | return torch.cat([encoded_feature, up_sampled_tensor], dim=1) 108 | 109 | 110 | class EtaEncoder(nn.Module): 111 | def __init__(self, in_ch=1, out_ch=2): 112 | super().__init__() 113 | self.in_conv = nn.Sequential( 114 | nn.Conv2d(in_ch, 16, 5, 1, 2), # (*, 16, 224, 224) 115 | nn.InstanceNorm2d(16), 116 | nn.LeakyReLU(0.1), 117 | nn.Conv2d(16, 64, 3, 1, 1), # (*, 64, 224, 224) 118 | nn.InstanceNorm2d(64), 119 | nn.LeakyReLU(0.1) 120 | ) 121 | self.seq = nn.Sequential( 122 | nn.Conv2d(64 + in_ch, 32, 32, 32, 0), # (*, 32, 7, 7) 123 | nn.InstanceNorm2d(32), 124 | nn.LeakyReLU(0.1), 125 | nn.Conv2d(32, out_ch, 7, 7, 0)) 126 | 127 | def forward(self, x): 128 | return self.seq(torch.cat([self.in_conv(x), x], dim=1)) 129 | 130 | 131 | class Patchifier(nn.Module): 132 | def __init__(self, in_ch, out_ch=1): 133 | super().__init__() 134 | self.conv = nn.Sequential( 135 | nn.Conv2d(in_ch, 64, 32, 32, 0), # (*, in_ch, 224, 224) --> (*, 64, 7, 7) 136 | nn.LeakyReLU(0.1), 137 | nn.Conv2d(64, out_ch, 1, 1, 0)) 138 | 139 | def forward(self, x): 140 | return self.conv(x) 141 | 142 | 143 | class ThetaEncoder(nn.Module): 144 | def __init__(self, in_ch, out_ch): 145 | super().__init__() 146 | self.conv = nn.Sequential( 147 | nn.Conv2d(in_ch, 32, 17, 9, 4), 148 | nn.InstanceNorm2d(32), 149 | nn.LeakyReLU(0.1), # (*, 32, 28, 28) 150 | nn.Conv2d(32, 64, 4, 2, 1), 151 | nn.InstanceNorm2d(64), 152 | nn.LeakyReLU(0.1), # (*, 64, 14, 14) 153 | nn.Conv2d(64, 64, 4, 2, 1), 154 | nn.InstanceNorm2d(64), 155 | nn.LeakyReLU(0.1)) # (* 64, 7, 7) 156 | self.mean_conv = nn.Sequential( 157 | nn.Conv2d(64, 32, 3, 1, 1), 158 | nn.InstanceNorm2d(32), 159 | nn.LeakyReLU(0.1), 160 | nn.Conv2d(32, out_ch, 6, 6, 0)) 161 | self.logvar_conv = nn.Sequential( 162 | nn.Conv2d(64, 32, 3, 1, 1), 163 | nn.InstanceNorm2d(32), 164 | nn.LeakyReLU(0.1), 165 | nn.Conv2d(32, out_ch, 6, 6, 0)) 166 | 167 | def forward(self, x): 168 | M = self.conv(x) 169 | mu = self.mean_conv(M) 170 | logvar = self.logvar_conv(M) 171 | return mu, logvar 172 | 173 | # class ThetaEncoder(nn.Module): 174 | # def __init__(self, in_ch, out_ch): 175 | # super().__init__() 176 | # self.conv = nn.Sequential( 177 | # nn.Conv2d(in_ch, 32, 32, 32, 0), # (*, in_ch, 224, 244) --> (*, 32, 7, 7) 178 | # nn.InstanceNorm2d(32), 179 | # nn.LeakyReLU(0.1), 180 | # nn.Conv2d(32, 64, 1, 1, 0), 181 | # # nn.InstanceNorm2d(64), 182 | # nn.LeakyReLU(0.1)) 183 | # self.mu_conv = nn.Sequential( 184 | # nn.Conv2d(64, 64, 3, 1, 1), 185 | # nn.InstanceNorm2d(64), 186 | # nn.LeakyReLU(0.1), 187 | # nn.Conv2d(64, out_ch, 7, 7, 0)) 188 | # self.logvar_conv = nn.Sequential( 189 | # nn.Conv2d(64, 64, 3, 1, 1), 190 | # nn.InstanceNorm2d(64), 191 | # nn.LeakyReLU(0.1), 192 | # nn.Conv2d(64, out_ch, 7, 7, 0)) 193 | # 194 | # def forward(self, x, patch_shuffle=False): 195 | # m = self.conv(x) 196 | # if patch_shuffle: 197 | # batch_size = m.shape[0] 198 | # num_features = m.shape[1] 199 | # num_patches_per_dim = m.shape[-1] 200 | # m = m.view(batch_size, num_features, -1)[:, :, torch.randperm(num_patches_per_dim ** 2)] 201 | # m = m.view(batch_size, num_features, num_patches_per_dim, num_patches_per_dim) 202 | # mu = self.mu_conv(m) 203 | # logvar = self.logvar_conv(m) 204 | # return mu, logvar 205 | 206 | 207 | class AttentionModule(nn.Module): 208 | def __init__(self, dim, v_ch=5): 209 | super().__init__() 210 | self.dim = dim 211 | self.v_ch = v_ch 212 | self.q_fc = nn.Sequential( 213 | nn.Linear(dim, 128), 214 | nn.LeakyReLU(0.1), 215 | nn.Linear(128, 16), 216 | nn.LayerNorm(16)) 217 | self.k_fc = nn.Sequential( 218 | nn.Linear(dim, 128), 219 | nn.LeakyReLU(0.1), 220 | nn.Linear(128, 16), 221 | nn.LayerNorm(16)) 222 | 223 | self.scale = self.dim ** (-0.5) 224 | 225 | def forward(self, q, k, v, modality_dropout=None, temperature=10.0): 226 | """ 227 | Attention module for optimal anatomy fusion. 228 | 229 | ===INPUTS=== 230 | * q: torch.Tensor (batch_size, feature_dim_q, num_q_patches=1) 231 | Query variable. In HACA3, query is the concatenation of target \theta and target \eta. 232 | * k: torch.Tensor (batch_size, feature_dim_k, num_k_patches=1, num_contrasts=4) 233 | Key variable. In HACA3, keys are \theta and \eta's of source images. 234 | * v: torch.Tensor (batch_size, self.v_ch=5, num_v_patches=224*224, num_contrasts=4) 235 | Value variable. In HACA3, values are multi-channel logits of source images. 236 | self.v_ch is the number of \beta channels. 237 | * modality_dropout: torch.Tensor (batch_size, num_contrasts=4) 238 | Indicates which contrast indexes have been dropped out. 1: if dropped out, 0: if exists. 239 | """ 240 | batch_size, feature_dim_q, num_q_patches = q.shape 241 | _, feature_dim_k, _, num_contrasts = k.shape 242 | num_v_patches = v.shape[2] 243 | assert ( 244 | feature_dim_k == feature_dim_q or feature_dim_q == self.feature_dim 245 | ), 'Feature dimensions do not match.' 246 | 247 | # q.shape: (batch_size, num_q_patches=1, 1, feature_dim_q) 248 | q = q.reshape(batch_size, feature_dim_q, num_q_patches, 1).permute(0, 2, 3, 1) 249 | # k.shape: (batch_size, num_k_patches=1, num_contrasts=4, feature_dim_k) 250 | k = k.permute(0, 2, 3, 1) 251 | # v.shape: (batch_size, num_v_patches=224*224, num_contrasts=4, v_ch=5) 252 | v = v.permute(0, 2, 3, 1) 253 | q = self.q_fc(q) 254 | # k.shape: (batch_size, num_k_patches=1, feature_dim_k, num_contrasts=4) 255 | k = self.k_fc(k).permute(0, 1, 3, 2) 256 | 257 | # dot_prod.shape: (batch_size, num_q_patches=1, 1, num_contrasts=4) 258 | dot_prod = (q @ k) * self.scale 259 | interpolation_factor = int(math.sqrt(num_v_patches // num_q_patches)) 260 | 261 | q_spatial_dim = int(math.sqrt(num_q_patches)) 262 | dot_prod = dot_prod.view(batch_size, q_spatial_dim, q_spatial_dim, num_contrasts) 263 | 264 | image_dim = int(math.sqrt(num_v_patches)) 265 | # dot_prod_interp.shape: (batch_size, image_dim, image_dim, num_contrasts) 266 | dot_prod_interp = dot_prod.repeat(1, interpolation_factor, interpolation_factor, 1) 267 | if modality_dropout is not None: 268 | modality_dropout = modality_dropout.view(batch_size, num_contrasts, 1, 1).permute(0, 2, 3, 1) 269 | dot_prod_interp = dot_prod_interp - (modality_dropout.repeat(1, image_dim, image_dim, 1).detach() * 1e5) 270 | 271 | attention = (dot_prod_interp / temperature).softmax(dim=-1) 272 | v = attention.view(batch_size, num_v_patches, 1, num_contrasts) @ v 273 | v = v.view(batch_size, image_dim, image_dim, self.v_ch).permute(0, 3, 1, 2) 274 | attention = attention.view(batch_size, image_dim, image_dim, num_contrasts).permute(0, 3, 1, 2) 275 | return v, attention 276 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![HACA3 features](figures/GA.png) 2 | 3 | # HACA3: A unified approach for multi-site MR image harmonization | [Paper](https://www.sciencedirect.com/science/article/pii/S0895611123001039) 4 | 5 | HACA3 is an advanced approach for multi-site MRI harmonization. This page provides a gentle introduction to HACA3 inference and training. 6 | 7 | - Publication: [Zuo et al. HACA3: A unified approach for multi-site MR image harmonization. *Computerized Medical Imaging 8 | and Graphics, 2023.*](https://www.sciencedirect.com/science/article/pii/S0895611123001039) 9 | 10 | - Citation: 11 | ```bibtex 12 | @article{ZUO2023102285, 13 | title = {HACA3: A unified approach for multi-site MR image harmonization}, 14 | journal = {Computerized Medical Imaging and Graphics}, 15 | volume = {109}, 16 | pages = {102285}, 17 | year = {2023}, 18 | issn = {0895-6111}, 19 | doi = {https://doi.org/10.1016/j.compmedimag.2023.102285}, 20 | author = {Lianrui Zuo and Yihao Liu and Yuan Xue and Blake E. Dewey and 21 | Samuel W. Remedios and Savannah P. Hays and Murat Bilgel and 22 | Ellen M. Mowry and Scott D. Newsome and Peter A. Calabresi and 23 | Susan M. Resnick and Jerry L. Prince and Aaron Carass} 24 | } 25 | ``` 26 | 27 | ### Recent Updates 28 | - August 11, 2024 - **An [interactive demo](https://colab.research.google.com/drive/1PeBuqOAGupLQ2gXWVneX1Kn31ISh4oFB?usp=share_link) for HACA3 now available.** You can explore HACA3 harmonization and imputation real time. 29 | 30 | ## 1. Introduction and motivation 31 | ### 1.1 The double-edged sword of MRI: flexibility and variability 32 | Magnetic resonance imaging (MRI) is a powerful imaging technique, offering flexibility in capturing various tissue 33 | contrasts in a single imaging session. For example, T1-weighted, T2-weighed, and FLAIR images can be acquired in a single 34 | imaging session to provide comprehensive insights into different tissue properties. However, this flexibility comes at 35 | a cost: ***lack of standardization and consistency*** across imaging studies. Several factors contribute to this 36 | variability, including but not limited to 37 | - Pulse sequences, e.g., MPRAGE, SPGR 38 | - Imaging parameters, e.g., flip angle, echo time 39 | - Scanner manufacturers, e.g, Siemens, GE 40 | - Technician and site preferences. 41 | 42 | ### 1.2 Why should we harmonize MR images? 43 | Contrast variations in MR images may sometimes be subtle but are often significant enough to impact the quality and 44 | reliability of ***multi-site*** and ***longitudinal*** studies. 45 | 46 | - ***Example #1***: Multi-site inconsistency. In this example, two images were acquired at different sites using 47 | distinct imaging parameters. This led to ***different image contrast for the same subject***. As a result, an automatic 48 | segmentation algorithm produced inconsistent outcomes due to these contrast differences. Harmonization effectively 49 | alleviates this issue. 50 |
51 | multi-site 52 |
53 | 54 | - ***Example #2***: Longitudinal study. In this example, longitudinal images were acquired during four different visits. 55 | During Visit #2, the imaging parameters were altered (due to unexpected reasons), causing a noticeable jump in the 56 | estimated volumes of cortical gray matters (GM). Given the cortical GM volume at Visit #3, this jump is unlikely to be 57 | a result of actual biological changes. Harmonization makes the longitudinal trend more biological plausible. See my 58 | [CMSC2023 talk](https://www.youtube.com/watch?v=TpdB55wxgs4&t=2s) to learn more about how harmonization helps 59 | longitudinal study. 60 |
61 | longitudinal 62 |
63 | 64 | 65 | ## 2. Prerequisites 66 | Standard neuroimage preprocessing steps are needed before running HACA3. These preprocessing steps include: 67 | - Inhomogeneity correction 68 | - Super-resolution for 2D acquired scans. This step is optional, but recommended for optimal performance. 69 | See [SMORE](https://github.com/volcanofly/SMORE-Super-resolution-for-3D-medical-images-MRI) for more details. 70 | - Registration to MNI space (1mm isotropic resolution). HACA3 assumes a spatial dimension of 192x224x192. 71 | 72 | ## 3. Installation and pretrained weights 73 | 74 | ### 3.1 Option 1 (recommended): Run HACA3 through singularity image 75 | In general, no installation of HACA3 is required with this option. 76 | Singularity image of HACA3 model can be directly downloaded [**here**](https://iacl.ece.jhu.edu/~lianrui/haca3/haca3_v1.0.9.sif). 77 | 78 | 79 | ### 3.2 Option 2: Install from source using `pip` 80 | 1. Clone the repository: 81 | ```bash 82 | git clone https://github.com/lianruizuo/haca3.git 83 | ``` 84 | 2. Navigate to the directory: 85 | ```bash 86 | cd haca3 87 | ``` 88 | 3. Install dependencies: 89 | ```bash 90 | pip install . 91 | ``` 92 | Package requirements are automatically handled. To see a list of requirements, see `setup.py` L50-60. 93 | This installs the `haca3` package and creates two CLI aliases `haca3-train` and `haca3-test`. 94 | 95 | 96 | ### 3.3 Pretrained weights 97 | Pretrained weights of HACA3 can be downloaded [**here**](https://iacl.ece.jhu.edu/~lianrui/haca3/harmonization_public.pt). 98 | This model was trained on public datasets including the structural MR images from [IXI](https://brain-development.org/ixi-dataset/), 99 | [OASIS3](https://www.oasis-brains.org), and [BLSA](https://www.nia.nih.gov/research/labs/blsa) dataset. 100 | HACA3 uses a 3D convolutional network to combine multi-orientation 2D slices into a single 3D volume. 101 | Pretrained fusion model can be downloaded [**here**](https://iacl.ece.jhu.edu/~lianrui/haca3/fusion.pt). 102 | 103 | ## 4. Usage: Inference 104 | 105 | ### 4.1 Option 1 (recommended): Run HACA3 through singularity image 106 | ```bash 107 | singularity exec --nv -e haca3.sif haca3-test \ 108 | --in-path [PATH-TO-INPUT-SOURCE-IMAGE-1] \ 109 | --in-path [PATH-TO-INPUT-SOURCE-IMAGE-2, IF THERE ARE MULTIPLE SOURCE IMAGES] \ 110 | --target-image [TARGET-IMAGE] \ 111 | --harmonization-model [PRETRAINED-HACA3-MODEL] \ 112 | --fusion-model [PRETRAINED-FUSION-MODEL] \ 113 | --out-path [PATH-TO-HARMONIZED-IMAGE] \ 114 | --intermediate-out-dir [DIRECTORY SAVES INTERMEDIATE RESULTS] 115 | ``` 116 | 117 | - ***Example #3:*** 118 | Suppose the task is to harmonize MR images from `Site A` to match the contrast of a pre-selected T1w image of 119 | `Site B`. As a source site, `Site A` has T1w, T2w, and FLAIR images. The files are saved like this: 120 | ``` 121 | ├──data_directory 122 | ├──site_A_t1w.nii.gz 123 | ├──site_A_t2w.nii.gz 124 | ├──site_A_flair.nii.gz 125 | └──site_B_t1w.nii.gz 126 | ``` 127 | You can always retrain HACA3 using your own datasets. In this example, we choose to use the pretrained HACA3 weights 128 | `harmonization.pt` and fusion model weights `fusion.pt` (see [3.3 Pretrained weights](#33-pretrained-weights) for 129 | how to download these weights). The singularity command to run HACA3 is: 130 | ```bash 131 | singularity exec --nv -e haca3.sif haca3-test \ 132 | --in-path data_directory/site_A_t1w.nii.gz \ 133 | --in-path data_directory/site_A_t2w.nii.gz \ 134 | --in-path data_directory/site_A_flair.nii.gz \ 135 | --target-image data_directory/site_B_t1w.nii.gz \ 136 | --harmonization-model harmonization.pt \ 137 | --fusion-model fusion.pt \ 138 | --out-path output_directory/site_A_harmonized_to_site_B_t1w.nii.gz \ 139 | --intermediate-out-dir output_directory 140 | ``` 141 | The harmonized image and intermediate results will be saved at `output_directory`. 142 | 143 | 144 | ### 4.2 Option 2: Run HACA3 from source after installation 145 | ```bash 146 | haca3-test \ 147 | --in-path [PATH-TO-INPUT-SOURCE-IMAGE-1] \ 148 | --in-path [PATH-TO-INPUT-SOURCE-IMAGE-2, IF THERE ARE MULTIPLE SOURCE IMAGES] \ 149 | --target-image [TARGET-IMAGE] \ 150 | --harmonization-model [PRETRAINED-HACA3-MODEL] \ 151 | --fusion-model [PRETRAINED-FUSION-MODEL] \ 152 | --out-path [PATH-TO-HARMONIZED-IMAGE] \ 153 | --intermediate-out-dir [DIRECTORY-THAT-SAVES-INTERMEDIATE-RESULTS] 154 | ``` 155 | 156 | 157 | ### 4.3 All options for inference 158 | - ```--in-path```: file path to input source image. Multiple ```--in-path``` may be provided if there are multiple 159 | source images. See the above example for more details. 160 | - ```--target-image```: file path to target image. HACA3 will match the contrast of source images to this target image. 161 | - ```--target-theta```: In HACA3, ```theta``` 162 | is a two-dimensional representation of image contrast. Target image contrast can be directly specified by providing 163 | a ```theta``` value, e.g., ```--target-theta 0.5 0.5```. Note: either ```--target-image``` or ```--target-image``` must 164 | be provided during inference. If both are provided, only ```--target-theta``` will be used. 165 | - ```--norm-val```: normalization value. 166 | - ```--out-path```: file path to harmonized image. 167 | - ```--harmonization-model```: pretrained HACA3 weights. Pretrained model weights on IXI, OASIS and HCP data can 168 | be downloaded [here](https://iacl.ece.jhu.edu/~lianrui/haca3/harmonization_public.pt). 169 | - ```--fusion-model```: pretrained fusion model weights. HACA3 uses a 3D convolutional network to combine multi-orientation 170 | 2D slices into a single 3D volume. Pretrained fusion model can be downloaded [here](https://iacl.ece.jhu.edu/~lianrui/haca3/fusion.pt). 171 | - ```--save-intermediate```: if specified, intermediate results will be saved. Default: ```False```. Action: ```store_true```. 172 | - ```--intermediate-out-dir```: directory to save intermediate results. 173 | - ```--gpu-id```: integer number specifies which GPU to run HACA3. 174 | - ```--num-batches```: During inference, HACA3 takes entire 3D MRI volumes as input. This may cause a considerable amount 175 | GPU memory. For reduced GPU memory consumption, source images maybe divided into smaller batches. 176 | However, this may slightly increase the inference time. 177 | 178 | ## 5. Go further with harmonization 179 | - ***Application #1: Identifying optimal operating contrast.*** With the ability of synthesizing arbitrary 180 | contrasts of the same underlying anatomy, we use harmonization to identify the optimal operating contrast (OOC) of various 181 | downstream tasks, e.g., different segmentation algorithms. 182 | - Publications: 183 | [Hays et al. Evaluating the Impact of MR Image Contrast on Whole Brain Segmentation. SPIE 2022.](https://drive.google.com/file/d/1ZxLqJCFORPqhwZCQVM_7r7TwZcn5bbzy/view) 184 | [Hays et al. Exploring the Optimal Operating MR Contrast for Brain Ventricle Parcellation. MIDL 2023.](https://openreview.net/pdf?id=3ndjE9eawkr) 185 | [Hays et al. Optimal operating MR contrast for brain ventricle parcellation. ISBI 2023.](https://arxiv.org/pdf/2304.02056) 186 | 187 | - ***Application #2: Automatic quality assurance.*** Since HACA3 has the ability of identifying images with high artifact 188 | levels, we use the HACA3 artifact encoder to do automatic quality assurance. 189 | - Publication: 190 | [Zuo et al. A latent space for unsupervised MR image quality control via artifact assessment. SPIE 2023.](https://arxiv.org/pdf/2302.00528) 191 | 192 | - ***Application #3: Consistent longitudinal analysis.*** We have identified that inconsistent acquisition can cause 193 | significant issues in longitudinal volumetric analysis, and harmonization is a solution to alleviate this issue of inconsistency. 194 | - Publication: 195 | [Zuo et al. Inconsistent MR Acquisition in Longitudinal Volumetric Analysis: Impacts and Solutions. CMSC 2023.](https://cmsc.confex.com/cmsc/2023/meetingapp.cgi/Paper/8967) 196 | - Video presentation on [YouTube](https://www.youtube.com/watch?v=TpdB55wxgs4&t=2s) 197 | 198 | - ***Application #4: Quantifying scanner difference from images.*** In many cases, scanner and acquisition information is 199 | not immediately available from NIFTI files. The contrast encoder in HACA3 and our previous harmonization model 200 | [CALAMITI](https://www.sciencedirect.com/science/article/pii/S1053811921008429) provides a way to capture these acquisition differences 201 | from MR images themselves. This information can be used to inform downstream tasks about the level of data heterogeneity. 202 | - Publication: 203 | [Hays et al. Quantifying Contrast Differences Among Magnetic Resonance Images Used in Clinical Studies. CMSC 2023.](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=pMxz1VYAAAAJ&citation_for_view=pMxz1VYAAAAJ:qjMakFHDy7sC) 204 | 205 | 206 | ## 6. Acknowledgements 207 | Special thanks to Samuel Remedios, Blake Dewey, and Yihao Liu for their feedbacks on HACA3 code release and this GitHub page. 208 | 209 | The authors thank BLSA participants, as well as colleagues of the Laboratory of Behavioral Neuroscience (LBN) of NIA and 210 | the Image Analysis and Communications Laboratory (IACL) of JHU. 211 | This work was supported in part by the Intramural Research Program of the National Institutes of Health, 212 | National Institute on Aging, 213 | in part by the TREAT-MS study funded by the Patient-Centered Outcomes Research Institute (PCORI) grant MS-1610-37115 214 | (Co-PIs: Drs. S.D. Newsome and E.M. Mowry), 215 | in part by the National Science Foundation Graduate Research Fellowship under Grant No. DGE-1746891, 216 | in part by the NIH grant (R01NS082347, PI: P. Calabresi), National Multiple Sclerosis Society grant (RG-1907-34570, PI: D. Pham), 217 | and the DOD/Congressionally Directed Medical Research Programs (CDMRP) grant (MS190131, PI: J. Prince). 218 | -------------------------------------------------------------------------------- /haca3/modules/model.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import random 4 | import torch 5 | from torch import nn 6 | from torch.optim import Adam 7 | from torch.optim.lr_scheduler import CyclicLR 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torch.utils.data import DataLoader 10 | import torchvision.models as models 11 | from torchvision.transforms import ToTensor 12 | from datetime import datetime 13 | import nibabel as nib 14 | from torch.cuda.amp import autocast 15 | 16 | from .utils import * 17 | from .dataset import HACA3Dataset 18 | from .network import UNet, ThetaEncoder, EtaEncoder, Patchifier, AttentionModule, FusionNet 19 | 20 | 21 | class HACA3: 22 | def __init__(self, beta_dim, theta_dim, eta_dim, pretrained_haca3=None, pretrained_eta_encoder=None, gpu_id=0): 23 | self.beta_dim = beta_dim 24 | self.theta_dim = theta_dim 25 | self.eta_dim = eta_dim 26 | self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') 27 | self.timestr = datetime.now().strftime("%Y%m%d-%H%M%S") 28 | 29 | self.train_loader, self.valid_loader = None, None 30 | self.out_dir = None 31 | self.optimizer = None 32 | self.scheduler = None 33 | self.writer, self.writer_path = None, None 34 | self.checkpoint = None 35 | 36 | self.l1_loss, self.kld_loss, self.contrastive_loss, self.perceptual_loss = None, None, None, None 37 | 38 | # define networks 39 | self.beta_encoder = UNet(in_ch=1, out_ch=self.beta_dim, base_ch=8, final_act='none') 40 | self.theta_encoder = ThetaEncoder(in_ch=1, out_ch=self.theta_dim) 41 | self.eta_encoder = EtaEncoder(in_ch=1, out_ch=self.eta_dim) 42 | self.attention_module = AttentionModule(self.theta_dim + self.eta_dim, v_ch=self.beta_dim) 43 | self.decoder = UNet(in_ch=1 + self.theta_dim, out_ch=1, base_ch=16, final_act='relu') 44 | self.patchifier = Patchifier(in_ch=1, out_ch=128) 45 | 46 | if pretrained_eta_encoder is not None: 47 | checkpoint_eta_encoder = torch.load(pretrained_eta_encoder, map_location=self.device) 48 | self.eta_encoder.load_state_dict(checkpoint_eta_encoder['eta_encoder']) 49 | if pretrained_haca3 is not None: 50 | self.checkpoint = torch.load(pretrained_haca3, map_location=self.device) 51 | self.beta_encoder.load_state_dict(self.checkpoint['beta_encoder']) 52 | self.theta_encoder.load_state_dict(self.checkpoint['theta_encoder']) 53 | self.eta_encoder.load_state_dict(self.checkpoint['eta_encoder']) 54 | self.decoder.load_state_dict(self.checkpoint['decoder']) 55 | self.attention_module.load_state_dict(self.checkpoint['attention_module']) 56 | self.patchifier.load_state_dict(self.checkpoint['patchifier']) 57 | self.beta_encoder.to(self.device) 58 | self.theta_encoder.to(self.device) 59 | self.eta_encoder.to(self.device) 60 | self.decoder.to(self.device) 61 | self.attention_module.to(self.device) 62 | self.patchifier.to(self.device) 63 | self.start_epoch = 0 64 | 65 | def initialize_training(self, out_dir, lr): 66 | # define loss functions 67 | self.l1_loss = nn.L1Loss(reduction='none') 68 | self.kld_loss = KLDivergenceLoss() 69 | vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(self.device) 70 | self.perceptual_loss = PerceptualLoss(vgg) 71 | self.contrastive_loss = PatchNCELoss() 72 | 73 | # define optimizer and learning rate scheduler 74 | self.optimizer = Adam(list(self.beta_encoder.parameters()) + 75 | list(self.theta_encoder.parameters()) + 76 | list(self.decoder.parameters()) + 77 | list(self.attention_module.parameters()) + 78 | list(self.patchifier.parameters()), lr=lr) 79 | self.scheduler = CyclicLR(self.optimizer, base_lr=4e-4, max_lr=7e-4, cycle_momentum=False) 80 | if self.checkpoint is not None: 81 | self.start_epoch = self.checkpoint['epoch'] 82 | self.optimizer.load_state_dict(self.checkpoint['optimizer']) 83 | self.scheduler.load_state_dict(self.checkpoint['scheduler']) 84 | if 'timestr' in self.checkpoint: 85 | self.timestr = self.checkpoint['timestr'] 86 | self.start_epoch = self.start_epoch + 1 87 | 88 | self.out_dir = out_dir 89 | mkdir_p(self.out_dir) 90 | mkdir_p(os.path.join(self.out_dir, f'training_results_{self.timestr}')) 91 | mkdir_p(os.path.join(self.out_dir, f'training_models_{self.timestr}')) 92 | 93 | self.writer_path = os.path.join(self.out_dir, self.timestr) 94 | self.writer = SummaryWriter(self.writer_path) 95 | 96 | def load_dataset(self, dataset_dirs, contrasts, orientations, batch_size, normalization_method='01'): 97 | train_dataset = HACA3Dataset(dataset_dirs, contrasts, orientations, 'train', normalization_method) 98 | valid_dataset = HACA3Dataset(dataset_dirs, contrasts, orientations, 'valid', normalization_method) 99 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) 100 | self.valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8) 101 | 102 | def calculate_theta(self, images): 103 | if isinstance(images, list): 104 | thetas, mus, logvars = [], [], [] 105 | for image in images: 106 | mu, logvar = self.theta_encoder(image) 107 | theta = torch.randn(mu.size()).to(self.device) * torch.sqrt(torch.exp(logvar)) + mu 108 | thetas.append(theta) 109 | mus.append(mu) 110 | logvars.append(logvar) 111 | else: 112 | mus, logvars = self.theta_encoder(images) 113 | thetas = torch.randn(mus.size()).to(self.device) * torch.sqrt(torch.exp(logvars)) + mus 114 | return thetas, mus, logvars 115 | 116 | def calculate_beta(self, images): 117 | logits, betas = [], [] 118 | for image in images: 119 | logit = self.beta_encoder(image) 120 | beta = self.channel_aggregation(reparameterize_logit(logit)) 121 | logits.append(logit) 122 | betas.append(beta) 123 | return logits, betas 124 | 125 | def calculate_eta(self, images): 126 | if isinstance(images, list): 127 | etas = [] 128 | for image in images: 129 | eta = self.eta_encoder(image) 130 | etas.append(eta) 131 | else: 132 | etas = self.eta_encoder(images) 133 | return etas 134 | 135 | def prepare_source_images(self, image_dicts): 136 | num_contrasts = len(image_dicts) 137 | num_contrasts_with_degradation = np.random.permutation(num_contrasts)[0] 138 | degradation_ids = sorted(np.random.choice(range(num_contrasts), 139 | num_contrasts_with_degradation, 140 | replace=False)) 141 | source_images = [] 142 | for i in range(num_contrasts): 143 | if i in degradation_ids: 144 | source_images.append(image_dicts[i]['image_degrade'].to(self.device)) 145 | else: 146 | source_images.append(image_dicts[i]['image'].to(self.device)) 147 | return source_images 148 | 149 | def channel_aggregation(self, beta_onehot_encode): 150 | """ 151 | Combine multi-channel one-hot encoded beta into one channel (label-encoding). 152 | 153 | ===INPUTS=== 154 | * beta_onehot_encode: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) 155 | One-hot encoded beta variable. At each pixel location, only one channel will take value of 1, 156 | and other channels will be 0. 157 | ===OUTPUTS=== 158 | * beta_label_encode: torch.Tensor (batch_size, 1, image_dim, image_dim) 159 | The intensity value of each pixel will be determined by the channel index with value of 1. 160 | """ 161 | batch_size = beta_onehot_encode.shape[0] 162 | image_dim = beta_onehot_encode.shape[3] 163 | value_tensor = (torch.arange(0, self.beta_dim) * 1.0).to(self.device) 164 | value_tensor = value_tensor.view(1, self.beta_dim, 1, 1).repeat(batch_size, 1, image_dim, image_dim) 165 | beta_label_encode = beta_onehot_encode * value_tensor.detach() 166 | return beta_label_encode.sum(1, keepdim=True) / self.beta_dim 167 | 168 | def select_available_contrasts(self, image_dicts): 169 | """ 170 | Select available contrasts as target. 171 | 172 | ===INPUTS=== 173 | * image_dicts: list (num_contrasts, ) 174 | List of dictionaries. Each element is a dictionary received from dataloader. See dataset.py for details. 175 | 176 | ===OUTPUTS=== 177 | * target_image: torch.Tensor (batch_size, 1, image_dim=224, image_dim=224) 178 | Images as target for I2I. 179 | * selected_contrast_id: torch.Tensor (batch_size, num_contrasts) 180 | Indicates which contrast has been selected as target image. 181 | """ 182 | target_image_combined = torch.cat([d['image'] for d in image_dicts], dim=1) 183 | # (batch_size, num_contrasts) 184 | available_contrasts = torch.stack([d['exists'] for d in image_dicts], dim=-1) 185 | subject_ids = available_contrasts.nonzero(as_tuple=True)[0] 186 | contrast_ids = available_contrasts.nonzero(as_tuple=True)[1] 187 | unique_subject_ids = list(torch.unique(subject_ids)) 188 | selected_contrast_ids = [] 189 | for i in unique_subject_ids: 190 | selected_contrast_ids.append(random.choice(contrast_ids[subject_ids == i])) 191 | target_image = target_image_combined[unique_subject_ids, selected_contrast_ids, ...].unsqueeze(1).to( 192 | self.device) 193 | selected_contrast_id = torch.zeros_like(available_contrasts).to(self.device) 194 | selected_contrast_id[unique_subject_ids, selected_contrast_ids, ...] = 1.0 195 | return target_image, selected_contrast_id 196 | 197 | def decode(self, logits, target_theta, query, keys, available_contrast_id, mask, contrast_dropout=False, 198 | contrast_id_to_drop=None): 199 | """ 200 | HACA3 decoding. 201 | 202 | ===INPUTS=== 203 | * logits: list (num_contrasts, ) 204 | Encoded logit of each source image. 205 | Each element has shape (batch_size, self.beta_dim, image_dim, image_dim). 206 | * target_theta: torch.Tensor (batch_size, self.theta_dim, 1, 1) 207 | theta values of target images used for decoding. 208 | * query: torch.Tensor (batch_size, self.theta_dim+self.eta_dim, 1, 1) 209 | query variable. Concatenation of "target_theta" and "target_eta". 210 | * keys: list (num_contrasts, ) 211 | keys variable. Each element has shape (batch_size, self.theta_dim+self.eta_dim) 212 | * available_contrast_id: torch.Tensor (batch_size, num_contrasts) 213 | Indicates which contrasts are available. 1: if available, 0: if unavailable. 214 | * contrast_dropout: bool 215 | Indicates if available contrasts will be randomly dropped out. 216 | 217 | ===OUTPUTS=== 218 | * rec_image: torch.Tensor (batch_size, 1, image_dim, image_dim) 219 | Synthetic image after decoding. 220 | * attention: torch.Tensor (batch_size, num_contrasts) 221 | Learned attention of each source image contrast. 222 | * logit_fusion: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) 223 | Optimal logit after fusion. 224 | * beta_fusion: torch.Tensor (batch_size, self.beta_dim, image_dim, image_dim) 225 | Optimal beta after fusion. beta_fusion = reparameterize_logit(logit_fusion). 226 | """ 227 | num_contrasts = len(logits) 228 | batch_size = logits[0].shape[0] 229 | image_dim = logits[0].shape[-1] 230 | 231 | # logits_combined: (batch_size, self.beta_dim, num_contrasts, image_dim * image_dim) 232 | logits_combined = torch.stack(logits, dim=-1).permute(0, 1, 4, 2, 3) 233 | logits_combined = logits_combined.view(batch_size, self.beta_dim, num_contrasts, image_dim * image_dim) 234 | 235 | # value: (batch_size, self.beta_dim, image_dim*image_dim, num_contrasts) 236 | v = logits_combined.permute(0, 1, 3, 2) 237 | # key: (batch_size, self.theta_dim+self.eta_dim, 1, num_contrasts) 238 | k = torch.cat(keys, dim=-1) 239 | # query: (batch_size, self.theta_dim+self.eta_dim, 1) 240 | q = query.view(batch_size, self.theta_dim + self.eta_dim, 1) 241 | 242 | if contrast_dropout: 243 | available_contrast_id = dropout_contrasts(available_contrast_id, contrast_id_to_drop) 244 | logit_fusion, attention = self.attention_module(q, k, v, modality_dropout=1 - available_contrast_id, 245 | temperature=10.0) 246 | beta_fusion = self.channel_aggregation(reparameterize_logit(logit_fusion)) 247 | combined_map = torch.cat([beta_fusion, target_theta.repeat(1, 1, image_dim, image_dim)], dim=1) 248 | rec_image = self.decoder(combined_map) * mask 249 | return rec_image, attention, logit_fusion, beta_fusion 250 | 251 | def calculate_features_for_contrastive_loss(self, betas, source_images, available_contrast_id): 252 | """ 253 | Prepare query, positive, and negative examples for calculating contrastive loss. 254 | 255 | ===INPUTS=== 256 | * betas: list (num_contrasts, ) 257 | Each element: torch.Tensor, (batch_size, self.beta_dim, 224, 224) 258 | * source_images: list(num_contrasts, ) 259 | Each element: torch.Tensor, (batch_size, 1, 224, 224) 260 | * available_contrast_id: torch.Tensor (batch_size, num_contrasts) 261 | Indicates which contrasts are available. 1: if available, 0: if unavailable. 262 | 263 | ===OUTPUTS=== 264 | * query_features: torch.Tensor (batch_size, 128, num_query_patches=49) 265 | Also called anchor features. Number of feature dimension (128) and 266 | number of patches (49) are determined by self.patchifier. 267 | * positive_features: torch.Tensor (batch_size, 128, num_positive_patches=49) 268 | Positive features are encouraged to be as close to query features as possible. 269 | Number of positive patches should be equal to the number of query patches. 270 | * negative_features: torch.Tensor (batch_size, 128, num_negative_patches) 271 | Negative features served as negative examples. They are pushed away from query features during training. 272 | Number of negative patches does not necessarily equal to "num_query_patches" or "num_positive_patches". 273 | """ 274 | batch_size = betas[0].shape[0] 275 | betas_stack = torch.stack(betas, dim=-1) 276 | source_images_stack = torch.stack(source_images, dim=-1) 277 | query_contrast_ids, positive_contrast_ids = [], [] 278 | for subject_id in range(batch_size): 279 | contrast_id_tmp = random.sample(set(available_contrast_id[subject_id].nonzero(as_tuple=True)[0]), 2) 280 | query_contrast_ids.append(contrast_id_tmp[0]) 281 | positive_contrast_ids.append(contrast_id_tmp[1]) 282 | query_example = torch.cat([betas_stack[[subject_id], :, :, :, query_contrast_ids[subject_id]] 283 | for subject_id in range(batch_size)], dim=0) 284 | query_feature = self.patchifier(query_example).view(batch_size, 128, -1) 285 | positive_example = torch.cat([betas_stack[[subject_id], :, :, :, positive_contrast_ids[subject_id]] 286 | for subject_id in range(batch_size)], dim=0) 287 | positive_feature = self.patchifier(positive_example).view(batch_size, 128, -1) 288 | num_positive_patches = positive_feature.shape[-1] 289 | negative_feature = torch.cat([ 290 | self.patchifier(torch.cat([source_images_stack[[subject_id], :, :, :, query_contrast_ids[subject_id]] 291 | for subject_id in range(batch_size)], dim=0)).view(batch_size, 128, -1), 292 | self.patchifier(torch.cat([source_images_stack[[subject_id], :, :, :, positive_contrast_ids[subject_id]] 293 | for subject_id in range(batch_size)], dim=0)).view(batch_size, 128, -1), 294 | self.patchifier(torch.cat([betas_stack[[subject_id], :, :, :, query_contrast_ids[subject_id]] 295 | for subject_id in range(batch_size)], dim=0)).view(batch_size, 128, -1)[:, :, 296 | torch.randperm(num_positive_patches)], 297 | self.patchifier(torch.cat([betas_stack[[subject_id], :, :, :, query_contrast_ids[subject_id]] 298 | for subject_id in range(batch_size)], dim=0)).view(batch_size, 128, -1)[ 299 | torch.randperm(batch_size), :, :], 300 | self.patchifier(torch.cat([betas_stack[[subject_id], :, :, :, positive_contrast_ids[subject_id]] 301 | for subject_id in range(batch_size)], dim=0)).view(batch_size, 128, -1)[:, :, 302 | torch.randperm(num_positive_patches)], 303 | self.patchifier(torch.cat([betas_stack[[subject_id], :, :, :, positive_contrast_ids[subject_id]] 304 | for subject_id in range(batch_size)], dim=0)).view(batch_size, 128, -1)[ 305 | torch.randperm(batch_size), :, :] 306 | ], dim=-1) 307 | return query_feature, positive_feature, negative_feature 308 | 309 | def calculate_loss(self, rec_image, ref_image, mask, mu, logvar, betas, source_images, available_contrast_id, 310 | is_train=True): 311 | """ 312 | Calculate losses for HACA3 training and validation. 313 | 314 | """ 315 | # 1. reconstruction loss 316 | rec_loss = self.l1_loss(rec_image[mask], ref_image[mask]).mean() 317 | perceptual_loss = self.perceptual_loss(rec_image, ref_image).mean() 318 | 319 | # 2. KLD loss 320 | kld_loss = self.kld_loss(mu, logvar).mean() 321 | 322 | # 3. beta contrastive loss 323 | query_feature, \ 324 | positive_feature, \ 325 | negative_feature = self.calculate_features_for_contrastive_loss(betas, source_images, available_contrast_id) 326 | beta_loss = self.contrastive_loss(query_feature, positive_feature.detach(), negative_feature.detach()) 327 | 328 | # COMBINE LOSSES 329 | total_loss = 10 * rec_loss + 5e-1 * perceptual_loss + 1e-5 * kld_loss + 5e-1 * beta_loss 330 | if is_train: 331 | self.optimizer.zero_grad() 332 | total_loss.backward() 333 | self.optimizer.step() 334 | self.scheduler.step() 335 | loss = {'rec_loss': rec_loss.item(), 336 | 'per_loss': perceptual_loss.item(), 337 | 'kld_loss': kld_loss.item(), 338 | 'beta_loss': beta_loss.item(), 339 | 'total_loss': total_loss.item()} 340 | return loss 341 | 342 | def calculate_cycle_consistency_loss(self, theta_rec, theta_ref, eta_rec, eta_ref, beta_rec, beta_ref, 343 | is_train=True): 344 | theta_loss = self.l1_loss(theta_rec, theta_ref).mean() 345 | eta_loss = self.l1_loss(eta_rec, eta_ref).mean() 346 | beta_loss = self.l1_loss(beta_rec, beta_ref).mean() 347 | 348 | cycle_loss = theta_loss + eta_loss + 5e-2 * beta_loss 349 | if is_train: 350 | self.optimizer.zero_grad() 351 | (5e-2 * cycle_loss).backward() 352 | self.optimizer.step() 353 | self.scheduler.step() 354 | loss = {'theta_cyc': theta_loss.item(), 355 | 'eta_cyc': eta_loss.item(), 356 | 'beta_cyc': beta_loss.item()} 357 | return loss 358 | 359 | def write_tensorboard(self, loss, epoch, batch_id, train_or_valid='train', cycle_loss=None): 360 | if train_or_valid == 'train': 361 | curr_iteration = (epoch - 1) * len(self.train_loader) + batch_id 362 | self.writer.add_scalar(f'{train_or_valid}/learning rate', self.scheduler.get_last_lr()[0], curr_iteration) 363 | else: 364 | curr_iteration = (epoch - 1) * len(self.valid_loader) + batch_id 365 | self.writer.add_scalar(f'{train_or_valid}/reconstruction loss', loss['rec_loss'], curr_iteration) 366 | self.writer.add_scalar(f'{train_or_valid}/perceptual loss', loss['per_loss'], curr_iteration) 367 | self.writer.add_scalar(f'{train_or_valid}/kld loss', loss['kld_loss'], curr_iteration) 368 | self.writer.add_scalar(f'{train_or_valid}/beta loss', loss['beta_loss'], curr_iteration) 369 | self.writer.add_scalar(f'{train_or_valid}/total loss', loss['total_loss'], curr_iteration) 370 | if cycle_loss is not None: 371 | self.writer.add_scalar(f'{train_or_valid}/theta cycle loss', cycle_loss['theta_cyc'], curr_iteration) 372 | self.writer.add_scalar(f'{train_or_valid}/eta cycle loss', cycle_loss['eta_cyc'], curr_iteration) 373 | self.writer.add_scalar(f'{train_or_valid}/beta cycle loss', cycle_loss['beta_cyc'], curr_iteration) 374 | 375 | def save_model(self, epoch, file_name): 376 | state = {'epoch': epoch, 377 | 'timestr': self.timestr, 378 | 'beta_encoder': self.beta_encoder.state_dict(), 379 | 'theta_encoder': self.theta_encoder.state_dict(), 380 | 'eta_encoder': self.eta_encoder.state_dict(), 381 | 'decoder': self.decoder.state_dict(), 382 | 'attention_module': self.attention_module.state_dict(), 383 | 'patchifier': self.patchifier.state_dict(), 384 | 'optimizer': self.optimizer.state_dict(), 385 | 'scheduler': self.scheduler.state_dict()} 386 | torch.save(obj=state, f=file_name) 387 | 388 | def image_to_image_translation(self, batch_id, epoch, image_dicts, train_or_valid): 389 | if train_or_valid == 'train': 390 | contrast_dropout = True 391 | is_train = True 392 | else: 393 | contrast_dropout = False 394 | is_train = False 395 | 396 | source_images = self.prepare_source_images(image_dicts) 397 | mask = image_dicts[0]['mask'].to(self.device) 398 | target_image, contrast_id_for_decoding = self.select_available_contrasts(image_dicts) 399 | # available_contrast_id: (batch_size, num_contrasts). 1: if available, 0: otherwise. 400 | available_contrast_id = torch.stack([d['exists'] for d in image_dicts], dim=-1).to(self.device) 401 | batch_size = source_images[0].shape[0] 402 | 403 | # ====== 1. INTRA-SITE IMAGE-TO-IMAGE TRANSLATION ====== 404 | logits, betas = self.calculate_beta(source_images) 405 | thetas_source, _, _ = self.calculate_theta(source_images) 406 | etas_source = self.calculate_eta(source_images) 407 | theta_target, mu_target, logvar_target = self.calculate_theta(target_image) 408 | eta_target = self.calculate_eta(target_image) 409 | query = torch.cat([theta_target, eta_target], dim=1) 410 | keys = [torch.cat([theta, eta], dim=1) for (theta, eta) in zip(thetas_source, etas_source)] 411 | if torch.rand((1,)) > 0.2: 412 | contrast_id_to_drop = contrast_id_for_decoding 413 | else: 414 | contrast_id_to_drop = None 415 | rec_image, attention, logit_fusion, beta_fusion = self.decode(logits, theta_target, query, keys, 416 | available_contrast_id, 417 | mask, 418 | contrast_dropout=contrast_dropout, 419 | contrast_id_to_drop=contrast_id_to_drop) 420 | loss = self.calculate_loss(rec_image, target_image, mask, mu_target, logvar_target, 421 | betas, source_images, available_contrast_id, is_train=is_train) 422 | 423 | # ====== 2. SAVE IMAGES OF INTRA-SITE I2I ====== 424 | if batch_id % 100 == 1: 425 | file_name = os.path.join(self.out_dir, f'training_results_{self.timestr}', 426 | f'{train_or_valid}_epoch{str(epoch).zfill(3)}_batch{str(batch_id).zfill(4)}' 427 | '_intra-site.nii.gz') 428 | save_image(source_images + [rec_image] + [target_image] + betas + [beta_fusion], file_name) 429 | 430 | # ====== 3. INTER-SITE IMAGE-TO-IMAGE TRANSLATION ====== 431 | if epoch > 1: 432 | random_index = torch.randperm(batch_size) 433 | target_image_shuffled = target_image[random_index, ...] 434 | logits, betas = self.calculate_beta(source_images) 435 | thetas_source, _, _ = self.calculate_theta(source_images) 436 | etas_source = self.calculate_eta(source_images) 437 | theta_target, mu_target, logvar_target = self.calculate_theta(target_image_shuffled) 438 | eta_target = self.calculate_eta(target_image_shuffled) 439 | query = torch.cat([theta_target, eta_target], dim=1) 440 | keys = [torch.cat([theta, eta], dim=1) for (theta, eta) in zip(thetas_source, etas_source)] 441 | rec_image, attention, logit_fusion, beta_fusion = self.decode(logits, theta_target, query, keys, 442 | available_contrast_id, mask, 443 | contrast_dropout=True) 444 | theta_recon, _ = self.theta_encoder(rec_image) 445 | eta_recon = self.eta_encoder(rec_image) 446 | beta_recon = self.channel_aggregation(reparameterize_logit(self.beta_encoder(rec_image))) 447 | cycle_loss = self.calculate_cycle_consistency_loss(theta_recon, theta_target.detach(), 448 | eta_recon, eta_target.detach(), 449 | beta_recon, beta_fusion.detach(), 450 | is_train=is_train) 451 | 452 | # ====== 4. SAVE IMAGES FOR INTER-SITE I2I ====== 453 | if epoch > 1 and batch_id % 100 == 1: 454 | file_name = os.path.join(self.out_dir, f'training_results_{self.timestr}', 455 | f'{train_or_valid}_epoch{str(epoch).zfill(3)}_batch{str(batch_id).zfill(4)}' 456 | '_inter-site.nii.gz') 457 | save_image(source_images + [rec_image] + [target_image_shuffled] + betas + [beta_fusion], file_name) 458 | 459 | # ====== 5. VISUALIZE LOSSES FOR INTRA- AND INTER-SITE I2I ====== 460 | if epoch > 1: 461 | if is_train: 462 | self.train_loader.set_description((f'epoch: {epoch}; ' 463 | f'rec: {loss["rec_loss"]:.3f}; ' 464 | f'per: {loss["per_loss"]:.3f}; ' 465 | f'kld: {loss["kld_loss"]:.3f}; ' 466 | f'beta: {loss["beta_loss"]:.3f}; ' 467 | f'theta_c: {cycle_loss["theta_cyc"]:.3f}; ' 468 | f'eta_c: {cycle_loss["eta_cyc"]:.3f}; ' 469 | f'beta_c: {cycle_loss["beta_cyc"]:.3f}; ')) 470 | else: 471 | self.valid_loader.set_description((f'epoch: {epoch}; ' 472 | f'rec: {loss["rec_loss"]:.3f}; ' 473 | f'per: {loss["per_loss"]:.3f}; ' 474 | f'kld: {loss["kld_loss"]:.3f}; ' 475 | f'beta: {loss["beta_loss"]:.3f}; ' 476 | f'theta_c: {cycle_loss["theta_cyc"]:.3f}; ' 477 | f'eta_c: {cycle_loss["eta_cyc"]:.3f}; ' 478 | f'beta_c: {cycle_loss["beta_cyc"]:.3f}; ')) 479 | self.write_tensorboard(loss, epoch, batch_id, train_or_valid, cycle_loss) 480 | else: 481 | if is_train: 482 | self.train_loader.set_description((f'epoch: {epoch}; ' 483 | f'rec: {loss["rec_loss"]:.3f}; ' 484 | f'per: {loss["per_loss"]:.3f}; ' 485 | f'kld: {loss["kld_loss"]:.3f}; ' 486 | f'beta: {loss["beta_loss"]:.3f}; ')) 487 | else: 488 | self.valid_loader.set_description((f'epoch: {epoch}; ' 489 | f'rec: {loss["rec_loss"]:.3f}; ' 490 | f'per: {loss["per_loss"]:.3f}; ' 491 | f'kld: {loss["kld_loss"]:.3f}; ' 492 | f'beta: {loss["beta_loss"]:.3f}; ')) 493 | self.write_tensorboard(loss, epoch, batch_id, train_or_valid) 494 | 495 | # ====== 6. SAVE TRAINED MODELS ====== 496 | if batch_id % 2000 == 0 and is_train: 497 | file_name = os.path.join(self.out_dir, f'training_models_{self.timestr}', 498 | f'epoch{str(epoch).zfill(3)}_batch{str(batch_id).zfill(4)}_model.pt') 499 | self.save_model(epoch, file_name) 500 | 501 | def train(self, epochs): 502 | for epoch in range(self.start_epoch, epochs + 1): 503 | # ====== 1. TRAINING ====== 504 | self.train_loader = tqdm(self.train_loader) 505 | self.eta_encoder.eval() 506 | self.theta_encoder.train() 507 | self.beta_encoder.train() 508 | self.decoder.train() 509 | self.attention_module.train() 510 | self.patchifier.train() 511 | for batch_id, image_dicts in enumerate(self.train_loader): 512 | self.image_to_image_translation(batch_id, epoch, image_dicts, train_or_valid='train') 513 | 514 | # ====== 2. VALIDATION ====== 515 | self.valid_loader = tqdm(self.valid_loader) 516 | self.beta_encoder.eval() 517 | self.eta_encoder.eval() 518 | self.theta_encoder.eval() 519 | self.decoder.eval() 520 | self.patchifier.eval() 521 | self.attention_module.eval() 522 | with torch.set_grad_enabled(False): 523 | for batch_id, image_dicts in enumerate(self.valid_loader): 524 | self.image_to_image_translation(batch_id, epoch, image_dicts, train_or_valid='valid') 525 | 526 | def harmonize(self, source_images, target_images, target_theta, target_eta, out_paths, 527 | recon_orientation, norm_vals, header=None, num_batches=4, save_intermediate=False, intermediate_out_dir=None): 528 | if out_paths is not None: 529 | for out_path in out_paths: 530 | mkdir_p(out_path.parent) 531 | if save_intermediate: 532 | mkdir_p(intermediate_out_dir) 533 | if out_paths is not None: 534 | prefix = out_paths[0].name.replace('.nii.gz', '') 535 | with torch.set_grad_enabled(False): 536 | self.beta_encoder.eval() 537 | self.theta_encoder.eval() 538 | self.eta_encoder.eval() 539 | self.decoder.eval() 540 | 541 | # === 1. CALCULATE BETA, THETA, ETA FROM SOURCE IMAGES === 542 | logits, betas, keys, masks = [], [], [], [] 543 | for source_image in source_images: 544 | source_image = source_image.unsqueeze(1) 545 | source_image_batches = divide_into_batches(source_image, num_batches) 546 | mask_tmp, logit_tmp, beta_tmp, key_tmp = [], [], [], [] 547 | for source_image_batch in source_image_batches: 548 | batch_size = source_image_batch.shape[0] 549 | source_image_batch = source_image_batch.to(self.device) 550 | mask = (source_image_batch > 1e-6) * 1.0 551 | logit = self.beta_encoder(source_image_batch) 552 | beta = self.channel_aggregation(reparameterize_logit(logit)) 553 | theta_source, _ = self.theta_encoder(source_image_batch) 554 | eta_source = self.eta_encoder(source_image_batch).view(batch_size, self.eta_dim, 1, 1) 555 | mask_tmp.append(mask) 556 | logit_tmp.append(logit) 557 | beta_tmp.append(beta) 558 | key_tmp.append(torch.cat([theta_source, eta_source], dim=1)) 559 | masks.append(torch.cat(mask_tmp, dim=0)) 560 | logits.append(torch.cat(logit_tmp, dim=0)) 561 | betas.append(torch.cat(beta_tmp, dim=0)) 562 | keys.append(torch.cat(key_tmp, dim=0)) 563 | 564 | # === 2. CALCULATE THETA, ETA FOR TARGET IMAGES (IF NEEDED) === 565 | if target_theta is None: 566 | queries, thetas_target = [], [] 567 | for target_image in target_images: 568 | target_image = target_image.to(self.device).unsqueeze(1) 569 | theta_target, _ = self.theta_encoder(target_image) 570 | theta_target = theta_target.mean(dim=0, keepdim=True) 571 | eta_target = self.eta_encoder(target_image).mean(dim=0, keepdim=True).view(1, self.eta_dim, 1, 1) 572 | thetas_target.append(theta_target) 573 | queries.append( 574 | torch.cat([theta_target, eta_target], dim=1).view(1, self.theta_dim + self.eta_dim, 1)) 575 | if save_intermediate: 576 | # save theta and eta of target images 577 | with open(intermediate_out_dir / f'{prefix}_targets.txt', 'w') as fp: 578 | fp.write(','.join(['img'] + [f'theta{i}' for i in range(self.theta_dim)] + 579 | [f'eta{i}' for i in range(self.eta_dim)]) + '\n') 580 | for i, img_query in enumerate([query.squeeze().cpu().numpy().tolist() for query in queries]): 581 | fp.write(','.join([f'target{i}'] + ['%.6f' % val for val in img_query]) + '\n') 582 | else: 583 | queries, thetas_target = [], [] 584 | for target_theta_tmp, target_eta_tmp in zip(target_theta, target_eta): 585 | thetas_target.append(target_theta_tmp.view(1, self.theta_dim, 1, 1).to(self.device)) 586 | queries.append(torch.cat([target_theta_tmp.view(1, self.theta_dim, 1).to(self.device), 587 | target_eta_tmp.view(1, self.eta_dim, 1).to(self.device)], dim=1)) 588 | 589 | # === 3. SAVE ENCODED VARIABLES (IF REQUESTED) === 590 | if save_intermediate and header is not None: 591 | if recon_orientation == 'axial': 592 | # 3a. source images 593 | for i, source_img in enumerate(source_images): 594 | img_save = source_img.squeeze().permute(1, 2, 0).permute(1, 0, 2).cpu().numpy() 595 | img_save = img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96] 596 | nib.Nifti1Image(img_save, None, header).to_filename( 597 | intermediate_out_dir / f'{prefix}_source{i}.nii.gz' 598 | ) 599 | # 3b. beta images 600 | beta = torch.stack(betas, dim=-1) 601 | if len(beta.shape) > 4: 602 | beta = beta.squeeze() 603 | beta = beta.permute(1, 2, 0, 3).permute(1, 0, 2, 3).cpu().numpy() 604 | img_save = nib.Nifti1Image(beta[112 - 96:112 + 96, :, 112 - 96:112 + 96, :], None, header) 605 | file_name = intermediate_out_dir / f'{prefix}_source_betas.nii.gz' 606 | nib.save(img_save, file_name) 607 | # 3c. theta/eta values 608 | with open(intermediate_out_dir / f'{prefix}_sources.txt', 'w') as fp: 609 | fp.write(','.join(['img', 'slice'] + [f'theta{i}' for i in range(self.theta_dim)] + 610 | [f'eta{i}' for i in range(self.eta_dim)]) + '\n') 611 | for i, img_key in enumerate([key.squeeze().cpu().numpy().tolist() for key in keys]): 612 | for j, slice_key in enumerate(img_key): 613 | fp.write(','.join([f'source{i}', f'slice{j:03d}'] + 614 | ['%.6f' % val for val in slice_key]) + '\n') 615 | 616 | # ===4. DECODING=== 617 | for tid, (theta_target, query, norm_val) in enumerate(zip(thetas_target, queries, norm_vals)): 618 | if out_paths is not None: 619 | out_prefix = out_paths[tid].name.replace('.nii.gz', '') 620 | rec_image, beta_fusion, logit_fusion, attention = [], [], [], [] 621 | for batch_id in range(num_batches): 622 | keys_tmp = [divide_into_batches(ks, num_batches)[batch_id] for ks in keys] 623 | logits_tmp = [divide_into_batches(ls, num_batches)[batch_id] for ls in logits] 624 | masks_tmp = [divide_into_batches(ms, num_batches)[batch_id] for ms in masks] 625 | batch_size = keys_tmp[0].shape[0] 626 | query_tmp = query.view(1, self.theta_dim + self.eta_dim, 1).repeat(batch_size, 1, 1) 627 | k = torch.cat(keys_tmp, dim=-1).view(batch_size, self.theta_dim + self.eta_dim, 1, len(source_images)) 628 | v = torch.stack(logits_tmp, dim=-1).view(batch_size, self.beta_dim, 224 * 224, len(source_images)) 629 | logit_fusion_tmp, attention_tmp = self.attention_module(query_tmp, k, v, None, 5.0) 630 | beta_fusion_tmp = self.channel_aggregation(reparameterize_logit(logit_fusion_tmp)) 631 | combined_map = torch.cat([beta_fusion_tmp, theta_target.repeat(batch_size, 1, 224, 224)], dim=1) 632 | rec_image_tmp = self.decoder(combined_map) * masks_tmp[0] 633 | 634 | rec_image.append(rec_image_tmp) 635 | beta_fusion.append(beta_fusion_tmp) 636 | logit_fusion.append(logit_fusion_tmp) 637 | attention.append(attention_tmp) 638 | 639 | rec_image = torch.cat(rec_image, dim=0) 640 | beta_fusion = torch.cat(beta_fusion, dim=0) 641 | logit_fusion = torch.cat(logit_fusion, dim=0) 642 | attention = torch.cat(attention, dim=0) 643 | 644 | # ===5. SAVE INTERMEDIATE RESULTS (IF REQUESTED)=== 645 | # harmonized image 646 | if header is not None: 647 | if recon_orientation == "axial": 648 | img_save = np.array(rec_image.cpu().squeeze().permute(1, 2, 0).permute(1, 0, 2)) 649 | elif recon_orientation == "coronal": 650 | img_save = np.array(rec_image.cpu().squeeze().permute(0, 2, 1).flip(2).permute(1, 0, 2)) 651 | else: 652 | img_save = np.array(rec_image.cpu().squeeze().permute(2, 0, 1).flip(2).permute(1, 0, 2)) 653 | img_save = nib.Nifti1Image((img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96]) * norm_val, None, 654 | header) 655 | file_name = out_path.parent / f'{out_prefix}_harmonized_{recon_orientation}.nii.gz' 656 | nib.save(img_save, file_name) 657 | 658 | if save_intermediate and header is not None: 659 | # 5a. beta fusion 660 | if recon_orientation == 'axial': 661 | img_save = beta_fusion.squeeze().permute(1, 2, 0).permute(1, 0, 2).cpu().numpy() 662 | img_save = nib.Nifti1Image(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96], None, header) 663 | file_name = intermediate_out_dir / f'{out_prefix}_beta_fusion.nii.gz' 664 | nib.save(img_save, file_name) 665 | # 5b. logit fusion 666 | if recon_orientation == 'axial': 667 | img_save = logit_fusion.permute(2, 3, 0, 1).permute(1, 0, 2, 3).cpu().numpy() 668 | img_save = nib.Nifti1Image(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96, :], None, header) 669 | file_name = intermediate_out_dir / f'{out_prefix}_logit_fusion.nii.gz' 670 | nib.save(img_save, file_name) 671 | # 5c. attention 672 | if recon_orientation == 'axial': 673 | img_save = attention.permute(2, 3, 0, 1).permute(1, 0, 2, 3).cpu().numpy() 674 | img_save = nib.Nifti1Image(img_save[112 - 96:112 + 96, :, 112 - 96:112 + 96], None, header) 675 | file_name = intermediate_out_dir / f'{out_prefix}_attention.nii.gz' 676 | nib.save(img_save, file_name) 677 | if header is None: 678 | return rec_image.cpu().squeeze() 679 | 680 | def combine_images(self, image_paths, out_path, norm_val, pretrained_fusion=None): 681 | # obtain images 682 | images = [] 683 | for image_path in image_paths: 684 | image_pad = torch.zeros((224, 224, 224)) 685 | image_obj = nib.load(image_path) 686 | image_vol, _ = normalize_intensity(torch.from_numpy(image_obj.get_fdata().astype(np.float32))) 687 | image_pad[112 - 96:112 + 96, :, 112 - 96:112 + 96] = image_vol 688 | image_header = image_obj.header 689 | images.append(image_pad.numpy()) 690 | 691 | if pretrained_fusion is not None: 692 | checkpoint = torch.load(pretrained_fusion, map_location=self.device) 693 | fusion_net = FusionNet(in_ch=3, out_ch=1) 694 | fusion_net.load_state_dict(checkpoint['fusion_net']) 695 | fusion_net.to(self.device) 696 | fusion_net.eval() 697 | with autocast(): 698 | image = torch.cat( 699 | [ToTensor()(im).permute(2, 1, 0).permute(2, 0, 1).unsqueeze(0).unsqueeze(0) for im in images], 700 | dim=1).to(self.device) 701 | image_fusion = fusion_net(image).squeeze().detach().permute(1, 2, 0).permute(1, 0, 2).cpu().numpy() 702 | else: 703 | # calculate median 704 | image_cat = np.stack(images, axis=-1) 705 | image_fusion = np.median(image_cat, axis=-1) 706 | 707 | # save fusion_image 708 | img_save = image_fusion[112 - 96:112 + 96, :, 112 - 96:112 + 96] * norm_val 709 | img_save = nib.Nifti1Image(img_save, None, image_header) 710 | prefix = out_path.name.replace('.nii.gz', '') 711 | file_name = out_path.parent / f'{prefix}_harmonized_fusion.nii.gz' 712 | nib.save(img_save, file_name) 713 | --------------------------------------------------------------------------------