├── modules
├── __init__.py
└── Hawkes
│ ├── __init__.py
│ ├── tools
│ ├── __init__.py
│ ├── BasisFunction.py
│ └── Quasi_Newton.py
│ ├── Hawkes_C.pyx
│ └── model.py
├── transformer_helpers
├── Constants.py
├── Modules.py
├── Layers.py
└── SubLayers.py
├── .gitignore
├── DualTPP_diagram.png
├── saved_models
├── .DS_Store
├── training_wgan_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_count_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_count_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_mse_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_seq2seq_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_wgan_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_count_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_mse_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_nll_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_seq2seq_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_seq2seq_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_transformer_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_transformer_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_wgan_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_count_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_seq2seq_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_transformer_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_wgan_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_nll_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_transformer_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_count_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_wgan_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_nll_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_seq2seq_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_transformer_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
└── training_rmtpp_mse_var_comp_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── saved_models_15thaug_epochs10
├── .DS_Store
├── training_count_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_seq2seq_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_wgan_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_wgan_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_count_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_count_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_mse_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_mse_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_nll_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_seq2seq_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_transformer_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_wgan_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_count_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_seq2seq_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_transformer_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_transformer_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_wgan_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_sin
│ ├── checkpoint
│ ├── cp_sin.ckpt.index
│ └── cp_sin.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_seq2seq_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_transformer_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_count_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_comp_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_seq2seq_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_wgan_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_var_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_transformer_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_rmtpp_mse_comp_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
└── training_rmtpp_mse_var_comp_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── requirements.txt
├── saved_models_16thaug_baselines
├── training_seq2seq_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_wgan_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_wgan_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_seq2seq_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_transformer_taxi
│ ├── checkpoint
│ ├── cp_taxi.ckpt.index
│ └── cp_taxi.ckpt.data-00000-of-00001
├── training_wgan_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_seq2seq_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_transformer_Trump
│ ├── checkpoint
│ ├── cp_Trump.ckpt.index
│ └── cp_Trump.ckpt.data-00000-of-00001
├── training_transformer_911_ems
│ ├── checkpoint
│ ├── cp_911_ems.ckpt.index
│ └── cp_911_ems.ckpt.data-00000-of-00001
├── training_seq2seq_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── training_wgan_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
└── training_transformer_911_traffic
│ ├── checkpoint
│ ├── cp_911_traffic.ckpt.index
│ └── cp_911_traffic.ckpt.data-00000-of-00001
├── script.sh
├── .idea
└── vcs.xml
├── README.md
├── transformer_utils.py
├── generator.py
└── main.py
/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transformer_helpers/Constants.py:
--------------------------------------------------------------------------------
1 | PAD = 0.
2 |
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | pp_seq2seq/nmt/*.pyc
3 | pp_seq2seq/nmt/data/*
4 |
--------------------------------------------------------------------------------
/modules/Hawkes/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import simulator
2 | from .model import estimator
3 |
--------------------------------------------------------------------------------
/DualTPP_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/DualTPP_diagram.png
--------------------------------------------------------------------------------
/saved_models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/.DS_Store
--------------------------------------------------------------------------------
/saved_models/training_wgan_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_wgan_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_transformer_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_transformer_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_wgan_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/.DS_Store
--------------------------------------------------------------------------------
/saved_models/training_count_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_transformer_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_wgan_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_transformer_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk==3.2.4
2 | tensorflow-addons==0.9.1
3 | tensorflow==2.1.0
4 | torch==1.7.1+cu101
5 | properscoring
6 | cvxpy
7 | pandas==0.25.3
8 |
--------------------------------------------------------------------------------
/saved_models/training_count_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_wgan_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_wgan_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_count_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_wgan_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_transformer_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_wgan_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_sin.ckpt"
2 | all_model_checkpoint_paths: "cp_sin.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_taxi.ckpt"
2 | all_model_checkpoint_paths: "cp_taxi.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_transformer_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_wgan_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_Trump.ckpt"
2 | all_model_checkpoint_paths: "cp_Trump.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_transformer_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_ems.ckpt"
2 | all_model_checkpoint_paths: "cp_911_ems.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_transformer_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_transformer_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "cp_911_traffic.ckpt"
2 | all_model_checkpoint_paths: "cp_911_traffic.ckpt"
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python main.py 911_traffic all --output_dir Outputs_ignore_post --saved_models saved_models_15thaug_epochs10 --out_bin_sz 3 --no_rescale_rmtpp_params
4 |
--------------------------------------------------------------------------------
/modules/Hawkes/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .Quasi_Newton import Quasi_Newton,merge_stg
2 | from .BasisFunction import linear_COS, loglinear_COS, linear_CBS, loglinear_CBS, plinear, linear_SSM
3 |
--------------------------------------------------------------------------------
/saved_models/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_nll_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_nll_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_sin/cp_sin.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_count_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.index
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_taxi/cp_taxi.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_wgan_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_Trump/cp_Trump.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_ems/cp_911_ems.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_seq2seq_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_16thaug_baselines/training_transformer_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pratham16cse/DualTPP/HEAD/saved_models_15thaug_epochs10/training_rmtpp_mse_var_comp_911_traffic/cp_911_traffic.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/transformer_helpers/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import tensorflow as tf
5 |
6 |
7 | class ScaledDotProductAttention(tf.keras.Model):
8 | """ Scaled Dot-Product Attention """
9 |
10 | def __init__(
11 | self, temperature, attn_dropout=0.2,
12 | name='ScaledDotProductAttention', **kwargs):
13 | super(ScaledDotProductAttention, self).__init__(name=name, **kwargs)
14 |
15 | self.temperature = temperature
16 | #self.dropout = nn.Dropout(attn_dropout)
17 | self.dropout = tf.keras.layers.Dropout(attn_dropout)
18 |
19 | def call(self, q, k, v, mask=None):
20 | #attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
21 | attn = tf.matmul(q / self.temperature, tf.transpose(k, perm=[0, 1, 3, 2]))
22 |
23 | if mask is not None:
24 | #attn = attn.masked_fill(mask, -1e9)
25 | attn = tf.where(mask, attn, -1e9)
26 |
27 | attn = self.dropout(tf.nn.softmax(attn, axis=-1))
28 | #output = torch.matmul(attn, v)
29 | output = tf.matmul(attn, v)
30 |
31 | return output, attn
32 |
--------------------------------------------------------------------------------
/transformer_helpers/Layers.py:
--------------------------------------------------------------------------------
1 | #import torch.nn as nn
2 | import tensorflow as tf
3 | from transformer_helpers.SubLayers import MultiHeadAttention, PositionwiseFeedForward
4 |
5 |
6 | class EncoderLayer(tf.keras.Model):
7 | """ Compose with two layers """
8 |
9 | def __init__(
10 | self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True,
11 | name='EncoderLayer', **kwargs):
12 | super(EncoderLayer, self).__init__(name=name, **kwargs)
13 | self.slf_attn = MultiHeadAttention(
14 | n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before)
15 | self.pos_ffn = PositionwiseFeedForward(
16 | d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
17 |
18 | def call(self, enc_input, feats, non_pad_mask=None, slf_attn_mask=None):
19 | enc_output, enc_slf_attn = self.slf_attn(
20 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
21 | non_pad_mask = tf.cast(tf.expand_dims(non_pad_mask, axis=-1), tf.float32)
22 |
23 | #enc_output = tf.concat([enc_output, feats], axis=-1)
24 | #enc_output *= non_pad_mask
25 |
26 | enc_output = self.pos_ffn(enc_output)
27 | #enc_output *= non_pad_mask
28 |
29 | return enc_output, enc_slf_attn
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Long Horizon Forecasting With Temporal Point Processes
2 |
3 | 
4 |
5 | This is the code produced as part of the paper _Long Horizon Forecasting With Temporal Point Processes_
6 |
7 | > "Long Horizon Forecasting With Temporal Point Processes"
8 | > Prathamesh Deshpande, Kamlesh Marathe, Abir De, Sunita Sarawagi. WSDM 2021. [arXiv:2101.02815](https://arxiv.org/abs/2101.02815)
9 |
10 | ## Packages needed
11 | Specified in [requirements](requirements.txt).
12 |
13 | ## Dataset Download
14 | We have provided all the datasets used in our experiments [here](https://drive.google.com/drive/folders/1b1KUwkeIqIViPZoRZzbPAzKeNn7P1OD-?usp=sharing).
15 |
16 | Please download the `data/` folder add place it in the [DualTPP](https://github.com/pratham16cse/DualTPP) directory.
17 |
18 | ## Experiment execution
19 | To run the code to reproduce the results, please use this [script](script.sh) \[ Under development, more datasets will be soon added to the script\].
20 |
21 | ## Output
22 | All the outputs will be stored in the `` directory.
23 |
24 | The numbers reported in Table 2 of the [paper](https://arxiv.org/abs/2101.02815) will be stored in `output_dir/results_.json` and `output_dir/results_.txt` files.
25 |
26 | ## Parameters Description
27 | Under Development
28 |
29 | ## Contact
30 | For any queries related to library versions, datasets, script, and results please contact us here:
31 |
32 | Email: prathameshsdeshpande@gmail.com
33 |
34 | Whatsapp: +91 9043751980
35 |
--------------------------------------------------------------------------------
/transformer_helpers/SubLayers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import transformer_helpers.Constants as Constants
7 | from transformer_helpers.Modules import ScaledDotProductAttention
8 |
9 | import tensorflow as tf
10 | from tensorflow import keras
11 | from tensorflow.keras import layers
12 | import tensorflow_probability as tfp
13 | import tensorflow_addons as tfa
14 |
15 |
16 | class MultiHeadAttention(tf.keras.Model):
17 | """ Multi-Head Attention module """
18 |
19 | def __init__(
20 | self, n_head, d_model, d_k, d_v, dropout=0.1, normalize_before=True,
21 | name='MultiHeadAttention', **kwargs):
22 | super(MultiHeadAttention, self).__init__(name=name, **kwargs)
23 |
24 | self.normalize_before = normalize_before
25 | self.n_head = n_head
26 | self.d_k = d_k
27 | self.d_v = d_v
28 |
29 | self.w_qs = layers.Dense(n_head * d_k, use_bias=False)
30 | self.w_ks = layers.Dense(n_head * d_k, use_bias=False)
31 | self.w_vs = layers.Dense(n_head * d_v, use_bias=False)
32 | # Note: layers.Dense uses xavier_uniform_ initialization by default
33 | #nn.init.xavier_uniform_(self.w_qs.weight)
34 | #nn.init.xavier_uniform_(self.w_ks.weight)
35 | #nn.init.xavier_uniform_(self.w_vs.weight)
36 |
37 | self.fc = layers.Dense(d_model)
38 | #nn.init.xavier_uniform_(self.fc.weight)
39 |
40 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
41 |
42 | #self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
43 | self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
44 | #self.dropout = nn.Dropout(dropout)
45 | self.dropout = tf.keras.layers.Dropout(dropout)
46 |
47 | def call(self, q, k, v, mask=None):
48 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
49 | sz_b, len_q, len_k, len_v = q.shape[0], q.shape[1], k.shape[1], v.shape[1]
50 |
51 | residual = q
52 | if self.normalize_before:
53 | q = self.layer_norm(q)
54 |
55 | # Pass through the pre-attention projection: b x lq x (n*dv)
56 | # Separate different heads: b x lq x n x dv
57 | q = tf.reshape(self.w_qs(q), [sz_b, len_q, n_head, d_k])
58 | k = tf.reshape(self.w_ks(k), [sz_b, len_k, n_head, d_k])
59 | v = tf.reshape(self.w_vs(v), [sz_b, len_v, n_head, d_v])
60 |
61 | # Transpose for attention dot product: b x n x lq x dv
62 | #q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
63 | q = tf.transpose(q, perm=[0, 2, 1, 3])
64 | k = tf.transpose(k, perm=[0, 2, 1, 3])
65 | v = tf.transpose(v, perm=[0, 2, 1, 3])
66 |
67 | if mask is not None:
68 | mask = tf.expand_dims(mask, axis=1) # For head axis broadcasting.
69 |
70 | output, attn = self.attention(q, k, v, mask=mask)
71 |
72 | # Transpose to move the head dimension back: b x lq x n x dv
73 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
74 | #output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
75 | output = tf.reshape(tf.transpose(output, perm=[0, 2, 1, 3]), [sz_b, len_q, -1])
76 | output = self.dropout(self.fc(output))
77 | output += residual
78 |
79 | if not self.normalize_before:
80 | output = self.layer_norm(output)
81 | return output, attn
82 |
83 |
84 | class PositionwiseFeedForward(tf.keras.Model):
85 | """ Two-layer position-wise feed-forward neural network. """
86 |
87 | def __init__(
88 | self, d_in, d_hid, dropout=0.1, normalize_before=True,
89 | name='PositionwiseFeedForward', **kwargs):
90 | super(PositionwiseFeedForward, self).__init__(name=name, **kwargs)
91 |
92 | self.normalize_before = normalize_before
93 |
94 | self.w_1 = layers.Dense(d_hid) # position-wise
95 | self.w_2 = layers.Dense(d_in) # position-wise
96 |
97 | #self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
98 | self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
99 | #self.dropout = nn.Dropout(dropout)
100 | self.dropout = tf.keras.layers.Dropout(dropout)
101 |
102 |
103 | def call(self, x):
104 | residual = x
105 | if self.normalize_before:
106 | x = self.layer_norm(x)
107 |
108 | x = tfa.activations.gelu(self.w_1(x))
109 | x = self.dropout(x)
110 | x = self.w_2(x)
111 | x = self.dropout(x)
112 | x = x + residual
113 |
114 | if not self.normalize_before:
115 | x = self.layer_norm(x)
116 | return x
117 |
--------------------------------------------------------------------------------
/modules/Hawkes/Hawkes_C.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import gamma,digamma
3 | cimport numpy as np
4 |
5 | def LG_kernel_SUM_exp_cython(np.ndarray[np.float64_t,ndim=1] T, np.ndarray[np.float64_t,ndim=1] alpha, np.ndarray[np.float64_t,ndim=1] beta):
6 | cdef int m = len(alpha)
7 | cdef int n = T.shape[0]
8 | cdef np.ndarray[np.float64_t,ndim=1] l = np.zeros(n, dtype=np.float64)
9 | cdef np.ndarray[np.float64_t,ndim=1] l_i = np.zeros(n, dtype=np.float64)
10 | cdef np.ndarray[np.float64_t,ndim=1] dl_a = np.zeros(n, dtype=np.float64)
11 | cdef np.ndarray[np.float64_t,ndim=1] dl_b = np.zeros(n, dtype=np.float64)
12 | cdef dict dl = {}
13 |
14 | for i in range(m):
15 | l_i,dl_a,dl_b = LG_kernel_SUM_exp_i_cython(T,alpha[i],beta[i])
16 | l = l + l_i
17 | dl.update({('alpha',i):dl_a,('beta',i):dl_b})
18 |
19 | return [l,dl]
20 |
21 |
22 | def LG_kernel_SUM_exp_i_cython(np.ndarray[np.float64_t,ndim=1] T, double alpha, double beta):
23 |
24 | cdef int n = T.shape[0]
25 |
26 | cdef np.ndarray[np.float64_t,ndim=1] l = np.zeros(n, dtype=np.float64)
27 | cdef np.ndarray[np.float64_t,ndim=1] dl_a = np.zeros(n, dtype=np.float64)
28 | cdef np.ndarray[np.float64_t,ndim=1] dl_b = np.zeros(n, dtype=np.float64)
29 |
30 | cdef np.ndarray[np.float64_t,ndim=1] dt = T[1:] - T[:-1]
31 | cdef np.ndarray[np.float64_t,ndim=1] r = np.exp(-beta*dt)
32 |
33 | cdef double x = 0.0
34 | cdef double x_a = 0.0
35 | cdef double x_b = 0.0
36 |
37 | cdef int i;
38 |
39 | for i in range(n-1):
40 | x = ( x + alpha*beta ) * r[i]
41 | x_a = ( x_a + beta ) * r[i]
42 | x_b = ( x_b + alpha ) * r[i] - x*dt[i]
43 |
44 | l[i+1] = x
45 | dl_a[i+1] = x_a
46 | dl_b[i+1] = x_b
47 |
48 | return [l,dl_a,dl_b]
49 |
50 |
51 | def LG_kernel_SUM_pow_cython(np.ndarray[np.float64_t,ndim=1] T,double k, double p, double c):
52 |
53 | cdef int n = T.shape[0]
54 |
55 | cdef np.ndarray[np.float64_t,ndim=1] l = np.zeros(n, dtype=np.float64)
56 | cdef np.ndarray[np.float64_t,ndim=1] dl_p = np.zeros(n, dtype=np.float64)
57 | cdef np.ndarray[np.float64_t,ndim=1] dl_k = np.zeros(n, dtype=np.float64)
58 | cdef np.ndarray[np.float64_t,ndim=1] dl_c = np.zeros(n, dtype=np.float64)
59 |
60 | cdef int num_div = 16
61 | cdef double delta = 1.0/num_div
62 | cdef np.ndarray[np.float64_t,ndim=1] s = np.linspace(-9,9,num_div*18+1)
63 | cdef np.ndarray[np.float64_t,ndim=1] log_phi = s-np.exp(-s)
64 | cdef np.ndarray[np.float64_t,ndim=1] log_dphi = log_phi + np.log(1+np.exp(-s))
65 | cdef np.ndarray[np.float64_t,ndim=1] phi = np.exp(log_phi) # phi = np.exp(s-np.exp(-s))
66 | cdef np.ndarray[np.float64_t,ndim=1] dphi = np.exp(log_dphi) # dphi = phi*(1+np.exp(-s))
67 |
68 | cdef np.ndarray[np.float64_t,ndim=1] H = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p)
69 | cdef np.ndarray[np.float64_t,ndim=1] H_p = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) * (log_phi-digamma(p))
70 | cdef np.ndarray[np.float64_t,ndim=1] H_c = delta * k * np.exp( log_dphi + p*log_phi - c*phi ) / gamma(p) * (-1)
71 |
72 | cdef np.ndarray[np.float64_t,ndim=1] g = np.zeros_like(s)
73 |
74 | cdef int i
75 |
76 | for i in range(n-1):
77 | g = (g+1)*np.exp( - phi*(T[i+1]-T[i]) )
78 | l[i+1] = g.dot(H)
79 | dl_k[i+1] = l[i+1]/k
80 | dl_p[i+1] = g.dot(H_p)
81 | dl_c[i+1] = g.dot(H_c)
82 |
83 | return [l,dl_k,dl_p,dl_c]
84 |
85 | def preprocess_data_nonpara_cython(np.ndarray[np.float64_t,ndim=1] T, np.ndarray[np.float64_t,ndim=1] bin_edge, double en):
86 |
87 | cdef double support = bin_edge[-1]
88 | cdef double bin_width = bin_edge[1] - bin_edge[0]
89 | cdef int i,j
90 | cdef int n = T.shape[0]
91 | cdef int m = bin_edge.shape[0] - 1 # the number of bins
92 |
93 | ###### dl
94 | cdef list index_tgt_list = []
95 | cdef list index_trg_list = []
96 |
97 | for i in range(n):
98 | for j in range(i-1,-1,-1):
99 | if T[i] - T[j] < support:
100 | index_tgt_list.append(i)
101 | index_trg_list.append(j)
102 | else:
103 | break
104 |
105 | cdef np.ndarray[np.int64_t,ndim=1] index_tgt = np.array(index_tgt_list)
106 | cdef np.ndarray[np.int64_t,ndim=1] index_trg = np.array(index_trg_list)
107 | cdef np.ndarray[np.int64_t,ndim=1] index_bin = np.searchsorted(bin_edge,T[index_tgt]-T[index_trg],side='right') - 1
108 |
109 | ###
110 | cdef np.ndarray[np.float64_t,ndim=2] dl = np.zeros((m,n))
111 |
112 | for i in range(index_tgt.shape[0]):
113 | dl[index_bin[i],index_tgt[i]] += 1.0
114 |
115 | ###### dInt
116 | cdef np.ndarray[np.float64_t,ndim=2] dInt = np.zeros((m,n))
117 | cdef np.ndarray[np.int64_t,ndim=1] index = np.searchsorted(bin_edge,en-T,side='right') - 1
118 | cdef np.ndarray[np.float64_t,ndim=1] d_from_left = en - T - bin_edge[index]
119 | cdef int index_i
120 |
121 | for i in range(n):
122 | index_i = index[i]
123 | if index_i < m:
124 | dInt[index_i,i] = d_from_left[i]
125 | if index_i > 0:
126 | for j in range(index_i):
127 | dInt[j,i] = bin_width
128 |
129 | return [dl,dInt.sum(axis=-1)]
130 |
--------------------------------------------------------------------------------
/transformer_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import tensorflow as tf
7 | from tensorflow import keras
8 |
9 | #from transformer_helpers.Models import get_non_pad_mask
10 | from models import get_non_pad_mask
11 | import transformer_helpers.Constants as Constants
12 |
13 |
14 | def compute_event(event, non_pad_mask):
15 | """ Log-likelihood of events. """
16 |
17 | non_pad_mask_not = tf.cast(non_pad_mask==Constants.PAD, tf.float32)
18 | # add 1e-9 in case some events have 0 likelihood
19 | event += math.pow(10, -9)
20 | #event.masked_fill_(~non_pad_mask.bool(), 1.0)
21 | event = tf.where(event==non_pad_mask_not, 1., event)
22 |
23 | result = tf.math.log(event)
24 | return result
25 |
26 |
27 | def compute_integral_biased(all_lambda, time, non_pad_mask):
28 | """ Log-likelihood of non-events, using linear interpolation. """
29 |
30 | diff_time = (time[:, 1:] - time[:, :-1]) * non_pad_mask[:, 1:]
31 | diff_lambda = (all_lambda[:, 1:] + all_lambda[:, :-1]) * non_pad_mask[:, 1:]
32 |
33 | biased_integral = diff_lambda * diff_time
34 | result = 0.5 * biased_integral
35 | return result
36 |
37 |
38 | def compute_integral_unbiased(model, data, time, non_pad_mask, type_mask):
39 | """ Log-likelihood of non-events, using Monte Carlo integration. """
40 |
41 | num_samples = 100
42 |
43 | #diff_time = (time[:, 1:] - time[:, :-1]) * non_pad_mask[:, 1:]
44 | diff_time = (time) * non_pad_mask
45 | #temp_time = diff_time.unsqueeze(2) * \
46 | # torch.rand([*diff_time.size(), num_samples], device=data.device)
47 | temp_time = tf.expand_dims(diff_time, axis=2) * \
48 | tf.random.uniform([*(diff_time.shape.as_list()), num_samples])
49 | #temp_time /= (time[:, :-1] + 1).unsqueeze(2)
50 | #temp_time /= tf.expand_dims((time[:, :-1] + 1), axis=2)
51 | temp_time /= tf.expand_dims((tf.cumsum(time, axis=-1) + 1), axis=2)
52 |
53 | temp_hid = model.linear(data)
54 | #temp_hid = torch.sum(temp_hid * type_mask[:, 1:, :], dim=2, keepdim=True)
55 | temp_hid = tf.reduce_sum(temp_hid * type_mask, axis=2, keepdims=True)
56 |
57 | #all_lambda = F.softplus(temp_hid + model.alpha * temp_time, threshold=10)
58 | # No threshold parameter for tf.nn.softplus is available, (or not required).
59 | all_lambda = tf.nn.softplus(temp_hid + model.alpha * temp_time)
60 | #all_lambda = torch.sum(all_lambda, dim=2) / num_samples
61 | all_lambda = tf.reduce_sum(all_lambda, axis=2) / num_samples
62 |
63 | unbiased_integral = all_lambda * diff_time
64 | return unbiased_integral
65 |
66 |
67 | def log_likelihood(model, data, time, types):
68 | """ Log-likelihood of sequence. """
69 |
70 | #non_pad_mask = get_non_pad_mask(types).squeeze(2)
71 | non_pad_mask = get_non_pad_mask(types)
72 |
73 | #type_mask = torch.zeros([*types.size(), model.num_types], device=data.device)
74 | #type_mask = tf.zeros([*(types.shape.as_list()), model.num_types])
75 | #for i in range(model.num_types):
76 | # #type_mask[:, :, i] = (types == i + 1).bool().to(data.device)
77 | # type_mask[:, :, i] = tf.cast((types == i + 1), tf.bool)
78 | type_ids = tf.expand_dims(tf.expand_dims(tf.range(1, model.num_types+1, dtype=tf.float32), axis=0), axis=1)
79 | type_mask = tf.cast((tf.expand_dims(types, axis=-1) == type_ids), tf.float32)
80 |
81 | all_hid = model.linear(data)
82 | #all_lambda = F.softplus(all_hid, threshold=10)
83 | all_lambda = tf.nn.softplus(all_hid)
84 | #type_lambda = torch.sum(all_lambda * type_mask, dim=2)
85 | type_lambda = tf.reduce_sum(all_lambda * type_mask, axis=2)
86 |
87 | # event log-likelihood
88 | event_ll = compute_event(type_lambda, non_pad_mask)
89 | #event_ll = torch.sum(event_ll, dim=-1)
90 | event_ll = tf.reduce_sum(event_ll, axis=-1)
91 |
92 | # non-event log-likelihood, either numerical integration or MC integration
93 | # non_event_ll = compute_integral_biased(type_lambda, time, non_pad_mask)
94 | non_event_ll = compute_integral_unbiased(model, data, time, non_pad_mask, type_mask)
95 | #non_event_ll = torch.sum(non_event_ll, dim=-1)
96 | non_event_ll = tf.reduce_sum(non_event_ll, axis=-1)
97 |
98 | return event_ll, non_event_ll
99 |
100 |
101 | def type_loss(prediction, types, loss_func):
102 | """ Event prediction loss, cross entropy or label smoothing. """
103 |
104 | # convert [1,2,3] based types to [0,1,2]; also convert padding events to -1
105 | truth = types[:, 1:] - 1
106 | prediction = prediction[:, :-1, :]
107 |
108 | #pred_type = torch.max(prediction, dim=-1)[1]
109 | pred_type = tf.cast(tf.argmax(prediction, axis=-1), tf.float32)
110 | #correct_num = torch.sum(pred_type == truth)
111 | correct_num = tf.reduce_sum(tf.cast(pred_type == truth, tf.float32))
112 |
113 | # compute cross entropy loss
114 | if isinstance(loss_func, LabelSmoothingLoss):
115 | loss = loss_func(prediction, truth)
116 | else:
117 | #loss = loss_func(prediction.transpose(1, 2), truth)
118 | loss = loss_func(truth, prediction)
119 |
120 | #loss = torch.sum(loss)
121 | loss = tf.reduce_sum(loss)
122 | return loss, correct_num
123 |
124 |
125 | def time_loss(prediction, event_time):
126 | """ Time prediction loss. """
127 |
128 | #prediction.squeeze_(-1)
129 | #tf.squeeze(prediction, axis=-1)
130 |
131 | #true = event_time[:, 1:] - event_time[:, :-1]
132 | true = event_time
133 | #prediction = prediction[:, :-1]
134 |
135 | # event time gap prediction
136 | diff = prediction - true
137 | #se = torch.sum(diff * diff)
138 | se = tf.reduce_sum(diff * diff)
139 | return se
140 |
141 |
142 | class LabelSmoothingLoss(tf.keras.losses.Loss):
143 | """
144 | With label smoothing,
145 | KL-divergence between q_{smoothed ground truth prob.}(w)
146 | and p_{prob. computed by model}(w) is minimized.
147 | """
148 | def __init__(
149 | self, label_smoothing, tgt_vocab_size, ignore_index=-100,
150 | reduction=keras.losses.Reduction.AUTO,
151 | name='LabelSmoothingLoss'):
152 |
153 | assert 0.0 < label_smoothing <= 1.0
154 | super(LabelSmoothingLoss, self).__init__(reduction=reduction, name=name)
155 |
156 | self.eps = label_smoothing
157 | self.num_classes = tgt_vocab_size
158 | self.ignore_index = ignore_index
159 |
160 | def call(self, output, target):
161 | """
162 | output (FloatTensor): (batch_size) x n_classes
163 | target (LongTensor): batch_size
164 | """
165 |
166 | non_pad_mask = tf.cast((target != tf.cast((self.ignore_index), tf.float32)), tf.float32)
167 |
168 | #target[target.eq(self.ignore_index)] = 0
169 | #target[target == (self.ignore_index)] = 0
170 | target = tf.where(target==self.ignore_index, 0, target)
171 | #one_hot = F.one_hot(target, num_classes=self.num_classes).float()
172 | one_hot = tf.cast(tf.one_hot(tf.cast(target, tf.int64), depth=self.num_classes), tf.float32)
173 | one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / self.num_classes
174 |
175 | #log_prb = F.log_softmax(output, dim=-1)
176 | log_prb = tf.nn.log_softmax(output, axis=-1)
177 | #loss = -(one_hot * log_prb).sum(dim=-1)
178 | loss = tf.reduce_sum(-(one_hot * log_prb), axis=-1)
179 | loss = loss * non_pad_mask
180 | return loss
181 |
--------------------------------------------------------------------------------
/modules/Hawkes/tools/BasisFunction.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import sys,time,datetime,copy,subprocess,itertools,pickle,warnings,numbers
5 |
6 | import numpy as np
7 | import scipy as sp
8 | import pandas as pd
9 | from matplotlib import pyplot as plt
10 | import matplotlib as mpl
11 |
12 | from scipy.interpolate import BSpline
13 | import scipy.sparse as sparse
14 | from scipy.sparse import linalg as spla
15 |
16 | ###################################################################################### core class
17 | class BasisFunctionExpansion_1D:
18 |
19 | def __init__(self,itv=None,num_basis=10):
20 | self.itv = itv
21 | self.num_basis = num_basis
22 | self.coef = np.zeros(num_basis)
23 |
24 | def set_coef(self,coef):
25 | self.coef = coef
26 | return self
27 |
28 | def Matrix_BasisFunction(self,x):
29 | pass
30 |
31 | def d_Matrix_BasisFunction(self,x):
32 | pass
33 |
34 | def set_x(self,x):
35 | [st,en] = self.itv
36 | self.x = x
37 | self.A = self.Matrix_BasisFunction(x)
38 | self.A_t = self.A.transpose()
39 | self.A_sp = sparse.csc_matrix(self.A)
40 | self.A_t_sp = sparse.csc_matrix(self.A_t)
41 | bin_edge = np.hstack([st,(x[:-1]+x[1:])/2,en])
42 | self.weight = bin_edge[1:] - bin_edge[:-1]
43 | return self
44 |
45 | def get_y(self):
46 | pass
47 |
48 | def get_dy(self):
49 | pass
50 |
51 | def get_y_at(self,x):
52 | pass
53 |
54 | def get_int(self):
55 | weight = self.weight
56 | y = self.get_y()
57 | Int = weight.dot(y)
58 | return Int
59 |
60 | def get_dint(self):
61 | weight = self.weight
62 | dy = self.get_dy()
63 | dInt = weight.dot(dy)
64 | return dInt
65 |
66 | def set_V(self,V):
67 | self.V = V
68 | return self
69 |
70 | def set_bayes(self):
71 | x = self.x
72 | d_A = self.d_Matrix_BasisFunction(x)
73 | W = sparse.csc_matrix(d_A.transpose().dot(d_A))
74 | self.W = W
75 | return self
76 |
77 | def LGH(self):
78 | x = self.coef; V = self.V; W = self.W;
79 | P = W/V
80 | P[0,0] += 1e-3
81 | log_const = logdet_sp(P)/2 - P.shape[0]*np.log(2*np.pi)/2
82 | L = log_const - x.dot(P.dot(x))/2.0
83 | G = - P.dot(x)
84 | H = - P
85 | return [L,G,H]
86 |
87 | def GH_transform(self,G,H):
88 | A = self.A_sp; A_t = self.A_t_sp;
89 | G = A_t.dot(G)
90 | H = A_t.dot(H.dot(A))
91 | return [G,H]
92 |
93 |
94 | ###################################################################################### Base class
95 | class linear_1D(BasisFunctionExpansion_1D):
96 |
97 | def get_y(self):
98 | A = self.A; coef = self.coef
99 | y = A.dot(coef)
100 | return y
101 |
102 | def get_dy(self):
103 | A = self.A;
104 | return A
105 |
106 | def get_y_at(self,x):
107 | coef = self.coef
108 | A = self.Matrix_BasisFunction(x)
109 | return A.dot(coef)
110 |
111 | class loglinear_1D(BasisFunctionExpansion_1D):
112 |
113 | def get_y(self):
114 | A = self.A; coef = self.coef
115 | y = np.exp( A.dot(coef) )
116 | self.y = y
117 | return y
118 |
119 | def get_dy(self):
120 | A = self.A; y = self.y
121 | dy = y.reshape(-1,1) * A
122 | return dy
123 |
124 | def get_y_at(self,x):
125 | coef = self.coef
126 | A = self.Matrix_BasisFunction(x)
127 | return np.exp( A.dot(coef) )
128 |
129 | ###################################################################################### Bump function
130 | def bump_cos(x):
131 | y = np.zeros_like(x)
132 | index = (-21 else index_ini })
27 | self._hash_table.update({ (para,i): index_ini+i for i in range(length) })
28 | index_ini += length
29 |
30 | ### hash table
31 | def idx(self,key):
32 | return self._hash_table[key]
33 |
34 | def _index(self,key_list):
35 | try:
36 | index = np.hstack([ np.arange(self._length,dtype='i8')[self._hash_table[key]] for key in key_list ])
37 | except:
38 | index = slice(0)
39 | return index
40 |
41 | def add_key(self,key_list,key_name):
42 | self._hash_table.update({ key_name: self._index(key_list) })
43 | return self
44 |
45 | ### I/O
46 | def from_dict(self,dic):
47 | x = np.zeros(self._length)
48 | for key in dic:
49 | try:
50 | x[self.idx(key)] = dic[key]
51 | except:
52 | x[self.idx(key)] = dic[key][0]
53 | return x
54 |
55 | def to_dict(self,ndarray):
56 | return { para: ndarray[self.idx(para)] for para in self._para_list }
57 |
58 | ##################################
59 | ## Quasi Newton
60 | ##################################
61 | def Quasi_Newton(model,prior=[],merge=[],opt=[]):
62 |
63 | ## parameter setting
64 | para_list = model.stg["para_list"]
65 | para_length = [ model.stg["para_length"][key] for key in para_list ]
66 | param = array_label(para_list,para_length)
67 | param.add_key( [ pr for pr in para_list if model.stg["para_exp"][pr] ], "para_exp")
68 | param.add_key( [ pr for pr in para_list if not model.stg["para_exp"][pr] ], "para_ord")
69 | model.stg['para_label'] = param
70 |
71 | if 'para_ini' not in opt:
72 | para = param.from_dict(model.stg['para_ini'])
73 | else:
74 | para = param.from_dict(opt['para_ini'])
75 |
76 | step_Q = param.from_dict(model.stg["para_step_Q"])
77 | m = len(para)
78 |
79 | ## prior setting
80 | if prior:
81 | ##fix check
82 | para_fix_index = [ (prior_i["name"],prior_i["index"]) for prior_i in prior if prior_i["type"] == "f" ]
83 | para_fix_value = [ prior_i["mu"] for prior_i in prior if prior_i["type"] == "f" ]
84 | param.add_key(para_fix_index,"fix")
85 | para[param.idx('fix')] = para_fix_value
86 | prior = [ prior_i for prior_i in prior if prior_i["type"] != "f" ]
87 | else:
88 | param.add_key([],"fix")
89 |
90 | ## merge setting
91 | if merge:
92 | d = len(merge)
93 | index_merge = np.zeros(m,dtype='i8')
94 |
95 | for i in range(d):
96 | key = 'merge%d' % i
97 | param.add_key(merge[i],key)
98 | para[param.idx(key)] = para[param.idx(key)].mean()
99 | index_merge[param.idx(key)] = i+1
100 |
101 | M_merge_z = np.eye(m)[ index_merge == 0 ]
102 | M_merge_nz = np.vstack( [ np.eye(m)[index_merge == i+1].sum(axis=0) for i in range(d) ] )
103 | M_merge = np.vstack([M_merge_z,M_merge_nz])
104 | M_merge_T = np.transpose(M_merge)
105 | m_reduced = M_merge.shape[0]
106 |
107 | else:
108 | M_merge = 1
109 | M_merge_T = 1
110 | m_reduced = m
111 |
112 | # calculate Likelihood and Gradient at the initial state
113 | [L1,G1] = Penalized_LG(model,para,prior)
114 | G1[param.idx("para_exp")] *= para[param.idx("para_exp")]
115 | G1 = np.dot(M_merge,G1)
116 |
117 | # main
118 | H = np.eye(m_reduced)
119 | i_loop = 0
120 |
121 | while 1:
122 |
123 | if 'print' in opt:
124 | print(i_loop)
125 | print(param.to_dict(para))
126 | #print(G1)
127 | print( "L = %.3f, norm(G) = %e\n" % (L1,np.linalg.norm(G1)) )
128 | #sys.exit()
129 |
130 | if 'stop' in opt:
131 | if i_loop == opt['stop']:
132 | break
133 |
134 | #break rule
135 | if np.linalg.norm(G1) < 1e-5 :
136 | break
137 |
138 | #calculate direction
139 | s = H.dot(G1);
140 | s_extended = np.dot(M_merge_T,s)
141 | gamma = 1/np.max([np.max(np.abs(s_extended)/step_Q),1])
142 | s = s * gamma
143 | s_extended = s_extended * gamma
144 |
145 | #move to new point
146 | para[param.idx("para_ord")] += s_extended[param.idx("para_ord")]
147 | para[param.idx("para_exp")] *= np.exp( s_extended[param.idx("para_exp")] )
148 |
149 | #calculate Likelihood and Gradient at the new point
150 | [L2,G2] = Penalized_LG(model,para,prior)
151 | G2[param.idx("para_exp")] *= para[param.idx("para_exp")]
152 | G2 = np.dot(M_merge,G2)
153 |
154 | #update hessian matrix
155 | y = (G1-G2).reshape(-1,1)
156 | s = s.reshape(-1,1)
157 |
158 | if y.T.dot(s) > 0:
159 | H = H + (y.T.dot(s)+y.T.dot(H).dot(y))*(s*s.T)/(y.T.dot(s))**2 - (H.dot(y)*s.T+(s*y.T).dot(H))/(y.T.dot(s))
160 | else:
161 | H = np.eye(m_reduced)
162 |
163 | #update Gradients
164 | L1 = L2
165 | G1 = G2
166 |
167 | i_loop += 1
168 |
169 | ###OPTION: Estimation Error
170 | if 'ste' in opt:
171 | ste = EstimationError(model,para,prior)
172 | else:
173 | ste = []
174 |
175 | ###OPTION: Check map solution
176 | if 'check' in opt:
177 | Check_QN(model,para,prior)
178 |
179 | return [param.to_dict(para),L1,ste,np.linalg.norm(G1),i_loop]
180 |
181 | def Check_QN(model,para,prior):
182 | param = model.stg['para_label']
183 | ste = EstimationError_approx(model,para,prior)
184 | ste[param.idx('fix')] = 0
185 | a = np.linspace(-1,1,21)
186 |
187 | for key in model.stg['para_list']:
188 | for index in range(model.stg['para_length'][key]):
189 |
190 | plt.figure()
191 | plt.title(key + '-' + str(index))
192 |
193 | for i in range(len(a)):
194 | para_tmp = para.copy()
195 | para_tmp[param.idx((key,index))] += a[i] * ste[param.idx((key,index))]
196 | L = Penalized_LG(model,para_tmp,prior)[0]
197 | plt.plot(para_tmp[param.idx((key,index))],L,"ko")
198 |
199 | if i==10:
200 | plt.plot(para_tmp[param.idx((key,index))],L,"ro")
201 |
202 | #################################
203 | ## Basic funnctions
204 | #################################
205 | def G_NUMERICAL(model,para):
206 |
207 | m = len(para)
208 | param = model.stg['para_label']
209 | step_diff = param.from_dict(model.stg['para_step_diff'])
210 | step_diff[param.idx('para_exp')] *= para[param.idx('para_exp')]
211 | G = np.zeros(m)
212 |
213 | for i in range(m):
214 | step = step_diff[i]
215 |
216 | """
217 | para_tmp = para.copy(); para_tmp[i] -= step; L1 = model.LG(param.to_dict(para_tmp))[0]
218 | para_tmp = para.copy(); para_tmp[i] += step; L2 = model.LG(param.to_dict(para_tmp))[0]
219 | G[i]= (L2-L1)/2/step
220 | """
221 |
222 | para_tmp = para.copy(); para_tmp[i] -= 2*step; L1 = model.LG(param.to_dict(para_tmp))[0]
223 | para_tmp = para.copy(); para_tmp[i] -= 1*step; L2 = model.LG(param.to_dict(para_tmp))[0]
224 | para_tmp = para.copy(); para_tmp[i] += 1*step; L3 = model.LG(param.to_dict(para_tmp))[0]
225 | para_tmp = para.copy(); para_tmp[i] += 2*step; L4 = model.LG(param.to_dict(para_tmp))[0]
226 | G[i]= (L1-8*L2+8*L3-L4)/12/step
227 |
228 | return G
229 |
230 |
231 | def Hessian(model,para,prior):
232 |
233 | m = len(para)
234 | param = model.stg['para_label']
235 | step_diff = param.from_dict(model.stg['para_step_diff'])
236 | step_diff[param.idx('para_exp')] *= para[param.idx('para_exp')]
237 | H = np.zeros((m,m))
238 |
239 | for i in range(m):
240 | step = step_diff[i]
241 | para_tmp = para.copy(); para_tmp[i] -= step; G1 = Penalized_LG(model,para_tmp,prior)[1]
242 | para_tmp = para.copy(); para_tmp[i] += step; G2 = Penalized_LG(model,para_tmp,prior)[1]
243 | H[i] = (G2-G1)/2/step
244 |
245 | H[param.idx('fix')] = 0
246 | H[param.idx('fix'),param.idx('fix')] = -1e+20
247 |
248 | return H
249 |
250 | def EstimationError(model,para,prior):
251 | H = Hessian(model,para,prior)
252 | ste = np.sqrt(np.diag(np.linalg.inv(-H)))
253 | return ste
254 |
255 | def EstimationError_approx(model,para,prior):
256 | H = Hessian(model,para,prior)
257 | ste = 1.0/np.sqrt(np.diag(-H))
258 | return ste
259 |
260 |
261 | def Penalized_LG(model,para,prior,only_L=False):
262 |
263 | param = model.stg['para_label']
264 |
265 | [L,G] = model.LG(param.to_dict(para),only_L)
266 |
267 | if isinstance(G,str):
268 | G = G_NUMERICAL(model,para)
269 | else:
270 | G = param.from_dict(G)
271 |
272 | ## fix
273 | if not only_L:
274 | G[param.idx("fix")] = 0
275 |
276 | ## prior
277 | if prior:
278 |
279 | for prior_i in prior:
280 | para_key = prior_i["name"]
281 | para_index = prior_i["index"]
282 | prior_type = prior_i["type"]
283 | mu = prior_i["mu"]
284 | sigma = prior_i["sigma"]
285 | index = param.idx((para_key,para_index))
286 | x = para[index]
287 |
288 | if prior_type == 'n': #prior: normal distribution
289 | L += - np.log(2*np.pi*sigma**2)/2 - (x-mu)**2/2/sigma**2
290 | if not only_L:
291 | G[index] += - (x-mu)/sigma**2
292 | elif prior_type == 'ln': #prior: log-normal distribution
293 | L += - np.log(2*np.pi*sigma**2)/2 - np.log(x) - (np.log(x)-mu)**2/2/sigma**2
294 | if not only_L:
295 | G[index] += - 1/x - (np.log(x)-mu)/sigma**2/x
296 | elif prior_type == "b": #prior: barrier function
297 | L += - mu/x
298 | if not only_L:
299 | G[index] += mu/x**2
300 | elif prior_type == "b2": #prior: barrier function
301 | L += mu *np.log10(np.e)*np.log(x)
302 | if not only_L:
303 | G[index] += mu * np.log10(np.e)/x
304 |
305 | return [L,G]
306 |
307 | #################################
308 | ## para_stg
309 | #################################
310 | def merge_stg(para_stgs):
311 |
312 | stg = {}
313 | stg['para_list'] = []
314 | stg['para_length'] = {}
315 | stg['para_exp'] = {}
316 | stg['para_ini'] = {}
317 | stg['para_step_Q'] = {}
318 | stg['para_step_diff'] = {}
319 |
320 | for para_stg in para_stgs:
321 | stg['para_list'].extend(para_stg['list'])
322 | stg['para_length'].update(para_stg['length'])
323 | stg['para_exp'].update(para_stg['exp'])
324 | stg['para_ini'].update(para_stg['ini'])
325 | stg['para_step_Q'].update(para_stg['step_Q'])
326 | stg['para_step_diff'].update(para_stg['step_diff'])
327 |
328 | return stg
329 |
--------------------------------------------------------------------------------
/generator.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 |
6 | from collections import OrderedDict
7 | from collections import Counter
8 | from operator import itemgetter
9 |
10 | # Hawkes model is from https://omitakahiro.github.io/Hawkes/index.html
11 | from modules import Hawkes as hk
12 |
13 | para = {'mu':0.1, 'alpha':0.3, 'beta':0.6}
14 | mu_t = lambda x: (1.0 + 0.8*np.sin(2*np.pi*x/100)) * 0.2 # baseline function for overlay
15 | itv = [0,360000]
16 | demo_itv = [0,360]
17 |
18 | np.random.seed(42)
19 | downsampling = {'Trump': 10}
20 | #downsampling = {'taxi': 20, 'Trump': 20}
21 |
22 | def downsampling_dataset(timestamps, dataset_name):
23 | print('Down-sampling', dataset_name, 'dataset by', downsampling[dataset_name])
24 | return timestamps[::downsampling[dataset_name]]
25 |
26 | def purge_duplicate_events(timestamps, types):
27 | timestamps = timestamps.tolist()
28 | types = types.tolist()
29 | del_indices = []
30 | events = [(ts, ty) for ts, ty in zip(timestamps, types)]
31 | events_next = events[1:]
32 | np.where([e[0]==en[0] and e[1]==en[1] for e, en in zip(events[:-1], events_next)])
33 | for i in range(1, len(timestamps)):
34 | if timestamps[i]==timestamps[i-1] and types[i]==types[i-1]:
35 | del_indices.append(i)
36 |
37 | for ind in sorted(del_indices, reverse=True):
38 | del timestamps[ind]
39 | del types[ind]
40 |
41 | return timestamps, types
42 |
43 | def keep_top_k_types(types, keep_classes=10):
44 | types_counter = OrderedDict(sorted(Counter(types).items(), key=itemgetter(1), reverse=True))
45 | type2supertype = OrderedDict()
46 | for i, (type_, _) in enumerate(types_counter.items()):
47 | if i > keep_classes:
48 | type2supertype[type_] = keep_classes + 1
49 | else:
50 | type2supertype[type_] = i + 1
51 |
52 | types_new = [type2supertype[ty] for ty in types]
53 | return np.array(types_new)
54 |
55 |
56 | def hawkes_demo():
57 | hk_model = hk.simulator().set_kernel('exp').set_baseline('const').set_parameter(para)
58 | T = hk_model.simulate(demo_itv)
59 | hk_model.plot_l()
60 | plt.savefig('hawkes_intensity.png')
61 | plt.close()
62 | hk_model.plot_N()
63 | plt.savefig('hawkes_event_counts.png')
64 | plt.close()
65 |
66 | def sin_hawkes_overlay_demo():
67 | hk_model = hk.simulator().set_kernel('exp').set_baseline('custom',l_custom=mu_t).set_parameter(para)
68 | T = hk_model.simulate(demo_itv)
69 | hk_model.plot_l()
70 | plt.savefig('sin_hawkes_overlay_intensity.png')
71 | plt.close()
72 | hk_model.plot_N()
73 | plt.savefig('sin_hawkes_overlay_event_counts.png')
74 | plt.close()
75 |
76 | def create_sin_data():
77 | omega = 1.0
78 | points = 10000
79 | num_marks = 7
80 | x = np.linspace(0, points, 3*points)
81 | y_ = 10*np.sin(omega*x)
82 | y = y_ + 11
83 | gaps=y
84 | timestamp = np.cumsum(gaps)
85 | types = []
86 | if y_[0]=0.:
90 | # types.append(0)
91 | #else:
92 | # types.append(1)
93 | if y_[i]>=0. and y_[i]>y_[i-1]:
94 | types.append(1)
95 | if y_[i]>=0. and y_[i]y_[i-1]:
100 | types.append(4)
101 |
102 | types = np.array(types)
103 |
104 | plt.plot(x[:25], y[:25], 'o', color='black');
105 | plt.savefig('data/sin.png')
106 | plt.close()
107 | return gaps, timestamp, types
108 |
109 | def create_hawkes_data():
110 | hawkes_demo()
111 | hk_model = hk.simulator().set_kernel('exp').set_baseline('const').set_parameter(para)
112 | timestamp = hk_model.simulate([0, 360000])
113 | gaps = timestamp[1:] - timestamp[:-1]
114 | return gaps, timestamp
115 |
116 | def create_sin_hawkes_overlay_data():
117 | sin_hawkes_overlay_demo()
118 | hk_model = hk.simulator().set_kernel('exp').set_baseline('custom',l_custom=mu_t).set_parameter(para)
119 | timestamp = hk_model.simulate([0, 360000])
120 | gaps = timestamp[1:] - timestamp[:-1]
121 | return gaps, timestamp
122 |
123 | def create_taxi_data():
124 | # https://s3.amazonaws.com/nyc-tlc/trip+data/yellow_tripdata_2019-01.csv
125 | # https://s3.amazonaws.com/nyc-tlc/trip+data/yellow_tripdata_2019-02.csv
126 | taxi_df_jan = pd.read_csv(
127 | './data/yellow_tripdata_2019-01.csv',
128 | usecols=["tpep_pickup_datetime", "PULocationID", "DOLocationID"])
129 | taxi_df_feb = pd.read_csv(
130 | './data/yellow_tripdata_2019-02.csv',
131 | usecols=["tpep_pickup_datetime", "PULocationID", "DOLocationID"])
132 | taxi_df = taxi_df_jan.append(taxi_df_feb)
133 | taxi_df = taxi_df[taxi_df.PULocationID == 237]
134 | taxi_df['tpep_pickup_datetime'] = pd.to_datetime(taxi_df['tpep_pickup_datetime'], errors='coerce')
135 | taxi_df = taxi_df[(taxi_df['tpep_pickup_datetime'].dt.year == 2019)]
136 | taxi_df = taxi_df[(taxi_df['tpep_pickup_datetime'].dt.month < 3)]
137 | taxi_df = taxi_df.sort_values('tpep_pickup_datetime')
138 | taxi_types = taxi_df['DOLocationID'].values
139 | #taxi_timestamps = taxi_timestamps.sort_values().astype(np.int64)
140 | taxi_timestamps = pd.DatetimeIndex(taxi_df['tpep_pickup_datetime']).astype(np.int64)/1000000000
141 | taxi_timestamps = np.array(taxi_timestamps)
142 | taxi_timestamps -= taxi_timestamps[0]
143 | taxi_timestamps = taxi_timestamps[:-1]
144 | taxi_types = taxi_types[:-1]
145 | taxi_types = keep_top_k_types(taxi_types)
146 | dataset_name = 'taxi'
147 | if dataset_name in downsampling:
148 | taxi_timestamps = downsampling_dataset(taxi_timestamps, dataset_name)
149 | taxi_types = downsampling_dataset(taxi_types, dataset_name)
150 | taxi_gaps = taxi_timestamps[1:] - taxi_timestamps[:-1]
151 | plt.plot(taxi_gaps[:100])
152 | plt.ylabel('Gaps')
153 | plt.savefig('data/taxi_gaps.png')
154 | plt.close()
155 | return taxi_gaps, taxi_timestamps, taxi_types
156 |
157 | def create_911_traffic_data():
158 | call_df = pd.read_csv('./data/911.csv')
159 | call_df = call_df[call_df['zip'].isnull()==False] # Ignore calls with NaN zip codes
160 | print('Types of Emergencies')
161 | print(call_df.title.apply(lambda x: x.split(':')[0]).value_counts())
162 | call_df['type'] = call_df.title.apply(lambda x: x.split(':')[0])
163 | print('Subtypes')
164 | for each in call_df.type.unique():
165 | subtype_count = call_df[call_df.title.apply(lambda x: x.split(':')[0]==each)].title.value_counts()
166 | print('For', each, 'type of Emergency, we have ', subtype_count.count(), 'subtypes')
167 | print(subtype_count[subtype_count>100])
168 | print('Out of 3 types taking Traffic type considering only Traffic')
169 | call_data = call_df[call_df['type']=='Traffic']
170 | call_data['timeStamp'] = pd.to_datetime(call_data['timeStamp'], errors='coerce')
171 | print("We have timeline from", call_data['timeStamp'].min(), "to", call_data['timeStamp'].max())
172 | call_data = call_data.sort_values('timeStamp')
173 |
174 | call_timestamps = pd.DatetimeIndex(call_data['timeStamp']).astype(np.int64)/1000000000
175 | #call_timestamps = call_data.sort_values().astype(np.int64)
176 | call_timestamps = np.array(call_timestamps)
177 | call_timestamps -= call_timestamps[0]
178 | call_types = call_data['zip'].values
179 | call_types = keep_top_k_types(call_types)
180 | dataset_name = 'call'
181 | if dataset_name in downsampling:
182 | call_timestamps = downsampling_dataset(call_timestamps, dataset_name)
183 | call_types = downsampling_dataset(call_types, dataset_name)
184 | call_gaps = call_timestamps[1:] - call_timestamps[:-1]
185 | plt.plot(call_gaps[:100])
186 | plt.ylabel('Gaps')
187 | plt.xlabel('timeline')
188 | plt.savefig('data/call_traffic_gaps.png')
189 | plt.close()
190 | return call_gaps, call_timestamps, call_types
191 |
192 | def create_911_ems_data():
193 | call_df = pd.read_csv('./data/911.csv')
194 | call_df = call_df[call_df['zip'].isnull()==False] # Ignore calls with NaN zip codes
195 | call_df['type'] = call_df.title.apply(lambda x: x.split(':')[0])
196 | print('Out of 3 types taking EMS type considering only EMS')
197 | call_data = call_df[call_df['type']=='EMS']
198 | call_data['timeStamp'] = pd.to_datetime(call_data['timeStamp'], errors='coerce')
199 | print("We have timeline from", call_data['timeStamp'].min(), "to", call_data['timeStamp'].max())
200 | call_data = call_data.sort_values('timeStamp')
201 |
202 | call_timestamps = pd.DatetimeIndex(call_data['timeStamp']).astype(np.int64)/1000000000
203 | #call_timestamps = call_data.sort_values().astype(np.int64)
204 | call_timestamps = np.array(call_timestamps)
205 | call_timestamps -= call_timestamps[0]
206 | call_types = call_data['zip'].values
207 | call_types = keep_top_k_types(call_types)
208 | dataset_name = 'call'
209 | if dataset_name in downsampling:
210 | call_timestamps = downsampling_dataset(call_timestamps, dataset_name)
211 | call_types = downsampling_dataset(call_types, dataset_name)
212 | call_gaps = call_timestamps[1:] - call_timestamps[:-1]
213 | plt.plot(call_gaps[:100])
214 | plt.ylabel('Gaps')
215 | plt.xlabel('timeline')
216 | plt.savefig('data/call_ems_gaps.png')
217 | plt.close()
218 | return call_gaps, call_timestamps, call_types
219 |
220 | def generate_dataset():
221 | os.makedirs('./data', exist_ok=True)
222 | #os.chdir('./data')
223 | if not os.path.isfile("sin.txt"):
224 | print('Generating sin data')
225 | gaps, timestamps, types = create_sin_data()
226 | timestamps, types = purge_duplicate_events(timestamps, types)
227 | np.savetxt('data/sin.txt', timestamps)
228 | np.savetxt('data/sin_types.txt', types)
229 | # if not os.path.isfile("hawkes.txt"):
230 | # print('Generating hawkes data')
231 | # gaps, timestamps = create_hawkes_data()
232 | # timestamps, types = purge_duplicate_events(timestamps, types)
233 | # np.savetxt('hawkes.txt', timestamps)
234 | # if not os.path.isfile("sin_hawkes_overlay.txt"):
235 | # print('Generating sin_hawkes_overlay data')
236 | # gaps, timestamps = create_sin_hawkes_overlay_data()
237 | # timestamps, types = purge_duplicate_events(timestamps, types)
238 | # np.savetxt('sin_hawkes_overlay.txt', timestamps)
239 | if not os.path.isfile("911_traffic.txt"):
240 | print('Generating 911 data')
241 | gaps, timestamps, types = create_911_traffic_data()
242 | timestamps, types = purge_duplicate_events(timestamps, types)
243 | np.savetxt('data/911_traffic.txt', timestamps)
244 | np.savetxt('data/911_traffic_types.txt', types)
245 | if not os.path.isfile("911_ems.txt"):
246 | print('Generating 911 data')
247 | gaps, timestamps, types = create_911_ems_data()
248 | timestamps, types = purge_duplicate_events(timestamps, types)
249 | np.savetxt('data/911_ems.txt', timestamps)
250 | np.savetxt('data/911_ems_types.txt', types)
251 | if not os.path.isfile("taxi.txt"):
252 | print('Generating taxi data')
253 | gaps, timestamps, types = create_taxi_data()
254 | timestamps = np.array(timestamps).astype(np.float32)
255 | types = np.array(types).astype(np.float32)
256 | timestamps, types = purge_duplicate_events(timestamps, types)
257 | np.savetxt('data/taxi.txt', timestamps)
258 | np.savetxt('data/taxi_types.txt', types)
259 | #os.chdir('../')
260 |
261 | def create_twitter_data(dataset_name, keep_classes=10):
262 | delimiter=' '
263 | if dataset_name in ['Movie', 'Delhi', 'Verdict', 'Fight']:
264 | delimiter='\t'
265 | twitter_df = pd.read_csv('./data/'+dataset_name+'.txt', delimiter=delimiter, header=None)
266 | twitter_df = twitter_df.values[::-1]
267 | #twitter_df = twitter_df[1]
268 | timestamps = twitter_df[:, 1]
269 | timestamps -= timestamps[0]
270 | gaps = timestamps[1:] - timestamps[:-1]
271 | types = twitter_df[:, 0]
272 | types_counter = OrderedDict(sorted(Counter(types).items(), key=itemgetter(1), reverse=True))
273 | type2supertype = OrderedDict()
274 | for i, (type_, _) in enumerate(types_counter.items()):
275 | if i > keep_classes:
276 | type2supertype[type_] = keep_classes + 1
277 | else:
278 | type2supertype[type_] = i + 1
279 |
280 | types_new = [type2supertype[ty] for ty in types]
281 | types = types_new
282 | if dataset_name in downsampling:
283 | plt.plot(gaps)
284 | plt.ylabel('all_Gaps_before_downsample')
285 | plt.savefig(dataset_name+'_all_gaps_before_downsample.png')
286 | plt.close()
287 | timestamps = downsampling_dataset(timestamps, dataset_name)
288 | types = downsampling_dataset(types, dataset_name)
289 |
290 | plt.plot(gaps[:100])
291 | plt.ylabel('Gaps')
292 | plt.savefig(dataset_name+'_gaps.png')
293 | plt.close()
294 | plt.plot(gaps)
295 | plt.ylabel('all_Gaps')
296 | plt.savefig(dataset_name+'_all_gaps.png')
297 | plt.close()
298 | return gaps, timestamps, types
299 |
300 | def generate_twitter_dataset(twitter_dataset_names):
301 | os.makedirs('./data', exist_ok=True)
302 | #os.chdir('./data')
303 | for dataset_name in twitter_dataset_names:
304 | if not os.path.isfile(dataset_name+'.txt'):
305 | print('Generating', dataset_name, 'data')
306 | gaps, timestamps, types = create_twitter_data(dataset_name)
307 | timestamps, types = purge_duplicate_events(np.array(timestamps), np.array(types))
308 | np.savetxt(dataset_name+'.txt', timestamps)
309 | np.savetxt(dataset_name+'_types.txt', types)
310 | #os.chdir('../')
311 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys, os, json
3 | import numpy as np
4 | from itertools import product
5 | from argparse import Namespace
6 | import multiprocessing as MP
7 | from operator import itemgetter
8 | import datetime
9 | from collections import OrderedDict
10 | from generator import generate_dataset, generate_twitter_dataset
11 | import json
12 | import time
13 |
14 | import run
15 | import utils
16 |
17 | # import os
18 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('dataset_name', type=str, help='dataset_name')
22 | parser.add_argument('model_name', type=str, help='model_name')
23 |
24 | parser.add_argument('--num_types', type=int, default=0,
25 | help='Number of marker types. If markers not required, \
26 | num_types=0')
27 |
28 | parser.add_argument('--epochs', type=int, default=0,
29 | help='number of training epochs')
30 | parser.add_argument('--patience', type=int, default=2,
31 | help='Number of epochs to wait for \
32 | before beginning cross-validation')
33 |
34 | parser.add_argument('--learning_rate', type=float, default=1e-3, nargs='+',
35 | help='Learning rate for the training algorithm')
36 | parser.add_argument('-hls', '--hidden_layer_size', type=int, default=32, nargs='+',
37 | help='Number of units in RNN')
38 | parser.add_argument('-embds', '--embed_size', type=int, default=8, nargs='+',
39 | help='Embedding dimension of marks/types')
40 |
41 | parser.add_argument('--output_dir', type=str,
42 | help='Path to store all raw outputs, checkpoints, \
43 | summaries, and plots', default='Outputs')
44 | parser.add_argument('--saved_models', type=str,
45 | help='Path to store model checkpoints', default='saved_models')
46 |
47 | parser.add_argument('--seed', type=int,
48 | help='Seed for parameter initialization',
49 | default=42)
50 |
51 | # Bin size T_i - T_(i-1) in seconds
52 | parser.add_argument('--bin_size', type=int, default=0,
53 | help='Number of seconds in a bin')
54 |
55 | # F(T_(i-1), T_(i-2) ..... , T_(i-r)) -> T(i)
56 | # r_feature_sz = 20
57 | parser.add_argument('--in_bin_sz', type=int,
58 | help='Input count of bins r_feature_sz',
59 | default=20)
60 |
61 | # dec_len = 8 # For All Models
62 | parser.add_argument('--out_bin_sz', type=int,
63 | help='Output count of bin',
64 | default=1)
65 |
66 | parser.add_argument('--cnt_net_type', type=str, default='ff',
67 | help='Count model network type (ff or rnn)')
68 |
69 | # enc_len = 80 # For RMTPP
70 | parser.add_argument('--enc_len', type=int, default=80,
71 | help='Input length for rnn of rmtpp')
72 |
73 | # comp_enc_len = 40 # For Compound RMTPP
74 | parser.add_argument('--comp_enc_len', type=int, default=40,
75 | help='Input length for rnn of compound rmtpp')
76 |
77 | # comp_bin_sz = 10 # For Compound RMTPP
78 | parser.add_argument('--comp_bin_sz', type=int, default=10,
79 | help='events inside one bin of compound rmtpp')
80 |
81 | # wgan_enc_len = 60 # For WGAN
82 | parser.add_argument('--wgan_enc_len', type=int, default=60,
83 | help='Input length for rnn of WGAN')
84 | parser.add_argument('--use_wgan_d', action='store_true', default=False,
85 | help='Whether to use WGAN discriminator or not')
86 |
87 | # Seq2Seq / CWE parameters
88 | parser.add_argument('--use_cwe_d', action='store_true', default=False,
89 | help='Whether to use CWE/Seq2Seq discriminator or not')
90 |
91 | # interval_size = 360 # For RMTPP
92 | parser.add_argument('--interval_size', type=int, default=360,
93 | help='Interval size for threshold query')
94 |
95 | parser.add_argument('--batch_size', type=int, default=32,
96 | help='Input batch size')
97 | parser.add_argument('--query', type=int, default=1,
98 | help='Query number')
99 | parser.add_argument('--stride_len', type=int, default=1,
100 | help='Stride len for RMTPP number')
101 | parser.add_argument('--normalization', type=str, default='average',
102 | help='gap normalization method')
103 |
104 | parser.add_argument('--generate_plots', action='store_true', default=False,
105 | help='Generate dev and test plots, both per epochs \
106 | and after training')
107 | parser.add_argument('--parallel_hparam', action='store_true', default=False,
108 | help='Parallel execution of hyperparameters')
109 |
110 | # Flags for RMTPP calibration
111 | parser.add_argument('--calibrate_rmtpp', action='store_true', default=False,
112 | help='Whether to calibrate RMTPP')
113 | parser.add_argument('--extra_var_model', action='store_true', default=False,
114 | help='Use a separate model to train the variance of RMTPP')
115 |
116 | # Flags for optimizer
117 | parser.add_argument('--opt_num_counts', type=int, default=5,
118 | help='Number of counts to try before and after mean for optimizer')
119 | parser.add_argument('--no_rescale_rmtpp_params', action='store_true', default=False,
120 | help='Do not rescale RMTPP intensities for optimizer')
121 | parser.add_argument('--use_ratio_constraints', action='store_true', default=False,
122 | help='Maintain Ratios of adjacent RMTPP event predictions')
123 | parser.add_argument('--search', type=int, default=0,
124 | help='Search algorithm over counts 0:binary, 1:linear')
125 |
126 | # Parameters for extra_var_model
127 | parser.add_argument('--num_grps', type=int, default=10,
128 | help='Number of groups in each bin in forecast horizon')
129 | parser.add_argument('--num_pos', type=int, default=40,
130 | help='Number of positions in each group in forecast horizon')
131 |
132 | # Time-feature parameters
133 | parser.add_argument('--no_count_model_feats', action='store_true', default=False,
134 | help='Do not use time-features for count model')
135 | parser.add_argument('--no_rmtpp_model_feats', action='store_true', default=False,
136 | help='Do not use time-features for rmtpp model')
137 |
138 |
139 | # Trainsformer Paramerters
140 | parser.add_argument('-d_model', type=int, default=32) #64
141 | parser.add_argument('-d_rnn', type=int, default=8) #256
142 | parser.add_argument('-d_inner_hid', type=int, default=32) #128
143 | parser.add_argument('-d_k', type=int, default=8) #16
144 | parser.add_argument('-d_v', type=int, default=8) #16
145 |
146 | parser.add_argument('-n_head', type=int, default=2) #4
147 | parser.add_argument('-n_layers', type=int, default=1) #4
148 |
149 | parser.add_argument('-dropout', type=float, default=0.1)
150 | parser.add_argument('-lr', type=float, default=1e-4)
151 | parser.add_argument('-smooth', type=float, default=0.)
152 |
153 | args = parser.parse_args()
154 |
155 | dataset_names = list()
156 | if args.dataset_name == 'all':
157 | dataset_names.append('sin')
158 | # dataset_names.append('hawkes')
159 | # dataset_names.append('sin_hawkes_overlay')
160 | dataset_names.append('taxi')
161 | dataset_names.append('911_traffic')
162 | dataset_names.append('911_ems')
163 | dataset_names.append('twitter')
164 | else:
165 | dataset_names.append(args.dataset_name)
166 |
167 | print(dataset_names)
168 |
169 | twitter_dataset_names = list()
170 | if 'twitter' in dataset_names:
171 | dataset_names.remove('twitter')
172 | twitter_dataset_names.append('Trump')
173 | #twitter_dataset_names.append('Verdict')
174 | #twitter_dataset_names.append('Delhi')
175 |
176 | for data_name in twitter_dataset_names:
177 | dataset_names.append(data_name)
178 |
179 | args.dataset_name = dataset_names
180 |
181 | model_names = list()
182 | if args.model_name == 'all':
183 | #model_names.append('hawkes_model')
184 | #model_names.append('wgan')
185 | #model_names.append('seq2seq')
186 | #model_names.append('transformer')
187 | model_names.append('count_model')
188 | # model_names.append('hierarchical')
189 | model_names.append('rmtpp_nll')
190 | model_names.append('rmtpp_mse')
191 | model_names.append('rmtpp_mse_var')
192 | #model_names.append('rmtpp_nll_comp')
193 | model_names.append('rmtpp_mse_comp')
194 | model_names.append('rmtpp_mse_var_comp')
195 | #model_names.append('pure_hierarchical_nll')
196 | #model_names.append('pure_hierarchical_mse')
197 | model_names.append('inference_models')
198 | else:
199 | model_names.append(args.model_name)
200 | args.model_name = model_names
201 |
202 | #run_model_flags = {
203 | # #'compute_time_range_pdf': False,
204 | #
205 | # #'run_rmtpp_count_with_optimization': False,
206 | # #'run_rmtpp_with_optimization_fixed_cnt': False,
207 | #
208 | # 'count_only': True,
209 | #}
210 |
211 | run_model_flags = OrderedDict()
212 | if 'rmtpp_nll' in model_names:
213 | run_model_flags['rmtpp_nll_opt'] = {'rmtpp_type':'nll'}
214 | #run_model_flags['rmtpp_nll_cont'] = {'rmtpp_type':'nll'}
215 | #run_model_flags['rmtpp_nll_reinit'] = True
216 | run_model_flags['rmtpp_nll_simu'] = {'rmtpp_type':'nll'}
217 | if 'rmtpp_mse' in model_names:
218 | run_model_flags['rmtpp_mse_opt'] = {'rmtpp_type':'mse'}
219 | #run_model_flags['rmtpp_mse_cont'] = {'rmtpp_type':'mse'}
220 | #run_model_flags['rmtpp_mse_reinit'] = True
221 | run_model_flags['rmtpp_mse_simu'] = {'rmtpp_type':'mse'}
222 | #run_model_flags['rmtpp_mse_simu_nc'] = {'rmtpp_type':'mse'}
223 | #run_model_flags['rmtpp_mse_coopt'] = {'rmtpp_type':'mse'}
224 | if 'rmtpp_mse_var' in model_names:
225 | run_model_flags['rmtpp_mse_var_opt'] = {'rmtpp_type':'mse_var'}
226 | #run_model_flags['rmtpp_mse_var_cont'] = {'rmtpp_type':'mse_var'}
227 | #run_model_flags['rmtpp_mse_var_reinit'] = True
228 | run_model_flags['rmtpp_mse_var_simu'] = {'rmtpp_type':'mse_var'}
229 | #run_model_flags['rmtpp_mse_var_coopt'] = {'rmtpp_type':'mse_var'}
230 | if 'rmtpp_nll_comp' in model_names:
231 | #run_model_flags['run_rmtpp_with_joint_optimization_fixed_cnt_solver_nll_comp'] = True
232 | run_model_flags['rmtpp_nll_opt_comp'] = {'rmtpp_type':'nll', 'rmtpp_type_comp':'nll'}
233 | #run_model_flags['rmtpp_nll_cont_comp'] = {'rmtpp_type':'nll', 'rmtpp_type_comp':'nll'}
234 | if 'rmtpp_mse_comp' in model_names:
235 | #run_model_flags['run_rmtpp_with_joint_optimization_fixed_cnt_solver_mse_comp'] = True
236 | run_model_flags['rmtpp_mse_opt_comp'] = {'rmtpp_type':'mse', 'rmtpp_type_comp':'mse'}
237 | #run_model_flags['rmtpp_mse_cont_comp'] = {'rmtpp_type':'mse', 'rmtpp_type_comp':'mse'}
238 | if 'rmtpp_mse_var_comp' in model_names:
239 | #run_model_flags['run_rmtpp_with_joint_optimization_fixed_cnt_solver_mse_var_comp'] = True
240 | run_model_flags['rmtpp_mse_var_opt_comp'] = {'rmtpp_type':'mse_var', 'rmtpp_type_comp':'mse_var'}
241 | #run_model_flags['rmtpp_mse_var_cont_comp'] = {'rmtpp_type':'mse_var', 'rmtpp_type_comp':'mse_var'}
242 | if 'pure_hierarchical_nll' in model_names:
243 | run_model_flags['run_pure_hierarchical_infer_nll'] = True
244 | if 'pure_hierarchical_mse' in model_names:
245 | run_model_flags['run_pure_hierarchical_infer_mse'] = True
246 | if 'count_model' in model_names:
247 | run_model_flags['count_only'] = True
248 | if 'wgan' in model_names:
249 | run_model_flags['wgan_simu'] = True
250 | if 'seq2seq' in model_names:
251 | run_model_flags['seq2seq_simu'] = True
252 | if 'transformer' in model_names:
253 | run_model_flags['transformer_simu'] = True
254 | #run_model_flags['transformer_simu_nc'] = True
255 | if 'hawkes_model' in model_names:
256 | run_model_flags['hawkes_simu'] = True
257 |
258 | automate_bin_sz = False
259 | if args.bin_size == 0:
260 | automate_bin_sz = True
261 |
262 | if args.patience >= args.epochs:
263 | args.patience = 0
264 |
265 | id_process = os.getpid()
266 | time_current = datetime.datetime.now().isoformat()
267 |
268 | print('args', args)
269 |
270 | print("********************************************************************")
271 | print("PID: %s" % str(id_process))
272 | print("Time: %s" % time_current)
273 | print("epochs: %s" % str(args.epochs))
274 | print("learning_rate: %s" % str(args.learning_rate))
275 | print("seed: %s" % str(args.seed))
276 | print("Models: %s" % str(model_names))
277 | print("Datasets: %s" % str(dataset_names))
278 | print("********************************************************************")
279 |
280 | print("####################################################################")
281 | np.random.seed(args.seed)
282 | os.makedirs(args.output_dir, exist_ok=True)
283 | print("Generating Datasets\n")
284 | generate_dataset()
285 | generate_twitter_dataset(twitter_dataset_names)
286 | print("####################################################################")
287 |
288 | event_count_result = OrderedDict()
289 | results = dict()
290 | for dataset_name in dataset_names:
291 | print("Processing", dataset_name, "Datasets\n")
292 | args.current_dataset = dataset_name
293 | if dataset_name == 'Trump':
294 | args.comp_enc_len = 25
295 | if automate_bin_sz:
296 | if dataset_name in ['Trump', 'sin']:
297 | args.bin_size = utils.get_optimal_bin_size(dataset_name)
298 | else:
299 | args.bin_size = utils.find_best_bin_size(dataset_name)
300 | print('New bin size is', args.bin_size, 'sec')
301 | dataset = utils.get_processed_data(dataset_name, args)
302 |
303 | count_test_out_counts = dataset['count_test_out_counts']
304 | event_count_preds_true = count_test_out_counts
305 | count_var = None
306 |
307 | per_model_count = dict()
308 | per_model_save = {
309 | 'wgan': None,
310 | 'seq2seq': None,
311 | 'transformer': None,
312 | 'count_model': None,
313 | 'hierarchical': None,
314 | 'rmtpp_mse': None,
315 | 'rmtpp_nll': None,
316 | 'rmtpp_mse_var': None,
317 | 'inference_models': None,
318 | }
319 | per_model_count['true'] = event_count_preds_true
320 | for model_name in model_names:
321 | print("--------------------------------------------------------------------")
322 | args.current_model = model_name
323 | print("Running", model_name, "Model\n")
324 |
325 | model, count_dist_params, rmtpp_var_model, results \
326 | = run.run_model(dataset_name,
327 | model_name,
328 | dataset,
329 | args,
330 | results,
331 | prev_models=per_model_save,
332 | run_model_flags=run_model_flags)
333 |
334 | #if model_name == 'count_model':
335 | # count_all_means_pred = count_dist_params['count_all_means_pred']
336 | # count_all_sigms_pred = count_dist_params['count_all_sigms_pred']
337 |
338 | #per_model_count[model_name] = count_all_means_pred
339 | per_model_save[model_name] = model
340 | #if model_name == 'rmtpp_mse' and args.extra_var_model:
341 | # per_model_save['rmtpp_var_model'] = rmtpp_var_model
342 | #print("Finished Running", model_name, "Model\n")
343 |
344 | #if model_name != 'inference_models' and per_model_count[model_name] is not None:
345 | # old_stdout = sys.stdout
346 | # sys.stdout=open(os.path.join(args.output_dir, "count_model_"+dataset_name+".txt"),"a")
347 | # print("____________________________________________________________________")
348 | # print(model_name, 'MAE for Count Prediction:', np.mean(np.abs(per_model_count['true']-per_model_count[model_name])))
349 | # print(model_name, 'MAE for Count Prediction (per bin):', np.mean(np.abs(per_model_count['true']-per_model_count[model_name]), axis=0))
350 | # print("____________________________________________________________________")
351 | # sys.stdout.close()
352 | # sys.stdout = old_stdout
353 |
354 | print('Got result', 'for model', model_name, 'on dataset', dataset_name)
355 |
356 | # TODO: Generate count prediction plots
357 | #for idx in range(10):
358 | # utils.generate_plots(args, dataset_name, dataset, per_model_count, test_sample_idx=idx, count_var=count_var)
359 |
360 | #event_count_result[dataset_name] = per_model_count
361 | print("####################################################################")
362 |
363 |
364 | with open(os.path.join(args.output_dir, 'results_'+dataset_name+'.txt'), 'w') as fp:
365 |
366 | fp.write('\n\nResults in random interval:')
367 | fp.write('\nModel Name & Count MAE & Wass dist & opt_loss & cont_loss & count_loss')
368 | for model_name, metrics_dict in results.items():
369 | fp.write(
370 | '\n & {} & {:.3f} & {:.3f} & {:.3f} \\\\'.format(
371 | model_name,
372 | metrics_dict['count_mae_rh'],
373 | metrics_dict['wass_dist_rh'],
374 | metrics_dict['bleu_score_rh'],
375 | #metrics_dict['bleu_score_rh'],
376 | #metrics_dict['opt_loss'],
377 | #metrics_dict['cont_loss'],
378 | #metrics_dict['count_loss'],
379 | )
380 | )
381 |
382 | fp.write('\n\nResults in Forecast Horizon:')
383 | fp.write('\nModel Name & Count MAE & Wass Dist & bleu_score')
384 | for model_name, metrics_dict in results.items():
385 | fp.write(
386 | '\n & {} & {:.3f} & {:.3f} & {:.3f} \\\\'.format(
387 | model_name,
388 | metrics_dict['count_mae_fh'],
389 | metrics_dict['wass_dist_fh'],
390 | metrics_dict['bleu_score_fh'],
391 | )
392 | )
393 |
394 | fp.write('\n\nQuery 2 Results')
395 | fp.write('\nModel Name & Query_2_Metric')
396 | for model_name, metrics_dict in results.items():
397 | fp.write(
398 | '\n & {} & {:.3f} \\\\'.format(
399 | model_name,
400 | metrics_dict['more_metric'],
401 | )
402 | )
403 | fp.write('\n\nQuery 3 Results')
404 | fp.write('\nModel Name & Query_3_Metric')
405 | for model_name, metrics_dict in results.items():
406 | fp.write(
407 | '\n & {} & {:.3f} \\\\'.format(
408 | model_name,
409 | metrics_dict['less_metric'],
410 | )
411 | )
412 |
413 |
414 | fp.write('\n\nAll metrics in random interval:')
415 | fp.write('\nModel Name & Count MAE & Wass dist & opt_loss & cont_loss & count_loss')
416 | for model_name, metrics_dict in results.items():
417 | fp.write(
418 | '\n & {} & {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f} \\\\'.format(
419 | model_name,
420 | metrics_dict['count_mae_rh'],
421 | metrics_dict['wass_dist_rh'],
422 | metrics_dict['bleu_score_rh'],
423 | metrics_dict['opt_loss'],
424 | metrics_dict['cont_loss'],
425 | metrics_dict['count_loss'],
426 | )
427 | )
428 |
429 | fp.write('\n\nAll metrics in Forecast Horizon:')
430 | fp.write('\nModel Name & Count MAE & Wass Dist & bleu score')
431 | for model_name, metrics_dict in results.items():
432 | fp.write(
433 | '\n & {} & {:.3f} & {:.3f} & {:.3f} \\\\'.format(
434 | model_name,
435 | metrics_dict['count_mae_fh'],
436 | metrics_dict['wass_dist_fh'],
437 | metrics_dict['bleu_score_fh'],
438 | )
439 | )
440 |
441 |
442 | fp.write('\n')
443 | for model_name, metrics_dict in results.items():
444 | fp.write('\n {}'.format(model_name))
445 | for metric, val in metrics_dict.items():
446 | fp.write('\n {}: {:.3f}'.format(metric, val))
447 | fp.write('\n')
448 |
449 |
450 | for model_name, metrics_dict in results.items():
451 | for metric, metric_val in metrics_dict.items():
452 | results[model_name][metric] = str(metric_val)
453 | import json
454 | with open(os.path.join(args.output_dir, 'results_'+dataset_name+'.json'), 'w') as fp:
455 | json.dump(results, fp)
456 |
457 | #with open(os.path.join(args.output_dir, 'results_'+dataset_name+'.json'), 'w') as fp:
458 | # json.dump(results, fp)
459 | #results_json = json.dumps(results, indent=4)
460 | #fp.write(results_json)
461 |
--------------------------------------------------------------------------------
/modules/Hawkes/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import sys,time,datetime,copy,subprocess,itertools,pickle,warnings,numbers
5 |
6 | import numpy as np
7 | import scipy as sp
8 | import pandas as pd
9 | from matplotlib import pyplot as plt
10 | import matplotlib as mpl
11 | import matplotlib.gridspec as gridspec
12 |
13 | from scipy.special import gamma,digamma
14 |
15 | from .tools import Quasi_Newton,merge_stg,loglinear_COS,plinear
16 |
17 | try:
18 | import pyximport; pyximport.install(setup_args={'include_dirs': np.get_include()},language_level=2)
19 | from .Hawkes_C import LG_kernel_SUM_exp_cython, LG_kernel_SUM_pow_cython, preprocess_data_nonpara_cython, LG_kernel_SUM_nonpara_cython
20 | cython_import = True
21 | #print("cython")
22 | except:
23 | cython_import = False
24 | #print("python")
25 |
26 |
27 | ##########################################################################################################
28 | ## class
29 | ##########################################################################################################
30 | class base_class:
31 |
32 | ### initialize
33 | def set_kernel(self,type,**kwargs):
34 | kernel_class = {'exp':kernel_exp, 'pow':kernel_pow, 'nonpara':kernel_nonpara}
35 | self.kernel = kernel_class[type](**kwargs)
36 | return self
37 |
38 | def set_baseline(self,type,**kwargs):
39 | baseline_class = {'const':baseline_const,'loglinear':baseline_loglinear,'plinear':baseline_plinear,'custom':baseline_custom}
40 | self.baseline = baseline_class[type](**kwargs)
41 | return self
42 |
43 | def set_parameter(self,para):
44 | self.para = para
45 | self.baseline.set_parameter(para)
46 | self.kernel.set_parameter(para)
47 | return self
48 |
49 | def set_data(self,Data,itv):
50 | st,en = itv
51 | T = Data['T']
52 | T = T[ (sten) or (i==N_MAX):
133 | break
134 |
135 | if np.random.rand() < l1/l0: ## Fire
136 | T[i] = x
137 | i += 1
138 | l_kernel_sequential.event()
139 |
140 | l0 = l_baseline(x) + l_kernel_sequential.l
141 |
142 | T = T[:i]
143 |
144 | return T
145 |
146 | class estimator(base_class):
147 |
148 | def fit(self,T,itv,prior=[],opt=[],merge=[]):
149 | T = np.array(T); T = T[(itv[0]1 else np.inf
727 | return br
728 |
729 | def sequential(self,mode='simulation'):
730 | para = self.para
731 | return kernel_sequential_pow(para,mode=mode)
732 |
733 | class kernel_sequential_pow(kernel_sequential):
734 |
735 | def __init__(self,para,mode='simulation'):
736 | self.mode = mode
737 | k = para['k']; p = para['p']; c = para['c'];
738 | num_div = 16
739 | delta = 1.0/num_div
740 | s = np.linspace(-9,9,num_div*18+1)
741 | log_phi = s-np.exp(-s)
742 | log_dphi = log_phi + np.log(1+np.exp(-s))
743 | phi = np.exp(log_phi) # phi = np.exp(s-np.exp(-s))
744 | H = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p)
745 | g = np.zeros_like(s)
746 | self.g = g
747 | self.l = 0
748 | self.Int = 0
749 | self.phi = phi
750 | self.H = H
751 |
752 | if self.mode == 'estimation':
753 | H_k = delta * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p)
754 | H_p = delta * k * np.exp( log_dphi + (p-1)*log_phi - c*phi ) / gamma(p) * (log_phi-digamma(p))
755 | H_c = delta * k * np.exp( log_dphi + p*log_phi - c*phi ) / gamma(p) * (-1)
756 | self.H_k = H_k
757 | self.H_p = H_p
758 | self.H_c = H_c
759 | self.dl = {'k':0, 'p':0, 'c':0}
760 |
761 | def step_forward(self,step):
762 | g = self.g; phi = self.phi; H = self.H;
763 | index = phi*step<1e-6
764 | v = np.where(index,step,np.divide((1-np.exp(-phi*step)),phi,where=~index))
765 | Int = (g*v).dot(H) #(g*(1-np.exp(-phi*step))/phi).dot(H)
766 | g = g*np.exp(-phi*step)
767 | l = g.dot(H)
768 | self.g = g
769 | self.l = l
770 | self.Int = Int
771 |
772 | if self.mode == 'estimation':
773 | H_k = self.H_k; H_p = self.H_p; H_c = self.H_c;
774 | dl = {'k':g.dot(H_k),'p':g.dot(H_p),'c':g.dot(H_c)}
775 | self.dl = dl
776 |
777 | return self
778 |
779 | def event(self):
780 | g = self.g; H = self.H;
781 | g = g+1.0
782 | l = g.dot(H)
783 | self.g = g
784 | self.l = l
785 |
786 | return self
787 |
788 | ###############################
789 | class kernel_nonpara(base_component_kernel_class):
790 |
791 | def __init__(self,support,num_bin):
792 | self.type = 'nonpara'
793 | self.support = support
794 | self.num_bin = num_bin
795 | self.bin_edge = np.linspace(0,support,num_bin+1)
796 | self.bin_width = support/num_bin
797 | self.para_list = ['g']
798 | self.has_sequential = False
799 |
800 | def set_data(self,Data,itv):
801 | self.Data = Data
802 | self.itv = itv
803 | dl,dInt = preprocess_data_nonpara_cython(Data['T'],self.bin_edge,itv[1])
804 | self.dl = dl
805 | self.dInt = dInt
806 | return self
807 |
808 | def prep_fit(self):
809 | num_bin = self.num_bin
810 | support = self.support
811 | list = ['g']
812 | length = {'g': num_bin }
813 | exp = {'g': True }
814 | ini = {'g': np.ones(num_bin)*0.5/support}
815 | step_Q = {'g': np.ones(num_bin)*0.2}
816 | step_diff = {'g': np.ones(num_bin)*0.01}
817 | return {"list":list,'length':length,'exp':exp,'ini':ini,'step_Q':step_Q,'step_diff':step_diff}
818 |
819 | def LG_SUM(self):
820 | g = self.para['g']
821 | l = np.dot(g,self.dl)
822 | dl = {'g':self.dl}
823 | return [l,dl]
824 |
825 | def LG_INT(self):
826 | g = self.para['g']
827 | Int = g.dot(self.dInt)
828 | dInt = {'g':self.dInt}
829 | return [Int,dInt]
830 |
831 | def func(self,x):
832 | bin_edge = self.bin_edge
833 | g = self.para['g']
834 | g_ext = np.hstack([g,0])
835 | l = g_ext[ np.searchsorted(bin_edge,x,side='right') - 1 ]
836 | return l
837 |
838 | def d_func(self,x):
839 | bin_edge = self.bin_edge
840 | dl = np.zeros((bin_edge.shape[0],x.shape[0]))
841 | dl[np.searchsorted(bin_edge,x,side='right')-1,np.arange(x.shape[0])] = 1.0
842 | dl = {'g': dl[:-1] }
843 | return dl
844 |
845 | def int(self,x1,x2):
846 | bin_edge = self.bin_edge
847 | bin_width = self.bin_width
848 | g = self.para['g']
849 | g_ext = np.hstack([g,0])
850 | cum = np.hstack([0,g.cumsum()*bin_width])
851 | index1 = np.searchsorted(bin_edge,x1,side='right') - 1
852 | index2 = np.searchsorted(bin_edge,x2,side='right') - 1
853 | int1 = cum[index1] + (x1-bin_edge[index1]) * g_ext[index1]
854 | int2 = cum[index2] + (x2-bin_edge[index2]) * g_ext[index2]
855 | return int2-int1
856 |
857 | def d_int(self,x1,x2):
858 | bin_edge = self.bin_edge
859 | bin_width = self.bin_width
860 | index1 = np.searchsorted(bin_edge,x1,side='right') - 1
861 | index2 = np.searchsorted(bin_edge,x2,side='right') - 1
862 | dInt1 = np.vstack([ np.hstack([np.ones(index)*bin_width,x1[i]-bin_edge[index],np.zeros(bin_edge.shape[0]-index-1)]) for i,index in enumerate(index1) ]).transpose()
863 | dInt2 = np.vstack([ np.hstack([np.ones(index)*bin_width,x2[i]-bin_edge[index],np.zeros(bin_edge.shape[0]-index-1)]) for i,index in enumerate(index2) ]).transpose()
864 |
865 | return {'g': (dInt2 - dInt1)[:-1] }
866 |
867 | def branching_ratio(self):
868 | br = self.para['g'].sum()*self.bin_width
869 | return br
870 |
871 | def plot(self):
872 | bin_edge = self.bin_edge
873 | x = np.vstack([bin_edge[:-1],bin_edge[1:]]).transpose().flatten()
874 | y = np.repeat(self.para['g'],2)
875 | plt.plot(x,y,'k-')
876 |
877 | ###########################################################################################
878 | ###########################################################################################
879 | ## graph routine
880 | ###########################################################################################
881 | ###########################################################################################
882 | def plot_N(T,itv):
883 |
884 | gs = gridspec.GridSpec(100,1)
885 |
886 | plt.figure(figsize=(4,5), dpi=100)
887 | mpl.rc('font', size=12, family='DejaVu Sans')
888 | mpl.rc('axes',titlesize=12)
889 | mpl.rc('pdf',fonttype=42)
890 |
891 | [st,en] = itv
892 | n = len(T)
893 | x = np.hstack([st,np.repeat(T,2),en])
894 | y = np.repeat(np.arange(n+1),2)
895 |
896 | plt.subplot(gs[0:10,0])
897 | plt.plot(np.hstack([ [t,t,np.NaN] for t in T]),np.array( [0,1,np.NaN] * n ),'k-',linewidth=0.5)
898 | plt.xticks([])
899 | plt.xlim(itv)
900 | plt.ylim([0,1])
901 | plt.yticks([])
902 | plt.gca().spines['top'].set_visible(False)
903 | plt.gca().spines['right'].set_visible(False)
904 | plt.gca().spines['left'].set_visible(False)
905 |
906 | plt.subplot(gs[15:100,0])
907 | plt.plot(x,y,'k-',clip_on=False)
908 | plt.xlim(itv)
909 | plt.ylim([0,n])
910 | plt.xlabel('time')
911 | plt.ylabel(r'$N(0,t)$')
912 | plt.gca().spines['top'].set_visible(False)
913 | plt.gca().spines['right'].set_visible(False)
914 |
915 | def plot_l(T,x,l,l_baseline):
916 |
917 | gs = gridspec.GridSpec(100,1)
918 |
919 | plt.figure(figsize=(4,5), dpi=100)
920 | mpl.rc('font', size=12, family='DejaVu Sans')
921 | mpl.rc('axes',titlesize=12)
922 | mpl.rc('pdf',fonttype=42)
923 |
924 | l_max = l.max()
925 | n = len(T)
926 |
927 | plt.subplot(gs[0:10,0])
928 | plt.plot(np.hstack([ [t,t,np.NaN] for t in T]),np.array( [0,1,np.NaN] * n ),'k-',linewidth=0.5)
929 | plt.xticks([])
930 | plt.xlim([x[0],x[-1]])
931 | plt.ylim([0,1])
932 | plt.yticks([])
933 | plt.gca().spines['top'].set_visible(False)
934 | plt.gca().spines['right'].set_visible(False)
935 | plt.gca().spines['left'].set_visible(False)
936 |
937 | plt.subplot(gs[15:100,0])
938 | plt.plot(x,l,'k-',lw=1)
939 | plt.plot(x,l_baseline,'k:',lw=1)
940 | plt.xlim([x[0],x[-1]])
941 | plt.ylim([0,l_max])
942 | plt.xlabel('time')
943 | plt.ylabel(r'$\lambda(t|H_t)$')
944 | plt.gca().spines['top'].set_visible(False)
945 | plt.gca().spines['right'].set_visible(False)
946 |
947 | def plot_N_pred(T,T_pred,itv,en_f):
948 |
949 | gs = gridspec.GridSpec(100,1)
950 |
951 | plt.figure(figsize=(4,5), dpi=100)
952 | mpl.rc('font', size=12, family='DejaVu Sans')
953 | mpl.rc('axes',titlesize=12)
954 | mpl.rc('pdf',fonttype=42)
955 |
956 | [st,en] = itv
957 | n = len(T)
958 | x = np.hstack([st,np.repeat(T,2),en])
959 | y = np.repeat(np.arange(n+1),2)
960 | n_pred_max = np.max([ len(T_i) for T_i in T_pred ])
961 |
962 | plt.subplot(gs[0:10,0])
963 | plt.plot(np.hstack([ [t,t,np.NaN] for t in T]),np.array( [0,1,np.NaN] * n ),'k-',linewidth=0.5)
964 | plt.xticks([])
965 | plt.xlim([itv[0],en_f])
966 | plt.ylim([0,1])
967 | plt.yticks([])
968 | plt.gca().spines['top'].set_visible(False)
969 | plt.gca().spines['right'].set_visible(False)
970 | plt.gca().spines['left'].set_visible(False)
971 |
972 | plt.subplot(gs[15:,0])
973 | plt.plot(x,y,'k-')
974 | plt.plot([en,en],[0,n+n_pred_max],'k--')
975 |
976 | for i in range(len(T_pred)):
977 | n_pred = len(T_pred[i])
978 | x = np.hstack([en,np.repeat(T_pred[i],2),en_f])
979 | y = np.repeat(np.arange(n_pred+1),2) + n
980 | plt.plot(x,y,'-',color=[0.7,0.7,1.0],lw=0.5)
981 |
982 | plt.xlim([st,en_f])
983 | plt.ylim([0,n+n_pred_max])
984 | plt.xlabel('time')
985 | plt.ylabel(r'$N(0,t)$')
986 | plt.gca().spines['top'].set_visible(False)
987 | plt.gca().spines['right'].set_visible(False)
988 |
989 | def plot_KS(T_trans,itv_trans):
990 | from scipy.stats import kstest
991 |
992 | plt.figure(figsize=(4,4), dpi=100)
993 | mpl.rc('font', size=12, family='DejaVu Sans')
994 | mpl.rc('axes',titlesize=12)
995 | mpl.rc('pdf',fonttype=42)
996 |
997 | n = len(T_trans)
998 | [st,en] = itv_trans
999 | x = np.hstack([st,np.repeat(T_trans,2),en])
1000 | y = np.repeat(np.arange(n+1),2)/n
1001 | w = 1.36/np.sqrt(n)
1002 | [_,pvalue] = kstest(T_trans/itv_trans[1],'uniform')
1003 |
1004 | plt.plot(x,y,"k-",label='Data')
1005 | plt.fill_between([0,n*w,n*(1-w),n],[0,0,1-2*w,1-w],[w,2*w,1,1],color="#dddddd",label='95% interval')
1006 | plt.xlim([0,n])
1007 | plt.ylim([0,1])
1008 | plt.ylabel("cumulative distribution function")
1009 | plt.xlabel("transfunced time")
1010 | plt.title("p-value = %.3f" % pvalue)
1011 | plt.legend(loc="upper left")
1012 |
--------------------------------------------------------------------------------