├── dataloaders ├── __init__.py ├── ixi.py ├── torchiowrap.py ├── ixi_torchiowrap.py └── mood.py ├── requirements.txt ├── engine.py ├── .gitignore ├── ceVae ├── skipae.py ├── helpers.py ├── ce_noise.py ├── aes.py └── ae_bases.py ├── misc ├── skipae.py ├── cevae.py └── ssae.py ├── ccevae.py ├── train.py ├── README.md ├── helpers.py ├── ce_noise.py ├── evaluate.py ├── Pipeline.ipynb ├── ae_bases.py └── DEMO.md /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .mood import MoodTrainSet, MoodValSet 2 | from .torchiowrap import H5DSImage 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # StRegA Requirements 2 | # Core dependencies 3 | torch>=1.9.0 4 | torchvision>=0.10.0 5 | numpy>=1.19.0 6 | scipy>=1.7.0 7 | 8 | # Medical imaging 9 | nibabel>=3.2.0 10 | h5py>=3.0.0 11 | torchio>=0.18.0 12 | 13 | # Image processing 14 | scikit-image>=0.18.0 15 | 16 | # Utilities 17 | tqdm>=4.60.0 18 | pandas>=1.3.0 19 | matplotlib>=3.4.0 20 | seaborn>=0.11.0 21 | 22 | # HuggingFace (for pretrained model) 23 | transformers>=4.20.0 24 | 25 | # Optional: Experiment tracking 26 | # wandb>=0.12.0 27 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | from model import VAE 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 7 | input_dim = (256, 256) # Slice shapes 8 | input_size = (1,256,256) 9 | z_dim = 1024 10 | model_feature_map_sizes=(16, 64, 256, 1024) # Compact vae 11 | 12 | if len(input_dim) == 2: 13 | conv = nn.Conv2d 14 | convt = nn.ConvTranspose2d 15 | d = 2 16 | else: 17 | conv = nn.Conv3d 18 | convt = nn.ConvTranspose3d 19 | d = 3 20 | 21 | model = VAE(input_size=input_size, z_dim=z_dim, fmap_sizes=model_feature_map_sizes, 22 | conv_op=conv, 23 | tconv_op=convt, 24 | activation_op=torch.nn.PReLU) 25 | 26 | model.d = d 27 | model.to(device) 28 | 29 | lr = 1e-4 30 | optimizer = Adam(model.parameters(), lr=lr) 31 | scaler = GradScaler() 32 | 33 | # Call train function in train.py with created dataloaders 34 | -------------------------------------------------------------------------------- /dataloaders/ixi.py: -------------------------------------------------------------------------------- 1 | import h5py as h5 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchio.data.subject import Subject 7 | from .ixi_torchiowrap import IXI_H5DSImage 8 | 9 | class IXITrainSet(Dataset): 10 | def __init__(self, indices=None,data_path='Ixi_with_skull.h5', torchiosub=True, lazypatch=True, preload=False): 11 | self.h5 = h5.File(data_path, 'r', swmr=True) 12 | self.samples = [] 13 | if indices: 14 | self.samples = [self.h5[str(i).zfill(5)]for i in indices] 15 | # self.samples2 = [self.h5[region][str(i).zfill(5)][:] for i in indices] 16 | else: 17 | self.samples = [self.h5[i] for i in list(self.h5[region])] 18 | if preload: 19 | print('Preloading MoodTrainSet') 20 | for i in range(len(self.samples)): 21 | self.samples[i] = self.samples[i][:] 22 | self.torchiosub = torchiosub 23 | self.lazypatch = lazypatch 24 | 25 | def __len__(self): 26 | return len(self.samples) 27 | 28 | def __getitem__(self, item): 29 | if self.torchiosub: 30 | return Subject({'img':IXI_H5DSImage(self.samples[item], lazypatch=self.lazypatch)}) 31 | else: 32 | return torch.from_numpy(self.samples[item][()]).unsqueeze(0) 33 | 34 | -------------------------------------------------------------------------------- /dataloaders/torchiowrap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import nibabel as nib 4 | 5 | from torchio.data.image import Image 6 | import torchio 7 | 8 | class H5DSImage(Image): 9 | def __init__(self, h5DS=None, lazypatch=True, imtype=torchio.INTENSITY, **kwargs): 10 | kwargs['path'] = '' 11 | kwargs['type'] = imtype 12 | super().__init__(**kwargs) 13 | self.h5DS = h5DS 14 | self.lazypatch = lazypatch 15 | if not self.lazypatch: 16 | self.load() 17 | 18 | def load(self) -> None: 19 | if self._loaded: 20 | return 21 | if self.lazypatch: 22 | tensor, affine = self.h5DS, np.eye(4) 23 | else: 24 | tensor, affine = self.read_and_check_h5(self.h5DS) 25 | self[torchio.DATA] = tensor 26 | self[torchio.AFFINE] = affine 27 | self._loaded = True 28 | 29 | @property 30 | def spatial_shape(self): 31 | if self.lazypatch: 32 | return self.shape 33 | else: 34 | return self.shape[1:] 35 | 36 | def crop(self, index_ini, index_fin): 37 | new_origin = nib.affines.apply_affine(self.affine, index_ini) 38 | new_affine = self.affine.copy() 39 | new_affine[:3, 3] = new_origin 40 | i0, j0, k0 = index_ini 41 | i1, j1, k1 = index_fin 42 | if len(self.data.shape) == 4: 43 | patch = self.data[:, i0:i1, j0:j1, k0:k1] 44 | else: 45 | patch = np.expand_dims(self.data[i0:i1, j0:j1, k0:k1], 0) 46 | if not isinstance(self.data, torch.Tensor): 47 | patch = torch.from_numpy(patch) 48 | kwargs = dict( 49 | tensor=patch, 50 | affine=new_affine, 51 | type=self.type, 52 | path=self.path, 53 | h5DS=self.h5DS 54 | ) 55 | for key, value in self.items(): 56 | if key in torchio.data.image.PROTECTED_KEYS: continue 57 | kwargs[key] = value 58 | return self.__class__(**kwargs) 59 | 60 | def read_and_check_h5(self, h5DS): 61 | tensor, affine = torch.from_numpy(h5DS[()]).unsqueeze(0), np.eye(4) 62 | tensor = super().parse_tensor_shape(tensor) 63 | if self.channels_last: 64 | tensor = tensor.permute(3, 0, 1, 2) 65 | if self.check_nans and torch.isnan(tensor).any(): 66 | warnings.warn(f'NaNs found in file "{path}"') 67 | return tensor, affine -------------------------------------------------------------------------------- /dataloaders/ixi_torchiowrap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import nibabel as nib 4 | 5 | from torchio.data.image import Image 6 | import torchio 7 | 8 | class IXI_H5DSImage(Image): 9 | def __init__(self, h5DS=None, lazypatch=True, imtype=torchio.INTENSITY, **kwargs): 10 | kwargs['path'] = '' 11 | kwargs['type'] = imtype 12 | super().__init__(**kwargs) 13 | self.h5DS = h5DS 14 | self.lazypatch = lazypatch 15 | 16 | if not self.lazypatch: 17 | self.load() 18 | 19 | def load(self) -> None: 20 | if self._loaded: 21 | return 22 | if self.lazypatch: 23 | tensor, affine = self.h5DS, np.eye(4) 24 | else: 25 | tensor, affine = self.read_and_check_h5(self.h5DS) 26 | self[torchio.DATA] = tensor 27 | self[torchio.AFFINE] = affine 28 | self._loaded = True 29 | 30 | @property 31 | def spatial_shape(self): 32 | if self.lazypatch: 33 | return self.shape 34 | else: 35 | return self.shape[1:] 36 | 37 | def crop(self, index_ini, index_fin): 38 | new_origin = nib.affines.apply_affine(self.affine, index_ini) 39 | new_affine = self.affine.copy() 40 | new_affine[:3, 3] = new_origin 41 | i0, j0, k0 = index_ini 42 | i1, j1, k1 = index_fin 43 | if len(self.data.shape) == 4: 44 | patch = self.data[:, i0:i1, j0:j1, k0:k1] 45 | else: 46 | patch = np.expand_dims(self.data[i0:i1, j0:j1, k0:k1], 0) 47 | if not isinstance(self.data, torch.Tensor): 48 | patch = torch.from_numpy(patch) 49 | kwargs = dict( 50 | tensor=patch, 51 | affine=new_affine, 52 | type=self.type, 53 | path=self.path, 54 | h5DS=self.h5DS 55 | ) 56 | for key, value in self.items(): 57 | if key in torchio.data.image.PROTECTED_KEYS: continue 58 | kwargs[key] = value 59 | return self.__class__(**kwargs) 60 | 61 | def read_and_check_h5(self, h5DS): 62 | tensor, affine = torch.from_numpy(h5DS[()]).unsqueeze(0), np.eye(4) 63 | tensor = super().parse_tensor_shape(tensor) 64 | if self.channels_last: 65 | tensor = tensor.permute(3, 0, 1, 2) 66 | if self.check_nans and torch.isnan(tensor).any(): 67 | warnings.warn(f'NaNs found in file "{path}"') 68 | return tensor, affine -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | *.ptrh 131 | *.swp 132 | /HuggingFace 133 | /Xmodels 134 | -------------------------------------------------------------------------------- /ceVae/skipae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributions as dist 4 | import torch.nn as nn 5 | 6 | # Model 7 | def down_conv(in_channels, out_channels, kernel_size, padding, stride): 8 | block = nn.Sequential( 9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride), 10 | nn.BatchNorm2d(out_channels), 11 | nn.ReLU(inplace=True) 12 | ) 13 | return block 14 | 15 | def up_conv(in_channels, out_channels, kernel_size, padding, stride, output_padding): 16 | block = nn.Sequential( 17 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, output_padding=output_padding), 18 | nn.BatchNorm2d(out_channels), 19 | nn.ReLU(inplace=True) 20 | ) 21 | return block 22 | 23 | class Skip_AE(nn.Module): 24 | def __init__(self): 25 | super(Skip_AE, self).__init__() 26 | 27 | self.down_conv1 = down_conv(in_channels = 1, out_channels = 64, kernel_size = 5, padding = 2, stride = 1) 28 | self.down_conv2 = down_conv(in_channels = 64, out_channels = 128, kernel_size = 5, padding = 2, stride = 2) 29 | self.down_conv3 = down_conv(in_channels = 128, out_channels = 256, kernel_size = 5, padding = 2, stride = 2) 30 | self.down_conv4 = down_conv(in_channels = 256, out_channels = 512, kernel_size = 5, padding = 2, stride = 2) 31 | self.down_conv5 = down_conv(in_channels = 512, out_channels = 64, kernel_size = 5, padding = 2, stride = 2) 32 | 33 | self.up_conv1 = up_conv(in_channels=64, out_channels=512, kernel_size=5, stride=2, padding=2, output_padding = 1) 34 | self.up_conv2 = up_conv(in_channels=512, out_channels=256, kernel_size=5, stride=2, padding=2, output_padding = 1) 35 | self.up_conv3 = up_conv(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2, output_padding = 1) 36 | self.up_conv4 = up_conv(in_channels=128, out_channels=64, kernel_size=5, stride=2, padding=2, output_padding = 1) 37 | self.up_conv5 = up_conv(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, output_padding=0) 38 | 39 | def forward(self, image): 40 | #encoder 41 | x1 = self.down_conv1(image) 42 | x2 = self.down_conv2(x1) 43 | x3 = self.down_conv3(x2) 44 | x4 = self.down_conv4(x3) 45 | x5 = self.down_conv5(x4) 46 | 47 | #decoder 48 | y1 = self.up_conv1(x5) 49 | y1 = y1 + x4 50 | y2 = self.up_conv2(y1) 51 | y2 = nn.Dropout(0.1)(y2 + x3) 52 | y3 = self.up_conv3(y2) 53 | y3 = nn.Dropout(0.1)(y3 + x2) 54 | y4 = self.up_conv4(y3) 55 | y4 = nn.Dropout(0.1)(y4 + x1) 56 | y5 = self.up_conv5(y4) 57 | out = torch.tanh(y5) 58 | return out -------------------------------------------------------------------------------- /misc/skipae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributions as dist 4 | import torch.nn as nn 5 | 6 | # Model 7 | def down_conv(in_channels, out_channels, kernel_size, padding, stride): 8 | block = nn.Sequential( 9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride), 10 | nn.BatchNorm2d(out_channels), 11 | nn.ReLU(inplace=True) 12 | ) 13 | return block 14 | 15 | def up_conv(in_channels, out_channels, kernel_size, padding, stride, output_padding): 16 | block = nn.Sequential( 17 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, output_padding=output_padding), 18 | nn.BatchNorm2d(out_channels), 19 | nn.ReLU(inplace=True) 20 | ) 21 | return block 22 | 23 | class Skip_AE(nn.Module): 24 | def __init__(self): 25 | super(Skip_AE, self).__init__() 26 | 27 | self.down_conv1 = down_conv(in_channels = 1, out_channels = 64, kernel_size = 5, padding = 2, stride = 1) 28 | self.down_conv2 = down_conv(in_channels = 64, out_channels = 128, kernel_size = 5, padding = 2, stride = 2) 29 | self.down_conv3 = down_conv(in_channels = 128, out_channels = 256, kernel_size = 5, padding = 2, stride = 2) 30 | self.down_conv4 = down_conv(in_channels = 256, out_channels = 512, kernel_size = 5, padding = 2, stride = 2) 31 | self.down_conv5 = down_conv(in_channels = 512, out_channels = 64, kernel_size = 5, padding = 2, stride = 2) 32 | 33 | self.up_conv1 = up_conv(in_channels=64, out_channels=512, kernel_size=5, stride=2, padding=2, output_padding = 1) 34 | self.up_conv2 = up_conv(in_channels=512, out_channels=256, kernel_size=5, stride=2, padding=2, output_padding = 1) 35 | self.up_conv3 = up_conv(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2, output_padding = 1) 36 | self.up_conv4 = up_conv(in_channels=128, out_channels=64, kernel_size=5, stride=2, padding=2, output_padding = 1) 37 | self.up_conv5 = up_conv(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, output_padding=0) 38 | 39 | def forward(self, image): 40 | #encoder 41 | x1 = self.down_conv1(image) 42 | x2 = self.down_conv2(x1) 43 | x3 = self.down_conv3(x2) 44 | x4 = self.down_conv4(x3) 45 | x5 = self.down_conv5(x4) 46 | 47 | #decoder 48 | y1 = self.up_conv1(x5) 49 | y1 = y1 + x4 50 | y2 = self.up_conv2(y1) 51 | y2 = nn.Dropout(0.1)(y2 + x3) 52 | y3 = self.up_conv3(y2) 53 | y3 = nn.Dropout(0.1)(y3 + x2) 54 | y4 = self.up_conv4(y3) 55 | y4 = nn.Dropout(0.1)(y4 + x1) 56 | y5 = self.up_conv5(y4) 57 | out = torch.tanh(y5) 58 | return out -------------------------------------------------------------------------------- /dataloaders/mood.py: -------------------------------------------------------------------------------- 1 | import h5py as h5 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchio.data.subject import Subject 7 | 8 | from .torchiowrap import H5DSImage 9 | 10 | class MoodTrainSet(Dataset): 11 | def __init__(self, indices=None, region='brain', data_path='MOOD_train.h5', torchiosub=True, lazypatch=True, preload=False): 12 | self.h5 = h5.File(data_path, 'r', swmr=True) 13 | self.samples = [] 14 | if indices: 15 | self.samples = [self.h5[region][str(i).zfill(5)]for i in indices] 16 | # self.samples2 = [self.h5[region][str(i).zfill(5)][:] for i in indices] 17 | else: 18 | self.samples = [self.h5[region][i] for i in list(self.h5[region])] 19 | if preload: 20 | print('Preloading MoodTrainSet') 21 | for i in range(len(self.samples)): 22 | self.samples[i] = self.samples[i][:] 23 | self.torchiosub = torchiosub 24 | self.lazypatch = lazypatch 25 | 26 | def __len__(self): 27 | return len(self.samples) 28 | 29 | def __getitem__(self, item): 30 | if self.torchiosub: 31 | return Subject({'img':H5DSImage(self.samples[item], lazypatch=self.lazypatch)}) 32 | else: 33 | return torch.from_numpy(self.samples[item][()]).unsqueeze(0) 34 | 35 | class MoodValSet(Dataset): 36 | def __init__(self, load_abnormal=True, load_normal=True, loadASTrain=False, data_path='MOOD_val.h5', torchiosub=True, lazypatch=True, preload=False): 37 | self.h5 = h5.File(data_path, 'r', swmr=True) 38 | self.samples = [] 39 | if load_abnormal: 40 | self.samples+=[(self.h5['abnormal'][i], self.h5['abnormal_mask'][i]) for i in list(self.h5['abnormal'])] 41 | if load_normal: 42 | self.samples+=[self.h5['normal'][i] for i in list(self.h5['normal'])] 43 | if preload: 44 | print('Preloading MoodValSet') 45 | for i in range(len(self.samples)): 46 | if len(self.samples[i]) == 2: 47 | self.samples[i] = (self.samples[i][0][:], self.samples[i][1][:]) 48 | else: 49 | self.samples[i] = self.samples[i][:] 50 | self.loadASTrain = loadASTrain 51 | self.torchiosub = torchiosub 52 | self.lazypatch = lazypatch 53 | 54 | def __len__(self): 55 | return len(self.samples) 56 | 57 | def __getitem__(self, item): 58 | if self.loadASTrain: 59 | if self.torchiosub: 60 | return Subject({'img':H5DSImage(self.samples[item][0], lazypatch=self.lazypatch)}) 61 | else: 62 | return torch.from_numpy(self.samples[item][0][()]).unsqueeze(0) 63 | else: 64 | if self.torchiosub: 65 | if len(self.samples[item]) == 2: 66 | return Subject({'img':H5DSImage(self.samples[item][0], lazypatch=self.lazypatch), 67 | 'gt':H5DSImage(self.samples[item][1], lazypatch=self.lazypatch)}) 68 | else: 69 | return Subject({'img':H5DSImage(self.samples[item], lazypatch=self.lazypatch), 70 | 'gt':H5DSImage(self.samples[item], lazypatch=self.lazypatch)}) #this is dirty. TODO 71 | 72 | else: 73 | return (torch.from_numpy(self.samples[item][0][()]).unsqueeze(0), torch.from_numpy(self.samples[item][1][()]).unsqueeze(0)) -------------------------------------------------------------------------------- /ccevae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributions as dist 4 | 5 | from ae_bases import BasicEncoder, BasicGenerator 6 | 7 | class VAE(torch.nn.Module): 8 | def __init__( 9 | self, 10 | input_size, 11 | z_dim=256, 12 | fmap_sizes=(16, 64, 256, 1024), 13 | to_1x1=True, 14 | conv_op=torch.nn.Conv2d, 15 | conv_params=None, 16 | tconv_op=torch.nn.ConvTranspose2d, 17 | tconv_params=None, 18 | normalization_op=None, 19 | normalization_params=None, 20 | activation_op=torch.nn.LeakyReLU, 21 | activation_params=None, 22 | block_op=None, 23 | block_params=None, 24 | *args, 25 | **kwargs 26 | ): 27 | super(VAE, self).__init__() 28 | 29 | input_size_enc = list(input_size) 30 | input_size_dec = list(input_size) 31 | 32 | self.enc = BasicEncoder( 33 | input_size=input_size_enc, 34 | fmap_sizes=fmap_sizes, 35 | z_dim=z_dim * 2, 36 | conv_op=conv_op, 37 | conv_params=conv_params, 38 | normalization_op=normalization_op, 39 | normalization_params=normalization_params, 40 | activation_op=activation_op, 41 | activation_params=activation_params, 42 | block_op=block_op, 43 | block_params=block_params, 44 | to_1x1=to_1x1, 45 | ) 46 | self.dec = BasicGenerator( 47 | input_size=input_size_dec, 48 | fmap_sizes=fmap_sizes[::-1], 49 | z_dim=z_dim, 50 | upsample_op=tconv_op, 51 | conv_params=tconv_params, 52 | normalization_op=normalization_op, 53 | normalization_params=normalization_params, 54 | activation_op=activation_op, 55 | activation_params=activation_params, 56 | block_op=block_op, 57 | block_params=block_params, 58 | to_1x1=to_1x1, 59 | ) 60 | 61 | self.hidden_size = self.enc.output_size 62 | 63 | def forward(self, inpt, sample=True, no_dist=False, **kwargs): 64 | y1 = self.enc(inpt, **kwargs) 65 | 66 | mu, log_std = torch.chunk(y1, 2, dim=1) 67 | std = torch.exp(log_std) 68 | z_dist = dist.Normal(mu, std) 69 | if sample: 70 | z_sample = z_dist.rsample() 71 | else: 72 | z_sample = mu 73 | 74 | x_rec = self.dec(z_sample) 75 | 76 | if no_dist: 77 | return x_rec 78 | else: 79 | return x_rec, z_dist 80 | 81 | def encode(self, inpt, **kwargs): 82 | enc = self.enc(inpt, **kwargs) 83 | mu, log_std = torch.chunk(enc, 2, dim=1) 84 | std = torch.exp(log_std) 85 | return mu, std 86 | 87 | def decode(self, inpt, **kwargs): 88 | x_rec = self.dec(inpt, **kwargs) 89 | return x_rec 90 | 91 | class AE(torch.nn.Module): 92 | def __init__( 93 | self, 94 | input_size, 95 | z_dim=1024, 96 | fmap_sizes=(16, 64, 256, 1024), 97 | to_1x1=True, 98 | conv_op=torch.nn.Conv2d, 99 | conv_params=None, 100 | tconv_op=torch.nn.ConvTranspose2d, 101 | tconv_params=None, 102 | normalization_op=None, 103 | normalization_params=None, 104 | activation_op=torch.nn.LeakyReLU, 105 | activation_params=None, 106 | block_op=None, 107 | block_params=None, 108 | *args, 109 | **kwargs 110 | ): 111 | super(AE, self).__init__() 112 | 113 | input_size_enc = list(input_size) 114 | input_size_dec = list(input_size) 115 | 116 | self.enc = BasicEncoder( 117 | input_size=input_size_enc, 118 | fmap_sizes=fmap_sizes, 119 | z_dim=z_dim, 120 | conv_op=conv_op, 121 | conv_params=conv_params, 122 | normalization_op=normalization_op, 123 | normalization_params=normalization_params, 124 | activation_op=activation_op, 125 | activation_params=activation_params, 126 | block_op=block_op, 127 | block_params=block_params, 128 | to_1x1=to_1x1, 129 | ) 130 | self.dec = BasicGenerator( 131 | input_size=input_size_dec, 132 | fmap_sizes=fmap_sizes[::-1], 133 | z_dim=z_dim, 134 | upsample_op=tconv_op, 135 | conv_params=tconv_params, 136 | normalization_op=normalization_op, 137 | normalization_params=normalization_params, 138 | activation_op=activation_op, 139 | activation_params=activation_params, 140 | block_op=block_op, 141 | block_params=block_params, 142 | to_1x1=to_1x1, 143 | ) 144 | 145 | self.hidden_size = self.enc.output_size 146 | 147 | def forward(self, inpt, **kwargs): 148 | 149 | y1 = self.enc(inpt, **kwargs) 150 | 151 | x_rec = self.dec(y1) 152 | 153 | return x_rec 154 | 155 | def encode(self, inpt, **kwargs): 156 | enc = self.enc(inpt, **kwargs) 157 | return enc 158 | 159 | def decode(self, inpt, **kwargs): 160 | rec = self.dec(inpt, **kwargs) 161 | return rec -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import wandb 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.utils.data import Dataset, DataLoader 9 | import torchio 10 | import torchio as tio 11 | from torch.optim import Adam 12 | from torch import nn, optim 13 | from torch.cuda.amp import autocast, GradScaler 14 | from torchvision import transforms 15 | from helpers import kl_loss_fn, rec_loss_fn, geco_beta_update, get_ema, get_square_mask 16 | 17 | def train(model, train_loader, validation_loader, num_epochs, optimizer, scaler, device): 18 | for epoch in range(num_epochs): 19 | model.train() 20 | print('Epoch ' + str(epoch) + ': Train') 21 | for i, data in enumerate(train_loader): 22 | img = data 23 | 24 | tmp = img.view(img.shape[0], 1, -1) 25 | min_vals = tmp.min(2, keepdim=True).values 26 | max_vals = tmp.max(2, keepdim=True).values 27 | tmp = (tmp - min_vals) / max_vals 28 | x = tmp.view(img.size()) 29 | 30 | shape = x.shape 31 | tensor_reshaped = x.reshape(shape[0],-1) 32 | tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)] 33 | tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:]).to(device) 34 | 35 | images = Variable(tensor).to(device) 36 | optimizer.zero_grad() 37 | 38 | ### VAE Part 39 | with autocast(): 40 | loss_vae = 0 41 | if ce_factor < 1: 42 | x_r, z_dist = model(images) 43 | 44 | kl_loss = 0 45 | if model.d == 3: 46 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,)) * beta 47 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3,4)) 48 | else: 49 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta 50 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3)) 51 | loss_vae = kl_loss + rec_loss_vae * theta 52 | 53 | ### CE Part 54 | loss_ce = 0 55 | if ce_factor > 0: 56 | ce_tensor = get_square_mask( 57 | tensor.shape, 58 | square_size=(0, np.max(input_size[1:]) // 2), 59 | noise_val=(torch.min(tensor).item(), torch.max(tensor).item()), 60 | n_squares=(0, 3), 61 | ) 62 | 63 | ce_tensor = torch.from_numpy(ce_tensor).float().to(device) 64 | inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, tensor) 65 | 66 | inpt_noisy = inpt_noisy.to(device) 67 | 68 | with autocast(): 69 | x_rec_ce, _ = model(inpt_noisy) 70 | if model.d == 3: 71 | rec_loss_ce = rec_loss_fn(x_rec_ce, images, sumdim=(1,2,3,4)) 72 | else: 73 | rec_loss_ce = rec_loss_fn(x_rec_ce, images, sumdim=(1,2,3)) 74 | loss_ce = rec_loss_ce 75 | loss = (1.0 - ce_factor) * loss_vae + ce_factor * loss_ce 76 | 77 | if use_geco and ce_factor < 1: 78 | g_goal = 0.1 79 | g_lr = 1e-4 80 | vae_loss_ema = (1.0 - 0.9) * rec_loss_vae + 0.9 * vae_loss_ema 81 | theta = geco_beta_update(theta, vae_loss_ema, g_goal, g_lr, speedup=2) 82 | 83 | 84 | scaler.scale(loss).backward() 85 | scaler.step(optimizer) 86 | scaler.update() 87 | 88 | loss = round(loss.item(),4) 89 | print("Epoch: ", epoch, ", i: ", i, ", Training loss: ", loss) 90 | 91 | checkpoint = { 92 | 'state_dict': model.state_dict(), 93 | 'optimizer': optimizer.state_dict(), 94 | 'AMPScaler': scaler.state_dict() 95 | } 96 | if epoch%4==0: 97 | torch.save(checkpoint, os.path.join(save_path, trainID + '-epoch-' + str(epoch) + ".pth.tar")) 98 | 99 | model.eval() 100 | with torch.no_grad(): 101 | print('Epoch '+ str(epoch)+ ': Val') 102 | for i, data in enumerate(validation_loader): 103 | img = data 104 | tmp = img.view(img.shape[0], 1, -1) 105 | min_vals = tmp.min(2, keepdim=True).values 106 | max_vals = tmp.max(2, keepdim=True).values 107 | tmp = (tmp - min_vals) / max_vals 108 | x = tmp.view(img.size()) 109 | 110 | shape = x.shape 111 | tensor_reshaped = x.reshape(shape[0],-1) 112 | tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)] 113 | tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:]) 114 | 115 | images = Variable(tensor).to(device) 116 | x_r, z_dist = model(images) 117 | kl_loss = 0 118 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta 119 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3)) 120 | loss_vae = kl_loss + rec_loss_vae * theta 121 | mood_val_loss += loss_vae.item() 122 | print("Validation loss: " + str(mood_val_loss)) 123 | 124 | return model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **S**egmen**t**ation **Reg**ularised **A**nomaly (StRegA) 2 | 3 | Official code of the paper "StRegA: Unsupervised Anomaly Detection in Brain MRIs using a Compact Context-encoding Variational Autoencoder" (https://doi.org/10.1016/j.compbiomed.2022.106093 and https://arxiv.org/abs/2201.13271). 4 | 5 | This was first presented at ISMRM-ESMRMB 2022, London. 6 | Abstract available on RG: https://www.researchgate.net/publication/358357400_StRegA_Unsupervised_Anomaly_Detection_in_Brain_MRIs_using_Compact_Context-encoding_Variational_Autoencoder 7 | 8 | The name "StRegA" is inspired by the name of the Italian herb liquore with saffron - Strega (following the tradition of namming MR-related products with name of alchoholic drinks or liquores. 9 | 10 | ## Demo & Documentation 11 | 12 | 📖 **See [DEMO.md](DEMO.md) for comprehensive training and inference instructions**, including: 13 | - Complete training workflow with all parameters 14 | - Testing with locally trained models 15 | - Testing with HuggingFace pre-trained model 16 | - Paper experiment configurations 17 | - Evaluation metrics 18 | 19 | ### Quick Start - Inference 20 | 21 | ```bash 22 | # Using HuggingFace pre-trained model 23 | python inference.py --input /path/to/brain_mri.nii.gz --output /path/to/results/ 24 | 25 | # Using local checkpoint 26 | python inference.py --input /path/to/brain_mri.nii.gz --output /path/to/results/ \ 27 | --checkpoint /path/to/model.pth.tar 28 | ``` 29 | 30 | ### Quick Start - Evaluation 31 | 32 | ```bash 33 | python evaluate.py --predictions /path/to/predictions/ --ground_truth /path/to/gt/ \ 34 | --output results.csv --verbose 35 | ``` 36 | 37 | ## Information regarding this repo 38 | 39 | ### Code structure 40 | 41 | - `DEMO.md` - Comprehensive documentation for training and inference 42 | - `inference.py` - Standalone inference script for anomaly detection 43 | - `evaluate.py` - Evaluation script for computing metrics (Dice, F1, etc.) 44 | - `engine.py` and `train.py` are used to train new models with a custom data loader expected to iterate over slices of FSL segmented data on the 2D model 45 | - `ccevae.py` contains the model code and uses parts from `ae_bases.py`, `ce_noise.py` and `helpers.py` 46 | - `Pipeline.ipynb` shows the entire StRegA pipeline including post-processing. 47 | - The `dataloaders` folder has some examples of the dataloader that was used during training and validation 48 | 49 | ### Checkpoint 50 | 51 | A "master" checkpoint was created (not part of the manuscript) by training on MOOD (T1) + IXI T1 + IXI T2 + IXI PD MRIs segmented with FSL. This can be found on Huggingface: [https://huggingface.co/soumickmj/StRegA_cceVAE2D_Brain_MOOD_IXIT1_IXIT2_IXIPD](https://huggingface.co/soumickmj/StRegA_cceVAE2D_Brain_MOOD_IXIT1_IXIT2_IXIPD). This checkpoint can directly be used (example provided in Pipeline.ipynb notebook) or can be saved as a checkpoint file to make it like a locally-trained model. 52 | Here is an example: 53 | ```python 54 | from transformers import AutoModel 55 | modelHF = AutoModel.from_pretrained("soumickmj/StRegA_cceVAE2D_Brain_MOOD_IXIT1_IXIT2_IXIPD", trust_remote_code=True) 56 | torch.save(modelHF.model, "/path/to/checkpoint/brain.ptrh") 57 | ``` 58 | 59 | ## Contacts 60 | 61 | Please feel free to contact me for any questions or feedback: 62 | 63 | [soumick.chatterjee@ovgu.de](mailto:soumick.chatterjee@ovgu.de) 64 | 65 | [contact@soumick.com](mailto:contact@soumick.com) 66 | 67 | ## Credits 68 | 69 | If you like this repository, please click on Star! 70 | 71 | If you use this approach in your research or use codes from this repository, please cite either or both of the following in your publications: 72 | 73 | > [Soumick Chatterjee, Alessandro Sciarra, Max Dünnwald, Pavan Tummala, Shubham Kumar Agrawal, Aishwarya Jauhari, Aman Kalra, Steffen Oeltze-Jafra, Oliver Speck, Andreas Nürnberger: StRegA: Unsupervised Anomaly Detection in Brain MRIs using a Compact Context-encoding Variational Autoencoder (Computers in Biology and Medicine, Oct 2022)](https://doi.org/10.1016/j.compbiomed.2022.106093) 74 | 75 | BibTeX entry: 76 | 77 | ```bibtex 78 | @article{chatterjee2022strega, 79 | title={StRegA: Unsupervised Anomaly Detection in Brain MRIs using a Compact Context-encoding Variational Autoencoder}, 80 | author={Chatterjee, Soumick and Sciarra, Alessandro and D{\"u}nnwald, Max and Tummala, Pavan and Agrawal, Shubham Kumar and Jauhari, Aishwarya and Kalra, Aman and Oeltze-Jafra, Steffen and Speck, Oliver and N{\"u}rnberger, Andreas}, 81 | journal={Computers in Biology and Medicine}, 82 | pages={106093}, 83 | year={2022}, 84 | publisher={Elsevier}, 85 | doi={10.1016/j.compbiomed.2022.106093} 86 | } 87 | } 88 | ``` 89 | 90 | The complete manuscript is also on ArXiv:- 91 | > [Soumick Chatterjee, Alessandro Sciarra, Max Dünnwald, Pavan Tummala, Shubham Kumar Agrawal, Aishwarya Jauhari, Aman Kalra, Steffen Oeltze-Jafra, Oliver Speck, Andreas Nürnberger: StRegA: Unsupervised Anomaly Detection in Brain MRIs using a Compact Context-encoding Variational Autoencoder (arXiv:2201.13271 92 | , Jan 2022)](https://arxiv.org/abs/2201.13271) 93 | 94 | The ISMRM-ESMRMB 2022 abstract:- 95 | 96 | > [Soumick Chatterjee, Alessandro Sciarra, Max Dünnwald, Pavan Tummala, Shubham Kumar Agrawal, Aishwarya Jauhari, Aman Kalra, Steffen Oeltze-Jafra, Oliver Speck, Andreas Nürnberger: StRegA: Unsupervised Anomaly Detection in Brain MRIs using Compact Context-encoding Variational Autoencoder (ISMRM-ESMRMB 2022, May 2022)](https://www.researchgate.net/publication/358357668_Multi-scale_UNet_with_Self-Constructing_Graph_Latent_for_Deformable_Image_Registration) 97 | 98 | BibTeX entry: 99 | 100 | 101 | ```bibtex 102 | @inproceedings{mickISMRM22strega, 103 | author = {Chatterjee, Soumick and Sciarra, Alessandro and D{\"u}nnwald, Max and Tummala, Pavan and Agrawal, Shubham Kumar and Jauhari, Aishwarya and Kalra, Aman and Oeltze-Jafra, Steffen and Speck, Oliver and N{\"u}rnberger, Andreas}, 104 | year = {2022}, 105 | month = {05}, 106 | pages = {0172}, 107 | title = {StRegA: Unsupervised Anomaly Detection in Brain MRIs using Compact Context-encoding Variational Autoencoder}, 108 | booktitle={ISMRM-ESMRMB 2022} 109 | } 110 | ``` 111 | Thank you so much for your support. 112 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | import re 4 | import os 5 | from typing import List 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | import torch.distributions as dist 10 | #from trixi.util.pytorchutils import get_smooth_image_gradient 11 | from ce_noise import smooth_tensor 12 | 13 | def kl_loss_fn(z_post, sum_samples=True, correct=False, sumdim=(1,2,3)): 14 | z_prior = dist.Normal(0, 1.0) 15 | kl_div = dist.kl_divergence(z_post, z_prior) 16 | if correct: 17 | kl_div = torch.sum(kl_div, dim=sumdim) 18 | else: 19 | kl_div = torch.mean(kl_div, dim=sumdim) 20 | if sum_samples: 21 | return torch.mean(kl_div) 22 | else: 23 | return kl_div 24 | 25 | def rec_loss_fn(recon_x, x, sum_samples=True, correct=False, sumdim=(1,2,3)): 26 | if correct: 27 | x_dist = dist.Laplace(recon_x, 1.0) 28 | log_p_x_z = x_dist.log_prob(x) 29 | log_p_x_z = torch.sum(log_p_x_z, dim=sumdim) 30 | else: 31 | log_p_x_z = -torch.abs(recon_x - x) 32 | log_p_x_z = torch.mean(log_p_x_z, dim=sumdim) 33 | if sum_samples: 34 | return -torch.mean(log_p_x_z) 35 | else: 36 | return -log_p_x_z 37 | 38 | def geco_beta_update(beta, error_ema, goal, step_size, min_clamp=1e-10, max_clamp=1e4, speedup=None): 39 | constraint = (error_ema - goal).detach() 40 | if speedup is not None and constraint > 0.0: 41 | beta = beta * torch.exp(speedup * step_size * constraint) 42 | else: 43 | beta = beta * torch.exp(step_size * constraint) 44 | if min_clamp is not None: 45 | beta = np.max((beta.item(), min_clamp)) 46 | if max_clamp is not None: 47 | beta = np.min((beta.item(), max_clamp)) 48 | return beta 49 | 50 | def get_ema(new, old, alpha): 51 | if old is None: 52 | return new 53 | return (1.0 - alpha) * new + alpha * old 54 | 55 | 56 | def get_range_val(value, rnd_type="uniform"): 57 | if isinstance(value, (list, tuple, np.ndarray)): 58 | if len(value) == 2: 59 | if value[0] == value[1]: 60 | n_val = value[0] 61 | else: 62 | orig_type = type(value[0]) 63 | if rnd_type == "uniform": 64 | n_val = random.uniform(value[0], value[1]) 65 | elif rnd_type == "normal": 66 | n_val = random.normalvariate(value[0], value[1]) 67 | n_val = orig_type(n_val) 68 | elif len(value) == 1: 69 | n_val = value[0] 70 | else: 71 | raise RuntimeError("value must be either a single vlaue or a list/tuple of len 2") 72 | return n_val 73 | else: 74 | return value 75 | 76 | def get_square_mask(data_shape, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None): 77 | """Returns a 'mask' with the same size as the data, where random squares are != 0 78 | 79 | Args: 80 | data_shape ([tensor]): [data_shape to determine the shape of the returned tensor] 81 | square_size ([tuple]): [int/ int tuple (min_size, max_size), determining the min and max squear size] 82 | n_squares ([type]): [int/ int tuple (min_number, max_number), determining the min and max number of squares] 83 | noise_val (tuple, optional): [int/ int tuple (min_val, max_val), determining the min and max value given in the 84 | squares, which habe the value != 0 ]. Defaults to (0, 0). 85 | channel_wise_n_val (bool, optional): [Use a different value for each channel]. Defaults to False. 86 | square_pos ([type], optional): [Square position]. Defaults to None. 87 | """ 88 | 89 | def mask_random_square(img_shape, square_size, n_val, channel_wise_n_val=False, square_pos=None): 90 | """Masks (sets = 0) a random square in an image""" 91 | 92 | img_h = img_shape[-2] 93 | img_w = img_shape[-1] 94 | 95 | img = np.zeros(img_shape) 96 | 97 | if square_pos is None: 98 | w_start = np.random.randint(0, img_w - square_size) 99 | h_start = np.random.randint(0, img_h - square_size) 100 | else: 101 | pos_wh = square_pos[np.random.randint(0, len(square_pos))] 102 | w_start = pos_wh[0] 103 | h_start = pos_wh[1] 104 | 105 | if img.ndim == 2: 106 | rnd_n_val = get_range_val(n_val) 107 | img[h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 108 | elif img.ndim == 3: 109 | if channel_wise_n_val: 110 | for i in range(img.shape[0]): 111 | rnd_n_val = get_range_val(n_val) 112 | img[i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 113 | else: 114 | rnd_n_val = get_range_val(n_val) 115 | img[:, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 116 | elif img.ndim == 4: 117 | if channel_wise_n_val: 118 | for i in range(img.shape[0]): 119 | rnd_n_val = get_range_val(n_val) 120 | img[:, i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 121 | else: 122 | rnd_n_val = get_range_val(n_val) 123 | img[:, :, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 124 | 125 | return img 126 | 127 | def mask_random_squares(img_shape, square_size, n_squares, n_val, channel_wise_n_val=False, square_pos=None): 128 | """Masks a given number of squares in an image""" 129 | img = np.zeros(img_shape) 130 | for i in range(n_squares): 131 | img = mask_random_square( 132 | img_shape, square_size, n_val, channel_wise_n_val=channel_wise_n_val, square_pos=square_pos 133 | ) 134 | return img 135 | 136 | ret_data = np.zeros(data_shape) 137 | for sample_idx in range(data_shape[0]): 138 | # rnd_n_val = get_range_val(noise_val) 139 | rnd_square_size = get_range_val(square_size) 140 | rnd_n_squares = get_range_val(n_squares) 141 | 142 | ret_data[sample_idx] = mask_random_squares( 143 | data_shape[1:], 144 | square_size=rnd_square_size, 145 | n_squares=rnd_n_squares, 146 | n_val=noise_val, 147 | channel_wise_n_val=channel_wise_n_val, 148 | square_pos=square_pos, 149 | ) 150 | 151 | return ret_data 152 | 153 | 154 | -------------------------------------------------------------------------------- /ceVae/helpers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | import re 4 | import os 5 | from typing import List 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | import torch.distributions as dist 10 | #from trixi.util.pytorchutils import get_smooth_image_gradient 11 | from .ce_noise import smooth_tensor 12 | 13 | def kl_loss_fn(z_post, sum_samples=True, correct=False, sumdim=(1,2,3)): 14 | z_prior = dist.Normal(0, 1.0) 15 | kl_div = dist.kl_divergence(z_post, z_prior) 16 | if correct: 17 | kl_div = torch.sum(kl_div, dim=sumdim) 18 | else: 19 | kl_div = torch.mean(kl_div, dim=sumdim) 20 | if sum_samples: 21 | return torch.mean(kl_div) 22 | else: 23 | return kl_div 24 | 25 | def rec_loss_fn(recon_x, x, sum_samples=True, correct=False, sumdim=(1,2,3)): 26 | if correct: 27 | x_dist = dist.Laplace(recon_x, 1.0) 28 | log_p_x_z = x_dist.log_prob(x) 29 | log_p_x_z = torch.sum(log_p_x_z, dim=sumdim) 30 | else: 31 | log_p_x_z = -torch.abs(recon_x - x) 32 | log_p_x_z = torch.mean(log_p_x_z, dim=sumdim) 33 | if sum_samples: 34 | return -torch.mean(log_p_x_z) 35 | else: 36 | return -log_p_x_z 37 | 38 | def geco_beta_update(beta, error_ema, goal, step_size, min_clamp=1e-10, max_clamp=1e4, speedup=None): 39 | constraint = (error_ema - goal).detach() 40 | if speedup is not None and constraint > 0.0: 41 | beta = beta * torch.exp(speedup * step_size * constraint) 42 | else: 43 | beta = beta * torch.exp(step_size * constraint) 44 | if min_clamp is not None: 45 | beta = np.max((beta.item(), min_clamp)) 46 | if max_clamp is not None: 47 | beta = np.min((beta.item(), max_clamp)) 48 | return beta 49 | 50 | def get_ema(new, old, alpha): 51 | if old is None: 52 | return new 53 | return (1.0 - alpha) * new + alpha * old 54 | 55 | 56 | def get_range_val(value, rnd_type="uniform"): 57 | if isinstance(value, (list, tuple, np.ndarray)): 58 | if len(value) == 2: 59 | if value[0] == value[1]: 60 | n_val = value[0] 61 | else: 62 | orig_type = type(value[0]) 63 | if rnd_type == "uniform": 64 | n_val = random.uniform(value[0], value[1]) 65 | elif rnd_type == "normal": 66 | n_val = random.normalvariate(value[0], value[1]) 67 | n_val = orig_type(n_val) 68 | elif len(value) == 1: 69 | n_val = value[0] 70 | else: 71 | raise RuntimeError("value must be either a single vlaue or a list/tuple of len 2") 72 | return n_val 73 | else: 74 | return value 75 | 76 | def get_square_mask(data_shape, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None): 77 | """Returns a 'mask' with the same size as the data, where random squares are != 0 78 | 79 | Args: 80 | data_shape ([tensor]): [data_shape to determine the shape of the returned tensor] 81 | square_size ([tuple]): [int/ int tuple (min_size, max_size), determining the min and max squear size] 82 | n_squares ([type]): [int/ int tuple (min_number, max_number), determining the min and max number of squares] 83 | noise_val (tuple, optional): [int/ int tuple (min_val, max_val), determining the min and max value given in the 84 | squares, which habe the value != 0 ]. Defaults to (0, 0). 85 | channel_wise_n_val (bool, optional): [Use a different value for each channel]. Defaults to False. 86 | square_pos ([type], optional): [Square position]. Defaults to None. 87 | """ 88 | 89 | def mask_random_square(img_shape, square_size, n_val, channel_wise_n_val=False, square_pos=None): 90 | """Masks (sets = 0) a random square in an image""" 91 | 92 | img_h = img_shape[-2] 93 | img_w = img_shape[-1] 94 | 95 | img = np.zeros(img_shape) 96 | 97 | if square_pos is None: 98 | w_start = np.random.randint(0, img_w - square_size) 99 | h_start = np.random.randint(0, img_h - square_size) 100 | else: 101 | pos_wh = square_pos[np.random.randint(0, len(square_pos))] 102 | w_start = pos_wh[0] 103 | h_start = pos_wh[1] 104 | 105 | if img.ndim == 2: 106 | rnd_n_val = get_range_val(n_val) 107 | img[h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 108 | elif img.ndim == 3: 109 | if channel_wise_n_val: 110 | for i in range(img.shape[0]): 111 | rnd_n_val = get_range_val(n_val) 112 | img[i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 113 | else: 114 | rnd_n_val = get_range_val(n_val) 115 | img[:, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 116 | elif img.ndim == 4: 117 | if channel_wise_n_val: 118 | for i in range(img.shape[0]): 119 | rnd_n_val = get_range_val(n_val) 120 | img[:, i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 121 | else: 122 | rnd_n_val = get_range_val(n_val) 123 | img[:, :, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 124 | 125 | return img 126 | 127 | def mask_random_squares(img_shape, square_size, n_squares, n_val, channel_wise_n_val=False, square_pos=None): 128 | """Masks a given number of squares in an image""" 129 | img = np.zeros(img_shape) 130 | for i in range(n_squares): 131 | img = mask_random_square( 132 | img_shape, square_size, n_val, channel_wise_n_val=channel_wise_n_val, square_pos=square_pos 133 | ) 134 | return img 135 | 136 | ret_data = np.zeros(data_shape) 137 | for sample_idx in range(data_shape[0]): 138 | # rnd_n_val = get_range_val(noise_val) 139 | rnd_square_size = get_range_val(square_size) 140 | rnd_n_squares = get_range_val(n_squares) 141 | 142 | ret_data[sample_idx] = mask_random_squares( 143 | data_shape[1:], 144 | square_size=rnd_square_size, 145 | n_squares=rnd_n_squares, 146 | n_val=noise_val, 147 | channel_wise_n_val=channel_wise_n_val, 148 | square_pos=square_pos, 149 | ) 150 | 151 | return ret_data 152 | 153 | 154 | -------------------------------------------------------------------------------- /ceVae/ce_noise.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def get_range_val(value, rnd_type="uniform"): 8 | if isinstance(value, (list, tuple, np.ndarray)): 9 | if len(value) == 2: 10 | if value[0] == value[1]: 11 | n_val = value[0] 12 | else: 13 | orig_type = type(value[0]) 14 | if rnd_type == "uniform": 15 | n_val = random.uniform(value[0], value[1]) 16 | elif rnd_type == "normal": 17 | n_val = random.normalvariate(value[0], value[1]) 18 | n_val = orig_type(n_val) 19 | elif len(value) == 1: 20 | n_val = value[0] 21 | else: 22 | raise RuntimeError("value must be either a single vlaue or a list/tuple of len 2") 23 | return n_val 24 | else: 25 | return value 26 | 27 | 28 | def get_square_mask(data_shape, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None): 29 | """Returns a 'mask' with the same size as the data, where random squares are != 0 30 | 31 | Args: 32 | data_shape ([tensor]): [data_shape to determine the shape of the returned tensor] 33 | square_size ([tuple]): [int/ int tuple (min_size, max_size), determining the min and max squear size] 34 | n_squares ([type]): [int/ int tuple (min_number, max_number), determining the min and max number of squares] 35 | noise_val (tuple, optional): [int/ int tuple (min_val, max_val), determining the min and max value given in the 36 | squares, which habe the value != 0 ]. Defaults to (0, 0). 37 | channel_wise_n_val (bool, optional): [Use a different value for each channel]. Defaults to False. 38 | square_pos ([type], optional): [Square position]. Defaults to None. 39 | """ 40 | 41 | def mask_random_square(img_shape, square_size, n_val, channel_wise_n_val=False, square_pos=None): 42 | """Masks (sets = 0) a random square in an image""" 43 | 44 | img_h = img_shape[-2] 45 | img_w = img_shape[-1] 46 | 47 | img = np.zeros(img_shape) 48 | 49 | if square_pos is None: 50 | w_start = np.random.randint(0, img_w - square_size) 51 | h_start = np.random.randint(0, img_h - square_size) 52 | else: 53 | pos_wh = square_pos[np.random.randint(0, len(square_pos))] 54 | w_start = pos_wh[0] 55 | h_start = pos_wh[1] 56 | 57 | if img.ndim == 2: 58 | rnd_n_val = get_range_val(n_val) 59 | img[h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 60 | elif img.ndim == 3: 61 | if channel_wise_n_val: 62 | for i in range(img.shape[0]): 63 | rnd_n_val = get_range_val(n_val) 64 | img[i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 65 | else: 66 | rnd_n_val = get_range_val(n_val) 67 | img[:, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 68 | elif img.ndim == 4: 69 | if channel_wise_n_val: 70 | for i in range(img.shape[0]): 71 | rnd_n_val = get_range_val(n_val) 72 | img[:, i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 73 | else: 74 | rnd_n_val = get_range_val(n_val) 75 | img[:, :, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 76 | 77 | return img 78 | 79 | def mask_random_squares(img_shape, square_size, n_squares, n_val, channel_wise_n_val=False, square_pos=None): 80 | """Masks a given number of squares in an image""" 81 | img = np.zeros(img_shape) 82 | for i in range(n_squares): 83 | img = mask_random_square( 84 | img_shape, square_size, n_val, channel_wise_n_val=channel_wise_n_val, square_pos=square_pos 85 | ) 86 | return img 87 | 88 | ret_data = np.zeros(data_shape) 89 | for sample_idx in range(data_shape[0]): 90 | # rnd_n_val = get_range_val(noise_val) 91 | rnd_square_size = get_range_val(square_size) 92 | rnd_n_squares = get_range_val(n_squares) 93 | 94 | ret_data[sample_idx] = mask_random_squares( 95 | data_shape[1:], 96 | square_size=rnd_square_size, 97 | n_squares=rnd_n_squares, 98 | n_val=noise_val, 99 | channel_wise_n_val=channel_wise_n_val, 100 | square_pos=square_pos, 101 | ) 102 | 103 | return ret_data 104 | 105 | 106 | def smooth_tensor2D(tensor, kernel_size=8, sigma=3, channels=1): 107 | 108 | # Set these to whatever you want for your gaussian filter 109 | 110 | if kernel_size % 2 == 0: 111 | kernel_size -= 1 112 | 113 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 114 | x_cord = torch.arange(kernel_size) 115 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 116 | y_grid = x_grid.t() 117 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 118 | 119 | mean = (kernel_size - 1) / 2.0 120 | variance = sigma ** 2.0 121 | 122 | # Calculate the 2-dimensional gaussian kernel which is 123 | # the product of two gaussian distributions for two different 124 | # variables (in this case called x and y) 125 | import math 126 | 127 | gaussian_kernel = (1.0 / (2.0 * math.pi * variance)) * torch.exp( 128 | -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2.0 * variance) 129 | ) 130 | # Make sure sum of values in gaussian kernel equals 1. 131 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 132 | 133 | # Reshape to 2d depthwise convolutional weight 134 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 135 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) 136 | 137 | gaussian_filter = torch.nn.Conv2d( 138 | in_channels=channels, 139 | out_channels=channels, 140 | kernel_size=kernel_size, 141 | groups=channels, 142 | bias=False, 143 | padding=kernel_size // 2, 144 | ) 145 | 146 | gaussian_filter.weight.data = gaussian_kernel 147 | gaussian_filter.weight.requires_grad = False 148 | 149 | gaussian_filter.to(tensor.device) 150 | 151 | return gaussian_filter(tensor) 152 | 153 | def smooth_tensor3D(tensor, kernel_size=8, sigma=3, channels=1): 154 | 155 | # Set these to whatever you want for your gaussian filter 156 | 157 | if kernel_size % 2 == 0: 158 | kernel_size -= 1 159 | 160 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, kernel_size, 3) 161 | cord = torch.arange(kernel_size) 162 | x_grid, y_grid, z_grid = torch.meshgrid(cord,cord,cord) 163 | xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float() 164 | 165 | mean = (kernel_size - 1) / 2.0 166 | variance = sigma ** 2.0 167 | 168 | # Calculate the 3-dimensional gaussian kernel which is 169 | # the product of two gaussian distributions for two different 170 | # variables (in this case called x and y) 171 | import math 172 | 173 | gaussian_kernel = (1.0 / (2.0 * math.pi * variance)) * torch.exp( 174 | -torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2.0 * variance) 175 | ) 176 | # Make sure sum of values in gaussian kernel equals 1. 177 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 178 | 179 | # Reshape to 3d depthwise convolutional weight 180 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size, kernel_size) 181 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1, 1) 182 | 183 | gaussian_filter = torch.nn.Conv3d( 184 | in_channels=channels, 185 | out_channels=channels, 186 | kernel_size=kernel_size, 187 | groups=channels, 188 | bias=False, 189 | padding=kernel_size // 2, 190 | ) 191 | 192 | gaussian_filter.weight.data = gaussian_kernel 193 | gaussian_filter.weight.requires_grad = False 194 | 195 | gaussian_filter.to(tensor.device) 196 | 197 | return gaussian_filter(tensor) 198 | 199 | def smooth_tensor(tensor, kernel_size=8, sigma=3, channels=1): 200 | if len(tensor.shape) == 5: 201 | return smooth_tensor3D(tensor, kernel_size, sigma, channels) 202 | else: 203 | return smooth_tensor2D(tensor, kernel_size, sigma, channels) 204 | 205 | def normalize(tensor): 206 | 207 | tens_deta = tensor.detach().cpu() 208 | tens_deta -= float(np.min(tens_deta.numpy())) 209 | tens_deta /= float(np.max(tens_deta.numpy())) 210 | 211 | return tens_deta 212 | -------------------------------------------------------------------------------- /ce_noise.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def get_range_val(value, rnd_type="uniform"): 8 | if isinstance(value, (list, tuple, np.ndarray)): 9 | if len(value) == 2: 10 | if value[0] == value[1]: 11 | n_val = value[0] 12 | else: 13 | orig_type = type(value[0]) 14 | if rnd_type == "uniform": 15 | n_val = random.uniform(value[0], value[1]) 16 | elif rnd_type == "normal": 17 | n_val = random.normalvariate(value[0], value[1]) 18 | n_val = orig_type(n_val) 19 | elif len(value) == 1: 20 | n_val = value[0] 21 | else: 22 | raise RuntimeError("value must be either a single vlaue or a list/tuple of len 2") 23 | return n_val 24 | else: 25 | return value 26 | 27 | 28 | def get_square_mask(data_shape, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None): 29 | """Returns a 'mask' with the same size as the data, where random squares are != 0 30 | 31 | Args: 32 | data_shape ([tensor]): [data_shape to determine the shape of the returned tensor] 33 | square_size ([tuple]): [int/ int tuple (min_size, max_size), determining the min and max squear size] 34 | n_squares ([type]): [int/ int tuple (min_number, max_number), determining the min and max number of squares] 35 | noise_val (tuple, optional): [int/ int tuple (min_val, max_val), determining the min and max value given in the 36 | squares, which habe the value != 0 ]. Defaults to (0, 0). 37 | channel_wise_n_val (bool, optional): [Use a different value for each channel]. Defaults to False. 38 | square_pos ([type], optional): [Square position]. Defaults to None. 39 | """ 40 | 41 | def mask_random_square(img_shape, square_size, n_val, channel_wise_n_val=False, square_pos=None): 42 | """Masks (sets = 0) a random square in an image""" 43 | 44 | img_h = img_shape[-2] 45 | img_w = img_shape[-1] 46 | 47 | img = np.zeros(img_shape) 48 | 49 | if square_pos is None: 50 | w_start = np.random.randint(0, img_w - square_size) 51 | h_start = np.random.randint(0, img_h - square_size) 52 | else: 53 | pos_wh = square_pos[np.random.randint(0, len(square_pos))] 54 | w_start = pos_wh[0] 55 | h_start = pos_wh[1] 56 | 57 | if img.ndim == 2: 58 | rnd_n_val = get_range_val(n_val) 59 | img[h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 60 | elif img.ndim == 3: 61 | if channel_wise_n_val: 62 | for i in range(img.shape[0]): 63 | rnd_n_val = get_range_val(n_val) 64 | img[i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 65 | else: 66 | rnd_n_val = get_range_val(n_val) 67 | img[:, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 68 | elif img.ndim == 4: 69 | if channel_wise_n_val: 70 | for i in range(img.shape[0]): 71 | rnd_n_val = get_range_val(n_val) 72 | img[:, i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 73 | else: 74 | rnd_n_val = get_range_val(n_val) 75 | img[:, :, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val 76 | 77 | return img 78 | 79 | def mask_random_squares(img_shape, square_size, n_squares, n_val, channel_wise_n_val=False, square_pos=None): 80 | """Masks a given number of squares in an image""" 81 | img = np.zeros(img_shape) 82 | for i in range(n_squares): 83 | img = mask_random_square( 84 | img_shape, square_size, n_val, channel_wise_n_val=channel_wise_n_val, square_pos=square_pos 85 | ) 86 | return img 87 | 88 | ret_data = np.zeros(data_shape) 89 | for sample_idx in range(data_shape[0]): 90 | # rnd_n_val = get_range_val(noise_val) 91 | rnd_square_size = get_range_val(square_size) 92 | rnd_n_squares = get_range_val(n_squares) 93 | 94 | ret_data[sample_idx] = mask_random_squares( 95 | data_shape[1:], 96 | square_size=rnd_square_size, 97 | n_squares=rnd_n_squares, 98 | n_val=noise_val, 99 | channel_wise_n_val=channel_wise_n_val, 100 | square_pos=square_pos, 101 | ) 102 | 103 | return ret_data 104 | 105 | 106 | def smooth_tensor2D(tensor, kernel_size=8, sigma=3, channels=1): 107 | 108 | # Set these to whatever you want for your gaussian filter 109 | 110 | if kernel_size % 2 == 0: 111 | kernel_size -= 1 112 | 113 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 114 | x_cord = torch.arange(kernel_size) 115 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 116 | y_grid = x_grid.t() 117 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 118 | 119 | mean = (kernel_size - 1) / 2.0 120 | variance = sigma ** 2.0 121 | 122 | # Calculate the 2-dimensional gaussian kernel which is 123 | # the product of two gaussian distributions for two different 124 | # variables (in this case called x and y) 125 | import math 126 | 127 | gaussian_kernel = (1.0 / (2.0 * math.pi * variance)) * torch.exp( 128 | -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2.0 * variance) 129 | ) 130 | # Make sure sum of values in gaussian kernel equals 1. 131 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 132 | 133 | # Reshape to 2d depthwise convolutional weight 134 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 135 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) 136 | 137 | gaussian_filter = torch.nn.Conv2d( 138 | in_channels=channels, 139 | out_channels=channels, 140 | kernel_size=kernel_size, 141 | groups=channels, 142 | bias=False, 143 | padding=kernel_size // 2, 144 | ) 145 | 146 | gaussian_filter.weight.data = gaussian_kernel 147 | gaussian_filter.weight.requires_grad = False 148 | 149 | gaussian_filter.to(tensor.device) 150 | 151 | return gaussian_filter(tensor) 152 | 153 | def smooth_tensor3D(tensor, kernel_size=8, sigma=3, channels=1): 154 | 155 | # Set these to whatever you want for your gaussian filter 156 | 157 | if kernel_size % 2 == 0: 158 | kernel_size -= 1 159 | 160 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, kernel_size, 3) 161 | cord = torch.arange(kernel_size) 162 | x_grid, y_grid, z_grid = torch.meshgrid(cord,cord,cord) 163 | xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float() 164 | 165 | mean = (kernel_size - 1) / 2.0 166 | variance = sigma ** 2.0 167 | 168 | # Calculate the 3-dimensional gaussian kernel which is 169 | # the product of two gaussian distributions for two different 170 | # variables (in this case called x and y) 171 | import math 172 | 173 | gaussian_kernel = (1.0 / (2.0 * math.pi * variance)) * torch.exp( 174 | -torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2.0 * variance) 175 | ) 176 | # Make sure sum of values in gaussian kernel equals 1. 177 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 178 | 179 | # Reshape to 3d depthwise convolutional weight 180 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size, kernel_size) 181 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1, 1) 182 | 183 | gaussian_filter = torch.nn.Conv3d( 184 | in_channels=channels, 185 | out_channels=channels, 186 | kernel_size=kernel_size, 187 | groups=channels, 188 | bias=False, 189 | padding=kernel_size // 2, 190 | ) 191 | 192 | gaussian_filter.weight.data = gaussian_kernel 193 | gaussian_filter.weight.requires_grad = False 194 | 195 | gaussian_filter.to(tensor.device) 196 | 197 | return gaussian_filter(tensor) 198 | 199 | def smooth_tensor(tensor, kernel_size=8, sigma=3, channels=1): 200 | if len(tensor.shape) == 5: 201 | return smooth_tensor3D(tensor, kernel_size, sigma, channels) 202 | else: 203 | return smooth_tensor2D(tensor, kernel_size, sigma, channels) 204 | 205 | def normalize(tensor): 206 | 207 | tens_deta = tensor.detach().cpu() 208 | tens_deta -= float(np.min(tens_deta.numpy())) 209 | tens_deta /= float(np.max(tens_deta.numpy())) 210 | 211 | return tens_deta 212 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | StRegA Evaluation Script 4 | 5 | This script evaluates StRegA anomaly detection results against ground truth masks 6 | and computes metrics including Dice coefficient, precision, recall, and F1 score. 7 | 8 | Usage: 9 | python evaluate.py --predictions /path/to/predictions/ --ground_truth /path/to/gt/ \ 10 | --output results.csv 11 | """ 12 | 13 | import argparse 14 | import os 15 | import numpy as np 16 | import nibabel as nib 17 | import pandas as pd 18 | from glob import glob 19 | 20 | 21 | def dice_coefficient(true_mask, pred_mask, non_seg_score=1.0): 22 | """ 23 | Compute Dice coefficient between two binary masks. 24 | 25 | Args: 26 | true_mask: Ground truth binary mask 27 | pred_mask: Predicted binary mask 28 | non_seg_score: Score to return when both masks are empty 29 | 30 | Returns: 31 | Dice coefficient (0-1) 32 | """ 33 | assert true_mask.shape == pred_mask.shape, \ 34 | f"Shape mismatch: {true_mask.shape} vs {pred_mask.shape}" 35 | 36 | true_mask = np.asarray(true_mask).astype(bool) 37 | pred_mask = np.asarray(pred_mask).astype(bool) 38 | 39 | im_sum = true_mask.sum() + pred_mask.sum() 40 | if im_sum == 0: 41 | return non_seg_score 42 | 43 | intersection = np.logical_and(true_mask, pred_mask) 44 | return 2. * intersection.sum() / im_sum 45 | 46 | 47 | def precision_score(true_mask, pred_mask): 48 | """ 49 | Compute precision: TP / (TP + FP) 50 | 51 | Args: 52 | true_mask: Ground truth binary mask 53 | pred_mask: Predicted binary mask 54 | 55 | Returns: 56 | Precision score (0-1) 57 | 58 | Edge cases: 59 | - If no predictions are made and no positives exist: returns 1.0 (perfect precision) 60 | - If no predictions are made but positives exist: returns 0.0 (no true positives found) 61 | """ 62 | true_mask = np.asarray(true_mask).astype(bool) 63 | pred_mask = np.asarray(pred_mask).astype(bool) 64 | 65 | true_positives = np.sum(np.logical_and(true_mask, pred_mask)) 66 | predicted_positives = np.sum(pred_mask) 67 | 68 | if predicted_positives == 0: 69 | # No predictions made - return 1.0 if nothing to detect, 0.0 otherwise 70 | return 1.0 if np.sum(true_mask) == 0 else 0.0 71 | 72 | return true_positives / predicted_positives 73 | 74 | 75 | def recall_score(true_mask, pred_mask): 76 | """ 77 | Compute recall (sensitivity): TP / (TP + FN) 78 | 79 | Args: 80 | true_mask: Ground truth binary mask 81 | pred_mask: Predicted binary mask 82 | 83 | Returns: 84 | Recall score (0-1) 85 | 86 | Edge cases: 87 | - If no actual positives exist and no predictions made: returns 1.0 (perfect recall) 88 | - If no actual positives exist but predictions made: returns 0.0 (false positives) 89 | """ 90 | true_mask = np.asarray(true_mask).astype(bool) 91 | pred_mask = np.asarray(pred_mask).astype(bool) 92 | 93 | true_positives = np.sum(np.logical_and(true_mask, pred_mask)) 94 | actual_positives = np.sum(true_mask) 95 | 96 | if actual_positives == 0: 97 | # No actual positives - return 1.0 if no predictions, 0.0 otherwise (false positives) 98 | return 1.0 if np.sum(pred_mask) == 0 else 0.0 99 | 100 | return true_positives / actual_positives 101 | 102 | 103 | def f1_score(true_mask, pred_mask): 104 | """ 105 | Compute F1 score: 2 * (precision * recall) / (precision + recall) 106 | 107 | Args: 108 | true_mask: Ground truth binary mask 109 | pred_mask: Predicted binary mask 110 | 111 | Returns: 112 | F1 score (0-1) 113 | """ 114 | prec = precision_score(true_mask, pred_mask) 115 | rec = recall_score(true_mask, pred_mask) 116 | 117 | if prec + rec == 0: 118 | return 0.0 119 | 120 | return 2 * (prec * rec) / (prec + rec) 121 | 122 | 123 | def specificity_score(true_mask, pred_mask): 124 | """ 125 | Compute specificity: TN / (TN + FP) 126 | 127 | Args: 128 | true_mask: Ground truth binary mask 129 | pred_mask: Predicted binary mask 130 | 131 | Returns: 132 | Specificity score (0-1) 133 | """ 134 | true_mask = np.asarray(true_mask).astype(bool) 135 | pred_mask = np.asarray(pred_mask).astype(bool) 136 | 137 | true_negatives = np.sum(np.logical_and(~true_mask, ~pred_mask)) 138 | actual_negatives = np.sum(~true_mask) 139 | 140 | if actual_negatives == 0: 141 | return 1.0 142 | 143 | return true_negatives / actual_negatives 144 | 145 | 146 | def compute_all_metrics(true_mask, pred_mask): 147 | """ 148 | Compute all evaluation metrics. 149 | 150 | Args: 151 | true_mask: Ground truth binary mask 152 | pred_mask: Predicted binary mask 153 | 154 | Returns: 155 | Dictionary of metrics 156 | """ 157 | return { 158 | 'dice': dice_coefficient(true_mask, pred_mask), 159 | 'precision': precision_score(true_mask, pred_mask), 160 | 'recall': recall_score(true_mask, pred_mask), 161 | 'f1': f1_score(true_mask, pred_mask), 162 | 'specificity': specificity_score(true_mask, pred_mask), 163 | 'gt_volume': np.sum(true_mask.astype(bool)), 164 | 'pred_volume': np.sum(pred_mask.astype(bool)) 165 | } 166 | 167 | 168 | def load_nifti_mask(path, binarize=True, threshold=0): 169 | """ 170 | Load a NIfTI mask file. 171 | 172 | Args: 173 | path: Path to NIfTI file 174 | binarize: Whether to binarize the mask 175 | threshold: Threshold for binarization 176 | 177 | Returns: 178 | NumPy array of the mask 179 | """ 180 | nii = nib.load(path) 181 | mask = nii.get_fdata() 182 | 183 | if binarize: 184 | mask = (mask > threshold).astype(np.float32) 185 | 186 | return mask 187 | 188 | 189 | def find_matching_files(pred_dir, gt_dir, pred_suffix='_anomaly_mask.nii.gz', 190 | gt_suffix='.nii.gz'): 191 | """ 192 | Find matching prediction and ground truth file pairs. 193 | 194 | Args: 195 | pred_dir: Directory containing predictions 196 | gt_dir: Directory containing ground truth 197 | pred_suffix: Suffix for prediction files 198 | gt_suffix: Suffix for ground truth files 199 | 200 | Returns: 201 | List of (pred_path, gt_path, subject_id) tuples 202 | """ 203 | pred_files = glob(os.path.join(pred_dir, f'*{pred_suffix}')) 204 | matches = [] 205 | 206 | for pred_path in pred_files: 207 | # Extract subject ID by removing suffix and directory 208 | base_name = os.path.basename(pred_path) 209 | subject_id = base_name.replace(pred_suffix, '') 210 | 211 | # Look for matching ground truth 212 | gt_path = os.path.join(gt_dir, f'{subject_id}{gt_suffix}') 213 | 214 | if os.path.exists(gt_path): 215 | matches.append((pred_path, gt_path, subject_id)) 216 | else: 217 | print(f"Warning: No ground truth found for {subject_id}") 218 | 219 | return matches 220 | 221 | 222 | def main(): 223 | parser = argparse.ArgumentParser( 224 | description='StRegA Evaluation Script' 225 | ) 226 | parser.add_argument( 227 | '--predictions', '-p', 228 | type=str, 229 | required=True, 230 | help='Directory containing prediction masks' 231 | ) 232 | parser.add_argument( 233 | '--ground_truth', '-g', 234 | type=str, 235 | required=True, 236 | help='Directory containing ground truth masks' 237 | ) 238 | parser.add_argument( 239 | '--output', '-o', 240 | type=str, 241 | default='evaluation_results.csv', 242 | help='Output CSV file path (default: evaluation_results.csv)' 243 | ) 244 | parser.add_argument( 245 | '--pred_suffix', 246 | type=str, 247 | default='_anomaly_mask.nii.gz', 248 | help='Suffix for prediction files (default: _anomaly_mask.nii.gz)' 249 | ) 250 | parser.add_argument( 251 | '--gt_suffix', 252 | type=str, 253 | default='.nii.gz', 254 | help='Suffix for ground truth files (default: .nii.gz)' 255 | ) 256 | parser.add_argument( 257 | '--binarize_gt', 258 | action='store_true', 259 | help='Binarize ground truth masks (for multi-class labels)' 260 | ) 261 | parser.add_argument( 262 | '--verbose', '-v', 263 | action='store_true', 264 | help='Print per-subject results' 265 | ) 266 | 267 | args = parser.parse_args() 268 | 269 | # Find matching files 270 | matches = find_matching_files( 271 | args.predictions, 272 | args.ground_truth, 273 | args.pred_suffix, 274 | args.gt_suffix 275 | ) 276 | 277 | if len(matches) == 0: 278 | print("Error: No matching file pairs found!") 279 | return 280 | 281 | print(f"Found {len(matches)} matching file pairs") 282 | 283 | # Evaluate each pair 284 | results = [] 285 | 286 | for pred_path, gt_path, subject_id in matches: 287 | if args.verbose: 288 | print(f"\nEvaluating: {subject_id}") 289 | 290 | # Load masks 291 | pred_mask = load_nifti_mask(pred_path, binarize=True) 292 | gt_mask = load_nifti_mask(gt_path, binarize=args.binarize_gt) 293 | 294 | # Handle shape mismatches 295 | if pred_mask.shape != gt_mask.shape: 296 | print(f" Warning: Shape mismatch for {subject_id}") 297 | print(f" Prediction: {pred_mask.shape}") 298 | print(f" Ground truth: {gt_mask.shape}") 299 | continue 300 | 301 | # Compute metrics 302 | metrics = compute_all_metrics(gt_mask, pred_mask) 303 | metrics['subject_id'] = subject_id 304 | results.append(metrics) 305 | 306 | if args.verbose: 307 | print(f" Dice: {metrics['dice']:.4f}") 308 | print(f" Precision: {metrics['precision']:.4f}") 309 | print(f" Recall: {metrics['recall']:.4f}") 310 | print(f" F1: {metrics['f1']:.4f}") 311 | 312 | # Create results DataFrame 313 | df = pd.DataFrame(results) 314 | 315 | # Reorder columns 316 | cols = ['subject_id', 'dice', 'precision', 'recall', 'f1', 317 | 'specificity', 'gt_volume', 'pred_volume'] 318 | df = df[cols] 319 | 320 | # Save results 321 | df.to_csv(args.output, index=False) 322 | print(f"\nResults saved to: {args.output}") 323 | 324 | # Print summary statistics 325 | print("\n" + "="*50) 326 | print("Summary Statistics") 327 | print("="*50) 328 | 329 | for metric in ['dice', 'precision', 'recall', 'f1', 'specificity']: 330 | mean_val = df[metric].mean() 331 | std_val = df[metric].std() 332 | min_val = df[metric].min() 333 | max_val = df[metric].max() 334 | print(f"{metric.capitalize():12s}: {mean_val:.4f} ± {std_val:.4f} " 335 | f"(min: {min_val:.4f}, max: {max_val:.4f})") 336 | 337 | print("="*50) 338 | print(f"Number of subjects evaluated: {len(df)}") 339 | 340 | 341 | if __name__ == '__main__': 342 | main() 343 | -------------------------------------------------------------------------------- /ceVae/aes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributions as dist 4 | 5 | from .ae_bases import BasicEncoder, BasicGenerator 6 | 7 | 8 | class VAE(torch.nn.Module): 9 | def __init__( 10 | self, 11 | input_size, 12 | z_dim=256, 13 | fmap_sizes=(16, 64, 256, 1024), 14 | to_1x1=True, 15 | conv_op=torch.nn.Conv2d, 16 | conv_params=None, 17 | tconv_op=torch.nn.ConvTranspose2d, 18 | tconv_params=None, 19 | normalization_op=None, 20 | normalization_params=None, 21 | activation_op=torch.nn.LeakyReLU, 22 | activation_params=None, 23 | block_op=None, 24 | block_params=None, 25 | *args, 26 | **kwargs 27 | ): 28 | """Basic VAE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder) 29 | 30 | Args: 31 | input_size ((int, int, int): Size of the input in format CxHxW): 32 | z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256 33 | fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 34 | int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024). 35 | to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) 36 | or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. 37 | Defaults to True. 38 | conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. 39 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 40 | tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. 41 | tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 42 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 43 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 44 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 45 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 46 | block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. 47 | block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. 48 | """ 49 | 50 | super(VAE, self).__init__() 51 | 52 | input_size_enc = list(input_size) 53 | input_size_dec = list(input_size) 54 | 55 | self.enc = BasicEncoder( 56 | input_size=input_size_enc, 57 | fmap_sizes=fmap_sizes, 58 | z_dim=z_dim * 2, 59 | conv_op=conv_op, 60 | conv_params=conv_params, 61 | normalization_op=normalization_op, 62 | normalization_params=normalization_params, 63 | activation_op=activation_op, 64 | activation_params=activation_params, 65 | block_op=block_op, 66 | block_params=block_params, 67 | to_1x1=to_1x1, 68 | ) 69 | self.dec = BasicGenerator( 70 | input_size=input_size_dec, 71 | fmap_sizes=fmap_sizes[::-1], 72 | z_dim=z_dim, 73 | upsample_op=tconv_op, 74 | conv_params=tconv_params, 75 | normalization_op=normalization_op, 76 | normalization_params=normalization_params, 77 | activation_op=activation_op, 78 | activation_params=activation_params, 79 | block_op=block_op, 80 | block_params=block_params, 81 | to_1x1=to_1x1, 82 | ) 83 | 84 | self.hidden_size = self.enc.output_size 85 | 86 | def forward(self, inpt, sample=True, no_dist=False, **kwargs): 87 | y1 = self.enc(inpt, **kwargs) 88 | 89 | mu, log_std = torch.chunk(y1, 2, dim=1) 90 | std = torch.exp(log_std) 91 | z_dist = dist.Normal(mu, std) 92 | if sample: 93 | z_sample = z_dist.rsample() 94 | else: 95 | z_sample = mu 96 | 97 | x_rec = self.dec(z_sample) 98 | 99 | if no_dist: 100 | return x_rec 101 | else: 102 | return x_rec, z_dist 103 | 104 | def encode(self, inpt, **kwargs): 105 | """Encodes a sample and returns the paramters for the approx inference dist. (Normal) 106 | 107 | Args: 108 | inpt ([tensor]): The input to encode 109 | 110 | Returns: 111 | mu : The mean used to parameterized a Normal distribution 112 | std: The standard deviation used to parameterized a Normal distribution 113 | """ 114 | enc = self.enc(inpt, **kwargs) 115 | mu, log_std = torch.chunk(enc, 2, dim=1) 116 | std = torch.exp(log_std) 117 | return mu, std 118 | 119 | def decode(self, inpt, **kwargs): 120 | """Decodes a latent space sample, used the generative model (decode = mu_{gen}(z) as used in p(x|z) = N(x | mu_{gen}(z), 1) ). 121 | 122 | Args: 123 | inpt ([type]): A sample from the latent space to decode 124 | 125 | Returns: 126 | [type]: [description] 127 | """ 128 | x_rec = self.dec(inpt, **kwargs) 129 | return x_rec 130 | 131 | 132 | class AE(torch.nn.Module): 133 | def __init__( 134 | self, 135 | input_size, 136 | z_dim=1024, 137 | fmap_sizes=(16, 64, 256, 1024), 138 | to_1x1=True, 139 | conv_op=torch.nn.Conv2d, 140 | conv_params=None, 141 | tconv_op=torch.nn.ConvTranspose2d, 142 | tconv_params=None, 143 | normalization_op=None, 144 | normalization_params=None, 145 | activation_op=torch.nn.LeakyReLU, 146 | activation_params=None, 147 | block_op=None, 148 | block_params=None, 149 | *args, 150 | **kwargs 151 | ): 152 | """Basic AE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder) 153 | 154 | Args: 155 | input_size ((int, int, int): Size of the input in format CxHxW): 156 | z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256 157 | fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 158 | int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024). 159 | to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) 160 | or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. 161 | Defaults to True. 162 | conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. 163 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 164 | tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. 165 | tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 166 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 167 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 168 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 169 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 170 | block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. 171 | block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. 172 | """ 173 | super(AE, self).__init__() 174 | 175 | input_size_enc = list(input_size) 176 | input_size_dec = list(input_size) 177 | 178 | self.enc = BasicEncoder( 179 | input_size=input_size_enc, 180 | fmap_sizes=fmap_sizes, 181 | z_dim=z_dim, 182 | conv_op=conv_op, 183 | conv_params=conv_params, 184 | normalization_op=normalization_op, 185 | normalization_params=normalization_params, 186 | activation_op=activation_op, 187 | activation_params=activation_params, 188 | block_op=block_op, 189 | block_params=block_params, 190 | to_1x1=to_1x1, 191 | ) 192 | self.dec = BasicGenerator( 193 | input_size=input_size_dec, 194 | fmap_sizes=fmap_sizes[::-1], 195 | z_dim=z_dim, 196 | upsample_op=tconv_op, 197 | conv_params=tconv_params, 198 | normalization_op=normalization_op, 199 | normalization_params=normalization_params, 200 | activation_op=activation_op, 201 | activation_params=activation_params, 202 | block_op=block_op, 203 | block_params=block_params, 204 | to_1x1=to_1x1, 205 | ) 206 | 207 | self.hidden_size = self.enc.output_size 208 | 209 | def forward(self, inpt, **kwargs): 210 | 211 | y1 = self.enc(inpt, **kwargs) 212 | 213 | x_rec = self.dec(y1) 214 | 215 | return x_rec 216 | 217 | def encode(self, inpt, **kwargs): 218 | """Encodes a input sample to a latent space sample 219 | 220 | Args: 221 | inpt ([tensor]): Input sample 222 | 223 | Returns: 224 | enc: Encoded input sample in the latent space 225 | """ 226 | enc = self.enc(inpt, **kwargs) 227 | return enc 228 | 229 | def decode(self, inpt, **kwargs): 230 | """Decodes a latent space sample back to the input space 231 | 232 | Args: 233 | inpt ([tensor]): [Latent space sample] 234 | 235 | Returns: 236 | [rec]: [Encoded latent sample back in the input space] 237 | """ 238 | rec = self.dec(inpt, **kwargs) 239 | return rec 240 | -------------------------------------------------------------------------------- /misc/cevae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributions as dist 5 | 6 | class NoOp(nn.Module): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super(NoOp, self).__init__() 10 | 11 | def forward(self, x, *args, **kwargs): 12 | return x 13 | 14 | 15 | class ConvModule(nn.Module): 16 | def __init__(self, in_channels, out_channels, conv_op=nn.Conv2d, conv_params=None, 17 | normalization_op=nn.BatchNorm2d, normalization_params=None, 18 | activation_op=nn.LeakyReLU, activation_params=None): 19 | 20 | super(ConvModule, self).__init__() 21 | 22 | self.conv_params = conv_params 23 | if self.conv_params is None: 24 | self.conv_params = {} 25 | self.activation_params = activation_params 26 | if self.activation_params is None: 27 | self.activation_params = {} 28 | self.normalization_params = normalization_params 29 | if self.normalization_params is None: 30 | self.normalization_params = {} 31 | 32 | self.conv = None 33 | if conv_op is not None and not isinstance(conv_op, str): 34 | self.conv = conv_op(in_channels, out_channels, **self.conv_params) 35 | 36 | self.normalization = None 37 | if normalization_op is not None and not isinstance(normalization_op, str): 38 | self.normalization = normalization_op(out_channels, **self.normalization_params) 39 | 40 | self.activation = None 41 | if activation_op is not None and not isinstance(activation_op, str): 42 | self.activation = activation_op(**self.activation_params) 43 | 44 | def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None): 45 | 46 | x = input 47 | 48 | if self.conv is not None: 49 | if conv_add_input is None: 50 | x = self.conv(x) 51 | else: 52 | x = self.conv(x, **conv_add_input) 53 | 54 | if self.normalization is not None: 55 | if normalization_add_input is None: 56 | x = self.normalization(x) 57 | else: 58 | x = self.normalization(x, **normalization_add_input) 59 | 60 | if self.activation is not None: 61 | if activation_add_input is None: 62 | x = self.activation(x) 63 | else: 64 | x = self.activation(x, **activation_add_input) 65 | 66 | return x 67 | 68 | 69 | # Basic Generator 70 | class Generator(nn.Module): 71 | def __init__(self, image_size, z_dim=256, h_size=(256, 128, 64), 72 | upsample_op=nn.ConvTranspose2d, normalization_op=nn.InstanceNorm2d, activation_op=nn.LeakyReLU, 73 | conv_params=None, activation_params=None, block_op=None, block_params=None, to_1x1=True): 74 | 75 | super(Generator, self).__init__() 76 | 77 | if conv_params is None: 78 | conv_params = {} 79 | 80 | n_channels = image_size[0] 81 | img_size = np.array([image_size[1], image_size[2]]) 82 | 83 | if not isinstance(h_size, list) and not isinstance(h_size, tuple): 84 | raise AttributeError("h_size has to be either a list or tuple or an int") 85 | elif len(h_size) < 2: 86 | raise AttributeError("h_size has to contain at least three elements") 87 | else: 88 | h_size_bot = h_size[0] 89 | 90 | # We need to know how many layers we will use at the beginning 91 | img_size_new = img_size // (2 ** len(h_size)) 92 | if np.min(img_size_new) < 2 and z_dim is not None: 93 | raise AttributeError("h_size to long, one image dimension has already perished") 94 | 95 | ### Start block 96 | start_block = [] 97 | 98 | # Z_size random numbers 99 | 100 | if not to_1x1: 101 | kernel_size_start = [min(4, i) for i in img_size_new] 102 | else: 103 | kernel_size_start = img_size_new.tolist() 104 | 105 | if z_dim is not None: 106 | self.start = ConvModule(z_dim, h_size_bot, 107 | conv_op=upsample_op, 108 | conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False, 109 | **conv_params), 110 | normalization_op=normalization_op, 111 | normalization_params={}, 112 | activation_op=activation_op, 113 | activation_params=activation_params 114 | ) 115 | 116 | img_size_new = img_size_new * 2 117 | else: 118 | self.start = NoOp() 119 | 120 | ### Middle block (Done until we reach ? x image_size/2 x image_size/2) 121 | self.middle_blocks = nn.ModuleList() 122 | 123 | for h_size_top in h_size[1:]: 124 | 125 | if block_op is not None and not isinstance(block_op, str): 126 | self.middle_blocks.append( 127 | block_op(h_size_bot, **block_params) 128 | ) 129 | 130 | self.middle_blocks.append( 131 | ConvModule(h_size_bot, h_size_top, 132 | conv_op=upsample_op, 133 | conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), 134 | normalization_op=normalization_op, 135 | normalization_params={}, 136 | activation_op=activation_op, 137 | activation_params=activation_params 138 | ) 139 | ) 140 | 141 | h_size_bot = h_size_top 142 | img_size_new = img_size_new * 2 143 | 144 | ### End block 145 | self.end = ConvModule(h_size_bot, n_channels, 146 | conv_op=upsample_op, 147 | conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), 148 | normalization_op=None, 149 | activation_op=None) 150 | 151 | def forward(self, inpt, **kwargs): 152 | output = self.start(inpt, **kwargs) 153 | for middle in self.middle_blocks: 154 | output = middle(output, **kwargs) 155 | output = self.end(output, **kwargs) 156 | return output 157 | 158 | 159 | # Basic Encoder 160 | class Encoder(nn.Module): 161 | def __init__(self, image_size, z_dim=256, h_size=(64, 128, 256), 162 | conv_op=nn.Conv2d, normalization_op=nn.InstanceNorm2d, activation_op=nn.LeakyReLU, 163 | conv_params=None, activation_params=None, 164 | block_op=None, block_params=None, 165 | to_1x1=True): 166 | super(Encoder, self).__init__() 167 | 168 | if conv_params is None: 169 | conv_params = {} 170 | 171 | n_channels = image_size[0] 172 | img_size_new = np.array([image_size[1], image_size[2]]) 173 | 174 | if not isinstance(h_size, list) and not isinstance(h_size, tuple): 175 | raise AttributeError("h_size has to be either a list or tuple or an int") 176 | # elif len(h_size) < 2: 177 | # raise AttributeError("h_size has to contain at least three elements") 178 | else: 179 | h_size_bot = h_size[0] 180 | 181 | ### Start block 182 | self.start = ConvModule(n_channels, h_size_bot, 183 | conv_op=conv_op, 184 | conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), 185 | normalization_op=normalization_op, 186 | normalization_params={}, 187 | activation_op=activation_op, 188 | activation_params=activation_params 189 | ) 190 | img_size_new = img_size_new // 2 191 | 192 | ### Middle block (Done until we reach ? x 4 x 4) 193 | self.middle_blocks = nn.ModuleList() 194 | 195 | for h_size_top in h_size[1:]: 196 | 197 | if block_op is not None and not isinstance(block_op, str): 198 | self.middle_blocks.append( 199 | block_op(h_size_bot, **block_params) 200 | ) 201 | 202 | self.middle_blocks.append( 203 | ConvModule(h_size_bot, h_size_top, 204 | conv_op=conv_op, 205 | conv_params=dict(kernel_size=4, stride=2, padding=1, bias=False, **conv_params), 206 | normalization_op=normalization_op, 207 | normalization_params={}, 208 | activation_op=activation_op, 209 | activation_params=activation_params 210 | ) 211 | ) 212 | 213 | h_size_bot = h_size_top 214 | img_size_new = img_size_new // 2 215 | 216 | if np.min(img_size_new) < 2 and z_dim is not None: 217 | raise ("h_size to long, one image dimension has already perished") 218 | 219 | ### End block 220 | if not to_1x1: 221 | kernel_size_end = [min(4, i) for i in img_size_new] 222 | else: 223 | kernel_size_end = img_size_new.tolist() 224 | 225 | if z_dim is not None: 226 | self.end = ConvModule(h_size_bot, z_dim, 227 | conv_op=conv_op, 228 | conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False, 229 | **conv_params), 230 | normalization_op=None, 231 | activation_op=None, 232 | ) 233 | 234 | if to_1x1: 235 | self.output_size = (z_dim, 1, 1) 236 | else: 237 | self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(img_size_new, kernel_size_end)]) 238 | else: 239 | self.end = NoOp() 240 | self.output_size = img_size_new 241 | 242 | def forward(self, inpt, **kwargs): 243 | output = self.start(inpt, **kwargs) 244 | for middle in self.middle_blocks: 245 | output = middle(output, **kwargs) 246 | output = self.end(output, **kwargs) 247 | return output 248 | 249 | class VAE(torch.nn.Module): 250 | def __init__(self, input_size, h_size, z_dim, to_1x1=True, conv_op=torch.nn.Conv2d, 251 | upsample_op=torch.nn.ConvTranspose2d, normalization_op=None, activation_op=torch.nn.LeakyReLU, 252 | conv_params=None, activation_params=None, block_op=None, block_params=None, output_channels=None, 253 | additional_input_slices=None, 254 | *args, **kwargs): 255 | 256 | super(VAE, self).__init__() 257 | 258 | input_size_enc = list(input_size) 259 | input_size_dec = list(input_size) 260 | if output_channels is not None: 261 | input_size_dec[0] = output_channels 262 | if additional_input_slices is not None: 263 | input_size_enc[0] += additional_input_slices * 2 264 | 265 | self.encoder = Encoder(image_size=input_size_enc, h_size=h_size, z_dim=z_dim * 2, 266 | normalization_op=normalization_op, to_1x1=to_1x1, conv_op=conv_op, 267 | conv_params=conv_params, 268 | activation_op=activation_op, activation_params=activation_params, block_op=block_op, 269 | block_params=block_params) 270 | self.decoder = Generator(image_size=input_size_dec, h_size=h_size[::-1], z_dim=z_dim, 271 | normalization_op=normalization_op, to_1x1=to_1x1, upsample_op=upsample_op, 272 | conv_params=conv_params, activation_op=activation_op, 273 | activation_params=activation_params, block_op=block_op, 274 | block_params=block_params) 275 | 276 | self.hidden_size = self.encoder.output_size 277 | 278 | def forward(self, inpt, sample=None, **kwargs): 279 | enc = self.encoder(inpt, **kwargs) 280 | 281 | mu, log_std = torch.chunk(enc, 2, dim=1) 282 | std = torch.exp(log_std) 283 | z_dist = dist.Normal(mu, std) 284 | 285 | if sample or self.training: 286 | z = z_dist.rsample() 287 | else: 288 | z = mu 289 | 290 | x_rec = self.decoder(z, **kwargs) 291 | 292 | return x_rec, mu, std 293 | 294 | def encode(self, inpt, **kwargs): 295 | enc = self.encoder(inpt, **kwargs) 296 | mu, log_std = torch.chunk(enc, 2, dim=1) 297 | return mu, log_std 298 | 299 | def decode(self, inpt, **kwargs): 300 | x_rec = self.decoder(inpt, **kwargs) 301 | return x_rec -------------------------------------------------------------------------------- /Pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a1b2c3d4", 6 | "metadata": {}, 7 | "source": [ 8 | "# StRegA: Unsupervised Anomaly Detection Pipeline\n", 9 | "\n", 10 | "This notebook demonstrates the complete StRegA inference pipeline for detecting anomalies in brain MRIs.\n", 11 | "\n", 12 | "The workflow is aligned with the `inference.py` script to ensure consistent results." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "88243c52", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import torch\n", 24 | "import numpy as np\n", 25 | "import nibabel as nib\n", 26 | "from scipy import ndimage\n", 27 | "from skimage import morphology, filters\n", 28 | "from torchio import transforms\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "from torch.cuda.amp import autocast" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "b8add4cc", 36 | "metadata": {}, 37 | "source": [ 38 | "## 1. Setup Device and Load Model\n", 39 | "\n", 40 | "We load the pre-trained model from HuggingFace. This is the recommended approach for inference." 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "device_setup", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Setup device\n", 51 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 52 | "print(f\"Using device: {device}\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "5fc305fb", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# Load model from HuggingFace\n", 63 | "from transformers import AutoModel\n", 64 | "\n", 65 | "print(\"Loading model from HuggingFace...\")\n", 66 | "modelHF = AutoModel.from_pretrained(\n", 67 | " \"soumickmj/StRegA_cceVAE2D_Brain_MOOD_IXIT1_IXIT2_IXIPD\", \n", 68 | " trust_remote_code=True\n", 69 | ")\n", 70 | "model = modelHF.model.to(device)\n", 71 | "model.eval()\n", 72 | "print(\"Model loaded successfully!\")" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "alt_loading", 78 | "metadata": {}, 79 | "source": [ 80 | "### Alternative: Load from Local Checkpoint\n", 81 | "\n", 82 | "If you have a locally saved model checkpoint, you can use this cell instead:" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "62b4dc12", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Alternative: Load from local checkpoint (uncomment to use)\n", 93 | "# model = torch.load('checkpoint/brain.ptrh', map_location=device)\n", 94 | "# model.eval()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "id": "57e99c1f", 100 | "metadata": {}, 101 | "source": [ 102 | "## 2. Load and Preprocess NIfTI Volume\n", 103 | "\n", 104 | "The preprocessing matches the training pipeline:\n", 105 | "1. Load the NIfTI file\n", 106 | "2. Move slices to the first dimension\n", 107 | "3. Crop or pad to 256x256\n", 108 | "4. Per-slice normalization using `(x - min) / max`" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "preprocess_funcs", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def preprocess_volume(nifti_path, target_size=(256, 256)):\n", 119 | " \"\"\"\n", 120 | " Load and preprocess a NIfTI volume for inference.\n", 121 | " \n", 122 | " Args:\n", 123 | " nifti_path: Path to the NIfTI file\n", 124 | " target_size: Target spatial dimensions (H, W)\n", 125 | " \n", 126 | " Returns:\n", 127 | " Preprocessed tensor of shape (N_slices, 1, H, W)\n", 128 | " affine: NIfTI affine matrix\n", 129 | " original_shape: Original volume shape\n", 130 | " \"\"\"\n", 131 | " # Load volume\n", 132 | " nii = nib.load(nifti_path)\n", 133 | " vol = nii.get_fdata()\n", 134 | " affine = nii.affine\n", 135 | " original_shape = vol.shape\n", 136 | " \n", 137 | " # Move slices to first dimension (from H, W, D to D, H, W)\n", 138 | " vol = np.moveaxis(vol, 2, 0)\n", 139 | " \n", 140 | " # Convert to tensor and add channel dimension\n", 141 | " data_item = torch.tensor(vol).unsqueeze(dim=0).float()\n", 142 | " \n", 143 | " # Crop or pad to target size\n", 144 | " target_shape = (vol.shape[0], target_size[0], target_size[1])\n", 145 | " out = transforms.CropOrPad(target_shape)(data_item)\n", 146 | " out = out.squeeze(dim=0).unsqueeze(dim=1)\n", 147 | " \n", 148 | " return out, affine, original_shape\n", 149 | "\n", 150 | "\n", 151 | "def normalize_volume(volume, device):\n", 152 | " \"\"\"\n", 153 | " Normalize volume per-slice to match training pipeline.\n", 154 | " \n", 155 | " Training uses: (x - min) / max for each slice independently.\n", 156 | " This shifts minimum to 0 and scales by max value.\n", 157 | " \n", 158 | " Args:\n", 159 | " volume: Tensor of shape (N_slices, 1, H, W)\n", 160 | " device: torch device\n", 161 | " \n", 162 | " Returns:\n", 163 | " Normalized volume tensor\n", 164 | " \"\"\"\n", 165 | " volume = volume.to(device)\n", 166 | " \n", 167 | " # Flatten spatial dims for per-slice normalization\n", 168 | " volume_flat = volume.view(volume.shape[0], 1, -1) # (N, 1, H*W)\n", 169 | " min_vals = volume_flat.min(dim=2, keepdim=True).values # (N, 1, 1)\n", 170 | " max_vals = volume_flat.max(dim=2, keepdim=True).values # (N, 1, 1)\n", 171 | " \n", 172 | " # Division by zero protection\n", 173 | " valid_max = max_vals != 0\n", 174 | " volume_flat_norm = torch.where(\n", 175 | " valid_max,\n", 176 | " (volume_flat - min_vals) / torch.where(valid_max, max_vals, torch.ones_like(max_vals)),\n", 177 | " volume_flat\n", 178 | " )\n", 179 | " volume_norm = volume_flat_norm.view(volume.shape)\n", 180 | " \n", 181 | " # Handle NaN values\n", 182 | " volume_norm = torch.nan_to_num(volume_norm, nan=0.0)\n", 183 | " \n", 184 | " return volume_norm" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "load_data", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# Load a single NIfTI volume\n", 195 | "# Replace with your own file path\n", 196 | "nifti_path = 'path/to/your/brain_mri.nii.gz'\n", 197 | "\n", 198 | "# For BraTS data structure (as example):\n", 199 | "# nifti_path = './brats/seg/example_volume.nii.gz'\n", 200 | "\n", 201 | "volume, affine, original_shape = preprocess_volume(nifti_path)\n", 202 | "print(f\"Original shape: {original_shape}\")\n", 203 | "print(f\"Preprocessed shape: {volume.shape}\")" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "normalize", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "# Normalize the volume (per-slice, matching training)\n", 214 | "img = normalize_volume(volume, device)\n", 215 | "print(f\"Normalized volume shape: {img.shape}\")\n", 216 | "print(f\"Value range: [{img.min():.4f}, {img.max():.4f}]\")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "26e4b3ff", 222 | "metadata": {}, 223 | "source": [ 224 | "## 3. Run Inference\n", 225 | "\n", 226 | "Pass the normalized volume through the model to get the reconstruction." 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "inference", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# Run inference\n", 237 | "with torch.no_grad():\n", 238 | " with autocast():\n", 239 | " reconstruction, _ = model(img)\n", 240 | "\n", 241 | "reconstruction = reconstruction.float()\n", 242 | "print(f\"Reconstruction shape: {reconstruction.shape}\")" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "id": "postprocess_header", 248 | "metadata": {}, 249 | "source": [ 250 | "## 4. Post-Processing Pipeline\n", 251 | "\n", 252 | "The anomaly detection post-processing consists of:\n", 253 | "1. Calculate difference (reconstruction error)\n", 254 | "2. Keep only positive differences\n", 255 | "3. Apply manual thresholding\n", 256 | "4. Apply Otsu thresholding\n", 257 | "5. Morphological opening to remove small false positives\n", 258 | "6. Mask out background regions" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "id": "postprocess", 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# Post-processing parameters (matching inference.py defaults)\n", 269 | "AREA_THRESHOLD = 256 # Minimum connected component area to keep (in pixels)\n", 270 | "ANOMALY_THRESHOLD = 0.2 # Initial threshold for anomaly detection\n", 271 | "\n", 272 | "# Calculate difference (reconstruction error)\n", 273 | "diff_mask = (reconstruction.cpu().numpy() - img.cpu().numpy())\n", 274 | "\n", 275 | "# Keep only positive differences (anomalies are under-reconstructed)\n", 276 | "m_diff_mask = diff_mask.copy()\n", 277 | "m_diff_mask[m_diff_mask < 0] = 0\n", 278 | "\n", 279 | "# Manual thresholding\n", 280 | "m_diff_mask[m_diff_mask > ANOMALY_THRESHOLD] = 1\n", 281 | "\n", 282 | "# Otsu thresholding for adaptive binarization\n", 283 | "val = filters.threshold_otsu(m_diff_mask)\n", 284 | "thr = m_diff_mask > val\n", 285 | "thr[thr < 0] = 0\n", 286 | "\n", 287 | "# Morphological opening to remove small false positives\n", 288 | "final = np.zeros_like(thr)\n", 289 | "for i in range(thr.shape[0]):\n", 290 | " final[i, 0] = morphology.area_opening(thr[i, 0], area_threshold=AREA_THRESHOLD)\n", 291 | "\n", 292 | "# Remove detections outside brain mask\n", 293 | "final[img.cpu().numpy() == 0] = 0\n", 294 | "\n", 295 | "print(f\"Anomaly mask shape: {final.shape}\")\n", 296 | "print(f\"Total anomalous voxels: {final.sum()}\")" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "id": "viz_header", 302 | "metadata": {}, 303 | "source": [ 304 | "## 5. Visualization\n", 305 | "\n", 306 | "Visualize the results for a selected slice." 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "id": "830b3201", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "def dice(true_mask, pred_mask, non_seg_score=1.0):\n", 317 | " \"\"\"\n", 318 | " Computes the Dice coefficient.\n", 319 | " \n", 320 | " Args:\n", 321 | " true_mask: Ground truth binary mask\n", 322 | " pred_mask: Predicted binary mask\n", 323 | " non_seg_score: Score to return when both masks are empty\n", 324 | " \n", 325 | " Returns:\n", 326 | " Dice coefficient (0-1)\n", 327 | " \"\"\"\n", 328 | " assert true_mask.shape == pred_mask.shape\n", 329 | "\n", 330 | " true_mask = np.asarray(true_mask).astype(bool)\n", 331 | " pred_mask = np.asarray(pred_mask).astype(bool)\n", 332 | "\n", 333 | " im_sum = true_mask.sum() + pred_mask.sum()\n", 334 | " if im_sum == 0:\n", 335 | " return non_seg_score\n", 336 | "\n", 337 | " intersection = np.logical_and(true_mask, pred_mask)\n", 338 | " return 2. * intersection.sum() / im_sum" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "id": "visualization", 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "# Select a slice to visualize\n", 349 | "s_index = img.shape[0] // 2 # Middle slice\n", 350 | "\n", 351 | "fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", 352 | "\n", 353 | "# Input image\n", 354 | "ax = axes[0]\n", 355 | "rotated_img = ndimage.rotate(img.cpu().numpy()[s_index, 0], -90)\n", 356 | "ax.imshow(rotated_img, cmap='gray')\n", 357 | "ax.set_title('Input (Normalized)')\n", 358 | "ax.axis('off')\n", 359 | "\n", 360 | "# Reconstruction\n", 361 | "ax = axes[1]\n", 362 | "rotated_img = ndimage.rotate(reconstruction.cpu().numpy()[s_index, 0], -90)\n", 363 | "ax.imshow(rotated_img, cmap='gray')\n", 364 | "ax.set_title('Reconstruction')\n", 365 | "ax.axis('off')\n", 366 | "\n", 367 | "# Difference (thresholded)\n", 368 | "ax = axes[2]\n", 369 | "rotated_img = ndimage.rotate(thr[s_index, 0], -90)\n", 370 | "ax.imshow(rotated_img, cmap='gray')\n", 371 | "ax.set_title('Otsu Threshold')\n", 372 | "ax.axis('off')\n", 373 | "\n", 374 | "# Final anomaly mask\n", 375 | "ax = axes[3]\n", 376 | "rotated_img = ndimage.rotate(final[s_index, 0], -90)\n", 377 | "ax.imshow(rotated_img, cmap='gray')\n", 378 | "ax.set_title('Final Anomaly Mask')\n", 379 | "ax.axis('off')\n", 380 | "\n", 381 | "plt.tight_layout()\n", 382 | "plt.show()" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "id": "gt_comparison", 388 | "metadata": {}, 389 | "source": [ 390 | "## 6. Evaluation with Ground Truth (Optional)\n", 391 | "\n", 392 | "If you have a ground truth mask, you can compute the Dice coefficient." 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "id": "evaluation", 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "# Load ground truth mask (if available)\n", 403 | "# gt_path = './brats/mask/example_volume.nii.gz'\n", 404 | "# gt_nii = nib.load(gt_path)\n", 405 | "# gt_mask = gt_nii.get_fdata()\n", 406 | "# gt_mask = np.moveaxis(gt_mask, 2, 0)\n", 407 | "# gt_mask[gt_mask > 0] = 1\n", 408 | "\n", 409 | "# # Crop/pad to match\n", 410 | "# gt_item = torch.tensor(gt_mask).unsqueeze(dim=0)\n", 411 | "# gt_out = transforms.CropOrPad((gt_mask.shape[0], 256, 256))(gt_item)\n", 412 | "# gt_out = gt_out.squeeze(dim=0).unsqueeze(dim=1).numpy()\n", 413 | "\n", 414 | "# # Compute Dice\n", 415 | "# dice_score = dice(gt_out, final)\n", 416 | "# print(f\"Dice coefficient: {dice_score:.4f}\")" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "id": "save_results", 422 | "metadata": {}, 423 | "source": [ 424 | "## 7. Save Results\n", 425 | "\n", 426 | "Save the anomaly mask as a NIfTI file." 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "save", 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "# Save anomaly mask as NIfTI\n", 437 | "# Reformat to original orientation\n", 438 | "anomaly_mask = final.squeeze(1) # Remove channel dim\n", 439 | "anomaly_mask = np.moveaxis(anomaly_mask, 0, 2) # Move slices back to last dim\n", 440 | "\n", 441 | "# Create NIfTI image\n", 442 | "output_nii = nib.Nifti1Image(anomaly_mask.astype(np.float32), affine)\n", 443 | "\n", 444 | "# Save\n", 445 | "# output_path = 'anomaly_mask.nii.gz'\n", 446 | "# nib.save(output_nii, output_path)\n", 447 | "# print(f\"Saved anomaly mask to: {output_path}\")" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "id": "notes", 453 | "metadata": {}, 454 | "source": [ 455 | "## Notes\n", 456 | "\n", 457 | "### Preprocessing Alignment\n", 458 | "The preprocessing in this notebook matches the training pipeline:\n", 459 | "- **Per-slice normalization**: `(x - min) / max` computed independently for each slice\n", 460 | "- **Important**: This is NOT standard min-max normalization `(x - min) / (max - min)`. The training code (train.py lines 24-27) specifically uses `(x - min) / max`.\n", 461 | "- **This matches**: `train.py` lines 24-27 and `inference.py`\n", 462 | "\n", 463 | "### Post-processing Steps\n", 464 | "1. **Difference calculation**: `reconstruction - input`\n", 465 | "2. **Positive thresholding**: Keep values > 0 (anomalies are under-reconstructed)\n", 466 | "3. **Manual threshold**: Values > ANOMALY_THRESHOLD (default 0.2) are set to 1\n", 467 | "4. **Otsu threshold**: Adaptive binarization\n", 468 | "5. **Morphological opening**: Remove small artifacts (area < AREA_THRESHOLD pixels, default 256)\n", 469 | "6. **Brain masking**: Remove detections outside the brain\n", 470 | "\n", 471 | "### Using the CLI\n", 472 | "For batch processing, use the `inference.py` script:\n", 473 | "```bash\n", 474 | "python inference.py -i /path/to/volume.nii.gz -o /path/to/output/\n", 475 | "```" 476 | ] 477 | } 478 | ], 479 | "metadata": { 480 | "kernelspec": { 481 | "display_name": "Python 3", 482 | "language": "python", 483 | "name": "python3" 484 | }, 485 | "language_info": { 486 | "codemirror_mode": { 487 | "name": "ipython", 488 | "version": 3 489 | }, 490 | "file_extension": ".py", 491 | "mimetype": "text/x-python", 492 | "name": "python", 493 | "nbconvert_exporter": "python", 494 | "pygments_lexer": "ipython3", 495 | "version": "3.11.3" 496 | } 497 | }, 498 | "nbformat": 4, 499 | "nbformat_minor": 5 500 | } 501 | -------------------------------------------------------------------------------- /ceVae/ae_bases.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class NoOp(nn.Module): 9 | def __init__(self, *args, **kwargs): 10 | """NoOp Pytorch Module. 11 | Forwards the given input as is. 12 | """ 13 | super(NoOp, self).__init__() 14 | 15 | def forward(self, x, *args, **kwargs): 16 | return x 17 | 18 | 19 | class ConvModule(nn.Module): 20 | def __init__( 21 | self, 22 | in_channels: int, 23 | out_channels: int, 24 | conv_op=nn.Conv2d, 25 | conv_params=None, 26 | normalization_op=None, 27 | normalization_params=None, 28 | activation_op=nn.LeakyReLU, 29 | activation_params=None, 30 | ): 31 | """Basic Conv Pytorch Conv Module 32 | Has can have a Conv Op, a Normlization Op and a Non Linearity: 33 | x = conv(x) 34 | x = some_norm(x) 35 | x = nonlin(x) 36 | 37 | Args: 38 | in_channels ([int]): [Number on input channels/ feature maps] 39 | out_channels ([int]): [Number of ouput channels/ feature maps] 40 | conv_op ([torch.nn.Module], optional): [Conv operation]. Defaults to nn.Conv2d. 41 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. 42 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...)]. Defaults to None. 43 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 44 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...)]. Defaults to nn.LeakyReLU. 45 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 46 | """ 47 | 48 | super(ConvModule, self).__init__() 49 | 50 | self.conv_params = conv_params 51 | if self.conv_params is None: 52 | self.conv_params = {} 53 | self.activation_params = activation_params 54 | if self.activation_params is None: 55 | self.activation_params = {} 56 | self.normalization_params = normalization_params 57 | if self.normalization_params is None: 58 | self.normalization_params = {} 59 | 60 | self.conv = None 61 | if conv_op is not None and not isinstance(conv_op, str): 62 | self.conv = conv_op(in_channels, out_channels, **self.conv_params) 63 | 64 | self.normalization = None 65 | if normalization_op is not None and not isinstance(normalization_op, str): 66 | self.normalization = normalization_op(out_channels, **self.normalization_params) 67 | 68 | self.activation = None 69 | if activation_op is not None and not isinstance(activation_op, str): 70 | self.activation = activation_op(**self.activation_params) 71 | 72 | def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None): 73 | 74 | x = input 75 | 76 | if self.conv is not None: 77 | if conv_add_input is None: 78 | x = self.conv(x) 79 | else: 80 | x = self.conv(x, **conv_add_input) 81 | 82 | if self.normalization is not None: 83 | if normalization_add_input is None: 84 | x = self.normalization(x) 85 | else: 86 | x = self.normalization(x, **normalization_add_input) 87 | 88 | if self.activation is not None: 89 | if activation_add_input is None: 90 | x = self.activation(x) 91 | else: 92 | x = self.activation(x, **activation_add_input) 93 | 94 | # nn.functional.dropout(x, p=0.95, training=True) 95 | 96 | return x 97 | 98 | 99 | class ConvBlock(nn.Module): 100 | def __init__( 101 | self, 102 | n_convs: int, 103 | n_featmaps: int, 104 | conv_op=nn.Conv2d, 105 | conv_params=None, 106 | normalization_op=nn.BatchNorm2d, 107 | normalization_params=None, 108 | activation_op=nn.LeakyReLU, 109 | activation_params=None, 110 | ): 111 | """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) 112 | 113 | Args: 114 | n_convs ([type]): [Number of convolutions] 115 | n_featmaps ([type]): [Feature map size of the conv] 116 | conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. 117 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. 118 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 119 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 120 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 121 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 122 | """ 123 | 124 | super(ConvBlock, self).__init__() 125 | 126 | self.n_featmaps = n_featmaps 127 | self.n_convs = n_convs 128 | self.conv_params = conv_params 129 | if self.conv_params is None: 130 | self.conv_params = {} 131 | 132 | self.conv_list = nn.ModuleList() 133 | 134 | for i in range(self.n_convs): 135 | conv_layer = ConvModule( 136 | n_featmaps, 137 | n_featmaps, 138 | conv_op=conv_op, 139 | conv_params=conv_params, 140 | normalization_op=normalization_op, 141 | normalization_params=normalization_params, 142 | activation_op=activation_op, 143 | activation_params=activation_params, 144 | ) 145 | self.conv_list.append(conv_layer) 146 | 147 | def forward(self, input, **frwd_params): 148 | x = input 149 | for conv_layer in self.conv_list: 150 | x = conv_layer(x) 151 | 152 | return x 153 | 154 | 155 | class ResBlock(nn.Module): 156 | def __init__( 157 | self, 158 | n_convs, 159 | n_featmaps, 160 | conv_op=nn.Conv2d, 161 | conv_params=None, 162 | normalization_op=nn.BatchNorm2d, 163 | normalization_params=None, 164 | activation_op=nn.LeakyReLU, 165 | activation_params=None, 166 | ): 167 | """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) and a skip/ residual connection: 168 | x = input 169 | x = conv_block(x) 170 | out = x + input 171 | 172 | Args: 173 | n_convs ([type]): [Number of convolutions in the conv block] 174 | n_featmaps ([type]): [Feature map size of the conv block] 175 | conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. 176 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. 177 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 178 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 179 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 180 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 181 | """ 182 | super(ResBlock, self).__init__() 183 | 184 | self.n_featmaps = n_featmaps 185 | self.n_convs = n_convs 186 | self.conv_params = conv_params 187 | if self.conv_params is None: 188 | self.conv_params = {} 189 | 190 | self.conv_block = ConvBlock( 191 | n_featmaps, 192 | n_convs, 193 | conv_op=conv_op, 194 | conv_params=conv_params, 195 | normalization_op=normalization_op, 196 | normalization_params=normalization_params, 197 | activation_op=activation_op, 198 | activation_params=activation_params, 199 | ) 200 | 201 | def forward(self, input, **frwd_params): 202 | x = input 203 | x = self.conv_block(x) 204 | 205 | out = x + input 206 | 207 | return out 208 | 209 | 210 | # Basic Generator 211 | class BasicGenerator(nn.Module): 212 | def __init__( 213 | self, 214 | input_size, 215 | z_dim=256, 216 | fmap_sizes=(256, 128, 64), 217 | upsample_op=nn.ConvTranspose2d, 218 | conv_params=None, 219 | normalization_op=NoOp, 220 | normalization_params=None, 221 | activation_op=nn.LeakyReLU, 222 | activation_params=None, 223 | block_op=NoOp, 224 | block_params=None, 225 | to_1x1=True, 226 | ): 227 | """Basic configureable Generator/ Decoder. 228 | Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. 229 | 230 | Args: 231 | input_size ((int, int, int): Size of the input in format CxHxW): 232 | z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). 233 | fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 234 | int defines the number of feature maps in the layer]. Defaults to (256, 128, 64). 235 | upsample_op ([torch.nn.Module], optional): [Upsampling operation used, to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. 236 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 237 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 238 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 239 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 240 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 241 | block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. 242 | block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. 243 | to_1x1 (bool, optional): [If Latent dimesion is a z_dim x 1 x 1 vector (True) or if allows spatial resolution not to be 1x1 (z_dim x H x W) (False) ]. Defaults to True. 244 | """ 245 | 246 | super(BasicGenerator, self).__init__() 247 | 248 | if conv_params is None: 249 | conv_params = dict(kernel_size=4, stride=2, padding=1, bias=False) 250 | if block_op is None: 251 | block_op = NoOp 252 | if block_params is None: 253 | block_params = {} 254 | 255 | n_channels = input_size[0] 256 | input_size_ = np.array(input_size[1:]) 257 | 258 | if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): 259 | raise AttributeError("fmap_sizes has to be either a list or tuple or an int") 260 | elif len(fmap_sizes) < 2: 261 | raise AttributeError("fmap_sizes has to contain at least three elements") 262 | else: 263 | h_size_bot = fmap_sizes[0] 264 | 265 | # We need to know how many layers we will use at the beginning 266 | input_size_new = input_size_ // (2 ** len(fmap_sizes)) 267 | if np.min(input_size_new) < 2 and z_dim is not None: 268 | raise AttributeError("fmap_sizes to long, one image dimension has already perished") 269 | 270 | ### Start block 271 | start_block = [] 272 | 273 | if not to_1x1: 274 | kernel_size_start = [min(conv_params["kernel_size"], i) for i in input_size_new] 275 | else: 276 | kernel_size_start = input_size_new.tolist() 277 | 278 | if z_dim is not None: 279 | self.start = ConvModule( 280 | z_dim, 281 | h_size_bot, 282 | conv_op=upsample_op, 283 | conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False), 284 | normalization_op=normalization_op, 285 | normalization_params=normalization_params, 286 | activation_op=activation_op, 287 | activation_params=activation_params, 288 | ) 289 | 290 | input_size_new = input_size_new * 2 291 | else: 292 | self.start = NoOp() 293 | 294 | ### Middle block (Done until we reach ? x input_size/2 x input_size/2) 295 | self.middle_blocks = nn.ModuleList() 296 | 297 | for h_size_top in fmap_sizes[1:]: 298 | 299 | self.middle_blocks.append(block_op(h_size_bot, **block_params)) 300 | 301 | self.middle_blocks.append( 302 | ConvModule( 303 | h_size_bot, 304 | h_size_top, 305 | conv_op=upsample_op, 306 | conv_params=conv_params, 307 | normalization_op=normalization_op, 308 | normalization_params={}, 309 | activation_op=activation_op, 310 | activation_params=activation_params, 311 | ) 312 | ) 313 | 314 | h_size_bot = h_size_top 315 | input_size_new = input_size_new * 2 316 | 317 | ### End block 318 | self.end = ConvModule( 319 | h_size_bot, 320 | n_channels, 321 | conv_op=upsample_op, 322 | conv_params=conv_params, 323 | normalization_op=None, 324 | activation_op=None, 325 | ) 326 | 327 | def forward(self, inpt, **kwargs): 328 | output = self.start(inpt, **kwargs) 329 | for middle in self.middle_blocks: 330 | output = middle(output, **kwargs) 331 | output = self.end(output, **kwargs) 332 | return output 333 | 334 | 335 | # Basic Encoder 336 | class BasicEncoder(nn.Module): 337 | def __init__( 338 | self, 339 | input_size, 340 | z_dim=256, 341 | fmap_sizes=(64, 128, 256), 342 | conv_op=nn.Conv2d, 343 | conv_params=None, 344 | normalization_op=NoOp, 345 | normalization_params=None, 346 | activation_op=nn.LeakyReLU, 347 | activation_params=None, 348 | block_op=NoOp, 349 | block_params=None, 350 | to_1x1=True, 351 | ): 352 | """Basic configureable Encoder. 353 | Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. 354 | 355 | Args: 356 | z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). 357 | fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 358 | int defines the number of feature maps in the layer]. Defaults to (64, 128, 256). 359 | conv_op ([torch.nn.Module], optional): [Convolutioon operation used to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. 360 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 361 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 362 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 363 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 364 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 365 | block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. 366 | block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. 367 | to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True. 368 | """ 369 | super(BasicEncoder, self).__init__() 370 | 371 | if conv_params is None: 372 | conv_params = dict(kernel_size=3, stride=2, padding=1, bias=False) 373 | if block_op is None: 374 | block_op = NoOp 375 | if block_params is None: 376 | block_params = {} 377 | 378 | n_channels = input_size[0] 379 | input_size_new = np.array(input_size[1:]) 380 | 381 | if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): 382 | raise AttributeError("fmap_sizes has to be either a list or tuple or an int") 383 | # elif len(fmap_sizes) < 2: 384 | # raise AttributeError("fmap_sizes has to contain at least three elements") 385 | else: 386 | h_size_bot = fmap_sizes[0] 387 | 388 | ### Start block 389 | self.start = ConvModule( 390 | n_channels, 391 | h_size_bot, 392 | conv_op=conv_op, 393 | conv_params=conv_params, 394 | normalization_op=normalization_op, 395 | normalization_params={}, 396 | activation_op=activation_op, 397 | activation_params=activation_params, 398 | ) 399 | input_size_new = input_size_new // 2 400 | 401 | ### Middle block (Done until we reach ? x 4 x 4) 402 | self.middle_blocks = nn.ModuleList() 403 | 404 | for h_size_top in fmap_sizes[1:]: 405 | 406 | self.middle_blocks.append(block_op(h_size_bot, **block_params)) 407 | 408 | self.middle_blocks.append( 409 | ConvModule( 410 | h_size_bot, 411 | h_size_top, 412 | conv_op=conv_op, 413 | conv_params=conv_params, 414 | normalization_op=normalization_op, 415 | normalization_params={}, 416 | activation_op=activation_op, 417 | activation_params=activation_params, 418 | ) 419 | ) 420 | 421 | h_size_bot = h_size_top 422 | input_size_new = input_size_new // 2 423 | 424 | if np.min(input_size_new) < 2 and z_dim is not None: 425 | raise ("fmap_sizes to long, one image dimension has already perished") 426 | 427 | ### End block 428 | if not to_1x1: 429 | kernel_size_end = [min(conv_params["kernel_size"], i) for i in input_size_new] 430 | else: 431 | kernel_size_end = input_size_new.tolist() 432 | 433 | if z_dim is not None: 434 | self.end = ConvModule( 435 | h_size_bot, 436 | z_dim, 437 | conv_op=conv_op, 438 | conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False), 439 | normalization_op=None, 440 | activation_op=None, 441 | ) 442 | 443 | if to_1x1: 444 | self.output_size = (z_dim, 1, 1) 445 | else: 446 | self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(input_size_new, kernel_size_end)]) 447 | else: 448 | self.end = NoOp() 449 | self.output_size = input_size_new 450 | 451 | def forward(self, inpt, **kwargs): 452 | output = self.start(inpt, **kwargs) 453 | for middle in self.middle_blocks: 454 | output = middle(output, **kwargs) 455 | output = self.end(output, **kwargs) 456 | return output 457 | -------------------------------------------------------------------------------- /ae_bases.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class NoOp(nn.Module): 9 | def __init__(self, *args, **kwargs): 10 | """NoOp Pytorch Module. 11 | Forwards the given input as is. 12 | """ 13 | super(NoOp, self).__init__() 14 | 15 | def forward(self, x, *args, **kwargs): 16 | return x 17 | 18 | 19 | class ConvModule(nn.Module): 20 | def __init__( 21 | self, 22 | in_channels: int, 23 | out_channels: int, 24 | conv_op=nn.Conv2d, 25 | conv_params=None, 26 | normalization_op=None, 27 | normalization_params=None, 28 | activation_op=nn.LeakyReLU, 29 | activation_params=None, 30 | ): 31 | """Basic Conv Pytorch Conv Module 32 | Has can have a Conv Op, a Normlization Op and a Non Linearity: 33 | x = conv(x) 34 | x = some_norm(x) 35 | x = nonlin(x) 36 | 37 | Args: 38 | in_channels ([int]): [Number on input channels/ feature maps] 39 | out_channels ([int]): [Number of ouput channels/ feature maps] 40 | conv_op ([torch.nn.Module], optional): [Conv operation]. Defaults to nn.Conv2d. 41 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. 42 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...)]. Defaults to None. 43 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 44 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...)]. Defaults to nn.LeakyReLU. 45 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 46 | """ 47 | 48 | super(ConvModule, self).__init__() 49 | 50 | self.conv_params = conv_params 51 | if self.conv_params is None: 52 | self.conv_params = {} 53 | self.activation_params = activation_params 54 | if self.activation_params is None: 55 | self.activation_params = {} 56 | self.normalization_params = normalization_params 57 | if self.normalization_params is None: 58 | self.normalization_params = {} 59 | 60 | self.conv = None 61 | if conv_op is not None and not isinstance(conv_op, str): 62 | self.conv = conv_op(in_channels, out_channels, **self.conv_params) 63 | 64 | self.normalization = None 65 | if normalization_op is not None and not isinstance(normalization_op, str): 66 | self.normalization = normalization_op(out_channels, **self.normalization_params) 67 | 68 | self.activation = None 69 | if activation_op is not None and not isinstance(activation_op, str): 70 | self.activation = activation_op(**self.activation_params) 71 | 72 | def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None): 73 | 74 | x = input 75 | 76 | if self.conv is not None: 77 | if conv_add_input is None: 78 | x = self.conv(x) 79 | else: 80 | x = self.conv(x, **conv_add_input) 81 | 82 | if self.normalization is not None: 83 | if normalization_add_input is None: 84 | x = self.normalization(x) 85 | else: 86 | x = self.normalization(x, **normalization_add_input) 87 | 88 | if self.activation is not None: 89 | if activation_add_input is None: 90 | x = self.activation(x) 91 | else: 92 | x = self.activation(x, **activation_add_input) 93 | 94 | # nn.functional.dropout(x, p=0.95, training=True) 95 | 96 | return x 97 | 98 | 99 | class ConvBlock(nn.Module): 100 | def __init__( 101 | self, 102 | n_convs: int, 103 | n_featmaps: int, 104 | conv_op=nn.Conv2d, 105 | conv_params=None, 106 | normalization_op=nn.BatchNorm2d, 107 | normalization_params=None, 108 | activation_op=nn.LeakyReLU, 109 | activation_params=None, 110 | ): 111 | """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) 112 | 113 | Args: 114 | n_convs ([type]): [Number of convolutions] 115 | n_featmaps ([type]): [Feature map size of the conv] 116 | conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. 117 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. 118 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 119 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 120 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 121 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 122 | """ 123 | 124 | super(ConvBlock, self).__init__() 125 | 126 | self.n_featmaps = n_featmaps 127 | self.n_convs = n_convs 128 | self.conv_params = conv_params 129 | if self.conv_params is None: 130 | self.conv_params = {} 131 | 132 | self.conv_list = nn.ModuleList() 133 | 134 | for i in range(self.n_convs): 135 | conv_layer = ConvModule( 136 | n_featmaps, 137 | n_featmaps, 138 | conv_op=conv_op, 139 | conv_params=conv_params, 140 | normalization_op=normalization_op, 141 | normalization_params=normalization_params, 142 | activation_op=activation_op, 143 | activation_params=activation_params, 144 | ) 145 | self.conv_list.append(conv_layer) 146 | 147 | def forward(self, input, **frwd_params): 148 | x = input 149 | for conv_layer in self.conv_list: 150 | x = conv_layer(x) 151 | 152 | return x 153 | 154 | 155 | class ResBlock(nn.Module): 156 | def __init__( 157 | self, 158 | n_convs, 159 | n_featmaps, 160 | conv_op=nn.Conv2d, 161 | conv_params=None, 162 | normalization_op=nn.BatchNorm2d, 163 | normalization_params=None, 164 | activation_op=nn.LeakyReLU, 165 | activation_params=None, 166 | ): 167 | """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) and a skip/ residual connection: 168 | x = input 169 | x = conv_block(x) 170 | out = x + input 171 | 172 | Args: 173 | n_convs ([type]): [Number of convolutions in the conv block] 174 | n_featmaps ([type]): [Feature map size of the conv block] 175 | conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. 176 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. 177 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 178 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 179 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 180 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 181 | """ 182 | super(ResBlock, self).__init__() 183 | 184 | self.n_featmaps = n_featmaps 185 | self.n_convs = n_convs 186 | self.conv_params = conv_params 187 | if self.conv_params is None: 188 | self.conv_params = {} 189 | 190 | self.conv_block = ConvBlock( 191 | n_featmaps, 192 | n_convs, 193 | conv_op=conv_op, 194 | conv_params=conv_params, 195 | normalization_op=normalization_op, 196 | normalization_params=normalization_params, 197 | activation_op=activation_op, 198 | activation_params=activation_params, 199 | ) 200 | 201 | def forward(self, input, **frwd_params): 202 | x = input 203 | x = self.conv_block(x) 204 | 205 | out = x + input 206 | 207 | return out 208 | 209 | 210 | # Basic Generator 211 | class BasicGenerator(nn.Module): 212 | def __init__( 213 | self, 214 | input_size, 215 | z_dim=256, 216 | fmap_sizes=(256, 128, 64), 217 | upsample_op=nn.ConvTranspose2d, 218 | conv_params=None, 219 | normalization_op=NoOp, 220 | normalization_params=None, 221 | activation_op=nn.LeakyReLU, 222 | activation_params=None, 223 | block_op=NoOp, 224 | block_params=None, 225 | to_1x1=True, 226 | ): 227 | """Basic configureable Generator/ Decoder. 228 | Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. 229 | 230 | Args: 231 | input_size ((int, int, int): Size of the input in format CxHxW): 232 | z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). 233 | fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 234 | int defines the number of feature maps in the layer]. Defaults to (256, 128, 64). 235 | upsample_op ([torch.nn.Module], optional): [Upsampling operation used, to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. 236 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 237 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 238 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 239 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 240 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 241 | block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. 242 | block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. 243 | to_1x1 (bool, optional): [If Latent dimesion is a z_dim x 1 x 1 vector (True) or if allows spatial resolution not to be 1x1 (z_dim x H x W) (False) ]. Defaults to True. 244 | """ 245 | 246 | super(BasicGenerator, self).__init__() 247 | 248 | if conv_params is None: 249 | conv_params = dict(kernel_size=4, stride=2, padding=1, bias=False) 250 | if block_op is None: 251 | block_op = NoOp 252 | if block_params is None: 253 | block_params = {} 254 | 255 | n_channels = input_size[0] 256 | input_size_ = np.array(input_size[1:]) 257 | 258 | if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): 259 | raise AttributeError("fmap_sizes has to be either a list or tuple or an int") 260 | elif len(fmap_sizes) < 2: 261 | raise AttributeError("fmap_sizes has to contain at least three elements") 262 | else: 263 | h_size_bot = fmap_sizes[0] 264 | 265 | # We need to know how many layers we will use at the beginning 266 | input_size_new = input_size_ // (2 ** len(fmap_sizes)) 267 | if np.min(input_size_new) < 2 and z_dim is not None: 268 | raise AttributeError("fmap_sizes to long, one image dimension has already perished") 269 | 270 | ### Start block 271 | start_block = [] 272 | 273 | if not to_1x1: 274 | kernel_size_start = [min(conv_params["kernel_size"], i) for i in input_size_new] 275 | else: 276 | kernel_size_start = input_size_new.tolist() 277 | 278 | if z_dim is not None: 279 | self.start = ConvModule( 280 | z_dim, 281 | h_size_bot, 282 | conv_op=upsample_op, 283 | conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False), 284 | normalization_op=normalization_op, 285 | normalization_params=normalization_params, 286 | activation_op=activation_op, 287 | activation_params=activation_params, 288 | ) 289 | 290 | input_size_new = input_size_new * 2 291 | else: 292 | self.start = NoOp() 293 | 294 | ### Middle block (Done until we reach ? x input_size/2 x input_size/2) 295 | self.middle_blocks = nn.ModuleList() 296 | 297 | for h_size_top in fmap_sizes[1:]: 298 | 299 | self.middle_blocks.append(block_op(h_size_bot, **block_params)) 300 | 301 | self.middle_blocks.append( 302 | ConvModule( 303 | h_size_bot, 304 | h_size_top, 305 | conv_op=upsample_op, 306 | conv_params=conv_params, 307 | normalization_op=normalization_op, 308 | normalization_params={}, 309 | activation_op=activation_op, 310 | activation_params=activation_params, 311 | ) 312 | ) 313 | 314 | h_size_bot = h_size_top 315 | input_size_new = input_size_new * 2 316 | 317 | ### End block 318 | self.end = ConvModule( 319 | h_size_bot, 320 | n_channels, 321 | conv_op=upsample_op, 322 | conv_params=conv_params, 323 | normalization_op=None, 324 | activation_op=None, 325 | ) 326 | 327 | def forward(self, inpt, **kwargs): 328 | output = self.start(inpt, **kwargs) 329 | for middle in self.middle_blocks: 330 | output = middle(output, **kwargs) 331 | output = self.end(output, **kwargs) 332 | return output 333 | 334 | 335 | # Basic Encoder 336 | class BasicEncoder(nn.Module): 337 | def __init__( 338 | self, 339 | input_size, 340 | z_dim=256, 341 | fmap_sizes=(64, 128, 256), 342 | conv_op=nn.Conv2d, 343 | conv_params=None, 344 | normalization_op=NoOp, 345 | normalization_params=None, 346 | activation_op=nn.LeakyReLU, 347 | activation_params=None, 348 | block_op=NoOp, 349 | block_params=None, 350 | to_1x1=True, 351 | ): 352 | """Basic configureable Encoder. 353 | Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. 354 | 355 | Args: 356 | z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). 357 | fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 358 | int defines the number of feature maps in the layer]. Defaults to (64, 128, 256). 359 | conv_op ([torch.nn.Module], optional): [Convolutioon operation used to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. 360 | conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). 361 | normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. 362 | normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. 363 | activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. 364 | activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. 365 | block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. 366 | block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. 367 | to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True. 368 | """ 369 | super(BasicEncoder, self).__init__() 370 | 371 | if conv_params is None: 372 | conv_params = dict(kernel_size=3, stride=2, padding=1, bias=False) 373 | if block_op is None: 374 | block_op = NoOp 375 | if block_params is None: 376 | block_params = {} 377 | 378 | n_channels = input_size[0] 379 | input_size_new = np.array(input_size[1:]) 380 | 381 | if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): 382 | raise AttributeError("fmap_sizes has to be either a list or tuple or an int") 383 | # elif len(fmap_sizes) < 2: 384 | # raise AttributeError("fmap_sizes has to contain at least three elements") 385 | else: 386 | h_size_bot = fmap_sizes[0] 387 | 388 | ### Start block 389 | self.start = ConvModule( 390 | n_channels, 391 | h_size_bot, 392 | conv_op=conv_op, 393 | conv_params=conv_params, 394 | normalization_op=normalization_op, 395 | normalization_params={}, 396 | activation_op=activation_op, 397 | activation_params=activation_params, 398 | ) 399 | input_size_new = input_size_new // 2 400 | 401 | ### Middle block (Done until we reach ? x 4 x 4) 402 | self.middle_blocks = nn.ModuleList() 403 | 404 | for h_size_top in fmap_sizes[1:]: 405 | 406 | self.middle_blocks.append(block_op(h_size_bot, **block_params)) 407 | 408 | self.middle_blocks.append( 409 | ConvModule( 410 | h_size_bot, 411 | h_size_top, 412 | conv_op=conv_op, 413 | conv_params=conv_params, 414 | normalization_op=normalization_op, 415 | normalization_params={}, 416 | activation_op=activation_op, 417 | activation_params=activation_params, 418 | ) 419 | ) 420 | 421 | h_size_bot = h_size_top 422 | input_size_new = input_size_new // 2 423 | 424 | if np.min(input_size_new) < 2 and z_dim is not None: 425 | raise ("fmap_sizes to long, one image dimension has already perished") 426 | 427 | ### End block 428 | if not to_1x1: 429 | kernel_size_end = [min(conv_params["kernel_size"], i) for i in input_size_new] 430 | else: 431 | kernel_size_end = input_size_new.tolist() 432 | 433 | if z_dim is not None: 434 | self.end = ConvModule( 435 | h_size_bot, 436 | z_dim, 437 | conv_op=conv_op, 438 | conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False), 439 | normalization_op=None, 440 | activation_op=None, 441 | ) 442 | 443 | if to_1x1: 444 | self.output_size = (z_dim, 1, 1) 445 | else: 446 | self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(input_size_new, kernel_size_end)]) 447 | else: 448 | self.end = NoOp() 449 | self.output_size = input_size_new 450 | 451 | def forward(self, inpt, **kwargs): 452 | output = self.start(inpt, **kwargs) 453 | for middle in self.middle_blocks: 454 | output = middle(output, **kwargs) 455 | output = self.end(output, **kwargs) 456 | return output 457 | -------------------------------------------------------------------------------- /misc/ssae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import pickle 4 | import json 5 | import random 6 | import logging 7 | import numpy as np 8 | from itertools import chain 9 | import torch 10 | import torch.nn as nn 11 | import torchvision 12 | from torch.autograd import Variable 13 | from torch.utils.data import Dataset, DataLoader 14 | import torchio 15 | import torchio as tio 16 | from tqdm import tqdm 17 | import sys 18 | import wandb 19 | import matplotlib.pyplot as plt 20 | from torch.utils.tensorboard import SummaryWriter 21 | from torch.optim import Adam 22 | from torch import nn, optim 23 | from torch.optim.lr_scheduler import StepLR 24 | import torch.distributions as dist 25 | import math 26 | from torch.cuda.amp import autocast, GradScaler 27 | 28 | from datasets import MoodTrainSet, MoodValSet 29 | from datasets.ixi import IXITrainSet 30 | from models.ceVae.aes import VAE 31 | from models.ceVae.helpers import kl_loss_fn, rec_loss_fn, geco_beta_update, get_ema, get_square_mask 32 | import torch.nn.functional as F 33 | 34 | #Initial params 35 | gpuID="0" 36 | seed = 1701 37 | num_workers=0 38 | batch_size = 16 39 | log_freq = 10 40 | checkpoint2load = False 41 | checkpoint = r'/scratch/sagrawal/cevae_seg_ixi_t1_t2_pd_mood_noaug/cevae_seg_ixi_t1_t2_pd_mood_noaug-epoch-44.pth.tar' 42 | useCuda = True 43 | 44 | ixi_t1_indices = list(range(100,350)) 45 | ixi_t2_indices = list(range(100,350)) 46 | ixi_proton_indices = list(range(100,350)) 47 | mood_t1_indices = list(range(100,350)) 48 | 49 | ixi_t1_eval_indices = list(range(30,50)) 50 | ixi_t2_eval_indices = list(range(30,50)) 51 | ixi_proton_eval_indices = list(range(30,50)) 52 | mood_t1_eval_indices = list(range(30,50)) 53 | 54 | ixi_t1_train = r"/project/ptummala/ixi/ixi_t1_segmented_581_3D.hdf5" 55 | ixi_t2_train = r"/project/ptummala/ixi/ixi_t2_segmented_578_3D.hdf5" 56 | ixi_proton_train = r"/project/ptummala/ixi/ixi_pd_segmented_578_3D.hdf5" 57 | mood_t1_train = r"/project/ptummala/moods/mood_t1_seg.h5" 58 | preload_h5 = False 59 | 60 | save_path = r'/scratch/ptummala/ssvae' 61 | 62 | #Training params 63 | trainID="SSCVAE-1" #ceVAE2D_seg_ixi_mood 64 | num_epochs = 250 65 | lr = 1e-4 66 | patch_size=(256,256,1) #Set it to None if not desired 67 | patchQ_len = 512 68 | patches_per_volume = 256 69 | log_freq = 10 70 | preload_h5 = False 71 | 72 | #Network Params 73 | IsVAE=True 74 | input_shape=(256,256,256) 75 | input_dim = (256,256) 76 | input_size = (1,256,256) 77 | z_dim=1024 78 | model_feature_map_sizes=(16, 64, 256, 1024) 79 | n_channels=1 80 | ce_factor=0.5 81 | beta=0.01 82 | vae_loss_ema = 1 83 | theta = 1 84 | use_geco=False 85 | 86 | os.environ["CUDA_VISIBLE_DEVICES"] = gpuID 87 | random.seed(seed) 88 | os.environ['PYTHONHASHSEED'] = str(seed) 89 | np.random.seed(seed) 90 | torch.manual_seed(seed) 91 | torch.cuda.manual_seed(seed) 92 | torch.backends.cudnn.deterministic = True 93 | torch.backends.cudnn.benchmark = False 94 | 95 | if __name__ == "__main__" : 96 | wandb.init(project='anomaly', entity='s9chroma') 97 | wandb.watch(model) 98 | config = wandb.config 99 | config.learning_rate = learning_rate 100 | wandb.run.name = 'SSVAE1' 101 | 102 | device = torch.device("cuda:0" if torch.cuda.is_available() and useCuda else "cpu") 103 | #device = 'cpu' 104 | 105 | ixi_t1_trainset = IXITrainSet(indices=ixi_t1_indices, data_path=ixi_t1_train, lazypatch=True if patch_size else False, preload=preload_h5) 106 | ixi_t2_trainset = IXITrainSet(indices=ixi_t2_indices, data_path=ixi_t2_train, lazypatch=True if patch_size else False, preload=preload_h5) 107 | ixi_pd_trainset = IXITrainSet(indices=ixi_proton_indices, data_path=ixi_proton_train, lazypatch=True if patch_size else False, preload=preload_h5) 108 | mood_trainset = IXITrainSet(indices=mood_t1_indices, data_path=mood_t1_train, lazypatch=True if patch_size else False, preload=preload_h5) 109 | 110 | tot_trainset = ixi_t1_trainset + ixi_t2_trainset + ixi_pd_trainset + mood_trainset 111 | 112 | ixi_t1_eval = IXITrainSet(indices=ixi_t1_eval_indices, data_path=ixi_t1_train, lazypatch=True if patch_size else False, preload=preload_h5) 113 | ixi_t2_eval = IXITrainSet(indices=ixi_t2_eval_indices, data_path=ixi_t2_train, lazypatch=True if patch_size else False, preload=preload_h5) 114 | ixi_pd_eval = IXITrainSet(indices=ixi_proton_eval_indices, data_path=ixi_proton_train, lazypatch=True if patch_size else False, preload=preload_h5) 115 | mood_t1_eval = IXITrainSet(indices=mood_t1_eval_indices, data_path=mood_t1_train, lazypatch=True if patch_size else False, preload=preload_h5) 116 | 117 | ixi_evalset = ixi_t1_eval + ixi_t2_eval + ixi_pd_eval 118 | 119 | if patch_size: 120 | input_shape = tuple(x for x in patch_size if x!=1) 121 | trainset = torchio.data.Queue( 122 | subjects_dataset = tot_trainset, 123 | max_length = patchQ_len, 124 | samples_per_volume = patches_per_volume, 125 | sampler = torchio.data.UniformSampler(patch_size=patch_size), 126 | # num_workers = num_workers 127 | ) 128 | 129 | ixi_trainset = torchio.data.Queue( 130 | subjects_dataset = ixi_evalset, 131 | max_length = patchQ_len, 132 | samples_per_volume = patches_per_volume, 133 | sampler = torchio.data.UniformSampler(patch_size=patch_size), 134 | # num_workers = num_workers 135 | ) 136 | 137 | mood_trainset = torchio.data.Queue( 138 | subjects_dataset = mood_t1_eval, 139 | max_length = patchQ_len, 140 | samples_per_volume = patches_per_volume, 141 | sampler = torchio.data.UniformSampler(patch_size=patch_size), 142 | # num_workers = num_workers 143 | ) 144 | 145 | train_loader = DataLoader(dataset=trainset,batch_size=batch_size,shuffle=True, num_workers=num_workers) 146 | ixi_eval_loader = DataLoader(dataset=ixi_trainset,batch_size=batch_size,shuffle=True, num_workers=num_workers) 147 | mood_eval_loader = DataLoader(dataset=mood_trainset,batch_size=batch_size,shuffle=True, num_workers=num_workers) 148 | 149 | 150 | if len(input_dim) == 2: 151 | conv = nn.Conv2d 152 | convt = nn.ConvTranspose2d 153 | d = 2 154 | else: 155 | conv = nn.Conv3d 156 | convt = nn.ConvTranspose3d 157 | d = 3 158 | 159 | model = VAE(input_size=input_size, z_dim=z_dim, fmap_sizes=model_feature_map_sizes, 160 | conv_op=conv, 161 | tconv_op=convt, 162 | activation_op=torch.nn.PReLU) 163 | model.d = d 164 | model.to(device) 165 | wandb.watch(model) 166 | optimizer = Adam(model.parameters(), lr=lr) 167 | scaler = GradScaler() 168 | 169 | if checkpoint2load: 170 | chk = torch.load(checkpoint,map_location=device) 171 | #model = chk['state_dict'] 172 | #optimizer = chk['optimizer'] 173 | #scaler = chk['AMPScaler'] 174 | start_epoch = 45 175 | 176 | model_dict = model.state_dict() 177 | pretrained_dict = chk['state_dict'] 178 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 179 | model_dict.update(pretrained_dict) 180 | model.load_state_dict(pretrained_dict) 181 | model.to(device) 182 | scaler.load_state_dict(chk['AMPScaler']) 183 | optimizer.load_state_dict(chk['optimizer']) 184 | del chk 185 | else: 186 | start_epoch = 0 187 | best_loss = float('inf') 188 | 189 | for epoch in range(start_epoch, num_epochs): 190 | model.train() 191 | runningLoss = 0.0 192 | runningLossCounter = 0.0 193 | train_loss = 0.0 194 | kl_loss_tot = 0.0 195 | loss_vae_tot = 0.0 196 | loss_ce_tot = 0.0 197 | print('Epoch '+ str(epoch)+ ': Train') 198 | with tqdm(total=len(train_loader)) as pbar: 199 | for i, data in enumerate(train_loader): 200 | 201 | img = data['img']['data'].squeeze(-1) 202 | 203 | tmp = img.view(img.shape[0], 1, -1) 204 | min_vals = tmp.min(2, keepdim=True).values 205 | max_vals = tmp.max(2, keepdim=True).values 206 | tmp = (tmp - min_vals) / max_vals 207 | x = tmp.view(img.size()) 208 | 209 | shape = x.shape 210 | tensor_reshaped = x.reshape(shape[0],-1) 211 | tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)] 212 | images = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:]) 213 | 214 | I1 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 215 | for i, im in enumerate(images): 216 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 217 | scale_percent = 50 # percent of original size 218 | width = int(blur.shape[1] * scale_percent / 100) 219 | height = int(blur.shape[0] * scale_percent / 100) 220 | dim = (width, height) 221 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 222 | I1[i][0] = torch.tensor(downsampled) 223 | U1 = torch.zeros_like(images) 224 | for i, im in enumerate(U1): 225 | dim = (shape[2], shape[3]) 226 | U1[i][0] = torch.tensor(cv2.resize(I1[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 227 | 228 | shape = I1.shape 229 | I2 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 230 | for i, im in enumerate(I1): 231 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 232 | scale_percent = 50 # percent of original size 233 | width = int(blur.shape[1] * scale_percent / 100) 234 | height = int(blur.shape[0] * scale_percent / 100) 235 | dim = (width, height) 236 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 237 | I2[i][0] = torch.tensor(downsampled) 238 | U2 = torch.zeros_like(I1) 239 | for i, im in enumerate(U2): 240 | dim = (shape[2], shape[3]) 241 | U2[i][0] = torch.tensor(cv2.resize(I2[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 242 | 243 | shape = I2.shape 244 | I3 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 245 | for i, im in enumerate(I2): 246 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 247 | scale_percent = 50 # percent of original size 248 | width = int(blur.shape[1] * scale_percent / 100) 249 | height = int(blur.shape[0] * scale_percent / 100) 250 | dim = (width, height) 251 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 252 | I3[i][0] = torch.tensor(downsampled) 253 | U3 = torch.zeros_like(I2) 254 | for i, im in enumerate(U3): 255 | dim = (shape[2], shape[3]) 256 | U3[i][0] = torch.tensor(cv2.resize(I3[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 257 | 258 | L1 = images - U1 259 | L2 = I1 - U2 260 | L3 = I2 - U3 261 | 262 | images = Variable(L1).to(device) 263 | optimizer.zero_grad() 264 | 265 | ### VAE Part 266 | with autocast(): 267 | loss_vae = 0 268 | if ce_factor < 1: 269 | x_r, z_dist = model(images) 270 | 271 | kl_loss = 0 272 | if model.d == 3: 273 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,)) * beta #check TODO 274 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3,4)) 275 | else: 276 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta 277 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3)) 278 | loss_vae = kl_loss + rec_loss_vae * theta 279 | 280 | 281 | 282 | ### CE Part 283 | loss_ce = 0 284 | if ce_factor > 0: 285 | 286 | ce_tensor = get_square_mask( 287 | tensor.shape, 288 | square_size=(0, np.max(input_size[1:]) // 2), 289 | noise_val=(torch.min(tensor).item(), torch.max(tensor).item()), 290 | n_squares=(0, 3), 291 | ) 292 | 293 | ce_tensor = torch.from_numpy(ce_tensor).float().to(device) 294 | inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, tensor) 295 | 296 | inpt_noisy = inpt_noisy.to(device) 297 | 298 | with autocast(): 299 | x_rec_ce, _ = model(inpt_noisy) 300 | if model.d == 3: 301 | rec_loss_ce = rec_loss_fn(x_rec_ce, images, sumdim=(1,2,3,4)) 302 | else: 303 | rec_loss_ce = rec_loss_fn(x_rec_ce, images, sumdim=(1,2,3)) 304 | loss_ce = rec_loss_ce 305 | loss = (1.0 - ce_factor) * loss_vae + ce_factor * loss_ce 306 | 307 | if use_geco and ce_factor < 1: 308 | g_goal = 0.1 309 | g_lr = 1e-4 310 | vae_loss_ema = (1.0 - 0.9) * rec_loss_vae + 0.9 * vae_loss_ema 311 | theta = geco_beta_update(theta, vae_loss_ema, g_goal, g_lr, speedup=2) 312 | 313 | 314 | scaler.scale(loss).backward() 315 | scaler.step(optimizer) 316 | scaler.update() 317 | loss = round(loss.item(),4) 318 | train_loss += loss 319 | runningLoss += loss 320 | kl_loss_tot += round(kl_loss.item(),4) 321 | loss_vae_tot += round(loss_vae.item(),4) 322 | loss_ce_tot += round(loss_ce.item(),4) 323 | runningLossCounter += 1 324 | print("Epoch: ", epoch, ", i: ", i, ", Training loss: ", loss) 325 | logging.info('[%d/%d][%d/%d] Train Loss: %.4f' % ((epoch+1), num_epochs, i, len(train_loader), loss)) 326 | #For tensorboard 327 | if i % log_freq == 0: 328 | niter = epoch*len(train_loader)+i 329 | tb_writer.add_scalar('Train/Loss', runningLoss/runningLossCounter, niter) 330 | runningLoss = 0.0 331 | runningLossCounter = 0.0 332 | pbar.update(1) 333 | 334 | wandb.log({"train_loss": train_loss}) 335 | wandb.log({"kl_loss": kl_loss_tot}) 336 | wandb.log({"vae_loss": loss_vae_tot}) 337 | wandb.log({"ce_loss": loss_ce_tot}) 338 | 339 | checkpoint = { 340 | 'state_dict': model.state_dict(), 341 | 'optimizer': optimizer.state_dict(), 342 | 'AMPScaler': scaler.state_dict() 343 | } 344 | if epoch%4==0: 345 | torch.save(checkpoint, os.path.join(save_path, trainID + '-epoch-' + str(epoch) + ".pth.tar")) 346 | #torch.save(checkpoint, os.path.join(save_path, trainID+".pth.tar")) 347 | tb_writer.add_scalar('Train/AvgLossEpoch', train_loss/len(train_loader), epoch) 348 | 349 | 350 | 351 | model.eval() 352 | mood_val_loss = 0 353 | with torch.no_grad(): 354 | print('Epoch '+ str(epoch)+ ': Ixi Val') 355 | with tqdm(total=len(ixi_eval_loader)) as pbar: 356 | for i, data in enumerate(ixi_eval_loader): 357 | img = data['img']['data'].squeeze(-1) 358 | tmp = img.view(img.shape[0], 1, -1) 359 | min_vals = tmp.min(2, keepdim=True).values 360 | max_vals = tmp.max(2, keepdim=True).values 361 | tmp = (tmp - min_vals) / max_vals 362 | x = tmp.view(img.size()) 363 | 364 | shape = x.shape 365 | tensor_reshaped = x.reshape(shape[0],-1) 366 | tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)] 367 | images = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:]) 368 | 369 | I1 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 370 | for i, im in enumerate(images): 371 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 372 | scale_percent = 50 # percent of original size 373 | width = int(blur.shape[1] * scale_percent / 100) 374 | height = int(blur.shape[0] * scale_percent / 100) 375 | dim = (width, height) 376 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 377 | I1[i][0] = torch.tensor(downsampled) 378 | U1 = torch.zeros_like(images) 379 | for i, im in enumerate(U1): 380 | dim = (shape[2], shape[3]) 381 | U1[i][0] = torch.tensor(cv2.resize(I1[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 382 | 383 | shape = I1.shape 384 | I2 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 385 | for i, im in enumerate(I1): 386 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 387 | scale_percent = 50 # percent of original size 388 | width = int(blur.shape[1] * scale_percent / 100) 389 | height = int(blur.shape[0] * scale_percent / 100) 390 | dim = (width, height) 391 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 392 | I2[i][0] = torch.tensor(downsampled) 393 | U2 = torch.zeros_like(I1) 394 | for i, im in enumerate(U2): 395 | dim = (shape[2], shape[3]) 396 | U2[i][0] = torch.tensor(cv2.resize(I2[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 397 | 398 | shape = I2.shape 399 | I3 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 400 | for i, im in enumerate(I2): 401 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 402 | scale_percent = 50 # percent of original size 403 | width = int(blur.shape[1] * scale_percent / 100) 404 | height = int(blur.shape[0] * scale_percent / 100) 405 | dim = (width, height) 406 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 407 | I3[i][0] = torch.tensor(downsampled) 408 | U3 = torch.zeros_like(I2) 409 | for i, im in enumerate(U3): 410 | dim = (shape[2], shape[3]) 411 | U3[i][0] = torch.tensor(cv2.resize(I3[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 412 | 413 | L1 = images - U1 414 | L2 = I1 - U2 415 | L3 = I2 - U3 416 | 417 | images = Variable(L1).to(device) 418 | 419 | x_r, z_dist = model(images) 420 | kl_loss = 0 421 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta 422 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3)) 423 | loss_vae = kl_loss + rec_loss_vae * theta 424 | mood_val_loss += loss_vae.item() 425 | pbar.update(1) 426 | wandb.log({"Val_loss_mood": mood_val_loss}) 427 | 428 | 429 | #model.eval() 430 | val_loss = 0 431 | with torch.no_grad(): 432 | print('Epoch '+ str(epoch)+ ': Mood Val') 433 | with tqdm(total=len(mood_eval_loader)) as pbar: 434 | for i, data in enumerate(mood_eval_loader): 435 | img = data['img']['data'].squeeze(-1) 436 | tmp = img.view(img.shape[0], 1, -1) 437 | min_vals = tmp.min(2, keepdim=True).values 438 | max_vals = tmp.max(2, keepdim=True).values 439 | tmp = (tmp - min_vals) / max_vals 440 | x = tmp.view(img.size()) 441 | 442 | shape = x.shape 443 | tensor_reshaped = x.reshape(shape[0],-1) 444 | tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)] 445 | images = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:]) 446 | 447 | I1 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 448 | for i, im in enumerate(images): 449 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 450 | scale_percent = 50 # percent of original size 451 | width = int(blur.shape[1] * scale_percent / 100) 452 | height = int(blur.shape[0] * scale_percent / 100) 453 | dim = (width, height) 454 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 455 | I1[i][0] = torch.tensor(downsampled) 456 | U1 = torch.zeros_like(images) 457 | for i, im in enumerate(U1): 458 | dim = (shape[2], shape[3]) 459 | U1[i][0] = torch.tensor(cv2.resize(I1[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 460 | 461 | shape = I1.shape 462 | I2 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 463 | for i, im in enumerate(I1): 464 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 465 | scale_percent = 50 # percent of original size 466 | width = int(blur.shape[1] * scale_percent / 100) 467 | height = int(blur.shape[0] * scale_percent / 100) 468 | dim = (width, height) 469 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 470 | I2[i][0] = torch.tensor(downsampled) 471 | U2 = torch.zeros_like(I1) 472 | for i, im in enumerate(U2): 473 | dim = (shape[2], shape[3]) 474 | U2[i][0] = torch.tensor(cv2.resize(I2[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 475 | 476 | shape = I2.shape 477 | I3 = torch.zeros(shape[0], shape[1], shape[2]//2, shape[3]//2) 478 | for i, im in enumerate(I2): 479 | blur = cv2.GaussianBlur(im[0].numpy(), (5, 5), 5) 480 | scale_percent = 50 # percent of original size 481 | width = int(blur.shape[1] * scale_percent / 100) 482 | height = int(blur.shape[0] * scale_percent / 100) 483 | dim = (width, height) 484 | downsampled = cv2.resize(blur, dim, interpolation = cv2.INTER_AREA) 485 | I3[i][0] = torch.tensor(downsampled) 486 | U3 = torch.zeros_like(I2) 487 | for i, im in enumerate(U3): 488 | dim = (shape[2], shape[3]) 489 | U3[i][0] = torch.tensor(cv2.resize(I3[i][0].numpy(), dim, interpolation = cv2.INTER_LINEAR)) 490 | 491 | L1 = images - U1 492 | L2 = I1 - U2 493 | L3 = I2 - U3 494 | 495 | images = Variable(L1).to(device) 496 | 497 | x_r, z_dist = model(images) 498 | kl_loss = 0 499 | kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta 500 | rec_loss_vae = rec_loss_fn(x_r, images, sumdim=(1,2,3)) 501 | loss_vae = kl_loss + rec_loss_vae * theta 502 | val_loss += loss_vae.item() 503 | pbar.update(1) 504 | wandb.log({"Val_loss_oasis": val_loss}) -------------------------------------------------------------------------------- /DEMO.md: -------------------------------------------------------------------------------- 1 | # StRegA Demo Workflow 2 | 3 | This document provides comprehensive instructions for training and testing StRegA (Segmentation Regularised Anomaly) models for unsupervised anomaly detection in Brain MRIs using Compact Context-encoding Variational Autoencoder. 4 | 5 | Based on the paper: ["StRegA: Unsupervised Anomaly Detection in Brain MRIs using a Compact Context-encoding Variational Autoencoder"](https://doi.org/10.1016/j.compbiomed.2022.106093) 6 | 7 | ## Table of Contents 8 | 9 | - [Prerequisites](#prerequisites) 10 | - [Installation](#installation) 11 | - [Data Preparation](#data-preparation) 12 | - [Training](#training) 13 | - [Testing/Inference](#testinginference) 14 | - [Paper Experiments](#paper-experiments) 15 | - [Additional Features](#additional-features-in-repo) 16 | - [Missing Features](#features-from-paper-not-in-repo) 17 | 18 | --- 19 | 20 | ## Prerequisites 21 | 22 | ### Hardware Requirements 23 | - NVIDIA GPU with CUDA support (recommended: 8GB+ VRAM) 24 | - Minimum 16GB RAM 25 | - 50GB+ storage for datasets 26 | 27 | ### Software Requirements 28 | - Python 3.8+ 29 | - PyTorch 1.9+ 30 | - CUDA 11.0+ (for GPU support) 31 | 32 | --- 33 | 34 | ## Installation 35 | 36 | ### 1. Clone the Repository 37 | 38 | ```bash 39 | git clone https://github.com/soumickmj/StRegA.git 40 | cd StRegA 41 | ``` 42 | 43 | ### 2. Install Dependencies 44 | 45 | ```bash 46 | pip install torch torchvision torchaudio 47 | pip install torchio nibabel h5py numpy scipy scikit-image matplotlib 48 | pip install transformers # For HuggingFace model 49 | pip install wandb # Optional: for experiment tracking 50 | pip install tqdm pandas seaborn 51 | ``` 52 | 53 | ### 3. Verify Installation 54 | 55 | ```python 56 | import torch 57 | import torchio 58 | from ccevae import VAE 59 | print(f"PyTorch version: {torch.__version__}") 60 | print(f"CUDA available: {torch.cuda.is_available()}") 61 | ``` 62 | 63 | --- 64 | 65 | ## Data Preparation 66 | 67 | ### Required Data Format 68 | 69 | StRegA expects **FSL-segmented** brain MRI data. The preprocessing pipeline involves: 70 | 71 | 1. **Skull stripping** using FSL BET 72 | 2. **Brain tissue segmentation** using FSL FAST 73 | 3. **Resampling** to 256×256 slices 74 | 75 | ### Supported Datasets 76 | 77 | The paper uses the following datasets: 78 | 79 | | Dataset | Type | Description | 80 | |---------|------|-------------| 81 | | IXI | T1, T2, PD | Normal brain MRIs | 82 | | MOOD | T1 | Medical Out-of-Distribution Detection challenge | 83 | | BraTS | T1, T1ce, T2, FLAIR | Brain tumor segmentation (for testing) | 84 | 85 | ### Data Preprocessing Script 86 | 87 | Create segmented data using FSL: 88 | 89 | ```bash 90 | # Example FSL preprocessing 91 | # 1. Skull stripping 92 | bet input.nii.gz brain.nii.gz -f 0.5 -g 0 93 | 94 | # 2. Tissue segmentation 95 | fast -t 1 -n 3 -H 0.1 -I 4 -l 20.0 -o segmented brain.nii.gz 96 | ``` 97 | 98 | ### HDF5 Data Format 99 | 100 | Training data should be stored in HDF5 format with the following structure: 101 | 102 | ```python 103 | import h5py 104 | import nibabel as nib 105 | import numpy as np 106 | 107 | # Example: Creating training HDF5 file 108 | with h5py.File('training_data.h5', 'w') as f: 109 | for i, nifti_path in enumerate(nifti_files): 110 | data = nib.load(nifti_path).get_fdata() 111 | f.create_dataset(f'{i:05d}', data=data) 112 | ``` 113 | 114 | --- 115 | 116 | ## Training 117 | 118 | ### Model Architecture 119 | 120 | The cceVAE (Compact Context-encoding VAE) architecture: 121 | 122 | - **Input size**: 256×256 (2D slices) 123 | - **Latent dimension**: 1024 124 | - **Feature map sizes**: (16, 64, 256, 1024) 125 | - **Activation**: PReLU 126 | 127 | ### Training Parameters 128 | 129 | | Parameter | Default Value | Description | 130 | |-----------|---------------|-------------| 131 | | `batch_size` | 16 | Training batch size | 132 | | `num_epochs` | 100 | Number of training epochs | 133 | | `lr` | 1e-4 | Learning rate | 134 | | `z_dim` | 1024 | Latent space dimension | 135 | | `ce_factor` | 0.5 | Context encoding loss weight | 136 | | `beta` | 0.01 | KL divergence weight | 137 | | `patch_size` | (256, 256, 1) | 2D slice size | 138 | | `patches_per_volume` | 256 | Patches sampled per volume | 139 | 140 | ### Training Script 141 | 142 | Create a training script `run_training.py`: 143 | 144 | ```python 145 | import os 146 | import torch 147 | import torch.nn as nn 148 | import numpy as np 149 | from torch.utils.data import DataLoader 150 | from torch.optim import Adam 151 | from torch.cuda.amp import autocast, GradScaler 152 | import torchio as tio 153 | 154 | from ccevae import VAE 155 | from helpers import kl_loss_fn, rec_loss_fn, geco_beta_update, get_square_mask 156 | from dataloaders.ixi import IXITrainSet 157 | 158 | # Configuration 159 | config = { 160 | 'gpu_id': "0", 161 | 'seed': 1701, 162 | 'batch_size': 16, 163 | 'num_epochs': 100, 164 | 'lr': 1e-4, 165 | 'z_dim': 1024, 166 | 'model_feature_map_sizes': (16, 64, 256, 1024), 167 | 'ce_factor': 0.5, 168 | 'beta': 0.01, 169 | 'theta': 1.0, 170 | 'use_geco': False, 171 | 'patch_size': (256, 256, 1), 172 | 'patches_per_volume': 256, 173 | 'patchQ_len': 512, 174 | 'save_path': './checkpoints', 175 | 'train_id': 'ceVAE2D_brain' 176 | } 177 | 178 | # Set random seed 179 | os.environ["CUDA_VISIBLE_DEVICES"] = config['gpu_id'] 180 | torch.manual_seed(config['seed']) 181 | np.random.seed(config['seed']) 182 | 183 | # Setup device 184 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 185 | print(f"Using device: {device}") 186 | 187 | # Create model 188 | input_size = (1, 256, 256) 189 | model = VAE( 190 | input_size=input_size, 191 | z_dim=config['z_dim'], 192 | fmap_sizes=config['model_feature_map_sizes'], 193 | conv_op=nn.Conv2d, 194 | tconv_op=nn.ConvTranspose2d, 195 | activation_op=torch.nn.PReLU 196 | ) 197 | model.d = 2 # 2D model 198 | model.to(device) 199 | 200 | # Optimizer and scaler 201 | optimizer = Adam(model.parameters(), lr=config['lr']) 202 | scaler = GradScaler() 203 | 204 | # Load datasets (modify paths as needed) 205 | train_data_paths = { 206 | 'ixi_t1': 'path/to/ixi_t1_segmented.hdf5', 207 | 'ixi_t2': 'path/to/ixi_t2_segmented.hdf5', 208 | 'ixi_pd': 'path/to/ixi_pd_segmented.hdf5', 209 | 'mood_t1': 'path/to/mood_t1_seg.h5' 210 | } 211 | 212 | # Create dataset (example with IXI T1) 213 | train_indices = list(range(100, 350)) 214 | trainset = IXITrainSet( 215 | indices=train_indices, 216 | data_path=train_data_paths['ixi_t1'], 217 | lazypatch=True 218 | ) 219 | 220 | # Use TorchIO Queue for patch-based training 221 | patch_queue = tio.data.Queue( 222 | subjects_dataset=trainset, 223 | max_length=config['patchQ_len'], 224 | samples_per_volume=config['patches_per_volume'], 225 | sampler=tio.data.UniformSampler(patch_size=config['patch_size']), 226 | ) 227 | 228 | train_loader = DataLoader( 229 | dataset=patch_queue, 230 | batch_size=config['batch_size'], 231 | shuffle=True, 232 | num_workers=0 233 | ) 234 | 235 | # Training loop 236 | vae_loss_ema = 1.0 237 | theta = config['theta'] 238 | 239 | for epoch in range(config['num_epochs']): 240 | model.train() 241 | epoch_loss = 0.0 242 | 243 | for i, data in enumerate(train_loader): 244 | img = data['img']['data'].squeeze(-1) 245 | 246 | # Normalize 247 | tmp = img.view(img.shape[0], 1, -1) 248 | min_vals = tmp.min(2, keepdim=True).values 249 | max_vals = tmp.max(2, keepdim=True).values 250 | tmp = (tmp - min_vals) / max_vals 251 | x = tmp.view(img.size()) 252 | 253 | # Remove NaN samples 254 | shape = x.shape 255 | tensor_reshaped = x.reshape(shape[0], -1) 256 | tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(), dim=1)] 257 | tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0], *shape[1:]).to(device) 258 | 259 | if tensor.shape[0] == 0: 260 | continue 261 | 262 | optimizer.zero_grad() 263 | 264 | # VAE forward pass 265 | with autocast(): 266 | loss_vae = 0 267 | if config['ce_factor'] < 1: 268 | x_r, z_dist = model(tensor) 269 | kl_loss = kl_loss_fn(z_dist, sumdim=(1, 2, 3)) * config['beta'] 270 | rec_loss_vae = rec_loss_fn(x_r, tensor, sumdim=(1, 2, 3)) 271 | loss_vae = kl_loss + rec_loss_vae * theta 272 | 273 | # Context Encoding (CE) Part 274 | loss_ce = 0 275 | if config['ce_factor'] > 0: 276 | ce_tensor = get_square_mask( 277 | tensor.shape, 278 | square_size=(0, np.max(input_size[1:]) // 2), 279 | noise_val=(torch.min(tensor).item(), torch.max(tensor).item()), 280 | n_squares=(0, 3), 281 | ) 282 | ce_tensor = torch.from_numpy(ce_tensor).float().to(device) 283 | inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, tensor) 284 | 285 | with autocast(): 286 | x_rec_ce, _ = model(inpt_noisy) 287 | rec_loss_ce = rec_loss_fn(x_rec_ce, tensor, sumdim=(1, 2, 3)) 288 | loss_ce = rec_loss_ce 289 | loss = (1.0 - config['ce_factor']) * loss_vae + config['ce_factor'] * loss_ce 290 | else: 291 | loss = loss_vae 292 | 293 | # GECO update (optional) 294 | if config['use_geco'] and config['ce_factor'] < 1: 295 | g_goal = 0.1 296 | g_lr = 1e-4 297 | vae_loss_ema = (1.0 - 0.9) * rec_loss_vae + 0.9 * vae_loss_ema 298 | theta = geco_beta_update(theta, vae_loss_ema, g_goal, g_lr, speedup=2) 299 | 300 | # Backward pass 301 | scaler.scale(loss).backward() 302 | scaler.step(optimizer) 303 | scaler.update() 304 | 305 | epoch_loss += loss.item() 306 | 307 | if i % 10 == 0: 308 | print(f"Epoch [{epoch}/{config['num_epochs']}] Step [{i}] Loss: {loss.item():.4f}") 309 | 310 | # Save checkpoint 311 | if epoch % 4 == 0: 312 | checkpoint = { 313 | 'state_dict': model.state_dict(), 314 | 'optimizer': optimizer.state_dict(), 315 | 'AMPScaler': scaler.state_dict() 316 | } 317 | os.makedirs(config['save_path'], exist_ok=True) 318 | torch.save( 319 | checkpoint, 320 | os.path.join(config['save_path'], f"{config['train_id']}-epoch-{epoch}.pth.tar") 321 | ) 322 | print(f"Checkpoint saved at epoch {epoch}") 323 | 324 | print("Training complete!") 325 | ``` 326 | 327 | ### Running Training 328 | 329 | ```bash 330 | python run_training.py 331 | ``` 332 | 333 | --- 334 | 335 | ## Testing/Inference 336 | 337 | ### Quick Start with CLI 338 | 339 | The `inference.py` script supports three input formats: 340 | 341 | ```bash 342 | # 1. Single NIfTI file 343 | python inference.py -i /path/to/brain.nii.gz -o /path/to/output/ 344 | 345 | # 2. Directory of NIfTI files (processes all .nii and .nii.gz files) 346 | python inference.py -i /path/to/nifti_folder/ -o /path/to/output/ 347 | 348 | # 3. HDF5 file (same format as training dataset) 349 | python inference.py -i /path/to/data.h5 -o /path/to/output/ --input_format h5 350 | 351 | # HDF5 with region key (for MOOD format) 352 | python inference.py -i /path/to/mood.h5 -o /path/to/output/ --input_format h5 --h5_region brain 353 | 354 | # HDF5 with specific indices 355 | python inference.py -i /path/to/data.h5 -o /path/to/output/ --input_format h5 --h5_indices 0-100 356 | 357 | # Output to HDF5 file instead of NIfTI 358 | python inference.py -i /path/to/data.h5 -o /path/to/results.h5 --input_format h5 359 | 360 | # Using local checkpoint 361 | python inference.py -i /path/to/input/ -o /path/to/output/ -c /path/to/checkpoint.pth.tar 362 | ``` 363 | 364 | ### CLI Options 365 | 366 | | Option | Description | 367 | |--------|-------------| 368 | | `--input, -i` | Input path: NIfTI file, directory of NIfTI files, or HDF5 file | 369 | | `--output, -o` | Output directory or HDF5 file path | 370 | | `--input_format` | Input format: `auto` (default), `nifti`, or `h5` | 371 | | `--h5_region` | Region key for HDF5 (e.g., "brain" for MOOD format) | 372 | | `--h5_indices` | Indices to process (e.g., "0,1,2" or "0-100") | 373 | | `--checkpoint, -c` | Path to model checkpoint (uses HuggingFace if not provided) | 374 | | `--checkpoint_format` | Checkpoint format: `pth.tar` or `ptrh` | 375 | | `--device, -d` | Device for inference (default: cuda:0) | 376 | | `--area_threshold` | Minimum area for morphological opening (default: 256) | 377 | | `--anomaly_threshold` | Initial anomaly detection threshold (default: 0.2) | 378 | 379 | ### Option 1: Using Locally Trained Model 380 | 381 | ```python 382 | import torch 383 | import numpy as np 384 | import nibabel as nib 385 | from scipy import ndimage 386 | from skimage import morphology, filters 387 | from torchio import transforms 388 | from torch.cuda.amp import autocast 389 | 390 | from ccevae import VAE 391 | import torch.nn as nn 392 | 393 | # Load model from checkpoint 394 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 395 | 396 | # Initialize model architecture 397 | input_size = (1, 256, 256) 398 | z_dim = 1024 399 | model_feature_map_sizes = (16, 64, 256, 1024) 400 | 401 | model = VAE( 402 | input_size=input_size, 403 | z_dim=z_dim, 404 | fmap_sizes=model_feature_map_sizes, 405 | conv_op=nn.Conv2d, 406 | tconv_op=nn.ConvTranspose2d, 407 | activation_op=torch.nn.PReLU 408 | ) 409 | model.d = 2 410 | 411 | # Load checkpoint 412 | checkpoint = torch.load('checkpoints/ceVAE2D_brain-epoch-96.pth.tar', map_location=device) 413 | model.load_state_dict(checkpoint['state_dict']) 414 | model.to(device) 415 | model.eval() 416 | 417 | # Load and preprocess test volume 418 | def preprocess_volume(nifti_path): 419 | """Load and preprocess a NIfTI volume for testing.""" 420 | vol = nib.load(nifti_path).get_fdata() 421 | vol = np.moveaxis(vol, 2, 0) # Move slices to first dimension 422 | 423 | # Pad/crop to 256x256 424 | data_item = torch.tensor(vol).unsqueeze(dim=0) 425 | out = transforms.CropOrPad((vol.shape[0], 256, 256))(data_item) 426 | out = out.squeeze(dim=0).unsqueeze(dim=1) 427 | 428 | return out.float() 429 | 430 | # Anomaly detection pipeline 431 | def detect_anomalies(model, volume, device): 432 | """ 433 | Run StRegA anomaly detection on a volume. 434 | 435 | Returns: 436 | anomaly_mask: Binary mask of detected anomalies 437 | diff_map: Continuous difference map 438 | """ 439 | volume = volume.to(device) 440 | 441 | # Normalize 442 | volume = (volume - torch.min(volume)) / (torch.max(volume) - torch.min(volume)) 443 | volume = torch.nan_to_num(volume, nan=0.0) 444 | 445 | with torch.no_grad(): 446 | with autocast(): 447 | reconstruction, _ = model(volume) 448 | 449 | reconstruction = reconstruction.float() 450 | 451 | # Calculate difference (reconstruction error) 452 | diff_mask = (reconstruction.cpu().numpy() - volume.cpu().numpy()) 453 | 454 | # Post-processing 455 | # 1. Remove negative differences (we only care about reconstructing anomalies) 456 | m_diff_mask = diff_mask.copy() 457 | m_diff_mask[m_diff_mask < 0] = 0 458 | 459 | # 2. Manual thresholding (optional initial threshold) 460 | m_diff_mask[m_diff_mask > 0.2] = 1 461 | 462 | # 3. Otsu thresholding 463 | val = filters.threshold_otsu(m_diff_mask) 464 | thr = m_diff_mask > val 465 | thr[thr < 0] = 0 466 | 467 | # 4. Morphological opening to remove small false positives 468 | final = np.zeros_like(thr) 469 | for i in range(thr.shape[0]): 470 | final[i, 0] = morphology.area_opening(thr[i, 0], area_threshold=256) 471 | 472 | # Remove detections outside brain mask 473 | final[volume.cpu().numpy() == 0] = 0 474 | 475 | return final, diff_mask 476 | 477 | # Example usage 478 | volume = preprocess_volume('path/to/test_volume_segmented.nii.gz') 479 | anomaly_mask, diff_map = detect_anomalies(model, volume, device) 480 | 481 | print(f"Volume shape: {volume.shape}") 482 | print(f"Anomaly mask shape: {anomaly_mask.shape}") 483 | print(f"Number of anomalous voxels: {np.sum(anomaly_mask)}") 484 | ``` 485 | 486 | ### Option 2: Using HuggingFace Pre-trained Model 487 | 488 | ```python 489 | import torch 490 | import numpy as np 491 | import nibabel as nib 492 | from scipy import ndimage 493 | from skimage import morphology, filters 494 | from torchio import transforms 495 | from torch.cuda.amp import autocast 496 | from transformers import AutoModel 497 | 498 | # Load model from HuggingFace 499 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 500 | 501 | modelHF = AutoModel.from_pretrained( 502 | "soumickmj/StRegA_cceVAE2D_Brain_MOOD_IXIT1_IXIT2_IXIPD", 503 | trust_remote_code=True 504 | ) 505 | model = modelHF.model.to(device) 506 | model.eval() 507 | 508 | # Optional: Save HuggingFace model as local checkpoint 509 | # torch.save(model, "checkpoints/brain_huggingface.ptrh") 510 | 511 | # Load and preprocess test volume 512 | def preprocess_volume(nifti_path): 513 | """Load and preprocess a NIfTI volume for testing.""" 514 | vol = nib.load(nifti_path).get_fdata() 515 | vol = np.moveaxis(vol, 2, 0) 516 | 517 | data_item = torch.tensor(vol).unsqueeze(dim=0) 518 | out = transforms.CropOrPad((vol.shape[0], 256, 256))(data_item) 519 | out = out.squeeze(dim=0).unsqueeze(dim=1) 520 | 521 | return out.float() 522 | 523 | # Anomaly detection function (same as above) 524 | def detect_anomalies(model, volume, device): 525 | """Run StRegA anomaly detection.""" 526 | volume = volume.to(device) 527 | volume = (volume - torch.min(volume)) / (torch.max(volume) - torch.min(volume)) 528 | volume = torch.nan_to_num(volume, nan=0.0) 529 | 530 | with torch.no_grad(): 531 | with autocast(): 532 | reconstruction, _ = model(volume) 533 | 534 | reconstruction = reconstruction.float() 535 | diff_mask = (reconstruction.cpu().numpy() - volume.cpu().numpy()) 536 | 537 | m_diff_mask = diff_mask.copy() 538 | m_diff_mask[m_diff_mask < 0] = 0 539 | m_diff_mask[m_diff_mask > 0.2] = 1 540 | 541 | val = filters.threshold_otsu(m_diff_mask) 542 | thr = m_diff_mask > val 543 | thr[thr < 0] = 0 544 | 545 | final = np.zeros_like(thr) 546 | for i in range(thr.shape[0]): 547 | final[i, 0] = morphology.area_opening(thr[i, 0], area_threshold=256) 548 | final[volume.cpu().numpy() == 0] = 0 549 | 550 | return final, diff_mask 551 | 552 | # Example usage 553 | volume = preprocess_volume('path/to/test_volume_segmented.nii.gz') 554 | anomaly_mask, diff_map = detect_anomalies(model, volume, device) 555 | 556 | print(f"Anomaly detection complete using HuggingFace model!") 557 | print(f"Detected {np.sum(anomaly_mask)} anomalous voxels") 558 | ``` 559 | 560 | ### Evaluation Metrics 561 | 562 | ```python 563 | def dice_coefficient(true_mask, pred_mask, non_seg_score=1.0): 564 | """ 565 | Compute Dice coefficient between two binary masks. 566 | 567 | Args: 568 | true_mask: Ground truth binary mask 569 | pred_mask: Predicted binary mask 570 | non_seg_score: Score to return when both masks are empty 571 | 572 | Returns: 573 | Dice coefficient (0-1) 574 | """ 575 | assert true_mask.shape == pred_mask.shape 576 | 577 | true_mask = np.asarray(true_mask).astype(bool) 578 | pred_mask = np.asarray(pred_mask).astype(bool) 579 | 580 | im_sum = true_mask.sum() + pred_mask.sum() 581 | if im_sum == 0: 582 | return non_seg_score 583 | 584 | intersection = np.logical_and(true_mask, pred_mask) 585 | return 2. * intersection.sum() / im_sum 586 | 587 | # Example evaluation 588 | # dice = dice_coefficient(ground_truth_mask, anomaly_mask) 589 | # print(f"Dice Score: {dice:.4f}") 590 | ``` 591 | 592 | --- 593 | 594 | ## Paper Experiments 595 | 596 | ### Experiment 1: IXI Dataset Training 597 | 598 | Training on IXI dataset with T1, T2, and PD-weighted images: 599 | 600 | ```python 601 | # Dataset configuration 602 | ixi_config = { 603 | 't1_path': 'path/to/ixi_t1_segmented.hdf5', 604 | 't2_path': 'path/to/ixi_t2_segmented.hdf5', 605 | 'pd_path': 'path/to/ixi_pd_segmented.hdf5', 606 | 'train_indices': list(range(100, 350)), # ~250 subjects 607 | 'val_indices': list(range(30, 50)), # ~20 subjects 608 | } 609 | ``` 610 | 611 | ### Experiment 2: MOOD Challenge Training 612 | 613 | Training on MOOD (Medical Out-of-Distribution) T1 brain data: 614 | 615 | ```python 616 | mood_config = { 617 | 'data_path': 'path/to/mood_t1_seg.h5', 618 | 'train_indices': list(range(100, 350)), 619 | } 620 | ``` 621 | 622 | ### Experiment 3: Combined Training (Master Model) 623 | 624 | The HuggingFace "master" checkpoint was trained on combined data: 625 | 626 | - MOOD T1 627 | - IXI T1 628 | - IXI T2 629 | - IXI PD 630 | 631 | ```python 632 | # Combine datasets 633 | from torch.utils.data import ConcatDataset 634 | 635 | combined_trainset = ConcatDataset([ 636 | ixi_t1_trainset, 637 | ixi_t2_trainset, 638 | ixi_pd_trainset, 639 | mood_trainset 640 | ]) 641 | ``` 642 | 643 | ### Experiment 4: BraTS Evaluation 644 | 645 | Testing on BraTS dataset for tumor detection: 646 | 647 | ```python 648 | # BraTS test data structure 649 | brats_paths = { 650 | 'non_seg': 'path/to/brats/non_seg/', # Original images 651 | 'seg': 'path/to/brats/seg/', # FSL-segmented images 652 | 'mask': 'path/to/brats/mask/' # Ground truth tumor masks 653 | } 654 | 655 | # Run evaluation on BraTS 656 | import os 657 | 658 | file_list = sorted(os.listdir(brats_paths['seg'])) 659 | dice_scores = [] 660 | 661 | for file in file_list: 662 | # Load data 663 | seg_vol = nib.load(os.path.join(brats_paths['seg'], file)).get_fdata() 664 | gt_mask = nib.load(os.path.join(brats_paths['mask'], file)).get_fdata() 665 | 666 | # Preprocess 667 | volume = preprocess_volume(os.path.join(brats_paths['seg'], file)) 668 | 669 | # Detect anomalies 670 | anomaly_mask, _ = detect_anomalies(model, volume, device) 671 | 672 | # Calculate Dice 673 | gt_mask[gt_mask > 0] = 1 674 | dice = dice_coefficient(gt_mask, anomaly_mask) 675 | dice_scores.append(dice) 676 | print(f"{file}: Dice = {dice:.4f}") 677 | 678 | print(f"\nMean Dice: {np.mean(dice_scores):.4f} ± {np.std(dice_scores):.4f}") 679 | ``` 680 | 681 | --- 682 | 683 | ## Additional Features in Repo 684 | 685 | The following features are implemented in the repository but not detailed in the paper: 686 | 687 | ### 1. GMVAE (Gaussian Mixture VAE) 688 | Located in `misc/gmvae.py`, implements a Gaussian Mixture VAE for clustering-based anomaly detection. 689 | 690 | ```python 691 | from misc.gmvae import GMVAE, GMVAE_Trainer 692 | # See misc/gmvae.py for implementation details 693 | ``` 694 | 695 | ### 2. Skip-connection Autoencoder 696 | Located in `misc/skipae.py` and `ceVae/skipae.py`, implements an autoencoder with skip connections. 697 | 698 | ```python 699 | from misc.skipae import Skip_AE 700 | # Alternative architecture with skip connections 701 | ``` 702 | 703 | ### 3. Scale-space VAE (SSVAE) 704 | Located in `misc/ssae.py`, implements scale-space decomposition for anomaly detection. 705 | 706 | ### 4. GECO (Generalized ELBO with Constrained Optimization) 707 | Automatic balancing of reconstruction and KL losses: 708 | 709 | ```python 710 | # Enable GECO in training 711 | config['use_geco'] = True 712 | 713 | # Parameters 714 | g_goal = 0.1 # Target reconstruction error 715 | g_lr = 1e-4 # GECO learning rate 716 | ``` 717 | 718 | ### 5. 3D Model Support 719 | The architecture supports both 2D and 3D models: 720 | 721 | ```python 722 | # For 3D model 723 | conv = nn.Conv3d 724 | convt = nn.ConvTranspose3d 725 | model.d = 3 726 | ``` 727 | 728 | ### 6. Multiple Dataloader Variants 729 | - `dataloaders/mood.py`: MOOD dataset loader 730 | - `dataloaders/ixi.py`: IXI dataset loader 731 | - `dataloaders/torchiowrap.py`: TorchIO wrapper for HDF5 data 732 | 733 | --- 734 | 735 | ## Features from Paper Not in Repo 736 | 737 | The following aspects mentioned in the paper may require additional implementation: 738 | 739 | ### 1. Quantitative Evaluation Metrics 740 | The paper reports multiple metrics that should be computed: 741 | 742 | ```python 743 | def compute_metrics(pred_mask, true_mask): 744 | """Compute evaluation metrics from the paper.""" 745 | from sklearn.metrics import precision_score, recall_score, f1_score 746 | 747 | pred_flat = pred_mask.flatten().astype(bool) 748 | true_flat = true_mask.flatten().astype(bool) 749 | 750 | dice = dice_coefficient(true_mask, pred_mask) 751 | precision = precision_score(true_flat, pred_flat) 752 | recall = recall_score(true_flat, pred_flat) 753 | f1 = f1_score(true_flat, pred_flat) 754 | 755 | return { 756 | 'dice': dice, 757 | 'precision': precision, 758 | 'recall': recall, 759 | 'f1': f1 760 | } 761 | ``` 762 | 763 | ### 2. Ablation Study Configurations 764 | The paper includes ablation studies with different CE factors: 765 | 766 | ```python 767 | # Ablation: CE factor variations 768 | ce_factors_to_test = [0.0, 0.25, 0.5, 0.75, 1.0] 769 | 770 | for ce_factor in ce_factors_to_test: 771 | config['ce_factor'] = ce_factor 772 | # Run training and evaluation 773 | ``` 774 | 775 | ### 3. Cross-validation Setup 776 | For robust evaluation: 777 | 778 | ```python 779 | from sklearn.model_selection import KFold 780 | 781 | kf = KFold(n_splits=5, shuffle=True, random_state=42) 782 | for fold, (train_idx, val_idx) in enumerate(kf.split(subject_ids)): 783 | # Train and evaluate on each fold 784 | pass 785 | ``` 786 | 787 | --- 788 | 789 | ## Troubleshooting 790 | 791 | ### Common Issues 792 | 793 | 1. **CUDA Out of Memory** 794 | - Reduce batch_size 795 | - Use gradient checkpointing 796 | - Reduce patches_per_volume 797 | 798 | 2. **NaN Loss Values** 799 | - Check data normalization 800 | - Lower learning rate 801 | - Add gradient clipping 802 | 803 | 3. **Poor Anomaly Detection** 804 | - Ensure FSL segmentation is correct 805 | - Adjust Otsu threshold 806 | - Tune morphological opening area_threshold 807 | 808 | ### Performance Tips 809 | 810 | 1. Use mixed precision training (enabled by default with GradScaler) 811 | 2. Increase num_workers for faster data loading 812 | 3. Use SSD storage for faster HDF5 access 813 | 814 | --- 815 | 816 | ## Citation 817 | 818 | If you use this code, please cite: 819 | 820 | ```bibtex 821 | @article{chatterjee2022strega, 822 | title={StRegA: Unsupervised Anomaly Detection in Brain MRIs using a Compact Context-encoding Variational Autoencoder}, 823 | author={Chatterjee, Soumick and Sciarra, Alessandro and D{\"u}nnwald, Max and Tummala, Pavan and Agrawal, Shubham Kumar and Jauhari, Aishwarya and Kalra, Aman and Oeltze-Jafra, Steffen and Speck, Oliver and N{\"u}rnberger, Andreas}, 824 | journal={Computers in Biology and Medicine}, 825 | pages={106093}, 826 | year={2022}, 827 | publisher={Elsevier}, 828 | doi={10.1016/j.compbiomed.2022.106093} 829 | } 830 | ``` 831 | 832 | --- 833 | 834 | ## Contact 835 | 836 | For questions or issues: 837 | - Email: soumick.chatterjee@ovgu.de or contact@soumick.com 838 | - GitHub Issues: https://github.com/soumickmj/StRegA/issues 839 | --------------------------------------------------------------------------------