├── medmnistc
├── __init__.py
├── corruptions
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── base.cpython-311.pyc
│ │ ├── noise.cpython-311.pyc
│ │ ├── __init__.cpython-311.pyc
│ │ └── registry.cpython-311.pyc
│ ├── base.py
│ ├── compression.py
│ ├── noise.py
│ ├── enhance.py
│ ├── filter.py
│ ├── microscopy.py
│ └── registry.py
├── assets
│ └── inks.npz
├── __pycache__
│ └── __init__.cpython-311.pyc
├── utils
│ ├── utils.py
│ └── baselines.py
├── augmentation.py
├── dataset.py
├── dataset_manager.py
├── eval.py
└── visualizer.py
├── .gitignore
├── assets
├── images
│ └── wallpaper.gif
└── examples
│ ├── create_dataset.ipynb
│ ├── evaluation.ipynb
│ └── augment.ipynb
├── requirements.txt
├── pyproject.toml
├── README.md
└── LICENSE
/medmnistc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/medmnistc/corruptions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
3 | build/
4 | dist/
--------------------------------------------------------------------------------
/medmnistc/assets/inks.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/medmnistc/assets/inks.npz
--------------------------------------------------------------------------------
/assets/images/wallpaper.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/assets/images/wallpaper.gif
--------------------------------------------------------------------------------
/medmnistc/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/medmnistc/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | medmnist == 3.0.1
2 | scikit-image == 0.23.2
3 | scikit-learn > 1.2.2
4 | numpy
5 | torch
6 | torchvision
7 | opencv-python
8 | scipy
9 | wand > 0.6.10
--------------------------------------------------------------------------------
/medmnistc/corruptions/__pycache__/base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/medmnistc/corruptions/__pycache__/base.cpython-311.pyc
--------------------------------------------------------------------------------
/medmnistc/corruptions/__pycache__/noise.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/medmnistc/corruptions/__pycache__/noise.cpython-311.pyc
--------------------------------------------------------------------------------
/medmnistc/corruptions/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/medmnistc/corruptions/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/medmnistc/corruptions/__pycache__/registry.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/francescodisalvo05/medmnistc-api/HEAD/medmnistc/corruptions/__pycache__/registry.cpython-311.pyc
--------------------------------------------------------------------------------
/medmnistc/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | import os
5 |
6 |
7 | def seed_everything(seed: int):
8 | random.seed(seed)
9 | os.environ['PYTHONHASHSEED'] = str(seed)
10 | np.random.seed(seed)
11 | torch.manual_seed(seed)
12 | torch.cuda.manual_seed(seed)
13 | torch.backends.cudnn.deterministic = True
14 | torch.backends.cudnn.benchmark = False
15 | rng = np.random.default_rng(seed) # rng will be used on skimage.util.random_noise
16 | return rng
--------------------------------------------------------------------------------
/medmnistc/corruptions/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 |
5 | class BaseCorruption:
6 | def __init__(self, severity_params):
7 | self.severity_params = severity_params
8 | self.font = cv2.FONT_HERSHEY_DUPLEX
9 | self._inks = None
10 |
11 | @property
12 | def inks(self):
13 | if self._inks is None:
14 | # Get the directory of the current file (__file__ is the path to the current file)
15 | current_file_dir = os.path.dirname(os.path.realpath(__file__))
16 | inks_path = os.path.join(current_file_dir, './../', 'assets', 'inks.npz')
17 | self._inks = np.load(inks_path, allow_pickle=True)
18 | return self._inks
19 |
20 | def apply(self, img):
21 | raise NotImplementedError("This method should be implemented by subclasses.")
22 |
23 |
24 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "medmnistc"
7 | authors = [
8 | {name = "Francesco Di Salvo", email = "francesco.di-salvo@uni-bamberg.de"}
9 | ]
10 | description = "This Python library aims to evaluate model robustness under corrupted test sets and to enhance domain generalization through domain-specific augmentations."
11 | version = "0.1.0"
12 | readme = "README.md"
13 | requires-python = ">=3.7"
14 | dependencies = [
15 | "medmnist == 3.0.1",
16 | "scikit-image == 0.23.2",
17 | "scikit-learn > 1.2.2",
18 | "numpy",
19 | "torch",
20 | "torchvision",
21 | "opencv-python",
22 | "scipy",
23 | "wand > 0.6.10"
24 | ]
25 |
26 | [project.urls]
27 | "Homepage" = "https://github.com/francescodisalvo05/medmnistc-api"
28 | "Issue Tracker" = "https://github.com/francescodisalvo05/medmnistc-api/issues"
--------------------------------------------------------------------------------
/medmnistc/augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class AugMedMNISTC(object):
5 | def __init__(self,
6 | train_corruptions : dict = {},
7 | verbose: bool = False):
8 | """
9 | Augmentation class based on the designed image corruptions.
10 | For each call, it will randomly choose *one* corruption (i.e., augmentation) using a
11 | uniformly sampled intensity hyperparameter in the range [min_intensity,max_intensity].
12 | Notably, among the possible augmentations, we do include `identity` (i.e., no aug).
13 |
14 | :param train_corruptions: Dictionary containing the corruptions to use during training.
15 | :param verbose: If True, print the name of the selected corruption.
16 | """
17 | assert len(train_corruptions) > 0, f"You need to define some corruptions firsts."
18 |
19 | self.verbose = verbose
20 | self.train_corruptions = train_corruptions
21 | self.train_corruptions_keys = list(self.train_corruptions.keys()) + ['identity']
22 |
23 |
24 | def __call__(self, img):
25 | corr = np.random.choice(self.train_corruptions_keys)
26 |
27 | if self.verbose:
28 | print(corr)
29 |
30 | if corr == 'identity':
31 | return img
32 |
33 | return self.train_corruptions[corr].apply(img, augmentation=True)
34 |
35 |
36 |
--------------------------------------------------------------------------------
/medmnistc/corruptions/compression.py:
--------------------------------------------------------------------------------
1 | from .base import BaseCorruption
2 | from PIL import Image
3 | from io import BytesIO
4 |
5 | import numpy as np
6 |
7 |
8 | class Pixelate(BaseCorruption):
9 | def apply(self, img, severity=-1, augmentation=False):
10 |
11 | if augmentation:
12 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
13 | resize_factor = np.random.uniform(low=range_min, high=range_max, size=None)
14 | else:
15 | resize_factor = self.severity_params[severity]
16 |
17 |
18 | width, height = img.size
19 | img = img.resize((int(width * resize_factor), int(height * resize_factor)), Image.BOX)
20 | img = img.resize((width, height), Image.BOX)
21 | return np.array(img)
22 |
23 |
24 | class JPEGCompression(BaseCorruption):
25 | def apply(self, img, severity=-1, augmentation=False):
26 |
27 | if augmentation:
28 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
29 | compression_quality = int(np.random.uniform(low=range_min, high=range_max, size=None))
30 | else:
31 | compression_quality = self.severity_params[severity]
32 |
33 | output = BytesIO()
34 | img.save(output, 'JPEG', quality=compression_quality)
35 | img = Image.open(output)
36 | return np.array(img)
--------------------------------------------------------------------------------
/medmnistc/corruptions/noise.py:
--------------------------------------------------------------------------------
1 | from .base import BaseCorruption
2 |
3 | import skimage as sk
4 | import numpy as np
5 |
6 |
7 | class GaussianNoise(BaseCorruption):
8 | def apply(self, img, severity=-1, augmentation=False):
9 | if augmentation:
10 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
11 | c = np.random.uniform(low=range_min, high=range_max, size=None)
12 | else:
13 | c = self.severity_params[severity]
14 |
15 | img = np.array(img) / 255.
16 | noisy_image = np.clip(img + np.random.normal(size=img.shape, scale=c), 0, 1)
17 | return (noisy_image * 255).astype(np.uint8)
18 |
19 |
20 | class ImpulseNoise(BaseCorruption):
21 | def apply(self, img, severity=-1, augmentation=False):
22 | if augmentation:
23 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
24 | c = np.random.uniform(low=range_min, high=range_max, size=None)
25 | noisy_image = sk.util.random_noise(np.array(img) / 255., mode='s&p', amount=c, rng=np.random.default_rng(99999))
26 | else:
27 | c = self.severity_params[severity]
28 | noisy_image = sk.util.random_noise(np.array(img) / 255., mode='s&p', amount=c, rng=self.rng)
29 | noisy_image = np.clip(noisy_image, 0, 1)
30 | return (noisy_image * 255).astype(np.uint8)
31 |
32 |
33 | class SpeckleNoise(BaseCorruption):
34 | def apply(self, img, severity=-1, augmentation=False):
35 | if augmentation:
36 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
37 | c = np.random.uniform(low=range_min, high=range_max, size=None)
38 | else:
39 | c = self.severity_params[severity]
40 | img = np.array(img) / 255.
41 | noise = np.random.normal(size=img.shape, scale=c)
42 | noisy_image = np.clip(img + img * noise, 0, 1)
43 | return (noisy_image * 255).astype(np.uint8)
44 |
45 |
46 | class ShotNoise(BaseCorruption):
47 | def apply(self, img, severity=-1, augmentation=False):
48 | if augmentation:
49 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
50 | mult = np.random.uniform(low=range_min, high=range_max, size=None)
51 | else:
52 | mult = self.severity_params[severity]
53 | img = np.array(img) / 255.
54 | noisy_image = np.clip(np.random.poisson(img * mult) / mult, 0, 1)
55 | return (noisy_image * 255).astype(np.uint8)
56 |
57 |
58 |
--------------------------------------------------------------------------------
/medmnistc/corruptions/enhance.py:
--------------------------------------------------------------------------------
1 | from .base import BaseCorruption
2 | from PIL import ImageEnhance
3 | import skimage as sk
4 | import numpy as np
5 |
6 | import torchvision.transforms.functional as TF
7 |
8 |
9 | class Brightness(BaseCorruption):
10 | def apply(self, img, severity=-1, augmentation=False):
11 | if augmentation:
12 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
13 | brightness_factor = np.random.uniform(low=range_min, high=range_max, size=None)
14 | else:
15 | brightness_factor = self.severity_params[severity]
16 | enhancer = ImageEnhance.Brightness(img)
17 | brightened_img = enhancer.enhance(brightness_factor)
18 | return np.array(brightened_img).astype(np.uint8)
19 |
20 |
21 | class Contrast(BaseCorruption):
22 | def apply(self, img, severity=-1, augmentation=False):
23 | if augmentation:
24 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
25 | contrast_factor = np.random.uniform(low=range_min, high=range_max, size=None)
26 | else:
27 | contrast_factor = self.severity_params[severity]
28 | enhancer = ImageEnhance.Contrast(img)
29 | contrasted_img = enhancer.enhance(contrast_factor)
30 | return np.array(contrasted_img).astype(np.uint8)
31 |
32 |
33 | class GammaCorrection(BaseCorruption):
34 | def apply(self, img, severity=-1, augmentation=False):
35 | if augmentation:
36 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
37 | correction_factor = np.random.uniform(low=range_min, high=range_max, size=None)
38 | else:
39 | correction_factor = self.severity_params[severity]
40 |
41 | img = TF.adjust_gamma(img, correction_factor, gain=1)
42 | return np.array(img).astype(np.uint8)
43 |
44 |
45 | class Saturate(BaseCorruption):
46 | def apply(self, img, severity=-1, augmentation=False):
47 | if augmentation:
48 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
49 | saturation_factor = np.random.uniform(low=range_min, high=range_max, size=None)
50 | else:
51 | saturation_factor = self.severity_params[severity]
52 |
53 | img = np.array(img) / 255.
54 | img = sk.color.rgb2hsv(img)
55 | img[:, :, 1] = np.clip(img[:, :, 1] + saturation_factor, 0, 1)
56 | img = sk.color.hsv2rgb(img)
57 | img = np.clip(img, 0, 1) * 255
58 |
59 | return img.astype(np.uint8)
60 |
61 |
--------------------------------------------------------------------------------
/medmnistc/dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from torch.utils.data import Dataset
3 | from PIL import Image
4 |
5 | import numpy as np
6 | import os
7 |
8 |
9 | class CorruptedMedMNIST(Dataset):
10 | def __init__(self,
11 | dataset_name : str,
12 | corruption : str,
13 | norm_mean : list = [0.5],
14 | norm_std : list = [0.5],
15 | root : str = None,
16 | as_rgb : bool = True,
17 | mmap_mode : str = None):
18 | """
19 | Dataset class of CorruptedMedMNIST
20 |
21 | :param dataset_name: Name of the reference medmnist dataset.
22 | :param corruption: Name of the desired corruption.
23 | :param norm_mean: Normalization mean.
24 | :param norm_std: Normalization standard deviation.
25 | :param root: Root path of the generated corrupted data.
26 | :param as_rgb: Flag for RGB of Greyscale data.
27 | :param mmap_mode: Memory mapping of the file: {None, ‘r+’, ‘r’, ‘w+’, ‘c’}.
28 | If not None, then memory-map the file, using the given mode
29 | (see numpy.memmap for a detailed description of the modes).
30 | Memory mapping is especially useful for accessing small
31 | fragments of large files without reading the entire file into memory.
32 | src: https://numpy.org/doc/stable/reference/generated/numpy.load.html
33 |
34 | This dataset class was greatly inspired from the MedMNIST APIs:
35 | https://github.com/MedMNIST/MedMNIST
36 | """
37 |
38 | super(CorruptedMedMNIST, self).__init__()
39 |
40 | self.dataset_name = dataset_name
41 | self.corruption = corruption
42 | self.root = root
43 | self.as_rgb = as_rgb
44 |
45 | if root is not None and os.path.exists(root):
46 | self.root = root
47 | else:
48 | raise RuntimeError(
49 | "Failed to setup the default `root` directory. "
50 | + "Please specify and create the `root` directory manually."
51 | )
52 |
53 | if not os.path.exists(os.path.join(self.root, self.dataset_name, f"{corruption}.npz")):
54 | print(os.path.join(self.root, self.dataset_name, f"{corruption}.npz"))
55 | raise RuntimeError(
56 | "Dataset not found."
57 | )
58 |
59 | npz_file = np.load(
60 | os.path.join(self.root, self.dataset_name, f"{corruption}.npz"),
61 | mmap_mode=mmap_mode,
62 | )
63 |
64 | self.imgs = npz_file["test_images"]
65 | self.labels = npz_file["test_labels"]
66 | self.transform = transforms.Compose([
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=norm_mean, std=norm_std)
69 | ])
70 |
71 |
72 | def __len__(self):
73 | return self.imgs.shape[0]
74 |
75 |
76 | def __getitem__(self, index):
77 | img, target = self.imgs[index], self.labels[index].astype(int)
78 | img = Image.fromarray(img)
79 |
80 | if self.as_rgb:
81 | img = img.convert("RGB")
82 |
83 | if self.transform is not None:
84 | img = self.transform(img)
85 |
86 | return img, target
--------------------------------------------------------------------------------
/medmnistc/corruptions/filter.py:
--------------------------------------------------------------------------------
1 | from .base import BaseCorruption
2 |
3 | from scipy.ndimage import zoom as scizoom
4 | from wand.image import Image as WandImage
5 | from wand.api import library as wandlibrary
6 | from io import BytesIO
7 |
8 | import torchvision.transforms.functional as TF
9 | import numpy as np
10 | import cv2
11 |
12 |
13 | def disk(radius, alias_blur=0.1, dtype=np.float32):
14 | if radius <= 8:
15 | L = np.arange(-8, 8 + 1)
16 | ksize = (3, 3)
17 | else:
18 | L = np.arange(-radius, radius + 1)
19 | ksize = (5, 5)
20 | X, Y = np.meshgrid(L, L)
21 | aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
22 | aliased_disk /= np.sum(aliased_disk)
23 | return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
24 |
25 |
26 | def clipped_zoom(img, zoom_factor):
27 | h = img.shape[0]
28 | ch = int(np.ceil(h / zoom_factor))
29 |
30 | top = (h - ch) // 2
31 | img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
32 | trim_top = (img.shape[0] - h) // 2
33 |
34 | return img[trim_top:trim_top + h, trim_top:trim_top + h]
35 |
36 |
37 |
38 | class MotionImage(WandImage):
39 | def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
40 | wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
41 |
42 |
43 | class GaussianBlur(BaseCorruption):
44 | def apply(self, img, severity=-1, augmentation=False):
45 | if augmentation:
46 | kernel_min, kernel_max = self.severity_params[0], self.severity_params[-1]
47 | kernel = int(np.random.uniform(low=kernel_min, high=kernel_max, size=None))
48 | if kernel % 2 == 0: # it must be odd
49 | kernel -= 1
50 | else:
51 | kernel = self.severity_params[severity]
52 | img = TF.gaussian_blur(img, kernel_size=kernel)
53 | return np.array(img).astype(np.uint8)
54 |
55 |
56 | class MotionBlur(BaseCorruption):
57 | def apply(self, img, severity=-1, augmentation=False):
58 | if augmentation:
59 | radius_min, radius_max = self.severity_params[0][0], self.severity_params[-1][0]
60 | sigma_min, sigma_max = self.severity_params[0][1], self.severity_params[-1][1]
61 | radius = np.random.uniform(low=radius_min, high=radius_max, size=None)
62 | sigma = np.random.uniform(low=sigma_min, high=sigma_max, size=None)
63 | else:
64 | radius_sigma = self.severity_params[severity]
65 | radius, sigma = radius_sigma
66 |
67 | output = BytesIO()
68 | img.save(output, format='PNG')
69 | img = MotionImage(blob=output.getvalue())
70 |
71 | img.motion_blur(radius=radius, sigma=sigma, angle=np.random.uniform(-45, 45))
72 | img = cv2.imdecode(np.fromstring(img.make_blob(), np.uint8),
73 | cv2.IMREAD_UNCHANGED)
74 |
75 | output.close()
76 |
77 | if img.shape != (224, 224):
78 | return np.clip(img[..., [2, 1, 0]], 0, 255).astype(np.uint8) # BGR to RGB
79 | else: # greyscale to RGB
80 | return np.clip(np.array([img, img, img]).transpose((1, 2, 0)), 0, 255).astype(np.uint8)
81 |
82 |
83 | class DefocusBlur(BaseCorruption):
84 | def apply(self, img, severity=-1, augmentation=False):
85 |
86 | if augmentation:
87 | radius_min, radius_max = self.severity_params[0][0], self.severity_params[-1][0]
88 | alias_min, alias_max = self.severity_params[0][1], self.severity_params[-1][1]
89 | radius = np.random.uniform(low=radius_min, high=radius_max, size=None)
90 | alias = np.random.uniform(low=alias_min, high=alias_max, size=None)
91 | else:
92 | radius_alias = self.severity_params[severity]
93 | radius, alias = radius_alias
94 |
95 | img = np.array(img) / 255.
96 | kernel = disk(radius=radius, alias_blur=alias)
97 |
98 | channels = []
99 | for d in range(3):
100 | channels.append(cv2.filter2D(img[:, :, d], -1, kernel))
101 |
102 | channels = np.array(channels).transpose((1, 2, 0)) # 3x224x224 -> 224x224x3
103 | out = np.clip(channels, 0, 1) * 255
104 |
105 | return out.astype(np.uint8)
106 |
107 |
108 | class ZoomBlur(BaseCorruption):
109 | def apply(self, img, severity=-1, augmentation=False):
110 |
111 | if augmentation: # hard code
112 | min_factor, max_factor, = 1.0, self.severity_params[-1][-1]
113 | min_step = self.severity_params[0][1] - self.severity_params[0][0]
114 | max_step = self.severity_params[-1][1] - self.severity_params[-1][0]
115 | max_factor_sampled = np.random.uniform(low=min_factor, high=max_factor, size=None)
116 | step = np.random.uniform(low=min_step, high=max_step, size=None)
117 | zoom_factors = np.arange(min_factor, max_factor_sampled, step)
118 | else:
119 | zoom_factors = self.severity_params[severity]
120 |
121 | img = (np.array(img) / 255.).astype(np.float32)
122 | out = np.zeros_like(img)
123 | for zoom_factor in zoom_factors:
124 | out += clipped_zoom(img, zoom_factor)
125 |
126 | img = (img + out) / (len(zoom_factors) + 1)
127 | img = np.clip(img, 0, 1) * 255
128 |
129 | return img.astype(np.uint8)
--------------------------------------------------------------------------------
/medmnistc/dataset_manager.py:
--------------------------------------------------------------------------------
1 | from medmnistc.corruptions.registry import CORRUPTIONS_DS, DATASET_RGB
2 | from medmnistc.utils.utils import seed_everything
3 | from medmnist import INFO
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 |
8 | from tqdm import tqdm
9 |
10 |
11 | class DatasetManager:
12 | def __init__(self,
13 | medmnist_path: str,
14 | output_path: str,
15 | random_seed : int = 0):
16 | """
17 | Class used to create the corrupted test sets.
18 | Speficially, it will create one `npz` file for each designed dataset-corruption.
19 | :param medmnist_path: Path to the medmnist datasets. Pre-download the 224 version.
20 | https://medmnist.com/
21 | :param output_path: Path to the output folder of the `medmnistc` dataset.
22 | Path convention: {output_folder} / {dataset} / {corruption}.npz
23 | :param random_seed: Control stochastic process and ensure reproducibility.
24 | """
25 | self.medmnist_path = medmnist_path
26 | self.output_path = output_path
27 |
28 | self.supported_datasets = [
29 | 'bloodmnist', 'breastmnist', 'chestmnist', 'dermamnist',
30 | 'octmnist', 'organamnist', 'organcmnist', 'organsmnist',
31 | 'pathmnist', 'pneumoniamnist', 'retinamnist', 'tissuemnist'
32 | ]
33 |
34 | self.random_seed = random_seed
35 |
36 |
37 |
38 |
39 | def create_dataset(self, dataset_name: str):
40 | """
41 | Create the corrupted dataset(s).
42 |
43 | :param dataset_name: Name of the dataset to corrupt.
44 | Options: {'all',
45 | 'bloodmnist', 'breastmnist', 'chestmnist', 'dermamnist',
46 | 'octmnist', 'organamnist', 'organcmnist', 'organsmnist',
47 | 'pathmnist', 'pneumoniamnist', 'retinamnist', 'tissuemnist'}
48 | If `all` is set, it will create create all the corrupted datasets.
49 | """
50 |
51 | # Create all the corrupted datasets
52 | if dataset_name == 'all':
53 | for ds in self.supported_datasets:
54 | self.create_single_dataset(dataset_name=ds)
55 |
56 | # Create only the chosen corrupted dataset
57 | else:
58 | dataset_name = dataset_name.lower()
59 | assert dataset_name in self.supported_datasets, f"Dataset not found. Please choose one among : {self.supported_datasets}"
60 | self.create_single_dataset(dataset_name=dataset_name)
61 |
62 |
63 | def create_single_dataset(self, dataset_name: str):
64 | """
65 | Generate one corrupted version for the required dataset.
66 | Note that we store one .npz file for each designed corruptions.
67 |
68 | :param dataset_name: Name of the dataset to corrupt.
69 | """
70 | print(f"=========== {dataset_name} ===========")
71 |
72 | rng = seed_everything(self.random_seed)
73 |
74 | # Get the designed corruptions
75 | corruptions = CORRUPTIONS_DS[dataset_name]
76 |
77 | # Get MedMNIST's dataset class
78 | # It required the pre-download of the 224 datasets
79 | info = INFO[dataset_name]
80 | DatasetClass = getattr(__import__('medmnist', fromlist=[info['python_class']]), info['python_class'])
81 |
82 | dataset = DatasetClass( split = "test",
83 | as_rgb = True,
84 | download = False,
85 | transform = None,
86 | size = 224,
87 | root = self.medmnist_path)
88 |
89 | dataset_path = os.path.join(self.output_path,dataset_name)
90 | os.makedirs(dataset_path, exist_ok=True)
91 |
92 | # Create the corrupted datasets
93 | # NOTE: This could be computationally heavy (RAM-wise) for large datasets (e.g. TissueMNIST),
94 | # as it multiply 5 times the test set.
95 | dataset_c, labels = [], []
96 |
97 | for (corruption,corruptor) in corruptions.items():
98 |
99 | print(f'Starting {corruption}...')
100 |
101 | # Load the corrupted images and relative labels into lists
102 | dataset_c, labels = [], []
103 |
104 | if corruption == "impulse_noise":
105 | corruptor.rng = rng #skimage..
106 |
107 | # By design, we have 5 intensity levels
108 | for severity in range(0,5):
109 |
110 | # Define corruptor method
111 | lam_corruption = lambda img : corruptor.apply(img, severity)
112 |
113 | # Iterate over the MedMNIST dataset and apply corruptions
114 | for img_idx in tqdm(range(len(dataset.imgs)), f'Severity {str(severity+1).zfill(2)}'):
115 |
116 | img, label = dataset.imgs[img_idx], dataset.labels[img_idx]
117 |
118 | # The defined corruptions support RGB images
119 | img = Image.fromarray(img).convert('RGB')
120 | corrupted_img = lam_corruption(img)
121 |
122 | # Convert to greyscale, if required
123 | if not DATASET_RGB[dataset_name]:
124 | corrupted_img = Image.fromarray(corrupted_img).convert('L')
125 |
126 | np_corrupted = np.array(corrupted_img)
127 | assert np.min(np_corrupted) >= 0 or np.max(np_corrupted) <= 255, f"(min,max) = {(np.min(np_corrupted),np.max(np_corrupted))}"
128 | assert np_corrupted.dtype == np.uint8, f"{np_corrupted.dtype}"
129 |
130 | dataset_c.append(np_corrupted)
131 | labels.append(label)
132 |
133 | filepath = os.path.join(dataset_path,f'{corruption}.npz')
134 | np.savez_compressed(filepath, test_images=dataset_c, test_labels=labels)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🏥 MedMNIST-C
2 |
3 | We introduce MedMNIST-C [[preprint](https://arxiv.org/pdf/2406.17536)], a `benchmark dataset` based on the MedMNIST+ collection covering `12 2D datasets and 9 imaging modalities`. We simulate task and modality-specific image corruptions of varying severity to comprehensively evaluate the robustness of established algorithms against `real-world artifacts` and `distribution shifts`. We further show that our simple-to-use artificial corruptions allow for highly performant, lightweight `data augmentation` to enhance model robustness.
4 |
5 |
6 |
7 |
8 |
9 | > You can download the corrupted datasets from [Zenodo](https://zenodo.org/records/11471504).
10 | Due to space constraints, we have uploaded all datasets except for TissueMNIST-C. However, You can still reproduce it using our APIs.
11 |
12 | ## Installation and Requirements
13 |
14 | ```
15 | pip install medmnistc
16 | ```
17 |
18 | We do require [Wand](https://docs.wand-py.org/en/latest/guide/install.html) for image manipulation, a Python binding for [ImageMagick](https://imagemagick.org/index.php). Thus, if you are using Ubuntu:
19 |
20 | ```
21 | sudo apt-get install libmagickwand-dev
22 | ```
23 |
24 | otherwise, please check the [tutorial](https://docs.wand-py.org/en/0.2.4/guide/install.html).
25 |
26 | ## Main components
27 |
28 | * `medmnistc/corruptions/registry.py`: List of all the corruptions and respective intensity hyperparameters.
29 | * `medmnistc/dataset_manager.py`: Dataset class responsible for the creation of the corrupted datasets.
30 | * `medmnistc/visualizer.py`: Class used to visualize and store the defined corruptions.
31 | * `medmnistc/augmentation.py`: Augumentation class based on the defined corruptions.
32 | * `medmnistc/dataset.py`: Dataset class used for the corrupted datasets.
33 | * `medmnistc/eval.py`: PyTorch class used for model evaluation under corrupted datasets.
34 | * `medmnistc/assets/baseline/*`: Normalization baselines used for model evaluation under corrupted datasets.
35 |
36 | ## Basic usage
37 |
38 | ### Create the corrupted datasets
39 | ```python
40 | from medmnistc.dataset_manager import DatasetManager
41 |
42 | medmnist_path = ... # PATH TO THE CLEAN IMAGES
43 | medmnistc_path = ... # PATH TO THE CORRUPTED IMAGES
44 |
45 | ds_manager = DatasetManager(medmnist_path = medmnist_path, output_path=output_path)
46 | ds_manager.create_dataset(dataset_name = "breastmnist") # create a single corrupted test set
47 | ds_manager.create_dataset(dataset_name = "all") # create all
48 | ```
49 |
50 | ### Augmentations
51 | ```python
52 | from medmnistc.augmentation import AugMedMNISTC
53 | from medmnistc.corruptions.registry import CORRUPTIONS_DS
54 | import torchvision.transforms as transforms
55 |
56 | dataset = "breastmnist" # select dataset
57 | train_corruptions = CORRUPTIONS_DS[dataset] # load the designed corruptions for this dataset
58 | images = ... # load images
59 |
60 | # Augment with AugMedMNISTC
61 | augment = AugMedMNISTC(train_corruptions)
62 | augmented_img = augment(images[0])
63 |
64 | # Integrate into transforms.Compose
65 | aug_compose = transforms.Compose([
66 | AugMedMNISTC(train_corruptions),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=..., std=...)
69 | ])
70 |
71 | augmented_img = aug_compose(images[0])
72 | ```
73 |
74 | ### Notebooks
75 |
76 | * [Create the dataset](assets/examples/create_dataset.ipynb)
77 | * [Visualize the corruptions](assets/examples/visualize.ipynb)
78 | * [Evaluate the corruptions](assets/examples/evaluation.ipynb)
79 | * [Use the designed augmentations](assets/examples/augment.ipynb)
80 |
81 | ## Papers using MedMNIST-C
82 |
83 | | **Authors** | **Paper** | **Venue** |
84 | | ------------- | ------------- | ------------- |
85 | | Kuhn et al. | An autonomous agent for auditing and improving the reliability of clinical AI models | [ArXiv'25](https://arxiv.org/pdf/2507.05755) |
86 | | Manzari et al. | Medical image classification with kan-integrated transformers and dilated neighborhood attention | [ArXiv'25](https://arxiv.org/abs/2502.13693) |
87 | | Imam et al. | On the Robustness of Medical Vision-Language Models: Are they Truly Generalizable? | [MIUA'25](https://arxiv.org/abs/2505.15425) |
88 | | Zeevi et al. | Rate-In: Information-Driven Adaptive Dropout Rates for Improved Inference-Time Uncertainty Estimation | [CVPR'25](https://openaccess.thecvf.com/content/CVPR2025/papers/Zeevi_Rate-In_Information-Driven_Adaptive_Dropout_Rates_for_Improved_Inference-Time_Uncertainty_Estimation_CVPR_2025_paper.pdf) |
89 | | Hekler et al. | Beyond Overconfidence: Foundation Models Redefine Calibration in Deep Neural Networks | [ArXiv'25](https://www.arxiv.org/abs/2506.09593) |
90 | | Abhishek et al. | Investigating the Quality of DermaMNIST and Fitzpatrick17k Dermatological Image Datasets | [Scientific Data'25](https://www.nature.com/articles/s41597-025-04382-5) |
91 | | Singh et al. | Dynamic Filter Application in Graph Convolutional Networks for Enhanced Spectral Feature Analysis and Class Discrimination in Medical Imaging | [IEEE Access'24](https://ieeexplore.ieee.org/document/10637462) |
92 |
93 | ## License
94 |
95 | The code is under [Apache-2.0 License](./LICENSE).
96 |
97 | The MedMNIST-C dataset is licensed under Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)), except DermaMNIST-C under Creative Commons Attribution-NonCommercial 4.0 International ([CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/)).
98 |
99 | ## Citation
100 |
101 | If you find this work useful, please consider citing us:
102 | ```
103 | @misc{disalvo2024medmnistc,
104 | title={MedMNIST-C: Comprehensive benchmark and improved classifier robustness by simulating realistic image corruptions},
105 | author={Francesco Di Salvo and Sebastian Doerrich and Christian Ledig},
106 | year={2024},
107 | eprint={2406.17536},
108 | archivePrefix={arXiv},
109 | primaryClass={eess.IV},
110 | url={https://arxiv.org/abs/2406.17536},
111 | }
112 | ```
113 |
114 | `DISCLAIMER`: This repository is inspired by MedMNIST APIs and the ImageNet-C repository. Thus, please also consider citing [MedMNIST](https://www.nature.com/articles/s41597-022-01721-8), the respective source datasets (described [here](https://medmnist.com/)) and [ImageNet-C](https://arxiv.org/abs/1903.12261).
115 |
116 | ## Release versions
117 |
118 | * `v0.1.0`: MedMNIST-C beta release.
119 |
--------------------------------------------------------------------------------
/medmnistc/corruptions/microscopy.py:
--------------------------------------------------------------------------------
1 | from .base import BaseCorruption
2 | from PIL import Image, ImageDraw
3 |
4 | import numpy as np
5 | import random
6 | import cv2
7 | import string
8 | import random
9 |
10 |
11 | class StainDeposit(BaseCorruption):
12 | def apply(self, img, severity=-1, augmentation=False):
13 |
14 | if augmentation:
15 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
16 | max_marks = int(np.random.uniform(low=range_min, high=range_max, size=None))
17 | else:
18 | max_marks = self.severity_params[severity]
19 |
20 |
21 | x = np.array(img)
22 | img_w, img_h = x.shape[1], x.shape[0]
23 |
24 | if max_marks > 1:
25 | num_marks = random.randint(1,max_marks)
26 | else:
27 | num_marks = max_marks
28 |
29 | inks = self.inks[str(max(3,severity))] # size
30 |
31 | for _ in range(num_marks):
32 |
33 | ink = inks[random.randint(0,len(inks)-1)]
34 | ink_height, ink_width = ink.shape
35 | rand_x = random.randint(10,img_w - ink_width - 10)
36 | rand_y = random.randint(10,img_h - ink_height - 10)
37 |
38 | for idx in range(3): #channels
39 | x[rand_y: rand_y + ink_height, rand_x:rand_x + ink_width,idx] *= (1-ink) # black
40 |
41 | return np.clip(x, 0, 255).astype(np.uint8)
42 |
43 |
44 | class Bubble(BaseCorruption):
45 | def apply(self, img, severity=-1, augmentation=False):
46 |
47 | if augmentation:
48 | range_min_rad, range_max_rad = self.severity_params[0][0], self.severity_params[-1][0]
49 | range_min_bub, range_max_bub = self.severity_params[0][1], self.severity_params[-1][1]
50 | max_radius = int(np.random.uniform(low=range_min_rad, high=range_max_rad, size=None))
51 | max_bubbles = int(np.random.uniform(low=range_min_bub, high=range_max_bub, size=None))
52 | else:
53 | maxradius_bubbles = self.severity_params[severity]
54 | max_radius, max_bubbles = maxradius_bubbles
55 |
56 | height, width = img.size
57 | output_image = img.copy()
58 |
59 | # create a new image for the bubbles with the same dimensions as the original image
60 | # and transparent background (RGBA mode)
61 | bubbles_image = Image.new('RGBA', output_image.size, (255, 255, 255, 0))
62 | # create a drawing context for the bubble image
63 | draw = ImageDraw.Draw(bubbles_image)
64 | # define border effect of the bubble
65 | border = 2
66 |
67 | num_bubbles = random.randint(7,max_bubbles)
68 |
69 | # draw several bubbles
70 | for _ in range(num_bubbles):
71 | radius = random.randint(3, max_radius) # random radius
72 | x, y = random.randint(radius, width - radius), random.randint(radius, height - radius)
73 | alpha = 100 # transparency
74 | draw.ellipse((x - radius - border, y - radius - border, x + radius + border, y + radius + border), fill=(255, 255, 255, alpha + 30))
75 | draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=(255, 255, 255, alpha))
76 |
77 | # overlay the bubbles onto the original image
78 | output_image.paste(bubbles_image, (0, 0), bubbles_image)
79 | return np.array(output_image)
80 |
81 |
82 | class BlackCorner(BaseCorruption):
83 | def apply(self, img, severity=-1, augmentation=False):
84 | if augmentation:
85 | range_min, range_max = self.severity_params[0], self.severity_params[-1]
86 | multiplier = np.random.uniform(low=range_min, high=range_max, size=None)
87 | else:
88 | multiplier = self.severity_params[severity]
89 |
90 | width, height = img.size
91 | img = np.array(img)
92 |
93 | center = (width // 2, height // 2)
94 |
95 | # default radius
96 | radius = min(center[0], center[1], width - center[0], height - center[1])
97 | # adjust radius based on the severity / aug
98 | circle = np.zeros((height, width), np.uint8)
99 | cv2.circle(circle, center, int(radius * multiplier), (255), thickness=-1)
100 | mask = circle == 255
101 | img[~mask] = 0
102 |
103 | return img
104 |
105 |
106 | class Characters(BaseCorruption):
107 | def apply(self, img, severity=-1, augmentation=False):
108 |
109 | if augmentation:
110 |
111 | range_min_w, range_max_w = self.severity_params[0][0], self.severity_params[-1][0]
112 | range_min_l, range_max_l = self.severity_params[0][1], self.severity_params[-1][1]
113 | range_min_fs, range_max_fs = self.severity_params[0][2], self.severity_params[-1][2]
114 | max_words = int(np.random.uniform(low=range_min_w, high=range_max_w, size=None))
115 | max_letters = int(np.random.uniform(low=range_min_l, high=range_max_l, size=None))
116 | max_font_scale = np.random.uniform(low=range_min_fs, high=range_max_fs, size=None)
117 |
118 | else:
119 |
120 | c = self.severity_params[severity]
121 | max_words, max_letters, max_font_scale = c
122 |
123 | num_words = random.randint(1,max_words)
124 |
125 | for _ in range(num_words):
126 |
127 | num_letters = random.randint(3,max_letters)
128 | font_scale = random.randint(14,int(max_font_scale * 100)) / 100.
129 |
130 | letters = string.ascii_lowercase
131 | random_str = ''.join(random.choice(letters) for _ in range(num_letters))
132 |
133 | img = np.array(img)
134 |
135 | width, height = img.shape[1], img.shape[0]
136 |
137 | # randomly sample the position of the string with respect to the image
138 | # org = (x,y) represents the bottom left corner
139 | rand_x = random.randint(10,width - (8 * num_letters))
140 | rand_y = random.randint(10,height - 10)
141 | org = (rand_x,rand_y)
142 |
143 | # black character
144 | color = self.random_color()
145 | thickness = 1
146 |
147 | img = cv2.putText(img, random_str, org, self.font,
148 | font_scale, color, thickness, cv2.LINE_AA)
149 |
150 | return img
151 |
152 |
153 | def random_color(self):
154 | red = random.randint(0, 255)
155 | green = random.randint(0, 255)
156 | blue = random.randint(0, 255)
157 | return (red, green, blue)
--------------------------------------------------------------------------------
/assets/examples/create_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from medmnistc.dataset_manager import DatasetManager\n",
10 | "\n",
11 | "import numpy as np"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "medmnist_path = \"/mnt/data/datasets/medmnist\" # PATH TO THE CLEAN IMAGES\n",
21 | "output_path = \"/mnt/data/datasets/medmnistc-tmp\" # PATH TO THE CORRUPTED IMAGES"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "### Create one dataset"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 4,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "name": "stdout",
38 | "output_type": "stream",
39 | "text": [
40 | "=========== breastmnist ===========\n",
41 | "Starting pixelate...\n"
42 | ]
43 | },
44 | {
45 | "name": "stderr",
46 | "output_type": "stream",
47 | "text": [
48 | "Severity 01: 100%|██████████| 156/156 [00:00<00:00, 3096.10it/s]\n",
49 | "Severity 02: 100%|██████████| 156/156 [00:00<00:00, 3244.26it/s]\n",
50 | "Severity 03: 100%|██████████| 156/156 [00:00<00:00, 3382.45it/s]\n",
51 | "Severity 04: 100%|██████████| 156/156 [00:00<00:00, 3501.48it/s]\n",
52 | "Severity 05: 100%|██████████| 156/156 [00:00<00:00, 3726.08it/s]\n"
53 | ]
54 | },
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "Starting jpeg_compression...\n"
60 | ]
61 | },
62 | {
63 | "name": "stderr",
64 | "output_type": "stream",
65 | "text": [
66 | "Severity 01: 100%|██████████| 156/156 [00:00<00:00, 1731.67it/s]\n",
67 | "Severity 02: 100%|██████████| 156/156 [00:00<00:00, 2852.85it/s]\n",
68 | "Severity 03: 100%|██████████| 156/156 [00:00<00:00, 3198.72it/s]\n",
69 | "Severity 04: 100%|██████████| 156/156 [00:00<00:00, 3380.25it/s]\n",
70 | "Severity 05: 100%|██████████| 156/156 [00:00<00:00, 3404.72it/s]\n"
71 | ]
72 | },
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "Starting speckle_noise...\n"
78 | ]
79 | },
80 | {
81 | "name": "stderr",
82 | "output_type": "stream",
83 | "text": [
84 | "Severity 01: 100%|██████████| 156/156 [00:00<00:00, 450.27it/s]\n",
85 | "Severity 02: 100%|██████████| 156/156 [00:00<00:00, 517.73it/s]\n",
86 | "Severity 03: 100%|██████████| 156/156 [00:00<00:00, 521.73it/s]\n",
87 | "Severity 04: 100%|██████████| 156/156 [00:00<00:00, 516.61it/s]\n",
88 | "Severity 05: 100%|██████████| 156/156 [00:00<00:00, 520.55it/s]\n"
89 | ]
90 | },
91 | {
92 | "name": "stdout",
93 | "output_type": "stream",
94 | "text": [
95 | "Starting motion_blur...\n"
96 | ]
97 | },
98 | {
99 | "name": "stderr",
100 | "output_type": "stream",
101 | "text": [
102 | "Severity 01: 100%|██████████| 156/156 [00:07<00:00, 21.04it/s]\n",
103 | "Severity 02: 100%|██████████| 156/156 [00:10<00:00, 15.30it/s]\n",
104 | "Severity 03: 100%|██████████| 156/156 [00:10<00:00, 15.28it/s]\n",
105 | "Severity 04: 100%|██████████| 156/156 [00:13<00:00, 11.80it/s]\n",
106 | "Severity 05: 100%|██████████| 156/156 [00:16<00:00, 9.53it/s]\n"
107 | ]
108 | },
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "Starting brightness_up...\n"
114 | ]
115 | },
116 | {
117 | "name": "stderr",
118 | "output_type": "stream",
119 | "text": [
120 | "Severity 01: 100%|██████████| 156/156 [00:00<00:00, 3433.98it/s]\n",
121 | "Severity 02: 100%|██████████| 156/156 [00:00<00:00, 2850.21it/s]\n",
122 | "Severity 03: 100%|██████████| 156/156 [00:00<00:00, 3442.84it/s]\n",
123 | "Severity 04: 100%|██████████| 156/156 [00:00<00:00, 3390.16it/s]\n",
124 | "Severity 05: 100%|██████████| 156/156 [00:00<00:00, 3406.70it/s]\n"
125 | ]
126 | },
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "Starting brightness_down...\n"
132 | ]
133 | },
134 | {
135 | "name": "stderr",
136 | "output_type": "stream",
137 | "text": [
138 | "Severity 01: 100%|██████████| 156/156 [00:00<00:00, 4325.34it/s]\n",
139 | "Severity 02: 100%|██████████| 156/156 [00:00<00:00, 4256.60it/s]\n",
140 | "Severity 03: 100%|██████████| 156/156 [00:00<00:00, 4285.20it/s]\n",
141 | "Severity 04: 100%|██████████| 156/156 [00:00<00:00, 4243.43it/s]\n",
142 | "Severity 05: 100%|██████████| 156/156 [00:00<00:00, 4163.05it/s]\n"
143 | ]
144 | },
145 | {
146 | "name": "stdout",
147 | "output_type": "stream",
148 | "text": [
149 | "Starting contrast_down...\n"
150 | ]
151 | },
152 | {
153 | "name": "stderr",
154 | "output_type": "stream",
155 | "text": [
156 | "Severity 01: 100%|██████████| 156/156 [00:00<00:00, 3279.54it/s]\n",
157 | "Severity 02: 100%|██████████| 156/156 [00:00<00:00, 3282.44it/s]\n",
158 | "Severity 03: 100%|██████████| 156/156 [00:00<00:00, 3289.52it/s]\n",
159 | "Severity 04: 100%|██████████| 156/156 [00:00<00:00, 3302.50it/s]\n",
160 | "Severity 05: 100%|██████████| 156/156 [00:00<00:00, 3310.47it/s]\n"
161 | ]
162 | }
163 | ],
164 | "source": [
165 | "dataset = \"breastmnist\"\n",
166 | "\n",
167 | "ds_manager = DatasetManager(medmnist_path = medmnist_path, output_path=output_path)\n",
168 | "ds_manager.create_dataset(dataset_name = dataset)"
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "metadata": {},
174 | "source": [
175 | "### Create all datasets"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "ds_manager = DatasetManager(medmnist_path = medmnist_path, output_path=output_path)\n",
185 | "ds_manager.create_dataset(dataset = \"all\")"
186 | ]
187 | }
188 | ],
189 | "metadata": {
190 | "kernelspec": {
191 | "display_name": "medmnistc",
192 | "language": "python",
193 | "name": "python3"
194 | },
195 | "language_info": {
196 | "codemirror_mode": {
197 | "name": "ipython",
198 | "version": 3
199 | },
200 | "file_extension": ".py",
201 | "mimetype": "text/x-python",
202 | "name": "python",
203 | "nbconvert_exporter": "python",
204 | "pygments_lexer": "ipython3",
205 | "version": "3.11.7"
206 | }
207 | },
208 | "nbformat": 4,
209 | "nbformat_minor": 2
210 | }
211 |
--------------------------------------------------------------------------------
/medmnistc/utils/baselines.py:
--------------------------------------------------------------------------------
1 | BASELINES = {
2 |
3 | "bloodmnist" : {
4 | "clean_score": 0.018369282094266692,
5 | "raw_scores": {
6 | "pixelate": 0.05975618949678589,
7 | "jpeg_compression": 0.13976187053705183,
8 | "defocus_blur": 0.22399795597780647,
9 | "motion_blur": 0.1845192137218041,
10 | "brightness_up": 0.06946294328845175,
11 | "brightness_down": 0.0790077546902872,
12 | "contrast_up": 0.05215362284240639,
13 | "contrast_down": 0.07966068499560328,
14 | "saturate": 0.03478558170020538,
15 | "stain_deposit": 0.18536119379858546,
16 | "bubble": 0.12898833684644578
17 | }
18 | },
19 |
20 | "breastmnist" : {
21 | "clean_score": 0.14536340852130336,
22 | "raw_scores": {
23 | "pixelate": 0.46127819548872184,
24 | "jpeg_compression": 0.2986215538847118,
25 | "speckle_noise": 0.38709273182957393,
26 | "motion_blur": 0.25902255639097743,
27 | "brightness_up": 0.3197994987468672,
28 | "brightness_down": 0.21127819548872181,
29 | "contrast_down": 0.15864661654135334
30 | }
31 | },
32 |
33 | "chestmnist" : {
34 | "clean_score": 0.2839064387513035,
35 | "raw_scores": {
36 | "pixelate": 0.44895070087295996,
37 | "jpeg_compression": 0.34575690815860566,
38 | "gaussian_noise": 0.47688317292098203,
39 | "speckle_noise": 0.45687894795025785,
40 | "impulse_noise": 0.48553631547337917,
41 | "shot_noise": 0.4926966815396587,
42 | "gaussian_blur": 0.3176801263607164,
43 | "brightness_up": 0.2993721457134663,
44 | "brightness_down": 0.29711530406534037,
45 | "contrast_up": 0.28751004051276824,
46 | "contrast_down": 0.2954742427570406,
47 | "gamma_corr_up": 0.287462117104012,
48 | "gamma_corr_down": 0.2891642074561852
49 | }
50 | },
51 |
52 | "dermamnist" : {
53 | "clean_score": 0.3985775597242611,
54 | "raw_scores": {
55 | "pixelate": 0.5105118313063095,
56 | "jpeg_compression": 0.47716296678659287,
57 | "gaussian_noise": 0.770161396521756,
58 | "speckle_noise": 0.7688784860204994,
59 | "impulse_noise": 0.7995552265467947,
60 | "shot_noise": 0.8391552104792783,
61 | "defocus_blur": 0.7008318467374395,
62 | "motion_blur": 0.6266414682428412,
63 | "zoom_blur": 0.5277668583085878,
64 | "brightness_up": 0.4473990148869178,
65 | "brightness_down": 0.5772968080812169,
66 | "contrast_up": 0.430415384126964,
67 | "contrast_down": 0.5398210493891896,
68 | "black_corner": 0.7591457853547467,
69 | "characters": 0.4461917314736283
70 | }
71 | },
72 |
73 | "octmnist" : {
74 | "clean_score": 0.20199999999999996,
75 | "raw_scores": {
76 | "pixelate": 0.3124,
77 | "jpeg_compression": 0.29640000000000005,
78 | "speckle_noise": 0.386,
79 | "defocus_blur": 0.3084,
80 | "motion_blur": 0.4892,
81 | "contrast_down": 0.4196
82 | }
83 | },
84 |
85 | "organamnist" : {
86 | "clean_score": 0.050053710403475504,
87 | "raw_scores": {
88 | "pixelate": 0.07589728744538755,
89 | "jpeg_compression": 0.10233182799848468,
90 | "gaussian_noise": 0.558867512509986,
91 | "speckle_noise": 0.5008977196696448,
92 | "impulse_noise": 0.6005122792556206,
93 | "shot_noise": 0.5704549211261293,
94 | "gaussian_blur": 0.23340661727240267,
95 | "brightness_up": 0.14763547790322412,
96 | "brightness_down": 0.2646445055204496,
97 | "contrast_up": 0.10977873034814234,
98 | "contrast_down": 0.1268424696095805,
99 | "gamma_corr_up": 0.10360009998047674,
100 | "gamma_corr_down": 0.10577092296366888
101 | }
102 | },
103 |
104 | "organcmnist" : {
105 | "clean_score": 0.07596783210891289,
106 | "raw_scores": {
107 | "pixelate": 0.12161270665759313,
108 | "jpeg_compression": 0.10754481922990186,
109 | "gaussian_noise": 0.5736544084235357,
110 | "speckle_noise": 0.5130960298987614,
111 | "impulse_noise": 0.5997780092763433,
112 | "shot_noise": 0.5492477253691926,
113 | "gaussian_blur": 0.316895537101905,
114 | "brightness_up": 0.16011849294420644,
115 | "brightness_down": 0.2324271692594097,
116 | "contrast_up": 0.09991836613387597,
117 | "contrast_down": 0.15327395964015147,
118 | "gamma_corr_up": 0.1426418229218433,
119 | "gamma_corr_down": 0.12390902424472901
120 | }
121 | },
122 |
123 | "organsmnist" : {
124 | "clean_score": 0.24300485476777922,
125 | "raw_scores": {
126 | "pixelate": 0.31259415418134306,
127 | "jpeg_compression": 0.26996994867350865,
128 | "gaussian_noise": 0.5889124203738291,
129 | "speckle_noise": 0.5450100158501043,
130 | "impulse_noise": 0.6010108858901404,
131 | "shot_noise": 0.5519125643755777,
132 | "gaussian_blur": 0.4466645657200248,
133 | "brightness_up": 0.3598501969979232,
134 | "brightness_down": 0.3592866403894267,
135 | "contrast_up": 0.2871702451645464,
136 | "contrast_down": 0.3128243336438499,
137 | "gamma_corr_up": 0.30253196514691594,
138 | "gamma_corr_down": 0.3116016263827458
139 | }
140 | },
141 |
142 | "pathmnist" : {
143 | "clean_score": 0.08839299601334816,
144 | "raw_scores": {
145 | "pixelate": 0.20934201922262957,
146 | "jpeg_compression": 0.24709267033930377,
147 | "defocus_blur": 0.4621259742146214,
148 | "motion_blur": 0.3689076534699618,
149 | "brightness_up": 0.41551581022730166,
150 | "brightness_down": 0.2811688318161528,
151 | "contrast_up": 0.2338925659881773,
152 | "contrast_down": 0.13718687818913905,
153 | "saturate": 0.3341143518079194,
154 | "stain_deposit": 0.24930831751137275,
155 | "bubble": 0.10818301335634566
156 | }
157 | },
158 |
159 | "pneumoniamnist" : {
160 | "clean_score": 0.10085470085470083,
161 | "raw_scores": {
162 | "pixelate": 0.22102564102564104,
163 | "jpeg_compression": 0.2798290598290598,
164 | "gaussian_noise": 0.4414529914529915,
165 | "speckle_noise": 0.3457264957264957,
166 | "impulse_noise": 0.4581196581196581,
167 | "shot_noise": 0.4106837606837607,
168 | "gaussian_blur": 0.2730769230769231,
169 | "brightness_up": 0.14649572649572645,
170 | "brightness_down": 0.2782051282051282,
171 | "contrast_up": 0.09529914529914527,
172 | "contrast_down": 0.23743589743589744,
173 | "gamma_corr_up": 0.09606837606837607,
174 | "gamma_corr_down": 0.16641025641025642
175 | }
176 | },
177 |
178 | "retinamnist" : {
179 | "clean_score": 0.5582064849927977,
180 | "raw_scores": {
181 | "pixelate": 0.6017003263074344,
182 | "jpeg_compression": 0.6795223564688244,
183 | "gaussian_noise": 0.7040832524914012,
184 | "speckle_noise": 0.6135258841167651,
185 | "defocus_blur": 0.6882905018079196,
186 | "motion_blur": 0.7075760943057883,
187 | "brightness_down": 0.5844791133844842,
188 | "contrast_down": 0.6092882382338243
189 | }
190 | },
191 |
192 | "tissuemnist" : {
193 | "clean_score": 0.3917787199385978,
194 | "raw_scores": {
195 | "pixelate": 0.49272908707597907,
196 | "jpeg_compression": 0.508070087070269,
197 | "impulse_noise": 0.7970869222746931,
198 | "gaussian_blur": 0.4805143361974326,
199 | "brightness_up": 0.4618274173802247,
200 | "brightness_down": 0.51981943234386,
201 | "contrast_up": 0.5026233641102424,
202 | "contrast_down": 0.5369342652194591
203 | }
204 | },
205 | }
--------------------------------------------------------------------------------
/medmnistc/eval.py:
--------------------------------------------------------------------------------
1 | from medmnistc.utils.baselines import BASELINES
2 |
3 | from sklearn.metrics import balanced_accuracy_score
4 | import numpy as np
5 | import json
6 | import os
7 |
8 |
9 | class Evaluator:
10 | def __init__(self,
11 | dataset_name: str,
12 | true_labels: list,
13 | corruption_types: list,
14 | output_folder: str,
15 | architecture: str,
16 | task: str,
17 | suffix_log: str = ''):
18 | """
19 | Evaluates the robustness of a given model on a set of pre-defined corruptions.
20 |
21 | :param dataset_name: Name of the dataset (used for logging).
22 | :param true_labels: True labels of the current dataset.
23 | :param corruption_types: List of corruptions used for the current experiment.
24 | :param output_folder: Where to store the output logs (json file).
25 | :param architecture: Name of the architecture (logging purposes).
26 | :param task: Classification task (i.e., binary-class etc).
27 | :param suffix: Suffix of the logging file (e.g. seed of the current experiment)
28 | """
29 | self.dataset_name = dataset_name
30 | self.len_dataset = len(true_labels)
31 | self.corruption_types = corruption_types
32 | self.output_folder = output_folder
33 | self.architecture = architecture
34 | self.task = task
35 | self.suffix_log = suffix_log
36 |
37 | self.initialize_evaluation()
38 |
39 | self.true_labels = np.array(true_labels)
40 | if self.true_labels.shape[1] == 1: # multi-class or binary
41 | self.true_labels = self.true_labels.reshape(-1,) # flatten
42 |
43 |
44 | def initialize_evaluation(self):
45 | """
46 | Load the baseline logging file, if required, and setup evaluation function based
47 | on the current classification task.
48 | """
49 | self.corruption_errors = {corruption: [] for corruption in self.corruption_types}
50 | self.clean_score = None # Init
51 |
52 | assert self.dataset_name in BASELINES.keys(), f"{self.dataset_name} has no pre-defined baselines in /utils/baselines.py"
53 | self.corruption_errors_alexnet = BASELINES[self.dataset_name]
54 |
55 | self.evaluation_metric = self.get_eval_metric()
56 |
57 |
58 | def get_eval_metric(self):
59 | """
60 | Define the appropriate evaluation function based on the current task.
61 | """
62 | # Return the average balanced accuracy per label, using the chosen operating points (per label).
63 | if self.task == "multi-label, binary-class":
64 | return lambda y_true, y_score, threshold: 1.0 - np.mean(
65 | [balanced_accuracy_score(y_true[:, i], y_score[:, i] > threshold[i]) for i in range(y_true.shape[1])]
66 | )
67 |
68 | # Returns the balanced accuracy, using the chosen operating point.
69 | elif self.task == "binary-class":
70 | return lambda y_true, y_score, threshold: 1.0 - balanced_accuracy_score(y_true, y_score[:, -1] > threshold)
71 |
72 | # Returns the balanced accuracy, neglecting the default `threshold` argument.
73 | elif self.task == "multi-class" or self.task == "ordinal-regression":
74 | return lambda y_true, y_score, _: 1.0 - balanced_accuracy_score(y_true, np.argmax(y_score, axis=-1))
75 |
76 | else:
77 | raise ValueError(f"Unknown task type {self.task}")
78 |
79 |
80 | def evaluate(self, predicted_probabilities, corruption_type, threshold=0.5):
81 | """
82 | Evaluate the predictions of the current model.
83 |
84 | :param predicted_probabilities: List of raw predictions (i.e., probabilities)
85 | Note that the dataset is "repeatd" 5 times
86 | due to the 5 increasing severities.
87 | :param corruption_type: Name of the current corruption.
88 | :param threshold: Operating point(s) based on the given task.
89 | float if task == "binary-class"
90 | list[float] if task == "multi-label, binary-class"
91 | Note: we use a list of operating points, as we may want to
92 | tune label-specific operating points.
93 | If that's not the case, just use [0.5] * num_labels
94 | None if task == "multi-class" or task == "ordinal-regression"
95 | """
96 |
97 | for severity in range(5):
98 | # get probabilities of the current severity slice
99 | index_range = slice(self.len_dataset * severity, self.len_dataset * (severity + 1))
100 | curr_prob = predicted_probabilities[index_range]
101 | # calculate relative score and update evaluation metric
102 | score = self.evaluation_metric(self.true_labels, curr_prob, threshold)
103 | self.corruption_errors[corruption_type].append(score)
104 |
105 |
106 | def evaluate_clean(self, predicted_probabilities, threshold=0.5):
107 | """
108 | Evaluate clean dataset in order to calculate the relative corruption error.
109 |
110 | :param predicted_probabilities: List of raw predictions (i.e., probabilities)
111 | Note that the dataset is "repeatd" 5 times
112 | due to the 5 increasing severities.
113 | :param threshold: Operating point(s) based on the given task.
114 | float if task == "binary-class"
115 | list[float] if task == "multi-label, binary-class"
116 | None if task == "multi-class" or task == "ordinal-regression"
117 | """
118 | score = self.evaluation_metric(self.true_labels, predicted_probabilities, threshold)
119 | self.clean_score = score
120 |
121 |
122 | def dump_summary(self):
123 | """
124 | Store a json file containing the aggregated and raw results.
125 | """
126 | self.output_log = {}
127 | self.populate_summary()
128 |
129 | full_output_path = os.path.join(self.output_folder, f'{self.dataset_name}_{self.architecture}_{self.suffix_log}.json')
130 | with open(full_output_path, 'w') as f:
131 | json.dump(self.output_log, f, indent=4)
132 |
133 | print(f'Logs stored at `{full_output_path}`')
134 |
135 |
136 | def populate_summary(self):
137 | """
138 | Calculate the error metrics according to the formulas reported on the paper.
139 | """
140 |
141 | assert self.clean_score, "You first need to compute the clean error via self.evaluate_clean(...)"
142 |
143 | self.output_log['metrics'] = {'clean_score': self.clean_score or 0}
144 | self.output_log['be_scores'] = {}
145 | self.output_log['rbe_scores'] = {}
146 |
147 | for corruption, errors in self.corruption_errors.items():
148 |
149 | # For other architectures, normalize errors against AlexNet's performance
150 | alexnet_error = self.corruption_errors_alexnet['raw_scores'][corruption]
151 | alexnet_clean_score = self.corruption_errors_alexnet['clean_score']
152 |
153 | # Ensure there are scores to normalize against to avoid division by zero
154 | if alexnet_error and alexnet_clean_score is not None:
155 | # Normalized Balanced Error (BE)
156 | be = np.mean(errors) / alexnet_error # alexnet_error is already averaged across severities
157 |
158 | # Calculate Relative Balanced Error (RBE)
159 | rbe_num = np.mean(errors) - self.clean_score # the clean score would be subtracted 5 times and divided by 5 (mean). So, we can put this out
160 | rbe_denom = alexnet_error - alexnet_clean_score # same here
161 | rbe = rbe_num / rbe_denom
162 |
163 | self.output_log['be_scores'][corruption] = be
164 | self.output_log['rbe_scores'][corruption] = rbe
165 |
166 | # Compute overall corrupted score and relative corrupted error
167 | self.output_log['metrics']['be'] = np.mean(list(self.output_log['be_scores'].values()))
168 | self.output_log['metrics']['rbe'] = np.mean(list(self.output_log['rbe_scores'].values()))
169 |
170 | # Include raw scores for completeness
171 | self.output_log['raw_scores'] = {k:np.mean(v) for k,v in self.corruption_errors.items()} # store only the average
172 |
--------------------------------------------------------------------------------
/assets/examples/evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Evaluation"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "### Import"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 49,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "%%capture\n",
24 | "\n",
25 | "from medmnistc.dataset import CorruptedMedMNIST\n",
26 | "from medmnistc.eval import Evaluator\n",
27 | "from medmnistc.corruptions.registry import CORRUPTIONS_DS\n",
28 | "\n",
29 | "from torch.utils.data import DataLoader\n",
30 | "from medmnist import INFO\n",
31 | "from copy import deepcopy\n",
32 | "from tqdm import tqdm\n",
33 | "\n",
34 | "import torchvision.transforms as transforms\n",
35 | "import medmnist\n",
36 | "import torch.nn as nn\n",
37 | "import torch\n",
38 | "import timm"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "### Setup experiment"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 50,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "config = {\n",
55 | " 'dataset' : 'breastmnist',\n",
56 | " 'architecture' : 'resnet18.tv_in1k', # timm-equivalent name\n",
57 | " 'medmnist_path' : '/mnt/data/datasets/medmnist',\n",
58 | " 'medmnistc_path' : '/mnt/data/datasets/medmnistc', \n",
59 | " 'logs_path' : './',\n",
60 | " 'seed' : 42, # training seed (if any) - here it is used in `Evaluator` as id for the output logs\n",
61 | "}\n",
62 | "\n",
63 | "info = INFO[config['dataset']]\n",
64 | "\n",
65 | "config.update({\n",
66 | " 'task': info['task'],\n",
67 | " 'in_channel': info['n_channels'],\n",
68 | " 'num_classes': len(info['label'])\n",
69 | "})\n",
70 | "\n",
71 | "# Define model - we are further training in this example\n",
72 | "model = timm.create_model(config['architecture'], pretrained=True)\n",
73 | "model = model.eval()\n",
74 | "\n",
75 | "mean, std = model.default_cfg['mean'], model.default_cfg['std']\n",
76 | "\n",
77 | "# Load clean dataset\n",
78 | "DataClass = getattr(medmnist, info['python_class'])\n",
79 | "\n",
80 | "data_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])\n",
81 | "test_dataset_clean = DataClass(split='test', transform=data_transform, download=False, as_rgb=True, size=224, root=config['medmnist_path']) \n",
82 | "test_loader_clean = DataLoader(test_dataset_clean, batch_size=128, shuffle=False, num_workers=4, persistent_workers=True)\n",
83 | "\n",
84 | "# Init the Evaluator class\n",
85 | "corruptions = CORRUPTIONS_DS[config['dataset']]\n",
86 | "evaluator = Evaluator(dataset_name=config['dataset'],\n",
87 | " true_labels=test_dataset_clean.labels,\n",
88 | " corruption_types=corruptions.keys(),\n",
89 | " output_folder=config['logs_path'],\n",
90 | " architecture=config['architecture'],\n",
91 | " task=config['task'],\n",
92 | " suffix_log=f\"s{config['seed']}\")"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "### Inference "
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 51,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "def evaluate(model, dataloader, task, device = 'cuda:0'):\n",
109 | " \"\"\"\n",
110 | " Evaluate a model on the current corrupted test set.\n",
111 | "\n",
112 | " :param config: Dictionary containing the parameters and hyperparameters.\n",
113 | " :param dataloader: DataLoader for the test set.\n",
114 | " :param task: Classification task ('multi-label, binary-class','multi-class', and so on..).\n",
115 | " :param device: Running device (cuda or cpu).\n",
116 | " :return: Predictions (raw probabilities).\n",
117 | " \"\"\"\n",
118 | " \n",
119 | " # Load model and prediction function\n",
120 | " if task == \"multi-label, binary-class\":\n",
121 | " prediction = nn.Sigmoid()\n",
122 | " else:\n",
123 | " prediction = nn.Softmax(dim=1)\n",
124 | "\n",
125 | " model = model.to(device)\n",
126 | "\n",
127 | " # Run the Evaluation\n",
128 | " y_pred = torch.tensor([]).to(device)\n",
129 | "\n",
130 | " with torch.no_grad():\n",
131 | " for images, labels in tqdm(dataloader):\n",
132 | " # Map the data to the available device\n",
133 | " images, labels = images.to(device), labels.to(torch.float32).to(device)\n",
134 | " outputs = model(images)\n",
135 | " outputs = prediction(outputs)\n",
136 | " # Store the predictions\n",
137 | " y_pred = torch.cat((y_pred, deepcopy(outputs)), 0)\n",
138 | "\n",
139 | " return y_pred"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 52,
145 | "metadata": {},
146 | "outputs": [
147 | {
148 | "name": "stderr",
149 | "output_type": "stream",
150 | "text": [
151 | "100%|██████████| 2/2 [00:00<00:00, 11.78it/s]\n"
152 | ]
153 | },
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "pixelate\n"
159 | ]
160 | },
161 | {
162 | "name": "stderr",
163 | "output_type": "stream",
164 | "text": [
165 | "100%|██████████| 7/7 [00:00<00:00, 20.49it/s]\n"
166 | ]
167 | },
168 | {
169 | "name": "stdout",
170 | "output_type": "stream",
171 | "text": [
172 | "jpeg_compression\n"
173 | ]
174 | },
175 | {
176 | "name": "stderr",
177 | "output_type": "stream",
178 | "text": [
179 | "100%|██████████| 7/7 [00:00<00:00, 20.31it/s]\n"
180 | ]
181 | },
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "speckle_noise\n"
187 | ]
188 | },
189 | {
190 | "name": "stderr",
191 | "output_type": "stream",
192 | "text": [
193 | "100%|██████████| 7/7 [00:00<00:00, 20.65it/s]\n"
194 | ]
195 | },
196 | {
197 | "name": "stdout",
198 | "output_type": "stream",
199 | "text": [
200 | "motion_blur\n"
201 | ]
202 | },
203 | {
204 | "name": "stderr",
205 | "output_type": "stream",
206 | "text": [
207 | "100%|██████████| 7/7 [00:00<00:00, 19.97it/s]\n"
208 | ]
209 | },
210 | {
211 | "name": "stdout",
212 | "output_type": "stream",
213 | "text": [
214 | "brightness_up\n"
215 | ]
216 | },
217 | {
218 | "name": "stderr",
219 | "output_type": "stream",
220 | "text": [
221 | "100%|██████████| 7/7 [00:00<00:00, 20.03it/s]\n"
222 | ]
223 | },
224 | {
225 | "name": "stdout",
226 | "output_type": "stream",
227 | "text": [
228 | "brightness_down\n"
229 | ]
230 | },
231 | {
232 | "name": "stderr",
233 | "output_type": "stream",
234 | "text": [
235 | "100%|██████████| 7/7 [00:00<00:00, 20.29it/s]\n"
236 | ]
237 | },
238 | {
239 | "name": "stdout",
240 | "output_type": "stream",
241 | "text": [
242 | "contrast_down\n"
243 | ]
244 | },
245 | {
246 | "name": "stderr",
247 | "output_type": "stream",
248 | "text": [
249 | "100%|██████████| 7/7 [00:00<00:00, 20.43it/s]"
250 | ]
251 | },
252 | {
253 | "name": "stdout",
254 | "output_type": "stream",
255 | "text": [
256 | "Logs stored at `./breastmnist_resnet18.tv_in1k_s42.json`\n"
257 | ]
258 | },
259 | {
260 | "name": "stderr",
261 | "output_type": "stream",
262 | "text": [
263 | "\n"
264 | ]
265 | }
266 | ],
267 | "source": [
268 | "# Evaluate clean performance\n",
269 | "y_pred = evaluate(model, test_loader_clean, config['task'])\n",
270 | "evaluator.evaluate_clean(y_pred.cpu().numpy())\n",
271 | "\n",
272 | "# Iterate over the designed corruptions.\n",
273 | "for corruption in corruptions.keys():\n",
274 | "\n",
275 | " print(corruption)\n",
276 | " \n",
277 | " # Load the corrupted test set, according to the selected corruption\n",
278 | " corrupted_test_test = CorruptedMedMNIST(\n",
279 | " dataset_name = config['dataset'], \n",
280 | " corruption = corruption,\n",
281 | " root = config['medmnistc_path'],\n",
282 | " as_rgb = test_dataset_clean.as_rgb,\n",
283 | " mmap_mode='r',\n",
284 | " norm_mean = mean,\n",
285 | " norm_std = std\n",
286 | " )\n",
287 | " \n",
288 | " # Get dataloader\n",
289 | " test_loader = DataLoader(corrupted_test_test, batch_size=128, shuffle=False, num_workers=4, persistent_workers=True)\n",
290 | "\n",
291 | " # Evaluate\n",
292 | " y_pred = evaluate(model, test_loader, config['task']) \n",
293 | "\n",
294 | " # Calculate the error\n",
295 | " evaluator.evaluate(y_pred.cpu().numpy(), corruption)\n",
296 | "\n",
297 | "# Create a json file containing the results\n",
298 | "evaluator.dump_summary()"
299 | ]
300 | }
301 | ],
302 | "metadata": {
303 | "kernelspec": {
304 | "display_name": "medmnistc",
305 | "language": "python",
306 | "name": "python3"
307 | },
308 | "language_info": {
309 | "codemirror_mode": {
310 | "name": "ipython",
311 | "version": 3
312 | },
313 | "file_extension": ".py",
314 | "mimetype": "text/x-python",
315 | "name": "python",
316 | "nbconvert_exporter": "python",
317 | "pygments_lexer": "ipython3",
318 | "version": "3.11.7"
319 | }
320 | },
321 | "nbformat": 4,
322 | "nbformat_minor": 2
323 | }
324 |
--------------------------------------------------------------------------------
/medmnistc/visualizer.py:
--------------------------------------------------------------------------------
1 | from medmnistc.corruptions.registry import DATASET_RGB, CORRUPTIONS_DS
2 |
3 | from skimage.util import montage as skimage_montage
4 | from PIL import Image
5 |
6 | import numpy as np
7 | import random
8 | import cv2
9 | import os
10 |
11 |
12 | class Visualizer:
13 | def __init__(self,
14 | medmnistc_path : str,
15 | medmnist_path : str,
16 | output_path : str):
17 | """
18 | Class used to plot examples of the selected corruptions.
19 |
20 | :param medmnistc_path: Root path of the corrupted datasets.
21 | :param medmnist_path: Root path of the clean datasets.
22 | :param output_path: Root path of the generated visualizations.
23 | """
24 | self.medmnist_path = medmnist_path
25 | self.medmnistc_path = medmnistc_path
26 | self.output_path = output_path
27 |
28 | self.supported_datasets = [
29 | 'bloodmnist', 'breastmnist', 'chestmnist', 'dermamnist',
30 | 'octmnist', 'organamnist', 'organcmnist', 'organsmnist',
31 | 'pathmnist', 'pneumoniamnist', 'retinamnist', 'tissuemnist'
32 | ]
33 |
34 | # Annotation hyperparameters
35 | self.font = cv2.FONT_HERSHEY_SIMPLEX
36 | self.font_scale = 0.5
37 | self.thickness = 1
38 | self.text_offset_x = 30
39 | self.text_offset_y = 30
40 | self.rect_offset = 5
41 |
42 | # Create folder
43 | os.makedirs(self.output_path, exist_ok=True)
44 |
45 |
46 | def plot_extended(self,
47 | dataset_name : str = None,
48 | idx_image : int = None):
49 | """
50 | Plot an image grid (N,5) where:
51 | - N is the number of the designed corruptions
52 | - 5 represents the 5 severity levels
53 |
54 | :param dataset_name: Name of the dataset to corrupt.
55 | Options: {'bloodmnist', 'breastmnist', 'chestmnist', 'dermamnist',
56 | 'octmnist', 'organamnist', 'organcmnist', 'organsmnist',
57 | 'pathmnist', 'pneumoniamnist', 'retinamnist', 'tissuemnist'}
58 | :param idx_image: Index of the selected image to corrupt and visualize.
59 | If None, a random index will be chosen.
60 | """
61 | assert dataset_name in self.supported_datasets, f"Dataset not found. Please choose one among : {self.supported_datasets}"
62 |
63 | output_folder = os.path.join(self.output_path,'extended')
64 | os.makedirs(output_folder, exist_ok=True)
65 |
66 | # Load clean images
67 | clean_test_images = np.load(os.path.join(self.medmnist_path,f'{dataset_name}_224.npz'))['test_images']
68 | num_images = len(clean_test_images)
69 |
70 | # Retrieve all corruption paths
71 | corruptions_path = os.listdir(os.path.join(self.medmnistc_path,dataset_name))
72 |
73 | # Setup image grid
74 | num_rows, num_cols = len(corruptions_path), 6
75 |
76 | # Select a random image, if not selected
77 | if not idx_image:
78 | idx_image = random.randint(0,num_images-1)
79 |
80 | # Check wether the dataset is RGB or not
81 | n_channels = 3 if DATASET_RGB[dataset_name] else 1
82 |
83 | # Init images to display
84 | images = []
85 |
86 | # Iterate over corruptions (ROWS)
87 | for corruption in CORRUPTIONS_DS[dataset_name].keys():
88 |
89 | test_images = np.load(os.path.join(self.medmnistc_path,dataset_name,f'{corruption}.npz'))['test_images']
90 |
91 | # Annotate the image in the first column
92 | corrutpion_name = corruption.split(".npz")[0].replace("_"," ")
93 | images.append(self._annotate_img(clean_test_images[idx_image].copy(),corrutpion_name))
94 |
95 | # Iterate over remaining corruption severities (COLUMNS)
96 | for sev_idx in range(0,5):
97 | idx_corr = idx_image + sev_idx * num_images
98 | images.append(test_images[idx_corr])
99 |
100 |
101 | # Create montage with all the selected images
102 | montage_arr = skimage_montage(
103 | images, channel_axis=3 if n_channels == 3 else None,
104 | grid_shape=(num_rows,num_cols),
105 | fill=(255,255,255)
106 | )
107 |
108 | # Store output
109 | filename = f'{dataset_name}_id{idx_image}.png'
110 | img_path = os.path.join(output_folder,filename)
111 |
112 | print(f'Image stored at : {img_path}')
113 |
114 | Image.fromarray(montage_arr).save(img_path)
115 |
116 |
117 | def plot_one_severity(self,
118 | dataset_name: str = None,
119 | idx_image:int = None,
120 | severity: int = 3,
121 | max_per_row: int = -1):
122 | """
123 | Plot an image along with all its corruptions in a row, with a user-specified severity.
124 |
125 | Name of the dataset to corrupt.
126 | Options: {'bloodmnist', 'breastmnist', 'chestmnist', 'dermamnist',
127 | 'octmnist', 'organamnist', 'organcmnist', 'organsmnist',
128 | 'pathmnist', 'pneumoniamnist', 'retinamnist', 'tissuemnist'}
129 | :param idx_image: Index of the selected image to corrupt and visualize.
130 | If None, a random index will be chosen.
131 | :param severity: Severity of the corruptions. This will be applied to all the corruptions.
132 | :param max_per_row: Maximum number of corruptions to show in a row.
133 | In `num_corruptions` > `max_per_row`, multiple images are stored.
134 | """
135 | assert dataset_name in self.supported_datasets, f"Dataset not found. Please choose one among : {self.supported_datasets}"
136 |
137 | # Init output folder
138 | output_folder = os.path.join(self.output_path,'one_severity')
139 | os.makedirs(output_folder, exist_ok=True)
140 |
141 | # Select the number of channels
142 | n_channels = 3 if DATASET_RGB[dataset_name] else 1
143 |
144 | # Load the clean test set
145 | clean_test_images = np.load(os.path.join(self.medmnist_path,f'{dataset_name}_224.npz'))['test_images']
146 | num_images = len(clean_test_images)
147 |
148 | # Retrieve all designed corruptions
149 | corruptions_path = os.listdir(os.path.join(self.medmnistc_path,dataset_name))
150 |
151 | if not idx_image:
152 | idx_image = random.randint(0,num_images-1)
153 |
154 | # Init images to display
155 | images = []
156 |
157 | # Define the output grid
158 | num_rows, num_cols = 1, len(corruptions_path) + 1 # +1 because of the clean one
159 | images.append(clean_test_images[idx_image]) # 1st image (clean one)
160 |
161 | # Iterate over corruptions
162 | for corruption in CORRUPTIONS_DS[dataset_name].keys():
163 |
164 | test_images = np.load(os.path.join(self.medmnistc_path,dataset_name,f'{corruption}.npz'))['test_images']
165 | corrutpion_name = corruption.split(".npz")[0].replace("_"," ")
166 | idx_corr = idx_image + (severity-1) * num_images
167 |
168 | # Annotate the image in the first column
169 | images.append(self._annotate_img(test_images[idx_corr],corrutpion_name))
170 |
171 | # Check if we need to decompose it into multiple images
172 | if max_per_row > 0 and len(images) > max_per_row:
173 |
174 | clean_image = images[0] # extract clean one (it will be always shown)
175 | corrupted_images = images[1:] # extract corrupted images
176 | num_corrupted_images = len(corrupted_images)
177 | num_parts = num_corrupted_images // (max_per_row-1) # define the number of plots (i.e., parts)
178 |
179 | for pt in range(num_parts+1):
180 |
181 | # Append clean image with the current corrupted ones
182 | curr_images = [clean_image] + corrupted_images[(max_per_row-1)*pt:(max_per_row-1)*pt + (max_per_row-1)]
183 |
184 | # Create montage with all the selected images
185 | montage_arr = skimage_montage(
186 | curr_images, channel_axis=3 if n_channels == 3 else None,
187 | grid_shape=(num_rows,max_per_row),
188 | fill=(255,255,255)
189 | )
190 |
191 | # Store output
192 | filename = f'{dataset_name}_id{idx_image}_sev{severity}_pt{pt+1}.png'
193 | img_path = os.path.join(output_folder,filename)
194 |
195 | print(f'Image stored at : {img_path}')
196 |
197 | Image.fromarray(montage_arr).save(img_path)
198 |
199 | elif max_per_row > 0:
200 |
201 | # Create montage with all the selected images
202 | montage_arr = skimage_montage(
203 | images, channel_axis=3 if n_channels == 3 else None,
204 | grid_shape=(num_rows,max_per_row),
205 | fill=(255,255,255)
206 | )
207 |
208 | # Store output
209 | filename = f'{dataset_name}_id{idx_image}_sev{severity}.png'
210 | img_path = os.path.join(output_folder,filename)
211 |
212 | print(f'Image stored at : {img_path}')
213 |
214 | Image.fromarray(montage_arr).save(img_path)
215 |
216 |
217 | else:
218 |
219 | # Create montage with all the selected images
220 | montage_arr = skimage_montage(
221 | images, channel_axis=3 if n_channels == 3 else None,
222 | grid_shape=(num_rows,num_cols),
223 | fill=(255,255,255)
224 | )
225 |
226 | # Store output
227 | filename = f'{dataset_name}_id{idx_image}_sev{severity}.png'
228 | img_path = os.path.join(output_folder,filename)
229 |
230 | print(f'Image stored at : {img_path}')
231 |
232 | Image.fromarray(montage_arr).save(img_path)
233 |
234 |
235 |
236 | def _annotate_img(self, img, text):
237 | """Annotate the selected image with the name of the corruption.
238 | Specifically, a white text over a black background is placed at:
239 | (text_offset_x, text_offset_y)
240 |
241 | :param img: Image to annotate (np.uint8)
242 | :param text: Text to add on the image.
243 | """
244 | # Get width and height of the text box
245 | (text_width, text_height), _ = cv2.getTextSize(text, self.font, self.font_scale, self.thickness)
246 |
247 | # Define box coordinates
248 | box_coords = (
249 | (self.text_offset_x - self.rect_offset, self.text_offset_y + self.rect_offset),
250 | (self.text_offset_x + text_width + self.rect_offset, self.text_offset_y - text_height - self.rect_offset)
251 | )
252 |
253 | # Black background & white text
254 | cv2.rectangle(img, box_coords[0], box_coords[1], (0, 0, 0), cv2.FILLED)
255 | cv2.putText(img,
256 | text,
257 | (self.text_offset_x, self.text_offset_y),
258 | self.font,
259 | self.font_scale,
260 | (255, 255, 255),
261 | self.thickness,
262 | cv2.LINE_AA)
263 |
264 | return np.array(img)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2020-2023 MedMNIST Team
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/medmnistc/corruptions/registry.py:
--------------------------------------------------------------------------------
1 | from .noise import GaussianNoise, ImpulseNoise, SpeckleNoise, ShotNoise
2 | from .compression import JPEGCompression, Pixelate
3 | from .filter import MotionBlur, DefocusBlur, ZoomBlur, GaussianBlur
4 | from .enhance import Brightness, Contrast, Saturate, GammaCorrection
5 | from .microscopy import Bubble, StainDeposit, BlackCorner, Characters
6 |
7 | import numpy as np
8 |
9 |
10 | CORRUPTIONS_DS = {
11 |
12 | 'pathmnist' : {
13 | 'pixelate': Pixelate(severity_params=[0.8, 0.6, 0.40, 0.30, 0.25]),
14 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
15 | 'defocus_blur' : DefocusBlur(severity_params=[(3, 0.1), (4, 0.1), (5, 0.2), (6,0.2), (7, 0.3)]),
16 | 'motion_blur' : MotionBlur(severity_params=[(5,5), (10, 5), (15, 5), (15, 8), (15, 12)]),
17 | 'brightness_up' : Brightness(severity_params=[1.1, 1.15, 1.2, 1.22, 1.25]),
18 | 'brightness_down' : Brightness(severity_params=[0.85, 0.80, 0.75, 0.72, 0.70]),
19 | 'contrast_up' : Contrast(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
20 | 'contrast_down' : Contrast(severity_params=[0.8, 0.7, 0.6, 0.55, 0.5]),
21 | 'saturate' : Saturate(severity_params=[0.05, 0.10, 0.15, 0.20, 0.25]),
22 | 'stain_deposit' : StainDeposit(severity_params=[1,2,3,4,5]),
23 | 'bubble' : Bubble(severity_params=[(7,15),(10,15),(12,15),(15,20),(17,25)])
24 | },
25 |
26 |
27 | 'bloodmnist' : {
28 | 'pixelate': Pixelate(severity_params=[0.6, 0.5, 0.40, 0.30, 0.25]),
29 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
30 | 'defocus_blur' : DefocusBlur(severity_params=[(2, 0.01), (3, 0.1), (4,0.1), (5,0.1), (6, 0.1)]),
31 | 'motion_blur' : MotionBlur(severity_params=[(3,3), (5,5), (10, 5), (10,7), (10, 9)]),
32 | 'brightness_up' : Brightness(severity_params=[1.1, 1.2, 1.3, 1.35, 1.4]),
33 | 'brightness_down' : Brightness(severity_params=[0.9, 0.8, 0.7, 0.6, 0.5]),
34 | 'contrast_up' : Contrast(severity_params=[1.1, 1.15, 1.2, 1.25, 1.3]),
35 | 'contrast_down' : Contrast(severity_params=[0.9, 0.8, 0.7, 0.6, 0.5]),
36 | 'saturate' : Saturate(severity_params=[0.05, 0.10, 0.15, 0.17, 0.20]),
37 | 'stain_deposit' : StainDeposit(severity_params=[1,2,3,3,3]),
38 | 'bubble' : Bubble(severity_params=[(5,10),(7,10),(10,10),(12,12),(15,12)])
39 | },
40 |
41 |
42 | 'dermamnist' : {
43 | 'pixelate': Pixelate(severity_params=[0.7, 0.5, 0.40, 0.30, 0.25]),
44 | 'jpeg_compression' : JPEGCompression(severity_params=[30, 20, 15, 10, 7]),
45 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, .08, .12, 0.18, 0.26]),
46 | 'speckle_noise' : SpeckleNoise(severity_params=[0.05, 0.15, 0.2, 0.35, 0.45]),
47 | 'impulse_noise' : ImpulseNoise(severity_params=[0.01, 0.03, 0.06, 0.09, 0.17]),
48 | 'shot_noise' : ShotNoise(severity_params=[60, 25, 18, 10, 5]),
49 | 'defocus_blur' : DefocusBlur(severity_params=[(4, 0.1), (5, 0.2), (6, 0.3), (7, 0.4), (8,0.5)]),
50 | 'motion_blur' : MotionBlur(severity_params=[(10, 5), (15, 5), (15, 8), (15, 12), (20, 15)]),
51 | 'zoom_blur' : ZoomBlur(severity_params=[
52 | np.arange(1, 1.11, 0.01),
53 | np.arange(1, 1.16, 0.01),
54 | np.arange(1, 1.21, 0.02),
55 | np.arange(1, 1.26, 0.02),
56 | np.arange(1, 1.31, 0.03)
57 | ]),
58 | 'brightness_up' : Brightness(severity_params=[1.1, 1.2, 1.3, 1.4, 1.5]),
59 | 'brightness_down' : Brightness(severity_params=[0.9, 0.8, 0.7, 0.6, 0.5]),
60 | 'contrast_up' : Contrast(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
61 | 'contrast_down' : Contrast(severity_params=[0.8, 0.7, 0.6, 0.5, 0.4]),
62 | 'black_corner' : BlackCorner(severity_params=[1.10, 1.05, 1.00, 0.90, 0.95]),
63 | 'characters' : Characters(severity_params=[(1,6,0.14),(2,7,0.15),(3,8,0.16),(4,9,0.17),(6,10,0.18)])
64 | },
65 |
66 |
67 | 'retinamnist' : {
68 | 'pixelate': Pixelate(severity_params=[0.8, 0.60, 0.50, 0.40, 0.35]),
69 | 'jpeg_compression' : JPEGCompression(severity_params=[30, 25, 20, 10, 5]),
70 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, 0.08, 0.12, 0.16, 0.20]),
71 | 'speckle_noise' : SpeckleNoise(severity_params=[0.10, 0.15, 0.20, 0.25, 0.30]),
72 | 'defocus_blur' : DefocusBlur(severity_params=[(4, 0.1), (5, 0.2), (6, 0.3), (7, 0.4), (8,0.5), (9,0.6)]),
73 | 'motion_blur' : MotionBlur(severity_params=[(8, 5), (15, 5), (15, 8), (15, 12), (20, 15)]),
74 | 'brightness_down' : Brightness(severity_params=[0.9, 0.8, 0.7, 0.6, 0.5]),
75 | 'contrast_down' : Contrast(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
76 | },
77 |
78 |
79 | 'tissuemnist' : {
80 | 'pixelate': Pixelate(severity_params=[0.40, 0.30, 0.20, 0.15, 0.10]),
81 | 'jpeg_compression' : JPEGCompression(severity_params=[25, 20, 15, 10, 7]),
82 | 'impulse_noise' : ImpulseNoise(severity_params=[0.01, 0.015, 0.02, 0.025, 0.03]),
83 | 'gaussian_blur' : GaussianBlur(severity_params=[13, 15, 17, 21, 25]),
84 | 'brightness_up' : Brightness(severity_params=[1.3, 1.4, 1.5, 1.6, 1.7]),
85 | 'brightness_down' : Brightness(severity_params=[0.8, 0.7, 0.6, 0.5, 0.4]),
86 | 'contrast_up' : Contrast(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
87 | 'contrast_down' : Contrast(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
88 | },
89 |
90 |
91 | 'octmnist' : {
92 | 'pixelate': Pixelate(severity_params=[0.30, 0.25, 0.20, 0.15, 0.10]), #
93 | 'jpeg_compression' : JPEGCompression(severity_params=[30, 15, 10, 7, 5]),
94 | 'speckle_noise' : SpeckleNoise(severity_params=[0.15, 0.30, 0.40, 0.50, 0.60]),
95 | 'defocus_blur' : DefocusBlur(severity_params=[(0.5, 0.6), (1, 0.5), (1.5, 0.1), (2.0,0.5), (2.5,0.1)]),
96 | 'motion_blur' : MotionBlur(severity_params=[(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)]),
97 | 'contrast_down' : Contrast(severity_params=[0.6, 0.4, 0.3, 0.2, 0.15])
98 | },
99 |
100 |
101 | 'breastmnist' : {
102 | 'pixelate': Pixelate(severity_params=[0.30, 0.25, 0.20, 0.15, 0.10]),
103 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
104 | 'speckle_noise' : SpeckleNoise(severity_params=[0.10, 0.15, 0.20, 0.25, 0.30]),
105 | 'motion_blur' : MotionBlur(severity_params=[(5,5), (9, 7), (9,10), (13, 10), (17, 12)]),
106 | 'brightness_up' : Brightness(severity_params=[1.4, 1.5, 1.6, 1.8, 2.0]),
107 | 'brightness_down' : Brightness(severity_params=[0.55, 0.5, 0.45, 0.4, 0.3]),
108 | 'contrast_down' : Contrast(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
109 | },
110 |
111 |
112 | 'chestmnist' : {
113 | 'pixelate': Pixelate(severity_params=[0.30, 0.25, 0.20, 0.15, 0.10]),
114 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
115 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, .08, .12, 0.18, 0.26]),
116 | 'speckle_noise' : SpeckleNoise(severity_params=[0.05, 0.15, 0.2, 0.35, 0.45]),
117 | 'impulse_noise' : ImpulseNoise(severity_params=[0.01, 0.03, 0.06, 0.09, 0.17]),
118 | 'shot_noise' : ShotNoise(severity_params=[60, 25, 18, 10, 5]),
119 | 'gaussian_blur' : GaussianBlur(severity_params=[3, 5, 7, 9, 11, 13]),
120 | 'brightness_up' : Brightness(severity_params=[1.1, 1.2, 1.3, 1.4, 1.5]),
121 | 'brightness_down' : Brightness(severity_params=[0.9, 0.8, 0.7, 0.6, 0.5]),
122 | 'contrast_up' : Contrast(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
123 | 'contrast_down' : Contrast(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
124 | 'gamma_corr_up' : GammaCorrection(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
125 | 'gamma_corr_down' : GammaCorrection(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
126 | },
127 |
128 |
129 | 'pneumoniamnist' : {
130 | 'pixelate': Pixelate(severity_params=[0.8, 0.7, 0.6, 0.5, 0.40]),
131 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
132 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, 0.05, 0.06, 0.07, 0.08]),
133 | 'speckle_noise' : SpeckleNoise(severity_params=[0.05, 0.07, 0.10, 0.15, 0.20]),
134 | 'impulse_noise' : ImpulseNoise(severity_params=[0.005, 0.01, 0.013, 0.017, 0.02]),
135 | 'shot_noise' : ShotNoise(severity_params=[300, 200, 150, 100, 80]),
136 | 'gaussian_blur' : GaussianBlur(severity_params=[3, 5, 7, 9, 11, 13]),
137 | 'brightness_up' : Brightness(severity_params=[1.1, 1.2, 1.3, 1.4, 1.5]),
138 | 'brightness_down' : Brightness(severity_params=[0.9, 0.8, 0.7, 0.6, 0.5]),
139 | 'contrast_up' : Contrast(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
140 | 'contrast_down' : Contrast(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
141 | 'gamma_corr_up' : GammaCorrection(severity_params=[1.1, 1.2, 1.3, 1.4, 1.6]),
142 | 'gamma_corr_down' : GammaCorrection(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
143 | },
144 |
145 |
146 | 'organamnist' : {
147 | 'pixelate': Pixelate(severity_params=[0.7, 0.6, 0.5, 0.40, 0.35]),
148 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
149 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, 0.08, 0.12, 0.16, 0.20]),
150 | 'speckle_noise' : SpeckleNoise(severity_params=[0.05, 0.10, 0.20, 0.30, 0.40]),
151 | 'impulse_noise' : ImpulseNoise(severity_params=[0.01, 0.02, 0.03, 0.05, 0.08]),
152 | 'shot_noise' : ShotNoise(severity_params=[200, 100, 50, 25, 15]),
153 | 'gaussian_blur' : GaussianBlur(severity_params=[11, 13, 15, 17, 21]),
154 | 'brightness_up' : Brightness(severity_params=[1.2, 1.3, 1.4, 1.5, 1.6]),
155 | 'brightness_down' : Brightness(severity_params=[0.8, 0.75, 0.7, 0.65, 0.60]),
156 | 'contrast_up' : Contrast(severity_params=[1.3, 1.4, 1.6, 1.7, 1.8]),
157 | 'contrast_down' : Contrast(severity_params=[0.8, 0.7, 0.6, 0.55, 0.5]),
158 | 'gamma_corr_up' : GammaCorrection(severity_params=[1.3, 1.4, 1.6, 1.8, 2.0]),
159 | 'gamma_corr_down' : GammaCorrection(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
160 | },
161 |
162 |
163 | 'organcmnist' : {
164 | 'pixelate': Pixelate(severity_params=[0.7, 0.6, 0.5, 0.40, 0.35]),
165 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
166 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, 0.08, 0.12, 0.16, 0.20]),
167 | 'speckle_noise' : SpeckleNoise(severity_params=[0.05, 0.10, 0.20, 0.30, 0.40]),
168 | 'impulse_noise' : ImpulseNoise(severity_params=[0.01, 0.02, 0.03, 0.05, 0.08]),
169 | 'shot_noise' : ShotNoise(severity_params=[200, 100, 50, 25, 15]),
170 | 'gaussian_blur' : GaussianBlur(severity_params=[11, 13, 15, 17, 21]),
171 | 'brightness_up' : Brightness(severity_params=[1.2, 1.3, 1.4, 1.5, 1.6]),
172 | 'brightness_down' : Brightness(severity_params=[0.8, 0.75, 0.7, 0.65, 0.60]),
173 | 'contrast_up' : Contrast(severity_params=[1.3, 1.4, 1.6, 1.7, 1.8]),
174 | 'contrast_down' : Contrast(severity_params=[0.8, 0.7, 0.6, 0.55, 0.5]),
175 | 'gamma_corr_up' : GammaCorrection(severity_params=[1.3, 1.4, 1.6, 1.8, 2.0]),
176 | 'gamma_corr_down' : GammaCorrection(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
177 | },
178 |
179 |
180 | 'organsmnist' : {
181 | 'pixelate': Pixelate(severity_params=[0.7, 0.6, 0.5, 0.40, 0.35]),
182 | 'jpeg_compression' : JPEGCompression(severity_params=[50, 30, 15, 10, 7]),
183 | 'gaussian_noise' : GaussianNoise(severity_params=[0.04, 0.08, 0.12, 0.16, 0.20]),
184 | 'speckle_noise' : SpeckleNoise(severity_params=[0.05, 0.10, 0.20, 0.30, 0.40]),
185 | 'impulse_noise' : ImpulseNoise(severity_params=[0.01, 0.02, 0.03, 0.05, 0.08]),
186 | 'shot_noise' : ShotNoise(severity_params=[200, 100, 50, 25, 15]),
187 | 'gaussian_blur' : GaussianBlur(severity_params=[11, 13, 15, 17, 21]),
188 | 'brightness_up' : Brightness(severity_params=[1.2, 1.3, 1.4, 1.5, 1.6]),
189 | 'brightness_down' : Brightness(severity_params=[0.8, 0.75, 0.7, 0.65, 0.60]),
190 | 'contrast_up' : Contrast(severity_params=[1.3, 1.4, 1.6, 1.7, 1.8]),
191 | 'contrast_down' : Contrast(severity_params=[0.8, 0.7, 0.6, 0.55, 0.5]),
192 | 'gamma_corr_up' : GammaCorrection(severity_params=[1.3, 1.4, 1.6, 1.8, 2.0]),
193 | 'gamma_corr_down' : GammaCorrection(severity_params=[0.9, 0.8, 0.7, 0.6, 0.4]),
194 | },
195 | }
196 |
197 |
198 | CORRUPTIONS_DS_FOLDS = {
199 | 'digital' : ['pixelate','jpeg_compression'],
200 | 'noise' : ['gaussian_noise', 'speckle_noise', 'impulse_noise', 'shot_noise'],
201 | 'blur': ['defocus_blur','motion_blur','zoom_blur','gaussian_blur'],
202 | 'color' : ['brightness_up', 'brightness_down', 'contrast_up', 'contrast_down','saturate'],
203 | 'task-specific' : ['stain_deposit', 'bubble','black_corner', 'characters','gamma_corr_up','gamma_corr_down']
204 | }
205 |
206 |
207 | DATASET_RGB = {
208 | 'bloodmnist': True,
209 | 'breastmnist': False,
210 | 'chestmnist': False,
211 | 'dermamnist': True,
212 | 'octmnist': False,
213 | 'organamnist': False,
214 | 'organcmnist': False,
215 | 'organsmnist': False,
216 | 'pathmnist': True,
217 | 'pneumoniamnist': False,
218 | 'retinamnist': True,
219 | 'tissuemnist': False
220 | }
--------------------------------------------------------------------------------
/assets/examples/augment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from medmnistc.augmentation import AugMedMNISTC\n",
10 | "from medmnistc.corruptions.registry import CORRUPTIONS_DS\n",
11 | "\n",
12 | "from PIL import Image\n",
13 | "\n",
14 | "import torchvision.transforms as transforms\n",
15 | "import numpy as np\n",
16 | "import os"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "### Augmentation"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 9,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "speckle_noise\n"
36 | ]
37 | },
38 | {
39 | "data": {
40 | "image/jpeg": "",
41 | "image/png": "",
42 | "text/plain": [
43 | ""
44 | ]
45 | },
46 | "execution_count": 9,
47 | "metadata": {},
48 | "output_type": "execute_result"
49 | }
50 | ],
51 | "source": [
52 | "dataset = \"breastmnist\"\n",
53 | "medmnist_path = \"/mnt/data/datasets/medmnist\" # PATH TO THE CLEAN IMAGES\n",
54 | "\n",
55 | "train_corruptions = CORRUPTIONS_DS[dataset]\n",
56 | "path = os.path.join(medmnist_path,f'{dataset}_224.npz')\n",
57 | "train_images = np.load(path)['train_images']\n",
58 | "\n",
59 | "aug = AugMedMNISTC(train_corruptions, verbose=True) \n",
60 | "\n",
61 | "img = Image.fromarray(train_images[0])\n",
62 | "Image.fromarray(aug(img))"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "## Integrate into transforms.Compose"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 24,
75 | "metadata": {},
76 | "outputs": [
77 | {
78 | "data": {
79 | "text/plain": [
80 | "tensor([[[ 2.2489, 2.1119, 1.5810, ..., 1.1872, 2.2489, 2.1119],\n",
81 | " [ 0.6563, 1.8550, 2.2489, ..., 1.8550, 2.2489, 1.4440],\n",
82 | " [ 1.4440, 1.3242, 1.7180, ..., 1.4440, 2.2489, 2.2489],\n",
83 | " ...,\n",
84 | " [ 1.4440, 1.4440, 0.6563, ..., 1.3242, 0.7933, -0.0116],\n",
85 | " [ 1.3242, 1.3242, 0.3994, ..., 1.1872, 1.1872, 1.1872],\n",
86 | " [ 2.2489, 0.5193, 1.4440, ..., 2.2489, 2.1119, 1.3242]],\n",
87 | "\n",
88 | " [[ 1.4832, 1.7458, -0.5476, ..., 0.5378, -0.0049, 1.4832],\n",
89 | " [ 0.2577, 0.2577, 1.7458, ..., -0.2850, -0.1450, -0.5476],\n",
90 | " [ 1.0805, 2.0259, 0.6604, ..., 0.3978, 0.2577, 0.5378],\n",
91 | " ...,\n",
92 | " [ 0.6604, 0.9405, 0.2577, ..., 1.2031, 0.6604, 1.2031],\n",
93 | " [-0.2850, 0.2577, 2.0259, ..., 1.7458, -0.2850, 0.6604],\n",
94 | " [ 0.8004, 0.8004, 0.5378, ..., -0.5476, 0.5378, -0.2850]],\n",
95 | "\n",
96 | " [[ 0.3393, 1.5594, 0.7576, ..., 0.7576, 1.6988, 2.3611],\n",
97 | " [ 2.5006, 1.1585, 1.5594, ..., 0.7576, 2.6400, 1.1585],\n",
98 | " [ 1.0191, 1.0191, 1.0191, ..., 0.3393, 1.4200, 0.8797],\n",
99 | " ...,\n",
100 | " [ 0.8797, 1.1585, 0.3393, ..., 1.4200, 0.4788, 0.3393],\n",
101 | " [ 1.0191, 0.2173, 0.7576, ..., 0.0779, -0.2010, 1.1585],\n",
102 | " [ 2.2391, 0.3393, 0.3393, ..., -0.0615, -0.4624, 0.6182]]])"
103 | ]
104 | },
105 | "execution_count": 24,
106 | "metadata": {},
107 | "output_type": "execute_result"
108 | }
109 | ],
110 | "source": [
111 | "dataset = \"dermamnist\"\n",
112 | "medmnist_path = \"/mnt/data/datasets/medmnist\" # PATH TO THE CLEAN IMAGES\n",
113 | "\n",
114 | "train_corruptions = CORRUPTIONS_DS[dataset]\n",
115 | "path = os.path.join(medmnist_path,f'{dataset}_224.npz')\n",
116 | "train_images = np.load(path)['train_images']\n",
117 | "\n",
118 | "MEAN = [0.485, 0.456, 0.406]\n",
119 | "STD = [0.229, 0.224, 0.225]\n",
120 | "\n",
121 | "aug_compose = transforms.Compose([\n",
122 | " AugMedMNISTC(train_corruptions),\n",
123 | " transforms.ToTensor(),\n",
124 | " transforms.Normalize(mean=MEAN, std=STD)\n",
125 | "])\n",
126 | "\n",
127 | "img = Image.fromarray(train_images[0])\n",
128 | "aug_compose(img) # As before, but we have a normalized tensor now."
129 | ]
130 | }
131 | ],
132 | "metadata": {
133 | "kernelspec": {
134 | "display_name": "medmnistc",
135 | "language": "python",
136 | "name": "python3"
137 | },
138 | "language_info": {
139 | "codemirror_mode": {
140 | "name": "ipython",
141 | "version": 3
142 | },
143 | "file_extension": ".py",
144 | "mimetype": "text/x-python",
145 | "name": "python",
146 | "nbconvert_exporter": "python",
147 | "pygments_lexer": "ipython3",
148 | "version": "3.11.7"
149 | }
150 | },
151 | "nbformat": 4,
152 | "nbformat_minor": 2
153 | }
154 |
--------------------------------------------------------------------------------