├── github_image.png
├── rocaseg
├── repro
│ ├── __init__.py
│ └── _seed.py
├── __init__.py
├── models
│ ├── __init__.py
│ ├── discr_a.py
│ ├── unet_lext.py
│ └── unet_lext_aux.py
├── datasets
│ ├── __init__.py
│ ├── constants.py
│ ├── meta_oai.py
│ ├── dataset_maknee.py
│ ├── prepare_dataset_maknee.py
│ ├── sources.py
│ ├── dataset_okoa.py
│ ├── prepare_dataset_okoa.py
│ ├── dataset_oai_imo.py
│ └── prepare_dataset_oai_imo.py
├── preproc
│ ├── __init__.py
│ ├── custom.py
│ └── transforms.py
├── components
│ ├── __init__.py
│ ├── mixup.py
│ ├── losses.py
│ ├── checkpoint.py
│ ├── formats.py
│ └── metrics.py
├── describe.py
├── resample.py
├── analyze_predictions_multi.py
├── evaluate.py
├── train_baseline.py
└── analyze_predictions_single.py
├── MANIFEST.in
├── environment.yml
├── setup.py
├── README.md
├── notebooks
└── Statistical_tests.ipynb
└── scripts
└── runner.sh
/github_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIPT-Oulu/RobustCartilageSegmentation/HEAD/github_image.png
--------------------------------------------------------------------------------
/rocaseg/repro/__init__.py:
--------------------------------------------------------------------------------
1 | from ._seed import set_ultimate_seed
2 |
3 |
4 | __all__ = [
5 | 'set_ultimate_seed',
6 | ]
7 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 |
3 | recursive-include tests *
4 | recursive-exclude * __pycache__
5 | recursive-exclude * *.py[co]
6 |
--------------------------------------------------------------------------------
/rocaseg/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Top-level package for rocaseg."""
4 |
5 | __author__ = """Egor Panfilov"""
6 | __email__ = 'egor.v.panfilov@gmail.com'
7 | __version__ = '0.1.0'
8 |
--------------------------------------------------------------------------------
/rocaseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_lext import UNetLext
2 | from .unet_lext_aux import UNetLextAux
3 |
4 | from .discr_a import DiscriminatorA
5 |
6 |
7 | dict_models = {
8 | 'unet_lext': UNetLext,
9 | 'unet_lext_aux': UNetLextAux,
10 |
11 | 'discriminator_a': DiscriminatorA,
12 | }
13 |
--------------------------------------------------------------------------------
/rocaseg/repro/_seed.py:
--------------------------------------------------------------------------------
1 | def set_ultimate_seed(base_seed=777):
2 | import os
3 | import random
4 | os.environ['PYTHONHASHSEED'] = str(base_seed)
5 | random.seed(base_seed)
6 |
7 | try:
8 | import numpy as np
9 | np.random.seed(base_seed)
10 | except ModuleNotFoundError:
11 | print('Module `numpy` has not been found')
12 | try:
13 | import torch
14 | torch.manual_seed(base_seed + 1)
15 | torch.cuda.manual_seed_all(base_seed + 2)
16 | torch.backends.cudnn.deterministic = True
17 | except ModuleNotFoundError:
18 | print('Module `torch` has not been found')
19 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: env_rocaseg
2 | channels:
3 | - pytorch
4 | - simpleitk
5 | - defaults
6 | - conda-forge
7 | dependencies:
8 | - python=3.7.3
9 | - pip=19.1.1
10 | - joblib
11 | - cython
12 | - numpy
13 | - scipy
14 | - pandas=0.23.4
15 | - scikit-image=0.15.0
16 | - pywavelets
17 | - scikit-learn
18 | - matplotlib
19 | - seaborn
20 | - opencv
21 | - jupyter
22 | - xlrd
23 | - pytorch=1.2.0
24 | - torchvision=0.4.0
25 | # - cudatoolkit=10.0
26 | - simpleitk
27 | - pip:
28 | - pydicom==1.2.2
29 | - dicom2nifti==2.1.5
30 | - tifffile
31 | - sas7bdat
32 | - nibabel
33 | - tqdm
34 | - click
35 | - tensorboard>=1.14.0
36 | - tensorflow>=1.14.0
37 |
--------------------------------------------------------------------------------
/rocaseg/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset_oai_imo import (DatasetOAIiMoSagittal2d,
2 | index_from_path_oai_imo)
3 | from .dataset_okoa import (DatasetOKOASagittal2d,
4 | index_from_path_okoa)
5 | from .dataset_maknee import (DatasetMAKNEESagittal2d,
6 | index_from_path_maknee)
7 | from . import meta_oai
8 | from . import constants
9 | from .sources import sources_from_path
10 |
11 |
12 | __all__ = [
13 | 'index_from_path_oai_imo',
14 | 'index_from_path_okoa',
15 | 'index_from_path_maknee',
16 | 'DatasetOAIiMoSagittal2d',
17 | 'DatasetOKOASagittal2d',
18 | 'DatasetMAKNEESagittal2d',
19 | 'meta_oai',
20 | 'constants',
21 | 'sources_from_path',
22 | ]
23 |
--------------------------------------------------------------------------------
/rocaseg/preproc/__init__.py:
--------------------------------------------------------------------------------
1 | from .custom import Normalize, UnNormalize, PercentileClippingAndToFloat
2 | from .transforms import (DualCompose, OneOf, OneOrOther, ImageOnly, NoTransform,
3 | ToTensor, VerticalFlip, HorizontalFlip, Flip, Scale,
4 | Crop, CenterCrop, Pad, GammaCorrection, BilateralFilter)
5 |
6 |
7 | __all__ = [
8 | 'Normalize',
9 | 'UnNormalize',
10 | 'PercentileClippingAndToFloat',
11 | 'DualCompose',
12 | 'OneOf',
13 | 'OneOrOther',
14 | 'ImageOnly',
15 | 'NoTransform',
16 | 'ToTensor',
17 | 'VerticalFlip',
18 | 'HorizontalFlip',
19 | 'Flip',
20 | 'Scale',
21 | 'Crop',
22 | 'CenterCrop',
23 | 'Pad',
24 | 'GammaCorrection',
25 | 'BilateralFilter',
26 | ]
27 |
--------------------------------------------------------------------------------
/rocaseg/components/__init__.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch import optim
3 | from rocaseg.components.losses import CrossEntropyLoss
4 | from rocaseg.components.metrics import (confusion_matrix, dice_score,
5 | dice_score_from_cm)
6 | from rocaseg.components.checkpoint import CheckpointHandler
7 |
8 |
9 | dict_losses = {
10 | 'bce_loss': nn.BCEWithLogitsLoss,
11 | 'multi_ce_loss': CrossEntropyLoss,
12 | }
13 |
14 |
15 | dict_metrics = {
16 | 'confusion_matrix': confusion_matrix,
17 | 'dice_score': dice_score,
18 | 'bce_loss': nn.BCELoss(),
19 | }
20 |
21 |
22 | dict_optimizers = {
23 | 'sgd': optim.SGD,
24 | 'adam': optim.Adam,
25 | }
26 |
27 |
28 | __all__ = [
29 | 'dict_losses',
30 | 'dict_metrics',
31 | 'dict_optimizers',
32 | 'confusion_matrix',
33 | 'dice_score',
34 | 'dice_score_from_cm',
35 | 'CheckpointHandler',
36 | ]
37 |
--------------------------------------------------------------------------------
/rocaseg/datasets/constants.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 |
4 | """
5 | OAI iMorphics reference classes (DICOM attribute names). Cartilage tissues
6 | """
7 | locations_mh53 = OrderedDict([
8 | ('Background', 0),
9 | ('FemoralCartilage', 1),
10 | ('LateralTibialCartilage', 2),
11 | ('MedialTibialCartilage', 3),
12 | ('PatellarCartilage', 4),
13 | ('LateralMeniscus', 5),
14 | ('MedialMeniscus', 6),
15 | ])
16 |
17 |
18 | """
19 | Segmentation predictions. Major cartilage tissues. Joined L and M
20 | """
21 | locations_f43h = OrderedDict([
22 | ('_background', 0),
23 | ('femoral', 1),
24 | ('tibial', 2),
25 | ])
26 |
27 |
28 | """
29 | Segmentation predictions. Cartilage tissues. Joined L and M
30 | """
31 | locations_zp3n = OrderedDict([
32 | ('_background', 0),
33 | ('femoral', 1),
34 | ('tibial', 2),
35 | ('patellar', 3),
36 | ('menisci', 4),
37 | ])
38 |
39 |
40 | atlas_to_locations = {
41 | 'imo': locations_mh53,
42 | 'segm': locations_zp3n,
43 | 'okoa': locations_f43h,
44 | }
45 |
--------------------------------------------------------------------------------
/rocaseg/datasets/meta_oai.py:
--------------------------------------------------------------------------------
1 | release_to_prefix_var = {
2 | '0.C.2': 'V00', '0.E.1': 'V00',
3 | '1.C.2': 'V01', '1.E.1': 'V01',
4 | '2.D.2': 'V02',
5 | '3.C.2': 'V03', '3.E.1': 'V03',
6 | '4.G.1': 'V04',
7 | '5.C.1': 'V05', '5.E.1': 'V05',
8 | '6.C.1': 'V06', '6.E.1': 'V06',
9 | '8.C.1': 'V08', '8.E.1': 'V08',
10 | '10.C.1': 'V10', '10.E.1': 'V10',
11 | }
12 |
13 | prefix_var_to_visit_month = {
14 | 'V00': '000m',
15 | 'V01': '012m',
16 | 'V02': '018m',
17 | 'V03': '024m',
18 | 'V04': '030m',
19 | 'V05': '036m',
20 | 'V06': '048m',
21 | 'V07': '060m',
22 | 'V08': '072m',
23 | 'V09': '084m',
24 | 'V10': '096m',
25 | 'V11': '108m',
26 | }
27 |
28 | release_to_visit_month = {
29 | '0.C.2': '000m', '0.E.1': '000m',
30 | '1.C.2': '012m', '1.E.1': '012m',
31 | '2.D.2': '018m',
32 | '3.C.2': '024m', '3.E.1': '024m',
33 | '4.G.1': '030m',
34 | '5.C.1': '036m', '5.E.1': '036m',
35 | '6.C.1': '048m', '6.E.1': '048m',
36 | '8.C.1': '072m', '8.E.1': '072m',
37 | '10.C.1': '096m', '10.E.1': '096m',
38 |
39 | }
40 |
41 | side_code_to_str = {1: 'RIGHT', 2: 'LEFT'}
42 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """The setup script."""
5 |
6 | from setuptools import setup, find_packages
7 |
8 | with open('README.md') as readme_file:
9 | readme = readme_file.read()
10 |
11 | requirements = []
12 |
13 | setup_requirements = []
14 |
15 | test_requirements = []
16 |
17 | setup(
18 | author="Egor Panfilov",
19 | author_email='egor.v.panfilov@gmail.com',
20 | classifiers=[
21 | 'Development Status :: 4 - Beta',
22 | 'Intended Audience :: Science/Research',
23 | 'License :: OSI Approved :: MIT License',
24 | 'Natural Language :: English',
25 | 'Programming Language :: Python :: 3',
26 | 'Programming Language :: Python :: 3.6',
27 | 'Programming Language :: Python :: 3.7',
28 | ],
29 | description="Framework for knee cartilage and menisci segmentation from MRI",
30 | install_requires=requirements,
31 | license="MIT license",
32 | long_description=readme,
33 | include_package_data=True,
34 | name='rocaseg',
35 | packages=find_packages(include=['rocaseg']),
36 | setup_requires=setup_requirements,
37 | tests_require=test_requirements,
38 | url='https://github.com/MIPT-Oulu/RobustCartilageSegmentation',
39 | version='0.1.0',
40 | zip_safe=False,
41 | )
42 |
--------------------------------------------------------------------------------
/rocaseg/components/mixup.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | """
5 | Example usage:
6 |
7 | # Regular segmentation loss:
8 | ys_pred_oai = self.models['segm'](xs_oai)
9 | loss_segm = self.losses['segm'](input_=ys_pred_oai,
10 | target=ys_true_arg_oai)
11 |
12 | # Mixup
13 | xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
14 | x=xs_oai, y=ys_true_arg_oai,
15 | alpha=self.config['mixup_alpha'], device=maybe_gpu)
16 | ys_pred_oai = self.models['segm'](xs_mixup)
17 | loss_segm = mixup_criterion(criterion=self.losses['segm'],
18 | pred=ys_pred_oai,
19 | y_a=ys_mixup_a,
20 | y_b=ys_mixup_b,
21 | lam=lambda_mixup)
22 | """
23 |
24 |
25 | def mixup_data(x, y, alpha=1.0, device='cpu'):
26 | """Returns mixed inputs, pairs of targets, and lambda"""
27 | if alpha > 0:
28 | lam = np.random.beta(alpha, alpha)
29 | else:
30 | lam = 1
31 |
32 | batch_size = x.size()[0]
33 | index = torch.randperm(batch_size).to(device)
34 |
35 | mixed_x = lam * x + (1 - lam) * x[index, :]
36 | y_a, y_b = y, y[index]
37 | return mixed_x, y_a, y_b, lam
38 |
39 |
40 | def mixup_criterion(criterion, pred, y_a, y_b, lam):
41 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
42 |
--------------------------------------------------------------------------------
/rocaseg/components/losses.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from torch import nn
4 |
5 |
6 | logging.basicConfig()
7 | logger = logging.getLogger('losses')
8 | logger.setLevel(logging.DEBUG)
9 |
10 |
11 | class CrossEntropyLoss(nn.Module):
12 | def __init__(self, num_classes, batch_avg=True, batch_weight=None,
13 | class_avg=True, class_weight=None, **kwargs):
14 | """
15 |
16 | Parameters
17 | ----------
18 | batch_avg:
19 | Whether to average over the batch dimension.
20 | batch_weight:
21 | Batch samples importance coefficients.
22 | class_avg:
23 | Whether to average over the class dimension.
24 | class_weight:
25 | Classes importance coefficients.
26 | """
27 | super().__init__()
28 | self.num_classes = num_classes
29 | self.batch_avg = batch_avg
30 | self.class_avg = class_avg
31 | self.batch_weight = batch_weight
32 | self.class_weight = class_weight
33 | logger.warning('Redundant loss function arguments:\n{}'
34 | .format(repr(kwargs)))
35 | self.ce = nn.CrossEntropyLoss(weight=class_weight)
36 |
37 | def forward(self, input_, target, **kwargs):
38 | """
39 |
40 | Parameters
41 | ----------
42 | input_: (b, ch, d0, d1) tensor
43 | target: (b, d0, d1) tensor
44 |
45 | Returns
46 | -------
47 | out: float tensor
48 | """
49 | return self.ce(input_, target)
50 |
--------------------------------------------------------------------------------
/rocaseg/preproc/custom.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 |
4 |
5 | logging.basicConfig()
6 | logger = logging.getLogger('preprocessing_custom')
7 | logger.setLevel(logging.DEBUG)
8 |
9 |
10 | class Normalize:
11 | def __init__(self, mean, std):
12 | self.mean = mean
13 | self.std = std
14 |
15 | def __call__(self, img, mask=None):
16 | img = img.astype(np.float32)
17 | img = (img - self.mean) / self.std
18 |
19 | if mask is not None:
20 | mask = mask.astype(np.float32)
21 | return img, mask
22 |
23 |
24 | class UnNormalize:
25 | def __init__(self, mean, std):
26 | self.mean = mean
27 | self.std = std
28 |
29 | def __call__(self, *args):
30 | return [(a * self.std + self.mean) for a in args]
31 |
32 |
33 | class PercentileClippingAndToFloat:
34 | """Change the histogram of image by doing global contrast normalization."""
35 | def __init__(self, cut_min=0.5, cut_max=99.5):
36 | """
37 | cut_min - lowest percentile which is used to cut the image histogram
38 | cut_max - highest percentile
39 | """
40 | self.cut_min = cut_min
41 | self.cut_max = cut_max
42 |
43 | def __call__(self, img, mask=None):
44 | img = img.astype(np.float32)
45 | lim_low, lim_high = np.percentile(img, [self.cut_min, self.cut_max])
46 | img = np.clip(img, lim_low, lim_high)
47 |
48 | img -= lim_low
49 | img /= img.max()
50 |
51 | img = img.astype(np.float32)
52 | if mask is not None:
53 | mask = mask.astype(np.float32)
54 |
55 | return img, mask
56 |
--------------------------------------------------------------------------------
/rocaseg/components/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import logging
4 |
5 | import torch
6 |
7 |
8 | logging.basicConfig()
9 | logger = logging.getLogger('handler')
10 | logger.setLevel(logging.DEBUG)
11 |
12 |
13 | class CheckpointHandler(object):
14 | def __init__(self, path_root,
15 | fname_pattern=('{model_name}__'
16 | 'fold_{fold_idx}__'
17 | 'epoch_{epoch_idx:>03d}.pth'),
18 | num_saved=2):
19 | self.path_root = path_root
20 | self.fname_pattern = fname_pattern
21 | self.num_saved = num_saved
22 |
23 | ext = self.fname_pattern.split('.')[-1]
24 |
25 | if not os.path.exists(path_root):
26 | raise ValueError(f'Path {path_root} does not exist')
27 |
28 | full_pattern = os.path.join(self.path_root, '*.' + ext)
29 | self._all_ckpts = sorted(glob(full_pattern, recursive=False))
30 | logger.info(f'Checkpoints found: {len(self._all_ckpts)}')
31 |
32 | self._remove_excessive_ckpts()
33 |
34 | def _remove_excessive_ckpts(self):
35 | while len(self._all_ckpts) > self.num_saved:
36 | try:
37 | os.remove(self._all_ckpts[0])
38 | logger.info(f'Removed ckpt: {self._all_ckpts[0]}')
39 | self._all_ckpts = self._all_ckpts[1:]
40 | except OSError:
41 | logger.error(f'Cannot remove {self._all_ckpts[0]}')
42 |
43 | def get_last_ckpt(self):
44 | if len(self._all_ckpts) == 0:
45 | logger.warning(f'No checkpoints are available in {self.path_root}')
46 | return None
47 | else:
48 | fname_ckpt_sel = self._all_ckpts[-1]
49 | return fname_ckpt_sel
50 |
51 | def save_new_ckpt(self, model, model_name, fold_idx, epoch_idx):
52 | fname = self.fname_pattern.format(model_name=model_name,
53 | fold_idx=fold_idx,
54 | epoch_idx=epoch_idx)
55 | path_full = os.path.join(self.path_root, fname)
56 | try:
57 | torch.save(model.module.state_dict(), path_full)
58 | except AttributeError:
59 | torch.save(model.state_dict(), path_full)
60 |
61 | self._all_ckpts.append(path_full)
62 | self._remove_excessive_ckpts()
63 |
--------------------------------------------------------------------------------
/rocaseg/models/discr_a.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | logging.basicConfig()
10 | logger = logging.getLogger('models')
11 | logger.setLevel(logging.DEBUG)
12 |
13 |
14 | class DiscriminatorA(nn.Module):
15 | def __init__(self, basic_width=64, input_channels=5, output_channels=1,
16 | restore_weights=False, path_weights=None, **kwargs):
17 | super().__init__()
18 | logger.warning('Redundant model init arguments:\n{}'
19 | .format(repr(kwargs)))
20 |
21 | # Preparing the modules dict
22 | modules = OrderedDict()
23 |
24 | modules['conv1'] = \
25 | nn.Sequential(*[
26 | nn.Conv2d(input_channels, basic_width,
27 | kernel_size=4, stride=2, padding=1),
28 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
29 | ])
30 | modules['conv2'] = \
31 | nn.Sequential(*[
32 | nn.Conv2d(basic_width, basic_width*2,
33 | kernel_size=4, stride=2, padding=1),
34 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
35 | ])
36 | modules['conv3'] = \
37 | nn.Sequential(*[
38 | nn.Conv2d(basic_width*2, basic_width*4,
39 | kernel_size=4, stride=2, padding=1),
40 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
41 | ])
42 | modules['conv4'] = \
43 | nn.Sequential(*[
44 | nn.Conv2d(basic_width*4, basic_width*8,
45 | kernel_size=4, stride=2, padding=1),
46 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
47 | ])
48 | modules['output'] = nn.Conv2d(basic_width*8, output_channels,
49 | kernel_size=4, stride=2, padding=1)
50 |
51 | self.__dict__['_modules'] = modules
52 | if restore_weights:
53 | self.load_state_dict(torch.load(path_weights))
54 |
55 | def forward(self, x):
56 | tmp = x
57 |
58 | for name in self.__dict__['_modules']:
59 | layer = self.__dict__['_modules'][name]
60 | tmp = layer(tmp)
61 |
62 | out = F.interpolate(tmp, size=x.size()[-2:],
63 | mode='bilinear', align_corners=True)
64 | return out
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # rocaseg - Robust Cartilage Segmentation from MRI
2 |
3 | Source code for Panfilov et al. "Improving Robustness of Deep Learning Based Knee MRI Segmentation: Mixup and Adversarial Domain Adaptation", https://arxiv.org/abs/1908.04126v3.
4 |
5 |
6 |
7 |
8 |
9 | ### Important!
10 |
11 | The camera-ready version contained a bug in Dice score computation for tibial cartilage on Dataset C. Please, refer to the arXiv version for the corrected values - https://arxiv.org/abs/1908.04126v3.
12 |
13 | ### Description
14 |
15 | 1. To reproduce the experiments from the article one needs to have access to
16 | OAI iMorphics, OKOA, and MAKNEE datasets.
17 |
18 | 2. Download code from this repository.
19 |
20 | 3. Create a fresh Conda environment using `environment.yml`. Install the downloaded
21 | code as a Python module.
22 |
23 | 4. `datasets/prepare_dataset_...` files show how the raw data is converted into the
24 | format supported by the training and the inference pipelines.
25 |
26 | 5. The structure of the project has to be as follows:
27 | ```
28 | ./project/
29 | | ./data_raw/ # raw scans and annotations
30 | | ./OAI_iMorphics_scans/
31 | | ./OAI_iMorphics_annotations/
32 | | ./OKOA/
33 | | ./MAKNEE/
34 | | ./data/ # preprocessed scans and annotations
35 | | ./src/ (this repository)
36 | | ./results/ # models' weights, intermediate and final results
37 | | ./0_baseline/
38 | | ./weights/
39 | | ...
40 | | ./1_mixup/
41 | | ./2_mixup_nowd/
42 | | ./3_uda1/
43 | | ./4_uda2/
44 | | ./5_uda1_mixup_nowd/
45 | ```
46 |
47 | 6. File `scripts/runner.sh` contains the complete description of the workflow.
48 |
49 | 7. Statistical testing is implemented in `notebooks/Statistical_tests.ipynb`.
50 |
51 | 8. Pretrained models are available at https://drive.google.com/open?id=1f-gZ2wCf55OVjgA8oXd7xttGVW5DUUcU .
52 |
53 | ### Legal aspects
54 |
55 | This code is freely available only for research purposes.
56 |
57 | The software has not been certified as a medical device and, therefore, must not be used
58 | for diagnostic purposes.
59 |
60 | Commercial use of the provided code and the pre-trained models is strictly prohibited,
61 | since they were developed using the medical datasets under restrictive licenses.
62 |
63 | ### Cite this work
64 |
65 | ```
66 | @InProceedings{Panfilov_2019_ICCV_Workshops,
67 | author = {Panfilov, Egor and Tiulpin, Aleksei and Klein, Stefan and Nieminen, Miika T. and Saarakkala, Simo},
68 | title = {Improving Robustness of Deep Learning Based Knee MRI Segmentation: Mixup and Adversarial Domain Adaptation},
69 | booktitle = {The IEEE International Conference on Computer Vision (ICCV) Workshops},
70 | month = {Oct},
71 | year = {2019}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/rocaseg/components/formats.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import logging
3 |
4 | import numpy as np
5 | import cv2
6 | import nibabel as nib
7 |
8 |
9 | logging.basicConfig()
10 | logger = logging.getLogger('formats')
11 | logger.setLevel(logging.DEBUG)
12 |
13 |
14 | def png_to_numpy(pattern_fname_in, reverse=False):
15 | """
16 |
17 | Args:
18 | pattern_fname_in: str
19 | String or regexp compatible with `glob`.
20 | reverse: bool
21 | Whether to use reverse slice order.
22 |
23 | Returns:
24 | stack: [R, C, P] ndarray
25 | """
26 | fnames_in = sorted(glob(pattern_fname_in))
27 |
28 | stack = [cv2.imread(fn, cv2.IMREAD_GRAYSCALE) for fn in fnames_in]
29 | stack = np.stack(stack, axis=2)
30 | if reverse:
31 | stack = stack[..., ::-1]
32 | return stack
33 |
34 |
35 | def png_to_nifti(pattern_fname_in, fname_out, spacings=None, reverse=False,
36 | rcp_to_ras=False):
37 | """
38 |
39 | Args:
40 | pattern_fname_in: str
41 | String or regexp compatible with `glob`.
42 | fname_out: str
43 | Full path to the output file.
44 | spacings: 3-tuple of float
45 | (pixel spacing in r, pixel spacing in c, slice thickness).
46 | reverse: bool
47 | Whether to use reverse slice order.
48 | rcp_to_ras: bool
49 | Whether to convert from row-column-plane to RAS+ coordinates.
50 |
51 | """
52 | fnames_in = sorted(glob(pattern_fname_in))
53 |
54 | stack = [cv2.imread(fn, cv2.IMREAD_GRAYSCALE) for fn in fnames_in]
55 | stack = np.stack(stack, axis=2)
56 | if reverse:
57 | stack = stack[..., ::-1]
58 |
59 | numpy_to_nifti(stack=stack, fname_out=fname_out, spacings=spacings,
60 | rcp_to_ras=rcp_to_ras)
61 |
62 |
63 | def nifti_to_png(fname_in, pattern_fname_out, reverse=False, ras_to_rcp=False):
64 | """
65 |
66 | Args:
67 | fname_in: str
68 | Full path to the input file.
69 | pattern_fname_out: str
70 | Must include `{i}`, which is to be substituted with the running index.
71 | reverse: bool
72 | Whether to use reverse slice order.
73 | ras_to_rcp: bool
74 | Whether to convert from RAS+ to row-column-plane coordinates.
75 | """
76 | stack, spacings = nifti_to_numpy(fname_in=fname_in, ras_to_rcp=ras_to_rcp)
77 |
78 | if reverse:
79 | stack = stack[..., ::-1]
80 |
81 | for i in range(stack.shape[-1]):
82 | fn = pattern_fname_out.format(i=i)
83 | cv2.imwrite(fn, stack[..., i])
84 |
85 |
86 | def nifti_to_numpy(fname_in, ras_to_rcp=False):
87 | """
88 |
89 | Args:
90 | fname_in: str
91 | Full path to the input file.
92 | ras_to_rcp: bool
93 | Whether to convert from RAS+ to row-column-plane coordinates.
94 |
95 | Returns:
96 | stack: [R, C, P] ndarray
97 | spacings: 3-tuple of float
98 | (pixel spacing in r, pixel spacing in c, slice thickness).
99 |
100 | """
101 | scan = nib.load(fname_in)
102 | stack = scan.get_fdata()
103 | spacings = [scan.affine[i, i] for i in range(3)]
104 |
105 | if ras_to_rcp:
106 | stack = np.moveaxis(stack, [2, 1, 0], [0, 1, 2])
107 | spacings = [-s for s in spacings[::-1]]
108 |
109 | return stack, spacings
110 |
111 |
112 | def numpy_to_nifti(stack, fname_out, spacings=None, rcp_to_ras=False):
113 | """
114 |
115 | Args:
116 | stack: (r, c, p) ndarray
117 | Data array.
118 | fname_out:
119 | Full path to the output file.
120 | spacings: 3-tuple of float
121 | (pixel spacing in r, pixel spacing in c, slice thickness).
122 | rcp_to_ras: bool
123 | Whether to convert from row-column-plane to RAS+ coordinates.
124 | """
125 | if not rcp_to_ras:
126 | affine = np.eye(4, dtype=np.float)
127 | if spacings is not None:
128 | affine[0, 0] = spacings[0]
129 | affine[1, 1] = spacings[1]
130 | affine[2, 2] = spacings[2]
131 | else:
132 | stack = np.moveaxis(stack, [0, 1, 2], [2, 1, 0])
133 | affine = np.diag([-1., -1., -1., 1.]).astype(np.float)
134 | if spacings is not None:
135 | affine[0, 0] = -spacings[2]
136 | affine[1, 1] = -spacings[1]
137 | affine[2, 2] = -spacings[0]
138 |
139 | scan = nib.Nifti1Image(stack, affine=affine)
140 | nib.save(scan, fname_out)
141 |
--------------------------------------------------------------------------------
/rocaseg/describe.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from collections import defaultdict
4 |
5 | import click
6 |
7 | import torch
8 | from torch.utils.data.dataloader import DataLoader
9 |
10 | from rocaseg.datasets import sources_from_path
11 | from rocaseg.preproc import *
12 | from rocaseg.repro import set_ultimate_seed
13 |
14 |
15 | logging.basicConfig()
16 | logger = logging.getLogger('train')
17 | logger.setLevel(logging.DEBUG)
18 |
19 | set_ultimate_seed()
20 |
21 | if torch.cuda.is_available():
22 | maybe_gpu = 'cuda'
23 | else:
24 | maybe_gpu = 'cpu'
25 |
26 |
27 | class Describer:
28 | def __init__(self, config):
29 | self.config = config
30 |
31 | def run(self, loader):
32 | metrics_avg = defaultdict(float)
33 |
34 | for i, data_batch in enumerate(loader):
35 | metrics_curr = dict()
36 | xs, ys_true = data_batch
37 | # xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu)
38 |
39 | # Calculate metrics
40 | with torch.no_grad():
41 | e = self.config['metrics_skip_edge']
42 | if e != 0:
43 | metrics_curr['mean'] = xs[:, :, e:-e, e:-e].mean()
44 | metrics_curr['std'] = xs[:, :, e:-e, e:-e].std()
45 | metrics_curr['var'] = xs[:, :, e:-e, e:-e].var()
46 | else:
47 | metrics_curr['mean'] = xs.mean()
48 | metrics_curr['std'] = xs.std()
49 | metrics_curr['var'] = xs.var()
50 | for k, v in metrics_curr.items():
51 | metrics_avg[k] += v
52 |
53 | # Add metrics logging
54 | logger.info('Metrics:')
55 | metrics_avg = {k: v / len(loader)
56 | for k, v in metrics_avg.items()}
57 | for k, v in metrics_avg.items():
58 | logger.info(f'{k}: {v}')
59 |
60 |
61 | @click.command()
62 | @click.option('--path_data_root', default='../../data')
63 | @click.option('--path_experiment_root', default='../../results/temporary')
64 | @click.option('--dataset', type=click.Choice(
65 | ['oai_imo', 'okoa', 'maknee']))
66 | @click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
67 | @click.option('--sample_mode', default='x_y', type=str)
68 | @click.option('--batch_size', default=64, type=int)
69 | @click.option('--num_workers', default=1, type=int)
70 | @click.option('--seed_trainval_test', default=0, type=int)
71 | @click.option('--metrics_skip_edge', default=0, type=int)
72 | def main(**config):
73 | config['path_logs'] = os.path.join(
74 | config['path_experiment_root'], f"logs_{config['dataset']}_describe")
75 |
76 | os.makedirs(config['path_logs'], exist_ok=True)
77 |
78 | logging_fh = logging.FileHandler(
79 | os.path.join(config['path_logs'], 'main.log'))
80 | logging_fh.setLevel(logging.DEBUG)
81 | logger.addHandler(logging_fh)
82 |
83 | # Collect the available and specified sources
84 | sources = sources_from_path(path_data_root=config['path_data_root'],
85 | selection=config['dataset'],
86 | with_folds=False,
87 | seed_trainval_test=config['seed_trainval_test'])
88 |
89 | if config['dataset'] == 'oai_imo':
90 | from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d
91 | elif config['dataset'] == 'okoa':
92 | from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d
93 | elif config['dataset'] == 'maknee':
94 | from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d
95 | else:
96 | raise ValueError('Unknown dataset')
97 |
98 | for subset in ('trainval', 'test'):
99 | name = subset
100 | df = sources[config['dataset']][f"{subset}_df"]
101 |
102 | dataset = DatasetSagittal2d(
103 | df_meta=df, mask_mode=config['mask_mode'], name=name,
104 | sample_mode=config['sample_mode'],
105 | transforms=[
106 | PercentileClippingAndToFloat(cut_min=10, cut_max=99),
107 | ToTensor()
108 | ])
109 |
110 | loader = DataLoader(dataset,
111 | batch_size=config['batch_size'],
112 | shuffle=False,
113 | num_workers=config['num_workers'],
114 | pin_memory=True,
115 | drop_last=False)
116 | describer = Describer(config=config)
117 |
118 | describer.run(loader)
119 | loader.dataset.describe()
120 |
121 |
122 | if __name__ == '__main__':
123 | main()
124 |
--------------------------------------------------------------------------------
/notebooks/Statistical_tests.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "from glob import glob\n",
11 | "import pickle\n",
12 | "\n",
13 | "import numpy as np\n",
14 | "import scipy\n",
15 | "import scipy.stats"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "# TODO: uncomment only two experiments\n",
25 | "EXPERIMENTS = (\n",
26 | "# '0_baseline',\n",
27 | "# '1_mixup',\n",
28 | "# '2_mixup_nowd',\n",
29 | "# '3_uda1',\n",
30 | "# '4_uda2',\n",
31 | "# '5_uda1_mixup_nowd'\n",
32 | ")\n",
33 | "\n",
34 | "# TODO: uncomment the section related to the selected evaluation dataset\n",
35 | "# if True:\n",
36 | "# DATASET = 'oai_imo'\n",
37 | "# CLASSES = list(range(1, 5))\n",
38 | "# ATLAS = 'segm'\n",
39 | "# if True:\n",
40 | "# DATASET = 'okoa'\n",
41 | "# CLASSES = (1, 2)\n",
42 | "# ATLAS = 'okoa'\n",
43 | "\n",
44 | "# TODO: specify path to your `results` directory\n",
45 | "path_results_root = ''\n",
46 | "\n",
47 | "path_base = os.path.join(path_results_root,\n",
48 | " (f'{EXPERIMENTS[0]}/logs_{DATASET}_test'\n",
49 | " f'/cache_{DATASET}_test_{ATLAS}_volumew_paired.pkl'))\n",
50 | "path_eval = os.path.join(path_results_root,\n",
51 | " (f'{EXPERIMENTS[1]}/logs_{DATASET}_test'\n",
52 | " f'/cache_{DATASET}_test_{ATLAS}_volumew_paired.pkl'))\n",
53 | "\n",
54 | "with open(path_base, 'rb') as f:\n",
55 | " dict_base = pickle.load(f)\n",
56 | "\n",
57 | "with open(path_eval, 'rb') as f:\n",
58 | " dict_eval = pickle.load(f)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "kl_ranges = {\n",
68 | " '0': (0, ),\n",
69 | " '1': (1, ),\n",
70 | " '2': (2, ),\n",
71 | " '3': (3, ),\n",
72 | " '4': (4, ),\n",
73 | " 'all': (0, 1, 2, 3, 4),\n",
74 | "}"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "dsc_base = np.asarray(dict_base['dice_score'])\n",
84 | "dsc_eval = np.asarray(dict_eval['dice_score'])\n",
85 | "\n",
86 | "for kl_name, kl_values in kl_ranges.items():\n",
87 | " sel_base = np.isin(np.asarray(dict_base['KL']), kl_values)\n",
88 | " sel_eval = np.isin(np.asarray(dict_eval['KL']), kl_values)\n",
89 | " assert np.all(np.equal(sel_base, sel_eval))\n",
90 | " sel = sel_base\n",
91 | " if not np.sum(sel):\n",
92 | " continue\n",
93 | " \n",
94 | " print(f'------ KL: {kl_name} ------')\n",
95 | "\n",
96 | " print('--- Wilcoxon signed-rank, two-sided ---')\n",
97 | " res_2side = [scipy.stats.wilcoxon(dsc_base[sel, c],\n",
98 | " dsc_eval[sel, c])\n",
99 | " for c in CLASSES]\n",
100 | " print(*res_2side, sep='\\n')\n",
101 | "\n",
102 | " print('--- Wilcoxon signed-rank, one-sided, less ---')\n",
103 | " res_1side = [scipy.stats.wilcoxon(dsc_base[sel, c],\n",
104 | " dsc_eval[sel, c],\n",
105 | " alternative='less')\n",
106 | " for c in CLASSES]\n",
107 | " print(*res_1side, sep='\\n')\n",
108 | "\n",
109 | " print('--- Wilcoxon signed-rank, one-sided, greater ---')\n",
110 | " res_1side = [scipy.stats.wilcoxon(dsc_base[sel, c],\n",
111 | " dsc_eval[sel, c],\n",
112 | " alternative='greater')\n",
113 | " for c in CLASSES]\n",
114 | " print(*res_1side, sep='\\n')"
115 | ]
116 | }
117 | ],
118 | "metadata": {
119 | "kernelspec": {
120 | "display_name": "Python 3",
121 | "language": "python",
122 | "name": "python3"
123 | },
124 | "language_info": {
125 | "codemirror_mode": {
126 | "name": "ipython",
127 | "version": 3
128 | },
129 | "file_extension": ".py",
130 | "mimetype": "text/x-python",
131 | "name": "python",
132 | "nbconvert_exporter": "python",
133 | "pygments_lexer": "ipython3",
134 | "version": "3.7.3"
135 | }
136 | },
137 | "nbformat": 4,
138 | "nbformat_minor": 2
139 | }
140 |
--------------------------------------------------------------------------------
/rocaseg/components/metrics.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | logging.basicConfig()
8 | logger = logging.getLogger('metrics')
9 | logger.setLevel(logging.DEBUG)
10 |
11 |
12 | def confusion_matrix(input_, target, num_classes):
13 | """
14 | https://github.com/ternaus/robot-surgery-segmentation/blob/master/validation.py
15 |
16 | Args:
17 | input_: (d0, ..., dn) ndarray or tensor
18 | target: (d0, ..., dn) ndarray or tensor
19 | num_classes: int
20 | Total number of classes.
21 |
22 | Returns:
23 | out: (num_classes, num_classes) ndarray
24 | Confusion matrix.
25 | """
26 | if torch.is_tensor(input_):
27 | input_ = input_.detach().to('cpu').numpy()
28 | if torch.is_tensor(target):
29 | target = target.detach().to('cpu').numpy()
30 |
31 | replace_indices = np.vstack((
32 | target.flatten(),
33 | input_.flatten())
34 | ).T
35 | cm, _ = np.histogramdd(
36 | replace_indices,
37 | bins=(num_classes, num_classes),
38 | range=[(0, num_classes-1), (0, num_classes-1)]
39 | )
40 | return cm.astype(np.uint32)
41 |
42 |
43 | def dice_score_from_cm(cm):
44 | """
45 | https://github.com/ternaus/robot-surgery-segmentation/blob/master/validation.py
46 |
47 | Args:
48 | cm: (d, d) ndarray
49 | Confusion matrix.
50 |
51 | Returns:
52 | out: (d, ) list
53 | List of class Dice scores.
54 | """
55 | scores = []
56 | for index in range(cm.shape[0]):
57 | true_positives = cm[index, index]
58 | false_positives = cm[:, index].sum() - true_positives
59 | false_negatives = cm[index, :].sum() - true_positives
60 | denom = 2 * true_positives + false_positives + false_negatives
61 | if denom == 0:
62 | score = 0
63 | else:
64 | score = 2 * float(true_positives) / denom
65 | scores.append(score)
66 | return scores
67 |
68 |
69 | # ----------------------------------------------------------------------------
70 |
71 |
72 | def _template_score(func_score_from_cm, input_, target, num_classes,
73 | batch_avg, batch_weight, class_avg, class_weight):
74 | """
75 |
76 | Args:
77 | input_: (b, d0, ..., dn) ndarray or tensor
78 | target: (b, d0, ..., dn) ndarray or tensor
79 | num_classes: int
80 | Total number of classes.
81 | batch_avg: bool
82 | Whether to average over the batch dimension.
83 | batch_weight: (b,) iterable
84 | Batch samples importance coefficients.
85 | class_avg: bool
86 | Whether to average over the class dimension.
87 | class_weight: (c,) iterable
88 | Classes importance coefficients. Ignored when `class_avg` is False.
89 |
90 | Returns:
91 | out: scalar if `class_avg` is True, (num_classes,) list otherwise
92 | """
93 | if torch.is_tensor(input_):
94 | num_samples = tuple(input_.size())[0]
95 | else:
96 | num_samples = input_.shape[0]
97 |
98 | scores = np.zeros((num_samples, num_classes))
99 | for sample_idx in range(num_samples):
100 | cm = confusion_matrix(input_=input_[sample_idx],
101 | target=target[sample_idx],
102 | num_classes=num_classes)
103 | scores[sample_idx, :] = func_score_from_cm(cm)
104 |
105 | if batch_avg:
106 | scores = np.mean(scores, axis=0, keepdims=True)
107 | if class_avg:
108 | if class_weight is not None:
109 | scores = scores * np.reshape(class_weight, (1, -1))
110 | scores = np.mean(scores, axis=1, keepdims=True)
111 | return np.squeeze(scores)
112 |
113 |
114 | def dice_score(input_, target, num_classes,
115 | batch_avg=True, batch_weight=None,
116 | class_avg=False, class_weight=None):
117 | """
118 |
119 | Args:
120 | input_: (b, d0, ..., dn) ndarray or tensor
121 | target: (b, d0, ..., dn) ndarray or tensor
122 | num_classes: int
123 | Total number of classes.
124 | batch_avg: bool
125 | Whether to average over the batch dimension.
126 | batch_weight: (b,) iterable
127 | Batch samples importance coefficients.
128 | class_avg: bool
129 | Whether to average over the class dimension.
130 | class_weight: (c,) iterable
131 | Classes importance coefficients. Ignored when `class_avg` is False.
132 |
133 | Returns:
134 | out: scalar if `class_avg` is True, (num_classes,) list otherwise
135 | """
136 | return _template_score(
137 | dice_score_from_cm, input_, target, num_classes,
138 | batch_avg, batch_weight, class_avg, class_weight)
139 |
--------------------------------------------------------------------------------
/rocaseg/datasets/dataset_maknee.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import logging
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | import pandas as pd
8 |
9 | import cv2
10 | from torch.utils.data.dataset import Dataset
11 |
12 |
13 | logging.basicConfig()
14 | logger = logging.getLogger('dataset')
15 | logger.setLevel(logging.DEBUG)
16 |
17 |
18 | def index_from_path_maknee(path_root, force=False):
19 | fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv')
20 | fname_meta_base = os.path.join(path_root, 'meta_base.csv')
21 |
22 | if not os.path.exists(fname_meta_dyn) or force:
23 | fnames_image = glob.glob(
24 | os.path.join(path_root, '**', 'images', '*.png'), recursive=True)
25 | logger.info('{} images found'.format(len(fnames_image)))
26 | fnames_mask = glob.glob(
27 | os.path.join(path_root, '**', 'masks', '*.png'), recursive=True)
28 | logger.info('{} masks found'.format(len(fnames_mask)))
29 |
30 | df_meta = pd.read_csv(fname_meta_base,
31 | dtype={'patient': str,
32 | 'release': str,
33 | 'sequence': str,
34 | 'side': str,
35 | 'slice_idx': int,
36 | 'pixel_spacing_0': float,
37 | 'pixel_spacing_1': float,
38 | 'slice_thickness': float,
39 | 'KL': int},
40 | index_col=False)
41 |
42 | if len(fnames_image) != len(df_meta):
43 | raise ValueError("Number of images doesn't match with the metadata")
44 |
45 | df_meta['path_image'] = [os.path.join(path_root, e)
46 | for e in df_meta['path_rel_image']]
47 |
48 | # Sort the records
49 | df_meta_sorted = (df_meta
50 | .sort_values(['patient', 'sequence', 'slice_idx'])
51 | .reset_index()
52 | .drop('index', axis=1))
53 |
54 | df_meta_sorted.to_csv(fname_meta_dyn, index=False)
55 | else:
56 | df_meta_sorted = pd.read_csv(fname_meta_dyn,
57 | dtype={'patient': str,
58 | 'release': str,
59 | 'sequence': str,
60 | 'side': str,
61 | 'slice_idx': int,
62 | 'pixel_spacing_0': float,
63 | 'pixel_spacing_1': float,
64 | 'slice_thickness': float,
65 | 'KL': int},
66 | index_col=False)
67 |
68 | return df_meta_sorted
69 |
70 |
71 | def read_image(path_file):
72 | image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE)
73 | return image.reshape((1, *image.shape))
74 |
75 |
76 | class DatasetMAKNEESagittal2d(Dataset):
77 | def __init__(self, df_meta, mask_mode=None, name=None, transforms=None,
78 | sample_mode='x_y', **kwargs):
79 | logger.warning('Redundant dataset init arguments:\n{}'
80 | .format(repr(kwargs)))
81 |
82 | self.df_meta = df_meta
83 | self.mask_mode = mask_mode
84 | self.name = name
85 | self.transforms = transforms
86 | self.sample_mode = sample_mode
87 |
88 | def __len__(self):
89 | return len(self.df_meta)
90 |
91 | def _getitem_x_y(self, idx):
92 | image = read_image(self.df_meta['path_image'].iloc[idx])
93 | mask = np.zeros_like(image)
94 |
95 | # Apply transformations
96 | if self.transforms is not None:
97 | for t in self.transforms:
98 | if hasattr(t, 'randomize'):
99 | t.randomize()
100 | image, mask = t(image, mask)
101 |
102 | tmp = dict(self.df_meta.iloc[idx])
103 | tmp['image'] = image
104 | tmp['mask'] = mask
105 |
106 | tmp['xs'] = tmp['image']
107 | tmp['ys'] = tmp['mask']
108 | return tmp
109 |
110 | def __getitem__(self, idx):
111 | if self.sample_mode == 'x_y':
112 | return self._getitem_x_y(idx)
113 | else:
114 | raise ValueError('Invalid `sample_mode`')
115 |
116 | def describe(self):
117 | summary = defaultdict(float)
118 | for i in range(len(self)):
119 | if self.sample_mode == 'x_y':
120 | _, mask = self.__getitem__(i)
121 | else:
122 | mask = self.__getitem__(i)['mask']
123 | summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2))
124 | summary['class_importance'] = \
125 | np.sum(summary['num_class_pixels']) / summary['num_class_pixels']
126 | summary['class_importance'] /= np.sum(summary['class_importance'])
127 | logger.info('Dataset statistics:')
128 | logger.info(sorted(summary.items()))
129 |
--------------------------------------------------------------------------------
/rocaseg/resample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import click
3 | from tqdm import tqdm
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import cv2
8 |
9 |
10 | @click.command()
11 | @click.option('--path_root_in', help='E.g. data/31_OKOA_full_meta')
12 | @click.option('--spacing_in', nargs=2, default=(0.5859375, 0.5859375))
13 | @click.option('--path_root_out', help='E.g. data/32_OKOA_full_meta_rescaled')
14 | @click.option('--spacing_out', nargs=2, default=(0.36458333, 0.36458333))
15 | @click.option('--dirname_images', default='images')
16 | @click.option('--dirname_masks', default='masks')
17 | @click.option('--num_threads', default=12, type=click.IntRange(-1, 12))
18 | @click.option('--margin', default=0, type=int)
19 | @click.option('--update_meta', is_flag=True)
20 | def main(**config):
21 | # Get the index of image files and the corresponding metadata
22 | path_meta = os.path.join(config['path_root_in'], 'meta_base.csv')
23 | if os.path.exists(path_meta):
24 | pass
25 | else:
26 | path_meta = os.path.join(config['path_root_in'], 'meta_dynamic.csv')
27 |
28 | df_meta = pd.read_csv(path_meta,
29 | dtype={'patient': str,
30 | 'release': str,
31 | 'prefix_var': str,
32 | 'sequence': str,
33 | 'side': str,
34 | 'slice_idx': int,
35 | 'pixel_spacing_0': float,
36 | 'pixel_spacing_1': float,
37 | 'slice_thickness': float,
38 | 'KL': int,
39 | 'has_mask': int},
40 | index_col=False)
41 |
42 | df_in = df_meta.sort_values(['patient', 'release', 'sequence', 'side', 'slice_idx'])
43 |
44 | ratio = (np.asarray(config['spacing_in']) /
45 | np.asarray(config['spacing_out']))
46 |
47 | groupers_stack = ['patient', 'release', 'sequence', 'side', 'slice_idx']
48 |
49 | # Resample images
50 | if config['dirname_images'] is not None:
51 | for name_gb, df_gb in tqdm(df_in.groupby(groupers_stack), desc='Resample images'):
52 | patient, release, sequence, side, slice_idx = name_gb
53 |
54 | fn_base = f'{slice_idx:03d}.png'
55 | dir_in = os.path.join(config['path_root_in'],
56 | patient, release, sequence,
57 | config['dirname_images'])
58 | dir_out = os.path.join(config['path_root_out'],
59 | patient, release, sequence,
60 | config['dirname_images'])
61 | os.makedirs(dir_out, exist_ok=True)
62 |
63 | path_in = os.path.join(dir_in, fn_base)
64 | path_out = os.path.join(dir_out, fn_base)
65 |
66 | img_in = cv2.imread(path_in, cv2.IMREAD_GRAYSCALE)
67 |
68 | if config['margin'] == 0:
69 | tmp = img_in
70 | else:
71 | tmp = img_in[config['margin']:-config['margin'],
72 | config['margin']:-config['margin']]
73 |
74 | shape_out = tuple(np.floor(tmp.shape * ratio).astype(np.int))[::-1]
75 | tmp = cv2.resize(tmp, shape_out)
76 | img_out = tmp
77 |
78 | cv2.imwrite(path_out, img_out)
79 |
80 | # Resample masks
81 | if config['dirname_masks'] is not None:
82 | for name_gb, df_gb in tqdm(df_in.groupby(groupers_stack), desc='Resample masks'):
83 | patient, release, sequence, side, slice_idx = name_gb
84 |
85 | fn_base = f'{slice_idx:03d}.png'
86 | dir_in = os.path.join(config['path_root_in'],
87 | patient, release, sequence,
88 | config['dirname_masks'])
89 | dir_out = os.path.join(config['path_root_out'],
90 | patient, release, sequence,
91 | config['dirname_masks'])
92 | os.makedirs(dir_out, exist_ok=True)
93 |
94 | path_in = os.path.join(dir_in, fn_base)
95 | if not os.path.exists(path_in):
96 | print(f'No mask found for {name_gb}')
97 | continue
98 | path_out = os.path.join(dir_out, fn_base)
99 |
100 | mask_in = cv2.imread(path_in, cv2.IMREAD_GRAYSCALE)
101 |
102 | if config['margin'] == 0:
103 | tmp = mask_in
104 | else:
105 | tmp = mask_in[config['margin']:-config['margin'],
106 | config['margin']:-config['margin']]
107 |
108 | shape_out = tuple(np.floor(tmp.shape * ratio).astype(np.int))[::-1]
109 | tmp = cv2.resize(tmp, shape_out, interpolation=cv2.INTER_NEAREST)
110 | mask_out = tmp
111 |
112 | cv2.imwrite(path_out, mask_out)
113 |
114 | if config['update_meta']:
115 | df_out = (df_in.assign(pixel_spacing_0=config['spacing_out'][0])
116 | .assign(pixel_spacing_1=config['spacing_out'][1]))
117 | df_out.to_csv(os.path.join(config['path_root_out'], 'meta_base.csv'), index=False)
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------
/rocaseg/analyze_predictions_multi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import OrderedDict
4 |
5 | import click
6 | import logging
7 |
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 | import pandas as pd
12 |
13 | from rocaseg.datasets.constants import atlas_to_locations
14 |
15 | logging.basicConfig()
16 | logger = logging.getLogger('analyze')
17 | logger.setLevel(logging.INFO)
18 |
19 |
20 | def slicew_dsc_distr_vs_slice_idcs(results, config, num_classes, class_names):
21 | """Distribution of planar DSCs VS slice indices
22 | """
23 | path_vis = config['path_results_root']
24 |
25 | metric_names = ['dice_score', ]
26 | # colors = ["darkorange", "mediumorchid", "deepskyblue"]
27 | # colors = ["salmon", "dodgerblue", "orangered"]
28 | labels = ['Baseline', '+ mixup - WD', '+ UDA2']
29 | colors = ["lightsalmon", "dodgerblue", 'grey']
30 | # labels = ['Reference', '+ mixup - WD']
31 |
32 | # Average and visualize the metrics
33 | for metric_name in metric_names:
34 | print(metric_name)
35 |
36 | for class_idx in range(1, num_classes):
37 |
38 | fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 3.6))
39 | plt.title(metric_name)
40 |
41 | for acc_idx, (_, acc_slicew) in enumerate(results.items()):
42 | unique_scans = list(set(zip(
43 | acc_slicew['patient'],
44 | acc_slicew['release']
45 | )))
46 | num_scans = len(unique_scans)
47 | num_slices = 160
48 |
49 | y = np.full((num_slices, num_scans), np.nan)
50 | for slice_idx in range(num_slices):
51 | sel_by_slice = np.asarray(acc_slicew['slice_idx_proc']) == slice_idx
52 | scores = np.asarray(acc_slicew['dice_score'])[sel_by_slice]
53 | scores = scores[:, class_idx]
54 |
55 | y[slice_idx, :] = scores
56 |
57 | x = np.ones_like(y) * np.arange(0, num_slices)[:, None]
58 | y = y.ravel()
59 | x = x.ravel()
60 |
61 | sel = (y != 0)
62 | y = y[sel]
63 | x = x[sel]
64 |
65 | sns.lineplot(x=x, y=y, err_style='band', ax=axes,
66 | color=colors[acc_idx], label=labels[acc_idx])
67 |
68 | fontsize = 14
69 | axes.set_ylim((0.45, 0.95))
70 | # axes.set_ylim((0.5, 1.0))
71 | axes.set_xlim((10, 150))
72 | axes.set_xlabel('slice index', size=fontsize)
73 | axes.set_ylabel('DSC', size=fontsize)
74 | tmp_title = class_names[class_idx]
75 | tmp_title = tmp_title.replace('femoral', 'Femoral cartilage')
76 | tmp_title = tmp_title.replace('tibial', 'Tibial cartilage')
77 | tmp_title = tmp_title.replace('patellar', 'Patellar cartilage')
78 | tmp_title = tmp_title.replace('menisci', 'Menisci')
79 |
80 | axes.set_title(tmp_title, size=fontsize)
81 |
82 | for label in (axes.get_xticklabels() + axes.get_yticklabels()):
83 | label.set_fontsize(fontsize)
84 |
85 | leg = axes.legend(
86 | # loc='lower left',
87 | loc='lower right',
88 | prop={'size': fontsize}, ncol=1,
89 | framealpha=1.0,
90 | )
91 | for line in leg.get_lines():
92 | line.set_linewidth(3.0)
93 |
94 | # axes[class_idx_axes].get_legend().set_visible(False)
95 |
96 | plt.grid(linestyle=':')
97 |
98 | fname_vis = os.path.join(path_vis,
99 | f"metrics_{config['dataset']}_"
100 | f"test_slicew_confid_"
101 | f"{metric_name}_{class_idx}.pdf")
102 | plt.savefig(fname_vis, bbox_inches='tight')
103 | logger.info(f"Saved to {fname_vis}")
104 | if config['interactive']:
105 | plt.show()
106 | else:
107 | plt.close()
108 |
109 |
110 | @click.command()
111 | @click.option('--path_results_root', default='../../results')
112 | @click.option('--experiment_id', multiple=True)
113 | @click.option('--dataset', required=True, type=click.Choice(
114 | ['oai_imo', 'okoa', 'maknee']))
115 | @click.option('--atlas', required=True, type=click.Choice(
116 | ['imo', 'segm', 'okoa']))
117 | @click.option('--interactive', is_flag=True)
118 | @click.option('--num_workers', default=1, type=int)
119 | def main(**config):
120 | results_slicew = OrderedDict()
121 | results_volumew = OrderedDict()
122 |
123 | for exp_id in config['experiment_id']:
124 | path_experiment_root = os.path.join(
125 | config['path_results_root'], exp_id)
126 | path_logs = os.path.join(
127 | path_experiment_root, f"logs_{config['dataset']}_test")
128 |
129 | # Get the information on object classes
130 | locations = atlas_to_locations[config['atlas']]
131 | class_names = [k for k in locations]
132 | num_classes = max(locations.values()) + 1
133 |
134 | # Load precomputed planar scores
135 | fname_slicew = os.path.join(
136 | path_logs,
137 | f"cache_{config['dataset']}_test_{config['atlas']}_slicew_paired.pkl"
138 | )
139 | if os.path.exists(fname_slicew):
140 | with open(fname_slicew, 'rb') as f:
141 | acc_slicew = pickle.load(f)
142 | else:
143 | raise IOError(f'File {fname_slicew} does not exist')
144 |
145 | fname_volumew = os.path.join(
146 | path_logs,
147 | f"cache_{config['dataset']}_test_{config['atlas']}_volumew_paired.pkl"
148 | )
149 | if os.path.exists(fname_volumew):
150 | with open(fname_volumew, 'rb') as f:
151 | acc_volumew = pickle.load(f)
152 | else:
153 | raise IOError(f'File {fname_volumew} does not exist')
154 |
155 | results_slicew.update({exp_id: acc_slicew})
156 | results_volumew.update({exp_id: acc_volumew})
157 |
158 | # ------------------------------ Visualization --------------------------------------
159 |
160 | slicew_dsc_distr_vs_slice_idcs(results=results_slicew,
161 | config=config,
162 | num_classes=num_classes,
163 | class_names=class_names)
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
168 |
--------------------------------------------------------------------------------
/rocaseg/datasets/prepare_dataset_maknee.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import defaultdict
3 | from glob import glob
4 |
5 | import click
6 | from joblib import Parallel, delayed
7 | from tqdm import tqdm
8 |
9 | import numpy as np
10 | import pandas as pd
11 |
12 | import pydicom
13 | import cv2
14 |
15 |
16 | cv2.ocl.setUseOpenCL(False)
17 |
18 |
19 | def read_dicom(fname, only_data=False):
20 | data = pydicom.read_file(fname)
21 | if len(data.PixelData) == 131072:
22 | dtype = np.uint16
23 | else:
24 | dtype = np.uint8
25 | image = np.frombuffer(data.PixelData, dtype=dtype).astype(float)
26 |
27 | if data.PhotometricInterpretation == 'MONOCHROME1':
28 | image = image.max() - image
29 |
30 | image = image.reshape((data.Rows, data.Columns))
31 |
32 | if only_data:
33 | return image
34 | else:
35 | if hasattr(data, 'ImagerPixelSpacing'):
36 | spacing = [float(e) for e in data.ImagerPixelSpacing[:2]]
37 | slice_thickness = float(data.SliceThickness)
38 | elif hasattr(data, 'PixelSpacing'):
39 | spacing = [float(e) for e in data.PixelSpacing[:2]]
40 | slice_thickness = float(data.SliceThickness)
41 | else:
42 | msg = f'DICOM {fname} does not contain spacing info'
43 | print(msg)
44 | spacing = (0.0, 0.0)
45 | slice_thickness = 0.0
46 |
47 | if data.Laterality == 'R':
48 | side = 'RIGHT'
49 | elif data.Laterality == 'L':
50 | side = 'LEFT'
51 | else:
52 | msg = 'DICOM {fname} does not contain side info'
53 | raise AttributeError(msg)
54 | return image, spacing[0], spacing[1], slice_thickness, side
55 |
56 |
57 | @click.command()
58 | @click.argument('path_root_maknee')
59 | @click.argument('path_root_output')
60 | @click.option('--num_threads', default=12, type=click.IntRange(0, 16))
61 | @click.option('--margin', default=0, type=int)
62 | @click.option('--meta_only', is_flag=True)
63 | def main(**config):
64 | config['path_root_maknee'] = os.path.abspath(config['path_root_maknee'])
65 | config['path_root_output'] = os.path.abspath(config['path_root_output'])
66 |
67 | # -------------------------------------------------------------------------
68 | def worker(path_root_output, row, margin):
69 | meta = defaultdict(list)
70 |
71 | patient = row['patient']
72 | slice_idx = row['slice_idx']
73 | release = 'initial'
74 | sequence = 't2_de3d_we_sag_iso'
75 |
76 | image, *dicom_meta = read_dicom(row['fname_full_image'])
77 |
78 | side = dicom_meta[3]
79 |
80 | if margin != 0:
81 | image = image[margin:-margin, margin:-margin]
82 |
83 | fname_pattern = '{slice_idx:>03}.{ext}'
84 |
85 | # Save image and mask data
86 | dir_rel_image = os.path.join(patient, release, sequence, 'images')
87 | dir_rel_mask = os.path.join(patient, release, sequence, 'masks')
88 | dir_abs_image = os.path.join(path_root_output, dir_rel_image)
89 | dir_abs_mask = os.path.join(path_root_output, dir_rel_mask)
90 | for d in (dir_abs_image, dir_abs_mask):
91 | if not os.path.exists(d):
92 | os.makedirs(d)
93 |
94 | fname_image = fname_pattern.format(slice_idx=slice_idx, ext='png')
95 | path_abs_image = os.path.join(dir_abs_image, fname_image)
96 | if not config['meta_only']:
97 | cv2.imwrite(path_abs_image, image)
98 |
99 | path_rel_image = os.path.join(dir_rel_image, fname_image)
100 |
101 | meta['patient'].append(patient)
102 | meta['release'].append(release)
103 | meta['sequence'].append(sequence)
104 | meta['side'].append(side)
105 | meta['slice_idx'].append(slice_idx)
106 | meta['pixel_spacing_0'].append(dicom_meta[0])
107 | meta['pixel_spacing_1'].append(dicom_meta[1])
108 | meta['slice_thickness'].append(dicom_meta[2])
109 | meta['path_rel_image'].append(path_rel_image)
110 | return meta
111 |
112 | # -------------------------------------------------------------------------
113 |
114 | # Get list of images files
115 | fnames_dicom = glob(os.path.join(config['path_root_maknee'],
116 | 'MRI', 'Scans', '**',
117 | 't2_de3d_we_sag_iso*', 'IMG*'),
118 | recursive=True)
119 | fnames_dicom = list(sorted(fnames_dicom))
120 |
121 | def meta_from_fname(fn):
122 | # root / MRI / Scans / 001 / t2_de3d_we_sag_iso / IMG00000
123 | tmp = fn.split('/')
124 | meta = {
125 | 'fname_full_image': fn,
126 | 'slice_idx': os.path.splitext(tmp[-1])[0][-3:],
127 | 'patient': 'P{:>03}'.format(tmp[-3])}
128 | return meta
129 |
130 | dict_meta = {
131 | 'fname_full_image': [],
132 | 'slice_idx': [],
133 | 'patient': []}
134 |
135 | for e in fnames_dicom:
136 | tmp_meta = meta_from_fname(e)
137 | for k, v in tmp_meta.items():
138 | dict_meta[k].append(v)
139 |
140 | df_meta = pd.DataFrame.from_dict(dict_meta)
141 |
142 | metas = Parallel(config['num_threads'])(delayed(worker)(
143 | *[config['path_root_output'], row, config['margin']]
144 | ) for _, row in tqdm(df_meta.iterrows(), total=len(df_meta)))
145 |
146 | # Merge meta information from different stacks
147 | tmp = defaultdict(list)
148 | for d in metas:
149 | for k, v in d.items():
150 | tmp[k].extend(v)
151 | df_meta = pd.DataFrame.from_dict(tmp)
152 |
153 | # Add grading data to the meta-info df
154 | path_file_exp = os.path.join(config['path_root_maknee'], 'MAKnee_KL_subjects.xlsx')
155 | df_kl = pd.read_excel(path_file_exp)
156 |
157 | df_meta_uniq = df_meta.loc[:, ['patient', 'side']].drop_duplicates()
158 | df_kl.loc[:, 'ID'] = ['P{:>03}'.format(e) for e in df_kl['ID']]
159 | df_kl = df_kl.set_index(df_kl['ID'])
160 | tmp_kl = []
161 |
162 | for _, row in df_meta_uniq.iterrows():
163 | tmp_patient = row['patient']
164 | tmp_side = row['side']
165 |
166 | if tmp_side == 'RIGHT':
167 | tmp_kl.append(int(df_kl.loc[tmp_patient, 'KL right']))
168 | elif tmp_side == 'LEFT':
169 | tmp_kl.append(int(df_kl.loc[tmp_patient, 'KL left']))
170 | else:
171 | msg = f'Unexpected side value {tmp_side}'
172 | raise ValueError(msg)
173 |
174 | df_meta_uniq['KL'] = tmp_kl
175 | df_meta = pd.merge(df_meta, df_meta_uniq, on=['patient', 'side'], how='left')
176 |
177 | df_out = df_meta
178 |
179 | path_output_meta = os.path.join(config['path_root_output'], 'meta_base.csv')
180 | df_out.to_csv(path_output_meta, index=False)
181 |
182 |
183 | if __name__ == '__main__':
184 | main()
185 |
--------------------------------------------------------------------------------
/rocaseg/datasets/sources.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | from sklearn.model_selection import GroupShuffleSplit, GroupKFold
5 |
6 | from rocaseg.datasets import (index_from_path_oai_imo,
7 | index_from_path_okoa,
8 | index_from_path_maknee)
9 |
10 |
11 | logging.basicConfig()
12 | logger = logging.getLogger('datasets')
13 | logger.setLevel(logging.DEBUG)
14 |
15 |
16 | def sources_from_path(path_data_root,
17 | selection=None,
18 | with_folds=False,
19 | fold_num=5,
20 | seed_trainval_test=0):
21 | """
22 |
23 | Args:
24 | path_data_root: str
25 |
26 | selection: iterable or str or None
27 |
28 | with_folds: bool
29 | Whether to split trainval subset into the folds.
30 | fold_num: int
31 | Number of folds.
32 | seed_trainval_test: int
33 | Random state for the trainval/test splitting.
34 |
35 | Returns:
36 |
37 | """
38 | if selection is None:
39 | selection = ('oai_imo', 'okoa', 'maknee')
40 | elif isinstance(selection, str):
41 | selection = (selection, )
42 |
43 | sources = dict()
44 |
45 | for name in selection:
46 | if name == 'oai_imo':
47 | logger.info('--- OAI iMorphics dataset ---')
48 | tmp = dict()
49 | tmp['path_root'] = os.path.join(path_data_root,
50 | '91_OAI_iMorphics_full_meta')
51 |
52 | if not os.path.exists(tmp['path_root']):
53 | logger.warning(f"Dataset {name} is not found in {tmp['path_root']}")
54 | continue
55 |
56 | tmp['full_df'] = index_from_path_oai_imo(tmp['path_root'])
57 | logger.info(f"Total number of samples: "
58 | f"{len(tmp['full_df'])}")
59 |
60 | # Select the specific subset
61 | # Remove two series from the dataset as they are completely missing
62 | # information on patellar cartilage:
63 | # /0.C.2/9674570/20040913/10699609/
64 | # /1.C.2/9674570/20050829/10488714/
65 | tmp['sel_df'] = tmp['full_df'][tmp['full_df']['patient'] != '9674570']
66 | logger.info(f"Selected number of samples: "
67 | f"{len(tmp['sel_df'])}")
68 |
69 | if with_folds:
70 | # Get trainval/test split
71 | tmp_groups = tmp['sel_df'].loc[:, 'patient'].values
72 | tmp_grades = tmp['sel_df'].loc[:, 'KL'].values
73 |
74 | tmp_gss = GroupShuffleSplit(n_splits=1, test_size=0.2,
75 | random_state=seed_trainval_test)
76 | tmp_idcs_trainval, tmp_idcs_test = next(tmp_gss.split(X=tmp['sel_df'],
77 | y=tmp_grades,
78 | groups=tmp_groups))
79 | tmp['trainval_df'] = tmp['sel_df'].iloc[tmp_idcs_trainval]
80 | tmp['test_df'] = tmp['sel_df'].iloc[tmp_idcs_test]
81 | logger.info(f"Made trainval-test split, number of samples: "
82 | f"{len(tmp['trainval_df'])}, "
83 | f"{len(tmp['test_df'])}")
84 |
85 | # Make k folds
86 | tmp_gkf = GroupKFold(n_splits=fold_num)
87 | tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values
88 | tmp_grades = tmp['trainval_df'].loc[:, 'KL'].values
89 |
90 | tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'],
91 | y=tmp_grades, groups=tmp_groups)
92 | sources['oai_imo'] = tmp
93 |
94 | elif name == 'okoa':
95 | logger.info('--- OKOA dataset ---')
96 | tmp = dict()
97 | tmp['path_root'] = os.path.join(path_data_root,
98 | '32_OKOA_full_meta_rescaled')
99 |
100 | if not os.path.exists(tmp['path_root']):
101 | logger.warning(f"Dataset {name} is not found in {tmp['path_root']}")
102 | continue
103 |
104 | tmp['full_df'] = index_from_path_okoa(tmp['path_root'])
105 | logger.info(f"Total number of samples: "
106 | f"{len(tmp['full_df'])}")
107 |
108 | # Select the specific subset
109 | tmp['sel_df'] = tmp['full_df']
110 | logger.info(f"Selected number of samples: "
111 | f"{len(tmp['sel_df'])}")
112 |
113 | if with_folds:
114 | # Get trainval/test split
115 | tmp['trainval_df'] = tmp['sel_df'][tmp['sel_df']['subset'] == 'training']
116 | tmp['test_df'] = tmp['sel_df'][tmp['sel_df']['subset'] == 'evaluation']
117 | logger.info(f"Made trainval-test split, number of samples: "
118 | f"{len(tmp['trainval_df'])}, "
119 | f"{len(tmp['test_df'])}")
120 |
121 | # Make k folds
122 | tmp_gkf = GroupKFold(n_splits=fold_num)
123 | tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values
124 |
125 | tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'],
126 | groups=tmp_groups)
127 | sources['okoa'] = tmp
128 |
129 | elif name == 'maknee':
130 | logger.info('--- MAKNEE dataset ---')
131 | tmp = dict()
132 | tmp['path_root'] = os.path.join(path_data_root,
133 | '42_MAKNEE_full_meta_rescaled')
134 |
135 | if not os.path.exists(tmp['path_root']):
136 | logger.warning(f"Dataset {name} is not found in {tmp['path_root']}")
137 | continue
138 |
139 | tmp['full_df'] = index_from_path_maknee(tmp['path_root'])
140 | logger.info(f"Total number of samples: "
141 | f"{len(tmp['full_df'])}")
142 |
143 | # Select the specific subset
144 | tmp['sel_df'] = tmp['full_df']
145 | logger.info(f"Selected number of samples: "
146 | f"{len(tmp['sel_df'])}")
147 |
148 | # Get trainval/test split
149 | tmp_groups = tmp['sel_df'].loc[:, 'patient'].values
150 |
151 | tmp_gss = GroupShuffleSplit(n_splits=1, test_size=0.2,
152 | random_state=seed_trainval_test)
153 | tmp_idcs_trainval, tmp_idcs_test = next(tmp_gss.split(X=tmp['sel_df'],
154 | groups=tmp_groups))
155 | tmp['trainval_df'] = tmp['sel_df'].iloc[tmp_idcs_trainval]
156 | tmp['test_df'] = tmp['sel_df'].iloc[tmp_idcs_test]
157 | logger.info(f"Made trainval-test split, number of samples: "
158 | f"{len(tmp['trainval_df'])}, "
159 | f"{len(tmp['test_df'])}")
160 |
161 | if with_folds:
162 | # Make k folds
163 | tmp_gkf = GroupKFold(n_splits=fold_num)
164 | tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values
165 |
166 | tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'],
167 | groups=tmp_groups)
168 | sources['maknee'] = tmp
169 |
170 | else:
171 | raise ValueError(f'Unknown dataset `{name}`')
172 |
173 | return sources
174 |
--------------------------------------------------------------------------------
/rocaseg/datasets/dataset_okoa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import logging
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from joblib import Parallel, delayed
9 |
10 | from tqdm import tqdm
11 | import cv2
12 | from torch.utils.data.dataset import Dataset
13 |
14 | from rocaseg.datasets.constants import locations_f43h
15 |
16 |
17 | logging.basicConfig()
18 | logger = logging.getLogger('dataset')
19 | logger.setLevel(logging.DEBUG)
20 |
21 |
22 | def index_from_path_okoa(path_root, force=False):
23 | fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv')
24 | fname_meta_base = os.path.join(path_root, 'meta_base.csv')
25 |
26 | if not os.path.exists(fname_meta_dyn) or force:
27 | fnames_image = glob.glob(
28 | os.path.join(path_root, '**', 'images', '*.png'), recursive=True)
29 | logger.info('{} images found'.format(len(fnames_image)))
30 | fnames_mask = glob.glob(
31 | os.path.join(path_root, '**', 'masks', '*.png'), recursive=True)
32 | logger.info('{} masks found'.format(len(fnames_mask)))
33 |
34 | df_meta = pd.read_csv(fname_meta_base,
35 | dtype={'subset': str,
36 | 'patient': str,
37 | 'release': str,
38 | 'sequence': str,
39 | 'side': str,
40 | 'slice_idx': int,
41 | 'pixel_spacing_0': float,
42 | 'pixel_spacing_1': float,
43 | 'slice_thickness': float,
44 | 'KL': int},
45 | index_col=False)
46 |
47 | if len(fnames_image) != len(df_meta):
48 | raise ValueError("Number of images doesn't match with the metadata")
49 | if len(fnames_mask) != len(df_meta):
50 | raise ValueError("Number of masks doesn't match with the metadata")
51 |
52 | df_meta['path_image'] = [os.path.join(path_root, e)
53 | for e in df_meta['path_rel_image']]
54 | df_meta['path_mask'] = [os.path.join(path_root, e)
55 | for e in df_meta['path_rel_mask']]
56 |
57 | # Check for in-slice mask presence
58 | def worker_74n(fname):
59 | mask = read_mask(path_file=fname, mask_mode='raw')
60 | if np.any(mask[1:] > 0):
61 | return 1
62 | else:
63 | return 0
64 |
65 | logger.info('Exploring the annotations')
66 | tmp = Parallel(n_jobs=-1)(
67 | delayed(worker_74n)(row['path_mask'])
68 | for _, row in tqdm(df_meta.iterrows(), total=len(df_meta)))
69 | df_meta['has_mask'] = tmp
70 |
71 | # Sort the records
72 | df_meta_sorted = (df_meta
73 | .sort_values(['patient', 'sequence', 'slice_idx'])
74 | .reset_index()
75 | .drop('index', axis=1))
76 |
77 | df_meta_sorted.to_csv(fname_meta_dyn, index=False)
78 | else:
79 | df_meta_sorted = pd.read_csv(fname_meta_dyn,
80 | dtype={'subset': str,
81 | 'patient': str,
82 | 'release': str,
83 | 'sequence': str,
84 | 'side': str,
85 | 'slice_idx': int,
86 | 'pixel_spacing_0': float,
87 | 'pixel_spacing_1': float,
88 | 'slice_thickness': float,
89 | 'KL': int,
90 | 'has_mask': int},
91 | index_col=False)
92 | return df_meta_sorted
93 |
94 |
95 | def read_image(path_file):
96 | image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE)
97 | return image.reshape((1, *image.shape))
98 |
99 |
100 | def read_mask(path_file, mask_mode):
101 | """Read mask from the file, and pre-process it.
102 |
103 | IMPORTANT: currently, we handle the inter-class collisions by assigning
104 | the joint pixels to a class with a lower index.
105 |
106 | Parameters
107 | ----------
108 | path_file: str
109 | Full path to mask file.
110 | mask_mode: str
111 | Specifies which channels of mask to use.
112 |
113 | Returns
114 | -------
115 | out : (ch, d0, d1) uint8 ndarray
116 |
117 | """
118 | mask = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE)
119 |
120 | locations = {
121 | '_background': (locations_f43h['_background'],),
122 | 'femoral': (locations_f43h['femoral'],),
123 | 'tibial': (locations_f43h['tibial'], ),
124 | }
125 |
126 | if mask_mode == 'raw':
127 | return mask
128 | elif mask_mode == 'background_femoral_unitibial':
129 | ret = np.empty((3, *mask.shape), dtype=mask.dtype)
130 | ret[0, :, :] = np.isin(mask, locations['_background']).astype(np.uint8)
131 | ret[1, :, :] = np.isin(mask, locations['femoral']).astype(np.uint8)
132 | ret[2, :, :] = np.isin(mask, locations['tibial']).astype(np.uint8)
133 | return ret
134 | else:
135 | raise ValueError('Invalid `mask_mode`')
136 |
137 |
138 | class DatasetOKOASagittal2d(Dataset):
139 | def __init__(self, df_meta, mask_mode=None, name=None, transforms=None,
140 | sample_mode='x_y', **kwargs):
141 | logger.warning('Redundant dataset init arguments:\n{}'
142 | .format(repr(kwargs)))
143 |
144 | self.df_meta = df_meta
145 | self.mask_mode = mask_mode
146 | self.name = name
147 | self.transforms = transforms
148 | self.sample_mode = sample_mode
149 |
150 | def __len__(self):
151 | return len(self.df_meta)
152 |
153 | def _getitem_x_y(self, idx):
154 | image = read_image(self.df_meta['path_image'].iloc[idx])
155 | mask = read_mask(self.df_meta['path_mask'].iloc[idx], self.mask_mode)
156 |
157 | # Apply transformations
158 | if self.transforms is not None:
159 | for t in self.transforms:
160 | if hasattr(t, 'randomize'):
161 | t.randomize()
162 | image, mask = t(image, mask)
163 |
164 | tmp = dict(self.df_meta.iloc[idx])
165 | tmp['image'] = image
166 | tmp['mask'] = mask
167 |
168 | tmp['xs'] = tmp['image']
169 | tmp['ys'] = tmp['mask']
170 | return tmp
171 |
172 | def __getitem__(self, idx):
173 | if self.sample_mode == 'x_y':
174 | return self._getitem_x_y(idx)
175 | else:
176 | raise ValueError('Invalid `sample_mode`')
177 |
178 | def read_image(self, path_file):
179 | return read_image(path_file=path_file)
180 |
181 | def read_mask(self, path_file):
182 | return read_mask(path_file=path_file, mask_mode=self.mask_mode)
183 |
184 | def describe(self):
185 | summary = defaultdict(float)
186 | for i in range(len(self)):
187 | if self.sample_mode == 'x_y':
188 | _, mask = self.__getitem__(i)
189 | else:
190 | mask = self.__getitem__(i)['mask']
191 | summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2))
192 | summary['class_importance'] = \
193 | np.sum(summary['num_class_pixels']) / summary['num_class_pixels']
194 | summary['class_importance'] /= np.sum(summary['class_importance'])
195 | logger.info('Dataset statistics:')
196 | logger.info(sorted(summary.items()))
197 |
--------------------------------------------------------------------------------
/rocaseg/datasets/prepare_dataset_okoa.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import defaultdict
3 | import click
4 |
5 | from joblib import Parallel, delayed
6 | from glob import glob
7 | from tqdm import tqdm
8 |
9 | import numpy as np
10 | import pandas as pd
11 |
12 | import pydicom
13 | import cv2
14 |
15 |
16 | cv2.ocl.setUseOpenCL(False)
17 |
18 |
19 | def read_dicom(fname, only_data=False):
20 | data = pydicom.read_file(fname)
21 | if len(data.PixelData) == 131072:
22 | dtype = np.uint16
23 | else:
24 | dtype = np.uint8
25 | image = np.frombuffer(data.PixelData, dtype=dtype).astype(float)
26 |
27 | if data.PhotometricInterpretation == 'MONOCHROME1':
28 | image = image.max() - image
29 |
30 | image = image.reshape((data.Rows, data.Columns))
31 |
32 | if only_data:
33 | return image
34 | else:
35 | if hasattr(data, 'ImagerPixelSpacing'):
36 | spacing = [float(e) for e in data.ImagerPixelSpacing]
37 | slice_thickness = float(data.SliceThickness)
38 | elif hasattr(data, 'PixelSpacing'):
39 | spacing = [float(e) for e in data.PixelSpacing]
40 | slice_thickness = float(data.SliceThickness)
41 | else:
42 | msg = f'DICOM {fname} does not have the required attributes'
43 | print(msg)
44 | spacing = (0.0, 0.0)
45 | slice_thickness = 0.0
46 | return image, spacing[0], spacing[1], slice_thickness
47 |
48 |
49 | @click.command()
50 | @click.argument('path_root_okoa')
51 | @click.argument('path_root_output')
52 | @click.option('--num_threads', default=12, type=click.IntRange(0, 16))
53 | @click.option('--margin', default=0, type=int)
54 | @click.option('--meta_only', is_flag=True)
55 | def main(**config):
56 | config['path_root_okoa'] = os.path.abspath(config['path_root_okoa'])
57 | config['path_root_output'] = os.path.abspath(config['path_root_output'])
58 |
59 | # -------------------------------------------------------------------------
60 | def worker_s5g(path_root_output, row, margin):
61 | meta = defaultdict(list)
62 |
63 | patient = row['patient'].values[0]
64 | slice_idx = row['slice_idx'].values[0]
65 | side = row['side'].values[0]
66 | release = 'initial'
67 | sequence = 't2_de3d_we_sag_iso'
68 |
69 | image, *dicom_meta = read_dicom(row[('fname_full', 'Images')])
70 |
71 | # Set the default values for the voxel spacings
72 | pixel_spacing_0 = dicom_meta[0] or '0.5859375'
73 | pixel_spacing_1 = dicom_meta[1] or '0.5859375'
74 | slice_thickness = dicom_meta[2] or '0.60000002384186'
75 |
76 | mask_femur = read_dicom(row[('fname_full', 'Femur')], only_data=True)
77 | mask_tibia = read_dicom(row[('fname_full', 'Tibia')], only_data=True)
78 | mask_full = np.zeros(mask_femur.shape, dtype=np.uint8)
79 | # Use inverse order to prioritize femoral tissues in collision handling
80 | mask_full[mask_tibia > 0] = 2
81 | mask_full[mask_femur > 0] = 1
82 |
83 | if margin:
84 | image = image[margin:-margin, margin:-margin]
85 | mask_full = mask_full[margin:-margin, margin:-margin]
86 |
87 | fname_pattern = '{slice_idx:>03}.{ext}'
88 |
89 | # Save image and mask data
90 | dir_rel_image = os.path.join(patient, release, sequence, 'images')
91 | dir_rel_mask = os.path.join(patient, release, sequence, 'masks')
92 | dir_abs_image = os.path.join(path_root_output, dir_rel_image)
93 | dir_abs_mask = os.path.join(path_root_output, dir_rel_mask)
94 | for d in (dir_abs_image, dir_abs_mask):
95 | if not os.path.exists(d):
96 | os.makedirs(d)
97 |
98 | fname_image = fname_pattern.format(slice_idx=slice_idx, ext='png')
99 | path_abs_image = os.path.join(dir_abs_image, fname_image)
100 | if not config['meta_only']:
101 | cv2.imwrite(path_abs_image, image)
102 |
103 | fname_mask = fname_pattern.format(slice_idx=slice_idx, ext='png')
104 | path_abs_mask = os.path.join(dir_abs_mask, fname_mask)
105 | if not config['meta_only']:
106 | cv2.imwrite(path_abs_mask, mask_full)
107 |
108 | path_rel_image = os.path.join(dir_rel_image, fname_image)
109 | path_rel_mask = os.path.join(dir_rel_mask, fname_mask)
110 |
111 | meta['subset'].append(row['subset'].values[0])
112 | meta['patient'].append(patient)
113 | meta['release'].append(release)
114 | meta['sequence'].append(sequence)
115 | meta['side'].append(side)
116 | meta['KL'].append(row['KL'].values[0])
117 | meta['slice_idx'].append(slice_idx)
118 | meta['pixel_spacing_0'].append(pixel_spacing_0)
119 | meta['pixel_spacing_1'].append(pixel_spacing_1)
120 | meta['slice_thickness'].append(slice_thickness)
121 | meta['path_rel_image'].append(path_rel_image)
122 | meta['path_rel_mask'].append(path_rel_mask)
123 | return meta
124 |
125 | # -------------------------------------------------------------------------
126 |
127 | # Get list of images files
128 | paths_fnames_dicom = glob(os.path.join(config['path_root_okoa'], '**', '*.IMA'),
129 | recursive=True)
130 |
131 | # root / training|evaluation / P36 / Images|Femur|Tibia / (1-160).IMA
132 | def meta_from_fname(fn):
133 | tmp = fn.split('/')
134 | slice_idx = int(os.path.splitext(tmp[-1])[0]) - 1
135 | slice_idx = '{:>03}'.format(slice_idx)
136 | meta = {
137 | 'fname_full': fn,
138 | 'slice_idx': slice_idx,
139 | 'kind': tmp[-2],
140 | 'patient': tmp[-3],
141 | 'subset': tmp[-4]
142 | }
143 | return meta
144 |
145 | dict_meta = {
146 | 'fname_full': [],
147 | 'slice_idx': [],
148 | 'kind': [],
149 | 'patient': [],
150 | 'subset': []
151 | }
152 |
153 | for e in paths_fnames_dicom:
154 | tmp_meta = meta_from_fname(e)
155 | for k, v in tmp_meta.items():
156 | dict_meta[k].append(v)
157 |
158 | df_meta = pd.DataFrame.from_dict(dict_meta)
159 | df_meta = (df_meta
160 | .set_index(['subset', 'patient', 'slice_idx', 'kind'])
161 | .unstack('kind')
162 | .reset_index())
163 |
164 | # Add info on scan side and KL grades
165 | path_file_side = os.path.join(config['path_root_okoa'], 'sides.csv')
166 | df_side = pd.read_csv(path_file_side)
167 | path_file_kl = os.path.join(config['path_root_okoa'], 'KL_grades.csv')
168 | df_kl = pd.read_csv(path_file_kl)
169 |
170 | df_extra = pd.merge(df_side, df_kl, on='patient', how='left', sort=True)
171 | df_extra.loc[:, 'KL'] = -1
172 | for r_idx, r in df_extra.iterrows():
173 | if r['side'] == 'LEFT':
174 | df_extra.loc[r_idx, 'KL'] = r['KL left']
175 | elif r['side'] == 'RIGHT':
176 | df_extra.loc[r_idx, 'KL'] = r['KL right']
177 |
178 | # Keep only the fields of interest
179 | df_extra = df_extra.loc[:, ['patient', 'side', 'KL', 'age']]
180 |
181 | # Make same multi-index such than pd.merge doesn't create extra column
182 | df_extra.columns = pd.MultiIndex.from_tuples([(c, '') for c in df_extra.columns])
183 |
184 | # Merge the metadata into the single df
185 | df_meta = pd.merge(df_meta, df_extra, on='patient', how='left', sort=True)
186 |
187 | # Process the raw data
188 | metas = Parallel(config['num_threads'])(delayed(worker_s5g)(
189 | *[config['path_root_output'], row, config['margin']]
190 | ) for _, row in tqdm(df_meta.iterrows(), total=len(df_meta)))
191 |
192 | # Merge meta information from different stacks
193 | tmp = defaultdict(list)
194 | for d in metas:
195 | for k, v in d.items():
196 | tmp[k].extend(v)
197 | df_out = pd.DataFrame.from_dict(tmp)
198 |
199 | path_output_meta = os.path.join(config['path_root_output'], 'meta_base.csv')
200 | df_out.to_csv(path_output_meta, index=False)
201 |
202 |
203 | if __name__ == '__main__':
204 | main()
205 |
--------------------------------------------------------------------------------
/rocaseg/datasets/dataset_oai_imo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import logging
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from joblib import Parallel, delayed
9 | from tqdm import tqdm
10 |
11 | import cv2
12 | from torch.utils.data.dataset import Dataset
13 |
14 | from rocaseg.datasets.constants import locations_mh53
15 |
16 |
17 | logging.basicConfig()
18 | logger = logging.getLogger('dataset')
19 | logger.setLevel(logging.DEBUG)
20 |
21 |
22 | def index_from_path_oai_imo(path_root, force=False):
23 | fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv')
24 | fname_meta_base = os.path.join(path_root, 'meta_base.csv')
25 |
26 | if not os.path.exists(fname_meta_dyn) or force:
27 | fnames_image = glob.glob(
28 | os.path.join(path_root, '**', 'images', '*.png'), recursive=True)
29 | logger.info('{} images found'.format(len(fnames_image)))
30 | fnames_mask = glob.glob(
31 | os.path.join(path_root, '**', 'masks', '*.png'), recursive=True)
32 | logger.info('{} masks found'.format(len(fnames_mask)))
33 |
34 | df_meta = pd.read_csv(fname_meta_base,
35 | dtype={'patient': str,
36 | 'release': str,
37 | 'prefix_var': str,
38 | 'sequence': str,
39 | 'side': str,
40 | 'slice_idx': int,
41 | 'pixel_spacing_0': float,
42 | 'pixel_spacing_1': float,
43 | 'slice_thickness': float,
44 | 'KL': int},
45 | index_col=False)
46 |
47 | if len(fnames_image) != len(df_meta):
48 | raise ValueError("Number of images doesn't match with the metadata")
49 | if len(fnames_mask) != len(df_meta):
50 | raise ValueError("Number of masks doesn't match with the metadata")
51 |
52 | df_meta['path_image'] = [os.path.join(path_root, e)
53 | for e in df_meta['path_rel_image']]
54 | df_meta['path_mask'] = [os.path.join(path_root, e)
55 | for e in df_meta['path_rel_mask']]
56 |
57 | # Check for in-slice mask presence
58 | def worker_nv7(fname):
59 | mask = read_mask(path_file=fname, mask_mode='raw')
60 | if np.any(mask[1:] > 0):
61 | return 1
62 | else:
63 | return 0
64 |
65 | logger.info('Exploring the annotations')
66 | tmp = Parallel(n_jobs=-1)(
67 | delayed(worker_nv7)(row['path_mask'])
68 | for _, row in tqdm(df_meta.iterrows(), total=len(df_meta)))
69 | df_meta['has_mask'] = tmp
70 |
71 | # Sort the records
72 | df_meta_sorted = (df_meta
73 | .sort_values(['patient', 'prefix_var',
74 | 'sequence', 'slice_idx'])
75 | .reset_index()
76 | .drop('index', axis=1))
77 |
78 | df_meta_sorted.to_csv(fname_meta_dyn, index=False)
79 | else:
80 | df_meta_sorted = pd.read_csv(fname_meta_dyn,
81 | dtype={'patient': str,
82 | 'release': str,
83 | 'prefix_var': str,
84 | 'sequence': str,
85 | 'side': str,
86 | 'slice_idx': int,
87 | 'pixel_spacing_0': float,
88 | 'pixel_spacing_1': float,
89 | 'slice_thickness': float,
90 | 'KL': int,
91 | 'has_mask': int},
92 | index_col=False)
93 | return df_meta_sorted
94 |
95 |
96 | def read_image(path_file):
97 | image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE)
98 | return image.reshape((1, *image.shape))
99 |
100 |
101 | def read_mask(path_file, mask_mode):
102 | """Read mask from the file, and pre-process it.
103 |
104 | IMPORTANT: currently, we handle the inter-class collisions by assigning
105 | the joint pixels to a class with a lower index.
106 |
107 | Parameters
108 | ----------
109 | path_file: str
110 | Full path to mask file.
111 | mask_mode: str
112 | Specifies which channels of mask to use.
113 |
114 | Returns
115 | -------
116 | out : (ch, d0, d1) uint8 ndarray
117 |
118 | """
119 | mask = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE)
120 |
121 | locations = {
122 | '_background': (locations_mh53['Background'], ),
123 | 'femoral': (locations_mh53['FemoralCartilage'], ),
124 | 'tibial': (locations_mh53['LateralTibialCartilage'],
125 | locations_mh53['MedialTibialCartilage']),
126 | 'patellar': (locations_mh53['PatellarCartilage'], ),
127 | 'menisci': (locations_mh53['LateralMeniscus'],
128 | locations_mh53['MedialMeniscus']),
129 | }
130 |
131 | if mask_mode == 'raw':
132 | return mask
133 | elif mask_mode == 'all_unitibial_unimeniscus':
134 | ret = np.empty((5, *mask.shape), dtype=mask.dtype)
135 | ret[0, :, :] = np.isin(mask, locations['_background']).astype(np.uint8)
136 | ret[1, :, :] = np.isin(mask, locations['femoral']).astype(np.uint8)
137 | ret[2, :, :] = np.isin(mask, locations['tibial']).astype(np.uint8)
138 | ret[3, :, :] = np.isin(mask, locations['patellar']).astype(np.uint8)
139 | ret[4, :, :] = np.isin(mask, locations['menisci']).astype(np.uint8)
140 | return ret
141 | else:
142 | raise ValueError('Invalid `mask_mode`')
143 |
144 |
145 | class DatasetOAIiMoSagittal2d(Dataset):
146 | def __init__(self, df_meta, mask_mode=None, name=None, transforms=None,
147 | sample_mode='x_y', **kwargs):
148 | logger.warning('Redundant dataset init arguments:\n{}'
149 | .format(repr(kwargs)))
150 |
151 | self.df_meta = df_meta
152 | self.mask_mode = mask_mode
153 | self.name = name
154 | self.transforms = transforms
155 | self.sample_mode = sample_mode
156 |
157 | def __len__(self):
158 | return len(self.df_meta)
159 |
160 | def _getitem_x_y(self, idx):
161 | image = read_image(self.df_meta['path_image'].iloc[idx])
162 | mask = read_mask(self.df_meta['path_mask'].iloc[idx], self.mask_mode)
163 |
164 | # Apply transformations
165 | if self.transforms is not None:
166 | for t in self.transforms:
167 | if hasattr(t, 'randomize'):
168 | t.randomize()
169 | image, mask = t(image, mask)
170 |
171 | tmp = dict(self.df_meta.iloc[idx])
172 | tmp['image'] = image
173 | tmp['mask'] = mask
174 |
175 | tmp['xs'] = tmp['image']
176 | tmp['ys'] = tmp['mask']
177 | return tmp
178 |
179 | def __getitem__(self, idx):
180 | if self.sample_mode == 'x_y':
181 | return self._getitem_x_y(idx)
182 | else:
183 | raise ValueError('Invalid `sample_mode`')
184 |
185 | def read_image(self, path_file):
186 | return read_image(path_file=path_file)
187 |
188 | def read_mask(self, path_file):
189 | return read_mask(path_file=path_file, mask_mode=self.mask_mode)
190 |
191 | def describe(self):
192 | summary = defaultdict(float)
193 | for i in range(len(self)):
194 | if self.sample_mode == 'x_y':
195 | _, mask = self.__getitem__(i)
196 | else:
197 | mask = self.__getitem__(i)['mask']
198 | summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2))
199 | summary['class_importance'] = \
200 | np.sum(summary['num_class_pixels']) / summary['num_class_pixels']
201 | summary['class_importance'] /= np.sum(summary['class_importance'])
202 | logger.info('Dataset statistics:')
203 | logger.info(sorted(summary.items()))
204 |
--------------------------------------------------------------------------------
/rocaseg/models/unet_lext.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamically created UNet with variable Width, Depth and activation
3 |
4 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c).
5 |
6 | """
7 | import logging
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from collections import OrderedDict
13 |
14 |
15 | logging.basicConfig()
16 | logger = logging.getLogger('models')
17 | logger.setLevel(logging.DEBUG)
18 |
19 |
20 | def depthwise_separable_conv(input_channels, output_channels):
21 | """
22 | """
23 | depthwise = nn.Conv2d(input_channels, input_channels,
24 | kernel_size=3, padding=1, groups=input_channels)
25 | pointwise = nn.Conv2d(input_channels, output_channels,
26 | kernel_size=1)
27 | return nn.Sequential(depthwise, pointwise)
28 |
29 |
30 | def block_conv_bn_act(input_channels, output_channels,
31 | convolution, activation):
32 | """
33 | """
34 | if convolution == 'regular':
35 | layer_conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
36 | elif convolution == 'depthwise_separable':
37 | layer_conv = depthwise_separable_conv(input_channels, output_channels)
38 | else:
39 | raise ValueError(f'Wrong `convolution`: {convolution}')
40 |
41 | if activation == 'relu':
42 | layer_act = nn.ReLU(inplace=True)
43 | elif activation == 'selu':
44 | layer_act = nn.SELU(inplace=True)
45 | elif activation == 'elu':
46 | layer_act = nn.ELU(1, inplace=True)
47 | else:
48 | raise ValueError(f'Wrong `activation`: {activation}')
49 |
50 | block = list()
51 | block.append(layer_conv)
52 | if activation == 'relu':
53 | block.append(nn.BatchNorm2d(output_channels))
54 | block.append(layer_act)
55 |
56 | return nn.Sequential(*block)
57 |
58 |
59 | class Encoder(nn.Module):
60 | """Encoder class. for encoder-decoder architecture.
61 |
62 | """
63 | def __init__(self, input_channels, output_channels,
64 | depth=2, convolution='regular', activation='relu'):
65 | super().__init__()
66 | self.layers = nn.Sequential()
67 | for i in range(depth):
68 | tmp = []
69 | if i == 0:
70 | tmp.append(block_conv_bn_act(input_channels, output_channels,
71 | convolution=convolution,
72 | activation=activation))
73 | else:
74 | tmp.append(block_conv_bn_act(output_channels, output_channels,
75 | convolution=convolution,
76 | activation=activation))
77 |
78 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp))
79 |
80 | def forward(self, x):
81 | processed = self.layers(x)
82 | pooled = F.max_pool2d(processed, 2, 2)
83 | return processed, pooled
84 |
85 |
86 | class Decoder(nn.Module):
87 | """Decoder class. for encoder-decoder architecture.
88 |
89 | """
90 | def __init__(self, input_channels, output_channels, depth=2, mode='bilinear',
91 | convolution='regular', activation='relu'):
92 | super().__init__()
93 | self.layers = nn.Sequential()
94 | self.ups_mode = mode
95 | for i in range(depth):
96 | tmp = []
97 | if i == 0:
98 | tmp.append(block_conv_bn_act(input_channels, output_channels,
99 | convolution=convolution,
100 | activation=activation))
101 | else:
102 | tmp.append(block_conv_bn_act(output_channels, output_channels,
103 | convolution=convolution,
104 | activation=activation))
105 |
106 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp))
107 |
108 | def forward(self, x_big, x):
109 | x_ups = F.interpolate(x, size=x_big.size()[-2:], mode=self.ups_mode,
110 | align_corners=True)
111 | y_cat = torch.cat([x_ups, x_big], 1)
112 | y = self.layers(y_cat)
113 | return y
114 |
115 |
116 | class UNetLext(nn.Module):
117 | """UNet architecture with 3x3 convolutions. Created dynamically based on depth and width.
118 |
119 | """
120 | def __init__(self, basic_width=24, depth=6, center_depth=2,
121 | input_channels=3, output_channels=1,
122 | convolution='regular', activation='relu',
123 | pretrained=False, path_pretrained=None,
124 | restore_weights=False, path_weights=None, **kwargs):
125 | """
126 |
127 | Parameters
128 | ----------
129 | basic_width:
130 | Basic width of the network, which is doubled at each layer.
131 | depth:
132 | Number of layers.
133 | center_depth:
134 | Depth of the central block in UNet.
135 | input_channels:
136 | Number of input channels.
137 | output_channels:
138 | Number of output channels (/classes).
139 | convolution: {'regular', 'depthwise_separable'}
140 | activation: {'ReLU', 'SeLU', 'ELU'}
141 | Activation function.
142 | restore_weights: bool
143 | ???
144 | path_weights: str
145 | ???
146 | kwargs:
147 | Catches redundant arguments and issues a warning.
148 | """
149 | assert depth >= 2
150 | super().__init__()
151 | logger.warning('Redundant model init arguments:\n{}'
152 | .format(repr(kwargs)))
153 |
154 | # Preparing the modules dict
155 | modules = OrderedDict()
156 |
157 | modules['down1'] = Encoder(input_channels, basic_width, activation=activation)
158 |
159 | # Automatically creating the Encoder based on the depth and width
160 | for level in range(2, depth + 1):
161 | mul_in = 2 ** (level - 2)
162 | mul_out = 2 ** (level - 1)
163 | layer = Encoder(basic_width * mul_in, basic_width * mul_out,
164 | convolution=convolution, activation=activation)
165 | modules['down' + str(level)] = layer
166 |
167 | # Creating the center
168 | modules['center'] = nn.Sequential(
169 | *[block_conv_bn_act(basic_width * mul_out, basic_width * mul_out,
170 | convolution=convolution, activation=activation)
171 | for _ in range(center_depth)]
172 | )
173 |
174 | # Automatically creating the decoder
175 | for level in reversed(range(2, depth + 1)):
176 | mul_in = 2 ** (level - 1)
177 | layer = Decoder(2 * basic_width * mul_in, basic_width * mul_in // 2,
178 | convolution=convolution, activation=activation)
179 | modules['up' + str(level)] = layer
180 |
181 | modules['up1'] = Decoder(basic_width * 2, basic_width * 2,
182 | convolution=convolution, activation=activation)
183 |
184 | modules['mixer'] = nn.Conv2d(basic_width * 2, output_channels,
185 | kernel_size=1, padding=0, stride=1,
186 | bias=True)
187 |
188 | self.__dict__['_modules'] = modules
189 | if pretrained:
190 | self.load_state_dict(torch.load(path_pretrained))
191 | if restore_weights:
192 | self.load_state_dict(torch.load(path_weights))
193 |
194 | def forward(self, x):
195 | encoded_results = {}
196 |
197 | out = x
198 | for name in self.__dict__['_modules']:
199 | if name.startswith('down'):
200 | layer = self.__dict__['_modules'][name]
201 | convolved, pooled = layer(out)
202 | encoded_results[name] = convolved
203 | out = pooled
204 |
205 | out = self.center(out)
206 |
207 | for name in self.__dict__['_modules']:
208 | if name.startswith('up'):
209 | layer = self.__dict__['_modules'][name]
210 | out = layer(encoded_results['down' + name[-1]], out)
211 | return self.mixer(out)
212 |
--------------------------------------------------------------------------------
/rocaseg/models/unet_lext_aux.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamically created UNet with variable Width, Depth and activation
3 |
4 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c).
5 |
6 | """
7 | import logging
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from collections import OrderedDict
13 |
14 |
15 | logging.basicConfig()
16 | logger = logging.getLogger('models')
17 | logger.setLevel(logging.DEBUG)
18 |
19 |
20 | def ConvBlock3(inp, out, activation):
21 | """3x3 ConvNet building block with different activations support.
22 |
23 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c).
24 | """
25 | if activation == 'relu':
26 | return nn.Sequential(
27 | nn.Conv2d(inp, out, kernel_size=3, padding=1),
28 | nn.BatchNorm2d(out),
29 | nn.ReLU(inplace=True)
30 | )
31 | elif activation == 'selu':
32 | return nn.Sequential(
33 | nn.Conv2d(inp, out, kernel_size=3, padding=1),
34 | nn.SELU(inplace=True)
35 | )
36 | elif activation == 'elu':
37 | return nn.Sequential(
38 | nn.Conv2d(inp, out, kernel_size=3, padding=1),
39 | nn.ELU(1, inplace=True)
40 | )
41 |
42 |
43 | class Encoder(nn.Module):
44 | """Encoder class. for encoder-decoder architecture.
45 |
46 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c).
47 | """
48 | def __init__(self, input_channels, output_channels, depth=2, activation='relu'):
49 | super().__init__()
50 | self.layers = nn.Sequential()
51 | for i in range(depth):
52 | tmp = []
53 | if i == 0:
54 | tmp.append(ConvBlock3(input_channels, output_channels, activation))
55 | else:
56 | tmp.append(ConvBlock3(output_channels, output_channels, activation))
57 |
58 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp))
59 |
60 | def forward(self, x):
61 | processed = self.layers(x)
62 | pooled = F.max_pool2d(processed, 2, 2)
63 | return processed, pooled
64 |
65 |
66 | class Decoder(nn.Module):
67 | """Decoder class. for encoder-decoder architecture.
68 |
69 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c).
70 | """
71 | def __init__(self, input_channels, output_channels, depth=2, mode='bilinear',
72 | activation='relu'):
73 | super().__init__()
74 | self.layers = nn.Sequential()
75 | self.ups_mode = mode
76 | for i in range(depth):
77 | tmp = []
78 | if i == 0:
79 | tmp.append(ConvBlock3(input_channels, output_channels, activation))
80 | else:
81 | tmp.append(ConvBlock3(output_channels, output_channels, activation))
82 |
83 | self.layers.add_module('conv_3x3_{}'.format(i), nn.Sequential(*tmp))
84 |
85 | def forward(self, x_big, x):
86 | x_ups = F.interpolate(x, size=x_big.size()[-2:], mode=self.ups_mode,
87 | align_corners=True)
88 | y_cat = torch.cat([x_ups, x_big], 1)
89 | y = self.layers(y_cat)
90 | return y
91 |
92 |
93 | class AuxModule(nn.Module):
94 | def __init__(self, input_channels, output_channels, dilation_series, padding_series):
95 | super(AuxModule, self).__init__()
96 | self.layers = nn.ModuleList()
97 | for dilation, padding in zip(dilation_series, padding_series):
98 | self.layers.append(
99 | nn.Conv2d(input_channels, output_channels,
100 | kernel_size=3, stride=1, padding=padding,
101 | dilation=dilation, bias=True))
102 |
103 | for m in self.layers:
104 | m.weight.data.normal_(0, 0.01)
105 |
106 | def forward(self, x):
107 | out = self.layers[0](x)
108 | for i in range(len(self.layers) - 1):
109 | out += self.layers[i + 1](x)
110 | return out
111 |
112 |
113 | class UNetLextAux(nn.Module):
114 | """UNet architecture with 3x3 convolutions. Created dynamically based on depth and width.
115 |
116 | Aleksei Tiulpin, 2017 (c)
117 | """
118 | def __init__(self, basic_width=24, depth=6, center_depth=2,
119 | input_channels=3, output_channels=1, activation='relu',
120 | pretrained=False, path_pretrained=None,
121 | restore_weights=False, path_weights=None,
122 | with_aux=True, **kwargs):
123 | """
124 |
125 | Aleksei Tiulpin, Unversity of Oulu, 2017 (c).
126 |
127 | Parameters
128 | ----------
129 | basic_width:
130 | Basic width of the network, which is doubled at each layer.
131 | depth:
132 | Number of layers.
133 | center_depth:
134 | Depth of the central block in UNet.
135 | input_channels:
136 | Number of input channels.
137 | output_channels:
138 | Number of output channels (/classes).
139 | activation: {'ReLU', 'SeLU', 'ELU'}
140 | Activation function.
141 | restore_weights: bool
142 | ???
143 | path_weights: str
144 | ???
145 | kwargs:
146 | """
147 | assert depth >= 2
148 | super().__init__()
149 | logger.warning('Redundant model init arguments:\n{}'
150 | .format(repr(kwargs)))
151 | self._with_aux = with_aux
152 |
153 | # Preparing the modules dict
154 | modules = OrderedDict()
155 |
156 | modules['down1'] = Encoder(input_channels, basic_width, activation=activation)
157 |
158 | # Automatically creating the Encoder based on the depth and width
159 | for level in range(2, depth + 1):
160 | mul_in = 2 ** (level - 2)
161 | mul_out = 2 ** (level - 1)
162 | layer = Encoder(basic_width * mul_in, basic_width * mul_out,
163 | activation=activation)
164 | modules['down' + str(level)] = layer
165 |
166 | # Creating the center
167 | modules['center'] = nn.Sequential(
168 | *[ConvBlock3(basic_width * mul_out, basic_width * mul_out,
169 | activation=activation)
170 | for i in range(center_depth)]
171 | )
172 |
173 | # Automatically creating the decoder
174 | for level in reversed(range(2, depth + 1)):
175 | mul_in = 2 ** (level - 1)
176 | layer = Decoder(2 * basic_width * mul_in, basic_width * mul_in // 2,
177 | activation=activation)
178 | modules['up' + str(level)] = layer
179 |
180 | if self._with_aux and (level == 2):
181 | modules['aux'] = AuxModule(
182 | input_channels=basic_width * mul_in // 2,
183 | output_channels=output_channels,
184 | dilation_series=[6, 12, 18, 24],
185 | padding_series=[6, 12, 18, 24])
186 |
187 | modules['up1'] = Decoder(basic_width * 2, basic_width * 2,
188 | activation=activation)
189 |
190 | modules['mixer'] = nn.Conv2d(basic_width * 2, output_channels,
191 | kernel_size=1, padding=0, stride=1,
192 | bias=True)
193 |
194 | self.__dict__['_modules'] = modules
195 | if pretrained:
196 | self.load_state_dict(torch.load(path_pretrained))
197 | if restore_weights:
198 | self.load_state_dict(torch.load(path_weights))
199 |
200 | def forward(self, x):
201 | encoded_results = {}
202 |
203 | out = x
204 | for name in self.__dict__['_modules']:
205 | if name.startswith('down'):
206 | layer = self.__dict__['_modules'][name]
207 | convolved, pooled = layer(out)
208 | encoded_results[name] = convolved
209 | out = pooled
210 |
211 | out = self.center(out)
212 |
213 | for name in self.__dict__['_modules']:
214 | if name.startswith('up'):
215 | layer = self.__dict__['_modules'][name]
216 | out = layer(encoded_results['down' + name[-1]], out)
217 |
218 | if name == 'up2':
219 | out_aux = out
220 |
221 | if self._with_aux:
222 | out_aux = self.aux(out_aux)
223 | out_aux = F.interpolate(out_aux, size=x.size()[-2:],
224 | mode='bilinear', align_corners=True)
225 | out_main = self.mixer(out)
226 | return out_main, out_aux
227 | else:
228 | return self.mixer(out)
229 |
--------------------------------------------------------------------------------
/rocaseg/datasets/prepare_dataset_oai_imo.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from collections import defaultdict
4 |
5 | import click
6 | from joblib import Parallel, delayed
7 | from tqdm import tqdm
8 |
9 | import numpy as np
10 | from sas7bdat import SAS7BDAT
11 | from scipy import io
12 | import pandas as pd
13 |
14 | import pydicom
15 | import cv2
16 |
17 | from rocaseg.datasets.constants import locations_mh53
18 | from rocaseg.datasets.meta_oai import side_code_to_str, release_to_prefix_var
19 |
20 |
21 | cv2.ocl.setUseOpenCL(False)
22 |
23 |
24 | def read_dicom(fname):
25 | data = pydicom.read_file(fname)
26 | image = np.frombuffer(data.PixelData, dtype=np.uint16).astype(float)
27 |
28 | if data.PhotometricInterpretation == 'MONOCHROME1':
29 | image = image.max() - image
30 | image = image.reshape((data.Rows, data.Columns))
31 |
32 | if 'RIGHT' in data.SeriesDescription:
33 | side = 'RIGHT'
34 | elif 'LEFT' in data.SeriesDescription:
35 | side = 'LEFT'
36 | else:
37 | print(data)
38 | msg = f'DICOM {fname} does not contain side info'
39 | raise ValueError(msg)
40 |
41 | if hasattr(data, 'ImagerPixelSpacing'):
42 | spacing = [float(e) for e in data.ImagerPixelSpacing[:2]]
43 | elif hasattr(data, 'PixelSpacing'):
44 | spacing = [float(e) for e in data.PixelSpacing[:2]]
45 | else:
46 | msg = f'DICOM {fname} does not contain spacing info'
47 | raise AttributeError(msg)
48 |
49 | return (image,
50 | spacing[0],
51 | spacing[1],
52 | float(data.SliceThickness),
53 | side)
54 |
55 |
56 | def mask_from_mat(masks_mat, mask_shape, slice_idx, attr_name):
57 | mask = np.zeros(mask_shape, dtype=np.uint8)
58 | data = getattr(masks_mat[0][slice_idx], attr_name)
59 | if len(data.shape) > 0:
60 | for comp in range(data.shape[1]):
61 | cnt = data[0, comp][:, :2].copy()
62 | cnt[:, 1] = mask_shape[0] - cnt[:, 1]
63 | cntf = cnt.astype(np.int)
64 | cv2.drawContours(mask, [cntf], -1, (255, 255, 255), -1)
65 | mask = (mask > 0).astype(np.uint8)
66 | return mask
67 |
68 |
69 | @click.command()
70 | @click.argument('path_root_oai_mri')
71 | @click.argument('path_root_imo')
72 | @click.argument('path_root_output')
73 | @click.option('--num_threads', default=12, type=click.IntRange(0, 16))
74 | @click.option('--margin', default=0, type=int)
75 | @click.option('--meta_only', is_flag=True)
76 | def main(**config):
77 | config['path_root_oai_mri'] = os.path.abspath(config['path_root_oai_mri'])
78 | config['path_root_imo'] = os.path.abspath(config['path_root_imo'])
79 | config['path_root_output'] = os.path.abspath(config['path_root_output'])
80 |
81 | # -------------------------------------------------------------------------
82 | def worker_xz9(path_root_output, path_stack, margin):
83 | meta = defaultdict(list)
84 |
85 | release, patient = path_stack.split('/')[-4:-2]
86 | prefix_var = release_to_prefix_var[release]
87 | sequence = 'sag_3d_dess_we'
88 |
89 | path_annot = os.path.join(config['path_root_imo'], patient, prefix_var)
90 | fnames_annot = glob(os.path.join(path_annot, '*.mat'))
91 | if len(fnames_annot) != 1:
92 | raise ValueError(f'Unexpected annotations for patient: {patient}')
93 | fname_annot = fnames_annot[0]
94 |
95 | file_mat = io.loadmat(os.path.join(path_annot, fname_annot),
96 | struct_as_record=False)
97 | masks_mat = file_mat['datastruct']
98 | num_slices = masks_mat.shape[1]
99 |
100 | for slice_idx in range(num_slices):
101 | # Indexing of slices in OAI dataset starts with 001
102 | fname_src = os.path.join(path_stack, '{:>03}'.format(slice_idx+1))
103 | image, *dicom_meta = read_dicom(fname_src)
104 |
105 | side = dicom_meta[3]
106 |
107 | mask_proc = np.zeros_like(image)
108 |
109 | # NOTICE: Reference masks have some collisions. We solve them
110 | # by prioritising the tissues which are earlier in the list.
111 | for part_name, part_value in reversed(locations_mh53.items()):
112 | # Skip the background as it is not presented in the source data
113 | if part_name == 'Background':
114 | continue
115 | try:
116 | mask_temp = mask_from_mat(masks_mat, image.shape,
117 | slice_idx, part_name)
118 | mask_proc[mask_temp > 0] = part_value
119 | except AttributeError:
120 | print(f'Error accessing {part_name} in {fname_src}')
121 |
122 | if margin != 0:
123 | image = image[margin:-margin, margin:-margin]
124 | mask_proc = mask_proc[margin:-margin, margin:-margin]
125 |
126 | fname_pattern = '{slice_idx:>03}.{ext}'
127 |
128 | # Save image and mask data
129 | dir_rel_image = os.path.join(patient, release, sequence, 'images')
130 | dir_rel_mask = os.path.join(patient, release, sequence, 'masks')
131 | dir_abs_image = os.path.join(path_root_output, dir_rel_image)
132 | dir_abs_mask = os.path.join(path_root_output, dir_rel_mask)
133 | for d in (dir_abs_image, dir_abs_mask):
134 | if not os.path.exists(d):
135 | os.makedirs(d)
136 |
137 | fname_image = fname_pattern.format(slice_idx=slice_idx, ext='png')
138 | path_abs_image = os.path.join(dir_abs_image, fname_image)
139 | if not config['meta_only']:
140 | cv2.imwrite(path_abs_image, image)
141 |
142 | fname_mask = fname_pattern.format(slice_idx=slice_idx, ext='png')
143 | path_abs_mask = os.path.join(dir_abs_mask, fname_mask)
144 | if not config['meta_only']:
145 | cv2.imwrite(path_abs_mask, mask_proc)
146 |
147 | path_rel_image = os.path.join(dir_rel_image, fname_image)
148 | path_rel_mask = os.path.join(dir_rel_mask, fname_mask)
149 |
150 | meta['patient'].append(patient)
151 | meta['release'].append(release)
152 | meta['prefix_var'].append(prefix_var)
153 | meta['sequence'].append(sequence)
154 | meta['side'].append(side)
155 | meta['slice_idx'].append(slice_idx)
156 | meta['pixel_spacing_0'].append(dicom_meta[0])
157 | meta['pixel_spacing_1'].append(dicom_meta[1])
158 | meta['slice_thickness'].append(dicom_meta[2])
159 | meta['path_rel_image'].append(path_rel_image)
160 | meta['path_rel_mask'].append(path_rel_mask)
161 | return meta
162 | # -------------------------------------------------------------------------
163 |
164 | # OAI data path structure:
165 | # root / examination / release / patient / date / barcode (/ slices)
166 | paths_stacks = glob(os.path.join(config['path_root_oai_mri'], '**/**/**/**/**'))
167 | paths_stacks.sort(key=lambda x: int(x.split('/')[-3]))
168 |
169 | metas = Parallel(config['num_threads'])(delayed(worker_xz9)(
170 | *[config['path_root_output'], path_stack, config['margin']]
171 | ) for path_stack in tqdm(paths_stacks))
172 |
173 | # Merge meta information from different stacks
174 | tmp = defaultdict(list)
175 | for d in metas:
176 | for k, v in d.items():
177 | tmp[k].extend(v)
178 | df_out = pd.DataFrame.from_dict(tmp)
179 |
180 | # Find the grading data
181 | fnames_sas = glob(os.path.join(config['path_root_oai_mri'],
182 | '*', '*.sas7bdat'), recursive=True)
183 |
184 | # Read semi-quantitative data
185 | dfs = dict()
186 | for fn in fnames_sas:
187 | with SAS7BDAT(fn) as f:
188 | raw = [r for r in f]
189 | tmp = pd.DataFrame(raw[1:], columns=raw[0])
190 |
191 | prefix_var = [c for c in tmp.columns if c.endswith('XRKL')][0][:3]
192 |
193 | tmp = tmp.rename(lambda x: x.upper(), axis=1)
194 | tmp = tmp.rename({'VERSION': f'{prefix_var}VERSION',
195 | 'ID': 'patient',
196 | 'SIDE': 'side'}, axis=1)
197 |
198 | tmp['side'] = tmp['side'].apply(lambda s: side_code_to_str[s])
199 | dfs.update({prefix_var: tmp})
200 |
201 | # Set the index to join on
202 | for k, tmp in dfs.items():
203 | dfs[k] = tmp.set_index(['patient', 'side', 'READPRJ'])
204 |
205 | df = pd.concat(dfs.values(), axis=1)
206 | df = df.reset_index()
207 |
208 | # Remove unnecessary columns and reformat the grading info
209 | df_sel = df[['patient', 'side', 'V00XRKL', 'V01XRKL']]
210 | df_sel = (df_sel
211 | .set_index(['patient', 'side'])
212 | .rename({'V00XRKL': 'V00', 'V01XRKL': 'V01'}, axis=1)
213 | .stack()
214 | .reset_index()
215 | .rename({'level_2': 'prefix_var', 0: 'KL'}, axis=1))
216 |
217 | # Select the subset for which the annotations are available
218 | indexers = ['patient', 'side', 'prefix_var']
219 | sel = df_out.set_index(indexers).index.unique()
220 | df_sel = (df_sel
221 | .drop_duplicates(subset=indexers) # There are ~5 duplicates
222 | .set_index(indexers)
223 | .loc[sel, :]
224 | .reset_index())
225 |
226 | df_out = pd.merge(df_out, df_sel, on=indexers, how='left')
227 |
228 | path_output_meta = os.path.join(config['path_root_output'], 'meta_base.csv')
229 | df_out.to_csv(path_output_meta, index=False)
230 |
231 |
232 | if __name__ == '__main__':
233 | main()
234 |
--------------------------------------------------------------------------------
/scripts/runner.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # -------------------------------- Prepare datasets ------------------------------------
4 | (cd ../rocaseg/datasets &&
5 | echo "prepare_dataset" &&
6 | python prepare_dataset_oai_imo.py \
7 | ../../../data_raw/OAI_iMorphics_scans \
8 | ../../../data_raw/OAI_iMorphics_annotations \
9 | ../../../data/91_OAI_iMorphics_full_meta \
10 | --margin 20 \
11 | )
12 |
13 | (cd ../rocaseg/datasets &&
14 | echo "prepare_dataset" &&
15 | python prepare_dataset_okoa.py \
16 | ../../../data_raw/OKOA \
17 | ../../../data/31_OKOA_full_meta \
18 | --margin 0 \
19 | )
20 |
21 | (cd ../rocaseg/datasets &&
22 | echo "prepare_dataset" &&
23 | python prepare_dataset_maknee.py \
24 | ../../../data_raw/MAKNEE \
25 | ../../../data/41_MAKNEE_full_meta \
26 | --margin 0 \
27 | )
28 | # --------------------------------------------------------------------------------------
29 |
30 | # -------------------------------- Resample datasets -----------------------------------
31 | (cd ../rocaseg/ &&
32 | echo "resample" &&
33 | python resample.py \
34 | --path_root_in ../../data/31_OKOA_full_meta \
35 | --spacing_in 0.5859375 0.5859375 \
36 | --path_root_out ../../data/32_OKOA_full_meta_rescaled \
37 | --spacing_out 0.36458333 0.36458333 \
38 | --num_threads 12 \
39 | --margin 0 \
40 | )
41 |
42 | (cd ../rocaseg/ &&
43 | echo "resample" &&
44 | python resample.py \
45 | --path_root_in ../../data/41_MAKNEE_full_meta \
46 | --spacing_in 0.5859375 0.5859375 \
47 | --path_root_out ../../data/42_MAKNEE_full_meta_rescaled \
48 | --spacing_out 0.36458333 0.36458333 \
49 | --num_threads 12 \
50 | --margin 0 \
51 | )
52 | # --------------------------------------------------------------------------------------
53 |
54 | # --------------------------------- Train models ---------------------------------------
55 | (cd ../rocaseg/ &&
56 | echo "train" &&
57 | python train_baseline.py \
58 | --path_data_root ../../data \
59 | --path_experiment_root ../../results/0_baseline \
60 | --model_segm unet_lext \
61 | --input_channels 1 \
62 | --output_channels 5 \
63 | --center_depth 1 \
64 | --lr_segm 0.001 \
65 | --batch_size 32 \
66 | --epoch_num 50 \
67 | --fold_num 5 \
68 | --fold_idx -1 \
69 | --num_workers 12 \
70 | )
71 |
72 | (cd ../rocaseg/ &&
73 | echo "train" &&
74 | python train_baseline.py \
75 | --path_data_root ../../data \
76 | --path_experiment_root ../../results/1_mixup \
77 | --model_segm unet_lext \
78 | --input_channels 1 \
79 | --output_channels 5 \
80 | --center_depth 1 \
81 | --lr_segm 0.001 \
82 | --batch_size 32 \
83 | --epoch_num 50 \
84 | --fold_num 5 \
85 | --fold_idx -1 \
86 | --with_mixup \
87 | --mixup_alpha 0.7 \
88 | --num_workers 12 \
89 | )
90 |
91 | (cd ../rocaseg/ &&
92 | echo "train" &&
93 | python train_baseline.py \
94 | --path_data_root ../../data \
95 | --path_experiment_root ../../results/2_mixup_nowd \
96 | --model_segm unet_lext \
97 | --input_channels 1 \
98 | --output_channels 5 \
99 | --center_depth 1 \
100 | --lr_segm 0.001 \
101 | --wd_segm 0.0 \
102 | --batch_size 32 \
103 | --epoch_num 50 \
104 | --fold_num 5 \
105 | --fold_idx -1 \
106 | --with_mixup \
107 | --mixup_alpha 0.7 \
108 | --num_workers 12 \
109 | )
110 |
111 | (cd ../rocaseg/ &&
112 | echo "train" &&
113 | python train_uda1.py \
114 | --path_data_root ../../data \
115 | --path_experiment_root ../../results/3_uda1 \
116 | --model_segm unet_lext \
117 | --center_depth 1 \
118 | --model_discr discriminator_a \
119 | --input_channels 1 \
120 | --output_channels 5 \
121 | --mask_mode all_unitibial_unimeniscus \
122 | --loss_segm multi_ce_loss \
123 | --lr_segm 0.0001 \
124 | --lr_discr 0.00004 \
125 | --batch_size 32 \
126 | --epoch_num 30 \
127 | --fold_num 5 \
128 | --fold_idx 0 \
129 | --num_workers 12 \
130 | )
131 |
132 | (cd ../rocaseg/ &&
133 | echo "train" &&
134 | python train_uda1.py \
135 | --path_data_root ../../data \
136 | --path_experiment_root ../../results/4_uda2 \
137 | --model_segm unet_lext_aux \
138 | --center_depth 1 \
139 | --model_discr_out discriminator_a \
140 | --model_discr_aux discriminator_a \
141 | --input_channels 1 \
142 | --output_channels 5 \
143 | --mask_mode all_unitibial_unimeniscus \
144 | --loss_segm multi_ce_loss \
145 | --lr_segm 0.0001 \
146 | --lr_discr 0.00004 \
147 | --batch_size 32 \
148 | --epoch_num 30 \
149 | --fold_num 5 \
150 | --fold_idx 0 \
151 | --num_workers 12 \
152 | )
153 |
154 | (cd ../rocaseg/ &&
155 | echo "train" &&
156 | python train_uda1.py \
157 | --path_data_root ../../data \
158 | --path_experiment_root ../../results/5_uda1_mixup_nowd \
159 | --model_segm unet_lext \
160 | --center_depth 1 \
161 | --model_discr discriminator_a \
162 | --input_channels 1 \
163 | --output_channels 5 \
164 | --mask_mode all_unitibial_unimeniscus \
165 | --loss_segm multi_ce_loss \
166 | --lr_segm 0.0001 \
167 | --lr_discr 0.00004 \
168 | --wd_segm 0.0 \
169 | --batch_size 32 \
170 | --epoch_num 30 \
171 | --fold_num 5 \
172 | --fold_idx 0 \
173 | --num_workers 12 \
174 | --with_mixup \
175 | --mixup_alpha 0.7 \
176 | )
177 | # --------------------------------------------------------------------------------------
178 |
179 | # ------------------------------- Run model inference ----------------------------------
180 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 5_uda1_mixup_nowd
181 | do
182 | (cd ../rocaseg/ &&
183 | echo "evaluate" &&
184 | python evaluate.py \
185 | --path_data_root ../../data \
186 | --path_experiment_root ../../results/${EXP} \
187 | --model_segm unet_lext \
188 | --center_depth 1 \
189 | --restore_weights \
190 | --output_channels 5 \
191 | --dataset oai_imo \
192 | --subset test \
193 | --mask_mode all_unitibial_unimeniscus \
194 | --batch_size 64 \
195 | --fold_num 5 \
196 | --fold_idx -1 \
197 | --num_workers 12 \
198 | --predict_folds \
199 | --merge_predictions \
200 | )
201 | done
202 |
203 | for EXP in 4_uda2
204 | do
205 | (cd ../rocaseg/ &&
206 | echo "evaluate" &&
207 | python evaluate.py \
208 | --path_data_root ../../data \
209 | --path_experiment_root ../../results/${EXP} \
210 | --model_segm unet_lext_aux \
211 | --center_depth 1 \
212 | --restore_weights \
213 | --output_channels 5 \
214 | --dataset oai_imo \
215 | --subset test \
216 | --mask_mode all_unitibial_unimeniscus \
217 | --batch_size 64 \
218 | --fold_num 5 \
219 | --fold_idx -1 \
220 | --num_workers 12 \
221 | --predict_folds \
222 | --merge_predictions \
223 | )
224 | done
225 |
226 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 5_uda1_mixup_nowd
227 | do
228 | (cd ../rocaseg/ &&
229 | echo "evaluate" &&
230 | python evaluate.py \
231 | --path_data_root ../../data \
232 | --path_experiment_root ../../results/${EXP} \
233 | --model_segm unet_lext \
234 | --center_depth 1 \
235 | --restore_weights \
236 | --output_channels 5 \
237 | --dataset okoa \
238 | --subset all \
239 | --mask_mode background_femoral_unitibial \
240 | --batch_size 64 \
241 | --fold_num 5 \
242 | --fold_idx -1 \
243 | --num_workers 12 \
244 | --predict_folds \
245 | --merge_predictions \
246 | )
247 | done
248 |
249 | for EXP in 4_uda2
250 | do
251 | (cd ../rocaseg/ &&
252 | echo "evaluate" &&
253 | python evaluate.py \
254 | --path_data_root ../../data \
255 | --path_experiment_root ../../results/${EXP} \
256 | --model_segm unet_lext_aux \
257 | --center_depth 1 \
258 | --restore_weights \
259 | --output_channels 5 \
260 | --dataset okoa \
261 | --subset all \
262 | --mask_mode background_femoral_unitibial \
263 | --batch_size 64 \
264 | --fold_num 5 \
265 | --fold_idx -1 \
266 | --num_workers 12 \
267 | --predict_folds \
268 | --merge_predictions \
269 | )
270 | done
271 | # --------------------------------------------------------------------------------------
272 |
273 | # ------------------------------- Analyze model predictions ----------------------------
274 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 4_uda2 5_uda1_mixup_nowd
275 | do
276 | (cd ../rocaseg/ &&
277 | echo "analyze_predictions" &&
278 | python analyze_predictions_single.py \
279 | --path_experiment_root ../../results/${EXP} \
280 | --dirname_pred mask_foldavg \
281 | --dirname_true mask_prep \
282 | --dataset oai_imo \
283 | --atlas segm \
284 | --ignore_cache \
285 | --num_workers 12 \
286 | )
287 | done
288 |
289 | for EXP in 0_baseline 1_mixup 2_mixup_nowd 3_uda1 4_uda2 5_uda1_mixup_nowd
290 | do
291 | (cd ../rocaseg/ &&
292 | echo "analyze_predictions" &&
293 | python analyze_predictions_single.py \
294 | --path_experiment_root ../../results/${EXP} \
295 | --dirname_pred mask_foldavg \
296 | --dirname_true mask_prep \
297 | --dataset okoa \
298 | --atlas okoa \
299 | --ignore_cache \
300 | --num_workers 12 \
301 | )
302 | done
303 | # --------------------------------------------------------------------------------------
304 |
305 | # --------------------------------- Compare models -------------------------------------
306 | (cd ../rocaseg/ &&
307 | echo "analyze_predictions_multi" &&
308 | python analyze_predictions_multi.py \
309 | --path_results_root ../../results \
310 | --experiment_id 0_baseline \
311 | --experiment_id 2_mixup_nowd \
312 | --experiment_id 4_uda2 \
313 | --dataset oai_imo \
314 | --atlas segm \
315 | --num_workers 12 \
316 | )
317 |
318 | (cd ../rocaseg/ &&
319 | echo "analyze_predictions_multi" &&
320 | python analyze_predictions_multi.py \
321 | --path_results_root ../../results \
322 | --experiment_id 0_baseline \
323 | --experiment_id 2_mixup_nowd \
324 | --experiment_id 4_uda2 \
325 | --dataset okoa \
326 | --atlas okoa \
327 | --num_workers 12 \
328 | )
329 | # --------------------------------------------------------------------------------------
330 |
--------------------------------------------------------------------------------
/rocaseg/preproc/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import cv2
3 | import numpy as np
4 | import math
5 | import torch
6 |
7 |
8 | class DualCompose(object):
9 | def __init__(self, transforms):
10 | self.transforms = transforms
11 |
12 | def __call__(self, img, mask=None):
13 | for t in self.transforms:
14 | img, mask = t(img, mask)
15 | return img, mask
16 |
17 |
18 | class OneOf(object):
19 | def __init__(self, transforms, prob=.5):
20 | self.transforms = transforms
21 | self.prob = prob
22 |
23 | self.state = dict()
24 | self.randomize()
25 |
26 | def __call__(self, img, mask=None):
27 | if self.state['p'] < self.prob:
28 | img, mask = self.state['t'](img, mask)
29 | return img, mask
30 |
31 | def randomize(self):
32 | self.state['p'] = random.random()
33 | self.state['t'] = random.choice(self.transforms)
34 | self.state['t'].prob = 1.
35 |
36 |
37 | class OneOrOther(object):
38 | def __init__(self, first, second, prob=.5):
39 | self.first = first
40 | first.prob = 1.
41 | self.second = second
42 | second.prob = 1.
43 | self.prob = prob
44 |
45 | self.state = dict()
46 | self.randomize()
47 |
48 | def __call__(self, x, mask=None):
49 | if self.state['p'] < self.prob:
50 | x, mask = self.first(x, mask)
51 | else:
52 | x, mask = self.second(x, mask)
53 | return x, mask
54 |
55 | def randomize(self):
56 | self.state['p'] = random.random()
57 |
58 |
59 | class ImageOnly(object):
60 | def __init__(self, transform):
61 | self.transform = transform
62 |
63 | def __call__(self, img, mask=None):
64 | return self.transform(img, None)[0], mask
65 |
66 |
67 | class NoTransform(object):
68 | def __call__(self, *args):
69 | return args
70 |
71 |
72 | class ToTensor(object):
73 | def __call__(self, *args):
74 | return [torch.from_numpy(e) for e in args]
75 |
76 |
77 | class VerticalFlip(object):
78 | def __init__(self, prob=.5):
79 | self.prob = prob
80 |
81 | self.state = dict()
82 | self.randomize()
83 |
84 | def __call__(self, img, mask=None):
85 | """
86 |
87 | Parameters
88 | ----------
89 | img: (ch, d0, d1) ndarray
90 | mask: (ch, d0, d1) ndarray
91 | """
92 | if self.state['p'] < self.prob:
93 | img = np.flip(img, axis=1)
94 | if mask is not None:
95 | mask = np.flip(mask, axis=1)
96 | return img, mask
97 |
98 | def randomize(self):
99 | self.state['p'] = random.random()
100 |
101 |
102 | class HorizontalFlip(object):
103 | def __init__(self, prob=.5):
104 | self.prob = prob
105 |
106 | self.state = dict()
107 | self.randomize()
108 |
109 | def __call__(self, img, mask=None):
110 | """
111 |
112 | Parameters
113 | ----------
114 | img: (ch, d0, d1) ndarray
115 | mask: (ch, d0, d1) ndarray
116 | """
117 | if self.state['p'] < self.prob:
118 | img = np.flip(img, axis=2)
119 | if mask is not None:
120 | mask = np.flip(mask, axis=2)
121 | return img, mask
122 |
123 | def randomize(self):
124 | self.state['p'] = random.random()
125 |
126 |
127 | class Flip(object):
128 | def __init__(self, prob=.5):
129 | self.prob = prob
130 |
131 | self.state = dict()
132 | self.randomize()
133 |
134 | def __call__(self, img, mask=None):
135 | """
136 |
137 | Parameters
138 | ----------
139 | img: (ch, d0, d1) ndarray
140 | mask: (ch, d0, d1) ndarray
141 | """
142 | if self.state['p'] < self.prob:
143 | if self.state['d'] in (-1, 0):
144 | img = np.flip(img, axis=1)
145 | if self.state['d'] in (-1, 1):
146 | img = np.flip(img, axis=2)
147 | if mask is not None:
148 | if self.state['d'] in (-1, 0):
149 | mask = np.flip(mask, axis=1)
150 | if self.state['d'] in (-1, 1):
151 | mask = np.flip(mask, axis=2)
152 | return img, mask
153 |
154 | def randomize(self):
155 | self.state['p'] = random.random()
156 | self.state['d'] = random.randint(-1, 1)
157 |
158 |
159 | class Scale(object):
160 | def __init__(self, ratio_range=(0.7, 1.2), prob=.5):
161 | self.ratio_range = ratio_range
162 | self.prob = prob
163 |
164 | self.state = dict()
165 | self.randomize()
166 |
167 | def __call__(self, img, mask=None):
168 | """
169 |
170 | Parameters
171 | ----------
172 | img: (ch, d0, d1) ndarray
173 | mask: (ch, d0, d1) ndarray
174 | """
175 | if self.state['p'] < self.prob:
176 | ch, d0_i, d1_i = img.shape
177 | d0_o = math.floor(d1_i * self.state['r'])
178 | d0_o = d0_o + d0_o % 2
179 | d1_o = math.floor(d1_i * self.state['r'])
180 | d1_o = d1_o + d1_o % 2
181 |
182 | # img1 = cv2.copyMakeBorder(img, limit, limit, limit, limit,
183 | # borderType=cv2.BORDER_REFLECT_101)
184 | img = np.squeeze(img)
185 | img = cv2.resize(img, (d1_o, d0_o), interpolation=cv2.INTER_LINEAR)
186 | img = img[None, ...]
187 |
188 | if mask is not None:
189 | # msk1 = cv2.copyMakeBorder(mask, limit, limit, limit, limit,
190 | # borderType=cv2.BORDER_REFLECT_101)
191 | tmp = np.empty((mask.shape[0], d1_o, d0_o), dtype=mask.dtype)
192 | for idx_ch, mask_ch in enumerate(mask):
193 | tmp[idx_ch] = cv2.resize(mask_ch, (d1_o, d0_o),
194 | interpolation=cv2.INTER_NEAREST)
195 | mask = tmp
196 | return img, mask
197 |
198 | def randomize(self):
199 | self.state['p'] = random.random()
200 | self.state['r'] = round(random.uniform(*self.ratio_range), 2)
201 |
202 |
203 | class Crop(object):
204 | def __init__(self, output_size):
205 | assert isinstance(output_size, (int, tuple))
206 | if isinstance(output_size, int):
207 | self.output_size = (output_size, output_size)
208 | elif isinstance(output_size, tuple):
209 | self.output_size = output_size
210 | else:
211 | raise ValueError('Incorrect value')
212 | # self.keep_size = keep_size
213 | # self.prob = prob
214 |
215 | self.state = dict()
216 | self.randomize()
217 |
218 | def __call__(self, img, mask=None):
219 | rows_in, cols_in = img.shape[1:]
220 | rows_out, cols_out = self.output_size
221 | rows_out = min(rows_in, rows_out)
222 | cols_out = min(cols_in, cols_out)
223 |
224 | r0 = math.floor(self.state['r0f'] * (rows_in - rows_out))
225 | c0 = math.floor(self.state['c0f'] * (cols_in - cols_out))
226 | r1 = r0 + rows_out
227 | c1 = c0 + cols_out
228 |
229 | img = np.ascontiguousarray(img[:, r0:r1, c0:c1])
230 | if mask is not None:
231 | mask = np.ascontiguousarray(mask[:, r0:r1, c0:c1])
232 | return img, mask
233 |
234 | def randomize(self):
235 | # self.state['p'] = random.random()
236 | self.state['r0f'] = random.random()
237 | self.state['c0f'] = random.random()
238 |
239 |
240 | class CenterCrop(object):
241 | def __init__(self, height, width):
242 | self.height = height
243 | self.width = width
244 |
245 | def __call__(self, img, mask=None):
246 | """
247 |
248 | Parameters
249 | ----------
250 | img: (ch, d0, d1) ndarray
251 | mask: (ch, d0, d1) ndarray
252 | """
253 | c, h, w = img.shape
254 | dy = (h - self.height) // 2
255 | dx = (w - self.width) // 2
256 |
257 | y1 = dy
258 | y2 = y1 + self.height
259 | x1 = dx
260 | x2 = x1 + self.width
261 | img = np.ascontiguousarray(img[:, y1:y2, x1:x2])
262 | if mask is not None:
263 | mask = np.ascontiguousarray(mask[:, y1:y2, x1:x2])
264 |
265 | return img, mask
266 |
267 |
268 | class Pad(object):
269 | def __init__(self, dr, dc, **kwargs):
270 | self.dr = dr
271 | self.dc = dc
272 | self.kwargs = kwargs
273 |
274 | def __call__(self, img, mask=None):
275 | """
276 |
277 | Parameters
278 | ----------
279 | img: (ch, d0, d1) ndarray
280 | mask: (ch, d0, d1) ndarray
281 | """
282 | pad_width = (0, 0), (self.dr,) * 2, (self.dc,) * 2
283 | img = np.pad(img, pad_width, **self.kwargs)
284 | if mask is not None:
285 | mask = np.pad(mask, pad_width, **self.kwargs)
286 | return img, mask
287 |
288 |
289 | class GammaCorrection(object):
290 | def __init__(self, gamma_range=(0.5, 2), prob=0.5):
291 | self.gamma_range = gamma_range
292 | self.prob = prob
293 |
294 | self.state = dict()
295 | self.randomize()
296 |
297 | def __call__(self, image, mask=None):
298 | """
299 |
300 | Parameters
301 | ----------
302 | img: (ch, d0, d1) ndarray
303 | mask: (ch, d0, d1) ndarray
304 | """
305 | if self.state['p'] < self.prob:
306 | image = image ** (1 / self.state['gamma'])
307 | # TODO: implement also for integers
308 | image = np.clip(image, 0, 1)
309 | return image, mask
310 |
311 | def randomize(self):
312 | self.state['p'] = random.random()
313 | self.state['gamma'] = random.uniform(*self.gamma_range)
314 |
315 |
316 | class BilateralFilter(object):
317 | def __init__(self, d, sigma_color, sigma_space, prob=.5):
318 | self.d = d
319 | self.sigma_color = sigma_color
320 | self.sigma_space = sigma_space
321 | self.prob = prob
322 |
323 | self.state = dict()
324 | self.randomize()
325 |
326 | def __call__(self, img, mask=None):
327 | if self.state['p'] < self.prob:
328 | img = np.squeeze(img)
329 | img = cv2.bilateralFilter(img, self.d, self.sigma_color, self.sigma_space)
330 | img = img[None, ...]
331 | return img, mask
332 |
333 | def randomize(self):
334 | self.state['p'] = random.random()
335 |
--------------------------------------------------------------------------------
/rocaseg/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from glob import glob
4 |
5 | import numpy as np
6 | from skimage.color import label2rgb
7 | from skimage import img_as_ubyte
8 | from tqdm import tqdm
9 | import click
10 |
11 | import cv2
12 | import tifffile
13 | import torch
14 | import torch.nn as nn
15 | from torch.utils.data.dataloader import DataLoader
16 |
17 | from rocaseg.datasets import sources_from_path
18 | from rocaseg.components import CheckpointHandler
19 | from rocaseg.components.formats import numpy_to_nifti, png_to_numpy
20 | from rocaseg.models import dict_models
21 | from rocaseg.preproc import *
22 | from rocaseg.repro import set_ultimate_seed
23 |
24 |
25 | # The fix is a workaround to PyTorch multiprocessing issue:
26 | # "RuntimeError: received 0 items of ancdata"
27 | torch.multiprocessing.set_sharing_strategy('file_system')
28 |
29 | cv2.ocl.setUseOpenCL(False)
30 | cv2.setNumThreads(0)
31 |
32 | logging.basicConfig()
33 | logger = logging.getLogger('eval')
34 | logger.setLevel(logging.INFO)
35 |
36 | set_ultimate_seed()
37 |
38 | if torch.cuda.is_available():
39 | maybe_gpu = 'cuda'
40 | else:
41 | maybe_gpu = 'cpu'
42 |
43 |
44 | def predict_folds(config, loader, fold_idcs):
45 | """Evaluate the model versus each fold
46 | """
47 | for fold_idx in fold_idcs:
48 | paths_weights_fold = dict()
49 | paths_weights_fold['segm'] = \
50 | os.path.join(config['path_weights'], 'segm', f'fold_{fold_idx}')
51 |
52 | handlers_ckpt = dict()
53 | handlers_ckpt['segm'] = CheckpointHandler(paths_weights_fold['segm'])
54 |
55 | paths_ckpt_sel = dict()
56 | paths_ckpt_sel['segm'] = handlers_ckpt['segm'].get_last_ckpt()
57 |
58 | # Initialize and configure the model
59 | model = (dict_models[config['model_segm']]
60 | (input_channels=config['input_channels'],
61 | output_channels=config['output_channels'],
62 | center_depth=config['center_depth'],
63 | pretrained=config['pretrained'],
64 | restore_weights=config['restore_weights'],
65 | path_weights=paths_ckpt_sel['segm']))
66 | model = nn.DataParallel(model).to(maybe_gpu)
67 | model.eval()
68 |
69 | with tqdm(total=len(loader), desc=f'Eval, fold {fold_idx}') as prog_bar:
70 | for i, data_batch in enumerate(loader):
71 | xs, ys_true = data_batch['xs'], data_batch['ys']
72 | xs, ys_true = xs.to(maybe_gpu), ys_true.to(maybe_gpu)
73 |
74 | if config['model_segm'] == 'unet_lext':
75 | ys_pred = model(xs)
76 | elif config['model_segm'] == 'unet_lext_aux':
77 | ys_pred, _ = model(xs)
78 | else:
79 | msg = f"Unknown model {config['model_segm']}"
80 | raise ValueError(msg)
81 |
82 | ys_pred_softmax = nn.Softmax(dim=1)(ys_pred)
83 | ys_pred_softmax_np = ys_pred_softmax.detach().to('cpu').numpy()
84 |
85 | data_batch['pred_softmax'] = ys_pred_softmax_np
86 |
87 | # Rearrange the batch
88 | data_dicts = [{k: v[n] for k, v in data_batch.items()}
89 | for n in range(len(data_batch['image']))]
90 |
91 | for k, data_dict in enumerate(data_dicts):
92 | dir_base = os.path.join(
93 | config['path_predicts'],
94 | data_dict['patient'], data_dict['release'], data_dict['sequence'])
95 | fname_base = os.path.splitext(
96 | os.path.basename(data_dict['path_rel_image']))[0]
97 |
98 | # Save the predictions
99 | dir_predicts = os.path.join(dir_base, 'mask_folds')
100 | if not os.path.exists(dir_predicts):
101 | os.makedirs(dir_predicts)
102 |
103 | fname_full = os.path.join(
104 | dir_predicts,
105 | f'{fname_base}_fold_{fold_idx}.tiff')
106 |
107 | tmp = (data_dict['pred_softmax'] * 255).astype(np.uint8, casting='unsafe')
108 | tifffile.imsave(fname_full, tmp, compress=9)
109 |
110 | prog_bar.update(1)
111 |
112 |
113 | def merge_predictions(config, source, loader, dict_fns,
114 | save_plots=False, remove_foldw=False, convert_to_nifti=True):
115 | """Merge the predictions over all folds
116 | """
117 | dir_source_root = source['path_root']
118 | df_meta = loader.dataset.df_meta
119 |
120 | with tqdm(total=len(df_meta), desc='Merge') as prog_bar:
121 | for i, row in df_meta.iterrows():
122 | dir_scan_predicts = os.path.join(
123 | config['path_predicts'],
124 | row['patient'], row['release'], row['sequence'])
125 | dir_image_prep = os.path.join(dir_scan_predicts, 'image_prep')
126 | dir_mask_prep = os.path.join(dir_scan_predicts, 'mask_prep')
127 | dir_mask_folds = os.path.join(dir_scan_predicts, 'mask_folds')
128 | dir_mask_foldavg = os.path.join(dir_scan_predicts, 'mask_foldavg')
129 | dir_vis_foldavg = os.path.join(dir_scan_predicts, 'vis_foldavg')
130 |
131 | for p in (dir_image_prep, dir_mask_prep, dir_mask_folds, dir_mask_foldavg,
132 | dir_vis_foldavg):
133 | if not os.path.exists(p):
134 | os.makedirs(p)
135 |
136 | # Find the corresponding prediction files
137 | fname_base = os.path.splitext(os.path.basename(row['path_rel_image']))[0]
138 |
139 | fnames_pred = glob(os.path.join(dir_mask_folds, f'{fname_base}_fold_*.*'))
140 |
141 | # Read the reference data
142 | image = cv2.imread(
143 | os.path.join(dir_source_root, row['path_rel_image']),
144 | cv2.IMREAD_GRAYSCALE)
145 | image = dict_fns['crop'](image[None, ])[0]
146 | image = np.squeeze(image)
147 | if 'path_rel_mask' in row.index:
148 | ys_true = loader.dataset.read_mask(
149 | os.path.join(dir_source_root, row['path_rel_mask']))
150 | if ys_true is not None:
151 | ys_true = dict_fns['crop'](ys_true)[0]
152 | else:
153 | ys_true = None
154 |
155 | # Read the fold-wise predictions
156 | yss_pred = [tifffile.imread(f) for f in fnames_pred]
157 | ys_pred = np.stack(yss_pred, axis=0).astype(np.float32) / 255
158 | ys_pred = torch.from_numpy(ys_pred).unsqueeze(dim=0)
159 |
160 | # Average the fold predictions
161 | ys_pred = torch.mean(ys_pred, dim=1, keepdim=False)
162 | ys_pred_softmax = ys_pred / torch.sum(ys_pred, dim=1, keepdim=True)
163 | ys_pred_softmax_np = ys_pred_softmax.squeeze().numpy()
164 |
165 | ys_pred_arg_np = ys_pred_softmax_np.argmax(axis=0)
166 |
167 | # Save preprocessed input data
168 | fname_full = os.path.join(dir_image_prep, f'{fname_base}.png')
169 | cv2.imwrite(fname_full, image) # image
170 |
171 | if ys_true is not None:
172 | ys_true = ys_true.astype(np.float32)
173 | ys_true = torch.from_numpy(ys_true).unsqueeze(dim=0)
174 | ys_true_arg_np = ys_true.numpy().squeeze().argmax(axis=0)
175 | fname_full = os.path.join(dir_mask_prep, f'{fname_base}.png')
176 | cv2.imwrite(fname_full, ys_true_arg_np) # mask
177 |
178 | fname_meta = os.path.join(config['path_predicts'], 'meta_dynamic.csv')
179 | if not os.path.exists(fname_meta):
180 | df_meta.to_csv(fname_meta, index=False) # metainfo
181 |
182 | # Save ensemble prediction
183 | fname_full = os.path.join(dir_mask_foldavg, f'{fname_base}.png')
184 | cv2.imwrite(fname_full, ys_pred_arg_np)
185 |
186 | # Save ensemble visualizations
187 | if save_plots:
188 | if ys_true is not None:
189 | fname_full = os.path.join(
190 | dir_vis_foldavg, f"{fname_base}_overlay_mask.png")
191 | save_vis_overlay(image=image,
192 | mask=ys_true_arg_np,
193 | num_classes=config['output_channels'],
194 | fname=fname_full)
195 |
196 | fname_full = os.path.join(
197 | dir_vis_foldavg, f"{fname_base}_overlay_pred.png")
198 | save_vis_overlay(image=image,
199 | mask=ys_pred_arg_np,
200 | num_classes=config['output_channels'],
201 | fname=fname_full)
202 |
203 | if ys_true is not None:
204 | fname_full = os.path.join(
205 | dir_vis_foldavg, f"{fname_base}_overlay_diff.png")
206 | save_vis_mask_diff(image=image,
207 | mask_true=ys_true_arg_np,
208 | mask_pred=ys_pred_arg_np,
209 | fname=fname_full)
210 |
211 | # Remove the fold predictions
212 | if remove_foldw:
213 | for f in fnames_pred:
214 | try:
215 | os.remove(f)
216 | except OSError:
217 | logger.error(f'Cannot remove {f}')
218 | prog_bar.update(1)
219 |
220 | # Convert the results to 3D NIfTI images
221 | if convert_to_nifti:
222 | df_meta = df_meta.sort_values(by=["patient", "release", "sequence", "side"])
223 |
224 | for gb_name, gb_df in tqdm(
225 | df_meta.groupby(["patient", "release", "sequence", "side"]),
226 | desc="Convert to NIfTI"):
227 |
228 | patient, release, sequence, side = gb_name
229 | spacings = (gb_df['pixel_spacing_0'].iloc[0],
230 | gb_df['pixel_spacing_1'].iloc[0],
231 | gb_df['slice_thickness'].iloc[0])
232 |
233 | dir_scan_predicts = os.path.join(config['path_predicts'],
234 | patient, release, sequence)
235 | for result in ("image_prep", "mask_prep", "mask_foldavg"):
236 | pattern = os.path.join(dir_scan_predicts, result, '*.png')
237 | path_nii = os.path.join(dir_scan_predicts, f"{result}.nii")
238 |
239 | # Read and compose 3D image
240 | img = png_to_numpy(pattern_fname_in=pattern, reverse=False)
241 |
242 | # Save to NIfTI
243 | numpy_to_nifti(stack=img, fname_out=path_nii,
244 | spacings=spacings, rcp_to_ras=True)
245 |
246 |
247 | def save_vis_overlay(image, mask, num_classes, fname):
248 | # Add a sample of each class to have consistent class colors
249 | mask[0, :num_classes] = list(range(num_classes))
250 | overlay = label2rgb(label=mask, image=image, bg_label=0,
251 | colors=['orangered', 'gold', 'lime', 'fuchsia'])
252 | # Convert to uint8 to save space
253 | overlay = img_as_ubyte(overlay)
254 | # Save to file
255 | if overlay.ndim == 3:
256 | overlay = overlay[:, :, ::-1]
257 | cv2.imwrite(fname, overlay)
258 |
259 |
260 | def save_vis_mask_diff(image, mask_true, mask_pred, fname):
261 | diff = np.empty_like(mask_true)
262 | diff[(mask_true == mask_pred) & (mask_pred == 0)] = 0 # TN
263 | diff[(mask_true == mask_pred) & (mask_pred != 0)] = 0 # TP
264 | diff[(mask_true != mask_pred) & (mask_pred == 0)] = 2 # FP
265 | diff[(mask_true != mask_pred) & (mask_pred != 0)] = 3 # FN
266 | diff_colors = ('green', 'red', 'yellow')
267 | diff[0, :4] = [0, 1, 2, 3]
268 | overlay = label2rgb(label=diff, image=image, bg_label=0,
269 | colors=diff_colors)
270 | # Convert to uint8 to save space
271 | overlay = img_as_ubyte(overlay)
272 | # Save to file
273 | if overlay.ndim == 3:
274 | overlay = overlay[:, :, ::-1]
275 | cv2.imwrite(fname, overlay)
276 |
277 |
278 | @click.command()
279 | @click.option('--path_data_root', default='../../data')
280 | @click.option('--path_experiment_root', default='../../results/temporary')
281 | @click.option('--model_segm', default='unet_lext')
282 | @click.option('--center_depth', default=1, type=int)
283 | @click.option('--pretrained', is_flag=True)
284 | @click.option('--restore_weights', is_flag=True)
285 | @click.option('--input_channels', default=1, type=int)
286 | @click.option('--output_channels', default=1, type=int)
287 | @click.option('--dataset', type=click.Choice(
288 | ['oai_imo', 'okoa', 'maknee']))
289 | @click.option('--subset', type=click.Choice(
290 | ['test', 'all']))
291 | @click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
292 | @click.option('--sample_mode', default='x_y', type=str)
293 | @click.option('--batch_size', default=64, type=int)
294 | @click.option('--fold_num', default=5, type=int)
295 | @click.option('--fold_idx', default=-1, type=int)
296 | @click.option('--fold_idx_ignore', multiple=True, type=int)
297 | @click.option('--num_workers', default=1, type=int)
298 | @click.option('--seed_trainval_test', default=0, type=int)
299 | @click.option('--predict_folds', is_flag=True)
300 | @click.option('--merge_predictions', is_flag=True)
301 | @click.option('--save_plots', is_flag=True)
302 | def main(**config):
303 | config['path_data_root'] = os.path.abspath(config['path_data_root'])
304 | config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
305 |
306 | config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
307 | if not os.path.exists(config['path_weights']):
308 | raise ValueError('{} does not exist'.format(config['path_weights']))
309 |
310 | config['path_predicts'] = os.path.join(
311 | config['path_experiment_root'], f"predicts_{config['dataset']}_test")
312 | config['path_logs'] = os.path.join(
313 | config['path_experiment_root'], f"logs_{config['dataset']}_test")
314 |
315 | os.makedirs(config['path_predicts'], exist_ok=True)
316 | os.makedirs(config['path_logs'], exist_ok=True)
317 |
318 | logging_fh = logging.FileHandler(
319 | os.path.join(config['path_logs'], 'main.log'))
320 | logging_fh.setLevel(logging.DEBUG)
321 | logger.addHandler(logging_fh)
322 |
323 | # Collect the available and specified sources
324 | sources = sources_from_path(path_data_root=config['path_data_root'],
325 | selection=config['dataset'],
326 | with_folds=True,
327 | seed_trainval_test=config['seed_trainval_test'])
328 |
329 | # Select the subset for evaluation
330 | if config['subset'] == 'test':
331 | logging.warning('Using the regular trainval-test split')
332 | elif config['subset'] == 'all':
333 | logging.warning('Using data selection: full dataset')
334 | for s in sources:
335 | sources[s]['test_df'] = sources[s]['sel_df']
336 | logger.info(f"Selected number of samples: {len(sources[s]['test_df'])}")
337 | else:
338 | raise ValueError(f"Unknown dataset: {config['subset']}")
339 |
340 | if config['dataset'] == 'oai_imo':
341 | from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d
342 | elif config['dataset'] == 'okoa':
343 | from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d
344 | elif config['dataset'] == 'maknee':
345 | from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d
346 | else:
347 | raise ValueError(f"Unknown dataset: {config['dataset']}")
348 |
349 | # Configure dataset-dependent transforms
350 | fn_crop = CenterCrop(height=300, width=300)
351 | if config['dataset'] == 'oai_imo':
352 | fn_norm = Normalize(mean=0.252699, std=0.251142)
353 | fn_unnorm = UnNormalize(mean=0.252699, std=0.251142)
354 | elif config['dataset'] == 'okoa':
355 | fn_norm = Normalize(mean=0.232454, std=0.236259)
356 | fn_unnorm = UnNormalize(mean=0.232454, std=0.236259)
357 | else:
358 | msg = f"No transforms defined for dataset: {config['dataset']}"
359 | raise NotImplementedError(msg)
360 | dict_fns = {'crop': fn_crop, 'norm': fn_norm, 'unnorm': fn_unnorm}
361 |
362 | dataset_test = DatasetSagittal2d(
363 | df_meta=sources[config['dataset']]['test_df'], mask_mode=config['mask_mode'],
364 | name=config['dataset'], sample_mode=config['sample_mode'],
365 | transforms=[
366 | PercentileClippingAndToFloat(cut_min=10, cut_max=99),
367 | fn_crop,
368 | fn_norm,
369 | ToTensor()
370 | ])
371 | loader_test = DataLoader(dataset_test,
372 | batch_size=config['batch_size'],
373 | shuffle=False,
374 | num_workers=config['num_workers'],
375 | drop_last=False)
376 |
377 | # Build a list of folds to run on
378 | if config['fold_idx'] == -1:
379 | fold_idcs = list(range(config['fold_num']))
380 | else:
381 | fold_idcs = [config['fold_idx'], ]
382 | for g in config['fold_idx_ignore']:
383 | fold_idcs = [i for i in fold_idcs if i != g]
384 |
385 | # Execute
386 | with torch.no_grad():
387 | if config['predict_folds']:
388 | predict_folds(config=config, loader=loader_test, fold_idcs=fold_idcs)
389 |
390 | if config['merge_predictions']:
391 | merge_predictions(config=config, source=sources[config['dataset']],
392 | loader=loader_test, dict_fns=dict_fns,
393 | save_plots=config['save_plots'], remove_foldw=False,
394 | convert_to_nifti=True)
395 |
396 |
397 | if __name__ == '__main__':
398 | main()
399 |
--------------------------------------------------------------------------------
/rocaseg/train_baseline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from collections import defaultdict
4 | import click
5 |
6 | import numpy as np
7 | import cv2
8 |
9 | import torch
10 | from torch import nn
11 | from torch.utils.data.dataloader import DataLoader
12 | from torch.utils.tensorboard import SummaryWriter
13 | from tqdm import tqdm
14 |
15 | from rocaseg.datasets import DatasetOAIiMoSagittal2d, sources_from_path
16 | from rocaseg.models import dict_models
17 | from rocaseg.components import (dict_losses, confusion_matrix, dice_score_from_cm,
18 | dict_optimizers, CheckpointHandler)
19 | from rocaseg.preproc import *
20 | from rocaseg.repro import set_ultimate_seed
21 | from rocaseg.components.mixup import mixup_criterion, mixup_data
22 |
23 |
24 | cv2.ocl.setUseOpenCL(False)
25 | cv2.setNumThreads(0)
26 |
27 | logging.basicConfig()
28 | logger = logging.getLogger('train')
29 | logger.setLevel(logging.DEBUG)
30 |
31 | set_ultimate_seed()
32 |
33 | if torch.cuda.is_available():
34 | maybe_gpu = 'cuda'
35 | else:
36 | maybe_gpu = 'cpu'
37 |
38 |
39 | class ModelTrainer:
40 | def __init__(self, config, fold_idx=None):
41 | self.config = config
42 | self.fold_idx = fold_idx
43 |
44 | self.paths_weights_fold = dict()
45 | self.paths_weights_fold['segm'] = \
46 | os.path.join(config['path_weights'], 'segm', f'fold_{self.fold_idx}')
47 | os.makedirs(self.paths_weights_fold['segm'], exist_ok=True)
48 |
49 | self.path_logs_fold = \
50 | os.path.join(config['path_logs'], f'fold_{self.fold_idx}')
51 | os.makedirs(self.path_logs_fold, exist_ok=True)
52 |
53 | self.handlers_ckpt = dict()
54 | self.handlers_ckpt['segm'] = CheckpointHandler(self.paths_weights_fold['segm'])
55 |
56 | paths_ckpt_sel = dict()
57 | paths_ckpt_sel['segm'] = self.handlers_ckpt['segm'].get_last_ckpt()
58 |
59 | # Initialize and configure the models
60 | self.models = dict()
61 | self.models['segm'] = (dict_models[config['model_segm']]
62 | (input_channels=self.config['input_channels'],
63 | output_channels=self.config['output_channels'],
64 | center_depth=self.config['center_depth'],
65 | pretrained=self.config['pretrained'],
66 | path_pretrained=self.config['path_pretrained_segm'],
67 | restore_weights=self.config['restore_weights'],
68 | path_weights=paths_ckpt_sel['segm']))
69 | self.models['segm'] = nn.DataParallel(self.models['segm'])
70 | self.models['segm'] = self.models['segm'].to(maybe_gpu)
71 |
72 | # Configure the training
73 | self.optimizers = dict()
74 | self.optimizers['segm'] = (dict_optimizers['adam'](
75 | self.models['segm'].parameters(),
76 | lr=self.config['lr_segm'],
77 | weight_decay=self.config['wd_segm']))
78 |
79 | self.lr_update_rule = {30: 0.1}
80 |
81 | self.losses = dict()
82 | self.losses['segm'] = dict_losses[self.config['loss_segm']](
83 | num_classes=self.config['output_channels'],
84 | )
85 |
86 | self.losses['segm'] = self.losses['segm'].to(maybe_gpu)
87 |
88 | self.tensorboard = SummaryWriter(self.path_logs_fold)
89 |
90 | def run_one_epoch(self, epoch_idx, loaders):
91 | name_ds = list(loaders.keys())[0]
92 |
93 | fnames_acc = defaultdict(list)
94 | metrics_acc = dict()
95 | metrics_acc['samplew'] = defaultdict(list)
96 | metrics_acc['batchw'] = defaultdict(list)
97 | metrics_acc['datasetw'] = defaultdict(list)
98 | metrics_acc['datasetw'][f'{name_ds}__cm'] = \
99 | np.zeros((self.config['output_channels'],) * 2, dtype=np.uint32)
100 |
101 | prog_bar_params = {'postfix': {'epoch': epoch_idx}, }
102 |
103 | if self.models['segm'].training:
104 | # ------------------------ Training regime ------------------------
105 | loader_ds = loaders[name_ds]['train']
106 |
107 | steps_ds = len(loader_ds)
108 | prog_bar_params.update({'total': steps_ds,
109 | 'desc': f'Train, epoch {epoch_idx}'})
110 |
111 | loader_ds_iter = iter(loader_ds)
112 |
113 | with tqdm(**prog_bar_params) as prog_bar:
114 | for step_idx in range(steps_ds):
115 | self.optimizers['segm'].zero_grad()
116 |
117 | data_batch_ds = next(loader_ds_iter)
118 |
119 | xs_ds, ys_true_ds = data_batch_ds['xs'], data_batch_ds['ys']
120 | fnames_acc['oai'].extend(data_batch_ds['path_image'])
121 |
122 | ys_true_arg_ds = torch.argmax(ys_true_ds.long(), dim=1)
123 | xs_ds = xs_ds.to(maybe_gpu)
124 | ys_true_arg_ds = ys_true_arg_ds.to(maybe_gpu)
125 |
126 | if not self.config['with_mixup']:
127 | ys_pred_ds = self.models['segm'](xs_ds)
128 |
129 | loss_segm = self.losses['segm'](input_=ys_pred_ds,
130 | target=ys_true_arg_ds)
131 | else:
132 | xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
133 | x=xs_ds, y=ys_true_arg_ds,
134 | alpha=self.config['mixup_alpha'], device=maybe_gpu)
135 | ys_pred_ds = self.models['segm'](xs_mixup)
136 | loss_segm = mixup_criterion(criterion=self.losses['segm'],
137 | pred=ys_pred_ds,
138 | y_a=ys_mixup_a,
139 | y_b=ys_mixup_b,
140 | lam=lambda_mixup)
141 |
142 | metrics_acc['batchw']['loss'].append(loss_segm.item())
143 |
144 | loss_segm.backward()
145 | self.optimizers['segm'].step()
146 |
147 | prog_bar.update(1)
148 | else:
149 | # ----------------------- Validation regime -----------------------
150 | loader_ds = loaders[name_ds]['val']
151 |
152 | steps_ds = len(loader_ds)
153 | prog_bar_params.update({'total': steps_ds,
154 | 'desc': f'Validate, epoch {epoch_idx}'})
155 |
156 | loader_ds_iter = iter(loader_ds)
157 |
158 | with torch.no_grad(), tqdm(**prog_bar_params) as prog_bar:
159 | for step_idx in range(steps_ds):
160 | data_batch_ds = next(loader_ds_iter)
161 |
162 | xs_ds, ys_true_ds = data_batch_ds['xs'], data_batch_ds['ys']
163 | fnames_acc['oai'].extend(data_batch_ds['path_image'])
164 |
165 | ys_true_arg_ds = torch.argmax(ys_true_ds.long(), dim=1)
166 | xs_ds = xs_ds.to(maybe_gpu)
167 | ys_true_arg_ds = ys_true_arg_ds.to(maybe_gpu)
168 |
169 | if not self.config['with_mixup']:
170 | ys_pred_ds = self.models['segm'](xs_ds)
171 | loss_segm = self.losses['segm'](input_=ys_pred_ds,
172 | target=ys_true_arg_ds)
173 | else:
174 | xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
175 | x=xs_ds, y=ys_true_arg_ds,
176 | alpha=self.config['mixup_alpha'], device=maybe_gpu)
177 |
178 | ys_pred_ds = self.models['segm'](xs_mixup)
179 | loss_segm = mixup_criterion(criterion=self.losses['segm'],
180 | pred=ys_pred_ds,
181 | y_a=ys_mixup_a,
182 | y_b=ys_mixup_b,
183 | lam=lambda_mixup)
184 |
185 | metrics_acc['batchw']['loss'].append(loss_segm.item())
186 |
187 | # Calculate metrics
188 | ys_pred_softmax_ds = nn.Softmax(dim=1)(ys_pred_ds)
189 | ys_pred_softmax_np_ds = ys_pred_softmax_ds.to('cpu').numpy()
190 |
191 | ys_pred_arg_np_ds = ys_pred_softmax_np_ds.argmax(axis=1)
192 | ys_true_arg_np_ds = ys_true_arg_ds.to('cpu').numpy()
193 |
194 | metrics_acc['datasetw'][f'{name_ds}__cm'] += confusion_matrix(
195 | ys_pred_arg_np_ds, ys_true_arg_np_ds,
196 | self.config['output_channels'])
197 |
198 | prog_bar.update(1)
199 |
200 | for k, v in metrics_acc['samplew'].items():
201 | metrics_acc['samplew'][k] = np.asarray(v)
202 | metrics_acc['datasetw'][f'{name_ds}__dice_score'] = np.asarray(
203 | dice_score_from_cm(metrics_acc['datasetw'][f'{name_ds}__cm']))
204 | return metrics_acc, fnames_acc
205 |
206 | def fit(self, loaders):
207 | epoch_idx_best = -1
208 | loss_best = float('inf')
209 | metrics_train_best = dict()
210 | fnames_train_best = []
211 | metrics_val_best = dict()
212 | fnames_val_best = []
213 |
214 | for epoch_idx in range(self.config['epoch_num']):
215 | self.models = {n: m.train() for n, m in self.models.items()}
216 | metrics_train, fnames_train = \
217 | self.run_one_epoch(epoch_idx, loaders)
218 |
219 | # Process the accumulated metrics
220 | for k, v in metrics_train['batchw'].items():
221 | if k.startswith('loss'):
222 | metrics_train['datasetw'][k] = np.mean(np.asarray(v))
223 | else:
224 | logger.warning(f'Non-processed batch-wise entry: {k}')
225 |
226 | self.models = {n: m.eval() for n, m in self.models.items()}
227 | metrics_val, fnames_val = \
228 | self.run_one_epoch(epoch_idx, loaders)
229 |
230 | # Process the accumulated metrics
231 | for k, v in metrics_val['batchw'].items():
232 | if k.startswith('loss'):
233 | metrics_val['datasetw'][k] = np.mean(np.asarray(v))
234 | else:
235 | logger.warning(f'Non-processed batch-wise entry: {k}')
236 |
237 | # Learning rate update
238 | for s, m in self.lr_update_rule.items():
239 | if epoch_idx == s:
240 | for name, optim in self.optimizers.items():
241 | for param_group in optim.param_groups:
242 | param_group['lr'] *= m
243 |
244 | # Add console logging
245 | logger.info(f'Epoch: {epoch_idx}')
246 | for subset, metrics in (('train', metrics_train),
247 | ('val', metrics_val)):
248 | logger.info(f'{subset} metrics:')
249 | for k, v in metrics['datasetw'].items():
250 | logger.info(f'{k}: \n{v}')
251 |
252 | # Add TensorBoard logging
253 | for subset, metrics in (('train', metrics_train),
254 | ('val', metrics_val)):
255 | # Log only dataset-reduced metrics
256 | for k, v in metrics['datasetw'].items():
257 | if isinstance(v, np.ndarray):
258 | self.tensorboard.add_scalars(
259 | f'fold_{self.fold_idx}/{k}_{subset}',
260 | {f'class{i}': e for i, e in enumerate(v.ravel().tolist())},
261 | global_step=epoch_idx)
262 | elif isinstance(v, (str, int, float)):
263 | self.tensorboard.add_scalar(
264 | f'fold_{self.fold_idx}/{k}_{subset}',
265 | float(v),
266 | global_step=epoch_idx)
267 | else:
268 | logger.warning(f'{k} is of unsupported dtype {v}')
269 | for name, optim in self.optimizers.items():
270 | for param_group in optim.param_groups:
271 | self.tensorboard.add_scalar(
272 | f'fold_{self.fold_idx}/learning_rate/{name}',
273 | param_group['lr'],
274 | global_step=epoch_idx)
275 |
276 | # Save the model
277 | loss_curr = metrics_val['datasetw']['loss']
278 | if loss_curr < loss_best:
279 | loss_best = loss_curr
280 | epoch_idx_best = epoch_idx
281 | metrics_train_best = metrics_train
282 | metrics_val_best = metrics_val
283 | fnames_train_best = fnames_train
284 | fnames_val_best = fnames_val
285 |
286 | self.handlers_ckpt['segm'].save_new_ckpt(
287 | model=self.models['segm'],
288 | model_name=self.config['model_segm'],
289 | fold_idx=self.fold_idx,
290 | epoch_idx=epoch_idx)
291 |
292 | msg = (f'Finished fold {self.fold_idx} '
293 | f'with the best loss {loss_best:.5f} '
294 | f'on epoch {epoch_idx_best}, '
295 | f'weights: ({self.paths_weights_fold})')
296 | logger.info(msg)
297 | return (metrics_train_best, fnames_train_best,
298 | metrics_val_best, fnames_val_best)
299 |
300 |
301 | @click.command()
302 | @click.option('--path_data_root', default='../../data')
303 | @click.option('--path_experiment_root', default='../../results/temporary')
304 | @click.option('--model_segm', default='unet_lext')
305 | @click.option('--center_depth', default=1, type=int)
306 | @click.option('--pretrained', is_flag=True)
307 | @click.option('--path_pretrained_segm', type=str, help='Path to .pth file')
308 | @click.option('--restore_weights', is_flag=True)
309 | @click.option('--input_channels', default=1, type=int)
310 | @click.option('--output_channels', default=1, type=int)
311 | @click.option('--dataset', type=click.Choice(
312 | ['oai_imo', 'okoa', 'maknee']), default='oai_imo')
313 | @click.option('--mask_mode', default='all_unitibial_unimeniscus', type=str)
314 | @click.option('--sample_mode', default='x_y', type=str)
315 | @click.option('--loss_segm', default='multi_ce_loss')
316 | @click.option('--lr_segm', default=0.0001, type=float)
317 | @click.option('--wd_segm', default=5e-5, type=float)
318 | @click.option('--optimizer_segm', default='adam')
319 | @click.option('--batch_size', default=64, type=int)
320 | @click.option('--epoch_size', default=1.0, type=float)
321 | @click.option('--epoch_num', default=2, type=int)
322 | @click.option('--fold_num', default=5, type=int)
323 | @click.option('--fold_idx', default=-1, type=int)
324 | @click.option('--fold_idx_ignore', multiple=True, type=int)
325 | @click.option('--num_workers', default=1, type=int)
326 | @click.option('--seed_trainval_test', default=0, type=int)
327 | @click.option('--with_mixup', is_flag=True)
328 | @click.option('--mixup_alpha', default=1, type=float)
329 | def main(**config):
330 | config['path_data_root'] = os.path.abspath(config['path_data_root'])
331 | config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])
332 |
333 | config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
334 | config['path_logs'] = os.path.join(config['path_experiment_root'], 'logs_train')
335 | os.makedirs(config['path_weights'], exist_ok=True)
336 | os.makedirs(config['path_logs'], exist_ok=True)
337 |
338 | logging_fh = logging.FileHandler(
339 | os.path.join(config['path_logs'], 'main_{}.log'.format(config['fold_idx'])))
340 | logging_fh.setLevel(logging.DEBUG)
341 | logger.addHandler(logging_fh)
342 |
343 | # Collect the available and specified sources
344 | sources = sources_from_path(path_data_root=config['path_data_root'],
345 | selection=config['dataset'],
346 | with_folds=True,
347 | fold_num=config['fold_num'],
348 | seed_trainval_test=config['seed_trainval_test'])
349 |
350 | # Build a list of folds to run on
351 | if config['fold_idx'] == -1:
352 | fold_idcs = list(range(config['fold_num']))
353 | else:
354 | fold_idcs = [config['fold_idx'], ]
355 | for g in config['fold_idx_ignore']:
356 | fold_idcs = [i for i in fold_idcs if i != g]
357 |
358 | # Train each fold separately
359 | fold_scores = dict()
360 |
361 | # Use straightforward fold allocation strategy
362 | folds = list(sources[config['dataset']]['trainval_folds'])
363 |
364 | for fold_idx, idcs_subsets in enumerate(folds):
365 | if fold_idx not in fold_idcs:
366 | continue
367 | logger.info(f'Training fold {fold_idx}')
368 |
369 | name_ds = config['dataset']
370 |
371 | (sources[name_ds]['train_idcs'], sources[name_ds]['val_idcs']) = idcs_subsets
372 |
373 | sources[name_ds]['train_df'] = \
374 | sources[name_ds]['trainval_df'].iloc[sources[name_ds]['train_idcs']]
375 | sources[name_ds]['val_df'] = \
376 | sources[name_ds]['trainval_df'].iloc[sources[name_ds]['val_idcs']]
377 |
378 | for n, s in sources.items():
379 | logger.info('Made {} train-val split, number of samples: {}, {}'
380 | .format(n, len(s['train_df']), len(s['val_df'])))
381 |
382 | datasets = defaultdict(dict)
383 |
384 | datasets[name_ds]['train'] = DatasetOAIiMoSagittal2d(
385 | df_meta=sources[name_ds]['train_df'],
386 | mask_mode=config['mask_mode'],
387 | sample_mode=config['sample_mode'],
388 | transforms=[
389 | PercentileClippingAndToFloat(cut_min=10, cut_max=99),
390 | CenterCrop(height=300, width=300),
391 | HorizontalFlip(prob=.5),
392 | GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
393 | OneOf([
394 | DualCompose([
395 | Scale(ratio_range=(0.7, 0.8), prob=1.),
396 | Scale(ratio_range=(1.5, 1.6), prob=1.),
397 | ]),
398 | NoTransform()
399 | ]),
400 | Crop(output_size=(300, 300)),
401 | BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
402 | Normalize(mean=0.252699, std=0.251142),
403 | ToTensor(),
404 | ])
405 | datasets[name_ds]['val'] = DatasetOAIiMoSagittal2d(
406 | df_meta=sources[name_ds]['val_df'],
407 | mask_mode=config['mask_mode'],
408 | sample_mode=config['sample_mode'],
409 | transforms=[
410 | PercentileClippingAndToFloat(cut_min=10, cut_max=99),
411 | CenterCrop(height=300, width=300),
412 | Normalize(mean=0.252699, std=0.251142),
413 | ToTensor()
414 | ])
415 |
416 | loaders = defaultdict(dict)
417 |
418 | loaders[name_ds]['train'] = DataLoader(
419 | datasets[name_ds]['train'],
420 | batch_size=config['batch_size'],
421 | shuffle=True,
422 | num_workers=config['num_workers'],
423 | drop_last=True)
424 | loaders[name_ds]['val'] = DataLoader(
425 | datasets[name_ds]['val'],
426 | batch_size=config['batch_size'],
427 | shuffle=False,
428 | num_workers=config['num_workers'],
429 | drop_last=True)
430 |
431 | trainer = ModelTrainer(config=config, fold_idx=fold_idx)
432 |
433 | # INFO: run once before the training to compute the dataset statistics
434 | # dataset_train.describe()
435 |
436 | tmp = trainer.fit(loaders=loaders)
437 | metrics_train, fnames_train, metrics_val, fnames_val = tmp
438 |
439 | fold_scores[fold_idx] = (metrics_val['datasetw'][f'{name_ds}__dice_score'], )
440 |
441 | trainer.tensorboard.close()
442 | logger.info(f'Fold scores:\n{repr(fold_scores)}')
443 |
444 |
445 | if __name__ == '__main__':
446 | main()
447 |
--------------------------------------------------------------------------------
/rocaseg/analyze_predictions_single.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from tqdm import tqdm
4 | import click
5 | import logging
6 | from joblib import Parallel, delayed
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import matplotlib.pyplot as plt
11 | import seaborn as sns
12 |
13 | from rocaseg.components import dice_score
14 | from rocaseg.components.formats import png_to_numpy
15 | from rocaseg.datasets.constants import atlas_to_locations
16 |
17 |
18 | logging.basicConfig()
19 | logger = logging.getLogger('analyze')
20 | logger.setLevel(logging.INFO)
21 |
22 |
23 | def metrics_paired_slicew(path_pred, dirname_true, dirname_pred,
24 | df, num_classes, num_workers=1):
25 | def _process_3d_pair(path_root, name_true, name_pred, meta):
26 | # Read the data
27 | if meta['side'] == 'RIGHT':
28 | reverse = True
29 | else:
30 | reverse = False
31 |
32 | patt_true = os.path.join(path_root, meta['patient'], meta['release'],
33 | meta['sequence'], name_true, '*.png')
34 | stack_true = png_to_numpy(patt_true, reverse=reverse)
35 |
36 | patt_pred = os.path.join(path_root, meta['patient'], meta['release'],
37 | meta['sequence'], name_pred, '*.png')
38 | stack_pred = png_to_numpy(patt_pred, reverse=reverse)
39 | if stack_true.shape != stack_pred.shape:
40 | msg = (f'Reference and predictions samples are of different shape: '
41 | f'{name_true}: {stack_true.shape}, {name_pred}: {stack_pred.shape}')
42 | raise ValueError(msg)
43 | num_slices = stack_true.shape[-1]
44 |
45 | # Add batch dimension
46 | stack_true = stack_true[None, ...]
47 | stack_pred = stack_pred[None, ...]
48 |
49 | # Compute the metrics
50 | res = []
51 | for slice_idx in range(num_slices):
52 | tmp = {
53 | 'dice_score': dice_score(stack_pred[..., slice_idx],
54 | stack_true[..., slice_idx],
55 | num_classes=num_classes),
56 | 'slice_idx_proc': slice_idx,
57 | **meta,
58 | }
59 | res.append(tmp)
60 | return res
61 |
62 | acc_ls = []
63 | groupers_stack = ['patient', 'release', 'sequence', 'side']
64 |
65 | for name_gb, df_gb in tqdm(df.groupby(groupers_stack)):
66 | patient, release, sequence, side = name_gb
67 |
68 | acc_ls.append(delayed(_process_3d_pair)(
69 | path_root=path_pred,
70 | name_true=dirname_true,
71 | name_pred=dirname_pred,
72 | meta={
73 | 'patient': patient,
74 | 'release': release,
75 | 'sequence': sequence,
76 | 'side': side,
77 | **df_gb.to_dict("records")[0],
78 | }
79 | ))
80 |
81 | acc_ls = Parallel(n_jobs=num_workers, verbose=1)(acc_ls)
82 | acc_l = []
83 | for acc in acc_ls:
84 | acc_l.extend(acc)
85 |
86 | # Convert from list of dicts to dict of lists
87 | acc_d = {k: [d[k] for d in acc_l] for k in acc_l[0]}
88 | return acc_d
89 |
90 |
91 | def plot_metrics_paired_slicew_xvd7(acc, path_root_out, num_classes, class_names,
92 | metric_names, config):
93 | """Average and visualize the metrics.
94 |
95 | Args:
96 | acc: dict of lists
97 | path_root_out: str
98 | num_classes: int
99 | class_names: dict
100 | metric_names: iterable of str
101 | config: dict
102 |
103 | """
104 | for metric_name in metric_names:
105 | if metric_name not in acc:
106 | logger.error(f'`{metric_name}` is not presented in `acc`')
107 | continue
108 | metric_values = acc[metric_name]
109 |
110 | ncols = 2
111 | nrows = (num_classes - 1) // ncols + 1
112 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 12))
113 | axes = axes.ravel()
114 |
115 | tmp = np.stack(metric_values, axis=-1) # shape: (class_idx, sample_idx)
116 | tmp_means, tmp_stds = [], []
117 |
118 | for class_idx, class_scores in enumerate(tmp[1:], start=1):
119 | class_idx_axes = class_idx - 1
120 | class_scores = class_scores[class_scores != 0]
121 |
122 | if np.any(np.isnan(class_scores)):
123 | logger.warning('NaN score')
124 |
125 | tmp_means.append(np.mean(class_scores))
126 | tmp_stds.append(np.std(class_scores))
127 |
128 | axes[class_idx_axes].hist(class_scores, bins=40,
129 | range=(0.3, 1.0), density=True)
130 | axes[class_idx_axes].set_title(class_names[class_idx])
131 | axes[class_idx_axes].set_ylim((0, 10))
132 |
133 | tmp_values = ', '.join(
134 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)])
135 | logger.info(f"{metric_name}:\n{tmp_values}\n")
136 |
137 | plt.tight_layout()
138 | fname_vis = f"metrics_{config['dataset']}_test_slicew_{metric_name}.png"
139 | path_vis = os.path.join(path_root_out, fname_vis)
140 | plt.savefig(path_vis)
141 | if config['interactive']:
142 | plt.show()
143 | else:
144 | plt.close()
145 |
146 |
147 | def print_metrics_paired_slicew_pua5(acc, metric_names):
148 | """Average and print the metrics with respect to KL.
149 |
150 | Args:
151 | acc: dict of lists
152 | metric_names: iterable of str
153 |
154 | """
155 | kl_vec = np.asarray(acc['KL'])
156 |
157 | for kl_value in np.unique(kl_vec):
158 | kl_sel = kl_vec == kl_value
159 |
160 | for metric_name in metric_names:
161 | if metric_name not in acc:
162 | logger.error(f'`{metric_name}` is not presented in `acc`')
163 | continue
164 | metric_values = acc[metric_name]
165 | tmp = np.stack(metric_values, axis=-1)[..., kl_sel]
166 | # shape: (class_idx, sample_idx)
167 |
168 | tmp_means, tmp_stds = [], []
169 | for class_idx, class_scores in enumerate(tmp[1:], start=1):
170 | class_scores = class_scores[class_scores != 0]
171 |
172 | tmp_means.append(np.mean(class_scores))
173 | tmp_stds.append(np.std(class_scores))
174 |
175 | tmp_values = ', '.join(
176 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)])
177 | logger.info(f"KL{kl_value}, {metric_name}:\n{tmp_values}\n")
178 |
179 |
180 | def plot_metrics_paired_slicew_va49(acc, path_root_out, num_classes, class_names,
181 | metric_names, config):
182 | """Average and visualize the metrics.
183 |
184 | Args:
185 | acc:
186 | path_root_out:
187 | num_classes:
188 | class_names:
189 | metric_names:
190 | config:
191 |
192 | """
193 | groupers_stack = ['patient', 'release', 'sequence', 'side']
194 |
195 | acc_df = pd.DataFrame.from_dict(acc)
196 | acc_df = acc_df.sort_values([*groupers_stack, 'slice_idx_proc'])
197 |
198 | for metric_name in metric_names:
199 | if metric_name not in acc:
200 | logger.error(f'`{metric_name}` is not presented in `acc`')
201 | continue
202 |
203 | metric_values = []
204 | for gb_name, gb_df in acc_df.groupby(groupers_stack):
205 | tmp = np.asarray([np.asarray(e) for e in gb_df[metric_name]])
206 | metric_values.append(tmp)
207 | metric_values = np.stack(metric_values, axis=0)
208 | # Axes order: (scan, slice_idx, class_idx)
209 |
210 | nrows = np.floor(np.sqrt(num_classes)).astype(int)
211 | ncols = np.ceil(float(num_classes) / nrows).astype(int)
212 |
213 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10))
214 | ax = axes.ravel()
215 | plt.title(metric_name)
216 |
217 | for class_idx in range(1, num_classes):
218 | class_idx_axes = class_idx - 1
219 |
220 | y = metric_values[..., class_idx]
221 | x = np.tile(np.arange(0, y.shape[1]), reps=(y.shape[0], 1))
222 | y = y.ravel()
223 | x = x.ravel()
224 |
225 | sel = (y != 0)
226 | y = y[sel]
227 | x = x[sel]
228 |
229 | sns.lineplot(x=x, y=y, err_style='band', ax=ax[class_idx_axes],
230 | color="salmon")
231 | if class_idx != 0:
232 | ax[class_idx_axes].set_ylim((0.4, 1))
233 | ax[class_idx_axes].set_xlim((10, 150))
234 | ax[class_idx_axes].set_xlabel('Slice index') #, size=16)
235 | ax[class_idx_axes].set_ylabel('DSC') #, size=16)
236 | ax[class_idx_axes].set_title(class_names[class_idx]) #, size=16)
237 |
238 | # for label in (ax[class_idx_axes].get_xticklabels() +
239 | # ax[class_idx_axes].get_yticklabels()):
240 | # label.set_fontsize(15)
241 |
242 | plt.tight_layout()
243 |
244 | fname_vis = f"metrics_{config['dataset']}_test_slicew_confid_{metric_name}.png"
245 | path_vis = os.path.join(path_root_out, fname_vis)
246 | plt.savefig(path_vis)
247 | if config['interactive']:
248 | plt.show()
249 | else:
250 | plt.close()
251 |
252 |
253 | def _scan_to_metrics_paired(path_root, name_true, name_pred, num_classes, meta):
254 | # Read the data
255 | patt_true = os.path.join(path_root, meta['patient'], meta['release'],
256 | meta['sequence'], name_true, '*.png')
257 | stack_true = png_to_numpy(patt_true)
258 |
259 | patt_pred = os.path.join(path_root, meta['patient'], meta['release'],
260 | meta['sequence'], name_pred, '*.png')
261 | stack_pred = png_to_numpy(patt_pred)
262 |
263 | # Add batch dimension
264 | stack_true = stack_true[None, ...]
265 | stack_pred = stack_pred[None, ...]
266 |
267 | # Compute the metrics
268 | res = meta
269 | res['dice_score'] = dice_score(stack_pred, stack_true,
270 | num_classes=num_classes)
271 | return res
272 |
273 |
274 | def metrics_paired_volumew(*, path_pred, dirname_true, dirname_pred, df, num_classes,
275 | num_workers=1):
276 | acc_l = []
277 | groupers_stack = ['patient', 'release', 'sequence', 'side']
278 |
279 | for name_gb, df_gb in tqdm(df.groupby(groupers_stack)):
280 | patient, release, sequence, side = name_gb
281 |
282 | acc_l.append(delayed(_scan_to_metrics_paired)(
283 | path_root=path_pred,
284 | name_true=dirname_true,
285 | name_pred=dirname_pred,
286 | num_classes=num_classes,
287 | meta={
288 | 'patient': patient,
289 | 'release': release,
290 | 'sequence': sequence,
291 | 'side': side,
292 | **df_gb.to_dict("records")[0],
293 | }
294 | ))
295 |
296 | acc_l = Parallel(n_jobs=num_workers, verbose=10)(
297 | t for t in tqdm(acc_l, total=len(acc_l)))
298 | # Convert from list of dicts to dict of lists
299 | acc_d = {k: [d[k] for d in acc_l] for k in acc_l[0]}
300 | return acc_d
301 |
302 |
303 | def plot_metrics_paired_volumew_n5a9(acc, path_root_out, num_classes, class_names,
304 | metric_names, config):
305 | """Average and visualize the metrics.
306 |
307 | Args:
308 | acc: dict of lists
309 | path_root_out: str
310 | num_classes: int
311 | class_names: dict
312 | metric_names: iterable of str
313 | config: dict
314 |
315 | """
316 | groupers_stack = ['patient', 'release', 'sequence', 'side']
317 |
318 | acc_df = pd.DataFrame.from_dict(acc)
319 | acc_df = acc_df.sort_values(groupers_stack)
320 |
321 | for metric_name in metric_names:
322 | if metric_name not in acc:
323 | logger.error(f'`{metric_name}` is not presented in `acc`')
324 | continue
325 |
326 | metric_values = []
327 | for gb_name, gb_df in acc_df.groupby(groupers_stack):
328 | tmp = np.asarray([np.asarray(e) for e in gb_df[metric_name]])
329 | metric_values.append(tmp)
330 | metric_values = np.stack(metric_values, axis=0)
331 | # Axes order: (scan, class_idx[, sub_metric])
332 |
333 | nrows = np.floor(np.sqrt(num_classes)).astype(int)
334 | ncols = np.ceil(float(num_classes) / nrows).astype(int)
335 |
336 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10))
337 | ax = axes.ravel()
338 |
339 | tmp_means, tmp_stds = [], []
340 |
341 | xlims = {
342 | 'dice_score': (0.5, 1.0),
343 | }
344 |
345 | for class_idx, class_scores in enumerate(metric_values.T[1:], start=1):
346 | class_idx_axes = class_idx - 1
347 |
348 | if metric_name == 'dice_score':
349 | class_scores = class_scores[class_scores != 0]
350 | class_scores = np.squeeze(class_scores)
351 |
352 | tmp_means.append(np.mean(class_scores))
353 | tmp_stds.append(np.std(class_scores))
354 |
355 | if metric_name in xlims:
356 | ax[class_idx_axes].hist(class_scores, bins=50,
357 | range=xlims[metric_name],
358 | density=True)
359 | else:
360 | ax[class_idx_axes].hist(class_scores, bins=50,
361 | density=True)
362 | ax[class_idx_axes].set_title(class_names[class_idx])
363 |
364 | tmp_values = ', '.join(
365 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)])
366 | logger.info(f"{metric_name}:\n{tmp_values}\n")
367 |
368 | plt.tight_layout()
369 | fname_vis = f"metrics_{config['dataset']}_test_volumew_{metric_name}.png"
370 | path_vis = os.path.join(path_root_out, fname_vis)
371 | plt.savefig(path_vis)
372 | if config['interactive']:
373 | plt.show()
374 | else:
375 | plt.close()
376 |
377 |
378 | def print_metrics_paired_volumew_b46e(acc, metric_names):
379 | """Average and print the metrics with respect to KL grades.
380 |
381 | Args:
382 | acc: dict of lists
383 | metric_names: tuple of str
384 |
385 | """
386 | kl_vec = np.asarray(acc['KL'])
387 |
388 | for kl_value in np.unique(kl_vec):
389 | kl_sel = kl_vec == kl_value
390 |
391 | for metric_name in metric_names:
392 | if metric_name not in acc:
393 | logger.error(f'`{metric_name}` is not presented in `acc`')
394 | continue
395 | metric_values = acc[metric_name]
396 | tmp = np.asarray(metric_values)[kl_sel]
397 |
398 | tmp_means, tmp_stds = [], []
399 | for class_idx, class_scores in enumerate(
400 | np.moveaxis(tmp[..., 1:], -1, 0), start=1):
401 | class_scores = class_scores[class_scores != 0]
402 |
403 | tmp_means.append(np.mean(class_scores))
404 | tmp_stds.append(np.std(class_scores))
405 |
406 | tmp_values = ', '.join(
407 | [f'{m:.03f}({s:.03f})' for m, s in zip(tmp_means, tmp_stds)])
408 | logger.info(f"KL{kl_value}, {metric_name}\n{tmp_values}\n")
409 |
410 |
411 | @click.command()
412 | @click.option('--path_experiment_root', default='../../results/temporary')
413 | @click.option('--dirname_pred', required=True)
414 | @click.option('--dirname_true', required=True)
415 | @click.option('--dataset', required=True, type=click.Choice(
416 | ['oai_imo', 'okoa', 'maknee']))
417 | @click.option('--atlas', required=True, type=click.Choice(
418 | ['imo', 'segm', 'okoa']))
419 | @click.option('--ignore_cache', is_flag=True)
420 | @click.option('--interactive', is_flag=True)
421 | @click.option('--num_workers', default=1, type=int)
422 | def main(**config):
423 |
424 | path_pred = os.path.join(config['path_experiment_root'],
425 | f"predicts_{config['dataset']}_test")
426 | path_logs = os.path.join(config['path_experiment_root'],
427 | f"logs_{config['dataset']}_test")
428 |
429 | # Get the information on object classes
430 | locations = atlas_to_locations[config['atlas']]
431 | class_names = [k for k in locations]
432 | num_classes = max(locations.values()) + 1
433 |
434 | # Get the index of image files and the corresponding metadata
435 | path_meta = os.path.join(path_pred, 'meta_dynamic.csv')
436 | df_meta = pd.read_csv(path_meta,
437 | dtype={'patient': str,
438 | 'release': str,
439 | 'prefix_var': str,
440 | 'sequence': str,
441 | 'side': str,
442 | 'slice_idx': int,
443 | 'pixel_spacing_0': float,
444 | 'pixel_spacing_1': float,
445 | 'slice_thickness': float,
446 | 'KL': int,
447 | 'has_mask': int},
448 | index_col=False)
449 |
450 | df_sel = df_meta.sort_values(['patient', 'release', 'sequence', 'side', 'slice_idx'])
451 |
452 | # -------------------------------- Planar ------------------------------------------
453 | fname_pkl = os.path.join(path_logs,
454 | f"cache_{config['dataset']}_test_"
455 | f"{config['atlas']}_"
456 | f"slicew_paired.pkl")
457 |
458 | logger.info('Planar scores')
459 | if os.path.exists(fname_pkl) and not config['ignore_cache']:
460 | logger.info('Loading from the cache')
461 | with open(fname_pkl, 'rb') as f:
462 | acc_slicew = pickle.load(f)
463 | else:
464 | logger.info('Computing')
465 | acc_slicew = metrics_paired_slicew(path_pred=path_pred,
466 | dirname_true=config['dirname_true'],
467 | dirname_pred=config['dirname_pred'],
468 | df=df_sel,
469 | num_classes=num_classes,
470 | num_workers=config['num_workers'])
471 | logger.info('Caching the results into file')
472 | os.makedirs(path_logs, exist_ok=True)
473 | with open(fname_pkl, 'wb') as f:
474 | pickle.dump(acc_slicew, f)
475 |
476 | plot_metrics_paired_slicew_xvd7(
477 | acc=acc_slicew,
478 | path_root_out=path_logs,
479 | num_classes=num_classes,
480 | class_names=class_names,
481 | metric_names=('dice_score', ),
482 | config=config,
483 | )
484 |
485 | print_metrics_paired_slicew_pua5(
486 | acc=acc_slicew,
487 | metric_names=('dice_score',),
488 | )
489 |
490 | plot_metrics_paired_slicew_va49(
491 | acc=acc_slicew,
492 | path_root_out=path_logs,
493 | num_classes=num_classes,
494 | class_names=class_names,
495 | metric_names=('dice_score',),
496 | config=config,
497 | )
498 |
499 | # ------------------------------- Volumetric ---------------------------------------
500 | fname_pkl = os.path.join(path_logs,
501 | f"cache_{config['dataset']}_test_"
502 | f"{config['atlas']}_"
503 | f"volumew_paired.pkl")
504 |
505 | logger.info('Volumetric scores')
506 | if os.path.exists(fname_pkl) and not config['ignore_cache']:
507 | logger.info('Loading from the cache')
508 | with open(fname_pkl, 'rb') as f:
509 | acc_volumew = pickle.load(f)
510 | else:
511 | logger.info('Computing')
512 | acc_volumew = metrics_paired_volumew(
513 | path_pred=path_pred,
514 | dirname_true=config['dirname_true'],
515 | dirname_pred=config['dirname_pred'],
516 | df=df_sel,
517 | num_classes=num_classes,
518 | num_workers=config['num_workers']
519 | )
520 | logger.info('Caching the results into file')
521 | os.makedirs(path_logs, exist_ok=True)
522 | with open(fname_pkl, 'wb') as f:
523 | pickle.dump(acc_volumew, f)
524 |
525 | plot_metrics_paired_volumew_n5a9(
526 | acc=acc_volumew,
527 | path_root_out=path_logs,
528 | num_classes=num_classes,
529 | class_names=class_names,
530 | metric_names=('dice_score', ),
531 | config=config,
532 | )
533 |
534 | print_metrics_paired_volumew_b46e(
535 | acc=acc_volumew,
536 | metric_names=('dice_score', ),
537 | )
538 |
539 |
540 | if __name__ == '__main__':
541 | main()
542 |
--------------------------------------------------------------------------------