├── LICENSE
├── README.md
├── reports
    ├── README.md
    └── abstract.png
└── src
    ├── README.md
    ├── clinical_ts
        ├── a.md
        ├── config.py
        ├── data
        │   ├── a.md
        │   ├── collate_dataset.py
        │   ├── time_series_dataset.py
        │   ├── time_series_dataset_transforms.py
        │   └── time_series_dataset_utils.py
        ├── head
        │   ├── a.md
        │   └── multimodal.py
        ├── loss
        │   ├── a.md
        │   └── supervised.py
        ├── metric
        │   ├── a.md
        │   └── base.py
        ├── tabular
        │   ├── a.md
        │   └── base.py
        ├── task
        │   ├── a.md
        │   └── multimodal.py
        ├── template_model.py
        ├── template_modules.py
        ├── timeseries_utils.py
        ├── ts
        │   ├── a.md
        │   ├── base.py
        │   ├── basic_conv1d_modules
        │   │   ├── a.md
        │   │   └── basic_conv1d.py
        │   ├── encoder.py
        │   ├── head.py
        │   ├── s4.py
        │   ├── s4_modules
        │   │   ├── a.md
        │   │   ├── s42.py
        │   │   ├── s4_model.py
        │   │   └── s4_utils.py
        │   └── transformer_modules
        │   │   ├── a.md
        │   │   └── transformer.py
        └── utils
        │   ├── a.md
        │   ├── bootstrap_utils.py
        │   ├── callbacks.py
        │   ├── eval_utils_cafa.py
        │   └── schedulers.py
    ├── config
        ├── config_supervised_multimodal_labvalues_s4.yaml
        └── data
        │   ├── a.md
        │   └── multimodal_mdsed.yaml
    ├── data
        └── memmap
        │   └── a.md
    ├── ecg_utils.py
    ├── environment.yml
    ├── extensions
        └── cauchy
        │   ├── a.md
        │   ├── benchmark_cauchy.py
        │   ├── benchmark_cauchy_tune.py
        │   ├── cauchy.cpp
        │   ├── cauchy.py
        │   ├── cauchy_cuda.cu
        │   ├── map.h
        │   ├── setup.py
        │   ├── test_cauchy.py
        │   ├── tune_cauchy.py
        │   ├── tune_cauchy.sh
        │   ├── tuner.py
        │   └── tuning_setup.py
    ├── main_all.py
    ├── preprocessing.py
    └── timeseries_utils.py
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2024 AI4HealthUOL
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
 1 | ## This is the official repository for CardioLab. A machine and deep learning framework for the estimation and monitoring of laboratory values throught ECG data.
 2 | 
 3 | CardioLab have been proposed in two main manuscrips:
 4 | 
 5 | 1. **CardioLab: Laboratory Values Estimation from Electrocardiogram Features - An Exploratory Study.** Accepted by the international conference of computing in cardiology (CinC) 2024. [](https://arxiv.org/abs/2407.18629)
 6 |    
 7 | 2. **CardioLab: Laboratory Values Estimation and Monitoring from Electrocardiogram Signals - A Deep-Multimodal Approach** [](https://arxiv.org/abs/2411.14886)
 8 | 
 9 | 
10 | In terms of ECG data and clinical settings, our CinC manuscript investigate only abnormalities estimation (current) task with ECG tabular features, whereas our second manuscript investigate abnormalities estimation (current) as well as abnormalities monitoring (future) using ECG raw waveforms instead.
11 | 
12 | 
13 | ## Clinical Setting
14 | 
15 | 
16 | 
17 |  
18 | - A) Demonstrates the overall predictive workflow used in the study, where for model inputs we use ECG waveforms, demographics, biometrics, and vital signs, in a binary classification setting to predict abnormal laboratory values.
19 | 
20 | - B) Demonstrates the estimation task, where for feature space we sample the closest vital signs within 30 minutes of the ECG record, and the target is the closest laboratory value within 60 minutes.
21 | 
22 | - C) Demonstrates the monitoring task, where the feature space also includes the closest vital signs within 30 minutes of the ECG record, and the target is the presence of any abnormal laboratory value within a defined future time horizon, for which we investigated 30, 60, and 120 minutes.
23 | 
24 | 
25 | 
26 | ## References
27 | 
28 | ```bibtex
29 | @misc{alcaraz2024cardiolablaboratoryvaluesestimation,
30 |       title={CardioLab: Laboratory Values Estimation from Electrocardiogram Features -- An Exploratory Study}, 
31 |       author={Juan Miguel Lopez Alcaraz and Nils Strodthoff},
32 |       year={2024},
33 |       eprint={2407.18629},
34 |       archivePrefix={arXiv},
35 |       primaryClass={eess.SP},
36 |       url={https://arxiv.org/abs/2407.18629}, 
37 | }
38 | ```
39 | 
40 | ```bibtex
41 | @misc{alcaraz2024cardiolablaboratoryvaluesestimation,
42 |       title={CardioLab: Laboratory Values Estimation and Monitoring from Electrocardiogram Signals -- A Multimodal Deep Learning Approach}, 
43 |       author={Juan Miguel Lopez Alcaraz and Nils Strodthoff},
44 |       year={2024},
45 |       eprint={2411.14886},
46 |       archivePrefix={arXiv},
47 |       primaryClass={eess.SP},
48 |       url={https://arxiv.org/abs/2411.14886}, 
49 | }
50 | ```
51 | 
52 | 
53 | 
54 | 
--------------------------------------------------------------------------------
/reports/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/reports/abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI4HealthUOL/CardioLab/7dffd82df50b3dd3cb9c8d561fb364bf58dd88f0/reports/abstract.png
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
 1 | # Replicating the Experiments
 2 | 
 3 | This repository provides scripts to preprocess ECG data, train models, and analyze results. Follow the steps below to replicate the experiments.
 4 | 
 5 | ## 1. Download Required Datasets
 6 | Before running the scripts, download the following files and place them in the root directory of this repository:
 7 | 
 8 | 
 9 | 
10 | - **patients.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/)
11 | - **d_labitems.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/)
12 | - **labevents.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/)
13 | - **omr.csv.gz** from [MIMIC-IV](https://physionet.org/content/mimiciv/3.1/)
14 | 
15 | - **edstays.csv.gz** from [MIMIC-IV-ED](https://physionet.org/content/mimic-iv-ed/2.2/)
16 | - **vitalsign.csv.gz** from [MIMIC-IV-ED](https://physionet.org/content/mimic-iv-ed/2.2/)
17 | 
18 | - **records_w_diag_icd10.csv** from [MIMIC-IV-ECG-ICD](https://www.physionet.org/content/mimic-iv-ecg-ext-icd-labels/1.0.1/)
19 | 
20 | - **record_list.csv** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/)
21 | - **machine_measurements.csv** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/)
22 | - **machine_measurements_data_dictionary.csv** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/)
23 | - **mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0.zip** from [MIMIC-IV-ECG](https://physionet.org/content/mimic-iv-ecg/1.0/)
24 |   
25 | 
26 | ## 2. Run Preprocessing
27 | First, preprocess the data by running:
28 | 
29 | ```bash
30 | python preprocessing.py
31 | ```
32 | 
33 | ## 3. Train the Model
34 | After preprocessing is complete, train and test the model by running:
35 | 
36 | ```bash
37 | python main_all.py --config config/config_supervised_multimodal_labvalues_s4.yaml
38 | ```
39 | 
40 | ## 4. Output and Results
41 | The results, including model performance metrics and analyses, will be saved automatically in the the current directory.
42 | 
43 | ---
44 | 
45 | For further inquiries, please open an issue.
46 | 
--------------------------------------------------------------------------------
/src/clinical_ts/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/config.py:
--------------------------------------------------------------------------------
  1 | from hydra.core.config_store import ConfigStore
  2 | from dataclasses import dataclass
  3 | 
  4 | #for base configs
  5 | from .template_modules import * 
  6 | 
  7 | #for specific configs
  8 | from .ts.encoder import *
  9 | from .ts.head import *
 10 | from .ts.base import *
 11 | 
 12 | ##################################################################
 13 | # try to import (optional) modules with special dependencies
 14 | S4_AVAILABLE = True
 15 | try:
 16 |     from .ts.s4 import *
 17 | except:
 18 |     print("WARNING: Could not import s4 module")
 19 |     S4_AVAILABLE = False
 20 | 
 21 | from .tabular.base import *
 22 | 
 23 | from .head.multimodal import *
 24 | 
 25 | from .loss.supervised import *
 26 | 
 27 | from .metric.base import *
 28 | 
 29 | from .task.multimodal import *
 30 | 
 31 | ###########################################################################################################
 32 | # https://hydra.cc/docs/tutorials/structured_config/config_groups/
 33 | @dataclass
 34 | class FullConfig:
 35 | 
 36 |     base: BaseConfig
 37 |     data: BaseConfigData
 38 |     loss: LossConfig
 39 |     metric: MetricConfig
 40 |     trainer: TrainerConfig
 41 |     task: TaskConfig
 42 | 
 43 |     ts: TimeSeriesEncoderConfig
 44 |     static: EncoderStaticBaseConfig
 45 |     head: HeadBaseConfig
 46 |     
 47 | 
 48 | def create_default_config():
 49 |     cs = ConfigStore.instance()
 50 |     cs.store(name="config", node=FullConfig)
 51 | 
 52 |     ######################################################################
 53 |     # base
 54 |     ######################################################################
 55 |     cs.store(group="base", name="base", node=BaseConfig)
 56 | 
 57 |     ######################################################################
 58 |     # input data
 59 |     ######################################################################
 60 |     cs.store(group="data", name="base", node=BaseConfigData)
 61 |     
 62 |     ######################################################################
 63 |     # time series encoder
 64 |     ######################################################################
 65 |     cs.store(group="ts", name="tsenc",  node=TimeSeriesEncoderConfig)
 66 |     
 67 |     #ENCODER
 68 |     cs.store(group="ts/enc", name="none", node=NoEncoderConfig)
 69 |     
 70 |     #PREDICTOR
 71 |     cs.store(group="ts/pred", name="none", node=NoPredictorConfig)#no predictor
 72 |     if(S4_AVAILABLE):
 73 |         cs.store(group="ts/pred", name="s4", node=S4PredictorConfig)#S4 model
 74 |     
 75 |     #HEADS
 76 |     cs.store(group="ts/head", name="none", node=HeadBaseConfig)
 77 |     cs.store(group="ts/head", name="pool", node=PoolingHeadConfig)
 78 | 
 79 |     #SSL HEADS
 80 |     cs.store(group="ts/head_ssl", name="none", node=HeadBaseConfig)
 81 | 
 82 |     #QUANTIZER
 83 |     cs.store(group="ts/quant", name="none", node=QuantizerBaseConfig)
 84 | 
 85 |     #MASK
 86 |     cs.store(group="ts/mask", name="none", node=MaskingBaseConfig)
 87 | 
 88 |     #LOSS
 89 |     cs.store(group="ts/loss", name="none", node=SSLLossConfig)
 90 | 
 91 |     #PRE
 92 |     cs.store(group="ts/pre", name="none", node=PrePostBaseConfig)  
 93 |     
 94 |     #POST
 95 |     cs.store(group="ts/pre", name="none", node=PrePostBaseConfig)
 96 |     
 97 |     ######################################################################
 98 |     # static encoder
 99 |     ######################################################################
100 |     for g in ["static", "ts/static"]:
101 |         cs.store(group=g, name="none", node=EncoderStaticBaseConfig)
102 |         cs.store(group=g, name="mlp", node=BasicEncoderStaticMLPConfig)
103 |         
104 |     ######################################################################
105 |     # optional multimodal head
106 |     ######################################################################
107 |     cs.store(group="head", name="none", node=HeadBaseConfig)
108 |     cs.store(group="head", name="concat", node=ConcatFusionHeadConfig)
109 | 
110 |     ######################################################################
111 |     # loss function
112 |     ######################################################################
113 |     #no global loss
114 |     cs.store(group="loss", name="none", node=LossConfig)
115 |     #supervised losses
116 |     cs.store(group="loss", name="bce", node=BCELossConfig)
117 |     cs.store(group="loss", name="bcef", node=BCEFLossConfig) 
118 |     
119 |     ######################################################################
120 |     # metrics
121 |     ######################################################################
122 |     cs.store(group="metric", name="none", node=MetricConfig)
123 |     cs.store(group="metric", name="auroc", node=MetricAUROCConfig)
124 |     cs.store(group="metric", name="aurocagg", node=MetricAUROCAggConfig)
125 |     
126 |     ######################################################################
127 |     # trainer
128 |     ######################################################################
129 |     cs.store(group="trainer", name="trainer", node=TrainerConfig)
130 |     
131 |     ######################################################################
132 |     # task
133 |     ######################################################################
134 |     cs.store(group="task", name="none", node=TaskConfig)
135 |     cs.store(group="task", name="multi", node=TaskConfigMultimodal)
136 |     
137 |     return cs
138 | 
--------------------------------------------------------------------------------
/src/clinical_ts/data/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/data/collate_dataset.py:
--------------------------------------------------------------------------------
 1 | import torch.utils.data
 2 | 
 3 | class CollateDataset(torch.utils.data.Dataset):
 4 |     r"""Dataset for collating several existing datasets
 5 |     """
 6 | 
 7 |     def __init__(self, datasets, sample_idx_from_first_dataset=False) -> None:
 8 |         super().__init__()
 9 |         self.datasets = datasets
10 |         self.sample_idx_from_first_dataset = sample_idx_from_first_dataset
11 | 
12 |     def __getitem__(self, idx):
13 |         if(self.sample_idx_from_first_dataset):
14 |             sample_idx = self.datasets[0][idx][0]
15 |             res = tuple(self.datasets[0][idx][1:])
16 |         else:
17 |             sample_idx = idx
18 |             res = tuple(self.datasets[0][idx])
19 |         
20 |         for d in self.datasets[1:]:
21 |             res += tuple(d[sample_idx])
22 |         return res
23 |             
24 |     def __len__(self):
25 |         return len(self.datasets[0])
--------------------------------------------------------------------------------
/src/clinical_ts/data/time_series_dataset.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['tsdata_seq','tsdata_seq_static','tsdata_seq_idxs','tsdata_seq_static_idxs','ConcatTimeSeriesDataset','TimeSeriesDataset','TimeSeriesDatasetConfig']
  2 | 
  3 | import numpy as np
  4 | import torch
  5 | import torch.utils.data
  6 | from .time_series_dataset_transforms import Compose
  7 | 
  8 | #Note: due to issues with the numpy rng for multiprocessing (https://github.com/pytorch/pytorch/issues/5059) that could be fixed by a custom worker_init_fn we use random throught for convenience
  9 | import random
 10 | 
 11 | #Note: multiprocessing issues with python lists and dicts (https://github.com/pytorch/pytorch/issues/13246) and pandas dfs (https://github.com/pytorch/pytorch/issues/5902)
 12 | #import multiprocessing as mp
 13 | 
 14 | from dataclasses import dataclass
 15 | import pandas as pd
 16 | from typing import Union, Any
 17 | import pathlib
 18 | 
 19 | 
 20 | from collections import namedtuple
 21 | 
 22 | tsdata_seq = namedtuple("tsdata_seq",("seq","label"))
 23 | tsdata_seq_static = namedtuple("tsdata_seq_static",("seq","label","static"))
 24 | tsdata_seq_cat = namedtuple("tsdata_seq_cat",("seq","label","static_cat"))
 25 | tsdata_seq_static_cat = namedtuple("tsdata_seq_static_cat",("seq","label","static","static_cat"))
 26 | 
 27 | tsdata_seq_idxs = namedtuple("tsdata_seq_idxs",("seq","label","seq_idxs"))
 28 | tsdata_seq_cat_idxs = namedtuple("tsdata_seq_cat_idxs",("seq","label","static_cat","seq_idxs"))
 29 | tsdata_seq_static_idxs = namedtuple("tsdata_seq_static_idxs",("seq","label","static","seq_idxs"))
 30 | tsdata_seq_static_cat_idxs = namedtuple("tsdata_seq_static_cat_idxs",("seq","label","static","static_cat","seq_idxs"))
 31 | 
 32 | def arrays_equal_with_nans(arr1, arr2):
 33 |     '''helper function to compare arrays with nans'''
 34 |     return ((arr1 == arr2) | (np.isnan(arr1) & np.isnan(arr2))).all()
 35 | 
 36 | class ConcatTimeSeriesDataset(torch.utils.data.ConcatDataset):
 37 |     '''ConcatDataset that handles id mapping correctly (to allow to aggregate predictions)'''
 38 |     def __init__(self, datasets):
 39 |         super().__init__(datasets)
 40 |         idmaps = []
 41 |         for dataset_idx,ds in enumerate(self.datasets):
 42 |             idmap = ds.get_id_mapping()
 43 |             remap_dict = {x:j+(self.cumulative_sizes[dataset_idx-1] if dataset_idx>0 else 0) for j,x in enumerate(np.unique(idmap))}
 44 |             idmaps.append(np.array([remap_dict[x] for x in idmap]))
 45 |         self.df_idx_mapping = np.concatenate(idmaps)
 46 | 
 47 |     def get_id_mapping(self):
 48 |         return self.df_idx_mapping
 49 | 
 50 |     def aggregate_predictions(self, preds,targs,idmap=None,aggregate_fn = np.mean,verbose=False):
 51 |         return self.datasets[0].aggregate_predictions(preds, targs, self.df_idx_mapping if idmap is None else idmap, aggregate_fn, verbose)
 52 | 
 53 | 
 54 | class TimeSeriesDataset(torch.utils.data.Dataset):
 55 |     """timeseries dataset with partial crops."""
 56 | 
 57 |     def __init__(self, hparams):
 58 |         """
 59 |         accepts three kinds of input:
 60 |         1) filenames pointing to aligned numpy arrays [timesteps,channels,...] for data and either integer labels or filename pointing to numpy arrays[timesteps,...] e.g. for annotations
 61 |         2) memmap_filename to memmap file (same argument that was passed to reformat_as_memmap) for data [concatenated,...] and labels- data column in df corresponds to index in this memmap; memmap_label_filename can normally kept as None (will use the memmap_label file in the same directory in this case)
 62 |         3) npy_data [samples,ts,...] (either path or np.array directly- also supporting variable length input) - data column in df corresponds to sampleid
 63 | 
 64 |         transforms: list of callables (transformations) or single instance e.g. from torchvision.transforms.Compose (applied in the specified order i.e. leftmost element first)
 65 |         
 66 |         col_lbl = None: return dummy label 0 (e.g. for unsupervised pretraining)
 67 |         cols_static: (optional) list of cols with extra static information (continuous-valued)
 68 |         cols_static_cat: (optional) list of cols with extra static information (categorical)
 69 |         fs_annotation_over_fs_data over ratio of sampling frequencies
 70 |         return_idxs: returns sample_idx from the underlying dataframe and start_idx and end_idx within the sequence (for aligned sequences e.g. spectra or for certain contrastive approaches such as ts2vec)
 71 |         """
 72 |         super().__init__()
 73 |         assert not((hparams.memmap_filename is not None) and (hparams.npy_data is not None))
 74 |         # require integer entries if using memmap or npy
 75 |         assert (hparams.memmap_filename is None and hparams.npy_data is None) or (hparams.df[hparams.col_data].dtype==np.int64 or hparams.df[hparams.col_data].dtype==np.int32 or hparams.df[hparams.col_data].dtype==np.int16)
 76 |         # keys (in column data) have to be unique
 77 |         assert(hparams.allow_multiple_keys or len(hparams.df[hparams.col_data].unique())==len(hparams.df))
 78 | 
 79 |         self.timeseries_df_data = np.array(hparams.df[hparams.col_data])
 80 |         if(self.timeseries_df_data.dtype not in [np.int16, np.int32, np.int64]):
 81 |             assert(hparams.memmap_filename is None and hparams.npy_data is None) #only for filenames in mode files
 82 |             self.timeseries_df_data = np.array(hparams.df[hparams.col_data].astype(str)).astype(np.string_)
 83 | 
 84 |         if(hparams.col_lbl is None):# use dummy labels
 85 |             self.timeseries_df_label = np.zeros(len(hparams.df))
 86 |         else: # use actual labels
 87 |             if(isinstance(hparams.df[hparams.col_lbl].iloc[0],list) or isinstance(hparams.df[hparams.col_lbl].iloc[0],np.ndarray)):#stack arrays/lists for proper batching
 88 |                 self.timeseries_df_label = np.stack(hparams.df[hparams.col_lbl])
 89 |             else: # single integers/floats
 90 |                 self.timeseries_df_label = np.array(hparams.df[hparams.col_lbl])
 91 | 
 92 |             if(not(hparams.annotation and hparams.memmap_filename is not None)):#skip if memmap and annotation        
 93 |                 if(self.timeseries_df_label.dtype not in [np.int16, np.int32, np.int64, np.float32, np.float64]): #everything else cannot be batched anyway mp.Manager().list(self.timeseries_df_label)
 94 |                     assert(hparams.annotation and hparams.memmap_filename is None and hparams.npy_data is None)#only for filenames in mode files
 95 |                     self.timeseries_df_label = np.array(hparams.df[hparams.col_lbl].apply(lambda x:str(x))).astype(np.string_)
 96 | 
 97 |         def concat_columns(row):
 98 |             return [item for col in row for item in (col if isinstance(col, np.ndarray) else [col])]
 99 | 
100 |         if(hparams.cols_static is not None):
101 |             self.timeseries_df_static = np.array(hparams.df[hparams.cols_static].apply(concat_columns, axis=1).to_list())
102 |             self.timeseries_df_static = np.squeeze(self.timeseries_df_static,axis=1)#remove unit axis
103 |             self.static = True
104 |         else:
105 |             self.static = False
106 | 
107 |         if(hparams.cols_static_cat is not None):
108 |             self.timeseries_df_static_cat = np.array(hparams.df[hparams.cols_static_cat].apply(concat_columns, axis=1).to_list())
109 |             self.timeseries_df_static_cat = np.squeeze(self.timeseries_df_static_cat,axis=1)#remove unit axis
110 |             self.static_cat = True
111 |         else:
112 |             self.static_cat = False
113 | 
114 |         self.output_size = hparams.output_size
115 |         self.data_folder = hparams.data_folder
116 |         self.transforms = Compose(hparams.transforms) if isinstance(hparams.transforms,list) else hparams.transforms
117 |         #if(isinstance(self.transforms,list) or isinstance(self.transforms,np.ndarray)):
118 |         #    print("Warning: the use of lists as arguments for transforms is discouraged")
119 |         self.annotation = hparams.annotation
120 |         self.col_lbl = hparams.col_lbl
121 | 
122 |         self.mode="files"
123 |         self.fs_annotation_over_fs_data = hparams.fs_annotation_over_fs_data
124 |         self.return_idxs = hparams.return_idxs
125 |         
126 |         if(hparams.memmap_filename is not None):
127 |             self.memmap_meta_filename = hparams.memmap_filename.parent/(hparams.memmap_filename.stem+"_meta.npz")
128 |             self.mode="memmap"
129 |             memmap_meta = np.load(self.memmap_meta_filename, allow_pickle=True)
130 |             self.memmap_start = memmap_meta["start"].astype(np.int64)# cast as integers to be on the safe side
131 |             self.memmap_shape = memmap_meta["shape"].astype(np.int64)
132 |             self.memmap_length = memmap_meta["length"].astype(np.int64)
133 |             self.memmap_file_idx = memmap_meta["file_idx"].astype(np.int64)
134 |             self.memmap_dtype = np.dtype(str(memmap_meta["dtype"]))
135 |             self.memmap_filenames = np.array(memmap_meta["filenames"]).astype(np.string_)#save as byte to avoid issue with mp
136 |             if(hparams.annotation):
137 |                 #by default use the memmap_label.npy in the same directory as the signal memmap file
138 |                 memmap_label_filename=hparams.memmap_label_filename if hparams.memmap_label_filename is not None else self.memmap_meta_filename.parent/("_".join(self.memmap_meta_filename.stem.split("_")[:-1])+"_label.npy")
139 |                 self.memmap_meta_filename_label = hparams.memmap_filename.parent/(memmap_label_filename.stem+"_meta.npz")
140 |                 memmap_meta_label =np.load(self.memmap_meta_filename_label, allow_pickle=True)
141 |                 self.memmap_start_label = memmap_meta_label["start"].astype(np.int64)
142 |                 self.memmap_shape_label = memmap_meta_label["shape"].astype(np.int64)
143 |                 self.memmap_length_label = memmap_meta_label["length"].astype(np.int64)
144 |                 self.memmap_file_idx_label = memmap_meta_label["file_idx"].astype(np.int64)
145 |                 self.memmap_dtype_label = np.dtype(str(memmap_meta_label["dtype"]))
146 |                 self.memmap_filenames_label = np.array(memmap_meta_label["filenames"]).astype(np.string_)
147 |         elif(hparams.npy_data is not None):
148 |             self.mode="npy"
149 |             if(isinstance(hparams.npy_data,np.ndarray) or isinstance(hparams.npy_data,list)):
150 |                 self.npy_data = np.array(hparams.npy_data)
151 |                 assert(hparams.annotation is False)
152 |             else:
153 |                 self.npy_data = np.load(hparams.npy_data, allow_pickle=True)
154 |             if(hparams.annotation):
155 |                 self.npy_data_label = np.load(hparams.npy_data.parent/(hparams.npy_data.stem+"_label.npy"), allow_pickle=True)
156 | 
157 |         self.random_crop = hparams.random_crop
158 |         self.sample_items_per_record = hparams.sample_items_per_record
159 | 
160 |         self.df_idx_mapping=[]
161 |         self.start_idx_mapping=[]
162 |         self.end_idx_mapping=[]
163 | 
164 |         for df_idx,(id,row) in enumerate(hparams.df.iterrows()):
165 |             if(self.mode=="files"):
166 |                 data_length = row["data_length"]
167 |             elif(self.mode=="memmap"):
168 |                 data_length= self.memmap_length[row[hparams.col_data]]
169 |             else: #npy
170 |                 data_length = len(self.npy_data[row[hparams.col_data]])
171 | 
172 |             if(hparams.chunk_length == 0):#do not split
173 |                 idx_start = [hparams.start_idx]
174 |                 idx_end = [data_length]
175 |             else:
176 |                 idx_start = list(range(hparams.start_idx,data_length,hparams.chunk_length if hparams.stride is None else hparams.stride))
177 |                 idx_end = [min(l+hparams.chunk_length, data_length) for l in idx_start]
178 | 
179 |             #remove final chunk(s) if too short
180 |             for i in range(len(idx_start)):
181 |                 if(idx_end[i]-idx_start[i]< hparams.min_chunk_length):
182 |                     del idx_start[i:]
183 |                     del idx_end[i:]
184 |                     break
185 |             #append to lists
186 |             for _ in range(hparams.copies+1):
187 |                 for i_s,i_e in zip(idx_start,idx_end):
188 |                     self.df_idx_mapping.append(df_idx)
189 |                     self.start_idx_mapping.append(i_s)
190 |                     self.end_idx_mapping.append(i_e)
191 |         #convert to np.array to avoid mp issues with python lists
192 |         self.df_idx_mapping = np.array(self.df_idx_mapping)
193 |         self.start_idx_mapping = np.array(self.start_idx_mapping)
194 |         self.end_idx_mapping = np.array(self.end_idx_mapping)
195 |             
196 |     def __len__(self):
197 |         return len(self.df_idx_mapping)
198 | 
199 |     @property
200 |     def is_empty(self):
201 |         return len(self.df_idx_mapping)==0
202 | 
203 |     def __getitem__(self, idx):
204 |         lst=[]
205 |         for _ in range(self.sample_items_per_record):
206 |             #determine crop idxs
207 |             timesteps= self.get_sample_length(idx)
208 | 
209 |             if(self.random_crop):#random crop
210 |                 if(timesteps==self.output_size):
211 |                     start_idx_rel = 0
212 |                 else:
213 |                     start_idx_rel = random.randint(0, timesteps - self.output_size -1)#np.random.randint(0, timesteps - self.output_size)
214 |             else:
215 |                 start_idx_rel =  (timesteps - self.output_size)//2
216 |             if(self.sample_items_per_record==1):
217 |                 return self._getitem(idx,start_idx_rel)
218 |             else:
219 |                 lst.append(self._getitem(idx,start_idx_rel))
220 |         return tuple(lst)
221 | 
222 |     def _getitem(self, idx,start_idx_rel):
223 |         #low-level function that actually fetches the data
224 |         df_idx = self.df_idx_mapping[idx]
225 |         start_idx = self.start_idx_mapping[idx]
226 |         end_idx = self.end_idx_mapping[idx]
227 |         #determine crop idxs
228 |         timesteps= end_idx - start_idx
229 |         assert(timesteps>=self.output_size)
230 |         start_idx_crop = start_idx + start_idx_rel
231 |         end_idx_crop = start_idx_crop+self.output_size
232 |         if(self.annotation):
233 |             start_idx_crop_label = int(np.round(start_idx_crop*self.fs_annotation_over_fs_data))
234 |             end_idx_crop_label = start_idx_crop_label+int(np.round(self.output_size*self.fs_annotation_over_fs_data))
235 | 
236 |         #print(idx,start_idx,end_idx,start_idx_crop,end_idx_crop)
237 |         #load the actual data
238 |         if(self.mode=="files"):#from separate files
239 |             data_filename = str(self.timeseries_df_data[df_idx],encoding='utf-8') #todo: fix potential issues here
240 |             if self.data_folder is not None:
241 |                 data_filename = self.data_folder/data_filename
242 |             data = np.load(data_filename, allow_pickle=True)[start_idx_crop:end_idx_crop] #data type has to be adjusted when saving to npy
243 | 
244 |             ID = data_filename.stem
245 | 
246 |             if(self.annotation is True):
247 |                 label_filename = str(self.timeseries_df_label[df_idx],encoding='utf-8')
248 |                 if self.data_folder is not None:
249 |                     label_filename = self.data_folder/label_filename
250 |                 label = np.load(label_filename, allow_pickle=True)[start_idx_crop_label:end_idx_crop_label] #data type has to be adjusted when saving to npy
251 |             else:
252 |                 label = self.timeseries_df_label[df_idx] #input type has to be adjusted in the dataframe
253 |         elif(self.mode=="memmap"): #from one memmap file
254 |             memmap_idx = self.timeseries_df_data[df_idx] #grab the actual index (Note the df to create the ds might be a subset of the original df used to create the memmap)
255 |             memmap_file_idx = self.memmap_file_idx[memmap_idx]
256 |             idx_offset = self.memmap_start[memmap_idx]
257 | 
258 |             #wi = torch.utils.data.get_worker_info()
259 |             #pid = 0 if wi is None else wi.id#os.getpid()
260 |             #print("idx",idx,"ID",ID,"idx_offset",idx_offset,"start_idx_crop",start_idx_crop,"df_idx", self.df_idx_mapping[idx],"pid",pid)
261 |             mem_filename = str(self.memmap_filenames[memmap_file_idx],encoding='utf-8')
262 |             mem_file = np.memmap(self.memmap_meta_filename.parent/mem_filename, self.memmap_dtype, mode='r', shape=tuple(self.memmap_shape[memmap_file_idx]))
263 |             data = np.copy(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
264 |             del mem_file
265 |             #print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
266 |             if(self.annotation):
267 |                 memmap_file_idx_label = self.memmap_file_idx_label[memmap_idx]
268 |                 idx_offset_label = self.memmap_start_label[memmap_idx]
269 | 
270 |                 mem_filename_label = str(self.memmap_filenames_label[memmap_file_idx_label],encoding='utf-8')
271 |                 mem_file_label = np.memmap(self.memmap_meta_filename_label.parent/mem_filename_label, self.memmap_dtype_label, mode='r', shape=tuple(self.memmap_shape_label[memmap_file_idx]))
272 |                 
273 |                 label = np.copy(mem_file_label[idx_offset_label + start_idx_crop_label: idx_offset_label + end_idx_crop_label])
274 |                 del mem_file_label
275 |             else:
276 |                 label = self.timeseries_df_label[df_idx]
277 |         else:#single npy array
278 |             ID = self.timeseries_df_data[df_idx]
279 | 
280 |             data = self.npy_data[ID][start_idx_crop:end_idx_crop]
281 | 
282 |             if(self.annotation):
283 |                 label = self.npy_data_label[ID][start_idx_crop:end_idx_crop]
284 |             else:
285 |                 label = self.timeseries_df_label[df_idx]
286 | 
287 |         sample = (data, label, self.timeseries_df_static[df_idx] if self.static else None, self.timeseries_df_static_cat[df_idx] if self.static_cat else None,np.array([df_idx,start_idx_crop,end_idx_crop]))
288 |         
289 |         # consistency check: make sure that data and annotation lengths match (check here because transforms might change the shape of the annotation)
290 |         assert(self.annotation is False or len(sample[1])==int(np.round(self.fs_annotation_over_fs_data*len(sample[0]))))
291 |         sample = self.transforms(sample)
292 | 
293 |         if(self.return_idxs):
294 |             if(self.static is True and self.static_cat is True):
295 |                 return tsdata_seq_static_cat_idxs(sample[0],sample[1], sample[2], sample[3], sample[4])
296 |             elif(self.static is True):
297 |                 return tsdata_seq_static_idxs(sample[0],sample[1], sample[2], sample[4])
298 |             elif(self.static_cat is True):
299 |                 return tsdata_seq_cat_idxs(sample[0],sample[1], sample[3], sample[4])
300 |             else:
301 |                 return tsdata_seq_idxs(sample[0], sample[1], sample[4])
302 |         else:
303 |             if(self.static is True and self.static_cat is True):
304 |                 return tsdata_seq_static_cat(sample[0],sample[1], sample[2], sample[3])
305 |             elif(self.static is True):
306 |                 return tsdata_seq_static(sample[0],sample[1], sample[2])
307 |             elif(self.static_cat is True):
308 |                 return tsdata_seq_cat(sample[0],sample[1], sample[3])
309 |             else:
310 |                 return tsdata_seq(sample[0], sample[1])
311 |         
312 | 
313 |     def get_sampling_weights(self, class_weight_dict,length_weighting=False, timeseries_df_group_by_col=None):
314 |         '''
315 |         class_weight_dict: dictionary of class weights
316 |         length_weighting: weigh samples by length
317 |         timeseries_df_group_by_col: column of the pandas df used to create the object'''
318 |         assert(self.annotation is False)
319 |         assert(length_weighting is False or timeseries_df_group_by_col is None)
320 |         weights = np.zeros(len(self.df_idx_mapping),dtype=np.float32)
321 |         length_per_class = {}
322 |         length_per_group = {}
323 |         for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)):
324 |             label = self.timeseries_df_label[i]
325 |             weight = class_weight_dict[label]
326 |             if(length_weighting):
327 |                 if label in length_per_class.keys():
328 |                     length_per_class[label] += e-s
329 |                 else:
330 |                     length_per_class[label] = e-s
331 |             if(timeseries_df_group_by_col is not None):
332 |                 group = timeseries_df_group_by_col[i]
333 |                 if group in length_per_group.keys():
334 |                     length_per_group[group] += e-s
335 |                 else:
336 |                     length_per_group[group] = e-s
337 |             weights[iw] = weight
338 | 
339 |         if(length_weighting):#need second pass to properly take into account the total length per class
340 |             for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)):
341 |                 label = self.timeseries_df_label[i]
342 |                 weights[iw]= (e-s)/length_per_class[label]*weights[iw]
343 |         if(timeseries_df_group_by_col is not None):
344 |             for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)):
345 |                 group = timeseries_df_group_by_col[i]
346 |                 weights[iw]= (e-s)/length_per_group[group]*weights[iw]
347 | 
348 |         weights = weights/np.min(weights)#normalize smallest weight to 1
349 |         return weights
350 | 
351 |     def get_id_mapping(self):
352 |         return self.df_idx_mapping
353 | 
354 |     def get_sample_id(self,idx):
355 |         return self.df_idx_mapping[idx]
356 | 
357 |     def get_sample_length(self,idx):
358 |         return self.end_idx_mapping[idx]-self.start_idx_mapping[idx]
359 | 
360 |     def get_sample_start(self,idx):
361 |         return self.start_idx_mapping[idx]
362 | 
363 |     def aggregate_predictions(self, preds,targs=None,idmap=None,aggregate_fn = np.mean,verbose=False):
364 |         '''
365 |         aggregates potentially multiple predictions per sample (can also pass targs for convenience)
366 |         idmap: idmap as returned by TimeSeriesCropsDataset's get_id_mapping (uses self.get_id_mapping by default)
367 |         preds: ordered predictions as returned by learn.get_preds()
368 |         aggregate_fn: function that is used to aggregate multiple predictions per sample (most commonly np.amax or np.mean)
369 |         '''
370 |         idmap = self.get_id_mapping() if idmap is None else idmap
371 |         if(idmap is not None and len(idmap)!=len(np.unique(idmap))):
372 |             if(verbose):
373 |                 print("aggregating predictions...")
374 |             preds_aggregated = []
375 |             targs_aggregated = []
376 |             for i in np.unique(idmap):
377 |                 preds_local = preds[np.where(idmap==i)[0]]
378 |                 preds_aggregated.append(aggregate_fn(preds_local,axis=0))
379 |                 if targs is not None:
380 |                     targs_local = targs[np.where(idmap==i)[0]]
381 |                     #assert(np.all(targs_local==targs_local[0])) #all labels have to agree
382 |                     assert(np.all([np.array_equal(t, targs_local[0], equal_nan=True) for t in targs_local])) #all labels have to agree (including nans)
383 |                     targs_aggregated.append(targs_local[0])
384 |             if(targs is None):
385 |                 return np.array(preds_aggregated)
386 |             else:
387 |                 return np.array(preds_aggregated),np.array(targs_aggregated)
388 |         else:
389 |             if(targs is None):
390 |                 return preds
391 |             else:
392 |                 return preds,targs
393 | 
394 | @dataclass
395 | class TimeSeriesDatasetConfig:
396 |     df:pd.DataFrame
397 |     output_size:int
398 |     chunk_length:int
399 |     min_chunk_length:int
400 |     memmap_filename:Union[str,pathlib.PosixPath,None]=None
401 |     memmap_label_filename:Union[str,pathlib.PosixPath,None]=None
402 |     npy_data:Union[np.ndarray,list,None]=None
403 |     random_crop:bool=True
404 |     data_folder:Union[str,pathlib.PosixPath,None]=None
405 |     copies:int=0
406 |     col_data:str="data"
407 |     col_lbl:Union[str,None]="label"
408 |     cols_static:Union[str,None]=None
409 |     cols_static_cat:Union[str,None]=None
410 |     stride:Union[int,None]=None
411 |     start_idx:int=0
412 |     annotation:bool=False
413 |     transforms:Any=None
414 |     sample_items_per_record:int=1
415 |     fs_annotation_over_fs_data:float=1.
416 |     return_idxs:bool=False
417 |     allow_multiple_keys:bool=False #in the df allow multiple rows with identical IDs
--------------------------------------------------------------------------------
/src/clinical_ts/data/time_series_dataset_transforms.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['Compose', 'RandomCrop', 'CenterCrop', 'FixedCrop', 'GaussianNoise', 'Resample', 'ToSpectrogram', 'Flatten', 'ToTensor', 'Normalize', 'NormalizeBatch', 'ButterFilter', 'ChannelFilter', 'Transform', 'StaticTransform', 'TupleTransform', 'SequenceToSampleLabelTransform']
  2 | 
  3 | 
  4 | import numpy as np
  5 | import torch
  6 | import torch.utils.data
  7 | 
  8 | import math
  9 | import random
 10 | import resampy
 11 | 
 12 | #from skimage import transform
 13 | 
 14 | #import warnings
 15 | #warnings.filterwarnings("ignore", category=UserWarning)
 16 | 
 17 | from scipy.signal import butter, sosfilt, sosfiltfilt
 18 | from scipy import signal
 19 | 
 20 | #from scipy.interpolate import interp1d
 21 | 
 22 | 
 23 | 
 24 | #def nn_upsample(xin, yin, xout):
 25 | #    '''performs nearest neighbor upsampling of the integer array yin with values at xin for new datapoints at xout'''
 26 | #    f = interp1d(xin,yin, kind="nearest",bounds_error=False,fill_value="extrapolate")
 27 | #    return f(xout).astype(np.int64)
 28 | 
 29 | #def resample_labels(startpts, labels, startpts_to_mid, startpts_new, startpts_to_mid_new):
 30 | #    '''resamples integer labels labels at starpts+startpts_to_mid to new anchor points at startpts_new+startpts_to_mid_new'''
 31 | #    if(isinstance(startpts_to_mid,float) or isinstance(startpts_to_mid,int)):
 32 | #        startpts_to_mid = np.ones_like(startpts)*startpts_to_mid
 33 | #    if(isinstance(startpts_to_mid_new,float) or isinstance(startpts_to_mid_new,int)):
 34 | #        startpts_to_mid_new = np.ones_like(startpts_new)*startpts_to_mid_new
 35 | #    midpts = np.array(startpts)+startpts_to_mid
 36 | #    midpts_new = np.array(startpts_new)+startpts_to_mid_new
 37 | #    return nn_upsample(midpts, labels, midpts_new)
 38 | 
 39 | #https://stackoverflow.com/questions/12093594/how-to-implement-band-pass-butterworth-filter-with-scipy-signal-butter
 40 | def butter_filter(lowcut=10, highcut=20, fs=50, order=5, btype='band'):
 41 |     '''returns butterworth filter with given specifications'''
 42 |     nyq = 0.5 * fs
 43 |     low = lowcut / nyq
 44 |     high = highcut / nyq
 45 | 
 46 |     sos = butter(order, [low, high] if btype=="band" else (low if btype=="low" else high), analog=False, btype=btype, output='sos')
 47 |     return sos
 48 | 
 49 | #def butter_filter_frequency_response(filter):
 50 | #    '''returns frequency response of a given filter (result of call of butter_filter)'''
 51 | #    w, h = sosfreqz(filter)
 52 | #    #gain vs. freq(Hz)
 53 | #    #plt.plot((fs * 0.5 / np.pi) * w, abs(h))
 54 | #    return w,h
 55 | #
 56 | #def apply_butter_filter(data, filter, forwardbackward=True):
 57 | #    '''pass filter from call of butter_filter to data (assuming time axis at dimension 0)'''
 58 | #    if(forwardbackward):
 59 | #        return sosfiltfilt(filter, data, axis=0)
 60 | #    else:
 61 | #        data = sosfilt(filter, data, axis=0)
 62 | 
 63 | class Compose:
 64 |     '''composes several transformations into a single one (as provided by torchvision.transforms.Compose)'''
 65 |     def __init__(self, transforms):
 66 |         self.transforms = transforms
 67 | 
 68 |     def __call__(self, inp):
 69 |         for t in self.transforms:
 70 |             inp = t(inp)
 71 |         return inp
 72 | 
 73 | 
 74 | class RandomCrop(object):
 75 |     """Crop randomly from a sample.
 76 |     """
 77 | 
 78 |     def __init__(self, output_size,annotation=False):
 79 |         self.output_size = output_size
 80 |         self.annotation = annotation
 81 | 
 82 |     def __call__(self, sample):
 83 |         data, label, static, static_cat, idxs = sample
 84 | 
 85 |         timesteps= len(data)
 86 |         assert(timesteps>=self.output_size)
 87 |         if(timesteps==self.output_size):
 88 |             start=0
 89 |         else:
 90 |             start = random.randint(0, timesteps - self.output_size-1) #np.random.randint(0, timesteps - self.output_size)
 91 |             idxs[1]+=start
 92 |             idxs[2]=idxs[1]+self.output_size
 93 |         
 94 |         data = data[start: start + self.output_size]
 95 |         if(self.annotation):
 96 |             label = label[start: start + self.output_size]
 97 | 
 98 |         return (data, label, static, static_cat, idxs)
 99 | 
100 | 
101 | class CenterCrop(object):
102 |     """Center crop from a sample.
103 |     """
104 | 
105 |     def __init__(self, output_size, annotation=False):
106 |         self.output_size = output_size
107 |         self.annotation = annotation
108 | 
109 |     def __call__(self, sample):
110 |         data, label, static, static_cat, idxs = sample
111 | 
112 |         timesteps= len(data)
113 |         start = (timesteps - self.output_size)//2
114 |         idxs[1] += start
115 |         idxs[2] = idxs[1] + self.output_size
116 | 
117 |         data = data[start: start + self.output_size]
118 |         if(self.annotation):
119 |             label = label[start: start + self.output_size]
120 | 
121 |         return (data, label, static, static_cat, idxs)
122 | 
123 | class FixedCrop(object):
124 |     """Take a fixed crop from a sample (for aligned data e.g. spectrograms). start_idx and end_idx are relative to the start of the respective sample
125 |     """
126 | 
127 |     def __init__(self, start_idx, end_idx, annotation=False):
128 |         self.start_idx = start_idx
129 |         self.end_idx = end_idx
130 |         self.output_size = self.end_idx-self.start_idx
131 |         self.annotation = annotation
132 | 
133 |     def __call__(self, sample):
134 |         data, label, static, static_cat, idxs = sample
135 |         assert(self.end_idx [x,y,]ch, seq
256 |                     return torch.from_numpy(np.moveaxis(data,0,-1))
257 |                 else:
258 |                     return torch.from_numpy(data)
259 |             else:#default_collate will take care of it
260 |                 return data
261 | 
262 |         data, label, static, static_cat, idxs = sample
263 |         if not isinstance(data,tuple):
264 |             data = _to_tensor(data,self.transpose_data)
265 |         else:
266 |             data = tuple(_to_tensor(x,self.transpose_data) for x in data)
267 | 
268 |         if not isinstance(label,tuple):
269 |             label = _to_tensor(label,self.transpose_label)
270 |         else:
271 |             label = tuple(_to_tensor(x,self.transpose_label) for x in label)
272 | 
273 |         if not isinstance(static,tuple):
274 |             static = _to_tensor(static)
275 |         else:
276 |             static = tuple(_to_tensor(x) for x in static)
277 | 
278 |         if not isinstance(static_cat,tuple):
279 |             static_cat = _to_tensor(static_cat)
280 |         else:
281 |             static_cat = tuple(_to_tensor(x) for x in static_cat)
282 |         
283 |         if not isinstance(idxs,tuple):
284 |             idxs = _to_tensor(idxs)
285 |         else:
286 |             idxs = tuple(_to_tensor(x) for x in idxs)
287 | 
288 |         return (data, label, static, static_cat, idxs) #returning as a tuple (potentially of lists)
289 | 
290 | 
291 | class Normalize(object):
292 |     """Normalize using given stats.
293 |     """
294 |     def __init__(self, stats_mean, stats_std, input=True, channels=[]):
295 |         self.stats_mean=stats_mean.astype(np.float32) if stats_mean is not None else None
296 |         self.stats_std=stats_std.astype(np.float32)+1e-8 if stats_std is not None else None
297 |         self.input = input
298 |         if(len(channels)>0):
299 |             for i in range(len(stats_mean)):
300 |                 if(not(i in channels)):
301 |                     self.stats_mean[:,i]=0
302 |                     self.stats_std[:,i]=1
303 | 
304 |     def __call__(self, sample):
305 |         datax, labelx, static, static_cat, idxs = sample
306 |         data = datax if self.input else labelx
307 |         #assuming channel last
308 |         if(self.stats_mean is not None):
309 |             data = data - self.stats_mean
310 |         if(self.stats_std is not None):
311 |             data = data/self.stats_std
312 | 
313 |         if(self.input):
314 |             return (data, labelx, static, static_cat, idxs)
315 |         else:
316 |             return (datax, data, static, static_cat, idxs)
317 | 
318 | 
319 | class NormalizeBatch(object):
320 |     """Normalize using batch statistics.
321 |     axis: tuple of integers of axis numbers to be normalized over (by default everything but the last)
322 |     """
323 |     def __init__(self, input=True, channels=[],axis=None):
324 |         self.channels = channels
325 |         self.channels_keep = None
326 |         self.input = input
327 |         self.axis = axis
328 | 
329 |     def __call__(self, sample):
330 |         datax, labelx, static, static_cat, idxs = sample
331 |         data = datax if self.input else labelx
332 |         #assuming channel last
333 |         #batch_mean = np.mean(data,axis=tuple(range(0,len(data)-1)))
334 |         #batch_std = np.std(data,axis=tuple(range(0,len(data)-1)))+1e-8
335 |         batch_mean = np.mean(data,axis=self.axis if self.axis is not None else tuple(range(0,len(data.shape)-1)))
336 |         batch_std = np.std(data,axis=self.axis if self.axis is not None else tuple(range(0,len(data.shape)-1)))+1e-8
337 | 
338 |         if(len(self.channels)>0):
339 |             if(self.channels_keep is None):
340 |                 self.channels_keep = np.setdiff(range(data.shape[-1]),self.channels)
341 | 
342 |             batch_mean[self.channels_keep]=0
343 |             batch_std[self.channels_keep]=1
344 | 
345 |         data = (data - batch_mean)/batch_std
346 | 
347 |         if(self.input):
348 |             return (data, labelx, static, static_cat, idxs)
349 |         else:
350 |             return (datax, data, static, static_cat, idxs)
351 | 
352 | 
353 | class ButterFilter(object):
354 |     """Apply filter
355 |     """
356 | 
357 |     def __init__(self, lowcut=50, highcut=50, fs=100, order=5, btype='band', forwardbackward=True, input=True):
358 |         self.filter = butter_filter(lowcut,highcut,fs,order,btype)
359 |         self.input = input
360 |         self.forwardbackward = forwardbackward
361 | 
362 |     def __call__(self, sample):
363 |         datax, labelx, static, static_cat, idxs = sample
364 |         data = datax if self.input else labelx
365 | 
366 |         if(self.forwardbackward):
367 |             data = sosfiltfilt(self.filter, data, axis=0)
368 |         else:
369 |             data = sosfilt(self.filter, data, axis=0)
370 | 
371 |         if(self.input):
372 |             return (data, labelx, static, static_cat, idxs)
373 |         else:
374 |             return (datax, data, static, static_cat, idxs)
375 | 
376 | 
377 | class ChannelFilter(object):
378 |     """Select certain channels.
379 |     axis: axis index of the channel axis
380 |     """
381 | 
382 |     def __init__(self, channels=[0], axis=-1, input=True):
383 |         self.channels = channels
384 |         self.input = input
385 |         self.axis = axis
386 | 
387 |     def __call__(self, sample):
388 |         data, label, static, static_cat, idxs = sample
389 |         if(self.input):
390 |             return (np.take(data,self.channels,axis=self.axis), label, static, static_cat, idxs) #(data[...,self.channels], label, static)
391 |         else:
392 |             return (data, np.take(label,self.channels,axis=self.axis), static, static_cat, idxs)
393 | 
394 | 
395 | class Transform(object):
396 |     """Transforms data using a given function i.e. data_new = func(data) for input is True else label_new = func(label)
397 |     """
398 | 
399 |     def __init__(self, func, input=False):
400 |         self.func = func
401 |         self.input = input
402 | 
403 |     def __call__(self, sample):
404 |         data, label, static, static_cat, idxs = sample
405 |         if(self.input):
406 |             return (self.func(data), label, static, static_cat, idxs)
407 |         else:
408 |             return (data, self.func(label), static, static_cat, idxs)
409 | 
410 | class StaticTransform(object):
411 |     """Transforms static data using a given function i.e. data_new = func(data) for input is True else label_new = func(label)
412 |     """
413 |     def __init__(self, func):
414 |         self.func = func
415 |         
416 |     def __call__(self, sample):
417 |         data, label, static, static_cat, idxs = sample
418 |         static, static_cat = self.func(static, static_cat)
419 |         return (data, label, static, static_cat, idxs)
420 | 
421 | class TupleTransform(object):
422 |     """Transforms data using a given function (operating on both data and label and return a tuple) i.e. data_new, label_new = func(data_old, label_old)
423 |     """
424 | 
425 |     def __init__(self, func):
426 |         self.func = func
427 | 
428 |     def __call__(self, sample):
429 |         data,label,static, static_cat,idxs = sample
430 |         data2, label2 = self.func(data,label,static,static_cat)
431 |         return  (data2, label2, static, static_cat, idxs)
432 | 
433 | class SequenceToSampleLabelTransform(object):
434 |     """Transforms sequence-level to sample-level labels
435 |     majority vote: pick the most frequent label as segment label (i.e. suitable for single-label classification)
436 |     num_classes: number of output classes
437 |     binary: binary instead of floating point outputs (where the latter represent fractions)
438 |     epoch_length: split the original sequence in ts//epoch_length fragments
439 |     """
440 | 
441 |     def __init__(self, majority_vote=False, num_classes=2, binary=False,epoch_length=0):
442 |         self.majority_vote = majority_vote
443 |         self.num_classes = num_classes
444 |         self.binary = binary
445 |         self.epoch_length = epoch_length
446 | 
447 |     def __call__(self, sample):
448 |         data, label, static, static_cat, idxs = sample
449 |         
450 |         epoch_length = self.epoch_length if self.epoch_length>0 else len(label)
451 |         if(len(label.shape)==1):#each time step is single-labeled
452 |             label = np.eye(self.num_classes)[label] #now label has shape ts,num_classes
453 |         cnts = np.sum(label.reshape((-1,epoch_length,label.shape[-1])),axis=1)#segments,classes
454 |         if(self.majority_vote):
455 |             label = np.argmax(cnts,axis=-1)#segments
456 |         else:
457 |             if(self.binary):
458 |                 label=(cnts>0).astype(np.float32)
459 |             else:
460 |                 label = (cnts/epoch_length).astype(np.float32)
461 |         if(self.epoch_length>0):
462 |             return (data, label, static, static_cat, idxs)
463 |         else:#just one segment
464 |             return (data, label[0], static, static_cat, idxs)
465 | 
--------------------------------------------------------------------------------
/src/clinical_ts/data/time_series_dataset_utils.py:
--------------------------------------------------------------------------------
  1 | __all__ = [ 'save_dataset', 'load_dataset','dataset_add_chunk_col', 'dataset_add_length_col', 'dataset_add_labels_col', 'dataset_add_mean_col', 'dataset_add_median_col', 'dataset_add_std_col', 'dataset_add_iqr_col', 'dataset_get_stats', 'append_to_memmap', 'append_to_df_memmap',           'npys_to_memmap_batched', 'npys_to_memmap', 'reformat_as_memmap']
  2 | 
  3 | import numpy as np
  4 | import pandas as pd
  5 | 
  6 | from pathlib import Path
  7 | from scipy.stats import iqr
  8 | 
  9 | 
 10 | #workaround for windows pickles
 11 | from sys import platform
 12 | import pathlib
 13 | if platform == "linux" or platform == "linux2":
 14 |     pathlib.WindowsPath = pathlib.PosixPath
 15 | 
 16 | try:
 17 |     import pickle5 as pickle
 18 | except ImportError as e:
 19 |     import pickle
 20 | 
 21 | from tqdm.auto import tqdm
 22 | 
 23 | 
 24 | 
 25 | 
 26 | def save_dataset(df,lbl_itos,mean,std,target_root,filename_postfix="",protocol=4):
 27 |     target_root = Path(target_root)
 28 |     df.to_pickle(target_root/("df"+filename_postfix+".pkl"), protocol=protocol)
 29 | 
 30 |     if(isinstance(lbl_itos,dict)):#dict as pickle
 31 |         outfile = open(target_root/("lbl_itos"+filename_postfix+".pkl"), "wb")
 32 |         pickle.dump(lbl_itos, outfile, protocol=protocol)
 33 |         outfile.close()
 34 |     else:#array
 35 |         np.save(target_root/("lbl_itos"+filename_postfix+".npy"),lbl_itos)
 36 | 
 37 |     np.save(target_root/("mean"+filename_postfix+".npy"),mean)
 38 |     np.save(target_root/("std"+filename_postfix+".npy"),std)
 39 | 
 40 | def load_dataset(target_root,filename_postfix="",df_mapped=True):
 41 |     target_root = Path(target_root)
 42 | 
 43 |     if(df_mapped):
 44 |         df = pd.read_pickle(target_root/("df_memmap"+filename_postfix+".pkl"))
 45 |     else:
 46 |         df = pd.read_pickle(target_root/("df"+filename_postfix+".pkl"))
 47 | 
 48 | 
 49 |     if((target_root/("lbl_itos"+filename_postfix+".pkl")).exists()):#dict as pickle
 50 |         infile = open(target_root/("lbl_itos"+filename_postfix+".pkl"), "rb")
 51 |         lbl_itos=pickle.load(infile)
 52 |         infile.close()
 53 |     else:#array
 54 |         lbl_itos = np.load(target_root/("lbl_itos"+filename_postfix+".npy"))
 55 | 
 56 | 
 57 |     mean = np.load(target_root/("mean"+filename_postfix+".npy"))
 58 |     std = np.load(target_root/("std"+filename_postfix+".npy"))
 59 |     return df, lbl_itos, mean, std
 60 | 
 61 | 
 62 | def dataset_add_chunk_col(df, col="data"):
 63 |     '''add a chunk column to the dataset df'''
 64 |     df["chunk"]=df.groupby(col).cumcount()
 65 | 
 66 | def dataset_add_length_col(df, col="data", data_folder=None):
 67 |     '''add a length column to the dataset df'''
 68 |     df[col+"_length"]=df[col].apply(lambda x: len(np.load(x if data_folder is None else data_folder/x, allow_pickle=True)))
 69 | 
 70 | def dataset_add_labels_col(df, col="label", data_folder=None):
 71 |     '''add a column with unique labels in column col'''
 72 |     df[col+"_labels"]=df[col].apply(lambda x: list(np.unique(np.load(x if data_folder is None else data_folder/x, allow_pickle=True))))
 73 | 
 74 | def dataset_add_mean_col(df, col="data", axis=(0), data_folder=None):
 75 |     '''adds a column with mean'''
 76 |     df[col+"_mean"]=df[col].apply(lambda x: np.mean(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis))
 77 | 
 78 | def dataset_add_median_col(df, col="data", axis=(0), data_folder=None):
 79 |     '''adds a column with median'''
 80 |     df[col+"_median"]=df[col].apply(lambda x: np.median(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis))
 81 | 
 82 | def dataset_add_std_col(df, col="data", axis=(0), data_folder=None):
 83 |     '''adds a column with mean'''
 84 |     df[col+"_std"]=df[col].apply(lambda x: np.std(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis))
 85 | 
 86 | def dataset_add_iqr_col(df, col="data", axis=(0), data_folder=None):
 87 |     '''adds a column with mean'''
 88 |     df[col+"_iqr"]=df[col].apply(lambda x: iqr(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis))
 89 | 
 90 | def dataset_get_stats(df, col="data", simple=True):
 91 |     '''creates (weighted) means and stds from mean, std and length cols of the df'''
 92 |     if(simple):
 93 |         return df[col+"_mean"].mean(), df[col+"_std"].mean()
 94 |     else:
 95 |         #https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html
 96 |         #or https://gist.github.com/thomasbrandon/ad5b1218fc573c10ea4e1f0c63658469
 97 |         def combine_two_means_vars(x1,x2):
 98 |             (mean1,var1,n1) = x1
 99 |             (mean2,var2,n2) = x2
100 |             mean = mean1*n1/(n1+n2)+ mean2*n2/(n1+n2)
101 |             var = var1*n1/(n1+n2)+ var2*n2/(n1+n2)+n1*n2/(n1+n2)/(n1+n2)*np.power(mean1-mean2,2)
102 |             return (mean, var, (n1+n2))
103 | 
104 |         def combine_all_means_vars(means,vars,lengths):
105 |             inputs = list(zip(means,vars,lengths))
106 |             result = inputs[0]
107 | 
108 |             for inputs2 in inputs[1:]:
109 |                 result= combine_two_means_vars(result,inputs2)
110 |             return result
111 | 
112 |         means = list(df[col+"_mean"])
113 |         vars = np.power(list(df[col+"_std"]),2)
114 |         lengths = list(df[col+"_length"])
115 |         mean,var,length = combine_all_means_vars(means,vars,lengths)
116 |         return mean, np.sqrt(var)
117 | 
118 | 
119 | def npys_to_memmap_batched(npys, target_filename, max_len=0, delete_npys=True, batched_npy=False, batch_length=900000):
120 |     '''
121 |     analogous to npys_to_memmap but processes batches of files before flushing them into memmap for faster processing
122 |     '''
123 |     memmap = None
124 |     start = np.array([0])#start_idx in current memmap file (always already the next start- delete last token in the end)
125 |     length = []#length of segment
126 |     filenames= []#memmap files
127 |     file_idx=[]#corresponding memmap file for sample
128 |     shape=[]#shapes of all memmap files
129 | 
130 |     data = []
131 |     data_lengths=[]
132 |     dtype = None
133 | 
134 |     target_filename = Path(target_filename)
135 | 
136 |     for idx,npy in tqdm(enumerate(npys),total=len(npys)):
137 |         data_batched = np.load(npy, allow_pickle=True)
138 | 
139 |         for data_tmp in (tqdm(data_batched,leave=False) if batched_npy else [data_batched]):
140 |             data.append(data_tmp)
141 |             data_lengths.append(len(data[-1]))
142 |             if(idx==len(npys)-1 or np.sum(data_lengths)>batch_length):#flush
143 |                 data = np.concatenate(data,axis=0)#concatenate along time axis (still axis 0 at this stage)
144 |                 if(memmap is None or (max_len>0 and start[-1]>max_len)):#new memmap file has to be created
145 |                     if(max_len>0):
146 |                         filenames.append(target_filename.parent/(target_filename.stem+"_"+str(len(filenames))+".npy"))
147 |                     else:
148 |                         filenames.append(target_filename)
149 | 
150 |                     shape.append([np.sum(data_lengths)]+[l for l in data.shape[1:]])#insert present shape
151 | 
152 |                     if(memmap is not None):#an existing memmap exceeded max_len
153 |                         del memmap
154 |                     #create new memmap
155 |                     start[-1] = 0
156 |                     start = np.concatenate([start,np.cumsum(data_lengths)])
157 |                     length = np.concatenate([length,data_lengths])
158 | 
159 |                     memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape)
160 |                 else:
161 |                     #append to existing memmap
162 |                     start = np.concatenate([start,start[-1]+np.cumsum(data_lengths)])
163 |                     length = np.concatenate([length,data_lengths])
164 |                     shape[-1] = [start[-1]]+[l for l in data.shape[1:]]
165 |                     memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple(shape[-1]))
166 | 
167 |                 #store mapping memmap_id to memmap_file_id
168 |                 file_idx=np.concatenate([file_idx,[(len(filenames)-1)]*len(data_lengths)])
169 |                 #insert the actual data
170 |                 memmap[start[-len(data_lengths)-1]:start[-len(data_lengths)-1]+len(data)]=data[:]
171 |                 memmap.flush()
172 |                 dtype = data.dtype
173 |                 data = []#reset data storage
174 |                 data_lengths = []
175 | 
176 |     start= start[:-1]#remove the last element
177 |     #cleanup
178 |     for npy in npys:
179 |         if(delete_npys is True):
180 |             npy.unlink()
181 |     del memmap
182 | 
183 |     #convert everything to relative paths
184 |     filenames= [f.name for f in filenames]
185 |     #save metadata
186 |     np.savez(target_filename.parent/(target_filename.stem+"_meta.npz"),start=start,length=length,shape=shape,file_idx=file_idx,dtype=dtype,filenames=filenames)
187 | 
188 | def append_to_memmap(memmap1, memmap2, file_id1=0, file_id2=0):
189 |     '''
190 |     appends the contents of memmap2(file_id2) to memmap1(file_id1 in case of split files)
191 |     '''
192 |     memmap1 = Path(memmap1)
193 |     memmap2 = Path(memmap2)
194 | 
195 |     meta1 = np.load(memmap1.parent/(memmap1.stem+"_meta.npz"),allow_pickle=True)
196 |     meta2 = np.load(memmap2.parent/(memmap2.stem+"_meta.npz"),allow_pickle=True)
197 |     meta1_shape = meta1["shape"][file_id1]
198 |     meta2_shape = meta2["shape"][file_id2]
199 |     assert((len(meta1_shape)==1 and len(meta2_shape)==1) or meta1_shape[1:]==meta2_shape[1:])#shapes have to match up to length
200 |     assert(meta1["dtype"]==meta2["dtype"])#dtypes have to agree
201 |     mask = np.where(meta2["file_idx"]==file_id2)[0]
202 |     lengths2 = np.array(meta2["length"])[mask]
203 |     shape = np.concatenate(([meta1_shape[0]+np.sum(lengths2)],meta1_shape[1:])).astype(np.int64)
204 |     full_shape=[(shape if i==file_id1 else m) for i,m in enumerate(meta1["shape"])]
205 |     starts2 = meta1_shape[0]+np.concatenate(([0],np.cumsum(lengths2)[:-1]))
206 |     start = np.concatenate((meta1["start"],starts2))
207 |     length = np.concatenate((meta1["length"],lengths2))
208 |     file_idx= np.concatenate((meta1["file_idx"],np.array([file_id1]*len(mask))))
209 |     print("Appending",memmap2,"to",memmap1,"...")
210 |     memmap_extended = np.memmap(memmap1.parent/(meta1["filenames"][file_id1]), dtype=np.dtype(str(meta1["dtype"])), mode='r+', shape=tuple(shape))
211 |     memmap_source = np.memmap(memmap2.parent/(meta2["filenames"][file_id2]), dtype=np.dtype(str(meta2["dtype"])), mode="r", shape=tuple(meta2_shape))
212 |     memmap_extended[meta1_shape[0]:] = memmap_source
213 |     memmap_extended.flush()
214 | 
215 |     np.savez(memmap1.parent/(memmap1.stem+"_meta.npz"),start=start,length=length,shape=full_shape,file_idx=file_idx,dtype=meta1["dtype"],filenames=meta1["filenames"])
216 |     print("done.")
217 | 
218 | def append_to_df_memmap(path_df_memmap1,path_df_memmap2,path_memmap1,path_memmap2,file_id1=0,file_id2=0,col_data="data"):
219 |     df_memmap1 = pd.read_pickle(path_df_memmap1).sort_values(by=[col_data])
220 |     df_memmap2 = pd.read_pickle(path_df_memmap2).sort_values(by=[col_data])
221 |     path_memmap1 = Path(path_memmap1)
222 |     path_memmap2 = Path(path_memmap2)
223 |     
224 |     meta1 = np.load(path_memmap1.parent/(path_memmap1.stem+"_meta.npz"),allow_pickle=True)
225 |     meta2 = np.load(path_memmap2.parent/(path_memmap2.stem+"_meta.npz"),allow_pickle=True)
226 |     df_memmap2 = df_memmap2.iloc[np.where(meta2["file_idx"]==file_id2)].copy()
227 |     file_idx1 = meta1["file_idx"]
228 |     assert(len(df_memmap1)==len(file_idx1))# apply append_to_df_memmap before append_to_memmap
229 |     data_idx_start = np.max(np.where(file_idx1<=file_id1)[0])+1
230 |     df_memmap1.loc[df_memmap1[col_data]>=data_idx_start,col_data]=df_memmap1.loc[df_memmap1[col_data]>=data_idx_start,col_data]+len(df_memmap2)
231 |     
232 |     df_memmap2[col_data]=df_memmap2[col_data]-df_memmap2[col_data].min()+data_idx_start
233 |     df_memmap1 = pd.concat((df_memmap1,df_memmap2))
234 |     return df_memmap1
235 | 
236 | def npys_to_memmap(npys, target_filename, max_len=0, delete_npys=True, batched_npy=False):
237 |     '''
238 |     converts list of filenames pointing to npy files into a memmap file with target_filename
239 |     max_len: restricts filesize per memmap file (0 no restriction)
240 |     delete_npys: deletes original npys after processing to save space
241 |     batched_npy: assumes first axis in the npy file enumerates samples (otherwise just a single sample per npy file)
242 |     '''
243 |     memmap = None
244 |     start = []#start_idx in current memmap file
245 |     length = []#length of segment
246 |     filenames= []#memmap files
247 |     file_idx=[]#corresponding memmap file for sample
248 |     shape=[]
249 | 
250 |     target_filename = Path(target_filename)
251 | 
252 |     for _,npy in tqdm(enumerate(npys),total=len(npys)):
253 |         data_batched = np.load(npy, allow_pickle=True)
254 |         for data in tqdm(data_batched,leave=False) if batched_npy else[data_batched]:
255 |             if(memmap is None or (max_len>0 and start[-1]+length[-1]>max_len)):
256 |                 if(max_len>0):
257 |                     filenames.append(target_filename.parent/(target_filename.stem+"_"+str(len(filenames))+".npy"))
258 |                 else:
259 |                     filenames.append(target_filename)
260 | 
261 |                 if(memmap is not None):#an existing memmap exceeded max_len
262 |                     shape.append([start[-1]+length[-1]]+[l for l in data.shape[1:]])
263 |                     del memmap
264 |                 #create new memmap
265 |                 start.append(0)
266 |                 length.append(data.shape[0])
267 |                 memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape)
268 |             else:
269 |                 #append to existing memmap
270 |                 start.append(start[-1]+length[-1])
271 |                 length.append(data.shape[0])
272 |                 memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple([start[-1]+length[-1]]+[l for l in data.shape[1:]]))
273 | 
274 |             #store mapping memmap_id to memmap_file_id
275 |             file_idx.append(len(filenames)-1)
276 |             #insert the actual data
277 |             memmap[start[-1]:start[-1]+length[-1]]=data[:]
278 |             memmap.flush()
279 |         if(delete_npys is True):
280 |             npy.unlink()
281 |     del memmap
282 | 
283 |     #append final shape if necessary
284 |     if(len(shape)0 else hparams_input_shape.channels)+hparams_input_shape.static_dim+hparams_input_shape.static_dim_cat
16 |         self.linear = nn.Linear(input_size,target_dim)
17 |         self.output_shape = dataclasses.replace(hparams_input_shape)
18 |         self.output_shape.channels = target_dim
19 |         self.output_shape.length = 0
20 |         self.output_shape.static_dim = 0
21 |         self.output_shape.static_dim_cat = 0
22 |         
23 |     def forward(self, **kwargs):
24 |         static = kwargs["static"]
25 |         seq = kwargs["seq"]
26 |         return {"seq": self.linear(torch.cat((seq.view(seq.shape[0],-1),static),dim=1))}
27 |     
28 |     def get_output_shape(self):
29 |         return self.output_shape
30 | 
31 | @dataclass
32 | class ConcatFusionHeadConfig(HeadBaseConfig):
33 |     _target_:str = "clinical_ts.head.multimodal.ConcatFusionHead"
34 |     multi_prediction:bool=False
35 | 
--------------------------------------------------------------------------------
/src/clinical_ts/loss/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/loss/supervised.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['SupervisedLossConfig', 'BCELossConfig', 'BinaryCrossEntropyFocalLoss', 'BCEFLossConfig']
 2 | 
 3 | import torch
 4 | from torch import nn
 5 | import torch.nn.functional as F
 6 | import numpy as np
 7 | 
 8 | from dataclasses import dataclass, field
 9 | from typing import List
10 | 
11 | from ..template_modules import LossConfig
12 | 
13 | ####################################################################################
14 | # BASIC supervised losses
15 | ###################################################################################
16 | @dataclass
17 | class SupervisedLossConfig(LossConfig):
18 |     _target_:str = "" #insert appropriate loss class
19 |     loss_type:str ="supervised"
20 |     supervised_type:str="classification_single"#"classification_multi","regression_quantile"                    
21 | 
22 | @dataclass
23 | class BCELossConfig(SupervisedLossConfig):
24 |     _target_:str= "clinical_ts.loss.supervised.BinaryCrossEntropyLoss"
25 |     loss_type:str="supervised"
26 |     supervised_type:str="classification_multi"
27 |     pos_weight:List[float]=field(default_factory=lambda: [])#class weights e.g. inverse class prevalences
28 |     ignore_nans:bool=False #ignore nans- requires separate BCEs for each label
29 | 
30 | class BinaryCrossEntropyFocalLoss(nn.Module):
31 |     """
32 |     Focal BCE loss for binary classification with labels of 0 and 1
33 |     """
34 |     def __init__(self, hparams_loss):
35 |         super().__init__()
36 |         self.gamma = hparams_loss.gamma
37 |         
38 |         self.ignore_nans = hparams_loss.ignore_nans
39 |         self.pos_weight_set = len(hparams_loss.pos_weight)>0
40 | 
41 |         if(not self.ignore_nans):
42 |             self.bce = torch.nn.BCEWithLogitsLoss(reduction="none",pos_weight=torch.from_numpy(np.array(hparams_loss.pos_weight,dtype=np.float32)) if len(hparams_loss.pos_weight)>0 else None)
43 |         else:
44 |             if(self.pos_weight_set):
45 |                 self.bce = torch.nn.ModuleList([torch.nn.BCEWithLogitsLoss(reduction="none",pos_weight=torch.from_numpy(np.array([hparams_loss.pos_weight[i]],dtype=np.float32))) for i in range(len(self.pos_weight_set))])
46 |             else:
47 |                 self.bce = torch.nn.BCEWithLogitsLoss(reduction="none")
48 | 
49 |     def forward(self, preds, targs):
50 |         if(not(self.ignore_nans)):
51 |             probs = torch.sigmoid(preds)
52 |             p_t = probs * targs + (1 - probs) * (1 - targs)
53 |             focal_modulation = torch.pow((1 - p_t), self.gamma)
54 |             # mean aggregation
55 |             return (focal_modulation * self.bce(input=preds, target=targs.float())).sum(-1).mean()
56 |         else:
57 |             losses = []
58 |             for i in range(preds.size(1)):
59 |                 predsi = preds[:,i]
60 |                 targsi = targs[:,i]
61 |                 maski = ~torch.isnan(targsi)
62 |                 predsi = predsi[maski]
63 |                 targsi = targsi[maski]
64 |                 if(len(predsi)>0):
65 |                     probsi = torch.sigmoid(predsi)
66 |                     p_ti = probsi * targsi + (1 - probsi) * (1 - targsi)
67 |                     focal_modulationi = torch.pow((1 - p_ti), self.gamma)
68 |                     if(self.pos_weight_set):
69 |                         losses.append(torch.mean(focal_modulationi*self.bce[i](predsi,targsi)))
70 |                     else:
71 |                         losses.append(torch.mean(focal_modulationi*self.bce(predsi,targsi)))
72 |                 
73 |             return torch.sum(torch.stack(losses)) if(len(losses)>0) else 0.
74 |         
75 | @dataclass
76 | class BCEFLossConfig(BCELossConfig):
77 |     _target_:str= "clinical_ts.loss.supervised.BinaryCrossEntropyFocalLoss"
78 |     gamma:float=2.
--------------------------------------------------------------------------------
/src/clinical_ts/metric/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/metric/base.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['MetricConfig', 'MetricBase', 'MetricAUROC', 'MetricAUROCConfig', 'MetricAUROCAggConfig']
  2 | 
  3 | import numpy as np
  4 | 
  5 | from dataclasses import dataclass
  6 | 
  7 | from ..utils.eval_utils_cafa import multiclass_roc_curve
  8 | from ..utils.bootstrap_utils import empirical_bootstrap
  9 | 
 10 | import warnings
 11 | from sklearn.exceptions import UndefinedMetricWarning
 12 | 
 13 | # Filter out the warnings due to not enough positive/negative samples during bootstrapping
 14 | warnings.filterwarnings('ignore', category=UndefinedMetricWarning)
 15 | 
 16 | @dataclass
 17 | class MetricConfig:
 18 |     _target_:str = ""
 19 |     
 20 |     name:str = ""#name of the metric e.g. auroc
 21 | 
 22 |     aggregation:str = "" #"" means no aggregation across segments of the same sequence, other options: "mean", "max"
 23 |     
 24 |     key_summary_metric:str = "" #key into the output dict that can serve as summary metric for early stopping etc e.g. (without key_prefix and key_postfix and aggregation type)
 25 |     mode_summary_metric:str ="max" #used to determine if key_summary_metric is supposed to be maximized or minimized
 26 |     
 27 |     verbose:str = "" # comma-separated list of keys to be printed after metric evaluation (without key_prefix and key_postfix and aggregation type)
 28 |     
 29 |     bootstrap_report_nans:bool = False #report nans during bootstrapping (due to not enough labels of a certain type in certain bootstrap iterations etc)
 30 |     bootstrap_iterations:int = 0 #0: no bootstrap
 31 |     bootstrap_alpha:float= 0.95 # bootstrap alpha
 32 | 
 33 | def _reformat_lbl_itos(k):
 34 |     #    return re.sub(r'(?0 else "")+hparams_metric.name+"_"
 40 |         self.key_postfix = key_postfix
 41 |         self.aggregation = hparams_metric.aggregation
 42 |         self.aggregation_txt = ("_agg" if hparams_metric.aggregation=="mean" else "_agg"+hparams_metric.aggregation) if hparams_metric.aggregation!="" else ""
 43 |         self.key_summary_metric = self.key_prefix+hparams_metric.key_summary_metric+self.aggregation_txt+self.key_postfix #data loader id added by default
 44 |         self.mode_summary_metric = hparams_metric.mode_summary_metric
 45 |         self.verbose = [x for x in hparams_metric.verbose.split(",") if x!=""]
 46 | 
 47 |         self.bootstrap_iterations = hparams_metric.bootstrap_iterations if test else 0 #disable bootstrap during training
 48 |         self.bootstrap_alpha = hparams_metric.bootstrap_alpha
 49 |         self.bootstrap_report_nans = hparams_metric.bootstrap_report_nans
 50 | 
 51 |         self.lbl_itos = [_reformat_lbl_itos(l) for l in lbl_itos]
 52 |         self.keys = self.get_keys(self.lbl_itos)
 53 | 
 54 | 
 55 |     def get_keys(self, lbl_itos):
 56 |         '''returns metrics keys in the order they will later be returned by _eval'''
 57 |         raise NotImplementedError
 58 |     
 59 |     def __call__(self,targs,preds):
 60 |         
 61 |         if(self.bootstrap_iterations==0):
 62 |             point = self._eval(targs,preds)
 63 |         else:
 64 |             point,low,high,nans = empirical_bootstrap((targs,preds), self._eval, n_iterations=self.bootstrap_iterations , alpha=self.bootstrap_alpha,ignore_nans=True)#score_fn_kwargs={"classes":self.lbl_itos}
 65 |         res = {self.key_prefix+k+self.aggregation_txt+self.key_postfix:v for v,k in zip(point,self.keys)}
 66 |         if(self.bootstrap_iterations>0):
 67 |             res_low = {self.key_prefix+k+self.aggregation_txt+self.key_postfix+"_low":v for v,k in zip(low,self.keys)}
 68 |             res_high = {self.key_prefix+k+self.aggregation_txt+self.key_postfix+"_high":v for v,k in zip(high,self.keys)}
 69 |             res_nans = {self.key_prefix+k+self.aggregation_txt+self.key_postfix+"_nans":v for v,k in zip(nans,self.keys)}
 70 |             res.update(res_low)
 71 |             res.update(res_high)
 72 |             if(self.bootstrap_report_nans):
 73 |                 res.update(res_nans)
 74 | 
 75 |         if(len(self.verbose)>0):
 76 |             for k in self.verbose:
 77 |                 print("\n"+self.key_prefix+k+self.aggregation_txt+self.key_postfix+":"+str(res[self.key_prefix+k+self.aggregation_txt+self.key_postfix]))
 78 |         
 79 |         return res
 80 | 
 81 |     def _eval(self,targs,preds):
 82 |         # should return an array of results ordered according to the entries returned by get_keys()
 83 |         raise NotImplementedError
 84 | 
 85 | 
 86 |     
 87 | class MetricAUROC(MetricBase):
 88 |     '''provides class-wise+macro+micro AUROC/AUPR scores'''
 89 |     def __init__(self, hparams_metric, lbl_itos, key_prefix="", key_postfix="0", test=True):
 90 |         super().__init__(hparams_metric, lbl_itos=lbl_itos, key_prefix=key_prefix, key_postfix=key_postfix, test=test)
 91 |         self.precision_recall = hparams_metric.precision_recall
 92 | 
 93 |     def get_keys(self, lbl_itos):
 94 |         return list(lbl_itos)+["micro","macro"]
 95 |     
 96 |     def _eval(self,targs,preds):
 97 |         if(self.precision_recall):
 98 |             _,_,res = multiclass_roc_curve(targs,preds,classes=self.lbl_itos,precision_recall=True)
 99 |             return np.array(list(res.values()))
100 |         else:
101 |             _,_,res = multiclass_roc_curve(targs,preds,classes=self.lbl_itos)
102 |             return np.array(list(res.values()))
103 |         
104 | 
105 | @dataclass
106 | class MetricAUROCConfig(MetricConfig):
107 |     _target_:str = "clinical_ts.metric.base.MetricAUROC"
108 |     key_summary_metric:str = "macro"
109 |     verbose:str="macro" #by default print out macro auc
110 |     precision_recall:bool = False #calculate the area under the precision recall curve instead of the ROC curve
111 |     name:str = "auroc"
112 |     bootstrap_report_nans:bool = True #by default report number of bootstrap iterations where the score was nan (due to insufficient number of labels etc)
113 | 
114 | #shorthand for mean aggregation
115 | @dataclass
116 | class MetricAUROCAggConfig(MetricAUROCConfig):
117 |     aggregation:str="mean"
118 | 
--------------------------------------------------------------------------------
/src/clinical_ts/tabular/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/tabular/base.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['BasicEncoderStatic', 'BasicEncoderStaticConfig', 'BasicEncoderStaticMLP', 'BasicEncoderStaticMLPConfig']
 2 | 
 3 | import torch
 4 | from torch import nn
 5 | import numpy as np
 6 | 
 7 | import dataclasses
 8 | from dataclasses import dataclass, field
 9 | from typing import List
10 | 
11 | from ..template_modules import EncoderStaticBase, EncoderStaticBaseConfig
12 | from collections.abc import Iterable
13 | from ..ts.basic_conv1d_modules.basic_conv1d import bn_drop_lin
14 | 
15 | class BasicEncoderStatic(EncoderStaticBase):
16 |     def __init__(self, hparams_encoder_static, hparams_input_shape, target_dim=None):
17 |         super().__init__(hparams_encoder_static, hparams_input_shape, target_dim)
18 |         self.input_channels_cat = hparams_input_shape.static_dim_cat
19 |         self.input_channels_cont = hparams_input_shape.static_dim
20 |         assert(len(hparams_encoder_static.embedding_dims)==hparams_input_shape.static_dim_cat and len(hparams_encoder_static.vocab_sizes)==hparams_input_shape.static_dim_cat)
21 |         self.embeddings = nn.ModuleList() if hparams_input_shape.static_dim_cat is not None else None
22 |         for v,e in zip(hparams_encoder_static.vocab_sizes,hparams_encoder_static.embedding_dims):
23 |             self.embeddings.append(nn.Embedding(v,e))
24 |         self.input_dim = int(np.sum(hparams_encoder_static.embedding_dims) + hparams_input_shape.static_dim)
25 |         self.input_channels = hparams_input_shape.static_dim + hparams_input_shape.static_dim_cat
26 |         
27 | 
28 |     def embed(self, **kwargs):
29 |         static = kwargs["static"] if "static" in kwargs.keys() else None
30 |         static_cat = kwargs["static_cat"] if "static_cat" in kwargs.keys() else None
31 | 
32 |         res = []
33 |         if(static_cat is not None):
34 |             for i,e in enumerate(self.embeddings):
35 |                 res.append(e(static_cat[:,i].long()))
36 |             if(static is not None and static_cat is not None):
37 |                 res = torch.cat([torch.cat(res,dim=1),static],dim=1)
38 |             else:
39 |                 res = torch.cat(res,dim=1)
40 |         else:
41 |             res = static
42 |             
43 |         return res
44 |     
45 |     def forward(self, **kwargs):
46 |         raise NotImplementedError
47 |     
48 |     def get_output_shape(self):
49 |         raise NotImplementedError
50 | 
51 | @dataclass
52 | class BasicEncoderStaticConfig(EncoderStaticBaseConfig):
53 |     _target_:str = "clinical_ts.tabular.base.BasicEncoderStatic"
54 |     embedding_dims:List[int] = field(default_factory=lambda: []) #list with embedding dimensions
55 |     vocab_sizes:List[int] = field(default_factory=lambda: []) #list with vocab sizes (space-separated)
56 | 
57 | class BasicEncoderStaticMLP(BasicEncoderStatic):
58 |     def __init__(self, hparams_encoder_static, hparams_input_shape, target_dim=None):
59 |         super().__init__(hparams_encoder_static, hparams_input_shape, target_dim)
60 |         
61 |         lin_ftrs = [self.input_dim] + list(hparams_encoder_static.lin_ftrs)
62 |         if(target_dim is not None and lin_ftrs[-1]!=target_dim):
63 |             lin_ftrs.append(target_dim)
64 |         ps = [hparams_encoder_static.dropout] if not isinstance(hparams_encoder_static.dropout, Iterable) else hparams_encoder_static.dropout
65 |         if len(ps)==1:
66 |             ps= [ps[0]/2] * (len(lin_ftrs)-2) + ps
67 |         actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
68 |         layers = []
69 |         for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
70 |             layers+=bn_drop_lin(ni,no,hparams_encoder_static.batch_norm,p,actn,layer_norm=False)
71 |         self.layers=nn.Sequential(*layers)
72 | 
73 |         self.output_shape = dataclasses.replace(hparams_input_shape)
74 |         self.output_shape.static_dim = int(lin_ftrs[-1])
75 |         self.output_shape.static_dim_cat = 0
76 |     
77 |     def forward(self, **kwargs):
78 |         res = self.embed(**kwargs)
79 |         return {"static": self.layers(res)}
80 |     
81 |     def get_output_shape(self):
82 |         return self.output_shape
83 | 
84 | 
85 | @dataclass
86 | class BasicEncoderStaticMLPConfig(BasicEncoderStaticConfig):
87 |     _target_:str = "clinical_ts.tabular.base.BasicEncoderStaticMLP"
88 |     lin_ftrs:List[int] = field(default_factory=lambda: [512]) #list with MLP hidden layer sizes; last entry is the static encoder output dimension in case target_dim is not specified
89 |     dropout:float = 0.5
90 |     batch_norm:bool = True
91 | 
--------------------------------------------------------------------------------
/src/clinical_ts/task/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/task/multimodal.py:
--------------------------------------------------------------------------------
 1 | from ..template_model import SSLModel
 2 | from ..template_modules import TaskConfig
 3 | from ..data.time_series_dataset_utils import load_dataset
 4 | from ..data.time_series_dataset_transforms import Transform
 5 | from pathlib import Path
 6 | from dataclasses import dataclass
 7 | import numpy as np
 8 | import pandas as pd
 9 | import warnings
10 | 
11 | #disable performance warning
12 | warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
13 | 
14 | class MultimodalModel(SSLModel):
15 |     '''class for multimodal tasks'''
16 |    
17 |     def preprocess_dataset(self,dataset_kwargs):
18 |         df_mapped, lbl_itos, mean, std = load_dataset(Path(dataset_kwargs.path))
19 | 
20 |         if(self.hparams.loss.loss_type=="supervised" and dataset_kwargs.name.startswith("mimic_labvalues")):
21 |             
22 |             #reformat race
23 |             lbl_itos_ethnicity=['race_asian', 'race_black', 'race_hispanic', 'race_other', 'race_white']
24 | 
25 |             df_mapped["race"]=df_mapped.apply(lambda row: np.where([row[c] for c in lbl_itos_ethnicity])[0],axis=1)
26 |             df_mapped["race"]=df_mapped["race"].apply(lambda x:x[0] if len(x)==1 else -1)
27 |             df_mapped["race_nan"]=(df_mapped["race"]==-1)
28 |             if(self.hparams.task.impute_nans):
29 |                 df_mapped["race"] = df_mapped["race"].replace(-1, df_mapped[(df_mapped["race"] != -1)&(df_mapped.strat_fold<18)]["race"].median())# median imput missing race
30 |             df_mapped.drop(lbl_itos_ethnicity,axis=1,inplace=True)
31 | 
32 |             #prepare cat and cont features
33 |             cat_features = ['gender','race']
34 |             cat_features_m =['race_nan', 'temperature_nan', 'o2sat_nan', 'weight_nan', 'bmi_nan', 'sbp_nan', 'dbp_nan', 'resprate_nan', 'heartrate_nan', 'height_nan']
35 | 
36 |             cont_features = ['resprate', 'o2sat', 'anchor_age', 'dbp', 'heartrate', 'sbp', 'temperature',  'height', 'bmi', 'weight']#'rr_interval', 'qrs_onset', 'qrs_axis', 'p_axis', 't_axis', 'p_end', 'p_onset', 'qrs_end', 't_end'
37 | 
38 |             if(self.hparams.task.impute_nans):
39 |                 input_cols = cat_features + cont_features
40 |                 #grab training set medians and identify columns with nans
41 |                 df_train= df_mapped[df_mapped.strat_fold<18]
42 |                 train_medians= df_train[input_cols].median().to_dict()
43 |                 train_nans = [l for l,c in df_train[input_cols].isna().sum().to_dict().items() if c>0]
44 | 
45 |                 #impute nans through medians introduce additional column _nan (if desired)
46 |                 for c in train_nans:
47 |                     if(self.hparams.task.introduce_nan_columns):
48 |                         df_mapped[c+"_nan"]=0
49 |                         df_mapped.loc[df_mapped[c].isna(),c+"_nan"]=1
50 |                     df_mapped.loc[df_mapped[c].isna(),c]=train_medians[c]
51 | 
52 |                 #defragment
53 |                 df_mapped = df_mapped.copy()
54 | 
55 |             if(self.hparams.task.introduce_nan_columns):
56 |                 cat_features += cat_features_m
57 |             
58 |             cat_features_dim = [len(df_mapped[c].unique()) for c in cat_features]
59 |             
60 |             df_mapped["cat_features"]=df_mapped[cat_features].values.tolist()
61 |             df_mapped.drop(cat_features,axis=1,inplace=True)
62 |             df_mapped["cont_features"]=df_mapped[cont_features].values.tolist()
63 |             df_mapped.drop(cont_features,axis=1,inplace=True)
64 |     
65 |             def replace_nan(arr):
66 |                 return np.where(np.array(arr)==-999.0, np.nan, np.array(arr))
67 |             df_mapped["label"]=df_mapped["aggregated_label"].apply(lambda arr:replace_nan(arr))
68 |             df_mapped.drop("aggregated_label",axis=1,inplace=True)
69 | 
70 |             return df_mapped, lbl_itos, mean, std
71 | 
72 | 
73 | @dataclass
74 | class TaskConfigMultimodal(TaskConfig):
75 |     mainclass:str= "clinical_ts.task.multimodal.MultimodalModel"
76 |     impute_nans:bool = True #impute nans or leave it to the model to handle it
77 |     introduce_nan_columns:bool = False #impute using train set median and introduce an additional column that states if imputation occurred 
78 |     nan_columns_as_cat:bool = False #treat nan columns as categorical variables (instead of continuous)
79 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/base.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['NoPredictor', 'NoPredictorConfig', 'CNNPredictor', 'CNNPredictorConfig']
 2 | 
 3 | import dataclasses
 4 | from dataclasses import dataclass, field
 5 | from typing import List
 6 | 
 7 | import torch.nn as nn
 8 | import numpy as np
 9 | from .basic_conv1d_modules.basic_conv1d import _conv1d
10 | 
11 | from ..template_modules import PredictorBase, PredictorBaseConfig
12 | 
13 | class NoPredictor(PredictorBase):
14 |     def __init__(self, hparams, hparams_input_shape):
15 |         '''
16 |         no predictor e.g. for pretraining purposes
17 |         '''
18 |         super().__init__(hparams, hparams_input_shape)
19 |         
20 |     def forward(self, **kwargs):    
21 |         return {}
22 |     
23 | 
24 | @dataclass
25 | class NoPredictorConfig(PredictorBaseConfig):
26 |     _target_:str = "clinical_ts.ts.base.NoPredictor"
27 | 
28 | 
29 | class CNNPredictor(PredictorBase):
30 |     def __init__(self, hparams_encoder, hparams_input_shape):
31 |         '''this is a reduced version of the RNNEncoder'''
32 |         super().__init__(hparams_encoder, hparams_input_shape)
33 |         assert(not hparams_input_shape.sequence_last)
34 |         assert(len(hparams_encoder.strides)==len(hparams_encoder.kss) and len(hparams_encoder.strides)==len(hparams_encoder.features) and len(hparams_encoder.strides)==len(hparams_encoder.dilations))
35 |         lst = []
36 |         for i,(s,k,f,d) in enumerate(zip(hparams_encoder.strides,hparams_encoder.kss,hparams_encoder.features,hparams_encoder.dilations)):
37 |             lst.append(_conv1d(hparams_input_shape.channels if i==0 else hparams_encoder.features[i-1],f,kernel_size=k,stride=s,dilation=d,bn=hparams_encoder.normalization,layer_norm=hparams_encoder.layer_norm))
38 |            
39 |         self.layers = nn.Sequential(*lst)
40 |         self.downsampling_factor = np.prod(hparams_encoder.strides)
41 |         
42 |         self.output_dim = hparams_encoder.features[-1]
43 | 
44 |         self.output_shape = dataclasses.replace(hparams_input_shape)
45 |         self.output_shape.channels = self.output_dim
46 |         self.output_shape.length = int(hparams_input_shape.length//self.downsampling_factor+ (1 if hparams_input_shape.length%self.downsampling_factor>0 else 0))
47 | 
48 |     def get_output_shape(self):
49 |         return self.output_shape
50 | 
51 |     def forward(self, **kwargs):
52 |         seq =  kwargs["seq"].transpose(1,2)
53 |         return {"seq": self.layers(seq).transpose(1,2)}#bs,seq,feat
54 |     
55 | @dataclass
56 | class CNNPredictorConfig(PredictorBaseConfig):
57 |     _target_:str = "clinical_ts.ts.base.CNNPredictor"
58 | 
59 |     strides:List[int]=field(default_factory=lambda: [1,1,1,1]) #help="encoder strides (space-separated)")
60 |     kss:List[int]=field(default_factory=lambda: [1,1,1,1]) #help="encoder kernel sizes (space-separated)")
61 |     features:List[int]=field(default_factory=lambda: [512,512,512,512]) #help="encoder features (space-separated)")
62 |     dilations:List[int]=field(default_factory=lambda: [1,1,1,1]) #help="encoder dilations (space-separated)")
63 |     normalization:bool=True #help="disable encoder batch/layer normalization")
64 |     layer_norm:bool=False#", action="store_true", help="encoder layer normalization")
65 | 
66 | 
67 |     
--------------------------------------------------------------------------------
/src/clinical_ts/ts/basic_conv1d_modules/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/basic_conv1d_modules/basic_conv1d.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['AdaptiveConcatPool1d', 'SqueezeExcite1d',
  2 |            'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d']
  3 | 
  4 | # Cell
  5 | import torch
  6 | import torch.nn as nn
  7 | import torch.nn.functional as F
  8 | import math
  9 | 
 10 | from typing import Iterable
 11 | 
 12 | class Flatten(nn.Module):
 13 |     "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
 14 |     def __init__(self, full:bool=False): 
 15 |         super().__init__()
 16 |         self.full = full
 17 |     def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)
 18 | 
 19 | 
 20 | 
 21 | def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None, layer_norm=False, permute=False):
 22 |     '''
 23 |     Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`. 
 24 |     permute for input of the form B,Seq,Feat"
 25 |     '''
 26 |     layers=[]
 27 |     if(permute):
 28 |         layers.append(LambdaLayer(lambda x: x.permute(0,2,1)))
 29 |     if(bn):
 30 |         if(layer_norm is False):
 31 |             layers.append(nn.BatchNorm1d(n_in))
 32 |         else:
 33 |             layers.append(nn.LayerNorm(n_in))
 34 |     if(permute):
 35 |         layers.append(LambdaLayer(lambda x: x.permute(0,2,1)))
 36 |     if p != 0: layers.append(nn.Dropout(p))
 37 |     layers.append(nn.Linear(n_in, n_out))
 38 |     if actn is not None: layers.append(actn)
 39 |     return layers
 40 | 
 41 | # Cell
 42 | 
 43 | class LambdaLayer(nn.Module):
 44 |     def __init__(self, lambd):
 45 |         super(LambdaLayer, self).__init__()
 46 |         self.lambd = lambd
 47 |     def forward(self, x):
 48 |         return self.lambd(x)
 49 | 
 50 | def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0, layer_norm=False):
 51 |     lst=[]
 52 |     if(drop_p>0):
 53 |         lst.append(nn.Dropout(drop_p))
 54 |     lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn)))
 55 |     if(bn):
 56 |         if(layer_norm):
 57 |             lst.append(LambdaLayer(lambda x: x.transpose(1,2)))
 58 |             lst.append(nn.LayerNorm(out_planes))
 59 |             lst.append(LambdaLayer(lambda x: x.transpose(1,2)))
 60 |         else:
 61 |             lst.append(nn.BatchNorm1d(out_planes))
 62 |     if(act=="relu"):
 63 |         lst.append(nn.ReLU(True))
 64 |     if(act=="elu"):
 65 |         lst.append(nn.ELU(True))
 66 |     if(act=="prelu"):
 67 |         lst.append(nn.PReLU(True))
 68 |     if(act=="gelu"):
 69 |         lst.append(nn.GELU())
 70 |     return nn.Sequential(*lst)
 71 | 
 72 | def _fc(in_planes,out_planes, act="relu", bn=True):
 73 |     lst = [nn.Linear(in_planes, out_planes, bias=not(bn))]
 74 |     if(bn):
 75 |         lst.append(nn.BatchNorm1d(out_planes))
 76 |     if(act=="relu"):
 77 |         lst.append(nn.ReLU(True))
 78 |     if(act=="elu"):
 79 |         lst.append(nn.ELU(True))
 80 |     if(act=="prelu"):
 81 |         lst.append(nn.PReLU(True))
 82 |     return nn.Sequential(*lst)
 83 | 
 84 | class AdaptiveConcatPool1d(nn.Module):
 85 |     "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
 86 |     def __init__(self, sz=None):
 87 |         "Output will be 2*sz or 2 if sz is None"
 88 |         super().__init__()
 89 |         sz = sz or 1
 90 |         self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
 91 |     def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
 92 | 
 93 | # Cell
 94 | class SqueezeExcite1d(nn.Module):
 95 |     '''squeeze excite block as used for example in LSTM FCN'''
 96 |     def __init__(self,channels,reduction=16):
 97 |         super().__init__()
 98 |         channels_reduced = channels//reduction
 99 |         self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0))
100 |         self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))
101 | 
102 |     def forward(self, x):
103 |         #input is bs,ch,seq
104 |         z=torch.mean(x,dim=2,keepdim=True)#bs,ch
105 |         intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
106 |         s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
107 |         return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq
108 | 
109 | # Cell
110 | def weight_init(m):
111 |     '''call weight initialization for model n via n.appy(weight_init)'''
112 |     if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
113 |         nn.init.kaiming_normal_(m.weight)
114 |         if m.bias is not None:
115 |             nn.init.zeros_(m.bias)
116 |     if isinstance(m, nn.BatchNorm1d):
117 |         nn.init.constant_(m.weight,1)
118 |         nn.init.constant_(m.bias,0)
119 |     if isinstance(m,SqueezeExcite1d):
120 |         stdv1=math.sqrt(2./m.w1.size[0])
121 |         nn.init.normal_(m.w1,0.,stdv1)
122 |         stdv2=math.sqrt(1./m.w2.size[1])
123 |         nn.init.normal_(m.w2,0.,stdv2)
124 | 
125 | # Cell
126 | def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn:bool=True, act="relu", concat_pooling=True):
127 |     "Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"
128 |     lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc]
129 |     ps = [ps] if not isinstance(ps,Iterable) else ps
130 |     if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
131 |     actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None]
132 |     layers = [AdaptiveConcatPool1d() if concat_pooling else nn.AdaptiveAvgPool1d(1), Flatten()]
133 |     for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
134 |         layers += bn_drop_lin(ni,no,bn,p,actn)
135 |     return nn.Sequential(*layers)
136 | 
137 | # Cell
138 | class basic_conv1d(nn.Sequential):
139 |     '''basic conv1d'''
140 |     def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
141 |         layers = []
142 |         if(isinstance(kernel_size,int)):
143 |             kernel_size = [kernel_size]*len(filters)
144 |         for i in range(len(filters)):
145 |             layers_tmp = []
146 | 
147 |             layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p)))
148 |             if((split_first_layer is True and i==0)):
149 |                 layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.))
150 |                 #layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
151 |                 #layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
152 |             if(pool>0 and i0):
155 |                 layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction))
156 |             layers.append(nn.Sequential(*layers_tmp))
157 | 
158 |         #head
159 |         #layers.append(nn.AdaptiveAvgPool1d(1))
160 |         #layers.append(nn.Linear(filters[-1],num_classes))
161 |         #head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
162 |         self.headless = headless
163 |         if(headless is True):
164 |             head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten())
165 |         else:
166 |             head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
167 |         layers.append(head)
168 | 
169 |         super().__init__(*layers)
170 | 
171 |     def get_layer_groups(self):
172 |         return (self[2],self[-1])
173 | 
174 |     def get_output_layer(self):
175 |         if self.headless is False:
176 |             return self[-1][-1]
177 |         else:
178 |             return None
179 | 
180 |     def set_output_layer(self,x):
181 |         if self.headless is False:
182 |             self[-1][-1] = x
183 | 
184 | 
185 | # Cell
186 | def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs):
187 |     filters_in = filters + [num_classes]
188 |     return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True)
189 | 
190 | def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
191 |     return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
192 | 
193 | def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
194 |     return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
195 | 
196 | def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
197 |     return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
198 | 
199 | def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
200 |     return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
201 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/encoder.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['NoEncoder', 'NoEncoderConfig']
 2 | 
 3 | import torch
 4 | from torch import nn
 5 | import numpy as np
 6 | 
 7 | import dataclasses
 8 | from dataclasses import dataclass, field
 9 | from typing import List
10 | 
11 | from ..template_modules import EncoderBase, EncoderBaseConfig
12 | from .basic_conv1d_modules.basic_conv1d import _conv1d
13 | from .transformer_modules.transformer import TransformerConvStemTokenizer
14 | 
15 | class NoEncoder(EncoderBase):
16 |     def __init__(self, hparams_encoder, hparams_input_shape):
17 |         '''
18 |         no encoder- flattens by default if multiple channels are passed
19 |         '''
20 |         super().__init__(hparams_encoder, hparams_input_shape)
21 |         self.timesteps_per_token = hparams_encoder.timesteps_per_token
22 |         self.sequence_last = hparams_input_shape.sequence_last
23 |         self.input_channels = hparams_input_shape.channels if hparams_input_shape.channels2==0 else hparams_input_shape.channels*hparams_input_shape.channels2
24 |         
25 |         self.output_shape = dataclasses.replace(hparams_input_shape)
26 |         self.output_shape.channels = self.input_channels*self.timesteps_per_token
27 |         self.output_shape.channels2 = 0
28 |         self.output_shape.length = hparams_input_shape.length//self.timesteps_per_token
29 |         self.output_shape.sequence_last = False
30 |     
31 |     def forward(self, **kwargs):
32 |         seq = kwargs["seq"] #bs,channels,freq,seq
33 |         if(not self.sequence_last):
34 |             seq = torch.movedim(seq,1,-1)
35 |         if(len(seq.size())==4):#spectrogram input
36 |             seq = seq.view(seq.size(0),-1,seq.size(-1))#flatten  
37 |           
38 |         if(self.timesteps_per_token==1):
39 |             return {"seq": seq.transpose(1,2)}
40 |         else:
41 |             assert(seq.size(-1)%self.timesteps_per_token==0)
42 |             size = seq.size()
43 |             return {"seq": seq.view(size[0],-1,seq.shape[-1]).transpose(1,2).reshape(size[0],size[2]//self.timesteps_per_token,-1).transpose(1,2)}
44 | 
45 |     def get_output_shape(self):
46 |         return self.output_shape
47 | 
48 | 
49 | @dataclass
50 | class NoEncoderConfig(EncoderBaseConfig):
51 |     _target_:str = "clinical_ts.ts.encoder.NoEncoder"
52 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/head.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['PoolingHead', 'PoolingHeadConfig']
 2 | 
 3 | import torch
 4 | import torch.nn as nn
 5 | import numpy as np
 6 | 
 7 | from ..template_modules import HeadBase, HeadBaseConfig, _string_to_class
 8 | import dataclasses
 9 | from dataclasses import dataclass, field
10 | from typing import List
11 |     
12 | class PoolingHead(HeadBase):
13 |     def __init__(self, hparams_head, hparams_input_shape, target_dim):
14 |         super().__init__(hparams_head, hparams_input_shape, target_dim)
15 |         #assert(target_dim is None or hparams_head.output_layer is True)
16 |         if(target_dim is not None and hparams_head.output_layer is False):
17 |             print("Warning: target_dim",target_dim,"is passed to PoolingHead but output_layer is False. target_dim will be ignored.")
18 |         self.local_pool = hparams_head.multi_prediction
19 |         self.output_dim = hparams_input_shape.channels if not hparams_head.output_layer else target_dim
20 |         
21 |         if(self.local_pool):#local pool
22 |             self.local_pool_padding = (hparams_head.local_pool_kernel_size-1)//2
23 |             self.local_pool_kernel_size = hparams_head.local_pool_kernel_size
24 |             self.local_pool_stride = hparams_head.local_pool_kernel_size if hparams_head.local_pool_stride==0 else hparams_head.local_pool_stride
25 |             if(hparams_head.local_pool_max):
26 |                 self.pool = torch.nn.MaxPool1d(kernel_size=hparams_head.local_pool_kernel_size,stride=hparams_head.local_pool_stride if hparams_head.local_pool_stride!=0 else hparams_head.local_pool_kernel_size,padding=(hparams_head.local_pool_kernel_size-1)//2)
27 |             else:
28 |                 self.pool = torch.nn.AvgPool1d(kernel_size=hparams_head.local_pool_kernel_size,stride=hparams_head.local_pool_stride if hparams_head.local_pool_stride!=0 else hparams_head.local_pool_kernel_size,padding=(hparams_head.local_pool_kernel_size-1)//2)        
29 |         else:#global pool
30 |             if(hparams_head.local_pool_max):
31 |                 self.pool = torch.nn.AdaptiveMaxPool1d(1)
32 |             else:
33 |                 self.pool = torch.nn.AdaptiveAvgPool1d(1)
34 |         self.linear = nn.Linear(hparams_input_shape.channels, target_dim) if hparams_head.output_layer else nn.Identity()
35 | 
36 |         self.output_shape = dataclasses.replace(hparams_input_shape)
37 |         self.output_shape.channels = self.output_dim
38 |         #assert(hparams.predictor._target_!="clinical_ts.ts.transformer.TransformerPredictor" or (hparams.predictor.cls_token is True or (hparams.predictor.cls_token is False and (hparams_head.head_pooling_type!="cls" and hparams_head.head_pooling_type!="meanmax-cls"))))
39 | 
40 |         self.output_shape.length = int(np.floor((hparams_input_shape.length + 2*self.local_pool_padding- self.local_pool_kernel_size)/self.local_pool_stride+1)) if self.local_pool else 0
41 | 
42 |     def forward(self, **kwargs):
43 |         seq = kwargs["seq"]
44 |         #input has shape B,S,E
45 |         seq = seq.transpose(1,2) 
46 |         seq = self.pool(seq)
47 |         return {"seq": self.linear(seq.transpose(1,2))}#return B,S,E
48 |     
49 |     def get_output_shape(self):
50 |         return self.output_shape
51 | 
52 | @dataclass
53 | class PoolingHeadConfig(HeadBaseConfig):
54 |     _target_:str = "clinical_ts.ts.head.PoolingHead"
55 |     
56 |     multi_prediction:bool = False #local pool vs. global pool
57 |     local_pool_max:bool = False #max pool vs. avg pool
58 |     local_pool_kernel_size: int = 0 #kernel size for local pooling
59 |     local_pool_stride: int = 0 #kernel_size if 0
60 |     #local_pool_padding=(kernel_size-1)//2
61 |     output_layer: bool = False
62 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/s4.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['S4Predictor','S4PredictorConfig']
 2 | 
 3 | from .s4_modules.s4_model import S4Model
 4 | from ..template_modules import PredictorBase, PredictorBaseConfig
 5 | from typing import Any
 6 | from dataclasses import dataclass
 7 | 
 8 | class S4Predictor(PredictorBase):
 9 |     def __init__(self, hparams_predictor, hparams_input_shape):
10 |         super().__init__(hparams_predictor, hparams_input_shape)
11 |         self.predictor = S4Model(
12 |             d_input = hparams_input_shape.channels if hparams_input_shape.channels!=hparams_predictor.model_dim else None,#modified
13 |             d_output = None,
14 |             d_state = hparams_predictor.state_dim,
15 |             d_model = hparams_predictor.model_dim,
16 |             n_layers = hparams_predictor.layers,
17 |             dropout = hparams_predictor.dropout,
18 |             tie_dropout = hparams_predictor.tie_dropout,
19 |             prenorm = hparams_predictor.prenorm,
20 |             l_max = hparams_input_shape.length,
21 |             transposed_input = False,
22 |             bidirectional=not(hparams_predictor.causal),
23 |             layer_norm=not(hparams_predictor.batchnorm),
24 |             pooling = False,
25 |             backbone = hparams_predictor.backbone) #note: only apply linear layer before if feature dimensions do not match
26 | 
27 |     def forward(self, **kwargs):   
28 |         return {"seq": self.predictor(kwargs["seq"])}
29 | 
30 | @dataclass
31 | class S4PredictorConfig(PredictorBaseConfig):
32 |     _target_:str = "clinical_ts.ts.s4.S4Predictor"
33 |     model_dim:int = 512 
34 |     causal: bool = True #use bidirectional predictor
35 |     state_dim:int = 64 #help="S4: N")
36 |     layers:int = 4
37 |     dropout:float=0.2
38 |     tie_dropout:bool=True
39 |     prenorm:bool=False
40 |     batchnorm:bool=False
41 |     backbone:str="s42" #help="s4original/s4new/s4d")  
--------------------------------------------------------------------------------
/src/clinical_ts/ts/s4_modules/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/s4_modules/s4_model.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['S4Model']
  2 | #adapted from https://github.com/HazyResearch/state-spaces/blob/main/example.py
  3 | 
  4 | import torch.nn as nn
  5 | 
  6 | from .s42 import S4 as S42
  7 | 
  8 | from .s4_utils import DropoutNd
  9 | 
 10 | class S4Model(nn.Module):
 11 | 
 12 |     def __init__(
 13 |         self, 
 14 |         d_input, # None to disable encoder
 15 |         d_output, # None to disable decoder
 16 |         d_state=64, #MODIFIED: N
 17 |         d_model=512, #MODIFIED: H
 18 |         n_layers=4, 
 19 |         dropout=0.2,
 20 |         tie_dropout=False, #MODIFIED
 21 |         prenorm=False,
 22 |         l_max=1024,
 23 |         transposed_input=True, # behaves like 1d CNN if True else like a RNN with batch_first=True
 24 |         bidirectional=True, #MODIFIED
 25 |         layer_norm = True, # MODIFIED
 26 |         pooling = True, # MODIFIED
 27 |         backbone= "s42" # MODIFIED
 28 |     ):
 29 |         super().__init__()
 30 | 
 31 |         self.prenorm = prenorm
 32 | 
 33 |         # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
 34 |         self.transposed_input = transposed_input
 35 |         
 36 |         # MODIFIED TO ALLOW FOR MODELS WITHOUT ENCODER
 37 |         if(d_input is None):
 38 |             self.encoder = nn.Identity()
 39 |         else:
 40 |             self.encoder = nn.Conv1d(d_input, d_model, 1) if transposed_input else nn.Linear(d_input, d_model)
 41 | 
 42 |         # Stack S4 layers as residual blocks
 43 |         self.s4_layers = nn.ModuleList()
 44 |         self.norms = nn.ModuleList()
 45 |         self.dropouts = nn.ModuleList()
 46 |         for _ in range(n_layers):
 47 |             if(backbone == "s42"):
 48 |                 self.s4_layers.append(
 49 |                 S42(
 50 |                     d_state=d_state,
 51 |                     l_max=l_max,
 52 |                     d_model=d_model, 
 53 |                     bidirectional=bidirectional,
 54 |                     postact='glu',
 55 |                     dropout=dropout,
 56 |                     tie_dropout=tie_dropout, 
 57 |                     transposed=True,
 58 |                 ))
 59 |             
 60 |             #MODIFIED TO ALLOW BATCH NORM MODELS
 61 |             self.layer_norm = layer_norm
 62 |             if(layer_norm):
 63 |                 self.norms.append(nn.LayerNorm(d_model))
 64 |             else: #MODIFIED
 65 |                 self.norms.append(nn.BatchNorm1d(d_model))
 66 | 
 67 |             if(tie_dropout):
 68 |                 self.dropouts.append(DropoutNd(dropout,transposed=True))
 69 |             else:
 70 |                 self.dropouts.append(nn.Dropout(dropout))
 71 | 
 72 |         self.pooling = pooling
 73 |         # Linear decoder
 74 |         # MODIFIED TO ALLOW FOR MODELS WITHOUT DECODER
 75 |         if(d_output is None):
 76 |             self.decoder = None
 77 |         else:
 78 |             self.decoder = nn.Linear(d_model, d_output)
 79 | 
 80 |     #MODIFIED
 81 |     def forward(self, x, rate=1.0):
 82 |         """
 83 |         Input x is shape (B, d_input, L) if transposed_input else (B, L, d_input)
 84 |         """
 85 |         x = self.encoder(x)  # (B, d_input, L) -> (B, d_model, L) if transposed_input else (B, L, d_input) -> (B, L, d_model)
 86 | 
 87 |         if(self.transposed_input is False):
 88 |             x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
 89 |         
 90 |         for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
 91 |             # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)
 92 | 
 93 |             z = x
 94 |             if self.prenorm:
 95 |                 # Prenorm
 96 |                 # MODIFIED
 97 |                 z = norm(z.transpose(-1, -2)).transpose(-1, -2) if self.layer_norm else norm(z)
 98 |             
 99 |             # Apply S4 block: we ignore the state input and output
100 |             # MODIFIED
101 |             z, _ = layer(z, rate=rate)
102 | 
103 |             # Dropout on the output of the S4 block
104 |             z = dropout(z)
105 | 
106 |             # Residual connection
107 |             x = z + x
108 | 
109 |             if not self.prenorm:
110 |                 # Postnorm
111 |                 # MODIFIED
112 |                 x = norm(x.transpose(-1, -2)).transpose(-1, -2) if self.layer_norm else norm(z)
113 | 
114 |         x = x.transpose(-1, -2) # (B, d_model, L) -> (B, L, d_model)
115 | 
116 |         # MODIFIED ALLOW TO DISABLE POOLING
117 |         if(self.pooling):
118 |             # Pooling: average pooling over the sequence length
119 |             x = x.mean(dim=1)
120 | 
121 |         # Decode the outputs
122 |         if(self.decoder is not None):
123 |             x = self.decoder(x)  # (B, d_model) -> (B, d_output) if pooling else (B, L, d_model) -> (B, L, d_output)
124 |             
125 |         if(not self.pooling and self.transposed_input is True):
126 |             x = x.transpose(-1, -2) # (B, L, d_output) -> (B, d_output, L)
127 |         return x
128 | 
--------------------------------------------------------------------------------
/src/clinical_ts/ts/s4_modules/s4_utils.py:
--------------------------------------------------------------------------------
 1 | # adapted from https://github.com/state-spaces/s4/blob/main/models/s4/s4.py
 2 | import torch
 3 | import torch.nn as nn
 4 | from einops import rearrange
 5 | 
 6 | class DropoutNd(nn.Module):
 7 |     def __init__(self, p: float = 0.5, tie=True, transposed=True):
 8 |         """
 9 |         tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
10 |         """
11 |         super().__init__()
12 |         if p < 0 or p >= 1:
13 |             raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
14 |         self.p = p
15 |         self.tie = tie
16 |         self.transposed = transposed
17 |         self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
18 | 
19 |     def forward(self, X):
20 |         """X: (batch, dim, lengths...)."""
21 |         if self.training:
22 |             if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
23 |             mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
24 |             mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
25 |             X = X * mask * (1.0/(1-self.p))
26 |             if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
27 |             return X
28 |         return X
--------------------------------------------------------------------------------
/src/clinical_ts/ts/transformer_modules/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/utils/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/clinical_ts/utils/bootstrap_utils.py:
--------------------------------------------------------------------------------
 1 | __all__ = ['empirical_bootstrap']
 2 | 
 3 | import numpy as np
 4 | from sklearn.utils import resample
 5 | from multiprocessing import Pool
 6 | from functools import partial
 7 | from tqdm.auto import tqdm
 8 | 
 9 | def _eval(ids, input_tuple, score_fn, input_tuple2=None,score_fn_kwargs={}):
10 |     return score_fn(*[t[ids] for t in input_tuple],**score_fn_kwargs) if input_tuple2 is None else score_fn(*[t[ids] for t in input_tuple],**score_fn_kwargs)-score_fn(*[t[ids] for t in input_tuple2],**score_fn_kwargs)
11 | 
12 | def empirical_bootstrap(input_tuple, score_fn, ids=None, n_iterations=1000, alpha=0.95, score_fn_kwargs={},threads=None, input_tuple2=None, ignore_nans=False, chunksize=50):
13 |     '''
14 |         performs empirical bootstrap https://ocw.mit.edu/courses/mathematics/18-05-introduction-to-probability-and-statistics-spring-2014/readings/MIT18_05S14_Reading24.pdf
15 |         
16 |         input_tuple: tuple of inputs for the score function typically something like (labels,predictions)
17 |         score_function: scoring function that takes the individual entries of input tuple as argument e.g. f1_score
18 |         id: list of previously sampled ids (if None new ids will be sampled)
19 |         n_iterations: number of bootstrap iterations
20 |         alpha: alpha-level for the confidence intervals
21 |         score_fn_kwargs: additional (static) kwargs to be passed to the score_fn
22 |         threads: number of threads (None uses os.cpu_count()); 0 no multithreading
23 |         input_tuple2: if not None this is a second input of the same shape as input_tuple- in that case the function bootstraps the score difference between both inputs (this is just a convenience function- the same could be achieved by passing a tuple of the form (label,preds1,preds2) and computing the difference in the score_function itself)
24 |         ignore_nans: ignore nans (e.g. no positives during during AUC evaluation) for score evaluation
25 |         chunksize: process in chunks of size chunksize
26 |     '''
27 |     
28 |     if(not(isinstance(input_tuple,tuple))):
29 |         input_tuple = (input_tuple,)
30 |     if(input_tuple2 is not None and not(isinstance(input_tuple2,tuple))):
31 |         input_tuple2 = (input_tuple2,)
32 |         
33 |     score_point = score_fn(*input_tuple,**score_fn_kwargs) if input_tuple2 is None else score_fn(*input_tuple,**score_fn_kwargs)-score_fn(*input_tuple2,**score_fn_kwargs)
34 | 
35 |     if(n_iterations==0):
36 |         return score_point,np.zeros(score_point.shape),np.zeros(score_point.shape),[]
37 |     
38 |     if(ids is None):
39 |         ids = []
40 |         for _ in range(n_iterations):
41 |             ids.append(resample(range(len(input_tuple[0])), n_samples=len(input_tuple[0])))
42 |         ids = np.array(ids)
43 | 
44 |     fn = partial(_eval,input_tuple=input_tuple,score_fn=score_fn,input_tuple2=input_tuple2,score_fn_kwargs=score_fn_kwargs)
45 | 
46 |     if(threads is not None and threads==0):
47 |         results= np.array(fn(ids)).astype(np.float32)#shape: bootstrap_iterations, number_of_evaluation_metrics
48 |     else:
49 |         results=[]       
50 |         for istart in tqdm(np.arange(0,n_iterations,chunksize)):
51 |             iend = min(n_iterations,istart+chunksize)
52 |             pool = Pool(threads)
53 |             results.append(np.array(pool.map(fn, ids[istart:iend])).astype(np.float32))
54 |             pool.close()
55 |             pool.join()
56 |     
57 |         results = np.concatenate(results,axis=0)
58 |     
59 |     percentile_fn = np.nanpercentile if ignore_nans else np.percentile        
60 |     score_diff = np.array(results)- score_point
61 |     score_low = score_point + percentile_fn(score_diff, ((1.0-alpha)/2.0) * 100,axis=0)
62 |     score_high = score_point + percentile_fn(score_diff, (alpha+((1.0-alpha)/2.0)) * 100,axis=0)
63 | 
64 |     if(ignore_nans):#in this case return the number of nans in each score rather than the sampled ids (which could be different when evaluating several metrics at once)
65 |         return score_point, score_low, score_high, np.sum(np.isnan(score_diff),axis=0)
66 |     else:
67 |         return score_point, score_low, score_high, ids
--------------------------------------------------------------------------------
/src/clinical_ts/utils/callbacks.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['ForwardHook','TriggerQuantizerHyperparameterUpdate','UnfreezingFinetuningCallback','LRMonitorCallback', 'cos_anneal', 'DecayLR',"freeze_bn_stats","sanity_check"]
  2 | 
  3 | import torch
  4 | from torch import nn
  5 | import lightning.pytorch as lp
  6 | import math
  7 | 
  8 | from lightning.pytorch.callbacks import Callback, BaseFinetuning
  9 | 
 10 | class ForwardHook:
 11 |     "Create a forward hook on module `m` "
 12 | 
 13 |     def __init__(self, m, store_output=True):
 14 |         self.store_output = store_output
 15 |         self.hook = m.register_forward_hook(self.hook_fn)
 16 |         self.stored, self.removed = None, False
 17 | 
 18 |     def hook_fn(self, module, input, output):
 19 |         "stores input/output"
 20 |         if self.store_output:
 21 |             self.stored = output
 22 |         else:
 23 |             self.stored = input
 24 | 
 25 |     def remove(self):
 26 |         "Remove the hook from the model."
 27 |         if not self.removed:
 28 |             self.hook.remove()
 29 |             self.removed = True
 30 | 
 31 |     def __enter__(self, *args):
 32 |         return self
 33 | 
 34 |     def __exit__(self, *args):
 35 |         self.remove()
 36 | 
 37 | class TriggerQuantizerHyperparameterUpdate(Callback):
 38 |     def __init__(self,quantizer_modules):
 39 |         super(TriggerQuantizerHyperparameterUpdate, self).__init__()
 40 |         self.modules = quantizer_modules
 41 |         
 42 |     def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
 43 |         for m in self.modules:
 44 |             m.update_hyperparams(trainer.global_step)
 45 | 
 46 | class UnfreezingFinetuningCallback(BaseFinetuning):
 47 | 
 48 |     def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
 49 |         super().__init__()
 50 |         self.unfreeze_epoch = unfreeze_epoch
 51 |         self.train_bn = train_bn
 52 | 
 53 |     def freeze_before_training(self, pl_module: lp.LightningModule):
 54 |         modules = pl_module.get_params(modules=True)
 55 |         for mod in modules[1:]:
 56 |             self.freeze(mod["params"], train_bn=self.train_bn)
 57 |             
 58 |     def finetune_function(self, pl_module: lp.LightningModule, epoch: int, optimizer: torch.optim.Optimizer, opt_idx: int):
 59 |         """Called on every epoch starts."""
 60 |         if epoch == self.unfreeze_epoch:
 61 |             modules = pl_module.get_params(modules=True)
 62 |             for mod in modules[1:]:
 63 |                 self.unfreeze_and_add_param_group(
 64 |                     mod["params"],
 65 |                     optimizer,
 66 |                     lr=mod["lr"]*optimizer.param_groups[0]["lr"]/modules[0]["lr"],
 67 |                     train_bn=self.train_bn,
 68 |                 )
 69 | 
 70 | class LRMonitorCallback(Callback):
 71 |     def __init__(self,interval="epoch",start=True,end=True):
 72 |         super().__init__()
 73 |         self.interval = interval
 74 |         self.start = start
 75 |         self.end = end
 76 |         
 77 |     def on_train_batch_start(self, trainer, *args, **kwargs):                
 78 |         if(self.interval == "step" and self.start):
 79 |             current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups]
 80 |             print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs)
 81 | 
 82 |     def on_train_epoch_start(self, trainer, *args, **kwargs):                
 83 |         if(self.interval == "epoch" and self.start):
 84 |             current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups]
 85 |             print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs)
 86 |     
 87 |     def on_train_batch_end(self, trainer, *args, **kwargs):                
 88 |         if(self.interval == "step" and self.end):
 89 |             current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups]
 90 |             print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs)
 91 | 
 92 |     def on_train_epoch_end(self, trainer, *args, **kwargs):                
 93 |         if(self.interval == "epoch" and self.end):
 94 |             current_lrs = [d['lr'] for d in trainer.optimizers[0].param_groups]
 95 |             print(f'Epoch: {trainer.current_epoch} Step: {trainer.global_step} LRs:',current_lrs)
 96 | 
 97 | ############################################################################################################
 98 | def freeze_bn_stats(model, freeze=True):
 99 |     for m in model.modules():
100 |         if(isinstance(m,nn.BatchNorm1d)):
101 |             if(freeze):
102 |                 m.eval()
103 |             else:
104 |                 m.train()
105 | 
106 | ############################################################################################################
107 | def sanity_check(model, state_dict_pre):
108 |     """
109 |     Linear classifier should not change any weights other than the linear layer.
110 |     This sanity check asserts nothing wrong happens (e.g., BN stats updated).
111 |     """
112 |     print("=> loading state dict for sanity check")
113 |     state_dict = model.state_dict()
114 | 
115 |     for k in list(state_dict.keys()):
116 |         # only ignore fc layer
117 |         if 'head.1.weight' in k or 'head.1.bias' in k or 'head.4.weight' in k or 'head.4.bias' in k:
118 |             continue
119 | 
120 | 
121 |         assert ((state_dict[k].cpu() == state_dict_pre[k].cpu()).all()), \
122 |             '{} is changed in linear classifier training.'.format(k)
123 | 
124 |     print("=> sanity check passed.")
125 | 
126 | ############################################################################################################
127 | #from https://github.com/karpathy/deep-vector-quantization/blob/main/dvq/vqvae.py
128 | # -----------------------------------------------------------------------------
129 | def cos_anneal(e0, e1, t0, t1, e):
130 |     """ ramp from (e0, t0) -> (e1, t1) through a cosine schedule based on e in [e0, e1] """
131 |     alpha = max(0, min(1, (e - e0) / (e1 - e0))) # what fraction of the way through are we
132 |     alpha = 1.0 - math.cos(alpha * math.pi/2) # warp through cosine
133 |     t = alpha * t1 + (1 - alpha) * t0 # interpolate accordingly
134 |     return t
135 | 
136 | class DecayLR(Callback):
137 |     def __init__(self,num_steps=1200000,lrstart=3e-4,lrend=1.25e-6):
138 |         super(DecayLR, self).__init__()
139 |         self.num_steps = num_steps
140 |         self.lrstart = lrstart
141 |         self.lrend = lrend
142 | 
143 |     def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
144 |         # The step size is annealed from 1e10−4 to 1.25e10−6 over 1,200,000 updates. I use 3e-4
145 |         t = cos_anneal(0, self.num_steps, self.lrstart, self.lrend, trainer.global_step)
146 |         for g in pl_module.model_cpc.optimizer.param_groups:
147 |             g['lr'] = t
--------------------------------------------------------------------------------
/src/clinical_ts/utils/eval_utils_cafa.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['auc_prrc_uninterpolated', 'multiclass_roc_curve', 'single_eval_prrc', 'eval_prrc', 'eval_prrc_parallel',
  2 |            'eval_scores', 'eval_scores_bootstrap']
  3 | 
  4 | # Cell
  5 | import warnings
  6 | import numpy as np
  7 | import pandas as pd
  8 | from sklearn.metrics import roc_auc_score, auc
  9 | from scipy.interpolate import interp1d
 10 | 
 11 | from sklearn.metrics import roc_curve, precision_recall_curve
 12 | from sklearn.utils import resample
 13 | 
 14 | from tqdm import tqdm
 15 | 
 16 | # Cell
 17 | def auc_prrc_uninterpolated(recall,precision):
 18 |     '''uninterpolated auc as used by sklearn https://github.com/scikit-learn/scikit-learn/blob/1495f6924/sklearn/metrics/ranking.py see also the discussion at https://github.com/scikit-learn/scikit-learn/pull/9583'''
 19 |     #print(-np.sum(np.diff(recall) * np.array(precision)[:-1]),auc(recall,precision))
 20 |     return -np.sum(np.diff(recall) * np.array(precision)[:-1])
 21 | 
 22 | # Cell
 23 | #label-centric metrics
 24 | def multiclass_roc_curve(y_true, y_pred, classes=None, precision_recall=False):
 25 |     '''Compute ROC curve and ROC area for each class "0"..."n_classes - 1" (or classnames passed via classes), "micro", "macro"
 26 |     returns fpr,tpr,roc (dictionaries) for ROC
 27 |     returns recall,precision,average_precision for precision_recall
 28 |     '''
 29 | 
 30 |     fpr = dict()
 31 |     tpr = dict()
 32 |     roc_auc = dict()
 33 |     n_classes=len(y_pred[0])
 34 |     if(classes is None):
 35 |         classes = [str(i) for i in range(n_classes)]
 36 | 
 37 |     for i,c in enumerate(classes):
 38 |         y_truei = y_true[:, i]
 39 |         y_predi = y_pred[:, i]
 40 | 
 41 |         maski = ~np.isnan(y_truei)#mask our nan targets if available
 42 |         y_truei = y_truei[maski]
 43 |         y_predi = y_predi[maski]
 44 | 
 45 |         if(precision_recall):
 46 |             tpr[c], fpr[c], _ = precision_recall_curve(y_truei, y_predi)
 47 |             roc_auc[c] = auc_prrc_uninterpolated(fpr[c], tpr[c])
 48 |         else:
 49 |             fpr[c], tpr[c], _ = roc_curve(y_truei, y_predi)
 50 |             roc_auc[c] = auc(fpr[c], tpr[c])
 51 | 
 52 |     # Compute micro-average curve and area
 53 |     y_true_micro = y_true.ravel()
 54 |     y_pred_micro = y_pred.ravel()
 55 |     mask_micro = ~np.isnan(y_true_micro)
 56 |     y_true_micro = y_true_micro[mask_micro]
 57 |     y_pred_micro = y_pred_micro[mask_micro]
 58 | 
 59 |     if(precision_recall):
 60 |         tpr["micro"], fpr["micro"], _ = precision_recall_curve(y_true_micro, y_pred_micro)
 61 |         roc_auc["micro"] = auc_prrc_uninterpolated(fpr["micro"], tpr["micro"])
 62 |     else:
 63 |         fpr["micro"], tpr["micro"], _ = roc_curve(y_true_micro, y_pred_micro)
 64 |         roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
 65 | 
 66 |     # Compute macro-average curve and area (linear interpolation is incorrect for PRRC- therefore just for ROC)
 67 |     if(precision_recall is False):
 68 |         # 1. First aggregate all unique x values (false positive rates for ROC)
 69 |         all_fpr = np.unique(np.concatenate([fpr[c] for c in classes]))
 70 | 
 71 |         # 2. Then interpolate all curves at this points
 72 |         mean_tpr=None
 73 |         for c in classes:
 74 |             f = interp1d(fpr[c], tpr[c])
 75 |             if(mean_tpr is None):
 76 |                 mean_tpr = f(all_fpr)
 77 |             else:
 78 |                 mean_tpr += f(all_fpr)
 79 | 
 80 |         # 3. Finally average it and compute area
 81 |         mean_tpr /= n_classes
 82 | 
 83 |         fpr["macro"] = all_fpr
 84 |         tpr["macro"] = mean_tpr
 85 |         #macro2 differs slightly from macro due to interpolation effects
 86 |         #roc_auc["macro2"] = auc(fpr["macro"], tpr["macro"])
 87 | 
 88 |     #calculate macro auc directly by summing
 89 |     roc_auc_macro = 0
 90 |     for c in classes:
 91 |         roc_auc_macro += roc_auc[c]
 92 |     roc_auc["macro"]=roc_auc_macro/n_classes
 93 | 
 94 |     #calculate macro auc directly by summing
 95 |     roc_auc_macro = 0
 96 |     macro_auc_nans = 0 #due to an insufficient amount of pos/neg labels
 97 |     for c in classes:
 98 |         if(np.isnan(roc_auc[c])):#conservative choice: replace auc by 0.5 if it could not be calculated
 99 |             roc_auc_macro += 0.5
100 |             macro_auc_nans += 1
101 |         else:
102 |             roc_auc_macro += roc_auc[c]
103 |     roc_auc["macro"]=roc_auc_macro/n_classes
104 |     #roc_auc["macro_nans"] = macro_auc_nans
105 | 
106 |     return fpr, tpr, roc_auc
107 | 
108 | # Cell
109 | def single_eval_prrc(y_true,y_pred,threshold):
110 |     '''evaluate instance-wise scores for a single sample and a single threshold'''
111 |     y_pred_bin = (y_pred >= threshold)
112 |     TP = np.sum(np.logical_and(y_true == y_pred_bin,y_true>0))
113 |     count = np.sum(y_pred_bin)#TP+FP
114 | 
115 |     # Find precision: TP / (TP + FP)
116 |     precision = TP / count if count > 0 else np.nan
117 |     # Find recall/TPR/sensitivity: TP / (TP + FN)
118 |     recall = TP/np.sum(y_true>0)
119 |     # Find FPR/specificity: FP/ (FP + TN)=FP/N
120 |     FP = np.sum(np.logical_and(y_true != y_pred_bin,y_pred_bin>0))
121 |     specificity = FP/ np.sum(y_true==0)
122 |     return precision, recall, specificity
123 | 
124 | # Cell
125 | def eval_prrc(y_true,y_pred,threshold):
126 |     '''eval instance-wise scores across all samples for a single threshold'''
127 |     # Initialize Variables
128 |     PR = 0.0
129 |     RC = 0.0
130 |     SP = 0.0
131 | 
132 |     counts_above_threshold = 0
133 | 
134 |     for i in range(len(y_true)):
135 |         pr,rc,sp = single_eval_prrc(y_true[i],y_pred[i],threshold)
136 |         if pr is not np.nan:
137 |             PR += pr
138 |             counts_above_threshold += 1
139 |         RC += rc
140 |         SP += sp
141 | 
142 |     recall = RC/len(y_true)
143 |     specificity = SP/len(y_true)
144 | 
145 |     if counts_above_threshold > 0:
146 |         precision = PR/counts_above_threshold
147 |     else:
148 |         precision = np.nan
149 |         if(threshold<1.0):
150 |             print("No prediction is made above the %.2f threshold\n" % threshold)
151 |     return precision, recall, specificity, counts_above_threshold/len(y_true)
152 | 
153 | # Cell
154 | def eval_prrc_parallel(y_true,y_pred,thresholds):
155 | 
156 |     y_pred_bin = np.repeat(y_pred[None, :, :], len(thresholds), axis=0)>=thresholds[:,None,None]#thresholds, samples, classes
157 |     TP = np.sum(np.logical_and( y_true == True, y_pred_bin== True),axis=2)#threshold, samples
158 | 
159 |     with np.errstate(divide='ignore', invalid='ignore'):
160 |         den = np.sum(y_pred_bin,axis=2)>0
161 |         precision = TP/np.sum(y_pred_bin,axis=2)
162 |         precision[den==0] = np.nan
163 | 
164 |     recall = TP/np.sum(y_true==True, axis=1)#threshold,samples/samples=threshold,samples
165 | 
166 |     FP = np.sum(np.logical_and((y_true ==False),(y_pred_bin==True)),axis=2)
167 |     specificity = FP/np.sum(y_true==False, axis=1)
168 | 
169 |     with warnings.catch_warnings(): #for nan slices
170 |         warnings.simplefilter("ignore", category=RuntimeWarning)
171 |         av_precision = np.nanmean(precision,axis=1)
172 | 
173 |     av_recall = np.mean(recall,axis=1)
174 |     av_specificity = np.mean(specificity,axis=1)
175 |     av_coverage = np.mean(den,axis=1)
176 | 
177 |     return av_precision, av_recall, av_specificity, av_coverage
178 | 
179 | 
180 | # Cell
181 | def eval_scores(y_true,y_pred,classes=None,num_thresholds=100,full_output=False,parallel=True):
182 |     '''returns a dictionary of performance metrics:
183 |     sample centric c.f. https://github.com/ashleyzhou972/CAFA_assessment_tool/blob/master/precrec/precRec.py
184 |     https://www.nature.com/articles/nmeth.2340 vs https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3694662/ and https://arxiv.org/pdf/1601.00891
185 |     * Fmax, sample AUC, sample Average Precision (as in sklearn)
186 | 
187 |     label-centric: micro,macro,individual AUC and Average Precision
188 |     '''
189 |     results = {}
190 | 
191 |     # thresholds = np.arange(0.00, 1.01, 1./num_thresholds, float)
192 |     # if(parallel is False):
193 |     #     PR = np.zeros(len(thresholds))
194 |     #     RC = np.zeros(len(thresholds))
195 |     #     SP = np.zeros(len(thresholds))
196 |     #     COV = np.zeros(len(thresholds))
197 | 
198 |     #     for i,t in enumerate(thresholds):
199 |     #         PR[i],RC[i],SP[i],COV[i] = eval_prrc(y_true,y_pred,t)
200 |     #     F =  (2*PR*RC)/(PR+RC)
201 |     # else:
202 |     #     PR,RC,SP,COV = eval_prrc_parallel(y_true,y_pred,thresholds)
203 |     #     F = (2*PR*RC)/(PR+RC)
204 | 
205 |     # if(full_output is True):
206 |     #     results["PR"] = PR
207 |     #     results["RC"] = RC
208 |     #     results["SP"] = SP
209 |     #     results["F"] = F
210 |     #     results["COV"] = COV
211 | 
212 |     # if np.isnan(F).sum() == len(F):
213 |     #     results["Fmax"] = 0
214 |     #     results["precision_at_Fmax"] = 0
215 |     #     results["recall_at_Fmax"] = 0
216 |     #     results["threshold_at_Fmax"] = 0
217 |     #     results["coverage_at_Fmax"]= 0
218 |     # else:
219 |     #     imax = np.nanargmax(F)
220 |     #     results["Fmax"] = F[imax]
221 |     #     results["precision_at_Fmax"] = PR[imax]
222 |     #     results["recall_at_Fmax"] = RC[imax]
223 |     #     results["threshold_at_Fmax"] = thresholds[imax]
224 |     #     results["coverage_at_Fmax"]=COV[imax]
225 | 
226 |     # results["sample_AUC"]=auc(1-SP,RC)
227 |     # #https://github.com/scikit-learn/scikit-learn/blob/1495f6924/sklearn/metrics/ranking.py set final PR value to 1
228 |     # PR[-1]=1
229 |     # results["sample_APR"]=auc_prrc_uninterpolated(RC,PR)#skip last point with undefined precision
230 |     ###########################################################
231 |     #label-centric
232 |     #"micro","macro",i=0...n_classes-1
233 |     fpr, tpr, roc_auc = multiclass_roc_curve(y_true, y_pred,classes=classes,precision_recall=False)
234 |     if(full_output is True):
235 |         results["fpr"]=fpr
236 |         results["tpr"]=tpr
237 |     results["label_AUC"]=roc_auc
238 | 
239 |     # rc, pr, prrc_auc = multiclass_roc_curve(y_true, y_pred,classes=classes,precision_recall=True)
240 |     # if(full_output is True):
241 |     #     results["pr"]=pr
242 |     #     results["rc"]=rc
243 |     # results["label_APR"]=prrc_auc
244 | 
245 |     return results
246 | 
247 | # Cell
248 | def eval_scores_bootstrap(y_true, y_pred,classes=None, n_iterations = 10000, alpha=0.95):
249 |     #https://ocw.mit.edu/courses/mathematics/18-05-introduction-to-probability-and-statistics-spring-2014/readings/MIT18_05S14_Reading24.pdf empirical bootstrap rather than bootstrap percentiles
250 |     Fmax_diff = []
251 |     sample_AUC_diff = []
252 |     sample_APR_diff = []
253 |     label_AUC_diff = []
254 |     label_APR_diff = []
255 |     label_AUC_keys = None
256 | 
257 |     #point estimate
258 |     res_point = eval_scores(y_true,y_pred,classes=classes)
259 |     Fmax_point = res_point["Fmax"]
260 |     sample_AUC_point = res_point["sample_AUC"]
261 |     sample_APR_point = res_point["sample_APR"]
262 |     label_AUC_point = np.array(list(res_point["label_AUC"].values()))
263 |     label_APR_point = np.array(list(res_point["label_APR"].values()))
264 | 
265 |     #bootstrap
266 |     for i in tqdm(range(n_iterations)):
267 |         ids = resample(range(len(y_true)), n_samples=len(y_true))
268 |         res = eval_scores(y_true[ids],y_pred[ids],classes=classes)
269 |         Fmax_diff.append(res["Fmax"]-Fmax_point)
270 |         sample_AUC_diff.append(res["sample_AUC"]-sample_AUC_point)
271 |         sample_APR_diff.append(res["sample_APR"]-sample_APR_point)
272 |         label_AUC_keys = list(res["label_AUC"].keys())
273 |         label_AUC_diff.append(np.array(list(res["label_AUC"].values()))-label_AUC_point)
274 |         label_APR_diff.append(np.array(list(res["label_APR"].values()))-label_APR_point)
275 | 
276 |     p = ((1.0-alpha)/2.0) * 100
277 |     Fmax_low = Fmax_point + np.percentile(Fmax_diff, p)
278 |     sample_AUC_low = sample_AUC_point + np.percentile(sample_AUC_diff, p)
279 |     sample_APR_low = sample_APR_point + np.percentile(sample_APR_diff, p)
280 |     label_AUC_low = label_AUC_point + np.percentile(label_AUC_diff,p,axis=0)
281 |     label_APR_low = label_APR_point + np.percentile(label_APR_diff,p,axis=0)
282 |     p = (alpha+((1.0-alpha)/2.0)) * 100
283 |     Fmax_high = Fmax_point + np.percentile(Fmax_diff, p)
284 |     sample_AUC_high = sample_AUC_point + np.percentile(sample_AUC_diff, p)
285 |     sample_APR_high = sample_APR_point + np.percentile(sample_APR_diff, p)
286 |     label_AUC_high = label_AUC_point + np.percentile(label_AUC_diff,p,axis=0)
287 |     label_APR_high = label_APR_point + np.percentile(label_APR_diff,p,axis=0)
288 | 
289 |     return {"Fmax":[Fmax_low,Fmax_point,Fmax_high], "sample_AUC":[sample_AUC_low,sample_AUC_point,sample_AUC_high], "sample_APR":[sample_APR_low,sample_APR_point,sample_APR_high], "label_AUC":{k:[v1,v2,v3] for k,v1,v2,v3 in zip(label_AUC_keys,label_AUC_low,label_AUC_point,label_AUC_high)}, "label_APR":{k:[v1,v2,v3] for k,v1,v2,v3 in zip(label_AUC_keys,label_APR_low,label_APR_point,label_APR_high)}}
--------------------------------------------------------------------------------
/src/clinical_ts/utils/schedulers.py:
--------------------------------------------------------------------------------
  1 | __all__ = ['get_constant_schedule','get_constant_schedule_with_warmup','get_linear_schedule_with_warmup','get_cosine_schedule_with_warmup', 'get_cosine_with_hard_restarts_schedule_with_warmup', 'get_polynomial_decay_schedule_with_warmup', 'get_invsqrt_decay_schedule_with_warmup']
  2 | 
  3 | #adapted from https://huggingface.co/transformers/_modules/transformers/optimization.htm
  4 | import math
  5 | from typing import Callable, Iterable, Optional, Tuple, Union
  6 | 
  7 | import torch
  8 | from torch.optim import Optimizer
  9 | from torch.optim.lr_scheduler import LambdaLR
 10 | 
 11 | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
 12 |     """
 13 |     Create a schedule with a constant learning rate, using the learning rate set in optimizer.
 14 | 
 15 |     Args:
 16 |         optimizer (:class:`~torch.optim.Optimizer`):
 17 |             The optimizer for which to schedule the learning rate.
 18 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
 19 |             The index of the last epoch when resuming training.
 20 | 
 21 |     Return:
 22 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
 23 |     """
 24 |     return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
 25 | 
 26 | 
 27 | 
 28 | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
 29 |     """
 30 |     Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
 31 |     increases linearly between 0 and the initial lr set in the optimizer.
 32 | 
 33 |     Args:
 34 |         optimizer (:class:`~torch.optim.Optimizer`):
 35 |             The optimizer for which to schedule the learning rate.
 36 |         num_warmup_steps (:obj:`int`):
 37 |             The number of steps for the warmup phase.
 38 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
 39 |             The index of the last epoch when resuming training.
 40 | 
 41 |     Return:
 42 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
 43 |     """
 44 | 
 45 |     def lr_lambda(current_step: int):
 46 |         if current_step < num_warmup_steps:
 47 |             return float(1+current_step) / float(max(1.0, 1+num_warmup_steps))
 48 |         return 1.0
 49 | 
 50 |     return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
 51 | 
 52 | 
 53 | 
 54 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
 55 |     """
 56 |     Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
 57 |     a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
 58 | 
 59 |     Args:
 60 |         optimizer (:class:`~torch.optim.Optimizer`):
 61 |             The optimizer for which to schedule the learning rate.
 62 |         num_warmup_steps (:obj:`int`):
 63 |             The number of steps for the warmup phase.
 64 |         num_training_steps (:obj:`int`):
 65 |             The total number of training steps.
 66 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
 67 |             The index of the last epoch when resuming training.
 68 | 
 69 |     Return:
 70 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
 71 |     """
 72 | 
 73 |     def lr_lambda(current_step: int):
 74 |         if current_step < num_warmup_steps:
 75 |             return float(1+current_step) / float(max(1, 1+num_warmup_steps))
 76 |         return max(
 77 |             0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
 78 |         )
 79 | 
 80 |     return LambdaLR(optimizer, lr_lambda, last_epoch)
 81 | 
 82 | 
 83 | 
 84 | def get_cosine_schedule_with_warmup(
 85 |     optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
 86 | ):
 87 |     """
 88 |     Create a schedule with a learning rate that decreases following the values of the cosine function between the
 89 |     initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
 90 |     initial lr set in the optimizer.
 91 | 
 92 |     Args:
 93 |         optimizer (:class:`~torch.optim.Optimizer`):
 94 |             The optimizer for which to schedule the learning rate.
 95 |         num_warmup_steps (:obj:`int`):
 96 |             The number of steps for the warmup phase.
 97 |         num_training_steps (:obj:`int`):
 98 |             The total number of training steps.
 99 |         num_cycles (:obj:`float`, `optional`, defaults to 0.5):
100 |             The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
101 |             following a half-cosine).
102 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
103 |             The index of the last epoch when resuming training.
104 | 
105 |     Return:
106 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
107 |     """
108 | 
109 |     def lr_lambda(current_step):
110 |         if current_step < num_warmup_steps:
111 |             return float(1+current_step) / float(max(1, 1+num_warmup_steps))
112 |         progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
113 |         return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
114 | 
115 |     return LambdaLR(optimizer, lr_lambda, last_epoch)
116 | 
117 | 
118 | 
119 | def get_cosine_with_hard_restarts_schedule_with_warmup(
120 |     optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
121 | ):
122 |     """
123 |     Create a schedule with a learning rate that decreases following the values of the cosine function between the
124 |     initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
125 |     linearly between 0 and the initial lr set in the optimizer.
126 | 
127 |     Args:
128 |         optimizer (:class:`~torch.optim.Optimizer`):
129 |             The optimizer for which to schedule the learning rate.
130 |         num_warmup_steps (:obj:`int`):
131 |             The number of steps for the warmup phase.
132 |         num_training_steps (:obj:`int`):
133 |             The total number of training steps.
134 |         num_cycles (:obj:`int`, `optional`, defaults to 1):
135 |             The number of hard restarts to use.
136 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
137 |             The index of the last epoch when resuming training.
138 | 
139 |     Return:
140 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
141 |     """
142 | 
143 |     def lr_lambda(current_step):
144 |         if current_step < num_warmup_steps:
145 |             return float(current_step) / float(max(1, num_warmup_steps))
146 |         progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
147 |         if progress >= 1.0:
148 |             return 0.0
149 |         return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
150 | 
151 |     return LambdaLR(optimizer, lr_lambda, last_epoch)
152 | 
153 | 
154 | 
155 | def get_polynomial_decay_schedule_with_warmup(
156 |     optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
157 | ):
158 |     """
159 |     Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
160 |     optimizer to end lr defined by `lr_end`, after a warmup period during which it increases linearly from 0 to the
161 |     initial lr set in the optimizer.
162 | 
163 |     Args:
164 |         optimizer (:class:`~torch.optim.Optimizer`):
165 |             The optimizer for which to schedule the learning rate.
166 |         num_warmup_steps (:obj:`int`):
167 |             The number of steps for the warmup phase.
168 |         num_training_steps (:obj:`int`):
169 |             The total number of training steps.
170 |         lr_end (:obj:`float`, `optional`, defaults to 1e-7):
171 |             The end LR.
172 |         power (:obj:`float`, `optional`, defaults to 1.0):
173 |             Power factor.
174 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
175 |             The index of the last epoch when resuming training.
176 | 
177 |     Note: `power` defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
178 |     implementation at
179 |     https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
180 | 
181 |     Return:
182 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
183 | 
184 |     """
185 | 
186 |     lr_init = optimizer.defaults["lr"]
187 |     assert lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})"
188 | 
189 |     def lr_lambda(current_step: int):
190 |         if current_step < num_warmup_steps:
191 |             return float(1+current_step) / float(max(1, 1+num_warmup_steps))
192 |         elif current_step > num_training_steps:
193 |             return lr_end / lr_init  # as LambdaLR multiplies by lr_init
194 |         else:
195 |             lr_range = lr_init - lr_end
196 |             decay_steps = num_training_steps - num_warmup_steps
197 |             pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
198 |             decay = lr_range * pct_remaining ** power + lr_end
199 |             return decay / lr_init  # as LambdaLR multiplies by lr_init
200 | 
201 |     return LambdaLR(optimizer, lr_lambda, last_epoch)
202 | 
203 | def get_invsqrt_decay_schedule_with_warmup(
204 |     optimizer, num_warmup_steps, last_epoch=-1
205 | ):
206 |     """
207 |     Create a schedule with a learning rate that decreases as a with an inverse sqrt law, 
208 |     after a warmup period during which it increases linearly from 0 to the
209 |     initial lr set in the optimizer.
210 | 
211 |     Args:
212 |         optimizer (:class:`~torch.optim.Optimizer`):
213 |             The optimizer for which to schedule the learning rate.
214 |         num_warmup_steps (:obj:`int`):
215 |             The number of steps for the warmup phase.
216 |         last_epoch (:obj:`int`, `optional`, defaults to -1):
217 |             The index of the last epoch when resuming training.
218 | 
219 | 
220 |     Return:
221 |         :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
222 | 
223 |     """
224 |     decay_factor = num_warmup_steps ** 0.5
225 |     lr_init = optimizer.defaults["lr"]
226 |     
227 |     def lr_lambda(current_step: int):
228 |         if current_step < num_warmup_steps:
229 |             return float(1+current_step) / float(max(1, 1+num_warmup_steps))
230 |         else:
231 |             return decay_factor* current_step ** -0.5
232 |             
233 |     return LambdaLR(optimizer, lr_lambda, last_epoch)
234 | 
--------------------------------------------------------------------------------
/src/config/config_supervised_multimodal_labvalues_s4.yaml:
--------------------------------------------------------------------------------
 1 | defaults:
 2 |   - base: base
 3 |   - data@data0: mimic_labvalues
 4 |   - ts@ts0: tsenc
 5 |   - ts/enc@ts0.enc: none
 6 |   - ts/pred@ts0.pred: s4
 7 |   - ts/head@ts0.head: pool
 8 |   - ts/loss@ts0.loss: none
 9 |   - static: mlp
10 |   - head: concat
11 |   - loss: bcef
12 |   - trainer: trainer
13 |   - task: multi
14 |   - metric@metric0: aurocagg
15 |   - _self_
16 | 
17 | loss:
18 |   ignore_nans: True
19 | 
20 | base:
21 |   fs: 100.
22 |   input_size: 250
23 |   input_channels: 12
24 |   normalize: false
25 |   batch_size: 32
26 |   input_channels_cat: 2
27 |   input_channels_cont: 10
28 | 
29 | trainer:
30 |   gpus: 1
31 |   refresh_rate: 0
32 |   username: "nstrodt"
33 |   epochs: 40
34 |   precision: 32
35 | 
36 | ts0:
37 |   pred:
38 |     causal: False
39 |     state_dim: 8
40 |     model_dim: 512
41 |     backbone: "s42"
42 | 
43 | static:
44 |   vocab_sizes: [2,5]
45 |   embedding_dims: [16,32] 
46 |   lin_ftrs: [128,128,128] 
47 | 
48 | metric0:
49 |   bootstrap_iterations: 1000
50 | 
51 | task:
52 |   introduce_nan_columns: true
--------------------------------------------------------------------------------
/src/config/data/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/config/data/multimodal_mdsed.yaml:
--------------------------------------------------------------------------------
 1 | defaults:
 2 |   - base
 3 | 
 4 | name: "mimic_labvalues" 
 5 | path: "data/memmap"
 6 | fs: 100.
 7 | annotation: false
 8 | cols_static: ["cont_features"]
 9 | cols_static_cat: ["cat_features"]
10 | 
--------------------------------------------------------------------------------
/src/data/memmap/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/ecg_utils.py:
--------------------------------------------------------------------------------
  1 | import wfdb
  2 | 
  3 | import os
  4 | import shutil
  5 | import zipfile
  6 | 
  7 | import numpy as np
  8 | import pandas as pd
  9 | 
 10 | import resampy
 11 | from tqdm.auto import tqdm
 12 | from pathlib import Path
 13 | 
 14 | 
 15 | import datetime
 16 | 
 17 | from clinical_ts.timeseries_utils import *
 18 | 
 19 | from sklearn.model_selection import StratifiedKFold
 20 | 
 21 | channel_stoi_default = {"i": 0, "ii": 1, "v1":2, "v2":3, "v3":4, "v4":5, "v5":6, "v6":7, "iii":8, "avr":9, "avl":10, "avf":11, "vx":12, "vy":13, "vz":14}
 22 | 
 23 | def get_stratified_kfolds(labels,n_splits,random_state):
 24 |     skf = StratifiedKFold(n_splits=n_splits,shuffle=True,random_state=random_state)
 25 |     return skf.split(np.zeros(len(labels)),labels)
 26 | 
 27 | def resample_data(sigbufs, channel_labels, fs, target_fs, channels=12, channel_stoi=None):#,skimage_transform=True,interpolation_order=3):
 28 |     channel_labels = [c.lower() for c in channel_labels]
 29 |     #https://github.com/scipy/scipy/issues/7324 zoom issues
 30 |     factor = target_fs/fs
 31 |     timesteps_new = int(len(sigbufs)*factor)
 32 |     if(channel_stoi is not None):
 33 |         data = np.zeros((timesteps_new, channels), dtype=np.float32)
 34 |         for i,cl in enumerate(channel_labels):
 35 |             if(cl in channel_stoi.keys() and channel_stoi[cl]0 else tmp
 47 |     
 48 |     if(recreate_data):
 49 |         target_folder = Path(target_folder)
 50 |         target_folder.mkdir(parents=True, exist_ok=True)
 51 | 
 52 |         with zipfile.ZipFile(Path(data_path)/"mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0.zip", 'r') as archive:
 53 |             lst = archive.namelist()
 54 |             lst = [x for x in lst if x.endswith(".hea")]
 55 | 
 56 |             meta = []
 57 |             for l in tqdm(lst):
 58 |                 archive.extract(l, path="tmp_dir/")
 59 |                 archive.extract(l[:-3]+"dat", path="tmp_dir/")
 60 |                 filename = Path("tmp_dir")/l
 61 |                 sigbufs, header = wfdb.rdsamp(str(filename)[:-4])
 62 |             
 63 |                 tmp={}
 64 |                 tmp["data"]=filename.parent.parent.stem+"_"+filename.parent.stem+".npy" #patientid_study.npy
 65 |                 tmp["study_id"]=int(filename.stem)
 66 |                 tmp["subject_id"]=int(filename.parent.parent.stem[1:])
 67 |                 tmp['ecg_time']= datetime.datetime.combine(header["base_date"],header["base_time"])
 68 |                 tmp["nans"]= list(np.sum(np.isnan(sigbufs),axis=0))#save nans channel-dependent
 69 |                 if(np.sum(tmp["nans"])>0):#fix nans
 70 |                     fix_nans_and_clip(sigbufs,clip_amp=clip_amp)
 71 |                 elif(clip_amp>0):
 72 |                     sigbufs = np.clip(sigbufs,a_max=clip_amp,a_min=-clip_amp)
 73 | 
 74 |                 data = resample_data(sigbufs=sigbufs,channel_stoi=channel_stoi,channel_labels=header['sig_name'],fs=header['fs'],target_fs=target_fs,channels=channels)
 75 |                 
 76 |                 assert(target_fs<=header['fs'])
 77 |                 np.save(target_folder/tmp["data"],data)
 78 |                 meta.append(tmp)
 79 |                 
 80 |                 os.unlink("tmp_dir/"+l)
 81 |                 os.unlink("tmp_dir/"+l[:-3]+"dat")
 82 |                 shutil.rmtree("tmp_dir")
 83 | 
 84 |         df = pd.DataFrame(meta)
 85 | 
 86 |         #random split by patients
 87 |         #unique_patients = np.unique(df.subject_id)
 88 |         #splits_patients = get_stratified_kfolds(np.zeros(len(unique_patients)),n_splits=strat_folds,random_state=42)
 89 |         #df["fold"]=-1
 90 |         #for i,split in enumerate(splits_patients):
 91 |         #    df.loc[df.subject_id.isin(unique_patients[split[-1]]),"fold"]=i
 92 |         
 93 |         #add means and std
 94 |         dataset_add_mean_col(df,data_folder=target_folder)
 95 |         dataset_add_std_col(df,data_folder=target_folder)
 96 |         dataset_add_length_col(df,data_folder=target_folder)
 97 | 
 98 |         #save means and stds
 99 |         mean, std = dataset_get_stats(df)
100 | 
101 |         #save
102 |         lbl_itos=[]
103 |         save_dataset(df,lbl_itos,mean,std,target_folder)
104 |     else:
105 |         df, lbl_itos, mean, std = load_dataset(target_folder,df_mapped=False)
106 |     return df, lbl_itos, mean, std
107 | 
108 | 
--------------------------------------------------------------------------------
/src/environment.yml:
--------------------------------------------------------------------------------
  1 | name: lightning2
  2 | channels:
  3 |   - pytorch
  4 |   - nvidia
  5 |   - conda-forge
  6 |   - defaults
  7 | dependencies:
  8 |   - _libgcc_mutex=0.1=main
  9 |   - _openmp_mutex=5.1=1_gnu
 10 |   - blas=1.0=mkl
 11 |   - brotli-python=1.0.9=py312h6a678d5_7
 12 |   - bzip2=1.0.8=h7b6447c_0
 13 |   - ca-certificates=2023.11.17=hbcca054_0
 14 |   - certifi=2023.11.17=pyhd8ed1ab_0
 15 |   - cffi=1.16.0=py312h5eee18b_0
 16 |   - charset-normalizer=2.0.4=pyhd3eb1b0_0
 17 |   - colorama=0.4.6=pyhd8ed1ab_0
 18 |   - cryptography=41.0.7=py312hdda0065_0
 19 |   - cuda-cudart=12.1.105=0
 20 |   - cuda-cupti=12.1.105=0
 21 |   - cuda-libraries=12.1.0=0
 22 |   - cuda-nvrtc=12.1.105=0
 23 |   - cuda-nvtx=12.1.105=0
 24 |   - cuda-opencl=12.3.101=0
 25 |   - cuda-runtime=12.1.0=0
 26 |   - expat=2.5.0=h6a678d5_0
 27 |   - ffmpeg=4.3=hf484d3e_0
 28 |   - filelock=3.13.1=py312h06a4308_0
 29 |   - freetype=2.12.1=h4a9f257_0
 30 |   - fsspec=2023.12.2=pyhca7485f_0
 31 |   - giflib=5.2.1=h5eee18b_3
 32 |   - gmp=6.2.1=h295c915_3
 33 |   - gnutls=3.6.15=he1e5248_0
 34 |   - idna=3.4=py312h06a4308_0
 35 |   - intel-openmp=2023.1.0=hdb19cb5_46306
 36 |   - jinja2=3.1.2=py312h06a4308_0
 37 |   - jpeg=9e=h5eee18b_1
 38 |   - lame=3.100=h7b6447c_0
 39 |   - lcms2=2.12=h3be6417_0
 40 |   - ld_impl_linux-64=2.38=h1181459_1
 41 |   - lerc=3.0=h295c915_0
 42 |   - libcublas=12.1.0.26=0
 43 |   - libcufft=11.0.2.4=0
 44 |   - libcufile=1.8.1.2=0
 45 |   - libcurand=10.3.4.107=0
 46 |   - libcusolver=11.4.4.55=0
 47 |   - libcusparse=12.0.2.55=0
 48 |   - libdeflate=1.17=h5eee18b_1
 49 |   - libffi=3.4.4=h6a678d5_0
 50 |   - libgcc-ng=11.2.0=h1234567_1
 51 |   - libgomp=11.2.0=h1234567_1
 52 |   - libiconv=1.16=h7f8727e_2
 53 |   - libidn2=2.3.4=h5eee18b_0
 54 |   - libjpeg-turbo=2.0.0=h9bf148f_0
 55 |   - libnpp=12.0.2.50=0
 56 |   - libnvjitlink=12.1.105=0
 57 |   - libnvjpeg=12.1.1.14=0
 58 |   - libpng=1.6.39=h5eee18b_0
 59 |   - libstdcxx-ng=11.2.0=h1234567_1
 60 |   - libtasn1=4.19.0=h5eee18b_0
 61 |   - libtiff=4.5.1=h6a678d5_0
 62 |   - libunistring=0.9.10=h27cfd23_0
 63 |   - libuuid=1.41.5=h5eee18b_0
 64 |   - libwebp=1.3.2=h11a3e52_0
 65 |   - libwebp-base=1.3.2=h5eee18b_0
 66 |   - lightning=2.1.3=pyhd8ed1ab_1
 67 |   - lightning-utilities=0.10.1=pyhd8ed1ab_0
 68 |   - llvm-openmp=14.0.6=h9e868ea_0
 69 |   - lz4-c=1.9.4=h6a678d5_0
 70 |   - markupsafe=2.1.1=py312h5eee18b_0
 71 |   - mkl=2023.1.0=h213fc3f_46344
 72 |   - mkl-service=2.4.0=py312h5eee18b_1
 73 |   - mkl_fft=1.3.8=py312h5eee18b_0
 74 |   - mkl_random=1.2.4=py312hdb19cb5_0
 75 |   - mpmath=1.3.0=py312h06a4308_0
 76 |   - ncurses=6.4=h6a678d5_0
 77 |   - nettle=3.7.3=hbbd107a_1
 78 |   - networkx=3.1=py312h06a4308_0
 79 |   - numpy=1.26.3=py312hc5e2394_0
 80 |   - numpy-base=1.26.3=py312h0da6c21_0
 81 |   - openh264=2.1.1=h4ff587b_0
 82 |   - openjpeg=2.4.0=h3ad879b_0
 83 |   - openssl=3.0.12=h7f8727e_0
 84 |   - packaging=23.2=pyhd8ed1ab_0
 85 |   - pillow=10.0.1=py312ha6cbd5a_0
 86 |   - pip=23.3.1=py312h06a4308_0
 87 |   - pycparser=2.21=pyhd3eb1b0_0
 88 |   - pyopenssl=23.2.0=py312h06a4308_0
 89 |   - pysocks=1.7.1=py312h06a4308_0
 90 |   - python=3.12.1=h996f2a0_0
 91 |   - pytorch=2.2.0=py3.12_cuda12.1_cudnn8.9.2_0
 92 |   - pytorch-cuda=12.1=ha16c6d3_5
 93 |   - pytorch-lightning=2.1.3=pyhd8ed1ab_0
 94 |   - pytorch-mutex=1.0=cuda
 95 |   - pyyaml=6.0.1=py312h5eee18b_0
 96 |   - readline=8.2=h5eee18b_0
 97 |   - requests=2.31.0=py312h06a4308_0
 98 |   - setuptools=68.2.2=py312h06a4308_0
 99 |   - sqlite=3.41.2=h5eee18b_0
100 |   - sympy=1.12=py312h06a4308_0
101 |   - tbb=2021.8.0=hdb19cb5_0
102 |   - tk=8.6.12=h1ccaba5_0
103 |   - torchaudio=2.2.0=py312_cu121
104 |   - torchmetrics=1.2.1=pyhd8ed1ab_0
105 |   - torchvision=0.17.0=py312_cu121
106 |   - tqdm=4.66.1=pyhd8ed1ab_0
107 |   - typing-extensions=4.9.0=py312h06a4308_1
108 |   - typing_extensions=4.9.0=py312h06a4308_1
109 |   - urllib3=1.26.18=py312h06a4308_0
110 |   - wheel=0.41.2=py312h06a4308_0
111 |   - xz=5.4.5=h5eee18b_0
112 |   - yaml=0.2.5=h7b6447c_0
113 |   - zlib=1.2.13=h5eee18b_0
114 |   - zstd=1.5.5=hc292b87_0
115 |   - pip:
116 |       - absl-py==2.1.0
117 |       - alembic==1.13.1
118 |       - aniso8601==9.0.1
119 |       - antlr4-python3-runtime==4.9.3
120 |       - anyio==4.3.0
121 |       - argon2-cffi==23.1.0
122 |       - argon2-cffi-bindings==21.2.0
123 |       - arrow==1.3.0
124 |       - asttokens==2.4.1
125 |       - async-lru==2.0.4
126 |       - attrs==23.2.0
127 |       - babel==2.14.0
128 |       - beautifulsoup4==4.12.3
129 |       - bleach==6.1.0
130 |       - blinker==1.7.0
131 |       - cachetools==5.3.2
132 |       - catboost==1.2.5
133 |       - causal-conv1d==1.2.0.post2
134 |       - click==8.1.7
135 |       - cloudpickle==3.0.0
136 |       - comm==0.2.1
137 |       - contourpy==1.2.0
138 |       - cycler==0.12.1
139 |       - debugpy==1.8.1
140 |       - decorator==5.1.1
141 |       - defusedxml==0.7.1
142 |       - docker==7.0.0
143 |       - einops==0.7.0
144 |       - entrypoints==0.4
145 |       - et-xmlfile==1.1.0
146 |       - executing==2.0.1
147 |       - fastjsonschema==2.19.1
148 |       - flask==3.0.2
149 |       - fonttools==4.49.0
150 |       - fqdn==1.5.1
151 |       - ftfy==6.2.0
152 |       - gitdb==4.0.11
153 |       - gitpython==3.1.42
154 |       - google-auth==2.27.0
155 |       - google-auth-oauthlib==1.2.0
156 |       - graphene==3.3
157 |       - graphql-core==3.2.3
158 |       - graphql-relay==3.2.0
159 |       - greenlet==3.0.3
160 |       - grpcio==1.60.1
161 |       - gunicorn==21.2.0
162 |       - h11==0.14.0
163 |       - httpcore==1.0.4
164 |       - httpx==0.27.0
165 |       - huggingface-hub==0.22.1
166 |       - hydra-core==1.3.2
167 |       - imageio==2.33.1
168 |       - importlib-metadata==7.0.2
169 |       - ipykernel==6.29.3
170 |       - ipython==8.22.1
171 |       - ipywidgets==8.1.5
172 |       - isoduration==20.11.0
173 |       - itsdangerous==2.1.2
174 |       - jedi==0.19.1
175 |       - joblib==1.3.2
176 |       - json5==0.9.17
177 |       - jsonpointer==2.4
178 |       - jsonschema==4.21.1
179 |       - jsonschema-specifications==2023.12.1
180 |       - jupyter-client==8.6.0
181 |       - jupyter-core==5.7.1
182 |       - jupyter-events==0.9.0
183 |       - jupyter-lsp==2.2.3
184 |       - jupyter-server==2.12.5
185 |       - jupyter-server-terminals==0.5.2
186 |       - jupyterlab==4.1.2
187 |       - jupyterlab-pygments==0.3.0
188 |       - jupyterlab-server==2.25.3
189 |       - jupyterlab-widgets==3.0.13
190 |       - keopscore==2.2.1
191 |       - kiwisolver==1.4.5
192 |       - lazy-loader==0.3
193 |       - llvmlite==0.43.0
194 |       - mako==1.3.2
195 |       - mamba-ssm==1.2.0.post1
196 |       - markdown==3.5.2
197 |       - matplotlib==3.8.3
198 |       - matplotlib-inline==0.1.6
199 |       - mistune==3.0.2
200 |       - mlflow==2.11.1
201 |       - nbclient==0.9.0
202 |       - nbconvert==7.16.1
203 |       - nbformat==5.9.2
204 |       - nest-asyncio==1.6.0
205 |       - ninja==1.11.1.1
206 |       - notebook==7.1.1
207 |       - notebook-shim==0.2.4
208 |       - numba==0.60.0
209 |       - oauthlib==3.2.2
210 |       - omegaconf==2.3.0
211 |       - openpyxl==3.1.5
212 |       - opt-einsum==3.3.0
213 |       - overrides==7.7.0
214 |       - pandas==2.2.0
215 |       - pandocfilters==1.5.1
216 |       - parso==0.8.3
217 |       - pexpect==4.9.0
218 |       - platformdirs==4.2.0
219 |       - plotly==5.24.0
220 |       - prometheus-client==0.20.0
221 |       - prompt-toolkit==3.0.43
222 |       - protobuf==4.23.4
223 |       - psutil==5.9.8
224 |       - ptyprocess==0.7.0
225 |       - pure-eval==0.2.2
226 |       - pyarrow==15.0.1
227 |       - pyasn1==0.5.1
228 |       - pyasn1-modules==0.3.0
229 |       - pybind11==2.11.1
230 |       - pygments==2.17.2
231 |       - pykeops==2.2.1
232 |       - pyparsing==3.1.2
233 |       - python-dateutil==2.8.2
234 |       - python-graphviz==0.20.3
235 |       - python-json-logger==2.0.7
236 |       - pytz==2024.1
237 |       - pyzmq==25.1.2
238 |       - querystring-parser==1.2.4
239 |       - referencing==0.33.0
240 |       - regex==2023.12.25
241 |       - requests-oauthlib==1.3.1
242 |       - resampy==0.4.3
243 |       - rfc3339-validator==0.1.4
244 |       - rfc3986-validator==0.1.1
245 |       - rpds-py==0.18.0
246 |       - rsa==4.9
247 |       - safetensors==0.4.2
248 |       - scikit-image==0.22.0
249 |       - scikit-learn==1.4.0
250 |       - scipy==1.12.0
251 |       - send2trash==1.8.2
252 |       - shap==0.46.0
253 |       - six==1.16.0
254 |       - slicer==0.0.8
255 |       - smmap==5.0.1
256 |       - sniffio==1.3.1
257 |       - soupsieve==2.5
258 |       - sqlalchemy==2.0.28
259 |       - sqlparse==0.4.4
260 |       - stack-data==0.6.3
261 |       - structured-kernels==0.1.0
262 |       - tenacity==9.0.0
263 |       - tensorboard==2.15.1
264 |       - tensorboard-data-server==0.7.2
265 |       - terminado==0.18.0
266 |       - threadpoolctl==3.2.0
267 |       - tifffile==2024.1.30
268 |       - timm==1.0.7
269 |       - tinycss2==1.2.1
270 |       - tokenizers==0.15.2
271 |       - tornado==6.4
272 |       - traitlets==5.14.1
273 |       - transformers==4.39.1
274 |       - triton==2.2.0
275 |       - types-python-dateutil==2.8.19.20240106
276 |       - tzdata==2023.4
277 |       - uri-template==1.3.0
278 |       - wcwidth==0.2.13
279 |       - webcolors==1.13
280 |       - webencodings==0.5.1
281 |       - websocket-client==1.7.0
282 |       - werkzeug==3.0.1
283 |       - widgetsnbextension==4.0.13
284 |       - zipp==3.17.0
285 | prefix: /user/jael1674/anaconda3/envs/lightning2
286 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/a.md:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/benchmark_cauchy.py:
--------------------------------------------------------------------------------
 1 | import math
 2 | from functools import partial
 3 | 
 4 | import torch
 5 | 
 6 | from einops import rearrange
 7 | 
 8 | from .cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult
 9 | from benchmark.utils import benchmark_all, benchmark_combined, benchmark_forward, benchmark_backward
10 | 
11 | 
12 | def generate_data(batch_size, N, L, symmetric=True, device='cuda'):
13 |     if not symmetric:
14 |         v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
15 |         w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
16 |         z = torch.randn(L, dtype=torch.complex64, device=device)
17 |     else:
18 |         assert N % 2 == 0
19 |         v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
20 |         v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
21 |         w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
22 |         w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
23 |         z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
24 |     return v, z, w
25 | 
26 | 
27 | if __name__ == '__main__':
28 |     device = 'cuda'
29 |     bs = 1024
30 |     N = 64
31 |     L = 16384
32 | 
33 |     v, z, w = generate_data(bs, N, L, symmetric=True)
34 |     v_half = v[:, :N // 2].clone().detach().requires_grad_(True)
35 |     w_half = w[:, :N // 2].clone().detach().requires_grad_(True)
36 | 
37 |     repeat = 30
38 |     benchmark_all(repeat, cauchy_mult_keops, v, z, w, desc='Cauchy mult keops')
39 |     fn = partial(cauchy_mult, symmetric=False)
40 |     benchmark_all(repeat, fn, v, z, w, desc='Cauchy mult')
41 |     fn = partial(cauchy_mult, symmetric=True)
42 |     benchmark_all(repeat, fn, v_half, z, w_half, desc='Cauchy mult symmetric')
43 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/benchmark_cauchy_tune.py:
--------------------------------------------------------------------------------
 1 | import importlib
 2 | import json
 3 | import argparse
 4 | 
 5 | import torch
 6 | 
 7 | from benchmark.utils import benchmark_forward
 8 | 
 9 | 
10 | def generate_data(batch_size, N, L, symmetric=True, device='cuda'):
11 |     if not symmetric:
12 |         v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
13 |         w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
14 |         z = torch.randn(L, dtype=torch.complex64, device=device)
15 |     else:
16 |         assert N % 2 == 0
17 |         v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
18 |         v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
19 |         w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
20 |         w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
21 |         z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
22 |     return v, z, w
23 | 
24 | 
25 | parser = argparse.ArgumentParser(description='Tuning Cauchy multiply')
26 | parser.add_argument('--name', default='cauchy_mult')
27 | parser.add_argument('--mode', default='forward', choices=['forward', 'backward'])
28 | parser.add_argument('-bs', '--batch-size', default=1024, type=int)
29 | parser.add_argument('-N', default=64, type=int)
30 | parser.add_argument('-L', default=2 ** 14, type=int)
31 | 
32 | 
33 | if __name__ == '__main__':
34 |     args = parser.parse_args()
35 |     device = 'cuda'
36 |     bs = args.batch_size
37 |     N = args.N
38 |     L = args.L
39 |     repeat = 30
40 |     v, z, w = generate_data(bs, N, L, symmetric=True)
41 |     v_half = v[:, :N // 2].clone().detach().requires_grad_(True)
42 |     w_half = w[:, :N // 2].clone().detach().requires_grad_(True)
43 | 
44 |     tuning_extension_name = args.name
45 |     # print('Extension name:', tuning_extension_name)
46 |     module = importlib.import_module(tuning_extension_name)
47 |     if args.mode == 'forward':
48 |         _, m = benchmark_forward(repeat, module.cauchy_mult_sym_fwd, v_half, z, w_half,
49 |                                  verbose=False, desc='Cauchy mult symmetric fwd')
50 |     else:
51 |         out = module.cauchy_mult_sym_fwd(v_half, z, w_half)
52 |         dout = torch.randn_like(out)
53 |         _, m = benchmark_forward(repeat, module.cauchy_mult_sym_bwd, v_half, z, w_half, dout,
54 |                                  verbose=False, desc='Cauchy mult symmetric bwd')
55 |     result_dict = dict(time_mean = m.mean, time_iqr = m.iqr)
56 |     print(json.dumps(result_dict))
57 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/cauchy.cpp:
--------------------------------------------------------------------------------
 1 | #include 
 2 | #include 
 3 | #include 
 4 | #include 
 5 | 
 6 | #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
 7 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
 8 | 
 9 | torch::Tensor cauchy_mult_fwd_cuda(torch::Tensor v,
10 |                                    torch::Tensor z,
11 |                                    torch::Tensor w);
12 | std::tuple cauchy_mult_bwd_cuda(torch::Tensor v,
13 |                                                               torch::Tensor z,
14 |                                                               torch::Tensor w,
15 |                                                               torch::Tensor dout);
16 | torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v,
17 |                                        torch::Tensor z,
18 |                                        torch::Tensor w);
19 | std::tuple cauchy_mult_sym_bwd_cuda(torch::Tensor v,
20 |                                                                   torch::Tensor z,
21 |                                                                   torch::Tensor w,
22 |                                                                   torch::Tensor dout);
23 | 
24 | namespace cauchy {
25 | 
26 | torch::Tensor cauchy_mult_fwd(torch::Tensor v,
27 |                               torch::Tensor z,
28 |                               torch::Tensor w) {
29 |   CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w);
30 |   const auto batch_size = v.size(0);
31 |   const auto N = v.size(1);
32 |   const auto L = z.size(0);
33 |   CHECK_SHAPE(v, batch_size, N);
34 |   CHECK_SHAPE(z, L);
35 |   CHECK_SHAPE(w, batch_size, N);
36 |   return cauchy_mult_fwd_cuda(v, z, w);
37 | }
38 | 
39 | std::tuple
40 | cauchy_mult_bwd(torch::Tensor v,
41 |                 torch::Tensor z,
42 |                 torch::Tensor w,
43 |                 torch::Tensor dout) {
44 |   CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); CHECK_DEVICE(dout);
45 |   const auto batch_size = v.size(0);
46 |   const auto N = v.size(1);
47 |   const auto L = z.size(0);
48 |   CHECK_SHAPE(v, batch_size, N);
49 |   CHECK_SHAPE(z, L);
50 |   CHECK_SHAPE(w, batch_size, N);
51 |   CHECK_SHAPE(dout, batch_size, L);
52 |   return cauchy_mult_bwd_cuda(v, z, w, dout);
53 | }
54 | 
55 | torch::Tensor cauchy_mult_sym_fwd(torch::Tensor v,
56 |                                   torch::Tensor z,
57 |                                   torch::Tensor w) {
58 |   CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w);
59 |   const auto batch_size = v.size(0);
60 |   const auto N = v.size(1);
61 |   const auto L = z.size(0);
62 |   CHECK_SHAPE(v, batch_size, N);
63 |   CHECK_SHAPE(z, L);
64 |   CHECK_SHAPE(w, batch_size, N);
65 |   return cauchy_mult_sym_fwd_cuda(v, z, w);
66 | }
67 | 
68 | std::tuple
69 | cauchy_mult_sym_bwd(torch::Tensor v,
70 |                     torch::Tensor z,
71 |                     torch::Tensor w,
72 |                     torch::Tensor dout) {
73 |   CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); CHECK_DEVICE(dout);
74 |   const auto batch_size = v.size(0);
75 |   const auto N = v.size(1);
76 |   const auto L = z.size(0);
77 |   CHECK_SHAPE(v, batch_size, N);
78 |   CHECK_SHAPE(z, L);
79 |   CHECK_SHAPE(w, batch_size, N);
80 |   CHECK_SHAPE(dout, batch_size, L);
81 |   return cauchy_mult_sym_bwd_cuda(v, z, w, dout);
82 | }
83 | 
84 | }  // cauchy
85 | 
86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
87 |   m.def("cauchy_mult_fwd", &cauchy::cauchy_mult_fwd,
88 |         "Cauchy multiply forward");
89 |   m.def("cauchy_mult_bwd", &cauchy::cauchy_mult_bwd,
90 |         "Cauchy multiply backward");
91 |   m.def("cauchy_mult_sym_fwd", &cauchy::cauchy_mult_sym_fwd,
92 |         "Cauchy multiply symmetric forward");
93 |   m.def("cauchy_mult_sym_bwd", &cauchy::cauchy_mult_sym_bwd,
94 |         "Cauchy multiply symmetric backward");
95 | }
96 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/cauchy.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | 
  3 | from einops import rearrange
  4 | 
  5 | from cauchy_mult import cauchy_mult_fwd, cauchy_mult_bwd, cauchy_mult_sym_fwd, cauchy_mult_sym_bwd
  6 | 
  7 | 
  8 | def cauchy_mult_torch(v: torch.Tensor, z: torch.Tensor, w: torch.Tensor,
  9 |                       symmetric=True) -> torch.Tensor:
 10 |     """
 11 |     v: (B, N)
 12 |     z: (L)
 13 |     w: (B, N)
 14 |     symmetric: whether to assume that v and w contain complex conjugate pairs, of the form
 15 |     [v_half, v_half.conj()] and [w_half, w_half.conj()]
 16 |     """
 17 |     if not symmetric:
 18 |         return (rearrange(v, 'b n -> b 1 n') / (rearrange(z, 'l -> l 1') - rearrange(w, 'b n -> b 1 n'))).sum(dim=-1)
 19 |     else:
 20 |         N = v.shape[-1]
 21 |         assert N % 2 == 0
 22 |         vv = rearrange(v[:, :N // 2], 'b n -> b 1 n')
 23 |         zz = rearrange(z, 'l -> l 1')
 24 |         ww = rearrange(w[:, :N // 2], 'b n -> b 1 n')
 25 |         return 2 * ((zz * vv.real - vv.real * ww.real - vv.imag * ww.imag)
 26 |                     / (zz * zz - 2 * zz * ww.real + ww.abs().square())).sum(dim=-1)
 27 | 
 28 | 
 29 | def cauchy_mult_keops(v, z, w):
 30 |     from pykeops.torch import LazyTensor
 31 |     v_l = LazyTensor(rearrange(v, 'b N -> b 1 N 1'))
 32 |     z_l = LazyTensor(rearrange(z, 'L -> 1 L 1 1'))
 33 |     w_l = LazyTensor(rearrange(w, 'b N -> b 1 N 1'))
 34 |     sub = z_l - w_l  # (b N L 1), for some reason it doesn't display the last dimension
 35 |     div = v_l / sub
 36 |     s = div.sum(dim=2, backend='GPU')
 37 |     return s.squeeze(-1)
 38 | 
 39 | 
 40 | def _cauchy_mult(v, z, w, symmetric=True):
 41 |     if not symmetric:
 42 |         return CauchyMultiply.apply(v, z, w)
 43 |     else:
 44 |         return CauchyMultiplySymmetric.apply(v, z, w)
 45 | 
 46 | def cauchy_mult(v, z, w, symmetric=True):
 47 |     """ Wrap the cuda method to deal with shapes """
 48 |     v, w = torch.broadcast_tensors(v, w)
 49 |     shape = v.shape
 50 |     # z_shape = z.shape
 51 |     z = z.squeeze()
 52 |     assert len(z.shape) == 1
 53 | 
 54 |     v = v.contiguous()
 55 |     w = w.contiguous()
 56 |     z = z.contiguous()
 57 | 
 58 |     N = v.size(-1)
 59 |     assert w.size(-1) == N
 60 |     y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N), symmetric=symmetric)
 61 |     y = y.view(*shape[:-1], z.size(-1))
 62 |     # y = z.new_zeros(*shape[:-1], z.size(-1))
 63 |     return y
 64 | 
 65 | 
 66 | class CauchyMultiply(torch.autograd.Function):
 67 | 
 68 |     @staticmethod
 69 |     def forward(ctx, v, z, w):
 70 |         batch, N = v.shape
 71 |         # supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
 72 |         supported_N_values = [1 << log_n for log_n in [6]]
 73 |         L = z.shape[-1]
 74 |         if not N in supported_N_values:
 75 |             raise NotImplementedError(f'Only support N values in {supported_N_values}')
 76 |         if L % 32 != 0:
 77 |             raise NotImplementedError(f'Only support L values that are multiples of 32')
 78 |         if not v.is_cuda and z.is_cuda and w.is_cuda:
 79 |             raise NotImplementedError(f'Only support CUDA tensors')
 80 |         ctx.save_for_backward(v, z, w)
 81 |         return cauchy_mult_fwd(v, z, w)
 82 | 
 83 |     @staticmethod
 84 |     def backward(ctx, dout):
 85 |         v, z, w = ctx.saved_tensors
 86 |         dv, dw = cauchy_mult_bwd(v, z, w, dout)
 87 |         return dv, None, dw
 88 | 
 89 | 
 90 | class CauchyMultiplySymmetric(torch.autograd.Function):
 91 | 
 92 |     @staticmethod
 93 |     def forward(ctx, v, z, w):
 94 |         batch, N = v.shape
 95 |         supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
 96 |         L = z.shape[-1]
 97 |         if not N in supported_N_values:
 98 |             raise NotImplementedError(f'Only support N values in {supported_N_values}')
 99 |         max_L_value = 32 * 1024 * 64 * 1024
100 |         if L > max_L_value:
101 |             raise NotImplementedError(f'Only support L values <= {max_L_value}')
102 |         if not v.is_cuda and z.is_cuda and w.is_cuda:
103 |             raise NotImplementedError(f'Only support CUDA tensors')
104 |         ctx.save_for_backward(v, z, w)
105 |         return cauchy_mult_sym_fwd(v, z, w)
106 | 
107 |     @staticmethod
108 |     def backward(ctx, dout):
109 |         v, z, w = ctx.saved_tensors
110 |         dv, dw = cauchy_mult_sym_bwd(v, z, w, dout)
111 |         return dv, None, dw
112 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/map.h:
--------------------------------------------------------------------------------
 1 | // Downloaded from https://github.com/swansontec/map-macro
 2 | 
 3 | /*
 4 |  * Copyright (C) 2012 William Swanson
 5 |  *
 6 |  * Permission is hereby granted, free of charge, to any person
 7 |  * obtaining a copy of this software and associated documentation
 8 |  * files (the "Software"), to deal in the Software without
 9 |  * restriction, including without limitation the rights to use, copy,
10 |  * modify, merge, publish, distribute, sublicense, and/or sell copies
11 |  * of the Software, and to permit persons to whom the Software is
12 |  * furnished to do so, subject to the following conditions:
13 |  *
14 |  * The above copyright notice and this permission notice shall be
15 |  * included in all copies or substantial portions of the Software.
16 |  *
17 |  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 |  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 |  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 |  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
21 |  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
22 |  * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23 |  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 |  *
25 |  * Except as contained in this notice, the names of the authors or
26 |  * their institutions shall not be used in advertising or otherwise to
27 |  * promote the sale, use or other dealings in this Software without
28 |  * prior written authorization from the authors.
29 |  */
30 | 
31 | #ifndef MAP_H_INCLUDED
32 | #define MAP_H_INCLUDED
33 | 
34 | #define EVAL0(...) __VA_ARGS__
35 | #define EVAL1(...) EVAL0(EVAL0(EVAL0(__VA_ARGS__)))
36 | #define EVAL2(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__)))
37 | #define EVAL3(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__)))
38 | #define EVAL4(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__)))
39 | #define EVAL(...)  EVAL4(EVAL4(EVAL4(__VA_ARGS__)))
40 | 
41 | #define MAP_END(...)
42 | #define MAP_OUT
43 | #define MAP_COMMA ,
44 | 
45 | #define MAP_GET_END2() 0, MAP_END
46 | #define MAP_GET_END1(...) MAP_GET_END2
47 | #define MAP_GET_END(...) MAP_GET_END1
48 | #define MAP_NEXT0(test, next, ...) next MAP_OUT
49 | #define MAP_NEXT1(test, next) MAP_NEXT0(test, next, 0)
50 | #define MAP_NEXT(test, next)  MAP_NEXT1(MAP_GET_END test, next)
51 | 
52 | #define MAP0(f, x, peek, ...) f(x) MAP_NEXT(peek, MAP1)(f, peek, __VA_ARGS__)
53 | #define MAP1(f, x, peek, ...) f(x) MAP_NEXT(peek, MAP0)(f, peek, __VA_ARGS__)
54 | 
55 | #define MAP_LIST_NEXT1(test, next) MAP_NEXT0(test, MAP_COMMA next, 0)
56 | #define MAP_LIST_NEXT(test, next)  MAP_LIST_NEXT1(MAP_GET_END test, next)
57 | 
58 | #define MAP_LIST0(f, x, peek, ...) f(x) MAP_LIST_NEXT(peek, MAP_LIST1)(f, peek, __VA_ARGS__)
59 | #define MAP_LIST1(f, x, peek, ...) f(x) MAP_LIST_NEXT(peek, MAP_LIST0)(f, peek, __VA_ARGS__)
60 | 
61 | /**
62 |  * Applies the function macro `f` to each of the remaining parameters.
63 |  */
64 | #define MAP(f, ...) EVAL(MAP1(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0))
65 | 
66 | /**
67 |  * Applies the function macro `f` to each of the remaining parameters and
68 |  * inserts commas between the results.
69 |  */
70 | #define MAP_LIST(f, ...) EVAL(MAP_LIST1(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0))
71 | 
72 | #endif
73 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/setup.py:
--------------------------------------------------------------------------------
 1 | from setuptools import setup
 2 | import torch.cuda
 3 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
 4 | from torch.utils.cpp_extension import CUDA_HOME
 5 | 
 6 | ext_modules = []
 7 | if torch.cuda.is_available() and CUDA_HOME is not None:
 8 |     extension = CUDAExtension(
 9 |         'cauchy_mult', [
10 |             'cauchy.cpp',
11 |             'cauchy_cuda.cu',
12 |         ],
13 |         extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'],
14 |                             # 'nvcc': ['-O2', '-lineinfo']
15 |                             'nvcc': ['-O2', '-lineinfo', '--use_fast_math']
16 |                             }
17 |     )
18 |     ext_modules.append(extension)
19 | 
20 | setup(
21 |     name='cauchy_mult',
22 |     ext_modules=ext_modules,
23 |     # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)})
24 |     cmdclass={'build_ext': BuildExtension})
25 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/test_cauchy.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import torch
  3 | 
  4 | import pytest
  5 | 
  6 | from einops import rearrange
  7 | 
  8 | from cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult
  9 | 
 10 | 
 11 | def generate_data(batch_size, N, L, symmetric=True, device='cuda'):
 12 |     if not symmetric:
 13 |         v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
 14 |         w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
 15 |         z = torch.randn(L, dtype=torch.complex64, device=device)
 16 |     else:
 17 |         assert N % 2 == 0
 18 |         v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
 19 |         v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
 20 |         w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
 21 |         w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
 22 |         z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
 23 |     return v, z, w
 24 | 
 25 | 
 26 | def grad_to_half_grad(dx):
 27 |     dx_half, dx_half_conj = dx.chunk(2, dim=-1)
 28 |     return dx_half + dx_half_conj.conj()
 29 | 
 30 | 
 31 | # @pytest.mark.parametrize('L', [1024])
 32 | # @pytest.mark.parametrize('N', [64])
 33 | # def test_cauchy_mult_nonsymmetric(N, L):
 34 | #     device = 'cuda'
 35 | #     batch_size = 4
 36 | #     torch.random.manual_seed(2357)
 37 | #     v, z, w = generate_data(batch_size, N, L, symmetric=False, device=device)
 38 | #     out_torch = cauchy_mult_torch(v, z, w, symmetric=False)
 39 | #     out_keops = cauchy_mult_keops(v, z, w)
 40 | #     out = cauchy_mult(v, z, w, symmetric=False)
 41 | #     assert torch.allclose(out, out_torch, rtol=1e-4, atol=1e-4)
 42 | #     assert torch.allclose(out, out_keops, rtol=1e-4, atol=1e-4)
 43 | #     dout = torch.randn_like(out)
 44 | #     dv_torch, dw_torch = torch.autograd.grad(out_torch, (v, w), dout, retain_graph=True)
 45 | #     dv_keops, dw_keops = torch.autograd.grad(out_keops, (v, w), dout, retain_graph=True)
 46 | #     dv, dw = torch.autograd.grad(out, (v, w), dout, retain_graph=True)
 47 | #     assert torch.allclose(dv, dv_torch, rtol=1e-4, atol=1e-4)
 48 | #     assert torch.allclose(dv, dv_keops, rtol=1e-4, atol=1e-4)
 49 | #     assert torch.allclose(dw, dw_torch, rtol=1e-4, atol=1e-4)
 50 | #     assert torch.allclose(dw, dw_keops, rtol=1e-4, atol=1e-4)
 51 | 
 52 | 
 53 | @pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12, 2**13, 2**14, 2**18])
 54 | @pytest.mark.parametrize('N', [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
 55 | def test_cauchy_mult_symmetric(N, L):
 56 |     # rtol, atol = (1e-4, 1e-4) if N <= 64 and L <= 1024 else(1e-3, 1e-3)
 57 |     atol = 1e-4
 58 |     tol_factor = 10.0  # Our error shouldn't be this much higher than Keops' error
 59 |     device = 'cuda'
 60 |     batch_size = 4
 61 |     torch.random.manual_seed(2357)
 62 |     v, z, w = generate_data(batch_size, N, L, symmetric=True, device=device)
 63 |     v_half = v[:, :N // 2].clone().detach().requires_grad_(True)
 64 |     w_half = w[:, :N // 2].clone().detach().requires_grad_(True)
 65 |     # out_torch = cauchy_mult_torch(v, z, w, symmetric=True)
 66 |     out_torch = cauchy_mult_torch(v.cdouble(), z.cdouble(), w.cdouble(), symmetric=True).cfloat()
 67 |     out_keops = cauchy_mult_keops(v, z, w)
 68 |     out = cauchy_mult(v_half, z, w_half, symmetric=True)
 69 |     relerr_out_keops = (out_keops - out_torch).abs() / out_torch.abs()
 70 |     relerr_out = (out - out_torch).abs() / out_torch.abs()
 71 | 
 72 |     dout = torch.randn_like(out)
 73 |     dv_torch, dw_torch = torch.autograd.grad(out_torch, (v, w), dout, retain_graph=True)
 74 |     dv_torch, dw_torch = dv_torch[:, :N // 2], dw_torch[:, :N // 2]
 75 |     dv_keops, dw_keops = torch.autograd.grad(out_keops, (v, w), dout, retain_graph=True)
 76 |     dv_keops, dw_keops = grad_to_half_grad(dv_keops), grad_to_half_grad(dw_keops)
 77 |     dv, dw = torch.autograd.grad(out, (v_half, w_half), dout, retain_graph=True)
 78 |     relerr_dv_keops = (dv_keops - dv_torch).abs() / dv_torch.abs()
 79 |     relerr_dv = (dv - dv_torch).abs() / dv_torch.abs()
 80 |     relerr_dw_keops = (dw_keops - dw_torch).abs() / dw_torch.abs()
 81 |     relerr_dw = (dw - dw_torch).abs() / dw_torch.abs()
 82 |     print(f'Keops out relative error: max {relerr_out_keops.amax().item():.6f}, mean {relerr_out_keops.mean().item():6f}')
 83 |     print(f'out relative error: max {relerr_out.amax().item():.6f}, mean {relerr_out.mean().item():.6f}')
 84 |     print(f'Keops dv relative error: max {relerr_dv_keops.amax().item():.6f}, mean {relerr_dv_keops.mean().item():6f}')
 85 |     print(f'dv relative error: max {relerr_dv.amax().item():.6f}, mean {relerr_dv.mean().item():.6f}')
 86 |     print(f'Keops dw relative error: max {relerr_dw_keops.amax().item():.6f}, mean {relerr_dw_keops.mean().item():6f}')
 87 |     print(f'dw relative error: max {relerr_dw.amax().item():.6f}, mean {relerr_dw.mean().item():.6f}')
 88 |     assert (relerr_out.amax() <= relerr_out_keops.amax() * tol_factor + atol)
 89 |     assert (relerr_out.mean() <= relerr_out_keops.mean() * tol_factor + atol)
 90 |     # assert torch.allclose(out, out_torch, rtol=rtol, atol=atol)
 91 |     # assert torch.allclose(out, out_keops, rtol=rtol, atol=atol)
 92 |     assert (relerr_dv.amax() <= relerr_dv_keops.amax() * tol_factor + atol)
 93 |     assert (relerr_dv.mean() <= relerr_dv_keops.mean() * tol_factor + atol)
 94 |     assert (relerr_dw.amax() <= relerr_dw_keops.amax() * tol_factor + atol)
 95 |     assert (relerr_dw.mean() <= relerr_dw_keops.mean() * tol_factor + atol)
 96 |     # assert torch.allclose(dv, dv_torch, rtol=1e-4, atol=1e-4)
 97 |     # assert torch.allclose(dv, dv_keops, rtol=1e-4, atol=1e-4)
 98 |     # assert torch.allclose(dw, dw_torch, rtol=1e-4, atol=1e-4)
 99 |     # assert torch.allclose(dw, dw_keops, rtol=1e-4, atol=1e-4)
100 | 
101 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/tune_cauchy.py:
--------------------------------------------------------------------------------
 1 | import math
 2 | import json
 3 | import argparse
 4 | import itertools
 5 | from pathlib import Path
 6 | 
 7 | from tuner import KernelTuner
 8 | 
 9 | 
10 | def forward_params_list(N):
11 |     blocksize_params = ('MAX_BLOCK_SIZE_VALUE', [64, 128, 256, 512, 1024])
12 |     thread_value_default = [2, 4, 8, 16, 32, 32, 32, 32, 32, 32]
13 |     thread_values_supported = [2, 4, 8, 16, 32, 64, 128]
14 |     log_N_half = int(math.log2(N)) - 1
15 |     thread_values = []
16 |     for val in thread_values_supported:
17 |         if val <= N // 2:
18 |             array = list(thread_value_default)
19 |             array[log_N_half - 1] = val
20 |             thread_values.append('{' + ', '.join(str(v) for v in array) + '}')
21 |     thread_params = ('ITEMS_PER_THREAD_SYM_FWD_VALUES', thread_values)
22 |     value_prod = itertools.product(thread_params[1], blocksize_params[1])
23 |     params_list = [{thread_params[0]: value[0], blocksize_params[0]: value[1]}
24 |                    for value in value_prod]
25 |     return params_list
26 | 
27 | 
28 | def backward_params_list(L):
29 |     thread_value_supported = [8, 16, 32, 64, 128]
30 |     thread_params = ('ITEMS_PER_THREAD_SYM_BWD_VALUE', [v for v in thread_value_supported
31 |                                                         if (L + v - 1) // v <= 1024])
32 |     params_list = [{thread_params[0]: value} for value in thread_params[1]]
33 |     return params_list
34 | 
35 | 
36 | parser = argparse.ArgumentParser(description='Tuning Cauchy multiply')
37 | parser.add_argument('--mode', default='forward', choices=['forward', 'backward'])
38 | parser.add_argument('-N', default=64, type=int)
39 | parser.add_argument('-L', default=2 ** 14, type=int)
40 | parser.add_argument('--filename', default='tuning_result.json')
41 | 
42 | 
43 | if __name__ == '__main__':
44 |     args = parser.parse_args()
45 | 
46 |     extension_dir = Path(__file__).absolute().parent
47 |     source_files = ['cauchy_cuda.cu']
48 |     if args.mode == 'forward':
49 |         params_list = forward_params_list(args.N)
50 |         tuner = KernelTuner(extension_dir, source_files, params_list,
51 |                             benchmark_script='benchmark_cauchy_tune.py',
52 |                             benchmark_args=['--mode', 'forward', '-N', str(args.N), '-L', '16384'],
53 |                             npool=16)
54 |     else:
55 |         params_list = backward_params_list(args.L)
56 |         tuner = KernelTuner(extension_dir, source_files, params_list,
57 |                             benchmark_script='benchmark_cauchy_tune.py',
58 |                             benchmark_args=['--mode', 'backward', '-N', '64', '-L', str(args.L)],
59 |                             npool=16)
60 | 
61 |     result = tuner.tune()
62 |     with open(args.filename, 'w') as f:
63 |         json.dump(result, f)
64 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/tune_cauchy.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | python tune_cauchy.py --mode forward -N 4 --filename tuning_result_fwd_N_4.json
 3 | python tune_cauchy.py --mode forward -N 8 --filename tuning_result_fwd_N_8.json
 4 | python tune_cauchy.py --mode forward -N 16 --filename tuning_result_fwd_N_16.json
 5 | python tune_cauchy.py --mode forward -N 32 --filename tuning_result_fwd_N_32.json
 6 | python tune_cauchy.py --mode forward -N 64 --filename tuning_result_fwd_N_64.json
 7 | python tune_cauchy.py --mode forward -N 128 --filename tuning_result_fwd_N_128.json
 8 | python tune_cauchy.py --mode forward -N 256 --filename tuning_result_fwd_N_256.json
 9 | python tune_cauchy.py --mode forward -N 512 --filename tuning_result_fwd_N_512.json
10 | 
11 | python tune_cauchy.py --mode backward -L 1024 --filename tuning_result_bwd_L_1k.json
12 | python tune_cauchy.py --mode backward -L 2048 --filename tuning_result_bwd_L_2k.json
13 | python tune_cauchy.py --mode backward -L 4096 --filename tuning_result_bwd_L_4k.json
14 | python tune_cauchy.py --mode backward -L 8192 --filename tuning_result_bwd_L_8k.json
15 | python tune_cauchy.py --mode backward -L 16384 --filename tuning_result_bwd_L_16k.json
16 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/tuner.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import shutil
  3 | import subprocess
  4 | import sys
  5 | # import tempfile
  6 | # import importlib
  7 | import random
  8 | import string
  9 | import json
 10 | 
 11 | 
 12 | from functools import partial
 13 | from multiprocessing import Pipe, Pool, Process
 14 | from pathlib import Path
 15 | 
 16 | from tqdm import tqdm
 17 | 
 18 | import numpy as np
 19 | 
 20 | 
 21 | def read_file(filename):
 22 |     """ return the contents of the file named filename or None if file not found """
 23 |     if os.path.isfile(filename):
 24 |         with open(filename, 'r') as f:
 25 |             return f.read()
 26 | 
 27 | 
 28 | def write_file(filename, string):
 29 |     """dump the contents of string to a file called filename"""
 30 |     with open(filename, 'w', encoding="utf-8") as f:
 31 |         f.write(string)
 32 | 
 33 | 
 34 | def prepare_kernel_string(kernel_string, params):
 35 |     for k, v in params.items():
 36 |         kernel_string = "#define " + k + " " + str(v) + "\n" + kernel_string
 37 |     return kernel_string
 38 | 
 39 | 
 40 | def compile_extension(temp_dir, install=False, verbose=True):
 41 |     # Need to copy this process's environments, otherwise it can't find the compilers
 42 |     env = {**os.environ,
 43 |            'TUNING_SOURCE_DIR': str(temp_dir),
 44 |            'TUNING_EXTENSION_NAME': str(temp_dir.stem)}
 45 |     # https://stackoverflow.com/questions/53173314/how-to-change-distutils-output-directory
 46 |     # Need separate build directories for parallel compilation
 47 |     output = subprocess.run(
 48 |         # [sys.executable, "tuning_setup.py", 'build', f'--build-base={str(temp_dir)}',
 49 |         #  f'--build-lib={str(temp_dir)}'],
 50 |         [sys.executable, "tuning_setup.py", 'build' if not install else 'develop'],
 51 |         cwd=temp_dir,
 52 |         env=env,
 53 |         capture_output=True,
 54 |         # check=True
 55 |     )
 56 |     if verbose:
 57 |         print(output)
 58 |         print('Done compiling' if not install else 'Done installing')
 59 | 
 60 | 
 61 | def uninstall_extensions(tuning_extension_names, verbose=True):
 62 |     # Need to copy this process's environments, otherwise it can't find the compilers
 63 |     env = {**os.environ}
 64 |     output = subprocess.run(
 65 |         [sys.executable, '-m', 'pip', 'uninstall', '-y', *tuning_extension_names],
 66 |         env=env,
 67 |         capture_output=True,
 68 |         # check=True
 69 |     )
 70 |     if verbose:
 71 |         print(output)
 72 |         print('Done uninstalling')
 73 | 
 74 | 
 75 | def benchmark_extension(benchmark_script, *benchmark_args, verbose=True):
 76 |     # Need to copy this process's environments, otherwise it can't find the compilers
 77 |     env = os.environ
 78 |     # https://stackoverflow.com/questions/53173314/how-to-change-distutils-output-directory
 79 |     # Need separate build directories for parallel compilation
 80 |     process = subprocess.run(
 81 |         [sys.executable, benchmark_script, *benchmark_args],
 82 |         env=os.environ,
 83 |         capture_output=True,
 84 |         # check=True
 85 |     )
 86 |     if verbose:
 87 |         print(process)
 88 |         print('Done benchmarking')
 89 |     return json.loads(process.stdout.decode(sys.stdout.encoding))
 90 | 
 91 | 
 92 | # def benchmark(connection, temp_dir):
 93 | #     import torch
 94 | #     # module = importlib.import_module(tuning_extension_name)
 95 | #     torch.ops.load_library(temp_dir / 'torch_butterfly_tuning.so')
 96 | #     batch_size = 1024
 97 | #     n = 32
 98 | #     twiddle = torch.randn(1, 1, 5, n // 2, 2, 2, device='cuda')
 99 | #     input = torch.randn(batch_size, 1, n, device=twiddle.device)
100 | #     output = torch.ops.torch_butterfly.butterfly_multiply_fw(twiddle, input, True)
101 | #     # https://medium.com/@auro_227/timing-your-pytorch-code-fragments-e1a556e81f2
102 | #     res = []
103 | #     for _ in range(32):
104 | #         start = torch.cuda.Event(enable_timing=True)
105 | #         end = torch.cuda.Event(enable_timing=True)
106 | #         start.record()
107 | #         output = torch.ops.torch_butterfly.butterfly_multiply_fw(twiddle, input, True)
108 | #         end.record()
109 | #         torch.cuda.synchronize()
110 | #         res.append(start.elapsed_time(end))
111 | #     print(output.shape)
112 | #     res = np.array(res)
113 | #     connection.send((np.mean(res), np.std(res)))
114 | 
115 | 
116 | def set_up_tuning_temp_dir(params: dict, source_files, extension_dir, verbose=True):
117 |     if verbose:
118 |         print('params: ', params)
119 |     # TD [2021-10-22]: tempfile.mkdtemp sometimes create dir name with '_' in it, thus messing up
120 |     # the extension name.
121 |     # temp_dir = Path(tempfile.mkdtemp(prefix="temp_", dir=Path.cwd().parent)).absolute()
122 |     tuning_extension_name = 'temp_' + ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
123 |     temp_dir = (Path.cwd().parent / tuning_extension_name).absolute()
124 |     if temp_dir.exists():
125 |         shutil.rmtree(temp_dir)  # shutil.copytree doesn't want directory that already exists
126 |     shutil.copytree(extension_dir, temp_dir)
127 |     sources = [temp_dir / name for name in source_files]
128 |     for kernel_source in sources:
129 |         ks = read_file(kernel_source)
130 |         ks = prepare_kernel_string(ks, params)
131 |         write_file(kernel_source, ks)
132 |     return temp_dir
133 | 
134 | 
135 | class KernelTuner:
136 | 
137 |     def __init__(self, extension_dir, source_files, params_list, benchmark_script,
138 |                  benchmark_args, npool=8, verbose=True):
139 |         self.extension_dir = extension_dir
140 |         self.source_files = source_files
141 |         self.params_list = params_list
142 |         self.benchmark_script = benchmark_script
143 |         self.benchmark_args = benchmark_args
144 |         self.npool = npool
145 |         self.verbose = verbose
146 | 
147 |     def tune(self):
148 |         temp_dirs = [set_up_tuning_temp_dir(params, self.source_files, self.extension_dir,
149 |                                             verbose=self.verbose)
150 |                      for params in self.params_list]
151 |         # Compile in parallel (for speed), then install sequentially to ensure correctness
152 |         with Pool(self.npool) as p:
153 |             p.map(compile_extension, temp_dirs)
154 |         # with Pool(1) as p:
155 |         #     p.map(partial(compile_extension, install=True), [temp_dirs])
156 |         for temp_dir in tqdm(temp_dirs):
157 |             try:
158 |                 compile_extension(temp_dir, install=True)
159 |             except:
160 |                 pass
161 |         # # We benchmark on a separate process so that they can import the extension that just got compiled.
162 |         # for params, temp_dir in params_tempdir:
163 |         #     print('Benchmarking: ', params)
164 |         #     recv_conn, send_conn = Pipe(duplex=False)
165 |         #     benchmark_process = Process(target=benchmark_fwd, args=(send_conn, str(temp_dir.stem)))
166 |         #     benchmark_process.start()
167 |         #     result = recv_conn.recv()
168 |         #     benchmark_process.join()
169 |         #     print('result', result)
170 |         results = []
171 |         for params, temp_dir in tqdm(list(zip(self.params_list, temp_dirs))):
172 |             try:
173 |                 results.append((params,
174 |                                 benchmark_extension(self.benchmark_script,
175 |                                                     *['--name', temp_dir.stem] + self.benchmark_args)))
176 |             except:
177 |                 pass
178 |         print(results)
179 |         uninstall_extensions([temp_dir.stem for temp_dir in temp_dirs])
180 |         for temp_dir in temp_dirs:
181 |             shutil.rmtree(temp_dir)
182 |         return results
183 | 
--------------------------------------------------------------------------------
/src/extensions/cauchy/tuning_setup.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | from setuptools import setup
 3 | from pathlib import Path
 4 | 
 5 | import torch.cuda
 6 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
 7 | from torch.utils.cpp_extension import CUDA_HOME
 8 | 
 9 | 
10 | extensions_dir = Path(os.getenv('TUNING_SOURCE_DIR')).absolute()
11 | assert extensions_dir.exists()
12 | source_files=[
13 |     'cauchy.cpp',
14 |     'cauchy_cuda.cu',
15 | ]
16 | sources = [str(extensions_dir / name) for name in source_files]
17 | 
18 | extension_name = os.getenv('TUNING_EXTENSION_NAME', default='cauchy_mult_tuning')
19 | ext_modules = []
20 | if torch.cuda.is_available() and CUDA_HOME is not None:
21 |     extension = CUDAExtension(
22 |         extension_name,
23 |         sources,
24 |         include_dirs=[extensions_dir],
25 |         extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'],
26 |                             # 'nvcc': ['-O2', '-lineinfo']
27 |                             'nvcc': ['-O2', '-lineinfo', '--use_fast_math']
28 |                             }
29 |     )
30 |     ext_modules.append(extension)
31 | 
32 | setup(
33 |     name=extension_name,
34 |     ext_modules=ext_modules,
35 |     # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)})
36 |     cmdclass={'build_ext': BuildExtension})
37 | 
--------------------------------------------------------------------------------
/src/main_all.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import subprocess
  3 | import importlib
  4 | import shutil
  5 | 
  6 | from matplotlib import pyplot as plt
  7 | from pathlib import Path
  8 | 
  9 | import torch
 10 | import lightning.pytorch as lp
 11 | from lightning.pytorch.tuner import Tuner
 12 | from lightning.pytorch.loggers import TensorBoardLogger
 13 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
 14 | from clinical_ts.utils.callbacks import LRMonitorCallback, TriggerQuantizerHyperparameterUpdate, UnfreezingFinetuningCallback
 15 | 
 16 | import hydra
 17 | from hydra.core.hydra_config import HydraConfig
 18 | 
 19 | #################
 20 | #specific
 21 | from clinical_ts.config import *
 22 | from omegaconf import OmegaConf
 23 | 
 24 | 
 25 | #mlflow without autologging https://github.com/zjohn77/lightning-mlflow-hf/blob/74c30c784f719ea166941751bda24393946530b7/lightning_mlflow/train.py#L39
 26 | MLFLOW_AVAILABLE=True
 27 | try:
 28 |     import mlflow
 29 |     from lightning.pytorch.loggers import MLFlowLogger
 30 |     from omegaconf import DictConfig, ListConfig
 31 | 
 32 |     def log_params_from_omegaconf_dict(params):
 33 |         for param_name, element in params.items():
 34 |             _explore_recursive(param_name, element)
 35 | 
 36 |     def _explore_recursive(parent_name, element):
 37 |         if isinstance(element, DictConfig):
 38 |             for k, v in element.items():
 39 |                 if isinstance(v, DictConfig) or isinstance(v, ListConfig):
 40 |                     _explore_recursive(f'{parent_name}.{k}', v)
 41 |                 else:
 42 |                     if(k!="_target_" and v is not None):
 43 |                         mlflow.log_param(f'{parent_name}.{k}'," " if v=="" else v)
 44 |         elif isinstance(element, ListConfig):
 45 |             for i, v in enumerate(element):
 46 |                 mlflow.log_param(f'{parent_name}.{i}', " " if v=="" else v)
 47 |     
 48 | except ImportError:
 49 |     MLFLOW_AVAILABLE=False
 50 | 
 51 | def get_git_revision_short_hash():
 52 |     return str(subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip())
 53 | 
 54 | def get_slurm_job_id():
 55 |     job_id = os.environ.get('SLURM_JOB_ID')
 56 | 
 57 |     if job_id:
 58 |         return str(job_id)
 59 |     else:
 60 |         return ""
 61 |     
 62 | def get_work_dir(directory, pattern="version_"):
 63 |     path = Path(directory)
 64 |     
 65 |     # Find all subdirectories matching the pattern "version_X"
 66 |     version_dirs = [d.name[len(pattern):] for d in path.iterdir() if d.is_dir() and d.name.startswith(pattern)]
 67 |     
 68 |     if not version_dirs:
 69 |         return pattern+"0"
 70 |     
 71 |     # Extract the version numbers and find the maximum
 72 |     version_numbers = [int(d.split('_')[-1]) for d in version_dirs]
 73 |     return pattern+str(max(version_numbers) + 1)
 74 | 
 75 | def _string_to_class(_target_):
 76 |     if(len(_target_.split("."))==1):#assume global namespace
 77 |         cls_ = globals()[_target_]
 78 |     else:
 79 |         mod_ = importlib.import_module(".".join(_target_.split(".")[:-1]))
 80 |         cls_ = getattr(mod_, _target_.split(".")[-1])
 81 |     return cls_
 82 |         
 83 | ###################################################################################################
 84 | #MAIN
 85 | ###################################################################################################
 86 | cs = create_default_config()
 87 | 
 88 | @hydra.main(version_base=None, config_path="conf",  config_name="config_supervised_ecg")
 89 | def run(hparams: FullConfig) -> None:
 90 |     hparams.trainer.executable = "main_all"
 91 |     hparams.trainer.revision = get_git_revision_short_hash()
 92 | 
 93 |     if not os.path.exists(hparams.trainer.output_path):
 94 |         os.makedirs(hparams.trainer.output_path)
 95 |     
 96 |     #determine version/output path
 97 |     slurm_job_id = get_slurm_job_id()
 98 |     work_dir = get_work_dir(hparams.trainer.output_path,pattern="run_"+slurm_job_id+"_")
 99 | 
100 |     hparams.trainer.output_path = Path(hparams.trainer.output_path)/(work_dir)
101 |     if not os.path.exists(hparams.trainer.output_path):
102 |         os.makedirs(hparams.trainer.output_path)
103 | 
104 |     logger = [TensorBoardLogger(
105 |         save_dir=Path(hparams.trainer.output_path).parent,
106 |         version=work_dir,#hparams.trainer.metadata.split(":")[0],
107 |         name="")]
108 |     
109 |     print("FULL PARSED CONFIG:")
110 |     print(OmegaConf.to_yaml(hparams))
111 | 
112 |     #get hydra configs
113 |     hydra_cfg = HydraConfig.get()
114 |     config_file = Path(hydra_cfg.runtime.config_sources[1]["path"])/hydra_cfg.job.config_name
115 |     print("Output directory:",hparams.trainer.output_path)
116 |     print("Main config:",config_file)
117 |     print("Overrides:",OmegaConf.to_container(hydra_cfg.overrides.hydra))
118 |     print("Runtime choices:",OmegaConf.to_container(hydra_cfg.runtime.choices))
119 |     #print("Full config:",OmegaConf.to_yaml(hparams))
120 |     
121 |     #copy main config into output dir
122 |     shutil.copyfile(config_file, Path(hparams.trainer.output_path)/(config_file.stem))
123 | 
124 |     #create the model
125 |     classname = _string_to_class(hparams.task.mainclass)
126 |     model = classname(hparams)
127 |     #model = torch.compile(model)
128 | 
129 |     if(MLFLOW_AVAILABLE):
130 |         #os.environ['MLFLOW_TRACKING_USERNAME'] = "ai4h"
131 |         #os.environ['MLFLOW_TRACKING_PASSWORD'] = "mlf22!"
132 |         #os.environ['MLFLOW_TRACKING_URI'] = "https://ai4hmlflow.nsupdate.info/"
133 |         mlflow.set_experiment(hparams.trainer.executable+"("+hparams.task.mainclass.split(".")[-1]+")")
134 |         run = mlflow.start_run(run_name=hparams.trainer.metadata)
135 |         mlf_logger = MLFlowLogger(
136 |             experiment_name=mlflow.get_experiment(run.info.experiment_id).name,
137 |             tracking_uri=mlflow.get_tracking_uri(),
138 |             log_model=False,
139 |         )
140 |         mlf_logger._run_id = run.info.run_id
141 |         mlf_logger.log_hyperparams = log_params_from_omegaconf_dict       
142 |         logger.append(mlf_logger)
143 | 
144 |     key_summary_metric = model.metrics_train_val[0].key_summary_metric if len(model.metrics_train_val)>0 else 'val_loss'#use the key_summary_metric of the first metric otherwise val_loss
145 |     mode_summary_metric = model.metrics_train_val[0].mode_summary_metric if len(model.metrics_train_val)>0 else 'min'
146 | 
147 |     checkpoint_callback = ModelCheckpoint(
148 |         dirpath=logger[0].log_dir,
149 |         filename="best_model",
150 |         save_top_k=1,
151 | 		save_last=True,
152 |         verbose=True,
153 |         monitor=key_summary_metric,
154 |         mode=mode_summary_metric)
155 | 
156 |     lr_monitor = LearningRateMonitor(logging_interval="step")
157 |     lr_monitor2 = LRMonitorCallback(start=False,end=True)#interval="step")
158 |     
159 |     callbacks = [checkpoint_callback,lr_monitor,lr_monitor2]
160 |     if(hparams.trainer.refresh_rate>0):
161 |         callbacks.append(TQDMProgressBar(refresh_rate=hparams.trainer.refresh_rate))
162 |     quantizers = [m for m in model.modules() if isinstance(m,QuantizerBase)]
163 |     if(len(quantizers)>0):
164 |         print("Found",len(quantizers),"quantizer modules.")
165 |         callbacks.append(TriggerQuantizerHyperparameterUpdate(quantizers))
166 |     if(hparams.loss.loss_type=="supervised" and hparams.trainer.frozen_epochs>0):
167 |         callbacks.append(UnfreezingFinetuningCallback(unfreeze_epoch=hparams.trainer.frozen_epochs))
168 | 
169 |     trainer = lp.Trainer(
170 |         #overfit_batches=0.01,
171 |         accumulate_grad_batches=hparams.trainer.accumulate,
172 |         max_epochs=hparams.trainer.epochs if hparams.trainer.eval_only=="" else 0,
173 |     
174 |         default_root_dir=hparams.trainer.output_path,
175 |         
176 |         #debugging flags for val and train
177 |         num_sanity_val_steps=0,
178 |         #overfit_batches=10,
179 |         
180 |         logger=logger,
181 |         callbacks = callbacks,
182 |         benchmark=True,
183 |     
184 |         accelerator="gpu" if hparams.trainer.gpus>0 else "cpu",
185 |         devices=hparams.trainer.gpus if hparams.trainer.gpus>0 else 1,
186 |         num_nodes=hparams.trainer.num_nodes,
187 |         precision=hparams.trainer.precision,
188 |         strategy=hparams.trainer.strategy,
189 |         
190 |         enable_progress_bar=hparams.trainer.refresh_rate>0,
191 |         #weights_summary='top',
192 |         )
193 |     
194 |     if(hparams.trainer.fp32_matmul_precision!="highest"):
195 |         torch.set_float32_matmul_precision(hparams.trainer.fp32_matmul_precision)
196 |         
197 |     if(hparams.trainer.auto_batch_size):#batch size
198 |         tuner=Tuner(trainer)
199 |         tuner.scale_batch_size(model, mode="binsearch")
200 | 
201 |     if(hparams.trainer.lr_find):# lr find
202 |         tuner=Tuner(trainer)
203 | 
204 |         #torch.save(model.state_dict(), Path(hparams.trainer.output_path)/(logger.log_dir+"initial_weights.ckpt"))
205 |         # Run learning rate finder
206 |         lr_finder = tuner.lr_find(model)
207 | 
208 |         # Plot lr find plot
209 |         fig = lr_finder.plot(suggest=True)
210 |         fig.show()
211 |         plt.savefig(Path(hparams.trainer.output_path)/("lrfind.png"))
212 | 
213 |         # Pick point based on plot, or get suggestion
214 |         new_lr = lr_finder.suggestion()
215 |         print("Suggested lr:",new_lr)
216 |         # update hparams of the model
217 |         model.hparams.base.lr = [new_lr]
218 |         model.lr = new_lr
219 | 
220 |         # there is still some issue with the restored model- therefore just abort the run
221 |         #model.load_state_dict(torch.load(Path(hparams.trainer.output_path)/(logger.log_dir+"initial_weights.ckpt")))
222 |         return
223 | 
224 |     if(MLFLOW_AVAILABLE):
225 |         #version_number = hparams.trainer.output_path.split('/')[-1].replace('version_', '') # tw: extract version index
226 |         mlflow.log_param("a_slurm_job_id", slurm_job_id)
227 |         mlflow.log_param("a_work_dir", work_dir)
228 |     
229 |     if(hparams.trainer.epochs>0 and hparams.trainer.eval_only==""):
230 |         trainer.fit(model,ckpt_path= None if hparams.trainer.resume=="" else hparams.trainer.resume)
231 |     trainer.test(model,ckpt_path="best" if hparams.trainer.eval_only=="" else hparams.trainer.eval_only)
232 | 
233 |     if(MLFLOW_AVAILABLE):
234 |         mlflow.end_run()
235 |     
236 |     if(hparams.trainer.export_features!=""):
237 |         model.export_features(Path(hparams.trainer.output_path)/"features",module=hparams.trainer.export_features)
238 | 
239 | if __name__ == "__main__":
240 |     run()
241 | 
--------------------------------------------------------------------------------