├── github_image.png ├── rocaseg ├── repro │ ├── __init__.py │ └── _seed.py ├── __init__.py ├── models │ ├── __init__.py │ ├── discr_a.py │ ├── unet_lext.py │ └── unet_lext_aux.py ├── datasets │ ├── __init__.py │ ├── constants.py │ ├── meta_oai.py │ ├── dataset_maknee.py │ ├── prepare_dataset_maknee.py │ ├── sources.py │ ├── dataset_okoa.py │ ├── prepare_dataset_okoa.py │ ├── dataset_oai_imo.py │ └── prepare_dataset_oai_imo.py ├── preproc │ ├── __init__.py │ ├── custom.py │ └── transforms.py ├── components │ ├── __init__.py │ ├── mixup.py │ ├── losses.py │ ├── checkpoint.py │ ├── formats.py │ └── metrics.py ├── describe.py ├── resample.py ├── analyze_predictions_multi.py ├── evaluate.py ├── train_baseline.py └── analyze_predictions_single.py ├── MANIFEST.in ├── environment.yml ├── setup.py ├── README.md ├── notebooks └── Statistical_tests.ipynb └── scripts └── runner.sh /github_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIPT-Oulu/RobustCartilageSegmentation/HEAD/github_image.png -------------------------------------------------------------------------------- /rocaseg/repro/__init__.py: -------------------------------------------------------------------------------- 1 | from ._seed import set_ultimate_seed 2 | 3 | 4 | __all__ = [ 5 | 'set_ultimate_seed', 6 | ] 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | 3 | recursive-include tests * 4 | recursive-exclude * __pycache__ 5 | recursive-exclude * *.py[co] 6 | -------------------------------------------------------------------------------- /rocaseg/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for rocaseg.""" 4 | 5 | __author__ = """Egor Panfilov""" 6 | __email__ = 'egor.v.panfilov@gmail.com' 7 | __version__ = '0.1.0' 8 | -------------------------------------------------------------------------------- /rocaseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_lext import UNetLext 2 | from .unet_lext_aux import UNetLextAux 3 | 4 | from .discr_a import DiscriminatorA 5 | 6 | 7 | dict_models = { 8 | 'unet_lext': UNetLext, 9 | 'unet_lext_aux': UNetLextAux, 10 | 11 | 'discriminator_a': DiscriminatorA, 12 | } 13 | -------------------------------------------------------------------------------- /rocaseg/repro/_seed.py: -------------------------------------------------------------------------------- 1 | def set_ultimate_seed(base_seed=777): 2 | import os 3 | import random 4 | os.environ['PYTHONHASHSEED'] = str(base_seed) 5 | random.seed(base_seed) 6 | 7 | try: 8 | import numpy as np 9 | np.random.seed(base_seed) 10 | except ModuleNotFoundError: 11 | print('Module `numpy` has not been found') 12 | try: 13 | import torch 14 | torch.manual_seed(base_seed + 1) 15 | torch.cuda.manual_seed_all(base_seed + 2) 16 | torch.backends.cudnn.deterministic = True 17 | except ModuleNotFoundError: 18 | print('Module `torch` has not been found') 19 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: env_rocaseg 2 | channels: 3 | - pytorch 4 | - simpleitk 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - python=3.7.3 9 | - pip=19.1.1 10 | - joblib 11 | - cython 12 | - numpy 13 | - scipy 14 | - pandas=0.23.4 15 | - scikit-image=0.15.0 16 | - pywavelets 17 | - scikit-learn 18 | - matplotlib 19 | - seaborn 20 | - opencv 21 | - jupyter 22 | - xlrd 23 | - pytorch=1.2.0 24 | - torchvision=0.4.0 25 | # - cudatoolkit=10.0 26 | - simpleitk 27 | - pip: 28 | - pydicom==1.2.2 29 | - dicom2nifti==2.1.5 30 | - tifffile 31 | - sas7bdat 32 | - nibabel 33 | - tqdm 34 | - click 35 | - tensorboard>=1.14.0 36 | - tensorflow>=1.14.0 37 | -------------------------------------------------------------------------------- /rocaseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_oai_imo import (DatasetOAIiMoSagittal2d, 2 | index_from_path_oai_imo) 3 | from .dataset_okoa import (DatasetOKOASagittal2d, 4 | index_from_path_okoa) 5 | from .dataset_maknee import (DatasetMAKNEESagittal2d, 6 | index_from_path_maknee) 7 | from . import meta_oai 8 | from . import constants 9 | from .sources import sources_from_path 10 | 11 | 12 | __all__ = [ 13 | 'index_from_path_oai_imo', 14 | 'index_from_path_okoa', 15 | 'index_from_path_maknee', 16 | 'DatasetOAIiMoSagittal2d', 17 | 'DatasetOKOASagittal2d', 18 | 'DatasetMAKNEESagittal2d', 19 | 'meta_oai', 20 | 'constants', 21 | 'sources_from_path', 22 | ] 23 | -------------------------------------------------------------------------------- /rocaseg/preproc/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom import Normalize, UnNormalize, PercentileClippingAndToFloat 2 | from .transforms import (DualCompose, OneOf, OneOrOther, ImageOnly, NoTransform, 3 | ToTensor, VerticalFlip, HorizontalFlip, Flip, Scale, 4 | Crop, CenterCrop, Pad, GammaCorrection, BilateralFilter) 5 | 6 | 7 | __all__ = [ 8 | 'Normalize', 9 | 'UnNormalize', 10 | 'PercentileClippingAndToFloat', 11 | 'DualCompose', 12 | 'OneOf', 13 | 'OneOrOther', 14 | 'ImageOnly', 15 | 'NoTransform', 16 | 'ToTensor', 17 | 'VerticalFlip', 18 | 'HorizontalFlip', 19 | 'Flip', 20 | 'Scale', 21 | 'Crop', 22 | 'CenterCrop', 23 | 'Pad', 24 | 'GammaCorrection', 25 | 'BilateralFilter', 26 | ] 27 | -------------------------------------------------------------------------------- /rocaseg/components/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import optim 3 | from rocaseg.components.losses import CrossEntropyLoss 4 | from rocaseg.components.metrics import (confusion_matrix, dice_score, 5 | dice_score_from_cm) 6 | from rocaseg.components.checkpoint import CheckpointHandler 7 | 8 | 9 | dict_losses = { 10 | 'bce_loss': nn.BCEWithLogitsLoss, 11 | 'multi_ce_loss': CrossEntropyLoss, 12 | } 13 | 14 | 15 | dict_metrics = { 16 | 'confusion_matrix': confusion_matrix, 17 | 'dice_score': dice_score, 18 | 'bce_loss': nn.BCELoss(), 19 | } 20 | 21 | 22 | dict_optimizers = { 23 | 'sgd': optim.SGD, 24 | 'adam': optim.Adam, 25 | } 26 | 27 | 28 | __all__ = [ 29 | 'dict_losses', 30 | 'dict_metrics', 31 | 'dict_optimizers', 32 | 'confusion_matrix', 33 | 'dice_score', 34 | 'dice_score_from_cm', 35 | 'CheckpointHandler', 36 | ] 37 | -------------------------------------------------------------------------------- /rocaseg/datasets/constants.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | """ 5 | OAI iMorphics reference classes (DICOM attribute names). Cartilage tissues 6 | """ 7 | locations_mh53 = OrderedDict([ 8 | ('Background', 0), 9 | ('FemoralCartilage', 1), 10 | ('LateralTibialCartilage', 2), 11 | ('MedialTibialCartilage', 3), 12 | ('PatellarCartilage', 4), 13 | ('LateralMeniscus', 5), 14 | ('MedialMeniscus', 6), 15 | ]) 16 | 17 | 18 | """ 19 | Segmentation predictions. Major cartilage tissues. Joined L and M 20 | """ 21 | locations_f43h = OrderedDict([ 22 | ('_background', 0), 23 | ('femoral', 1), 24 | ('tibial', 2), 25 | ]) 26 | 27 | 28 | """ 29 | Segmentation predictions. Cartilage tissues. Joined L and M 30 | """ 31 | locations_zp3n = OrderedDict([ 32 | ('_background', 0), 33 | ('femoral', 1), 34 | ('tibial', 2), 35 | ('patellar', 3), 36 | ('menisci', 4), 37 | ]) 38 | 39 | 40 | atlas_to_locations = { 41 | 'imo': locations_mh53, 42 | 'segm': locations_zp3n, 43 | 'okoa': locations_f43h, 44 | } 45 | -------------------------------------------------------------------------------- /rocaseg/datasets/meta_oai.py: -------------------------------------------------------------------------------- 1 | release_to_prefix_var = { 2 | '0.C.2': 'V00', '0.E.1': 'V00', 3 | '1.C.2': 'V01', '1.E.1': 'V01', 4 | '2.D.2': 'V02', 5 | '3.C.2': 'V03', '3.E.1': 'V03', 6 | '4.G.1': 'V04', 7 | '5.C.1': 'V05', '5.E.1': 'V05', 8 | '6.C.1': 'V06', '6.E.1': 'V06', 9 | '8.C.1': 'V08', '8.E.1': 'V08', 10 | '10.C.1': 'V10', '10.E.1': 'V10', 11 | } 12 | 13 | prefix_var_to_visit_month = { 14 | 'V00': '000m', 15 | 'V01': '012m', 16 | 'V02': '018m', 17 | 'V03': '024m', 18 | 'V04': '030m', 19 | 'V05': '036m', 20 | 'V06': '048m', 21 | 'V07': '060m', 22 | 'V08': '072m', 23 | 'V09': '084m', 24 | 'V10': '096m', 25 | 'V11': '108m', 26 | } 27 | 28 | release_to_visit_month = { 29 | '0.C.2': '000m', '0.E.1': '000m', 30 | '1.C.2': '012m', '1.E.1': '012m', 31 | '2.D.2': '018m', 32 | '3.C.2': '024m', '3.E.1': '024m', 33 | '4.G.1': '030m', 34 | '5.C.1': '036m', '5.E.1': '036m', 35 | '6.C.1': '048m', '6.E.1': '048m', 36 | '8.C.1': '072m', '8.E.1': '072m', 37 | '10.C.1': '096m', '10.E.1': '096m', 38 | 39 | } 40 | 41 | side_code_to_str = {1: 'RIGHT', 2: 'LEFT'} 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import setup, find_packages 7 | 8 | with open('README.md') as readme_file: 9 | readme = readme_file.read() 10 | 11 | requirements = [] 12 | 13 | setup_requirements = [] 14 | 15 | test_requirements = [] 16 | 17 | setup( 18 | author="Egor Panfilov", 19 | author_email='egor.v.panfilov@gmail.com', 20 | classifiers=[ 21 | 'Development Status :: 4 - Beta', 22 | 'Intended Audience :: Science/Research', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Natural Language :: English', 25 | 'Programming Language :: Python :: 3', 26 | 'Programming Language :: Python :: 3.6', 27 | 'Programming Language :: Python :: 3.7', 28 | ], 29 | description="Framework for knee cartilage and menisci segmentation from MRI", 30 | install_requires=requirements, 31 | license="MIT license", 32 | long_description=readme, 33 | include_package_data=True, 34 | name='rocaseg', 35 | packages=find_packages(include=['rocaseg']), 36 | setup_requires=setup_requirements, 37 | tests_require=test_requirements, 38 | url='https://github.com/MIPT-Oulu/RobustCartilageSegmentation', 39 | version='0.1.0', 40 | zip_safe=False, 41 | ) 42 | -------------------------------------------------------------------------------- /rocaseg/components/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | """ 5 | Example usage: 6 | 7 | # Regular segmentation loss: 8 | ys_pred_oai = self.models['segm'](xs_oai) 9 | loss_segm = self.losses['segm'](input_=ys_pred_oai, 10 | target=ys_true_arg_oai) 11 | 12 | # Mixup 13 | xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( 14 | x=xs_oai, y=ys_true_arg_oai, 15 | alpha=self.config['mixup_alpha'], device=maybe_gpu) 16 | ys_pred_oai = self.models['segm'](xs_mixup) 17 | loss_segm = mixup_criterion(criterion=self.losses['segm'], 18 | pred=ys_pred_oai, 19 | y_a=ys_mixup_a, 20 | y_b=ys_mixup_b, 21 | lam=lambda_mixup) 22 | """ 23 | 24 | 25 | def mixup_data(x, y, alpha=1.0, device='cpu'): 26 | """Returns mixed inputs, pairs of targets, and lambda""" 27 | if alpha > 0: 28 | lam = np.random.beta(alpha, alpha) 29 | else: 30 | lam = 1 31 | 32 | batch_size = x.size()[0] 33 | index = torch.randperm(batch_size).to(device) 34 | 35 | mixed_x = lam * x + (1 - lam) * x[index, :] 36 | y_a, y_b = y, y[index] 37 | return mixed_x, y_a, y_b, lam 38 | 39 | 40 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 41 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 42 | -------------------------------------------------------------------------------- /rocaseg/components/losses.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch import nn 4 | 5 | 6 | logging.basicConfig() 7 | logger = logging.getLogger('losses') 8 | logger.setLevel(logging.DEBUG) 9 | 10 | 11 | class CrossEntropyLoss(nn.Module): 12 | def __init__(self, num_classes, batch_avg=True, batch_weight=None, 13 | class_avg=True, class_weight=None, **kwargs): 14 | """ 15 | 16 | Parameters 17 | ---------- 18 | batch_avg: 19 | Whether to average over the batch dimension. 20 | batch_weight: 21 | Batch samples importance coefficients. 22 | class_avg: 23 | Whether to average over the class dimension. 24 | class_weight: 25 | Classes importance coefficients. 26 | """ 27 | super().__init__() 28 | self.num_classes = num_classes 29 | self.batch_avg = batch_avg 30 | self.class_avg = class_avg 31 | self.batch_weight = batch_weight 32 | self.class_weight = class_weight 33 | logger.warning('Redundant loss function arguments:\n{}' 34 | .format(repr(kwargs))) 35 | self.ce = nn.CrossEntropyLoss(weight=class_weight) 36 | 37 | def forward(self, input_, target, **kwargs): 38 | """ 39 | 40 | Parameters 41 | ---------- 42 | input_: (b, ch, d0, d1) tensor 43 | target: (b, d0, d1) tensor 44 | 45 | Returns 46 | ------- 47 | out: float tensor 48 | """ 49 | return self.ce(input_, target) 50 | -------------------------------------------------------------------------------- /rocaseg/preproc/custom.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | 5 | logging.basicConfig() 6 | logger = logging.getLogger('preprocessing_custom') 7 | logger.setLevel(logging.DEBUG) 8 | 9 | 10 | class Normalize: 11 | def __init__(self, mean, std): 12 | self.mean = mean 13 | self.std = std 14 | 15 | def __call__(self, img, mask=None): 16 | img = img.astype(np.float32) 17 | img = (img - self.mean) / self.std 18 | 19 | if mask is not None: 20 | mask = mask.astype(np.float32) 21 | return img, mask 22 | 23 | 24 | class UnNormalize: 25 | def __init__(self, mean, std): 26 | self.mean = mean 27 | self.std = std 28 | 29 | def __call__(self, *args): 30 | return [(a * self.std + self.mean) for a in args] 31 | 32 | 33 | class PercentileClippingAndToFloat: 34 | """Change the histogram of image by doing global contrast normalization.""" 35 | def __init__(self, cut_min=0.5, cut_max=99.5): 36 | """ 37 | cut_min - lowest percentile which is used to cut the image histogram 38 | cut_max - highest percentile 39 | """ 40 | self.cut_min = cut_min 41 | self.cut_max = cut_max 42 | 43 | def __call__(self, img, mask=None): 44 | img = img.astype(np.float32) 45 | lim_low, lim_high = np.percentile(img, [self.cut_min, self.cut_max]) 46 | img = np.clip(img, lim_low, lim_high) 47 | 48 | img -= lim_low 49 | img /= img.max() 50 | 51 | img = img.astype(np.float32) 52 | if mask is not None: 53 | mask = mask.astype(np.float32) 54 | 55 | return img, mask 56 | -------------------------------------------------------------------------------- /rocaseg/components/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import logging 4 | 5 | import torch 6 | 7 | 8 | logging.basicConfig() 9 | logger = logging.getLogger('handler') 10 | logger.setLevel(logging.DEBUG) 11 | 12 | 13 | class CheckpointHandler(object): 14 | def __init__(self, path_root, 15 | fname_pattern=('{model_name}__' 16 | 'fold_{fold_idx}__' 17 | 'epoch_{epoch_idx:>03d}.pth'), 18 | num_saved=2): 19 | self.path_root = path_root 20 | self.fname_pattern = fname_pattern 21 | self.num_saved = num_saved 22 | 23 | ext = self.fname_pattern.split('.')[-1] 24 | 25 | if not os.path.exists(path_root): 26 | raise ValueError(f'Path {path_root} does not exist') 27 | 28 | full_pattern = os.path.join(self.path_root, '*.' + ext) 29 | self._all_ckpts = sorted(glob(full_pattern, recursive=False)) 30 | logger.info(f'Checkpoints found: {len(self._all_ckpts)}') 31 | 32 | self._remove_excessive_ckpts() 33 | 34 | def _remove_excessive_ckpts(self): 35 | while len(self._all_ckpts) > self.num_saved: 36 | try: 37 | os.remove(self._all_ckpts[0]) 38 | logger.info(f'Removed ckpt: {self._all_ckpts[0]}') 39 | self._all_ckpts = self._all_ckpts[1:] 40 | except OSError: 41 | logger.error(f'Cannot remove {self._all_ckpts[0]}') 42 | 43 | def get_last_ckpt(self): 44 | if len(self._all_ckpts) == 0: 45 | logger.warning(f'No checkpoints are available in {self.path_root}') 46 | return None 47 | else: 48 | fname_ckpt_sel = self._all_ckpts[-1] 49 | return fname_ckpt_sel 50 | 51 | def save_new_ckpt(self, model, model_name, fold_idx, epoch_idx): 52 | fname = self.fname_pattern.format(model_name=model_name, 53 | fold_idx=fold_idx, 54 | epoch_idx=epoch_idx) 55 | path_full = os.path.join(self.path_root, fname) 56 | try: 57 | torch.save(model.module.state_dict(), path_full) 58 | except AttributeError: 59 | torch.save(model.state_dict(), path_full) 60 | 61 | self._all_ckpts.append(path_full) 62 | self._remove_excessive_ckpts() 63 | -------------------------------------------------------------------------------- /rocaseg/models/discr_a.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | logging.basicConfig() 10 | logger = logging.getLogger('models') 11 | logger.setLevel(logging.DEBUG) 12 | 13 | 14 | class DiscriminatorA(nn.Module): 15 | def __init__(self, basic_width=64, input_channels=5, output_channels=1, 16 | restore_weights=False, path_weights=None, **kwargs): 17 | super().__init__() 18 | logger.warning('Redundant model init arguments:\n{}' 19 | .format(repr(kwargs))) 20 | 21 | # Preparing the modules dict 22 | modules = OrderedDict() 23 | 24 | modules['conv1'] = \ 25 | nn.Sequential(*[ 26 | nn.Conv2d(input_channels, basic_width, 27 | kernel_size=4, stride=2, padding=1), 28 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 29 | ]) 30 | modules['conv2'] = \ 31 | nn.Sequential(*[ 32 | nn.Conv2d(basic_width, basic_width*2, 33 | kernel_size=4, stride=2, padding=1), 34 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 35 | ]) 36 | modules['conv3'] = \ 37 | nn.Sequential(*[ 38 | nn.Conv2d(basic_width*2, basic_width*4, 39 | kernel_size=4, stride=2, padding=1), 40 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 41 | ]) 42 | modules['conv4'] = \ 43 | nn.Sequential(*[ 44 | nn.Conv2d(basic_width*4, basic_width*8, 45 | kernel_size=4, stride=2, padding=1), 46 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 47 | ]) 48 | modules['output'] = nn.Conv2d(basic_width*8, output_channels, 49 | kernel_size=4, stride=2, padding=1) 50 | 51 | self.__dict__['_modules'] = modules 52 | if restore_weights: 53 | self.load_state_dict(torch.load(path_weights)) 54 | 55 | def forward(self, x): 56 | tmp = x 57 | 58 | for name in self.__dict__['_modules']: 59 | layer = self.__dict__['_modules'][name] 60 | tmp = layer(tmp) 61 | 62 | out = F.interpolate(tmp, size=x.size()[-2:], 63 | mode='bilinear', align_corners=True) 64 | return out 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rocaseg - Robust Cartilage Segmentation from MRI 2 | 3 | Source code for Panfilov et al. "Improving Robustness of Deep Learning Based Knee MRI Segmentation: Mixup and Adversarial Domain Adaptation", https://arxiv.org/abs/1908.04126v3. 4 | 5 |

6 | Overview 7 |

8 | 9 | ### Important! 10 | 11 | The camera-ready version contained a bug in Dice score computation for tibial cartilage on Dataset C. Please, refer to the arXiv version for the corrected values - https://arxiv.org/abs/1908.04126v3. 12 | 13 | ### Description 14 | 15 | 1. To reproduce the experiments from the article one needs to have access to 16 | OAI iMorphics, OKOA, and MAKNEE datasets. 17 | 18 | 2. Download code from this repository. 19 | 20 | 3. Create a fresh Conda environment using `environment.yml`. Install the downloaded 21 | code as a Python module. 22 | 23 | 4. `datasets/prepare_dataset_...` files show how the raw data is converted into the 24 | format supported by the training and the inference pipelines. 25 | 26 | 5. The structure of the project has to be as follows: 27 | ``` 28 | ./project/ 29 | | ./data_raw/ # raw scans and annotations 30 | | ./OAI_iMorphics_scans/ 31 | | ./OAI_iMorphics_annotations/ 32 | | ./OKOA/ 33 | | ./MAKNEE/ 34 | | ./data/ # preprocessed scans and annotations 35 | | ./src/ (this repository) 36 | | ./results/ # models' weights, intermediate and final results 37 | | ./0_baseline/ 38 | | ./weights/ 39 | | ... 40 | | ./1_mixup/ 41 | | ./2_mixup_nowd/ 42 | | ./3_uda1/ 43 | | ./4_uda2/ 44 | | ./5_uda1_mixup_nowd/ 45 | ``` 46 | 47 | 6. File `scripts/runner.sh` contains the complete description of the workflow. 48 | 49 | 7. Statistical testing is implemented in `notebooks/Statistical_tests.ipynb`. 50 | 51 | 8. Pretrained models are available at https://drive.google.com/open?id=1f-gZ2wCf55OVjgA8oXd7xttGVW5DUUcU . 52 | 53 | ### Legal aspects 54 | 55 | This code is freely available only for research purposes. 56 | 57 | The software has not been certified as a medical device and, therefore, must not be used 58 | for diagnostic purposes. 59 | 60 | Commercial use of the provided code and the pre-trained models is strictly prohibited, 61 | since they were developed using the medical datasets under restrictive licenses. 62 | 63 | ### Cite this work 64 | 65 | ``` 66 | @InProceedings{Panfilov_2019_ICCV_Workshops, 67 | author = {Panfilov, Egor and Tiulpin, Aleksei and Klein, Stefan and Nieminen, Miika T. and Saarakkala, Simo}, 68 | title = {Improving Robustness of Deep Learning Based Knee MRI Segmentation: Mixup and Adversarial Domain Adaptation}, 69 | booktitle = {The IEEE International Conference on Computer Vision (ICCV) Workshops}, 70 | month = {Oct}, 71 | year = {2019} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /rocaseg/components/formats.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import logging 3 | 4 | import numpy as np 5 | import cv2 6 | import nibabel as nib 7 | 8 | 9 | logging.basicConfig() 10 | logger = logging.getLogger('formats') 11 | logger.setLevel(logging.DEBUG) 12 | 13 | 14 | def png_to_numpy(pattern_fname_in, reverse=False): 15 | """ 16 | 17 | Args: 18 | pattern_fname_in: str 19 | String or regexp compatible with `glob`. 20 | reverse: bool 21 | Whether to use reverse slice order. 22 | 23 | Returns: 24 | stack: [R, C, P] ndarray 25 | """ 26 | fnames_in = sorted(glob(pattern_fname_in)) 27 | 28 | stack = [cv2.imread(fn, cv2.IMREAD_GRAYSCALE) for fn in fnames_in] 29 | stack = np.stack(stack, axis=2) 30 | if reverse: 31 | stack = stack[..., ::-1] 32 | return stack 33 | 34 | 35 | def png_to_nifti(pattern_fname_in, fname_out, spacings=None, reverse=False, 36 | rcp_to_ras=False): 37 | """ 38 | 39 | Args: 40 | pattern_fname_in: str 41 | String or regexp compatible with `glob`. 42 | fname_out: str 43 | Full path to the output file. 44 | spacings: 3-tuple of float 45 | (pixel spacing in r, pixel spacing in c, slice thickness). 46 | reverse: bool 47 | Whether to use reverse slice order. 48 | rcp_to_ras: bool 49 | Whether to convert from row-column-plane to RAS+ coordinates. 50 | 51 | """ 52 | fnames_in = sorted(glob(pattern_fname_in)) 53 | 54 | stack = [cv2.imread(fn, cv2.IMREAD_GRAYSCALE) for fn in fnames_in] 55 | stack = np.stack(stack, axis=2) 56 | if reverse: 57 | stack = stack[..., ::-1] 58 | 59 | numpy_to_nifti(stack=stack, fname_out=fname_out, spacings=spacings, 60 | rcp_to_ras=rcp_to_ras) 61 | 62 | 63 | def nifti_to_png(fname_in, pattern_fname_out, reverse=False, ras_to_rcp=False): 64 | """ 65 | 66 | Args: 67 | fname_in: str 68 | Full path to the input file. 69 | pattern_fname_out: str 70 | Must include `{i}`, which is to be substituted with the running index. 71 | reverse: bool 72 | Whether to use reverse slice order. 73 | ras_to_rcp: bool 74 | Whether to convert from RAS+ to row-column-plane coordinates. 75 | """ 76 | stack, spacings = nifti_to_numpy(fname_in=fname_in, ras_to_rcp=ras_to_rcp) 77 | 78 | if reverse: 79 | stack = stack[..., ::-1] 80 | 81 | for i in range(stack.shape[-1]): 82 | fn = pattern_fname_out.format(i=i) 83 | cv2.imwrite(fn, stack[..., i]) 84 | 85 | 86 | def nifti_to_numpy(fname_in, ras_to_rcp=False): 87 | """ 88 | 89 | Args: 90 | fname_in: str 91 | Full path to the input file. 92 | ras_to_rcp: bool 93 | Whether to convert from RAS+ to row-column-plane coordinates. 94 | 95 | Returns: 96 | stack: [R, C, P] ndarray 97 | spacings: 3-tuple of float 98 | (pixel spacing in r, pixel spacing in c, slice thickness). 99 | 100 | """ 101 | scan = nib.load(fname_in) 102 | stack = scan.get_fdata() 103 | spacings = [scan.affine[i, i] for i in range(3)] 104 | 105 | if ras_to_rcp: 106 | stack = np.moveaxis(stack, [2, 1, 0], [0, 1, 2]) 107 | spacings = [-s for s in spacings[::-1]] 108 | 109 | return stack, spacings 110 | 111 | 112 | def numpy_to_nifti(stack, fname_out, spacings=None, rcp_to_ras=False): 113 | """ 114 | 115 | Args: 116 | stack: (r, c, p) ndarray 117 | Data array. 118 | fname_out: 119 | Full path to the output file. 120 | spacings: 3-tuple of float 121 | (pixel spacing in r, pixel spacing in c, slice thickness). 122 | rcp_to_ras: bool 123 | Whether to convert from row-column-plane to RAS+ coordinates. 124 | """ 125 | if not rcp_to_ras: 126 | affine = np.eye(4, dtype=np.float) 127 | if spacings is not None: 128 | affine[0, 0] = spacings[0] 129 | affine[1, 1] = spacings[1] 130 | affine[2, 2] = spacings[2] 131 | else: 132 | stack = np.moveaxis(stack, [0, 1, 2], [2, 1, 0]) 133 | affine = np.diag([-1., -1., -1., 1.]).astype(np.float) 134 | if spacings is not None: 135 | affine[0, 0] = -spacings[2] 136 | affine[1, 1] = -spacings[1] 137 | affine[2, 2] = -spacings[0] 138 | 139 | scan = nib.Nifti1Image(stack, affine=affine) 140 | nib.save(scan, fname_out) 141 | -------------------------------------------------------------------------------- /rocaseg/describe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from collections import defaultdict 4 | 5 | import click 6 | 7 | import torch 8 | from torch.utils.data.dataloader import DataLoader 9 | 10 | from rocaseg.datasets import sources_from_path 11 | from rocaseg.preproc import * 12 | from rocaseg.repro import set_ultimate_seed 13 | 14 | 15 | logging.basicConfig() 16 | logger = logging.getLogger('train') 17 | logger.setLevel(logging.DEBUG) 18 | 19 | set_ultimate_seed() 20 | 21 | if torch.cuda.is_available(): 22 | maybe_gpu = 'cuda' 23 | else: 24 | maybe_gpu = 'cpu' 25 | 26 | 27 | class Describer: 28 | def __init__(self, config): 29 | self.config = config 30 | 31 | def run(self, loader): 32 | metrics_avg = defaultdict(float) 33 | 34 | for i, data_batch in enumerate(loader): 35 | metrics_curr = dict() 36 | xs, ys_true = data_batch 37 | # xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu) 38 | 39 | # Calculate metrics 40 | with torch.no_grad(): 41 | e = self.config['metrics_skip_edge'] 42 | if e != 0: 43 | metrics_curr['mean'] = xs[:, :, e:-e, e:-e].mean() 44 | metrics_curr['std'] = xs[:, :, e:-e, e:-e].std() 45 | metrics_curr['var'] = xs[:, :, e:-e, e:-e].var() 46 | else: 47 | metrics_curr['mean'] = xs.mean() 48 | metrics_curr['std'] = xs.std() 49 | metrics_curr['var'] = xs.var() 50 | for k, v in metrics_curr.items(): 51 | metrics_avg[k] += v 52 | 53 | # Add metrics logging 54 | logger.info('Metrics:') 55 | metrics_avg = {k: v / len(loader) 56 | for k, v in metrics_avg.items()} 57 | for k, v in metrics_avg.items(): 58 | logger.info(f'{k}: {v}') 59 | 60 | 61 | @click.command() 62 | @click.option('--path_data_root', default='../../data') 63 | @click.option('--path_experiment_root', default='../../results/temporary') 64 | @click.option('--dataset', type=click.Choice( 65 | ['oai_imo', 'okoa', 'maknee'])) 66 | @click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str) 67 | @click.option('--sample_mode', default='x_y', type=str) 68 | @click.option('--batch_size', default=64, type=int) 69 | @click.option('--num_workers', default=1, type=int) 70 | @click.option('--seed_trainval_test', default=0, type=int) 71 | @click.option('--metrics_skip_edge', default=0, type=int) 72 | def main(**config): 73 | config['path_logs'] = os.path.join( 74 | config['path_experiment_root'], f"logs_{config['dataset']}_describe") 75 | 76 | os.makedirs(config['path_logs'], exist_ok=True) 77 | 78 | logging_fh = logging.FileHandler( 79 | os.path.join(config['path_logs'], 'main.log')) 80 | logging_fh.setLevel(logging.DEBUG) 81 | logger.addHandler(logging_fh) 82 | 83 | # Collect the available and specified sources 84 | sources = sources_from_path(path_data_root=config['path_data_root'], 85 | selection=config['dataset'], 86 | with_folds=False, 87 | seed_trainval_test=config['seed_trainval_test']) 88 | 89 | if config['dataset'] == 'oai_imo': 90 | from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d 91 | elif config['dataset'] == 'okoa': 92 | from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d 93 | elif config['dataset'] == 'maknee': 94 | from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d 95 | else: 96 | raise ValueError('Unknown dataset') 97 | 98 | for subset in ('trainval', 'test'): 99 | name = subset 100 | df = sources[config['dataset']][f"{subset}_df"] 101 | 102 | dataset = DatasetSagittal2d( 103 | df_meta=df, mask_mode=config['mask_mode'], name=name, 104 | sample_mode=config['sample_mode'], 105 | transforms=[ 106 | PercentileClippingAndToFloat(cut_min=10, cut_max=99), 107 | ToTensor() 108 | ]) 109 | 110 | loader = DataLoader(dataset, 111 | batch_size=config['batch_size'], 112 | shuffle=False, 113 | num_workers=config['num_workers'], 114 | pin_memory=True, 115 | drop_last=False) 116 | describer = Describer(config=config) 117 | 118 | describer.run(loader) 119 | loader.dataset.describe() 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /notebooks/Statistical_tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from glob import glob\n", 11 | "import pickle\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "import scipy\n", 15 | "import scipy.stats" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# TODO: uncomment only two experiments\n", 25 | "EXPERIMENTS = (\n", 26 | "# '0_baseline',\n", 27 | "# '1_mixup',\n", 28 | "# '2_mixup_nowd',\n", 29 | "# '3_uda1',\n", 30 | "# '4_uda2',\n", 31 | "# '5_uda1_mixup_nowd'\n", 32 | ")\n", 33 | "\n", 34 | "# TODO: uncomment the section related to the selected evaluation dataset\n", 35 | "# if True:\n", 36 | "# DATASET = 'oai_imo'\n", 37 | "# CLASSES = list(range(1, 5))\n", 38 | "# ATLAS = 'segm'\n", 39 | "# if True:\n", 40 | "# DATASET = 'okoa'\n", 41 | "# CLASSES = (1, 2)\n", 42 | "# ATLAS = 'okoa'\n", 43 | "\n", 44 | "# TODO: specify path to your `results` directory\n", 45 | "path_results_root = ''\n", 46 | "\n", 47 | "path_base = os.path.join(path_results_root,\n", 48 | " (f'{EXPERIMENTS[0]}/logs_{DATASET}_test'\n", 49 | " f'/cache_{DATASET}_test_{ATLAS}_volumew_paired.pkl'))\n", 50 | "path_eval = os.path.join(path_results_root,\n", 51 | " (f'{EXPERIMENTS[1]}/logs_{DATASET}_test'\n", 52 | " f'/cache_{DATASET}_test_{ATLAS}_volumew_paired.pkl'))\n", 53 | "\n", 54 | "with open(path_base, 'rb') as f:\n", 55 | " dict_base = pickle.load(f)\n", 56 | "\n", 57 | "with open(path_eval, 'rb') as f:\n", 58 | " dict_eval = pickle.load(f)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "kl_ranges = {\n", 68 | " '0': (0, ),\n", 69 | " '1': (1, ),\n", 70 | " '2': (2, ),\n", 71 | " '3': (3, ),\n", 72 | " '4': (4, ),\n", 73 | " 'all': (0, 1, 2, 3, 4),\n", 74 | "}" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "dsc_base = np.asarray(dict_base['dice_score'])\n", 84 | "dsc_eval = np.asarray(dict_eval['dice_score'])\n", 85 | "\n", 86 | "for kl_name, kl_values in kl_ranges.items():\n", 87 | " sel_base = np.isin(np.asarray(dict_base['KL']), kl_values)\n", 88 | " sel_eval = np.isin(np.asarray(dict_eval['KL']), kl_values)\n", 89 | " assert np.all(np.equal(sel_base, sel_eval))\n", 90 | " sel = sel_base\n", 91 | " if not np.sum(sel):\n", 92 | " continue\n", 93 | " \n", 94 | " print(f'------ KL: {kl_name} ------')\n", 95 | "\n", 96 | " print('--- Wilcoxon signed-rank, two-sided ---')\n", 97 | " res_2side = [scipy.stats.wilcoxon(dsc_base[sel, c],\n", 98 | " dsc_eval[sel, c])\n", 99 | " for c in CLASSES]\n", 100 | " print(*res_2side, sep='\\n')\n", 101 | "\n", 102 | " print('--- Wilcoxon signed-rank, one-sided, less ---')\n", 103 | " res_1side = [scipy.stats.wilcoxon(dsc_base[sel, c],\n", 104 | " dsc_eval[sel, c],\n", 105 | " alternative='less')\n", 106 | " for c in CLASSES]\n", 107 | " print(*res_1side, sep='\\n')\n", 108 | "\n", 109 | " print('--- Wilcoxon signed-rank, one-sided, greater ---')\n", 110 | " res_1side = [scipy.stats.wilcoxon(dsc_base[sel, c],\n", 111 | " dsc_eval[sel, c],\n", 112 | " alternative='greater')\n", 113 | " for c in CLASSES]\n", 114 | " print(*res_1side, sep='\\n')" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.7.3" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 2 139 | } 140 | -------------------------------------------------------------------------------- /rocaseg/components/metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | logging.basicConfig() 8 | logger = logging.getLogger('metrics') 9 | logger.setLevel(logging.DEBUG) 10 | 11 | 12 | def confusion_matrix(input_, target, num_classes): 13 | """ 14 | https://github.com/ternaus/robot-surgery-segmentation/blob/master/validation.py 15 | 16 | Args: 17 | input_: (d0, ..., dn) ndarray or tensor 18 | target: (d0, ..., dn) ndarray or tensor 19 | num_classes: int 20 | Total number of classes. 21 | 22 | Returns: 23 | out: (num_classes, num_classes) ndarray 24 | Confusion matrix. 25 | """ 26 | if torch.is_tensor(input_): 27 | input_ = input_.detach().to('cpu').numpy() 28 | if torch.is_tensor(target): 29 | target = target.detach().to('cpu').numpy() 30 | 31 | replace_indices = np.vstack(( 32 | target.flatten(), 33 | input_.flatten()) 34 | ).T 35 | cm, _ = np.histogramdd( 36 | replace_indices, 37 | bins=(num_classes, num_classes), 38 | range=[(0, num_classes-1), (0, num_classes-1)] 39 | ) 40 | return cm.astype(np.uint32) 41 | 42 | 43 | def dice_score_from_cm(cm): 44 | """ 45 | https://github.com/ternaus/robot-surgery-segmentation/blob/master/validation.py 46 | 47 | Args: 48 | cm: (d, d) ndarray 49 | Confusion matrix. 50 | 51 | Returns: 52 | out: (d, ) list 53 | List of class Dice scores. 54 | """ 55 | scores = [] 56 | for index in range(cm.shape[0]): 57 | true_positives = cm[index, index] 58 | false_positives = cm[:, index].sum() - true_positives 59 | false_negatives = cm[index, :].sum() - true_positives 60 | denom = 2 * true_positives + false_positives + false_negatives 61 | if denom == 0: 62 | score = 0 63 | else: 64 | score = 2 * float(true_positives) / denom 65 | scores.append(score) 66 | return scores 67 | 68 | 69 | # ---------------------------------------------------------------------------- 70 | 71 | 72 | def _template_score(func_score_from_cm, input_, target, num_classes, 73 | batch_avg, batch_weight, class_avg, class_weight): 74 | """ 75 | 76 | Args: 77 | input_: (b, d0, ..., dn) ndarray or tensor 78 | target: (b, d0, ..., dn) ndarray or tensor 79 | num_classes: int 80 | Total number of classes. 81 | batch_avg: bool 82 | Whether to average over the batch dimension. 83 | batch_weight: (b,) iterable 84 | Batch samples importance coefficients. 85 | class_avg: bool 86 | Whether to average over the class dimension. 87 | class_weight: (c,) iterable 88 | Classes importance coefficients. Ignored when `class_avg` is False. 89 | 90 | Returns: 91 | out: scalar if `class_avg` is True, (num_classes,) list otherwise 92 | """ 93 | if torch.is_tensor(input_): 94 | num_samples = tuple(input_.size())[0] 95 | else: 96 | num_samples = input_.shape[0] 97 | 98 | scores = np.zeros((num_samples, num_classes)) 99 | for sample_idx in range(num_samples): 100 | cm = confusion_matrix(input_=input_[sample_idx], 101 | target=target[sample_idx], 102 | num_classes=num_classes) 103 | scores[sample_idx, :] = func_score_from_cm(cm) 104 | 105 | if batch_avg: 106 | scores = np.mean(scores, axis=0, keepdims=True) 107 | if class_avg: 108 | if class_weight is not None: 109 | scores = scores * np.reshape(class_weight, (1, -1)) 110 | scores = np.mean(scores, axis=1, keepdims=True) 111 | return np.squeeze(scores) 112 | 113 | 114 | def dice_score(input_, target, num_classes, 115 | batch_avg=True, batch_weight=None, 116 | class_avg=False, class_weight=None): 117 | """ 118 | 119 | Args: 120 | input_: (b, d0, ..., dn) ndarray or tensor 121 | target: (b, d0, ..., dn) ndarray or tensor 122 | num_classes: int 123 | Total number of classes. 124 | batch_avg: bool 125 | Whether to average over the batch dimension. 126 | batch_weight: (b,) iterable 127 | Batch samples importance coefficients. 128 | class_avg: bool 129 | Whether to average over the class dimension. 130 | class_weight: (c,) iterable 131 | Classes importance coefficients. Ignored when `class_avg` is False. 132 | 133 | Returns: 134 | out: scalar if `class_avg` is True, (num_classes,) list otherwise 135 | """ 136 | return _template_score( 137 | dice_score_from_cm, input_, target, num_classes, 138 | batch_avg, batch_weight, class_avg, class_weight) 139 | -------------------------------------------------------------------------------- /rocaseg/datasets/dataset_maknee.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import logging 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import cv2 10 | from torch.utils.data.dataset import Dataset 11 | 12 | 13 | logging.basicConfig() 14 | logger = logging.getLogger('dataset') 15 | logger.setLevel(logging.DEBUG) 16 | 17 | 18 | def index_from_path_maknee(path_root, force=False): 19 | fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv') 20 | fname_meta_base = os.path.join(path_root, 'meta_base.csv') 21 | 22 | if not os.path.exists(fname_meta_dyn) or force: 23 | fnames_image = glob.glob( 24 | os.path.join(path_root, '**', 'images', '*.png'), recursive=True) 25 | logger.info('{} images found'.format(len(fnames_image))) 26 | fnames_mask = glob.glob( 27 | os.path.join(path_root, '**', 'masks', '*.png'), recursive=True) 28 | logger.info('{} masks found'.format(len(fnames_mask))) 29 | 30 | df_meta = pd.read_csv(fname_meta_base, 31 | dtype={'patient': str, 32 | 'release': str, 33 | 'sequence': str, 34 | 'side': str, 35 | 'slice_idx': int, 36 | 'pixel_spacing_0': float, 37 | 'pixel_spacing_1': float, 38 | 'slice_thickness': float, 39 | 'KL': int}, 40 | index_col=False) 41 | 42 | if len(fnames_image) != len(df_meta): 43 | raise ValueError("Number of images doesn't match with the metadata") 44 | 45 | df_meta['path_image'] = [os.path.join(path_root, e) 46 | for e in df_meta['path_rel_image']] 47 | 48 | # Sort the records 49 | df_meta_sorted = (df_meta 50 | .sort_values(['patient', 'sequence', 'slice_idx']) 51 | .reset_index() 52 | .drop('index', axis=1)) 53 | 54 | df_meta_sorted.to_csv(fname_meta_dyn, index=False) 55 | else: 56 | df_meta_sorted = pd.read_csv(fname_meta_dyn, 57 | dtype={'patient': str, 58 | 'release': str, 59 | 'sequence': str, 60 | 'side': str, 61 | 'slice_idx': int, 62 | 'pixel_spacing_0': float, 63 | 'pixel_spacing_1': float, 64 | 'slice_thickness': float, 65 | 'KL': int}, 66 | index_col=False) 67 | 68 | return df_meta_sorted 69 | 70 | 71 | def read_image(path_file): 72 | image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE) 73 | return image.reshape((1, *image.shape)) 74 | 75 | 76 | class DatasetMAKNEESagittal2d(Dataset): 77 | def __init__(self, df_meta, mask_mode=None, name=None, transforms=None, 78 | sample_mode='x_y', **kwargs): 79 | logger.warning('Redundant dataset init arguments:\n{}' 80 | .format(repr(kwargs))) 81 | 82 | self.df_meta = df_meta 83 | self.mask_mode = mask_mode 84 | self.name = name 85 | self.transforms = transforms 86 | self.sample_mode = sample_mode 87 | 88 | def __len__(self): 89 | return len(self.df_meta) 90 | 91 | def _getitem_x_y(self, idx): 92 | image = read_image(self.df_meta['path_image'].iloc[idx]) 93 | mask = np.zeros_like(image) 94 | 95 | # Apply transformations 96 | if self.transforms is not None: 97 | for t in self.transforms: 98 | if hasattr(t, 'randomize'): 99 | t.randomize() 100 | image, mask = t(image, mask) 101 | 102 | tmp = dict(self.df_meta.iloc[idx]) 103 | tmp['image'] = image 104 | tmp['mask'] = mask 105 | 106 | tmp['xs'] = tmp['image'] 107 | tmp['ys'] = tmp['mask'] 108 | return tmp 109 | 110 | def __getitem__(self, idx): 111 | if self.sample_mode == 'x_y': 112 | return self._getitem_x_y(idx) 113 | else: 114 | raise ValueError('Invalid `sample_mode`') 115 | 116 | def describe(self): 117 | summary = defaultdict(float) 118 | for i in range(len(self)): 119 | if self.sample_mode == 'x_y': 120 | _, mask = self.__getitem__(i) 121 | else: 122 | mask = self.__getitem__(i)['mask'] 123 | summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2)) 124 | summary['class_importance'] = \ 125 | np.sum(summary['num_class_pixels']) / summary['num_class_pixels'] 126 | summary['class_importance'] /= np.sum(summary['class_importance']) 127 | logger.info('Dataset statistics:') 128 | logger.info(sorted(summary.items())) 129 | -------------------------------------------------------------------------------- /rocaseg/resample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import cv2 8 | 9 | 10 | @click.command() 11 | @click.option('--path_root_in', help='E.g. data/31_OKOA_full_meta') 12 | @click.option('--spacing_in', nargs=2, default=(0.5859375, 0.5859375)) 13 | @click.option('--path_root_out', help='E.g. data/32_OKOA_full_meta_rescaled') 14 | @click.option('--spacing_out', nargs=2, default=(0.36458333, 0.36458333)) 15 | @click.option('--dirname_images', default='images') 16 | @click.option('--dirname_masks', default='masks') 17 | @click.option('--num_threads', default=12, type=click.IntRange(-1, 12)) 18 | @click.option('--margin', default=0, type=int) 19 | @click.option('--update_meta', is_flag=True) 20 | def main(**config): 21 | # Get the index of image files and the corresponding metadata 22 | path_meta = os.path.join(config['path_root_in'], 'meta_base.csv') 23 | if os.path.exists(path_meta): 24 | pass 25 | else: 26 | path_meta = os.path.join(config['path_root_in'], 'meta_dynamic.csv') 27 | 28 | df_meta = pd.read_csv(path_meta, 29 | dtype={'patient': str, 30 | 'release': str, 31 | 'prefix_var': str, 32 | 'sequence': str, 33 | 'side': str, 34 | 'slice_idx': int, 35 | 'pixel_spacing_0': float, 36 | 'pixel_spacing_1': float, 37 | 'slice_thickness': float, 38 | 'KL': int, 39 | 'has_mask': int}, 40 | index_col=False) 41 | 42 | df_in = df_meta.sort_values(['patient', 'release', 'sequence', 'side', 'slice_idx']) 43 | 44 | ratio = (np.asarray(config['spacing_in']) / 45 | np.asarray(config['spacing_out'])) 46 | 47 | groupers_stack = ['patient', 'release', 'sequence', 'side', 'slice_idx'] 48 | 49 | # Resample images 50 | if config['dirname_images'] is not None: 51 | for name_gb, df_gb in tqdm(df_in.groupby(groupers_stack), desc='Resample images'): 52 | patient, release, sequence, side, slice_idx = name_gb 53 | 54 | fn_base = f'{slice_idx:03d}.png' 55 | dir_in = os.path.join(config['path_root_in'], 56 | patient, release, sequence, 57 | config['dirname_images']) 58 | dir_out = os.path.join(config['path_root_out'], 59 | patient, release, sequence, 60 | config['dirname_images']) 61 | os.makedirs(dir_out, exist_ok=True) 62 | 63 | path_in = os.path.join(dir_in, fn_base) 64 | path_out = os.path.join(dir_out, fn_base) 65 | 66 | img_in = cv2.imread(path_in, cv2.IMREAD_GRAYSCALE) 67 | 68 | if config['margin'] == 0: 69 | tmp = img_in 70 | else: 71 | tmp = img_in[config['margin']:-config['margin'], 72 | config['margin']:-config['margin']] 73 | 74 | shape_out = tuple(np.floor(tmp.shape * ratio).astype(np.int))[::-1] 75 | tmp = cv2.resize(tmp, shape_out) 76 | img_out = tmp 77 | 78 | cv2.imwrite(path_out, img_out) 79 | 80 | # Resample masks 81 | if config['dirname_masks'] is not None: 82 | for name_gb, df_gb in tqdm(df_in.groupby(groupers_stack), desc='Resample masks'): 83 | patient, release, sequence, side, slice_idx = name_gb 84 | 85 | fn_base = f'{slice_idx:03d}.png' 86 | dir_in = os.path.join(config['path_root_in'], 87 | patient, release, sequence, 88 | config['dirname_masks']) 89 | dir_out = os.path.join(config['path_root_out'], 90 | patient, release, sequence, 91 | config['dirname_masks']) 92 | os.makedirs(dir_out, exist_ok=True) 93 | 94 | path_in = os.path.join(dir_in, fn_base) 95 | if not os.path.exists(path_in): 96 | print(f'No mask found for {name_gb}') 97 | continue 98 | path_out = os.path.join(dir_out, fn_base) 99 | 100 | mask_in = cv2.imread(path_in, cv2.IMREAD_GRAYSCALE) 101 | 102 | if config['margin'] == 0: 103 | tmp = mask_in 104 | else: 105 | tmp = mask_in[config['margin']:-config['margin'], 106 | config['margin']:-config['margin']] 107 | 108 | shape_out = tuple(np.floor(tmp.shape * ratio).astype(np.int))[::-1] 109 | tmp = cv2.resize(tmp, shape_out, interpolation=cv2.INTER_NEAREST) 110 | mask_out = tmp 111 | 112 | cv2.imwrite(path_out, mask_out) 113 | 114 | if config['update_meta']: 115 | df_out = (df_in.assign(pixel_spacing_0=config['spacing_out'][0]) 116 | .assign(pixel_spacing_1=config['spacing_out'][1])) 117 | df_out.to_csv(os.path.join(config['path_root_out'], 'meta_base.csv'), index=False) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /rocaseg/analyze_predictions_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | import click 6 | import logging 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | import pandas as pd 12 | 13 | from rocaseg.datasets.constants import atlas_to_locations 14 | 15 | logging.basicConfig() 16 | logger = logging.getLogger('analyze') 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | def slicew_dsc_distr_vs_slice_idcs(results, config, num_classes, class_names): 21 | """Distribution of planar DSCs VS slice indices 22 | """ 23 | path_vis = config['path_results_root'] 24 | 25 | metric_names = ['dice_score', ] 26 | # colors = ["darkorange", "mediumorchid", "deepskyblue"] 27 | # colors = ["salmon", "dodgerblue", "orangered"] 28 | labels = ['Baseline', '+ mixup - WD', '+ UDA2'] 29 | colors = ["lightsalmon", "dodgerblue", 'grey'] 30 | # labels = ['Reference', '+ mixup - WD'] 31 | 32 | # Average and visualize the metrics 33 | for metric_name in metric_names: 34 | print(metric_name) 35 | 36 | for class_idx in range(1, num_classes): 37 | 38 | fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 3.6)) 39 | plt.title(metric_name) 40 | 41 | for acc_idx, (_, acc_slicew) in enumerate(results.items()): 42 | unique_scans = list(set(zip( 43 | acc_slicew['patient'], 44 | acc_slicew['release'] 45 | ))) 46 | num_scans = len(unique_scans) 47 | num_slices = 160 48 | 49 | y = np.full((num_slices, num_scans), np.nan) 50 | for slice_idx in range(num_slices): 51 | sel_by_slice = np.asarray(acc_slicew['slice_idx_proc']) == slice_idx 52 | scores = np.asarray(acc_slicew['dice_score'])[sel_by_slice] 53 | scores = scores[:, class_idx] 54 | 55 | y[slice_idx, :] = scores 56 | 57 | x = np.ones_like(y) * np.arange(0, num_slices)[:, None] 58 | y = y.ravel() 59 | x = x.ravel() 60 | 61 | sel = (y != 0) 62 | y = y[sel] 63 | x = x[sel] 64 | 65 | sns.lineplot(x=x, y=y, err_style='band', ax=axes, 66 | color=colors[acc_idx], label=labels[acc_idx]) 67 | 68 | fontsize = 14 69 | axes.set_ylim((0.45, 0.95)) 70 | # axes.set_ylim((0.5, 1.0)) 71 | axes.set_xlim((10, 150)) 72 | axes.set_xlabel('slice index', size=fontsize) 73 | axes.set_ylabel('DSC', size=fontsize) 74 | tmp_title = class_names[class_idx] 75 | tmp_title = tmp_title.replace('femoral', 'Femoral cartilage') 76 | tmp_title = tmp_title.replace('tibial', 'Tibial cartilage') 77 | tmp_title = tmp_title.replace('patellar', 'Patellar cartilage') 78 | tmp_title = tmp_title.replace('menisci', 'Menisci') 79 | 80 | axes.set_title(tmp_title, size=fontsize) 81 | 82 | for label in (axes.get_xticklabels() + axes.get_yticklabels()): 83 | label.set_fontsize(fontsize) 84 | 85 | leg = axes.legend( 86 | # loc='lower left', 87 | loc='lower right', 88 | prop={'size': fontsize}, ncol=1, 89 | framealpha=1.0, 90 | ) 91 | for line in leg.get_lines(): 92 | line.set_linewidth(3.0) 93 | 94 | # axes[class_idx_axes].get_legend().set_visible(False) 95 | 96 | plt.grid(linestyle=':') 97 | 98 | fname_vis = os.path.join(path_vis, 99 | f"metrics_{config['dataset']}_" 100 | f"test_slicew_confid_" 101 | f"{metric_name}_{class_idx}.pdf") 102 | plt.savefig(fname_vis, bbox_inches='tight') 103 | logger.info(f"Saved to {fname_vis}") 104 | if config['interactive']: 105 | plt.show() 106 | else: 107 | plt.close() 108 | 109 | 110 | @click.command() 111 | @click.option('--path_results_root', default='../../results') 112 | @click.option('--experiment_id', multiple=True) 113 | @click.option('--dataset', required=True, type=click.Choice( 114 | ['oai_imo', 'okoa', 'maknee'])) 115 | @click.option('--atlas', required=True, type=click.Choice( 116 | ['imo', 'segm', 'okoa'])) 117 | @click.option('--interactive', is_flag=True) 118 | @click.option('--num_workers', default=1, type=int) 119 | def main(**config): 120 | results_slicew = OrderedDict() 121 | results_volumew = OrderedDict() 122 | 123 | for exp_id in config['experiment_id']: 124 | path_experiment_root = os.path.join( 125 | config['path_results_root'], exp_id) 126 | path_logs = os.path.join( 127 | path_experiment_root, f"logs_{config['dataset']}_test") 128 | 129 | # Get the information on object classes 130 | locations = atlas_to_locations[config['atlas']] 131 | class_names = [k for k in locations] 132 | num_classes = max(locations.values()) + 1 133 | 134 | # Load precomputed planar scores 135 | fname_slicew = os.path.join( 136 | path_logs, 137 | f"cache_{config['dataset']}_test_{config['atlas']}_slicew_paired.pkl" 138 | ) 139 | if os.path.exists(fname_slicew): 140 | with open(fname_slicew, 'rb') as f: 141 | acc_slicew = pickle.load(f) 142 | else: 143 | raise IOError(f'File {fname_slicew} does not exist') 144 | 145 | fname_volumew = os.path.join( 146 | path_logs, 147 | f"cache_{config['dataset']}_test_{config['atlas']}_volumew_paired.pkl" 148 | ) 149 | if os.path.exists(fname_volumew): 150 | with open(fname_volumew, 'rb') as f: 151 | acc_volumew = pickle.load(f) 152 | else: 153 | raise IOError(f'File {fname_volumew} does not exist') 154 | 155 | results_slicew.update({exp_id: acc_slicew}) 156 | results_volumew.update({exp_id: acc_volumew}) 157 | 158 | # ------------------------------ Visualization -------------------------------------- 159 | 160 | slicew_dsc_distr_vs_slice_idcs(results=results_slicew, 161 | config=config, 162 | num_classes=num_classes, 163 | class_names=class_names) 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /rocaseg/datasets/prepare_dataset_maknee.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from glob import glob 4 | 5 | import click 6 | from joblib import Parallel, delayed 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import pydicom 13 | import cv2 14 | 15 | 16 | cv2.ocl.setUseOpenCL(False) 17 | 18 | 19 | def read_dicom(fname, only_data=False): 20 | data = pydicom.read_file(fname) 21 | if len(data.PixelData) == 131072: 22 | dtype = np.uint16 23 | else: 24 | dtype = np.uint8 25 | image = np.frombuffer(data.PixelData, dtype=dtype).astype(float) 26 | 27 | if data.PhotometricInterpretation == 'MONOCHROME1': 28 | image = image.max() - image 29 | 30 | image = image.reshape((data.Rows, data.Columns)) 31 | 32 | if only_data: 33 | return image 34 | else: 35 | if hasattr(data, 'ImagerPixelSpacing'): 36 | spacing = [float(e) for e in data.ImagerPixelSpacing[:2]] 37 | slice_thickness = float(data.SliceThickness) 38 | elif hasattr(data, 'PixelSpacing'): 39 | spacing = [float(e) for e in data.PixelSpacing[:2]] 40 | slice_thickness = float(data.SliceThickness) 41 | else: 42 | msg = f'DICOM {fname} does not contain spacing info' 43 | print(msg) 44 | spacing = (0.0, 0.0) 45 | slice_thickness = 0.0 46 | 47 | if data.Laterality == 'R': 48 | side = 'RIGHT' 49 | elif data.Laterality == 'L': 50 | side = 'LEFT' 51 | else: 52 | msg = 'DICOM {fname} does not contain side info' 53 | raise AttributeError(msg) 54 | return image, spacing[0], spacing[1], slice_thickness, side 55 | 56 | 57 | @click.command() 58 | @click.argument('path_root_maknee') 59 | @click.argument('path_root_output') 60 | @click.option('--num_threads', default=12, type=click.IntRange(0, 16)) 61 | @click.option('--margin', default=0, type=int) 62 | @click.option('--meta_only', is_flag=True) 63 | def main(**config): 64 | config['path_root_maknee'] = os.path.abspath(config['path_root_maknee']) 65 | config['path_root_output'] = os.path.abspath(config['path_root_output']) 66 | 67 | # ------------------------------------------------------------------------- 68 | def worker(path_root_output, row, margin): 69 | meta = defaultdict(list) 70 | 71 | patient = row['patient'] 72 | slice_idx = row['slice_idx'] 73 | release = 'initial' 74 | sequence = 't2_de3d_we_sag_iso' 75 | 76 | image, *dicom_meta = read_dicom(row['fname_full_image']) 77 | 78 | side = dicom_meta[3] 79 | 80 | if margin != 0: 81 | image = image[margin:-margin, margin:-margin] 82 | 83 | fname_pattern = '{slice_idx:>03}.{ext}' 84 | 85 | # Save image and mask data 86 | dir_rel_image = os.path.join(patient, release, sequence, 'images') 87 | dir_rel_mask = os.path.join(patient, release, sequence, 'masks') 88 | dir_abs_image = os.path.join(path_root_output, dir_rel_image) 89 | dir_abs_mask = os.path.join(path_root_output, dir_rel_mask) 90 | for d in (dir_abs_image, dir_abs_mask): 91 | if not os.path.exists(d): 92 | os.makedirs(d) 93 | 94 | fname_image = fname_pattern.format(slice_idx=slice_idx, ext='png') 95 | path_abs_image = os.path.join(dir_abs_image, fname_image) 96 | if not config['meta_only']: 97 | cv2.imwrite(path_abs_image, image) 98 | 99 | path_rel_image = os.path.join(dir_rel_image, fname_image) 100 | 101 | meta['patient'].append(patient) 102 | meta['release'].append(release) 103 | meta['sequence'].append(sequence) 104 | meta['side'].append(side) 105 | meta['slice_idx'].append(slice_idx) 106 | meta['pixel_spacing_0'].append(dicom_meta[0]) 107 | meta['pixel_spacing_1'].append(dicom_meta[1]) 108 | meta['slice_thickness'].append(dicom_meta[2]) 109 | meta['path_rel_image'].append(path_rel_image) 110 | return meta 111 | 112 | # ------------------------------------------------------------------------- 113 | 114 | # Get list of images files 115 | fnames_dicom = glob(os.path.join(config['path_root_maknee'], 116 | 'MRI', 'Scans', '**', 117 | 't2_de3d_we_sag_iso*', 'IMG*'), 118 | recursive=True) 119 | fnames_dicom = list(sorted(fnames_dicom)) 120 | 121 | def meta_from_fname(fn): 122 | # root / MRI / Scans / 001 / t2_de3d_we_sag_iso / IMG00000 123 | tmp = fn.split('/') 124 | meta = { 125 | 'fname_full_image': fn, 126 | 'slice_idx': os.path.splitext(tmp[-1])[0][-3:], 127 | 'patient': 'P{:>03}'.format(tmp[-3])} 128 | return meta 129 | 130 | dict_meta = { 131 | 'fname_full_image': [], 132 | 'slice_idx': [], 133 | 'patient': []} 134 | 135 | for e in fnames_dicom: 136 | tmp_meta = meta_from_fname(e) 137 | for k, v in tmp_meta.items(): 138 | dict_meta[k].append(v) 139 | 140 | df_meta = pd.DataFrame.from_dict(dict_meta) 141 | 142 | metas = Parallel(config['num_threads'])(delayed(worker)( 143 | *[config['path_root_output'], row, config['margin']] 144 | ) for _, row in tqdm(df_meta.iterrows(), total=len(df_meta))) 145 | 146 | # Merge meta information from different stacks 147 | tmp = defaultdict(list) 148 | for d in metas: 149 | for k, v in d.items(): 150 | tmp[k].extend(v) 151 | df_meta = pd.DataFrame.from_dict(tmp) 152 | 153 | # Add grading data to the meta-info df 154 | path_file_exp = os.path.join(config['path_root_maknee'], 'MAKnee_KL_subjects.xlsx') 155 | df_kl = pd.read_excel(path_file_exp) 156 | 157 | df_meta_uniq = df_meta.loc[:, ['patient', 'side']].drop_duplicates() 158 | df_kl.loc[:, 'ID'] = ['P{:>03}'.format(e) for e in df_kl['ID']] 159 | df_kl = df_kl.set_index(df_kl['ID']) 160 | tmp_kl = [] 161 | 162 | for _, row in df_meta_uniq.iterrows(): 163 | tmp_patient = row['patient'] 164 | tmp_side = row['side'] 165 | 166 | if tmp_side == 'RIGHT': 167 | tmp_kl.append(int(df_kl.loc[tmp_patient, 'KL right'])) 168 | elif tmp_side == 'LEFT': 169 | tmp_kl.append(int(df_kl.loc[tmp_patient, 'KL left'])) 170 | else: 171 | msg = f'Unexpected side value {tmp_side}' 172 | raise ValueError(msg) 173 | 174 | df_meta_uniq['KL'] = tmp_kl 175 | df_meta = pd.merge(df_meta, df_meta_uniq, on=['patient', 'side'], how='left') 176 | 177 | df_out = df_meta 178 | 179 | path_output_meta = os.path.join(config['path_root_output'], 'meta_base.csv') 180 | df_out.to_csv(path_output_meta, index=False) 181 | 182 | 183 | if __name__ == '__main__': 184 | main() 185 | -------------------------------------------------------------------------------- /rocaseg/datasets/sources.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from sklearn.model_selection import GroupShuffleSplit, GroupKFold 5 | 6 | from rocaseg.datasets import (index_from_path_oai_imo, 7 | index_from_path_okoa, 8 | index_from_path_maknee) 9 | 10 | 11 | logging.basicConfig() 12 | logger = logging.getLogger('datasets') 13 | logger.setLevel(logging.DEBUG) 14 | 15 | 16 | def sources_from_path(path_data_root, 17 | selection=None, 18 | with_folds=False, 19 | fold_num=5, 20 | seed_trainval_test=0): 21 | """ 22 | 23 | Args: 24 | path_data_root: str 25 | 26 | selection: iterable or str or None 27 | 28 | with_folds: bool 29 | Whether to split trainval subset into the folds. 30 | fold_num: int 31 | Number of folds. 32 | seed_trainval_test: int 33 | Random state for the trainval/test splitting. 34 | 35 | Returns: 36 | 37 | """ 38 | if selection is None: 39 | selection = ('oai_imo', 'okoa', 'maknee') 40 | elif isinstance(selection, str): 41 | selection = (selection, ) 42 | 43 | sources = dict() 44 | 45 | for name in selection: 46 | if name == 'oai_imo': 47 | logger.info('--- OAI iMorphics dataset ---') 48 | tmp = dict() 49 | tmp['path_root'] = os.path.join(path_data_root, 50 | '91_OAI_iMorphics_full_meta') 51 | 52 | if not os.path.exists(tmp['path_root']): 53 | logger.warning(f"Dataset {name} is not found in {tmp['path_root']}") 54 | continue 55 | 56 | tmp['full_df'] = index_from_path_oai_imo(tmp['path_root']) 57 | logger.info(f"Total number of samples: " 58 | f"{len(tmp['full_df'])}") 59 | 60 | # Select the specific subset 61 | # Remove two series from the dataset as they are completely missing 62 | # information on patellar cartilage: 63 | # /0.C.2/9674570/20040913/10699609/ 64 | # /1.C.2/9674570/20050829/10488714/ 65 | tmp['sel_df'] = tmp['full_df'][tmp['full_df']['patient'] != '9674570'] 66 | logger.info(f"Selected number of samples: " 67 | f"{len(tmp['sel_df'])}") 68 | 69 | if with_folds: 70 | # Get trainval/test split 71 | tmp_groups = tmp['sel_df'].loc[:, 'patient'].values 72 | tmp_grades = tmp['sel_df'].loc[:, 'KL'].values 73 | 74 | tmp_gss = GroupShuffleSplit(n_splits=1, test_size=0.2, 75 | random_state=seed_trainval_test) 76 | tmp_idcs_trainval, tmp_idcs_test = next(tmp_gss.split(X=tmp['sel_df'], 77 | y=tmp_grades, 78 | groups=tmp_groups)) 79 | tmp['trainval_df'] = tmp['sel_df'].iloc[tmp_idcs_trainval] 80 | tmp['test_df'] = tmp['sel_df'].iloc[tmp_idcs_test] 81 | logger.info(f"Made trainval-test split, number of samples: " 82 | f"{len(tmp['trainval_df'])}, " 83 | f"{len(tmp['test_df'])}") 84 | 85 | # Make k folds 86 | tmp_gkf = GroupKFold(n_splits=fold_num) 87 | tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values 88 | tmp_grades = tmp['trainval_df'].loc[:, 'KL'].values 89 | 90 | tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'], 91 | y=tmp_grades, groups=tmp_groups) 92 | sources['oai_imo'] = tmp 93 | 94 | elif name == 'okoa': 95 | logger.info('--- OKOA dataset ---') 96 | tmp = dict() 97 | tmp['path_root'] = os.path.join(path_data_root, 98 | '32_OKOA_full_meta_rescaled') 99 | 100 | if not os.path.exists(tmp['path_root']): 101 | logger.warning(f"Dataset {name} is not found in {tmp['path_root']}") 102 | continue 103 | 104 | tmp['full_df'] = index_from_path_okoa(tmp['path_root']) 105 | logger.info(f"Total number of samples: " 106 | f"{len(tmp['full_df'])}") 107 | 108 | # Select the specific subset 109 | tmp['sel_df'] = tmp['full_df'] 110 | logger.info(f"Selected number of samples: " 111 | f"{len(tmp['sel_df'])}") 112 | 113 | if with_folds: 114 | # Get trainval/test split 115 | tmp['trainval_df'] = tmp['sel_df'][tmp['sel_df']['subset'] == 'training'] 116 | tmp['test_df'] = tmp['sel_df'][tmp['sel_df']['subset'] == 'evaluation'] 117 | logger.info(f"Made trainval-test split, number of samples: " 118 | f"{len(tmp['trainval_df'])}, " 119 | f"{len(tmp['test_df'])}") 120 | 121 | # Make k folds 122 | tmp_gkf = GroupKFold(n_splits=fold_num) 123 | tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values 124 | 125 | tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'], 126 | groups=tmp_groups) 127 | sources['okoa'] = tmp 128 | 129 | elif name == 'maknee': 130 | logger.info('--- MAKNEE dataset ---') 131 | tmp = dict() 132 | tmp['path_root'] = os.path.join(path_data_root, 133 | '42_MAKNEE_full_meta_rescaled') 134 | 135 | if not os.path.exists(tmp['path_root']): 136 | logger.warning(f"Dataset {name} is not found in {tmp['path_root']}") 137 | continue 138 | 139 | tmp['full_df'] = index_from_path_maknee(tmp['path_root']) 140 | logger.info(f"Total number of samples: " 141 | f"{len(tmp['full_df'])}") 142 | 143 | # Select the specific subset 144 | tmp['sel_df'] = tmp['full_df'] 145 | logger.info(f"Selected number of samples: " 146 | f"{len(tmp['sel_df'])}") 147 | 148 | # Get trainval/test split 149 | tmp_groups = tmp['sel_df'].loc[:, 'patient'].values 150 | 151 | tmp_gss = GroupShuffleSplit(n_splits=1, test_size=0.2, 152 | random_state=seed_trainval_test) 153 | tmp_idcs_trainval, tmp_idcs_test = next(tmp_gss.split(X=tmp['sel_df'], 154 | groups=tmp_groups)) 155 | tmp['trainval_df'] = tmp['sel_df'].iloc[tmp_idcs_trainval] 156 | tmp['test_df'] = tmp['sel_df'].iloc[tmp_idcs_test] 157 | logger.info(f"Made trainval-test split, number of samples: " 158 | f"{len(tmp['trainval_df'])}, " 159 | f"{len(tmp['test_df'])}") 160 | 161 | if with_folds: 162 | # Make k folds 163 | tmp_gkf = GroupKFold(n_splits=fold_num) 164 | tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values 165 | 166 | tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'], 167 | groups=tmp_groups) 168 | sources['maknee'] = tmp 169 | 170 | else: 171 | raise ValueError(f'Unknown dataset `{name}`') 172 | 173 | return sources 174 | -------------------------------------------------------------------------------- /rocaseg/datasets/dataset_okoa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import logging 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from joblib import Parallel, delayed 9 | 10 | from tqdm import tqdm 11 | import cv2 12 | from torch.utils.data.dataset import Dataset 13 | 14 | from rocaseg.datasets.constants import locations_f43h 15 | 16 | 17 | logging.basicConfig() 18 | logger = logging.getLogger('dataset') 19 | logger.setLevel(logging.DEBUG) 20 | 21 | 22 | def index_from_path_okoa(path_root, force=False): 23 | fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv') 24 | fname_meta_base = os.path.join(path_root, 'meta_base.csv') 25 | 26 | if not os.path.exists(fname_meta_dyn) or force: 27 | fnames_image = glob.glob( 28 | os.path.join(path_root, '**', 'images', '*.png'), recursive=True) 29 | logger.info('{} images found'.format(len(fnames_image))) 30 | fnames_mask = glob.glob( 31 | os.path.join(path_root, '**', 'masks', '*.png'), recursive=True) 32 | logger.info('{} masks found'.format(len(fnames_mask))) 33 | 34 | df_meta = pd.read_csv(fname_meta_base, 35 | dtype={'subset': str, 36 | 'patient': str, 37 | 'release': str, 38 | 'sequence': str, 39 | 'side': str, 40 | 'slice_idx': int, 41 | 'pixel_spacing_0': float, 42 | 'pixel_spacing_1': float, 43 | 'slice_thickness': float, 44 | 'KL': int}, 45 | index_col=False) 46 | 47 | if len(fnames_image) != len(df_meta): 48 | raise ValueError("Number of images doesn't match with the metadata") 49 | if len(fnames_mask) != len(df_meta): 50 | raise ValueError("Number of masks doesn't match with the metadata") 51 | 52 | df_meta['path_image'] = [os.path.join(path_root, e) 53 | for e in df_meta['path_rel_image']] 54 | df_meta['path_mask'] = [os.path.join(path_root, e) 55 | for e in df_meta['path_rel_mask']] 56 | 57 | # Check for in-slice mask presence 58 | def worker_74n(fname): 59 | mask = read_mask(path_file=fname, mask_mode='raw') 60 | if np.any(mask[1:] > 0): 61 | return 1 62 | else: 63 | return 0 64 | 65 | logger.info('Exploring the annotations') 66 | tmp = Parallel(n_jobs=-1)( 67 | delayed(worker_74n)(row['path_mask']) 68 | for _, row in tqdm(df_meta.iterrows(), total=len(df_meta))) 69 | df_meta['has_mask'] = tmp 70 | 71 | # Sort the records 72 | df_meta_sorted = (df_meta 73 | .sort_values(['patient', 'sequence', 'slice_idx']) 74 | .reset_index() 75 | .drop('index', axis=1)) 76 | 77 | df_meta_sorted.to_csv(fname_meta_dyn, index=False) 78 | else: 79 | df_meta_sorted = pd.read_csv(fname_meta_dyn, 80 | dtype={'subset': str, 81 | 'patient': str, 82 | 'release': str, 83 | 'sequence': str, 84 | 'side': str, 85 | 'slice_idx': int, 86 | 'pixel_spacing_0': float, 87 | 'pixel_spacing_1': float, 88 | 'slice_thickness': float, 89 | 'KL': int, 90 | 'has_mask': int}, 91 | index_col=False) 92 | return df_meta_sorted 93 | 94 | 95 | def read_image(path_file): 96 | image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE) 97 | return image.reshape((1, *image.shape)) 98 | 99 | 100 | def read_mask(path_file, mask_mode): 101 | """Read mask from the file, and pre-process it. 102 | 103 | IMPORTANT: currently, we handle the inter-class collisions by assigning 104 | the joint pixels to a class with a lower index. 105 | 106 | Parameters 107 | ---------- 108 | path_file: str 109 | Full path to mask file. 110 | mask_mode: str 111 | Specifies which channels of mask to use. 112 | 113 | Returns 114 | ------- 115 | out : (ch, d0, d1) uint8 ndarray 116 | 117 | """ 118 | mask = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE) 119 | 120 | locations = { 121 | '_background': (locations_f43h['_background'],), 122 | 'femoral': (locations_f43h['femoral'],), 123 | 'tibial': (locations_f43h['tibial'], ), 124 | } 125 | 126 | if mask_mode == 'raw': 127 | return mask 128 | elif mask_mode == 'background_femoral_unitibial': 129 | ret = np.empty((3, *mask.shape), dtype=mask.dtype) 130 | ret[0, :, :] = np.isin(mask, locations['_background']).astype(np.uint8) 131 | ret[1, :, :] = np.isin(mask, locations['femoral']).astype(np.uint8) 132 | ret[2, :, :] = np.isin(mask, locations['tibial']).astype(np.uint8) 133 | return ret 134 | else: 135 | raise ValueError('Invalid `mask_mode`') 136 | 137 | 138 | class DatasetOKOASagittal2d(Dataset): 139 | def __init__(self, df_meta, mask_mode=None, name=None, transforms=None, 140 | sample_mode='x_y', **kwargs): 141 | logger.warning('Redundant dataset init arguments:\n{}' 142 | .format(repr(kwargs))) 143 | 144 | self.df_meta = df_meta 145 | self.mask_mode = mask_mode 146 | self.name = name 147 | self.transforms = transforms 148 | self.sample_mode = sample_mode 149 | 150 | def __len__(self): 151 | return len(self.df_meta) 152 | 153 | def _getitem_x_y(self, idx): 154 | image = read_image(self.df_meta['path_image'].iloc[idx]) 155 | mask = read_mask(self.df_meta['path_mask'].iloc[idx], self.mask_mode) 156 | 157 | # Apply transformations 158 | if self.transforms is not None: 159 | for t in self.transforms: 160 | if hasattr(t, 'randomize'): 161 | t.randomize() 162 | image, mask = t(image, mask) 163 | 164 | tmp = dict(self.df_meta.iloc[idx]) 165 | tmp['image'] = image 166 | tmp['mask'] = mask 167 | 168 | tmp['xs'] = tmp['image'] 169 | tmp['ys'] = tmp['mask'] 170 | return tmp 171 | 172 | def __getitem__(self, idx): 173 | if self.sample_mode == 'x_y': 174 | return self._getitem_x_y(idx) 175 | else: 176 | raise ValueError('Invalid `sample_mode`') 177 | 178 | def read_image(self, path_file): 179 | return read_image(path_file=path_file) 180 | 181 | def read_mask(self, path_file): 182 | return read_mask(path_file=path_file, mask_mode=self.mask_mode) 183 | 184 | def describe(self): 185 | summary = defaultdict(float) 186 | for i in range(len(self)): 187 | if self.sample_mode == 'x_y': 188 | _, mask = self.__getitem__(i) 189 | else: 190 | mask = self.__getitem__(i)['mask'] 191 | summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2)) 192 | summary['class_importance'] = \ 193 | np.sum(summary['num_class_pixels']) / summary['num_class_pixels'] 194 | summary['class_importance'] /= np.sum(summary['class_importance']) 195 | logger.info('Dataset statistics:') 196 | logger.info(sorted(summary.items())) 197 | -------------------------------------------------------------------------------- /rocaseg/datasets/prepare_dataset_okoa.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import click 4 | 5 | from joblib import Parallel, delayed 6 | from glob import glob 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import pydicom 13 | import cv2 14 | 15 | 16 | cv2.ocl.setUseOpenCL(False) 17 | 18 | 19 | def read_dicom(fname, only_data=False): 20 | data = pydicom.read_file(fname) 21 | if len(data.PixelData) == 131072: 22 | dtype = np.uint16 23 | else: 24 | dtype = np.uint8 25 | image = np.frombuffer(data.PixelData, dtype=dtype).astype(float) 26 | 27 | if data.PhotometricInterpretation == 'MONOCHROME1': 28 | image = image.max() - image 29 | 30 | image = image.reshape((data.Rows, data.Columns)) 31 | 32 | if only_data: 33 | return image 34 | else: 35 | if hasattr(data, 'ImagerPixelSpacing'): 36 | spacing = [float(e) for e in data.ImagerPixelSpacing] 37 | slice_thickness = float(data.SliceThickness) 38 | elif hasattr(data, 'PixelSpacing'): 39 | spacing = [float(e) for e in data.PixelSpacing] 40 | slice_thickness = float(data.SliceThickness) 41 | else: 42 | msg = f'DICOM {fname} does not have the required attributes' 43 | print(msg) 44 | spacing = (0.0, 0.0) 45 | slice_thickness = 0.0 46 | return image, spacing[0], spacing[1], slice_thickness 47 | 48 | 49 | @click.command() 50 | @click.argument('path_root_okoa') 51 | @click.argument('path_root_output') 52 | @click.option('--num_threads', default=12, type=click.IntRange(0, 16)) 53 | @click.option('--margin', default=0, type=int) 54 | @click.option('--meta_only', is_flag=True) 55 | def main(**config): 56 | config['path_root_okoa'] = os.path.abspath(config['path_root_okoa']) 57 | config['path_root_output'] = os.path.abspath(config['path_root_output']) 58 | 59 | # ------------------------------------------------------------------------- 60 | def worker_s5g(path_root_output, row, margin): 61 | meta = defaultdict(list) 62 | 63 | patient = row['patient'].values[0] 64 | slice_idx = row['slice_idx'].values[0] 65 | side = row['side'].values[0] 66 | release = 'initial' 67 | sequence = 't2_de3d_we_sag_iso' 68 | 69 | image, *dicom_meta = read_dicom(row[('fname_full', 'Images')]) 70 | 71 | # Set the default values for the voxel spacings 72 | pixel_spacing_0 = dicom_meta[0] or '0.5859375' 73 | pixel_spacing_1 = dicom_meta[1] or '0.5859375' 74 | slice_thickness = dicom_meta[2] or '0.60000002384186' 75 | 76 | mask_femur = read_dicom(row[('fname_full', 'Femur')], only_data=True) 77 | mask_tibia = read_dicom(row[('fname_full', 'Tibia')], only_data=True) 78 | mask_full = np.zeros(mask_femur.shape, dtype=np.uint8) 79 | # Use inverse order to prioritize femoral tissues in collision handling 80 | mask_full[mask_tibia > 0] = 2 81 | mask_full[mask_femur > 0] = 1 82 | 83 | if margin: 84 | image = image[margin:-margin, margin:-margin] 85 | mask_full = mask_full[margin:-margin, margin:-margin] 86 | 87 | fname_pattern = '{slice_idx:>03}.{ext}' 88 | 89 | # Save image and mask data 90 | dir_rel_image = os.path.join(patient, release, sequence, 'images') 91 | dir_rel_mask = os.path.join(patient, release, sequence, 'masks') 92 | dir_abs_image = os.path.join(path_root_output, dir_rel_image) 93 | dir_abs_mask = os.path.join(path_root_output, dir_rel_mask) 94 | for d in (dir_abs_image, dir_abs_mask): 95 | if not os.path.exists(d): 96 | os.makedirs(d) 97 | 98 | fname_image = fname_pattern.format(slice_idx=slice_idx, ext='png') 99 | path_abs_image = os.path.join(dir_abs_image, fname_image) 100 | if not config['meta_only']: 101 | cv2.imwrite(path_abs_image, image) 102 | 103 | fname_mask = fname_pattern.format(slice_idx=slice_idx, ext='png') 104 | path_abs_mask = os.path.join(dir_abs_mask, fname_mask) 105 | if not config['meta_only']: 106 | cv2.imwrite(path_abs_mask, mask_full) 107 | 108 | path_rel_image = os.path.join(dir_rel_image, fname_image) 109 | path_rel_mask = os.path.join(dir_rel_mask, fname_mask) 110 | 111 | meta['subset'].append(row['subset'].values[0]) 112 | meta['patient'].append(patient) 113 | meta['release'].append(release) 114 | meta['sequence'].append(sequence) 115 | meta['side'].append(side) 116 | meta['KL'].append(row['KL'].values[0]) 117 | meta['slice_idx'].append(slice_idx) 118 | meta['pixel_spacing_0'].append(pixel_spacing_0) 119 | meta['pixel_spacing_1'].append(pixel_spacing_1) 120 | meta['slice_thickness'].append(slice_thickness) 121 | meta['path_rel_image'].append(path_rel_image) 122 | meta['path_rel_mask'].append(path_rel_mask) 123 | return meta 124 | 125 | # ------------------------------------------------------------------------- 126 | 127 | # Get list of images files 128 | paths_fnames_dicom = glob(os.path.join(config['path_root_okoa'], '**', '*.IMA'), 129 | recursive=True) 130 | 131 | # root / training|evaluation / P36 / Images|Femur|Tibia / (1-160).IMA 132 | def meta_from_fname(fn): 133 | tmp = fn.split('/') 134 | slice_idx = int(os.path.splitext(tmp[-1])[0]) - 1 135 | slice_idx = '{:>03}'.format(slice_idx) 136 | meta = { 137 | 'fname_full': fn, 138 | 'slice_idx': slice_idx, 139 | 'kind': tmp[-2], 140 | 'patient': tmp[-3], 141 | 'subset': tmp[-4] 142 | } 143 | return meta 144 | 145 | dict_meta = { 146 | 'fname_full': [], 147 | 'slice_idx': [], 148 | 'kind': [], 149 | 'patient': [], 150 | 'subset': [] 151 | } 152 | 153 | for e in paths_fnames_dicom: 154 | tmp_meta = meta_from_fname(e) 155 | for k, v in tmp_meta.items(): 156 | dict_meta[k].append(v) 157 | 158 | df_meta = pd.DataFrame.from_dict(dict_meta) 159 | df_meta = (df_meta 160 | .set_index(['subset', 'patient', 'slice_idx', 'kind']) 161 | .unstack('kind') 162 | .reset_index()) 163 | 164 | # Add info on scan side and KL grades 165 | path_file_side = os.path.join(config['path_root_okoa'], 'sides.csv') 166 | df_side = pd.read_csv(path_file_side) 167 | path_file_kl = os.path.join(config['path_root_okoa'], 'KL_grades.csv') 168 | df_kl = pd.read_csv(path_file_kl) 169 | 170 | df_extra = pd.merge(df_side, df_kl, on='patient', how='left', sort=True) 171 | df_extra.loc[:, 'KL'] = -1 172 | for r_idx, r in df_extra.iterrows(): 173 | if r['side'] == 'LEFT': 174 | df_extra.loc[r_idx, 'KL'] = r['KL left'] 175 | elif r['side'] == 'RIGHT': 176 | df_extra.loc[r_idx, 'KL'] = r['KL right'] 177 | 178 | # Keep only the fields of interest 179 | df_extra = df_extra.loc[:, ['patient', 'side', 'KL', 'age']] 180 | 181 | # Make same multi-index such than pd.merge doesn't create extra column 182 | df_extra.columns = pd.MultiIndex.from_tuples([(c, '') for c in df_extra.columns]) 183 | 184 | # Merge the metadata into the single df 185 | df_meta = pd.merge(df_meta, df_extra, on='patient', how='left', sort=True) 186 | 187 | # Process the raw data 188 | metas = Parallel(config['num_threads'])(delayed(worker_s5g)( 189 | *[config['path_root_output'], row, config['margin']] 190 | ) for _, row in tqdm(df_meta.iterrows(), total=len(df_meta))) 191 | 192 | # Merge meta information from different stacks 193 | tmp = defaultdict(list) 194 | for d in metas: 195 | for k, v in d.items(): 196 | tmp[k].extend(v) 197 | df_out = pd.DataFrame.from_dict(tmp) 198 | 199 | path_output_meta = os.path.join(config['path_root_output'], 'meta_base.csv') 200 | df_out.to_csv(path_output_meta, index=False) 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | -------------------------------------------------------------------------------- /rocaseg/datasets/dataset_oai_imo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import logging 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from joblib import Parallel, delayed 9 | from tqdm import tqdm 10 | 11 | import cv2 12 | from torch.utils.data.dataset import Dataset 13 | 14 | from rocaseg.datasets.constants import locations_mh53 15 | 16 | 17 | logging.basicConfig() 18 | logger = logging.getLogger('dataset') 19 | logger.setLevel(logging.DEBUG) 20 | 21 | 22 | def index_from_path_oai_imo(path_root, force=False): 23 | fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv') 24 | fname_meta_base = os.path.join(path_root, 'meta_base.csv') 25 | 26 | if not os.path.exists(fname_meta_dyn) or force: 27 | fnames_image = glob.glob( 28 | os.path.join(path_root, '**', 'images', '*.png'), recursive=True) 29 | logger.info('{} images found'.format(len(fnames_image))) 30 | fnames_mask = glob.glob( 31 | os.path.join(path_root, '**', 'masks', '*.png'), recursive=True) 32 | logger.info('{} masks found'.format(len(fnames_mask))) 33 | 34 | df_meta = pd.read_csv(fname_meta_base, 35 | dtype={'patient': str, 36 | 'release': str, 37 | 'prefix_var': str, 38 | 'sequence': str, 39 | 'side': str, 40 | 'slice_idx': int, 41 | 'pixel_spacing_0': float, 42 | 'pixel_spacing_1': float, 43 | 'slice_thickness': float, 44 | 'KL': int}, 45 | index_col=False) 46 | 47 | if len(fnames_image) != len(df_meta): 48 | raise ValueError("Number of images doesn't match with the metadata") 49 | if len(fnames_mask) != len(df_meta): 50 | raise ValueError("Number of masks doesn't match with the metadata") 51 | 52 | df_meta['path_image'] = [os.path.join(path_root, e) 53 | for e in df_meta['path_rel_image']] 54 | df_meta['path_mask'] = [os.path.join(path_root, e) 55 | for e in df_meta['path_rel_mask']] 56 | 57 | # Check for in-slice mask presence 58 | def worker_nv7(fname): 59 | mask = read_mask(path_file=fname, mask_mode='raw') 60 | if np.any(mask[1:] > 0): 61 | return 1 62 | else: 63 | return 0 64 | 65 | logger.info('Exploring the annotations') 66 | tmp = Parallel(n_jobs=-1)( 67 | delayed(worker_nv7)(row['path_mask']) 68 | for _, row in tqdm(df_meta.iterrows(), total=len(df_meta))) 69 | df_meta['has_mask'] = tmp 70 | 71 | # Sort the records 72 | df_meta_sorted = (df_meta 73 | .sort_values(['patient', 'prefix_var', 74 | 'sequence', 'slice_idx']) 75 | .reset_index() 76 | .drop('index', axis=1)) 77 | 78 | df_meta_sorted.to_csv(fname_meta_dyn, index=False) 79 | else: 80 | df_meta_sorted = pd.read_csv(fname_meta_dyn, 81 | dtype={'patient': str, 82 | 'release': str, 83 | 'prefix_var': str, 84 | 'sequence': str, 85 | 'side': str, 86 | 'slice_idx': int, 87 | 'pixel_spacing_0': float, 88 | 'pixel_spacing_1': float, 89 | 'slice_thickness': float, 90 | 'KL': int, 91 | 'has_mask': int}, 92 | index_col=False) 93 | return df_meta_sorted 94 | 95 | 96 | def read_image(path_file): 97 | image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE) 98 | return image.reshape((1, *image.shape)) 99 | 100 | 101 | def read_mask(path_file, mask_mode): 102 | """Read mask from the file, and pre-process it. 103 | 104 | IMPORTANT: currently, we handle the inter-class collisions by assigning 105 | the joint pixels to a class with a lower index. 106 | 107 | Parameters 108 | ---------- 109 | path_file: str 110 | Full path to mask file. 111 | mask_mode: str 112 | Specifies which channels of mask to use. 113 | 114 | Returns 115 | ------- 116 | out : (ch, d0, d1) uint8 ndarray 117 | 118 | """ 119 | mask = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE) 120 | 121 | locations = { 122 | '_background': (locations_mh53['Background'], ), 123 | 'femoral': (locations_mh53['FemoralCartilage'], ), 124 | 'tibial': (locations_mh53['LateralTibialCartilage'], 125 | locations_mh53['MedialTibialCartilage']), 126 | 'patellar': (locations_mh53['PatellarCartilage'], ), 127 | 'menisci': (locations_mh53['LateralMeniscus'], 128 | locations_mh53['MedialMeniscus']), 129 | } 130 | 131 | if mask_mode == 'raw': 132 | return mask 133 | elif mask_mode == 'all_unitibial_unimeniscus': 134 | ret = np.empty((5, *mask.shape), dtype=mask.dtype) 135 | ret[0, :, :] = np.isin(mask, locations['_background']).astype(np.uint8) 136 | ret[1, :, :] = np.isin(mask, locations['femoral']).astype(np.uint8) 137 | ret[2, :, :] = np.isin(mask, locations['tibial']).astype(np.uint8) 138 | ret[3, :, :] = np.isin(mask, locations['patellar']).astype(np.uint8) 139 | ret[4, :, :] = np.isin(mask, locations['menisci']).astype(np.uint8) 140 | return ret 141 | else: 142 | raise ValueError('Invalid `mask_mode`') 143 | 144 | 145 | class DatasetOAIiMoSagittal2d(Dataset): 146 | def __init__(self, df_meta, mask_mode=None, name=None, transforms=None, 147 | sample_mode='x_y', **kwargs): 148 | logger.warning('Redundant dataset init arguments:\n{}' 149 | .format(repr(kwargs))) 150 | 151 | self.df_meta = df_meta 152 | self.mask_mode = mask_mode 153 | self.name = name 154 | self.transforms = transforms 155 | self.sample_mode = sample_mode 156 | 157 | def __len__(self): 158 | return len(self.df_meta) 159 | 160 | def _getitem_x_y(self, idx): 161 | image = read_image(self.df_meta['path_image'].iloc[idx]) 162 | mask = read_mask(self.df_meta['path_mask'].iloc[idx], self.mask_mode) 163 | 164 | # Apply transformations 165 | if self.transforms is not None: 166 | for t in self.transforms: 167 | if hasattr(t, 'randomize'): 168 | t.randomize() 169 | image, mask = t(image, mask) 170 | 171 | tmp = dict(self.df_meta.iloc[idx]) 172 | tmp['image'] = image 173 | tmp['mask'] = mask 174 | 175 | tmp['xs'] = tmp['image'] 176 | tmp['ys'] = tmp['mask'] 177 | return tmp 178 | 179 | def __getitem__(self, idx): 180 | if self.sample_mode == 'x_y': 181 | return self._getitem_x_y(idx) 182 | else: 183 | raise ValueError('Invalid `sample_mode`') 184 | 185 | def read_image(self, path_file): 186 | return read_image(path_file=path_file) 187 | 188 | def read_mask(self, path_file): 189 | return read_mask(path_file=path_file, mask_mode=self.mask_mode) 190 | 191 | def describe(self): 192 | summary = defaultdict(float) 193 | for i in range(len(self)): 194 | if self.sample_mode == 'x_y': 195 | _, mask = self.__getitem__(i) 196 | else: 197 | mask = self.__getitem__(i)['mask'] 198 | summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2)) 199 | summary['class_importance'] = \ 200 | np.sum(summary['num_class_pixels']) / summary['num_class_pixels'] 201 | summary['class_importance'] /= np.sum(summary['class_importance']) 202 | logger.info('Dataset statistics:') 203 | logger.info(sorted(summary.items())) 204 | -------------------------------------------------------------------------------- /rocaseg/models/unet_lext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamically created UNet with variable Width, Depth and activation 3 | 4 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c). 5 | 6 | """ 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from collections import OrderedDict 13 | 14 | 15 | logging.basicConfig() 16 | logger = logging.getLogger('models') 17 | logger.setLevel(logging.DEBUG) 18 | 19 | 20 | def depthwise_separable_conv(input_channels, output_channels): 21 | """ 22 | """ 23 | depthwise = nn.Conv2d(input_channels, input_channels, 24 | kernel_size=3, padding=1, groups=input_channels) 25 | pointwise = nn.Conv2d(input_channels, output_channels, 26 | kernel_size=1) 27 | return nn.Sequential(depthwise, pointwise) 28 | 29 | 30 | def block_conv_bn_act(input_channels, output_channels, 31 | convolution, activation): 32 | """ 33 | """ 34 | if convolution == 'regular': 35 | layer_conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1) 36 | elif convolution == 'depthwise_separable': 37 | layer_conv = depthwise_separable_conv(input_channels, output_channels) 38 | else: 39 | raise ValueError(f'Wrong `convolution`: {convolution}') 40 | 41 | if activation == 'relu': 42 | layer_act = nn.ReLU(inplace=True) 43 | elif activation == 'selu': 44 | layer_act = nn.SELU(inplace=True) 45 | elif activation == 'elu': 46 | layer_act = nn.ELU(1, inplace=True) 47 | else: 48 | raise ValueError(f'Wrong `activation`: {activation}') 49 | 50 | block = list() 51 | block.append(layer_conv) 52 | if activation == 'relu': 53 | block.append(nn.BatchNorm2d(output_channels)) 54 | block.append(layer_act) 55 | 56 | return nn.Sequential(*block) 57 | 58 | 59 | class Encoder(nn.Module): 60 | """Encoder class. for encoder-decoder architecture. 61 | 62 | """ 63 | def __init__(self, input_channels, output_channels, 64 | depth=2, convolution='regular', activation='relu'): 65 | super().__init__() 66 | self.layers = nn.Sequential() 67 | for i in range(depth): 68 | tmp = [] 69 | if i == 0: 70 | tmp.append(block_conv_bn_act(input_channels, output_channels, 71 | convolution=convolution, 72 | activation=activation)) 73 | else: 74 | tmp.append(block_conv_bn_act(output_channels, output_channels, 75 | convolution=convolution, 76 | activation=activation)) 77 | 78 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp)) 79 | 80 | def forward(self, x): 81 | processed = self.layers(x) 82 | pooled = F.max_pool2d(processed, 2, 2) 83 | return processed, pooled 84 | 85 | 86 | class Decoder(nn.Module): 87 | """Decoder class. for encoder-decoder architecture. 88 | 89 | """ 90 | def __init__(self, input_channels, output_channels, depth=2, mode='bilinear', 91 | convolution='regular', activation='relu'): 92 | super().__init__() 93 | self.layers = nn.Sequential() 94 | self.ups_mode = mode 95 | for i in range(depth): 96 | tmp = [] 97 | if i == 0: 98 | tmp.append(block_conv_bn_act(input_channels, output_channels, 99 | convolution=convolution, 100 | activation=activation)) 101 | else: 102 | tmp.append(block_conv_bn_act(output_channels, output_channels, 103 | convolution=convolution, 104 | activation=activation)) 105 | 106 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp)) 107 | 108 | def forward(self, x_big, x): 109 | x_ups = F.interpolate(x, size=x_big.size()[-2:], mode=self.ups_mode, 110 | align_corners=True) 111 | y_cat = torch.cat([x_ups, x_big], 1) 112 | y = self.layers(y_cat) 113 | return y 114 | 115 | 116 | class UNetLext(nn.Module): 117 | """UNet architecture with 3x3 convolutions. Created dynamically based on depth and width. 118 | 119 | """ 120 | def __init__(self, basic_width=24, depth=6, center_depth=2, 121 | input_channels=3, output_channels=1, 122 | convolution='regular', activation='relu', 123 | pretrained=False, path_pretrained=None, 124 | restore_weights=False, path_weights=None, **kwargs): 125 | """ 126 | 127 | Parameters 128 | ---------- 129 | basic_width: 130 | Basic width of the network, which is doubled at each layer. 131 | depth: 132 | Number of layers. 133 | center_depth: 134 | Depth of the central block in UNet. 135 | input_channels: 136 | Number of input channels. 137 | output_channels: 138 | Number of output channels (/classes). 139 | convolution: {'regular', 'depthwise_separable'} 140 | activation: {'ReLU', 'SeLU', 'ELU'} 141 | Activation function. 142 | restore_weights: bool 143 | ??? 144 | path_weights: str 145 | ??? 146 | kwargs: 147 | Catches redundant arguments and issues a warning. 148 | """ 149 | assert depth >= 2 150 | super().__init__() 151 | logger.warning('Redundant model init arguments:\n{}' 152 | .format(repr(kwargs))) 153 | 154 | # Preparing the modules dict 155 | modules = OrderedDict() 156 | 157 | modules['down1'] = Encoder(input_channels, basic_width, activation=activation) 158 | 159 | # Automatically creating the Encoder based on the depth and width 160 | for level in range(2, depth + 1): 161 | mul_in = 2 ** (level - 2) 162 | mul_out = 2 ** (level - 1) 163 | layer = Encoder(basic_width * mul_in, basic_width * mul_out, 164 | convolution=convolution, activation=activation) 165 | modules['down' + str(level)] = layer 166 | 167 | # Creating the center 168 | modules['center'] = nn.Sequential( 169 | *[block_conv_bn_act(basic_width * mul_out, basic_width * mul_out, 170 | convolution=convolution, activation=activation) 171 | for _ in range(center_depth)] 172 | ) 173 | 174 | # Automatically creating the decoder 175 | for level in reversed(range(2, depth + 1)): 176 | mul_in = 2 ** (level - 1) 177 | layer = Decoder(2 * basic_width * mul_in, basic_width * mul_in // 2, 178 | convolution=convolution, activation=activation) 179 | modules['up' + str(level)] = layer 180 | 181 | modules['up1'] = Decoder(basic_width * 2, basic_width * 2, 182 | convolution=convolution, activation=activation) 183 | 184 | modules['mixer'] = nn.Conv2d(basic_width * 2, output_channels, 185 | kernel_size=1, padding=0, stride=1, 186 | bias=True) 187 | 188 | self.__dict__['_modules'] = modules 189 | if pretrained: 190 | self.load_state_dict(torch.load(path_pretrained)) 191 | if restore_weights: 192 | self.load_state_dict(torch.load(path_weights)) 193 | 194 | def forward(self, x): 195 | encoded_results = {} 196 | 197 | out = x 198 | for name in self.__dict__['_modules']: 199 | if name.startswith('down'): 200 | layer = self.__dict__['_modules'][name] 201 | convolved, pooled = layer(out) 202 | encoded_results[name] = convolved 203 | out = pooled 204 | 205 | out = self.center(out) 206 | 207 | for name in self.__dict__['_modules']: 208 | if name.startswith('up'): 209 | layer = self.__dict__['_modules'][name] 210 | out = layer(encoded_results['down' + name[-1]], out) 211 | return self.mixer(out) 212 | -------------------------------------------------------------------------------- /rocaseg/models/unet_lext_aux.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamically created UNet with variable Width, Depth and activation 3 | 4 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c). 5 | 6 | """ 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from collections import OrderedDict 13 | 14 | 15 | logging.basicConfig() 16 | logger = logging.getLogger('models') 17 | logger.setLevel(logging.DEBUG) 18 | 19 | 20 | def ConvBlock3(inp, out, activation): 21 | """3x3 ConvNet building block with different activations support. 22 | 23 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c). 24 | """ 25 | if activation == 'relu': 26 | return nn.Sequential( 27 | nn.Conv2d(inp, out, kernel_size=3, padding=1), 28 | nn.BatchNorm2d(out), 29 | nn.ReLU(inplace=True) 30 | ) 31 | elif activation == 'selu': 32 | return nn.Sequential( 33 | nn.Conv2d(inp, out, kernel_size=3, padding=1), 34 | nn.SELU(inplace=True) 35 | ) 36 | elif activation == 'elu': 37 | return nn.Sequential( 38 | nn.Conv2d(inp, out, kernel_size=3, padding=1), 39 | nn.ELU(1, inplace=True) 40 | ) 41 | 42 | 43 | class Encoder(nn.Module): 44 | """Encoder class. for encoder-decoder architecture. 45 | 46 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c). 47 | """ 48 | def __init__(self, input_channels, output_channels, depth=2, activation='relu'): 49 | super().__init__() 50 | self.layers = nn.Sequential() 51 | for i in range(depth): 52 | tmp = [] 53 | if i == 0: 54 | tmp.append(ConvBlock3(input_channels, output_channels, activation)) 55 | else: 56 | tmp.append(ConvBlock3(output_channels, output_channels, activation)) 57 | 58 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp)) 59 | 60 | def forward(self, x): 61 | processed = self.layers(x) 62 | pooled = F.max_pool2d(processed, 2, 2) 63 | return processed, pooled 64 | 65 | 66 | class Decoder(nn.Module): 67 | """Decoder class. for encoder-decoder architecture. 68 | 69 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c). 70 | """ 71 | def __init__(self, input_channels, output_channels, depth=2, mode='bilinear', 72 | activation='relu'): 73 | super().__init__() 74 | self.layers = nn.Sequential() 75 | self.ups_mode = mode 76 | for i in range(depth): 77 | tmp = [] 78 | if i == 0: 79 | tmp.append(ConvBlock3(input_channels, output_channels, activation)) 80 | else: 81 | tmp.append(ConvBlock3(output_channels, output_channels, activation)) 82 | 83 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp)) 84 | 85 | def forward(self, x_big, x): 86 | x_ups = F.interpolate(x, size=x_big.size()[-2:], mode=self.ups_mode, 87 | align_corners=True) 88 | y_cat = torch.cat([x_ups, x_big], 1) 89 | y = self.layers(y_cat) 90 | return y 91 | 92 | 93 | class AuxModule(nn.Module): 94 | def __init__(self, input_channels, output_channels, dilation_series, padding_series): 95 | super(AuxModule, self).__init__() 96 | self.layers = nn.ModuleList() 97 | for dilation, padding in zip(dilation_series, padding_series): 98 | self.layers.append( 99 | nn.Conv2d(input_channels, output_channels, 100 | kernel_size=3, stride=1, padding=padding, 101 | dilation=dilation, bias=True)) 102 | 103 | for m in self.layers: 104 | m.weight.data.normal_(0, 0.01) 105 | 106 | def forward(self, x): 107 | out = self.layers[0](x) 108 | for i in range(len(self.layers) - 1): 109 | out += self.layers[i + 1](x) 110 | return out 111 | 112 | 113 | class UNetLextAux(nn.Module): 114 | """UNet architecture with 3x3 convolutions. Created dynamically based on depth and width. 115 | 116 | Aleksei Tiulpin, 2017 (c) 117 | """ 118 | def __init__(self, basic_width=24, depth=6, center_depth=2, 119 | input_channels=3, output_channels=1, activation='relu', 120 | pretrained=False, path_pretrained=None, 121 | restore_weights=False, path_weights=None, 122 | with_aux=True, **kwargs): 123 | """ 124 | 125 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c). 126 | 127 | Parameters 128 | ---------- 129 | basic_width: 130 | Basic width of the network, which is doubled at each layer. 131 | depth: 132 | Number of layers. 133 | center_depth: 134 | Depth of the central block in UNet. 135 | input_channels: 136 | Number of input channels. 137 | output_channels: 138 | Number of output channels (/classes). 139 | activation: {'ReLU', 'SeLU', 'ELU'} 140 | Activation function. 141 | restore_weights: bool 142 | ??? 143 | path_weights: str 144 | ??? 145 | kwargs: 146 | """ 147 | assert depth >= 2 148 | super().__init__() 149 | logger.warning('Redundant model init arguments:\n{}' 150 | .format(repr(kwargs))) 151 | self._with_aux = with_aux 152 | 153 | # Preparing the modules dict 154 | modules = OrderedDict() 155 | 156 | modules['down1'] = Encoder(input_channels, basic_width, activation=activation) 157 | 158 | # Automatically creating the Encoder based on the depth and width 159 | for level in range(2, depth + 1): 160 | mul_in = 2 ** (level - 2) 161 | mul_out = 2 ** (level - 1) 162 | layer = Encoder(basic_width * mul_in, basic_width * mul_out, 163 | activation=activation) 164 | modules['down' + str(level)] = layer 165 | 166 | # Creating the center 167 | modules['center'] = nn.Sequential( 168 | *[ConvBlock3(basic_width * mul_out, basic_width * mul_out, 169 | activation=activation) 170 | for i in range(center_depth)] 171 | ) 172 | 173 | # Automatically creating the decoder 174 | for level in reversed(range(2, depth + 1)): 175 | mul_in = 2 ** (level - 1) 176 | layer = Decoder(2 * basic_width * mul_in, basic_width * mul_in // 2, 177 | activation=activation) 178 | modules['up' + str(level)] = layer 179 | 180 | if self._with_aux and (level == 2): 181 | modules['aux'] = AuxModule( 182 | input_channels=basic_width * mul_in // 2, 183 | output_channels=output_channels, 184 | dilation_series=[6, 12, 18, 24], 185 | padding_series=[6, 12, 18, 24]) 186 | 187 | modules['up1'] = Decoder(basic_width * 2, basic_width * 2, 188 | activation=activation) 189 | 190 | modules['mixer'] = nn.Conv2d(basic_width * 2, output_channels, 191 | kernel_size=1, padding=0, stride=1, 192 | bias=True) 193 | 194 | self.__dict__['_modules'] = modules 195 | if pretrained: 196 | self.load_state_dict(torch.load(path_pretrained)) 197 | if restore_weights: 198 | self.load_state_dict(torch.load(path_weights)) 199 | 200 | def forward(self, x): 201 | encoded_results = {} 202 | 203 | out = x 204 | for name in self.__dict__['_modules']: 205 | if name.startswith('down'): 206 | layer = self.__dict__['_modules'][name] 207 | convolved, pooled = layer(out) 208 | encoded_results[name] = convolved 209 | out = pooled 210 | 211 | out = self.center(out) 212 | 213 | for name in self.__dict__['_modules']: 214 | if name.startswith('up'): 215 | layer = self.__dict__['_modules'][name] 216 | out = layer(encoded_results['down' + name[-1]], out) 217 | 218 | if name == 'up2': 219 | out_aux = out 220 | 221 | if self._with_aux: 222 | out_aux = self.aux(out_aux) 223 | out_aux = F.interpolate(out_aux, size=x.size()[-2:], 224 | mode='bilinear', align_corners=True) 225 | out_main = self.mixer(out) 226 | return out_main, out_aux 227 | else: 228 | return self.mixer(out) 229 | -------------------------------------------------------------------------------- /rocaseg/datasets/prepare_dataset_oai_imo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from collections import defaultdict 4 | 5 | import click 6 | from joblib import Parallel, delayed 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | from sas7bdat import SAS7BDAT 11 | from scipy import io 12 | import pandas as pd 13 | 14 | import pydicom 15 | import cv2 16 | 17 | from rocaseg.datasets.constants import locations_mh53 18 | from rocaseg.datasets.meta_oai import side_code_to_str, release_to_prefix_var 19 | 20 | 21 | cv2.ocl.setUseOpenCL(False) 22 | 23 | 24 | def read_dicom(fname): 25 | data = pydicom.read_file(fname) 26 | image = np.frombuffer(data.PixelData, dtype=np.uint16).astype(float) 27 | 28 | if data.PhotometricInterpretation == 'MONOCHROME1': 29 | image = image.max() - image 30 | image = image.reshape((data.Rows, data.Columns)) 31 | 32 | if 'RIGHT' in data.SeriesDescription: 33 | side = 'RIGHT' 34 | elif 'LEFT' in data.SeriesDescription: 35 | side = 'LEFT' 36 | else: 37 | print(data) 38 | msg = f'DICOM {fname} does not contain side info' 39 | raise ValueError(msg) 40 | 41 | if hasattr(data, 'ImagerPixelSpacing'): 42 | spacing = [float(e) for e in data.ImagerPixelSpacing[:2]] 43 | elif hasattr(data, 'PixelSpacing'): 44 | spacing = [float(e) for e in data.PixelSpacing[:2]] 45 | else: 46 | msg = f'DICOM {fname} does not contain spacing info' 47 | raise AttributeError(msg) 48 | 49 | return (image, 50 | spacing[0], 51 | spacing[1], 52 | float(data.SliceThickness), 53 | side) 54 | 55 | 56 | def mask_from_mat(masks_mat, mask_shape, slice_idx, attr_name): 57 | mask = np.zeros(mask_shape, dtype=np.uint8) 58 | data = getattr(masks_mat[0][slice_idx], attr_name) 59 | if len(data.shape) > 0: 60 | for comp in range(data.shape[1]): 61 | cnt = data[0, comp][:, :2].copy() 62 | cnt[:, 1] = mask_shape[0] - cnt[:, 1] 63 | cntf = cnt.astype(np.int) 64 | cv2.drawContours(mask, [cntf], -1, (255, 255, 255), -1) 65 | mask = (mask > 0).astype(np.uint8) 66 | return mask 67 | 68 | 69 | @click.command() 70 | @click.argument('path_root_oai_mri') 71 | @click.argument('path_root_imo') 72 | @click.argument('path_root_output') 73 | @click.option('--num_threads', default=12, type=click.IntRange(0, 16)) 74 | @click.option('--margin', default=0, type=int) 75 | @click.option('--meta_only', is_flag=True) 76 | def main(**config): 77 | config['path_root_oai_mri'] = os.path.abspath(config['path_root_oai_mri']) 78 | config['path_root_imo'] = os.path.abspath(config['path_root_imo']) 79 | config['path_root_output'] = os.path.abspath(config['path_root_output']) 80 | 81 | # ------------------------------------------------------------------------- 82 | def worker_xz9(path_root_output, path_stack, margin): 83 | meta = defaultdict(list) 84 | 85 | release, patient = path_stack.split('/')[-4:-2] 86 | prefix_var = release_to_prefix_var[release] 87 | sequence = 'sag_3d_dess_we' 88 | 89 | path_annot = os.path.join(config['path_root_imo'], patient, prefix_var) 90 | fnames_annot = glob(os.path.join(path_annot, '*.mat')) 91 | if len(fnames_annot) != 1: 92 | raise ValueError(f'Unexpected annotations for patient: {patient}') 93 | fname_annot = fnames_annot[0] 94 | 95 | file_mat = io.loadmat(os.path.join(path_annot, fname_annot), 96 | struct_as_record=False) 97 | masks_mat = file_mat['datastruct'] 98 | num_slices = masks_mat.shape[1] 99 | 100 | for slice_idx in range(num_slices): 101 | # Indexing of slices in OAI dataset starts with 001 102 | fname_src = os.path.join(path_stack, '{:>03}'.format(slice_idx+1)) 103 | image, *dicom_meta = read_dicom(fname_src) 104 | 105 | side = dicom_meta[3] 106 | 107 | mask_proc = np.zeros_like(image) 108 | 109 | # NOTICE: Reference masks have some collisions. We solve them 110 | # by prioritising the tissues which are earlier in the list. 111 | for part_name, part_value in reversed(locations_mh53.items()): 112 | # Skip the background as it is not presented in the source data 113 | if part_name == 'Background': 114 | continue 115 | try: 116 | mask_temp = mask_from_mat(masks_mat, image.shape, 117 | slice_idx, part_name) 118 | mask_proc[mask_temp > 0] = part_value 119 | except AttributeError: 120 | print(f'Error accessing {part_name} in {fname_src}') 121 | 122 | if margin != 0: 123 | image = image[margin:-margin, margin:-margin] 124 | mask_proc = mask_proc[margin:-margin, margin:-margin] 125 | 126 | fname_pattern = '{slice_idx:>03}.{ext}' 127 | 128 | # Save image and mask data 129 | dir_rel_image = os.path.join(patient, release, sequence, 'images') 130 | dir_rel_mask = os.path.join(patient, release, sequence, 'masks') 131 | dir_abs_image = os.path.join(path_root_output, dir_rel_image) 132 | dir_abs_mask = os.path.join(path_root_output, dir_rel_mask) 133 | for d in (dir_abs_image, dir_abs_mask): 134 | if not os.path.exists(d): 135 | os.makedirs(d) 136 | 137 | fname_image = fname_pattern.format(slice_idx=slice_idx, ext='png') 138 | path_abs_image = os.path.join(dir_abs_image, fname_image) 139 | if not config['meta_only']: 140 | cv2.imwrite(path_abs_image, image) 141 | 142 | fname_mask = fname_pattern.format(slice_idx=slice_idx, ext='png') 143 | path_abs_mask = os.path.join(dir_abs_mask, fname_mask) 144 | if not config['meta_only']: 145 | cv2.imwrite(path_abs_mask, mask_proc) 146 | 147 | path_rel_image = os.path.join(dir_rel_image, fname_image) 148 | path_rel_mask = os.path.join(dir_rel_mask, fname_mask) 149 | 150 | meta['patient'].append(patient) 151 | meta['release'].append(release) 152 | meta['prefix_var'].append(prefix_var) 153 | meta['sequence'].append(sequence) 154 | meta['side'].append(side) 155 | meta['slice_idx'].append(slice_idx) 156 | meta['pixel_spacing_0'].append(dicom_meta[0]) 157 | meta['pixel_spacing_1'].append(dicom_meta[1]) 158 | meta['slice_thickness'].append(dicom_meta[2]) 159 | meta['path_rel_image'].append(path_rel_image) 160 | meta['path_rel_mask'].append(path_rel_mask) 161 | return meta 162 | # ------------------------------------------------------------------------- 163 | 164 | # OAI data path structure: 165 | # root / examination / release / patient / date / barcode (/ slices) 166 | paths_stacks = glob(os.path.join(config['path_root_oai_mri'], '**/**/**/**/**')) 167 | paths_stacks.sort(key=lambda x: int(x.split('/')[-3])) 168 | 169 | metas = Parallel(config['num_threads'])(delayed(worker_xz9)( 170 | *[config['path_root_output'], path_stack, config['margin']] 171 | ) for path_stack in tqdm(paths_stacks)) 172 | 173 | # Merge meta information from different stacks 174 | tmp = defaultdict(list) 175 | for d in metas: 176 | for k, v in d.items(): 177 | tmp[k].extend(v) 178 | df_out = pd.DataFrame.from_dict(tmp) 179 | 180 | # Find the grading data 181 | fnames_sas = glob(os.path.join(config['path_root_oai_mri'], 182 | '*', '*.sas7bdat'), recursive=True) 183 | 184 | # Read semi-quantitative data 185 | dfs = dict() 186 | for fn in fnames_sas: 187 | with SAS7BDAT(fn) as f: 188 | raw = [r for r in f] 189 | tmp = pd.DataFrame(raw[1:], columns=raw[0]) 190 | 191 | prefix_var = [c for c in tmp.columns if c.endswith('XRKL')][0][:3] 192 | 193 | tmp = tmp.rename(lambda x: x.upper(), axis=1) 194 | tmp = tmp.rename({'VERSION': f'{prefix_var}VERSION', 195 | 'ID': 'patient', 196 | 'SIDE': 'side'}, axis=1) 197 | 198 | tmp['side'] = tmp['side'].apply(lambda s: side_code_to_str[s]) 199 | dfs.update({prefix_var: tmp}) 200 | 201 | # Set the index to join on 202 | for k, tmp in dfs.items(): 203 | dfs[k] = tmp.set_index(['patient', 'side', 'READPRJ']) 204 | 205 | df = pd.concat(dfs.values(), axis=1) 206 | df = df.reset_index() 207 | 208 | # Remove unnecessary columns and reformat the grading info 209 | df_sel = df[['patient', 'side', 'V00XRKL', 'V01XRKL']] 210 | df_sel = (df_sel 211 | .set_index(['patient', 'side']) 212 | .rename({'V00XRKL': 'V00', 'V01XRKL': 'V01'}, axis=1) 213 | .stack() 214 | .reset_index() 215 | .rename({'level_2': 'prefix_var', 0: 'KL'}, axis=1)) 216 | 217 | # Select the subset for which the annotations are available 218 | indexers = ['patient', 'side', 'prefix_var'] 219 | sel = df_out.set_index(indexers).index.unique() 220 | df_sel = (df_sel 221 | .drop_duplicates(subset=indexers) # There are ~5 duplicates 222 | .set_index(indexers) 223 | .loc[sel, :] 224 | .reset_index()) 225 | 226 | df_out = pd.merge(df_out, df_sel, on=indexers, how='left') 227 | 228 | path_output_meta = os.path.join(config['path_root_output'], 'meta_base.csv') 229 | df_out.to_csv(path_output_meta, index=False) 230 | 231 | 232 | if __name__ == '__main__': 233 | main() 234 | -------------------------------------------------------------------------------- /scripts/runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # -------------------------------- Prepare datasets ------------------------------------ 4 | (cd ../rocaseg/datasets && 5 | echo "prepare_dataset" && 6 | python prepare_dataset_oai_imo.py \ 7 | ../../../data_raw/OAI_iMorphics_scans \ 8 | ../../../data_raw/OAI_iMorphics_annotations \ 9 | ../../../data/91_OAI_iMorphics_full_meta \ 10 | --margin 20 \ 11 | ) 12 | 13 | (cd ../rocaseg/datasets && 14 | echo "prepare_dataset" && 15 | python prepare_dataset_okoa.py \ 16 | ../../../data_raw/OKOA \ 17 | ../../../data/31_OKOA_full_meta \ 18 | --margin 0 \ 19 | ) 20 | 21 | (cd ../rocaseg/datasets && 22 | echo "prepare_dataset" && 23 | python prepare_dataset_maknee.py \ 24 | ../../../data_raw/MAKNEE \ 25 | ../../../data/41_MAKNEE_full_meta \ 26 | --margin 0 \ 27 | ) 28 | # -------------------------------------------------------------------------------------- 29 | 30 | # -------------------------------- Resample datasets ----------------------------------- 31 | (cd ../rocaseg/ && 32 | echo "resample" && 33 | python resample.py \ 34 | --path_root_in ../../data/31_OKOA_full_meta \ 35 | --spacing_in 0.5859375 0.5859375 \ 36 | --path_root_out ../../data/32_OKOA_full_meta_rescaled \ 37 | --spacing_out 0.36458333 0.36458333 \ 38 | --num_threads 12 \ 39 | --margin 0 \ 40 | ) 41 | 42 | (cd ../rocaseg/ && 43 | echo "resample" && 44 | python resample.py \ 45 | --path_root_in ../../data/41_MAKNEE_full_meta \ 46 | --spacing_in 0.5859375 0.5859375 \ 47 | --path_root_out ../../data/42_MAKNEE_full_meta_rescaled \ 48 | --spacing_out 0.36458333 0.36458333 \ 49 | --num_threads 12 \ 50 | --margin 0 \ 51 | ) 52 | # -------------------------------------------------------------------------------------- 53 | 54 | # --------------------------------- Train models --------------------------------------- 55 | (cd ../rocaseg/ && 56 | echo "train" && 57 | python train_baseline.py \ 58 | --path_data_root ../../data \ 59 | --path_experiment_root ../../results/0_baseline \ 60 | --model_segm unet_lext \ 61 | --input_channels 1 \ 62 | --output_channels 5 \ 63 | --center_depth 1 \ 64 | --lr_segm 0.001 \ 65 | --batch_size 32 \ 66 | --epoch_num 50 \ 67 | --fold_num 5 \ 68 | --fold_idx -1 \ 69 | --num_workers 12 \ 70 | ) 71 | 72 | (cd ../rocaseg/ && 73 | echo "train" && 74 | python train_baseline.py \ 75 | --path_data_root ../../data \ 76 | --path_experiment_root ../../results/1_mixup \ 77 | --model_segm unet_lext \ 78 | --input_channels 1 \ 79 | --output_channels 5 \ 80 | --center_depth 1 \ 81 | --lr_segm 0.001 \ 82 | --batch_size 32 \ 83 | --epoch_num 50 \ 84 | --fold_num 5 \ 85 | --fold_idx -1 \ 86 | --with_mixup \ 87 | --mixup_alpha 0.7 \ 88 | --num_workers 12 \ 89 | ) 90 | 91 | (cd ../rocaseg/ && 92 | echo "train" && 93 | python train_baseline.py \ 94 | --path_data_root ../../data \ 95 | --path_experiment_root ../../results/2_mixup_nowd \ 96 | --model_segm unet_lext \ 97 | --input_channels 1 \ 98 | --output_channels 5 \ 99 | --center_depth 1 \ 100 | --lr_segm 0.001 \ 101 | --wd_segm 0.0 \ 102 | --batch_size 32 \ 103 | --epoch_num 50 \ 104 | --fold_num 5 \ 105 | --fold_idx -1 \ 106 | --with_mixup \ 107 | --mixup_alpha 0.7 \ 108 | --num_workers 12 \ 109 | ) 110 | 111 | (cd ../rocaseg/ && 112 | echo "train" && 113 | python train_uda1.py \ 114 | --path_data_root ../../data \ 115 | --path_experiment_root ../../results/3_uda1 \ 116 | --model_segm unet_lext \ 117 | --center_depth 1 \ 118 | --model_discr discriminator_a \ 119 | --input_channels 1 \ 120 | --output_channels 5 \ 121 | --mask_mode all_unitibial_unimeniscus \ 122 | --loss_segm multi_ce_loss \ 123 | --lr_segm 0.0001 \ 124 | --lr_discr 0.00004 \ 125 | --batch_size 32 \ 126 | --epoch_num 30 \ 127 | --fold_num 5 \ 128 | --fold_idx 0 \ 129 | --num_workers 12 \ 130 | ) 131 | 132 | (cd ../rocaseg/ && 133 | echo "train" && 134 | python train_uda1.py \ 135 | --path_data_root ../../data \ 136 | --path_experiment_root ../../results/4_uda2 \ 137 | --model_segm unet_lext_aux \ 138 | --center_depth 1 \ 139 | --model_discr_out discriminator_a \ 140 | --model_discr_aux discriminator_a \ 141 | --input_channels 1 \ 142 | --output_channels 5 \ 143 | --mask_mode all_unitibial_unimeniscus \ 144 | --loss_segm multi_ce_loss \ 145 | --lr_segm 0.0001 \ 146 | --lr_discr 0.00004 \ 147 | --batch_size 32 \ 148 | --epoch_num 30 \ 149 | --fold_num 5 \ 150 | --fold_idx 0 \ 151 | --num_workers 12 \ 152 | ) 153 | 154 | (cd ../rocaseg/ && 155 | echo "train" && 156 | python train_uda1.py \ 157 | --path_data_root ../../data \ 158 | --path_experiment_root ../../results/5_uda1_mixup_nowd \ 159 | --model_segm unet_lext \ 160 | --center_depth 1 \ 161 | --model_discr discriminator_a \ 162 | --input_channels 1 \ 163 | --output_channels 5 \ 164 | --mask_mode all_unitibial_unimeniscus \ 165 | --loss_segm multi_ce_loss \ 166 | --lr_segm 0.0001 \ 167 | --lr_discr 0.00004 \ 168 | --wd_segm 0.0 \ 169 | --batch_size 32 \ 170 | --epoch_num 30 \ 171 | --fold_num 5 \ 172 | --fold_idx 0 \ 173 | --num_workers 12 \ 174 | --with_mixup \ 175 | --mixup_alpha 0.7 \ 176 | ) 177 | # -------------------------------------------------------------------------------------- 178 | 179 | # ------------------------------- Run model inference ---------------------------------- 180 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 5_uda1_mixup_nowd 181 | do 182 | (cd ../rocaseg/ && 183 | echo "evaluate" && 184 | python evaluate.py \ 185 | --path_data_root ../../data \ 186 | --path_experiment_root ../../results/${EXP} \ 187 | --model_segm unet_lext \ 188 | --center_depth 1 \ 189 | --restore_weights \ 190 | --output_channels 5 \ 191 | --dataset oai_imo \ 192 | --subset test \ 193 | --mask_mode all_unitibial_unimeniscus \ 194 | --batch_size 64 \ 195 | --fold_num 5 \ 196 | --fold_idx -1 \ 197 | --num_workers 12 \ 198 | --predict_folds \ 199 | --merge_predictions \ 200 | ) 201 | done 202 | 203 | for EXP in 4_uda2 204 | do 205 | (cd ../rocaseg/ && 206 | echo "evaluate" && 207 | python evaluate.py \ 208 | --path_data_root ../../data \ 209 | --path_experiment_root ../../results/${EXP} \ 210 | --model_segm unet_lext_aux \ 211 | --center_depth 1 \ 212 | --restore_weights \ 213 | --output_channels 5 \ 214 | --dataset oai_imo \ 215 | --subset test \ 216 | --mask_mode all_unitibial_unimeniscus \ 217 | --batch_size 64 \ 218 | --fold_num 5 \ 219 | --fold_idx -1 \ 220 | --num_workers 12 \ 221 | --predict_folds \ 222 | --merge_predictions \ 223 | ) 224 | done 225 | 226 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 5_uda1_mixup_nowd 227 | do 228 | (cd ../rocaseg/ && 229 | echo "evaluate" && 230 | python evaluate.py \ 231 | --path_data_root ../../data \ 232 | --path_experiment_root ../../results/${EXP} \ 233 | --model_segm unet_lext \ 234 | --center_depth 1 \ 235 | --restore_weights \ 236 | --output_channels 5 \ 237 | --dataset okoa \ 238 | --subset all \ 239 | --mask_mode background_femoral_unitibial \ 240 | --batch_size 64 \ 241 | --fold_num 5 \ 242 | --fold_idx -1 \ 243 | --num_workers 12 \ 244 | --predict_folds \ 245 | --merge_predictions \ 246 | ) 247 | done 248 | 249 | for EXP in 4_uda2 250 | do 251 | (cd ../rocaseg/ && 252 | echo "evaluate" && 253 | python evaluate.py \ 254 | --path_data_root ../../data \ 255 | --path_experiment_root ../../results/${EXP} \ 256 | --model_segm unet_lext_aux \ 257 | --center_depth 1 \ 258 | --restore_weights \ 259 | --output_channels 5 \ 260 | --dataset okoa \ 261 | --subset all \ 262 | --mask_mode background_femoral_unitibial \ 263 | --batch_size 64 \ 264 | --fold_num 5 \ 265 | --fold_idx -1 \ 266 | --num_workers 12 \ 267 | --predict_folds \ 268 | --merge_predictions \ 269 | ) 270 | done 271 | # -------------------------------------------------------------------------------------- 272 | 273 | # ------------------------------- Analyze model predictions ---------------------------- 274 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 4_uda2 5_uda1_mixup_nowd 275 | do 276 | (cd ../rocaseg/ && 277 | echo "analyze_predictions" && 278 | python analyze_predictions_single.py \ 279 | --path_experiment_root ../../results/${EXP} \ 280 | --dirname_pred mask_foldavg \ 281 | --dirname_true mask_prep \ 282 | --dataset oai_imo \ 283 | --atlas segm \ 284 | --ignore_cache \ 285 | --num_workers 12 \ 286 | ) 287 | done 288 | 289 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 4_uda2 5_uda1_mixup_nowd 290 | do 291 | (cd ../rocaseg/ && 292 | echo "analyze_predictions" && 293 | python analyze_predictions_single.py \ 294 | --path_experiment_root ../../results/${EXP} \ 295 | --dirname_pred mask_foldavg \ 296 | --dirname_true mask_prep \ 297 | --dataset okoa \ 298 | --atlas okoa \ 299 | --ignore_cache \ 300 | --num_workers 12 \ 301 | ) 302 | done 303 | # -------------------------------------------------------------------------------------- 304 | 305 | # --------------------------------- Compare models ------------------------------------- 306 | (cd ../rocaseg/ && 307 | echo "analyze_predictions_multi" && 308 | python analyze_predictions_multi.py \ 309 | --path_results_root ../../results \ 310 | --experiment_id 0_baseline \ 311 | --experiment_id 2_mixup_nowd \ 312 | --experiment_id 4_uda2 \ 313 | --dataset oai_imo \ 314 | --atlas segm \ 315 | --num_workers 12 \ 316 | ) 317 | 318 | (cd ../rocaseg/ && 319 | echo "analyze_predictions_multi" && 320 | python analyze_predictions_multi.py \ 321 | --path_results_root ../../results \ 322 | --experiment_id 0_baseline \ 323 | --experiment_id 2_mixup_nowd \ 324 | --experiment_id 4_uda2 \ 325 | --dataset okoa \ 326 | --atlas okoa \ 327 | --num_workers 12 \ 328 | ) 329 | # -------------------------------------------------------------------------------------- 330 | -------------------------------------------------------------------------------- /rocaseg/preproc/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import cv2 3 | import numpy as np 4 | import math 5 | import torch 6 | 7 | 8 | class DualCompose(object): 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, img, mask=None): 13 | for t in self.transforms: 14 | img, mask = t(img, mask) 15 | return img, mask 16 | 17 | 18 | class OneOf(object): 19 | def __init__(self, transforms, prob=.5): 20 | self.transforms = transforms 21 | self.prob = prob 22 | 23 | self.state = dict() 24 | self.randomize() 25 | 26 | def __call__(self, img, mask=None): 27 | if self.state['p'] < self.prob: 28 | img, mask = self.state['t'](img, mask) 29 | return img, mask 30 | 31 | def randomize(self): 32 | self.state['p'] = random.random() 33 | self.state['t'] = random.choice(self.transforms) 34 | self.state['t'].prob = 1. 35 | 36 | 37 | class OneOrOther(object): 38 | def __init__(self, first, second, prob=.5): 39 | self.first = first 40 | first.prob = 1. 41 | self.second = second 42 | second.prob = 1. 43 | self.prob = prob 44 | 45 | self.state = dict() 46 | self.randomize() 47 | 48 | def __call__(self, x, mask=None): 49 | if self.state['p'] < self.prob: 50 | x, mask = self.first(x, mask) 51 | else: 52 | x, mask = self.second(x, mask) 53 | return x, mask 54 | 55 | def randomize(self): 56 | self.state['p'] = random.random() 57 | 58 | 59 | class ImageOnly(object): 60 | def __init__(self, transform): 61 | self.transform = transform 62 | 63 | def __call__(self, img, mask=None): 64 | return self.transform(img, None)[0], mask 65 | 66 | 67 | class NoTransform(object): 68 | def __call__(self, *args): 69 | return args 70 | 71 | 72 | class ToTensor(object): 73 | def __call__(self, *args): 74 | return [torch.from_numpy(e) for e in args] 75 | 76 | 77 | class VerticalFlip(object): 78 | def __init__(self, prob=.5): 79 | self.prob = prob 80 | 81 | self.state = dict() 82 | self.randomize() 83 | 84 | def __call__(self, img, mask=None): 85 | """ 86 | 87 | Parameters 88 | ---------- 89 | img: (ch, d0, d1) ndarray 90 | mask: (ch, d0, d1) ndarray 91 | """ 92 | if self.state['p'] < self.prob: 93 | img = np.flip(img, axis=1) 94 | if mask is not None: 95 | mask = np.flip(mask, axis=1) 96 | return img, mask 97 | 98 | def randomize(self): 99 | self.state['p'] = random.random() 100 | 101 | 102 | class HorizontalFlip(object): 103 | def __init__(self, prob=.5): 104 | self.prob = prob 105 | 106 | self.state = dict() 107 | self.randomize() 108 | 109 | def __call__(self, img, mask=None): 110 | """ 111 | 112 | Parameters 113 | ---------- 114 | img: (ch, d0, d1) ndarray 115 | mask: (ch, d0, d1) ndarray 116 | """ 117 | if self.state['p'] < self.prob: 118 | img = np.flip(img, axis=2) 119 | if mask is not None: 120 | mask = np.flip(mask, axis=2) 121 | return img, mask 122 | 123 | def randomize(self): 124 | self.state['p'] = random.random() 125 | 126 | 127 | class Flip(object): 128 | def __init__(self, prob=.5): 129 | self.prob = prob 130 | 131 | self.state = dict() 132 | self.randomize() 133 | 134 | def __call__(self, img, mask=None): 135 | """ 136 | 137 | Parameters 138 | ---------- 139 | img: (ch, d0, d1) ndarray 140 | mask: (ch, d0, d1) ndarray 141 | """ 142 | if self.state['p'] < self.prob: 143 | if self.state['d'] in (-1, 0): 144 | img = np.flip(img, axis=1) 145 | if self.state['d'] in (-1, 1): 146 | img = np.flip(img, axis=2) 147 | if mask is not None: 148 | if self.state['d'] in (-1, 0): 149 | mask = np.flip(mask, axis=1) 150 | if self.state['d'] in (-1, 1): 151 | mask = np.flip(mask, axis=2) 152 | return img, mask 153 | 154 | def randomize(self): 155 | self.state['p'] = random.random() 156 | self.state['d'] = random.randint(-1, 1) 157 | 158 | 159 | class Scale(object): 160 | def __init__(self, ratio_range=(0.7, 1.2), prob=.5): 161 | self.ratio_range = ratio_range 162 | self.prob = prob 163 | 164 | self.state = dict() 165 | self.randomize() 166 | 167 | def __call__(self, img, mask=None): 168 | """ 169 | 170 | Parameters 171 | ---------- 172 | img: (ch, d0, d1) ndarray 173 | mask: (ch, d0, d1) ndarray 174 | """ 175 | if self.state['p'] < self.prob: 176 | ch, d0_i, d1_i = img.shape 177 | d0_o = math.floor(d1_i * self.state['r']) 178 | d0_o = d0_o + d0_o % 2 179 | d1_o = math.floor(d1_i * self.state['r']) 180 | d1_o = d1_o + d1_o % 2 181 | 182 | # img1 = cv2.copyMakeBorder(img, limit, limit, limit, limit, 183 | # borderType=cv2.BORDER_REFLECT_101) 184 | img = np.squeeze(img) 185 | img = cv2.resize(img, (d1_o, d0_o), interpolation=cv2.INTER_LINEAR) 186 | img = img[None, ...] 187 | 188 | if mask is not None: 189 | # msk1 = cv2.copyMakeBorder(mask, limit, limit, limit, limit, 190 | # borderType=cv2.BORDER_REFLECT_101) 191 | tmp = np.empty((mask.shape[0], d1_o, d0_o), dtype=mask.dtype) 192 | for idx_ch, mask_ch in enumerate(mask): 193 | tmp[idx_ch] = cv2.resize(mask_ch, (d1_o, d0_o), 194 | interpolation=cv2.INTER_NEAREST) 195 | mask = tmp 196 | return img, mask 197 | 198 | def randomize(self): 199 | self.state['p'] = random.random() 200 | self.state['r'] = round(random.uniform(*self.ratio_range), 2) 201 | 202 | 203 | class Crop(object): 204 | def __init__(self, output_size): 205 | assert isinstance(output_size, (int, tuple)) 206 | if isinstance(output_size, int): 207 | self.output_size = (output_size, output_size) 208 | elif isinstance(output_size, tuple): 209 | self.output_size = output_size 210 | else: 211 | raise ValueError('Incorrect value') 212 | # self.keep_size = keep_size 213 | # self.prob = prob 214 | 215 | self.state = dict() 216 | self.randomize() 217 | 218 | def __call__(self, img, mask=None): 219 | rows_in, cols_in = img.shape[1:] 220 | rows_out, cols_out = self.output_size 221 | rows_out = min(rows_in, rows_out) 222 | cols_out = min(cols_in, cols_out) 223 | 224 | r0 = math.floor(self.state['r0f'] * (rows_in - rows_out)) 225 | c0 = math.floor(self.state['c0f'] * (cols_in - cols_out)) 226 | r1 = r0 + rows_out 227 | c1 = c0 + cols_out 228 | 229 | img = np.ascontiguousarray(img[:, r0:r1, c0:c1]) 230 | if mask is not None: 231 | mask = np.ascontiguousarray(mask[:, r0:r1, c0:c1]) 232 | return img, mask 233 | 234 | def randomize(self): 235 | # self.state['p'] = random.random() 236 | self.state['r0f'] = random.random() 237 | self.state['c0f'] = random.random() 238 | 239 | 240 | class CenterCrop(object): 241 | def __init__(self, height, width): 242 | self.height = height 243 | self.width = width 244 | 245 | def __call__(self, img, mask=None): 246 | """ 247 | 248 | Parameters 249 | ---------- 250 | img: (ch, d0, d1) ndarray 251 | mask: (ch, d0, d1) ndarray 252 | """ 253 | c, h, w = img.shape 254 | dy = (h - self.height) // 2 255 | dx = (w - self.width) // 2 256 | 257 | y1 = dy 258 | y2 = y1 + self.height 259 | x1 = dx 260 | x2 = x1 + self.width 261 | img = np.ascontiguousarray(img[:, y1:y2, x1:x2]) 262 | if mask is not None: 263 | mask = np.ascontiguousarray(mask[:, y1:y2, x1:x2]) 264 | 265 | return img, mask 266 | 267 | 268 | class Pad(object): 269 | def __init__(self, dr, dc, **kwargs): 270 | self.dr = dr 271 | self.dc = dc 272 | self.kwargs = kwargs 273 | 274 | def __call__(self, img, mask=None): 275 | """ 276 | 277 | Parameters 278 | ---------- 279 | img: (ch, d0, d1) ndarray 280 | mask: (ch, d0, d1) ndarray 281 | """ 282 | pad_width = (0, 0), (self.dr,) * 2, (self.dc,) * 2 283 | img = np.pad(img, pad_width, **self.kwargs) 284 | if mask is not None: 285 | mask = np.pad(mask, pad_width, **self.kwargs) 286 | return img, mask 287 | 288 | 289 | class GammaCorrection(object): 290 | def __init__(self, gamma_range=(0.5, 2), prob=0.5): 291 | self.gamma_range = gamma_range 292 | self.prob = prob 293 | 294 | self.state = dict() 295 | self.randomize() 296 | 297 | def __call__(self, image, mask=None): 298 | """ 299 | 300 | Parameters 301 | ---------- 302 | img: (ch, d0, d1) ndarray 303 | mask: (ch, d0, d1) ndarray 304 | """ 305 | if self.state['p'] < self.prob: 306 | image = image ** (1 / self.state['gamma']) 307 | # TODO: implement also for integers 308 | image = np.clip(image, 0, 1) 309 | return image, mask 310 | 311 | def randomize(self): 312 | self.state['p'] = random.random() 313 | self.state['gamma'] = random.uniform(*self.gamma_range) 314 | 315 | 316 | class BilateralFilter(object): 317 | def __init__(self, d, sigma_color, sigma_space, prob=.5): 318 | self.d = d 319 | self.sigma_color = sigma_color 320 | self.sigma_space = sigma_space 321 | self.prob = prob 322 | 323 | self.state = dict() 324 | self.randomize() 325 | 326 | def __call__(self, img, mask=None): 327 | if self.state['p'] < self.prob: 328 | img = np.squeeze(img) 329 | img = cv2.bilateralFilter(img, self.d, self.sigma_color, self.sigma_space) 330 | img = img[None, ...] 331 | return img, mask 332 | 333 | def randomize(self): 334 | self.state['p'] = random.random() 335 | -------------------------------------------------------------------------------- /rocaseg/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from glob import glob 4 | 5 | import numpy as np 6 | from skimage.color import label2rgb 7 | from skimage import img_as_ubyte 8 | from tqdm import tqdm 9 | import click 10 | 11 | import cv2 12 | import tifffile 13 | import torch 14 | import torch.nn as nn 15 | from torch.utils.data.dataloader import DataLoader 16 | 17 | from rocaseg.datasets import sources_from_path 18 | from rocaseg.components import CheckpointHandler 19 | from rocaseg.components.formats import numpy_to_nifti, png_to_numpy 20 | from rocaseg.models import dict_models 21 | from rocaseg.preproc import * 22 | from rocaseg.repro import set_ultimate_seed 23 | 24 | 25 | # The fix is a workaround to PyTorch multiprocessing issue: 26 | # "RuntimeError: received 0 items of ancdata" 27 | torch.multiprocessing.set_sharing_strategy('file_system') 28 | 29 | cv2.ocl.setUseOpenCL(False) 30 | cv2.setNumThreads(0) 31 | 32 | logging.basicConfig() 33 | logger = logging.getLogger('eval') 34 | logger.setLevel(logging.INFO) 35 | 36 | set_ultimate_seed() 37 | 38 | if torch.cuda.is_available(): 39 | maybe_gpu = 'cuda' 40 | else: 41 | maybe_gpu = 'cpu' 42 | 43 | 44 | def predict_folds(config, loader, fold_idcs): 45 | """Evaluate the model versus each fold 46 | """ 47 | for fold_idx in fold_idcs: 48 | paths_weights_fold = dict() 49 | paths_weights_fold['segm'] = \ 50 | os.path.join(config['path_weights'], 'segm', f'fold_{fold_idx}') 51 | 52 | handlers_ckpt = dict() 53 | handlers_ckpt['segm'] = CheckpointHandler(paths_weights_fold['segm']) 54 | 55 | paths_ckpt_sel = dict() 56 | paths_ckpt_sel['segm'] = handlers_ckpt['segm'].get_last_ckpt() 57 | 58 | # Initialize and configure the model 59 | model = (dict_models[config['model_segm']] 60 | (input_channels=config['input_channels'], 61 | output_channels=config['output_channels'], 62 | center_depth=config['center_depth'], 63 | pretrained=config['pretrained'], 64 | restore_weights=config['restore_weights'], 65 | path_weights=paths_ckpt_sel['segm'])) 66 | model = nn.DataParallel(model).to(maybe_gpu) 67 | model.eval() 68 | 69 | with tqdm(total=len(loader), desc=f'Eval, fold {fold_idx}') as prog_bar: 70 | for i, data_batch in enumerate(loader): 71 | xs, ys_true = data_batch['xs'], data_batch['ys'] 72 | xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu) 73 | 74 | if config['model_segm'] == 'unet_lext': 75 | ys_pred = model(xs) 76 | elif config['model_segm'] == 'unet_lext_aux': 77 | ys_pred, _ = model(xs) 78 | else: 79 | msg = f"Unknown model {config['model_segm']}" 80 | raise ValueError(msg) 81 | 82 | ys_pred_softmax = nn.Softmax(dim=1)(ys_pred) 83 | ys_pred_softmax_np = ys_pred_softmax.detach().to('cpu').numpy() 84 | 85 | data_batch['pred_softmax'] = ys_pred_softmax_np 86 | 87 | # Rearrange the batch 88 | data_dicts = [{k: v[n] for k, v in data_batch.items()} 89 | for n in range(len(data_batch['image']))] 90 | 91 | for k, data_dict in enumerate(data_dicts): 92 | dir_base = os.path.join( 93 | config['path_predicts'], 94 | data_dict['patient'], data_dict['release'], data_dict['sequence']) 95 | fname_base = os.path.splitext( 96 | os.path.basename(data_dict['path_rel_image']))[0] 97 | 98 | # Save the predictions 99 | dir_predicts = os.path.join(dir_base, 'mask_folds') 100 | if not os.path.exists(dir_predicts): 101 | os.makedirs(dir_predicts) 102 | 103 | fname_full = os.path.join( 104 | dir_predicts, 105 | f'{fname_base}_fold_{fold_idx}.tiff') 106 | 107 | tmp = (data_dict['pred_softmax'] * 255).astype(np.uint8, casting='unsafe') 108 | tifffile.imsave(fname_full, tmp, compress=9) 109 | 110 | prog_bar.update(1) 111 | 112 | 113 | def merge_predictions(config, source, loader, dict_fns, 114 | save_plots=False, remove_foldw=False, convert_to_nifti=True): 115 | """Merge the predictions over all folds 116 | """ 117 | dir_source_root = source['path_root'] 118 | df_meta = loader.dataset.df_meta 119 | 120 | with tqdm(total=len(df_meta), desc='Merge') as prog_bar: 121 | for i, row in df_meta.iterrows(): 122 | dir_scan_predicts = os.path.join( 123 | config['path_predicts'], 124 | row['patient'], row['release'], row['sequence']) 125 | dir_image_prep = os.path.join(dir_scan_predicts, 'image_prep') 126 | dir_mask_prep = os.path.join(dir_scan_predicts, 'mask_prep') 127 | dir_mask_folds = os.path.join(dir_scan_predicts, 'mask_folds') 128 | dir_mask_foldavg = os.path.join(dir_scan_predicts, 'mask_foldavg') 129 | dir_vis_foldavg = os.path.join(dir_scan_predicts, 'vis_foldavg') 130 | 131 | for p in (dir_image_prep, dir_mask_prep, dir_mask_folds, dir_mask_foldavg, 132 | dir_vis_foldavg): 133 | if not os.path.exists(p): 134 | os.makedirs(p) 135 | 136 | # Find the corresponding prediction files 137 | fname_base = os.path.splitext(os.path.basename(row['path_rel_image']))[0] 138 | 139 | fnames_pred = glob(os.path.join(dir_mask_folds, f'{fname_base}_fold_*.*')) 140 | 141 | # Read the reference data 142 | image = cv2.imread( 143 | os.path.join(dir_source_root, row['path_rel_image']), 144 | cv2.IMREAD_GRAYSCALE) 145 | image = dict_fns['crop'](image[None, ])[0] 146 | image = np.squeeze(image) 147 | if 'path_rel_mask' in row.index: 148 | ys_true = loader.dataset.read_mask( 149 | os.path.join(dir_source_root, row['path_rel_mask'])) 150 | if ys_true is not None: 151 | ys_true = dict_fns['crop'](ys_true)[0] 152 | else: 153 | ys_true = None 154 | 155 | # Read the fold-wise predictions 156 | yss_pred = [tifffile.imread(f) for f in fnames_pred] 157 | ys_pred = np.stack(yss_pred, axis=0).astype(np.float32) / 255 158 | ys_pred = torch.from_numpy(ys_pred).unsqueeze(dim=0) 159 | 160 | # Average the fold predictions 161 | ys_pred = torch.mean(ys_pred, dim=1, keepdim=False) 162 | ys_pred_softmax = ys_pred / torch.sum(ys_pred, dim=1, keepdim=True) 163 | ys_pred_softmax_np = ys_pred_softmax.squeeze().numpy() 164 | 165 | ys_pred_arg_np = ys_pred_softmax_np.argmax(axis=0) 166 | 167 | # Save preprocessed input data 168 | fname_full = os.path.join(dir_image_prep, f'{fname_base}.png') 169 | cv2.imwrite(fname_full, image) # image 170 | 171 | if ys_true is not None: 172 | ys_true = ys_true.astype(np.float32) 173 | ys_true = torch.from_numpy(ys_true).unsqueeze(dim=0) 174 | ys_true_arg_np = ys_true.numpy().squeeze().argmax(axis=0) 175 | fname_full = os.path.join(dir_mask_prep, f'{fname_base}.png') 176 | cv2.imwrite(fname_full, ys_true_arg_np) # mask 177 | 178 | fname_meta = os.path.join(config['path_predicts'], 'meta_dynamic.csv') 179 | if not os.path.exists(fname_meta): 180 | df_meta.to_csv(fname_meta, index=False) # metainfo 181 | 182 | # Save ensemble prediction 183 | fname_full = os.path.join(dir_mask_foldavg, f'{fname_base}.png') 184 | cv2.imwrite(fname_full, ys_pred_arg_np) 185 | 186 | # Save ensemble visualizations 187 | if save_plots: 188 | if ys_true is not None: 189 | fname_full = os.path.join( 190 | dir_vis_foldavg, f"{fname_base}_overlay_mask.png") 191 | save_vis_overlay(image=image, 192 | mask=ys_true_arg_np, 193 | num_classes=config['output_channels'], 194 | fname=fname_full) 195 | 196 | fname_full = os.path.join( 197 | dir_vis_foldavg, f"{fname_base}_overlay_pred.png") 198 | save_vis_overlay(image=image, 199 | mask=ys_pred_arg_np, 200 | num_classes=config['output_channels'], 201 | fname=fname_full) 202 | 203 | if ys_true is not None: 204 | fname_full = os.path.join( 205 | dir_vis_foldavg, f"{fname_base}_overlay_diff.png") 206 | save_vis_mask_diff(image=image, 207 | mask_true=ys_true_arg_np, 208 | mask_pred=ys_pred_arg_np, 209 | fname=fname_full) 210 | 211 | # Remove the fold predictions 212 | if remove_foldw: 213 | for f in fnames_pred: 214 | try: 215 | os.remove(f) 216 | except OSError: 217 | logger.error(f'Cannot remove {f}') 218 | prog_bar.update(1) 219 | 220 | # Convert the results to 3D NIfTI images 221 | if convert_to_nifti: 222 | df_meta = df_meta.sort_values(by=["patient", "release", "sequence", "side"]) 223 | 224 | for gb_name, gb_df in tqdm( 225 | df_meta.groupby(["patient", "release", "sequence", "side"]), 226 | desc="Convert to NIfTI"): 227 | 228 | patient, release, sequence, side = gb_name 229 | spacings = (gb_df['pixel_spacing_0'].iloc[0], 230 | gb_df['pixel_spacing_1'].iloc[0], 231 | gb_df['slice_thickness'].iloc[0]) 232 | 233 | dir_scan_predicts = os.path.join(config['path_predicts'], 234 | patient, release, sequence) 235 | for result in ("image_prep", "mask_prep", "mask_foldavg"): 236 | pattern = os.path.join(dir_scan_predicts, result, '*.png') 237 | path_nii = os.path.join(dir_scan_predicts, f"{result}.nii") 238 | 239 | # Read and compose 3D image 240 | img = png_to_numpy(pattern_fname_in=pattern, reverse=False) 241 | 242 | # Save to NIfTI 243 | numpy_to_nifti(stack=img, fname_out=path_nii, 244 | spacings=spacings, rcp_to_ras=True) 245 | 246 | 247 | def save_vis_overlay(image, mask, num_classes, fname): 248 | # Add a sample of each class to have consistent class colors 249 | mask[0, :num_classes] = list(range(num_classes)) 250 | overlay = label2rgb(label=mask, image=image, bg_label=0, 251 | colors=['orangered', 'gold', 'lime', 'fuchsia']) 252 | # Convert to uint8 to save space 253 | overlay = img_as_ubyte(overlay) 254 | # Save to file 255 | if overlay.ndim == 3: 256 | overlay = overlay[:, :, ::-1] 257 | cv2.imwrite(fname, overlay) 258 | 259 | 260 | def save_vis_mask_diff(image, mask_true, mask_pred, fname): 261 | diff = np.empty_like(mask_true) 262 | diff[(mask_true == mask_pred) & (mask_pred == 0)] = 0 # TN 263 | diff[(mask_true == mask_pred) & (mask_pred != 0)] = 0 # TP 264 | diff[(mask_true != mask_pred) & (mask_pred == 0)] = 2 # FP 265 | diff[(mask_true != mask_pred) & (mask_pred != 0)] = 3 # FN 266 | diff_colors = ('green', 'red', 'yellow') 267 | diff[0, :4] = [0, 1, 2, 3] 268 | overlay = label2rgb(label=diff, image=image, bg_label=0, 269 | colors=diff_colors) 270 | # Convert to uint8 to save space 271 | overlay = img_as_ubyte(overlay) 272 | # Save to file 273 | if overlay.ndim == 3: 274 | overlay = overlay[:, :, ::-1] 275 | cv2.imwrite(fname, overlay) 276 | 277 | 278 | @click.command() 279 | @click.option('--path_data_root', default='../../data') 280 | @click.option('--path_experiment_root', default='../../results/temporary') 281 | @click.option('--model_segm', default='unet_lext') 282 | @click.option('--center_depth', default=1, type=int) 283 | @click.option('--pretrained', is_flag=True) 284 | @click.option('--restore_weights', is_flag=True) 285 | @click.option('--input_channels', default=1, type=int) 286 | @click.option('--output_channels', default=1, type=int) 287 | @click.option('--dataset', type=click.Choice( 288 | ['oai_imo', 'okoa', 'maknee'])) 289 | @click.option('--subset', type=click.Choice( 290 | ['test', 'all'])) 291 | @click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str) 292 | @click.option('--sample_mode', default='x_y', type=str) 293 | @click.option('--batch_size', default=64, type=int) 294 | @click.option('--fold_num', default=5, type=int) 295 | @click.option('--fold_idx', default=-1, type=int) 296 | @click.option('--fold_idx_ignore', multiple=True, type=int) 297 | @click.option('--num_workers', default=1, type=int) 298 | @click.option('--seed_trainval_test', default=0, type=int) 299 | @click.option('--predict_folds', is_flag=True) 300 | @click.option('--merge_predictions', is_flag=True) 301 | @click.option('--save_plots', is_flag=True) 302 | def main(**config): 303 | config['path_data_root'] = os.path.abspath(config['path_data_root']) 304 | config['path_experiment_root'] = os.path.abspath(config['path_experiment_root']) 305 | 306 | config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights') 307 | if not os.path.exists(config['path_weights']): 308 | raise ValueError('{} does not exist'.format(config['path_weights'])) 309 | 310 | config['path_predicts'] = os.path.join( 311 | config['path_experiment_root'], f"predicts_{config['dataset']}_test") 312 | config['path_logs'] = os.path.join( 313 | config['path_experiment_root'], f"logs_{config['dataset']}_test") 314 | 315 | os.makedirs(config['path_predicts'], exist_ok=True) 316 | os.makedirs(config['path_logs'], exist_ok=True) 317 | 318 | logging_fh = logging.FileHandler( 319 | os.path.join(config['path_logs'], 'main.log')) 320 | logging_fh.setLevel(logging.DEBUG) 321 | logger.addHandler(logging_fh) 322 | 323 | # Collect the available and specified sources 324 | sources = sources_from_path(path_data_root=config['path_data_root'], 325 | selection=config['dataset'], 326 | with_folds=True, 327 | seed_trainval_test=config['seed_trainval_test']) 328 | 329 | # Select the subset for evaluation 330 | if config['subset'] == 'test': 331 | logging.warning('Using the regular trainval-test split') 332 | elif config['subset'] == 'all': 333 | logging.warning('Using data selection: full dataset') 334 | for s in sources: 335 | sources[s]['test_df'] = sources[s]['sel_df'] 336 | logger.info(f"Selected number of samples: {len(sources[s]['test_df'])}") 337 | else: 338 | raise ValueError(f"Unknown dataset: {config['subset']}") 339 | 340 | if config['dataset'] == 'oai_imo': 341 | from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d 342 | elif config['dataset'] == 'okoa': 343 | from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d 344 | elif config['dataset'] == 'maknee': 345 | from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d 346 | else: 347 | raise ValueError(f"Unknown dataset: {config['dataset']}") 348 | 349 | # Configure dataset-dependent transforms 350 | fn_crop = CenterCrop(height=300, width=300) 351 | if config['dataset'] == 'oai_imo': 352 | fn_norm = Normalize(mean=0.252699, std=0.251142) 353 | fn_unnorm = UnNormalize(mean=0.252699, std=0.251142) 354 | elif config['dataset'] == 'okoa': 355 | fn_norm = Normalize(mean=0.232454, std=0.236259) 356 | fn_unnorm = UnNormalize(mean=0.232454, std=0.236259) 357 | else: 358 | msg = f"No transforms defined for dataset: {config['dataset']}" 359 | raise NotImplementedError(msg) 360 | dict_fns = {'crop': fn_crop, 'norm': fn_norm, 'unnorm': fn_unnorm} 361 | 362 | dataset_test = DatasetSagittal2d( 363 | df_meta=sources[config['dataset']]['test_df'], mask_mode=config['mask_mode'], 364 | name=config['dataset'], sample_mode=config['sample_mode'], 365 | transforms=[ 366 | PercentileClippingAndToFloat(cut_min=10, cut_max=99), 367 | fn_crop, 368 | fn_norm, 369 | ToTensor() 370 | ]) 371 | loader_test = DataLoader(dataset_test, 372 | batch_size=config['batch_size'], 373 | shuffle=False, 374 | num_workers=config['num_workers'], 375 | drop_last=False) 376 | 377 | # Build a list of folds to run on 378 | if config['fold_idx'] == -1: 379 | fold_idcs = list(range(config['fold_num'])) 380 | else: 381 | fold_idcs = [config['fold_idx'], ] 382 | for g in config['fold_idx_ignore']: 383 | fold_idcs = [i for i in fold_idcs if i != g] 384 | 385 | # Execute 386 | with torch.no_grad(): 387 | if config['predict_folds']: 388 | predict_folds(config=config, loader=loader_test, fold_idcs=fold_idcs) 389 | 390 | if config['merge_predictions']: 391 | merge_predictions(config=config, source=sources[config['dataset']], 392 | loader=loader_test, dict_fns=dict_fns, 393 | save_plots=config['save_plots'], remove_foldw=False, 394 | convert_to_nifti=True) 395 | 396 | 397 | if __name__ == '__main__': 398 | main() 399 | -------------------------------------------------------------------------------- /rocaseg/train_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from collections import defaultdict 4 | import click 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | import torch 10 | from torch import nn 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tqdm import tqdm 14 | 15 | from rocaseg.datasets import DatasetOAIiMoSagittal2d, sources_from_path 16 | from rocaseg.models import dict_models 17 | from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm, 18 | dict_optimizers, CheckpointHandler) 19 | from rocaseg.preproc import * 20 | from rocaseg.repro import set_ultimate_seed 21 | from rocaseg.components.mixup import mixup_criterion, mixup_data 22 | 23 | 24 | cv2.ocl.setUseOpenCL(False) 25 | cv2.setNumThreads(0) 26 | 27 | logging.basicConfig() 28 | logger = logging.getLogger('train') 29 | logger.setLevel(logging.DEBUG) 30 | 31 | set_ultimate_seed() 32 | 33 | if torch.cuda.is_available(): 34 | maybe_gpu = 'cuda' 35 | else: 36 | maybe_gpu = 'cpu' 37 | 38 | 39 | class ModelTrainer: 40 | def __init__(self, config, fold_idx=None): 41 | self.config = config 42 | self.fold_idx = fold_idx 43 | 44 | self.paths_weights_fold = dict() 45 | self.paths_weights_fold['segm'] = \ 46 | os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}') 47 | os.makedirs(self.paths_weights_fold['segm'], exist_ok=True) 48 | 49 | self.path_logs_fold = \ 50 | os.path.join(config['path_logs'], f'fold_{self.fold_idx}') 51 | os.makedirs(self.path_logs_fold, exist_ok=True) 52 | 53 | self.handlers_ckpt = dict() 54 | self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm']) 55 | 56 | paths_ckpt_sel = dict() 57 | paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt() 58 | 59 | # Initialize and configure the models 60 | self.models = dict() 61 | self.models['segm'] = (dict_models[config['model_segm']] 62 | (input_channels=self.config['input_channels'], 63 | output_channels=self.config['output_channels'], 64 | center_depth=self.config['center_depth'], 65 | pretrained=self.config['pretrained'], 66 | path_pretrained=self.config['path_pretrained_segm'], 67 | restore_weights=self.config['restore_weights'], 68 | path_weights=paths_ckpt_sel['segm'])) 69 | self.models['segm'] = nn.DataParallel(self.models['segm']) 70 | self.models['segm'] = self.models['segm'].to(maybe_gpu) 71 | 72 | # Configure the training 73 | self.optimizers = dict() 74 | self.optimizers['segm'] = (dict_optimizers['adam']( 75 | self.models['segm'].parameters(), 76 | lr=self.config['lr_segm'], 77 | weight_decay=self.config['wd_segm'])) 78 | 79 | self.lr_update_rule = {30: 0.1} 80 | 81 | self.losses = dict() 82 | self.losses['segm'] = dict_losses[self.config['loss_segm']]( 83 | num_classes=self.config['output_channels'], 84 | ) 85 | 86 | self.losses['segm'] = self.losses['segm'].to(maybe_gpu) 87 | 88 | self.tensorboard = SummaryWriter(self.path_logs_fold) 89 | 90 | def run_one_epoch(self, epoch_idx, loaders): 91 | name_ds = list(loaders.keys())[0] 92 | 93 | fnames_acc = defaultdict(list) 94 | metrics_acc = dict() 95 | metrics_acc['samplew'] = defaultdict(list) 96 | metrics_acc['batchw'] = defaultdict(list) 97 | metrics_acc['datasetw'] = defaultdict(list) 98 | metrics_acc['datasetw'][f'{name_ds}__cm'] = \ 99 | np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32) 100 | 101 | prog_bar_params = {'postfix': {'epoch': epoch_idx}, } 102 | 103 | if self.models['segm'].training: 104 | # ------------------------ Training regime ------------------------ 105 | loader_ds = loaders[name_ds]['train'] 106 | 107 | steps_ds = len(loader_ds) 108 | prog_bar_params.update({'total': steps_ds, 109 | 'desc': f'Train, epoch {epoch_idx}'}) 110 | 111 | loader_ds_iter = iter(loader_ds) 112 | 113 | with tqdm(**prog_bar_params) as prog_bar: 114 | for step_idx in range(steps_ds): 115 | self.optimizers['segm'].zero_grad() 116 | 117 | data_batch_ds = next(loader_ds_iter) 118 | 119 | xs_ds, ys_true_ds = data_batch_ds['xs'], data_batch_ds['ys'] 120 | fnames_acc['oai'].extend(data_batch_ds['path_image']) 121 | 122 | ys_true_arg_ds = torch.argmax(ys_true_ds.long(), dim=1) 123 | xs_ds = xs_ds.to(maybe_gpu) 124 | ys_true_arg_ds = ys_true_arg_ds.to(maybe_gpu) 125 | 126 | if not self.config['with_mixup']: 127 | ys_pred_ds = self.models['segm'](xs_ds) 128 | 129 | loss_segm = self.losses['segm'](input_=ys_pred_ds, 130 | target=ys_true_arg_ds) 131 | else: 132 | xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( 133 | x=xs_ds, y=ys_true_arg_ds, 134 | alpha=self.config['mixup_alpha'], device=maybe_gpu) 135 | ys_pred_ds = self.models['segm'](xs_mixup) 136 | loss_segm = mixup_criterion(criterion=self.losses['segm'], 137 | pred=ys_pred_ds, 138 | y_a=ys_mixup_a, 139 | y_b=ys_mixup_b, 140 | lam=lambda_mixup) 141 | 142 | metrics_acc['batchw']['loss'].append(loss_segm.item()) 143 | 144 | loss_segm.backward() 145 | self.optimizers['segm'].step() 146 | 147 | prog_bar.update(1) 148 | else: 149 | # ----------------------- Validation regime ----------------------- 150 | loader_ds = loaders[name_ds]['val'] 151 | 152 | steps_ds = len(loader_ds) 153 | prog_bar_params.update({'total': steps_ds, 154 | 'desc': f'Validate, epoch {epoch_idx}'}) 155 | 156 | loader_ds_iter = iter(loader_ds) 157 | 158 | with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar: 159 | for step_idx in range(steps_ds): 160 | data_batch_ds = next(loader_ds_iter) 161 | 162 | xs_ds, ys_true_ds = data_batch_ds['xs'], data_batch_ds['ys'] 163 | fnames_acc['oai'].extend(data_batch_ds['path_image']) 164 | 165 | ys_true_arg_ds = torch.argmax(ys_true_ds.long(), dim=1) 166 | xs_ds = xs_ds.to(maybe_gpu) 167 | ys_true_arg_ds = ys_true_arg_ds.to(maybe_gpu) 168 | 169 | if not self.config['with_mixup']: 170 | ys_pred_ds = self.models['segm'](xs_ds) 171 | loss_segm = self.losses['segm'](input_=ys_pred_ds, 172 | target=ys_true_arg_ds) 173 | else: 174 | xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( 175 | x=xs_ds, y=ys_true_arg_ds, 176 | alpha=self.config['mixup_alpha'], device=maybe_gpu) 177 | 178 | ys_pred_ds = self.models['segm'](xs_mixup) 179 | loss_segm = mixup_criterion(criterion=self.losses['segm'], 180 | pred=ys_pred_ds, 181 | y_a=ys_mixup_a, 182 | y_b=ys_mixup_b, 183 | lam=lambda_mixup) 184 | 185 | metrics_acc['batchw']['loss'].append(loss_segm.item()) 186 | 187 | # Calculate metrics 188 | ys_pred_softmax_ds = nn.Softmax(dim=1)(ys_pred_ds) 189 | ys_pred_softmax_np_ds = ys_pred_softmax_ds.to('cpu').numpy() 190 | 191 | ys_pred_arg_np_ds = ys_pred_softmax_np_ds.argmax(axis=1) 192 | ys_true_arg_np_ds = ys_true_arg_ds.to('cpu').numpy() 193 | 194 | metrics_acc['datasetw'][f'{name_ds}__cm'] += confusion_matrix( 195 | ys_pred_arg_np_ds, ys_true_arg_np_ds, 196 | self.config['output_channels']) 197 | 198 | prog_bar.update(1) 199 | 200 | for k, v in metrics_acc['samplew'].items(): 201 | metrics_acc['samplew'][k] = np.asarray(v) 202 | metrics_acc['datasetw'][f'{name_ds}__dice_score'] = np.asarray( 203 | dice_score_from_cm(metrics_acc['datasetw'][f'{name_ds}__cm'])) 204 | return metrics_acc, fnames_acc 205 | 206 | def fit(self, loaders): 207 | epoch_idx_best = -1 208 | loss_best = float('inf') 209 | metrics_train_best = dict() 210 | fnames_train_best = [] 211 | metrics_val_best = dict() 212 | fnames_val_best = [] 213 | 214 | for epoch_idx in range(self.config['epoch_num']): 215 | self.models = {n: m.train() for n, m in self.models.items()} 216 | metrics_train, fnames_train = \ 217 | self.run_one_epoch(epoch_idx, loaders) 218 | 219 | # Process the accumulated metrics 220 | for k, v in metrics_train['batchw'].items(): 221 | if k.startswith('loss'): 222 | metrics_train['datasetw'][k] = np.mean(np.asarray(v)) 223 | else: 224 | logger.warning(f'Non-processed batch-wise entry: {k}') 225 | 226 | self.models = {n: m.eval() for n, m in self.models.items()} 227 | metrics_val, fnames_val = \ 228 | self.run_one_epoch(epoch_idx, loaders) 229 | 230 | # Process the accumulated metrics 231 | for k, v in metrics_val['batchw'].items(): 232 | if k.startswith('loss'): 233 | metrics_val['datasetw'][k] = np.mean(np.asarray(v)) 234 | else: 235 | logger.warning(f'Non-processed batch-wise entry: {k}') 236 | 237 | # Learning rate update 238 | for s, m in self.lr_update_rule.items(): 239 | if epoch_idx == s: 240 | for name, optim in self.optimizers.items(): 241 | for param_group in optim.param_groups: 242 | param_group['lr'] *= m 243 | 244 | # Add console logging 245 | logger.info(f'Epoch: {epoch_idx}') 246 | for subset, metrics in (('train', metrics_train), 247 | ('val', metrics_val)): 248 | logger.info(f'{subset} metrics:') 249 | for k, v in metrics['datasetw'].items(): 250 | logger.info(f'{k}: \n{v}') 251 | 252 | # Add TensorBoard logging 253 | for subset, metrics in (('train', metrics_train), 254 | ('val', metrics_val)): 255 | # Log only dataset-reduced metrics 256 | for k, v in metrics['datasetw'].items(): 257 | if isinstance(v, np.ndarray): 258 | self.tensorboard.add_scalars( 259 | f'fold_{self.fold_idx}/{k}_{subset}', 260 | {f'class{i}': e for i, e in enumerate(v.ravel().tolist())}, 261 | global_step=epoch_idx) 262 | elif isinstance(v, (str, int, float)): 263 | self.tensorboard.add_scalar( 264 | f'fold_{self.fold_idx}/{k}_{subset}', 265 | float(v), 266 | global_step=epoch_idx) 267 | else: 268 | logger.warning(f'{k} is of unsupported dtype {v}') 269 | for name, optim in self.optimizers.items(): 270 | for param_group in optim.param_groups: 271 | self.tensorboard.add_scalar( 272 | f'fold_{self.fold_idx}/learning_rate/{name}', 273 | param_group['lr'], 274 | global_step=epoch_idx) 275 | 276 | # Save the model 277 | loss_curr = metrics_val['datasetw']['loss'] 278 | if loss_curr < loss_best: 279 | loss_best = loss_curr 280 | epoch_idx_best = epoch_idx 281 | metrics_train_best = metrics_train 282 | metrics_val_best = metrics_val 283 | fnames_train_best = fnames_train 284 | fnames_val_best = fnames_val 285 | 286 | self.handlers_ckpt['segm'].save_new_ckpt( 287 | model=self.models['segm'], 288 | model_name=self.config['model_segm'], 289 | fold_idx=self.fold_idx, 290 | epoch_idx=epoch_idx) 291 | 292 | msg = (f'Finished fold {self.fold_idx} ' 293 | f'with the best loss {loss_best:.5f} ' 294 | f'on epoch {epoch_idx_best}, ' 295 | f'weights: ({self.paths_weights_fold})') 296 | logger.info(msg) 297 | return (metrics_train_best, fnames_train_best, 298 | metrics_val_best, fnames_val_best) 299 | 300 | 301 | @click.command() 302 | @click.option('--path_data_root', default='../../data') 303 | @click.option('--path_experiment_root', default='../../results/temporary') 304 | @click.option('--model_segm', default='unet_lext') 305 | @click.option('--center_depth', default=1, type=int) 306 | @click.option('--pretrained', is_flag=True) 307 | @click.option('--path_pretrained_segm', type=str, help='Path to .pth file') 308 | @click.option('--restore_weights', is_flag=True) 309 | @click.option('--input_channels', default=1, type=int) 310 | @click.option('--output_channels', default=1, type=int) 311 | @click.option('--dataset', type=click.Choice( 312 | ['oai_imo', 'okoa', 'maknee']), default='oai_imo') 313 | @click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str) 314 | @click.option('--sample_mode', default='x_y', type=str) 315 | @click.option('--loss_segm', default='multi_ce_loss') 316 | @click.option('--lr_segm', default=0.0001, type=float) 317 | @click.option('--wd_segm', default=5e-5, type=float) 318 | @click.option('--optimizer_segm', default='adam') 319 | @click.option('--batch_size', default=64, type=int) 320 | @click.option('--epoch_size', default=1.0, type=float) 321 | @click.option('--epoch_num', default=2, type=int) 322 | @click.option('--fold_num', default=5, type=int) 323 | @click.option('--fold_idx', default=-1, type=int) 324 | @click.option('--fold_idx_ignore', multiple=True, type=int) 325 | @click.option('--num_workers', default=1, type=int) 326 | @click.option('--seed_trainval_test', default=0, type=int) 327 | @click.option('--with_mixup', is_flag=True) 328 | @click.option('--mixup_alpha', default=1, type=float) 329 | def main(**config): 330 | config['path_data_root'] = os.path.abspath(config['path_data_root']) 331 | config['path_experiment_root'] = os.path.abspath(config['path_experiment_root']) 332 | 333 | config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights') 334 | config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train') 335 | os.makedirs(config['path_weights'], exist_ok=True) 336 | os.makedirs(config['path_logs'], exist_ok=True) 337 | 338 | logging_fh = logging.FileHandler( 339 | os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx']))) 340 | logging_fh.setLevel(logging.DEBUG) 341 | logger.addHandler(logging_fh) 342 | 343 | # Collect the available and specified sources 344 | sources = sources_from_path(path_data_root=config['path_data_root'], 345 | selection=config['dataset'], 346 | with_folds=True, 347 | fold_num=config['fold_num'], 348 | seed_trainval_test=config['seed_trainval_test']) 349 | 350 | # Build a list of folds to run on 351 | if config['fold_idx'] == -1: 352 | fold_idcs = list(range(config['fold_num'])) 353 | else: 354 | fold_idcs = [config['fold_idx'], ] 355 | for g in config['fold_idx_ignore']: 356 | fold_idcs = [i for i in fold_idcs if i != g] 357 | 358 | # Train each fold separately 359 | fold_scores = dict() 360 | 361 | # Use straightforward fold allocation strategy 362 | folds = list(sources[config['dataset']]['trainval_folds']) 363 | 364 | for fold_idx, idcs_subsets in enumerate(folds): 365 | if fold_idx not in fold_idcs: 366 | continue 367 | logger.info(f'Training fold {fold_idx}') 368 | 369 | name_ds = config['dataset'] 370 | 371 | (sources[name_ds]['train_idcs'], sources[name_ds]['val_idcs']) = idcs_subsets 372 | 373 | sources[name_ds]['train_df'] = \ 374 | sources[name_ds]['trainval_df'].iloc[sources[name_ds]['train_idcs']] 375 | sources[name_ds]['val_df'] = \ 376 | sources[name_ds]['trainval_df'].iloc[sources[name_ds]['val_idcs']] 377 | 378 | for n, s in sources.items(): 379 | logger.info('Made {} train-val split, number of samples: {}, {}' 380 | .format(n, len(s['train_df']), len(s['val_df']))) 381 | 382 | datasets = defaultdict(dict) 383 | 384 | datasets[name_ds]['train'] = DatasetOAIiMoSagittal2d( 385 | df_meta=sources[name_ds]['train_df'], 386 | mask_mode=config['mask_mode'], 387 | sample_mode=config['sample_mode'], 388 | transforms=[ 389 | PercentileClippingAndToFloat(cut_min=10, cut_max=99), 390 | CenterCrop(height=300, width=300), 391 | HorizontalFlip(prob=.5), 392 | GammaCorrection(gamma_range=(0.5, 1.5), prob=.5), 393 | OneOf([ 394 | DualCompose([ 395 | Scale(ratio_range=(0.7, 0.8), prob=1.), 396 | Scale(ratio_range=(1.5, 1.6), prob=1.), 397 | ]), 398 | NoTransform() 399 | ]), 400 | Crop(output_size=(300, 300)), 401 | BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3), 402 | Normalize(mean=0.252699, std=0.251142), 403 | ToTensor(), 404 | ]) 405 | datasets[name_ds]['val'] = DatasetOAIiMoSagittal2d( 406 | df_meta=sources[name_ds]['val_df'], 407 | mask_mode=config['mask_mode'], 408 | sample_mode=config['sample_mode'], 409 | transforms=[ 410 | PercentileClippingAndToFloat(cut_min=10, cut_max=99), 411 | CenterCrop(height=300, width=300), 412 | Normalize(mean=0.252699, std=0.251142), 413 | ToTensor() 414 | ]) 415 | 416 | loaders = defaultdict(dict) 417 | 418 | loaders[name_ds]['train'] = DataLoader( 419 | datasets[name_ds]['train'], 420 | batch_size=config['batch_size'], 421 | shuffle=True, 422 | num_workers=config['num_workers'], 423 | drop_last=True) 424 | loaders[name_ds]['val'] = DataLoader( 425 | datasets[name_ds]['val'], 426 | batch_size=config['batch_size'], 427 | shuffle=False, 428 | num_workers=config['num_workers'], 429 | drop_last=True) 430 | 431 | trainer = ModelTrainer(config=config, fold_idx=fold_idx) 432 | 433 | # INFO: run once before the training to compute the dataset statistics 434 | # dataset_train.describe() 435 | 436 | tmp = trainer.fit(loaders=loaders) 437 | metrics_train, fnames_train, metrics_val, fnames_val = tmp 438 | 439 | fold_scores[fold_idx] = (metrics_val['datasetw'][f'{name_ds}__dice_score'], ) 440 | 441 | trainer.tensorboard.close() 442 | logger.info(f'Fold scores:\n{repr(fold_scores)}') 443 | 444 | 445 | if __name__ == '__main__': 446 | main() 447 | -------------------------------------------------------------------------------- /rocaseg/analyze_predictions_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import click 5 | import logging 6 | from joblib import Parallel, delayed 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | 13 | from rocaseg.components import dice_score 14 | from rocaseg.components.formats import png_to_numpy 15 | from rocaseg.datasets.constants import atlas_to_locations 16 | 17 | 18 | logging.basicConfig() 19 | logger = logging.getLogger('analyze') 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | def metrics_paired_slicew(path_pred, dirname_true, dirname_pred, 24 | df, num_classes, num_workers=1): 25 | def _process_3d_pair(path_root, name_true, name_pred, meta): 26 | # Read the data 27 | if meta['side'] == 'RIGHT': 28 | reverse = True 29 | else: 30 | reverse = False 31 | 32 | patt_true = os.path.join(path_root, meta['patient'], meta['release'], 33 | meta['sequence'], name_true, '*.png') 34 | stack_true = png_to_numpy(patt_true, reverse=reverse) 35 | 36 | patt_pred = os.path.join(path_root, meta['patient'], meta['release'], 37 | meta['sequence'], name_pred, '*.png') 38 | stack_pred = png_to_numpy(patt_pred, reverse=reverse) 39 | if stack_true.shape != stack_pred.shape: 40 | msg = (f'Reference and predictions samples are of different shape: ' 41 | f'{name_true}: {stack_true.shape}, {name_pred}: {stack_pred.shape}') 42 | raise ValueError(msg) 43 | num_slices = stack_true.shape[-1] 44 | 45 | # Add batch dimension 46 | stack_true = stack_true[None, ...] 47 | stack_pred = stack_pred[None, ...] 48 | 49 | # Compute the metrics 50 | res = [] 51 | for slice_idx in range(num_slices): 52 | tmp = { 53 | 'dice_score': dice_score(stack_pred[..., slice_idx], 54 | stack_true[..., slice_idx], 55 | num_classes=num_classes), 56 | 'slice_idx_proc': slice_idx, 57 | **meta, 58 | } 59 | res.append(tmp) 60 | return res 61 | 62 | acc_ls = [] 63 | groupers_stack = ['patient', 'release', 'sequence', 'side'] 64 | 65 | for name_gb, df_gb in tqdm(df.groupby(groupers_stack)): 66 | patient, release, sequence, side = name_gb 67 | 68 | acc_ls.append(delayed(_process_3d_pair)( 69 | path_root=path_pred, 70 | name_true=dirname_true, 71 | name_pred=dirname_pred, 72 | meta={ 73 | 'patient': patient, 74 | 'release': release, 75 | 'sequence': sequence, 76 | 'side': side, 77 | **df_gb.to_dict("records")[0], 78 | } 79 | )) 80 | 81 | acc_ls = Parallel(n_jobs=num_workers, verbose=1)(acc_ls) 82 | acc_l = [] 83 | for acc in acc_ls: 84 | acc_l.extend(acc) 85 | 86 | # Convert from list of dicts to dict of lists 87 | acc_d = {k: [d[k] for d in acc_l] for k in acc_l[0]} 88 | return acc_d 89 | 90 | 91 | def plot_metrics_paired_slicew_xvd7(acc, path_root_out, num_classes, class_names, 92 | metric_names, config): 93 | """Average and visualize the metrics. 94 | 95 | Args: 96 | acc: dict of lists 97 | path_root_out: str 98 | num_classes: int 99 | class_names: dict 100 | metric_names: iterable of str 101 | config: dict 102 | 103 | """ 104 | for metric_name in metric_names: 105 | if metric_name not in acc: 106 | logger.error(f'`{metric_name}` is not presented in `acc`') 107 | continue 108 | metric_values = acc[metric_name] 109 | 110 | ncols = 2 111 | nrows = (num_classes - 1) // ncols + 1 112 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 12)) 113 | axes = axes.ravel() 114 | 115 | tmp = np.stack(metric_values, axis=-1) # shape: (class_idx, sample_idx) 116 | tmp_means, tmp_stds = [], [] 117 | 118 | for class_idx, class_scores in enumerate(tmp[1:], start=1): 119 | class_idx_axes = class_idx - 1 120 | class_scores = class_scores[class_scores != 0] 121 | 122 | if np.any(np.isnan(class_scores)): 123 | logger.warning('NaN score') 124 | 125 | tmp_means.append(np.mean(class_scores)) 126 | tmp_stds.append(np.std(class_scores)) 127 | 128 | axes[class_idx_axes].hist(class_scores, bins=40, 129 | range=(0.3, 1.0), density=True) 130 | axes[class_idx_axes].set_title(class_names[class_idx]) 131 | axes[class_idx_axes].set_ylim((0, 10)) 132 | 133 | tmp_values = ', '.join( 134 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)]) 135 | logger.info(f"{metric_name}:\n{tmp_values}\n") 136 | 137 | plt.tight_layout() 138 | fname_vis = f"metrics_{config['dataset']}_test_slicew_{metric_name}.png" 139 | path_vis = os.path.join(path_root_out, fname_vis) 140 | plt.savefig(path_vis) 141 | if config['interactive']: 142 | plt.show() 143 | else: 144 | plt.close() 145 | 146 | 147 | def print_metrics_paired_slicew_pua5(acc, metric_names): 148 | """Average and print the metrics with respect to KL. 149 | 150 | Args: 151 | acc: dict of lists 152 | metric_names: iterable of str 153 | 154 | """ 155 | kl_vec = np.asarray(acc['KL']) 156 | 157 | for kl_value in np.unique(kl_vec): 158 | kl_sel = kl_vec == kl_value 159 | 160 | for metric_name in metric_names: 161 | if metric_name not in acc: 162 | logger.error(f'`{metric_name}` is not presented in `acc`') 163 | continue 164 | metric_values = acc[metric_name] 165 | tmp = np.stack(metric_values, axis=-1)[..., kl_sel] 166 | # shape: (class_idx, sample_idx) 167 | 168 | tmp_means, tmp_stds = [], [] 169 | for class_idx, class_scores in enumerate(tmp[1:], start=1): 170 | class_scores = class_scores[class_scores != 0] 171 | 172 | tmp_means.append(np.mean(class_scores)) 173 | tmp_stds.append(np.std(class_scores)) 174 | 175 | tmp_values = ', '.join( 176 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)]) 177 | logger.info(f"KL{kl_value}, {metric_name}:\n{tmp_values}\n") 178 | 179 | 180 | def plot_metrics_paired_slicew_va49(acc, path_root_out, num_classes, class_names, 181 | metric_names, config): 182 | """Average and visualize the metrics. 183 | 184 | Args: 185 | acc: 186 | path_root_out: 187 | num_classes: 188 | class_names: 189 | metric_names: 190 | config: 191 | 192 | """ 193 | groupers_stack = ['patient', 'release', 'sequence', 'side'] 194 | 195 | acc_df = pd.DataFrame.from_dict(acc) 196 | acc_df = acc_df.sort_values([*groupers_stack, 'slice_idx_proc']) 197 | 198 | for metric_name in metric_names: 199 | if metric_name not in acc: 200 | logger.error(f'`{metric_name}` is not presented in `acc`') 201 | continue 202 | 203 | metric_values = [] 204 | for gb_name, gb_df in acc_df.groupby(groupers_stack): 205 | tmp = np.asarray([np.asarray(e) for e in gb_df[metric_name]]) 206 | metric_values.append(tmp) 207 | metric_values = np.stack(metric_values, axis=0) 208 | # Axes order: (scan, slice_idx, class_idx) 209 | 210 | nrows = np.floor(np.sqrt(num_classes)).astype(int) 211 | ncols = np.ceil(float(num_classes) / nrows).astype(int) 212 | 213 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10)) 214 | ax = axes.ravel() 215 | plt.title(metric_name) 216 | 217 | for class_idx in range(1, num_classes): 218 | class_idx_axes = class_idx - 1 219 | 220 | y = metric_values[..., class_idx] 221 | x = np.tile(np.arange(0, y.shape[1]), reps=(y.shape[0], 1)) 222 | y = y.ravel() 223 | x = x.ravel() 224 | 225 | sel = (y != 0) 226 | y = y[sel] 227 | x = x[sel] 228 | 229 | sns.lineplot(x=x, y=y, err_style='band', ax=ax[class_idx_axes], 230 | color="salmon") 231 | if class_idx != 0: 232 | ax[class_idx_axes].set_ylim((0.4, 1)) 233 | ax[class_idx_axes].set_xlim((10, 150)) 234 | ax[class_idx_axes].set_xlabel('Slice index') #, size=16) 235 | ax[class_idx_axes].set_ylabel('DSC') #, size=16) 236 | ax[class_idx_axes].set_title(class_names[class_idx]) #, size=16) 237 | 238 | # for label in (ax[class_idx_axes].get_xticklabels() + 239 | # ax[class_idx_axes].get_yticklabels()): 240 | # label.set_fontsize(15) 241 | 242 | plt.tight_layout() 243 | 244 | fname_vis = f"metrics_{config['dataset']}_test_slicew_confid_{metric_name}.png" 245 | path_vis = os.path.join(path_root_out, fname_vis) 246 | plt.savefig(path_vis) 247 | if config['interactive']: 248 | plt.show() 249 | else: 250 | plt.close() 251 | 252 | 253 | def _scan_to_metrics_paired(path_root, name_true, name_pred, num_classes, meta): 254 | # Read the data 255 | patt_true = os.path.join(path_root, meta['patient'], meta['release'], 256 | meta['sequence'], name_true, '*.png') 257 | stack_true = png_to_numpy(patt_true) 258 | 259 | patt_pred = os.path.join(path_root, meta['patient'], meta['release'], 260 | meta['sequence'], name_pred, '*.png') 261 | stack_pred = png_to_numpy(patt_pred) 262 | 263 | # Add batch dimension 264 | stack_true = stack_true[None, ...] 265 | stack_pred = stack_pred[None, ...] 266 | 267 | # Compute the metrics 268 | res = meta 269 | res['dice_score'] = dice_score(stack_pred, stack_true, 270 | num_classes=num_classes) 271 | return res 272 | 273 | 274 | def metrics_paired_volumew(*, path_pred, dirname_true, dirname_pred, df, num_classes, 275 | num_workers=1): 276 | acc_l = [] 277 | groupers_stack = ['patient', 'release', 'sequence', 'side'] 278 | 279 | for name_gb, df_gb in tqdm(df.groupby(groupers_stack)): 280 | patient, release, sequence, side = name_gb 281 | 282 | acc_l.append(delayed(_scan_to_metrics_paired)( 283 | path_root=path_pred, 284 | name_true=dirname_true, 285 | name_pred=dirname_pred, 286 | num_classes=num_classes, 287 | meta={ 288 | 'patient': patient, 289 | 'release': release, 290 | 'sequence': sequence, 291 | 'side': side, 292 | **df_gb.to_dict("records")[0], 293 | } 294 | )) 295 | 296 | acc_l = Parallel(n_jobs=num_workers, verbose=10)( 297 | t for t in tqdm(acc_l, total=len(acc_l))) 298 | # Convert from list of dicts to dict of lists 299 | acc_d = {k: [d[k] for d in acc_l] for k in acc_l[0]} 300 | return acc_d 301 | 302 | 303 | def plot_metrics_paired_volumew_n5a9(acc, path_root_out, num_classes, class_names, 304 | metric_names, config): 305 | """Average and visualize the metrics. 306 | 307 | Args: 308 | acc: dict of lists 309 | path_root_out: str 310 | num_classes: int 311 | class_names: dict 312 | metric_names: iterable of str 313 | config: dict 314 | 315 | """ 316 | groupers_stack = ['patient', 'release', 'sequence', 'side'] 317 | 318 | acc_df = pd.DataFrame.from_dict(acc) 319 | acc_df = acc_df.sort_values(groupers_stack) 320 | 321 | for metric_name in metric_names: 322 | if metric_name not in acc: 323 | logger.error(f'`{metric_name}` is not presented in `acc`') 324 | continue 325 | 326 | metric_values = [] 327 | for gb_name, gb_df in acc_df.groupby(groupers_stack): 328 | tmp = np.asarray([np.asarray(e) for e in gb_df[metric_name]]) 329 | metric_values.append(tmp) 330 | metric_values = np.stack(metric_values, axis=0) 331 | # Axes order: (scan, class_idx[, sub_metric]) 332 | 333 | nrows = np.floor(np.sqrt(num_classes)).astype(int) 334 | ncols = np.ceil(float(num_classes) / nrows).astype(int) 335 | 336 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10)) 337 | ax = axes.ravel() 338 | 339 | tmp_means, tmp_stds = [], [] 340 | 341 | xlims = { 342 | 'dice_score': (0.5, 1.0), 343 | } 344 | 345 | for class_idx, class_scores in enumerate(metric_values.T[1:], start=1): 346 | class_idx_axes = class_idx - 1 347 | 348 | if metric_name == 'dice_score': 349 | class_scores = class_scores[class_scores != 0] 350 | class_scores = np.squeeze(class_scores) 351 | 352 | tmp_means.append(np.mean(class_scores)) 353 | tmp_stds.append(np.std(class_scores)) 354 | 355 | if metric_name in xlims: 356 | ax[class_idx_axes].hist(class_scores, bins=50, 357 | range=xlims[metric_name], 358 | density=True) 359 | else: 360 | ax[class_idx_axes].hist(class_scores, bins=50, 361 | density=True) 362 | ax[class_idx_axes].set_title(class_names[class_idx]) 363 | 364 | tmp_values = ', '.join( 365 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)]) 366 | logger.info(f"{metric_name}:\n{tmp_values}\n") 367 | 368 | plt.tight_layout() 369 | fname_vis = f"metrics_{config['dataset']}_test_volumew_{metric_name}.png" 370 | path_vis = os.path.join(path_root_out, fname_vis) 371 | plt.savefig(path_vis) 372 | if config['interactive']: 373 | plt.show() 374 | else: 375 | plt.close() 376 | 377 | 378 | def print_metrics_paired_volumew_b46e(acc, metric_names): 379 | """Average and print the metrics with respect to KL grades. 380 | 381 | Args: 382 | acc: dict of lists 383 | metric_names: tuple of str 384 | 385 | """ 386 | kl_vec = np.asarray(acc['KL']) 387 | 388 | for kl_value in np.unique(kl_vec): 389 | kl_sel = kl_vec == kl_value 390 | 391 | for metric_name in metric_names: 392 | if metric_name not in acc: 393 | logger.error(f'`{metric_name}` is not presented in `acc`') 394 | continue 395 | metric_values = acc[metric_name] 396 | tmp = np.asarray(metric_values)[kl_sel] 397 | 398 | tmp_means, tmp_stds = [], [] 399 | for class_idx, class_scores in enumerate( 400 | np.moveaxis(tmp[..., 1:], -1, 0), start=1): 401 | class_scores = class_scores[class_scores != 0] 402 | 403 | tmp_means.append(np.mean(class_scores)) 404 | tmp_stds.append(np.std(class_scores)) 405 | 406 | tmp_values = ', '.join( 407 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)]) 408 | logger.info(f"KL{kl_value}, {metric_name}\n{tmp_values}\n") 409 | 410 | 411 | @click.command() 412 | @click.option('--path_experiment_root', default='../../results/temporary') 413 | @click.option('--dirname_pred', required=True) 414 | @click.option('--dirname_true', required=True) 415 | @click.option('--dataset', required=True, type=click.Choice( 416 | ['oai_imo', 'okoa', 'maknee'])) 417 | @click.option('--atlas', required=True, type=click.Choice( 418 | ['imo', 'segm', 'okoa'])) 419 | @click.option('--ignore_cache', is_flag=True) 420 | @click.option('--interactive', is_flag=True) 421 | @click.option('--num_workers', default=1, type=int) 422 | def main(**config): 423 | 424 | path_pred = os.path.join(config['path_experiment_root'], 425 | f"predicts_{config['dataset']}_test") 426 | path_logs = os.path.join(config['path_experiment_root'], 427 | f"logs_{config['dataset']}_test") 428 | 429 | # Get the information on object classes 430 | locations = atlas_to_locations[config['atlas']] 431 | class_names = [k for k in locations] 432 | num_classes = max(locations.values()) + 1 433 | 434 | # Get the index of image files and the corresponding metadata 435 | path_meta = os.path.join(path_pred, 'meta_dynamic.csv') 436 | df_meta = pd.read_csv(path_meta, 437 | dtype={'patient': str, 438 | 'release': str, 439 | 'prefix_var': str, 440 | 'sequence': str, 441 | 'side': str, 442 | 'slice_idx': int, 443 | 'pixel_spacing_0': float, 444 | 'pixel_spacing_1': float, 445 | 'slice_thickness': float, 446 | 'KL': int, 447 | 'has_mask': int}, 448 | index_col=False) 449 | 450 | df_sel = df_meta.sort_values(['patient', 'release', 'sequence', 'side', 'slice_idx']) 451 | 452 | # -------------------------------- Planar ------------------------------------------ 453 | fname_pkl = os.path.join(path_logs, 454 | f"cache_{config['dataset']}_test_" 455 | f"{config['atlas']}_" 456 | f"slicew_paired.pkl") 457 | 458 | logger.info('Planar scores') 459 | if os.path.exists(fname_pkl) and not config['ignore_cache']: 460 | logger.info('Loading from the cache') 461 | with open(fname_pkl, 'rb') as f: 462 | acc_slicew = pickle.load(f) 463 | else: 464 | logger.info('Computing') 465 | acc_slicew = metrics_paired_slicew(path_pred=path_pred, 466 | dirname_true=config['dirname_true'], 467 | dirname_pred=config['dirname_pred'], 468 | df=df_sel, 469 | num_classes=num_classes, 470 | num_workers=config['num_workers']) 471 | logger.info('Caching the results into file') 472 | os.makedirs(path_logs, exist_ok=True) 473 | with open(fname_pkl, 'wb') as f: 474 | pickle.dump(acc_slicew, f) 475 | 476 | plot_metrics_paired_slicew_xvd7( 477 | acc=acc_slicew, 478 | path_root_out=path_logs, 479 | num_classes=num_classes, 480 | class_names=class_names, 481 | metric_names=('dice_score', ), 482 | config=config, 483 | ) 484 | 485 | print_metrics_paired_slicew_pua5( 486 | acc=acc_slicew, 487 | metric_names=('dice_score',), 488 | ) 489 | 490 | plot_metrics_paired_slicew_va49( 491 | acc=acc_slicew, 492 | path_root_out=path_logs, 493 | num_classes=num_classes, 494 | class_names=class_names, 495 | metric_names=('dice_score',), 496 | config=config, 497 | ) 498 | 499 | # ------------------------------- Volumetric --------------------------------------- 500 | fname_pkl = os.path.join(path_logs, 501 | f"cache_{config['dataset']}_test_" 502 | f"{config['atlas']}_" 503 | f"volumew_paired.pkl") 504 | 505 | logger.info('Volumetric scores') 506 | if os.path.exists(fname_pkl) and not config['ignore_cache']: 507 | logger.info('Loading from the cache') 508 | with open(fname_pkl, 'rb') as f: 509 | acc_volumew = pickle.load(f) 510 | else: 511 | logger.info('Computing') 512 | acc_volumew = metrics_paired_volumew( 513 | path_pred=path_pred, 514 | dirname_true=config['dirname_true'], 515 | dirname_pred=config['dirname_pred'], 516 | df=df_sel, 517 | num_classes=num_classes, 518 | num_workers=config['num_workers'] 519 | ) 520 | logger.info('Caching the results into file') 521 | os.makedirs(path_logs, exist_ok=True) 522 | with open(fname_pkl, 'wb') as f: 523 | pickle.dump(acc_volumew, f) 524 | 525 | plot_metrics_paired_volumew_n5a9( 526 | acc=acc_volumew, 527 | path_root_out=path_logs, 528 | num_classes=num_classes, 529 | class_names=class_names, 530 | metric_names=('dice_score', ), 531 | config=config, 532 | ) 533 | 534 | print_metrics_paired_volumew_b46e( 535 | acc=acc_volumew, 536 | metric_names=('dice_score', ), 537 | ) 538 | 539 | 540 | if __name__ == '__main__': 541 | main() 542 | --------------------------------------------------------------------------------