├── modules ├── __init__.py └── Hawkes │ ├── __init__.py │ ├── tools │ ├── __init__.py │ ├── BasisFunction.py │ └── Quasi_Newton.py │ ├── Hawkes_C.pyx │ └── model.py ├── transformer_helpers ├── Constants.py ├── Modules.py ├── Layers.py └── SubLayers.py ├── .gitignore ├── DualTPP_diagram.png ├── saved_models ├── .DS_Store ├── training_wgan_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_count_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_count_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_seq2seq_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_wgan_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_count_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_nll_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_seq2seq_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_seq2seq_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_transformer_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_transformer_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_wgan_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_count_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_seq2seq_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_transformer_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_wgan_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_nll_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_transformer_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_count_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_wgan_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_nll_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_seq2seq_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_transformer_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 └── training_rmtpp_mse_var_comp_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── saved_models_15thaug_epochs10 ├── .DS_Store ├── training_count_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_seq2seq_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_wgan_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_wgan_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_count_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_count_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_nll_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_seq2seq_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_transformer_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_wgan_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_count_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_seq2seq_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_transformer_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_transformer_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_wgan_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_sin │ ├── checkpoint │ ├── cp_sin.ckpt.index │ └── cp_sin.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_seq2seq_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_transformer_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_count_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_comp_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_seq2seq_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_wgan_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_var_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_transformer_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_rmtpp_mse_comp_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 └── training_rmtpp_mse_var_comp_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── requirements.txt ├── saved_models_16thaug_baselines ├── training_seq2seq_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_wgan_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_wgan_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_seq2seq_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_transformer_taxi │ ├── checkpoint │ ├── cp_taxi.ckpt.index │ └── cp_taxi.ckpt.data-00000-of-00001 ├── training_wgan_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_seq2seq_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_transformer_Trump │ ├── checkpoint │ ├── cp_Trump.ckpt.index │ └── cp_Trump.ckpt.data-00000-of-00001 ├── training_transformer_911_ems │ ├── checkpoint │ ├── cp_911_ems.ckpt.index │ └── cp_911_ems.ckpt.data-00000-of-00001 ├── training_seq2seq_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── training_wgan_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 └── training_transformer_911_traffic │ ├── checkpoint │ ├── cp_911_traffic.ckpt.index │ └── cp_911_traffic.ckpt.data-00000-of-00001 ├── script.sh ├── .idea └── vcs.xml ├── README.md ├── transformer_utils.py ├── generator.py └── main.py /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformer_helpers/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0. 2 | 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | pp_seq2seq/nmt/*.pyc 3 | pp_seq2seq/nmt/data/* 4 | -------------------------------------------------------------------------------- /modules/Hawkes/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import simulator 2 | from .model import estimator 3 | -------------------------------------------------------------------------------- /DualTPP_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/DualTPP_diagram.png -------------------------------------------------------------------------------- /saved_models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/.DS_Store -------------------------------------------------------------------------------- /saved_models/training_wgan_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_wgan_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_transformer_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_transformer_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_wgan_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/.DS_Store -------------------------------------------------------------------------------- /saved_models/training_count_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_transformer_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_wgan_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_transformer_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.2.4 2 | tensorflow-addons==0.9.1 3 | tensorflow==2.1.0 4 | torch==1.7.1+cu101 5 | properscoring 6 | cvxpy 7 | pandas==0.25.3 8 | -------------------------------------------------------------------------------- /saved_models/training_count_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_wgan_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_wgan_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_count_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_wgan_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_seq2seq_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_transformer_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_wgan_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_sin.ckpt" 2 | all_model_checkpoint_paths: "cp_sin.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_taxi.ckpt" 2 | all_model_checkpoint_paths: "cp_taxi.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_seq2seq_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_transformer_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_wgan_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_Trump.ckpt" 2 | all_model_checkpoint_paths: "cp_Trump.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_transformer_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_ems.ckpt" 2 | all_model_checkpoint_paths: "cp_911_ems.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_transformer_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_transformer_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cp_911_traffic.ckpt" 2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt" 3 | -------------------------------------------------------------------------------- /saved_models/training_count_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py 911_traffic all --output_dir Outputs_ignore_post --saved_models saved_models_15thaug_epochs10 --out_bin_sz 3 --no_rescale_rmtpp_params 4 | -------------------------------------------------------------------------------- /modules/Hawkes/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .Quasi_Newton import Quasi_Newton,merge_stg 2 | from .BasisFunction import linear_COS, loglinear_COS, linear_CBS, loglinear_CBS, plinear, linear_SSM 3 | -------------------------------------------------------------------------------- /saved_models/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /transformer_helpers/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import tensorflow as tf 5 | 6 | 7 | class ScaledDotProductAttention(tf.keras.Model): 8 | """ Scaled Dot-Product Attention """ 9 | 10 | def __init__( 11 | self, temperature, attn_dropout=0.2, 12 | name='ScaledDotProductAttention', **kwargs): 13 | super(ScaledDotProductAttention, self).__init__(name=name, **kwargs) 14 | 15 | self.temperature = temperature 16 | #self.dropout = nn.Dropout(attn_dropout) 17 | self.dropout = tf.keras.layers.Dropout(attn_dropout) 18 | 19 | def call(self, q, k, v, mask=None): 20 | #attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 21 | attn = tf.matmul(q / self.temperature, tf.transpose(k, perm=[0, 1, 3, 2])) 22 | 23 | if mask is not None: 24 | #attn = attn.masked_fill(mask, -1e9) 25 | attn = tf.where(mask, attn, -1e9) 26 | 27 | attn = self.dropout(tf.nn.softmax(attn, axis=-1)) 28 | #output = torch.matmul(attn, v) 29 | output = tf.matmul(attn, v) 30 | 31 | return output, attn 32 | -------------------------------------------------------------------------------- /transformer_helpers/Layers.py: -------------------------------------------------------------------------------- 1 | #import torch.nn as nn 2 | import tensorflow as tf 3 | from transformer_helpers.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | 6 | class EncoderLayer(tf.keras.Model): 7 | """ Compose with two layers """ 8 | 9 | def __init__( 10 | self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True, 11 | name='EncoderLayer', **kwargs): 12 | super(EncoderLayer, self).__init__(name=name, **kwargs) 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before) 15 | self.pos_ffn = PositionwiseFeedForward( 16 | d_model, d_inner, dropout=dropout, normalize_before=normalize_before) 17 | 18 | def call(self, enc_input, feats, non_pad_mask=None, slf_attn_mask=None): 19 | enc_output, enc_slf_attn = self.slf_attn( 20 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 21 | non_pad_mask = tf.cast(tf.expand_dims(non_pad_mask, axis=-1), tf.float32) 22 | 23 | #enc_output = tf.concat([enc_output, feats], axis=-1) 24 | #enc_output *= non_pad_mask 25 | 26 | enc_output = self.pos_ffn(enc_output) 27 | #enc_output *= non_pad_mask 28 | 29 | return enc_output, enc_slf_attn 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Long Horizon Forecasting With Temporal Point Processes 2 | 3 | ![DualTPP Diagram](DualTPP_diagram.png) 4 | 5 | This is the code produced as part of the paper _Long Horizon Forecasting With Temporal Point Processes_ 6 | 7 | > "Long Horizon Forecasting With Temporal Point Processes" 8 | > Prathamesh Deshpande, Kamlesh Marathe, Abir De, Sunita Sarawagi. WSDM 2021. [arXiv:2101.02815](https://arxiv.org/abs/2101.02815) 9 | 10 | ## Packages needed 11 | Specified in [requirements](requirements.txt). 12 | 13 | ## Dataset Download 14 | We have provided all the datasets used in our experiments [here](https://drive.google.com/drive/folders/1b1KUwkeIqIViPZoRZzbPAzKeNn7P1OD-?usp=sharing). 15 | 16 | Please download the `data/` folder add place it in the [DualTPP](https://github.com/pratham16cse/DualTPP) directory. 17 | 18 | ## Experiment execution 19 | To run the code to reproduce the results, please use this [script](script.sh) \[ Under development, more datasets will be soon added to the script\]. 20 | 21 | ## Output 22 | All the outputs will be stored in the `` directory. 23 | 24 | The numbers reported in Table 2 of the [paper](https://arxiv.org/abs/2101.02815) will be stored in `output_dir/results_.json` and `output_dir/results_.txt` files. 25 | 26 | ## Parameters Description 27 | Under Development 28 | 29 | ## Contact 30 | For any queries related to library versions, datasets, script, and results please contact us here: 31 | 32 | Email: prathameshsdeshpande@gmail.com 33 | 34 | Whatsapp: +91 9043751980 35 | -------------------------------------------------------------------------------- /transformer_helpers/SubLayers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import transformer_helpers.Constants as Constants 7 | from transformer_helpers.Modules import ScaledDotProductAttention 8 | 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | from tensorflow.keras import layers 12 | import tensorflow_probability as tfp 13 | import tensorflow_addons as tfa 14 | 15 | 16 | class MultiHeadAttention(tf.keras.Model): 17 | """ Multi-Head Attention module """ 18 | 19 | def __init__( 20 | self, n_head, d_model, d_k, d_v, dropout=0.1, normalize_before=True, 21 | name='MultiHeadAttention', **kwargs): 22 | super(MultiHeadAttention, self).__init__(name=name, **kwargs) 23 | 24 | self.normalize_before = normalize_before 25 | self.n_head = n_head 26 | self.d_k = d_k 27 | self.d_v = d_v 28 | 29 | self.w_qs = layers.Dense(n_head * d_k, use_bias=False) 30 | self.w_ks = layers.Dense(n_head * d_k, use_bias=False) 31 | self.w_vs = layers.Dense(n_head * d_v, use_bias=False) 32 | # Note: layers.Dense uses xavier_uniform_ initialization by default 33 | #nn.init.xavier_uniform_(self.w_qs.weight) 34 | #nn.init.xavier_uniform_(self.w_ks.weight) 35 | #nn.init.xavier_uniform_(self.w_vs.weight) 36 | 37 | self.fc = layers.Dense(d_model) 38 | #nn.init.xavier_uniform_(self.fc.weight) 39 | 40 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 41 | 42 | #self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 43 | self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6) 44 | #self.dropout = nn.Dropout(dropout) 45 | self.dropout = tf.keras.layers.Dropout(dropout) 46 | 47 | def call(self, q, k, v, mask=None): 48 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 49 | sz_b, len_q, len_k, len_v = q.shape[0], q.shape[1], k.shape[1], v.shape[1] 50 | 51 | residual = q 52 | if self.normalize_before: 53 | q = self.layer_norm(q) 54 | 55 | # Pass through the pre-attention projection: b x lq x (n*dv) 56 | # Separate different heads: b x lq x n x dv 57 | q = tf.reshape(self.w_qs(q), [sz_b, len_q, n_head, d_k]) 58 | k = tf.reshape(self.w_ks(k), [sz_b, len_k, n_head, d_k]) 59 | v = tf.reshape(self.w_vs(v), [sz_b, len_v, n_head, d_v]) 60 | 61 | # Transpose for attention dot product: b x n x lq x dv 62 | #q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 63 | q = tf.transpose(q, perm=[0, 2, 1, 3]) 64 | k = tf.transpose(k, perm=[0, 2, 1, 3]) 65 | v = tf.transpose(v, perm=[0, 2, 1, 3]) 66 | 67 | if mask is not None: 68 | mask = tf.expand_dims(mask, axis=1) # For head axis broadcasting. 69 | 70 | output, attn = self.attention(q, k, v, mask=mask) 71 | 72 | # Transpose to move the head dimension back: b x lq x n x dv 73 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 74 | #output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 75 | output = tf.reshape(tf.transpose(output, perm=[0, 2, 1, 3]), [sz_b, len_q, -1]) 76 | output = self.dropout(self.fc(output)) 77 | output += residual 78 | 79 | if not self.normalize_before: 80 | output = self.layer_norm(output) 81 | return output, attn 82 | 83 | 84 | class PositionwiseFeedForward(tf.keras.Model): 85 | """ Two-layer position-wise feed-forward neural network. """ 86 | 87 | def __init__( 88 | self, d_in, d_hid, dropout=0.1, normalize_before=True, 89 | name='PositionwiseFeedForward', **kwargs): 90 | super(PositionwiseFeedForward, self).__init__(name=name, **kwargs) 91 | 92 | self.normalize_before = normalize_before 93 | 94 | self.w_1 = layers.Dense(d_hid) # position-wise 95 | self.w_2 = layers.Dense(d_in) # position-wise 96 | 97 | #self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 98 | self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6) 99 | #self.dropout = nn.Dropout(dropout) 100 | self.dropout = tf.keras.layers.Dropout(dropout) 101 | 102 | 103 | def call(self, x): 104 | residual = x 105 | if self.normalize_before: 106 | x = self.layer_norm(x) 107 | 108 | x = tfa.activations.gelu(self.w_1(x)) 109 | x = self.dropout(x) 110 | x = self.w_2(x) 111 | x = self.dropout(x) 112 | x = x + residual 113 | 114 | if not self.normalize_before: 115 | x = self.layer_norm(x) 116 | return x 117 | -------------------------------------------------------------------------------- /modules/Hawkes/Hawkes_C.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import gamma,digamma 3 | cimport numpy as np 4 | 5 | def LG_kernel_SUM_exp_cython(np.ndarray[np.float64_t,ndim=1] T, np.ndarray[np.float64_t,ndim=1] alpha, np.ndarray[np.float64_t,ndim=1] beta): 6 | cdef int m = len(alpha) 7 | cdef int n = T.shape[0] 8 | cdef np.ndarray[np.float64_t,ndim=1] l = np.zeros(n, dtype=np.float64) 9 | cdef np.ndarray[np.float64_t,ndim=1] l_i = np.zeros(n, dtype=np.float64) 10 | cdef np.ndarray[np.float64_t,ndim=1] dl_a = np.zeros(n, dtype=np.float64) 11 | cdef np.ndarray[np.float64_t,ndim=1] dl_b = np.zeros(n, dtype=np.float64) 12 | cdef dict dl = {} 13 | 14 | for i in range(m): 15 | l_i,dl_a,dl_b = LG_kernel_SUM_exp_i_cython(T,alpha[i],beta[i]) 16 | l = l + l_i 17 | dl.update({('alpha',i):dl_a,('beta',i):dl_b}) 18 | 19 | return [l,dl] 20 | 21 | 22 | def LG_kernel_SUM_exp_i_cython(np.ndarray[np.float64_t,ndim=1] T, double alpha, double beta): 23 | 24 | cdef int n = T.shape[0] 25 | 26 | cdef np.ndarray[np.float64_t,ndim=1] l = np.zeros(n, dtype=np.float64) 27 | cdef np.ndarray[np.float64_t,ndim=1] dl_a = np.zeros(n, dtype=np.float64) 28 | cdef np.ndarray[np.float64_t,ndim=1] dl_b = np.zeros(n, dtype=np.float64) 29 | 30 | cdef np.ndarray[np.float64_t,ndim=1] dt = T[1:] - T[:-1] 31 | cdef np.ndarray[np.float64_t,ndim=1] r = np.exp(-beta*dt) 32 | 33 | cdef double x = 0.0 34 | cdef double x_a = 0.0 35 | cdef double x_b = 0.0 36 | 37 | cdef int i; 38 | 39 | for i in range(n-1): 40 | x = ( x + alpha*beta ) * r[i] 41 | x_a = ( x_a + beta ) * r[i] 42 | x_b = ( x_b + alpha ) * r[i] - x*dt[i] 43 | 44 | l[i+1] = x 45 | dl_a[i+1] = x_a 46 | dl_b[i+1] = x_b 47 | 48 | return [l,dl_a,dl_b] 49 | 50 | 51 | def LG_kernel_SUM_pow_cython(np.ndarray[np.float64_t,ndim=1] T,double k, double p, double c): 52 | 53 | cdef int n = T.shape[0] 54 | 55 | cdef np.ndarray[np.float64_t,ndim=1] l = np.zeros(n, dtype=np.float64) 56 | cdef np.ndarray[np.float64_t,ndim=1] dl_p = np.zeros(n, dtype=np.float64) 57 | cdef np.ndarray[np.float64_t,ndim=1] dl_k = np.zeros(n, dtype=np.float64) 58 | cdef np.ndarray[np.float64_t,ndim=1] dl_c = np.zeros(n, dtype=np.float64) 59 | 60 | cdef int num_div = 16 61 | cdef double delta = 1.0/num_div 62 | cdef np.ndarray[np.float64_t,ndim=1] s = np.linspace(-9,9,num_div*18+1) 63 | cdef np.ndarray[np.float64_t,ndim=1] log_phi = s-np.exp(-s) 64 | cdef np.ndarray[np.float64_t,ndim=1] log_dphi = log_phi + np.log(1+np.exp(-s)) 65 | cdef np.ndarray[np.float64_t,ndim=1] phi = np.exp(log_phi) # phi = np.exp(s-np.exp(-s)) 66 | cdef np.ndarray[np.float64_t,ndim=1] dphi = np.exp(log_dphi) # dphi = phi*(1+np.exp(-s)) 67 | 68 | cdef np.ndarray[np.float64_t,ndim=1] H = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) 69 | cdef np.ndarray[np.float64_t,ndim=1] H_p = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) * (log_phi-digamma(p)) 70 | cdef np.ndarray[np.float64_t,ndim=1] H_c = delta * k * np.exp( log_dphi + p*log_phi - c*phi ) / gamma(p) * (-1) 71 | 72 | cdef np.ndarray[np.float64_t,ndim=1] g = np.zeros_like(s) 73 | 74 | cdef int i 75 | 76 | for i in range(n-1): 77 | g = (g+1)*np.exp( - phi*(T[i+1]-T[i]) ) 78 | l[i+1] = g.dot(H) 79 | dl_k[i+1] = l[i+1]/k 80 | dl_p[i+1] = g.dot(H_p) 81 | dl_c[i+1] = g.dot(H_c) 82 | 83 | return [l,dl_k,dl_p,dl_c] 84 | 85 | def preprocess_data_nonpara_cython(np.ndarray[np.float64_t,ndim=1] T, np.ndarray[np.float64_t,ndim=1] bin_edge, double en): 86 | 87 | cdef double support = bin_edge[-1] 88 | cdef double bin_width = bin_edge[1] - bin_edge[0] 89 | cdef int i,j 90 | cdef int n = T.shape[0] 91 | cdef int m = bin_edge.shape[0] - 1 # the number of bins 92 | 93 | ###### dl 94 | cdef list index_tgt_list = [] 95 | cdef list index_trg_list = [] 96 | 97 | for i in range(n): 98 | for j in range(i-1,-1,-1): 99 | if T[i] - T[j] < support: 100 | index_tgt_list.append(i) 101 | index_trg_list.append(j) 102 | else: 103 | break 104 | 105 | cdef np.ndarray[np.int64_t,ndim=1] index_tgt = np.array(index_tgt_list) 106 | cdef np.ndarray[np.int64_t,ndim=1] index_trg = np.array(index_trg_list) 107 | cdef np.ndarray[np.int64_t,ndim=1] index_bin = np.searchsorted(bin_edge,T[index_tgt]-T[index_trg],side='right') - 1 108 | 109 | ### 110 | cdef np.ndarray[np.float64_t,ndim=2] dl = np.zeros((m,n)) 111 | 112 | for i in range(index_tgt.shape[0]): 113 | dl[index_bin[i],index_tgt[i]] += 1.0 114 | 115 | ###### dInt 116 | cdef np.ndarray[np.float64_t,ndim=2] dInt = np.zeros((m,n)) 117 | cdef np.ndarray[np.int64_t,ndim=1] index = np.searchsorted(bin_edge,en-T,side='right') - 1 118 | cdef np.ndarray[np.float64_t,ndim=1] d_from_left = en - T - bin_edge[index] 119 | cdef int index_i 120 | 121 | for i in range(n): 122 | index_i = index[i] 123 | if index_i < m: 124 | dInt[index_i,i] = d_from_left[i] 125 | if index_i > 0: 126 | for j in range(index_i): 127 | dInt[j,i] = bin_width 128 | 129 | return [dl,dInt.sum(axis=-1)] 130 | -------------------------------------------------------------------------------- /transformer_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import tensorflow as tf 7 | from tensorflow import keras 8 | 9 | #from transformer_helpers.Models import get_non_pad_mask 10 | from models import get_non_pad_mask 11 | import transformer_helpers.Constants as Constants 12 | 13 | 14 | def compute_event(event, non_pad_mask): 15 | """ Log-likelihood of events. """ 16 | 17 | non_pad_mask_not = tf.cast(non_pad_mask==Constants.PAD, tf.float32) 18 | # add 1e-9 in case some events have 0 likelihood 19 | event += math.pow(10, -9) 20 | #event.masked_fill_(~non_pad_mask.bool(), 1.0) 21 | event = tf.where(event==non_pad_mask_not, 1., event) 22 | 23 | result = tf.math.log(event) 24 | return result 25 | 26 | 27 | def compute_integral_biased(all_lambda, time, non_pad_mask): 28 | """ Log-likelihood of non-events, using linear interpolation. """ 29 | 30 | diff_time = (time[:, 1:] - time[:, :-1]) * non_pad_mask[:, 1:] 31 | diff_lambda = (all_lambda[:, 1:] + all_lambda[:, :-1]) * non_pad_mask[:, 1:] 32 | 33 | biased_integral = diff_lambda * diff_time 34 | result = 0.5 * biased_integral 35 | return result 36 | 37 | 38 | def compute_integral_unbiased(model, data, time, non_pad_mask, type_mask): 39 | """ Log-likelihood of non-events, using Monte Carlo integration. """ 40 | 41 | num_samples = 100 42 | 43 | #diff_time = (time[:, 1:] - time[:, :-1]) * non_pad_mask[:, 1:] 44 | diff_time = (time) * non_pad_mask 45 | #temp_time = diff_time.unsqueeze(2) * \ 46 | # torch.rand([*diff_time.size(), num_samples], device=data.device) 47 | temp_time = tf.expand_dims(diff_time, axis=2) * \ 48 | tf.random.uniform([*(diff_time.shape.as_list()), num_samples]) 49 | #temp_time /= (time[:, :-1] + 1).unsqueeze(2) 50 | #temp_time /= tf.expand_dims((time[:, :-1] + 1), axis=2) 51 | temp_time /= tf.expand_dims((tf.cumsum(time, axis=-1) + 1), axis=2) 52 | 53 | temp_hid = model.linear(data) 54 | #temp_hid = torch.sum(temp_hid * type_mask[:, 1:, :], dim=2, keepdim=True) 55 | temp_hid = tf.reduce_sum(temp_hid * type_mask, axis=2, keepdims=True) 56 | 57 | #all_lambda = F.softplus(temp_hid + model.alpha * temp_time, threshold=10) 58 | # No threshold parameter for tf.nn.softplus is available, (or not required). 59 | all_lambda = tf.nn.softplus(temp_hid + model.alpha * temp_time) 60 | #all_lambda = torch.sum(all_lambda, dim=2) / num_samples 61 | all_lambda = tf.reduce_sum(all_lambda, axis=2) / num_samples 62 | 63 | unbiased_integral = all_lambda * diff_time 64 | return unbiased_integral 65 | 66 | 67 | def log_likelihood(model, data, time, types): 68 | """ Log-likelihood of sequence. """ 69 | 70 | #non_pad_mask = get_non_pad_mask(types).squeeze(2) 71 | non_pad_mask = get_non_pad_mask(types) 72 | 73 | #type_mask = torch.zeros([*types.size(), model.num_types], device=data.device) 74 | #type_mask = tf.zeros([*(types.shape.as_list()), model.num_types]) 75 | #for i in range(model.num_types): 76 | # #type_mask[:, :, i] = (types == i + 1).bool().to(data.device) 77 | # type_mask[:, :, i] = tf.cast((types == i + 1), tf.bool) 78 | type_ids = tf.expand_dims(tf.expand_dims(tf.range(1, model.num_types+1, dtype=tf.float32), axis=0), axis=1) 79 | type_mask = tf.cast((tf.expand_dims(types, axis=-1) == type_ids), tf.float32) 80 | 81 | all_hid = model.linear(data) 82 | #all_lambda = F.softplus(all_hid, threshold=10) 83 | all_lambda = tf.nn.softplus(all_hid) 84 | #type_lambda = torch.sum(all_lambda * type_mask, dim=2) 85 | type_lambda = tf.reduce_sum(all_lambda * type_mask, axis=2) 86 | 87 | # event log-likelihood 88 | event_ll = compute_event(type_lambda, non_pad_mask) 89 | #event_ll = torch.sum(event_ll, dim=-1) 90 | event_ll = tf.reduce_sum(event_ll, axis=-1) 91 | 92 | # non-event log-likelihood, either numerical integration or MC integration 93 | # non_event_ll = compute_integral_biased(type_lambda, time, non_pad_mask) 94 | non_event_ll = compute_integral_unbiased(model, data, time, non_pad_mask, type_mask) 95 | #non_event_ll = torch.sum(non_event_ll, dim=-1) 96 | non_event_ll = tf.reduce_sum(non_event_ll, axis=-1) 97 | 98 | return event_ll, non_event_ll 99 | 100 | 101 | def type_loss(prediction, types, loss_func): 102 | """ Event prediction loss, cross entropy or label smoothing. """ 103 | 104 | # convert [1,2,3] based types to [0,1,2]; also convert padding events to -1 105 | truth = types[:, 1:] - 1 106 | prediction = prediction[:, :-1, :] 107 | 108 | #pred_type = torch.max(prediction, dim=-1)[1] 109 | pred_type = tf.cast(tf.argmax(prediction, axis=-1), tf.float32) 110 | #correct_num = torch.sum(pred_type == truth) 111 | correct_num = tf.reduce_sum(tf.cast(pred_type == truth, tf.float32)) 112 | 113 | # compute cross entropy loss 114 | if isinstance(loss_func, LabelSmoothingLoss): 115 | loss = loss_func(prediction, truth) 116 | else: 117 | #loss = loss_func(prediction.transpose(1, 2), truth) 118 | loss = loss_func(truth, prediction) 119 | 120 | #loss = torch.sum(loss) 121 | loss = tf.reduce_sum(loss) 122 | return loss, correct_num 123 | 124 | 125 | def time_loss(prediction, event_time): 126 | """ Time prediction loss. """ 127 | 128 | #prediction.squeeze_(-1) 129 | #tf.squeeze(prediction, axis=-1) 130 | 131 | #true = event_time[:, 1:] - event_time[:, :-1] 132 | true = event_time 133 | #prediction = prediction[:, :-1] 134 | 135 | # event time gap prediction 136 | diff = prediction - true 137 | #se = torch.sum(diff * diff) 138 | se = tf.reduce_sum(diff * diff) 139 | return se 140 | 141 | 142 | class LabelSmoothingLoss(tf.keras.losses.Loss): 143 | """ 144 | With label smoothing, 145 | KL-divergence between q_{smoothed ground truth prob.}(w) 146 | and p_{prob. computed by model}(w) is minimized. 147 | """ 148 | def __init__( 149 | self, label_smoothing, tgt_vocab_size, ignore_index=-100, 150 | reduction=keras.losses.Reduction.AUTO, 151 | name='LabelSmoothingLoss'): 152 | 153 | assert 0.0 < label_smoothing <= 1.0 154 | super(LabelSmoothingLoss, self).__init__(reduction=reduction, name=name) 155 | 156 | self.eps = label_smoothing 157 | self.num_classes = tgt_vocab_size 158 | self.ignore_index = ignore_index 159 | 160 | def call(self, output, target): 161 | """ 162 | output (FloatTensor): (batch_size) x n_classes 163 | target (LongTensor): batch_size 164 | """ 165 | 166 | non_pad_mask = tf.cast((target != tf.cast((self.ignore_index), tf.float32)), tf.float32) 167 | 168 | #target[target.eq(self.ignore_index)] = 0 169 | #target[target == (self.ignore_index)] = 0 170 | target = tf.where(target==self.ignore_index, 0, target) 171 | #one_hot = F.one_hot(target, num_classes=self.num_classes).float() 172 | one_hot = tf.cast(tf.one_hot(tf.cast(target, tf.int64), depth=self.num_classes), tf.float32) 173 | one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / self.num_classes 174 | 175 | #log_prb = F.log_softmax(output, dim=-1) 176 | log_prb = tf.nn.log_softmax(output, axis=-1) 177 | #loss = -(one_hot * log_prb).sum(dim=-1) 178 | loss = tf.reduce_sum(-(one_hot * log_prb), axis=-1) 179 | loss = loss * non_pad_mask 180 | return loss 181 | -------------------------------------------------------------------------------- /modules/Hawkes/tools/BasisFunction.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import sys,time,datetime,copy,subprocess,itertools,pickle,warnings,numbers 5 | 6 | import numpy as np 7 | import scipy as sp 8 | import pandas as pd 9 | from matplotlib import pyplot as plt 10 | import matplotlib as mpl 11 | 12 | from scipy.interpolate import BSpline 13 | import scipy.sparse as sparse 14 | from scipy.sparse import linalg as spla 15 | 16 | ###################################################################################### core class 17 | class BasisFunctionExpansion_1D: 18 | 19 | def __init__(self,itv=None,num_basis=10): 20 | self.itv = itv 21 | self.num_basis = num_basis 22 | self.coef = np.zeros(num_basis) 23 | 24 | def set_coef(self,coef): 25 | self.coef = coef 26 | return self 27 | 28 | def Matrix_BasisFunction(self,x): 29 | pass 30 | 31 | def d_Matrix_BasisFunction(self,x): 32 | pass 33 | 34 | def set_x(self,x): 35 | [st,en] = self.itv 36 | self.x = x 37 | self.A = self.Matrix_BasisFunction(x) 38 | self.A_t = self.A.transpose() 39 | self.A_sp = sparse.csc_matrix(self.A) 40 | self.A_t_sp = sparse.csc_matrix(self.A_t) 41 | bin_edge = np.hstack([st,(x[:-1]+x[1:])/2,en]) 42 | self.weight = bin_edge[1:] - bin_edge[:-1] 43 | return self 44 | 45 | def get_y(self): 46 | pass 47 | 48 | def get_dy(self): 49 | pass 50 | 51 | def get_y_at(self,x): 52 | pass 53 | 54 | def get_int(self): 55 | weight = self.weight 56 | y = self.get_y() 57 | Int = weight.dot(y) 58 | return Int 59 | 60 | def get_dint(self): 61 | weight = self.weight 62 | dy = self.get_dy() 63 | dInt = weight.dot(dy) 64 | return dInt 65 | 66 | def set_V(self,V): 67 | self.V = V 68 | return self 69 | 70 | def set_bayes(self): 71 | x = self.x 72 | d_A = self.d_Matrix_BasisFunction(x) 73 | W = sparse.csc_matrix(d_A.transpose().dot(d_A)) 74 | self.W = W 75 | return self 76 | 77 | def LGH(self): 78 | x = self.coef; V = self.V; W = self.W; 79 | P = W/V 80 | P[0,0] += 1e-3 81 | log_const = logdet_sp(P)/2 - P.shape[0]*np.log(2*np.pi)/2 82 | L = log_const - x.dot(P.dot(x))/2.0 83 | G = - P.dot(x) 84 | H = - P 85 | return [L,G,H] 86 | 87 | def GH_transform(self,G,H): 88 | A = self.A_sp; A_t = self.A_t_sp; 89 | G = A_t.dot(G) 90 | H = A_t.dot(H.dot(A)) 91 | return [G,H] 92 | 93 | 94 | ###################################################################################### Base class 95 | class linear_1D(BasisFunctionExpansion_1D): 96 | 97 | def get_y(self): 98 | A = self.A; coef = self.coef 99 | y = A.dot(coef) 100 | return y 101 | 102 | def get_dy(self): 103 | A = self.A; 104 | return A 105 | 106 | def get_y_at(self,x): 107 | coef = self.coef 108 | A = self.Matrix_BasisFunction(x) 109 | return A.dot(coef) 110 | 111 | class loglinear_1D(BasisFunctionExpansion_1D): 112 | 113 | def get_y(self): 114 | A = self.A; coef = self.coef 115 | y = np.exp( A.dot(coef) ) 116 | self.y = y 117 | return y 118 | 119 | def get_dy(self): 120 | A = self.A; y = self.y 121 | dy = y.reshape(-1,1) * A 122 | return dy 123 | 124 | def get_y_at(self,x): 125 | coef = self.coef 126 | A = self.Matrix_BasisFunction(x) 127 | return np.exp( A.dot(coef) ) 128 | 129 | ###################################################################################### Bump function 130 | def bump_cos(x): 131 | y = np.zeros_like(x) 132 | index = (-21 else index_ini }) 27 | self._hash_table.update({ (para,i): index_ini+i for i in range(length) }) 28 | index_ini += length 29 | 30 | ### hash table 31 | def idx(self,key): 32 | return self._hash_table[key] 33 | 34 | def _index(self,key_list): 35 | try: 36 | index = np.hstack([ np.arange(self._length,dtype='i8')[self._hash_table[key]] for key in key_list ]) 37 | except: 38 | index = slice(0) 39 | return index 40 | 41 | def add_key(self,key_list,key_name): 42 | self._hash_table.update({ key_name: self._index(key_list) }) 43 | return self 44 | 45 | ### I/O 46 | def from_dict(self,dic): 47 | x = np.zeros(self._length) 48 | for key in dic: 49 | try: 50 | x[self.idx(key)] = dic[key] 51 | except: 52 | x[self.idx(key)] = dic[key][0] 53 | return x 54 | 55 | def to_dict(self,ndarray): 56 | return { para: ndarray[self.idx(para)] for para in self._para_list } 57 | 58 | ################################## 59 | ## Quasi Newton 60 | ################################## 61 | def Quasi_Newton(model,prior=[],merge=[],opt=[]): 62 | 63 | ## parameter setting 64 | para_list = model.stg["para_list"] 65 | para_length = [ model.stg["para_length"][key] for key in para_list ] 66 | param = array_label(para_list,para_length) 67 | param.add_key( [ pr for pr in para_list if model.stg["para_exp"][pr] ], "para_exp") 68 | param.add_key( [ pr for pr in para_list if not model.stg["para_exp"][pr] ], "para_ord") 69 | model.stg['para_label'] = param 70 | 71 | if 'para_ini' not in opt: 72 | para = param.from_dict(model.stg['para_ini']) 73 | else: 74 | para = param.from_dict(opt['para_ini']) 75 | 76 | step_Q = param.from_dict(model.stg["para_step_Q"]) 77 | m = len(para) 78 | 79 | ## prior setting 80 | if prior: 81 | ##fix check 82 | para_fix_index = [ (prior_i["name"],prior_i["index"]) for prior_i in prior if prior_i["type"] == "f" ] 83 | para_fix_value = [ prior_i["mu"] for prior_i in prior if prior_i["type"] == "f" ] 84 | param.add_key(para_fix_index,"fix") 85 | para[param.idx('fix')] = para_fix_value 86 | prior = [ prior_i for prior_i in prior if prior_i["type"] != "f" ] 87 | else: 88 | param.add_key([],"fix") 89 | 90 | ## merge setting 91 | if merge: 92 | d = len(merge) 93 | index_merge = np.zeros(m,dtype='i8') 94 | 95 | for i in range(d): 96 | key = 'merge%d' % i 97 | param.add_key(merge[i],key) 98 | para[param.idx(key)] = para[param.idx(key)].mean() 99 | index_merge[param.idx(key)] = i+1 100 | 101 | M_merge_z = np.eye(m)[ index_merge == 0 ] 102 | M_merge_nz = np.vstack( [ np.eye(m)[index_merge == i+1].sum(axis=0) for i in range(d) ] ) 103 | M_merge = np.vstack([M_merge_z,M_merge_nz]) 104 | M_merge_T = np.transpose(M_merge) 105 | m_reduced = M_merge.shape[0] 106 | 107 | else: 108 | M_merge = 1 109 | M_merge_T = 1 110 | m_reduced = m 111 | 112 | # calculate Likelihood and Gradient at the initial state 113 | [L1,G1] = Penalized_LG(model,para,prior) 114 | G1[param.idx("para_exp")] *= para[param.idx("para_exp")] 115 | G1 = np.dot(M_merge,G1) 116 | 117 | # main 118 | H = np.eye(m_reduced) 119 | i_loop = 0 120 | 121 | while 1: 122 | 123 | if 'print' in opt: 124 | print(i_loop) 125 | print(param.to_dict(para)) 126 | #print(G1) 127 | print( "L = %.3f, norm(G) = %e\n" % (L1,np.linalg.norm(G1)) ) 128 | #sys.exit() 129 | 130 | if 'stop' in opt: 131 | if i_loop == opt['stop']: 132 | break 133 | 134 | #break rule 135 | if np.linalg.norm(G1) < 1e-5 : 136 | break 137 | 138 | #calculate direction 139 | s = H.dot(G1); 140 | s_extended = np.dot(M_merge_T,s) 141 | gamma = 1/np.max([np.max(np.abs(s_extended)/step_Q),1]) 142 | s = s * gamma 143 | s_extended = s_extended * gamma 144 | 145 | #move to new point 146 | para[param.idx("para_ord")] += s_extended[param.idx("para_ord")] 147 | para[param.idx("para_exp")] *= np.exp( s_extended[param.idx("para_exp")] ) 148 | 149 | #calculate Likelihood and Gradient at the new point 150 | [L2,G2] = Penalized_LG(model,para,prior) 151 | G2[param.idx("para_exp")] *= para[param.idx("para_exp")] 152 | G2 = np.dot(M_merge,G2) 153 | 154 | #update hessian matrix 155 | y = (G1-G2).reshape(-1,1) 156 | s = s.reshape(-1,1) 157 | 158 | if y.T.dot(s) > 0: 159 | H = H + (y.T.dot(s)+y.T.dot(H).dot(y))*(s*s.T)/(y.T.dot(s))**2 - (H.dot(y)*s.T+(s*y.T).dot(H))/(y.T.dot(s)) 160 | else: 161 | H = np.eye(m_reduced) 162 | 163 | #update Gradients 164 | L1 = L2 165 | G1 = G2 166 | 167 | i_loop += 1 168 | 169 | ###OPTION: Estimation Error 170 | if 'ste' in opt: 171 | ste = EstimationError(model,para,prior) 172 | else: 173 | ste = [] 174 | 175 | ###OPTION: Check map solution 176 | if 'check' in opt: 177 | Check_QN(model,para,prior) 178 | 179 | return [param.to_dict(para),L1,ste,np.linalg.norm(G1),i_loop] 180 | 181 | def Check_QN(model,para,prior): 182 | param = model.stg['para_label'] 183 | ste = EstimationError_approx(model,para,prior) 184 | ste[param.idx('fix')] = 0 185 | a = np.linspace(-1,1,21) 186 | 187 | for key in model.stg['para_list']: 188 | for index in range(model.stg['para_length'][key]): 189 | 190 | plt.figure() 191 | plt.title(key + '-' + str(index)) 192 | 193 | for i in range(len(a)): 194 | para_tmp = para.copy() 195 | para_tmp[param.idx((key,index))] += a[i] * ste[param.idx((key,index))] 196 | L = Penalized_LG(model,para_tmp,prior)[0] 197 | plt.plot(para_tmp[param.idx((key,index))],L,"ko") 198 | 199 | if i==10: 200 | plt.plot(para_tmp[param.idx((key,index))],L,"ro") 201 | 202 | ################################# 203 | ## Basic funnctions 204 | ################################# 205 | def G_NUMERICAL(model,para): 206 | 207 | m = len(para) 208 | param = model.stg['para_label'] 209 | step_diff = param.from_dict(model.stg['para_step_diff']) 210 | step_diff[param.idx('para_exp')] *= para[param.idx('para_exp')] 211 | G = np.zeros(m) 212 | 213 | for i in range(m): 214 | step = step_diff[i] 215 | 216 | """ 217 | para_tmp = para.copy(); para_tmp[i] -= step; L1 = model.LG(param.to_dict(para_tmp))[0] 218 | para_tmp = para.copy(); para_tmp[i] += step; L2 = model.LG(param.to_dict(para_tmp))[0] 219 | G[i]= (L2-L1)/2/step 220 | """ 221 | 222 | para_tmp = para.copy(); para_tmp[i] -= 2*step; L1 = model.LG(param.to_dict(para_tmp))[0] 223 | para_tmp = para.copy(); para_tmp[i] -= 1*step; L2 = model.LG(param.to_dict(para_tmp))[0] 224 | para_tmp = para.copy(); para_tmp[i] += 1*step; L3 = model.LG(param.to_dict(para_tmp))[0] 225 | para_tmp = para.copy(); para_tmp[i] += 2*step; L4 = model.LG(param.to_dict(para_tmp))[0] 226 | G[i]= (L1-8*L2+8*L3-L4)/12/step 227 | 228 | return G 229 | 230 | 231 | def Hessian(model,para,prior): 232 | 233 | m = len(para) 234 | param = model.stg['para_label'] 235 | step_diff = param.from_dict(model.stg['para_step_diff']) 236 | step_diff[param.idx('para_exp')] *= para[param.idx('para_exp')] 237 | H = np.zeros((m,m)) 238 | 239 | for i in range(m): 240 | step = step_diff[i] 241 | para_tmp = para.copy(); para_tmp[i] -= step; G1 = Penalized_LG(model,para_tmp,prior)[1] 242 | para_tmp = para.copy(); para_tmp[i] += step; G2 = Penalized_LG(model,para_tmp,prior)[1] 243 | H[i] = (G2-G1)/2/step 244 | 245 | H[param.idx('fix')] = 0 246 | H[param.idx('fix'),param.idx('fix')] = -1e+20 247 | 248 | return H 249 | 250 | def EstimationError(model,para,prior): 251 | H = Hessian(model,para,prior) 252 | ste = np.sqrt(np.diag(np.linalg.inv(-H))) 253 | return ste 254 | 255 | def EstimationError_approx(model,para,prior): 256 | H = Hessian(model,para,prior) 257 | ste = 1.0/np.sqrt(np.diag(-H)) 258 | return ste 259 | 260 | 261 | def Penalized_LG(model,para,prior,only_L=False): 262 | 263 | param = model.stg['para_label'] 264 | 265 | [L,G] = model.LG(param.to_dict(para),only_L) 266 | 267 | if isinstance(G,str): 268 | G = G_NUMERICAL(model,para) 269 | else: 270 | G = param.from_dict(G) 271 | 272 | ## fix 273 | if not only_L: 274 | G[param.idx("fix")] = 0 275 | 276 | ## prior 277 | if prior: 278 | 279 | for prior_i in prior: 280 | para_key = prior_i["name"] 281 | para_index = prior_i["index"] 282 | prior_type = prior_i["type"] 283 | mu = prior_i["mu"] 284 | sigma = prior_i["sigma"] 285 | index = param.idx((para_key,para_index)) 286 | x = para[index] 287 | 288 | if prior_type == 'n': #prior: normal distribution 289 | L += - np.log(2*np.pi*sigma**2)/2 - (x-mu)**2/2/sigma**2 290 | if not only_L: 291 | G[index] += - (x-mu)/sigma**2 292 | elif prior_type == 'ln': #prior: log-normal distribution 293 | L += - np.log(2*np.pi*sigma**2)/2 - np.log(x) - (np.log(x)-mu)**2/2/sigma**2 294 | if not only_L: 295 | G[index] += - 1/x - (np.log(x)-mu)/sigma**2/x 296 | elif prior_type == "b": #prior: barrier function 297 | L += - mu/x 298 | if not only_L: 299 | G[index] += mu/x**2 300 | elif prior_type == "b2": #prior: barrier function 301 | L += mu *np.log10(np.e)*np.log(x) 302 | if not only_L: 303 | G[index] += mu * np.log10(np.e)/x 304 | 305 | return [L,G] 306 | 307 | ################################# 308 | ## para_stg 309 | ################################# 310 | def merge_stg(para_stgs): 311 | 312 | stg = {} 313 | stg['para_list'] = [] 314 | stg['para_length'] = {} 315 | stg['para_exp'] = {} 316 | stg['para_ini'] = {} 317 | stg['para_step_Q'] = {} 318 | stg['para_step_diff'] = {} 319 | 320 | for para_stg in para_stgs: 321 | stg['para_list'].extend(para_stg['list']) 322 | stg['para_length'].update(para_stg['length']) 323 | stg['para_exp'].update(para_stg['exp']) 324 | stg['para_ini'].update(para_stg['ini']) 325 | stg['para_step_Q'].update(para_stg['step_Q']) 326 | stg['para_step_diff'].update(para_stg['step_diff']) 327 | 328 | return stg 329 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | 6 | from collections import OrderedDict 7 | from collections import Counter 8 | from operator import itemgetter 9 | 10 | # Hawkes model is from https://omitakahiro.github.io/Hawkes/index.html 11 | from modules import Hawkes as hk 12 | 13 | para = {'mu':0.1, 'alpha':0.3, 'beta':0.6} 14 | mu_t = lambda x: (1.0 + 0.8*np.sin(2*np.pi*x/100)) * 0.2 # baseline function for overlay 15 | itv = [0,360000] 16 | demo_itv = [0,360] 17 | 18 | np.random.seed(42) 19 | downsampling = {'Trump': 10} 20 | #downsampling = {'taxi': 20, 'Trump': 20} 21 | 22 | def downsampling_dataset(timestamps, dataset_name): 23 | print('Down-sampling', dataset_name, 'dataset by', downsampling[dataset_name]) 24 | return timestamps[::downsampling[dataset_name]] 25 | 26 | def purge_duplicate_events(timestamps, types): 27 | timestamps = timestamps.tolist() 28 | types = types.tolist() 29 | del_indices = [] 30 | events = [(ts, ty) for ts, ty in zip(timestamps, types)] 31 | events_next = events[1:] 32 | np.where([e[0]==en[0] and e[1]==en[1] for e, en in zip(events[:-1], events_next)]) 33 | for i in range(1, len(timestamps)): 34 | if timestamps[i]==timestamps[i-1] and types[i]==types[i-1]: 35 | del_indices.append(i) 36 | 37 | for ind in sorted(del_indices, reverse=True): 38 | del timestamps[ind] 39 | del types[ind] 40 | 41 | return timestamps, types 42 | 43 | def keep_top_k_types(types, keep_classes=10): 44 | types_counter = OrderedDict(sorted(Counter(types).items(), key=itemgetter(1), reverse=True)) 45 | type2supertype = OrderedDict() 46 | for i, (type_, _) in enumerate(types_counter.items()): 47 | if i > keep_classes: 48 | type2supertype[type_] = keep_classes + 1 49 | else: 50 | type2supertype[type_] = i + 1 51 | 52 | types_new = [type2supertype[ty] for ty in types] 53 | return np.array(types_new) 54 | 55 | 56 | def hawkes_demo(): 57 | hk_model = hk.simulator().set_kernel('exp').set_baseline('const').set_parameter(para) 58 | T = hk_model.simulate(demo_itv) 59 | hk_model.plot_l() 60 | plt.savefig('hawkes_intensity.png') 61 | plt.close() 62 | hk_model.plot_N() 63 | plt.savefig('hawkes_event_counts.png') 64 | plt.close() 65 | 66 | def sin_hawkes_overlay_demo(): 67 | hk_model = hk.simulator().set_kernel('exp').set_baseline('custom',l_custom=mu_t).set_parameter(para) 68 | T = hk_model.simulate(demo_itv) 69 | hk_model.plot_l() 70 | plt.savefig('sin_hawkes_overlay_intensity.png') 71 | plt.close() 72 | hk_model.plot_N() 73 | plt.savefig('sin_hawkes_overlay_event_counts.png') 74 | plt.close() 75 | 76 | def create_sin_data(): 77 | omega = 1.0 78 | points = 10000 79 | num_marks = 7 80 | x = np.linspace(0, points, 3*points) 81 | y_ = 10*np.sin(omega*x) 82 | y = y_ + 11 83 | gaps=y 84 | timestamp = np.cumsum(gaps) 85 | types = [] 86 | if y_[0]=0.: 90 | # types.append(0) 91 | #else: 92 | # types.append(1) 93 | if y_[i]>=0. and y_[i]>y_[i-1]: 94 | types.append(1) 95 | if y_[i]>=0. and y_[i]y_[i-1]: 100 | types.append(4) 101 | 102 | types = np.array(types) 103 | 104 | plt.plot(x[:25], y[:25], 'o', color='black'); 105 | plt.savefig('data/sin.png') 106 | plt.close() 107 | return gaps, timestamp, types 108 | 109 | def create_hawkes_data(): 110 | hawkes_demo() 111 | hk_model = hk.simulator().set_kernel('exp').set_baseline('const').set_parameter(para) 112 | timestamp = hk_model.simulate([0, 360000]) 113 | gaps = timestamp[1:] - timestamp[:-1] 114 | return gaps, timestamp 115 | 116 | def create_sin_hawkes_overlay_data(): 117 | sin_hawkes_overlay_demo() 118 | hk_model = hk.simulator().set_kernel('exp').set_baseline('custom',l_custom=mu_t).set_parameter(para) 119 | timestamp = hk_model.simulate([0, 360000]) 120 | gaps = timestamp[1:] - timestamp[:-1] 121 | return gaps, timestamp 122 | 123 | def create_taxi_data(): 124 | # https://s3.amazonaws.com/nyc-tlc/trip+data/yellow_tripdata_2019-01.csv 125 | # https://s3.amazonaws.com/nyc-tlc/trip+data/yellow_tripdata_2019-02.csv 126 | taxi_df_jan = pd.read_csv( 127 | './data/yellow_tripdata_2019-01.csv', 128 | usecols=["tpep_pickup_datetime", "PULocationID", "DOLocationID"]) 129 | taxi_df_feb = pd.read_csv( 130 | './data/yellow_tripdata_2019-02.csv', 131 | usecols=["tpep_pickup_datetime", "PULocationID", "DOLocationID"]) 132 | taxi_df = taxi_df_jan.append(taxi_df_feb) 133 | taxi_df = taxi_df[taxi_df.PULocationID == 237] 134 | taxi_df['tpep_pickup_datetime'] = pd.to_datetime(taxi_df['tpep_pickup_datetime'], errors='coerce') 135 | taxi_df = taxi_df[(taxi_df['tpep_pickup_datetime'].dt.year == 2019)] 136 | taxi_df = taxi_df[(taxi_df['tpep_pickup_datetime'].dt.month < 3)] 137 | taxi_df = taxi_df.sort_values('tpep_pickup_datetime') 138 | taxi_types = taxi_df['DOLocationID'].values 139 | #taxi_timestamps = taxi_timestamps.sort_values().astype(np.int64) 140 | taxi_timestamps = pd.DatetimeIndex(taxi_df['tpep_pickup_datetime']).astype(np.int64)/1000000000 141 | taxi_timestamps = np.array(taxi_timestamps) 142 | taxi_timestamps -= taxi_timestamps[0] 143 | taxi_timestamps = taxi_timestamps[:-1] 144 | taxi_types = taxi_types[:-1] 145 | taxi_types = keep_top_k_types(taxi_types) 146 | dataset_name = 'taxi' 147 | if dataset_name in downsampling: 148 | taxi_timestamps = downsampling_dataset(taxi_timestamps, dataset_name) 149 | taxi_types = downsampling_dataset(taxi_types, dataset_name) 150 | taxi_gaps = taxi_timestamps[1:] - taxi_timestamps[:-1] 151 | plt.plot(taxi_gaps[:100]) 152 | plt.ylabel('Gaps') 153 | plt.savefig('data/taxi_gaps.png') 154 | plt.close() 155 | return taxi_gaps, taxi_timestamps, taxi_types 156 | 157 | def create_911_traffic_data(): 158 | call_df = pd.read_csv('./data/911.csv') 159 | call_df = call_df[call_df['zip'].isnull()==False] # Ignore calls with NaN zip codes 160 | print('Types of Emergencies') 161 | print(call_df.title.apply(lambda x: x.split(':')[0]).value_counts()) 162 | call_df['type'] = call_df.title.apply(lambda x: x.split(':')[0]) 163 | print('Subtypes') 164 | for each in call_df.type.unique(): 165 | subtype_count = call_df[call_df.title.apply(lambda x: x.split(':')[0]==each)].title.value_counts() 166 | print('For', each, 'type of Emergency, we have ', subtype_count.count(), 'subtypes') 167 | print(subtype_count[subtype_count>100]) 168 | print('Out of 3 types taking Traffic type considering only Traffic') 169 | call_data = call_df[call_df['type']=='Traffic'] 170 | call_data['timeStamp'] = pd.to_datetime(call_data['timeStamp'], errors='coerce') 171 | print("We have timeline from", call_data['timeStamp'].min(), "to", call_data['timeStamp'].max()) 172 | call_data = call_data.sort_values('timeStamp') 173 | 174 | call_timestamps = pd.DatetimeIndex(call_data['timeStamp']).astype(np.int64)/1000000000 175 | #call_timestamps = call_data.sort_values().astype(np.int64) 176 | call_timestamps = np.array(call_timestamps) 177 | call_timestamps -= call_timestamps[0] 178 | call_types = call_data['zip'].values 179 | call_types = keep_top_k_types(call_types) 180 | dataset_name = 'call' 181 | if dataset_name in downsampling: 182 | call_timestamps = downsampling_dataset(call_timestamps, dataset_name) 183 | call_types = downsampling_dataset(call_types, dataset_name) 184 | call_gaps = call_timestamps[1:] - call_timestamps[:-1] 185 | plt.plot(call_gaps[:100]) 186 | plt.ylabel('Gaps') 187 | plt.xlabel('timeline') 188 | plt.savefig('data/call_traffic_gaps.png') 189 | plt.close() 190 | return call_gaps, call_timestamps, call_types 191 | 192 | def create_911_ems_data(): 193 | call_df = pd.read_csv('./data/911.csv') 194 | call_df = call_df[call_df['zip'].isnull()==False] # Ignore calls with NaN zip codes 195 | call_df['type'] = call_df.title.apply(lambda x: x.split(':')[0]) 196 | print('Out of 3 types taking EMS type considering only EMS') 197 | call_data = call_df[call_df['type']=='EMS'] 198 | call_data['timeStamp'] = pd.to_datetime(call_data['timeStamp'], errors='coerce') 199 | print("We have timeline from", call_data['timeStamp'].min(), "to", call_data['timeStamp'].max()) 200 | call_data = call_data.sort_values('timeStamp') 201 | 202 | call_timestamps = pd.DatetimeIndex(call_data['timeStamp']).astype(np.int64)/1000000000 203 | #call_timestamps = call_data.sort_values().astype(np.int64) 204 | call_timestamps = np.array(call_timestamps) 205 | call_timestamps -= call_timestamps[0] 206 | call_types = call_data['zip'].values 207 | call_types = keep_top_k_types(call_types) 208 | dataset_name = 'call' 209 | if dataset_name in downsampling: 210 | call_timestamps = downsampling_dataset(call_timestamps, dataset_name) 211 | call_types = downsampling_dataset(call_types, dataset_name) 212 | call_gaps = call_timestamps[1:] - call_timestamps[:-1] 213 | plt.plot(call_gaps[:100]) 214 | plt.ylabel('Gaps') 215 | plt.xlabel('timeline') 216 | plt.savefig('data/call_ems_gaps.png') 217 | plt.close() 218 | return call_gaps, call_timestamps, call_types 219 | 220 | def generate_dataset(): 221 | os.makedirs('./data', exist_ok=True) 222 | #os.chdir('./data') 223 | if not os.path.isfile("sin.txt"): 224 | print('Generating sin data') 225 | gaps, timestamps, types = create_sin_data() 226 | timestamps, types = purge_duplicate_events(timestamps, types) 227 | np.savetxt('data/sin.txt', timestamps) 228 | np.savetxt('data/sin_types.txt', types) 229 | # if not os.path.isfile("hawkes.txt"): 230 | # print('Generating hawkes data') 231 | # gaps, timestamps = create_hawkes_data() 232 | # timestamps, types = purge_duplicate_events(timestamps, types) 233 | # np.savetxt('hawkes.txt', timestamps) 234 | # if not os.path.isfile("sin_hawkes_overlay.txt"): 235 | # print('Generating sin_hawkes_overlay data') 236 | # gaps, timestamps = create_sin_hawkes_overlay_data() 237 | # timestamps, types = purge_duplicate_events(timestamps, types) 238 | # np.savetxt('sin_hawkes_overlay.txt', timestamps) 239 | if not os.path.isfile("911_traffic.txt"): 240 | print('Generating 911 data') 241 | gaps, timestamps, types = create_911_traffic_data() 242 | timestamps, types = purge_duplicate_events(timestamps, types) 243 | np.savetxt('data/911_traffic.txt', timestamps) 244 | np.savetxt('data/911_traffic_types.txt', types) 245 | if not os.path.isfile("911_ems.txt"): 246 | print('Generating 911 data') 247 | gaps, timestamps, types = create_911_ems_data() 248 | timestamps, types = purge_duplicate_events(timestamps, types) 249 | np.savetxt('data/911_ems.txt', timestamps) 250 | np.savetxt('data/911_ems_types.txt', types) 251 | if not os.path.isfile("taxi.txt"): 252 | print('Generating taxi data') 253 | gaps, timestamps, types = create_taxi_data() 254 | timestamps = np.array(timestamps).astype(np.float32) 255 | types = np.array(types).astype(np.float32) 256 | timestamps, types = purge_duplicate_events(timestamps, types) 257 | np.savetxt('data/taxi.txt', timestamps) 258 | np.savetxt('data/taxi_types.txt', types) 259 | #os.chdir('../') 260 | 261 | def create_twitter_data(dataset_name, keep_classes=10): 262 | delimiter=' ' 263 | if dataset_name in ['Movie', 'Delhi', 'Verdict', 'Fight']: 264 | delimiter='\t' 265 | twitter_df = pd.read_csv('./data/'+dataset_name+'.txt', delimiter=delimiter, header=None) 266 | twitter_df = twitter_df.values[::-1] 267 | #twitter_df = twitter_df[1] 268 | timestamps = twitter_df[:, 1] 269 | timestamps -= timestamps[0] 270 | gaps = timestamps[1:] - timestamps[:-1] 271 | types = twitter_df[:, 0] 272 | types_counter = OrderedDict(sorted(Counter(types).items(), key=itemgetter(1), reverse=True)) 273 | type2supertype = OrderedDict() 274 | for i, (type_, _) in enumerate(types_counter.items()): 275 | if i > keep_classes: 276 | type2supertype[type_] = keep_classes + 1 277 | else: 278 | type2supertype[type_] = i + 1 279 | 280 | types_new = [type2supertype[ty] for ty in types] 281 | types = types_new 282 | if dataset_name in downsampling: 283 | plt.plot(gaps) 284 | plt.ylabel('all_Gaps_before_downsample') 285 | plt.savefig(dataset_name+'_all_gaps_before_downsample.png') 286 | plt.close() 287 | timestamps = downsampling_dataset(timestamps, dataset_name) 288 | types = downsampling_dataset(types, dataset_name) 289 | 290 | plt.plot(gaps[:100]) 291 | plt.ylabel('Gaps') 292 | plt.savefig(dataset_name+'_gaps.png') 293 | plt.close() 294 | plt.plot(gaps) 295 | plt.ylabel('all_Gaps') 296 | plt.savefig(dataset_name+'_all_gaps.png') 297 | plt.close() 298 | return gaps, timestamps, types 299 | 300 | def generate_twitter_dataset(twitter_dataset_names): 301 | os.makedirs('./data', exist_ok=True) 302 | #os.chdir('./data') 303 | for dataset_name in twitter_dataset_names: 304 | if not os.path.isfile(dataset_name+'.txt'): 305 | print('Generating', dataset_name, 'data') 306 | gaps, timestamps, types = create_twitter_data(dataset_name) 307 | timestamps, types = purge_duplicate_events(np.array(timestamps), np.array(types)) 308 | np.savetxt(dataset_name+'.txt', timestamps) 309 | np.savetxt(dataset_name+'_types.txt', types) 310 | #os.chdir('../') 311 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys, os, json 3 | import numpy as np 4 | from itertools import product 5 | from argparse import Namespace 6 | import multiprocessing as MP 7 | from operator import itemgetter 8 | import datetime 9 | from collections import OrderedDict 10 | from generator import generate_dataset, generate_twitter_dataset 11 | import json 12 | import time 13 | 14 | import run 15 | import utils 16 | 17 | # import os 18 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('dataset_name', type=str, help='dataset_name') 22 | parser.add_argument('model_name', type=str, help='model_name') 23 | 24 | parser.add_argument('--num_types', type=int, default=0, 25 | help='Number of marker types. If markers not required, \ 26 | num_types=0') 27 | 28 | parser.add_argument('--epochs', type=int, default=0, 29 | help='number of training epochs') 30 | parser.add_argument('--patience', type=int, default=2, 31 | help='Number of epochs to wait for \ 32 | before beginning cross-validation') 33 | 34 | parser.add_argument('--learning_rate', type=float, default=1e-3, nargs='+', 35 | help='Learning rate for the training algorithm') 36 | parser.add_argument('-hls', '--hidden_layer_size', type=int, default=32, nargs='+', 37 | help='Number of units in RNN') 38 | parser.add_argument('-embds', '--embed_size', type=int, default=8, nargs='+', 39 | help='Embedding dimension of marks/types') 40 | 41 | parser.add_argument('--output_dir', type=str, 42 | help='Path to store all raw outputs, checkpoints, \ 43 | summaries, and plots', default='Outputs') 44 | parser.add_argument('--saved_models', type=str, 45 | help='Path to store model checkpoints', default='saved_models') 46 | 47 | parser.add_argument('--seed', type=int, 48 | help='Seed for parameter initialization', 49 | default=42) 50 | 51 | # Bin size T_i - T_(i-1) in seconds 52 | parser.add_argument('--bin_size', type=int, default=0, 53 | help='Number of seconds in a bin') 54 | 55 | # F(T_(i-1), T_(i-2) ..... , T_(i-r)) -> T(i) 56 | # r_feature_sz = 20 57 | parser.add_argument('--in_bin_sz', type=int, 58 | help='Input count of bins r_feature_sz', 59 | default=20) 60 | 61 | # dec_len = 8 # For All Models 62 | parser.add_argument('--out_bin_sz', type=int, 63 | help='Output count of bin', 64 | default=1) 65 | 66 | parser.add_argument('--cnt_net_type', type=str, default='ff', 67 | help='Count model network type (ff or rnn)') 68 | 69 | # enc_len = 80 # For RMTPP 70 | parser.add_argument('--enc_len', type=int, default=80, 71 | help='Input length for rnn of rmtpp') 72 | 73 | # comp_enc_len = 40 # For Compound RMTPP 74 | parser.add_argument('--comp_enc_len', type=int, default=40, 75 | help='Input length for rnn of compound rmtpp') 76 | 77 | # comp_bin_sz = 10 # For Compound RMTPP 78 | parser.add_argument('--comp_bin_sz', type=int, default=10, 79 | help='events inside one bin of compound rmtpp') 80 | 81 | # wgan_enc_len = 60 # For WGAN 82 | parser.add_argument('--wgan_enc_len', type=int, default=60, 83 | help='Input length for rnn of WGAN') 84 | parser.add_argument('--use_wgan_d', action='store_true', default=False, 85 | help='Whether to use WGAN discriminator or not') 86 | 87 | # Seq2Seq / CWE parameters 88 | parser.add_argument('--use_cwe_d', action='store_true', default=False, 89 | help='Whether to use CWE/Seq2Seq discriminator or not') 90 | 91 | # interval_size = 360 # For RMTPP 92 | parser.add_argument('--interval_size', type=int, default=360, 93 | help='Interval size for threshold query') 94 | 95 | parser.add_argument('--batch_size', type=int, default=32, 96 | help='Input batch size') 97 | parser.add_argument('--query', type=int, default=1, 98 | help='Query number') 99 | parser.add_argument('--stride_len', type=int, default=1, 100 | help='Stride len for RMTPP number') 101 | parser.add_argument('--normalization', type=str, default='average', 102 | help='gap normalization method') 103 | 104 | parser.add_argument('--generate_plots', action='store_true', default=False, 105 | help='Generate dev and test plots, both per epochs \ 106 | and after training') 107 | parser.add_argument('--parallel_hparam', action='store_true', default=False, 108 | help='Parallel execution of hyperparameters') 109 | 110 | # Flags for RMTPP calibration 111 | parser.add_argument('--calibrate_rmtpp', action='store_true', default=False, 112 | help='Whether to calibrate RMTPP') 113 | parser.add_argument('--extra_var_model', action='store_true', default=False, 114 | help='Use a separate model to train the variance of RMTPP') 115 | 116 | # Flags for optimizer 117 | parser.add_argument('--opt_num_counts', type=int, default=5, 118 | help='Number of counts to try before and after mean for optimizer') 119 | parser.add_argument('--no_rescale_rmtpp_params', action='store_true', default=False, 120 | help='Do not rescale RMTPP intensities for optimizer') 121 | parser.add_argument('--use_ratio_constraints', action='store_true', default=False, 122 | help='Maintain Ratios of adjacent RMTPP event predictions') 123 | parser.add_argument('--search', type=int, default=0, 124 | help='Search algorithm over counts 0:binary, 1:linear') 125 | 126 | # Parameters for extra_var_model 127 | parser.add_argument('--num_grps', type=int, default=10, 128 | help='Number of groups in each bin in forecast horizon') 129 | parser.add_argument('--num_pos', type=int, default=40, 130 | help='Number of positions in each group in forecast horizon') 131 | 132 | # Time-feature parameters 133 | parser.add_argument('--no_count_model_feats', action='store_true', default=False, 134 | help='Do not use time-features for count model') 135 | parser.add_argument('--no_rmtpp_model_feats', action='store_true', default=False, 136 | help='Do not use time-features for rmtpp model') 137 | 138 | 139 | # Trainsformer Paramerters 140 | parser.add_argument('-d_model', type=int, default=32) #64 141 | parser.add_argument('-d_rnn', type=int, default=8) #256 142 | parser.add_argument('-d_inner_hid', type=int, default=32) #128 143 | parser.add_argument('-d_k', type=int, default=8) #16 144 | parser.add_argument('-d_v', type=int, default=8) #16 145 | 146 | parser.add_argument('-n_head', type=int, default=2) #4 147 | parser.add_argument('-n_layers', type=int, default=1) #4 148 | 149 | parser.add_argument('-dropout', type=float, default=0.1) 150 | parser.add_argument('-lr', type=float, default=1e-4) 151 | parser.add_argument('-smooth', type=float, default=0.) 152 | 153 | args = parser.parse_args() 154 | 155 | dataset_names = list() 156 | if args.dataset_name == 'all': 157 | dataset_names.append('sin') 158 | # dataset_names.append('hawkes') 159 | # dataset_names.append('sin_hawkes_overlay') 160 | dataset_names.append('taxi') 161 | dataset_names.append('911_traffic') 162 | dataset_names.append('911_ems') 163 | dataset_names.append('twitter') 164 | else: 165 | dataset_names.append(args.dataset_name) 166 | 167 | print(dataset_names) 168 | 169 | twitter_dataset_names = list() 170 | if 'twitter' in dataset_names: 171 | dataset_names.remove('twitter') 172 | twitter_dataset_names.append('Trump') 173 | #twitter_dataset_names.append('Verdict') 174 | #twitter_dataset_names.append('Delhi') 175 | 176 | for data_name in twitter_dataset_names: 177 | dataset_names.append(data_name) 178 | 179 | args.dataset_name = dataset_names 180 | 181 | model_names = list() 182 | if args.model_name == 'all': 183 | #model_names.append('hawkes_model') 184 | #model_names.append('wgan') 185 | #model_names.append('seq2seq') 186 | #model_names.append('transformer') 187 | model_names.append('count_model') 188 | # model_names.append('hierarchical') 189 | model_names.append('rmtpp_nll') 190 | model_names.append('rmtpp_mse') 191 | model_names.append('rmtpp_mse_var') 192 | #model_names.append('rmtpp_nll_comp') 193 | model_names.append('rmtpp_mse_comp') 194 | model_names.append('rmtpp_mse_var_comp') 195 | #model_names.append('pure_hierarchical_nll') 196 | #model_names.append('pure_hierarchical_mse') 197 | model_names.append('inference_models') 198 | else: 199 | model_names.append(args.model_name) 200 | args.model_name = model_names 201 | 202 | #run_model_flags = { 203 | # #'compute_time_range_pdf': False, 204 | # 205 | # #'run_rmtpp_count_with_optimization': False, 206 | # #'run_rmtpp_with_optimization_fixed_cnt': False, 207 | # 208 | # 'count_only': True, 209 | #} 210 | 211 | run_model_flags = OrderedDict() 212 | if 'rmtpp_nll' in model_names: 213 | run_model_flags['rmtpp_nll_opt'] = {'rmtpp_type':'nll'} 214 | #run_model_flags['rmtpp_nll_cont'] = {'rmtpp_type':'nll'} 215 | #run_model_flags['rmtpp_nll_reinit'] = True 216 | run_model_flags['rmtpp_nll_simu'] = {'rmtpp_type':'nll'} 217 | if 'rmtpp_mse' in model_names: 218 | run_model_flags['rmtpp_mse_opt'] = {'rmtpp_type':'mse'} 219 | #run_model_flags['rmtpp_mse_cont'] = {'rmtpp_type':'mse'} 220 | #run_model_flags['rmtpp_mse_reinit'] = True 221 | run_model_flags['rmtpp_mse_simu'] = {'rmtpp_type':'mse'} 222 | #run_model_flags['rmtpp_mse_simu_nc'] = {'rmtpp_type':'mse'} 223 | #run_model_flags['rmtpp_mse_coopt'] = {'rmtpp_type':'mse'} 224 | if 'rmtpp_mse_var' in model_names: 225 | run_model_flags['rmtpp_mse_var_opt'] = {'rmtpp_type':'mse_var'} 226 | #run_model_flags['rmtpp_mse_var_cont'] = {'rmtpp_type':'mse_var'} 227 | #run_model_flags['rmtpp_mse_var_reinit'] = True 228 | run_model_flags['rmtpp_mse_var_simu'] = {'rmtpp_type':'mse_var'} 229 | #run_model_flags['rmtpp_mse_var_coopt'] = {'rmtpp_type':'mse_var'} 230 | if 'rmtpp_nll_comp' in model_names: 231 | #run_model_flags['run_rmtpp_with_joint_optimization_fixed_cnt_solver_nll_comp'] = True 232 | run_model_flags['rmtpp_nll_opt_comp'] = {'rmtpp_type':'nll', 'rmtpp_type_comp':'nll'} 233 | #run_model_flags['rmtpp_nll_cont_comp'] = {'rmtpp_type':'nll', 'rmtpp_type_comp':'nll'} 234 | if 'rmtpp_mse_comp' in model_names: 235 | #run_model_flags['run_rmtpp_with_joint_optimization_fixed_cnt_solver_mse_comp'] = True 236 | run_model_flags['rmtpp_mse_opt_comp'] = {'rmtpp_type':'mse', 'rmtpp_type_comp':'mse'} 237 | #run_model_flags['rmtpp_mse_cont_comp'] = {'rmtpp_type':'mse', 'rmtpp_type_comp':'mse'} 238 | if 'rmtpp_mse_var_comp' in model_names: 239 | #run_model_flags['run_rmtpp_with_joint_optimization_fixed_cnt_solver_mse_var_comp'] = True 240 | run_model_flags['rmtpp_mse_var_opt_comp'] = {'rmtpp_type':'mse_var', 'rmtpp_type_comp':'mse_var'} 241 | #run_model_flags['rmtpp_mse_var_cont_comp'] = {'rmtpp_type':'mse_var', 'rmtpp_type_comp':'mse_var'} 242 | if 'pure_hierarchical_nll' in model_names: 243 | run_model_flags['run_pure_hierarchical_infer_nll'] = True 244 | if 'pure_hierarchical_mse' in model_names: 245 | run_model_flags['run_pure_hierarchical_infer_mse'] = True 246 | if 'count_model' in model_names: 247 | run_model_flags['count_only'] = True 248 | if 'wgan' in model_names: 249 | run_model_flags['wgan_simu'] = True 250 | if 'seq2seq' in model_names: 251 | run_model_flags['seq2seq_simu'] = True 252 | if 'transformer' in model_names: 253 | run_model_flags['transformer_simu'] = True 254 | #run_model_flags['transformer_simu_nc'] = True 255 | if 'hawkes_model' in model_names: 256 | run_model_flags['hawkes_simu'] = True 257 | 258 | automate_bin_sz = False 259 | if args.bin_size == 0: 260 | automate_bin_sz = True 261 | 262 | if args.patience >= args.epochs: 263 | args.patience = 0 264 | 265 | id_process = os.getpid() 266 | time_current = datetime.datetime.now().isoformat() 267 | 268 | print('args', args) 269 | 270 | print("********************************************************************") 271 | print("PID: %s" % str(id_process)) 272 | print("Time: %s" % time_current) 273 | print("epochs: %s" % str(args.epochs)) 274 | print("learning_rate: %s" % str(args.learning_rate)) 275 | print("seed: %s" % str(args.seed)) 276 | print("Models: %s" % str(model_names)) 277 | print("Datasets: %s" % str(dataset_names)) 278 | print("********************************************************************") 279 | 280 | print("####################################################################") 281 | np.random.seed(args.seed) 282 | os.makedirs(args.output_dir, exist_ok=True) 283 | print("Generating Datasets\n") 284 | generate_dataset() 285 | generate_twitter_dataset(twitter_dataset_names) 286 | print("####################################################################") 287 | 288 | event_count_result = OrderedDict() 289 | results = dict() 290 | for dataset_name in dataset_names: 291 | print("Processing", dataset_name, "Datasets\n") 292 | args.current_dataset = dataset_name 293 | if dataset_name == 'Trump': 294 | args.comp_enc_len = 25 295 | if automate_bin_sz: 296 | if dataset_name in ['Trump', 'sin']: 297 | args.bin_size = utils.get_optimal_bin_size(dataset_name) 298 | else: 299 | args.bin_size = utils.find_best_bin_size(dataset_name) 300 | print('New bin size is', args.bin_size, 'sec') 301 | dataset = utils.get_processed_data(dataset_name, args) 302 | 303 | count_test_out_counts = dataset['count_test_out_counts'] 304 | event_count_preds_true = count_test_out_counts 305 | count_var = None 306 | 307 | per_model_count = dict() 308 | per_model_save = { 309 | 'wgan': None, 310 | 'seq2seq': None, 311 | 'transformer': None, 312 | 'count_model': None, 313 | 'hierarchical': None, 314 | 'rmtpp_mse': None, 315 | 'rmtpp_nll': None, 316 | 'rmtpp_mse_var': None, 317 | 'inference_models': None, 318 | } 319 | per_model_count['true'] = event_count_preds_true 320 | for model_name in model_names: 321 | print("--------------------------------------------------------------------") 322 | args.current_model = model_name 323 | print("Running", model_name, "Model\n") 324 | 325 | model, count_dist_params, rmtpp_var_model, results \ 326 | = run.run_model(dataset_name, 327 | model_name, 328 | dataset, 329 | args, 330 | results, 331 | prev_models=per_model_save, 332 | run_model_flags=run_model_flags) 333 | 334 | #if model_name == 'count_model': 335 | # count_all_means_pred = count_dist_params['count_all_means_pred'] 336 | # count_all_sigms_pred = count_dist_params['count_all_sigms_pred'] 337 | 338 | #per_model_count[model_name] = count_all_means_pred 339 | per_model_save[model_name] = model 340 | #if model_name == 'rmtpp_mse' and args.extra_var_model: 341 | # per_model_save['rmtpp_var_model'] = rmtpp_var_model 342 | #print("Finished Running", model_name, "Model\n") 343 | 344 | #if model_name != 'inference_models' and per_model_count[model_name] is not None: 345 | # old_stdout = sys.stdout 346 | # sys.stdout=open(os.path.join(args.output_dir, "count_model_"+dataset_name+".txt"),"a") 347 | # print("____________________________________________________________________") 348 | # print(model_name, 'MAE for Count Prediction:', np.mean(np.abs(per_model_count['true']-per_model_count[model_name]))) 349 | # print(model_name, 'MAE for Count Prediction (per bin):', np.mean(np.abs(per_model_count['true']-per_model_count[model_name]), axis=0)) 350 | # print("____________________________________________________________________") 351 | # sys.stdout.close() 352 | # sys.stdout = old_stdout 353 | 354 | print('Got result', 'for model', model_name, 'on dataset', dataset_name) 355 | 356 | # TODO: Generate count prediction plots 357 | #for idx in range(10): 358 | # utils.generate_plots(args, dataset_name, dataset, per_model_count, test_sample_idx=idx, count_var=count_var) 359 | 360 | #event_count_result[dataset_name] = per_model_count 361 | print("####################################################################") 362 | 363 | 364 | with open(os.path.join(args.output_dir, 'results_'+dataset_name+'.txt'), 'w') as fp: 365 | 366 | fp.write('\n\nResults in random interval:') 367 | fp.write('\nModel Name & Count MAE & Wass dist & opt_loss & cont_loss & count_loss') 368 | for model_name, metrics_dict in results.items(): 369 | fp.write( 370 | '\n & {} & {:.3f} & {:.3f} & {:.3f} \\\\'.format( 371 | model_name, 372 | metrics_dict['count_mae_rh'], 373 | metrics_dict['wass_dist_rh'], 374 | metrics_dict['bleu_score_rh'], 375 | #metrics_dict['bleu_score_rh'], 376 | #metrics_dict['opt_loss'], 377 | #metrics_dict['cont_loss'], 378 | #metrics_dict['count_loss'], 379 | ) 380 | ) 381 | 382 | fp.write('\n\nResults in Forecast Horizon:') 383 | fp.write('\nModel Name & Count MAE & Wass Dist & bleu_score') 384 | for model_name, metrics_dict in results.items(): 385 | fp.write( 386 | '\n & {} & {:.3f} & {:.3f} & {:.3f} \\\\'.format( 387 | model_name, 388 | metrics_dict['count_mae_fh'], 389 | metrics_dict['wass_dist_fh'], 390 | metrics_dict['bleu_score_fh'], 391 | ) 392 | ) 393 | 394 | fp.write('\n\nQuery 2 Results') 395 | fp.write('\nModel Name & Query_2_Metric') 396 | for model_name, metrics_dict in results.items(): 397 | fp.write( 398 | '\n & {} & {:.3f} \\\\'.format( 399 | model_name, 400 | metrics_dict['more_metric'], 401 | ) 402 | ) 403 | fp.write('\n\nQuery 3 Results') 404 | fp.write('\nModel Name & Query_3_Metric') 405 | for model_name, metrics_dict in results.items(): 406 | fp.write( 407 | '\n & {} & {:.3f} \\\\'.format( 408 | model_name, 409 | metrics_dict['less_metric'], 410 | ) 411 | ) 412 | 413 | 414 | fp.write('\n\nAll metrics in random interval:') 415 | fp.write('\nModel Name & Count MAE & Wass dist & opt_loss & cont_loss & count_loss') 416 | for model_name, metrics_dict in results.items(): 417 | fp.write( 418 | '\n & {} & {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f} \\\\'.format( 419 | model_name, 420 | metrics_dict['count_mae_rh'], 421 | metrics_dict['wass_dist_rh'], 422 | metrics_dict['bleu_score_rh'], 423 | metrics_dict['opt_loss'], 424 | metrics_dict['cont_loss'], 425 | metrics_dict['count_loss'], 426 | ) 427 | ) 428 | 429 | fp.write('\n\nAll metrics in Forecast Horizon:') 430 | fp.write('\nModel Name & Count MAE & Wass Dist & bleu score') 431 | for model_name, metrics_dict in results.items(): 432 | fp.write( 433 | '\n & {} & {:.3f} & {:.3f} & {:.3f} \\\\'.format( 434 | model_name, 435 | metrics_dict['count_mae_fh'], 436 | metrics_dict['wass_dist_fh'], 437 | metrics_dict['bleu_score_fh'], 438 | ) 439 | ) 440 | 441 | 442 | fp.write('\n') 443 | for model_name, metrics_dict in results.items(): 444 | fp.write('\n {}'.format(model_name)) 445 | for metric, val in metrics_dict.items(): 446 | fp.write('\n {}: {:.3f}'.format(metric, val)) 447 | fp.write('\n') 448 | 449 | 450 | for model_name, metrics_dict in results.items(): 451 | for metric, metric_val in metrics_dict.items(): 452 | results[model_name][metric] = str(metric_val) 453 | import json 454 | with open(os.path.join(args.output_dir, 'results_'+dataset_name+'.json'), 'w') as fp: 455 | json.dump(results, fp) 456 | 457 | #with open(os.path.join(args.output_dir, 'results_'+dataset_name+'.json'), 'w') as fp: 458 | # json.dump(results, fp) 459 | #results_json = json.dumps(results, indent=4) 460 | #fp.write(results_json) 461 | -------------------------------------------------------------------------------- /modules/Hawkes/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import sys,time,datetime,copy,subprocess,itertools,pickle,warnings,numbers 5 | 6 | import numpy as np 7 | import scipy as sp 8 | import pandas as pd 9 | from matplotlib import pyplot as plt 10 | import matplotlib as mpl 11 | import matplotlib.gridspec as gridspec 12 | 13 | from scipy.special import gamma,digamma 14 | 15 | from .tools import Quasi_Newton,merge_stg,loglinear_COS,plinear 16 | 17 | try: 18 | import pyximport; pyximport.install(setup_args={'include_dirs': np.get_include()},language_level=2) 19 | from .Hawkes_C import LG_kernel_SUM_exp_cython, LG_kernel_SUM_pow_cython, preprocess_data_nonpara_cython, LG_kernel_SUM_nonpara_cython 20 | cython_import = True 21 | #print("cython") 22 | except: 23 | cython_import = False 24 | #print("python") 25 | 26 | 27 | ########################################################################################################## 28 | ## class 29 | ########################################################################################################## 30 | class base_class: 31 | 32 | ### initialize 33 | def set_kernel(self,type,**kwargs): 34 | kernel_class = {'exp':kernel_exp, 'pow':kernel_pow, 'nonpara':kernel_nonpara} 35 | self.kernel = kernel_class[type](**kwargs) 36 | return self 37 | 38 | def set_baseline(self,type,**kwargs): 39 | baseline_class = {'const':baseline_const,'loglinear':baseline_loglinear,'plinear':baseline_plinear,'custom':baseline_custom} 40 | self.baseline = baseline_class[type](**kwargs) 41 | return self 42 | 43 | def set_parameter(self,para): 44 | self.para = para 45 | self.baseline.set_parameter(para) 46 | self.kernel.set_parameter(para) 47 | return self 48 | 49 | def set_data(self,Data,itv): 50 | st,en = itv 51 | T = Data['T'] 52 | T = T[ (sten) or (i==N_MAX): 133 | break 134 | 135 | if np.random.rand() < l1/l0: ## Fire 136 | T[i] = x 137 | i += 1 138 | l_kernel_sequential.event() 139 | 140 | l0 = l_baseline(x) + l_kernel_sequential.l 141 | 142 | T = T[:i] 143 | 144 | return T 145 | 146 | class estimator(base_class): 147 | 148 | def fit(self,T,itv,prior=[],opt=[],merge=[]): 149 | T = np.array(T); T = T[(itv[0]1 else np.inf 727 | return br 728 | 729 | def sequential(self,mode='simulation'): 730 | para = self.para 731 | return kernel_sequential_pow(para,mode=mode) 732 | 733 | class kernel_sequential_pow(kernel_sequential): 734 | 735 | def __init__(self,para,mode='simulation'): 736 | self.mode = mode 737 | k = para['k']; p = para['p']; c = para['c']; 738 | num_div = 16 739 | delta = 1.0/num_div 740 | s = np.linspace(-9,9,num_div*18+1) 741 | log_phi = s-np.exp(-s) 742 | log_dphi = log_phi + np.log(1+np.exp(-s)) 743 | phi = np.exp(log_phi) # phi = np.exp(s-np.exp(-s)) 744 | H = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) 745 | g = np.zeros_like(s) 746 | self.g = g 747 | self.l = 0 748 | self.Int = 0 749 | self.phi = phi 750 | self.H = H 751 | 752 | if self.mode == 'estimation': 753 | H_k = delta * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) 754 | H_p = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) * (log_phi-digamma(p)) 755 | H_c = delta * k * np.exp( log_dphi + p*log_phi - c*phi ) / gamma(p) * (-1) 756 | self.H_k = H_k 757 | self.H_p = H_p 758 | self.H_c = H_c 759 | self.dl = {'k':0, 'p':0, 'c':0} 760 | 761 | def step_forward(self,step): 762 | g = self.g; phi = self.phi; H = self.H; 763 | index = phi*step<1e-6 764 | v = np.where(index,step,np.divide((1-np.exp(-phi*step)),phi,where=~index)) 765 | Int = (g*v).dot(H) #(g*(1-np.exp(-phi*step))/phi).dot(H) 766 | g = g*np.exp(-phi*step) 767 | l = g.dot(H) 768 | self.g = g 769 | self.l = l 770 | self.Int = Int 771 | 772 | if self.mode == 'estimation': 773 | H_k = self.H_k; H_p = self.H_p; H_c = self.H_c; 774 | dl = {'k':g.dot(H_k),'p':g.dot(H_p),'c':g.dot(H_c)} 775 | self.dl = dl 776 | 777 | return self 778 | 779 | def event(self): 780 | g = self.g; H = self.H; 781 | g = g+1.0 782 | l = g.dot(H) 783 | self.g = g 784 | self.l = l 785 | 786 | return self 787 | 788 | ############################### 789 | class kernel_nonpara(base_component_kernel_class): 790 | 791 | def __init__(self,support,num_bin): 792 | self.type = 'nonpara' 793 | self.support = support 794 | self.num_bin = num_bin 795 | self.bin_edge = np.linspace(0,support,num_bin+1) 796 | self.bin_width = support/num_bin 797 | self.para_list = ['g'] 798 | self.has_sequential = False 799 | 800 | def set_data(self,Data,itv): 801 | self.Data = Data 802 | self.itv = itv 803 | dl,dInt = preprocess_data_nonpara_cython(Data['T'],self.bin_edge,itv[1]) 804 | self.dl = dl 805 | self.dInt = dInt 806 | return self 807 | 808 | def prep_fit(self): 809 | num_bin = self.num_bin 810 | support = self.support 811 | list = ['g'] 812 | length = {'g': num_bin } 813 | exp = {'g': True } 814 | ini = {'g': np.ones(num_bin)*0.5/support} 815 | step_Q = {'g': np.ones(num_bin)*0.2} 816 | step_diff = {'g': np.ones(num_bin)*0.01} 817 | return {"list":list,'length':length,'exp':exp,'ini':ini,'step_Q':step_Q,'step_diff':step_diff} 818 | 819 | def LG_SUM(self): 820 | g = self.para['g'] 821 | l = np.dot(g,self.dl) 822 | dl = {'g':self.dl} 823 | return [l,dl] 824 | 825 | def LG_INT(self): 826 | g = self.para['g'] 827 | Int = g.dot(self.dInt) 828 | dInt = {'g':self.dInt} 829 | return [Int,dInt] 830 | 831 | def func(self,x): 832 | bin_edge = self.bin_edge 833 | g = self.para['g'] 834 | g_ext = np.hstack([g,0]) 835 | l = g_ext[ np.searchsorted(bin_edge,x,side='right') - 1 ] 836 | return l 837 | 838 | def d_func(self,x): 839 | bin_edge = self.bin_edge 840 | dl = np.zeros((bin_edge.shape[0],x.shape[0])) 841 | dl[np.searchsorted(bin_edge,x,side='right')-1,np.arange(x.shape[0])] = 1.0 842 | dl = {'g': dl[:-1] } 843 | return dl 844 | 845 | def int(self,x1,x2): 846 | bin_edge = self.bin_edge 847 | bin_width = self.bin_width 848 | g = self.para['g'] 849 | g_ext = np.hstack([g,0]) 850 | cum = np.hstack([0,g.cumsum()*bin_width]) 851 | index1 = np.searchsorted(bin_edge,x1,side='right') - 1 852 | index2 = np.searchsorted(bin_edge,x2,side='right') - 1 853 | int1 = cum[index1] + (x1-bin_edge[index1]) * g_ext[index1] 854 | int2 = cum[index2] + (x2-bin_edge[index2]) * g_ext[index2] 855 | return int2-int1 856 | 857 | def d_int(self,x1,x2): 858 | bin_edge = self.bin_edge 859 | bin_width = self.bin_width 860 | index1 = np.searchsorted(bin_edge,x1,side='right') - 1 861 | index2 = np.searchsorted(bin_edge,x2,side='right') - 1 862 | dInt1 = np.vstack([ np.hstack([np.ones(index)*bin_width,x1[i]-bin_edge[index],np.zeros(bin_edge.shape[0]-index-1)]) for i,index in enumerate(index1) ]).transpose() 863 | dInt2 = np.vstack([ np.hstack([np.ones(index)*bin_width,x2[i]-bin_edge[index],np.zeros(bin_edge.shape[0]-index-1)]) for i,index in enumerate(index2) ]).transpose() 864 | 865 | return {'g': (dInt2 - dInt1)[:-1] } 866 | 867 | def branching_ratio(self): 868 | br = self.para['g'].sum()*self.bin_width 869 | return br 870 | 871 | def plot(self): 872 | bin_edge = self.bin_edge 873 | x = np.vstack([bin_edge[:-1],bin_edge[1:]]).transpose().flatten() 874 | y = np.repeat(self.para['g'],2) 875 | plt.plot(x,y,'k-') 876 | 877 | ########################################################################################### 878 | ########################################################################################### 879 | ## graph routine 880 | ########################################################################################### 881 | ########################################################################################### 882 | def plot_N(T,itv): 883 | 884 | gs = gridspec.GridSpec(100,1) 885 | 886 | plt.figure(figsize=(4,5), dpi=100) 887 | mpl.rc('font', size=12, family='DejaVu Sans') 888 | mpl.rc('axes',titlesize=12) 889 | mpl.rc('pdf',fonttype=42) 890 | 891 | [st,en] = itv 892 | n = len(T) 893 | x = np.hstack([st,np.repeat(T,2),en]) 894 | y = np.repeat(np.arange(n+1),2) 895 | 896 | plt.subplot(gs[0:10,0]) 897 | plt.plot(np.hstack([ [t,t,np.NaN] for t in T]),np.array( [0,1,np.NaN] * n ),'k-',linewidth=0.5) 898 | plt.xticks([]) 899 | plt.xlim(itv) 900 | plt.ylim([0,1]) 901 | plt.yticks([]) 902 | plt.gca().spines['top'].set_visible(False) 903 | plt.gca().spines['right'].set_visible(False) 904 | plt.gca().spines['left'].set_visible(False) 905 | 906 | plt.subplot(gs[15:100,0]) 907 | plt.plot(x,y,'k-',clip_on=False) 908 | plt.xlim(itv) 909 | plt.ylim([0,n]) 910 | plt.xlabel('time') 911 | plt.ylabel(r'$N(0,t)$') 912 | plt.gca().spines['top'].set_visible(False) 913 | plt.gca().spines['right'].set_visible(False) 914 | 915 | def plot_l(T,x,l,l_baseline): 916 | 917 | gs = gridspec.GridSpec(100,1) 918 | 919 | plt.figure(figsize=(4,5), dpi=100) 920 | mpl.rc('font', size=12, family='DejaVu Sans') 921 | mpl.rc('axes',titlesize=12) 922 | mpl.rc('pdf',fonttype=42) 923 | 924 | l_max = l.max() 925 | n = len(T) 926 | 927 | plt.subplot(gs[0:10,0]) 928 | plt.plot(np.hstack([ [t,t,np.NaN] for t in T]),np.array( [0,1,np.NaN] * n ),'k-',linewidth=0.5) 929 | plt.xticks([]) 930 | plt.xlim([x[0],x[-1]]) 931 | plt.ylim([0,1]) 932 | plt.yticks([]) 933 | plt.gca().spines['top'].set_visible(False) 934 | plt.gca().spines['right'].set_visible(False) 935 | plt.gca().spines['left'].set_visible(False) 936 | 937 | plt.subplot(gs[15:100,0]) 938 | plt.plot(x,l,'k-',lw=1) 939 | plt.plot(x,l_baseline,'k:',lw=1) 940 | plt.xlim([x[0],x[-1]]) 941 | plt.ylim([0,l_max]) 942 | plt.xlabel('time') 943 | plt.ylabel(r'$\lambda(t|H_t)$') 944 | plt.gca().spines['top'].set_visible(False) 945 | plt.gca().spines['right'].set_visible(False) 946 | 947 | def plot_N_pred(T,T_pred,itv,en_f): 948 | 949 | gs = gridspec.GridSpec(100,1) 950 | 951 | plt.figure(figsize=(4,5), dpi=100) 952 | mpl.rc('font', size=12, family='DejaVu Sans') 953 | mpl.rc('axes',titlesize=12) 954 | mpl.rc('pdf',fonttype=42) 955 | 956 | [st,en] = itv 957 | n = len(T) 958 | x = np.hstack([st,np.repeat(T,2),en]) 959 | y = np.repeat(np.arange(n+1),2) 960 | n_pred_max = np.max([ len(T_i) for T_i in T_pred ]) 961 | 962 | plt.subplot(gs[0:10,0]) 963 | plt.plot(np.hstack([ [t,t,np.NaN] for t in T]),np.array( [0,1,np.NaN] * n ),'k-',linewidth=0.5) 964 | plt.xticks([]) 965 | plt.xlim([itv[0],en_f]) 966 | plt.ylim([0,1]) 967 | plt.yticks([]) 968 | plt.gca().spines['top'].set_visible(False) 969 | plt.gca().spines['right'].set_visible(False) 970 | plt.gca().spines['left'].set_visible(False) 971 | 972 | plt.subplot(gs[15:,0]) 973 | plt.plot(x,y,'k-') 974 | plt.plot([en,en],[0,n+n_pred_max],'k--') 975 | 976 | for i in range(len(T_pred)): 977 | n_pred = len(T_pred[i]) 978 | x = np.hstack([en,np.repeat(T_pred[i],2),en_f]) 979 | y = np.repeat(np.arange(n_pred+1),2) + n 980 | plt.plot(x,y,'-',color=[0.7,0.7,1.0],lw=0.5) 981 | 982 | plt.xlim([st,en_f]) 983 | plt.ylim([0,n+n_pred_max]) 984 | plt.xlabel('time') 985 | plt.ylabel(r'$N(0,t)$') 986 | plt.gca().spines['top'].set_visible(False) 987 | plt.gca().spines['right'].set_visible(False) 988 | 989 | def plot_KS(T_trans,itv_trans): 990 | from scipy.stats import kstest 991 | 992 | plt.figure(figsize=(4,4), dpi=100) 993 | mpl.rc('font', size=12, family='DejaVu Sans') 994 | mpl.rc('axes',titlesize=12) 995 | mpl.rc('pdf',fonttype=42) 996 | 997 | n = len(T_trans) 998 | [st,en] = itv_trans 999 | x = np.hstack([st,np.repeat(T_trans,2),en]) 1000 | y = np.repeat(np.arange(n+1),2)/n 1001 | w = 1.36/np.sqrt(n) 1002 | [_,pvalue] = kstest(T_trans/itv_trans[1],'uniform') 1003 | 1004 | plt.plot(x,y,"k-",label='Data') 1005 | plt.fill_between([0,n*w,n*(1-w),n],[0,0,1-2*w,1-w],[w,2*w,1,1],color="#dddddd",label='95% interval') 1006 | plt.xlim([0,n]) 1007 | plt.ylim([0,1]) 1008 | plt.ylabel("cumulative distribution function") 1009 | plt.xlabel("transfunced time") 1010 | plt.title("p-value = %.3f" % pvalue) 1011 | plt.legend(loc="upper left") 1012 | --------------------------------------------------------------------------------