├── .gitignore ├── README.md ├── datasets ├── FreqShape │ ├── split=1.pt │ ├── split=2.pt │ ├── split=3.pt │ ├── split=4.pt │ └── split=5.pt └── SeqCombMV2 │ └── README.md ├── experiments ├── Boiler │ ├── bc_model_ptype.py │ └── train_transformer.py ├── PAM │ ├── bc_model_ptype.py │ └── train_transformer.py ├── epilepsy │ ├── bc_model_ptype.py │ └── train_transformer.py ├── evaluation │ ├── evaluate_predictor_only.py │ ├── get_model_outputs.py │ ├── occlusion_exp.py │ ├── results │ │ ├── boiler_dyna_sp=1_occlusion_results.csv │ │ ├── boiler_dyna_sp=1_occlusion_results_zero.csv │ │ ├── boiler_dyna_sp=2_occlusion_results.csv │ │ ├── boiler_dyna_sp=2_occlusion_results_zero.csv │ │ ├── boiler_dyna_sp=3_occlusion_results.csv │ │ ├── boiler_dyna_sp=3_occlusion_results_zero.csv │ │ ├── boiler_dyna_sp=4_occlusion_results.csv │ │ ├── boiler_dyna_sp=4_occlusion_results_zero.csv │ │ ├── boiler_dyna_sp=5_occlusion_results.csv │ │ ├── boiler_dyna_sp=5_occlusion_results_zero.csv │ │ ├── boiler_ours_sp=1_occlusion_results.csv │ │ ├── boiler_ours_sp=2_occlusion_results.csv │ │ ├── boiler_ours_sp=3_occlusion_results.csv │ │ ├── boiler_ours_sp=4_occlusion_results.csv │ │ ├── boiler_ours_sp=5_occlusion_results.csv │ │ ├── boiler_random_sp=1_occlusion_results.csv │ │ ├── boiler_random_sp=1_occlusion_results_zero.csv │ │ ├── boiler_random_sp=2_occlusion_results.csv │ │ ├── boiler_random_sp=2_occlusion_results_zero.csv │ │ ├── boiler_random_sp=3_occlusion_results.csv │ │ ├── boiler_random_sp=3_occlusion_results_zero.csv │ │ ├── boiler_random_sp=4_occlusion_results.csv │ │ ├── boiler_random_sp=4_occlusion_results_zero.csv │ │ ├── boiler_random_sp=5_occlusion_results.csv │ │ ├── boiler_random_sp=5_occlusion_results_zero.csv │ │ ├── boiler_timex_sp=1_occlusion_results.csv │ │ ├── boiler_timex_sp=2_occlusion_results.csv │ │ ├── boiler_timex_sp=3_occlusion_results.csv │ │ ├── boiler_timex_sp=4_occlusion_results.csv │ │ ├── boiler_timex_sp=5_occlusion_results.csv │ │ ├── epilepsy_dyna_sp=1_occlusion_results.csv │ │ ├── epilepsy_dyna_sp=1_occlusion_results_zero.csv │ │ ├── epilepsy_dyna_sp=2_occlusion_results.csv │ │ ├── epilepsy_dyna_sp=2_occlusion_results_zero.csv │ │ ├── epilepsy_dyna_sp=3_occlusion_results.csv │ │ ├── epilepsy_dyna_sp=3_occlusion_results_zero.csv │ │ ├── epilepsy_dyna_sp=4_occlusion_results.csv │ │ ├── epilepsy_dyna_sp=4_occlusion_results_zero.csv │ │ ├── epilepsy_dyna_sp=5_occlusion_results.csv │ │ ├── epilepsy_dyna_sp=5_occlusion_results_zero.csv │ │ ├── epilepsy_ours_sp=1_occlusion_results.csv │ │ ├── epilepsy_ours_sp=1_occlusion_results_zero.csv │ │ ├── epilepsy_ours_sp=2_occlusion_results.csv │ │ ├── epilepsy_ours_sp=2_occlusion_results_zero.csv │ │ ├── epilepsy_ours_sp=3_occlusion_results.csv │ │ ├── epilepsy_ours_sp=3_occlusion_results_zero.csv │ │ ├── epilepsy_ours_sp=4_occlusion_results.csv │ │ ├── epilepsy_ours_sp=4_occlusion_results_zero.csv │ │ ├── epilepsy_ours_sp=5_occlusion_results.csv │ │ ├── epilepsy_ours_sp=5_occlusion_results_zero.csv │ │ ├── epilepsy_random_sp=1_occlusion_results.csv │ │ ├── epilepsy_random_sp=1_occlusion_results_zero.csv │ │ ├── epilepsy_random_sp=2_occlusion_results.csv │ │ ├── epilepsy_random_sp=2_occlusion_results_zero.csv │ │ ├── epilepsy_random_sp=3_occlusion_results.csv │ │ ├── epilepsy_random_sp=3_occlusion_results_zero.csv │ │ ├── epilepsy_random_sp=4_occlusion_results.csv │ │ ├── epilepsy_random_sp=4_occlusion_results_zero.csv │ │ ├── epilepsy_random_sp=5_occlusion_results.csv │ │ ├── epilepsy_random_sp=5_occlusion_results_zero.csv │ │ ├── epilepsy_timex_sp=1_occlusion_results.csv │ │ ├── epilepsy_timex_sp=1_occlusion_results_zero.csv │ │ ├── epilepsy_timex_sp=2_occlusion_results.csv │ │ ├── epilepsy_timex_sp=2_occlusion_results_zero.csv │ │ ├── epilepsy_timex_sp=3_occlusion_results.csv │ │ ├── epilepsy_timex_sp=3_occlusion_results_zero.csv │ │ ├── epilepsy_timex_sp=4_occlusion_results.csv │ │ ├── epilepsy_timex_sp=4_occlusion_results_zero.csv │ │ ├── epilepsy_timex_sp=5_occlusion_results.csv │ │ ├── epilepsy_timex_sp=5_occlusion_results_zero.csv │ │ ├── pam_dyna_sp=1_occlusion_results.csv │ │ ├── pam_dyna_sp=1_occlusion_results_zero.csv │ │ ├── pam_dyna_sp=2_occlusion_results.csv │ │ ├── pam_dyna_sp=2_occlusion_results_zero.csv │ │ ├── pam_dyna_sp=3_occlusion_results.csv │ │ ├── pam_dyna_sp=3_occlusion_results_zero.csv │ │ ├── pam_dyna_sp=4_occlusion_results.csv │ │ ├── pam_dyna_sp=4_occlusion_results_zero.csv │ │ ├── pam_dyna_sp=5_occlusion_results.csv │ │ ├── pam_dyna_sp=5_occlusion_results_zero.csv │ │ ├── pam_ours_sp=1_occlusion_results.csv │ │ ├── pam_ours_sp=2_occlusion_results.csv │ │ ├── pam_ours_sp=3_occlusion_results.csv │ │ ├── pam_ours_sp=4_occlusion_results.csv │ │ ├── pam_ours_sp=5_occlusion_results.csv │ │ ├── pam_random_sp=1_occlusion_results.csv │ │ ├── pam_random_sp=1_occlusion_results_zero.csv │ │ ├── pam_random_sp=2_occlusion_results.csv │ │ ├── pam_random_sp=2_occlusion_results_zero.csv │ │ ├── pam_random_sp=3_occlusion_results.csv │ │ ├── pam_random_sp=3_occlusion_results_zero.csv │ │ ├── pam_random_sp=4_occlusion_results.csv │ │ ├── pam_random_sp=4_occlusion_results_zero.csv │ │ ├── pam_random_sp=5_occlusion_results.csv │ │ ├── pam_random_sp=5_occlusion_results_zero.csv │ │ ├── pam_timex_sp=1_occlusion_results.csv │ │ ├── pam_timex_sp=1_occlusion_results_zero.csv │ │ ├── pam_timex_sp=2_occlusion_results.csv │ │ ├── pam_timex_sp=2_occlusion_results_zero.csv │ │ ├── pam_timex_sp=3_occlusion_results.csv │ │ ├── pam_timex_sp=3_occlusion_results_zero.csv │ │ ├── pam_timex_sp=4_occlusion_results.csv │ │ ├── pam_timex_sp=4_occlusion_results_zero.csv │ │ ├── pam_timex_sp=5_occlusion_results.csv │ │ └── pam_timex_sp=5_occlusion_results_zero.csv │ ├── saliency_exp_synth.py │ ├── vis_occlusion.py │ └── winit_wrapper.py ├── freqshape │ ├── bc_model_ptype.py │ ├── models │ │ ├── Scomb_transformer_split=1.pt │ │ ├── Scomb_transformer_split=2.pt │ │ ├── Scomb_transformer_split=3.pt │ │ ├── Scomb_transformer_split=4.pt │ │ ├── Scomb_transformer_split=5.pt │ │ ├── our_bc_full_split=1.pt │ │ ├── our_bc_full_split=2.pt │ │ ├── our_bc_full_split=3.pt │ │ ├── our_bc_full_split=4.pt │ │ └── our_bc_full_split=5.pt │ ├── train_cnn.py │ ├── train_lstm.py │ └── train_transformer.py ├── lowvardetect │ ├── bc_model_ptype.py │ ├── conv_t_cpu.py │ └── train_transformer.py ├── mitecg_hard │ ├── bc_model_ptype.py │ ├── conv_t_cpu.py │ ├── train_cnn.py │ ├── train_lstm.py │ └── train_transformer.py ├── other_baselines │ ├── contrast_generator.py │ ├── cortx_exp.py │ ├── infonce.py │ └── train_SGT.py ├── scs_better │ ├── bc_model_ptype.py │ ├── train_cnn.py │ ├── train_lstm.py │ └── train_transformer.py ├── seqcomb_mv │ ├── bc_model_ptype.py │ ├── models │ │ ├── bc_full_LC_split=1.pt │ │ ├── bc_full_LC_split=2.pt │ │ ├── bc_full_LC_split=3.pt │ │ ├── bc_full_LC_split=4.pt │ │ ├── bc_full_LC_split=5.pt │ │ ├── our_bc_full_LC_split=1.pt │ │ ├── our_bc_full_LC_split=2.pt │ │ ├── our_bc_full_LC_split=3.pt │ │ ├── our_bc_full_LC_split=4.pt │ │ ├── our_bc_full_LC_split=5.pt │ │ ├── transformer_split=1.pt │ │ ├── transformer_split=2.pt │ │ ├── transformer_split=3.pt │ │ ├── transformer_split=4.pt │ │ └── transformer_split=5.pt │ ├── train_cnn.py │ ├── train_lstm.py │ └── train_transformer.py ├── viz │ └── vis_other_explainer.py └── water │ ├── bc_model_ptype.py │ ├── train_lstm.py │ └── train_transformer.py ├── pic ├── model.pdf └── model.png ├── requirements.txt ├── setup.py ├── timesynth-0.2.4 ├── PKG-INFO ├── README.md ├── build │ └── lib │ │ └── timesynth │ │ ├── __init__.py │ │ ├── noise │ │ ├── __init__.py │ │ ├── base_noise.py │ │ ├── gaussian_noise.py │ │ └── red_noise.py │ │ ├── signals │ │ ├── __init__.py │ │ ├── ar.py │ │ ├── base_signal.py │ │ ├── car.py │ │ ├── dde.py │ │ ├── gaussian_process.py │ │ ├── narma.py │ │ ├── ode.py │ │ ├── pseudoperiodic.py │ │ └── sinusoidal.py │ │ ├── timesampler │ │ ├── __init__.py │ │ └── timesampler.py │ │ └── timeseries.py ├── setup.cfg ├── setup.py ├── timesynth.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt └── timesynth │ ├── __init__.py │ ├── noise │ ├── __init__.py │ ├── base_noise.py │ ├── gaussian_noise.py │ └── red_noise.py │ ├── signals │ ├── __init__.py │ ├── ar.py │ ├── base_signal.py │ ├── car.py │ ├── dde.py │ ├── gaussian_process.py │ ├── narma.py │ ├── ode.py │ ├── pseudoperiodic.py │ └── sinusoidal.py │ ├── timesampler │ ├── __init__.py │ └── timesampler.py │ └── timeseries.py └── txai ├── __init__.py ├── baselines ├── Dynamask │ ├── __init__.py │ ├── attribution │ │ ├── __init__.py │ │ ├── mask.py │ │ ├── mask_group.py │ │ └── perturbation.py │ └── utils │ │ ├── __init__.py │ │ ├── losses.py │ │ ├── metrics.py │ │ └── tensor_manipulation.py ├── FIT │ ├── .gitignore │ ├── TSX │ │ ├── __init__.py │ │ ├── experiments.py │ │ ├── explainers.py │ │ ├── generator.py │ │ ├── main.py │ │ ├── models.py │ │ ├── select_sub_groups.py │ │ ├── temperature_scaling.py │ │ └── utils.py │ ├── config.json │ ├── data_generator │ │ ├── data │ │ │ └── clean_state_data.py │ │ ├── data_preprocess.py │ │ ├── hmm_forward.py │ │ ├── icu_mortality.py │ │ ├── preprocess_real.ipynb │ │ ├── simulated_data_l2x.py │ │ ├── simulated_l2x_switchstate.py │ │ ├── simulations_metrics.py │ │ ├── simulations_metrics_time.py │ │ ├── simulations_threshold_spikes.py │ │ ├── state_data.py │ │ └── true_generator_state_data.py │ ├── environment.yml │ └── evaluation │ │ ├── accordance.py │ │ ├── baseline_results.py │ │ ├── baselines.py │ │ ├── cv_mimic.sh │ │ ├── cv_simulation.sh │ │ ├── cv_simulation_attention.sh │ │ ├── generator_baselines.py │ │ ├── global_importance.py │ │ ├── interventions.py │ │ ├── main_global_importance.py │ │ ├── performance_drop_test.py │ │ ├── performance_scores.py │ │ ├── plot_baselines.py │ │ └── sanity_checks.py ├── SGT │ ├── Helper.py │ ├── __init__.py │ ├── cnn.py │ ├── interpretable.py │ ├── maskedAcc_MNIST.py │ ├── regular.py │ ├── train_MNIST.py │ └── utils.py └── __init__.py ├── models ├── __init__.py ├── bc_model.py ├── bc_model4.py ├── encoders │ ├── positional_enc.py │ ├── simple.py │ └── transformer_simple.py ├── layers.py ├── mask_generators │ ├── base_adv_model.py │ ├── base_mask_model.py │ ├── gumbel.py │ ├── gumbelmask_model.py │ ├── maskgen.py │ └── unstructured_maskgen.py └── run_model_utils.py ├── prototypes ├── posthoc.py └── tune_ptypes.py ├── smoother.py ├── synth_data ├── __init__.py ├── freq_shapes.py ├── generate_spikes.py ├── hmm.py ├── lowvardetect.py ├── lowvarmatch.py ├── motif_seq.py ├── motifseq.py ├── redundant_spike.py ├── seq_comb_better.py ├── seq_comb_mv.py ├── simple_spike.py ├── synth_data_base.py └── trigtrack.py ├── trainers ├── __init__.py ├── eliminates │ └── train_mv6_consistency_idexp.py ├── train_mv4_consistency.py ├── train_mv6_consistency.py └── train_transformer.py ├── utils ├── attention.py ├── baseline_comp │ ├── run_FIT.py │ ├── run_WinIT.py │ ├── run_dynamask.py │ ├── run_random.py │ └── screen.py ├── cl.py ├── cl_metrics.py ├── concepts.py ├── constants.py ├── data │ ├── __init__.py │ ├── datasets.py │ ├── preprocess.py │ ├── synth.py │ └── utils_phy12.py ├── evaluation.py ├── experimental.py ├── functional.py ├── masking.py ├── predictors │ ├── __init__.py │ ├── eval.py │ ├── loss.py │ ├── loss_cl.py │ ├── loss_smoother_stats.py │ └── select_models.py └── shapebank │ └── v1.py └── vis ├── vis_saliency.py └── visualize_mv6.py /.gitignore: -------------------------------------------------------------------------------- 1 | txai.egg-info 2 | **/__pycache__/ 3 | *.pkl 4 | .vscode 5 | .ipynb_checkpoints 6 | *.err 7 | *.out 8 | experiments/*/models/ 9 | !experiments/freqshape/models -------------------------------------------------------------------------------- /datasets/FreqShape/split=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/datasets/FreqShape/split=1.pt -------------------------------------------------------------------------------- /datasets/FreqShape/split=2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/datasets/FreqShape/split=2.pt -------------------------------------------------------------------------------- /datasets/FreqShape/split=3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/datasets/FreqShape/split=3.pt -------------------------------------------------------------------------------- /datasets/FreqShape/split=4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/datasets/FreqShape/split=4.pt -------------------------------------------------------------------------------- /datasets/FreqShape/split=5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/datasets/FreqShape/split=5.pt -------------------------------------------------------------------------------- /datasets/SeqCombMV2/README.md: -------------------------------------------------------------------------------- 1 | please download the SeqCombMV2 dataset here: 2 | 3 | 4 | 5 | https://drive.google.com/file/d/1k3487vs9iEw3y1neOWMRwuUGmAHDvIof/view?usp=drive_link -------------------------------------------------------------------------------- /experiments/Boiler/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | from txai.utils.data.preprocess import process_Boiler_OLD, RWDataset 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | clf_criterion = Poly1CrossEntropyLoss( 14 | num_classes = 2, 15 | epsilon = 1.0, 16 | weight = None,#torch.tensor([1.0,4.0]), 17 | reduction = 'mean' 18 | ) 19 | 20 | for i in range(4, 6): 21 | torch.cuda.empty_cache() 22 | trainB, val, test = process_Boiler_OLD(split_no = i, device = device, base_path = '/TimeX/datasets/Boiler/') 23 | train_dataset = RWDataset(*trainB) 24 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 64, shuffle = True) 25 | 26 | model = TransformerMVTS( 27 | d_inp = val[0].shape[-1], 28 | max_len = val[0].shape[0], 29 | n_classes = 2, 30 | nlayers = 1, 31 | trans_dim_feedforward = 32, 32 | trans_dropout = 0.25, 33 | d_pe = 16, 34 | norm_embedding = True, 35 | stronger_clf_head = False, 36 | ) 37 | 38 | model.to(device) 39 | 40 | optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-3, weight_decay = 0.001) 41 | 42 | spath = 'models/transformer_split={}.pt'.format(i) 43 | 44 | model, loss, auc = train( 45 | model, 46 | train_loader, 47 | val_tuple = val, 48 | n_classes = 2, 49 | num_epochs = 1000, 50 | save_path = spath, 51 | optimizer = optimizer, 52 | show_sizes = False, 53 | validate_by_step = 32, 54 | use_scheduler = False 55 | ) 56 | 57 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 58 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 59 | 60 | f1 = eval_mvts_transformer(test, model, batch_size = 32) 61 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/PAM/train_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import time 4 | import numpy as np 5 | 6 | import sys 7 | sys.path.append('..') 8 | sys.path.append('../..') 9 | 10 | from txai.models.encoders.transformer_simple import TransformerMVTS 11 | from txai.trainers.train_transformer import train 12 | from txai.utils.data.preprocess import process_PAM 13 | from txai.utils.predictors import eval_mvts_transformer 14 | 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | class PAMDataset(torch.utils.data.Dataset): 19 | def __init__(self, X, times, y): 20 | self.X = X # Shape: (T, N, d) 21 | self.times = times # Shape: (T, N) 22 | self.y = y # Shape: (N,) 23 | 24 | def __len__(self): 25 | return self.X.shape[0] 26 | 27 | def __getitem__(self, idx): 28 | x = self.X[:,idx,:] 29 | T = self.times[:,idx] 30 | y = self.y[idx] 31 | return x, T, y 32 | 33 | test_f1 = [] 34 | elapsed = [] 35 | 36 | for i in range(1, 6): 37 | start = time.time() 38 | print(f'\n------------------ Split {i} ------------------') 39 | trainPAM, val, test = process_PAM(split_no = i, device = device, base_path = '/TimeX/datasets/PAM/', gethalf = True) 40 | # Output of above are chunks 41 | train_dataset = PAMDataset(trainPAM.X, trainPAM.time, trainPAM.y) 42 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True) 43 | 44 | model = TransformerMVTS( 45 | d_inp = trainPAM.X.shape[2], 46 | max_len = trainPAM.X.shape[0], 47 | n_classes = 8, 48 | ) 49 | 50 | # Convert to GPU: 51 | model.to(device) 52 | 53 | spath = f'models/transformer_split={i}.pt' 54 | 55 | # Train model: 56 | model, loss, auc = train(model, train_loader, 57 | val_tuple = (val.X, val.time, val.y), n_classes = 8, num_epochs = 100, 58 | save_path = spath, validate_by_step = None) 59 | 60 | elapsed.append(time.time() - start) 61 | 62 | # Get test result: 63 | f1 = eval_mvts_transformer((test.X, test.time, test.y), model, batch_size = 64) 64 | test_f1.append(f1) 65 | print('Test F1: {:.4f} \t Time: {}'.format(f1, elapsed[-1])) 66 | 67 | print('='*50) 68 | print('Testing Scores:', test_f1) 69 | print('Avg: {:.4f}'.format(np.mean(test_f1))) 70 | print('Avg elapsed: {:.4f}'.format(np.mean(elapsed))) 71 | -------------------------------------------------------------------------------- /experiments/epilepsy/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | from txai.utils.data import EpiDataset 10 | from txai.utils.data.preprocess import process_Epilepsy 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | clf_criterion = Poly1CrossEntropyLoss( 15 | num_classes = 2, 16 | epsilon = 1.0, 17 | weight = None, 18 | reduction = 'mean' 19 | ) 20 | 21 | for i in range(1, 6): 22 | torch.cuda.empty_cache() 23 | trainEpi, val, test = process_Epilepsy(split_no = i, device = device, base_path = '/TimeX/datasets/Epilepsy/') 24 | train_dataset = EpiDataset(trainEpi.X, trainEpi.time, trainEpi.y) 25 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True) 26 | 27 | print('X shape') 28 | print(trainEpi.X.shape) 29 | print('y shape', trainEpi.y.shape) 30 | 31 | val = (val.X, val.time, val.y) 32 | test = (test.X, test.time, test.y) 33 | 34 | model = TransformerMVTS( 35 | d_inp = val[0].shape[-1], 36 | max_len = val[0].shape[0], 37 | n_classes = 2, 38 | nlayers = 1, 39 | trans_dim_feedforward = 16, 40 | trans_dropout = 0.1, 41 | d_pe = 16, 42 | norm_embedding = False, 43 | ) 44 | 45 | model.to(device) 46 | 47 | optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4, weight_decay = 0.001) 48 | 49 | spath = 'models/transformer_split={}.pt'.format(i) 50 | 51 | model, loss, auc = train( 52 | model, 53 | train_loader, 54 | val_tuple = val, 55 | n_classes = 2, 56 | num_epochs = 300, 57 | save_path = spath, 58 | optimizer = optimizer, 59 | show_sizes = False, 60 | validate_by_step = 32, 61 | use_scheduler = False 62 | ) 63 | 64 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 65 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 66 | 67 | f1 = eval_mvts_transformer(test, model, batch_size = 32) 68 | print('Test F1: {:.4f}'.format(f1)) 69 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.5073822903892518,0.6248137282077577 3 | 0.8,0.5020304127788527,0.599791234264644 4 | 0.85,0.4966682637892979,0.57886837810641103 5 | 0.9,0.4915097135298933,0.55694102929694214 6 | 0.925,0.4844746993352032,0.4436040261425775 7 | 0.95,0.4693966035084718,0.33583365638688 8 | 0.975,0.471857563030245,0.32783796550502753 9 | 0.99,0.44132792477141536,0.12606970238634213 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.999795022758392,0.9999630916052066 3 | 0.8,0.999698844531427,0.9999442170388487 4 | 0.85,0.9978928451234581,0.9996458706778609 5 | 0.9,0.9597696263835016,0.9955222508437436 6 | 0.925,0.9308579822980892,0.9826424401308425 7 | 0.95,0.7752799451931278,0.9358943114242915 8 | 0.975,0.5195698120268752,0.4537390368825841 9 | 0.99,0.45991099696226684,0.13498270139687762 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7065278987156802,0.734497124802416 3 | 0.8,0.7081407092761507,0.737991525866185 4 | 0.85,0.6922325274680965,0.7495240750768875 5 | 0.9,0.7013733584198931,0.759714764526543 6 | 0.925,0.652051116987636,0.7378137097955972 7 | 0.95,0.6393926919365268,0.7352547452719632 8 | 0.975,0.6773133480248019,0.636275306526741 9 | 0.99,0.5490989588419755,0.5993847542939217 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9738548261364137,0.9884780966332397 3 | 0.8,0.9798557141729712,0.9924070015645292 4 | 0.85,0.9807430786407693,0.9930467926142519 5 | 0.9,0.9787434524460921,0.9919659294319898 6 | 0.925,0.9783504605910875,0.9916594384223789 7 | 0.95,0.9761028893077276,0.98979933729674 8 | 0.975,0.9734006084887703,0.9896822225041497 9 | 0.99,0.8413630170037241,0.9594031513211397 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.4906421506491996,0.65482097200569716 3 | 0.8,0.48516488208887976,0.5250721227026731 4 | 0.85,0.4808668274867983,0.50458507355574047 5 | 0.9,0.4809995823134234,0.49850279709303094 6 | 0.925,0.4749161520639152,0.4699974517233572 7 | 0.95,0.46537442584207894,0.31129925284461574 8 | 0.975,0.4645755635089214,0.2770664119062614 9 | 0.99,0.479671709504129,0.29860635919031225 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.5519465568987727,0.7071404943690317 3 | 0.8,0.5419382267873841,0.6551519410999189 4 | 0.85,0.550743708994749,0.682624124655203 5 | 0.9,0.5149302254345571,0.5445667606023147 6 | 0.925,0.5304441346066131,0.636585568245934 7 | 0.95,0.47105228750473965,0.29376340398978906 8 | 0.975,0.4546241093412603,0.18390074315347177 9 | 0.99,0.4533540155886774,0.13557308858443526 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.5280585039370573,0.5873570104469253 3 | 0.8,0.5167616447645633,0.5600557068743692 4 | 0.85,0.515032422731203,0.555643819025109 5 | 0.9,0.5253051157012121,0.5875934995152725 6 | 0.925,0.5267867179912206,0.5996791281682263 7 | 0.95,0.5313534357617828,0.633569599606131 8 | 0.975,0.5501183953825477,0.6129115296528686 9 | 0.99,0.5876971885423512,0.6138176938003903 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.577578611489794,0.7481961337321457 3 | 0.8,0.5803857840734931,0.7616905869216658 4 | 0.85,0.5834713063723771,0.7812814094433197 5 | 0.9,0.5833985962338678,0.8028336315010376 6 | 0.925,0.5805805605586823,0.8096166420036197 7 | 0.95,0.5902273101500836,0.8231655434828392 8 | 0.975,0.6064145539152981,0.8263679037200382 9 | 0.99,0.6789103702456536,0.9061424144462988 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.542393994959463,0.6461163309726252 3 | 0.8,0.5431300829309631,0.638773641235101 4 | 0.85,0.5226260383032244,0.5793097508027576 5 | 0.9,0.5056681558408079,0.5134853627107341 6 | 0.925,0.4888301350952001,0.439354568019187 7 | 0.95,0.47617448925499756,0.33981936367405163 8 | 0.975,0.47669733223418703,0.30038315369706514 9 | 0.99,0.4450812517852677,0.09596381270612267 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_dyna_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.544183892954087,0.6666004900829497 3 | 0.8,0.5342890213086356,0.6543956447925996 4 | 0.85,0.524024359648682,0.6039165523893646 5 | 0.9,0.520941664507005,0.5477818790374862 6 | 0.925,0.5164241598050221,0.5105917950032086 7 | 0.95,0.52697286281602,0.5820779855045012 8 | 0.975,0.540196070982646,0.6482651940469366 9 | 0.99,0.4632119136636496,0.08268637499282219 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_ours_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7497001703818182,0.8896017680886655 3 | 0.8,0.7519411782885509,0.8857210479495536 4 | 0.85,0.6565783734953727,0.863120930145356 5 | 0.9,0.5998054343373785,0.813007548229141 6 | 0.925,0.600702500512778,0.8174366081504666 7 | 0.95,0.6007077981806284,0.8182634633879233 8 | 0.975,0.5870849792055665,0.7885855093102057 9 | 0.99,0.565024258825125,0.7622745529733411 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_ours_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7328403365052861,0.9097046000050613 3 | 0.8,0.6633116937035892,0.8736829420076637 4 | 0.85,0.6842157518696863,0.8884463945080308 5 | 0.9,0.6627244376497405,0.8666804568704388 6 | 0.925,0.5696569223088087,0.7785847951033406 7 | 0.95,0.5899093591768221,0.8059088931828177 8 | 0.975,0.5316163400242712,0.6087896480444057 9 | 0.99,0.5084951642659239,0.5473230883532323 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_ours_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7133862741330987,0.8855096338897421 3 | 0.8,0.6941348630386439,0.8801151020660042 4 | 0.85,0.7133107427802347,0.883820717514328 5 | 0.9,0.6672668777341334,0.8634164600343037 6 | 0.925,0.5742566770489772,0.7804890034422332 7 | 0.95,0.5802669268011874,0.7647549295196788 8 | 0.975,0.49988097773132684,0.474820508759102 9 | 0.99,0.5297668767460492,0.5923268855460588 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_ours_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7511093099878282,0.7991160460465434 3 | 0.8,0.7068772951306277,0.7590636592462296 4 | 0.85,0.5514590523124819,0.46608031291929175 5 | 0.9,0.47726736302287776,0.28266980404837166 6 | 0.925,0.4776022206699438,0.3188876692584095 7 | 0.95,0.5178926075907001,0.47309486264570305 8 | 0.975,0.5820708532764667,0.7543135004938915 9 | 0.99,0.5559347807218421,0.6707159232315478 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_ours_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7077186331328493,0.8714963508190847 3 | 0.8,0.63842977579086,0.8623555959053709 4 | 0.85,0.6736923721750847,0.873144375803798 5 | 0.9,0.6624755766108583,0.8623452653381494 6 | 0.925,0.6145179045724812,0.8245285688210828 7 | 0.95,0.6035460359494863,0.824146138158631 8 | 0.975,0.5665543137910233,0.7180182873949928 9 | 0.99,0.5089994869550545,0.561254483550248 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5484526742433096,0.6755445396468043,0.75 3 | 0.5375583305500266,0.6422700600698839,0.8 4 | 0.5276449703391143,0.6084819600341776,0.85 5 | 0.5182983066301473,0.5745739350393755,0.9 6 | 0.5140231460523138,0.5578317322740708,0.925 7 | 0.5100441917718254,0.5418165050810417,0.95 8 | 0.5042451830317675,0.5163645979385275,0.975 9 | 0.5021191802339133,0.5076218346952868,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.4842733023199772,0.35697307297193204,0.75 3 | 0.4846970144945225,0.3639603039704773,0.8 4 | 0.48738422086407274,0.39172983656138866,0.85 5 | 0.493407920202715,0.44444829537443653,0.9 6 | 0.4960512688792639,0.4656159460407675,0.925 7 | 0.49810435317221974,0.48102124729692247,0.95 8 | 0.49726049564355507,0.47808735674771635,0.975 9 | 0.49618651008530323,0.4767848750491096,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5266185114652683,0.5875165089400287,0.75 3 | 0.520588057436753,0.5713604570940751,0.8 4 | 0.5144585839458,0.5525084876696955,0.85 5 | 0.5084128633335652,0.5310964229291762,0.9 6 | 0.5051602839305205,0.5203564688855926,0.925 7 | 0.5031955204447977,0.5122866606503461,0.95 8 | 0.5014841248151598,0.5053820048286941,0.975 9 | 0.5017708699601068,0.5063175401555349,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5636557016201504,0.7689021951078637,0.75 3 | 0.562191300369907,0.7626475610655278,0.8 4 | 0.5591372821089273,0.7482152888947853,0.85 5 | 0.549637142297472,0.7087253252878077,0.9 6 | 0.5412371625755495,0.6743033036039996,0.925 7 | 0.5296811289887453,0.6277698699370027,0.95 8 | 0.5156618756895441,0.5700271465867763,0.975 9 | 0.5083498128157121,0.5363383660746106,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5489813759342876,0.6535299621760415,0.75 3 | 0.5361469876290883,0.6188256732440116,0.8 4 | 0.5238369861959302,0.5833329833761731,0.85 5 | 0.5124666536997774,0.5458365691234122,0.9 6 | 0.5084805623723969,0.5300128914549331,0.925 7 | 0.5049616953517785,0.5189595933367483,0.95 8 | 0.502992847974529,0.5116479173702592,0.975 9 | 0.5015778487392585,0.5088863623313091,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6332461350808415,0.8402982038559099,0.75 3 | 0.6198327404477166,0.8282367810986646,0.8 4 | 0.605340362806037,0.808168721051817,0.85 5 | 0.5825818207003794,0.7612519003829814,0.9 6 | 0.5687333025241912,0.7262040029129887,0.925 7 | 0.5532593675502195,0.683087458477005,0.95 8 | 0.5323950580657383,0.6121751146955692,0.975 9 | 0.5186755107274494,0.5602181048623825,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5123609175328977,0.5301353617480892,0.75 3 | 0.5092954282530743,0.5262357996212139,0.8 4 | 0.5068331639668603,0.5226083917162765,0.85 5 | 0.5049240697642532,0.5181861028879936,0.9 6 | 0.5037668421269592,0.5136064226660253,0.925 7 | 0.5020133729624201,0.5062894468978578,0.95 8 | 0.5001887475178587,0.4984806793801374,0.975 9 | 0.4994718728708662,0.49542605045789523,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.43607480718479874,0.15122430988658192,0.75 3 | 0.4399774765780153,0.16980463058269948,0.8 4 | 0.44602347932927844,0.203519026277516,0.85 5 | 0.45703044672777626,0.27428448894021873,0.9 6 | 0.46512956983746434,0.3271218333501061,0.925 7 | 0.47554989765977906,0.3875403585730425,0.95 8 | 0.48601908208036837,0.44098320138372016,0.975 9 | 0.49265456541256325,0.47000207578193043,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5434487745532761,0.6412279254800172,0.75 3 | 0.532476553258624,0.6118428351011109,0.8 4 | 0.5229952624919562,0.583705674089794,0.85 5 | 0.5140331927972154,0.5553284004296843,0.9 6 | 0.5102566477234994,0.5419926890691376,0.925 7 | 0.506075657797496,0.5259353282951272,0.95 8 | 0.5024269438913781,0.5115684098358602,0.975 9 | 0.50246205793512,0.510162920927001,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_random_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.5338112543752822,0.6389281867309192,0.75 3 | 0.5234644913865159,0.5949152033833311,0.8 4 | 0.5153981418527536,0.5600437021877598,0.85 5 | 0.5079928628210064,0.5327522110776806,0.9 6 | 0.5055026591145968,0.522542161535517,0.925 7 | 0.5036342968892644,0.5142601804744839,0.95 8 | 0.5026495879299556,0.510180006823429,0.975 9 | 0.5028915939844731,0.509941309767113,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_timex_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.5437427056790309,0.6706349706788446 3 | 0.8,0.5413672998006205,0.6603247598242958 4 | 0.85,0.5422242208062803,0.6616889621956308 5 | 0.9,0.5600128068680306,0.7244700261377012 6 | 0.925,0.5717125332686581,0.7639232820760644 7 | 0.95,0.5551499305015345,0.7272124396760031 8 | 0.975,0.4489505363997698,0.19083584314621618 9 | 0.99,0.43855624355899203,0.1621761277847657 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_timex_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.6709463714583829,0.8533088144729333 3 | 0.8,0.6098496046553401,0.7632155488929878 4 | 0.85,0.5453071300700374,0.619926619553395 5 | 0.9,0.5035407880861319,0.4685415521686709 6 | 0.925,0.4857952414017798,0.3599911676277563 7 | 0.95,0.4836276856682646,0.3522373360277302 8 | 0.975,0.5375249509871131,0.6301683207764921 9 | 0.99,0.5833744447559595,0.6760728005055863 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_timex_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7045884739794994,0.8783307304087256 3 | 0.8,0.6639068062246125,0.8672260327267325 4 | 0.85,0.6414010238288649,0.8634780335778957 5 | 0.9,0.6513219905763974,0.8708551725517901 6 | 0.925,0.6239325922188033,0.8509687991227405 7 | 0.95,0.596632759876253,0.8121806509547882 8 | 0.975,0.5613476040638865,0.7148516690287199 9 | 0.99,0.51919534744326,0.5982257655402208 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_timex_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7535619192727501,0.8340881163123923 3 | 0.8,0.7547288086794338,0.8384624231765214 4 | 0.85,0.7882137347571472,0.8693056189962379 5 | 0.9,0.6306588125559458,0.8452615127918696 6 | 0.925,0.6271884853731646,0.8428993439974213 7 | 0.95,0.6402379451751636,0.8540792982326931 8 | 0.975,0.6726383664220059,0.8566569796843246 9 | 0.99,0.532755250658538,0.5757131370260729 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/boiler_timex_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.704213527349532,0.8790343439759825 3 | 0.8,0.6633766077113757,0.8548196100102629 4 | 0.85,0.5985419569318197,0.7977867763358841 5 | 0.9,0.5655429746967109,0.7652897229121178 6 | 0.925,0.5592033447205661,0.7549000938515742 7 | 0.95,0.5616732686713325,0.7562826243096175 8 | 0.975,0.5685903276711435,0.7125143230214306 9 | 0.99,0.5293914286372051,0.6347577514565154 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9726204088344977,0.9852315689981097 3 | 0.8,0.9726473116542793,0.9752528355387523 4 | 0.85,0.9741326849416169,0.9854336011342155 5 | 0.9,0.9750819231876233,0.9851807655954632 6 | 0.925,0.9754137964991171,0.9851795841209829 7 | 0.95,0.9735338397635247,0.9835869565217392 8 | 0.975,0.9667890688492673,0.9773667887523629 9 | 0.99,0.9037487979905816,0.9030703568052929 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9308452334584287,0.9666115311909262 3 | 0.8,0.9331695711762336,0.9675330812854441 4 | 0.85,0.9293368747599057,0.9670049621928167 5 | 0.9,0.929532765612363,0.9686991965973535 6 | 0.925,0.9283094760169656,0.9679253308128545 7 | 0.95,0.9302164439148142,0.9709227315689981 8 | 0.975,0.9287384989725312,0.9690666351606805 9 | 0.99,0.9188840809420651,0.9597616375236294 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9492591254253879,0.9718363067107749 3 | 0.8,0.9480917175651344,0.9712783553875235 4 | 0.85,0.9451140713414025,0.9701202150283554 5 | 0.9,0.9472745153594113,0.9701662925330812 6 | 0.925,0.9484507386068366,0.969906663516068 7 | 0.95,0.9479076406589076,0.9703139768431002 8 | 0.975,0.9493735911031405,0.9715554111531192 9 | 0.99,0.9349945960768968,0.9596863185255198 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9517487933714286,0.973718100189036 3 | 0.8,0.9526419325981752,0.9714641422495274 4 | 0.85,0.952779016044128,0.9696431947069943 5 | 0.9,0.951291646892418,0.9666942344045368 6 | 0.925,0.9493346025306623,0.964521502835539 7 | 0.95,0.9464535718704525,0.9619624291115312 8 | 0.975,0.9379324286421671,0.9530021266540643 9 | 0.99,0.8642961341448604,0.8802463374291116 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9866791465291753,0.9944089673913044 3 | 0.8,0.9858502471986903,0.9932313327032136 4 | 0.85,0.9847632126410042,0.9927670132325142 5 | 0.9,0.9832258401985134,0.9921042060491494 6 | 0.925,0.9785135780101495,0.99093661389414 7 | 0.95,0.9796869283807297,0.9907384215500945 8 | 0.975,0.9806547832492816,0.9906518785444235 9 | 0.99,0.973672394189787,0.982507974952741 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9734652765043705,0.9785326086956522 3 | 0.8,0.9720591198126262,0.9756486294896031 4 | 0.85,0.9695303720812122,0.9714916115311909 5 | 0.9,0.9652944087610676,0.9666319116257089 6 | 0.925,0.9624433962425005,0.9656131852551983 7 | 0.95,0.9602423836122383,0.9644491375236295 8 | 0.975,0.9580698175391607,0.9647004962192818 9 | 0.99,0.9481276349349945,0.9568611176748583 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.952965088975404,0.9675484404536862 3 | 0.8,0.9528114949442194,0.9672176275992439 4 | 0.85,0.9493094820045211,0.9663787807183364 5 | 0.9,0.9492643395720652,0.9657334002835538 6 | 0.925,0.9483810862719719,0.9641528827977315 7 | 0.95,0.9481035332462268,0.964491965973535 8 | 0.975,0.9422975205656896,0.9570982986767487 9 | 0.99,0.9178511232128017,0.9279114484877127 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.973863568770395,0.9721219281663516 3 | 0.8,0.970934030032341,0.9671857277882798 4 | 0.85,0.9695616468176735,0.9654052457466918 5 | 0.9,0.9689117521006996,0.9651417769376182 6 | 0.925,0.9687929188345448,0.9654525047258979 7 | 0.95,0.9692897364256242,0.9674273393194708 8 | 0.975,0.9717597632464114,0.9746313799621928 9 | 0.99,0.9727193092613013,0.9807966091682419 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8650600901238026,0.8634132797731568 3 | 0.8,0.8606333096808503,0.8616505198487714 4 | 0.85,0.853568979382561,0.8605174858223064 5 | 0.9,0.8435314731061212,0.8581805293005671 6 | 0.925,0.8394829577511596,0.8540397566162571 7 | 0.95,0.8337256517056144,0.8519766068052931 8 | 0.975,0.8121404713679924,0.8349415170132325 9 | 0.99,0.7549923507801681,0.7757717982041588 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_dyna_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9594762715258567,0.9829052457466918 3 | 0.8,0.9588271830770572,0.980203213610586 4 | 0.85,0.9593141505553011,0.9778615311909262 5 | 0.9,0.9569639420863554,0.9747294423440453 6 | 0.925,0.9631113485241709,0.974327150283554 7 | 0.95,0.9632163954706268,0.9737358223062382 8 | 0.975,0.9664120669518724,0.9732927693761815 9 | 0.99,0.91356859090931,0.9150735467863895 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9619895806288483,0.9754465973534971 3 | 0.8,0.955027977829543,0.9743454631379962 4 | 0.85,0.9447399120245021,0.972476370510397 5 | 0.9,0.9452758690486291,0.9719305293005671 6 | 0.925,0.9467061892376631,0.9719423440453685 7 | 0.95,0.9495775133637652,0.9720723062381853 8 | 0.975,0.9544109244593892,0.9718185845935727 9 | 0.99,0.9542223849030205,0.9711950614366729 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9627941974596637,0.9755505671077505 3 | 0.8,0.9554159386125108,0.9744742438563327 4 | 0.85,0.9456389605521056,0.9725602551984878 5 | 0.9,0.9443025865519153,0.9717013232514178 6 | 0.925,0.9453728468274865,0.9715040170132324 7 | 0.95,0.9472783802059761,0.9717639413988657 8 | 0.975,0.9548562669027927,0.9718631852551984 9 | 0.99,0.9555459453220119,0.971414815689981 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.968199833424332,0.9801902173913044 3 | 0.8,0.968043697261174,0.9807195179584121 4 | 0.85,0.9659851421868,0.9796124763705103 5 | 0.9,0.9647423140578784,0.9802989130434783 6 | 0.925,0.9623321807103304,0.9803757088846882 7 | 0.95,0.9602341858103116,0.9805679938563328 8 | 0.975,0.9554853388402041,0.9797232396030247 9 | 0.99,0.9554839868980565,0.9785459002835539 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9684309144576737,0.9801772211720227 3 | 0.8,0.9682033131205801,0.9804823369565218 4 | 0.85,0.9662141588088877,0.9795696479206049 5 | 0.9,0.9652170120860457,0.9803331758034026 6 | 0.925,0.9629380634078568,0.9805103969754254 7 | 0.95,0.9600845652771406,0.9806885042533082 8 | 0.975,0.9547481076377412,0.9799391540642721 9 | 0.99,0.9545507955960422,0.9788306356332703 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9354621117777822,0.968234877126654 3 | 0.8,0.9367430619841542,0.9697802457466919 4 | 0.85,0.9380817802342584,0.9701535916824198 5 | 0.9,0.9380296077331314,0.9706970699432893 6 | 0.925,0.9429866190623252,0.9726287807183365 7 | 0.95,0.9461570442993025,0.9737384806238185 8 | 0.975,0.9575018957999542,0.9768436909262761 9 | 0.99,0.967076483022544,0.9781952977315689 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.934537552414306,0.967718572778828 3 | 0.8,0.9362668347169969,0.9695758506616257 4 | 0.85,0.9368654205038406,0.9698676748582231 5 | 0.9,0.9357517766686803,0.9702540170132326 6 | 0.925,0.9403199504298827,0.9720212074669188 7 | 0.95,0.9427516750188276,0.9729708175803403 8 | 0.975,0.9520869953837657,0.9762361176748582 9 | 0.99,0.9645330829618508,0.9777720344990548 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.968199833424332,0.9801902173913044 3 | 0.8,0.968043697261174,0.9807195179584121 4 | 0.85,0.9659851421868,0.9796124763705103 5 | 0.9,0.9647423140578784,0.9802989130434783 6 | 0.925,0.9623321807103304,0.9803757088846882 7 | 0.95,0.9602341858103116,0.9805679938563328 8 | 0.975,0.9554853388402041,0.9797232396030247 9 | 0.99,0.9554839868980565,0.9785459002835539 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9611467043843142,0.9787783553875236 3 | 0.8,0.9614139045270933,0.9786613894139887 4 | 0.85,0.9578106287418546,0.9763823251417769 5 | 0.9,0.9583622602495684,0.9755142367674857 6 | 0.925,0.9593989107889417,0.9749858223062382 7 | 0.95,0.9607901030114592,0.9746668241965972 8 | 0.975,0.9608742328742775,0.972663634215501 9 | 0.99,0.9600195461503664,0.9701187381852553 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9335527948996579,0.9762603379017014 3 | 0.8,0.9349811188617343,0.9756651701323251 4 | 0.85,0.9315155113564172,0.9739600070888468 5 | 0.9,0.933625885638729,0.9739047731568997 6 | 0.925,0.9402038194942126,0.9748546786389414 7 | 0.95,0.9420345457316481,0.9743501890359169 8 | 0.975,0.9567128959771032,0.9735408790170132 9 | 0.99,0.954896090843619,0.9718918360113422 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_ours_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9209529717113999,0.9735574196597353 3 | 0.8,0.9193155662182833,0.9724810964083176 4 | 0.85,0.9186637532470048,0.9715146502835539 5 | 0.9,0.9171577753217959,0.9708615902646502 6 | 0.925,0.9189159701994271,0.9709120982986768 7 | 0.95,0.9271899845308809,0.9710213846880906 8 | 0.975,0.9499431037202767,0.9726931710775047 9 | 0.99,0.9548521242740542,0.9715728379017012 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.928070340995353,0.9406221349243857,0.75 3 | 0.9233305385507201,0.9348431592627598,0.8 4 | 0.9166901031559037,0.9286363126181474,0.85 5 | 0.903211312137641,0.9132399279300566,0.9 6 | 0.8908270998081903,0.8997701736767485,0.925 7 | 0.8697939424146913,0.8758864012287335,0.95 8 | 0.7982298625325848,0.7919963078922495,0.975 9 | 0.723462642343999,0.7050022152646502,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9282202085102422,0.9389639059546313,0.75 3 | 0.9231060193341124,0.9337709711720226,0.8 4 | 0.916585712979557,0.9281506379962192,0.85 5 | 0.9044782309223613,0.9147561732041588,0.9 6 | 0.8890030973756067,0.8978488598771266,0.925 7 | 0.8646077771406514,0.8692839969281664,0.95 8 | 0.7951754145157067,0.7882757266068052,0.975 9 | 0.7138757344994704,0.6943636873818526,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9632867141085919,0.9759115075614367,0.75 3 | 0.9618656923807837,0.973992320415879,0.8 4 | 0.9582264906986527,0.9699524161153119,0.85 5 | 0.9504397536570485,0.9616083116729678,0.9 6 | 0.9428029992595048,0.9532445947542533,0.925 7 | 0.928098777854486,0.9387952800094519,0.95 8 | 0.8676191857204618,0.8791890063799622,0.975 9 | 0.7890281131180782,0.8016896561909264,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9646365381598606,0.9789193052930056,0.75 3 | 0.9622545480279598,0.9762059014650284,0.8 4 | 0.9590429668280148,0.9728393194706995,0.85 5 | 0.9509733452329527,0.9639562854442344,0.9 6 | 0.94094626186448,0.9528466446124764,0.925 7 | 0.9243790378563601,0.9352619328922496,0.95 8 | 0.8688502276117305,0.8803601134215502,0.975 9 | 0.7925048185686518,0.8048823251417769,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9439872312344271,0.9588318466446125,0.75 3 | 0.9450627536954916,0.9584729442344045,0.8 4 | 0.9429202342362786,0.9556204808601134,0.85 5 | 0.940080841128619,0.9512402232986767,0.9 6 | 0.9331318660900927,0.9447668064744802,0.925 7 | 0.9214182568945093,0.9333966800567108,0.95 8 | 0.8749382780665504,0.8859930883742912,0.975 9 | 0.8185744385864322,0.8292989130434781,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9444030654088451,0.9614565808128545,0.75 3 | 0.9457652774822478,0.9608176098771267,0.8 4 | 0.9436369648774068,0.9584253012759925,0.85 5 | 0.9380441932755985,0.9526808246691871,0.9 6 | 0.9308893811895066,0.9457864484877128,0.925 7 | 0.9196315400674147,0.9335328745274103,0.95 8 | 0.8697785260954383,0.8827008506616257,0.975 9 | 0.8071028487944713,0.8203801098771267,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9554078449688838,0.9679639059546314,0.75 3 | 0.953035851620632,0.9658760633270322,0.8 4 | 0.9492522187582765,0.9622703213610586,0.85 5 | 0.9395049487662526,0.951937706758034,0.9 6 | 0.9274923929729588,0.939281722589792,0.925 7 | 0.9053170773868482,0.9160519258034027,0.95 8 | 0.8279105572223158,0.8300184014650283,0.975 9 | 0.7359205074505832,0.7270333766540642,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9580742443844468,0.9716741788752363,0.75 3 | 0.9564986989871965,0.9708069175330813,0.8 4 | 0.9512076472992052,0.9652684605387524,0.85 5 | 0.9397106913251896,0.9531773688563326,0.9 6 | 0.9283255629946833,0.9416799681001891,0.925 7 | 0.906734636713382,0.9185712429111531,0.95 8 | 0.8225856613759245,0.8264785562381853,0.975 9 | 0.7376653670728005,0.7339433187618147,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9568938628162809,0.9733696833648393,0.75 3 | 0.9553554831461624,0.9711000708884688,0.8 4 | 0.9528028189186916,0.9677047495274103,0.85 5 | 0.9470885276587386,0.9601151051512288,0.9 6 | 0.9376223405233232,0.9488987476370511,0.925 7 | 0.9195261782584708,0.9317308896502835,0.95 8 | 0.849766306701466,0.8641978674385633,0.975 9 | 0.7682307217682773,0.7824651465028356,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_random_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.9553916182396993,0.9783802575614367,0.75 3 | 0.9543330529740196,0.9756199196597353,0.8 4 | 0.9552954319751003,0.972172495274102,0.85 5 | 0.9455537737301535,0.9626512582703214,0.9 6 | 0.937857299076906,0.9519544541587901,0.925 7 | 0.9185641434257154,0.9310025401701324,0.95 8 | 0.8484365684573174,0.8594701677693761,0.975 9 | 0.7690078757231951,0.7797279655009451,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9483506578816343,0.9569278709829868 3 | 0.8,0.938883059203548,0.9496112948960302 4 | 0.85,0.9278470773165113,0.9454194234404538 5 | 0.9,0.926606483286396,0.9432679584120982 6 | 0.925,0.9237560267929734,0.9395687618147448 7 | 0.95,0.9244845301952764,0.9405092155009451 8 | 0.975,0.9299055674439254,0.9448186436672967 9 | 0.99,0.929761172854085,0.9452292060491494 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9506547806928851,0.9582608695652175 3 | 0.8,0.9401293210153852,0.9505056710775046 4 | 0.85,0.928537793633523,0.9456013705103969 5 | 0.9,0.9254947158803113,0.9431202741020794 6 | 0.925,0.9230378591822611,0.9395185491493384 7 | 0.95,0.923681581241671,0.94039106805293 8 | 0.975,0.9284820837615664,0.9450729560491493 9 | 0.99,0.9306481404976077,0.9456279536862003 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9705572774044817,0.984102079395085 3 | 0.8,0.9706765185458115,0.984352551984877 4 | 0.85,0.9705889094375356,0.9844009924385633 5 | 0.9,0.9700393333649231,0.9839992911153119 6 | 0.925,0.9694056582322693,0.983523156899811 7 | 0.95,0.9681478908735714,0.962218513705104 8 | 0.975,0.9661905450965054,0.9808580458412097 9 | 0.99,0.9646084781981762,0.9602368856332704 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9702695750396955,0.9836826559546314 3 | 0.8,0.9706237057802289,0.9842202268431002 4 | 0.85,0.9706041207640148,0.9842899338374291 5 | 0.9,0.9700632456612518,0.9840489130434782 6 | 0.925,0.9695267475203527,0.98367084120983 7 | 0.95,0.9683485469242029,0.9825590737240075 8 | 0.975,0.9664121328897877,0.9811870864839319 9 | 0.99,0.9643925082995208,0.9803479442344045 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9384155366021241,0.9640843572778828 3 | 0.8,0.9448814414468552,0.9654347826086958 4 | 0.85,0.9459257693301024,0.966984877126654 5 | 0.9,0.9468716968546738,0.9697510042533082 6 | 0.925,0.9519850093703481,0.9717143194706994 7 | 0.95,0.9543062196905001,0.9722137878071834 8 | 0.975,0.9601035529147848,0.9733586365784499 9 | 0.99,0.9607986875326124,0.9624976370510399 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9355093154500134,0.962139650283554 3 | 0.8,0.9431511556556702,0.9647019730623819 4 | 0.85,0.9437568071353983,0.9669553402646502 5 | 0.9,0.9436563492623087,0.969117438563327 6 | 0.925,0.9496120898932373,0.9710775047258979 7 | 0.95,0.9518459923994633,0.9715749054820415 8 | 0.975,0.9595474598835081,0.9734634924385632 9 | 0.99,0.9610220994411323,0.9727876890359168 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9631469525804806,0.979695179584121 3 | 0.8,0.963646811513176,0.9795534026465029 4 | 0.85,0.9614252506628991,0.9778003898865784 5 | 0.9,0.9620033848256696,0.9770206167296787 6 | 0.925,0.9627964518722989,0.97648422731569 7 | 0.95,0.9628950361133267,0.9755118738185256 8 | 0.975,0.961607293716048,0.9629011105860113 9 | 0.99,0.9602333993762998,0.9501473889413987 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9611467043843142,0.9787783553875236 3 | 0.8,0.9614139045270933,0.9786613894139887 4 | 0.85,0.9578106287418546,0.9763823251417769 5 | 0.9,0.9583622602495684,0.9755142367674857 6 | 0.925,0.9593989107889417,0.9749858223062382 7 | 0.95,0.9607901030114592,0.9746668241965972 8 | 0.975,0.9608742328742775,0.972663634215501 9 | 0.99,0.9600195461503664,0.9701187381852553 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9338922938678077,0.9777504725897921 3 | 0.8,0.9343933638111157,0.9779974007561437 4 | 0.85,0.9305383629573024,0.9767592155009451 5 | 0.9,0.9317712400601732,0.9759129844045369 6 | 0.925,0.9401776962156274,0.9766233459357279 7 | 0.95,0.949910989729049,0.977015595463138 8 | 0.975,0.962551646595752,0.9656506970699432 9 | 0.99,0.9670848086065489,0.9464044777882799 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/epilepsy_timex_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9192857753207121,0.9745959357277882 3 | 0.8,0.9138023626836721,0.9736389413988658 4 | 0.85,0.9206220050165796,0.9752681947069943 5 | 0.9,0.9187793107427881,0.9735952268431002 6 | 0.925,0.9167266587630827,0.97335952268431 7 | 0.95,0.9281321023097735,0.9731758034026464 8 | 0.975,0.9571317087839928,0.9753833884688091 9 | 0.99,0.9666861778320438,0.9761040879017013 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8647125174838393,0.9390913302322046 3 | 0.8,0.8515304285432562,0.9341297212737945 4 | 0.85,0.8304721140949025,0.9175726307267447 5 | 0.9,0.7686578033469917,0.9005786095595832 6 | 0.925,0.7026468109744722,0.8644670166855431 7 | 0.95,0.5971646232441137,0.8078966192530464 8 | 0.975,0.48690781829671764,0.7590125746576589 9 | 0.99,0.44434022173227805,0.7326004802425553 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8615174198522577,0.9590717233733619 3 | 0.8,0.8510707374043716,0.9555449063682698 4 | 0.85,0.8330866167527277,0.9501800383478408 5 | 0.9,0.777713352248724,0.9350345203603945 6 | 0.925,0.7351736186952617,0.9199861672188185 7 | 0.95,0.657656242692305,0.8973456306309961 8 | 0.975,0.5280023611847845,0.8709790455086411 9 | 0.99,0.475033668482006,0.8368225947149475 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8823649079371847,0.9249236681650478 3 | 0.8,0.8709974656567131,0.9118397578273983 4 | 0.85,0.8552565094753309,0.9071838900001521 5 | 0.9,0.8307843389234486,0.8885982954488476 6 | 0.925,0.8023256064895024,0.8686701718766654 7 | 0.95,0.7353417704069665,0.8400739580984889 8 | 0.975,0.570448885687366,0.79238600314329 9 | 0.99,0.4230202668457375,0.7292292624534021 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9076581983635843,0.9799310362337901 3 | 0.8,0.9018084527370192,0.97840453682847 4 | 0.85,0.8903998360180905,0.9752679321869493 5 | 0.9,0.8642948506959193,0.967838159437469 6 | 0.925,0.8317847912311669,0.9581028761754475 7 | 0.95,0.7751154750188312,0.9406971733822788 8 | 0.975,0.6232308563186895,0.9031266663246721 9 | 0.99,0.4811124870038694,0.8567092405323851 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8647125174838393,0.9390913302322046 3 | 0.8,0.8515304285432562,0.9241297212737945 4 | 0.85,0.8304721140949025,0.9175726307267447 5 | 0.9,0.7686578033469917,0.8905786095595832 6 | 0.925,0.7026468109744722,0.8744670166855431 7 | 0.95,0.5971646232441137,0.8578966192530464 8 | 0.975,0.48690781829671764,0.7590125746576589 9 | 0.99,0.44434022173227805,0.7226004802425553 -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7701988219874111,0.9593733197534988 3 | 0.8,0.7573306323724439,0.9564234054138914 4 | 0.85,0.7364008145959144,0.9521524214042064 5 | 0.9,0.716240374233695,0.9474485067732443 6 | 0.925,0.7070078822720385,0.9448109136878162 7 | 0.95,0.7076203479896552,0.9409819204829504 8 | 0.975,0.6743555078631768,0.9268139147786119 9 | 0.99,0.5575986390378173,0.883588300405207 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.897305970718056,0.944665943771487 3 | 0.8,0.8717255880234251,0.9398840463551557 4 | 0.85,0.8328797006125186,0.9313280247506227 5 | 0.9,0.7813660232764095,0.9155015833698798 6 | 0.925,0.7373189078107967,0.9020815770426168 7 | 0.95,0.6710946848139067,0.86486735879082 8 | 0.975,0.5779849238172223,0.8185548661590747 9 | 0.99,0.42271030277987043,0.7506563332767672 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9011069908448994,0.9761260534126748 3 | 0.8,0.8852254436610192,0.9719882546222687 4 | 0.85,0.8399528050412147,0.9631575392953058 5 | 0.9,0.7847660904428257,0.9485189110959292 6 | 0.925,0.7443203096231732,0.9317154416566193 7 | 0.95,0.7092770066726084,0.9140797134816014 8 | 0.975,0.6040267364848833,0.8789154135924135 9 | 0.99,0.44772929658432936,0.8222125705459206 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8059189415309609,0.9534258583646458 3 | 0.8,0.7769229053961487,0.9448277227053746 4 | 0.85,0.744578597848158,0.9342673196910962 5 | 0.9,0.7040259064490654,0.9143477995211176 6 | 0.925,0.6645763303537094,0.8876700992554234 7 | 0.95,0.5570740405216337,0.8048528036342726 8 | 0.975,0.43898859868423357,0.784163498432783 9 | 0.99,0.35076178197479607,0.7151945057997157 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_dyna_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8444952563759718,0.9603003446254632 3 | 0.8,0.8273912580281579,0.9566901623090506 4 | 0.85,0.8102713528382326,0.9515290120761519 5 | 0.9,0.7957033740457712,0.937019175778896 6 | 0.925,0.7752864256502285,0.9249395074736706 7 | 0.95,0.6931102157930866,0.8992784824982292 8 | 0.975,0.518604486732204,0.8474387946532471 9 | 0.99,0.3888871732058109,0.8029705491875869 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_ours_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7880257114368723,0.9433880970623606 3 | 0.8,0.7723705538472947,0.9386165294969775 4 | 0.85,0.757430096536775,0.9331253153832735 5 | 0.9,0.7446068735063893,0.9232458138219346 6 | 0.925,0.7387568286307833,0.9171950608725763 7 | 0.95,0.6780107322367006,0.89441540584142 8 | 0.975,0.5596431599189239,0.8492145221348352 9 | 0.99,0.519493838596673,0.8378989998814939 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_ours_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7998914354836859,0.9497159448885077 3 | 0.8,0.7783444986351193,0.9435410469943191 4 | 0.85,0.7315588500321236,0.9302117799684992 5 | 0.9,0.7315367942678922,0.9204870847054651 6 | 0.925,0.707855698627865,0.9071823292862147 7 | 0.95,0.6603430834187399,0.8889442339160927 8 | 0.975,0.5323869429218702,0.8479446537361636 9 | 0.99,0.30339985284138793,0.7408031960661732 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_ours_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.6996162029384698,0.9295999340035845 3 | 0.8,0.6795479834394202,0.9222297704721208 4 | 0.85,0.6391254103413129,0.9084195689176342 5 | 0.9,0.6153106798594652,0.9027161752065663 6 | 0.925,0.6117946842138073,0.8997968241348502 7 | 0.95,0.6060123384628564,0.8948149704837329 8 | 0.975,0.5686463705281969,0.8732768906614342 9 | 0.99,0.4676359708174619,0.8180749467179467 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_ours_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8502617897721692,0.9608373507710237 3 | 0.8,0.8377911405815164,0.9580850273667212 4 | 0.85,0.832315626816783,0.9570727034894859 5 | 0.9,0.8023729867842333,0.9508099470965967 6 | 0.925,0.7650548747404577,0.9375795333981133 7 | 0.95,0.6130825606630395,0.8710453939461016 8 | 0.975,0.5457998357806932,0.8332931835358832 9 | 0.99,0.47021906657281903,0.806775876355663 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_ours_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7337032771175518,0.9227372633526045 3 | 0.8,0.6911155545709119,0.9100806705018041 4 | 0.85,0.6632486764442372,0.8979009154002457 5 | 0.9,0.6302694057716269,0.8834772644415146 6 | 0.925,0.5683974908211105,0.8616070485341856 7 | 0.95,0.46938666664254464,0.8167587902454857 8 | 0.975,0.41385810934350625,0.7687352244729512 9 | 0.99,0.382067926621131,0.7370019836602031 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6243175624698509,0.8944806072377295,0.75 3 | 0.5468163296195719,0.8643126340235,0.8 4 | 0.4660761193151422,0.8254748141534431,0.85 5 | 0.3765036556681825,0.768976293203776,0.9 6 | 0.3305584291584018,0.730704425382297,0.925 7 | 0.27239030930331276,0.6838460430549543,0.95 8 | 0.21451431424261508,0.6261023883569116,0.975 9 | 0.17329992131428032,0.5706261063551771,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6849506114358814,0.925303145837488,0.75 3 | 0.5877740287256166,0.8953809475188882,0.8 4 | 0.4941525756745621,0.8555098155814825,0.85 5 | 0.37838886652765835,0.7941619513911904,0.9 6 | 0.33746688191182583,0.7587577978806196,0.925 7 | 0.3032421881702755,0.7319281153106156,0.95 8 | 0.29254637930501354,0.710936698822587,0.975 9 | 0.24810058108134941,0.6873442299310628,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6579832865163592,0.9035950548635154,0.75 3 | 0.581198104223193,0.8729549136128402,0.8 4 | 0.5018920106679938,0.8338885057006641,0.85 5 | 0.4085029122656846,0.7841737159258659,0.9 6 | 0.3512303015870989,0.7504161770105383,0.925 7 | 0.2881604983025109,0.7057795241609895,0.95 8 | 0.2201272408941774,0.6373157428486657,0.975 9 | 0.1748853945532707,0.574430870431051,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6628435883061753,0.9061655283503075,0.75 3 | 0.5935257663563152,0.8823866110394528,0.8 4 | 0.5319450655437059,0.8574016614670009,0.85 5 | 0.47685895142647344,0.8264853079286949,0.9 6 | 0.4491762584901619,0.8097759301294255,0.925 7 | 0.4120612497551006,0.7919073253992428,0.95 8 | 0.33444774833625907,0.7519118192199417,0.975 9 | 0.2433069268772304,0.6641900460604081,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6306689597372594,0.9119474145841677,0.75 3 | 0.5636082151673678,0.8862876235661451,0.8 4 | 0.4864741170506566,0.8507655729592463,0.85 5 | 0.3971032220514545,0.7975112339747171,0.9 6 | 0.34460894569606915,0.7593908682883639,0.925 7 | 0.2880401235996496,0.7105797901680273,0.95 8 | 0.2264476975064841,0.647200097657601,0.975 9 | 0.17984229942147167,0.5959209762366557,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6196489738818959,0.9220619672442902,0.75 3 | 0.5677252719332474,0.8997467327859067,0.8 4 | 0.5332310983617279,0.8779427337894322,0.85 5 | 0.49473544107223166,0.8604226232804046,0.9 6 | 0.47392500344258287,0.8528619834425271,0.925 7 | 0.4380207949780166,0.8312460635421622,0.95 8 | 0.36436823798222717,0.78400557016059,0.975 9 | 0.294084126166099,0.7215488992826791,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6510384265225133,0.9036674831851513,0.75 3 | 0.5866964383612864,0.8812861725723617,0.8 4 | 0.5101564027556771,0.850336386385863,0.85 5 | 0.42629055711997355,0.8055531167999106,0.9 6 | 0.36999603886980376,0.7709742800567139,0.925 7 | 0.3076252972140127,0.7276694932595953,0.95 8 | 0.2396965713232647,0.6628070784044688,0.975 9 | 0.18536215116785704,0.6009638502763031,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6346654125168086,0.9152245306187228,0.75 3 | 0.5997683240775665,0.8926936486221095,0.8 4 | 0.5725218653138555,0.8795370402194076,0.85 5 | 0.5276423410646356,0.8670120263274992,0.9 6 | 0.4865906405217153,0.8547425858758565,0.925 7 | 0.41953936124634855,0.8190469093246694,0.95 8 | 0.3101430766118969,0.7255436888978604,0.975 9 | 0.20403214875165596,0.6102391668624677,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.45252698320414053,0.843984490431834,0.75 3 | 0.410550921474084,0.8215929270164999,0.8 4 | 0.35676780188963153,0.7910473778091041,0.85 5 | 0.3022443268876998,0.7471398336894484,0.9 6 | 0.2827540624629441,0.7245526460026314,0.925 7 | 0.2552273624312921,0.6949619481716109,0.95 8 | 0.2086363544334157,0.6263257377913332,0.975 9 | 0.17393756369605692,0.5739626811460445,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_random_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | auprc,auroc,thresh 2 | 0.6398437672217054,0.9118524687576197,0.75 3 | 0.5522278795683419,0.8797225766749458,0.8 4 | 0.4546555979704651,0.838377575906702,0.85 5 | 0.3896404838803428,0.7988448658429232,0.9 6 | 0.3588462952449195,0.783725357572086,0.925 7 | 0.33453385293611687,0.7635274088408783,0.95 8 | 0.30224880078594624,0.7332794951852599,0.975 9 | 0.25454657983185763,0.6732607116670708,0.99 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=1_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8749957483338546,0.9460678980200609 3 | 0.8,0.8568912197942121,0.9311579874433327 4 | 0.85,0.8345641263628285,0.9313244092901858 5 | 0.9,0.777974036942304,0.9118347854236588 6 | 0.925,0.7084456914866025,0.9022886576529885 7 | 0.95,0.6124972121828816,0.8611883759953225 8 | 0.975,0.4940584587153083,0.7979517635181812 9 | 0.99,0.4442204021328776,0.7496105341283219 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=1_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9019144472572358,0.9736850227700725 3 | 0.8,0.8883809301115094,0.9694731599968763 4 | 0.85,0.8717500634041633,0.9619123689088003 5 | 0.9,0.8239104016020898,0.9446257415628216 6 | 0.925,0.7658046742905984,0.9267404919052233 7 | 0.95,0.6465091500850224,0.8962631405132164 8 | 0.975,0.5363980866572248,0.8466420593476082 9 | 0.99,0.49511891343515146,0.8288620110570069 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=2_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8985846936754498,0.9422827218471152 3 | 0.8,0.8931764838553655,0.9407740623516136 4 | 0.85,0.8756372763642634,0.9359101967544194 5 | 0.9,0.8387863549557144,0.916751143037228 6 | 0.925,0.7920152428270602,0.8839587655792816 7 | 0.95,0.7260711888428191,0.8399180785102115 8 | 0.975,0.6234895113176533,0.7992350414069103 9 | 0.99,0.45683865801023354,0.7382629287453247 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=2_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.9007181507563709,0.9734540078837148 3 | 0.8,0.8878392111269547,0.9706804312931043 4 | 0.85,0.8683478652867976,0.9656053974876113 5 | 0.9,0.8334780001496029,0.9552150055993307 6 | 0.925,0.7980822995460874,0.9434140413084691 7 | 0.95,0.7381325965139602,0.9222919425231624 8 | 0.975,0.6478050205804369,0.8943400754595793 9 | 0.99,0.4774661075726392,0.8216969206001683 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=3_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7762036158269223,0.952588536377776 3 | 0.8,0.7631552965527236,0.9491020582968475 4 | 0.85,0.7310047194507772,0.9423812267800327 5 | 0.9,0.6783047423443693,0.9282212890095586 6 | 0.925,0.6583180748078481,0.9107447346840366 7 | 0.95,0.6334513062882238,0.8889453695408365 8 | 0.975,0.5763842740812346,0.8199616075131759 9 | 0.99,0.4786601882295928,0.734092790588474 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=3_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.7884153871260706,0.9500487446639573 3 | 0.8,0.7798089690680725,0.9466616378944752 4 | 0.85,0.7614176175147416,0.9413346476274345 5 | 0.9,0.7243005004940626,0.9330316189751036 6 | 0.925,0.7052301691449974,0.9279995666215559 7 | 0.95,0.6918333247771871,0.9204178855920444 8 | 0.975,0.6244063095542642,0.8915070540603852 9 | 0.99,0.525415520690039,0.8440949097581414 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=4_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.903432515121722,0.9508877976043613 3 | 0.8,0.8904535995121526,0.9481223470838842 4 | 0.85,0.8659293753528655,0.9421205627017893 5 | 0.9,0.8000592757697149,0.9178649610563692 6 | 0.925,0.7698678308765258,0.8972446024057711 7 | 0.95,0.7213128751752558,0.8565569823867838 8 | 0.975,0.6111572722989054,0.8415051607120674 9 | 0.99,0.48034397233394965,0.8091106111685915 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=4_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.903226192496092,0.9718918415934336 3 | 0.8,0.8904137267474049,0.969765526164206 4 | 0.85,0.8579317584601727,0.9650120238445767 5 | 0.9,0.8106451673156063,0.9468509514840004 6 | 0.925,0.8028294951007725,0.933106228455866 7 | 0.95,0.7472802908497161,0.9129549160876743 8 | 0.975,0.608647186641021,0.87247224363485 9 | 0.99,0.4834814664658279,0.8195949317989827 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=5_occlusion_results.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8388433932302428,0.9505948152808941 3 | 0.8,0.8068421823222345,0.939881755327618 4 | 0.85,0.7415148429973213,0.921731528025093 5 | 0.9,0.6769748900221814,0.8970696038002213 6 | 0.925,0.6251764096052815,0.87027143236333 7 | 0.95,0.550866902664938,0.8252770021551442 8 | 0.975,0.4859809516173781,0.7912677389350928 9 | 0.99,0.42317444570783963,0.7355294965519601 10 | -------------------------------------------------------------------------------- /experiments/evaluation/results/pam_timex_sp=5_occlusion_results_zero.csv: -------------------------------------------------------------------------------- 1 | thresh,auprc,auroc 2 | 0.75,0.8814928006334966,0.9656952609755267 3 | 0.8,0.8611639646897584,0.9595513460268581 4 | 0.85,0.8082126142642543,0.9484076792046441 5 | 0.9,0.7655591336881931,0.9339273857821176 6 | 0.925,0.7372229199204969,0.9170502977903173 7 | 0.95,0.6507788262108735,0.8809822011277599 8 | 0.975,0.5270506403395412,0.8227208995979989 9 | 0.99,0.4429363899985723,0.7896289695144857 10 | -------------------------------------------------------------------------------- /experiments/evaluation/vis_occlusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib import cm, colors 4 | import seaborn as sns 5 | from sklearn import metrics 6 | import pickle as pkl 7 | import torch 8 | import pandas as pd 9 | from sklearn.metrics import auc, precision_recall_curve, roc_auc_score 10 | import warnings 11 | warnings.simplefilter(action='ignore', category=FutureWarning) 12 | 13 | import matplotlib.ticker as ticker 14 | plt.rcParams.update({'font.size': 11}) 15 | 16 | 17 | pd.set_option('display.max_columns', None) 18 | pd.set_option('display.max_rows', None) 19 | pd.set_option('display.width', 5000) 20 | pd.set_option('max_colwidth', 500) 21 | 22 | datanames = ["pam", 'epilepsy', 'boiler'] 23 | 24 | alg_name = ["random", "dyna", "timex", "ours"] 25 | based_path = "/TimeX/experiments/evaluation/results/" 26 | 27 | tlist = [0.75, 0.8, 0.85, 0.9, 0.925, 0.95, 0.975, 0.99] 28 | 29 | fig, axes = plt.subplots(1, 3, figsize=(16, 4)) 30 | 31 | data_dict = { 32 | "pam":"PAM", 33 | "epilepsy":"Epilepsy", 34 | "boiler":"Boiler" 35 | } 36 | for kk, dataname in enumerate(datanames): 37 | for alg in alg_name: 38 | csv_files = ['{}_{}_sp={}_occlusion_results.csv'.format(dataname, alg, i) for i in range(1,6)] 39 | 40 | if dataname=="boiler": 41 | csv_files = ['{}_{}_sp={}_occlusion_results.csv'.format(dataname, alg, i) for i in [1,2,3,5]] 42 | 43 | all_data = pd.DataFrame() 44 | for file in csv_files: 45 | df = pd.read_csv(based_path+file) 46 | all_data = pd.concat([all_data, df]) 47 | 48 | auroc_means = all_data.groupby('thresh')['auroc'].mean() 49 | auroc_std = all_data.groupby('thresh')['auroc'].std() 50 | auroc_se = auroc_std / np.sqrt(len(csv_files)) # 标准误差 51 | 52 | 53 | # 绘制条形图 54 | axes[kk].plot(range(len(tlist)), auroc_means, '-o', label = alg) 55 | 56 | # 使用fill_between添加误差区间 57 | axes[kk].fill_between(range(len(tlist)), auroc_means - auroc_se, auroc_means + auroc_se, alpha=0.25) 58 | 59 | axes[kk].set_title(data_dict[dataname]) 60 | axes[kk].set_xlabel('Bottom Proportion Perturbed') 61 | axes[kk].set_ylabel('Prediction AUROC') 62 | axes[kk].set_xticks(range(len(tlist))) 63 | axes[kk].set_xticklabels([str(num) for num in tlist]) # Set the tick labels as strings 64 | 65 | if dataname!="boiler": 66 | axes[kk].yaxis.set_major_locator(ticker.MultipleLocator(0.05)) 67 | else: 68 | axes[kk].yaxis.set_major_locator(ticker.MultipleLocator(0.1)) 69 | fig.tight_layout() 70 | h, lab = axes[0].get_legend_handles_labels() 71 | lab = ["random", "dyna", "timex", "ours"] 72 | lab_name = { 73 | "random":"Random", 74 | "dyna":"Dynamask", 75 | "timex":"Timex", 76 | "ours":"Ours", 77 | } 78 | order = [0,1,2,3] 79 | 80 | plt.legend([h[i] for i in order], [lab_name[lab[i]] for i in order], ncol = len(lab), bbox_to_anchor=(-0.1,1.25), fontsize=12) 81 | 82 | plt.show() 83 | 84 | plt.savefig("./vis_results/{}.pdf".format("mean_perturbated"), bbox_inches="tight") 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /experiments/freqshape/models/Scomb_transformer_split=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/Scomb_transformer_split=1.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/Scomb_transformer_split=2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/Scomb_transformer_split=2.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/Scomb_transformer_split=3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/Scomb_transformer_split=3.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/Scomb_transformer_split=4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/Scomb_transformer_split=4.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/Scomb_transformer_split=5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/Scomb_transformer_split=5.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/our_bc_full_split=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/our_bc_full_split=1.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/our_bc_full_split=2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/our_bc_full_split=2.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/our_bc_full_split=3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/our_bc_full_split=3.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/our_bc_full_split=4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/our_bc_full_split=4.pt -------------------------------------------------------------------------------- /experiments/freqshape/models/our_bc_full_split=5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/freqshape/models/our_bc_full_split=5.pt -------------------------------------------------------------------------------- /experiments/freqshape/train_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import CNN 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/FreqShape/') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = CNN( 26 | d_inp = val[0].shape[-1], 27 | n_classes = 4, 28 | ) 29 | 30 | model.to(device) 31 | 32 | optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 0.1) 33 | 34 | spath = 'models/Freqshape_cnn_split={}.pt'.format(i) 35 | 36 | model, loss, auc = train( 37 | model, 38 | train_loader, 39 | val_tuple = val, 40 | n_classes = 4, 41 | num_epochs = 50, 42 | save_path = spath, 43 | optimizer = optimizer, 44 | show_sizes = False, 45 | use_scheduler = False, 46 | validate_by_step= None, 47 | ) 48 | 49 | f1 = eval_mvts_transformer(test, model) 50 | print('Test F1: {:.4f}'.format(f1)) 51 | 52 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 53 | torch.save(model_sdict_cpu, spath) -------------------------------------------------------------------------------- /experiments/freqshape/train_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import LSTM 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/FreqShape/') 21 | print(D['train_loader'].X.shape, D['val'][0].shape) 22 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 23 | 24 | val, test = D['val'], D['test'] 25 | 26 | model = LSTM( 27 | d_inp = val[0].shape[-1], 28 | n_classes = 4, 29 | ) 30 | 31 | model.to(device) 32 | 33 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) 34 | 35 | spath = 'models/Freqshape_lstm_split={}.pt'.format(i) 36 | 37 | model, loss, auc = train( 38 | model, 39 | train_loader, 40 | val_tuple = val, 41 | n_classes = 4, 42 | num_epochs = 100, 43 | save_path = spath, 44 | optimizer = optimizer, 45 | show_sizes = False, 46 | use_scheduler = False, 47 | validate_by_step= None, 48 | ) 49 | 50 | f1 = eval_mvts_transformer(test, model) 51 | print('Test F1: {:.4f}'.format(f1)) 52 | 53 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 54 | torch.save(model_sdict_cpu, spath) -------------------------------------------------------------------------------- /experiments/freqshape/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/FreqShape/') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = TransformerMVTS( 26 | d_inp = val[0].shape[-1], 27 | max_len = val[0].shape[0], 28 | n_classes = 4, 29 | trans_dim_feedforward = 16, 30 | trans_dropout = 0.1, 31 | d_pe = 16, 32 | # aggreg = 'mean', 33 | # norm_embedding = True 34 | ) 35 | 36 | model.to(device) 37 | 38 | optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 0.1) 39 | 40 | spath = 'models/Scomb_transformer_split={}.pt'.format(i) 41 | 42 | model, loss, auc = train( 43 | model, 44 | train_loader, 45 | val_tuple = val, 46 | n_classes = 4, 47 | num_epochs = 100, 48 | save_path = spath, 49 | optimizer = optimizer, 50 | show_sizes = False, 51 | use_scheduler = False, 52 | ) 53 | 54 | f1 = eval_mvts_transformer(test, model) 55 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/lowvardetect/conv_t_cpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | for i in range(1, 6): 14 | #D = process_Synth(split_no = i, device = device, base_path = '/n/data1/hms/dbmi/zitnik/lab/users/owq978/TimeSeriesCBM/datasets/FreqShape') 15 | #train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 16 | 17 | #val, test = D['val'], D['test'] 18 | 19 | model = TransformerMVTS( 20 | d_inp = 2, 21 | max_len = 200, 22 | n_classes = 4, 23 | nlayers = 1, 24 | trans_dim_feedforward = 32, 25 | trans_dropout = 0.25, 26 | d_pe = 16, 27 | # aggreg = 'mean', 28 | norm_embedding = True 29 | ) 30 | 31 | spath = 'models/transformer_new2_split={}.pt'.format(i) 32 | print('re-save {}'.format(spath)) 33 | sd = torch.load(spath) 34 | 35 | model_sdict_cpu = {k:v.cpu() for k, v in sd.items()} 36 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 37 | 38 | -------------------------------------------------------------------------------- /experiments/lowvardetect/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/LowVarDetect') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = TransformerMVTS( 26 | d_inp = val[0].shape[-1], 27 | max_len = val[0].shape[0], 28 | n_classes = 4, 29 | nlayers = 1, 30 | nhead = 1, 31 | trans_dim_feedforward = 32, 32 | trans_dropout = 0.1, 33 | d_pe = 16, 34 | # aggreg = 'mean', 35 | norm_embedding = True 36 | ) 37 | 38 | model.to(device) 39 | 40 | optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 0.01) 41 | 42 | spath = 'models/transformer_new2_split={}.pt'.format(i) 43 | 44 | model, loss, auc = train( 45 | model, 46 | train_loader, 47 | val_tuple = val, 48 | n_classes = 4, 49 | num_epochs = 120, 50 | save_path = spath, 51 | optimizer = optimizer, 52 | show_sizes = False, 53 | use_scheduler = False, 54 | ) 55 | 56 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 57 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 58 | 59 | f1 = eval_mvts_transformer(test, model) 60 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/mitecg_hard/conv_t_cpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | for i in range(1, 6): 14 | #D = process_Synth(split_no = i, device = device, base_path = '/n/data1/hms/dbmi/zitnik/lab/users/owq978/TimeSeriesCBM/datasets/FreqShape') 15 | #train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 16 | 17 | #val, test = D['val'], D['test'] 18 | 19 | model = TransformerMVTS( 20 | d_inp = 1, 21 | max_len = 360, 22 | n_classes = 2, 23 | nlayers = 1, 24 | nhead = 1, 25 | trans_dim_feedforward = 64, 26 | trans_dropout = 0.1, 27 | d_pe = 16, 28 | norm_embedding = True 29 | ) 30 | 31 | spath = 'models/transformer_exc_split={}.pt'.format(i) 32 | print('re-save {}'.format(spath)) 33 | sd = torch.load(spath) 34 | 35 | model_sdict_cpu = {k:v.cpu() for k, v in sd.items()} 36 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 37 | 38 | -------------------------------------------------------------------------------- /experiments/mitecg_hard/train_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import CNN 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | from txai.utils.data import EpiDataset 10 | from txai.utils.data.preprocess import process_MITECG 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | clf_criterion = Poly1CrossEntropyLoss( 15 | num_classes = 2, 16 | epsilon = 1.0, 17 | weight = None,#torch.tensor([1.0, 3.0]), 18 | #weight =None, 19 | ) 20 | 21 | for i in range(1, 6): 22 | torch.cuda.empty_cache() 23 | trainEpi, val, test, _ = process_MITECG(split_no = i, device = device, hard_split = True, normalize = False, 24 | balance_classes = False, div_time = False, need_binarize = True, exclude_pac_pvc = True, 25 | base_path = '/TimeX/datasets/MITECG/') 26 | train_dataset = EpiDataset(trainEpi.X, trainEpi.time, trainEpi.y) 27 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 16, shuffle = True) 28 | 29 | print(trainEpi.y) 30 | 31 | print('X shape') 32 | print(trainEpi.X.shape) 33 | print('y shape', trainEpi.y.shape) 34 | 35 | 36 | val = (val.X, val.time, val.y) 37 | test = (test.X, test.time, test.y) 38 | 39 | model = CNN( 40 | d_inp = val[0].shape[-1], 41 | n_classes = 2, 42 | ) 43 | 44 | model.to(device) 45 | 46 | optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4) 47 | 48 | spath = 'models/MITECG-Hard_cnn_split={}.pt'.format(i) 49 | print('Saving at {}'.format(spath)) 50 | 51 | model, loss, auc = train( 52 | model, 53 | train_loader, 54 | val_tuple = val, 55 | n_classes = 2, 56 | num_epochs = 200, 57 | save_path = spath, 58 | optimizer = optimizer, 59 | show_sizes = False, 60 | validate_by_step = None, 61 | use_scheduler = False, 62 | print_freq = 1 63 | ) 64 | 65 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 66 | torch.save(model_sdict_cpu, spath) 67 | 68 | f1 = eval_mvts_transformer(test, model) 69 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/mitecg_hard/train_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import LSTM 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | from txai.utils.data import EpiDataset 10 | from txai.utils.data.preprocess import process_MITECG 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | clf_criterion = Poly1CrossEntropyLoss( 15 | num_classes = 2, 16 | epsilon = 1.0, 17 | weight = None,#torch.tensor([1.0, 3.0]), 18 | #weight =None, 19 | ) 20 | 21 | for i in range(1, 6): 22 | # for i in [3]: 23 | torch.cuda.empty_cache() 24 | trainEpi, val, test, _ = process_MITECG(split_no = i, device = device, hard_split = True, normalize = False, 25 | balance_classes = False, div_time = False, need_binarize = True, exclude_pac_pvc = True, 26 | base_path = '/TimeX/datasets/MITECG/') 27 | train_dataset = EpiDataset(trainEpi.X, trainEpi.time, trainEpi.y) 28 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 16, shuffle = True) 29 | 30 | print(trainEpi.y) 31 | 32 | print('X shape') 33 | print(trainEpi.X.shape) 34 | print('y shape', trainEpi.y.shape) 35 | 36 | 37 | val = (val.X, val.time, val.y) 38 | test = (test.X, test.time, test.y) 39 | 40 | # print((test[-1] == 0).sum()) 41 | # print((test[-1] == 1).sum()) 42 | # print((test[-1] == 0).sum()) 43 | # print((test[-1] == 1).sum()) 44 | # print((trainEpi.y == 0).sum()) 45 | # print((trainEpi.y == 1).sum()) 46 | # exit() 47 | 48 | model = LSTM( 49 | d_inp = val[0].shape[-1], 50 | n_classes = 2, 51 | ) 52 | 53 | model.to(device) 54 | 55 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) 56 | 57 | spath = 'models/MITECG-Hard_lstm_split={}.pt'.format(i) 58 | print('Saving at {}'.format(spath)) 59 | 60 | model, loss, auc = train( 61 | model, 62 | train_loader, 63 | val_tuple = val, 64 | n_classes = 2, 65 | num_epochs = 200, 66 | save_path = spath, 67 | optimizer = optimizer, 68 | show_sizes = False, 69 | validate_by_step = 32, 70 | use_scheduler = False, 71 | print_freq = 1, 72 | clip_grad=1.0 73 | ) 74 | 75 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 76 | torch.save(model_sdict_cpu, spath) 77 | 78 | f1 = eval_mvts_transformer(test, model, batch_size=32) 79 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/other_baselines/contrast_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | # from utils import drop_feature 7 | 8 | 9 | class contrast_generator(nn.Module): 10 | def __init__(self, predict_model): 11 | super(contrast_generator, self).__init__() 12 | self.predict_model = predict_model 13 | self.device = 'cuda' 14 | self.drop_prob = 0.5 15 | 16 | def forward(self, model, X, pos_num, neg_num, times): 17 | pos_mask = torch.tensor(np.random.choice([1, 0], size=(X.shape)), device=self.device) 18 | pos_mask = torch.cuda.FloatTensor(X.shape).uniform_() > 0.8 19 | pos_tensor = X.mul(pos_mask) 20 | _, tar_exp,_ = model(X,times, get_agg_embed=True) 21 | _, pos_exp,_ = model(pos_tensor,times, get_agg_embed=True) 22 | return tar_exp, pos_exp 23 | 24 | 25 | def pos_drop_sampling_mask(self, data, S): 26 | with torch.no_grad(): 27 | sample_num = S.shape[1] 28 | pos_tensor = data.mul(S) 29 | ref_score = self.predict_model.predict(data) 30 | pred_score = self.predict_model.predict(pos_tensor) 31 | try: 32 | pos_gap = np.concatenate((pos_gap, pred_score), axis=1) 33 | except: 34 | pos_gap = pred_score 35 | score_gap = np.absolute(pos_gap - ref_score) 36 | rank_pos_list = np.argmin(score_gap, axis=1) 37 | best_pos = np.stack([ x[rank_pos_list[idx]] for idx, x in enumerate(train) ]) 38 | return best_pos -------------------------------------------------------------------------------- /experiments/scs_better/train_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import CNN 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(4, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/n/data1/hms/dbmi/zitnik/lab/users/owq978/TimeSeriesCBM/datasets/SeqCombSingleBetter') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = CNN( 26 | d_inp = val[0].shape[-1], 27 | n_classes = 4, 28 | ) 29 | 30 | model.to(device) 31 | 32 | optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 0.01) 33 | 34 | spath = 'models/Scomb_cnn_split={}.pt'.format(i) 35 | 36 | model, loss, auc = train( 37 | model, 38 | train_loader, 39 | val_tuple = val, 40 | n_classes = 4, 41 | num_epochs = 10, 42 | save_path = spath, 43 | optimizer = optimizer, 44 | show_sizes = False, 45 | use_scheduler = False, 46 | validate_by_step = None, 47 | ) 48 | 49 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 50 | torch.save(model_sdict_cpu, spath) 51 | 52 | f1 = eval_mvts_transformer(test, model) 53 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/scs_better/train_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import LSTM 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/SeqCombSingle/') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = LSTM( 26 | d_inp = val[0].shape[-1], 27 | n_classes = 4, 28 | ) 29 | 30 | model.to(device) 31 | 32 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) 33 | 34 | spath = 'models/Scomb_lstm_split={}.pt'.format(i) 35 | 36 | model, loss, auc = train( 37 | model, 38 | train_loader, 39 | val_tuple = val, 40 | n_classes = 4, 41 | num_epochs = 200, 42 | save_path = spath, 43 | optimizer = optimizer, 44 | show_sizes = False, 45 | use_scheduler = False, 46 | validate_by_step = None, 47 | ) 48 | 49 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 50 | torch.save(model_sdict_cpu, spath) 51 | 52 | f1 = eval_mvts_transformer(test, model) 53 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/scs_better/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in [3]:#range(1, 6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/SeqCombSingle/') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = TransformerMVTS( 26 | d_inp = val[0].shape[-1], 27 | max_len = val[0].shape[0], 28 | n_classes = 4, 29 | nlayers = 2, 30 | nhead = 1, 31 | trans_dim_feedforward = 64, 32 | trans_dropout = 0.25, 33 | d_pe = 16, 34 | # aggreg = 'mean', 35 | # norm_embedding = True 36 | ) 37 | 38 | model.to(device) 39 | 40 | optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 0.01) 41 | 42 | spath = 'models/Scomb_transformer_split={}.pt'.format(i) 43 | 44 | model, loss, auc = train( 45 | model, 46 | train_loader, 47 | val_tuple = val, 48 | n_classes = 4, 49 | num_epochs = 200, 50 | save_path = spath, 51 | optimizer = optimizer, 52 | show_sizes = False, 53 | use_scheduler = False, 54 | ) 55 | 56 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 57 | torch.save(model_sdict_cpu, 'models/Scomb_transformer_split={}_cpu.pt'.format(i)) 58 | 59 | f1 = eval_mvts_transformer(test, model) 60 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/bc_full_LC_split=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/bc_full_LC_split=1.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/bc_full_LC_split=2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/bc_full_LC_split=2.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/bc_full_LC_split=3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/bc_full_LC_split=3.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/bc_full_LC_split=4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/bc_full_LC_split=4.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/bc_full_LC_split=5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/bc_full_LC_split=5.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/our_bc_full_LC_split=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/our_bc_full_LC_split=1.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/our_bc_full_LC_split=2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/our_bc_full_LC_split=2.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/our_bc_full_LC_split=3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/our_bc_full_LC_split=3.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/our_bc_full_LC_split=4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/our_bc_full_LC_split=4.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/our_bc_full_LC_split=5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/our_bc_full_LC_split=5.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/transformer_split=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/transformer_split=1.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/transformer_split=2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/transformer_split=2.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/transformer_split=3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/transformer_split=3.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/transformer_split=4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/transformer_split=4.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/models/transformer_split=5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/experiments/seqcomb_mv/models/transformer_split=5.pt -------------------------------------------------------------------------------- /experiments/seqcomb_mv/train_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import CNN 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1,6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/SeqCombMV2') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = CNN( 26 | d_inp = val[0].shape[-1], 27 | n_classes = 4, 28 | ) 29 | 30 | model.to(device) 31 | 32 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4, weight_decay = 0.001) 33 | 34 | spath = 'models/ScombMV_cnn_split={}.pt'.format(i) 35 | 36 | model, loss, auc = train( 37 | model, 38 | train_loader, 39 | val_tuple = val, 40 | n_classes = 4, 41 | num_epochs = 50, 42 | save_path = spath, 43 | optimizer = optimizer, 44 | show_sizes = False, 45 | use_scheduler = False, 46 | validate_by_step = None, 47 | ) 48 | 49 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 50 | torch.save(model_sdict_cpu, spath) 51 | 52 | f1 = eval_mvts_transformer(test, model) 53 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/seqcomb_mv/train_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import LSTM 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | clf_criterion = Poly1CrossEntropyLoss( 13 | num_classes = 4, 14 | epsilon = 1.0, 15 | weight = None, 16 | reduction = 'mean' 17 | ) 18 | 19 | for i in range(1,6): 20 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/SeqCombMV2') 21 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 22 | 23 | val, test = D['val'], D['test'] 24 | 25 | model = LSTM( 26 | d_inp = val[0].shape[-1], 27 | n_classes = 4, 28 | ) 29 | 30 | model.to(device) 31 | 32 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) 33 | 34 | spath = 'models/ScombMV_lstm_split={}.pt'.format(i) 35 | 36 | model, loss, auc = train( 37 | model, 38 | train_loader, 39 | val_tuple = val, 40 | n_classes = 4, 41 | num_epochs = 200, 42 | save_path = spath, 43 | optimizer = optimizer, 44 | show_sizes = False, 45 | use_scheduler = False, 46 | validate_by_step = None, 47 | ) 48 | 49 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 50 | torch.save(model_sdict_cpu, spath) 51 | 52 | f1 = eval_mvts_transformer(test, model) 53 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/seqcomb_mv/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.transformer_simple import TransformerMVTS 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | 10 | import random 11 | import numpy as np 12 | 13 | my_seed=42 14 | random.seed(my_seed) 15 | np.random.seed(my_seed) 16 | torch.manual_seed(my_seed) 17 | torch.cuda.manual_seed_all(my_seed) 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | clf_criterion = Poly1CrossEntropyLoss( 22 | num_classes = 4, 23 | epsilon = 1.0, 24 | weight = None, 25 | reduction = 'mean' 26 | ) 27 | 28 | for i in [2]:#range(1, 6): 29 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/SeqCombMV2') 30 | train_loader = torch.utils.data.DataLoader(D['train_loader'], batch_size = 64, shuffle = True) 31 | 32 | val, test = D['val'], D['test'] 33 | 34 | model = TransformerMVTS( 35 | d_inp = val[0].shape[-1], 36 | max_len = val[0].shape[0], 37 | n_classes = 4, 38 | trans_dim_feedforward = 128, 39 | nlayers = 2, 40 | trans_dropout = 0.25, 41 | d_pe = 16, 42 | # aggreg = 'mean', 43 | # norm_embedding = True 44 | ) 45 | 46 | model.to(device) 47 | 48 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4, weight_decay = 0.001) 49 | 50 | spath = 'models/transformer_split={}.pt'.format(i) 51 | 52 | model, loss, auc = train( 53 | model, 54 | train_loader, 55 | val_tuple = val, 56 | n_classes = 4, 57 | num_epochs = 1000, 58 | save_path = spath, 59 | optimizer = optimizer, 60 | show_sizes = False, 61 | use_scheduler = False, 62 | ) 63 | 64 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 65 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 66 | 67 | f1, auprc, auroc = eval_mvts_transformer(test, model, auprc=True, auroc=True) 68 | print('Test F1: {:.4f}, AUPRC: {:.4f}, AUROC: {:.4f}'.format(f1, auprc, auroc)) 69 | 70 | -------------------------------------------------------------------------------- /experiments/water/train_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 4 | from txai.trainers.train_transformer import train 5 | from txai.models.encoders.simple import LSTM 6 | from txai.utils.data import process_Synth 7 | from txai.utils.predictors import eval_mvts_transformer 8 | from txai.synth_data.simple_spike import SpikeTrainDataset 9 | from torch.utils.data import Dataset 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | clf_criterion = Poly1CrossEntropyLoss( 14 | num_classes = 4, 15 | epsilon = 1.0, 16 | weight = None, 17 | reduction = 'mean' 18 | ) 19 | 20 | class AttrDict(dict): 21 | def __getattr__(self, attr): 22 | try: 23 | return self[attr] 24 | except KeyError: 25 | raise AttributeError(f"Attribute '{attr}' not found.") 26 | 27 | def __setattr__(self, attr, value): 28 | self[attr] = value 29 | 30 | class CustomDataset(Dataset): 31 | def __init__(self, X, times, y): 32 | self.X = X 33 | self.times = times 34 | self.y = y 35 | 36 | def __len__(self): 37 | return len(self.y) 38 | 39 | def __getitem__(self, idx): 40 | return self.X[:, idx, :], self.times[:, idx], self.y[idx] 41 | 42 | for i in range(1,6): 43 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/water') 44 | dataset = CustomDataset(D['train_loader'].X, D['train_loader'].times, D['train_loader'].y) 45 | train_loader = torch.utils.data.DataLoader(dataset, batch_size = 64, shuffle = True) 46 | 47 | 48 | val, test = D['val'], D['test'] 49 | 50 | model = LSTM( 51 | d_inp = val[0].shape[-1], 52 | n_classes = 2, 53 | ) 54 | 55 | model.to(device) 56 | 57 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4) 58 | 59 | spath = 'models/Water_lstm_split={}.pt'.format(i) 60 | 61 | model, loss, auc = train( 62 | model, 63 | train_loader, 64 | val_tuple = val, 65 | n_classes = 2, 66 | num_epochs = 500, 67 | save_path = spath, 68 | optimizer = optimizer, 69 | show_sizes = False, 70 | use_scheduler = False, 71 | validate_by_step = None, 72 | ) 73 | 74 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 75 | torch.save(model_sdict_cpu, spath) 76 | 77 | f1 = eval_mvts_transformer(test, model) 78 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /experiments/water/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | from txai.utils.predictors.loss import Poly1CrossEntropyLoss 5 | from txai.trainers.train_transformer import train 6 | from txai.models.encoders.transformer_simple import TransformerMVTS 7 | from txai.utils.data import process_Synth 8 | from txai.utils.predictors import eval_mvts_transformer 9 | from txai.synth_data.simple_spike import SpikeTrainDataset 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | clf_criterion = Poly1CrossEntropyLoss( 14 | num_classes = 2, 15 | epsilon = 1.0, 16 | weight = None, 17 | reduction = 'mean' 18 | ) 19 | 20 | class AttrDict(dict): 21 | def __getattr__(self, attr): 22 | try: 23 | return self[attr] 24 | except KeyError: 25 | raise AttributeError(f"Attribute '{attr}' not found.") 26 | 27 | def __setattr__(self, attr, value): 28 | self[attr] = value 29 | 30 | class CustomDataset(Dataset): 31 | def __init__(self, X, times, y): 32 | self.X = X 33 | self.times = times 34 | self.y = y 35 | 36 | def __len__(self): 37 | return len(self.y) 38 | 39 | def __getitem__(self, idx): 40 | return self.X[:, idx, :], self.times[:, idx], self.y[idx] 41 | 42 | 43 | for i in range(1, 6): 44 | D = process_Synth(split_no = i, device = device, base_path = '/TimeX/datasets/water') 45 | dataset = CustomDataset(D['train_loader'].X, D['train_loader'].times, D['train_loader'].y) 46 | train_loader = torch.utils.data.DataLoader(dataset, batch_size = 64, shuffle = True) 47 | 48 | val, test = D['val'], D['test'] 49 | 50 | model = TransformerMVTS( 51 | d_inp = val[0].shape[-1], 52 | max_len = val[0].shape[0], 53 | n_classes = 2, 54 | nlayers = 1, 55 | trans_dim_feedforward = 64, 56 | trans_dropout = 0.1, 57 | d_pe = 16, 58 | # aggreg = 'mean', 59 | stronger_clf_head = False, 60 | pre_agg_transform = False, 61 | norm_embedding = True 62 | ) 63 | 64 | model.to(device) 65 | 66 | optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-3, weight_decay = 0.001) 67 | 68 | spath = 'models/transformer_split={}.pt'.format(i) 69 | 70 | model, loss, auc = train( 71 | model, 72 | train_loader, 73 | val_tuple = val, 74 | n_classes = 2, 75 | num_epochs = 500, 76 | save_path = spath, 77 | optimizer = optimizer, 78 | show_sizes = False, 79 | use_scheduler = False, 80 | ) 81 | 82 | model_sdict_cpu = {k:v.cpu() for k, v in model.state_dict().items()} 83 | torch.save(model_sdict_cpu, 'models/transformer_split={}_cpu.pt'.format(i)) 84 | 85 | f1 = eval_mvts_transformer(test, model) 86 | print('Test F1: {:.4f}'.format(f1)) -------------------------------------------------------------------------------- /pic/model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/pic/model.pdf -------------------------------------------------------------------------------- /pic/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/pic/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.10.10 3 | aiosignal==1.3.1 4 | asttokens==2.4.1 5 | attrs==24.2.0 6 | axial_positional_embedding==0.2.1 7 | captum==0.7.0 8 | CoLT5-attention==0.11.1 9 | contourpy==1.3.0 10 | cycler==0.12.1 11 | decorator==5.1.1 12 | einops==0.8.0 13 | executing==2.1.0 14 | filelock==3.16.1 15 | fonttools==4.54.1 16 | frozenlist==1.4.1 17 | fsspec==2024.9.0 18 | idna==3.10 19 | ipdb==0.13.13 20 | ipython==8.28.0 21 | jedi==0.19.1 22 | Jinja2==3.1.4 23 | jitcdde==1.4.0 24 | jitcxde_common==1.4.1 25 | joblib==1.4.2 26 | kiwisolver==1.4.7 27 | lightning-utilities==0.11.8 28 | llvmlite==0.43.0 29 | local-attention==1.9.15 30 | MarkupSafe==3.0.1 31 | matplotlib==3.9.2 32 | matplotlib-inline==0.1.7 33 | mpmath==1.3.0 34 | multidict==6.1.0 35 | networkx==3.4.1 36 | numba==0.60.0 37 | numpy==2.0.2 38 | nvidia-cublas-cu12==12.1.3.1 39 | nvidia-cuda-cupti-cu12==12.1.105 40 | nvidia-cuda-nvrtc-cu12==12.1.105 41 | nvidia-cuda-runtime-cu12==12.1.105 42 | nvidia-cudnn-cu12==9.1.0.70 43 | nvidia-cufft-cu12==11.0.2.54 44 | nvidia-curand-cu12==10.3.2.106 45 | nvidia-cusolver-cu12==11.4.5.107 46 | nvidia-cusparse-cu12==12.1.0.106 47 | nvidia-nccl-cu12==2.20.5 48 | nvidia-nvjitlink-cu12==12.6.77 49 | nvidia-nvtx-cu12==12.1.105 50 | packaging==24.1 51 | pandas==2.2.3 52 | parso==0.8.4 53 | pexpect==4.9.0 54 | pillow==11.0.0 55 | product_key_memory==0.2.11 56 | prompt_toolkit==3.0.48 57 | propcache==0.2.0 58 | ptyprocess==0.7.0 59 | pure_eval==0.2.3 60 | Pygments==2.18.0 61 | pyparsing==3.2.0 62 | python-dateutil==2.9.0.post0 63 | pytorch-lightning==2.4.0 64 | pytz==2024.2 65 | PyYAML==6.0.2 66 | reformer-pytorch==1.4.4 67 | scikit-learn==1.5.2 68 | scipy==1.14.1 69 | setuptools==75.2.0 70 | six==1.16.0 71 | stack-data==0.6.3 72 | symengine==0.13.0 73 | sympy==1.13.3 74 | threadpoolctl==3.5.0 75 | time_interpret==0.3.0 76 | torch==2.4.1 77 | torchmetrics==1.4.3 78 | tqdm==4.66.5 79 | traitlets==5.14.3 80 | triton==3.0.0 81 | tslearn==0.6.3 82 | typing_extensions==4.12.2 83 | tzdata==2024.2 84 | wcwidth==0.2.13 85 | yarl==1.15.4 86 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='txai', 4 | version='0.1', 5 | description='timex++', 6 | packages=['txai'], 7 | author = 'timex++', 8 | zip_safe=False) 9 | -------------------------------------------------------------------------------- /timesynth-0.2.4/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: timesynth 3 | Version: 0.2.4 4 | Summary: Library for creating synthetic time series 5 | Home-page: https://github.com/TimeSynth/TimeSynth 6 | Author: Abhishek Malali, Reinier Maat, Pavlos Protopapas 7 | Author-email: anon@anon.com 8 | License: MIT 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /timesynth-0.2.4/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/TimeSynth/TimeSynth.svg?branch=master)](https://travis-ci.org/TimeSynth/TimeSynth) [![codecov](https://codecov.io/gh/TimeSynth/TimeSynth/branch/master/graph/badge.svg)](https://codecov.io/gh/TimeSynth/TimeSynth) 2 | 3 | # TimeSynth 4 | _Multipurpose Library for Synthetic Time Series_ 5 | 6 | **TimeSynth** is an open source library for generating synthetic time series for 7 | model testing. The library can generate regular and irregular time series. The architecture 8 | allows the user to match different signals with different architectures allowing 9 | a vast array of signals to be generated. The available signals and noise types are 10 | listed below. 11 | 12 | N.B. We only support Python 3.5+ at this time. 13 | 14 | #### Signal Types 15 | * Harmonic functions(sin, cos or custom functions) 16 | * Gaussian processes with different kernels 17 | * Constant 18 | * Squared exponential 19 | * Exponential 20 | * Rational quadratic 21 | * Linear 22 | * Matern 23 | * Periodic 24 | * Pseudoperiodic signals 25 | * Autoregressive(p) process 26 | * Continuous autoregressive process (CAR) 27 | * Nonlinear Autoregressive Moving Average model (NARMA) 28 | 29 | #### Noise Types 30 | * White noise 31 | * Red noise 32 | 33 | ### Installation 34 | To install the package via github, 35 | ```{bash} 36 | git clone https://github.com/TimeSynth/TimeSynth.git 37 | cd TimeSynth 38 | python setup.py install 39 | ``` 40 | 41 | ### Using TimeSynth 42 | ```shell 43 | $ python 44 | ``` 45 | The code snippet demonstrates creating a irregular sinusoidal signal with white noise. 46 | ```python 47 | >>> import timesynth as ts 48 | >>> # Initializing TimeSampler 49 | >>> time_sampler = ts.TimeSampler(stop_time=20) 50 | >>> # Sampling irregular time samples 51 | >>> irregular_time_samples = time_sampler.sample_irregular_time(num_points=500, keep_percentage=50) 52 | >>> # Initializing Sinusoidal signal 53 | >>> sinusoid = ts.signals.Sinusoidal(frequency=0.25) 54 | >>> # Initializing Gaussian noise 55 | >>> white_noise = ts.noise.GaussianNoise(std=0.3) 56 | >>> # Initializing TimeSeries class with the signal and noise objects 57 | >>> timeseries = ts.TimeSeries(sinusoid, noise_generator=white_noise) 58 | >>> # Sampling using the irregular time samples 59 | >>> samples, signals, errors = timeseries.sample(irregular_time_samples) 60 | ``` 61 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/__init__.py: -------------------------------------------------------------------------------- 1 | from .timeseries import TimeSeries 2 | from . import signals 3 | from . import noise 4 | from .timesampler import TimeSampler 5 | 6 | name = "timesynth" 7 | 8 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/noise/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname as _dirname, basename as _basename, isfile as _isfile 2 | import glob as _glob 3 | 4 | exec('\n'.join(map(lambda name: "from ." + name + " import *", 5 | [_basename(f)[:-3] for f in _glob.glob(_dirname(__file__) + "/*.py") \ 6 | if _isfile(f) and not _basename(f).startswith('_')]))) 7 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/noise/base_noise.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | 4 | class BaseNoise: 5 | """BaseNoise class 6 | 7 | Signature for all noise classes. 8 | 9 | """ 10 | 11 | def __init__(self): 12 | raise NotImplementedError 13 | 14 | def sample_next(self, t, samples, errors): # We provide t for irregularly sampled timeseries 15 | """Samples next point based on history of samples and errors 16 | 17 | Parameters 18 | ---------- 19 | t : int 20 | time 21 | samples : array-like 22 | all samples taken so far 23 | errors : array-like 24 | all errors sampled so far 25 | 26 | Returns 27 | ------- 28 | float 29 | sampled error for time t 30 | 31 | """ 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/noise/gaussian_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_noise import BaseNoise 3 | 4 | 5 | __all__ = ['GaussianNoise'] 6 | 7 | 8 | class GaussianNoise(BaseNoise): 9 | """Gaussian noise generator. 10 | This class adds uncorrelated, additive white noise to your signal. 11 | 12 | Attributes 13 | ---------- 14 | mean : float 15 | mean for the noise 16 | std : float 17 | standard deviation for the noise 18 | 19 | """ 20 | 21 | def __init__(self, mean=0, std=1.): 22 | self.vectorizable = True 23 | self.mean = mean 24 | self.std = std 25 | 26 | def sample_next(self, t, samples, errors): 27 | return np.random.normal(loc=self.mean, scale=self.std, size=1) 28 | 29 | def sample_vectorized(self, time_vector): 30 | n_samples = len(time_vector) 31 | return np.random.normal(loc=self.mean, scale=self.std, size=n_samples) 32 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/noise/red_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_noise import BaseNoise 3 | 4 | 5 | __all__ = ['RedNoise'] 6 | 7 | 8 | class RedNoise(BaseNoise): 9 | """Red noise generator. 10 | This class adds correlated (red) noise to your signal. 11 | 12 | Attributes 13 | ---------- 14 | mean : float 15 | mean for the noise 16 | std : float 17 | standard deviation for the noise 18 | tau : float 19 | ? 20 | start_value : float 21 | ? 22 | 23 | """ 24 | 25 | def __init__(self, mean=0, std=1., tau=0.2, start_value=0): 26 | self.vectorizable = False 27 | self.mean = mean 28 | self.std = std 29 | self.start_value = 0 30 | self.tau = tau 31 | self.previous_value = None 32 | self.previous_time = None 33 | 34 | def sample_next(self, t, samples, errors): 35 | if self.previous_time is None: 36 | red_noise = self.start_value 37 | else: 38 | time_diff = t - self.previous_time 39 | wnoise = np.random.normal(loc=self.mean, scale=self.std, size=1) 40 | red_noise = ((self.tau/(self.tau + time_diff)) * 41 | (time_diff*wnoise + self.previous_value)) 42 | self.previous_time = t 43 | self.previous_value =red_noise 44 | return red_noise 45 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname as _dirname, basename as _basename, isfile as _isfile 2 | import glob as _glob 3 | 4 | exec('\n'.join(map(lambda name: "from ." + name + " import *", 5 | [_basename(f)[:-3] for f in _glob.glob(_dirname(__file__) + "/*.py") \ 6 | if _isfile(f) and not _basename(f).startswith('_')]))) 7 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/ar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | __all__ = ['AutoRegressive'] 5 | 6 | 7 | class AutoRegressive(BaseSignal): 8 | """Sample generator for autoregressive (AR) signals. 9 | 10 | Generates time series with an autogressive lag defined by the number of parameters in ar_param. 11 | NOTE: Only use this for regularly sampled signals 12 | 13 | Parameters 14 | ---------- 15 | ar_param : list (default [None]) 16 | Parameter of the AR(p) process 17 | [phi_1, phi_2, phi_3, .... phi_p] 18 | sigma : float (default 1.0) 19 | Standard deviation of the signal 20 | start_value : list (default [None]) 21 | Starting value of the AR(p) process 22 | 23 | """ 24 | 25 | def __init__(self, ar_param=[None], sigma=0.5, start_value=[None]): 26 | self.vectorizable = False 27 | ar_param.reverse() 28 | self.ar_param = ar_param 29 | self.sigma = sigma 30 | if start_value[0] is None: 31 | self.start_value = [0 for i in range(len(ar_param))] 32 | else: 33 | if len(start_value) != len(ar_param): 34 | raise ValueError("AR parameters do not match starting value") 35 | else: 36 | self.start_value = start_value 37 | self.previous_value = self.start_value 38 | 39 | def sample_next(self, time, samples, errors): 40 | """Sample a single time point 41 | 42 | Parameters 43 | ---------- 44 | time : number 45 | Time at which a sample was required 46 | 47 | Returns 48 | ------- 49 | ar_value : float 50 | sampled signal for time t 51 | """ 52 | ar_value = [self.previous_value[i] * self.ar_param[i] for i in range(len(self.ar_param))] 53 | noise = np.random.normal(loc=0.0, scale=self.sigma, size=1) 54 | ar_value = np.sum(ar_value) + noise 55 | self.previous_value = self.previous_value[1:]+[ar_value] 56 | return ar_value 57 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/base_signal.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | 4 | class BaseSignal: 5 | """BaseSignal class 6 | 7 | Signature for all signal classes. 8 | 9 | """ 10 | 11 | def __init__(self): 12 | raise NotImplementedError 13 | 14 | def sample_next(self, time, samples, errors): 15 | """Samples next point based on history of samples and errors 16 | 17 | Parameters 18 | ---------- 19 | time : int 20 | time 21 | samples : array-like 22 | all samples taken so far 23 | errors : array-like 24 | all errors sampled so far 25 | 26 | Returns 27 | ------- 28 | float 29 | sampled signal for time t 30 | 31 | """ 32 | raise NotImplementedError 33 | 34 | def sample_vectorized(self, time_vector): 35 | """Samples for all time points in input 36 | 37 | Parameters 38 | ---------- 39 | time_vector : array like 40 | all time stamps to be sampled 41 | 42 | Returns 43 | ------- 44 | float 45 | sampled signal for time t 46 | 47 | """ 48 | raise NotImplementedError 49 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/car.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | __all__ = ['CAR'] 5 | 6 | 7 | class CAR(BaseSignal): 8 | """Signal generatpr for continuously autoregressive (CAR) signals. 9 | 10 | Parameters 11 | ---------- 12 | ar_param : number (default 1.0) 13 | Parameter of the AR(1) process 14 | sigma : number (default 1.0) 15 | Standard deviation of the signal 16 | start_value : number (default 0.0) 17 | Starting value of the AR process 18 | 19 | """ 20 | 21 | def __init__(self, ar_param=1.0, sigma=0.5, start_value=0.01): 22 | self.vectorizable = False 23 | self.ar_param = ar_param 24 | self.sigma = sigma 25 | self.start_value = start_value 26 | self.previous_value = None 27 | self.previous_time = None 28 | 29 | def sample_next(self, time, samples, errors): 30 | """Sample a single time point 31 | 32 | Parameters 33 | ---------- 34 | time : number 35 | Time at which a sample was required 36 | 37 | Returns 38 | ------- 39 | float 40 | sampled signal for time t 41 | 42 | """ 43 | if self.previous_value is None: 44 | output = self.start_value 45 | else: 46 | time_diff = time - self.previous_time 47 | noise = np.random.normal(loc=0.0, scale=1.0, size=1) 48 | output = (np.power(self.ar_param, time_diff))*self.previous_value+\ 49 | self.sigma*np.sqrt(1-np.power(self.ar_param, time_diff))*noise 50 | self.previous_time = time 51 | self.previous_value = output 52 | return output 53 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/ode.py: -------------------------------------------------------------------------------- 1 | # Stub for Ordinary Differential Equations 2 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/pseudoperiodic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | __all__ = ['PseudoPeriodic'] 5 | 6 | 7 | class PseudoPeriodic(BaseSignal): 8 | """Signal generator for pseudoeriodic waves. 9 | 10 | The wave's amplitude and frequency have some stochasticity that 11 | can be set manually. 12 | 13 | Parameters 14 | ---------- 15 | amplitude : number (default 1.0) 16 | Amplitude of the harmonic series 17 | frequency : number (default 1.0) 18 | Frequency of the harmonic series 19 | ampSD : number (default 0.1) 20 | Amplitude standard deviation 21 | freqSD : number (default 0.1) 22 | Frequency standard deviation 23 | ftype : function(default np.sin) 24 | Harmonic function 25 | 26 | """ 27 | 28 | def __init__(self, amplitude=1.0, frequency=100, ampSD=0.1, freqSD=0.4, 29 | ftype=np.sin): 30 | self.vectorizable = True 31 | self.amplitude = amplitude 32 | self.frequency = frequency 33 | self.freqSD = freqSD 34 | self.ampSD = ampSD 35 | self.ftype = ftype 36 | 37 | def sample_next(self, time, samples, errors): 38 | """Sample a single time point 39 | 40 | Parameters 41 | ---------- 42 | time : number 43 | Time at which a sample was required 44 | 45 | Returns 46 | ------- 47 | float 48 | sampled signal for time t 49 | 50 | """ 51 | freq_val = np.random.normal(loc=self.frequency, scale=self.freqSD, size=1) 52 | amplitude_val = np.random.normal(loc=self.amplitude, scale=self.ampSD, size=1) 53 | return float(amplitude_val * np.sin(freq_val * time)) 54 | 55 | def sample_vectorized(self, time_vector): 56 | """Sample entire series based off of time vector 57 | 58 | Parameters 59 | ---------- 60 | time_vector : array-like 61 | Timestamps for signal generation 62 | 63 | Returns 64 | ------- 65 | array-like 66 | sampled signal for time vector 67 | 68 | """ 69 | n_samples = len(time_vector) 70 | freq_arr = np.random.normal(loc=self.frequency, scale=self.freqSD, size=n_samples) 71 | amp_arr = np.random.normal(loc=self.amplitude, scale=self.ampSD, size=n_samples) 72 | signal = np.multiply(amp_arr, self.ftype(np.multiply(freq_arr, time_vector))) 73 | return signal 74 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/signals/sinusoidal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | 5 | __all__ = ['Sinusoidal'] 6 | 7 | 8 | class Sinusoidal(BaseSignal): 9 | """Signal generator for harmonic (sinusoidal) waves. 10 | 11 | Parameters 12 | ---------- 13 | amplitude : number (default 1.0) 14 | Amplitude of the harmonic series 15 | frequency : number (default 1.0) 16 | Frequency of the harmonic series 17 | ftype : function (default np.sin) 18 | Harmonic function 19 | 20 | """ 21 | 22 | def __init__(self, amplitude=1.0, frequency=1.0, ftype=np.sin): 23 | self.vectorizable = True 24 | self.amplitude = amplitude 25 | self.ftype = ftype 26 | self.frequency = frequency 27 | 28 | def sample_next(self, time, samples, errors): 29 | """Sample a single time point 30 | 31 | Parameters 32 | ---------- 33 | time : number 34 | Time at which a sample was required 35 | 36 | Returns 37 | ------- 38 | float 39 | sampled signal for time t 40 | 41 | """ 42 | return self.amplitude * self.ftype(2*np.pi*self.frequency*time) 43 | 44 | def sample_vectorized(self, time_vector): 45 | """Sample entire series based off of time vector 46 | 47 | Parameters 48 | ---------- 49 | time_vector : array-like 50 | Timestamps for signal generation 51 | 52 | Returns 53 | ------- 54 | array-like 55 | sampled signal for time vector 56 | 57 | """ 58 | if self.vectorizable is True: 59 | signal = self.amplitude * self.ftype(2*np.pi*self.frequency * 60 | time_vector) 61 | return signal 62 | else: 63 | raise ValueError("Signal type not vectorizable") 64 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/timesampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .timesampler import * 2 | -------------------------------------------------------------------------------- /timesynth-0.2.4/build/lib/timesynth/timeseries.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ['TimeSeries'] 4 | 5 | 6 | class TimeSeries: 7 | """A TimeSeries object is the main interface from which to sample time series. 8 | You have to provide at least a signal generator; a noise generator is optional. 9 | It is recommended to set the sampling frequency. 10 | 11 | Parameters 12 | ---------- 13 | signal_generator : Signal object 14 | signal object for time series 15 | noise_generator : Noise object 16 | noise object for time series 17 | 18 | """ 19 | def __init__(self, signal_generator, noise_generator=None): 20 | self.signal_generator = signal_generator 21 | self.noise_generator = noise_generator 22 | 23 | 24 | def sample(self, time_vector): 25 | """Samples from the specified TimeSeries. 26 | 27 | Parameters 28 | ---------- 29 | time_vector : numpy array 30 | Times at which to generate a sample 31 | 32 | Returns 33 | ------- 34 | samples, signals, errors, : tuple (array, array, array) 35 | Returns samples, and the signals and errors they were constructed from 36 | """ 37 | 38 | # Vectorize if possible 39 | if self.signal_generator.vectorizable and not self.noise_generator is None and self.noise_generator.vectorizable: 40 | signals = self.signal_generator.sample_vectorized(time_vector) 41 | errors = self.noise_generator.sample_vectorized(time_vector) 42 | samples = signals + errors 43 | elif self.signal_generator.vectorizable and self.noise_generator is None: 44 | signals = self.signal_generator.sample_vectorized(time_vector) 45 | errors = np.zeros(len(time_vector)) 46 | samples = signals 47 | else: 48 | n_samples = len(time_vector) 49 | samples = np.zeros(n_samples) # Signal and errors combined 50 | signals = np.zeros(n_samples) # Signal samples 51 | errors = np.zeros(n_samples) # Handle errors seprately 52 | times = np.arange(n_samples) 53 | 54 | # Sample iteratively, while providing access to all previously sampled steps 55 | for i in range(n_samples): 56 | # Get time 57 | t = time_vector[i] 58 | # Sample error 59 | if not self.noise_generator is None: 60 | errors[i] = self.noise_generator.sample_next(t, samples[:i - 1], errors[:i - 1]) 61 | 62 | # Sample signal 63 | signal = self.signal_generator.sample_next(t, samples[:i - 1], errors[:i - 1]) 64 | signals[i] = signal 65 | 66 | # Compound signal and noise 67 | samples[i] = signals[i] + errors[i] 68 | 69 | # Return both times and samples, as well as signals and errors 70 | return samples, signals, errors 71 | -------------------------------------------------------------------------------- /timesynth-0.2.4/setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test = pytest 3 | 4 | [egg_info] 5 | tag_build = 6 | tag_date = 0 7 | 8 | -------------------------------------------------------------------------------- /timesynth-0.2.4/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | setup(name='timesynth', 5 | version='0.2.4', 6 | description='Library for creating synthetic time series', 7 | url='https://github.com/TimeSynth/TimeSynth', 8 | author='Abhishek Malali, Reinier Maat, Pavlos Protopapas', 9 | author_email='anon@anon.com', 10 | license='MIT', 11 | include_package_data=True, 12 | packages=find_packages(), 13 | install_requires=['numpy', 'scipy', 'sympy', 'symengine', 'jitcdde==1.4', 'jitcxde_common==1.4.1'], 14 | tests_require=['pytest'], 15 | setup_requires=["pytest-runner"]) 16 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: timesynth 3 | Version: 0.2.4 4 | Summary: Library for creating synthetic time series 5 | Home-page: https://github.com/TimeSynth/TimeSynth 6 | Author: Abhishek Malali, Reinier Maat, Pavlos Protopapas 7 | Author-email: anon@anon.com 8 | License: MIT 9 | Requires-Dist: numpy 10 | Requires-Dist: scipy 11 | Requires-Dist: sympy 12 | Requires-Dist: symengine 13 | Requires-Dist: jitcdde==1.4 14 | Requires-Dist: jitcxde_common==1.4.1 15 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.cfg 3 | setup.py 4 | timesynth/__init__.py 5 | timesynth/timeseries.py 6 | timesynth.egg-info/PKG-INFO 7 | timesynth.egg-info/SOURCES.txt 8 | timesynth.egg-info/dependency_links.txt 9 | timesynth.egg-info/requires.txt 10 | timesynth.egg-info/top_level.txt 11 | timesynth/noise/__init__.py 12 | timesynth/noise/base_noise.py 13 | timesynth/noise/gaussian_noise.py 14 | timesynth/noise/red_noise.py 15 | timesynth/signals/__init__.py 16 | timesynth/signals/ar.py 17 | timesynth/signals/base_signal.py 18 | timesynth/signals/car.py 19 | timesynth/signals/dde.py 20 | timesynth/signals/gaussian_process.py 21 | timesynth/signals/narma.py 22 | timesynth/signals/ode.py 23 | timesynth/signals/pseudoperiodic.py 24 | timesynth/signals/sinusoidal.py 25 | timesynth/timesampler/__init__.py 26 | timesynth/timesampler/timesampler.py -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | sympy 4 | symengine 5 | jitcdde==1.4 6 | jitcxde_common==1.4.1 7 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | timesynth 2 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/__init__.py: -------------------------------------------------------------------------------- 1 | from .timeseries import TimeSeries 2 | from . import signals 3 | from . import noise 4 | from .timesampler import TimeSampler 5 | 6 | name = "timesynth" 7 | 8 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/noise/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname as _dirname, basename as _basename, isfile as _isfile 2 | import glob as _glob 3 | 4 | exec('\n'.join(map(lambda name: "from ." + name + " import *", 5 | [_basename(f)[:-3] for f in _glob.glob(_dirname(__file__) + "/*.py") \ 6 | if _isfile(f) and not _basename(f).startswith('_')]))) 7 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/noise/base_noise.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | 4 | class BaseNoise: 5 | """BaseNoise class 6 | 7 | Signature for all noise classes. 8 | 9 | """ 10 | 11 | def __init__(self): 12 | raise NotImplementedError 13 | 14 | def sample_next(self, t, samples, errors): # We provide t for irregularly sampled timeseries 15 | """Samples next point based on history of samples and errors 16 | 17 | Parameters 18 | ---------- 19 | t : int 20 | time 21 | samples : array-like 22 | all samples taken so far 23 | errors : array-like 24 | all errors sampled so far 25 | 26 | Returns 27 | ------- 28 | float 29 | sampled error for time t 30 | 31 | """ 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/noise/gaussian_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_noise import BaseNoise 3 | 4 | 5 | __all__ = ['GaussianNoise'] 6 | 7 | 8 | class GaussianNoise(BaseNoise): 9 | """Gaussian noise generator. 10 | This class adds uncorrelated, additive white noise to your signal. 11 | 12 | Attributes 13 | ---------- 14 | mean : float 15 | mean for the noise 16 | std : float 17 | standard deviation for the noise 18 | 19 | """ 20 | 21 | def __init__(self, mean=0, std=1.): 22 | self.vectorizable = True 23 | self.mean = mean 24 | self.std = std 25 | 26 | def sample_next(self, t, samples, errors): 27 | return np.random.normal(loc=self.mean, scale=self.std, size=1) 28 | 29 | def sample_vectorized(self, time_vector): 30 | n_samples = len(time_vector) 31 | return np.random.normal(loc=self.mean, scale=self.std, size=n_samples) 32 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/noise/red_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_noise import BaseNoise 3 | 4 | 5 | __all__ = ['RedNoise'] 6 | 7 | 8 | class RedNoise(BaseNoise): 9 | """Red noise generator. 10 | This class adds correlated (red) noise to your signal. 11 | 12 | Attributes 13 | ---------- 14 | mean : float 15 | mean for the noise 16 | std : float 17 | standard deviation for the noise 18 | tau : float 19 | ? 20 | start_value : float 21 | ? 22 | 23 | """ 24 | 25 | def __init__(self, mean=0, std=1., tau=0.2, start_value=0): 26 | self.vectorizable = False 27 | self.mean = mean 28 | self.std = std 29 | self.start_value = 0 30 | self.tau = tau 31 | self.previous_value = None 32 | self.previous_time = None 33 | 34 | def sample_next(self, t, samples, errors): 35 | if self.previous_time is None: 36 | red_noise = self.start_value 37 | else: 38 | time_diff = t - self.previous_time 39 | wnoise = np.random.normal(loc=self.mean, scale=self.std, size=1) 40 | red_noise = ((self.tau/(self.tau + time_diff)) * 41 | (time_diff*wnoise + self.previous_value)) 42 | self.previous_time = t 43 | self.previous_value =red_noise 44 | return red_noise 45 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname as _dirname, basename as _basename, isfile as _isfile 2 | import glob as _glob 3 | 4 | exec('\n'.join(map(lambda name: "from ." + name + " import *", 5 | [_basename(f)[:-3] for f in _glob.glob(_dirname(__file__) + "/*.py") \ 6 | if _isfile(f) and not _basename(f).startswith('_')]))) 7 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/ar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | __all__ = ['AutoRegressive'] 5 | 6 | 7 | class AutoRegressive(BaseSignal): 8 | """Sample generator for autoregressive (AR) signals. 9 | 10 | Generates time series with an autogressive lag defined by the number of parameters in ar_param. 11 | NOTE: Only use this for regularly sampled signals 12 | 13 | Parameters 14 | ---------- 15 | ar_param : list (default [None]) 16 | Parameter of the AR(p) process 17 | [phi_1, phi_2, phi_3, .... phi_p] 18 | sigma : float (default 1.0) 19 | Standard deviation of the signal 20 | start_value : list (default [None]) 21 | Starting value of the AR(p) process 22 | 23 | """ 24 | 25 | def __init__(self, ar_param=[None], sigma=0.5, start_value=[None]): 26 | self.vectorizable = False 27 | ar_param.reverse() 28 | self.ar_param = ar_param 29 | self.sigma = sigma 30 | if start_value[0] is None: 31 | self.start_value = [0 for i in range(len(ar_param))] 32 | else: 33 | if len(start_value) != len(ar_param): 34 | raise ValueError("AR parameters do not match starting value") 35 | else: 36 | self.start_value = start_value 37 | self.previous_value = self.start_value 38 | 39 | def sample_next(self, time, samples, errors): 40 | """Sample a single time point 41 | 42 | Parameters 43 | ---------- 44 | time : number 45 | Time at which a sample was required 46 | 47 | Returns 48 | ------- 49 | ar_value : float 50 | sampled signal for time t 51 | """ 52 | ar_value = [self.previous_value[i] * self.ar_param[i] for i in range(len(self.ar_param))] 53 | noise = np.random.normal(loc=0.0, scale=self.sigma, size=1) 54 | ar_value = np.sum(ar_value) + noise 55 | self.previous_value = self.previous_value[1:]+[ar_value] 56 | return ar_value 57 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/base_signal.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | 4 | class BaseSignal: 5 | """BaseSignal class 6 | 7 | Signature for all signal classes. 8 | 9 | """ 10 | 11 | def __init__(self): 12 | raise NotImplementedError 13 | 14 | def sample_next(self, time, samples, errors): 15 | """Samples next point based on history of samples and errors 16 | 17 | Parameters 18 | ---------- 19 | time : int 20 | time 21 | samples : array-like 22 | all samples taken so far 23 | errors : array-like 24 | all errors sampled so far 25 | 26 | Returns 27 | ------- 28 | float 29 | sampled signal for time t 30 | 31 | """ 32 | raise NotImplementedError 33 | 34 | def sample_vectorized(self, time_vector): 35 | """Samples for all time points in input 36 | 37 | Parameters 38 | ---------- 39 | time_vector : array like 40 | all time stamps to be sampled 41 | 42 | Returns 43 | ------- 44 | float 45 | sampled signal for time t 46 | 47 | """ 48 | raise NotImplementedError 49 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/car.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | __all__ = ['CAR'] 5 | 6 | 7 | class CAR(BaseSignal): 8 | """Signal generatpr for continuously autoregressive (CAR) signals. 9 | 10 | Parameters 11 | ---------- 12 | ar_param : number (default 1.0) 13 | Parameter of the AR(1) process 14 | sigma : number (default 1.0) 15 | Standard deviation of the signal 16 | start_value : number (default 0.0) 17 | Starting value of the AR process 18 | 19 | """ 20 | 21 | def __init__(self, ar_param=1.0, sigma=0.5, start_value=0.01): 22 | self.vectorizable = False 23 | self.ar_param = ar_param 24 | self.sigma = sigma 25 | self.start_value = start_value 26 | self.previous_value = None 27 | self.previous_time = None 28 | 29 | def sample_next(self, time, samples, errors): 30 | """Sample a single time point 31 | 32 | Parameters 33 | ---------- 34 | time : number 35 | Time at which a sample was required 36 | 37 | Returns 38 | ------- 39 | float 40 | sampled signal for time t 41 | 42 | """ 43 | if self.previous_value is None: 44 | output = self.start_value 45 | else: 46 | time_diff = time - self.previous_time 47 | noise = np.random.normal(loc=0.0, scale=1.0, size=1) 48 | output = (np.power(self.ar_param, time_diff))*self.previous_value+\ 49 | self.sigma*np.sqrt(1-np.power(self.ar_param, time_diff))*noise 50 | self.previous_time = time 51 | self.previous_value = output 52 | return output 53 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/ode.py: -------------------------------------------------------------------------------- 1 | # Stub for Ordinary Differential Equations 2 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/pseudoperiodic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | __all__ = ['PseudoPeriodic'] 5 | 6 | 7 | class PseudoPeriodic(BaseSignal): 8 | """Signal generator for pseudoeriodic waves. 9 | 10 | The wave's amplitude and frequency have some stochasticity that 11 | can be set manually. 12 | 13 | Parameters 14 | ---------- 15 | amplitude : number (default 1.0) 16 | Amplitude of the harmonic series 17 | frequency : number (default 1.0) 18 | Frequency of the harmonic series 19 | ampSD : number (default 0.1) 20 | Amplitude standard deviation 21 | freqSD : number (default 0.1) 22 | Frequency standard deviation 23 | ftype : function(default np.sin) 24 | Harmonic function 25 | 26 | """ 27 | 28 | def __init__(self, amplitude=1.0, frequency=100, ampSD=0.1, freqSD=0.4, 29 | ftype=np.sin): 30 | self.vectorizable = True 31 | self.amplitude = amplitude 32 | self.frequency = frequency 33 | self.freqSD = freqSD 34 | self.ampSD = ampSD 35 | self.ftype = ftype 36 | 37 | def sample_next(self, time, samples, errors): 38 | """Sample a single time point 39 | 40 | Parameters 41 | ---------- 42 | time : number 43 | Time at which a sample was required 44 | 45 | Returns 46 | ------- 47 | float 48 | sampled signal for time t 49 | 50 | """ 51 | freq_val = np.random.normal(loc=self.frequency, scale=self.freqSD, size=1) 52 | amplitude_val = np.random.normal(loc=self.amplitude, scale=self.ampSD, size=1) 53 | return float(amplitude_val * np.sin(freq_val * time)) 54 | 55 | def sample_vectorized(self, time_vector): 56 | """Sample entire series based off of time vector 57 | 58 | Parameters 59 | ---------- 60 | time_vector : array-like 61 | Timestamps for signal generation 62 | 63 | Returns 64 | ------- 65 | array-like 66 | sampled signal for time vector 67 | 68 | """ 69 | n_samples = len(time_vector) 70 | freq_arr = np.random.normal(loc=self.frequency, scale=self.freqSD, size=n_samples) 71 | amp_arr = np.random.normal(loc=self.amplitude, scale=self.ampSD, size=n_samples) 72 | signal = np.multiply(amp_arr, self.ftype(np.multiply(freq_arr, time_vector))) 73 | return signal 74 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/signals/sinusoidal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base_signal import BaseSignal 3 | 4 | 5 | __all__ = ['Sinusoidal'] 6 | 7 | 8 | class Sinusoidal(BaseSignal): 9 | """Signal generator for harmonic (sinusoidal) waves. 10 | 11 | Parameters 12 | ---------- 13 | amplitude : number (default 1.0) 14 | Amplitude of the harmonic series 15 | frequency : number (default 1.0) 16 | Frequency of the harmonic series 17 | ftype : function (default np.sin) 18 | Harmonic function 19 | 20 | """ 21 | 22 | def __init__(self, amplitude=1.0, frequency=1.0, ftype=np.sin): 23 | self.vectorizable = True 24 | self.amplitude = amplitude 25 | self.ftype = ftype 26 | self.frequency = frequency 27 | 28 | def sample_next(self, time, samples, errors): 29 | """Sample a single time point 30 | 31 | Parameters 32 | ---------- 33 | time : number 34 | Time at which a sample was required 35 | 36 | Returns 37 | ------- 38 | float 39 | sampled signal for time t 40 | 41 | """ 42 | return self.amplitude * self.ftype(2*np.pi*self.frequency*time) 43 | 44 | def sample_vectorized(self, time_vector): 45 | """Sample entire series based off of time vector 46 | 47 | Parameters 48 | ---------- 49 | time_vector : array-like 50 | Timestamps for signal generation 51 | 52 | Returns 53 | ------- 54 | array-like 55 | sampled signal for time vector 56 | 57 | """ 58 | if self.vectorizable is True: 59 | signal = self.amplitude * self.ftype(2*np.pi*self.frequency * 60 | time_vector) 61 | return signal 62 | else: 63 | raise ValueError("Signal type not vectorizable") 64 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/timesampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .timesampler import * 2 | -------------------------------------------------------------------------------- /timesynth-0.2.4/timesynth/timeseries.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ['TimeSeries'] 4 | 5 | 6 | class TimeSeries: 7 | """A TimeSeries object is the main interface from which to sample time series. 8 | You have to provide at least a signal generator; a noise generator is optional. 9 | It is recommended to set the sampling frequency. 10 | 11 | Parameters 12 | ---------- 13 | signal_generator : Signal object 14 | signal object for time series 15 | noise_generator : Noise object 16 | noise object for time series 17 | 18 | """ 19 | def __init__(self, signal_generator, noise_generator=None): 20 | self.signal_generator = signal_generator 21 | self.noise_generator = noise_generator 22 | 23 | 24 | def sample(self, time_vector): 25 | """Samples from the specified TimeSeries. 26 | 27 | Parameters 28 | ---------- 29 | time_vector : numpy array 30 | Times at which to generate a sample 31 | 32 | Returns 33 | ------- 34 | samples, signals, errors, : tuple (array, array, array) 35 | Returns samples, and the signals and errors they were constructed from 36 | """ 37 | 38 | # Vectorize if possible 39 | if self.signal_generator.vectorizable and not self.noise_generator is None and self.noise_generator.vectorizable: 40 | signals = self.signal_generator.sample_vectorized(time_vector) 41 | errors = self.noise_generator.sample_vectorized(time_vector) 42 | samples = signals + errors 43 | elif self.signal_generator.vectorizable and self.noise_generator is None: 44 | signals = self.signal_generator.sample_vectorized(time_vector) 45 | errors = np.zeros(len(time_vector)) 46 | samples = signals 47 | else: 48 | n_samples = len(time_vector) 49 | samples = np.zeros(n_samples) # Signal and errors combined 50 | signals = np.zeros(n_samples) # Signal samples 51 | errors = np.zeros(n_samples) # Handle errors seprately 52 | times = np.arange(n_samples) 53 | 54 | # Sample iteratively, while providing access to all previously sampled steps 55 | for i in range(n_samples): 56 | # Get time 57 | t = time_vector[i] 58 | # Sample error 59 | if not self.noise_generator is None: 60 | errors[i] = self.noise_generator.sample_next(t, samples[:i - 1], errors[:i - 1]) 61 | 62 | # Sample signal 63 | signal = self.signal_generator.sample_next(t, samples[:i - 1], errors[:i - 1]) 64 | signals[i] = signal 65 | 66 | # Compound signal and noise 67 | samples[i] = signals[i] + errors[i] 68 | 69 | # Return both times and samples, as well as signals and errors 70 | return samples, signals, errors 71 | -------------------------------------------------------------------------------- /txai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/txai/__init__.py -------------------------------------------------------------------------------- /txai/baselines/Dynamask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/txai/baselines/Dynamask/__init__.py -------------------------------------------------------------------------------- /txai/baselines/Dynamask/attribution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/txai/baselines/Dynamask/attribution/__init__.py -------------------------------------------------------------------------------- /txai/baselines/Dynamask/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/txai/baselines/Dynamask/utils/__init__.py -------------------------------------------------------------------------------- /txai/baselines/Dynamask/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cross_entropy(proba_pred, proba_target): 5 | """Computes the cross entropy between the two probabilities torch tensors.""" 6 | return -(proba_target * torch.log(proba_pred)).mean() 7 | 8 | 9 | def log_loss(proba_pred, proba_target): 10 | """Computes the log loss between the two probabilities torch tensors.""" 11 | label_target = torch.argmax(proba_target, dim=-1, keepdim=True) 12 | proba_select = torch.gather(proba_pred, -1, label_target) 13 | return -(torch.log(proba_select)).mean() 14 | 15 | 16 | def log_loss_target(proba_pred, target): 17 | """Computes log loss between the target and the predicted probabilities expressed as torch tensors. 18 | 19 | The target is a one dimensional tensor whose dimension matches the first dimension of proba_pred. 20 | It contains integers that represent the true class for each instance. 21 | """ 22 | proba_select = torch.gather(proba_pred, -1, target) 23 | return -(torch.log(proba_select)).mean() 24 | 25 | 26 | def mse(Y, Y_target): 27 | """Computes the mean squared error between Y and Y_target.""" 28 | return torch.mean((Y - Y_target) ** 2) 29 | -------------------------------------------------------------------------------- /txai/baselines/Dynamask/utils/tensor_manipulation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def normalize(tensor, eps=1.0e-7): 5 | tensor -= tensor.min() 6 | tensor /= tensor.max() + eps 7 | return tensor 8 | 9 | 10 | def extract_subtensor(tensor: torch.Tensor, ids_time, ids_feature): 11 | """This method extracts a subtensor specified with the indices. 12 | 13 | Args: 14 | tensor: The (T, N_features) tensor from which the data should be extracted. 15 | ids_time: List of the times that should be extracted. 16 | ids_feature: List of the features that should be extracted. 17 | 18 | Returns: 19 | torch.Tensor: Submask extracted based on the indices. 20 | """ 21 | T, N_features = tensor.shape 22 | # If no identifiers have been specified, we use the whole data 23 | if ids_time is None: 24 | ids_time = [k for k in range(T)] 25 | if ids_feature is None: 26 | ids_feature = [k for k in range(N_features)] 27 | # Extract the relevant data in the mask 28 | subtensor = tensor.clone().detach() 29 | subtensor = subtensor[ids_time, :] 30 | subtensor = subtensor[:, ids_feature] 31 | return subtensor 32 | -------------------------------------------------------------------------------- /txai/baselines/FIT/.gitignore: -------------------------------------------------------------------------------- 1 | #ignore pdf and jpg for now 2 | ._*.pdf 3 | ._*.py 4 | .nfs* 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | #IDE/OS specific files 11 | .DS_Store 12 | ._.DS_Store 13 | .idea 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # celery beat schedule file 101 | celerybeat-schedule 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /txai/baselines/FIT/TSX/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichuan-liu/TimeXplusplus/f2ac77c5c675a4e7f48413a4852b10242cc32bf0/txai/baselines/FIT/TSX/__init__.py -------------------------------------------------------------------------------- /txai/baselines/FIT/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "simulation": { 3 | "feature_generator_explainer": { 4 | "encoding_size" : 10, 5 | "batch_size" : 100, 6 | "n_epochs" : 300, 7 | "historical": 1, 8 | "rnn_type" : "GRU", 9 | "non_linearity":"torch.nn.ReLU()" 10 | }, 11 | "risk_predictor":{ 12 | "encoding_size" : 200, 13 | "batch_size" : 100, 14 | "n_epochs" : 140, 15 | "historical": 1, 16 | "rnn_type" : "GRU" 17 | }, 18 | "lime_explainer":{ 19 | "batch_size" : 100, 20 | "n_epochs" : 20 21 | } 22 | }, 23 | "simulation_spike": { 24 | "feature_generator_explainer": { 25 | "encoding_size" : 50, 26 | "batch_size" : 200, 27 | "n_epochs" : 250, 28 | "historical": 1, 29 | "rnn_type" : "GRU", 30 | "non_linearity":"torch.nn.TanH()" 31 | }, 32 | "risk_predictor":{ 33 | "encoding_size" : 50, 34 | "batch_size" : 100, 35 | "n_epochs" : 80, 36 | "historical": 1, 37 | "rnn_type" : "GRU" 38 | }, 39 | "lime_explainer":{ 40 | "batch_size" : 100, 41 | "n_epochs" : 20 42 | } 43 | }, 44 | "mimic": { 45 | "feature_generator_explainer": { 46 | "encoding_size" : 80, 47 | "batch_size" : 100, 48 | "n_epochs" : 200, 49 | "historical": 1, 50 | "rnn_type" : "GRU" 51 | }, 52 | "risk_predictor":{ 53 | "encoding_size" : 150, 54 | "batch_size" : 100, 55 | "n_epochs" : 80, 56 | "historical": 1, 57 | "rnn_type" : "GRU" 58 | }, 59 | "lime_explainer":{ 60 | "batch_size" : 100, 61 | "n_epochs" : 20 62 | } 63 | }, 64 | "ghg": { 65 | "feature_generator_explainer": { 66 | "encoding_size" : 100, 67 | "batch_size" : 100, 68 | "n_epochs" : 130, 69 | "historical": 1, 70 | "rnn_type" : "GRU" 71 | }, 72 | "risk_predictor":{ 73 | "encoding_size" : 500, 74 | "batch_size" : 1000, 75 | "n_epochs" : 200, 76 | "historical": 0, 77 | "rnn_type" : "LSTM" 78 | }, 79 | "lime_explainer":{ 80 | "batch_size" : 100, 81 | "n_epochs" : 20 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /txai/baselines/FIT/data_generator/data/clean_state_data.py: -------------------------------------------------------------------------------- 1 | import os, argparse, pickle 2 | import torch 3 | 4 | # Process: 5 | # 1. Load from numpy 6 | # 2. Convert to Torch 7 | # 3. Break off validation split 8 | # 4. Make times 9 | # 5. Make training dataset object 10 | # 6. Package into dictionary 11 | # 7. Save 12 | 13 | class StateTrainDataset(torch.utils.data.Dataset): 14 | def __init__(self, X, times, y): 15 | self.X, self.times, self.y = X, times, y 16 | 17 | def __len__(self): 18 | return self.X.shape[1] 19 | 20 | def __getitem__(self, idx): 21 | x = self.X[:,idx,:] 22 | T = self.times[:,idx] 23 | y = self.y[idx] 24 | return x, T, y 25 | 26 | def print_tuple(t): 27 | print('X', t[0].shape) 28 | print('time', t[1].shape) 29 | print('y', t[2].shape) 30 | 31 | def main(loc, split_no): 32 | # Load all data: 33 | for s in ['train', 'test']: 34 | # Importance: 35 | imp = pickle.load(open(os.path.join(loc, 'state_dataset_importance_{}.pkl'.format(s)), 'rb')) 36 | # imp already stored in matrix form, no need to convert 37 | log = pickle.load(open(os.path.join(loc, 'state_dataset_logits_{}.pkl'.format(s)), 'rb')) 38 | x = pickle.load(open(os.path.join(loc, 'state_dataset_x_{}.pkl'.format(s)), 'rb')) 39 | 40 | # Step 2: Convert to Torch 41 | xt = torch.from_numpy(x).permute(2, 0, 1) 42 | impt = torch.from_numpy(imp).permute(2, 0, 1) 43 | yt = torch.from_numpy((log > 0.5).astype(int)) # Convert logits to static 44 | yt = yt.sum(dim=-1) 45 | 46 | if s == 'train': 47 | Xtrain, ytrain = xt, yt 48 | 49 | # Step 3: break off validation set 50 | whole_inds = torch.randperm(Xtrain.shape[1]) 51 | val_inds = whole_inds[:100] 52 | 53 | Xval = Xtrain[:,val_inds,:] 54 | yval = ytrain[val_inds] 55 | # Step 4: make times 56 | timeval = torch.arange(1,Xval.shape[0]+1).unsqueeze(-1).repeat(1,Xval.shape[1]) 57 | 58 | Xtrain = Xtrain[:,whole_inds[100:],:] 59 | ytrain = ytrain[whole_inds[100:]] 60 | timetrain = torch.arange(1,Xtrain.shape[0]+1).unsqueeze(-1).repeat(1,Xtrain.shape[1]) 61 | 62 | elif s == 'test': 63 | Xtest, ytest = xt, yt 64 | timetest = torch.arange(1,Xtest.shape[0]+1).unsqueeze(-1).repeat(1,Xtest.shape[1]) 65 | gt_exps = impt # Only keep GT explanations for test split 66 | 67 | # Step 5: Make training dataset object 68 | train_dataset = StateTrainDataset(Xtrain, timetrain, ytrain) 69 | 70 | # Step 6: package into dictionary 71 | dataset = { 72 | 'train_loader': train_dataset, 73 | 'val': (Xval, timeval, yval), 74 | 'test': (Xtest, timetest, ytest), 75 | 'gt_exps': gt_exps, 76 | } 77 | 78 | print('\nTrain') 79 | print_tuple((Xtrain, timetrain, ytrain)) 80 | 81 | print('\nVal') 82 | print_tuple((Xval, timeval, yval)) 83 | 84 | print('\nTest') 85 | print_tuple((Xtest, timetest, ytest)) 86 | 87 | 88 | # Step 7: save 89 | torch.save(dataset, '/home/owq978/TimeSeriesXAI/datasets/StateTrans/split={}.pt'.format(split_no)) 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--split', type = int, required = True) 95 | args = parser.parse_args() 96 | 97 | main(loc = 'simulated_data', split_no = args.split) 98 | -------------------------------------------------------------------------------- /txai/baselines/FIT/evaluation/cv_mimic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for cv in 0 1 2 4 | do 5 | python -u -m TSX.main --data mimic --model feature_generator_explainer --generator joint_RNN_generator --train --cv $cv 6 | python -u -m TSX.main --data mimic --model risk_predictor --train --cv $cv 7 | python -u -m TSX.main --data mimic --model risk_predictor --predictor attention --train --cv $cv 8 | python -u -m TSX.main --data mimic --model feature_generator_explainer --generator joint_RNN_generator --cv $cv --all_samples 9 | done 10 | -------------------------------------------------------------------------------- /txai/baselines/FIT/evaluation/cv_simulation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for cv in 0 1 2 4 | do 5 | python -u -m TSX.main --data simulation --model feature_generator_explainer --generator joint_RNN_generator --train --cv $cv 6 | python -u -m TSX.main --data simulation --model risk_predictor --train --cv $cv 7 | python -u -m TSX.main --data simulation --model risk_predictor --predictor attention --train --cv $cv 8 | python -u -m TSX.main --data simulation --model feature_generator_explainer --generator joint_RNN_generator --cv $cv --all_samples 9 | done 10 | 11 | for cv in 0 1 2 3 12 | do 13 | python -u -m TSX.main --data simulation_spike --model feature_generator_explainer --generator joint_RNN_generator --train --cv $cv 14 | python -u -m TSX.main --data simulation_spike --model risk_predictor --train --cv $cv 15 | python -u -m TSX.main --data simulation --model risk_predictor --predictor attention --train --cv $cv 16 | python -u -m TSX.main --data simulation_spike --model feature_generator_explainer --generator joint_RNN_generator --cv $cv --all_samples 17 | done 18 | -------------------------------------------------------------------------------- /txai/baselines/FIT/evaluation/cv_simulation_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for cv in 0 1 2 4 | do 5 | python -u -m TSX.main --data simulation --model feature_generator_explainer --generator joint_RNN_generator --train --cv $cv 6 | python -u -m TSX.main --data simulation --model risk_predictor --train --cv $cv 7 | python -u -m TSX.main --data simulation --model risk_predictor --predictor attention --train --cv $cv 8 | python -u -m TSX.main --data simulation --model feature_generator_explainer --generator joint_RNN_generator --predictor attention --cv $cv --all_samples 9 | done 10 | 11 | for cv in 0 1 2 3 12 | do 13 | python -u -m TSX.main --data simulation_spike --model feature_generator_explainer --generator joint_RNN_generator --train --cv $cv 14 | python -u -m TSX.main --data simulation_spike --model risk_predictor --train --cv $cv 15 | python -u -m TSX.main --data simulation --model risk_predictor --predictor attention --train --cv $cv 16 | python -u -m TSX.main --data simulation_spike --model feature_generator_explainer --generator joint_RNN_generator --predictor attention --cv $cv --all_samples 17 | done 18 | -------------------------------------------------------------------------------- /txai/baselines/FIT/evaluation/performance_scores.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle as pkl 4 | from sklearn import metrics 5 | import argparse 6 | 7 | 8 | def performance_metric(score, g_truth): 9 | n = len(score) 10 | Temp_TPR = np.zeros([n, ]) 11 | Temp_FDR = np.zeros([n, ]) 12 | score = np.clip(score, 0, 1) 13 | 14 | for i in range(n): 15 | # TPR 16 | # print(score[i, :10], g_truth[i, :10]) 17 | TPR_Nom = np.dot(score[i, :], (g_truth[i, :]))#np.sum(score[i, :] * g_truth[i, :]) 18 | TPR_Den = np.sum(g_truth[i, :]) 19 | # print(float(TPR_Nom), float(TPR_Den + 1e-18)) 20 | Temp_TPR[i] = float(TPR_Nom) / float(TPR_Den + 1e-18) 21 | 22 | # FDR 23 | FDR_Nom = np.dot(score[i, :], (1 - g_truth[i, :]))#np.sum(score[i, :] * (1 - g_truth[i, :])) 24 | FDR_Den = np.sum(score[i, :]) 25 | Temp_FDR[i] = float(FDR_Nom) / float(FDR_Den + 1e-18) 26 | 27 | return np.mean(Temp_TPR), np.mean(Temp_FDR), np.std(Temp_TPR), np.std(Temp_FDR) 28 | 29 | def main(args): 30 | if args.data == 'simulation': 31 | feature_size = 3 32 | data_path = './data/simulated_data' 33 | data_type = 'state' 34 | elif args.data == 'simulation_l2x': 35 | feature_size = 3 36 | data_path = './data/simulated_data_l2x' 37 | data_type = 'state' 38 | elif args.data == 'simulation_spike': 39 | feature_size = 3 40 | data_path = './data/simulated_spike_data' 41 | data_type = 'spike' 42 | 43 | score_path = '/scratch/gobi1/shalmali/TSX_results/new_results/%s' %(args.data) 44 | if data_type == 'state': 45 | with open(os.path.join(data_path, 'state_dataset_importance_test.pkl'), 'rb') as f: 46 | gt_importance_test = pkl.load(f) 47 | elif data_type == 'spike': 48 | with open(os.path.join(data_path, 'gt_test.pkl'), 'rb') as f: 49 | gt_importance_test = pkl.load(f) 50 | 51 | auc, aupr, fdr, tpr = [], [], [], [] 52 | for cv in [0, 1, 2, 3, 4]: 53 | with open(os.path.join(score_path, '%s_test_importance_scores_%s.pkl' %(args.explainer, str(cv))), 'rb') as f: 54 | importance_scores = pkl.load(f) 55 | 56 | gt_importance_test.astype(int) 57 | gt_score = gt_importance_test.flatten() 58 | explainer_score = importance_scores.flatten() 59 | n = len(gt_importance_test) 60 | if (args.explainer == 'deep_lift' or args.explainer == 'integrated_gradient' or args.explainer == 'gradient_shap'): 61 | explainer_score = np.abs(explainer_score) 62 | auc_score = metrics.roc_auc_score(gt_score, explainer_score) 63 | aupr_score = metrics.average_precision_score(gt_score, explainer_score) 64 | p_metric = performance_metric(importance_scores.reshape(n,-1), gt_importance_test.reshape(n,-1)) 65 | tpr.append(p_metric[0]) 66 | fdr.append(p_metric[1]) 67 | auc.append(auc_score) 68 | aupr.append(aupr_score) 69 | print(args.explainer, ' auc: %.3f +- %.3f'%(np.mean(auc), np.std(auc)), ' aupr: %.3f +- %.3f'%(np.mean(aupr), np.std(aupr))) 70 | print(args.explainer, ' tpr: %.3f +- %.3f' % (np.mean(tpr), np.std(tpr)), ' fdr: %.3f +- %.3f' % (np.mean(fdr), np.std(fdr))) 71 | 72 | 73 | if __name__ == '__main__': 74 | np.random.seed(1234) 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--explainer', type=str, default='fit', help='Explainer model') 77 | parser.add_argument('--data', type=str, default='simulation') 78 | parser.add_argument('--generator_type', type=str, default='history') 79 | args = parser.parse_args() 80 | main(args) -------------------------------------------------------------------------------- /txai/baselines/SGT/__init__.py: -------------------------------------------------------------------------------- 1 | from .interpretable import train as SGT_train -------------------------------------------------------------------------------- /txai/baselines/SGT/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Net(nn.Module): 7 | def __init__(self): 8 | super(Net, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 10 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 11 | self.dropout1 = nn.Dropout(0.25) 12 | self.dropout2 = nn.Dropout(0.5) 13 | self.fc1 = nn.Linear(9216, 128) 14 | self.fc2 = nn.Linear(128, 10) 15 | 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = F.relu(x) 19 | x = self.conv2(x) 20 | x = F.relu(x) 21 | x = F.max_pool2d(x, 2) 22 | x = self.dropout1(x) 23 | x = torch.flatten(x, 1) 24 | x = self.fc1(x) 25 | x = F.relu(x) 26 | x = self.dropout2(x) 27 | x = self.fc2(x) 28 | output = F.log_softmax(x, dim=1) 29 | return output 30 | -------------------------------------------------------------------------------- /txai/baselines/SGT/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | term_width=10 16 | TOTAL_BAR_LENGTH = 65. 17 | last_time = time.time() 18 | begin_time = last_time 19 | def progress_bar(current, total, msg=None): 20 | global last_time, begin_time 21 | if current == 0: 22 | begin_time = time.time() # Reset for new bar. 23 | 24 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 25 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 26 | 27 | sys.stdout.write(' [') 28 | for i in range(cur_len): 29 | sys.stdout.write('=') 30 | sys.stdout.write('>') 31 | for i in range(rest_len): 32 | sys.stdout.write('.') 33 | sys.stdout.write(']') 34 | 35 | cur_time = time.time() 36 | step_time = cur_time - last_time 37 | last_time = cur_time 38 | tot_time = cur_time - begin_time 39 | 40 | L = [] 41 | L.append(' Step: %s' % format_time(step_time)) 42 | L.append(' | Tot: %s' % format_time(tot_time)) 43 | if msg: 44 | L.append(' | ' + msg) 45 | 46 | msg = ''.join(L) 47 | sys.stdout.write(msg) 48 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 49 | sys.stdout.write(' ') 50 | 51 | # Go back to the center of the bar. 52 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 53 | sys.stdout.write('\b') 54 | sys.stdout.write(' %d/%d ' % (current+1, total)) 55 | 56 | if current < total-1: 57 | sys.stdout.write('\r') 58 | else: 59 | sys.stdout.write('\n') 60 | sys.stdout.flush() 61 | 62 | def format_time(seconds): 63 | days = int(seconds / 3600/24) 64 | seconds = seconds - days*3600*24 65 | hours = int(seconds / 3600) 66 | seconds = seconds - hours*3600 67 | minutes = int(seconds / 60) 68 | seconds = seconds - minutes*60 69 | secondsf = int(seconds) 70 | seconds = seconds - secondsf 71 | millis = int(seconds*1000) 72 | 73 | f = '' 74 | i = 1 75 | if days > 0: 76 | f += str(days) + 'D' 77 | i += 1 78 | if hours > 0 and i <= 2: 79 | f += str(hours) + 'h' 80 | i += 1 81 | if minutes > 0 and i <= 2: 82 | f += str(minutes) + 'm' 83 | i += 1 84 | if secondsf > 0 and i <= 2: 85 | f += str(secondsf) + 's' 86 | i += 1 87 | if millis > 0 and i <= 2: 88 | f += str(millis) + 'ms' 89 | i += 1 90 | if f == '': 91 | f = '0ms' 92 | return f -------------------------------------------------------------------------------- /txai/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE: commenting out because this code hasn't been commited 2 | # from .TSR_repo.run_TSR import TSR 3 | TSR = lambda *args, **kwargs: NotImplementedError -------------------------------------------------------------------------------- /txai/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.transformer_simple import TransformerMVTS -------------------------------------------------------------------------------- /txai/models/encoders/positional_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | class PositionalEncodingTF(nn.Module): 9 | def __init__(self, d_model, max_len=500, MAX=10000,): 10 | super(PositionalEncodingTF, self).__init__() 11 | self.max_len = max_len 12 | self.d_model = d_model 13 | self.MAX = MAX 14 | self._num_timescales = d_model // 2 15 | 16 | def getPE(self, P_time): 17 | B = P_time.shape[1] # Number of batches 18 | 19 | P_time = P_time.float() 20 | 21 | # timescales = self.max_len ** torch.linspace(0, 1, self._num_timescales).to(device) this was numpy 22 | timescales = self.max_len ** torch.linspace(0, 1, self._num_timescales).to(device) 23 | 24 | #times = torch.Tensor(P_time.cpu()).unsqueeze(2) 25 | times = P_time.unsqueeze(2) 26 | 27 | scaled_time = times / torch.Tensor(timescales[None, None, :]) 28 | # Use a 32-D embedding to represent a single time point 29 | pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], axis=-1) # T x B x d_model 30 | #pe = pe.type(torch.FloatTensor) 31 | 32 | return pe 33 | 34 | def forward(self, P_time): 35 | pe = self.getPE(P_time) 36 | #pe = pe.to(device) 37 | return pe -------------------------------------------------------------------------------- /txai/models/encoders/simple.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class CNN(nn.Module): 5 | def __init__(self, d_inp, n_classes, dim=128): 6 | super().__init__() 7 | self.encoder = nn.Sequential( 8 | nn.Conv1d(d_inp, out_channels=dim, kernel_size=7, padding=3), 9 | nn.ReLU(), 10 | nn.MaxPool1d(kernel_size=2, stride=2), 11 | nn.Conv1d(dim, dim, kernel_size=3, padding=1), 12 | nn.ReLU(), 13 | nn.MaxPool1d(kernel_size=2, stride=2), 14 | nn.Conv1d(dim, dim, kernel_size=3, padding=1), 15 | nn.ReLU(), 16 | nn.MaxPool1d(kernel_size=2, stride=2), 17 | nn.AdaptiveAvgPool1d(1), 18 | nn.Flatten(), 19 | ) 20 | 21 | self.mlp = nn.Sequential( 22 | nn.Linear(dim, dim), 23 | nn.ReLU(), 24 | nn.Linear(dim, n_classes), 25 | ) 26 | 27 | def forward(self, x, _times, get_embedding=False, captum_input=False, show_sizes=False): 28 | if captum_input: 29 | if len(x.shape) == 2: 30 | x = x.unsqueeze(0) 31 | # batch, time, channels -> batch, channels, time 32 | x = x.permute(0, 2, 1) 33 | else: 34 | if len(x.shape) == 2: 35 | x = x.unsqueeze(1) 36 | # time, batch, channels -> batch, channels, time 37 | x = x.permute(1, 2, 0) 38 | 39 | if x.shape[-1] < 8: 40 | # pad sequence to at least 8 so two max pools don't fail 41 | # necessary for when WinIT uses a small window 42 | x = F.pad(x, (0, 8 - x.shape[-1]), mode="constant", value=0) 43 | 44 | embedding = self.encoder(x) 45 | out = self.mlp(embedding) 46 | 47 | if get_embedding: 48 | return out, embedding 49 | else: 50 | return out 51 | 52 | 53 | class LSTM(nn.Module): 54 | def __init__(self, d_inp, n_classes, dim=128): 55 | super().__init__() 56 | self.encoder = nn.LSTM( 57 | d_inp, 58 | dim // 2, # half for bidirectional 59 | num_layers=3, 60 | batch_first=True, 61 | bidirectional=True, 62 | ) 63 | 64 | self.mlp = nn.Sequential( 65 | nn.Linear(dim, dim), 66 | nn.ReLU(), 67 | nn.Linear(dim, n_classes), 68 | ) 69 | 70 | def forward(self, x, _times, get_embedding=False, captum_input=False, show_sizes=False): 71 | if not captum_input: 72 | if len(x.shape) == 2: 73 | x = x.unsqueeze(1) 74 | # time, batch, channels -> batch, time, channels 75 | x = x.permute(1, 0, 2) 76 | elif len(x.shape) == 2: 77 | x = x.unsqueeze(0) 78 | 79 | embedding, _ = self.encoder(x) 80 | embedding = embedding.mean(dim=1) # mean over time 81 | out = self.mlp(embedding) 82 | 83 | if get_embedding: 84 | return out, embedding 85 | else: 86 | return out 87 | -------------------------------------------------------------------------------- /txai/models/run_model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def concat_all_dicts(dlist, org_v = False, ours = False): 4 | # Marries together all dictionaries 5 | # Will change based on output from model 6 | 7 | mother_dict = {k:[] for k in dlist[0].keys()} 8 | 9 | is_tensor_list = [] 10 | 11 | for d in dlist: 12 | for k in d.keys(): 13 | if k == 'smooth_src' and org_v: 14 | mother_dict[k].append(torch.stack(d[k], dim = -1)) 15 | else: 16 | mother_dict[k].append(d[k]) 17 | 18 | mother_dict['pred'] = torch.cat(mother_dict['pred'], dim = 0).cpu() 19 | mother_dict['pred_mask'] = torch.cat(mother_dict['pred_mask'], dim = 0).cpu() 20 | mother_dict['mask_logits'] = torch.cat(mother_dict['mask_logits'], dim = 0).cpu() 21 | if org_v: 22 | mother_dict['concept_scores'] = torch.cat(mother_dict['concept_scores'], dim = 0).cpu() 23 | mother_dict['ste_mask'] = torch.cat(mother_dict['ste_mask'], dim = 0).cpu() 24 | # [[(), ()], ... 24] 25 | mother_dict['smooth_src'] = torch.cat(mother_dict['smooth_src'], dim = 1).cpu() # Will be (T, B, d, ne) 26 | 27 | L = len(mother_dict['all_z']) 28 | if ours: 29 | mother_dict['all_z'] = ( 30 | torch.cat([mother_dict['all_z'][i][0] for i in range(L)], dim = 0).cpu(), 31 | torch.cat([mother_dict['all_z'][i][1] for i in range(L)], dim = 0).cpu() 32 | ) 33 | 34 | mother_dict['z_mask_list'] = torch.cat(mother_dict['z_mask_list'], dim = 0).cpu() 35 | 36 | return mother_dict 37 | 38 | def batch_forwards(model, X, times, batch_size = 64, org_v = False, ours=False): 39 | ''' 40 | Runs the model in batches for large datasets. Used to get lots of embeddings, outputs, etc. 41 | - Need to use this bc there's a specialized dictionary notation for output of the forward method (see concat_all_dicts) 42 | ''' 43 | 44 | iters = torch.arange(0, X.shape[1], step = batch_size) 45 | out_list = [] 46 | 47 | for i in range(len(iters)): 48 | if i == (len(iters) - 1): 49 | batch_X = X[:,iters[i]:,:] 50 | batch_times = times[:,iters[i]:] 51 | else: 52 | batch_X = X[:,iters[i]:iters[i+1],:] 53 | batch_times = times[:,iters[i]:iters[i+1]] 54 | 55 | with torch.no_grad(): 56 | out = model(batch_X, batch_times, captum_input = False) 57 | 58 | out_list.append(out) 59 | 60 | out_full = concat_all_dicts(out_list, org_v = org_v, ours=ours) 61 | 62 | return out_full 63 | 64 | def batch_forwards_TransformerMVTS(model, X, times, batch_size = 64): 65 | 66 | iters = torch.arange(0, X.shape[1], step = batch_size) 67 | out_list = [] 68 | z_list = [] 69 | 70 | for i in range(len(iters)): 71 | if i == (len(iters) - 1): 72 | batch_X = X[:,iters[i]:,:] 73 | batch_times = times[:,iters[i]:] 74 | else: 75 | batch_X = X[:,iters[i]:iters[i+1],:] 76 | batch_times = times[:,iters[i]:iters[i+1]] 77 | 78 | with torch.no_grad(): 79 | out, z, _ = model(batch_X, batch_times, captum_input = False, get_agg_embed = True) 80 | 81 | out_list.append(out) 82 | z_list.append(z) 83 | 84 | ztotal = torch.cat(z_list, dim = 0) 85 | outtotal = torch.cat(out_list, dim = 0) 86 | 87 | return outtotal, ztotal 88 | 89 | -------------------------------------------------------------------------------- /txai/prototypes/tune_ptypes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def tune_ptypes( 5 | model, 6 | ptype_optimizer, 7 | train_loader, 8 | num_epochs, 9 | sim_criterion, 10 | ): 11 | 12 | for epoch in range(num_epochs): 13 | 14 | sloss = [] 15 | 16 | for X, times, y, ids in train_loader: 17 | 18 | ptype_optimizer.zero_grad() 19 | 20 | out_dict = model(X, times, captum_input = True) 21 | 22 | # Just get ptype outs and compare to full outs: 23 | ptype_z = out_dict['ptypes'] 24 | full_z = out_dict['all_z'][0].detach() # Gradient stoppage 25 | 26 | sim_loss = sim_criterion(ptype_z, full_z) 27 | 28 | sim_loss.backward() 29 | ptype_optimizer.step() 30 | 31 | sloss.append(sim_loss.detach().clone().item()) 32 | 33 | print(f'Epoch: {epoch}: Loss = {np.mean(sloss):.4f}') 34 | 35 | # No validation for now -------------------------------------------------------------------------------- /txai/synth_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate_spikes import * -------------------------------------------------------------------------------- /txai/synth_data/lowvardetect.py: -------------------------------------------------------------------------------- 1 | import timesynth as ts 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from scipy.special import expit 5 | from scipy.signal import butter, lfilter, freqz 6 | import pickle as pkl 7 | import os, math, random 8 | from sklearn.preprocessing import OneHotEncoder 9 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 10 | import random 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from tqdm import trange, tqdm 15 | 16 | from txai.synth_data.synth_data_base import GenerateSynth, print_tuple, visualize_some, plot_vis_mv 17 | 18 | class LowVarDetect(GenerateSynth): 19 | 20 | def __init__(self, T, D): 21 | ''' 22 | D: dimension of samples (# sensors) 23 | T: length of time for each sample 24 | ''' 25 | super(LowVarDetect, self).__init__(T, D, 4) 26 | 27 | 28 | def generate_seq(self, class_num = 0): 29 | 30 | ''' 31 | class_num must be in [0,1,2] 32 | ''' 33 | 34 | assert class_num in [0,1,2,3], 'class_num must be in [0,1,2]' 35 | 36 | # Sample: 37 | samp = np.zeros((self.T, self.D)) 38 | 39 | for di in range(self.D): 40 | noise = ts.noise.GaussianNoise(std=1) 41 | x = ts.signals.NARMA(order=2,seed=random.seed()) 42 | x_ts = ts.TimeSeries(x, noise_generator=noise) 43 | x_sample, signals, errors = x_ts.sample(np.array(range(self.T))) 44 | samp[:,di] = x_sample 45 | 46 | # if class_num == 0: # Null class - make no modifications 47 | # return samp, [None] 48 | 49 | # Sample sequence length through random uniform 50 | #imp_sensors = np.random.choice(np.arange(self.D), size = (2,), replace = False) 51 | 52 | i = 0 if class_num in [0,1] else 1 53 | 54 | seqlen = np.random.randint(low = 10, high = 20, size = 1)[0] 55 | imp_time = np.random.randint(low = 20, high = self.T - 40, size = 1)[0] 56 | loc = -1.5 if class_num in [0,2] else 1.5 57 | samp[imp_time:(imp_time + seqlen),i] = np.random.normal(loc=loc, scale = 0.1, size = (seqlen,)) 58 | 59 | # Make coordinates: 60 | 61 | # Pick out coordinates: 62 | coords = list(zip(list(range(imp_time, imp_time+seqlen)), [i] * seqlen)) 63 | 64 | return samp, coords 65 | 66 | if __name__ == '__main__': 67 | 68 | gen = LowVarDetect(T = 200, D = 2) 69 | 70 | for i in range(5): 71 | train, val, test, gt_exps = gen.get_all_loaders(Ntrain=5000, Nval=100, Ntest=1000) 72 | 73 | dataset = { 74 | 'train_loader': train, 75 | 'val': val, 76 | 'test': test, 77 | 'gt_exps': gt_exps 78 | } 79 | 80 | plot_vis_mv(dataset) 81 | plt.savefig(f'lvd_example_split={i}.png') 82 | #exit() 83 | 84 | torch.save(dataset, '/datasets/LowVarDetect/split={}.pt'.format(i + 1)) 85 | 86 | print('Split {} -------------------------------'.format(i+1)) 87 | print('Val ' + '-'*20) 88 | print_tuple(val) 89 | print('\nTest' + '-'*20) 90 | print_tuple(test) 91 | 92 | print('GT EXP') 93 | print(gt_exps.shape) 94 | 95 | print('Visualizing') 96 | visualize_some(dataset, save_prefix = 'red_spike') -------------------------------------------------------------------------------- /txai/synth_data/redundant_spike.py: -------------------------------------------------------------------------------- 1 | import timesynth as ts 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from scipy.special import expit 5 | from scipy.signal import butter, lfilter, freqz 6 | import pickle as pkl 7 | import os, math 8 | from sklearn.preprocessing import OneHotEncoder 9 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 10 | import random 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from tqdm import trange, tqdm 15 | 16 | from txai.synth_data.synth_data_base import GenerateSynth, print_tuple, visualize_some 17 | 18 | class RedundantSpike(GenerateSynth): 19 | 20 | def __init__(self, T, D, number_spikes = 4): 21 | ''' 22 | D: dimension of samples (# sensors) 23 | T: length of time for each sample 24 | ''' 25 | super(RedundantSpike, self).__init__(T, D, 3) 26 | self.number_spikes = number_spikes 27 | 28 | 29 | def generate_seq(self, class_num = 0): 30 | 31 | ''' 32 | class_num must be in [0,1,2] 33 | ''' 34 | 35 | assert class_num in [0,1,2], 'class_num must be in [0,1,2]' 36 | 37 | # Sample: 38 | samp = np.zeros((self.T, self.D)) 39 | 40 | for di in range(self.D): 41 | noise = ts.noise.GaussianNoise(std=0.001) 42 | x = ts.signals.NARMA(order=2,seed=random.seed()) 43 | x_ts = ts.TimeSeries(x, noise_generator=noise) 44 | x_sample, signals, errors = x_ts.sample(np.array(range(self.T))) 45 | samp[:,di] = x_sample 46 | 47 | if class_num == 0: # Null class - make no modifications 48 | return samp, [None] 49 | 50 | prev_imp_sensors = [] 51 | prev_imp_time = [] 52 | num_spikes = np.random.binomial(n=self.number_spikes*2, p=0.5) 53 | num_spikes = max(num_spikes, 1) 54 | coords = [] 55 | fill_choice = np.max(np.abs(samp)) 56 | for _ in range(num_spikes): 57 | 58 | imp_sensor = np.random.choice(np.arange(self.D)) 59 | imp_time = np.random.choice(np.arange(self.T)) 60 | 61 | if class_num == 1: 62 | samp[imp_time, imp_sensor] = fill_choice * -5.0 63 | elif class_num == 2: 64 | samp[imp_time, imp_sensor] = fill_choice * 5.0 65 | 66 | coords.append((imp_time, imp_sensor)) 67 | 68 | return samp, coords 69 | 70 | if __name__ == '__main__': 71 | 72 | gen = RedundantSpike(T = 50, D = 4) 73 | 74 | for i in range(5): 75 | train, val, test, gt_exps = gen.get_all_loaders(Ntrain=5000, Nval=100, Ntest=1000) 76 | 77 | dataset = { 78 | 'train_loader': train, 79 | 'val': val, 80 | 'test': test, 81 | 'gt_exps': gt_exps 82 | } 83 | 84 | torch.save(dataset, '/datasets/RedundantSpike/split={}.pt'.format(i + 1)) 85 | 86 | print('Split {} -------------------------------'.format(i+1)) 87 | print('Val ' + '-'*20) 88 | print_tuple(val) 89 | print('\nTest' + '-'*20) 90 | print_tuple(test) 91 | 92 | print('GT EXP') 93 | print(gt_exps.shape) 94 | 95 | print('Visualizing') 96 | visualize_some(dataset, save_prefix = 'red_spike') -------------------------------------------------------------------------------- /txai/synth_data/trigtrack.py: -------------------------------------------------------------------------------- 1 | import timesynth as ts 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from scipy.special import expit 5 | from scipy.signal import butter, lfilter, freqz 6 | import pickle as pkl 7 | import os, math 8 | from sklearn.preprocessing import OneHotEncoder 9 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 10 | import random 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from tqdm import trange, tqdm 15 | 16 | from txai.synth_data.synth_data_base import GenerateSynth, print_tuple, visualize_some 17 | 18 | class TrigTrack(GenerateSynth): 19 | 20 | class_wavelen_map = { 21 | 0: 0.25, 22 | 1: 0.5, 23 | 2: 1.0, 24 | 3: 2.0 25 | } 26 | 27 | def __init__(self, T, D, noise = None): 28 | ''' 29 | D: dimension of samples (# sensors) 30 | T: length of time for each sample 31 | ''' 32 | super(TrigTrack, self).__init__(T, D, 4) 33 | 34 | self.important_sensor = np.random.choice([0, 1, 2, 3]) 35 | self.noise = noise 36 | 37 | 38 | def generate_seq(self, class_num = 0): 39 | 40 | ''' 41 | class_num must be in [0,1,2] 42 | ''' 43 | 44 | assert class_num in [0,1,2,3], 'class_num must be in [0,1,2]' 45 | 46 | # Sample: 47 | samp = np.zeros((self.T, self.D)) 48 | 49 | for i in range(self.D): 50 | if i == self.important_sensor: 51 | wave_len = self.class_wavelen_map[class_num] 52 | 53 | else: 54 | wave_len = np.random.choice(np.linspace(0.25, 2, num=50)) 55 | 56 | # Amplitude randomly sampled: 57 | amp = np.random.choice(np.linspace(-5.0, 5.0, num=50)) 58 | 59 | signal = amp * np.sin(wave_len * np.arange(self.T * 2)) 60 | 61 | # Multiply by noise: 62 | if self.noise is not None: 63 | signal = signal + np.random.normal(loc=0.0, scale = self.noise, size = signal.shape) 64 | 65 | # Choose random starting/ending point for signal: 66 | start = np.random.choice(np.arange(self.T - 2)) 67 | end = start + self.T 68 | 69 | samp[:,i] = signal[start:end] 70 | 71 | return samp, [(i, self.important_sensor) for i in range(self.T)] 72 | 73 | if __name__ == '__main__': 74 | 75 | gen = TrigTrack(T = 50, D = 4, noise = 0.25) 76 | print('noise', gen.noise) 77 | 78 | for i in range(5): 79 | train, val, test, gt_exps = gen.get_all_loaders(Ntrain=5000, Nval=100, Ntest=1000) 80 | 81 | dataset = { 82 | 'train_loader': train, 83 | 'val': val, 84 | 'test': test, 85 | 'gt_exps': gt_exps 86 | } 87 | 88 | torch.save(dataset, '/TimeSeriesCBM/datasets/TrigTrackNoise/split={}.pt'.format(i + 1)) 89 | 90 | print('Split {} -------------------------------'.format(i+1)) 91 | print('Val ' + '-'*20) 92 | print_tuple(val) 93 | print('\nTest' + '-'*20) 94 | print_tuple(test) 95 | 96 | print('GT EXP') 97 | print(gt_exps.shape) 98 | 99 | print('Visualizing') 100 | visualize_some(dataset, save_prefix = 'trig_fold{}'.format(i)) -------------------------------------------------------------------------------- /txai/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_transformer import train as train_simple -------------------------------------------------------------------------------- /txai/utils/baseline_comp/run_FIT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import sys, os 4 | from tqdm import tqdm 5 | from txai.baselines.FIT.TSX.explainers import FITExplainer 6 | from txai.baselines.FIT.TSX.generator import JointFeatureGenerator 7 | 8 | def run_FIT( 9 | model, 10 | X, 11 | time, 12 | FIT_obj, # FIT object with trained generator 13 | y = None, 14 | ): 15 | 16 | if y is None: 17 | model.eval() 18 | with torch.no_grad(): 19 | y = model(X, time) 20 | 21 | score = FIT_obj.attribute(x = X, y = y, times = time) 22 | 23 | return score 24 | 25 | def screen_FIT( 26 | model, 27 | test_tups, 28 | n_classes, 29 | generator = None, 30 | train_loader = None, 31 | val_loader = None, 32 | feature_size = 34, 33 | generator_epochs = 50, 34 | skip_eval = False, 35 | ): 36 | ''' 37 | 38 | ''' 39 | 40 | FIT = FITExplainer(model = model, n_classes = n_classes) 41 | 42 | if generator is None: 43 | # Train generator 44 | generator_model = JointFeatureGenerator( 45 | feature_size = feature_size, 46 | prediction_size = 1, 47 | data = 'custom', 48 | ) 49 | 50 | # Fit generator: 51 | FIT.fit_generator( 52 | generator_model, 53 | train_loader = train_loader, 54 | test_loader = val_loader, 55 | n_epochs = generator_epochs, 56 | ) 57 | 58 | else: 59 | FIT.generator = generator 60 | 61 | all_exp = [] 62 | 63 | if not skip_eval: 64 | for X, time, y in tqdm(test_tups): 65 | 66 | score = run_FIT(model, X, time, FIT, y = y) 67 | all_exp.append(score) 68 | 69 | return all_exp, FIT 70 | -------------------------------------------------------------------------------- /txai/utils/baseline_comp/run_WinIT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def run_winit(): 5 | pass 6 | 7 | 8 | def screen_winit(): 9 | pass -------------------------------------------------------------------------------- /txai/utils/baseline_comp/run_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def run_random( 4 | model = None, 5 | X = None, 6 | time_input = None, 7 | y = None, 8 | device = None,): 9 | 10 | # Really basic - just get a random explanation 11 | return torch.randn_like(X).squeeze() if X is not None else 0 12 | 13 | 14 | def screen_random( 15 | model, 16 | test_tuples, 17 | only_correct = True, 18 | device = None): 19 | ''' 20 | Screens over an entire test set to produce explanations for random Explainer 21 | 22 | - Assumes all input tensors are on same device 23 | 24 | test_tuples: list of tuples 25 | - [(X_0, time_0, y_0), ..., (X_N, time_N, y_N)] 26 | ''' 27 | 28 | out_exp = [] 29 | 30 | model.eval() 31 | for X, time, y in test_tuples: 32 | 33 | exp = run_random(X) 34 | 35 | out_exp.append(exp) 36 | 37 | return out_exp -------------------------------------------------------------------------------- /txai/utils/baseline_comp/screen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def screen_explainer(): 5 | 6 | pass -------------------------------------------------------------------------------- /txai/utils/cl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def basic_negative_sampling(batch, batch_ids, dataX, num_negatives): 6 | ''' 7 | batch: (B, T, d) 8 | batch_ids: () 9 | dataX: (T, Nx, d) 10 | num_negatives: int 11 | 12 | output: (B, num_negatives) - gives ints 13 | ''' 14 | 15 | mask = torch.randn(batch.shape[0], dataX.shape[1]) # Size (B, Nx) 16 | inds = torch.empty(batch.shape[0], num_negatives).long() 17 | for i, bid in enumerate(batch_ids): 18 | mask[i, bid] = -1e9 # Effectively ignoring 19 | inds[i,:] = mask[i,:].topk(k=num_negatives)[1] # Get indices 20 | 21 | # # randn, get top-k 22 | 23 | # #possible = mask.nonzero(as_tuple=True)[0].numpy() 24 | # inds = torch.from_numpy(np.random.choice(possible, size = (num_negatives,))) 25 | 26 | # mask = torch.zeros(dataX.shape[1]).bool(); mask[inds] = 1 27 | 28 | return inds 29 | 30 | 31 | @torch.no_grad() # No grad so gradients aren't carried into similarity computations on batch 32 | def in_batch_triplet_sampling(z_main, num_triplets_per_sample = 1): 33 | ''' 34 | Samples triplets from the batch and separates into anchors, positives, and negatives based 35 | on reference embedding similarity 36 | ''' 37 | 38 | z_main_cpu = z_main.detach().clone().cpu() 39 | 40 | # Get two rows of unique indices: 41 | B, d = z_main_cpu.shape 42 | anchor_inds = torch.arange(B) 43 | 44 | pmat = (np.ones((B, B)) - np.eye(B)) / (B - 1) 45 | 46 | all_samps_tensors = [] 47 | all_anchor_inds = [] 48 | 49 | for i in range(num_triplets_per_sample): 50 | 51 | samps = [np.random.choice(B, size = (2,), replace = True, p = pmat[i,:]) for i in range(B)] 52 | 53 | samps_mat = np.stack(samps, axis = 0) 54 | samps_tensor = torch.from_numpy(samps_mat).long() 55 | all_samps_tensors.append(samps_tensor) 56 | all_anchor_inds.append(anchor_inds.clone()) 57 | 58 | samps_tensor = torch.cat(all_samps_tensors, dim = 0) 59 | anchor_inds = torch.cat(all_anchor_inds).flatten() 60 | 61 | # Calculate similarities and get masks 62 | # Use euclidean distance bc this is default in triplet loss pytorch for now 63 | leftside = (z_main_cpu[anchor_inds] - z_main_cpu[samps_tensor[:,0]]).norm(p=2, dim = 1) 64 | rightside = (z_main_cpu[anchor_inds] - z_main_cpu[samps_tensor[:,1]]).norm(p=2, dim = 1) 65 | 66 | left_larger = (leftside > rightside) 67 | 68 | # Assign respective indices to each side 69 | positives = torch.where(~left_larger, samps_tensor[:,0], samps_tensor[:,1]) 70 | negatives = torch.where(left_larger, samps_tensor[:,0], samps_tensor[:,1]) # Larger distance is negative 71 | 72 | return anchor_inds, positives, negatives 73 | -------------------------------------------------------------------------------- /txai/utils/cl_metrics.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def sim_mat(full_z, mask_z): 8 | ''' 9 | Calculates similarity matrix bw all samples in embeddings 10 | 11 | NOTE: Very inefficient, don't use if you have a lot of embeddings 12 | ''' 13 | out_norm = F.normalize(full_z, dim = -1) 14 | out_masked_norm = F.normalize(mask_z, dim = -1) 15 | 16 | mat = np.zeros((out_norm.shape[0], out_masked_norm.shape[0])) 17 | for i in trange(out_norm.shape[0]): 18 | for j in range(i, out_masked_norm.shape[0]): 19 | mat[i,j] = torch.dot(out_norm[i,:], out_masked_norm[j,:]) 20 | 21 | return mat -------------------------------------------------------------------------------- /txai/utils/constants.py: -------------------------------------------------------------------------------- 1 | model_types = ['tsimple'] 2 | exp_methods = ['fit', 'dyna', 'winit', 'tsr', 'sgt+grad'] -------------------------------------------------------------------------------- /txai/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess import zip_x_time_y, process_PAM 2 | from .preprocess import process_Epilepsy, process_ECG 3 | from .preprocess import EpiDataset, decomposition_statistics 4 | from .synth import process_Synth -------------------------------------------------------------------------------- /txai/utils/data/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DatasetwInds(torch.utils.data.Dataset): 4 | def __init__(self, X, times, y): 5 | self.X = X 6 | self.times = times 7 | self.y = y 8 | 9 | def __len__(self): 10 | return self.X.shape[1] 11 | 12 | def __getitem__(self, idx): 13 | x = self.X[:,idx,:] 14 | T = self.times[:,idx] 15 | y = self.y[idx] 16 | return x, T, y, torch.tensor(idx).long().to(x.device) -------------------------------------------------------------------------------- /txai/utils/data/synth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from txai.synth_data.generate_spikes import SpikeTrainDataset 4 | from txai.baselines.FIT.data_generator.data.clean_state_data import StateTrainDataset 5 | 6 | spike_path = '/home/owq978/TimeSeriesXAI/datasets/Spike/' 7 | def process_Synth(split_no = 1, device = None, base_path = spike_path, regression = False, 8 | label_noise = None): 9 | 10 | split_path = os.path.join(base_path, 'split={}.pt'.format(split_no)) 11 | print("split_path:", split_path) 12 | 13 | D = torch.load(split_path) 14 | 15 | D['train_loader'].X = D['train_loader'].X.float().to(device) 16 | D['train_loader'].times = D['train_loader'].times.float().to(device) 17 | if regression: 18 | D['train_loader'].y = D['train_loader'].y.float().to(device) 19 | else: 20 | D['train_loader'].y = D['train_loader'].y.long().to(device) 21 | 22 | val = [] 23 | val.append(D['val'][0].float().to(device)) 24 | val.append(D['val'][1].float().to(device)) 25 | val.append(D['val'][2].long().to(device)) 26 | if regression: 27 | val[-1] = val[-1].float() 28 | D['val'] = tuple(val) 29 | 30 | test = [] 31 | test.append(D['test'][0].float().to(device)) 32 | test.append(D['test'][1].float().to(device)) 33 | test.append(D['test'][2].long().to(device)) 34 | if regression: 35 | test[-1] = test[-1].float() 36 | D['test'] = tuple(test) 37 | 38 | if label_noise is not None: 39 | # Find some samples in training to switch labels: 40 | 41 | to_flip = int(label_noise * D['train_loader'].y.shape[0]) 42 | to_flip = to_flip + 1 if (to_flip % 2 == 1) else to_flip # Add one if it isn't even 43 | 44 | flips = torch.randperm(D['train_loader'].y.shape[0])[:to_flip] 45 | 46 | max_label = D['train_loader'].y.max() 47 | 48 | for i in flips: 49 | D['train_loader'].y[i] = (D['train_loader'].y[i] + 1) % max_label 50 | 51 | return D -------------------------------------------------------------------------------- /txai/utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gaussian_time_samples(X, perturbation_freq = 0.2): 4 | ''' 5 | Samples out time points according to some frequency and perturbs every sensor 6 | at those chosen time points (with Gaussian noise). 7 | ''' 8 | 9 | squeeze_back = False 10 | if len(X.shape) > 2: 11 | X = X.squeeze() # Should squeeze out batch dimension 12 | squeeze_back = True 13 | 14 | # Choose N time points 15 | 16 | time_samp = torch.rand((X.shape[0],)) 17 | 18 | bool_samp = (time_samp < perturbation_freq) 19 | 20 | # At selected times, add Gaussian noise equivalent to feature-wide statistics: 21 | 22 | mu = torch.mean(X, dim=-1) 23 | std = torch.std(X, dim=-1) 24 | 25 | # Sample and mask out those time points that we don't perturb 26 | noise = [torch.normal(mu, std) for i in range(X.shape[0]) if bool_samp[i]] 27 | noise = torch.stack(noise) # Stack together normal noises 28 | 29 | Xpert = torch.where(bool_samp, noise, X) 30 | # Place noise where bool_samp is true, X where it's not 31 | 32 | if squeeze_back: 33 | Xpert = Xpert.unsqueeze(dim=1) 34 | 35 | return Xpert 36 | 37 | def random_time_mask(rate, size): 38 | ''' 39 | Assumes the size is (T,d) 40 | ''' 41 | 42 | size = tuple(size) 43 | mask = torch.zeros(size[0]) 44 | n = int(size[0] * rate) 45 | inds = torch.randperm(size[0])[:n] 46 | mask[inds] = 1 47 | mask = mask.unsqueeze(-1).repeat(1,size[1]) # Repeat along time dimensions 48 | 49 | return mask 50 | 51 | def dyna_norm_mask(Xtrain): 52 | # Returns a function that, when called, gives a dynamic normal mask application 53 | 54 | # Compute mean, std: 55 | std = Xtrain.std(unbiased = True, dim = 0) 56 | mu = Xtrain.mean(dim=0) 57 | 58 | def apply_mask(X, mask): 59 | to_replace = (mu + torch.randn_like(std) * std).unsqueeze(0).repeat(X.shape[0], 1, 1) 60 | return (mask * X) + (1 - mask) * to_replace 61 | 62 | return apply_mask 63 | 64 | -------------------------------------------------------------------------------- /txai/utils/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Poly1CrossEntropyLoss, SATLoss, SATGiniLoss, GiniLoss, GSATLoss 2 | from .eval import eval_on_tuple, eval_and_select, eval_mvts_transformer 3 | from .select_models import lower_bound_performance -------------------------------------------------------------------------------- /txai/utils/predictors/select_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from txai.utils.predictors.loss_cl import LabelConsistencyLoss 5 | 6 | # Old selection criteria, here for reference: ----------------- 7 | def lower_bound_performance(lower_bound): 8 | def func(metric, sparsity): 9 | if metric >= lower_bound: 10 | return (1 - sparsity) 11 | return 0 12 | 13 | return func 14 | 15 | def best_metric(): 16 | def func(metric, sparsity): 17 | return metric 18 | return func 19 | # ------------------------------------------------------------- 20 | 21 | def cosine_sim(out_dict, val = None): 22 | full_z, mask_z = out_dict['all_z'] 23 | sim = F.cosine_similarity(full_z, mask_z, dim = -1) 24 | return sim.mean().detach().cpu().item() 25 | 26 | def small_mask(out_dict, val = None): 27 | mask = out_dict['ste_mask'] 28 | return -1.0 * mask.float().detach().cpu().mean().item() 29 | 30 | 31 | def sim_small_mask(out_dict, val = None): 32 | return cosine_sim(out_dict) + small_mask(out_dict) 33 | 34 | def simloss_on_val_wboth(sim_criterion, lam = 1.0): 35 | # Early stopping for sim loss 36 | 37 | def f(out_dict, val = None): 38 | org_z, con_z = out_dict['all_z'] 39 | mlab, flab = out_dict['pred_mask'], out_dict['pred'] 40 | L = sim_criterion[0](org_z, con_z) + lam * sim_criterion[1](mlab, flab) 41 | return -1.0 * L # Need maximum, so return negative 42 | 43 | return f 44 | 45 | def simloss_on_val_laonly(sim_criterion): 46 | # Early stopping for sim loss - Label Alignment only 47 | def f(out_dict, val = None): 48 | mlab, flab = out_dict['pred_mask'], out_dict['pred'] 49 | L = sim_criterion(mlab, flab) 50 | return -1.0 * L # Need maximum, so return negative 51 | return f 52 | 53 | def simloss_on_val_cononly(sim_criterion): 54 | # Early stopping for sim loss - MBC only 55 | def f(out_dict, val = None): 56 | org_z, con_z = out_dict['all_z'] 57 | L = sim_criterion(org_z, con_z) 58 | return -1.0 * L # Need maximum, so return negative 59 | return f 60 | 61 | def cosine_sim_for_simclr(org_z, con_z): 62 | sim = -1.0 * F.cosine_similarity(org_z, con_z, dim = -1).mean() 63 | return sim -------------------------------------------------------------------------------- /txai/utils/shapebank/v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def gen_sample(template, increase = True): 5 | 6 | length = np.random.choice(np.arange(start=5, stop=45)) 7 | if increase == True: 8 | seq = np.linspace(-2, 2, num = int(length)) 9 | else: 10 | seq = np.linspace(2, -2, num = int(length)) 11 | 12 | seq *= np.random.normal(1.0, scale = 0.01, size = seq.shape) 13 | 14 | # Get mask w/sampled location: 15 | loc = np.random.choice(np.arange(start=0, stop=int(template.shape[0]-length))) 16 | 17 | a = torch.randn_like(template) 18 | a[loc:(loc+length),0,0] = torch.from_numpy(seq) 19 | 20 | return a 21 | 22 | def gen_sample_zero(template, increase = True): 23 | 24 | length = np.random.choice(np.arange(start=5, stop=45)) 25 | amp = np.random.normal(1.0, scale = 0.25) 26 | if increase == True: 27 | seq = np.linspace(-2, 2, num = int(length)) 28 | else: 29 | seq = np.linspace(2, -2, num = int(length)) 30 | 31 | seq *= np.random.normal(1.0, scale = 0.05, size = seq.shape) 32 | 33 | # Get mask w/sampled location: 34 | loc = np.random.choice(np.arange(start=0, stop=int(template.shape[0]-length))) 35 | 36 | a = torch.zeros_like(template) 37 | a[loc:(loc+length),0,0] = torch.from_numpy(seq) 38 | 39 | return a 40 | 41 | def gen_dataset(template, samps = 1000, device = None): 42 | inc = torch.cat([gen_sample(template, increase = True) for _ in range(samps)], dim = 1).to(device) 43 | dec = torch.cat([gen_sample(template, increase = False) for _ in range(samps)], dim = 1).to(device) 44 | 45 | times = torch.arange(inc.shape[0]).unsqueeze(-1).repeat(1, samps * 2).to(device) 46 | whole = torch.cat([inc, dec], dim=1).to(device) 47 | batch_id = torch.cat([torch.zeros(inc.shape[1]), torch.ones(dec.shape[1])]).to(device).long() 48 | return whole, times, batch_id 49 | 50 | def gen_dataset_zero(template, samps = 1000, device = None): 51 | inc = torch.cat([gen_sample_zero(template, increase = True) for _ in range(samps)], dim = 1).to(device) 52 | dec = torch.cat([gen_sample_zero(template, increase = False) for _ in range(samps)], dim = 1).to(device) 53 | 54 | times = torch.arange(inc.shape[0]).unsqueeze(-1).repeat(1, samps * 2).to(device) 55 | whole = torch.cat([inc, dec], dim=1) 56 | batch_id = torch.cat([torch.zeros(inc.shape[1]), torch.ones(dec.shape[1])]).to(device).long() 57 | return whole, times, batch_id -------------------------------------------------------------------------------- /txai/vis/vis_saliency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | def vis_one_saliency(X, exp, ax, fig, col_num): 6 | 7 | Xnp = X.detach().clone().cpu().numpy() 8 | enp = exp.detach().clone().cpu().numpy() 9 | T, d = Xnp.shape 10 | 11 | x_range = np.arange(T) 12 | 13 | for i in range(d): 14 | # Assumes heatmap: 15 | px, py = np.meshgrid(np.linspace(min(x_range), max(x_range), len(x_range) + 1), [min(Xnp[:,i]), max(Xnp[:,i])]) 16 | ax[i,col_num].plot(x_range, Xnp[:,i], color = 'black') 17 | cmap = ax[i,col_num].pcolormesh(px, py, np.expand_dims(enp[:,i], 0), alpha = 0.5, cmap = 'Greens') 18 | fig.colorbar(cmap, ax = ax[i][col_num]) 19 | 20 | def vis_one_saliency_univariate(X, exp, ax, fig): 21 | 22 | Xnp = X.detach().clone().cpu().numpy() 23 | enp = exp.detach().clone().cpu().numpy() 24 | T, d = Xnp.shape 25 | 26 | assert d == 1, 'vis_one_saliency_univariate is only for univariate inputs' 27 | 28 | x_range = np.arange(T) 29 | 30 | print('enp', enp.shape) 31 | 32 | # Assumes heatmap: 33 | px, py = np.meshgrid(np.linspace(min(x_range), max(x_range), len(x_range) + 1), [min(Xnp[:,0]), max(Xnp[:,0])]) 34 | ax.plot(x_range, Xnp[:,0], color = 'black') 35 | cmap = ax.pcolormesh(px, py, enp, alpha = 0.5, cmap = 'Greens') 36 | fig.colorbar(cmap, ax = ax) 37 | --------------------------------------------------------------------------------