├── tests ├── __init__.py └── ecgmentations │ ├── __init__.py │ ├── core │ ├── __init__.py │ ├── test_transformation.py │ ├── test_modification.py │ └── test_composition.py │ └── augmentations │ ├── __init__.py │ ├── time │ ├── __init__.py │ ├── test_transformations.py │ └── test_functional.py │ ├── test_functional.py │ ├── test_misc.py │ └── test_transformations.py ├── pytest.ini ├── images └── preview.png ├── ecgmentations ├── __version__.py ├── __init__.py ├── augmentations │ ├── pulse │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transformations.py │ ├── filter │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transformations.py │ ├── time │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transformations.py │ ├── __init__.py │ ├── functional.py │ ├── misc.py │ └── transformations.py └── core │ ├── __init__.py │ ├── enum.py │ ├── constants.py │ ├── utils.py │ ├── transformation.py │ ├── modification.py │ ├── augmentation.py │ └── composition.py ├── .pre-commit-config.yaml ├── LICENSE ├── setup.py ├── README.md └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ecgmentations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ecgmentations/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/time/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | core: Run tests for core subpackage 4 | -------------------------------------------------------------------------------- /images/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rostepifanov/ecgmentations/HEAD/images/preview.png -------------------------------------------------------------------------------- /ecgmentations/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 0, 9) 2 | 3 | __version__ = '.'.join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /ecgmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from ecgmentations.__version__ import __version__ 2 | from ecgmentations.core import * 3 | from ecgmentations.augmentations import * -------------------------------------------------------------------------------- /ecgmentations/augmentations/pulse/__init__.py: -------------------------------------------------------------------------------- 1 | from ecgmentations.augmentations.pulse.transformations import ( 2 | SinePulse, 3 | PowerlineNoise, 4 | RespirationNoise, 5 | SquarePulse, 6 | ) -------------------------------------------------------------------------------- /ecgmentations/augmentations/filter/__init__.py: -------------------------------------------------------------------------------- 1 | from ecgmentations.augmentations.filter.transformations import ( 2 | LowPassFilter, 3 | HighPassFilter, 4 | BandPassFilter, 5 | SigmoidCompression, 6 | ) -------------------------------------------------------------------------------- /ecgmentations/core/__init__.py: -------------------------------------------------------------------------------- 1 | from ecgmentations.core.augmentation import EcgOnlyAugmentation, DualAugmentation, Identity 2 | from ecgmentations.core.composition import Sequential, NonSequential, OneOf 3 | from ecgmentations.core.modification import ToChannels 4 | from ecgmentations.core.enum import BorderType, PositionType, ReductionType -------------------------------------------------------------------------------- /ecgmentations/augmentations/time/__init__.py: -------------------------------------------------------------------------------- 1 | from ecgmentations.augmentations.time.transformations import ( 2 | TimeReverse, 3 | TimeShift, 4 | TimeSegmentShuffle, 5 | RandomTimeWrap, 6 | TimeCutout, 7 | TimeCrop, 8 | CenterTimeCrop, 9 | RandomTimeCrop, 10 | TimePadIfNeeded, 11 | Pooling, 12 | Blur, 13 | ) -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-json 7 | - id: check-toml 8 | - id: check-xml 9 | - id: check-yaml 10 | - id: end-of-file-fixer 11 | - id: name-tests-test 12 | - id: trailing-whitespace 13 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from ecgmentations.augmentations.time.transformations import * 2 | from ecgmentations.augmentations.pulse.transformations import * 3 | from ecgmentations.augmentations.filter.transformations import * 4 | from ecgmentations.augmentations.transformations import ( 5 | AmplitudeInvert, 6 | ChannelShuffle, 7 | ChannelDropout, 8 | GaussNoise, 9 | GaussBlur, 10 | AmplitudeScale, 11 | ) -------------------------------------------------------------------------------- /ecgmentations/core/enum.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | class BorderType(enum.Enum): 4 | CONSTANT = 'constant' 5 | REPLICATE = 'replicate' 6 | REFLECT_1001 = 'reflect' 7 | REFLECT_101 = 'reflect' 8 | WRAP = 'wrap' 9 | DEFAULT = CONSTANT 10 | 11 | class PositionType(enum.Enum): 12 | CENTER = 'center' 13 | LEFT = 'left' 14 | RIGHT = 'right' 15 | RANDOM = 'random' 16 | 17 | class ReductionType(enum.Enum): 18 | MIN = 'min' 19 | MEAN = 'mean' 20 | MAX = 'max' 21 | MEDIAN = 'median' 22 | -------------------------------------------------------------------------------- /tests/ecgmentations/core/test_transformation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations as E 5 | 6 | @pytest.mark.core 7 | def test_Identity_CASE_repr(): 8 | instance = E.Identity(always_apply=True) 9 | 10 | repr = str(instance) 11 | 12 | assert 'Identity' in repr 13 | assert 'always_apply' in repr 14 | assert 'p' in repr 15 | 16 | @pytest.mark.core 17 | def test_Identity_CASE_call(): 18 | input = np.random.randn(5000, 12) 19 | 20 | instance = E.Identity(always_apply=True) 21 | 22 | output = instance(ecg=input)['ecg'] 23 | expected = input 24 | 25 | assert np.allclose(output, expected) 26 | -------------------------------------------------------------------------------- /ecgmentations/core/constants.py: -------------------------------------------------------------------------------- 1 | import ecgmentations.core.enum as E 2 | 3 | SPATIAL_DIM = 0 4 | CHANNEL_DIM = 1 5 | 6 | NUM_SPATIAL_DIMENSIONS = 1 7 | NUM_MONO_CHANNEL_DIMENSIONS = 1 8 | NUM_MULTI_CHANNEL_DIMENSIONS = 2 9 | 10 | MAP_BORDER_TYPE_TO_NUMPY = { 11 | E.BorderType.CONSTANT: 'constant', 12 | E.BorderType.REPLICATE: 'edge', 13 | E.BorderType.REFLECT_1001: 'symmetric', 14 | E.BorderType.REFLECT_101: 'reflect', 15 | E.BorderType.WRAP: 'wrap', 16 | } 17 | 18 | MAP_BORDER_TYPE_TO_SC = { 19 | E.BorderType.CONSTANT: 'constant', 20 | E.BorderType.REPLICATE: 'nearest', 21 | E.BorderType.REFLECT_1001: 'reflect', 22 | E.BorderType.REFLECT_101: 'mirror', 23 | E.BorderType.WRAP: 'wrap', 24 | } 25 | -------------------------------------------------------------------------------- /ecgmentations/core/utils.py: -------------------------------------------------------------------------------- 1 | def format_args(args_dict): 2 | formatted_args = [] 3 | 4 | for k, v in args_dict.items(): 5 | if isinstance(v, str): 6 | v = f"'{v}'" 7 | formatted_args.append(f'{k}={v}') 8 | 9 | return ', '.join(formatted_args) 10 | 11 | def shorten_class_name(class_fullname): 12 | splitted = class_fullname.split('.') 13 | 14 | if len(splitted) == 1: 15 | return class_fullname 16 | 17 | top_module, *_, class_name = splitted 18 | 19 | if top_module == 'ecgmentations': 20 | return class_name 21 | 22 | return class_fullname 23 | 24 | def get_shortest_class_fullname(cls): 25 | class_fullname = '{cls.__module__}.{cls.__name__}'.format(cls=cls) 26 | return shorten_class_name(class_fullname) 27 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/filter/functional.py: -------------------------------------------------------------------------------- 1 | import scipy as sp 2 | 3 | def lowpass_filter(ecg, ecg_frequency, cutoff_frequency): 4 | params = sp.signal.butter(3, cutoff_frequency, 'low', analog=False, fs=ecg_frequency) 5 | ecg = sp.signal.lfilter(*params, ecg).astype(ecg.dtype) 6 | 7 | return ecg 8 | 9 | def highpass_filter(ecg, ecg_frequency, cutoff_frequency): 10 | params = sp.signal.butter(3, cutoff_frequency, 'high', analog=False, fs=ecg_frequency) 11 | ecg = sp.signal.lfilter(*params, ecg).astype(ecg.dtype) 12 | 13 | return ecg 14 | 15 | def bandpass_filter(ecg, ecg_frequency, cutoff_frequencies): 16 | params = sp.signal.butter(3, cutoff_frequencies, 'bandpass', analog=False, fs=ecg_frequency) 17 | ecg = sp.signal.lfilter(*params, ecg).astype(ecg.dtype) 18 | 19 | return ecg 20 | 21 | def sigmoid_compression(ecg): 22 | ecg = sp.special.expit(ecg) 23 | 24 | return ecg 25 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/pulse/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | 4 | import ecgmentations.core.constants as C 5 | 6 | def add_sine_pulse(ecg, ecg_frequency, amplitude, frequency, phase): 7 | length = ecg.shape[C.SPATIAL_DIM] 8 | 9 | t = np.linspace(0, length / ecg_frequency, length) 10 | pulse = amplitude * np.sin(2 * np.pi * frequency * t + phase) 11 | 12 | if len(ecg.shape) == C.NUM_MULTI_CHANNEL_DIMENSIONS: 13 | pulse = np.expand_dims(pulse, axis=C.CHANNEL_DIM) 14 | 15 | return np.add(ecg, pulse, dtype=ecg.dtype) 16 | 17 | def add_square_pulse(ecg, ecg_frequency, amplitude, frequency, phase): 18 | length = ecg.shape[C.SPATIAL_DIM] 19 | 20 | t = np.linspace(0, length / ecg_frequency, length) 21 | pulse = amplitude * sp.signal.square(2 * np.pi * frequency * t + phase) 22 | 23 | if len(ecg.shape) == C.NUM_MULTI_CHANNEL_DIMENSIONS: 24 | pulse = np.expand_dims(pulse, axis=C.CHANNEL_DIM) 25 | 26 | return np.add(ecg, pulse, dtype=ecg.dtype) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 rostepifanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/test_functional.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations.augmentations.functional as F 5 | 6 | def test_amplitude_invert_CASE_default(): 7 | input = np.array([[1, 2, 3, 4, 5, 6], ]).T 8 | 9 | output = F.amplitude_invert(input) 10 | expected = np.array([[-1, -2, -3, -4, -5, -6], ]).T 11 | 12 | assert np.allclose(output, expected) 13 | 14 | def test_channel_shuffle_CASE_direct_order(): 15 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 16 | 17 | output = F.channel_shuffle(input, (0, 1)) 18 | expected = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 19 | 20 | assert np.allclose(output, expected) 21 | 22 | def test_channel_shuffle_CASE_inverse_order(): 23 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 24 | 25 | output = F.channel_shuffle(input, (1, 0)) 26 | expected = np.array([[6, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 6]]).T 27 | 28 | assert np.allclose(output, expected) 29 | 30 | def test_channel_dropout_CASE_default(): 31 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 32 | 33 | channels_to_drop = (0, ) 34 | fill_value = 0 35 | 36 | output = F.channel_dropout(input, channels_to_drop, fill_value) 37 | expected = np.array([[0, 0, 0, 0, 0, 0], [6, 5, 4, 3, 2, 1]]).T 38 | 39 | assert np.allclose(output, expected) 40 | -------------------------------------------------------------------------------- /tests/ecgmentations/core/test_modification.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations as E 5 | 6 | @pytest.mark.core 7 | def test_ToChannels_CASE_create_AND_type_error(): 8 | with pytest.raises(RuntimeError, match=r'transform is type of <.+> that is not subtype of Transformation'): 9 | instance = E.ToChannels( 10 | object() 11 | , always_apply=True) 12 | 13 | @pytest.mark.core 14 | def test_ToChannels_CASE_create_AND_channels_error(): 15 | with pytest.raises(RuntimeError, match=r'object at \d+ position is not subtype of int'): 16 | instance = E.ToChannels( 17 | E.TimeReverse(always_apply=True) 18 | , channels=[0, 1.], always_apply=True) 19 | 20 | @pytest.mark.core 21 | def test_ToChannels_CASE_call_AND_zero_channel(): 22 | ecg = np.random.randn(5000, 12) 23 | mask = np.arange(1000)[: None] 24 | 25 | channels = [0, ] 26 | exchannels = [ ch for ch in np.arange(12) if ch not in channels] 27 | 28 | instance = E.ToChannels( 29 | E.TimeReverse(always_apply=True) 30 | , channels=channels, always_apply=True) 31 | 32 | transformed = instance(ecg=ecg, mask=mask) 33 | tecg, tmask = transformed['ecg'], transformed['mask'] 34 | 35 | assert tecg.shape == ecg.shape 36 | assert not np.allclose(tecg[:, channels], ecg[:, channels]) 37 | assert np.allclose(tecg[:, exchannels], ecg[:, exchannels]) 38 | 39 | assert tmask.shape == mask.shape 40 | assert not np.array_equal(tmask, mask) 41 | -------------------------------------------------------------------------------- /ecgmentations/core/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ecgmentations.core.utils import get_shortest_class_fullname 4 | 5 | class Transformation(object): 6 | """Root class for single and compound augmentations 7 | """ 8 | 9 | REPR_INDENT_STEP=2 10 | 11 | def __init__(self, always_apply, p): 12 | """ 13 | :args: 14 | always_apply: bool 15 | the flag of force application 16 | p: float 17 | the probability of application 18 | """ 19 | self.always_apply = always_apply 20 | self.p = p 21 | 22 | def get_base_init_args(self): 23 | """ 24 | :return: 25 | output: dict 26 | initialization parameters 27 | """ 28 | return {'always_apply': self.always_apply, 'p': self.p} 29 | 30 | def whether_apply(self, force_apply): 31 | return force_apply or self.always_apply or (np.random.random() < self.p) 32 | 33 | def __call__(self, *args, force_apply=False, **data): 34 | raise NotImplementedError 35 | 36 | def get_class_name(self): 37 | """ 38 | :return: 39 | output: str 40 | the name of class 41 | """ 42 | return self.__class__.__name__ 43 | 44 | @classmethod 45 | def get_class_fullname(cls): 46 | """ 47 | :return: 48 | output: str 49 | the full name of class 50 | """ 51 | return get_shortest_class_fullname(cls) 52 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import ecgmentations.core.constants as C 4 | import ecgmentations.augmentations.time.functional as TF 5 | 6 | def apply_along_dim(data, func, dim): 7 | """Apply the same transformation along the dim 8 | """ 9 | data = np.moveaxis(data, dim, 0) 10 | 11 | data = np.stack([*map( 12 | func, 13 | data, 14 | )], axis=dim) 15 | 16 | return data 17 | 18 | def amplitude_invert(ecg): 19 | return np.negative(ecg) 20 | 21 | def channel_shuffle(ecg, channel_order): 22 | ecg = ecg[:, channel_order] 23 | 24 | return np.require(ecg, requirements=['C_CONTIGUOUS']) 25 | 26 | def channel_dropout(ecg, channels_to_drop, fill_value): 27 | ecg = np.copy(ecg) 28 | ecg[:, channels_to_drop] = fill_value 29 | 30 | return ecg 31 | 32 | def add(ecg, other): 33 | if len(ecg.shape) > len(other.shape): 34 | other = np.expand_dims(other, axis=C.CHANNEL_DIM) 35 | 36 | return np.add(ecg, other, dtype=ecg.dtype) 37 | 38 | def conv(ecg, kernel, border_mode, fill_value): 39 | pad_width = kernel.size // 2 40 | 41 | ecg = TF.pad(ecg, pad_width, pad_width, border_mode, fill_value) 42 | 43 | func = lambda arr: np.correlate(arr, kernel, mode='valid').astype(arr.dtype) 44 | 45 | if len(ecg.shape) == C.NUM_MULTI_CHANNEL_DIMENSIONS: 46 | ecg = apply_along_dim(ecg, func, C.CHANNEL_DIM) 47 | else: 48 | ecg = func(ecg) 49 | 50 | return np.require(ecg, requirements=['C_CONTIGUOUS']) 51 | 52 | def multiply(ecg, factor): 53 | return np.multiply(ecg, factor, dtype=ecg.dtype) 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | from pathlib import Path 4 | 5 | INSTALL_REQUIRES = ['numpy>=1.20.0', 'scipy>=1.10.1', 'opencv-python-headless>=4.1.1'] 6 | 7 | def get_version(): 8 | locals_ = dict() 9 | 10 | with open(Path(__file__).parent / 'ecgmentations' / '__version__.py') as f: 11 | exec(f.read(), globals(), locals_) 12 | return locals_['__version__'] 13 | 14 | def get_long_description(): 15 | with open(Path(__file__).parent / 'README.md', encoding='utf-8') as f: 16 | return f.read() 17 | 18 | setuptools.setup( 19 | name='ecgmentations', 20 | version=get_version(), 21 | description='Ecg augmentation library and easy to use wrapper around other libraries', 22 | long_description=get_long_description(), 23 | long_description_content_type='text/markdown', 24 | author='Rostislav Epifanov', 25 | author_email='rostepifanov@gmail.com', 26 | license='MIT', 27 | url='https://github.com/rostepifanov/ecgmentations', 28 | packages=setuptools.find_packages(exclude=['tests']), 29 | python_requires='>=3.7', 30 | install_requires=INSTALL_REQUIRES, 31 | extras_require={'tests': ['pytest']}, 32 | classifiers=[ 33 | 'Development Status :: 4 - Beta', 34 | 'License :: OSI Approved :: MIT License', 35 | 'Intended Audience :: Developers', 36 | 'Intended Audience :: Science/Research', 37 | 'Operating System :: OS Independent', 38 | 'Programming Language :: Python', 39 | 'Programming Language :: Python :: 3', 40 | 'Programming Language :: Python :: 3.7', 41 | 'Programming Language :: Python :: 3.8', 42 | 'Programming Language :: Python :: 3.9', 43 | 'Programming Language :: Python :: 3.10', 44 | 'Programming Language :: Python :: 3.11', 45 | 'Topic :: Scientific/Engineering', 46 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 47 | 'Topic :: Software Development', 48 | 'Topic :: Software Development :: Libraries', 49 | 'Topic :: Software Development :: Libraries :: Python Modules', 50 | ] 51 | ) 52 | -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/time/test_transformations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations as E 5 | import ecgmentations.core.constants as C 6 | 7 | def test_TimeCrop_CASE_left_crop(): 8 | length = 5000 9 | expected_length = 4000 10 | 11 | ecg = np.random.randn(length, 12) 12 | mask = np.zeros((length, 1)) 13 | 14 | instance = E.TimeCrop(length=expected_length, position=E.PositionType.LEFT, always_apply=True) 15 | transformed = instance(ecg=ecg, mask=mask) 16 | 17 | tecg, tmask = transformed['ecg'], transformed['mask'] 18 | 19 | assert tecg.shape == (expected_length, ecg.shape[C.CHANNEL_DIM]) 20 | assert tmask.shape == (expected_length, mask.shape[C.CHANNEL_DIM]) 21 | 22 | assert np.allclose(tecg, ecg[:expected_length]) 23 | 24 | def test_TimePadIfNeeded_CASE_left_padding(): 25 | length = 4000 26 | expected_length = 5000 27 | 28 | ecg = np.random.randn(length, 12) 29 | mask = np.zeros((length, 1)) 30 | 31 | instance = E.TimePadIfNeeded(min_length=expected_length, position=E.PositionType.LEFT, always_apply=True) 32 | transformed = instance(ecg=ecg, mask=mask) 33 | 34 | tecg, tmask = transformed['ecg'], transformed['mask'] 35 | 36 | assert tecg.shape == (expected_length, ecg.shape[C.CHANNEL_DIM]) 37 | assert tmask.shape == (expected_length, mask.shape[C.CHANNEL_DIM]) 38 | 39 | assert np.allclose(tecg[:length], ecg) 40 | assert np.allclose(tecg[length:], 0.) 41 | 42 | def test_TimeCutout_CASE_mask_fill_value(): 43 | ecg = np.random.randn(5000, 12) 44 | mask = np.zeros((5000, 1)) 45 | 46 | mask_fill_value = 0 47 | 48 | instance = E.TimeCutout(mask_fill_value=mask_fill_value, always_apply=True) 49 | transformed = instance(ecg=ecg, mask=mask) 50 | 51 | tmask = transformed['mask'] 52 | 53 | assert np.all(tmask[mask != tmask], mask_fill_value) 54 | 55 | @pytest.mark.parametrize('reduction', list(map(lambda t: t.value, E.ReductionType))) 56 | def test_Pooling_CASE_reduction(reduction): 57 | input = np.random.randn(5000, 12) 58 | 59 | instance = E.Pooling(reduction, always_apply=True) 60 | output = instance(ecg=input)['ecg'] 61 | 62 | assert output.shape == input.shape 63 | assert not np.allclose(output, input) 64 | -------------------------------------------------------------------------------- /tests/ecgmentations/core/test_composition.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations as E 5 | 6 | @pytest.mark.core 7 | def test_Sequential_CASE_create_AND_list_error(): 8 | with pytest.raises(RuntimeError, match=r'transforms is type of <.+> that is not list'): 9 | instance = E.Sequential( 10 | E.TimeReverse(always_apply=True), 11 | ) 12 | 13 | @pytest.mark.core 14 | def test_Sequential_CASE_create_AND_subtype_error(): 15 | with pytest.raises(RuntimeError, match=r'object at \d+ position is not subtype of Transformation'): 16 | instance = E.Sequential([ 17 | E.TimeReverse(always_apply=True), 18 | object(), 19 | ], always_apply=True) 20 | 21 | @pytest.mark.core 22 | def test_Sequential_CASE_call_AND_no_transfroms(): 23 | input = np.random.randn(5000, 12) 24 | 25 | instance = E.Sequential([ 26 | ], always_apply=True) 27 | 28 | output = instance(ecg=input)['ecg'] 29 | 30 | assert np.allclose(output, input) 31 | 32 | @pytest.mark.core 33 | def test_Sequential_CASE_call_AND_one_flip(): 34 | input = np.random.randn(5000, 12) 35 | 36 | instance = E.Sequential([ 37 | E.TimeReverse(always_apply=True), 38 | ], always_apply=True) 39 | 40 | output = instance(ecg=input)['ecg'] 41 | 42 | assert not np.allclose(output, input) 43 | 44 | @pytest.mark.core 45 | def test_Sequential_CASE_call_AND_double_flip(): 46 | input = np.random.randn(5000, 12) 47 | 48 | instance = E.Sequential([ 49 | E.TimeReverse(always_apply=True), 50 | E.TimeReverse(always_apply=True) 51 | ], always_apply=True) 52 | 53 | output = instance(ecg=input)['ecg'] 54 | 55 | assert np.allclose(output, input) 56 | 57 | @pytest.mark.core 58 | def test_OneOf_CASE_call_AND_no_transfroms(): 59 | input = np.random.randn(5000, 12) 60 | 61 | instance = E.OneOf([ 62 | ], always_apply=True) 63 | 64 | output = instance(ecg=input)['ecg'] 65 | 66 | assert np.allclose(output, input) 67 | 68 | @pytest.mark.core 69 | def test_OneOf_CASE_call_AND_check_application(): 70 | input = np.random.randn(5000, 12) 71 | 72 | instance = E.Sequential([ 73 | E.TimeReverse(always_apply=True), 74 | E.OneOf([ 75 | E.TimeReverse(), 76 | E.TimeReverse(), 77 | ], always_apply=True), 78 | ], always_apply=True) 79 | 80 | output = instance(ecg=input)['ecg'] 81 | 82 | assert np.allclose(output, input) 83 | -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/test_misc.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import ecgmentations.augmentations.misc as M 4 | 5 | def test_prepare_float_CASE_wrong_type(): 6 | with pytest.raises(ValueError): 7 | M.prepare_float(None, '') 8 | 9 | def test_prepare_float_CASE_int(): 10 | output = M.prepare_float(1, '') 11 | 12 | assert pytest.approx(output) == 1. 13 | 14 | def test_prepare_int_asymrange_CASE_wrong_type(): 15 | with pytest.raises(ValueError): 16 | M.prepare_int_asymrange(None, '', 0) 17 | 18 | def test_prepare_int_asymrange_CASE_float(): 19 | with pytest.raises(ValueError): 20 | M.prepare_int_asymrange(1., '', 0) 21 | 22 | def test_prepare_int_asymrange_CASE_low(): 23 | with pytest.raises(ValueError): 24 | M.prepare_int_asymrange(-1, '', 0) 25 | 26 | def test_prepare_int_asymrange_CASE_int(): 27 | output = M.prepare_int_asymrange(1, '', 0) 28 | expected = (0, 1) 29 | 30 | assert output == expected 31 | 32 | def test_prepare_int_asymrange_CASE_tuple_int(): 33 | output = M.prepare_int_asymrange((1, 2), '', 0) 34 | expected = (1, 2) 35 | 36 | assert output == expected 37 | 38 | def test_prepare_int_asymrange_CASE_tuple_AND_wrong_type(): 39 | with pytest.raises(ValueError): 40 | M.prepare_int_asymrange((None, None), '', 0) 41 | 42 | def test_prepare_int_asymrange_CASE_tuple_int_AND_low(): 43 | with pytest.raises(ValueError): 44 | M.prepare_int_asymrange((-1, 2), '', 0) 45 | 46 | def test_prepare_int_asymrange_CASE_tuple_int_AND_long(): 47 | with pytest.raises(ValueError): 48 | M.prepare_int_asymrange((0, 1, 2), '', 0) 49 | 50 | def test_prepare_float_asymrange_CASE_wrong_type(): 51 | with pytest.raises(ValueError): 52 | M.prepare_float_asymrange(None, '', 0.) 53 | 54 | def test_prepare_float_asymrange_CASE_float(): 55 | with pytest.raises(ValueError): 56 | M.prepare_float_asymrange(1, '', 0.) 57 | 58 | def test_prepare_float_asymrange_CASE_low(): 59 | with pytest.raises(ValueError): 60 | M.prepare_float_asymrange(-1., '', 0.) 61 | 62 | def test_prepare_float_asymrange_CASE_float(): 63 | output = M.prepare_float_asymrange(1., '', 0.) 64 | expected = (0., 1.) 65 | 66 | assert pytest.approx(output) == expected 67 | 68 | def test_prepare_float_asymrange_CASE_tuple_float(): 69 | output = M.prepare_float_asymrange((1., 2.), '', 0.) 70 | expected = (1., 2.) 71 | 72 | assert pytest.approx(output) == expected 73 | 74 | def test_prepare_float_asymrange_CASE_tuple_AND_wrong_type(): 75 | with pytest.raises(ValueError): 76 | M.prepare_float_asymrange((None, None), '', 0.) 77 | 78 | def test_prepare_float_asymrange_CASE_tuple_float_AND_low(): 79 | with pytest.raises(ValueError): 80 | M.prepare_float_asymrange((-1., 2.), '', 0.) 81 | 82 | def test_prepare_int_asymrange_CASE_tuple_float_AND_long(): 83 | with pytest.raises(ValueError): 84 | M.prepare_float_asymrange((0., 1., 2.), '', 0.) -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/test_transformations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations as E 5 | 6 | SHAPE_PRESERVED_TRANSFORMS = [ 7 | E.AmplitudeInvert, 8 | E.ChannelShuffle, 9 | E.ChannelDropout, 10 | E.GaussNoise, 11 | E.GaussBlur, 12 | E.AmplitudeScale, 13 | E.TimeReverse, 14 | E.TimeShift, 15 | E.TimeSegmentShuffle, 16 | E.RandomTimeWrap, 17 | E.TimeCutout, 18 | E.Pooling, 19 | E.Blur, 20 | E.PowerlineNoise, 21 | E.SinePulse, 22 | E.SquarePulse, 23 | E.RespirationNoise, 24 | E.LowPassFilter, 25 | E.HighPassFilter, 26 | E.BandPassFilter, 27 | E.SigmoidCompression, 28 | ] 29 | 30 | SHAPE_UNPRESERVED_TRANSFORMS = [ 31 | E.TimeCrop, 32 | E.RandomTimeCrop, 33 | E.CenterTimeCrop, 34 | E.TimePadIfNeeded, 35 | ] 36 | 37 | @pytest.mark.parametrize('transform', SHAPE_PRESERVED_TRANSFORMS + SHAPE_UNPRESERVED_TRANSFORMS) 38 | def test_Transform_CASE_repr(transform): 39 | instance = transform(always_apply=True) 40 | repr = str(instance) 41 | 42 | assert 'always_apply' in repr 43 | assert 'p' in repr 44 | 45 | @pytest.mark.parametrize('transform', SHAPE_PRESERVED_TRANSFORMS) 46 | def test_Transform_CASE_call_AND_mono_channel(transform): 47 | if transform == E.ChannelShuffle: 48 | return 49 | elif transform == E.ChannelDropout: 50 | return 51 | 52 | ecg = np.random.randn(5000).astype(np.float32) 53 | mask = np.zeros((5000, ), dtype=np.uint8) 54 | 55 | tecg = np.copy(ecg) 56 | tmask = np.copy(mask) 57 | 58 | instance = transform(always_apply=True) 59 | transformed = instance(ecg=tecg, mask=tmask) 60 | 61 | tecg, tmask = transformed['ecg'], transformed['mask'] 62 | 63 | assert tecg.flags['C_CONTIGUOUS'] == True 64 | assert tecg.dtype == ecg.dtype 65 | assert tecg.shape == ecg.shape 66 | assert not np.allclose(tecg, ecg) 67 | 68 | assert tmask.flags['C_CONTIGUOUS'] == True 69 | assert tmask.dtype == mask.dtype 70 | assert tmask.shape == mask.shape 71 | 72 | if isinstance(transform, E.EcgOnlyAugmentation): 73 | assert np.all(tmask == mask) 74 | 75 | @pytest.mark.parametrize('transform', SHAPE_PRESERVED_TRANSFORMS) 76 | def test_Transform_CASE_call_AND_multi_channel(transform): 77 | ecg = np.random.randn(5000, 12).astype(np.float32) 78 | mask = np.zeros((5000, 1), dtype=np.uint8) 79 | 80 | tecg = np.copy(ecg) 81 | tmask = np.copy(mask) 82 | 83 | instance = transform(always_apply=True) 84 | transformed = instance(ecg=tecg, mask=tmask) 85 | 86 | tecg, tmask = transformed['ecg'], transformed['mask'] 87 | 88 | assert tecg.flags['C_CONTIGUOUS'] == True 89 | assert tecg.dtype == ecg.dtype 90 | assert tecg.shape == ecg.shape 91 | assert not np.allclose(tecg, ecg) 92 | 93 | assert tmask.flags['C_CONTIGUOUS'] == True 94 | assert tmask.dtype == mask.dtype 95 | assert tmask.shape == mask.shape 96 | 97 | if isinstance(transform, E.EcgOnlyAugmentation): 98 | assert np.all(tmask == mask) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ecgmentations 2 | 3 | ![Python version support](https://img.shields.io/pypi/pyversions/ecgmentations) 4 | [![PyPI version](https://badge.fury.io/py/ecgmentations.svg)](https://badge.fury.io/py/ecgmentations) 5 | [![Downloads](https://pepy.tech/badge/ecgmentations/month)](https://pepy.tech/project/ecgmentations?versions=0.0.*) 6 | 7 | Ecgmentations is a Python library for ecg augmentation. Ecg augmentation is used in deep learning to increase the quality of trained models. The purpose of ecg augmentation is to create new training samples from the existing data. 8 | 9 | Here is an example of how you can apply some augmentations from Ecgmentations to create new ecgs from the original one: 10 | 11 | ![preview](https://raw.githubusercontent.com/rostepifanov/ecgmentations/main/images/preview.png) 12 | 13 | ## Table of contents 14 | - [Authors](#authors) 15 | - [Installation](#installation) 16 | - [A simple example](#a-simple-example) 17 | - [List of augmentations](#list-of-augmentations) 18 | - [Citing](#citing) 19 | 20 | ## Authors 21 | [**Rostislav Epifanov** — Researcher in Novosibirsk]() 22 | 23 | ## Installation 24 | Installation from PyPI: 25 | 26 | ``` 27 | pip install ecgmentations 28 | ``` 29 | 30 | Installation from GitHub: 31 | 32 | ``` 33 | pip install git+https://github.com/rostepifanov/ecgmentations 34 | ``` 35 | 36 | ## A simple example 37 | ```python 38 | import numpy as np 39 | import ecgmentations as E 40 | 41 | # Declare an augmentation pipeline 42 | transform = E.Sequential([ 43 | E.TimeReverse(p=0.5), 44 | E.ChannelShuffle(p=0.06), 45 | ]) 46 | 47 | # Create example ecg (length, nchannels) 48 | ecg = np.ones((5000, 12)) 49 | 50 | # Augment an ecg 51 | transformed = transform(ecg=ecg) 52 | transformed_ecg = transformed['ecg'] 53 | ``` 54 | 55 | ## List of augmentations 56 | 57 | The list of time axis transforms: 58 | 59 | - [TimeReverse]() 60 | - [TimeShift]() 61 | - [TimeSegmentShuffle]() 62 | - [RandomTimeWrap]() 63 | - [TimeCutout]() 64 | - [TimeCrop]() 65 | - [CenterTimeCrop]() 66 | - [RandomTimeCrop]() 67 | - [TimePadIfNeeded]() 68 | - [Pooling]() 69 | - [Blur]() 70 | 71 | The list of pulse transforms: 72 | 73 | - [SinePulse]() 74 | - [PowerlineNoise]() 75 | - [RespirationNoise]() 76 | - [SquarePulse]() 77 | 78 | The list of filter transforms: 79 | 80 | - [LowPassFilter]() 81 | - [HighPassFilter]() 82 | - [BandPassFilter]() 83 | - [SigmoidCompression]() 84 | 85 | The list of other transforms: 86 | 87 | - [AmplitudeInvert]() 88 | - [ChannelShuffle]() 89 | - [ChannelDropout]() 90 | - [GaussNoise]() 91 | - [GaussBlur]() 92 | - [AmplitudeScale]() 93 | 94 | 95 | ## Citing 96 | 97 | If you find this library useful for your research, please consider citing: 98 | 99 | ``` 100 | @misc{epifanov2023ecgmentations, 101 | Author = {Rostislav Epifanov}, 102 | Title = {Ecgmentations}, 103 | Year = {2023}, 104 | Publisher = {GitHub}, 105 | Journal = {GitHub repository}, 106 | Howpublished = {\url{https://github.com/rostepifanov/ecgmentations}} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | ### JetBrains template 108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 110 | 111 | # User-specific stuff: 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/dictionaries 115 | 116 | # Sensitive or high-churn files: 117 | .idea/**/dataSources/ 118 | .idea/**/dataSources.ids 119 | .idea/**/dataSources.local.xml 120 | .idea/**/sqlDataSources.xml 121 | .idea/**/dynamic.xml 122 | .idea/**/uiDesigner.xml 123 | 124 | # Gradle: 125 | .idea/**/gradle.xml 126 | .idea/**/libraries 127 | 128 | # CMake 129 | cmake-build-debug/ 130 | cmake-build-release/ 131 | 132 | # Mongo Explorer plugin: 133 | .idea/**/mongoSettings.xml 134 | 135 | ## File-based project format: 136 | *.iws 137 | 138 | ## Plugin-specific files: 139 | 140 | # IntelliJ 141 | out/ 142 | 143 | # mpeltonen/sbt-idea plugin 144 | .idea_modules/ 145 | 146 | # JIRA plugin 147 | atlassian-ide-plugin.xml 148 | 149 | # Cursive Clojure plugin 150 | .idea/replstate.xml 151 | 152 | # Crashlytics plugin (for Android Studio and IntelliJ) 153 | com_crashlytics_export_strings.xml 154 | crashlytics.properties 155 | crashlytics-build.properties 156 | fabric.properties 157 | 158 | .idea 159 | 160 | conda_build/ -------------------------------------------------------------------------------- /ecgmentations/core/modification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ecgmentations.core.transformation import Transformation 4 | from ecgmentations.core.utils import format_args 5 | 6 | class Modification(Transformation): 7 | def __init__(self, transform, always_apply, p): 8 | """ 9 | :args: 10 | transform: list of Transformation 11 | list of operations to apply with modification 12 | always_apply: bool 13 | the flag of force application 14 | p: float 15 | the probability of application 16 | """ 17 | super(Modification, self).__init__(always_apply, p) 18 | 19 | if not isinstance(transform, Transformation): 20 | raise RuntimeError( 21 | 'transform is type of {} that is not subtype of Transformation'.format(type(transform)) 22 | ) 23 | 24 | self.transform = transform 25 | 26 | def __repr__(self): 27 | return self.repr() 28 | 29 | def repr(self, indent=Transformation.REPR_INDENT_STEP): 30 | args = self.get_base_init_args() 31 | 32 | repr_string = self.get_class_name() + '(' 33 | 34 | repr_string += '\n' 35 | 36 | if hasattr(self.transform, 'repr'): 37 | t_repr = self.transform.repr(indent + self.REPR_INDENT_STEP) 38 | else: 39 | t_repr = repr(self.transform) 40 | 41 | repr_string += ' ' * indent + t_repr + ',' 42 | 43 | repr_string += '\n' + ' ' * (indent - self.REPR_INDENT_STEP) + ', {args})'.format(args=format_args(args)) 44 | 45 | return repr_string 46 | 47 | class ToChannels(Modification): 48 | """Transformation transforms to selected channels 49 | """ 50 | def __init__(self, transform, channels=[0, ], always_apply=False, p=0.5): 51 | """ 52 | :args: 53 | transform: list of Transformation 54 | list of operations to apply with modification 55 | channels: list of int 56 | selected channels to apply transform 57 | always_apply: bool 58 | the flag of force application 59 | p: float 60 | the probability of application 61 | """ 62 | super(ToChannels, self).__init__(transform, always_apply, p) 63 | 64 | if not isinstance(channels, list): 65 | raise RuntimeError( 66 | 'channels is type of {} that is not list'.format(type(channels)) 67 | ) 68 | elif not all(isinstance(ch, int) for ch in channels): 69 | for idx, ch in enumerate(channels): 70 | if not isinstance(ch, int): 71 | raise RuntimeError( 72 | 'object at {} position is not subtype of int'.format(idx) 73 | ) 74 | 75 | self.channels = channels 76 | 77 | def __call__(self, *args, force_apply=False, **data): 78 | if self.whether_apply(force_apply): 79 | ecg = np.copy(data['ecg']) 80 | 81 | data['ecg'] = data['ecg'][:, self.channels] 82 | data = self.transform(**data) 83 | 84 | ecg[:, self.channels] = data['ecg'] 85 | data['ecg'] = ecg 86 | else: 87 | data = self.transform(**data) 88 | 89 | return data 90 | -------------------------------------------------------------------------------- /tests/ecgmentations/augmentations/time/test_functional.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import ecgmentations as E 5 | import ecgmentations.augmentations.time.functional as F 6 | 7 | def test_time_reverse_CASE_default(): 8 | input = np.array([[1, 2, 3, 4, 5, 6], ]).T 9 | 10 | output = F.time_reverse(input) 11 | expected = np.array([[6, 5, 4, 3, 2, 1], ]).T 12 | 13 | assert np.allclose(output, expected) 14 | 15 | def test_time_shift_CASE_zero_shift(): 16 | input = np.array([[1, 2, 3, 4, 5, 6], ]).T 17 | 18 | output = F.time_shift(input, 0., E.BorderType.CONSTANT, 0.) 19 | expected = np.array([[1, 2, 3, 4, 5, 6], ]).T 20 | 21 | assert np.allclose(output, expected) 22 | 23 | def test_time_shift_CASE_positive_shift(): 24 | input = np.array([[1, 2, 3, 4, 5, 6], ]).T 25 | 26 | output = F.time_shift(input, 0.167, E.BorderType.CONSTANT, 0.) 27 | expected = np.array([[0, 1, 2, 3, 4, 5], ]).T 28 | 29 | assert np.allclose(output, expected) 30 | 31 | def test_time_shift_CASE_negative_shift(): 32 | input = np.array([[1, 2, 3, 4, 5, 6], ]).T 33 | 34 | output = F.time_shift(input, -0.167, E.BorderType.CONSTANT, 0.) 35 | expected = np.array([[2, 3, 4, 5, 6, 0], ]).T 36 | 37 | assert np.allclose(output, expected) 38 | 39 | def time_segment_swap_CASE_default(): 40 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 41 | 42 | segment_order = [2, 0, 1] 43 | 44 | output = F.time_segment_swap(input, segment_order) 45 | expected = np.array([[5, 6, 1, 2, 3, 4], [2, 1, 6, 5, 4, 3]]).T 46 | 47 | assert np.allclose(output, expected) 48 | 49 | def test_time_cutout_CASE_default(): 50 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 51 | 52 | cutouts = [(0, 2)] 53 | fill_value = 0 54 | 55 | output = F.time_cutout(input, cutouts, fill_value) 56 | expected = np.array([[0, 0, 3, 4, 5, 6], [0, 0, 4, 3, 2, 1]]).T 57 | 58 | assert np.allclose(output, expected) 59 | 60 | def test_time_crop_CASE_default(): 61 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 62 | 63 | left_bound = 1 / 4 64 | crop_length = 2 65 | 66 | output = F.time_crop(input, left_bound, crop_length) 67 | expected = np.array([[2, 3], [5, 4]]).T 68 | 69 | assert np.allclose(output, expected) 70 | 71 | def test_time_crop_CASE_equal_legth(): 72 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 73 | 74 | left_bound = 1 / 6 75 | crop_length = 6 76 | 77 | output = F.time_crop(input, left_bound, crop_length) 78 | expected = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 79 | 80 | assert np.allclose(output, expected) 81 | 82 | def test_time_crop_CASE_large_length(): 83 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 84 | 85 | left_bound = 1 / 6 86 | crop_length = 8 87 | 88 | with pytest.raises(ValueError): 89 | F.time_crop(input, left_bound, crop_length) 90 | 91 | def test_pad_CASE_border_constant(): 92 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 93 | 94 | left_pad = 2 95 | rigth_pad = 2 96 | 97 | output = F.pad(input, left_pad, rigth_pad, E.BorderType.CONSTANT, 0.) 98 | expected = np.array([[0, 0, 1, 2, 3, 4, 5, 6, 0, 0], [0, 0, 6, 5, 4, 3, 2, 1, 0, 0]]).T 99 | 100 | assert np.allclose(output, expected) 101 | 102 | def test_pad_CASE_border_replicate(): 103 | input = np.array([[1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]]).T 104 | 105 | left_pad = 2 106 | rigth_pad = 2 107 | 108 | output = F.pad(input, left_pad, rigth_pad, E.BorderType.REPLICATE, None) 109 | expected = np.array([[1, 1, 1, 2, 3, 4, 5, 6, 6, 6], [6, 6, 6, 5, 4, 3, 2, 1, 1, 1]]).T 110 | 111 | assert np.allclose(output, expected) 112 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/time/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | 4 | from itertools import tee 5 | 6 | import ecgmentations.core.enum as E 7 | import ecgmentations.core.constants as C 8 | import ecgmentations.augmentations.functional as F 9 | 10 | def time_reverse(ecg): 11 | """Reverse spatial dim 12 | """ 13 | ecg = np.flip(ecg, axis=C.SPATIAL_DIM) 14 | 15 | return np.require(ecg, requirements=['C_CONTIGUOUS']) 16 | 17 | def time_shift(ecg, shift, border_mode, fill_value): 18 | length = ecg.shape[C.SPATIAL_DIM] 19 | 20 | pad_ = int(length*shift) 21 | 22 | if pad_ > 0: 23 | ecg = pad(ecg, pad_, 0, border_mode, fill_value) 24 | ecg = ecg[:-pad_] 25 | elif pad_ < 0: 26 | ecg = pad(ecg, 0, -pad_, border_mode, fill_value) 27 | ecg = ecg[-pad_:] 28 | 29 | return np.require(ecg, requirements=['C_CONTIGUOUS']) 30 | 31 | def time_segment_swap(ecg, segment_order): 32 | shape = ecg.shape 33 | length = shape[C.SPATIAL_DIM] 34 | 35 | time_point_order = np.arange(length) 36 | 37 | num_segments = len(segment_order) 38 | time_point_order = np.array( 39 | np.array_split(time_point_order, num_segments) 40 | )[segment_order] 41 | 42 | ecg = ecg[time_point_order] 43 | 44 | if len(shape) == C.NUM_MULTI_CHANNEL_DIMENSIONS: 45 | ecg.shape = (length, -1) 46 | else: 47 | ecg.shape = (length, ) 48 | 49 | return ecg 50 | 51 | def pairwise(iterable): 52 | first, second = tee(iterable) 53 | next(second, None) 54 | yield from zip(first, second) 55 | 56 | def time_wrap(ecg, cells, ncells): 57 | length = ecg.shape[C.SPATIAL_DIM] 58 | 59 | bounds = (length * cells).astype(np.int32) 60 | nbounds = (length * ncells).astype(np.int32) 61 | 62 | necg = np.zeros_like(ecg) 63 | 64 | for (left_bound, rigth_bound), (left_nbound, rigth_nbound) in zip(pairwise(bounds), pairwise(nbounds)): 65 | necg[left_nbound: rigth_nbound] = np.apply_along_axis( 66 | lambda ecg: np.interp( 67 | np.linspace(0, 1, rigth_nbound - left_nbound), 68 | np.linspace(0, 1, rigth_bound - left_bound), 69 | ecg 70 | ), 71 | axis=C.SPATIAL_DIM, 72 | arr=ecg[left_bound: rigth_bound] 73 | ) 74 | 75 | return necg 76 | 77 | def time_cutout(ecg, cutouts, fill_value): 78 | ecg = np.copy(ecg) 79 | 80 | for cutout_start, cutout_length in cutouts: 81 | ecg[cutout_start: cutout_start+cutout_length] = fill_value 82 | 83 | return ecg 84 | 85 | def time_crop(ecg, left_bound, crop_length): 86 | length = ecg.shape[C.SPATIAL_DIM] 87 | 88 | if length < crop_length: 89 | raise ValueError( 90 | 'Requested crop length {crop_length} is ' 91 | 'larger than the ecg length {length}'.format( 92 | crop_length=crop_length, length=length 93 | ) 94 | ) 95 | 96 | t1 = int((length - crop_length) * left_bound) 97 | t2 = t1 + crop_length 98 | 99 | return ecg[t1:t2] 100 | 101 | def pad(ecg, left_pad, rigth_pad, border_mode, fill_value): 102 | kwargs = dict() 103 | 104 | if border_mode == E.BorderType.CONSTANT: 105 | kwargs['constant_values'] = fill_value 106 | 107 | func = lambda arr: np.pad( 108 | arr, 109 | pad_width=(left_pad, rigth_pad), 110 | mode=C.MAP_BORDER_TYPE_TO_NUMPY[border_mode], 111 | **kwargs 112 | ) 113 | 114 | if len(ecg.shape) == C.NUM_MULTI_CHANNEL_DIMENSIONS: 115 | ecg = F.apply_along_dim(ecg, func, C.CHANNEL_DIM) 116 | else: 117 | ecg = func(ecg) 118 | 119 | return ecg 120 | 121 | def pooling(ecg, reduction, kernel_size, border_mode, fill_value): 122 | if reduction == E.ReductionType.MIN: 123 | filter = sp.ndimage.minimum_filter 124 | elif reduction == E.ReductionType.MEAN: 125 | filter = sp.ndimage.uniform_filter 126 | elif reduction == E.ReductionType.MAX: 127 | filter = sp.ndimage.maximum_filter 128 | elif reduction == E.ReductionType.MEDIAN: 129 | filter = sp.ndimage.median_filter 130 | else: 131 | raise ValueError('Get invalide reduction: {}'.format(reduction)) 132 | 133 | pad_width = kernel_size // 2 134 | 135 | ecg = pad(ecg, pad_width, pad_width, border_mode, fill_value) 136 | ecg = filter(ecg, size=kernel_size) 137 | 138 | ecg = ecg[pad_width:-pad_width] 139 | 140 | return ecg 141 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/filter/transformations.py: -------------------------------------------------------------------------------- 1 | import ecgmentations.augmentations.misc as M 2 | import ecgmentations.augmentations.filter.functional as F 3 | 4 | from ecgmentations.core.augmentation import EcgOnlyAugmentation 5 | 6 | class LowPassFilter(EcgOnlyAugmentation): 7 | """Apply low-pass filter to the input ecg. 8 | """ 9 | def __init__( 10 | self, 11 | ecg_frequency=500., 12 | cutoff_frequency=47., 13 | always_apply=False, 14 | p=1.0, 15 | ): 16 | """ 17 | :NOTE: 18 | cutoff_frequency: 19 | 47 Hz (default) "Effective Data Augmentation, Filters, and Automation Techniques for Automatic 12-Lead ECG Classification Using Deep Residual" 20 | 21 | :args: 22 | ecg_frequency: float 23 | frequency of the input ecg 24 | cutoff_frequency: float 25 | cutoff frequency for filter 26 | """ 27 | super(LowPassFilter, self).__init__(always_apply, p) 28 | 29 | self.ecg_frequency = M.prepare_non_negative_float(ecg_frequency, 'ecg_frequency') 30 | self.cutoff_frequency = M.prepare_non_negative_float(cutoff_frequency, 'cutoff_frequency') 31 | 32 | def apply(self, ecg, **params): 33 | return F.lowpass_filter(ecg, self.ecg_frequency, self.cutoff_frequency) 34 | 35 | def get_transform_init_args_names(self): 36 | return ('ecg_frequency', 'cutoff_frequency') 37 | 38 | class HighPassFilter(EcgOnlyAugmentation): 39 | """Apply high-pass filter to the input ecg. 40 | """ 41 | def __init__( 42 | self, 43 | ecg_frequency=500., 44 | cutoff_frequency=0.5, 45 | always_apply=False, 46 | p=1.0, 47 | ): 48 | """ 49 | :NOTE: 50 | cutoff_frequency: 51 | 0.5 Hz (default) "Effective Data Augmentation, Filters, and Automation Techniques for Automatic 12-Lead ECG Classification Using Deep Residual" 52 | 53 | :args: 54 | ecg_frequency: float 55 | frequency of the input ecg 56 | cutoff_frequency: float 57 | cutoff frequency for filter 58 | """ 59 | super(HighPassFilter, self).__init__(always_apply, p) 60 | 61 | self.ecg_frequency = M.prepare_non_negative_float(ecg_frequency, 'ecg_frequency') 62 | self.cutoff_frequency = M.prepare_non_negative_float(cutoff_frequency, 'cutoff_frequency') 63 | 64 | def apply(self, ecg, **params): 65 | return F.highpass_filter(ecg, self.ecg_frequency, self.cutoff_frequency) 66 | 67 | def get_transform_init_args_names(self): 68 | return ('ecg_frequency', 'cutoff_frequency') 69 | 70 | class BandPassFilter(EcgOnlyAugmentation): 71 | """Apply band-pass filter to the input ecg. 72 | """ 73 | def __init__( 74 | self, 75 | ecg_frequency=500., 76 | cutoff_frequencies=(0.5, 47.), 77 | always_apply=False, 78 | p=1.0, 79 | ): 80 | """ 81 | :NOTE: 82 | cutoff_frequencies: 83 | see params of LowPassFilter and HighPassFilter 84 | 85 | :args: 86 | ecg_frequency: float 87 | frequency of the input ecg 88 | cutoff_frequencies: tuple of float 89 | cutoff frequencies for filter 90 | """ 91 | super(BandPassFilter, self).__init__(always_apply, p) 92 | 93 | self.ecg_frequency = M.prepare_non_negative_float(ecg_frequency, 'ecg_frequency') 94 | self.cutoff_frequencies = M.prepare_float_asymrange(cutoff_frequencies, 'cutoff_frequencies', low=0.) 95 | 96 | def apply(self, ecg, **params): 97 | return F.bandpass_filter(ecg, self.ecg_frequency, self.cutoff_frequencies) 98 | 99 | def get_transform_init_args_names(self): 100 | return ('ecg_frequency', 'cutoff_frequencies') 101 | 102 | class SigmoidCompression(EcgOnlyAugmentation): 103 | """Apply sigmoid compression to the input ecg. 104 | """ 105 | def __init__( 106 | self, 107 | always_apply=False, 108 | p=1.0, 109 | ): 110 | super(SigmoidCompression, self).__init__(always_apply, p) 111 | 112 | def apply(self, ecg, **params): 113 | return F.sigmoid_compression(ecg) 114 | 115 | def get_transform_init_args_names(self): 116 | return tuple() 117 | -------------------------------------------------------------------------------- /ecgmentations/core/augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from ecgmentations.core.transformation import Transformation 5 | from ecgmentations.core.utils import format_args 6 | 7 | class Augmentation(Transformation): 8 | """Root class for single augmentations 9 | """ 10 | def __init__(self, always_apply=False, p=0.5): 11 | """ 12 | :args: 13 | always_apply: bool 14 | the flag of force application 15 | p: float 16 | the probability of application 17 | """ 18 | super(Augmentation, self).__init__(always_apply, p) 19 | 20 | def __call__(self, *args, force_apply=False, **data): 21 | """ 22 | :args: 23 | force_apply: bool 24 | the flag of force application 25 | data: dict 26 | the data to make a transformation 27 | 28 | :return: 29 | dict of transformed data 30 | """ 31 | if args: 32 | raise KeyError('You have to pass data to augmentations as named arguments, for example: aug(ecg=ecg)') 33 | 34 | if self.whether_apply(force_apply): 35 | params = self.get_params() 36 | 37 | if self.targets_as_params: 38 | assert all(name in data for name in self.targets_as_params), '{} requires {}'.format( 39 | self.get_class_name(), self.targets_as_params 40 | ) 41 | 42 | targets_as_params = {name: data[name] for name in self.targets_as_params} 43 | 44 | params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params) 45 | params.update(params_dependent_on_targets) 46 | 47 | return self.apply_with_params(params, **data) 48 | 49 | return data 50 | 51 | def apply_with_params(self, params, **data): 52 | if params is None: 53 | return data 54 | 55 | pdata = {} 56 | 57 | for name, datum in data.items(): 58 | if datum is not None: 59 | target_function = self._get_target_function(name) 60 | pdata[name] = target_function(datum, **params) 61 | else: 62 | pdata[name] = None 63 | 64 | return pdata 65 | 66 | def __repr__(self): 67 | state = self.get_base_init_args() 68 | state.update(self.get_transform_init_args()) 69 | 70 | name = self.get_class_name() 71 | args=format_args(state) 72 | 73 | return '{}({})'.format(name, args) 74 | 75 | def _get_target_function(self, name): 76 | target_function = self.targets.get(name, lambda x, **p: x) 77 | return target_function 78 | 79 | def get_params(self): 80 | return {} 81 | 82 | @property 83 | def targets(self): 84 | """ 85 | :NOTE: 86 | you must specify targets in subclass 87 | 88 | for example: ('ecg', ) or ('ecg', 'mask') 89 | """ 90 | raise NotImplementedError 91 | 92 | @property 93 | def targets_as_params(self): 94 | return [] 95 | 96 | def get_params_dependent_on_targets(self, params): 97 | raise NotImplementedError( 98 | 'Method get_params_dependent_on_targets is not implemented in class {}'.format(self.get_class_name()) 99 | ) 100 | 101 | def get_transform_init_args_names(self): 102 | raise NotImplementedError( 103 | 'Class {} is not serializable because the `get_transform_init_args_names` method is not ' 104 | 'implemented'.format(self.get_class_name()) 105 | ) 106 | 107 | def get_transform_init_args(self): 108 | return {k: getattr(self, k) for k in self.get_transform_init_args_names()} 109 | 110 | class EcgOnlyAugmentation(Augmentation): 111 | """Augmentation applied to ecg only 112 | """ 113 | def apply(self, ecg, **params): 114 | raise NotImplementedError 115 | 116 | @property 117 | def targets(self): 118 | return { 'ecg': self.apply } 119 | 120 | class DualAugmentation(Augmentation): 121 | """Augmentation for segmentation task 122 | """ 123 | @property 124 | def targets(self): 125 | return { 126 | 'ecg': self.apply, 127 | 'mask': self.apply_to_mask, 128 | } 129 | 130 | def apply_to_mask(self, mask, **params): 131 | return self.apply(mask, **{k: cv2.INTER_NEAREST if k == 'interpolation' else v for k, v in params.items()}) 132 | 133 | class Identity(DualAugmentation): 134 | """Identity transform 135 | """ 136 | def apply(self, ecg, **params): 137 | return ecg 138 | 139 | def get_transform_init_args_names(self): 140 | return tuple() 141 | -------------------------------------------------------------------------------- /ecgmentations/core/composition.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ecgmentations.core.transformation import Transformation 4 | from ecgmentations.core.utils import format_args, get_shortest_class_fullname 5 | 6 | class Composition(Transformation): 7 | def __init__(self, transforms, always_apply, p): 8 | """ 9 | :args: 10 | transforms: list of Transformation 11 | list of operations to compose 12 | always_apply: bool 13 | the flag of force application 14 | p: float 15 | the probability of application 16 | """ 17 | super(Composition, self).__init__(always_apply, p) 18 | 19 | if not isinstance(transforms, list): 20 | raise RuntimeError( 21 | 'transforms is type of {} that is not list'.format(type(transforms)) 22 | ) 23 | elif not all(isinstance(t, Transformation) for t in transforms): 24 | for idx, t in enumerate(transforms): 25 | if not isinstance(t, Transformation): 26 | raise RuntimeError( 27 | 'object at {} position is not subtype of Transformation'.format(idx) 28 | ) 29 | 30 | self.transformations = transforms 31 | 32 | def __len__(self): 33 | return len(self.transformations) 34 | 35 | def __getitem__(self, idx): 36 | return self.transformations[idx] 37 | 38 | def __repr__(self): 39 | return self.repr() 40 | 41 | def repr(self, indent=Transformation.REPR_INDENT_STEP): 42 | args = self.get_base_init_args() 43 | 44 | repr_string = self.get_class_name() + '([' 45 | 46 | for t in self.transformations: 47 | repr_string += '\n' 48 | 49 | if hasattr(t, 'repr'): 50 | t_repr = t.repr(indent + self.REPR_INDENT_STEP) 51 | else: 52 | t_repr = repr(t) 53 | 54 | repr_string += ' ' * indent + t_repr + ',' 55 | 56 | repr_string += '\n' + ' ' * (indent - self.REPR_INDENT_STEP) + '], {args})'.format(args=format_args(args)) 57 | 58 | return repr_string 59 | 60 | @classmethod 61 | def get_class_fullname(cls): 62 | return get_shortest_class_fullname(cls) 63 | 64 | class Sequential(Composition): 65 | """Compose transforms to apply sequentially. 66 | """ 67 | def __init__(self, transforms, always_apply=False, p=1.0): 68 | """ 69 | :args: 70 | transforms: list of Apply 71 | list of operations to apply sequentially 72 | always_apply: bool 73 | the flag of force application 74 | p: float 75 | the probability of application 76 | """ 77 | super(Sequential, self).__init__(transforms, always_apply, p) 78 | 79 | def __call__(self, *args, force_apply=False, **data): 80 | if self.whether_apply(force_apply): 81 | for transform in self.transformations: 82 | data = transform(**data) 83 | 84 | return data 85 | 86 | class NonSequential(Sequential): 87 | """Compose transformations to apply sequentially in random order. 88 | """ 89 | def __call__(self, *args, force_apply=False, **data): 90 | if self.whether_apply(force_apply): 91 | np.random.shuffle(self.transformations) 92 | 93 | for transform in self.transformations: 94 | data = transform(**data) 95 | 96 | return data 97 | 98 | class OneOf(Composition): 99 | """Select one of transforms to apply. 100 | """ 101 | def __init__(self, transforms, always_apply=False, p=0.5): 102 | """ 103 | :NOTE: 104 | transform probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. 105 | 106 | :args: 107 | transforms: list of Apply 108 | list of operations to select one to apply 109 | always_apply: bool 110 | the flag of force application 111 | p: float 112 | the probability of application 113 | """ 114 | super(OneOf, self).__init__(transforms, always_apply, p) 115 | 116 | transforms_ps = [t.p for t in self.transformations] 117 | s = sum(transforms_ps) 118 | 119 | self.transformations_ps = [t / s for t in transforms_ps] 120 | 121 | def __call__(self, *args, force_apply = False, **data): 122 | if self.transformations_ps and self.whether_apply(force_apply): 123 | transform = np.random.choice(self.transformations, p=self.transformations_ps) 124 | data = transform(force_apply=True, **data) 125 | 126 | return data 127 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/misc.py: -------------------------------------------------------------------------------- 1 | def _non_negative_number(param, name): 2 | if param < 0: 3 | raise ValueError('{} should be non negative.'.format(name)) 4 | 5 | return param 6 | 7 | def _prepare_int(param, name, check): 8 | if not isinstance(param, int): 9 | raise ValueError( 10 | '{} must be scalar (int).'.format( 11 | name 12 | ) 13 | ) 14 | 15 | param = int(param) 16 | 17 | return check(param, name) 18 | 19 | prepare_int = lambda param, name: _prepare_int(param, name, lambda x, _: x) 20 | prepare_non_negative_int = lambda param, name: _prepare_int(param, name, _non_negative_number) 21 | 22 | def _prepare_float(param, name, check): 23 | if not isinstance(param, (float, int)): 24 | raise ValueError( 25 | '{} must be scalar (float).'.format( 26 | name 27 | ) 28 | ) 29 | 30 | param = float(param) 31 | 32 | return check(param, name) 33 | 34 | prepare_float = lambda param, name: _prepare_float(param, name, lambda x, _: x) 35 | prepare_non_negative_float = lambda param, name: _prepare_float(param, name, _non_negative_number) 36 | 37 | def prepare_int_asymrange(param, name, low): 38 | if isinstance(param, int): 39 | if param >= low: 40 | param = (low, param) 41 | else: 42 | raise ValueError( 43 | 'Invalid value of {}. Got {} that less than {}.'.format( 44 | name, param, low 45 | ) 46 | ) 47 | elif isinstance(param, (tuple, list)): 48 | if len(param) == 2: 49 | if not (list(map(type, param)) == [int, int]): 50 | raise ValueError( 51 | '{} must be tuple (int, int).'.format( 52 | name 53 | ) 54 | ) 55 | 56 | if param[0] > param[1]: 57 | param = (param[1], param[0]) 58 | 59 | if param[0] < low: 60 | raise ValueError( 61 | 'Invalid value of {}. Got: {} that less than {}.'.format( 62 | name, param, low 63 | ) 64 | ) 65 | else: 66 | raise ValueError( 67 | 'Invalid value of {}. Got {}'.format( 68 | name, param 69 | ) 70 | ) 71 | else: 72 | raise ValueError( 73 | '{} must be either scalar (int) or tuple (int, int).'.format( 74 | name 75 | ) 76 | ) 77 | 78 | return tuple(param) 79 | 80 | def prepare_float_symrange(param, name): 81 | if isinstance(param, float): 82 | if param >= 0: 83 | param = (-param, param) 84 | else: 85 | param = (param, -param) 86 | elif isinstance(param, (tuple, list)): 87 | if len(param) == 2: 88 | if param[0] > param[1]: 89 | param = (param[1], param[0]) 90 | else: 91 | raise ValueError( 92 | 'Invalid value of {}. Got: {}.'.format( 93 | name, param 94 | ) 95 | ) 96 | else: 97 | raise ValueError( 98 | '{} must be either scalar (float) or tuple (float, float).'.format( 99 | name 100 | ) 101 | ) 102 | 103 | return tuple(param) 104 | 105 | def prepare_float_asymrange(param, name, low): 106 | if isinstance(param, float): 107 | if param >= low: 108 | param = (low, param) 109 | else: 110 | raise ValueError( 111 | 'Invalid value of {}. Got {} that less than {}.'.format( 112 | name, param, low 113 | ) 114 | ) 115 | elif isinstance(param, (tuple, list)): 116 | if len(param) == 2: 117 | if not (list(map(type, param)) == [float, float]): 118 | raise ValueError( 119 | '{} must be tuple (float, float).'.format( 120 | name 121 | ) 122 | ) 123 | 124 | if param[0] > param[1]: 125 | param = (param[1], param[0]) 126 | 127 | if param[0] < low: 128 | raise ValueError( 129 | 'Invalid value of {}. Got: {} that less than {}.'.format( 130 | name, param, low 131 | ) 132 | ) 133 | else: 134 | raise ValueError( 135 | 'Invalid value of {}. Got {}'.format( 136 | name, param 137 | ) 138 | ) 139 | else: 140 | raise ValueError( 141 | '{} must be either scalar (float) or tuple (float, float).'.format( 142 | name 143 | ) 144 | ) 145 | 146 | return tuple(param) 147 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/transformations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import ecgmentations.core.enum as E 5 | import ecgmentations.core.constants as C 6 | import ecgmentations.augmentations.misc as M 7 | import ecgmentations.augmentations.functional as F 8 | 9 | from ecgmentations.core.augmentation import EcgOnlyAugmentation, DualAugmentation 10 | 11 | class AmplitudeInvert(EcgOnlyAugmentation): 12 | """Invert the input ecg. 13 | """ 14 | def apply(self, ecg, **params): 15 | return F.amplitude_invert(ecg) 16 | 17 | def get_transform_init_args_names(self): 18 | return tuple() 19 | 20 | class ChannelShuffle(EcgOnlyAugmentation): 21 | """Randomly rearrange channels of the input ecg. 22 | """ 23 | def apply(self, ecg, channel_order, **params): 24 | return F.channel_shuffle(ecg, channel_order) 25 | 26 | @property 27 | def targets_as_params(self): 28 | return ['ecg'] 29 | 30 | def get_params_dependent_on_targets(self, params): 31 | if len(params['ecg'].shape) == C.NUM_MONO_CHANNEL_DIMENSIONS: 32 | raise RuntimeError('Ecg has implicit channel. ChannelShuffle is not defined.') 33 | 34 | channel_order = np.arange(params['ecg'].shape[C.CHANNEL_DIM]) 35 | np.random.shuffle(channel_order) 36 | 37 | return {'channel_order': channel_order} 38 | 39 | def get_transform_init_args_names(self): 40 | return () 41 | 42 | class ChannelDropout(EcgOnlyAugmentation): 43 | """Randomly drop channels in the input ecg. 44 | """ 45 | def __init__( 46 | self, 47 | channel_drop_range=(1, 1), 48 | fill_value=0, 49 | always_apply=False, 50 | p=0.5 51 | ): 52 | """ 53 | :args: 54 | channel_drop_range: (int, int) 55 | range for select the number of dropping channels 56 | fill_value: int 57 | fill value for dropped channels 58 | """ 59 | super(ChannelDropout, self).__init__(always_apply, p) 60 | 61 | self.channel_drop_range = M.prepare_int_asymrange(channel_drop_range, 'channel_drop_range', 1) 62 | 63 | self.min_drop_channels = channel_drop_range[0] 64 | self.max_drop_channels = channel_drop_range[1] 65 | 66 | self.fill_value = M.prepare_float(fill_value, 'fill_value') 67 | 68 | def apply(self, ecg, channels_to_drop, **params): 69 | return F.channel_dropout(ecg, channels_to_drop, self.fill_value) 70 | 71 | @property 72 | def targets_as_params(self): 73 | return ['ecg'] 74 | 75 | def get_params_dependent_on_targets(self, params): 76 | if len(params['ecg'].shape) == C.NUM_MONO_CHANNEL_DIMENSIONS: 77 | raise RuntimeError('Ecg has implicit channel. ChannelDropout is not defined.') 78 | 79 | num_channels = params['ecg'].shape[C.CHANNEL_DIM] 80 | 81 | if num_channels == 1: 82 | raise NotImplementedError('Ecg has one channel. ChannelDropout is not defined.') 83 | 84 | if not ( self.max_drop_channels < num_channels ): 85 | raise ValueError('Can not drop all channels in ChannelDropout.') 86 | 87 | num_drop_channels = np.random.randint(low=self.min_drop_channels, high=self.max_drop_channels + 1) 88 | channels_to_drop = np.random.choice(num_channels, size=num_drop_channels) 89 | 90 | return {'channels_to_drop': channels_to_drop} 91 | 92 | def get_transform_init_args_names(self): 93 | return ('channel_drop_range', 'fill_value') 94 | 95 | class GaussNoise(EcgOnlyAugmentation): 96 | """Randomly add gaussian noise to the input ecg. 97 | """ 98 | def __init__( 99 | self, 100 | mean=0., 101 | variance=0.01, 102 | per_channel=True, 103 | always_apply=False, 104 | p=0.5 105 | ): 106 | """ 107 | :NOTE: 108 | variance: 109 | 0.01 mV (default) "Self-supervised representation learning from 12-lead ECG data" 110 | 0.02 mV "RandECG: Data Augmentation for Deep Neural Network Based ECG Classification" 111 | 112 | :args: 113 | mean: float 114 | mean of gaussian noise 115 | variance: float 116 | variance of gaussian noise 117 | per_channel: bool 118 | if set to True, noise will be sampled for each channel independently 119 | """ 120 | super(GaussNoise, self).__init__(always_apply, p) 121 | 122 | self.mean = M.prepare_float(mean, 'mean') 123 | self.variance = M.prepare_non_negative_float(variance, 'variance') 124 | self.per_channel = per_channel 125 | 126 | def apply(self, ecg, gauss, **params): 127 | return F.add(ecg, gauss) 128 | 129 | @property 130 | def targets_as_params(self): 131 | return ['ecg'] 132 | 133 | def get_params_dependent_on_targets(self, params): 134 | if self.per_channel and len(params['ecg'].shape) == C.NUM_MULTI_CHANNEL_DIMENSIONS: 135 | shape = params['ecg'].shape 136 | else: 137 | shape = params['ecg'].shape[:C.NUM_SPATIAL_DIMENSIONS] 138 | 139 | gauss = np.random.normal(self.mean, self.variance**0.5, shape) 140 | 141 | return {'gauss': gauss} 142 | 143 | def get_transform_init_args_names(self): 144 | return ('mean', 'variance', 'per_channel') 145 | 146 | class GaussBlur(EcgOnlyAugmentation): 147 | """Blur by gaussian the input ecg. 148 | """ 149 | def __init__( 150 | self, 151 | variance=1., 152 | kernel_size_range=(5, 5), 153 | always_apply=False, 154 | p=0.5 155 | ): 156 | """ 157 | :NOTE: 158 | transformation is similar to gaussian blur in paper "Self-supervised representation learning from 159 | 12-lead ECG data" with variance equals one and kernel size equals five 160 | 161 | code kernel is about of (0.05, 0.25, 0.40, 0.25, 0.05) 162 | paper kernel is (0.10, 0.20, 0.40, 0.20, 0.10) 163 | 164 | :args: 165 | variance: float 166 | variance of gaussian kernel 167 | kernel_size_range: (int, int) 168 | range for select kernel size of blur filter 169 | """ 170 | super(GaussBlur, self).__init__(always_apply, p) 171 | 172 | self.variance = M.prepare_non_negative_float(variance, 'variance') 173 | self.kernel_size_range = M.prepare_int_asymrange(kernel_size_range, 'kernel_size_range', 0) 174 | 175 | self.min_kernel_size = kernel_size_range[0] 176 | self.max_kernel_size = kernel_size_range[1] 177 | 178 | if self.min_kernel_size % 2 == 0 or self.max_kernel_size % 2 == 0: 179 | raise ValueError('Invalid range borders. Must be odd, but got: {}.'.format(kernel_size_range)) 180 | 181 | def apply(self, ecg, kernel, **params): 182 | return F.conv(ecg, kernel, E.BorderType.CONSTANT, 0) 183 | 184 | def get_params(self): 185 | kernel_size = 2 * np.random.randint(self.min_kernel_size // 2, self.max_kernel_size // 2 + 1) + 1 186 | 187 | kernel = np.exp(-0.5 * np.square(np.arange(-kernel_size, kernel_size+1)) / self.variance) 188 | kernel = kernel / np.sum(kernel) 189 | 190 | return {'kernel': kernel} 191 | 192 | def get_transform_init_args_names(self): 193 | return ('variance', 'kernel_size_range') 194 | 195 | class AmplitudeScale(EcgOnlyAugmentation): 196 | """Scale amplitude of the input ecg. 197 | """ 198 | def __init__( 199 | self, 200 | scaling_range=(-0.05, 0.05), 201 | always_apply=False, 202 | p=0.5 203 | ): 204 | """ 205 | :args: 206 | scaling_range: (float, float) 207 | range for selecting scaling factor 208 | """ 209 | super(AmplitudeScale, self).__init__(always_apply, p) 210 | 211 | self.scaling_range = M.prepare_float_symrange(scaling_range, 'scaling_range') 212 | 213 | self.min_scaling_range = self.scaling_range[0] 214 | self.max_scaling_range = self.scaling_range[1] 215 | 216 | def apply(self, ecg, scaling_factor, **params): 217 | return F.multiply(ecg, scaling_factor) 218 | 219 | def get_params(self): 220 | scaling_factor = 1 + np.random.uniform(self.min_scaling_range, self.max_scaling_range) 221 | 222 | return {'scaling_factor': scaling_factor} 223 | 224 | def get_transform_init_args_names(self): 225 | return ('scaling_range', ) 226 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/pulse/transformations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import ecgmentations.augmentations.misc as M 4 | import ecgmentations.augmentations.pulse.functional as F 5 | 6 | from ecgmentations.core.augmentation import EcgOnlyAugmentation 7 | 8 | class SinePulse(EcgOnlyAugmentation): 9 | """Add sine pulse to the input ecg. 10 | """ 11 | def __init__( 12 | self, 13 | ecg_frequency=500., 14 | pulse_frequency_range=(0., 1.), 15 | amplitude_limit=1., 16 | always_apply=False, 17 | p=0.5 18 | ): 19 | """ 20 | :NOTE: 21 | amplitude_limit: 22 | 1 mV (default) "RandECG: Data Augmentation for Deep Neural Network Based ECG Classification" 23 | 24 | pulse_frequency_range: 25 | 0 - 1 Hz (default) "RandECG: Data Augmentation for Deep Neural Network Based ECG Classification" 26 | 27 | :args: 28 | ecg_frequency: float 29 | frequency of the input ecg 30 | pulse_frequency_range: (float, float) 31 | range of pulse frequency 32 | amplitude_limit: float 33 | limit of pulse amplitude 34 | """ 35 | super(SinePulse, self).__init__(always_apply, p) 36 | 37 | self.ecg_frequency = M.prepare_non_negative_float(ecg_frequency, 'ecg_frequency') 38 | self.pulse_frequency_range = M.prepare_float_asymrange(pulse_frequency_range, 'pulse_frequency_range', 0.) 39 | 40 | self.pulse_frequency_min = self.pulse_frequency_range[0] 41 | self.pulse_frequency_max = self.pulse_frequency_range[1] 42 | self.pulse_frequency_delta = self.pulse_frequency_max - self.pulse_frequency_min 43 | 44 | self.amplitude_limit = M.prepare_non_negative_float(amplitude_limit, 'amplitude_limit') 45 | 46 | def apply(self, ecg, amplitude, frequency, phase, **params): 47 | return F.add_sine_pulse(ecg, self.ecg_frequency, amplitude, frequency, phase) 48 | 49 | def get_params(self): 50 | amplitude = np.random.random() * self.amplitude_limit 51 | frequency = np.random.random() * self.pulse_frequency_delta + self.pulse_frequency_min 52 | phase = np.random.random() * 2 * np.pi 53 | 54 | return {'amplitude': amplitude, 'frequency': frequency, 'phase': phase} 55 | 56 | def get_transform_init_args_names(self): 57 | return ('ecg_frequency', 'pulse_frequency_range', 'amplitude_limit') 58 | 59 | class PowerlineNoise(SinePulse): 60 | """Add powerline noise to the input ecg. 61 | """ 62 | def __init__( 63 | self, 64 | ecg_frequency=500., 65 | powerline_frequency=50., 66 | amplitude_limit=0.3, 67 | always_apply=False, 68 | p=0.5 69 | ): 70 | """ 71 | :NOTE: 72 | powerline frequency: 73 | 50 Hz is for Europe 74 | 60 Hz is for USA or Asia 75 | 76 | amplitude_limit: 77 | 0.3 mV (default) "IIR digital filter design for powerline noise cancellation of ECG signal using arduino platform" 78 | 0.333 mV "A Comparison of the Noise Sensitivity of Nine QRS Detection Algorithms" 79 | 0.25 mV (low noise) "Self-supervised representation learning from 12-lead ECG data" 80 | 2. mV (high noise) "Self-supervised representation learning from 12-lead ECG data" 81 | 82 | :args: 83 | ecg_frequency: float 84 | frequency of the input ecg 85 | powerline_frequency: float 86 | frequency of powerline 87 | amplitude_limit: float 88 | limit of noise amplitude 89 | """ 90 | self.powerline_frequency = M.prepare_non_negative_float(powerline_frequency, 'powerline_frequency') 91 | powerline_frequency_range = (self.powerline_frequency, self.powerline_frequency) 92 | 93 | super(PowerlineNoise, self).__init__(ecg_frequency, powerline_frequency_range, amplitude_limit, always_apply, p) 94 | 95 | def get_transform_init_args_names(self): 96 | return ('ecg_frequency', 'powerline_frequency', 'amplitude_limit') 97 | 98 | class RespirationNoise(SinePulse): 99 | """Add respiration noise to the input ecg 100 | """ 101 | def __init__( 102 | self, 103 | ecg_frequency=500., 104 | breathing_rate_range=(12, 18), 105 | amplitude_limit=1., 106 | always_apply=False, 107 | p=0.5 108 | ): 109 | """ 110 | :NOTE: 111 | breathing_rate: 112 | 0.333 Hz "A Comparison of the Noise Sensitivity of Nine QRS Detection Algorithms" 113 | 0.333 equals to 20 bpm 114 | 115 | amplitude_limit: 116 | 1 mV (default) "A Comparison of the Noise Sensitivity of Nine QRS Detection Algorithms" 117 | 118 | :args: 119 | ecg_frequency: float 120 | frequency of the input ecg 121 | breathing_rate_range: (int, int) 122 | breathing rate range in bpm 123 | amplitude_limit: float 124 | limit of noise amplitude 125 | """ 126 | self.breathing_rate_range = M.prepare_int_asymrange(breathing_rate_range, 'breathing_rate_range', 0) 127 | breathing_frequency_range = (breathing_rate_range[0] / 60, breathing_rate_range[1] / 60) 128 | 129 | super(RespirationNoise, self).__init__(ecg_frequency, breathing_frequency_range, amplitude_limit, always_apply, p) 130 | 131 | def get_transform_init_args_names(self): 132 | return ('ecg_frequency', 'breathing_rate_range', 'amplitude_limit') 133 | 134 | # class BaselineWander(EcgOnlyAugmentation): 135 | # """Add baseline wander to the input ecg. 136 | # """ 137 | # def __init__( 138 | # self, 139 | # ecg_frequency=500., 140 | # amplitude_limit=0.05, 141 | # always_apply=False, 142 | # p=0.5 143 | # ): 144 | # """ 145 | # :NOTE: 146 | # amplitude_limit: 147 | # 0.05 mV (default, low noise) "Self-supervised representation learning from 12-lead ECG data" 148 | # 0.3 mV - (high noise) "Self-supervised representation learning from 12-lead ECG data" 149 | 150 | # :args: 151 | # ecg_frequency: float 152 | # frequency of the input ecg 153 | # amplitude_limit: float 154 | # limit of pulse amplitude 155 | # """ 156 | # super(BaselineWander, self).__init__(always_apply, p) 157 | 158 | # self.ecg_frequency = M.prepare_non_negative_float(ecg_frequency, 'ecg_frequency') 159 | # self.amplitude_limit = M.prepare_non_negative_float(amplitude_limit, 'amplitude_limit') 160 | # self.frequencies = (0.01, 0.02, 0.03, 0.04, 0.05) 161 | 162 | # # def apply(self, ecg, amplitude, frequency, phase, **params): 163 | # # return F.add_sine_pulse(ecg, self.ecg_frequency, amplitude, frequency, phase) 164 | 165 | # def get_params(self): 166 | # params = [ for _ in self.frequencies ] 167 | # # amplitude = np.random.random() * self.amplitude_limit 168 | # # phase = np.random.random() * 2 * np.pi 169 | 170 | # # return {'amplitude': amplitude, 'frequency': frequency, 'phase': phase} 171 | 172 | # def get_transform_init_args_names(self): 173 | # return ('ecg_frequency', 'amplitude_limit') 174 | 175 | class SquarePulse(EcgOnlyAugmentation): 176 | """Add square pulse to the input ecg. 177 | """ 178 | def __init__( 179 | self, 180 | ecg_frequency=500., 181 | pulse_frequency_range=(0., 5.), 182 | amplitude_limit=0.02, 183 | always_apply=False, 184 | p=0.5 185 | ): 186 | """ 187 | :NOTE: 188 | amplitude_limit: 189 | 0.02 mV (default) "RandECG: Data Augmentation for Deep Neural Network Based ECG Classification" 190 | 191 | pulse_frequency_range: 192 | 0 - 5 Hz (default) "RandECG: Data Augmentation for Deep Neural Network Based ECG Classification" 193 | 194 | :args: 195 | ecg_frequency: float 196 | frequency of the input ecg 197 | pulse_frequency_range: float 198 | range of pulse frequency 199 | amplitude_limit: float 200 | limit of pulse amplitude 201 | """ 202 | super(SquarePulse, self).__init__(always_apply, p) 203 | 204 | self.ecg_frequency = M.prepare_non_negative_float(ecg_frequency, 'ecg_frequency') 205 | self.pulse_frequency_range = M.prepare_float_asymrange(pulse_frequency_range, 'pulse_frequency_range', 0.) 206 | 207 | self.pulse_frequency_min = self.pulse_frequency_range[0] 208 | self.pulse_frequency_max = self.pulse_frequency_range[1] 209 | self.pulse_frequency_delta = self.pulse_frequency_max - self.pulse_frequency_min 210 | 211 | self.amplitude_limit = M.prepare_non_negative_float(amplitude_limit, 'amplitude_limit') 212 | 213 | def apply(self, ecg, amplitude, frequency, phase, **params): 214 | return F.add_square_pulse(ecg, self.ecg_frequency, amplitude, frequency, phase) 215 | 216 | def get_params(self): 217 | amplitude = np.random.random() * self.amplitude_limit 218 | frequency = np.random.random() * self.pulse_frequency_delta + self.pulse_frequency_min 219 | phase = np.random.random() * 2 * np.pi 220 | 221 | return {'amplitude': amplitude, 'frequency': frequency, 'phase': phase} 222 | 223 | def get_transform_init_args_names(self): 224 | return ('ecg_frequency', 'pulse_frequency_range', 'amplitude_limit') 225 | -------------------------------------------------------------------------------- /ecgmentations/augmentations/time/transformations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import ecgmentations.core.enum as E 4 | import ecgmentations.core.constants as C 5 | import ecgmentations.augmentations.misc as M 6 | import ecgmentations.augmentations.time.functional as F 7 | 8 | from ecgmentations.core.augmentation import EcgOnlyAugmentation, DualAugmentation 9 | 10 | class TimeReverse(DualAugmentation): 11 | """Reverse the input ecg. 12 | """ 13 | def apply(self, ecg, **params): 14 | return F.time_reverse(ecg) 15 | 16 | def get_transform_init_args_names(self): 17 | return tuple() 18 | 19 | class TimeShift(DualAugmentation): 20 | """Shift the input ecg along time axis. 21 | """ 22 | def __init__( 23 | self, 24 | shift_limit=0.05, 25 | border_mode=E.BorderType.DEFAULT, 26 | fill_value=0., 27 | mask_fill_value=0, 28 | always_apply=False, 29 | p=0.5, 30 | ): 31 | """ 32 | :args: 33 | shift_limit: float 34 | limit of shifting 35 | border_mode: OpenCV flag 36 | OpenCV border mode 37 | fill_value: int or float or None 38 | padding value if border_mode is E.BorderType.CONSTANT 39 | mask_fill_value: int or None 40 | padding value for mask if border_mode is E.BorderType.CONSTANT 41 | """ 42 | super(TimeShift, self).__init__(always_apply, p) 43 | 44 | self.shift_limit = M.prepare_non_negative_float(shift_limit, 'shift_limit') 45 | 46 | self.border_mode = border_mode 47 | self.fill_value = M.prepare_float(fill_value, 'fill_value') 48 | self.mask_fill_value = M.prepare_int(mask_fill_value, 'mask_fill_value') 49 | 50 | def apply(self, ecg, shift, **params): 51 | return F.time_shift(ecg, shift, self.border_mode, self.fill_value) 52 | 53 | def apply_to_mask(self, mask, shift, **params): 54 | return F.time_shift(mask, shift, self.border_mode, self.mask_fill_value) 55 | 56 | def get_params(self): 57 | shift = (2 * np.random.random() - 1) * self.shift_limit 58 | 59 | return {'shift': shift} 60 | 61 | def get_transform_init_args_names(self): 62 | return ('shift_limit', 'border_mode', 'fill_value', 'mask_fill_value') 63 | 64 | class TimeSegmentShuffle(DualAugmentation): 65 | """Randomly shuffle of the input ecg segments 66 | """ 67 | def __init__( 68 | self, 69 | num_segments=5, 70 | always_apply=False, 71 | p=0.5, 72 | ): 73 | """ 74 | :args: 75 | num_segments: int 76 | count of grid cells on the ecg 77 | """ 78 | super(TimeSegmentShuffle, self).__init__(always_apply, p) 79 | 80 | self.num_segments = M.prepare_non_negative_int(num_segments, 'num_segments') 81 | 82 | def apply(self, ecg, segment_order, **params): 83 | return F.time_segment_swap(ecg, segment_order) 84 | 85 | def get_params(self): 86 | segment_order = np.arange(self.num_segments) 87 | np.random.shuffle(segment_order) 88 | 89 | return {'segment_order': segment_order} 90 | 91 | def get_transform_init_args_names(self): 92 | return ('num_segments', ) 93 | 94 | class RandomTimeWrap(DualAugmentation): 95 | """Randomly stretch and squeeze contiguous segments of the input ecg 96 | """ 97 | def __init__( 98 | self, 99 | num_steps=5, 100 | wrap_limit=0.05, 101 | always_apply=False, 102 | p=0.5, 103 | ): 104 | """ 105 | :args: 106 | num_steps: int 107 | count of grid cells on the ecg 108 | wrap_limit: float 109 | limit of stretching or squeezing 110 | """ 111 | super(RandomTimeWrap, self).__init__(always_apply, p) 112 | 113 | self.num_steps = M.prepare_non_negative_int(num_steps, 'num_steps') 114 | self.wrap_limit = M.prepare_non_negative_float(wrap_limit, 'wrap_limit') 115 | 116 | def apply(self, ecg, cells, ncells, **params): 117 | return F.time_wrap(ecg, cells, ncells) 118 | 119 | def get_params(self): 120 | cells = np.linspace(0, 1, self.num_steps + 1) 121 | ncells = np.linspace(0, 1, self.num_steps + 1) 122 | 123 | if self.num_steps > 1: 124 | directions = np.random.choice([-1, 1], size=self.num_steps - 1) 125 | shifts = np.random.random(size=self.num_steps-1) * self.wrap_limit * 0.5 126 | 127 | ncells[1:-1] += shifts * directions / (self.num_steps + 1) 128 | 129 | return {'cells': cells, 'ncells': ncells} 130 | 131 | def get_transform_init_args_names(self): 132 | return ('num_steps', 'wrap_limit') 133 | 134 | class TimeCutout(DualAugmentation): 135 | """Randomly cutout time ranges in the input ecg. 136 | """ 137 | def __init__( 138 | self, 139 | num_ranges=(1, 5), 140 | length_range=(0, 50), 141 | fill_value=0., 142 | mask_fill_value=None, 143 | always_apply=False, 144 | p=0.5, 145 | ): 146 | """ 147 | :args: 148 | num_ranges: (int, int) 149 | number of cutout ranges 150 | length_range: (int, int) 151 | range for selecting cutout length 152 | fill_value: float 153 | value to fill cutouted ranges in the input ecg 154 | mask_fill_value: int or None 155 | value to fill cutouted ranges in the mask. if value is None, mask is not affected 156 | """ 157 | super(TimeCutout, self).__init__(always_apply, p) 158 | 159 | self.num_ranges = M.prepare_int_asymrange(num_ranges, 'num_ranges', 0) 160 | 161 | self.min_num_ranges = num_ranges[0] 162 | self.max_num_ranges = num_ranges[1] 163 | 164 | self.length_range = M.prepare_int_asymrange(length_range, 'length_range', 0) 165 | 166 | self.min_length_range = length_range[0] 167 | self.max_length_range = length_range[1] 168 | 169 | self.fill_value = M.prepare_float(fill_value, 'fill_value') 170 | self.mask_fill_value = mask_fill_value 171 | 172 | def apply(self, ecg, cutouts, **params): 173 | return F.time_cutout(ecg, cutouts, self.fill_value) 174 | 175 | def apply_to_mask(self, mask, cutouts, **params): 176 | if self.mask_fill_value is None: 177 | return mask 178 | else: 179 | return F.time_cutout(mask, cutouts, self.mask_fill_value) 180 | 181 | @property 182 | def targets_as_params(self): 183 | return ['ecg'] 184 | 185 | def get_params_dependent_on_targets(self, params): 186 | length = params['ecg'].shape[C.SPATIAL_DIM] 187 | 188 | cutouts = [] 189 | 190 | for _ in range(np.random.randint(self.min_num_ranges, self.max_num_ranges + 1)): 191 | cutout_length = np.random.randint(self.min_length_range, self.max_length_range + 1) 192 | cutout_start = np.random.randint(0, length - cutout_length + 1) 193 | 194 | cutouts.append((cutout_start, cutout_length)) 195 | 196 | return {'cutouts': cutouts} 197 | 198 | def get_transform_init_args_names(self): 199 | return ('num_ranges', 'length_range', 'fill_value', 'mask_fill_value') 200 | 201 | class TimeCrop(DualAugmentation): 202 | """Crop time segment from the input ecg. 203 | """ 204 | def __init__( 205 | self, 206 | length=5000, 207 | position=E.PositionType.RANDOM, 208 | always_apply=False, 209 | p=1.0, 210 | ): 211 | """ 212 | :args: 213 | lengthЭ: int 214 | the length of cropped time segment 215 | position: PositionType or str 216 | position of cropped time segment 217 | """ 218 | super(TimeCrop, self).__init__(always_apply, p) 219 | 220 | self.length = M.prepare_int(length, 'length') 221 | self.position = E.PositionType(position) 222 | 223 | def apply(self, ecg, left_bound, **params): 224 | return F.time_crop(ecg, left_bound, self.length) 225 | 226 | def get_params(self): 227 | if self.position == E.PositionType.LEFT: 228 | left_bound = 0.0 229 | elif self.position == E.PositionType.CENTER: 230 | left_bound = 0.5 231 | elif self.position == E.PositionType.RIGHT: 232 | left_bound = 1.0 233 | else: 234 | left_bound = np.random.random() 235 | 236 | return {'left_bound': left_bound} 237 | 238 | def get_transform_init_args_names(self): 239 | return ('length', 'position') 240 | 241 | class CenterTimeCrop(TimeCrop): 242 | """Crop time segment to the center of the input ecg. 243 | """ 244 | def __init__( 245 | self, 246 | length=5000, 247 | always_apply=False, 248 | p=1.0, 249 | ): 250 | """ 251 | :args: 252 | length: int 253 | the length of cropped region 254 | """ 255 | super(CenterTimeCrop, self).__init__(length, E.PositionType.CENTER, always_apply, p) 256 | 257 | def get_transform_init_args_names(self): 258 | return ('length', ) 259 | 260 | class RandomTimeCrop(TimeCrop): 261 | """Crop a random time segment of the input ecg. 262 | """ 263 | def __init__( 264 | self, 265 | length=5000, 266 | always_apply=False, 267 | p=1.0, 268 | ): 269 | """ 270 | :args: 271 | length: int 272 | the length of cropped region 273 | """ 274 | super(RandomTimeCrop, self).__init__(length, E.PositionType.RANDOM, always_apply, p) 275 | 276 | def get_transform_init_args_names(self): 277 | return ('length', ) 278 | 279 | class TimePadIfNeeded(DualAugmentation): 280 | """Pad lenght of the ecg to the minimal length. 281 | """ 282 | def __init__( 283 | self, 284 | min_length=5000, 285 | position=E.PositionType.CENTER, 286 | border_mode=E.BorderType.CONSTANT, 287 | fill_value=0., 288 | mask_fill_value=0, 289 | always_apply=False, 290 | p=1.0, 291 | ): 292 | """ 293 | :args: 294 | min_length: int 295 | minimal length to fill with padding 296 | position: PositionType or str 297 | position of ecg 298 | border_mode: OpenCV flag 299 | OpenCV border mode 300 | fill_value: int or float or None 301 | padding value if border_mode is E.BorderType.CONSTANT 302 | mask_fill_value: int or None 303 | padding value for mask if border_mode is E.BorderType.CONSTANT 304 | """ 305 | super(TimePadIfNeeded, self).__init__(always_apply, p) 306 | 307 | self.min_length = M.prepare_non_negative_int(min_length, 'min_length') 308 | self.position = E.PositionType(position) 309 | 310 | self.border_mode = border_mode 311 | self.fill_value = M.prepare_float(fill_value, 'fill_value') 312 | self.mask_fill_value = M.prepare_int(mask_fill_value, 'mask_fill_value') 313 | 314 | def apply(self, ecg, left_pad, rigth_pad, **params): 315 | return F.pad(ecg, left_pad, rigth_pad, self.border_mode, self.fill_value) 316 | 317 | def apply_to_mask(self, mask, left_pad, rigth_pad, **params): 318 | return F.pad(mask, left_pad, rigth_pad, self.border_mode, self.mask_fill_value) 319 | 320 | @property 321 | def targets_as_params(self): 322 | return ['ecg'] 323 | 324 | def get_params_dependent_on_targets(self, params): 325 | length = params['ecg'].shape[C.SPATIAL_DIM] 326 | 327 | pad_length = max(0, self.min_length - length) 328 | 329 | if self.position == E.PositionType.LEFT: 330 | left_pad = 0 331 | rigth_pad = pad_length 332 | elif self.position == E.PositionType.CENTER: 333 | left_pad = pad_length // 2 334 | rigth_pad = pad_length - left_pad 335 | elif self.position == E.PositionType.RIGHT: 336 | left_pad = pad_length 337 | rigth_pad = 0 338 | else: 339 | left_pad = np.random.randint(0, pad_length + 1) 340 | rigth_pad = pad_length - left_pad 341 | 342 | return {'left_pad': left_pad, 'rigth_pad': rigth_pad} 343 | 344 | def get_transform_init_args_names(self): 345 | return ('min_length', 'position', 'border_mode', 'fill_value', 'mask_fill_value') 346 | 347 | class Pooling(EcgOnlyAugmentation): 348 | """Reduce resolution of time axis of the input ecg 349 | """ 350 | def __init__( 351 | self, 352 | reduction=E.ReductionType.MEAN, 353 | kernel_size_range=(3, 5), 354 | always_apply=False, 355 | p=0.5 356 | ): 357 | """ 358 | :args: 359 | reduction: ReductionType or str 360 | reduction type (MIN, MEAN, MAX) 361 | kernel_size_range: (int, int) 362 | range for select kernel size of blur filter 363 | """ 364 | super(Pooling, self).__init__(always_apply, p) 365 | 366 | self.reduction = E.ReductionType(reduction) 367 | 368 | self.kernel_size_range = M.prepare_int_asymrange(kernel_size_range, 'kernel_size_range', 0) 369 | 370 | self.min_kernel_size = kernel_size_range[0] 371 | self.max_kernel_size = kernel_size_range[1] 372 | 373 | if self.min_kernel_size % 2 == 0 or self.max_kernel_size % 2 == 0: 374 | raise ValueError('Invalid range borders. Must be odd, but got: {}.'.format(kernel_size_range)) 375 | 376 | def apply(self, ecg, kernel_size, **params): 377 | return F.pooling(ecg, self.reduction, kernel_size, E.BorderType.CONSTANT, 0) 378 | 379 | def get_params(self): 380 | kernel_size = 2 * np.random.randint(self.min_kernel_size // 2, self.max_kernel_size // 2 + 1) + 1 381 | 382 | return {'kernel_size': kernel_size} 383 | 384 | def get_transform_init_args_names(self): 385 | return ('reduction', 'kernel_size_range') 386 | 387 | class Blur(Pooling): 388 | """Blur the input ecg. 389 | """ 390 | def __init__( 391 | self, 392 | kernel_size_range=(3, 5), 393 | always_apply=False, 394 | p=0.5 395 | ): 396 | """ 397 | :args: 398 | kernel_size_range: (int, int) 399 | range for select kernel size of blur filter 400 | """ 401 | super(Blur, self).__init__(E.ReductionType.MEAN, kernel_size_range, always_apply, p) 402 | --------------------------------------------------------------------------------