├── LICENSE ├── README.md ├── reports ├── README.md └── abstract.png └── src ├── README.md ├── clinical_ts ├── a.md ├── config.py ├── data │ ├── a.md │ ├── collate_dataset.py │ ├── time_series_dataset.py │ ├── time_series_dataset_transforms.py │ └── time_series_dataset_utils.py ├── head │ ├── a.md │ └── multimodal.py ├── loss │ ├── a.md │ └── supervised.py ├── metric │ ├── a.md │ └── base.py ├── tabular │ ├── a.md │ └── base.py ├── task │ ├── a.md │ └── multimodal.py ├── template_model.py ├── template_modules.py ├── timeseries_utils.py ├── ts │ ├── a.md │ ├── base.py │ ├── basic_conv1d_modules │ │ ├── a.md │ │ └── basic_conv1d.py │ ├── encoder.py │ ├── head.py │ ├── s4.py │ ├── s4_modules │ │ ├── a.md │ │ ├── s42.py │ │ ├── s4_model.py │ │ └── s4_utils.py │ └── transformer_modules │ │ ├── a.md │ │ └── transformer.py └── utils │ ├── a.md │ ├── bootstrap_utils.py │ ├── callbacks.py │ ├── eval_utils_cafa.py │ └── schedulers.py ├── config ├── config_supervised_multimodal_labvalues_s4.yaml └── data │ ├── a.md │ └── multimodal_mdsed.yaml ├── data └── memmap │ └── a.md ├── ecg_utils.py ├── environment.yml ├── extensions └── cauchy │ ├── a.md │ ├── benchmark_cauchy.py │ ├── benchmark_cauchy_tune.py │ ├── cauchy.cpp │ ├── cauchy.py │ ├── cauchy_cuda.cu │ ├── map.h │ ├── setup.py │ ├── test_cauchy.py │ ├── tune_cauchy.py │ ├── tune_cauchy.sh │ ├── tuner.py │ └── tuning_setup.py ├── main_all.py ├── preprocessing.py └── timeseries_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AI4HealthUOL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is the official repository for CardioLab. A machine and deep learning framework for the estimation and monitoring of laboratory values throught ECG data. 2 | 3 | CardioLab have been proposed in two main manuscrips: 4 | 5 | 1. **CardioLab: Laboratory Values Estimation from Electrocardiogram Features - An Exploratory Study.** Accepted by the international conference of computing in cardiology (CinC) 2024. [![arXiv](https://img.shields.io/badge/arXiv-1234.56789-b31b1b.svg)](https://arxiv.org/abs/2407.18629) 6 | 7 | 2. **CardioLab: Laboratory Values Estimation and Monitoring from Electrocardiogram Signals - A Deep-Multimodal Approach** [![arXiv](https://img.shields.io/badge/arXiv-1234.56789-b31b1b.svg)](https://arxiv.org/abs/2411.14886) 8 | 9 | 10 | In terms of ECG data and clinical settings, our CinC manuscript investigate only abnormalities estimation (current) task with ECG tabular features, whereas our second manuscript investigate abnormalities estimation (current) as well as abnormalities monitoring (future) using ECG raw waveforms instead. 11 | 12 | 13 | ## Clinical Setting 14 | 15 | ![alt text](https://github.com/AI4HealthUOL/KardioLab/blob/main/reports/abstract.png?style=centerme) 16 | 17 | 18 | - A) Demonstrates the overall predictive workflow used in the study, where for model inputs we use ECG waveforms, demographics, biometrics, and vital signs, in a binary classification setting to predict abnormal laboratory values. 19 | 20 | - B) Demonstrates the estimation task, where for feature space we sample the closest vital signs within 30 minutes of the ECG record, and the target is the closest laboratory value within 60 minutes. 21 | 22 | - C) Demonstrates the monitoring task, where the feature space also includes the closest vital signs within 30 minutes of the ECG record, and the target is the presence of any abnormal laboratory value within a defined future time horizon, for which we investigated 30, 60, and 120 minutes. 23 | 24 | 25 | 26 | ## References 27 | 28 | ```bibtex 29 | @misc{alcaraz2024cardiolablaboratoryvaluesestimation, 30 | title={CardioLab: Laboratory Values Estimation from Electrocardiogram Features -- An Exploratory Study}, 31 | author={Juan Miguel Lopez Alcaraz and Nils Strodthoff}, 32 | year={2024}, 33 | eprint={2407.18629}, 34 | archivePrefix={arXiv}, 35 | primaryClass={eess.SP}, 36 | url={https://arxiv.org/abs/2407.18629}, 37 | } 38 | ``` 39 | 40 | ```bibtex 41 | @misc{alcaraz2024cardiolablaboratoryvaluesestimation, 42 | title={CardioLab: Laboratory Values Estimation and Monitoring from Electrocardiogram Signals -- A Multimodal Deep Learning Approach}, 43 | author={Juan Miguel Lopez Alcaraz and Nils Strodthoff}, 44 | year={2024}, 45 | eprint={2411.14886}, 46 | archivePrefix={arXiv}, 47 | primaryClass={eess.SP}, 48 | url={https://arxiv.org/abs/2411.14886}, 49 | } 50 | ``` 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /reports/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /reports/abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4HealthUOL/CardioLab/7dffd82df50b3dd3cb9c8d561fb364bf58dd88f0/reports/abstract.png -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Replicating the Experiments 2 | 3 | This repository provides scripts to preprocess ECG data, train models, and analyze results. Follow the steps below to replicate the experiments. 4 | 5 | ## 1. Download Required Datasets 6 | Before running the scripts, download the following files and place them in the root directory of this repository: 7 | 8 | 9 | 10 | - **patients.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/) 11 | - **d_labitems.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/) 12 | - **labevents.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/) 13 | - **omr.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/) 14 | 15 | - **edstays.csv.gz** from [MIMIC-IV-ED](https://physionet.org/content/mimic-iv-ed/2.2/) 16 | - **vitalsign.csv.gz** from [MIMIC-IV-ED](https://physionet.org/content/mimic-iv-ed/2.2/) 17 | 18 | - **records_w_diag_icd10.csv** from [MIMIC-IV-ECG-ICD](https://www.physionet.org/content/mimic-iv-ecg-ext-icd-labels/1.0.1/) 19 | 20 | - **record_list.csv** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/) 21 | - **machine_measurements.csv** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/) 22 | - **machine_measurements_data_dictionary.csv** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/) 23 | - **mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0.zip** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/) 24 | 25 | 26 | ## 2. Run Preprocessing 27 | First, preprocess the data by running: 28 | 29 | ```bash 30 | python preprocessing.py 31 | ``` 32 | 33 | ## 3. Train the Model 34 | After preprocessing is complete, train and test the model by running: 35 | 36 | ```bash 37 | python main_all.py --config config/config_supervised_multimodal_labvalues_s4.yaml 38 | ``` 39 | 40 | ## 4. Output and Results 41 | The results, including model performance metrics and analyses, will be saved automatically in the the current directory. 42 | 43 | --- 44 | 45 | For further inquiries, please open an issue. 46 | -------------------------------------------------------------------------------- /src/clinical_ts/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/config.py: -------------------------------------------------------------------------------- 1 | from hydra.core.config_store import ConfigStore 2 | from dataclasses import dataclass 3 | 4 | #for base configs 5 | from .template_modules import * 6 | 7 | #for specific configs 8 | from .ts.encoder import * 9 | from .ts.head import * 10 | from .ts.base import * 11 | 12 | ################################################################## 13 | # try to import (optional) modules with special dependencies 14 | S4_AVAILABLE = True 15 | try: 16 | from .ts.s4 import * 17 | except: 18 | print("WARNING: Could not import s4 module") 19 | S4_AVAILABLE = False 20 | 21 | from .tabular.base import * 22 | 23 | from .head.multimodal import * 24 | 25 | from .loss.supervised import * 26 | 27 | from .metric.base import * 28 | 29 | from .task.multimodal import * 30 | 31 | ########################################################################################################### 32 | # https://hydra.cc/docs/tutorials/structured_config/config_groups/ 33 | @dataclass 34 | class FullConfig: 35 | 36 | base: BaseConfig 37 | data: BaseConfigData 38 | loss: LossConfig 39 | metric: MetricConfig 40 | trainer: TrainerConfig 41 | task: TaskConfig 42 | 43 | ts: TimeSeriesEncoderConfig 44 | static: EncoderStaticBaseConfig 45 | head: HeadBaseConfig 46 | 47 | 48 | def create_default_config(): 49 | cs = ConfigStore.instance() 50 | cs.store(name="config", node=FullConfig) 51 | 52 | ###################################################################### 53 | # base 54 | ###################################################################### 55 | cs.store(group="base", name="base", node=BaseConfig) 56 | 57 | ###################################################################### 58 | # input data 59 | ###################################################################### 60 | cs.store(group="data", name="base", node=BaseConfigData) 61 | 62 | ###################################################################### 63 | # time series encoder 64 | ###################################################################### 65 | cs.store(group="ts", name="tsenc", node=TimeSeriesEncoderConfig) 66 | 67 | #ENCODER 68 | cs.store(group="ts/enc", name="none", node=NoEncoderConfig) 69 | 70 | #PREDICTOR 71 | cs.store(group="ts/pred", name="none", node=NoPredictorConfig)#no predictor 72 | if(S4_AVAILABLE): 73 | cs.store(group="ts/pred", name="s4", node=S4PredictorConfig)#S4 model 74 | 75 | #HEADS 76 | cs.store(group="ts/head", name="none", node=HeadBaseConfig) 77 | cs.store(group="ts/head", name="pool", node=PoolingHeadConfig) 78 | 79 | #SSL HEADS 80 | cs.store(group="ts/head_ssl", name="none", node=HeadBaseConfig) 81 | 82 | #QUANTIZER 83 | cs.store(group="ts/quant", name="none", node=QuantizerBaseConfig) 84 | 85 | #MASK 86 | cs.store(group="ts/mask", name="none", node=MaskingBaseConfig) 87 | 88 | #LOSS 89 | cs.store(group="ts/loss", name="none", node=SSLLossConfig) 90 | 91 | #PRE 92 | cs.store(group="ts/pre", name="none", node=PrePostBaseConfig) 93 | 94 | #POST 95 | cs.store(group="ts/pre", name="none", node=PrePostBaseConfig) 96 | 97 | ###################################################################### 98 | # static encoder 99 | ###################################################################### 100 | for g in ["static", "ts/static"]: 101 | cs.store(group=g, name="none", node=EncoderStaticBaseConfig) 102 | cs.store(group=g, name="mlp", node=BasicEncoderStaticMLPConfig) 103 | 104 | ###################################################################### 105 | # optional multimodal head 106 | ###################################################################### 107 | cs.store(group="head", name="none", node=HeadBaseConfig) 108 | cs.store(group="head", name="concat", node=ConcatFusionHeadConfig) 109 | 110 | ###################################################################### 111 | # loss function 112 | ###################################################################### 113 | #no global loss 114 | cs.store(group="loss", name="none", node=LossConfig) 115 | #supervised losses 116 | cs.store(group="loss", name="bce", node=BCELossConfig) 117 | cs.store(group="loss", name="bcef", node=BCEFLossConfig) 118 | 119 | ###################################################################### 120 | # metrics 121 | ###################################################################### 122 | cs.store(group="metric", name="none", node=MetricConfig) 123 | cs.store(group="metric", name="auroc", node=MetricAUROCConfig) 124 | cs.store(group="metric", name="aurocagg", node=MetricAUROCAggConfig) 125 | 126 | ###################################################################### 127 | # trainer 128 | ###################################################################### 129 | cs.store(group="trainer", name="trainer", node=TrainerConfig) 130 | 131 | ###################################################################### 132 | # task 133 | ###################################################################### 134 | cs.store(group="task", name="none", node=TaskConfig) 135 | cs.store(group="task", name="multi", node=TaskConfigMultimodal) 136 | 137 | return cs 138 | -------------------------------------------------------------------------------- /src/clinical_ts/data/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/data/collate_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | 3 | class CollateDataset(torch.utils.data.Dataset): 4 | r"""Dataset for collating several existing datasets 5 | """ 6 | 7 | def __init__(self, datasets, sample_idx_from_first_dataset=False) -> None: 8 | super().__init__() 9 | self.datasets = datasets 10 | self.sample_idx_from_first_dataset = sample_idx_from_first_dataset 11 | 12 | def __getitem__(self, idx): 13 | if(self.sample_idx_from_first_dataset): 14 | sample_idx = self.datasets[0][idx][0] 15 | res = tuple(self.datasets[0][idx][1:]) 16 | else: 17 | sample_idx = idx 18 | res = tuple(self.datasets[0][idx]) 19 | 20 | for d in self.datasets[1:]: 21 | res += tuple(d[sample_idx]) 22 | return res 23 | 24 | def __len__(self): 25 | return len(self.datasets[0]) -------------------------------------------------------------------------------- /src/clinical_ts/data/time_series_dataset.py: -------------------------------------------------------------------------------- 1 | __all__ = ['tsdata_seq','tsdata_seq_static','tsdata_seq_idxs','tsdata_seq_static_idxs','ConcatTimeSeriesDataset','TimeSeriesDataset','TimeSeriesDatasetConfig'] 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | from .time_series_dataset_transforms import Compose 7 | 8 | #Note: due to issues with the numpy rng for multiprocessing (https://github.com/pytorch/pytorch/issues/5059) that could be fixed by a custom worker_init_fn we use random throught for convenience 9 | import random 10 | 11 | #Note: multiprocessing issues with python lists and dicts (https://github.com/pytorch/pytorch/issues/13246) and pandas dfs (https://github.com/pytorch/pytorch/issues/5902) 12 | #import multiprocessing as mp 13 | 14 | from dataclasses import dataclass 15 | import pandas as pd 16 | from typing import Union, Any 17 | import pathlib 18 | 19 | 20 | from collections import namedtuple 21 | 22 | tsdata_seq = namedtuple("tsdata_seq",("seq","label")) 23 | tsdata_seq_static = namedtuple("tsdata_seq_static",("seq","label","static")) 24 | tsdata_seq_cat = namedtuple("tsdata_seq_cat",("seq","label","static_cat")) 25 | tsdata_seq_static_cat = namedtuple("tsdata_seq_static_cat",("seq","label","static","static_cat")) 26 | 27 | tsdata_seq_idxs = namedtuple("tsdata_seq_idxs",("seq","label","seq_idxs")) 28 | tsdata_seq_cat_idxs = namedtuple("tsdata_seq_cat_idxs",("seq","label","static_cat","seq_idxs")) 29 | tsdata_seq_static_idxs = namedtuple("tsdata_seq_static_idxs",("seq","label","static","seq_idxs")) 30 | tsdata_seq_static_cat_idxs = namedtuple("tsdata_seq_static_cat_idxs",("seq","label","static","static_cat","seq_idxs")) 31 | 32 | def arrays_equal_with_nans(arr1, arr2): 33 | '''helper function to compare arrays with nans''' 34 | return ((arr1 == arr2) | (np.isnan(arr1) & np.isnan(arr2))).all() 35 | 36 | class ConcatTimeSeriesDataset(torch.utils.data.ConcatDataset): 37 | '''ConcatDataset that handles id mapping correctly (to allow to aggregate predictions)''' 38 | def __init__(self, datasets): 39 | super().__init__(datasets) 40 | idmaps = [] 41 | for dataset_idx,ds in enumerate(self.datasets): 42 | idmap = ds.get_id_mapping() 43 | remap_dict = {x:j+(self.cumulative_sizes[dataset_idx-1] if dataset_idx>0 else 0) for j,x in enumerate(np.unique(idmap))} 44 | idmaps.append(np.array([remap_dict[x] for x in idmap])) 45 | self.df_idx_mapping = np.concatenate(idmaps) 46 | 47 | def get_id_mapping(self): 48 | return self.df_idx_mapping 49 | 50 | def aggregate_predictions(self, preds,targs,idmap=None,aggregate_fn = np.mean,verbose=False): 51 | return self.datasets[0].aggregate_predictions(preds, targs, self.df_idx_mapping if idmap is None else idmap, aggregate_fn, verbose) 52 | 53 | 54 | class TimeSeriesDataset(torch.utils.data.Dataset): 55 | """timeseries dataset with partial crops.""" 56 | 57 | def __init__(self, hparams): 58 | """ 59 | accepts three kinds of input: 60 | 1) filenames pointing to aligned numpy arrays [timesteps,channels,...] for data and either integer labels or filename pointing to numpy arrays[timesteps,...] e.g. for annotations 61 | 2) memmap_filename to memmap file (same argument that was passed to reformat_as_memmap) for data [concatenated,...] and labels- data column in df corresponds to index in this memmap; memmap_label_filename can normally kept as None (will use the memmap_label file in the same directory in this case) 62 | 3) npy_data [samples,ts,...] (either path or np.array directly- also supporting variable length input) - data column in df corresponds to sampleid 63 | 64 | transforms: list of callables (transformations) or single instance e.g. from torchvision.transforms.Compose (applied in the specified order i.e. leftmost element first) 65 | 66 | col_lbl = None: return dummy label 0 (e.g. for unsupervised pretraining) 67 | cols_static: (optional) list of cols with extra static information (continuous-valued) 68 | cols_static_cat: (optional) list of cols with extra static information (categorical) 69 | fs_annotation_over_fs_data over ratio of sampling frequencies 70 | return_idxs: returns sample_idx from the underlying dataframe and start_idx and end_idx within the sequence (for aligned sequences e.g. spectra or for certain contrastive approaches such as ts2vec) 71 | """ 72 | super().__init__() 73 | assert not((hparams.memmap_filename is not None) and (hparams.npy_data is not None)) 74 | # require integer entries if using memmap or npy 75 | assert (hparams.memmap_filename is None and hparams.npy_data is None) or (hparams.df[hparams.col_data].dtype==np.int64 or hparams.df[hparams.col_data].dtype==np.int32 or hparams.df[hparams.col_data].dtype==np.int16) 76 | # keys (in column data) have to be unique 77 | assert(hparams.allow_multiple_keys or len(hparams.df[hparams.col_data].unique())==len(hparams.df)) 78 | 79 | self.timeseries_df_data = np.array(hparams.df[hparams.col_data]) 80 | if(self.timeseries_df_data.dtype not in [np.int16, np.int32, np.int64]): 81 | assert(hparams.memmap_filename is None and hparams.npy_data is None) #only for filenames in mode files 82 | self.timeseries_df_data = np.array(hparams.df[hparams.col_data].astype(str)).astype(np.string_) 83 | 84 | if(hparams.col_lbl is None):# use dummy labels 85 | self.timeseries_df_label = np.zeros(len(hparams.df)) 86 | else: # use actual labels 87 | if(isinstance(hparams.df[hparams.col_lbl].iloc[0],list) or isinstance(hparams.df[hparams.col_lbl].iloc[0],np.ndarray)):#stack arrays/lists for proper batching 88 | self.timeseries_df_label = np.stack(hparams.df[hparams.col_lbl]) 89 | else: # single integers/floats 90 | self.timeseries_df_label = np.array(hparams.df[hparams.col_lbl]) 91 | 92 | if(not(hparams.annotation and hparams.memmap_filename is not None)):#skip if memmap and annotation 93 | if(self.timeseries_df_label.dtype not in [np.int16, np.int32, np.int64, np.float32, np.float64]): #everything else cannot be batched anyway mp.Manager().list(self.timeseries_df_label) 94 | assert(hparams.annotation and hparams.memmap_filename is None and hparams.npy_data is None)#only for filenames in mode files 95 | self.timeseries_df_label = np.array(hparams.df[hparams.col_lbl].apply(lambda x:str(x))).astype(np.string_) 96 | 97 | def concat_columns(row): 98 | return [item for col in row for item in (col if isinstance(col, np.ndarray) else [col])] 99 | 100 | if(hparams.cols_static is not None): 101 | self.timeseries_df_static = np.array(hparams.df[hparams.cols_static].apply(concat_columns, axis=1).to_list()) 102 | self.timeseries_df_static = np.squeeze(self.timeseries_df_static,axis=1)#remove unit axis 103 | self.static = True 104 | else: 105 | self.static = False 106 | 107 | if(hparams.cols_static_cat is not None): 108 | self.timeseries_df_static_cat = np.array(hparams.df[hparams.cols_static_cat].apply(concat_columns, axis=1).to_list()) 109 | self.timeseries_df_static_cat = np.squeeze(self.timeseries_df_static_cat,axis=1)#remove unit axis 110 | self.static_cat = True 111 | else: 112 | self.static_cat = False 113 | 114 | self.output_size = hparams.output_size 115 | self.data_folder = hparams.data_folder 116 | self.transforms = Compose(hparams.transforms) if isinstance(hparams.transforms,list) else hparams.transforms 117 | #if(isinstance(self.transforms,list) or isinstance(self.transforms,np.ndarray)): 118 | # print("Warning: the use of lists as arguments for transforms is discouraged") 119 | self.annotation = hparams.annotation 120 | self.col_lbl = hparams.col_lbl 121 | 122 | self.mode="files" 123 | self.fs_annotation_over_fs_data = hparams.fs_annotation_over_fs_data 124 | self.return_idxs = hparams.return_idxs 125 | 126 | if(hparams.memmap_filename is not None): 127 | self.memmap_meta_filename = hparams.memmap_filename.parent/(hparams.memmap_filename.stem+"_meta.npz") 128 | self.mode="memmap" 129 | memmap_meta = np.load(self.memmap_meta_filename, allow_pickle=True) 130 | self.memmap_start = memmap_meta["start"].astype(np.int64)# cast as integers to be on the safe side 131 | self.memmap_shape = memmap_meta["shape"].astype(np.int64) 132 | self.memmap_length = memmap_meta["length"].astype(np.int64) 133 | self.memmap_file_idx = memmap_meta["file_idx"].astype(np.int64) 134 | self.memmap_dtype = np.dtype(str(memmap_meta["dtype"])) 135 | self.memmap_filenames = np.array(memmap_meta["filenames"]).astype(np.string_)#save as byte to avoid issue with mp 136 | if(hparams.annotation): 137 | #by default use the memmap_label.npy in the same directory as the signal memmap file 138 | memmap_label_filename=hparams.memmap_label_filename if hparams.memmap_label_filename is not None else self.memmap_meta_filename.parent/("_".join(self.memmap_meta_filename.stem.split("_")[:-1])+"_label.npy") 139 | self.memmap_meta_filename_label = hparams.memmap_filename.parent/(memmap_label_filename.stem+"_meta.npz") 140 | memmap_meta_label =np.load(self.memmap_meta_filename_label, allow_pickle=True) 141 | self.memmap_start_label = memmap_meta_label["start"].astype(np.int64) 142 | self.memmap_shape_label = memmap_meta_label["shape"].astype(np.int64) 143 | self.memmap_length_label = memmap_meta_label["length"].astype(np.int64) 144 | self.memmap_file_idx_label = memmap_meta_label["file_idx"].astype(np.int64) 145 | self.memmap_dtype_label = np.dtype(str(memmap_meta_label["dtype"])) 146 | self.memmap_filenames_label = np.array(memmap_meta_label["filenames"]).astype(np.string_) 147 | elif(hparams.npy_data is not None): 148 | self.mode="npy" 149 | if(isinstance(hparams.npy_data,np.ndarray) or isinstance(hparams.npy_data,list)): 150 | self.npy_data = np.array(hparams.npy_data) 151 | assert(hparams.annotation is False) 152 | else: 153 | self.npy_data = np.load(hparams.npy_data, allow_pickle=True) 154 | if(hparams.annotation): 155 | self.npy_data_label = np.load(hparams.npy_data.parent/(hparams.npy_data.stem+"_label.npy"), allow_pickle=True) 156 | 157 | self.random_crop = hparams.random_crop 158 | self.sample_items_per_record = hparams.sample_items_per_record 159 | 160 | self.df_idx_mapping=[] 161 | self.start_idx_mapping=[] 162 | self.end_idx_mapping=[] 163 | 164 | for df_idx,(id,row) in enumerate(hparams.df.iterrows()): 165 | if(self.mode=="files"): 166 | data_length = row["data_length"] 167 | elif(self.mode=="memmap"): 168 | data_length= self.memmap_length[row[hparams.col_data]] 169 | else: #npy 170 | data_length = len(self.npy_data[row[hparams.col_data]]) 171 | 172 | if(hparams.chunk_length == 0):#do not split 173 | idx_start = [hparams.start_idx] 174 | idx_end = [data_length] 175 | else: 176 | idx_start = list(range(hparams.start_idx,data_length,hparams.chunk_length if hparams.stride is None else hparams.stride)) 177 | idx_end = [min(l+hparams.chunk_length, data_length) for l in idx_start] 178 | 179 | #remove final chunk(s) if too short 180 | for i in range(len(idx_start)): 181 | if(idx_end[i]-idx_start[i]< hparams.min_chunk_length): 182 | del idx_start[i:] 183 | del idx_end[i:] 184 | break 185 | #append to lists 186 | for _ in range(hparams.copies+1): 187 | for i_s,i_e in zip(idx_start,idx_end): 188 | self.df_idx_mapping.append(df_idx) 189 | self.start_idx_mapping.append(i_s) 190 | self.end_idx_mapping.append(i_e) 191 | #convert to np.array to avoid mp issues with python lists 192 | self.df_idx_mapping = np.array(self.df_idx_mapping) 193 | self.start_idx_mapping = np.array(self.start_idx_mapping) 194 | self.end_idx_mapping = np.array(self.end_idx_mapping) 195 | 196 | def __len__(self): 197 | return len(self.df_idx_mapping) 198 | 199 | @property 200 | def is_empty(self): 201 | return len(self.df_idx_mapping)==0 202 | 203 | def __getitem__(self, idx): 204 | lst=[] 205 | for _ in range(self.sample_items_per_record): 206 | #determine crop idxs 207 | timesteps= self.get_sample_length(idx) 208 | 209 | if(self.random_crop):#random crop 210 | if(timesteps==self.output_size): 211 | start_idx_rel = 0 212 | else: 213 | start_idx_rel = random.randint(0, timesteps - self.output_size -1)#np.random.randint(0, timesteps - self.output_size) 214 | else: 215 | start_idx_rel = (timesteps - self.output_size)//2 216 | if(self.sample_items_per_record==1): 217 | return self._getitem(idx,start_idx_rel) 218 | else: 219 | lst.append(self._getitem(idx,start_idx_rel)) 220 | return tuple(lst) 221 | 222 | def _getitem(self, idx,start_idx_rel): 223 | #low-level function that actually fetches the data 224 | df_idx = self.df_idx_mapping[idx] 225 | start_idx = self.start_idx_mapping[idx] 226 | end_idx = self.end_idx_mapping[idx] 227 | #determine crop idxs 228 | timesteps= end_idx - start_idx 229 | assert(timesteps>=self.output_size) 230 | start_idx_crop = start_idx + start_idx_rel 231 | end_idx_crop = start_idx_crop+self.output_size 232 | if(self.annotation): 233 | start_idx_crop_label = int(np.round(start_idx_crop*self.fs_annotation_over_fs_data)) 234 | end_idx_crop_label = start_idx_crop_label+int(np.round(self.output_size*self.fs_annotation_over_fs_data)) 235 | 236 | #print(idx,start_idx,end_idx,start_idx_crop,end_idx_crop) 237 | #load the actual data 238 | if(self.mode=="files"):#from separate files 239 | data_filename = str(self.timeseries_df_data[df_idx],encoding='utf-8') #todo: fix potential issues here 240 | if self.data_folder is not None: 241 | data_filename = self.data_folder/data_filename 242 | data = np.load(data_filename, allow_pickle=True)[start_idx_crop:end_idx_crop] #data type has to be adjusted when saving to npy 243 | 244 | ID = data_filename.stem 245 | 246 | if(self.annotation is True): 247 | label_filename = str(self.timeseries_df_label[df_idx],encoding='utf-8') 248 | if self.data_folder is not None: 249 | label_filename = self.data_folder/label_filename 250 | label = np.load(label_filename, allow_pickle=True)[start_idx_crop_label:end_idx_crop_label] #data type has to be adjusted when saving to npy 251 | else: 252 | label = self.timeseries_df_label[df_idx] #input type has to be adjusted in the dataframe 253 | elif(self.mode=="memmap"): #from one memmap file 254 | memmap_idx = self.timeseries_df_data[df_idx] #grab the actual index (Note the df to create the ds might be a subset of the original df used to create the memmap) 255 | memmap_file_idx = self.memmap_file_idx[memmap_idx] 256 | idx_offset = self.memmap_start[memmap_idx] 257 | 258 | #wi = torch.utils.data.get_worker_info() 259 | #pid = 0 if wi is None else wi.id#os.getpid() 260 | #print("idx",idx,"ID",ID,"idx_offset",idx_offset,"start_idx_crop",start_idx_crop,"df_idx", self.df_idx_mapping[idx],"pid",pid) 261 | mem_filename = str(self.memmap_filenames[memmap_file_idx],encoding='utf-8') 262 | mem_file = np.memmap(self.memmap_meta_filename.parent/mem_filename, self.memmap_dtype, mode='r', shape=tuple(self.memmap_shape[memmap_file_idx])) 263 | data = np.copy(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) 264 | del mem_file 265 | #print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) 266 | if(self.annotation): 267 | memmap_file_idx_label = self.memmap_file_idx_label[memmap_idx] 268 | idx_offset_label = self.memmap_start_label[memmap_idx] 269 | 270 | mem_filename_label = str(self.memmap_filenames_label[memmap_file_idx_label],encoding='utf-8') 271 | mem_file_label = np.memmap(self.memmap_meta_filename_label.parent/mem_filename_label, self.memmap_dtype_label, mode='r', shape=tuple(self.memmap_shape_label[memmap_file_idx])) 272 | 273 | label = np.copy(mem_file_label[idx_offset_label + start_idx_crop_label: idx_offset_label + end_idx_crop_label]) 274 | del mem_file_label 275 | else: 276 | label = self.timeseries_df_label[df_idx] 277 | else:#single npy array 278 | ID = self.timeseries_df_data[df_idx] 279 | 280 | data = self.npy_data[ID][start_idx_crop:end_idx_crop] 281 | 282 | if(self.annotation): 283 | label = self.npy_data_label[ID][start_idx_crop:end_idx_crop] 284 | else: 285 | label = self.timeseries_df_label[df_idx] 286 | 287 | sample = (data, label, self.timeseries_df_static[df_idx] if self.static else None, self.timeseries_df_static_cat[df_idx] if self.static_cat else None,np.array([df_idx,start_idx_crop,end_idx_crop])) 288 | 289 | # consistency check: make sure that data and annotation lengths match (check here because transforms might change the shape of the annotation) 290 | assert(self.annotation is False or len(sample[1])==int(np.round(self.fs_annotation_over_fs_data*len(sample[0])))) 291 | sample = self.transforms(sample) 292 | 293 | if(self.return_idxs): 294 | if(self.static is True and self.static_cat is True): 295 | return tsdata_seq_static_cat_idxs(sample[0],sample[1], sample[2], sample[3], sample[4]) 296 | elif(self.static is True): 297 | return tsdata_seq_static_idxs(sample[0],sample[1], sample[2], sample[4]) 298 | elif(self.static_cat is True): 299 | return tsdata_seq_cat_idxs(sample[0],sample[1], sample[3], sample[4]) 300 | else: 301 | return tsdata_seq_idxs(sample[0], sample[1], sample[4]) 302 | else: 303 | if(self.static is True and self.static_cat is True): 304 | return tsdata_seq_static_cat(sample[0],sample[1], sample[2], sample[3]) 305 | elif(self.static is True): 306 | return tsdata_seq_static(sample[0],sample[1], sample[2]) 307 | elif(self.static_cat is True): 308 | return tsdata_seq_cat(sample[0],sample[1], sample[3]) 309 | else: 310 | return tsdata_seq(sample[0], sample[1]) 311 | 312 | 313 | def get_sampling_weights(self, class_weight_dict,length_weighting=False, timeseries_df_group_by_col=None): 314 | ''' 315 | class_weight_dict: dictionary of class weights 316 | length_weighting: weigh samples by length 317 | timeseries_df_group_by_col: column of the pandas df used to create the object''' 318 | assert(self.annotation is False) 319 | assert(length_weighting is False or timeseries_df_group_by_col is None) 320 | weights = np.zeros(len(self.df_idx_mapping),dtype=np.float32) 321 | length_per_class = {} 322 | length_per_group = {} 323 | for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)): 324 | label = self.timeseries_df_label[i] 325 | weight = class_weight_dict[label] 326 | if(length_weighting): 327 | if label in length_per_class.keys(): 328 | length_per_class[label] += e-s 329 | else: 330 | length_per_class[label] = e-s 331 | if(timeseries_df_group_by_col is not None): 332 | group = timeseries_df_group_by_col[i] 333 | if group in length_per_group.keys(): 334 | length_per_group[group] += e-s 335 | else: 336 | length_per_group[group] = e-s 337 | weights[iw] = weight 338 | 339 | if(length_weighting):#need second pass to properly take into account the total length per class 340 | for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)): 341 | label = self.timeseries_df_label[i] 342 | weights[iw]= (e-s)/length_per_class[label]*weights[iw] 343 | if(timeseries_df_group_by_col is not None): 344 | for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)): 345 | group = timeseries_df_group_by_col[i] 346 | weights[iw]= (e-s)/length_per_group[group]*weights[iw] 347 | 348 | weights = weights/np.min(weights)#normalize smallest weight to 1 349 | return weights 350 | 351 | def get_id_mapping(self): 352 | return self.df_idx_mapping 353 | 354 | def get_sample_id(self,idx): 355 | return self.df_idx_mapping[idx] 356 | 357 | def get_sample_length(self,idx): 358 | return self.end_idx_mapping[idx]-self.start_idx_mapping[idx] 359 | 360 | def get_sample_start(self,idx): 361 | return self.start_idx_mapping[idx] 362 | 363 | def aggregate_predictions(self, preds,targs=None,idmap=None,aggregate_fn = np.mean,verbose=False): 364 | ''' 365 | aggregates potentially multiple predictions per sample (can also pass targs for convenience) 366 | idmap: idmap as returned by TimeSeriesCropsDataset's get_id_mapping (uses self.get_id_mapping by default) 367 | preds: ordered predictions as returned by learn.get_preds() 368 | aggregate_fn: function that is used to aggregate multiple predictions per sample (most commonly np.amax or np.mean) 369 | ''' 370 | idmap = self.get_id_mapping() if idmap is None else idmap 371 | if(idmap is not None and len(idmap)!=len(np.unique(idmap))): 372 | if(verbose): 373 | print("aggregating predictions...") 374 | preds_aggregated = [] 375 | targs_aggregated = [] 376 | for i in np.unique(idmap): 377 | preds_local = preds[np.where(idmap==i)[0]] 378 | preds_aggregated.append(aggregate_fn(preds_local,axis=0)) 379 | if targs is not None: 380 | targs_local = targs[np.where(idmap==i)[0]] 381 | #assert(np.all(targs_local==targs_local[0])) #all labels have to agree 382 | assert(np.all([np.array_equal(t, targs_local[0], equal_nan=True) for t in targs_local])) #all labels have to agree (including nans) 383 | targs_aggregated.append(targs_local[0]) 384 | if(targs is None): 385 | return np.array(preds_aggregated) 386 | else: 387 | return np.array(preds_aggregated),np.array(targs_aggregated) 388 | else: 389 | if(targs is None): 390 | return preds 391 | else: 392 | return preds,targs 393 | 394 | @dataclass 395 | class TimeSeriesDatasetConfig: 396 | df:pd.DataFrame 397 | output_size:int 398 | chunk_length:int 399 | min_chunk_length:int 400 | memmap_filename:Union[str,pathlib.PosixPath,None]=None 401 | memmap_label_filename:Union[str,pathlib.PosixPath,None]=None 402 | npy_data:Union[np.ndarray,list,None]=None 403 | random_crop:bool=True 404 | data_folder:Union[str,pathlib.PosixPath,None]=None 405 | copies:int=0 406 | col_data:str="data" 407 | col_lbl:Union[str,None]="label" 408 | cols_static:Union[str,None]=None 409 | cols_static_cat:Union[str,None]=None 410 | stride:Union[int,None]=None 411 | start_idx:int=0 412 | annotation:bool=False 413 | transforms:Any=None 414 | sample_items_per_record:int=1 415 | fs_annotation_over_fs_data:float=1. 416 | return_idxs:bool=False 417 | allow_multiple_keys:bool=False #in the df allow multiple rows with identical IDs -------------------------------------------------------------------------------- /src/clinical_ts/data/time_series_dataset_transforms.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Compose', 'RandomCrop', 'CenterCrop', 'FixedCrop', 'GaussianNoise', 'Resample', 'ToSpectrogram', 'Flatten', 'ToTensor', 'Normalize', 'NormalizeBatch', 'ButterFilter', 'ChannelFilter', 'Transform', 'StaticTransform', 'TupleTransform', 'SequenceToSampleLabelTransform'] 2 | 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | import math 9 | import random 10 | import resampy 11 | 12 | #from skimage import transform 13 | 14 | #import warnings 15 | #warnings.filterwarnings("ignore", category=UserWarning) 16 | 17 | from scipy.signal import butter, sosfilt, sosfiltfilt 18 | from scipy import signal 19 | 20 | #from scipy.interpolate import interp1d 21 | 22 | 23 | 24 | #def nn_upsample(xin, yin, xout): 25 | # '''performs nearest neighbor upsampling of the integer array yin with values at xin for new datapoints at xout''' 26 | # f = interp1d(xin,yin, kind="nearest",bounds_error=False,fill_value="extrapolate") 27 | # return f(xout).astype(np.int64) 28 | 29 | #def resample_labels(startpts, labels, startpts_to_mid, startpts_new, startpts_to_mid_new): 30 | # '''resamples integer labels labels at starpts+startpts_to_mid to new anchor points at startpts_new+startpts_to_mid_new''' 31 | # if(isinstance(startpts_to_mid,float) or isinstance(startpts_to_mid,int)): 32 | # startpts_to_mid = np.ones_like(startpts)*startpts_to_mid 33 | # if(isinstance(startpts_to_mid_new,float) or isinstance(startpts_to_mid_new,int)): 34 | # startpts_to_mid_new = np.ones_like(startpts_new)*startpts_to_mid_new 35 | # midpts = np.array(startpts)+startpts_to_mid 36 | # midpts_new = np.array(startpts_new)+startpts_to_mid_new 37 | # return nn_upsample(midpts, labels, midpts_new) 38 | 39 | #https://stackoverflow.com/questions/12093594/how-to-implement-band-pass-butterworth-filter-with-scipy-signal-butter 40 | def butter_filter(lowcut=10, highcut=20, fs=50, order=5, btype='band'): 41 | '''returns butterworth filter with given specifications''' 42 | nyq = 0.5 * fs 43 | low = lowcut / nyq 44 | high = highcut / nyq 45 | 46 | sos = butter(order, [low, high] if btype=="band" else (low if btype=="low" else high), analog=False, btype=btype, output='sos') 47 | return sos 48 | 49 | #def butter_filter_frequency_response(filter): 50 | # '''returns frequency response of a given filter (result of call of butter_filter)''' 51 | # w, h = sosfreqz(filter) 52 | # #gain vs. freq(Hz) 53 | # #plt.plot((fs * 0.5 / np.pi) * w, abs(h)) 54 | # return w,h 55 | # 56 | #def apply_butter_filter(data, filter, forwardbackward=True): 57 | # '''pass filter from call of butter_filter to data (assuming time axis at dimension 0)''' 58 | # if(forwardbackward): 59 | # return sosfiltfilt(filter, data, axis=0) 60 | # else: 61 | # data = sosfilt(filter, data, axis=0) 62 | 63 | class Compose: 64 | '''composes several transformations into a single one (as provided by torchvision.transforms.Compose)''' 65 | def __init__(self, transforms): 66 | self.transforms = transforms 67 | 68 | def __call__(self, inp): 69 | for t in self.transforms: 70 | inp = t(inp) 71 | return inp 72 | 73 | 74 | class RandomCrop(object): 75 | """Crop randomly from a sample. 76 | """ 77 | 78 | def __init__(self, output_size,annotation=False): 79 | self.output_size = output_size 80 | self.annotation = annotation 81 | 82 | def __call__(self, sample): 83 | data, label, static, static_cat, idxs = sample 84 | 85 | timesteps= len(data) 86 | assert(timesteps>=self.output_size) 87 | if(timesteps==self.output_size): 88 | start=0 89 | else: 90 | start = random.randint(0, timesteps - self.output_size-1) #np.random.randint(0, timesteps - self.output_size) 91 | idxs[1]+=start 92 | idxs[2]=idxs[1]+self.output_size 93 | 94 | data = data[start: start + self.output_size] 95 | if(self.annotation): 96 | label = label[start: start + self.output_size] 97 | 98 | return (data, label, static, static_cat, idxs) 99 | 100 | 101 | class CenterCrop(object): 102 | """Center crop from a sample. 103 | """ 104 | 105 | def __init__(self, output_size, annotation=False): 106 | self.output_size = output_size 107 | self.annotation = annotation 108 | 109 | def __call__(self, sample): 110 | data, label, static, static_cat, idxs = sample 111 | 112 | timesteps= len(data) 113 | start = (timesteps - self.output_size)//2 114 | idxs[1] += start 115 | idxs[2] = idxs[1] + self.output_size 116 | 117 | data = data[start: start + self.output_size] 118 | if(self.annotation): 119 | label = label[start: start + self.output_size] 120 | 121 | return (data, label, static, static_cat, idxs) 122 | 123 | class FixedCrop(object): 124 | """Take a fixed crop from a sample (for aligned data e.g. spectrograms). start_idx and end_idx are relative to the start of the respective sample 125 | """ 126 | 127 | def __init__(self, start_idx, end_idx, annotation=False): 128 | self.start_idx = start_idx 129 | self.end_idx = end_idx 130 | self.output_size = self.end_idx-self.start_idx 131 | self.annotation = annotation 132 | 133 | def __call__(self, sample): 134 | data, label, static, static_cat, idxs = sample 135 | assert(self.end_idx [x,y,]ch, seq 256 | return torch.from_numpy(np.moveaxis(data,0,-1)) 257 | else: 258 | return torch.from_numpy(data) 259 | else:#default_collate will take care of it 260 | return data 261 | 262 | data, label, static, static_cat, idxs = sample 263 | if not isinstance(data,tuple): 264 | data = _to_tensor(data,self.transpose_data) 265 | else: 266 | data = tuple(_to_tensor(x,self.transpose_data) for x in data) 267 | 268 | if not isinstance(label,tuple): 269 | label = _to_tensor(label,self.transpose_label) 270 | else: 271 | label = tuple(_to_tensor(x,self.transpose_label) for x in label) 272 | 273 | if not isinstance(static,tuple): 274 | static = _to_tensor(static) 275 | else: 276 | static = tuple(_to_tensor(x) for x in static) 277 | 278 | if not isinstance(static_cat,tuple): 279 | static_cat = _to_tensor(static_cat) 280 | else: 281 | static_cat = tuple(_to_tensor(x) for x in static_cat) 282 | 283 | if not isinstance(idxs,tuple): 284 | idxs = _to_tensor(idxs) 285 | else: 286 | idxs = tuple(_to_tensor(x) for x in idxs) 287 | 288 | return (data, label, static, static_cat, idxs) #returning as a tuple (potentially of lists) 289 | 290 | 291 | class Normalize(object): 292 | """Normalize using given stats. 293 | """ 294 | def __init__(self, stats_mean, stats_std, input=True, channels=[]): 295 | self.stats_mean=stats_mean.astype(np.float32) if stats_mean is not None else None 296 | self.stats_std=stats_std.astype(np.float32)+1e-8 if stats_std is not None else None 297 | self.input = input 298 | if(len(channels)>0): 299 | for i in range(len(stats_mean)): 300 | if(not(i in channels)): 301 | self.stats_mean[:,i]=0 302 | self.stats_std[:,i]=1 303 | 304 | def __call__(self, sample): 305 | datax, labelx, static, static_cat, idxs = sample 306 | data = datax if self.input else labelx 307 | #assuming channel last 308 | if(self.stats_mean is not None): 309 | data = data - self.stats_mean 310 | if(self.stats_std is not None): 311 | data = data/self.stats_std 312 | 313 | if(self.input): 314 | return (data, labelx, static, static_cat, idxs) 315 | else: 316 | return (datax, data, static, static_cat, idxs) 317 | 318 | 319 | class NormalizeBatch(object): 320 | """Normalize using batch statistics. 321 | axis: tuple of integers of axis numbers to be normalized over (by default everything but the last) 322 | """ 323 | def __init__(self, input=True, channels=[],axis=None): 324 | self.channels = channels 325 | self.channels_keep = None 326 | self.input = input 327 | self.axis = axis 328 | 329 | def __call__(self, sample): 330 | datax, labelx, static, static_cat, idxs = sample 331 | data = datax if self.input else labelx 332 | #assuming channel last 333 | #batch_mean = np.mean(data,axis=tuple(range(0,len(data)-1))) 334 | #batch_std = np.std(data,axis=tuple(range(0,len(data)-1)))+1e-8 335 | batch_mean = np.mean(data,axis=self.axis if self.axis is not None else tuple(range(0,len(data.shape)-1))) 336 | batch_std = np.std(data,axis=self.axis if self.axis is not None else tuple(range(0,len(data.shape)-1)))+1e-8 337 | 338 | if(len(self.channels)>0): 339 | if(self.channels_keep is None): 340 | self.channels_keep = np.setdiff(range(data.shape[-1]),self.channels) 341 | 342 | batch_mean[self.channels_keep]=0 343 | batch_std[self.channels_keep]=1 344 | 345 | data = (data - batch_mean)/batch_std 346 | 347 | if(self.input): 348 | return (data, labelx, static, static_cat, idxs) 349 | else: 350 | return (datax, data, static, static_cat, idxs) 351 | 352 | 353 | class ButterFilter(object): 354 | """Apply filter 355 | """ 356 | 357 | def __init__(self, lowcut=50, highcut=50, fs=100, order=5, btype='band', forwardbackward=True, input=True): 358 | self.filter = butter_filter(lowcut,highcut,fs,order,btype) 359 | self.input = input 360 | self.forwardbackward = forwardbackward 361 | 362 | def __call__(self, sample): 363 | datax, labelx, static, static_cat, idxs = sample 364 | data = datax if self.input else labelx 365 | 366 | if(self.forwardbackward): 367 | data = sosfiltfilt(self.filter, data, axis=0) 368 | else: 369 | data = sosfilt(self.filter, data, axis=0) 370 | 371 | if(self.input): 372 | return (data, labelx, static, static_cat, idxs) 373 | else: 374 | return (datax, data, static, static_cat, idxs) 375 | 376 | 377 | class ChannelFilter(object): 378 | """Select certain channels. 379 | axis: axis index of the channel axis 380 | """ 381 | 382 | def __init__(self, channels=[0], axis=-1, input=True): 383 | self.channels = channels 384 | self.input = input 385 | self.axis = axis 386 | 387 | def __call__(self, sample): 388 | data, label, static, static_cat, idxs = sample 389 | if(self.input): 390 | return (np.take(data,self.channels,axis=self.axis), label, static, static_cat, idxs) #(data[...,self.channels], label, static) 391 | else: 392 | return (data, np.take(label,self.channels,axis=self.axis), static, static_cat, idxs) 393 | 394 | 395 | class Transform(object): 396 | """Transforms data using a given function i.e. data_new = func(data) for input is True else label_new = func(label) 397 | """ 398 | 399 | def __init__(self, func, input=False): 400 | self.func = func 401 | self.input = input 402 | 403 | def __call__(self, sample): 404 | data, label, static, static_cat, idxs = sample 405 | if(self.input): 406 | return (self.func(data), label, static, static_cat, idxs) 407 | else: 408 | return (data, self.func(label), static, static_cat, idxs) 409 | 410 | class StaticTransform(object): 411 | """Transforms static data using a given function i.e. data_new = func(data) for input is True else label_new = func(label) 412 | """ 413 | def __init__(self, func): 414 | self.func = func 415 | 416 | def __call__(self, sample): 417 | data, label, static, static_cat, idxs = sample 418 | static, static_cat = self.func(static, static_cat) 419 | return (data, label, static, static_cat, idxs) 420 | 421 | class TupleTransform(object): 422 | """Transforms data using a given function (operating on both data and label and return a tuple) i.e. data_new, label_new = func(data_old, label_old) 423 | """ 424 | 425 | def __init__(self, func): 426 | self.func = func 427 | 428 | def __call__(self, sample): 429 | data,label,static, static_cat,idxs = sample 430 | data2, label2 = self.func(data,label,static,static_cat) 431 | return (data2, label2, static, static_cat, idxs) 432 | 433 | class SequenceToSampleLabelTransform(object): 434 | """Transforms sequence-level to sample-level labels 435 | majority vote: pick the most frequent label as segment label (i.e. suitable for single-label classification) 436 | num_classes: number of output classes 437 | binary: binary instead of floating point outputs (where the latter represent fractions) 438 | epoch_length: split the original sequence in ts//epoch_length fragments 439 | """ 440 | 441 | def __init__(self, majority_vote=False, num_classes=2, binary=False,epoch_length=0): 442 | self.majority_vote = majority_vote 443 | self.num_classes = num_classes 444 | self.binary = binary 445 | self.epoch_length = epoch_length 446 | 447 | def __call__(self, sample): 448 | data, label, static, static_cat, idxs = sample 449 | 450 | epoch_length = self.epoch_length if self.epoch_length>0 else len(label) 451 | if(len(label.shape)==1):#each time step is single-labeled 452 | label = np.eye(self.num_classes)[label] #now label has shape ts,num_classes 453 | cnts = np.sum(label.reshape((-1,epoch_length,label.shape[-1])),axis=1)#segments,classes 454 | if(self.majority_vote): 455 | label = np.argmax(cnts,axis=-1)#segments 456 | else: 457 | if(self.binary): 458 | label=(cnts>0).astype(np.float32) 459 | else: 460 | label = (cnts/epoch_length).astype(np.float32) 461 | if(self.epoch_length>0): 462 | return (data, label, static, static_cat, idxs) 463 | else:#just one segment 464 | return (data, label[0], static, static_cat, idxs) 465 | -------------------------------------------------------------------------------- /src/clinical_ts/data/time_series_dataset_utils.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 'save_dataset', 'load_dataset','dataset_add_chunk_col', 'dataset_add_length_col', 'dataset_add_labels_col', 'dataset_add_mean_col', 'dataset_add_median_col', 'dataset_add_std_col', 'dataset_add_iqr_col', 'dataset_get_stats', 'append_to_memmap', 'append_to_df_memmap', 'npys_to_memmap_batched', 'npys_to_memmap', 'reformat_as_memmap'] 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from pathlib import Path 7 | from scipy.stats import iqr 8 | 9 | 10 | #workaround for windows pickles 11 | from sys import platform 12 | import pathlib 13 | if platform == "linux" or platform == "linux2": 14 | pathlib.WindowsPath = pathlib.PosixPath 15 | 16 | try: 17 | import pickle5 as pickle 18 | except ImportError as e: 19 | import pickle 20 | 21 | from tqdm.auto import tqdm 22 | 23 | 24 | 25 | 26 | def save_dataset(df,lbl_itos,mean,std,target_root,filename_postfix="",protocol=4): 27 | target_root = Path(target_root) 28 | df.to_pickle(target_root/("df"+filename_postfix+".pkl"), protocol=protocol) 29 | 30 | if(isinstance(lbl_itos,dict)):#dict as pickle 31 | outfile = open(target_root/("lbl_itos"+filename_postfix+".pkl"), "wb") 32 | pickle.dump(lbl_itos, outfile, protocol=protocol) 33 | outfile.close() 34 | else:#array 35 | np.save(target_root/("lbl_itos"+filename_postfix+".npy"),lbl_itos) 36 | 37 | np.save(target_root/("mean"+filename_postfix+".npy"),mean) 38 | np.save(target_root/("std"+filename_postfix+".npy"),std) 39 | 40 | def load_dataset(target_root,filename_postfix="",df_mapped=True): 41 | target_root = Path(target_root) 42 | 43 | if(df_mapped): 44 | df = pd.read_pickle(target_root/("df_memmap"+filename_postfix+".pkl")) 45 | else: 46 | df = pd.read_pickle(target_root/("df"+filename_postfix+".pkl")) 47 | 48 | 49 | if((target_root/("lbl_itos"+filename_postfix+".pkl")).exists()):#dict as pickle 50 | infile = open(target_root/("lbl_itos"+filename_postfix+".pkl"), "rb") 51 | lbl_itos=pickle.load(infile) 52 | infile.close() 53 | else:#array 54 | lbl_itos = np.load(target_root/("lbl_itos"+filename_postfix+".npy")) 55 | 56 | 57 | mean = np.load(target_root/("mean"+filename_postfix+".npy")) 58 | std = np.load(target_root/("std"+filename_postfix+".npy")) 59 | return df, lbl_itos, mean, std 60 | 61 | 62 | def dataset_add_chunk_col(df, col="data"): 63 | '''add a chunk column to the dataset df''' 64 | df["chunk"]=df.groupby(col).cumcount() 65 | 66 | def dataset_add_length_col(df, col="data", data_folder=None): 67 | '''add a length column to the dataset df''' 68 | df[col+"_length"]=df[col].apply(lambda x: len(np.load(x if data_folder is None else data_folder/x, allow_pickle=True))) 69 | 70 | def dataset_add_labels_col(df, col="label", data_folder=None): 71 | '''add a column with unique labels in column col''' 72 | df[col+"_labels"]=df[col].apply(lambda x: list(np.unique(np.load(x if data_folder is None else data_folder/x, allow_pickle=True)))) 73 | 74 | def dataset_add_mean_col(df, col="data", axis=(0), data_folder=None): 75 | '''adds a column with mean''' 76 | df[col+"_mean"]=df[col].apply(lambda x: np.mean(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) 77 | 78 | def dataset_add_median_col(df, col="data", axis=(0), data_folder=None): 79 | '''adds a column with median''' 80 | df[col+"_median"]=df[col].apply(lambda x: np.median(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) 81 | 82 | def dataset_add_std_col(df, col="data", axis=(0), data_folder=None): 83 | '''adds a column with mean''' 84 | df[col+"_std"]=df[col].apply(lambda x: np.std(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) 85 | 86 | def dataset_add_iqr_col(df, col="data", axis=(0), data_folder=None): 87 | '''adds a column with mean''' 88 | df[col+"_iqr"]=df[col].apply(lambda x: iqr(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) 89 | 90 | def dataset_get_stats(df, col="data", simple=True): 91 | '''creates (weighted) means and stds from mean, std and length cols of the df''' 92 | if(simple): 93 | return df[col+"_mean"].mean(), df[col+"_std"].mean() 94 | else: 95 | #https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html 96 | #or https://gist.github.com/thomasbrandon/ad5b1218fc573c10ea4e1f0c63658469 97 | def combine_two_means_vars(x1,x2): 98 | (mean1,var1,n1) = x1 99 | (mean2,var2,n2) = x2 100 | mean = mean1*n1/(n1+n2)+ mean2*n2/(n1+n2) 101 | var = var1*n1/(n1+n2)+ var2*n2/(n1+n2)+n1*n2/(n1+n2)/(n1+n2)*np.power(mean1-mean2,2) 102 | return (mean, var, (n1+n2)) 103 | 104 | def combine_all_means_vars(means,vars,lengths): 105 | inputs = list(zip(means,vars,lengths)) 106 | result = inputs[0] 107 | 108 | for inputs2 in inputs[1:]: 109 | result= combine_two_means_vars(result,inputs2) 110 | return result 111 | 112 | means = list(df[col+"_mean"]) 113 | vars = np.power(list(df[col+"_std"]),2) 114 | lengths = list(df[col+"_length"]) 115 | mean,var,length = combine_all_means_vars(means,vars,lengths) 116 | return mean, np.sqrt(var) 117 | 118 | 119 | def npys_to_memmap_batched(npys, target_filename, max_len=0, delete_npys=True, batched_npy=False, batch_length=900000): 120 | ''' 121 | analogous to npys_to_memmap but processes batches of files before flushing them into memmap for faster processing 122 | ''' 123 | memmap = None 124 | start = np.array([0])#start_idx in current memmap file (always already the next start- delete last token in the end) 125 | length = []#length of segment 126 | filenames= []#memmap files 127 | file_idx=[]#corresponding memmap file for sample 128 | shape=[]#shapes of all memmap files 129 | 130 | data = [] 131 | data_lengths=[] 132 | dtype = None 133 | 134 | target_filename = Path(target_filename) 135 | 136 | for idx,npy in tqdm(enumerate(npys),total=len(npys)): 137 | data_batched = np.load(npy, allow_pickle=True) 138 | 139 | for data_tmp in (tqdm(data_batched,leave=False) if batched_npy else [data_batched]): 140 | data.append(data_tmp) 141 | data_lengths.append(len(data[-1])) 142 | if(idx==len(npys)-1 or np.sum(data_lengths)>batch_length):#flush 143 | data = np.concatenate(data,axis=0)#concatenate along time axis (still axis 0 at this stage) 144 | if(memmap is None or (max_len>0 and start[-1]>max_len)):#new memmap file has to be created 145 | if(max_len>0): 146 | filenames.append(target_filename.parent/(target_filename.stem+"_"+str(len(filenames))+".npy")) 147 | else: 148 | filenames.append(target_filename) 149 | 150 | shape.append([np.sum(data_lengths)]+[l for l in data.shape[1:]])#insert present shape 151 | 152 | if(memmap is not None):#an existing memmap exceeded max_len 153 | del memmap 154 | #create new memmap 155 | start[-1] = 0 156 | start = np.concatenate([start,np.cumsum(data_lengths)]) 157 | length = np.concatenate([length,data_lengths]) 158 | 159 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape) 160 | else: 161 | #append to existing memmap 162 | start = np.concatenate([start,start[-1]+np.cumsum(data_lengths)]) 163 | length = np.concatenate([length,data_lengths]) 164 | shape[-1] = [start[-1]]+[l for l in data.shape[1:]] 165 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple(shape[-1])) 166 | 167 | #store mapping memmap_id to memmap_file_id 168 | file_idx=np.concatenate([file_idx,[(len(filenames)-1)]*len(data_lengths)]) 169 | #insert the actual data 170 | memmap[start[-len(data_lengths)-1]:start[-len(data_lengths)-1]+len(data)]=data[:] 171 | memmap.flush() 172 | dtype = data.dtype 173 | data = []#reset data storage 174 | data_lengths = [] 175 | 176 | start= start[:-1]#remove the last element 177 | #cleanup 178 | for npy in npys: 179 | if(delete_npys is True): 180 | npy.unlink() 181 | del memmap 182 | 183 | #convert everything to relative paths 184 | filenames= [f.name for f in filenames] 185 | #save metadata 186 | np.savez(target_filename.parent/(target_filename.stem+"_meta.npz"),start=start,length=length,shape=shape,file_idx=file_idx,dtype=dtype,filenames=filenames) 187 | 188 | def append_to_memmap(memmap1, memmap2, file_id1=0, file_id2=0): 189 | ''' 190 | appends the contents of memmap2(file_id2) to memmap1(file_id1 in case of split files) 191 | ''' 192 | memmap1 = Path(memmap1) 193 | memmap2 = Path(memmap2) 194 | 195 | meta1 = np.load(memmap1.parent/(memmap1.stem+"_meta.npz"),allow_pickle=True) 196 | meta2 = np.load(memmap2.parent/(memmap2.stem+"_meta.npz"),allow_pickle=True) 197 | meta1_shape = meta1["shape"][file_id1] 198 | meta2_shape = meta2["shape"][file_id2] 199 | assert((len(meta1_shape)==1 and len(meta2_shape)==1) or meta1_shape[1:]==meta2_shape[1:])#shapes have to match up to length 200 | assert(meta1["dtype"]==meta2["dtype"])#dtypes have to agree 201 | mask = np.where(meta2["file_idx"]==file_id2)[0] 202 | lengths2 = np.array(meta2["length"])[mask] 203 | shape = np.concatenate(([meta1_shape[0]+np.sum(lengths2)],meta1_shape[1:])).astype(np.int64) 204 | full_shape=[(shape if i==file_id1 else m) for i,m in enumerate(meta1["shape"])] 205 | starts2 = meta1_shape[0]+np.concatenate(([0],np.cumsum(lengths2)[:-1])) 206 | start = np.concatenate((meta1["start"],starts2)) 207 | length = np.concatenate((meta1["length"],lengths2)) 208 | file_idx= np.concatenate((meta1["file_idx"],np.array([file_id1]*len(mask)))) 209 | print("Appending",memmap2,"to",memmap1,"...") 210 | memmap_extended = np.memmap(memmap1.parent/(meta1["filenames"][file_id1]), dtype=np.dtype(str(meta1["dtype"])), mode='r+', shape=tuple(shape)) 211 | memmap_source = np.memmap(memmap2.parent/(meta2["filenames"][file_id2]), dtype=np.dtype(str(meta2["dtype"])), mode="r", shape=tuple(meta2_shape)) 212 | memmap_extended[meta1_shape[0]:] = memmap_source 213 | memmap_extended.flush() 214 | 215 | np.savez(memmap1.parent/(memmap1.stem+"_meta.npz"),start=start,length=length,shape=full_shape,file_idx=file_idx,dtype=meta1["dtype"],filenames=meta1["filenames"]) 216 | print("done.") 217 | 218 | def append_to_df_memmap(path_df_memmap1,path_df_memmap2,path_memmap1,path_memmap2,file_id1=0,file_id2=0,col_data="data"): 219 | df_memmap1 = pd.read_pickle(path_df_memmap1).sort_values(by=[col_data]) 220 | df_memmap2 = pd.read_pickle(path_df_memmap2).sort_values(by=[col_data]) 221 | path_memmap1 = Path(path_memmap1) 222 | path_memmap2 = Path(path_memmap2) 223 | 224 | meta1 = np.load(path_memmap1.parent/(path_memmap1.stem+"_meta.npz"),allow_pickle=True) 225 | meta2 = np.load(path_memmap2.parent/(path_memmap2.stem+"_meta.npz"),allow_pickle=True) 226 | df_memmap2 = df_memmap2.iloc[np.where(meta2["file_idx"]==file_id2)].copy() 227 | file_idx1 = meta1["file_idx"] 228 | assert(len(df_memmap1)==len(file_idx1))# apply append_to_df_memmap before append_to_memmap 229 | data_idx_start = np.max(np.where(file_idx1<=file_id1)[0])+1 230 | df_memmap1.loc[df_memmap1[col_data]>=data_idx_start,col_data]=df_memmap1.loc[df_memmap1[col_data]>=data_idx_start,col_data]+len(df_memmap2) 231 | 232 | df_memmap2[col_data]=df_memmap2[col_data]-df_memmap2[col_data].min()+data_idx_start 233 | df_memmap1 = pd.concat((df_memmap1,df_memmap2)) 234 | return df_memmap1 235 | 236 | def npys_to_memmap(npys, target_filename, max_len=0, delete_npys=True, batched_npy=False): 237 | ''' 238 | converts list of filenames pointing to npy files into a memmap file with target_filename 239 | max_len: restricts filesize per memmap file (0 no restriction) 240 | delete_npys: deletes original npys after processing to save space 241 | batched_npy: assumes first axis in the npy file enumerates samples (otherwise just a single sample per npy file) 242 | ''' 243 | memmap = None 244 | start = []#start_idx in current memmap file 245 | length = []#length of segment 246 | filenames= []#memmap files 247 | file_idx=[]#corresponding memmap file for sample 248 | shape=[] 249 | 250 | target_filename = Path(target_filename) 251 | 252 | for _,npy in tqdm(enumerate(npys),total=len(npys)): 253 | data_batched = np.load(npy, allow_pickle=True) 254 | for data in tqdm(data_batched,leave=False) if batched_npy else[data_batched]: 255 | if(memmap is None or (max_len>0 and start[-1]+length[-1]>max_len)): 256 | if(max_len>0): 257 | filenames.append(target_filename.parent/(target_filename.stem+"_"+str(len(filenames))+".npy")) 258 | else: 259 | filenames.append(target_filename) 260 | 261 | if(memmap is not None):#an existing memmap exceeded max_len 262 | shape.append([start[-1]+length[-1]]+[l for l in data.shape[1:]]) 263 | del memmap 264 | #create new memmap 265 | start.append(0) 266 | length.append(data.shape[0]) 267 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape) 268 | else: 269 | #append to existing memmap 270 | start.append(start[-1]+length[-1]) 271 | length.append(data.shape[0]) 272 | memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple([start[-1]+length[-1]]+[l for l in data.shape[1:]])) 273 | 274 | #store mapping memmap_id to memmap_file_id 275 | file_idx.append(len(filenames)-1) 276 | #insert the actual data 277 | memmap[start[-1]:start[-1]+length[-1]]=data[:] 278 | memmap.flush() 279 | if(delete_npys is True): 280 | npy.unlink() 281 | del memmap 282 | 283 | #append final shape if necessary 284 | if(len(shape)0 else hparams_input_shape.channels)+hparams_input_shape.static_dim+hparams_input_shape.static_dim_cat 16 | self.linear = nn.Linear(input_size,target_dim) 17 | self.output_shape = dataclasses.replace(hparams_input_shape) 18 | self.output_shape.channels = target_dim 19 | self.output_shape.length = 0 20 | self.output_shape.static_dim = 0 21 | self.output_shape.static_dim_cat = 0 22 | 23 | def forward(self, **kwargs): 24 | static = kwargs["static"] 25 | seq = kwargs["seq"] 26 | return {"seq": self.linear(torch.cat((seq.view(seq.shape[0],-1),static),dim=1))} 27 | 28 | def get_output_shape(self): 29 | return self.output_shape 30 | 31 | @dataclass 32 | class ConcatFusionHeadConfig(HeadBaseConfig): 33 | _target_:str = "clinical_ts.head.multimodal.ConcatFusionHead" 34 | multi_prediction:bool=False 35 | -------------------------------------------------------------------------------- /src/clinical_ts/loss/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/loss/supervised.py: -------------------------------------------------------------------------------- 1 | __all__ = ['SupervisedLossConfig', 'BCELossConfig', 'BinaryCrossEntropyFocalLoss', 'BCEFLossConfig'] 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from dataclasses import dataclass, field 9 | from typing import List 10 | 11 | from ..template_modules import LossConfig 12 | 13 | #################################################################################### 14 | # BASIC supervised losses 15 | ################################################################################### 16 | @dataclass 17 | class SupervisedLossConfig(LossConfig): 18 | _target_:str = "" #insert appropriate loss class 19 | loss_type:str ="supervised" 20 | supervised_type:str="classification_single"#"classification_multi","regression_quantile" 21 | 22 | @dataclass 23 | class BCELossConfig(SupervisedLossConfig): 24 | _target_:str= "clinical_ts.loss.supervised.BinaryCrossEntropyLoss" 25 | loss_type:str="supervised" 26 | supervised_type:str="classification_multi" 27 | pos_weight:List[float]=field(default_factory=lambda: [])#class weights e.g. inverse class prevalences 28 | ignore_nans:bool=False #ignore nans- requires separate BCEs for each label 29 | 30 | class BinaryCrossEntropyFocalLoss(nn.Module): 31 | """ 32 | Focal BCE loss for binary classification with labels of 0 and 1 33 | """ 34 | def __init__(self, hparams_loss): 35 | super().__init__() 36 | self.gamma = hparams_loss.gamma 37 | 38 | self.ignore_nans = hparams_loss.ignore_nans 39 | self.pos_weight_set = len(hparams_loss.pos_weight)>0 40 | 41 | if(not self.ignore_nans): 42 | self.bce = torch.nn.BCEWithLogitsLoss(reduction="none",pos_weight=torch.from_numpy(np.array(hparams_loss.pos_weight,dtype=np.float32)) if len(hparams_loss.pos_weight)>0 else None) 43 | else: 44 | if(self.pos_weight_set): 45 | self.bce = torch.nn.ModuleList([torch.nn.BCEWithLogitsLoss(reduction="none",pos_weight=torch.from_numpy(np.array([hparams_loss.pos_weight[i]],dtype=np.float32))) for i in range(len(self.pos_weight_set))]) 46 | else: 47 | self.bce = torch.nn.BCEWithLogitsLoss(reduction="none") 48 | 49 | def forward(self, preds, targs): 50 | if(not(self.ignore_nans)): 51 | probs = torch.sigmoid(preds) 52 | p_t = probs * targs + (1 - probs) * (1 - targs) 53 | focal_modulation = torch.pow((1 - p_t), self.gamma) 54 | # mean aggregation 55 | return (focal_modulation * self.bce(input=preds, target=targs.float())).sum(-1).mean() 56 | else: 57 | losses = [] 58 | for i in range(preds.size(1)): 59 | predsi = preds[:,i] 60 | targsi = targs[:,i] 61 | maski = ~torch.isnan(targsi) 62 | predsi = predsi[maski] 63 | targsi = targsi[maski] 64 | if(len(predsi)>0): 65 | probsi = torch.sigmoid(predsi) 66 | p_ti = probsi * targsi + (1 - probsi) * (1 - targsi) 67 | focal_modulationi = torch.pow((1 - p_ti), self.gamma) 68 | if(self.pos_weight_set): 69 | losses.append(torch.mean(focal_modulationi*self.bce[i](predsi,targsi))) 70 | else: 71 | losses.append(torch.mean(focal_modulationi*self.bce(predsi,targsi))) 72 | 73 | return torch.sum(torch.stack(losses)) if(len(losses)>0) else 0. 74 | 75 | @dataclass 76 | class BCEFLossConfig(BCELossConfig): 77 | _target_:str= "clinical_ts.loss.supervised.BinaryCrossEntropyFocalLoss" 78 | gamma:float=2. -------------------------------------------------------------------------------- /src/clinical_ts/metric/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/metric/base.py: -------------------------------------------------------------------------------- 1 | __all__ = ['MetricConfig', 'MetricBase', 'MetricAUROC', 'MetricAUROCConfig', 'MetricAUROCAggConfig'] 2 | 3 | import numpy as np 4 | 5 | from dataclasses import dataclass 6 | 7 | from ..utils.eval_utils_cafa import multiclass_roc_curve 8 | from ..utils.bootstrap_utils import empirical_bootstrap 9 | 10 | import warnings 11 | from sklearn.exceptions import UndefinedMetricWarning 12 | 13 | # Filter out the warnings due to not enough positive/negative samples during bootstrapping 14 | warnings.filterwarnings('ignore', category=UndefinedMetricWarning) 15 | 16 | @dataclass 17 | class MetricConfig: 18 | _target_:str = "" 19 | 20 | name:str = ""#name of the metric e.g. auroc 21 | 22 | aggregation:str = "" #"" means no aggregation across segments of the same sequence, other options: "mean", "max" 23 | 24 | key_summary_metric:str = "" #key into the output dict that can serve as summary metric for early stopping etc e.g. (without key_prefix and key_postfix and aggregation type) 25 | mode_summary_metric:str ="max" #used to determine if key_summary_metric is supposed to be maximized or minimized 26 | 27 | verbose:str = "" # comma-separated list of keys to be printed after metric evaluation (without key_prefix and key_postfix and aggregation type) 28 | 29 | bootstrap_report_nans:bool = False #report nans during bootstrapping (due to not enough labels of a certain type in certain bootstrap iterations etc) 30 | bootstrap_iterations:int = 0 #0: no bootstrap 31 | bootstrap_alpha:float= 0.95 # bootstrap alpha 32 | 33 | def _reformat_lbl_itos(k): 34 | # return re.sub(r'(?0 else "")+hparams_metric.name+"_" 40 | self.key_postfix = key_postfix 41 | self.aggregation = hparams_metric.aggregation 42 | self.aggregation_txt = ("_agg" if hparams_metric.aggregation=="mean" else "_agg"+hparams_metric.aggregation) if hparams_metric.aggregation!="" else "" 43 | self.key_summary_metric = self.key_prefix+hparams_metric.key_summary_metric+self.aggregation_txt+self.key_postfix #data loader id added by default 44 | self.mode_summary_metric = hparams_metric.mode_summary_metric 45 | self.verbose = [x for x in hparams_metric.verbose.split(",") if x!=""] 46 | 47 | self.bootstrap_iterations = hparams_metric.bootstrap_iterations if test else 0 #disable bootstrap during training 48 | self.bootstrap_alpha = hparams_metric.bootstrap_alpha 49 | self.bootstrap_report_nans = hparams_metric.bootstrap_report_nans 50 | 51 | self.lbl_itos = [_reformat_lbl_itos(l) for l in lbl_itos] 52 | self.keys = self.get_keys(self.lbl_itos) 53 | 54 | 55 | def get_keys(self, lbl_itos): 56 | '''returns metrics keys in the order they will later be returned by _eval''' 57 | raise NotImplementedError 58 | 59 | def __call__(self,targs,preds): 60 | 61 | if(self.bootstrap_iterations==0): 62 | point = self._eval(targs,preds) 63 | else: 64 | point,low,high,nans = empirical_bootstrap((targs,preds), self._eval, n_iterations=self.bootstrap_iterations , alpha=self.bootstrap_alpha,ignore_nans=True)#score_fn_kwargs={"classes":self.lbl_itos} 65 | res = {self.key_prefix+k+self.aggregation_txt+self.key_postfix:v for v,k in zip(point,self.keys)} 66 | if(self.bootstrap_iterations>0): 67 | res_low = {self.key_prefix+k+self.aggregation_txt+self.key_postfix+"_low":v for v,k in zip(low,self.keys)} 68 | res_high = {self.key_prefix+k+self.aggregation_txt+self.key_postfix+"_high":v for v,k in zip(high,self.keys)} 69 | res_nans = {self.key_prefix+k+self.aggregation_txt+self.key_postfix+"_nans":v for v,k in zip(nans,self.keys)} 70 | res.update(res_low) 71 | res.update(res_high) 72 | if(self.bootstrap_report_nans): 73 | res.update(res_nans) 74 | 75 | if(len(self.verbose)>0): 76 | for k in self.verbose: 77 | print("\n"+self.key_prefix+k+self.aggregation_txt+self.key_postfix+":"+str(res[self.key_prefix+k+self.aggregation_txt+self.key_postfix])) 78 | 79 | return res 80 | 81 | def _eval(self,targs,preds): 82 | # should return an array of results ordered according to the entries returned by get_keys() 83 | raise NotImplementedError 84 | 85 | 86 | 87 | class MetricAUROC(MetricBase): 88 | '''provides class-wise+macro+micro AUROC/AUPR scores''' 89 | def __init__(self, hparams_metric, lbl_itos, key_prefix="", key_postfix="0", test=True): 90 | super().__init__(hparams_metric, lbl_itos=lbl_itos, key_prefix=key_prefix, key_postfix=key_postfix, test=test) 91 | self.precision_recall = hparams_metric.precision_recall 92 | 93 | def get_keys(self, lbl_itos): 94 | return list(lbl_itos)+["micro","macro"] 95 | 96 | def _eval(self,targs,preds): 97 | if(self.precision_recall): 98 | _,_,res = multiclass_roc_curve(targs,preds,classes=self.lbl_itos,precision_recall=True) 99 | return np.array(list(res.values())) 100 | else: 101 | _,_,res = multiclass_roc_curve(targs,preds,classes=self.lbl_itos) 102 | return np.array(list(res.values())) 103 | 104 | 105 | @dataclass 106 | class MetricAUROCConfig(MetricConfig): 107 | _target_:str = "clinical_ts.metric.base.MetricAUROC" 108 | key_summary_metric:str = "macro" 109 | verbose:str="macro" #by default print out macro auc 110 | precision_recall:bool = False #calculate the area under the precision recall curve instead of the ROC curve 111 | name:str = "auroc" 112 | bootstrap_report_nans:bool = True #by default report number of bootstrap iterations where the score was nan (due to insufficient number of labels etc) 113 | 114 | #shorthand for mean aggregation 115 | @dataclass 116 | class MetricAUROCAggConfig(MetricAUROCConfig): 117 | aggregation:str="mean" 118 | -------------------------------------------------------------------------------- /src/clinical_ts/tabular/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/tabular/base.py: -------------------------------------------------------------------------------- 1 | __all__ = ['BasicEncoderStatic', 'BasicEncoderStaticConfig', 'BasicEncoderStaticMLP', 'BasicEncoderStaticMLPConfig'] 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | import dataclasses 8 | from dataclasses import dataclass, field 9 | from typing import List 10 | 11 | from ..template_modules import EncoderStaticBase, EncoderStaticBaseConfig 12 | from collections.abc import Iterable 13 | from ..ts.basic_conv1d_modules.basic_conv1d import bn_drop_lin 14 | 15 | class BasicEncoderStatic(EncoderStaticBase): 16 | def __init__(self, hparams_encoder_static, hparams_input_shape, target_dim=None): 17 | super().__init__(hparams_encoder_static, hparams_input_shape, target_dim) 18 | self.input_channels_cat = hparams_input_shape.static_dim_cat 19 | self.input_channels_cont = hparams_input_shape.static_dim 20 | assert(len(hparams_encoder_static.embedding_dims)==hparams_input_shape.static_dim_cat and len(hparams_encoder_static.vocab_sizes)==hparams_input_shape.static_dim_cat) 21 | self.embeddings = nn.ModuleList() if hparams_input_shape.static_dim_cat is not None else None 22 | for v,e in zip(hparams_encoder_static.vocab_sizes,hparams_encoder_static.embedding_dims): 23 | self.embeddings.append(nn.Embedding(v,e)) 24 | self.input_dim = int(np.sum(hparams_encoder_static.embedding_dims) + hparams_input_shape.static_dim) 25 | self.input_channels = hparams_input_shape.static_dim + hparams_input_shape.static_dim_cat 26 | 27 | 28 | def embed(self, **kwargs): 29 | static = kwargs["static"] if "static" in kwargs.keys() else None 30 | static_cat = kwargs["static_cat"] if "static_cat" in kwargs.keys() else None 31 | 32 | res = [] 33 | if(static_cat is not None): 34 | for i,e in enumerate(self.embeddings): 35 | res.append(e(static_cat[:,i].long())) 36 | if(static is not None and static_cat is not None): 37 | res = torch.cat([torch.cat(res,dim=1),static],dim=1) 38 | else: 39 | res = torch.cat(res,dim=1) 40 | else: 41 | res = static 42 | 43 | return res 44 | 45 | def forward(self, **kwargs): 46 | raise NotImplementedError 47 | 48 | def get_output_shape(self): 49 | raise NotImplementedError 50 | 51 | @dataclass 52 | class BasicEncoderStaticConfig(EncoderStaticBaseConfig): 53 | _target_:str = "clinical_ts.tabular.base.BasicEncoderStatic" 54 | embedding_dims:List[int] = field(default_factory=lambda: []) #list with embedding dimensions 55 | vocab_sizes:List[int] = field(default_factory=lambda: []) #list with vocab sizes (space-separated) 56 | 57 | class BasicEncoderStaticMLP(BasicEncoderStatic): 58 | def __init__(self, hparams_encoder_static, hparams_input_shape, target_dim=None): 59 | super().__init__(hparams_encoder_static, hparams_input_shape, target_dim) 60 | 61 | lin_ftrs = [self.input_dim] + list(hparams_encoder_static.lin_ftrs) 62 | if(target_dim is not None and lin_ftrs[-1]!=target_dim): 63 | lin_ftrs.append(target_dim) 64 | ps = [hparams_encoder_static.dropout] if not isinstance(hparams_encoder_static.dropout, Iterable) else hparams_encoder_static.dropout 65 | if len(ps)==1: 66 | ps= [ps[0]/2] * (len(lin_ftrs)-2) + ps 67 | actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None] 68 | layers = [] 69 | for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns): 70 | layers+=bn_drop_lin(ni,no,hparams_encoder_static.batch_norm,p,actn,layer_norm=False) 71 | self.layers=nn.Sequential(*layers) 72 | 73 | self.output_shape = dataclasses.replace(hparams_input_shape) 74 | self.output_shape.static_dim = int(lin_ftrs[-1]) 75 | self.output_shape.static_dim_cat = 0 76 | 77 | def forward(self, **kwargs): 78 | res = self.embed(**kwargs) 79 | return {"static": self.layers(res)} 80 | 81 | def get_output_shape(self): 82 | return self.output_shape 83 | 84 | 85 | @dataclass 86 | class BasicEncoderStaticMLPConfig(BasicEncoderStaticConfig): 87 | _target_:str = "clinical_ts.tabular.base.BasicEncoderStaticMLP" 88 | lin_ftrs:List[int] = field(default_factory=lambda: [512]) #list with MLP hidden layer sizes; last entry is the static encoder output dimension in case target_dim is not specified 89 | dropout:float = 0.5 90 | batch_norm:bool = True 91 | -------------------------------------------------------------------------------- /src/clinical_ts/task/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/task/multimodal.py: -------------------------------------------------------------------------------- 1 | from ..template_model import SSLModel 2 | from ..template_modules import TaskConfig 3 | from ..data.time_series_dataset_utils import load_dataset 4 | from ..data.time_series_dataset_transforms import Transform 5 | from pathlib import Path 6 | from dataclasses import dataclass 7 | import numpy as np 8 | import pandas as pd 9 | import warnings 10 | 11 | #disable performance warning 12 | warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning) 13 | 14 | class MultimodalModel(SSLModel): 15 | '''class for multimodal tasks''' 16 | 17 | def preprocess_dataset(self,dataset_kwargs): 18 | df_mapped, lbl_itos, mean, std = load_dataset(Path(dataset_kwargs.path)) 19 | 20 | if(self.hparams.loss.loss_type=="supervised" and dataset_kwargs.name.startswith("mimic_labvalues")): 21 | 22 | #reformat race 23 | lbl_itos_ethnicity=['race_asian', 'race_black', 'race_hispanic', 'race_other', 'race_white'] 24 | 25 | df_mapped["race"]=df_mapped.apply(lambda row: np.where([row[c] for c in lbl_itos_ethnicity])[0],axis=1) 26 | df_mapped["race"]=df_mapped["race"].apply(lambda x:x[0] if len(x)==1 else -1) 27 | df_mapped["race_nan"]=(df_mapped["race"]==-1) 28 | if(self.hparams.task.impute_nans): 29 | df_mapped["race"] = df_mapped["race"].replace(-1, df_mapped[(df_mapped["race"] != -1)&(df_mapped.strat_fold<18)]["race"].median())# median imput missing race 30 | df_mapped.drop(lbl_itos_ethnicity,axis=1,inplace=True) 31 | 32 | #prepare cat and cont features 33 | cat_features = ['gender','race'] 34 | cat_features_m =['race_nan', 'temperature_nan', 'o2sat_nan', 'weight_nan', 'bmi_nan', 'sbp_nan', 'dbp_nan', 'resprate_nan', 'heartrate_nan', 'height_nan'] 35 | 36 | cont_features = ['resprate', 'o2sat', 'anchor_age', 'dbp', 'heartrate', 'sbp', 'temperature', 'height', 'bmi', 'weight']#'rr_interval', 'qrs_onset', 'qrs_axis', 'p_axis', 't_axis', 'p_end', 'p_onset', 'qrs_end', 't_end' 37 | 38 | if(self.hparams.task.impute_nans): 39 | input_cols = cat_features + cont_features 40 | #grab training set medians and identify columns with nans 41 | df_train= df_mapped[df_mapped.strat_fold<18] 42 | train_medians= df_train[input_cols].median().to_dict() 43 | train_nans = [l for l,c in df_train[input_cols].isna().sum().to_dict().items() if c>0] 44 | 45 | #impute nans through medians introduce additional column _nan (if desired) 46 | for c in train_nans: 47 | if(self.hparams.task.introduce_nan_columns): 48 | df_mapped[c+"_nan"]=0 49 | df_mapped.loc[df_mapped[c].isna(),c+"_nan"]=1 50 | df_mapped.loc[df_mapped[c].isna(),c]=train_medians[c] 51 | 52 | #defragment 53 | df_mapped = df_mapped.copy() 54 | 55 | if(self.hparams.task.introduce_nan_columns): 56 | cat_features += cat_features_m 57 | 58 | cat_features_dim = [len(df_mapped[c].unique()) for c in cat_features] 59 | 60 | df_mapped["cat_features"]=df_mapped[cat_features].values.tolist() 61 | df_mapped.drop(cat_features,axis=1,inplace=True) 62 | df_mapped["cont_features"]=df_mapped[cont_features].values.tolist() 63 | df_mapped.drop(cont_features,axis=1,inplace=True) 64 | 65 | def replace_nan(arr): 66 | return np.where(np.array(arr)==-999.0, np.nan, np.array(arr)) 67 | df_mapped["label"]=df_mapped["aggregated_label"].apply(lambda arr:replace_nan(arr)) 68 | df_mapped.drop("aggregated_label",axis=1,inplace=True) 69 | 70 | return df_mapped, lbl_itos, mean, std 71 | 72 | 73 | @dataclass 74 | class TaskConfigMultimodal(TaskConfig): 75 | mainclass:str= "clinical_ts.task.multimodal.MultimodalModel" 76 | impute_nans:bool = True #impute nans or leave it to the model to handle it 77 | introduce_nan_columns:bool = False #impute using train set median and introduce an additional column that states if imputation occurred 78 | nan_columns_as_cat:bool = False #treat nan columns as categorical variables (instead of continuous) 79 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/base.py: -------------------------------------------------------------------------------- 1 | __all__ = ['NoPredictor', 'NoPredictorConfig', 'CNNPredictor', 'CNNPredictorConfig'] 2 | 3 | import dataclasses 4 | from dataclasses import dataclass, field 5 | from typing import List 6 | 7 | import torch.nn as nn 8 | import numpy as np 9 | from .basic_conv1d_modules.basic_conv1d import _conv1d 10 | 11 | from ..template_modules import PredictorBase, PredictorBaseConfig 12 | 13 | class NoPredictor(PredictorBase): 14 | def __init__(self, hparams, hparams_input_shape): 15 | ''' 16 | no predictor e.g. for pretraining purposes 17 | ''' 18 | super().__init__(hparams, hparams_input_shape) 19 | 20 | def forward(self, **kwargs): 21 | return {} 22 | 23 | 24 | @dataclass 25 | class NoPredictorConfig(PredictorBaseConfig): 26 | _target_:str = "clinical_ts.ts.base.NoPredictor" 27 | 28 | 29 | class CNNPredictor(PredictorBase): 30 | def __init__(self, hparams_encoder, hparams_input_shape): 31 | '''this is a reduced version of the RNNEncoder''' 32 | super().__init__(hparams_encoder, hparams_input_shape) 33 | assert(not hparams_input_shape.sequence_last) 34 | assert(len(hparams_encoder.strides)==len(hparams_encoder.kss) and len(hparams_encoder.strides)==len(hparams_encoder.features) and len(hparams_encoder.strides)==len(hparams_encoder.dilations)) 35 | lst = [] 36 | for i,(s,k,f,d) in enumerate(zip(hparams_encoder.strides,hparams_encoder.kss,hparams_encoder.features,hparams_encoder.dilations)): 37 | lst.append(_conv1d(hparams_input_shape.channels if i==0 else hparams_encoder.features[i-1],f,kernel_size=k,stride=s,dilation=d,bn=hparams_encoder.normalization,layer_norm=hparams_encoder.layer_norm)) 38 | 39 | self.layers = nn.Sequential(*lst) 40 | self.downsampling_factor = np.prod(hparams_encoder.strides) 41 | 42 | self.output_dim = hparams_encoder.features[-1] 43 | 44 | self.output_shape = dataclasses.replace(hparams_input_shape) 45 | self.output_shape.channels = self.output_dim 46 | self.output_shape.length = int(hparams_input_shape.length//self.downsampling_factor+ (1 if hparams_input_shape.length%self.downsampling_factor>0 else 0)) 47 | 48 | def get_output_shape(self): 49 | return self.output_shape 50 | 51 | def forward(self, **kwargs): 52 | seq = kwargs["seq"].transpose(1,2) 53 | return {"seq": self.layers(seq).transpose(1,2)}#bs,seq,feat 54 | 55 | @dataclass 56 | class CNNPredictorConfig(PredictorBaseConfig): 57 | _target_:str = "clinical_ts.ts.base.CNNPredictor" 58 | 59 | strides:List[int]=field(default_factory=lambda: [1,1,1,1]) #help="encoder strides (space-separated)") 60 | kss:List[int]=field(default_factory=lambda: [1,1,1,1]) #help="encoder kernel sizes (space-separated)") 61 | features:List[int]=field(default_factory=lambda: [512,512,512,512]) #help="encoder features (space-separated)") 62 | dilations:List[int]=field(default_factory=lambda: [1,1,1,1]) #help="encoder dilations (space-separated)") 63 | normalization:bool=True #help="disable encoder batch/layer normalization") 64 | layer_norm:bool=False#", action="store_true", help="encoder layer normalization") 65 | 66 | 67 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/basic_conv1d_modules/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/basic_conv1d_modules/basic_conv1d.py: -------------------------------------------------------------------------------- 1 | __all__ = ['AdaptiveConcatPool1d', 'SqueezeExcite1d', 2 | 'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d'] 3 | 4 | # Cell 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | 10 | from typing import Iterable 11 | 12 | class Flatten(nn.Module): 13 | "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor" 14 | def __init__(self, full:bool=False): 15 | super().__init__() 16 | self.full = full 17 | def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1) 18 | 19 | 20 | 21 | def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None, layer_norm=False, permute=False): 22 | ''' 23 | Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`. 24 | permute for input of the form B,Seq,Feat" 25 | ''' 26 | layers=[] 27 | if(permute): 28 | layers.append(LambdaLayer(lambda x: x.permute(0,2,1))) 29 | if(bn): 30 | if(layer_norm is False): 31 | layers.append(nn.BatchNorm1d(n_in)) 32 | else: 33 | layers.append(nn.LayerNorm(n_in)) 34 | if(permute): 35 | layers.append(LambdaLayer(lambda x: x.permute(0,2,1))) 36 | if p != 0: layers.append(nn.Dropout(p)) 37 | layers.append(nn.Linear(n_in, n_out)) 38 | if actn is not None: layers.append(actn) 39 | return layers 40 | 41 | # Cell 42 | 43 | class LambdaLayer(nn.Module): 44 | def __init__(self, lambd): 45 | super(LambdaLayer, self).__init__() 46 | self.lambd = lambd 47 | def forward(self, x): 48 | return self.lambd(x) 49 | 50 | def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0, layer_norm=False): 51 | lst=[] 52 | if(drop_p>0): 53 | lst.append(nn.Dropout(drop_p)) 54 | lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn))) 55 | if(bn): 56 | if(layer_norm): 57 | lst.append(LambdaLayer(lambda x: x.transpose(1,2))) 58 | lst.append(nn.LayerNorm(out_planes)) 59 | lst.append(LambdaLayer(lambda x: x.transpose(1,2))) 60 | else: 61 | lst.append(nn.BatchNorm1d(out_planes)) 62 | if(act=="relu"): 63 | lst.append(nn.ReLU(True)) 64 | if(act=="elu"): 65 | lst.append(nn.ELU(True)) 66 | if(act=="prelu"): 67 | lst.append(nn.PReLU(True)) 68 | if(act=="gelu"): 69 | lst.append(nn.GELU()) 70 | return nn.Sequential(*lst) 71 | 72 | def _fc(in_planes,out_planes, act="relu", bn=True): 73 | lst = [nn.Linear(in_planes, out_planes, bias=not(bn))] 74 | if(bn): 75 | lst.append(nn.BatchNorm1d(out_planes)) 76 | if(act=="relu"): 77 | lst.append(nn.ReLU(True)) 78 | if(act=="elu"): 79 | lst.append(nn.ELU(True)) 80 | if(act=="prelu"): 81 | lst.append(nn.PReLU(True)) 82 | return nn.Sequential(*lst) 83 | 84 | class AdaptiveConcatPool1d(nn.Module): 85 | "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`." 86 | def __init__(self, sz=None): 87 | "Output will be 2*sz or 2 if sz is None" 88 | super().__init__() 89 | sz = sz or 1 90 | self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) 91 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) 92 | 93 | # Cell 94 | class SqueezeExcite1d(nn.Module): 95 | '''squeeze excite block as used for example in LSTM FCN''' 96 | def __init__(self,channels,reduction=16): 97 | super().__init__() 98 | channels_reduced = channels//reduction 99 | self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0)) 100 | self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0)) 101 | 102 | def forward(self, x): 103 | #input is bs,ch,seq 104 | z=torch.mean(x,dim=2,keepdim=True)#bs,ch 105 | intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1) 106 | s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1 107 | return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq 108 | 109 | # Cell 110 | def weight_init(m): 111 | '''call weight initialization for model n via n.appy(weight_init)''' 112 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): 113 | nn.init.kaiming_normal_(m.weight) 114 | if m.bias is not None: 115 | nn.init.zeros_(m.bias) 116 | if isinstance(m, nn.BatchNorm1d): 117 | nn.init.constant_(m.weight,1) 118 | nn.init.constant_(m.bias,0) 119 | if isinstance(m,SqueezeExcite1d): 120 | stdv1=math.sqrt(2./m.w1.size[0]) 121 | nn.init.normal_(m.w1,0.,stdv1) 122 | stdv2=math.sqrt(1./m.w2.size[1]) 123 | nn.init.normal_(m.w2,0.,stdv2) 124 | 125 | # Cell 126 | def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn:bool=True, act="relu", concat_pooling=True): 127 | "Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here" 128 | lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc] 129 | ps = [ps] if not isinstance(ps,Iterable) else ps 130 | if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps 131 | actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None] 132 | layers = [AdaptiveConcatPool1d() if concat_pooling else nn.AdaptiveAvgPool1d(1), Flatten()] 133 | for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns): 134 | layers += bn_drop_lin(ni,no,bn,p,actn) 135 | return nn.Sequential(*layers) 136 | 137 | # Cell 138 | class basic_conv1d(nn.Sequential): 139 | '''basic conv1d''' 140 | def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): 141 | layers = [] 142 | if(isinstance(kernel_size,int)): 143 | kernel_size = [kernel_size]*len(filters) 144 | for i in range(len(filters)): 145 | layers_tmp = [] 146 | 147 | layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p))) 148 | if((split_first_layer is True and i==0)): 149 | layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.)) 150 | #layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn))) 151 | #layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn)) 152 | if(pool>0 and i0): 155 | layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction)) 156 | layers.append(nn.Sequential(*layers_tmp)) 157 | 158 | #head 159 | #layers.append(nn.AdaptiveAvgPool1d(1)) 160 | #layers.append(nn.Linear(filters[-1],num_classes)) 161 | #head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5 162 | self.headless = headless 163 | if(headless is True): 164 | head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten()) 165 | else: 166 | head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) 167 | layers.append(head) 168 | 169 | super().__init__(*layers) 170 | 171 | def get_layer_groups(self): 172 | return (self[2],self[-1]) 173 | 174 | def get_output_layer(self): 175 | if self.headless is False: 176 | return self[-1][-1] 177 | else: 178 | return None 179 | 180 | def set_output_layer(self,x): 181 | if self.headless is False: 182 | self[-1][-1] = x 183 | 184 | 185 | # Cell 186 | def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs): 187 | filters_in = filters + [num_classes] 188 | return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True) 189 | 190 | def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 191 | return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 192 | 193 | def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 194 | return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 195 | 196 | def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 197 | return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 198 | 199 | def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): 200 | return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) 201 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/encoder.py: -------------------------------------------------------------------------------- 1 | __all__ = ['NoEncoder', 'NoEncoderConfig'] 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | import dataclasses 8 | from dataclasses import dataclass, field 9 | from typing import List 10 | 11 | from ..template_modules import EncoderBase, EncoderBaseConfig 12 | from .basic_conv1d_modules.basic_conv1d import _conv1d 13 | from .transformer_modules.transformer import TransformerConvStemTokenizer 14 | 15 | class NoEncoder(EncoderBase): 16 | def __init__(self, hparams_encoder, hparams_input_shape): 17 | ''' 18 | no encoder- flattens by default if multiple channels are passed 19 | ''' 20 | super().__init__(hparams_encoder, hparams_input_shape) 21 | self.timesteps_per_token = hparams_encoder.timesteps_per_token 22 | self.sequence_last = hparams_input_shape.sequence_last 23 | self.input_channels = hparams_input_shape.channels if hparams_input_shape.channels2==0 else hparams_input_shape.channels*hparams_input_shape.channels2 24 | 25 | self.output_shape = dataclasses.replace(hparams_input_shape) 26 | self.output_shape.channels = self.input_channels*self.timesteps_per_token 27 | self.output_shape.channels2 = 0 28 | self.output_shape.length = hparams_input_shape.length//self.timesteps_per_token 29 | self.output_shape.sequence_last = False 30 | 31 | def forward(self, **kwargs): 32 | seq = kwargs["seq"] #bs,channels,freq,seq 33 | if(not self.sequence_last): 34 | seq = torch.movedim(seq,1,-1) 35 | if(len(seq.size())==4):#spectrogram input 36 | seq = seq.view(seq.size(0),-1,seq.size(-1))#flatten 37 | 38 | if(self.timesteps_per_token==1): 39 | return {"seq": seq.transpose(1,2)} 40 | else: 41 | assert(seq.size(-1)%self.timesteps_per_token==0) 42 | size = seq.size() 43 | return {"seq": seq.view(size[0],-1,seq.shape[-1]).transpose(1,2).reshape(size[0],size[2]//self.timesteps_per_token,-1).transpose(1,2)} 44 | 45 | def get_output_shape(self): 46 | return self.output_shape 47 | 48 | 49 | @dataclass 50 | class NoEncoderConfig(EncoderBaseConfig): 51 | _target_:str = "clinical_ts.ts.encoder.NoEncoder" 52 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/head.py: -------------------------------------------------------------------------------- 1 | __all__ = ['PoolingHead', 'PoolingHeadConfig'] 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from ..template_modules import HeadBase, HeadBaseConfig, _string_to_class 8 | import dataclasses 9 | from dataclasses import dataclass, field 10 | from typing import List 11 | 12 | class PoolingHead(HeadBase): 13 | def __init__(self, hparams_head, hparams_input_shape, target_dim): 14 | super().__init__(hparams_head, hparams_input_shape, target_dim) 15 | #assert(target_dim is None or hparams_head.output_layer is True) 16 | if(target_dim is not None and hparams_head.output_layer is False): 17 | print("Warning: target_dim",target_dim,"is passed to PoolingHead but output_layer is False. target_dim will be ignored.") 18 | self.local_pool = hparams_head.multi_prediction 19 | self.output_dim = hparams_input_shape.channels if not hparams_head.output_layer else target_dim 20 | 21 | if(self.local_pool):#local pool 22 | self.local_pool_padding = (hparams_head.local_pool_kernel_size-1)//2 23 | self.local_pool_kernel_size = hparams_head.local_pool_kernel_size 24 | self.local_pool_stride = hparams_head.local_pool_kernel_size if hparams_head.local_pool_stride==0 else hparams_head.local_pool_stride 25 | if(hparams_head.local_pool_max): 26 | self.pool = torch.nn.MaxPool1d(kernel_size=hparams_head.local_pool_kernel_size,stride=hparams_head.local_pool_stride if hparams_head.local_pool_stride!=0 else hparams_head.local_pool_kernel_size,padding=(hparams_head.local_pool_kernel_size-1)//2) 27 | else: 28 | self.pool = torch.nn.AvgPool1d(kernel_size=hparams_head.local_pool_kernel_size,stride=hparams_head.local_pool_stride if hparams_head.local_pool_stride!=0 else hparams_head.local_pool_kernel_size,padding=(hparams_head.local_pool_kernel_size-1)//2) 29 | else:#global pool 30 | if(hparams_head.local_pool_max): 31 | self.pool = torch.nn.AdaptiveMaxPool1d(1) 32 | else: 33 | self.pool = torch.nn.AdaptiveAvgPool1d(1) 34 | self.linear = nn.Linear(hparams_input_shape.channels, target_dim) if hparams_head.output_layer else nn.Identity() 35 | 36 | self.output_shape = dataclasses.replace(hparams_input_shape) 37 | self.output_shape.channels = self.output_dim 38 | #assert(hparams.predictor._target_!="clinical_ts.ts.transformer.TransformerPredictor" or (hparams.predictor.cls_token is True or (hparams.predictor.cls_token is False and (hparams_head.head_pooling_type!="cls" and hparams_head.head_pooling_type!="meanmax-cls")))) 39 | 40 | self.output_shape.length = int(np.floor((hparams_input_shape.length + 2*self.local_pool_padding- self.local_pool_kernel_size)/self.local_pool_stride+1)) if self.local_pool else 0 41 | 42 | def forward(self, **kwargs): 43 | seq = kwargs["seq"] 44 | #input has shape B,S,E 45 | seq = seq.transpose(1,2) 46 | seq = self.pool(seq) 47 | return {"seq": self.linear(seq.transpose(1,2))}#return B,S,E 48 | 49 | def get_output_shape(self): 50 | return self.output_shape 51 | 52 | @dataclass 53 | class PoolingHeadConfig(HeadBaseConfig): 54 | _target_:str = "clinical_ts.ts.head.PoolingHead" 55 | 56 | multi_prediction:bool = False #local pool vs. global pool 57 | local_pool_max:bool = False #max pool vs. avg pool 58 | local_pool_kernel_size: int = 0 #kernel size for local pooling 59 | local_pool_stride: int = 0 #kernel_size if 0 60 | #local_pool_padding=(kernel_size-1)//2 61 | output_layer: bool = False 62 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/s4.py: -------------------------------------------------------------------------------- 1 | __all__ = ['S4Predictor','S4PredictorConfig'] 2 | 3 | from .s4_modules.s4_model import S4Model 4 | from ..template_modules import PredictorBase, PredictorBaseConfig 5 | from typing import Any 6 | from dataclasses import dataclass 7 | 8 | class S4Predictor(PredictorBase): 9 | def __init__(self, hparams_predictor, hparams_input_shape): 10 | super().__init__(hparams_predictor, hparams_input_shape) 11 | self.predictor = S4Model( 12 | d_input = hparams_input_shape.channels if hparams_input_shape.channels!=hparams_predictor.model_dim else None,#modified 13 | d_output = None, 14 | d_state = hparams_predictor.state_dim, 15 | d_model = hparams_predictor.model_dim, 16 | n_layers = hparams_predictor.layers, 17 | dropout = hparams_predictor.dropout, 18 | tie_dropout = hparams_predictor.tie_dropout, 19 | prenorm = hparams_predictor.prenorm, 20 | l_max = hparams_input_shape.length, 21 | transposed_input = False, 22 | bidirectional=not(hparams_predictor.causal), 23 | layer_norm=not(hparams_predictor.batchnorm), 24 | pooling = False, 25 | backbone = hparams_predictor.backbone) #note: only apply linear layer before if feature dimensions do not match 26 | 27 | def forward(self, **kwargs): 28 | return {"seq": self.predictor(kwargs["seq"])} 29 | 30 | @dataclass 31 | class S4PredictorConfig(PredictorBaseConfig): 32 | _target_:str = "clinical_ts.ts.s4.S4Predictor" 33 | model_dim:int = 512 34 | causal: bool = True #use bidirectional predictor 35 | state_dim:int = 64 #help="S4: N") 36 | layers:int = 4 37 | dropout:float=0.2 38 | tie_dropout:bool=True 39 | prenorm:bool=False 40 | batchnorm:bool=False 41 | backbone:str="s42" #help="s4original/s4new/s4d") -------------------------------------------------------------------------------- /src/clinical_ts/ts/s4_modules/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/s4_modules/s4_model.py: -------------------------------------------------------------------------------- 1 | __all__ = ['S4Model'] 2 | #adapted from https://github.com/HazyResearch/state-spaces/blob/main/example.py 3 | 4 | import torch.nn as nn 5 | 6 | from .s42 import S4 as S42 7 | 8 | from .s4_utils import DropoutNd 9 | 10 | class S4Model(nn.Module): 11 | 12 | def __init__( 13 | self, 14 | d_input, # None to disable encoder 15 | d_output, # None to disable decoder 16 | d_state=64, #MODIFIED: N 17 | d_model=512, #MODIFIED: H 18 | n_layers=4, 19 | dropout=0.2, 20 | tie_dropout=False, #MODIFIED 21 | prenorm=False, 22 | l_max=1024, 23 | transposed_input=True, # behaves like 1d CNN if True else like a RNN with batch_first=True 24 | bidirectional=True, #MODIFIED 25 | layer_norm = True, # MODIFIED 26 | pooling = True, # MODIFIED 27 | backbone= "s42" # MODIFIED 28 | ): 29 | super().__init__() 30 | 31 | self.prenorm = prenorm 32 | 33 | # Linear encoder (d_input = 1 for grayscale and 3 for RGB) 34 | self.transposed_input = transposed_input 35 | 36 | # MODIFIED TO ALLOW FOR MODELS WITHOUT ENCODER 37 | if(d_input is None): 38 | self.encoder = nn.Identity() 39 | else: 40 | self.encoder = nn.Conv1d(d_input, d_model, 1) if transposed_input else nn.Linear(d_input, d_model) 41 | 42 | # Stack S4 layers as residual blocks 43 | self.s4_layers = nn.ModuleList() 44 | self.norms = nn.ModuleList() 45 | self.dropouts = nn.ModuleList() 46 | for _ in range(n_layers): 47 | if(backbone == "s42"): 48 | self.s4_layers.append( 49 | S42( 50 | d_state=d_state, 51 | l_max=l_max, 52 | d_model=d_model, 53 | bidirectional=bidirectional, 54 | postact='glu', 55 | dropout=dropout, 56 | tie_dropout=tie_dropout, 57 | transposed=True, 58 | )) 59 | 60 | #MODIFIED TO ALLOW BATCH NORM MODELS 61 | self.layer_norm = layer_norm 62 | if(layer_norm): 63 | self.norms.append(nn.LayerNorm(d_model)) 64 | else: #MODIFIED 65 | self.norms.append(nn.BatchNorm1d(d_model)) 66 | 67 | if(tie_dropout): 68 | self.dropouts.append(DropoutNd(dropout,transposed=True)) 69 | else: 70 | self.dropouts.append(nn.Dropout(dropout)) 71 | 72 | self.pooling = pooling 73 | # Linear decoder 74 | # MODIFIED TO ALLOW FOR MODELS WITHOUT DECODER 75 | if(d_output is None): 76 | self.decoder = None 77 | else: 78 | self.decoder = nn.Linear(d_model, d_output) 79 | 80 | #MODIFIED 81 | def forward(self, x, rate=1.0): 82 | """ 83 | Input x is shape (B, d_input, L) if transposed_input else (B, L, d_input) 84 | """ 85 | x = self.encoder(x) # (B, d_input, L) -> (B, d_model, L) if transposed_input else (B, L, d_input) -> (B, L, d_model) 86 | 87 | if(self.transposed_input is False): 88 | x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L) 89 | 90 | for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts): 91 | # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L) 92 | 93 | z = x 94 | if self.prenorm: 95 | # Prenorm 96 | # MODIFIED 97 | z = norm(z.transpose(-1, -2)).transpose(-1, -2) if self.layer_norm else norm(z) 98 | 99 | # Apply S4 block: we ignore the state input and output 100 | # MODIFIED 101 | z, _ = layer(z, rate=rate) 102 | 103 | # Dropout on the output of the S4 block 104 | z = dropout(z) 105 | 106 | # Residual connection 107 | x = z + x 108 | 109 | if not self.prenorm: 110 | # Postnorm 111 | # MODIFIED 112 | x = norm(x.transpose(-1, -2)).transpose(-1, -2) if self.layer_norm else norm(z) 113 | 114 | x = x.transpose(-1, -2) # (B, d_model, L) -> (B, L, d_model) 115 | 116 | # MODIFIED ALLOW TO DISABLE POOLING 117 | if(self.pooling): 118 | # Pooling: average pooling over the sequence length 119 | x = x.mean(dim=1) 120 | 121 | # Decode the outputs 122 | if(self.decoder is not None): 123 | x = self.decoder(x) # (B, d_model) -> (B, d_output) if pooling else (B, L, d_model) -> (B, L, d_output) 124 | 125 | if(not self.pooling and self.transposed_input is True): 126 | x = x.transpose(-1, -2) # (B, L, d_output) -> (B, d_output, L) 127 | return x 128 | -------------------------------------------------------------------------------- /src/clinical_ts/ts/s4_modules/s4_utils.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/state-spaces/s4/blob/main/models/s4/s4.py 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | class DropoutNd(nn.Module): 7 | def __init__(self, p: float = 0.5, tie=True, transposed=True): 8 | """ 9 | tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) 10 | """ 11 | super().__init__() 12 | if p < 0 or p >= 1: 13 | raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) 14 | self.p = p 15 | self.tie = tie 16 | self.transposed = transposed 17 | self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) 18 | 19 | def forward(self, X): 20 | """X: (batch, dim, lengths...).""" 21 | if self.training: 22 | if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') 23 | mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape 24 | mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p 25 | X = X * mask * (1.0/(1-self.p)) 26 | if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') 27 | return X 28 | return X -------------------------------------------------------------------------------- /src/clinical_ts/ts/transformer_modules/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/utils/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/clinical_ts/utils/bootstrap_utils.py: -------------------------------------------------------------------------------- 1 | __all__ = ['empirical_bootstrap'] 2 | 3 | import numpy as np 4 | from sklearn.utils import resample 5 | from multiprocessing import Pool 6 | from functools import partial 7 | from tqdm.auto import tqdm 8 | 9 | def _eval(ids, input_tuple, score_fn, input_tuple2=None,score_fn_kwargs={}): 10 | return score_fn(*[t[ids] for t in input_tuple],**score_fn_kwargs) if input_tuple2 is None else score_fn(*[t[ids] for t in input_tuple],**score_fn_kwargs)-score_fn(*[t[ids] for t in input_tuple2],**score_fn_kwargs) 11 | 12 | def empirical_bootstrap(input_tuple, score_fn, ids=None, n_iterations=1000, alpha=0.95, score_fn_kwargs={},threads=None, input_tuple2=None, ignore_nans=False, chunksize=50): 13 | ''' 14 | performs empirical bootstrap https://ocw.mit.edu/courses/mathematics/18-05-introduction-to-probability-and-statistics-spring-2014/readings/MIT18_05S14_Reading24.pdf 15 | 16 | input_tuple: tuple of inputs for the score function typically something like (labels,predictions) 17 | score_function: scoring function that takes the individual entries of input tuple as argument e.g. f1_score 18 | id: list of previously sampled ids (if None new ids will be sampled) 19 | n_iterations: number of bootstrap iterations 20 | alpha: alpha-level for the confidence intervals 21 | score_fn_kwargs: additional (static) kwargs to be passed to the score_fn 22 | threads: number of threads (None uses os.cpu_count()); 0 no multithreading 23 | input_tuple2: if not None this is a second input of the same shape as input_tuple- in that case the function bootstraps the score difference between both inputs (this is just a convenience function- the same could be achieved by passing a tuple of the form (label,preds1,preds2) and computing the difference in the score_function itself) 24 | ignore_nans: ignore nans (e.g. no positives during during AUC evaluation) for score evaluation 25 | chunksize: process in chunks of size chunksize 26 | ''' 27 | 28 | if(not(isinstance(input_tuple,tuple))): 29 | input_tuple = (input_tuple,) 30 | if(input_tuple2 is not None and not(isinstance(input_tuple2,tuple))): 31 | input_tuple2 = (input_tuple2,) 32 | 33 | score_point = score_fn(*input_tuple,**score_fn_kwargs) if input_tuple2 is None else score_fn(*input_tuple,**score_fn_kwargs)-score_fn(*input_tuple2,**score_fn_kwargs) 34 | 35 | if(n_iterations==0): 36 | return score_point,np.zeros(score_point.shape),np.zeros(score_point.shape),[] 37 | 38 | if(ids is None): 39 | ids = [] 40 | for _ in range(n_iterations): 41 | ids.append(resample(range(len(input_tuple[0])), n_samples=len(input_tuple[0]))) 42 | ids = np.array(ids) 43 | 44 | fn = partial(_eval,input_tuple=input_tuple,score_fn=score_fn,input_tuple2=input_tuple2,score_fn_kwargs=score_fn_kwargs) 45 | 46 | if(threads is not None and threads==0): 47 | results= np.array(fn(ids)).astype(np.float32)#shape: bootstrap_iterations, number_of_evaluation_metrics 48 | else: 49 | results=[] 50 | for istart in tqdm(np.arange(0,n_iterations,chunksize)): 51 | iend = min(n_iterations,istart+chunksize) 52 | pool = Pool(threads) 53 | results.append(np.array(pool.map(fn, ids[istart:iend])).astype(np.float32)) 54 | pool.close() 55 | pool.join() 56 | 57 | results = np.concatenate(results,axis=0) 58 | 59 | percentile_fn = np.nanpercentile if ignore_nans else np.percentile 60 | score_diff = np.array(results)- score_point 61 | score_low = score_point + percentile_fn(score_diff, ((1.0-alpha)/2.0) * 100,axis=0) 62 | score_high = score_point + percentile_fn(score_diff, (alpha+((1.0-alpha)/2.0)) * 100,axis=0) 63 | 64 | if(ignore_nans):#in this case return the number of nans in each score rather than the sampled ids (which could be different when evaluating several metrics at once) 65 | return score_point, score_low, score_high, np.sum(np.isnan(score_diff),axis=0) 66 | else: 67 | return score_point, score_low, score_high, ids -------------------------------------------------------------------------------- /src/clinical_ts/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | __all__ = ['ForwardHook','TriggerQuantizerHyperparameterUpdate','UnfreezingFinetuningCallback','LRMonitorCallback', 'cos_anneal', 'DecayLR',"freeze_bn_stats","sanity_check"] 2 | 3 | import torch 4 | from torch import nn 5 | import lightning.pytorch as lp 6 | import math 7 | 8 | from lightning.pytorch.callbacks import Callback, BaseFinetuning 9 | 10 | class ForwardHook: 11 | "Create a forward hook on module `m` " 12 | 13 | def __init__(self, m, store_output=True): 14 | self.store_output = store_output 15 | self.hook = m.register_forward_hook(self.hook_fn) 16 | self.stored, self.removed = None, False 17 | 18 | def hook_fn(self, module, input, output): 19 | "stores input/output" 20 | if self.store_output: 21 | self.stored = output 22 | else: 23 | self.stored = input 24 | 25 | def remove(self): 26 | "Remove the hook from the model." 27 | if not self.removed: 28 | self.hook.remove() 29 | self.removed = True 30 | 31 | def __enter__(self, *args): 32 | return self 33 | 34 | def __exit__(self, *args): 35 | self.remove() 36 | 37 | class TriggerQuantizerHyperparameterUpdate(Callback): 38 | def __init__(self,quantizer_modules): 39 | super(TriggerQuantizerHyperparameterUpdate, self).__init__() 40 | self.modules = quantizer_modules 41 | 42 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): 43 | for m in self.modules: 44 | m.update_hyperparams(trainer.global_step) 45 | 46 | class UnfreezingFinetuningCallback(BaseFinetuning): 47 | 48 | def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True): 49 | super().__init__() 50 | self.unfreeze_epoch = unfreeze_epoch 51 | self.train_bn = train_bn 52 | 53 | def freeze_before_training(self, pl_module: lp.LightningModule): 54 | modules = pl_module.get_params(modules=True) 55 | for mod in modules[1:]: 56 | self.freeze(mod["params"], train_bn=self.train_bn) 57 | 58 | def finetune_function(self, pl_module: lp.LightningModule, epoch: int, optimizer: torch.optim.Optimizer, opt_idx: int): 59 | """Called on every epoch starts.""" 60 | if epoch == self.unfreeze_epoch: 61 | modules = pl_module.get_params(modules=True) 62 | for mod in modules[1:]: 63 | self.unfreeze_and_add_param_group( 64 | mod["params"], 65 | optimizer, 66 | lr=mod["lr"]*optimizer.param_groups[0]["lr"]/modules[0]["lr"], 67 | train_bn=self.train_bn, 68 | ) 69 | 70 | class LRMonitorCallback(Callback): 71 | def __init__(self,interval="epoch",start=True,end=True): 72 | super().__init__() 73 | self.interval = interval 74 | self.start = start 75 | self.end = end 76 | 77 | def on_train_batch_start(self, trainer, *args, **kwargs): 78 | if(self.interval == "step" and self.start): 79 | current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups] 80 | print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs) 81 | 82 | def on_train_epoch_start(self, trainer, *args, **kwargs): 83 | if(self.interval == "epoch" and self.start): 84 | current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups] 85 | print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs) 86 | 87 | def on_train_batch_end(self, trainer, *args, **kwargs): 88 | if(self.interval == "step" and self.end): 89 | current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups] 90 | print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs) 91 | 92 | def on_train_epoch_end(self, trainer, *args, **kwargs): 93 | if(self.interval == "epoch" and self.end): 94 | current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups] 95 | print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs) 96 | 97 | ############################################################################################################ 98 | def freeze_bn_stats(model, freeze=True): 99 | for m in model.modules(): 100 | if(isinstance(m,nn.BatchNorm1d)): 101 | if(freeze): 102 | m.eval() 103 | else: 104 | m.train() 105 | 106 | ############################################################################################################ 107 | def sanity_check(model, state_dict_pre): 108 | """ 109 | Linear classifier should not change any weights other than the linear layer. 110 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 111 | """ 112 | print("=> loading state dict for sanity check") 113 | state_dict = model.state_dict() 114 | 115 | for k in list(state_dict.keys()): 116 | # only ignore fc layer 117 | if 'head.1.weight' in k or 'head.1.bias' in k or 'head.4.weight' in k or 'head.4.bias' in k: 118 | continue 119 | 120 | 121 | assert ((state_dict[k].cpu() == state_dict_pre[k].cpu()).all()), \ 122 | '{} is changed in linear classifier training.'.format(k) 123 | 124 | print("=> sanity check passed.") 125 | 126 | ############################################################################################################ 127 | #from https://github.com/karpathy/deep-vector-quantization/blob/main/dvq/vqvae.py 128 | # ----------------------------------------------------------------------------- 129 | def cos_anneal(e0, e1, t0, t1, e): 130 | """ ramp from (e0, t0) -> (e1, t1) through a cosine schedule based on e in [e0, e1] """ 131 | alpha = max(0, min(1, (e - e0) / (e1 - e0))) # what fraction of the way through are we 132 | alpha = 1.0 - math.cos(alpha * math.pi/2) # warp through cosine 133 | t = alpha * t1 + (1 - alpha) * t0 # interpolate accordingly 134 | return t 135 | 136 | class DecayLR(Callback): 137 | def __init__(self,num_steps=1200000,lrstart=3e-4,lrend=1.25e-6): 138 | super(DecayLR, self).__init__() 139 | self.num_steps = num_steps 140 | self.lrstart = lrstart 141 | self.lrend = lrend 142 | 143 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 144 | # The step size is annealed from 1e10−4 to 1.25e10−6 over 1,200,000 updates. I use 3e-4 145 | t = cos_anneal(0, self.num_steps, self.lrstart, self.lrend, trainer.global_step) 146 | for g in pl_module.model_cpc.optimizer.param_groups: 147 | g['lr'] = t -------------------------------------------------------------------------------- /src/clinical_ts/utils/eval_utils_cafa.py: -------------------------------------------------------------------------------- 1 | __all__ = ['auc_prrc_uninterpolated', 'multiclass_roc_curve', 'single_eval_prrc', 'eval_prrc', 'eval_prrc_parallel', 2 | 'eval_scores', 'eval_scores_bootstrap'] 3 | 4 | # Cell 5 | import warnings 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.metrics import roc_auc_score, auc 9 | from scipy.interpolate import interp1d 10 | 11 | from sklearn.metrics import roc_curve, precision_recall_curve 12 | from sklearn.utils import resample 13 | 14 | from tqdm import tqdm 15 | 16 | # Cell 17 | def auc_prrc_uninterpolated(recall,precision): 18 | '''uninterpolated auc as used by sklearn https://github.com/scikit-learn/scikit-learn/blob/1495f6924/sklearn/metrics/ranking.py see also the discussion at https://github.com/scikit-learn/scikit-learn/pull/9583''' 19 | #print(-np.sum(np.diff(recall) * np.array(precision)[:-1]),auc(recall,precision)) 20 | return -np.sum(np.diff(recall) * np.array(precision)[:-1]) 21 | 22 | # Cell 23 | #label-centric metrics 24 | def multiclass_roc_curve(y_true, y_pred, classes=None, precision_recall=False): 25 | '''Compute ROC curve and ROC area for each class "0"..."n_classes - 1" (or classnames passed via classes), "micro", "macro" 26 | returns fpr,tpr,roc (dictionaries) for ROC 27 | returns recall,precision,average_precision for precision_recall 28 | ''' 29 | 30 | fpr = dict() 31 | tpr = dict() 32 | roc_auc = dict() 33 | n_classes=len(y_pred[0]) 34 | if(classes is None): 35 | classes = [str(i) for i in range(n_classes)] 36 | 37 | for i,c in enumerate(classes): 38 | y_truei = y_true[:, i] 39 | y_predi = y_pred[:, i] 40 | 41 | maski = ~np.isnan(y_truei)#mask our nan targets if available 42 | y_truei = y_truei[maski] 43 | y_predi = y_predi[maski] 44 | 45 | if(precision_recall): 46 | tpr[c], fpr[c], _ = precision_recall_curve(y_truei, y_predi) 47 | roc_auc[c] = auc_prrc_uninterpolated(fpr[c], tpr[c]) 48 | else: 49 | fpr[c], tpr[c], _ = roc_curve(y_truei, y_predi) 50 | roc_auc[c] = auc(fpr[c], tpr[c]) 51 | 52 | # Compute micro-average curve and area 53 | y_true_micro = y_true.ravel() 54 | y_pred_micro = y_pred.ravel() 55 | mask_micro = ~np.isnan(y_true_micro) 56 | y_true_micro = y_true_micro[mask_micro] 57 | y_pred_micro = y_pred_micro[mask_micro] 58 | 59 | if(precision_recall): 60 | tpr["micro"], fpr["micro"], _ = precision_recall_curve(y_true_micro, y_pred_micro) 61 | roc_auc["micro"] = auc_prrc_uninterpolated(fpr["micro"], tpr["micro"]) 62 | else: 63 | fpr["micro"], tpr["micro"], _ = roc_curve(y_true_micro, y_pred_micro) 64 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 65 | 66 | # Compute macro-average curve and area (linear interpolation is incorrect for PRRC- therefore just for ROC) 67 | if(precision_recall is False): 68 | # 1. First aggregate all unique x values (false positive rates for ROC) 69 | all_fpr = np.unique(np.concatenate([fpr[c] for c in classes])) 70 | 71 | # 2. Then interpolate all curves at this points 72 | mean_tpr=None 73 | for c in classes: 74 | f = interp1d(fpr[c], tpr[c]) 75 | if(mean_tpr is None): 76 | mean_tpr = f(all_fpr) 77 | else: 78 | mean_tpr += f(all_fpr) 79 | 80 | # 3. Finally average it and compute area 81 | mean_tpr /= n_classes 82 | 83 | fpr["macro"] = all_fpr 84 | tpr["macro"] = mean_tpr 85 | #macro2 differs slightly from macro due to interpolation effects 86 | #roc_auc["macro2"] = auc(fpr["macro"], tpr["macro"]) 87 | 88 | #calculate macro auc directly by summing 89 | roc_auc_macro = 0 90 | for c in classes: 91 | roc_auc_macro += roc_auc[c] 92 | roc_auc["macro"]=roc_auc_macro/n_classes 93 | 94 | #calculate macro auc directly by summing 95 | roc_auc_macro = 0 96 | macro_auc_nans = 0 #due to an insufficient amount of pos/neg labels 97 | for c in classes: 98 | if(np.isnan(roc_auc[c])):#conservative choice: replace auc by 0.5 if it could not be calculated 99 | roc_auc_macro += 0.5 100 | macro_auc_nans += 1 101 | else: 102 | roc_auc_macro += roc_auc[c] 103 | roc_auc["macro"]=roc_auc_macro/n_classes 104 | #roc_auc["macro_nans"] = macro_auc_nans 105 | 106 | return fpr, tpr, roc_auc 107 | 108 | # Cell 109 | def single_eval_prrc(y_true,y_pred,threshold): 110 | '''evaluate instance-wise scores for a single sample and a single threshold''' 111 | y_pred_bin = (y_pred >= threshold) 112 | TP = np.sum(np.logical_and(y_true == y_pred_bin,y_true>0)) 113 | count = np.sum(y_pred_bin)#TP+FP 114 | 115 | # Find precision: TP / (TP + FP) 116 | precision = TP / count if count > 0 else np.nan 117 | # Find recall/TPR/sensitivity: TP / (TP + FN) 118 | recall = TP/np.sum(y_true>0) 119 | # Find FPR/specificity: FP/ (FP + TN)=FP/N 120 | FP = np.sum(np.logical_and(y_true != y_pred_bin,y_pred_bin>0)) 121 | specificity = FP/ np.sum(y_true==0) 122 | return precision, recall, specificity 123 | 124 | # Cell 125 | def eval_prrc(y_true,y_pred,threshold): 126 | '''eval instance-wise scores across all samples for a single threshold''' 127 | # Initialize Variables 128 | PR = 0.0 129 | RC = 0.0 130 | SP = 0.0 131 | 132 | counts_above_threshold = 0 133 | 134 | for i in range(len(y_true)): 135 | pr,rc,sp = single_eval_prrc(y_true[i],y_pred[i],threshold) 136 | if pr is not np.nan: 137 | PR += pr 138 | counts_above_threshold += 1 139 | RC += rc 140 | SP += sp 141 | 142 | recall = RC/len(y_true) 143 | specificity = SP/len(y_true) 144 | 145 | if counts_above_threshold > 0: 146 | precision = PR/counts_above_threshold 147 | else: 148 | precision = np.nan 149 | if(threshold<1.0): 150 | print("No prediction is made above the %.2f threshold\n" % threshold) 151 | return precision, recall, specificity, counts_above_threshold/len(y_true) 152 | 153 | # Cell 154 | def eval_prrc_parallel(y_true,y_pred,thresholds): 155 | 156 | y_pred_bin = np.repeat(y_pred[None, :, :], len(thresholds), axis=0)>=thresholds[:,None,None]#thresholds, samples, classes 157 | TP = np.sum(np.logical_and( y_true == True, y_pred_bin== True),axis=2)#threshold, samples 158 | 159 | with np.errstate(divide='ignore', invalid='ignore'): 160 | den = np.sum(y_pred_bin,axis=2)>0 161 | precision = TP/np.sum(y_pred_bin,axis=2) 162 | precision[den==0] = np.nan 163 | 164 | recall = TP/np.sum(y_true==True, axis=1)#threshold,samples/samples=threshold,samples 165 | 166 | FP = np.sum(np.logical_and((y_true ==False),(y_pred_bin==True)),axis=2) 167 | specificity = FP/np.sum(y_true==False, axis=1) 168 | 169 | with warnings.catch_warnings(): #for nan slices 170 | warnings.simplefilter("ignore", category=RuntimeWarning) 171 | av_precision = np.nanmean(precision,axis=1) 172 | 173 | av_recall = np.mean(recall,axis=1) 174 | av_specificity = np.mean(specificity,axis=1) 175 | av_coverage = np.mean(den,axis=1) 176 | 177 | return av_precision, av_recall, av_specificity, av_coverage 178 | 179 | 180 | # Cell 181 | def eval_scores(y_true,y_pred,classes=None,num_thresholds=100,full_output=False,parallel=True): 182 | '''returns a dictionary of performance metrics: 183 | sample centric c.f. https://github.com/ashleyzhou972/CAFA_assessment_tool/blob/master/precrec/precRec.py 184 | https://www.nature.com/articles/nmeth.2340 vs https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3694662/ and https://arxiv.org/pdf/1601.00891 185 | * Fmax, sample AUC, sample Average Precision (as in sklearn) 186 | 187 | label-centric: micro,macro,individual AUC and Average Precision 188 | ''' 189 | results = {} 190 | 191 | # thresholds = np.arange(0.00, 1.01, 1./num_thresholds, float) 192 | # if(parallel is False): 193 | # PR = np.zeros(len(thresholds)) 194 | # RC = np.zeros(len(thresholds)) 195 | # SP = np.zeros(len(thresholds)) 196 | # COV = np.zeros(len(thresholds)) 197 | 198 | # for i,t in enumerate(thresholds): 199 | # PR[i],RC[i],SP[i],COV[i] = eval_prrc(y_true,y_pred,t) 200 | # F = (2*PR*RC)/(PR+RC) 201 | # else: 202 | # PR,RC,SP,COV = eval_prrc_parallel(y_true,y_pred,thresholds) 203 | # F = (2*PR*RC)/(PR+RC) 204 | 205 | # if(full_output is True): 206 | # results["PR"] = PR 207 | # results["RC"] = RC 208 | # results["SP"] = SP 209 | # results["F"] = F 210 | # results["COV"] = COV 211 | 212 | # if np.isnan(F).sum() == len(F): 213 | # results["Fmax"] = 0 214 | # results["precision_at_Fmax"] = 0 215 | # results["recall_at_Fmax"] = 0 216 | # results["threshold_at_Fmax"] = 0 217 | # results["coverage_at_Fmax"]= 0 218 | # else: 219 | # imax = np.nanargmax(F) 220 | # results["Fmax"] = F[imax] 221 | # results["precision_at_Fmax"] = PR[imax] 222 | # results["recall_at_Fmax"] = RC[imax] 223 | # results["threshold_at_Fmax"] = thresholds[imax] 224 | # results["coverage_at_Fmax"]=COV[imax] 225 | 226 | # results["sample_AUC"]=auc(1-SP,RC) 227 | # #https://github.com/scikit-learn/scikit-learn/blob/1495f6924/sklearn/metrics/ranking.py set final PR value to 1 228 | # PR[-1]=1 229 | # results["sample_APR"]=auc_prrc_uninterpolated(RC,PR)#skip last point with undefined precision 230 | ########################################################### 231 | #label-centric 232 | #"micro","macro",i=0...n_classes-1 233 | fpr, tpr, roc_auc = multiclass_roc_curve(y_true, y_pred,classes=classes,precision_recall=False) 234 | if(full_output is True): 235 | results["fpr"]=fpr 236 | results["tpr"]=tpr 237 | results["label_AUC"]=roc_auc 238 | 239 | # rc, pr, prrc_auc = multiclass_roc_curve(y_true, y_pred,classes=classes,precision_recall=True) 240 | # if(full_output is True): 241 | # results["pr"]=pr 242 | # results["rc"]=rc 243 | # results["label_APR"]=prrc_auc 244 | 245 | return results 246 | 247 | # Cell 248 | def eval_scores_bootstrap(y_true, y_pred,classes=None, n_iterations = 10000, alpha=0.95): 249 | #https://ocw.mit.edu/courses/mathematics/18-05-introduction-to-probability-and-statistics-spring-2014/readings/MIT18_05S14_Reading24.pdf empirical bootstrap rather than bootstrap percentiles 250 | Fmax_diff = [] 251 | sample_AUC_diff = [] 252 | sample_APR_diff = [] 253 | label_AUC_diff = [] 254 | label_APR_diff = [] 255 | label_AUC_keys = None 256 | 257 | #point estimate 258 | res_point = eval_scores(y_true,y_pred,classes=classes) 259 | Fmax_point = res_point["Fmax"] 260 | sample_AUC_point = res_point["sample_AUC"] 261 | sample_APR_point = res_point["sample_APR"] 262 | label_AUC_point = np.array(list(res_point["label_AUC"].values())) 263 | label_APR_point = np.array(list(res_point["label_APR"].values())) 264 | 265 | #bootstrap 266 | for i in tqdm(range(n_iterations)): 267 | ids = resample(range(len(y_true)), n_samples=len(y_true)) 268 | res = eval_scores(y_true[ids],y_pred[ids],classes=classes) 269 | Fmax_diff.append(res["Fmax"]-Fmax_point) 270 | sample_AUC_diff.append(res["sample_AUC"]-sample_AUC_point) 271 | sample_APR_diff.append(res["sample_APR"]-sample_APR_point) 272 | label_AUC_keys = list(res["label_AUC"].keys()) 273 | label_AUC_diff.append(np.array(list(res["label_AUC"].values()))-label_AUC_point) 274 | label_APR_diff.append(np.array(list(res["label_APR"].values()))-label_APR_point) 275 | 276 | p = ((1.0-alpha)/2.0) * 100 277 | Fmax_low = Fmax_point + np.percentile(Fmax_diff, p) 278 | sample_AUC_low = sample_AUC_point + np.percentile(sample_AUC_diff, p) 279 | sample_APR_low = sample_APR_point + np.percentile(sample_APR_diff, p) 280 | label_AUC_low = label_AUC_point + np.percentile(label_AUC_diff,p,axis=0) 281 | label_APR_low = label_APR_point + np.percentile(label_APR_diff,p,axis=0) 282 | p = (alpha+((1.0-alpha)/2.0)) * 100 283 | Fmax_high = Fmax_point + np.percentile(Fmax_diff, p) 284 | sample_AUC_high = sample_AUC_point + np.percentile(sample_AUC_diff, p) 285 | sample_APR_high = sample_APR_point + np.percentile(sample_APR_diff, p) 286 | label_AUC_high = label_AUC_point + np.percentile(label_AUC_diff,p,axis=0) 287 | label_APR_high = label_APR_point + np.percentile(label_APR_diff,p,axis=0) 288 | 289 | return {"Fmax":[Fmax_low,Fmax_point,Fmax_high], "sample_AUC":[sample_AUC_low,sample_AUC_point,sample_AUC_high], "sample_APR":[sample_APR_low,sample_APR_point,sample_APR_high], "label_AUC":{k:[v1,v2,v3] for k,v1,v2,v3 in zip(label_AUC_keys,label_AUC_low,label_AUC_point,label_AUC_high)}, "label_APR":{k:[v1,v2,v3] for k,v1,v2,v3 in zip(label_AUC_keys,label_APR_low,label_APR_point,label_APR_high)}} -------------------------------------------------------------------------------- /src/clinical_ts/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | __all__ = ['get_constant_schedule','get_constant_schedule_with_warmup','get_linear_schedule_with_warmup','get_cosine_schedule_with_warmup', 'get_cosine_with_hard_restarts_schedule_with_warmup', 'get_polynomial_decay_schedule_with_warmup', 'get_invsqrt_decay_schedule_with_warmup'] 2 | 3 | #adapted from https://huggingface.co/transformers/_modules/transformers/optimization.htm 4 | import math 5 | from typing import Callable, Iterable, Optional, Tuple, Union 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | from torch.optim.lr_scheduler import LambdaLR 10 | 11 | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): 12 | """ 13 | Create a schedule with a constant learning rate, using the learning rate set in optimizer. 14 | 15 | Args: 16 | optimizer (:class:`~torch.optim.Optimizer`): 17 | The optimizer for which to schedule the learning rate. 18 | last_epoch (:obj:`int`, `optional`, defaults to -1): 19 | The index of the last epoch when resuming training. 20 | 21 | Return: 22 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 23 | """ 24 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 25 | 26 | 27 | 28 | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): 29 | """ 30 | Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate 31 | increases linearly between 0 and the initial lr set in the optimizer. 32 | 33 | Args: 34 | optimizer (:class:`~torch.optim.Optimizer`): 35 | The optimizer for which to schedule the learning rate. 36 | num_warmup_steps (:obj:`int`): 37 | The number of steps for the warmup phase. 38 | last_epoch (:obj:`int`, `optional`, defaults to -1): 39 | The index of the last epoch when resuming training. 40 | 41 | Return: 42 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 43 | """ 44 | 45 | def lr_lambda(current_step: int): 46 | if current_step < num_warmup_steps: 47 | return float(1+current_step) / float(max(1.0, 1+num_warmup_steps)) 48 | return 1.0 49 | 50 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 51 | 52 | 53 | 54 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 55 | """ 56 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 57 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 58 | 59 | Args: 60 | optimizer (:class:`~torch.optim.Optimizer`): 61 | The optimizer for which to schedule the learning rate. 62 | num_warmup_steps (:obj:`int`): 63 | The number of steps for the warmup phase. 64 | num_training_steps (:obj:`int`): 65 | The total number of training steps. 66 | last_epoch (:obj:`int`, `optional`, defaults to -1): 67 | The index of the last epoch when resuming training. 68 | 69 | Return: 70 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 71 | """ 72 | 73 | def lr_lambda(current_step: int): 74 | if current_step < num_warmup_steps: 75 | return float(1+current_step) / float(max(1, 1+num_warmup_steps)) 76 | return max( 77 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 78 | ) 79 | 80 | return LambdaLR(optimizer, lr_lambda, last_epoch) 81 | 82 | 83 | 84 | def get_cosine_schedule_with_warmup( 85 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 86 | ): 87 | """ 88 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 89 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 90 | initial lr set in the optimizer. 91 | 92 | Args: 93 | optimizer (:class:`~torch.optim.Optimizer`): 94 | The optimizer for which to schedule the learning rate. 95 | num_warmup_steps (:obj:`int`): 96 | The number of steps for the warmup phase. 97 | num_training_steps (:obj:`int`): 98 | The total number of training steps. 99 | num_cycles (:obj:`float`, `optional`, defaults to 0.5): 100 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 101 | following a half-cosine). 102 | last_epoch (:obj:`int`, `optional`, defaults to -1): 103 | The index of the last epoch when resuming training. 104 | 105 | Return: 106 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 107 | """ 108 | 109 | def lr_lambda(current_step): 110 | if current_step < num_warmup_steps: 111 | return float(1+current_step) / float(max(1, 1+num_warmup_steps)) 112 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 113 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 114 | 115 | return LambdaLR(optimizer, lr_lambda, last_epoch) 116 | 117 | 118 | 119 | def get_cosine_with_hard_restarts_schedule_with_warmup( 120 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 121 | ): 122 | """ 123 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 124 | initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases 125 | linearly between 0 and the initial lr set in the optimizer. 126 | 127 | Args: 128 | optimizer (:class:`~torch.optim.Optimizer`): 129 | The optimizer for which to schedule the learning rate. 130 | num_warmup_steps (:obj:`int`): 131 | The number of steps for the warmup phase. 132 | num_training_steps (:obj:`int`): 133 | The total number of training steps. 134 | num_cycles (:obj:`int`, `optional`, defaults to 1): 135 | The number of hard restarts to use. 136 | last_epoch (:obj:`int`, `optional`, defaults to -1): 137 | The index of the last epoch when resuming training. 138 | 139 | Return: 140 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 141 | """ 142 | 143 | def lr_lambda(current_step): 144 | if current_step < num_warmup_steps: 145 | return float(current_step) / float(max(1, num_warmup_steps)) 146 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 147 | if progress >= 1.0: 148 | return 0.0 149 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 150 | 151 | return LambdaLR(optimizer, lr_lambda, last_epoch) 152 | 153 | 154 | 155 | def get_polynomial_decay_schedule_with_warmup( 156 | optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 157 | ): 158 | """ 159 | Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the 160 | optimizer to end lr defined by `lr_end`, after a warmup period during which it increases linearly from 0 to the 161 | initial lr set in the optimizer. 162 | 163 | Args: 164 | optimizer (:class:`~torch.optim.Optimizer`): 165 | The optimizer for which to schedule the learning rate. 166 | num_warmup_steps (:obj:`int`): 167 | The number of steps for the warmup phase. 168 | num_training_steps (:obj:`int`): 169 | The total number of training steps. 170 | lr_end (:obj:`float`, `optional`, defaults to 1e-7): 171 | The end LR. 172 | power (:obj:`float`, `optional`, defaults to 1.0): 173 | Power factor. 174 | last_epoch (:obj:`int`, `optional`, defaults to -1): 175 | The index of the last epoch when resuming training. 176 | 177 | Note: `power` defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT 178 | implementation at 179 | https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 180 | 181 | Return: 182 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 183 | 184 | """ 185 | 186 | lr_init = optimizer.defaults["lr"] 187 | assert lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})" 188 | 189 | def lr_lambda(current_step: int): 190 | if current_step < num_warmup_steps: 191 | return float(1+current_step) / float(max(1, 1+num_warmup_steps)) 192 | elif current_step > num_training_steps: 193 | return lr_end / lr_init # as LambdaLR multiplies by lr_init 194 | else: 195 | lr_range = lr_init - lr_end 196 | decay_steps = num_training_steps - num_warmup_steps 197 | pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps 198 | decay = lr_range * pct_remaining ** power + lr_end 199 | return decay / lr_init # as LambdaLR multiplies by lr_init 200 | 201 | return LambdaLR(optimizer, lr_lambda, last_epoch) 202 | 203 | def get_invsqrt_decay_schedule_with_warmup( 204 | optimizer, num_warmup_steps, last_epoch=-1 205 | ): 206 | """ 207 | Create a schedule with a learning rate that decreases as a with an inverse sqrt law, 208 | after a warmup period during which it increases linearly from 0 to the 209 | initial lr set in the optimizer. 210 | 211 | Args: 212 | optimizer (:class:`~torch.optim.Optimizer`): 213 | The optimizer for which to schedule the learning rate. 214 | num_warmup_steps (:obj:`int`): 215 | The number of steps for the warmup phase. 216 | last_epoch (:obj:`int`, `optional`, defaults to -1): 217 | The index of the last epoch when resuming training. 218 | 219 | 220 | Return: 221 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 222 | 223 | """ 224 | decay_factor = num_warmup_steps ** 0.5 225 | lr_init = optimizer.defaults["lr"] 226 | 227 | def lr_lambda(current_step: int): 228 | if current_step < num_warmup_steps: 229 | return float(1+current_step) / float(max(1, 1+num_warmup_steps)) 230 | else: 231 | return decay_factor* current_step ** -0.5 232 | 233 | return LambdaLR(optimizer, lr_lambda, last_epoch) 234 | -------------------------------------------------------------------------------- /src/config/config_supervised_multimodal_labvalues_s4.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base: base 3 | - data@data0: mimic_labvalues 4 | - ts@ts0: tsenc 5 | - ts/enc@ts0.enc: none 6 | - ts/pred@ts0.pred: s4 7 | - ts/head@ts0.head: pool 8 | - ts/loss@ts0.loss: none 9 | - static: mlp 10 | - head: concat 11 | - loss: bcef 12 | - trainer: trainer 13 | - task: multi 14 | - metric@metric0: aurocagg 15 | - _self_ 16 | 17 | loss: 18 | ignore_nans: True 19 | 20 | base: 21 | fs: 100. 22 | input_size: 250 23 | input_channels: 12 24 | normalize: false 25 | batch_size: 32 26 | input_channels_cat: 2 27 | input_channels_cont: 10 28 | 29 | trainer: 30 | gpus: 1 31 | refresh_rate: 0 32 | username: "nstrodt" 33 | epochs: 40 34 | precision: 32 35 | 36 | ts0: 37 | pred: 38 | causal: False 39 | state_dim: 8 40 | model_dim: 512 41 | backbone: "s42" 42 | 43 | static: 44 | vocab_sizes: [2,5] 45 | embedding_dims: [16,32] 46 | lin_ftrs: [128,128,128] 47 | 48 | metric0: 49 | bootstrap_iterations: 1000 50 | 51 | task: 52 | introduce_nan_columns: true -------------------------------------------------------------------------------- /src/config/data/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/config/data/multimodal_mdsed.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: "mimic_labvalues" 5 | path: "data/memmap" 6 | fs: 100. 7 | annotation: false 8 | cols_static: ["cont_features"] 9 | cols_static_cat: ["cat_features"] 10 | -------------------------------------------------------------------------------- /src/data/memmap/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/ecg_utils.py: -------------------------------------------------------------------------------- 1 | import wfdb 2 | 3 | import os 4 | import shutil 5 | import zipfile 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import resampy 11 | from tqdm.auto import tqdm 12 | from pathlib import Path 13 | 14 | 15 | import datetime 16 | 17 | from clinical_ts.timeseries_utils import * 18 | 19 | from sklearn.model_selection import StratifiedKFold 20 | 21 | channel_stoi_default = {"i": 0, "ii": 1, "v1":2, "v2":3, "v3":4, "v4":5, "v5":6, "v6":7, "iii":8, "avr":9, "avl":10, "avf":11, "vx":12, "vy":13, "vz":14} 22 | 23 | def get_stratified_kfolds(labels,n_splits,random_state): 24 | skf = StratifiedKFold(n_splits=n_splits,shuffle=True,random_state=random_state) 25 | return skf.split(np.zeros(len(labels)),labels) 26 | 27 | def resample_data(sigbufs, channel_labels, fs, target_fs, channels=12, channel_stoi=None):#,skimage_transform=True,interpolation_order=3): 28 | channel_labels = [c.lower() for c in channel_labels] 29 | #https://github.com/scipy/scipy/issues/7324 zoom issues 30 | factor = target_fs/fs 31 | timesteps_new = int(len(sigbufs)*factor) 32 | if(channel_stoi is not None): 33 | data = np.zeros((timesteps_new, channels), dtype=np.float32) 34 | for i,cl in enumerate(channel_labels): 35 | if(cl in channel_stoi.keys() and channel_stoi[cl]0 else tmp 47 | 48 | if(recreate_data): 49 | target_folder = Path(target_folder) 50 | target_folder.mkdir(parents=True, exist_ok=True) 51 | 52 | with zipfile.ZipFile(Path(data_path)/"mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0.zip", 'r') as archive: 53 | lst = archive.namelist() 54 | lst = [x for x in lst if x.endswith(".hea")] 55 | 56 | meta = [] 57 | for l in tqdm(lst): 58 | archive.extract(l, path="tmp_dir/") 59 | archive.extract(l[:-3]+"dat", path="tmp_dir/") 60 | filename = Path("tmp_dir")/l 61 | sigbufs, header = wfdb.rdsamp(str(filename)[:-4]) 62 | 63 | tmp={} 64 | tmp["data"]=filename.parent.parent.stem+"_"+filename.parent.stem+".npy" #patientid_study.npy 65 | tmp["study_id"]=int(filename.stem) 66 | tmp["subject_id"]=int(filename.parent.parent.stem[1:]) 67 | tmp['ecg_time']= datetime.datetime.combine(header["base_date"],header["base_time"]) 68 | tmp["nans"]= list(np.sum(np.isnan(sigbufs),axis=0))#save nans channel-dependent 69 | if(np.sum(tmp["nans"])>0):#fix nans 70 | fix_nans_and_clip(sigbufs,clip_amp=clip_amp) 71 | elif(clip_amp>0): 72 | sigbufs = np.clip(sigbufs,a_max=clip_amp,a_min=-clip_amp) 73 | 74 | data = resample_data(sigbufs=sigbufs,channel_stoi=channel_stoi,channel_labels=header['sig_name'],fs=header['fs'],target_fs=target_fs,channels=channels) 75 | 76 | assert(target_fs<=header['fs']) 77 | np.save(target_folder/tmp["data"],data) 78 | meta.append(tmp) 79 | 80 | os.unlink("tmp_dir/"+l) 81 | os.unlink("tmp_dir/"+l[:-3]+"dat") 82 | shutil.rmtree("tmp_dir") 83 | 84 | df = pd.DataFrame(meta) 85 | 86 | #random split by patients 87 | #unique_patients = np.unique(df.subject_id) 88 | #splits_patients = get_stratified_kfolds(np.zeros(len(unique_patients)),n_splits=strat_folds,random_state=42) 89 | #df["fold"]=-1 90 | #for i,split in enumerate(splits_patients): 91 | # df.loc[df.subject_id.isin(unique_patients[split[-1]]),"fold"]=i 92 | 93 | #add means and std 94 | dataset_add_mean_col(df,data_folder=target_folder) 95 | dataset_add_std_col(df,data_folder=target_folder) 96 | dataset_add_length_col(df,data_folder=target_folder) 97 | 98 | #save means and stds 99 | mean, std = dataset_get_stats(df) 100 | 101 | #save 102 | lbl_itos=[] 103 | save_dataset(df,lbl_itos,mean,std,target_folder) 104 | else: 105 | df, lbl_itos, mean, std = load_dataset(target_folder,df_mapped=False) 106 | return df, lbl_itos, mean, std 107 | 108 | -------------------------------------------------------------------------------- /src/environment.yml: -------------------------------------------------------------------------------- 1 | name: lightning2 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - brotli-python=1.0.9=py312h6a678d5_7 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2023.11.17=hbcca054_0 14 | - certifi=2023.11.17=pyhd8ed1ab_0 15 | - cffi=1.16.0=py312h5eee18b_0 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - colorama=0.4.6=pyhd8ed1ab_0 18 | - cryptography=41.0.7=py312hdda0065_0 19 | - cuda-cudart=12.1.105=0 20 | - cuda-cupti=12.1.105=0 21 | - cuda-libraries=12.1.0=0 22 | - cuda-nvrtc=12.1.105=0 23 | - cuda-nvtx=12.1.105=0 24 | - cuda-opencl=12.3.101=0 25 | - cuda-runtime=12.1.0=0 26 | - expat=2.5.0=h6a678d5_0 27 | - ffmpeg=4.3=hf484d3e_0 28 | - filelock=3.13.1=py312h06a4308_0 29 | - freetype=2.12.1=h4a9f257_0 30 | - fsspec=2023.12.2=pyhca7485f_0 31 | - giflib=5.2.1=h5eee18b_3 32 | - gmp=6.2.1=h295c915_3 33 | - gnutls=3.6.15=he1e5248_0 34 | - idna=3.4=py312h06a4308_0 35 | - intel-openmp=2023.1.0=hdb19cb5_46306 36 | - jinja2=3.1.2=py312h06a4308_0 37 | - jpeg=9e=h5eee18b_1 38 | - lame=3.100=h7b6447c_0 39 | - lcms2=2.12=h3be6417_0 40 | - ld_impl_linux-64=2.38=h1181459_1 41 | - lerc=3.0=h295c915_0 42 | - libcublas=12.1.0.26=0 43 | - libcufft=11.0.2.4=0 44 | - libcufile=1.8.1.2=0 45 | - libcurand=10.3.4.107=0 46 | - libcusolver=11.4.4.55=0 47 | - libcusparse=12.0.2.55=0 48 | - libdeflate=1.17=h5eee18b_1 49 | - libffi=3.4.4=h6a678d5_0 50 | - libgcc-ng=11.2.0=h1234567_1 51 | - libgomp=11.2.0=h1234567_1 52 | - libiconv=1.16=h7f8727e_2 53 | - libidn2=2.3.4=h5eee18b_0 54 | - libjpeg-turbo=2.0.0=h9bf148f_0 55 | - libnpp=12.0.2.50=0 56 | - libnvjitlink=12.1.105=0 57 | - libnvjpeg=12.1.1.14=0 58 | - libpng=1.6.39=h5eee18b_0 59 | - libstdcxx-ng=11.2.0=h1234567_1 60 | - libtasn1=4.19.0=h5eee18b_0 61 | - libtiff=4.5.1=h6a678d5_0 62 | - libunistring=0.9.10=h27cfd23_0 63 | - libuuid=1.41.5=h5eee18b_0 64 | - libwebp=1.3.2=h11a3e52_0 65 | - libwebp-base=1.3.2=h5eee18b_0 66 | - lightning=2.1.3=pyhd8ed1ab_1 67 | - lightning-utilities=0.10.1=pyhd8ed1ab_0 68 | - llvm-openmp=14.0.6=h9e868ea_0 69 | - lz4-c=1.9.4=h6a678d5_0 70 | - markupsafe=2.1.1=py312h5eee18b_0 71 | - mkl=2023.1.0=h213fc3f_46344 72 | - mkl-service=2.4.0=py312h5eee18b_1 73 | - mkl_fft=1.3.8=py312h5eee18b_0 74 | - mkl_random=1.2.4=py312hdb19cb5_0 75 | - mpmath=1.3.0=py312h06a4308_0 76 | - ncurses=6.4=h6a678d5_0 77 | - nettle=3.7.3=hbbd107a_1 78 | - networkx=3.1=py312h06a4308_0 79 | - numpy=1.26.3=py312hc5e2394_0 80 | - numpy-base=1.26.3=py312h0da6c21_0 81 | - openh264=2.1.1=h4ff587b_0 82 | - openjpeg=2.4.0=h3ad879b_0 83 | - openssl=3.0.12=h7f8727e_0 84 | - packaging=23.2=pyhd8ed1ab_0 85 | - pillow=10.0.1=py312ha6cbd5a_0 86 | - pip=23.3.1=py312h06a4308_0 87 | - pycparser=2.21=pyhd3eb1b0_0 88 | - pyopenssl=23.2.0=py312h06a4308_0 89 | - pysocks=1.7.1=py312h06a4308_0 90 | - python=3.12.1=h996f2a0_0 91 | - pytorch=2.2.0=py3.12_cuda12.1_cudnn8.9.2_0 92 | - pytorch-cuda=12.1=ha16c6d3_5 93 | - pytorch-lightning=2.1.3=pyhd8ed1ab_0 94 | - pytorch-mutex=1.0=cuda 95 | - pyyaml=6.0.1=py312h5eee18b_0 96 | - readline=8.2=h5eee18b_0 97 | - requests=2.31.0=py312h06a4308_0 98 | - setuptools=68.2.2=py312h06a4308_0 99 | - sqlite=3.41.2=h5eee18b_0 100 | - sympy=1.12=py312h06a4308_0 101 | - tbb=2021.8.0=hdb19cb5_0 102 | - tk=8.6.12=h1ccaba5_0 103 | - torchaudio=2.2.0=py312_cu121 104 | - torchmetrics=1.2.1=pyhd8ed1ab_0 105 | - torchvision=0.17.0=py312_cu121 106 | - tqdm=4.66.1=pyhd8ed1ab_0 107 | - typing-extensions=4.9.0=py312h06a4308_1 108 | - typing_extensions=4.9.0=py312h06a4308_1 109 | - urllib3=1.26.18=py312h06a4308_0 110 | - wheel=0.41.2=py312h06a4308_0 111 | - xz=5.4.5=h5eee18b_0 112 | - yaml=0.2.5=h7b6447c_0 113 | - zlib=1.2.13=h5eee18b_0 114 | - zstd=1.5.5=hc292b87_0 115 | - pip: 116 | - absl-py==2.1.0 117 | - alembic==1.13.1 118 | - aniso8601==9.0.1 119 | - antlr4-python3-runtime==4.9.3 120 | - anyio==4.3.0 121 | - argon2-cffi==23.1.0 122 | - argon2-cffi-bindings==21.2.0 123 | - arrow==1.3.0 124 | - asttokens==2.4.1 125 | - async-lru==2.0.4 126 | - attrs==23.2.0 127 | - babel==2.14.0 128 | - beautifulsoup4==4.12.3 129 | - bleach==6.1.0 130 | - blinker==1.7.0 131 | - cachetools==5.3.2 132 | - catboost==1.2.5 133 | - causal-conv1d==1.2.0.post2 134 | - click==8.1.7 135 | - cloudpickle==3.0.0 136 | - comm==0.2.1 137 | - contourpy==1.2.0 138 | - cycler==0.12.1 139 | - debugpy==1.8.1 140 | - decorator==5.1.1 141 | - defusedxml==0.7.1 142 | - docker==7.0.0 143 | - einops==0.7.0 144 | - entrypoints==0.4 145 | - et-xmlfile==1.1.0 146 | - executing==2.0.1 147 | - fastjsonschema==2.19.1 148 | - flask==3.0.2 149 | - fonttools==4.49.0 150 | - fqdn==1.5.1 151 | - ftfy==6.2.0 152 | - gitdb==4.0.11 153 | - gitpython==3.1.42 154 | - google-auth==2.27.0 155 | - google-auth-oauthlib==1.2.0 156 | - graphene==3.3 157 | - graphql-core==3.2.3 158 | - graphql-relay==3.2.0 159 | - greenlet==3.0.3 160 | - grpcio==1.60.1 161 | - gunicorn==21.2.0 162 | - h11==0.14.0 163 | - httpcore==1.0.4 164 | - httpx==0.27.0 165 | - huggingface-hub==0.22.1 166 | - hydra-core==1.3.2 167 | - imageio==2.33.1 168 | - importlib-metadata==7.0.2 169 | - ipykernel==6.29.3 170 | - ipython==8.22.1 171 | - ipywidgets==8.1.5 172 | - isoduration==20.11.0 173 | - itsdangerous==2.1.2 174 | - jedi==0.19.1 175 | - joblib==1.3.2 176 | - json5==0.9.17 177 | - jsonpointer==2.4 178 | - jsonschema==4.21.1 179 | - jsonschema-specifications==2023.12.1 180 | - jupyter-client==8.6.0 181 | - jupyter-core==5.7.1 182 | - jupyter-events==0.9.0 183 | - jupyter-lsp==2.2.3 184 | - jupyter-server==2.12.5 185 | - jupyter-server-terminals==0.5.2 186 | - jupyterlab==4.1.2 187 | - jupyterlab-pygments==0.3.0 188 | - jupyterlab-server==2.25.3 189 | - jupyterlab-widgets==3.0.13 190 | - keopscore==2.2.1 191 | - kiwisolver==1.4.5 192 | - lazy-loader==0.3 193 | - llvmlite==0.43.0 194 | - mako==1.3.2 195 | - mamba-ssm==1.2.0.post1 196 | - markdown==3.5.2 197 | - matplotlib==3.8.3 198 | - matplotlib-inline==0.1.6 199 | - mistune==3.0.2 200 | - mlflow==2.11.1 201 | - nbclient==0.9.0 202 | - nbconvert==7.16.1 203 | - nbformat==5.9.2 204 | - nest-asyncio==1.6.0 205 | - ninja==1.11.1.1 206 | - notebook==7.1.1 207 | - notebook-shim==0.2.4 208 | - numba==0.60.0 209 | - oauthlib==3.2.2 210 | - omegaconf==2.3.0 211 | - openpyxl==3.1.5 212 | - opt-einsum==3.3.0 213 | - overrides==7.7.0 214 | - pandas==2.2.0 215 | - pandocfilters==1.5.1 216 | - parso==0.8.3 217 | - pexpect==4.9.0 218 | - platformdirs==4.2.0 219 | - plotly==5.24.0 220 | - prometheus-client==0.20.0 221 | - prompt-toolkit==3.0.43 222 | - protobuf==4.23.4 223 | - psutil==5.9.8 224 | - ptyprocess==0.7.0 225 | - pure-eval==0.2.2 226 | - pyarrow==15.0.1 227 | - pyasn1==0.5.1 228 | - pyasn1-modules==0.3.0 229 | - pybind11==2.11.1 230 | - pygments==2.17.2 231 | - pykeops==2.2.1 232 | - pyparsing==3.1.2 233 | - python-dateutil==2.8.2 234 | - python-graphviz==0.20.3 235 | - python-json-logger==2.0.7 236 | - pytz==2024.1 237 | - pyzmq==25.1.2 238 | - querystring-parser==1.2.4 239 | - referencing==0.33.0 240 | - regex==2023.12.25 241 | - requests-oauthlib==1.3.1 242 | - resampy==0.4.3 243 | - rfc3339-validator==0.1.4 244 | - rfc3986-validator==0.1.1 245 | - rpds-py==0.18.0 246 | - rsa==4.9 247 | - safetensors==0.4.2 248 | - scikit-image==0.22.0 249 | - scikit-learn==1.4.0 250 | - scipy==1.12.0 251 | - send2trash==1.8.2 252 | - shap==0.46.0 253 | - six==1.16.0 254 | - slicer==0.0.8 255 | - smmap==5.0.1 256 | - sniffio==1.3.1 257 | - soupsieve==2.5 258 | - sqlalchemy==2.0.28 259 | - sqlparse==0.4.4 260 | - stack-data==0.6.3 261 | - structured-kernels==0.1.0 262 | - tenacity==9.0.0 263 | - tensorboard==2.15.1 264 | - tensorboard-data-server==0.7.2 265 | - terminado==0.18.0 266 | - threadpoolctl==3.2.0 267 | - tifffile==2024.1.30 268 | - timm==1.0.7 269 | - tinycss2==1.2.1 270 | - tokenizers==0.15.2 271 | - tornado==6.4 272 | - traitlets==5.14.1 273 | - transformers==4.39.1 274 | - triton==2.2.0 275 | - types-python-dateutil==2.8.19.20240106 276 | - tzdata==2023.4 277 | - uri-template==1.3.0 278 | - wcwidth==0.2.13 279 | - webcolors==1.13 280 | - webencodings==0.5.1 281 | - websocket-client==1.7.0 282 | - werkzeug==3.0.1 283 | - widgetsnbextension==4.0.13 284 | - zipp==3.17.0 285 | prefix: /user/jael1674/anaconda3/envs/lightning2 286 | -------------------------------------------------------------------------------- /src/extensions/cauchy/a.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/extensions/cauchy/benchmark_cauchy.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | 6 | from einops import rearrange 7 | 8 | from .cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult 9 | from benchmark.utils import benchmark_all, benchmark_combined, benchmark_forward, benchmark_backward 10 | 11 | 12 | def generate_data(batch_size, N, L, symmetric=True, device='cuda'): 13 | if not symmetric: 14 | v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True) 15 | w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True) 16 | z = torch.randn(L, dtype=torch.complex64, device=device) 17 | else: 18 | assert N % 2 == 0 19 | v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device) 20 | v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True) 21 | w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device) 22 | w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True) 23 | z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device)) 24 | return v, z, w 25 | 26 | 27 | if __name__ == '__main__': 28 | device = 'cuda' 29 | bs = 1024 30 | N = 64 31 | L = 16384 32 | 33 | v, z, w = generate_data(bs, N, L, symmetric=True) 34 | v_half = v[:, :N // 2].clone().detach().requires_grad_(True) 35 | w_half = w[:, :N // 2].clone().detach().requires_grad_(True) 36 | 37 | repeat = 30 38 | benchmark_all(repeat, cauchy_mult_keops, v, z, w, desc='Cauchy mult keops') 39 | fn = partial(cauchy_mult, symmetric=False) 40 | benchmark_all(repeat, fn, v, z, w, desc='Cauchy mult') 41 | fn = partial(cauchy_mult, symmetric=True) 42 | benchmark_all(repeat, fn, v_half, z, w_half, desc='Cauchy mult symmetric') 43 | -------------------------------------------------------------------------------- /src/extensions/cauchy/benchmark_cauchy_tune.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import json 3 | import argparse 4 | 5 | import torch 6 | 7 | from benchmark.utils import benchmark_forward 8 | 9 | 10 | def generate_data(batch_size, N, L, symmetric=True, device='cuda'): 11 | if not symmetric: 12 | v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True) 13 | w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True) 14 | z = torch.randn(L, dtype=torch.complex64, device=device) 15 | else: 16 | assert N % 2 == 0 17 | v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device) 18 | v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True) 19 | w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device) 20 | w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True) 21 | z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device)) 22 | return v, z, w 23 | 24 | 25 | parser = argparse.ArgumentParser(description='Tuning Cauchy multiply') 26 | parser.add_argument('--name', default='cauchy_mult') 27 | parser.add_argument('--mode', default='forward', choices=['forward', 'backward']) 28 | parser.add_argument('-bs', '--batch-size', default=1024, type=int) 29 | parser.add_argument('-N', default=64, type=int) 30 | parser.add_argument('-L', default=2 ** 14, type=int) 31 | 32 | 33 | if __name__ == '__main__': 34 | args = parser.parse_args() 35 | device = 'cuda' 36 | bs = args.batch_size 37 | N = args.N 38 | L = args.L 39 | repeat = 30 40 | v, z, w = generate_data(bs, N, L, symmetric=True) 41 | v_half = v[:, :N // 2].clone().detach().requires_grad_(True) 42 | w_half = w[:, :N // 2].clone().detach().requires_grad_(True) 43 | 44 | tuning_extension_name = args.name 45 | # print('Extension name:', tuning_extension_name) 46 | module = importlib.import_module(tuning_extension_name) 47 | if args.mode == 'forward': 48 | _, m = benchmark_forward(repeat, module.cauchy_mult_sym_fwd, v_half, z, w_half, 49 | verbose=False, desc='Cauchy mult symmetric fwd') 50 | else: 51 | out = module.cauchy_mult_sym_fwd(v_half, z, w_half) 52 | dout = torch.randn_like(out) 53 | _, m = benchmark_forward(repeat, module.cauchy_mult_sym_bwd, v_half, z, w_half, dout, 54 | verbose=False, desc='Cauchy mult symmetric bwd') 55 | result_dict = dict(time_mean = m.mean, time_iqr = m.iqr) 56 | print(json.dumps(result_dict)) 57 | -------------------------------------------------------------------------------- /src/extensions/cauchy/cauchy.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") 7 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 8 | 9 | torch::Tensor cauchy_mult_fwd_cuda(torch::Tensor v, 10 | torch::Tensor z, 11 | torch::Tensor w); 12 | std::tuple cauchy_mult_bwd_cuda(torch::Tensor v, 13 | torch::Tensor z, 14 | torch::Tensor w, 15 | torch::Tensor dout); 16 | torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v, 17 | torch::Tensor z, 18 | torch::Tensor w); 19 | std::tuple cauchy_mult_sym_bwd_cuda(torch::Tensor v, 20 | torch::Tensor z, 21 | torch::Tensor w, 22 | torch::Tensor dout); 23 | 24 | namespace cauchy { 25 | 26 | torch::Tensor cauchy_mult_fwd(torch::Tensor v, 27 | torch::Tensor z, 28 | torch::Tensor w) { 29 | CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); 30 | const auto batch_size = v.size(0); 31 | const auto N = v.size(1); 32 | const auto L = z.size(0); 33 | CHECK_SHAPE(v, batch_size, N); 34 | CHECK_SHAPE(z, L); 35 | CHECK_SHAPE(w, batch_size, N); 36 | return cauchy_mult_fwd_cuda(v, z, w); 37 | } 38 | 39 | std::tuple 40 | cauchy_mult_bwd(torch::Tensor v, 41 | torch::Tensor z, 42 | torch::Tensor w, 43 | torch::Tensor dout) { 44 | CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); CHECK_DEVICE(dout); 45 | const auto batch_size = v.size(0); 46 | const auto N = v.size(1); 47 | const auto L = z.size(0); 48 | CHECK_SHAPE(v, batch_size, N); 49 | CHECK_SHAPE(z, L); 50 | CHECK_SHAPE(w, batch_size, N); 51 | CHECK_SHAPE(dout, batch_size, L); 52 | return cauchy_mult_bwd_cuda(v, z, w, dout); 53 | } 54 | 55 | torch::Tensor cauchy_mult_sym_fwd(torch::Tensor v, 56 | torch::Tensor z, 57 | torch::Tensor w) { 58 | CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); 59 | const auto batch_size = v.size(0); 60 | const auto N = v.size(1); 61 | const auto L = z.size(0); 62 | CHECK_SHAPE(v, batch_size, N); 63 | CHECK_SHAPE(z, L); 64 | CHECK_SHAPE(w, batch_size, N); 65 | return cauchy_mult_sym_fwd_cuda(v, z, w); 66 | } 67 | 68 | std::tuple 69 | cauchy_mult_sym_bwd(torch::Tensor v, 70 | torch::Tensor z, 71 | torch::Tensor w, 72 | torch::Tensor dout) { 73 | CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); CHECK_DEVICE(dout); 74 | const auto batch_size = v.size(0); 75 | const auto N = v.size(1); 76 | const auto L = z.size(0); 77 | CHECK_SHAPE(v, batch_size, N); 78 | CHECK_SHAPE(z, L); 79 | CHECK_SHAPE(w, batch_size, N); 80 | CHECK_SHAPE(dout, batch_size, L); 81 | return cauchy_mult_sym_bwd_cuda(v, z, w, dout); 82 | } 83 | 84 | } // cauchy 85 | 86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 87 | m.def("cauchy_mult_fwd", &cauchy::cauchy_mult_fwd, 88 | "Cauchy multiply forward"); 89 | m.def("cauchy_mult_bwd", &cauchy::cauchy_mult_bwd, 90 | "Cauchy multiply backward"); 91 | m.def("cauchy_mult_sym_fwd", &cauchy::cauchy_mult_sym_fwd, 92 | "Cauchy multiply symmetric forward"); 93 | m.def("cauchy_mult_sym_bwd", &cauchy::cauchy_mult_sym_bwd, 94 | "Cauchy multiply symmetric backward"); 95 | } 96 | -------------------------------------------------------------------------------- /src/extensions/cauchy/cauchy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | 5 | from cauchy_mult import cauchy_mult_fwd, cauchy_mult_bwd, cauchy_mult_sym_fwd, cauchy_mult_sym_bwd 6 | 7 | 8 | def cauchy_mult_torch(v: torch.Tensor, z: torch.Tensor, w: torch.Tensor, 9 | symmetric=True) -> torch.Tensor: 10 | """ 11 | v: (B, N) 12 | z: (L) 13 | w: (B, N) 14 | symmetric: whether to assume that v and w contain complex conjugate pairs, of the form 15 | [v_half, v_half.conj()] and [w_half, w_half.conj()] 16 | """ 17 | if not symmetric: 18 | return (rearrange(v, 'b n -> b 1 n') / (rearrange(z, 'l -> l 1') - rearrange(w, 'b n -> b 1 n'))).sum(dim=-1) 19 | else: 20 | N = v.shape[-1] 21 | assert N % 2 == 0 22 | vv = rearrange(v[:, :N // 2], 'b n -> b 1 n') 23 | zz = rearrange(z, 'l -> l 1') 24 | ww = rearrange(w[:, :N // 2], 'b n -> b 1 n') 25 | return 2 * ((zz * vv.real - vv.real * ww.real - vv.imag * ww.imag) 26 | / (zz * zz - 2 * zz * ww.real + ww.abs().square())).sum(dim=-1) 27 | 28 | 29 | def cauchy_mult_keops(v, z, w): 30 | from pykeops.torch import LazyTensor 31 | v_l = LazyTensor(rearrange(v, 'b N -> b 1 N 1')) 32 | z_l = LazyTensor(rearrange(z, 'L -> 1 L 1 1')) 33 | w_l = LazyTensor(rearrange(w, 'b N -> b 1 N 1')) 34 | sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension 35 | div = v_l / sub 36 | s = div.sum(dim=2, backend='GPU') 37 | return s.squeeze(-1) 38 | 39 | 40 | def _cauchy_mult(v, z, w, symmetric=True): 41 | if not symmetric: 42 | return CauchyMultiply.apply(v, z, w) 43 | else: 44 | return CauchyMultiplySymmetric.apply(v, z, w) 45 | 46 | def cauchy_mult(v, z, w, symmetric=True): 47 | """ Wrap the cuda method to deal with shapes """ 48 | v, w = torch.broadcast_tensors(v, w) 49 | shape = v.shape 50 | # z_shape = z.shape 51 | z = z.squeeze() 52 | assert len(z.shape) == 1 53 | 54 | v = v.contiguous() 55 | w = w.contiguous() 56 | z = z.contiguous() 57 | 58 | N = v.size(-1) 59 | assert w.size(-1) == N 60 | y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N), symmetric=symmetric) 61 | y = y.view(*shape[:-1], z.size(-1)) 62 | # y = z.new_zeros(*shape[:-1], z.size(-1)) 63 | return y 64 | 65 | 66 | class CauchyMultiply(torch.autograd.Function): 67 | 68 | @staticmethod 69 | def forward(ctx, v, z, w): 70 | batch, N = v.shape 71 | # supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] 72 | supported_N_values = [1 << log_n for log_n in [6]] 73 | L = z.shape[-1] 74 | if not N in supported_N_values: 75 | raise NotImplementedError(f'Only support N values in {supported_N_values}') 76 | if L % 32 != 0: 77 | raise NotImplementedError(f'Only support L values that are multiples of 32') 78 | if not v.is_cuda and z.is_cuda and w.is_cuda: 79 | raise NotImplementedError(f'Only support CUDA tensors') 80 | ctx.save_for_backward(v, z, w) 81 | return cauchy_mult_fwd(v, z, w) 82 | 83 | @staticmethod 84 | def backward(ctx, dout): 85 | v, z, w = ctx.saved_tensors 86 | dv, dw = cauchy_mult_bwd(v, z, w, dout) 87 | return dv, None, dw 88 | 89 | 90 | class CauchyMultiplySymmetric(torch.autograd.Function): 91 | 92 | @staticmethod 93 | def forward(ctx, v, z, w): 94 | batch, N = v.shape 95 | supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] 96 | L = z.shape[-1] 97 | if not N in supported_N_values: 98 | raise NotImplementedError(f'Only support N values in {supported_N_values}') 99 | max_L_value = 32 * 1024 * 64 * 1024 100 | if L > max_L_value: 101 | raise NotImplementedError(f'Only support L values <= {max_L_value}') 102 | if not v.is_cuda and z.is_cuda and w.is_cuda: 103 | raise NotImplementedError(f'Only support CUDA tensors') 104 | ctx.save_for_backward(v, z, w) 105 | return cauchy_mult_sym_fwd(v, z, w) 106 | 107 | @staticmethod 108 | def backward(ctx, dout): 109 | v, z, w = ctx.saved_tensors 110 | dv, dw = cauchy_mult_sym_bwd(v, z, w, dout) 111 | return dv, None, dw 112 | -------------------------------------------------------------------------------- /src/extensions/cauchy/map.h: -------------------------------------------------------------------------------- 1 | // Downloaded from https://github.com/swansontec/map-macro 2 | 3 | /* 4 | * Copyright (C) 2012 William Swanson 5 | * 6 | * Permission is hereby granted, free of charge, to any person 7 | * obtaining a copy of this software and associated documentation 8 | * files (the "Software"), to deal in the Software without 9 | * restriction, including without limitation the rights to use, copy, 10 | * modify, merge, publish, distribute, sublicense, and/or sell copies 11 | * of the Software, and to permit persons to whom the Software is 12 | * furnished to do so, subject to the following conditions: 13 | * 14 | * The above copyright notice and this permission notice shall be 15 | * included in all copies or substantial portions of the Software. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY 21 | * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 22 | * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | * 25 | * Except as contained in this notice, the names of the authors or 26 | * their institutions shall not be used in advertising or otherwise to 27 | * promote the sale, use or other dealings in this Software without 28 | * prior written authorization from the authors. 29 | */ 30 | 31 | #ifndef MAP_H_INCLUDED 32 | #define MAP_H_INCLUDED 33 | 34 | #define EVAL0(...) __VA_ARGS__ 35 | #define EVAL1(...) EVAL0(EVAL0(EVAL0(__VA_ARGS__))) 36 | #define EVAL2(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) 37 | #define EVAL3(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) 38 | #define EVAL4(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) 39 | #define EVAL(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) 40 | 41 | #define MAP_END(...) 42 | #define MAP_OUT 43 | #define MAP_COMMA , 44 | 45 | #define MAP_GET_END2() 0, MAP_END 46 | #define MAP_GET_END1(...) MAP_GET_END2 47 | #define MAP_GET_END(...) MAP_GET_END1 48 | #define MAP_NEXT0(test, next, ...) next MAP_OUT 49 | #define MAP_NEXT1(test, next) MAP_NEXT0(test, next, 0) 50 | #define MAP_NEXT(test, next) MAP_NEXT1(MAP_GET_END test, next) 51 | 52 | #define MAP0(f, x, peek, ...) f(x) MAP_NEXT(peek, MAP1)(f, peek, __VA_ARGS__) 53 | #define MAP1(f, x, peek, ...) f(x) MAP_NEXT(peek, MAP0)(f, peek, __VA_ARGS__) 54 | 55 | #define MAP_LIST_NEXT1(test, next) MAP_NEXT0(test, MAP_COMMA next, 0) 56 | #define MAP_LIST_NEXT(test, next) MAP_LIST_NEXT1(MAP_GET_END test, next) 57 | 58 | #define MAP_LIST0(f, x, peek, ...) f(x) MAP_LIST_NEXT(peek, MAP_LIST1)(f, peek, __VA_ARGS__) 59 | #define MAP_LIST1(f, x, peek, ...) f(x) MAP_LIST_NEXT(peek, MAP_LIST0)(f, peek, __VA_ARGS__) 60 | 61 | /** 62 | * Applies the function macro `f` to each of the remaining parameters. 63 | */ 64 | #define MAP(f, ...) EVAL(MAP1(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) 65 | 66 | /** 67 | * Applies the function macro `f` to each of the remaining parameters and 68 | * inserts commas between the results. 69 | */ 70 | #define MAP_LIST(f, ...) EVAL(MAP_LIST1(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) 71 | 72 | #endif 73 | -------------------------------------------------------------------------------- /src/extensions/cauchy/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import torch.cuda 3 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension 4 | from torch.utils.cpp_extension import CUDA_HOME 5 | 6 | ext_modules = [] 7 | if torch.cuda.is_available() and CUDA_HOME is not None: 8 | extension = CUDAExtension( 9 | 'cauchy_mult', [ 10 | 'cauchy.cpp', 11 | 'cauchy_cuda.cu', 12 | ], 13 | extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], 14 | # 'nvcc': ['-O2', '-lineinfo'] 15 | 'nvcc': ['-O2', '-lineinfo', '--use_fast_math'] 16 | } 17 | ) 18 | ext_modules.append(extension) 19 | 20 | setup( 21 | name='cauchy_mult', 22 | ext_modules=ext_modules, 23 | # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)}) 24 | cmdclass={'build_ext': BuildExtension}) 25 | -------------------------------------------------------------------------------- /src/extensions/cauchy/test_cauchy.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | import pytest 5 | 6 | from einops import rearrange 7 | 8 | from cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult 9 | 10 | 11 | def generate_data(batch_size, N, L, symmetric=True, device='cuda'): 12 | if not symmetric: 13 | v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True) 14 | w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True) 15 | z = torch.randn(L, dtype=torch.complex64, device=device) 16 | else: 17 | assert N % 2 == 0 18 | v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device) 19 | v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True) 20 | w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device) 21 | w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True) 22 | z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device)) 23 | return v, z, w 24 | 25 | 26 | def grad_to_half_grad(dx): 27 | dx_half, dx_half_conj = dx.chunk(2, dim=-1) 28 | return dx_half + dx_half_conj.conj() 29 | 30 | 31 | # @pytest.mark.parametrize('L', [1024]) 32 | # @pytest.mark.parametrize('N', [64]) 33 | # def test_cauchy_mult_nonsymmetric(N, L): 34 | # device = 'cuda' 35 | # batch_size = 4 36 | # torch.random.manual_seed(2357) 37 | # v, z, w = generate_data(batch_size, N, L, symmetric=False, device=device) 38 | # out_torch = cauchy_mult_torch(v, z, w, symmetric=False) 39 | # out_keops = cauchy_mult_keops(v, z, w) 40 | # out = cauchy_mult(v, z, w, symmetric=False) 41 | # assert torch.allclose(out, out_torch, rtol=1e-4, atol=1e-4) 42 | # assert torch.allclose(out, out_keops, rtol=1e-4, atol=1e-4) 43 | # dout = torch.randn_like(out) 44 | # dv_torch, dw_torch = torch.autograd.grad(out_torch, (v, w), dout, retain_graph=True) 45 | # dv_keops, dw_keops = torch.autograd.grad(out_keops, (v, w), dout, retain_graph=True) 46 | # dv, dw = torch.autograd.grad(out, (v, w), dout, retain_graph=True) 47 | # assert torch.allclose(dv, dv_torch, rtol=1e-4, atol=1e-4) 48 | # assert torch.allclose(dv, dv_keops, rtol=1e-4, atol=1e-4) 49 | # assert torch.allclose(dw, dw_torch, rtol=1e-4, atol=1e-4) 50 | # assert torch.allclose(dw, dw_keops, rtol=1e-4, atol=1e-4) 51 | 52 | 53 | @pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12, 2**13, 2**14, 2**18]) 54 | @pytest.mark.parametrize('N', [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) 55 | def test_cauchy_mult_symmetric(N, L): 56 | # rtol, atol = (1e-4, 1e-4) if N <= 64 and L <= 1024 else(1e-3, 1e-3) 57 | atol = 1e-4 58 | tol_factor = 10.0 # Our error shouldn't be this much higher than Keops' error 59 | device = 'cuda' 60 | batch_size = 4 61 | torch.random.manual_seed(2357) 62 | v, z, w = generate_data(batch_size, N, L, symmetric=True, device=device) 63 | v_half = v[:, :N // 2].clone().detach().requires_grad_(True) 64 | w_half = w[:, :N // 2].clone().detach().requires_grad_(True) 65 | # out_torch = cauchy_mult_torch(v, z, w, symmetric=True) 66 | out_torch = cauchy_mult_torch(v.cdouble(), z.cdouble(), w.cdouble(), symmetric=True).cfloat() 67 | out_keops = cauchy_mult_keops(v, z, w) 68 | out = cauchy_mult(v_half, z, w_half, symmetric=True) 69 | relerr_out_keops = (out_keops - out_torch).abs() / out_torch.abs() 70 | relerr_out = (out - out_torch).abs() / out_torch.abs() 71 | 72 | dout = torch.randn_like(out) 73 | dv_torch, dw_torch = torch.autograd.grad(out_torch, (v, w), dout, retain_graph=True) 74 | dv_torch, dw_torch = dv_torch[:, :N // 2], dw_torch[:, :N // 2] 75 | dv_keops, dw_keops = torch.autograd.grad(out_keops, (v, w), dout, retain_graph=True) 76 | dv_keops, dw_keops = grad_to_half_grad(dv_keops), grad_to_half_grad(dw_keops) 77 | dv, dw = torch.autograd.grad(out, (v_half, w_half), dout, retain_graph=True) 78 | relerr_dv_keops = (dv_keops - dv_torch).abs() / dv_torch.abs() 79 | relerr_dv = (dv - dv_torch).abs() / dv_torch.abs() 80 | relerr_dw_keops = (dw_keops - dw_torch).abs() / dw_torch.abs() 81 | relerr_dw = (dw - dw_torch).abs() / dw_torch.abs() 82 | print(f'Keops out relative error: max {relerr_out_keops.amax().item():.6f}, mean {relerr_out_keops.mean().item():6f}') 83 | print(f'out relative error: max {relerr_out.amax().item():.6f}, mean {relerr_out.mean().item():.6f}') 84 | print(f'Keops dv relative error: max {relerr_dv_keops.amax().item():.6f}, mean {relerr_dv_keops.mean().item():6f}') 85 | print(f'dv relative error: max {relerr_dv.amax().item():.6f}, mean {relerr_dv.mean().item():.6f}') 86 | print(f'Keops dw relative error: max {relerr_dw_keops.amax().item():.6f}, mean {relerr_dw_keops.mean().item():6f}') 87 | print(f'dw relative error: max {relerr_dw.amax().item():.6f}, mean {relerr_dw.mean().item():.6f}') 88 | assert (relerr_out.amax() <= relerr_out_keops.amax() * tol_factor + atol) 89 | assert (relerr_out.mean() <= relerr_out_keops.mean() * tol_factor + atol) 90 | # assert torch.allclose(out, out_torch, rtol=rtol, atol=atol) 91 | # assert torch.allclose(out, out_keops, rtol=rtol, atol=atol) 92 | assert (relerr_dv.amax() <= relerr_dv_keops.amax() * tol_factor + atol) 93 | assert (relerr_dv.mean() <= relerr_dv_keops.mean() * tol_factor + atol) 94 | assert (relerr_dw.amax() <= relerr_dw_keops.amax() * tol_factor + atol) 95 | assert (relerr_dw.mean() <= relerr_dw_keops.mean() * tol_factor + atol) 96 | # assert torch.allclose(dv, dv_torch, rtol=1e-4, atol=1e-4) 97 | # assert torch.allclose(dv, dv_keops, rtol=1e-4, atol=1e-4) 98 | # assert torch.allclose(dw, dw_torch, rtol=1e-4, atol=1e-4) 99 | # assert torch.allclose(dw, dw_keops, rtol=1e-4, atol=1e-4) 100 | 101 | -------------------------------------------------------------------------------- /src/extensions/cauchy/tune_cauchy.py: -------------------------------------------------------------------------------- 1 | import math 2 | import json 3 | import argparse 4 | import itertools 5 | from pathlib import Path 6 | 7 | from tuner import KernelTuner 8 | 9 | 10 | def forward_params_list(N): 11 | blocksize_params = ('MAX_BLOCK_SIZE_VALUE', [64, 128, 256, 512, 1024]) 12 | thread_value_default = [2, 4, 8, 16, 32, 32, 32, 32, 32, 32] 13 | thread_values_supported = [2, 4, 8, 16, 32, 64, 128] 14 | log_N_half = int(math.log2(N)) - 1 15 | thread_values = [] 16 | for val in thread_values_supported: 17 | if val <= N // 2: 18 | array = list(thread_value_default) 19 | array[log_N_half - 1] = val 20 | thread_values.append('{' + ', '.join(str(v) for v in array) + '}') 21 | thread_params = ('ITEMS_PER_THREAD_SYM_FWD_VALUES', thread_values) 22 | value_prod = itertools.product(thread_params[1], blocksize_params[1]) 23 | params_list = [{thread_params[0]: value[0], blocksize_params[0]: value[1]} 24 | for value in value_prod] 25 | return params_list 26 | 27 | 28 | def backward_params_list(L): 29 | thread_value_supported = [8, 16, 32, 64, 128] 30 | thread_params = ('ITEMS_PER_THREAD_SYM_BWD_VALUE', [v for v in thread_value_supported 31 | if (L + v - 1) // v <= 1024]) 32 | params_list = [{thread_params[0]: value} for value in thread_params[1]] 33 | return params_list 34 | 35 | 36 | parser = argparse.ArgumentParser(description='Tuning Cauchy multiply') 37 | parser.add_argument('--mode', default='forward', choices=['forward', 'backward']) 38 | parser.add_argument('-N', default=64, type=int) 39 | parser.add_argument('-L', default=2 ** 14, type=int) 40 | parser.add_argument('--filename', default='tuning_result.json') 41 | 42 | 43 | if __name__ == '__main__': 44 | args = parser.parse_args() 45 | 46 | extension_dir = Path(__file__).absolute().parent 47 | source_files = ['cauchy_cuda.cu'] 48 | if args.mode == 'forward': 49 | params_list = forward_params_list(args.N) 50 | tuner = KernelTuner(extension_dir, source_files, params_list, 51 | benchmark_script='benchmark_cauchy_tune.py', 52 | benchmark_args=['--mode', 'forward', '-N', str(args.N), '-L', '16384'], 53 | npool=16) 54 | else: 55 | params_list = backward_params_list(args.L) 56 | tuner = KernelTuner(extension_dir, source_files, params_list, 57 | benchmark_script='benchmark_cauchy_tune.py', 58 | benchmark_args=['--mode', 'backward', '-N', '64', '-L', str(args.L)], 59 | npool=16) 60 | 61 | result = tuner.tune() 62 | with open(args.filename, 'w') as f: 63 | json.dump(result, f) 64 | -------------------------------------------------------------------------------- /src/extensions/cauchy/tune_cauchy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python tune_cauchy.py --mode forward -N 4 --filename tuning_result_fwd_N_4.json 3 | python tune_cauchy.py --mode forward -N 8 --filename tuning_result_fwd_N_8.json 4 | python tune_cauchy.py --mode forward -N 16 --filename tuning_result_fwd_N_16.json 5 | python tune_cauchy.py --mode forward -N 32 --filename tuning_result_fwd_N_32.json 6 | python tune_cauchy.py --mode forward -N 64 --filename tuning_result_fwd_N_64.json 7 | python tune_cauchy.py --mode forward -N 128 --filename tuning_result_fwd_N_128.json 8 | python tune_cauchy.py --mode forward -N 256 --filename tuning_result_fwd_N_256.json 9 | python tune_cauchy.py --mode forward -N 512 --filename tuning_result_fwd_N_512.json 10 | 11 | python tune_cauchy.py --mode backward -L 1024 --filename tuning_result_bwd_L_1k.json 12 | python tune_cauchy.py --mode backward -L 2048 --filename tuning_result_bwd_L_2k.json 13 | python tune_cauchy.py --mode backward -L 4096 --filename tuning_result_bwd_L_4k.json 14 | python tune_cauchy.py --mode backward -L 8192 --filename tuning_result_bwd_L_8k.json 15 | python tune_cauchy.py --mode backward -L 16384 --filename tuning_result_bwd_L_16k.json 16 | -------------------------------------------------------------------------------- /src/extensions/cauchy/tuner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import sys 5 | # import tempfile 6 | # import importlib 7 | import random 8 | import string 9 | import json 10 | 11 | 12 | from functools import partial 13 | from multiprocessing import Pipe, Pool, Process 14 | from pathlib import Path 15 | 16 | from tqdm import tqdm 17 | 18 | import numpy as np 19 | 20 | 21 | def read_file(filename): 22 | """ return the contents of the file named filename or None if file not found """ 23 | if os.path.isfile(filename): 24 | with open(filename, 'r') as f: 25 | return f.read() 26 | 27 | 28 | def write_file(filename, string): 29 | """dump the contents of string to a file called filename""" 30 | with open(filename, 'w', encoding="utf-8") as f: 31 | f.write(string) 32 | 33 | 34 | def prepare_kernel_string(kernel_string, params): 35 | for k, v in params.items(): 36 | kernel_string = "#define " + k + " " + str(v) + "\n" + kernel_string 37 | return kernel_string 38 | 39 | 40 | def compile_extension(temp_dir, install=False, verbose=True): 41 | # Need to copy this process's environments, otherwise it can't find the compilers 42 | env = {**os.environ, 43 | 'TUNING_SOURCE_DIR': str(temp_dir), 44 | 'TUNING_EXTENSION_NAME': str(temp_dir.stem)} 45 | # https://stackoverflow.com/questions/53173314/how-to-change-distutils-output-directory 46 | # Need separate build directories for parallel compilation 47 | output = subprocess.run( 48 | # [sys.executable, "tuning_setup.py", 'build', f'--build-base={str(temp_dir)}', 49 | # f'--build-lib={str(temp_dir)}'], 50 | [sys.executable, "tuning_setup.py", 'build' if not install else 'develop'], 51 | cwd=temp_dir, 52 | env=env, 53 | capture_output=True, 54 | # check=True 55 | ) 56 | if verbose: 57 | print(output) 58 | print('Done compiling' if not install else 'Done installing') 59 | 60 | 61 | def uninstall_extensions(tuning_extension_names, verbose=True): 62 | # Need to copy this process's environments, otherwise it can't find the compilers 63 | env = {**os.environ} 64 | output = subprocess.run( 65 | [sys.executable, '-m', 'pip', 'uninstall', '-y', *tuning_extension_names], 66 | env=env, 67 | capture_output=True, 68 | # check=True 69 | ) 70 | if verbose: 71 | print(output) 72 | print('Done uninstalling') 73 | 74 | 75 | def benchmark_extension(benchmark_script, *benchmark_args, verbose=True): 76 | # Need to copy this process's environments, otherwise it can't find the compilers 77 | env = os.environ 78 | # https://stackoverflow.com/questions/53173314/how-to-change-distutils-output-directory 79 | # Need separate build directories for parallel compilation 80 | process = subprocess.run( 81 | [sys.executable, benchmark_script, *benchmark_args], 82 | env=os.environ, 83 | capture_output=True, 84 | # check=True 85 | ) 86 | if verbose: 87 | print(process) 88 | print('Done benchmarking') 89 | return json.loads(process.stdout.decode(sys.stdout.encoding)) 90 | 91 | 92 | # def benchmark(connection, temp_dir): 93 | # import torch 94 | # # module = importlib.import_module(tuning_extension_name) 95 | # torch.ops.load_library(temp_dir / 'torch_butterfly_tuning.so') 96 | # batch_size = 1024 97 | # n = 32 98 | # twiddle = torch.randn(1, 1, 5, n // 2, 2, 2, device='cuda') 99 | # input = torch.randn(batch_size, 1, n, device=twiddle.device) 100 | # output = torch.ops.torch_butterfly.butterfly_multiply_fw(twiddle, input, True) 101 | # # https://medium.com/@auro_227/timing-your-pytorch-code-fragments-e1a556e81f2 102 | # res = [] 103 | # for _ in range(32): 104 | # start = torch.cuda.Event(enable_timing=True) 105 | # end = torch.cuda.Event(enable_timing=True) 106 | # start.record() 107 | # output = torch.ops.torch_butterfly.butterfly_multiply_fw(twiddle, input, True) 108 | # end.record() 109 | # torch.cuda.synchronize() 110 | # res.append(start.elapsed_time(end)) 111 | # print(output.shape) 112 | # res = np.array(res) 113 | # connection.send((np.mean(res), np.std(res))) 114 | 115 | 116 | def set_up_tuning_temp_dir(params: dict, source_files, extension_dir, verbose=True): 117 | if verbose: 118 | print('params: ', params) 119 | # TD [2021-10-22]: tempfile.mkdtemp sometimes create dir name with '_' in it, thus messing up 120 | # the extension name. 121 | # temp_dir = Path(tempfile.mkdtemp(prefix="temp_", dir=Path.cwd().parent)).absolute() 122 | tuning_extension_name = 'temp_' + ''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) 123 | temp_dir = (Path.cwd().parent / tuning_extension_name).absolute() 124 | if temp_dir.exists(): 125 | shutil.rmtree(temp_dir) # shutil.copytree doesn't want directory that already exists 126 | shutil.copytree(extension_dir, temp_dir) 127 | sources = [temp_dir / name for name in source_files] 128 | for kernel_source in sources: 129 | ks = read_file(kernel_source) 130 | ks = prepare_kernel_string(ks, params) 131 | write_file(kernel_source, ks) 132 | return temp_dir 133 | 134 | 135 | class KernelTuner: 136 | 137 | def __init__(self, extension_dir, source_files, params_list, benchmark_script, 138 | benchmark_args, npool=8, verbose=True): 139 | self.extension_dir = extension_dir 140 | self.source_files = source_files 141 | self.params_list = params_list 142 | self.benchmark_script = benchmark_script 143 | self.benchmark_args = benchmark_args 144 | self.npool = npool 145 | self.verbose = verbose 146 | 147 | def tune(self): 148 | temp_dirs = [set_up_tuning_temp_dir(params, self.source_files, self.extension_dir, 149 | verbose=self.verbose) 150 | for params in self.params_list] 151 | # Compile in parallel (for speed), then install sequentially to ensure correctness 152 | with Pool(self.npool) as p: 153 | p.map(compile_extension, temp_dirs) 154 | # with Pool(1) as p: 155 | # p.map(partial(compile_extension, install=True), [temp_dirs]) 156 | for temp_dir in tqdm(temp_dirs): 157 | try: 158 | compile_extension(temp_dir, install=True) 159 | except: 160 | pass 161 | # # We benchmark on a separate process so that they can import the extension that just got compiled. 162 | # for params, temp_dir in params_tempdir: 163 | # print('Benchmarking: ', params) 164 | # recv_conn, send_conn = Pipe(duplex=False) 165 | # benchmark_process = Process(target=benchmark_fwd, args=(send_conn, str(temp_dir.stem))) 166 | # benchmark_process.start() 167 | # result = recv_conn.recv() 168 | # benchmark_process.join() 169 | # print('result', result) 170 | results = [] 171 | for params, temp_dir in tqdm(list(zip(self.params_list, temp_dirs))): 172 | try: 173 | results.append((params, 174 | benchmark_extension(self.benchmark_script, 175 | *['--name', temp_dir.stem] + self.benchmark_args))) 176 | except: 177 | pass 178 | print(results) 179 | uninstall_extensions([temp_dir.stem for temp_dir in temp_dirs]) 180 | for temp_dir in temp_dirs: 181 | shutil.rmtree(temp_dir) 182 | return results 183 | -------------------------------------------------------------------------------- /src/extensions/cauchy/tuning_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from pathlib import Path 4 | 5 | import torch.cuda 6 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension 7 | from torch.utils.cpp_extension import CUDA_HOME 8 | 9 | 10 | extensions_dir = Path(os.getenv('TUNING_SOURCE_DIR')).absolute() 11 | assert extensions_dir.exists() 12 | source_files=[ 13 | 'cauchy.cpp', 14 | 'cauchy_cuda.cu', 15 | ] 16 | sources = [str(extensions_dir / name) for name in source_files] 17 | 18 | extension_name = os.getenv('TUNING_EXTENSION_NAME', default='cauchy_mult_tuning') 19 | ext_modules = [] 20 | if torch.cuda.is_available() and CUDA_HOME is not None: 21 | extension = CUDAExtension( 22 | extension_name, 23 | sources, 24 | include_dirs=[extensions_dir], 25 | extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], 26 | # 'nvcc': ['-O2', '-lineinfo'] 27 | 'nvcc': ['-O2', '-lineinfo', '--use_fast_math'] 28 | } 29 | ) 30 | ext_modules.append(extension) 31 | 32 | setup( 33 | name=extension_name, 34 | ext_modules=ext_modules, 35 | # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)}) 36 | cmdclass={'build_ext': BuildExtension}) 37 | -------------------------------------------------------------------------------- /src/main_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib 4 | import shutil 5 | 6 | from matplotlib import pyplot as plt 7 | from pathlib import Path 8 | 9 | import torch 10 | import lightning.pytorch as lp 11 | from lightning.pytorch.tuner import Tuner 12 | from lightning.pytorch.loggers import TensorBoardLogger 13 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar 14 | from clinical_ts.utils.callbacks import LRMonitorCallback, TriggerQuantizerHyperparameterUpdate, UnfreezingFinetuningCallback 15 | 16 | import hydra 17 | from hydra.core.hydra_config import HydraConfig 18 | 19 | ################# 20 | #specific 21 | from clinical_ts.config import * 22 | from omegaconf import OmegaConf 23 | 24 | 25 | #mlflow without autologging https://github.com/zjohn77/lightning-mlflow-hf/blob/74c30c784f719ea166941751bda24393946530b7/lightning_mlflow/train.py#L39 26 | MLFLOW_AVAILABLE=True 27 | try: 28 | import mlflow 29 | from lightning.pytorch.loggers import MLFlowLogger 30 | from omegaconf import DictConfig, ListConfig 31 | 32 | def log_params_from_omegaconf_dict(params): 33 | for param_name, element in params.items(): 34 | _explore_recursive(param_name, element) 35 | 36 | def _explore_recursive(parent_name, element): 37 | if isinstance(element, DictConfig): 38 | for k, v in element.items(): 39 | if isinstance(v, DictConfig) or isinstance(v, ListConfig): 40 | _explore_recursive(f'{parent_name}.{k}', v) 41 | else: 42 | if(k!="_target_" and v is not None): 43 | mlflow.log_param(f'{parent_name}.{k}'," " if v=="" else v) 44 | elif isinstance(element, ListConfig): 45 | for i, v in enumerate(element): 46 | mlflow.log_param(f'{parent_name}.{i}', " " if v=="" else v) 47 | 48 | except ImportError: 49 | MLFLOW_AVAILABLE=False 50 | 51 | def get_git_revision_short_hash(): 52 | return str(subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip()) 53 | 54 | def get_slurm_job_id(): 55 | job_id = os.environ.get('SLURM_JOB_ID') 56 | 57 | if job_id: 58 | return str(job_id) 59 | else: 60 | return "" 61 | 62 | def get_work_dir(directory, pattern="version_"): 63 | path = Path(directory) 64 | 65 | # Find all subdirectories matching the pattern "version_X" 66 | version_dirs = [d.name[len(pattern):] for d in path.iterdir() if d.is_dir() and d.name.startswith(pattern)] 67 | 68 | if not version_dirs: 69 | return pattern+"0" 70 | 71 | # Extract the version numbers and find the maximum 72 | version_numbers = [int(d.split('_')[-1]) for d in version_dirs] 73 | return pattern+str(max(version_numbers) + 1) 74 | 75 | def _string_to_class(_target_): 76 | if(len(_target_.split("."))==1):#assume global namespace 77 | cls_ = globals()[_target_] 78 | else: 79 | mod_ = importlib.import_module(".".join(_target_.split(".")[:-1])) 80 | cls_ = getattr(mod_, _target_.split(".")[-1]) 81 | return cls_ 82 | 83 | ################################################################################################### 84 | #MAIN 85 | ################################################################################################### 86 | cs = create_default_config() 87 | 88 | @hydra.main(version_base=None, config_path="conf", config_name="config_supervised_ecg") 89 | def run(hparams: FullConfig) -> None: 90 | hparams.trainer.executable = "main_all" 91 | hparams.trainer.revision = get_git_revision_short_hash() 92 | 93 | if not os.path.exists(hparams.trainer.output_path): 94 | os.makedirs(hparams.trainer.output_path) 95 | 96 | #determine version/output path 97 | slurm_job_id = get_slurm_job_id() 98 | work_dir = get_work_dir(hparams.trainer.output_path,pattern="run_"+slurm_job_id+"_") 99 | 100 | hparams.trainer.output_path = Path(hparams.trainer.output_path)/(work_dir) 101 | if not os.path.exists(hparams.trainer.output_path): 102 | os.makedirs(hparams.trainer.output_path) 103 | 104 | logger = [TensorBoardLogger( 105 | save_dir=Path(hparams.trainer.output_path).parent, 106 | version=work_dir,#hparams.trainer.metadata.split(":")[0], 107 | name="")] 108 | 109 | print("FULL PARSED CONFIG:") 110 | print(OmegaConf.to_yaml(hparams)) 111 | 112 | #get hydra configs 113 | hydra_cfg = HydraConfig.get() 114 | config_file = Path(hydra_cfg.runtime.config_sources[1]["path"])/hydra_cfg.job.config_name 115 | print("Output directory:",hparams.trainer.output_path) 116 | print("Main config:",config_file) 117 | print("Overrides:",OmegaConf.to_container(hydra_cfg.overrides.hydra)) 118 | print("Runtime choices:",OmegaConf.to_container(hydra_cfg.runtime.choices)) 119 | #print("Full config:",OmegaConf.to_yaml(hparams)) 120 | 121 | #copy main config into output dir 122 | shutil.copyfile(config_file, Path(hparams.trainer.output_path)/(config_file.stem)) 123 | 124 | #create the model 125 | classname = _string_to_class(hparams.task.mainclass) 126 | model = classname(hparams) 127 | #model = torch.compile(model) 128 | 129 | if(MLFLOW_AVAILABLE): 130 | #os.environ['MLFLOW_TRACKING_USERNAME'] = "ai4h" 131 | #os.environ['MLFLOW_TRACKING_PASSWORD'] = "mlf22!" 132 | #os.environ['MLFLOW_TRACKING_URI'] = "https://ai4hmlflow.nsupdate.info/" 133 | mlflow.set_experiment(hparams.trainer.executable+"("+hparams.task.mainclass.split(".")[-1]+")") 134 | run = mlflow.start_run(run_name=hparams.trainer.metadata) 135 | mlf_logger = MLFlowLogger( 136 | experiment_name=mlflow.get_experiment(run.info.experiment_id).name, 137 | tracking_uri=mlflow.get_tracking_uri(), 138 | log_model=False, 139 | ) 140 | mlf_logger._run_id = run.info.run_id 141 | mlf_logger.log_hyperparams = log_params_from_omegaconf_dict 142 | logger.append(mlf_logger) 143 | 144 | key_summary_metric = model.metrics_train_val[0].key_summary_metric if len(model.metrics_train_val)>0 else 'val_loss'#use the key_summary_metric of the first metric otherwise val_loss 145 | mode_summary_metric = model.metrics_train_val[0].mode_summary_metric if len(model.metrics_train_val)>0 else 'min' 146 | 147 | checkpoint_callback = ModelCheckpoint( 148 | dirpath=logger[0].log_dir, 149 | filename="best_model", 150 | save_top_k=1, 151 | save_last=True, 152 | verbose=True, 153 | monitor=key_summary_metric, 154 | mode=mode_summary_metric) 155 | 156 | lr_monitor = LearningRateMonitor(logging_interval="step") 157 | lr_monitor2 = LRMonitorCallback(start=False,end=True)#interval="step") 158 | 159 | callbacks = [checkpoint_callback,lr_monitor,lr_monitor2] 160 | if(hparams.trainer.refresh_rate>0): 161 | callbacks.append(TQDMProgressBar(refresh_rate=hparams.trainer.refresh_rate)) 162 | quantizers = [m for m in model.modules() if isinstance(m,QuantizerBase)] 163 | if(len(quantizers)>0): 164 | print("Found",len(quantizers),"quantizer modules.") 165 | callbacks.append(TriggerQuantizerHyperparameterUpdate(quantizers)) 166 | if(hparams.loss.loss_type=="supervised" and hparams.trainer.frozen_epochs>0): 167 | callbacks.append(UnfreezingFinetuningCallback(unfreeze_epoch=hparams.trainer.frozen_epochs)) 168 | 169 | trainer = lp.Trainer( 170 | #overfit_batches=0.01, 171 | accumulate_grad_batches=hparams.trainer.accumulate, 172 | max_epochs=hparams.trainer.epochs if hparams.trainer.eval_only=="" else 0, 173 | 174 | default_root_dir=hparams.trainer.output_path, 175 | 176 | #debugging flags for val and train 177 | num_sanity_val_steps=0, 178 | #overfit_batches=10, 179 | 180 | logger=logger, 181 | callbacks = callbacks, 182 | benchmark=True, 183 | 184 | accelerator="gpu" if hparams.trainer.gpus>0 else "cpu", 185 | devices=hparams.trainer.gpus if hparams.trainer.gpus>0 else 1, 186 | num_nodes=hparams.trainer.num_nodes, 187 | precision=hparams.trainer.precision, 188 | strategy=hparams.trainer.strategy, 189 | 190 | enable_progress_bar=hparams.trainer.refresh_rate>0, 191 | #weights_summary='top', 192 | ) 193 | 194 | if(hparams.trainer.fp32_matmul_precision!="highest"): 195 | torch.set_float32_matmul_precision(hparams.trainer.fp32_matmul_precision) 196 | 197 | if(hparams.trainer.auto_batch_size):#batch size 198 | tuner=Tuner(trainer) 199 | tuner.scale_batch_size(model, mode="binsearch") 200 | 201 | if(hparams.trainer.lr_find):# lr find 202 | tuner=Tuner(trainer) 203 | 204 | #torch.save(model.state_dict(), Path(hparams.trainer.output_path)/(logger.log_dir+"initial_weights.ckpt")) 205 | # Run learning rate finder 206 | lr_finder = tuner.lr_find(model) 207 | 208 | # Plot lr find plot 209 | fig = lr_finder.plot(suggest=True) 210 | fig.show() 211 | plt.savefig(Path(hparams.trainer.output_path)/("lrfind.png")) 212 | 213 | # Pick point based on plot, or get suggestion 214 | new_lr = lr_finder.suggestion() 215 | print("Suggested lr:",new_lr) 216 | # update hparams of the model 217 | model.hparams.base.lr = [new_lr] 218 | model.lr = new_lr 219 | 220 | # there is still some issue with the restored model- therefore just abort the run 221 | #model.load_state_dict(torch.load(Path(hparams.trainer.output_path)/(logger.log_dir+"initial_weights.ckpt"))) 222 | return 223 | 224 | if(MLFLOW_AVAILABLE): 225 | #version_number = hparams.trainer.output_path.split('/')[-1].replace('version_', '') # tw: extract version index 226 | mlflow.log_param("a_slurm_job_id", slurm_job_id) 227 | mlflow.log_param("a_work_dir", work_dir) 228 | 229 | if(hparams.trainer.epochs>0 and hparams.trainer.eval_only==""): 230 | trainer.fit(model,ckpt_path= None if hparams.trainer.resume=="" else hparams.trainer.resume) 231 | trainer.test(model,ckpt_path="best" if hparams.trainer.eval_only=="" else hparams.trainer.eval_only) 232 | 233 | if(MLFLOW_AVAILABLE): 234 | mlflow.end_run() 235 | 236 | if(hparams.trainer.export_features!=""): 237 | model.export_features(Path(hparams.trainer.output_path)/"features",module=hparams.trainer.export_features) 238 | 239 | if __name__ == "__main__": 240 | run() 241 | --------------------------------------------------------------------------------