├── .github └── workflows │ ├── docs.yaml │ └── python-publish.yml ├── .gitignore ├── LICENCE ├── MANIFEST.in ├── NOTICE ├── README.md ├── docs ├── Makefile ├── README.md ├── images │ └── thinning_algo.jpg ├── make.bat └── source │ ├── advanced │ ├── implementation.rst │ ├── performance_valid.rst │ ├── tensorboard.rst │ └── thinning_algo.rst │ ├── conf.py │ ├── dev_guide │ └── model_custom.rst │ ├── get_started │ ├── install.rst │ ├── introduction.rst │ └── quick_start.rst │ ├── index.rst │ ├── ref │ ├── config.rst │ ├── hpo.rst │ ├── models.rst │ ├── preprocess.rst │ ├── runner.rst │ ├── utils.rst │ └── wrapper.rst │ └── user_guide │ ├── dataset.rst │ ├── run_eval.rst │ └── run_train_pipeline.rst ├── easy_tpp ├── __init__.py ├── config_factory │ ├── __init__.py │ ├── config.py │ ├── data_config.py │ ├── hpo_config.py │ ├── model_config.py │ └── runner_config.py ├── default_registers │ ├── __init__.py │ ├── register_metrics.py │ └── register_optuna_trials.py ├── hpo │ ├── __init__.py │ ├── base_hpo.py │ └── optuna_hpo.py ├── model │ ├── __init__.py │ ├── tf_model │ │ ├── __init__.py │ │ ├── tf_anhn.py │ │ ├── tf_attnhp.py │ │ ├── tf_baselayer.py │ │ ├── tf_basemodel.py │ │ ├── tf_fullynn.py │ │ ├── tf_intensity_free.py │ │ ├── tf_nhp.py │ │ ├── tf_ode_tpp.py │ │ ├── tf_rmtpp.py │ │ ├── tf_sahp.py │ │ ├── tf_thinning.py │ │ └── tf_thp.py │ └── torch_model │ │ ├── __init__.py │ │ ├── torch_anhn.py │ │ ├── torch_attnhp.py │ │ ├── torch_baselayer.py │ │ ├── torch_basemodel.py │ │ ├── torch_fullynn.py │ │ ├── torch_intensity_free.py │ │ ├── torch_nhp.py │ │ ├── torch_ode_tpp.py │ │ ├── torch_rmtpp.py │ │ ├── torch_sahp.py │ │ ├── torch_thinning.py │ │ └── torch_thp.py ├── preprocess │ ├── __init__.py │ ├── data_collator.py │ ├── data_loader.py │ ├── dataset.py │ └── event_tokenizer.py ├── runner │ ├── __init__.py │ ├── base_runner.py │ └── tpp_runner.py ├── tf_wrapper.py ├── torch_wrapper.py └── utils │ ├── __init__.py │ ├── const.py │ ├── gen_utils.py │ ├── generic.py │ ├── import_utils.py │ ├── log_utils.py │ ├── metrics.py │ ├── misc.py │ ├── multiprocess_utils.py │ ├── ode_utils.py │ ├── registrable.py │ ├── tf_utils.py │ └── torch_utils.py ├── examples ├── configs │ ├── experiment_config.yaml │ └── hpo_config.yaml ├── data │ └── .gitkeep ├── data_inspection │ ├── config.yaml │ └── data_inspection.py ├── data_loader.py ├── event_tokenizer.py ├── gen_synthetic_data.py ├── hf_data_loader.py ├── script_data_processing │ ├── earthquake.py │ ├── make_hf_dataset.py │ ├── taobao.py │ ├── taxi.py │ └── volcano.py ├── train_experiment │ ├── retweet_config.yaml │ └── run_retweet.py ├── train_nhp.py ├── train_nhp_hpo.py ├── train_nhp_omegaconf.py └── train_nhp_with_features.py ├── notebooks ├── easytpp_1_dataset.ipynb ├── easytpp_2_tfb_wb.ipynb └── easytpp_3_train_eval.ipynb ├── requirements-doc.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── synthetic_data.json ├── test_data_loader.py └── test_nhp.py └── version.py /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | release: 9 | types: [ published ] 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | with: 19 | fetch-depth: 0 20 | - name: Set up Python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: '3.8' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip setuptools wheel 27 | sudo apt-get update 28 | sudo apt-get install openjdk-11-jdk 29 | sudo apt-get install pandoc 30 | - name: Build Sphinx docs 31 | run: | 32 | pip install tensorflow==2.2.0 33 | pip install torch 34 | pip install pandas 35 | pip install numpy 36 | pip install -r requirements-doc.txt 37 | cd docs 38 | make html 39 | # Publish built docs to gh-pages branch. 40 | # =============================== 41 | - name: Commit documentation changes 42 | run: | 43 | git clone https://github.com/ant-research/EasyTemporalPointProcess.git --branch gh-pages --single-branch gh-pages 44 | cp -r docs/build/html/* gh-pages/ 45 | cd gh-pages 46 | touch .nojekyll 47 | git config --local user.email "action@github.com" 48 | git config --local user.name "GitHub Action" 49 | git add . 50 | git commit -m "Update documentation" -a || true 51 | # The above command will fail if no changes were present, so we ignore 52 | # that. 53 | - name: Push changes 54 | uses: ad-m/github-push-action@master 55 | with: 56 | branch: gh-pages 57 | directory: gh-pages 58 | github_token: ${{ secrets.GITHUB_TOKEN }} 59 | # =============================== 60 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: '3.9' 29 | - name: Install dependencies 30 | run: | 31 | pip install -r requirements.txt 32 | pip install wheel 33 | - name: Build package 34 | run: python setup.py sdist bdist_wheel 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python build 2 | build 3 | dist 4 | easy_tpp.egg-info 5 | 6 | # python temp 7 | *.pyc 8 | 9 | # proto 10 | protoc 11 | protoc-3.4.0.tar.gz 12 | *_pb2.py 13 | 14 | # misc 15 | experiments 16 | log 17 | *.swp 18 | *.swo 19 | .vscode 20 | 21 | 22 | 23 | 24 | # idea files 25 | .idea 26 | 27 | examples/checkpoints/* 28 | 29 | notebooks/checkpoints/* -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include version.py -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | ============================================================= 2 | EasyTPP is a open source tool developed by Machine Intelligence Team 3 | Copyright (c) 2020-2022, Ant Group Holding Limited. 4 | Licensed under the Apache License, Version 2.0 5 | 6 | ============================================================= 7 | This toolkit contains various third-party components under 8 | different open source licenses 9 | 10 | ----------------------------- 11 | Training evaluation pipeline 12 | Apache License, Version 2.0 13 | FuxiCTR authors 14 | 15 | ---------------------------- 16 | Training evaluation pipeline 17 | Apache License, Version 2.0 18 | EasyNLP, Alibaba Inc. 19 | 20 | ---------------------------- 21 | Tokenizer and DataLoader 22 | Apache License, Version 2.0 23 | The HuggingFace Inc. team -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation for EasyTPP 2 | 3 | This contains the full documentation of EasyTPP, which is hosted at github and can be updated manually (for releases) 4 | by pushing to the gh-pages branch. 5 | 6 | 7 | To generate the documentation locally, type 8 | 9 | ``` 10 | pip install -r requirements-doc.txt 11 | cd docs 12 | make html 13 | ``` -------------------------------------------------------------------------------- /docs/images/thinning_algo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/EasyTemporalPointProcess/7e2b7a001a293c506bd595e8ddb72d83967c2cb2/docs/images/thinning_algo.jpg -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/advanced/implementation.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | Model Implementation Details 3 | =================================== 4 | 5 | Basic structure 6 | =================================== 7 | 8 | In the model folder, `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**) implements functionalities of computing loglikelihood and sampling procedures that are common 9 | to all the TPP models. In the inherited class, models with specific structures are defined, explained in below sections. 10 | 11 | 12 | Computing the loglikelihood of non-pad event sequence 13 | ------------------------------------------------------ 14 | 15 | The loglikelihood computation, following the definition in Equation 8 of `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process `_, is shared by all the TPP models. 16 | 17 | it takes `time_delta_seqs`, `lambda_at_event`, `lambdas_loss_samples`, `seq_mask`, 18 | `lambda_type_mask` as the input and output the loglikelihood items, please see `torch_basemodel` (**/model/torch_model/torch_basemodel.py**) / `tf_basemodel` (**/model/tf_model/tf_basemodel.py**) 19 | for details. 20 | 21 | It is noted that: 22 | 23 | 1. Sequential prediction: because we performance sequential prediction, i.e., predict next one given previous, we do not consider the last one as it has no labels. To implement the `forward` function, we take input of `time_seqs[:, :-1]` 24 | and `type_seqs[:, :-1]`. For `time_delta_seqs` it is different; please see the next point. 25 | 26 | 27 | 28 | 2. Continuous-time evolution: recall the definition in [dataset](./dataset.rst), assume we have a sequence of 4 events and 1 pad event 29 | at the end, i.e., 30 | 31 | .. code-block:: bash 32 | 33 | index: 0, 1, 2, 3, 4 34 | dtimes: 0, t_1-t_0, t_2-t_1, t_3-t_2, pad 35 | types: e_0, e_1, e_2, e_3, pad 36 | non_pad_mask: True, True, True, True, False 37 | 38 | For the i-th event, i-th dtime denotes the time evolution (e.g., decay in NHP) to the current event and 39 | (i+1)-th dtime denotes the time evolution to the next event. To compute the non-event loglikelihood, 40 | we should consider the time evolution after the event happens. Therefore we should use `type_delta_seqs[:, 1:]` with masks specified in the below step. 41 | 42 | 3. Masking: suppose we have predictions of 0,1,2,3-th event and their labels are 1,2,3,4-th events 43 | where $4$-th event needed to be masked. So we should set the sequence mask as `True, True, True, False`, i.e., `seq_mask=batch_non_pad_mask[:, 1:]`. 44 | The same logic applies to the attention mask and event type mask. 45 | 46 | Therefore the following code is a typical example of calling the loglikelihood computation: 47 | 48 | 49 | .. code-block:: python 50 | 51 | event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, # seq_len = max_len - 1 52 | lambdas_loss_samples=lambda_t_sample, # seq_len = max_len - 1 53 | time_delta_seq=time_delta_seq[:, 1:], 54 | seq_mask=batch_non_pad_mask[:, 1:], 55 | lambda_type_mask=type_mask[:, 1:]) 56 | 57 | 58 | 59 | Computing the integral inside the loglikelihood 60 | ----------------------------------------------- 61 | 62 | 63 | The loglikelihood of the parameters is the sum of the log-intensities of the events that happened, at the times they happened, 64 | minus an integral of the total intensities over the observation interval over [0,T]: 65 | 66 | .. math:: 67 | 68 | \sum_{t_i}\log \lambda_{k_i}(t_i) - \int_0^T \lambda(t) dt 69 | 70 | The first term refers to event loglikelihood and the second term (including the negative sign) refers to the non-event loglikelihood. 71 | 72 | 73 | 74 | 75 | 76 | 77 | Neural Hawkes Process (NHP) 78 | =================================== 79 | 80 | We implement NHP based on author's official pytorch code `Github:nce-mpp `_. 81 | 82 | 1. A continuous-time LSTM is introduced, with the code mainly come from `Github:nce-mpp `_. 83 | 2. A `forward` function in NHP class that recursively update the states: we compute the event embedding, pass to the LSTM cell and then decay afterwards. Noted that for i-th event, we should use (i+1)-th dt for the decay. So we do not consider the last event as it has no decay time. 84 | 85 | Attentive Neural Hawkes Process (AttNHP) 86 | ======================================== 87 | 88 | 89 | We implement AttNHP based on the authors' official pytorch code `Github:anhp-andtt `_ 90 | and similar to NHP, we factorize it into based model and inherited model. 91 | 92 | The forward functions is implemented faithfully to that of the author's repo. 93 | 94 | 95 | Transformer Hawkes Process (THP) 96 | ======================================== 97 | 98 | We implement THP based on a fixed version of pytorch code `Github:anhp-andtt/thp `_ 99 | and we factorize it into based model and inherited model. 100 | 101 | 102 | Self-Attentive Hawkes Process (SAHP) 103 | ======================================== 104 | 105 | We implement SAHP based on a fixed version of pytorch code `Github:anhp-andtt/sahp `_ 106 | and we factorize it into based model and inherited model. 107 | 108 | `SAHP` basically shares very similar structure to that of `THP`. 109 | 110 | 111 | 112 | Recurrent Marked Temporal Point Processes (RMTPP) 113 | ==================================================== 114 | 115 | We implement RMTPP faithfully to the author's paper. 116 | 117 | 118 | Intensity Free Learning of Temporal Point Process (IntensityFree) 119 | ================================================================== 120 | 121 | We implement the model based on the author's torch code `Github:ifl-tpp `_. 122 | 123 | A small difference between our implementation and the author's is we ignore the `context_init` (the initial state of the RNN) because in our data setup, we do not need a learnable initial RNN state. This modification generally makes little impact on the learning process. 124 | 125 | It is worth noting that the thinning algorithm can not be applied to this model because it is intensity-free. When comparing the performance of the model, we only look at its log-likelihood learning curve. 126 | 127 | 128 | Fully Neural Network based Model for General Temporal Point Processes (FullyNN) 129 | =============================================================================== 130 | 131 | We implement the model based on the author's keras code `Github:NeuralNetworkPointProcess `_. 132 | 133 | 134 | ODE-based Temporal Point Process (ODETPP) 135 | ========================================= 136 | 137 | We implement a TPP with Neural ODE state evolution, which is a simplified version of `Neural Spatio-Temporal Point Processes `_. The ODE implementation uses the code from the `blog `_ 138 | 139 | 140 | Attentive Neural Hawkes Network (ANHN) 141 | ====================================== 142 | 143 | We implement the model based on the author's paper: the attentive model without the graph regularizer is named ANHN. 144 | -------------------------------------------------------------------------------- /docs/source/advanced/performance_valid.rst: -------------------------------------------------------------------------------- 1 | ========================================= 2 | Performance validation of EasyTPP models 3 | ========================================= 4 | 5 | We run the experiments on various dataset to validate the implementations: each model is trained with a max number of epochs and 6 | the best model is selected based on the performance on the valid set, then we report the results on the test set. 7 | 8 | 9 | Simulated dataset 10 | --------------------------- 11 | Conttime 12 | ********************** 13 | 14 | 15 | 16 | +--------------+----------+----------+----------+--------------------+ 17 | | Models | Loglike | RMSE | Acc | Num Training Epochs| 18 | +==============+==========+==========+==========+====================+ 19 | | Torch_NHP | -0.93504 | 0.34000 | 0.38656 | 200 | 20 | +--------------+----------+----------+----------+--------------------+ 21 | | Tf_NHP | -0.85774 | 0.34014 | 0.38806 | 200 | 22 | +--------------+----------+----------+----------+--------------------+ 23 | | Torch_AttNHP | -1.02001 | 0.33678 | 0.36782 | 200 | 24 | +--------------+----------+----------+----------+--------------------+ 25 | | Tf_AttNHP | -1.02315 | 0.33816 | 0.19456 | 200 | 26 | +--------------+----------+----------+----------+--------------------+ 27 | | Torch_AttNHP | -1.00593 | 0.33685 | 0.37723 | 500 | 28 | +--------------+----------+----------+----------+--------------------+ 29 | | Tf_AttNHP | -0.99827 | 0.33717 | 0.36498 | 500 | 30 | +--------------+----------+----------+----------+--------------------+ 31 | | Torch_THP | -0.99827 | 0.33717 | 0.36498 | 500 | 32 | +--------------+----------+----------+----------+--------------------+ 33 | | Tf_THP | -1.01898 | 0.33677 | 0.37875 | 500 | 34 | +--------------+----------+----------+----------+--------------------+ 35 | 36 | 37 | 38 | ## Real dataset 39 | ### Taxi 40 | 41 | 42 | -------------------------------------------------------------------------------- /docs/source/advanced/tensorboard.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | Launching the Tensorboard 3 | =================================== 4 | 5 | 6 | Here we present how to launch the tensorboard within the ``EasyTPP`` framework. 7 | 8 | Step 1: Activate the usage of tensorboard in Config file 9 | ======================================================== 10 | 11 | 12 | As shown in `Training Pipeline <../get_started/run_train_pipeline.html>`_, we need to firstly initialize the 'model_config.yaml' file to setup the running config before training or evaluating the model. 13 | 14 | In the ``model config`` (`modeling` attribute of the config), one needs to set ``use_tfb`` to ``True`` in `trainer`. Then before the running process, summary writers tracking the performance on training and valid sets are both initialized. 15 | 16 | .. code-block:: yaml 17 | 18 | NHP_train: 19 | base_config: 20 | stage: train 21 | backend: torch 22 | dataset_id: taxi 23 | runner_id: std_tpp 24 | model_id: NHP # model name 25 | base_dir: './checkpoints/' 26 | trainer_config: 27 | batch_size: 256 28 | max_epoch: 200 29 | shuffle: False 30 | optimizer: adam 31 | learning_rate: 1.e-3 32 | valid_freq: 1 33 | use_tfb: True # Activate the tensorboard 34 | metrics: [ 'acc', 'rmse' ] 35 | seed: 2019 36 | gpu: -1 37 | model_config: 38 | hidden_size: 64 39 | loss_integral_num_sample_per_step: 20 40 | # pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model 41 | thinning: 42 | num_seq: 10 43 | num_sample: 1 44 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 45 | look_ahead_time: 10 46 | patience_counter: 5 # the maximum iteration used in adaptive thinning 47 | over_sample_rate: 5 48 | num_samples_boundary: 5 49 | dtime_max: 5 50 | num_step_gen: 1 51 | 52 | 53 | 54 | Step 2: Launching the tensorboard 55 | ======================================================== 56 | 57 | 58 | We simply go to the output file of the training runner (its directory is specified in `base_dir` of ``base_config``), find out the tensorboard file address and launch it. 59 | 60 | A complete example of using tensorboard can be seen at *examples/run_tensorboard.py*. 61 | 62 | 63 | .. code-block:: python 64 | 65 | import os 66 | 67 | def main(): 68 | # one can find this dir in the config out file 69 | log_dir = './checkpoints/NHP_train_taxi_20220527-20:18:30/tfb_train' 70 | os.system('tensorboard --logdir={}'.format(log_dir)) 71 | return 72 | 73 | 74 | if __name__ == '__main__': 75 | main() -------------------------------------------------------------------------------- /docs/source/advanced/thinning_algo.rst: -------------------------------------------------------------------------------- 1 | ============================================== 2 | Thinning Algorithm for Sampling Event Sequence 3 | ============================================== 4 | 5 | In ``EasyTPP`` we use ``Thinning algorithm`` depicted in Algorithm 2 6 | in `The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process `_ 7 | for event sampling. 8 | 9 | The implementation of the algorithm 10 | ==================================== 11 | 12 | 13 | We implement the algorithm both in PyTorch and Tensorflow, as seen in *./model/torch_thinning.py* and 14 | *./model/tf_thinning.py*, which basically follow the same procedure. 15 | 16 | The corresponding code is in function ``draw_next_time_one_step``, which consists of the following steps: 17 | 18 | 1. Compute the upper bound of the intensity at each event timestamp in function ``compute_intensity_upper_bound``, where we sample some timestamps inside event intervals and output a upper bound intensity matrix [batch_size, seq_len], denoting the upper bound of prediced intensity (for next time interval) for each sequence at each timestamp. 19 | 2. Sample the exponential distribution with the intensity computed in Step 1 in function ``sample_exp_distribution``, where we simply divide the standard exponential number with the intensity, which is equivalent to sampling with exp(sample_rate), according to `the properties of Exponential Distribution `_. The exponential random variables have size [batch_size, seq_len, num_sample, num_exp], where num_sample refers to the number of event times sampled in every interval and num_exp refers to number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm. 20 | 3. Compute the intensities at the sample times proposed in Step 2, with final size `[batch_size, seq_len, num_sample, num_exp]`. 21 | 4. Sample the standard uniform distribution with size `[batch_size, seq_len, num_sample, num_exp]`. 22 | 5. Perform the acceptance sampling with certain probability in function ``sample_accept``. 23 | 6. The earliest sampling dtimes are accepted. For unaccepted sampling dtimes, use boundary/maxsampletime for that draw. 24 | 7. The final predicted dtimes has size `[batch_size, seq_len, num_sample]`, which refers to the sampling dtimes for each sequence at each timestamp, along with an equal weight vector. 25 | 8. The product of the predicted dtimes and the weight is the final predicted dtimes, with size `[batch_size, seq_len]`. 26 | 27 | 28 | .. image:: ../../images/thinning_algo.jpg 29 | :alt: thinning_algo 30 | 31 | 32 | 33 | One-step prediction 34 | ==================================== 35 | By default, once given the parameters of thinning algo (defining a ``thinning`` config as part of ``model_config``), we perform the one-step prediction in model evaluation, i.e., predict the next event given the prefix. The implementation is in function ``prediction_event_one_step`` in BaseModel (i.e., TorchBaseModel or TfBaseModel). 36 | 37 | 38 | Multi-step prediction 39 | ==================================== 40 | The recursive multi-step prediction is activated by setting `num_step_gen` to a number bigger than 1 in the ``thinning`` config. 41 | 42 | Be noted that, we generate the multi-step events after the last non-pad event of each sequence. The implementation is in function `predict_multi_step_since_last_event` in BaseModel (i.e., TorchBaseModel or TfBaseModel). 43 | 44 | 45 | .. code-block:: yaml 46 | 47 | thinning: 48 | num_seq: 10 49 | num_sample: 1 50 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 51 | look_ahead_time: 10 52 | patience_counter: 5 # the maximum iteration used in adaptive thinning 53 | over_sample_rate: 5 54 | num_samples_boundary: 5 55 | dtime_max: 5 56 | num_step_gen: 5 # by default it is single step, i.e., 1 -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | 7 | # -- Autodoc information ----------------------------------------------------- 8 | # https://sphinx-rtd-tutorial.readthedocs.io/en/latest/sphinx-config.html 9 | 10 | 11 | import os 12 | import sys 13 | 14 | sys.path.insert(0, os.path.abspath('../../easy_tpp/')) 15 | 16 | sys.path.insert(0, os.path.abspath('../..')) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 20 | 21 | project = 'EasyTPP' 22 | copyright = '2022, Machine Intelligence, Alipay' 23 | author = 'Machine Intelligence, Alipay' 24 | release = '0.0.2' 25 | 26 | # -- General configuration --------------------------------------------------- 27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 28 | 29 | extensions = [ 30 | "sphinx.ext.autodoc", 31 | 'sphinx.ext.viewcode', 32 | "sphinx.ext.todo", 33 | "sphinx.ext.mathjax", 34 | "sphinx.ext.napoleon", 35 | 'sphinx.ext.autosummary' 36 | ] 37 | 38 | napoleon_google_docstring = True 39 | napoleon_numpy_docstring = False 40 | 41 | templates_path = ['_templates'] 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This patterns also effect to html_static_path and html_extra_path 45 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 49 | 50 | html_theme = 'sphinx_rtd_theme' 51 | html_static_path = ['_static'] 52 | 53 | autodoc_member_order = "bysource" 54 | autodoc_default_flags = ["members"] 55 | autodoc_default_options = { 56 | "members": True, 57 | "member-order": "bysource", 58 | "special-members": "__init__", 59 | } 60 | -------------------------------------------------------------------------------- /docs/source/dev_guide/model_custom.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Customize a Model 3 | ================== 4 | 5 | 6 | Here we introduce how to customize a TPP model with the support of ``EasyTPP``. 7 | 8 | 9 | 10 | Create a new TPP Model Class 11 | ============================= 12 | 13 | Assume we are building a PyTorch model. We need to initialize the model by inheriting class `EasyTPP.model.torch_model.TorchBaseModel <../ref/models.html>`_. 14 | 15 | .. code-block:: python 16 | 17 | from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel 18 | 19 | # Custom Torch TPP implementations need to 20 | # inherit from the TorchBaseModel interface 21 | class NewModel(TorchBaseModel): 22 | def __init__(self, model_config): 23 | super(NewModel, self).__init__(model_config) 24 | 25 | # Forward along the sequence, output the states / intensities at the event times 26 | def forward(self, batch): 27 | ... 28 | return states 29 | 30 | # Compute the loglikelihood loss 31 | def loglike_loss(self, batch): 32 | .... 33 | return loglike 34 | 35 | # Compute the intensities at given sampling times 36 | # Used in the Thinning sampler 37 | def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs): 38 | ... 39 | return intensities 40 | 41 | 42 | If we are building a Tensorflow model, we start with the following code 43 | 44 | .. code-block:: python 45 | 46 | from easy_tpp.model.torch_model.tf_basemodel import TfBaseModel 47 | 48 | # Custom Tf TPP implementations need to 49 | # inherit from the TorchBaseModel interface 50 | class NewModel(TfBaseModel): 51 | def __init__(self, model_config): 52 | super(NewModel, self).__init__(model_config) 53 | 54 | # Forward along the sequence, output the states / intensities at the event times 55 | def forward(self, batch): 56 | ... 57 | return states 58 | 59 | 60 | # Compute the loglikelihood loss 61 | def loglike_loss(self, batch): 62 | .... 63 | return loglike 64 | 65 | # Compute the intensities at given sampling times 66 | # Used in the Thinning sampler 67 | def compute_intensities_at_sample_times(self, batch, sample_times, **kwargs): 68 | ... 69 | return intensities 70 | 71 | Rewrite Relevant Methods 72 | ============================== 73 | 74 | There are three important functions needed to be implemented: 75 | 76 | - `forward`: the input is the batch data and the output is states at each step. 77 | - `loglike_loss`: it computes the loglikihood loss given the batch data. 78 | - `compute_intensities_at_sample_times`: it computes the intensities at each sampling steps. 79 | -------------------------------------------------------------------------------- /docs/source/get_started/install.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Installation 3 | ================== 4 | 5 | 6 | ``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction. 7 | 8 | 9 | Requirements 10 | ============= 11 | 12 | .. code-block:: bash 13 | 14 | PyTorch version >= 1.8.0 15 | Python version >= 3.7 16 | Tensorflow version >= 1.13.1 (only needed when using Tensorflow backend) 17 | 18 | 19 | 20 | First, we need a python environment whose version is at least greater than 3.7.0. If you don’t have one, please refer to the `Documentation `_ to install and configure the Anaconda environment. 21 | 22 | .. code-block:: bash 23 | 24 | conda create -n easytpp python=3.8 25 | conda activate easytpp 26 | 27 | Then, install Pytorch and keep the version at least greater than 1.8.0. 28 | 29 | .. code-block:: bash 30 | 31 | pip install torch 32 | 33 | By default, we assume to use PyTorch. If one wants to use Tensorflow backend, please install tensorflow additionally. Both Tensorflow 1.13.1 and 2.x are supported. 34 | 35 | .. code-block:: bash 36 | 37 | pip install tensorflow 38 | 39 | 40 | 41 | Install 42 | ===================== 43 | 44 | 45 | Install with pip 46 | -------------------------- 47 | 48 | 49 | .. code-block:: bash 50 | 51 | pip install easy-tpp 52 | 53 | 54 | Install with the source 55 | -------------------------- 56 | 57 | Setup from the source: 58 | 59 | .. code-block:: bash 60 | 61 | git clone https://github.com/ant-research/EasyTemporalPointProcess.git 62 | cd EasyTemporalPointProcess 63 | python setup.py install 64 | 65 | -------------------------------------------------------------------------------- /docs/source/get_started/introduction.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Introduction 3 | ================== 4 | 5 | 6 | ``EasyTPP`` provides an open-source library for `Neural TPP`, with a fully automated pipeline for model training and prediction. 7 | 8 | 9 | Framework 10 | ========= 11 | 12 | 13 | ``EasyTPP`` supports both Tensorflow and PyTorch: each model has two equivalent versions implemented in Tensorflow 1.13 and Pytorch 1.8 respectively. The data processing and model training / prediction pipeline are compatible with both Tensorflow and Pytorch as well. 14 | 15 | 16 | At the module level, ``EasyTPP`` is a package that consists of the following components, which are designed as loose-coupled modules that provide flexibility for users to develop customized functionalities. 17 | 18 | 19 | 20 | ======================== ============================================================================== 21 | Name Description 22 | ======================== ============================================================================== 23 | `Preprocess` module Provides data batch-wise padding, inter-time processing and other related work for raw sequence. 24 | 25 | `Model` module Implements a list of SOTA TPP models. Please refer to `Model Validation <../advanced/performance_valid.html>`_ for more details. 26 | 27 | `Config` module Encapsulate the construction of the configuration needed to run the pipeline. 28 | 29 | `Runner` module Controls the training and prediction pipeline. 30 | ======================== ============================================================================== 31 | 32 | 33 | 34 | Install 35 | ========= 36 | 37 | ``EasyTPP`` can be installed either by pip or the source. By default it is built based on PyTorch. If one wants to run with the Tensorflow backend, one needs to install Tensorflow additionally. 38 | 39 | Please see `Installation <./install.html>`_ for details of requirement and installation. 40 | 41 | 42 | Prepare Data 43 | ============ 44 | 45 | By default, we use the data in Gatech format, i.e., each dataset is a dict containing the keys such as `time_since_last_event`, `time_since_start` and `type_event`. `Preprocess <../ref/preprocess.html>`_ module 46 | will preprocess the data and feed it into the model. 47 | 48 | 49 | An example of building a pseudo dataloader can be found at `examples `_. Please refer to `Datatset <../user_guide/dataset.html>`_ for more explanations of the `TPP` dataset iterator. 50 | 51 | 52 | Model Training and Prediction 53 | ============================== 54 | 55 | The training and prediction pipeline consists of two steps: 56 | 57 | 1. Setup the config file, which specifies the dataset dir, model params and pipeline settings. 58 | 2. Launch the python script to run the whole pipeline. 59 | 60 | Please see `Training Pipeline <../user_guide/run_train_pipeline.html>`_ and `Evaluation Pipeline <../user_guide/run_eval.html>`_ for more details. -------------------------------------------------------------------------------- /docs/source/get_started/quick_start.rst: -------------------------------------------------------------------------------- 1 | ==================== 2 | Quick Start 3 | ==================== 4 | 5 | 6 | We use the [Taxi]_ dataset as an example to show how to use ``EasyTPP`` to train a model. More details and results are provided in `Training Pipeline <../user_guide/run_train_pipeline.html>`_. 7 | 8 | 9 | Download Dataset 10 | =================== 11 | 12 | 13 | 14 | The Taxi dataset we used is preprocessed by `HYPRO `_ . You can either download the dataset (in pickle) from Google Drive `here `_ or the dataset (in json) from `HuggingFace `_. 15 | 16 | 17 | Note that if the data sources are pickle files, we need to write the data config (in `Example Config `_) in the following way 18 | 19 | .. code-block:: yaml 20 | 21 | data: 22 | taxi: 23 | data_format: pickle 24 | train_dir: ./data/taxi/train.pkl 25 | valid_dir: ./data/taxi/dev.pkl 26 | test_dir: ./data/taxi/test.pkl 27 | 28 | If we choose to directly load from HuggingFace, we can put it this way: 29 | 30 | .. code-block:: yaml 31 | 32 | data: 33 | taxi: 34 | data_format: json 35 | train_dir: easytpp/taxi 36 | valid_dir: easytpp/taxi 37 | test_dir: easytpp/taxi 38 | 39 | 40 | Meanwhile, it is also feasible to put the local directory of json files downloaded from HuggingFace in the config: 41 | 42 | .. code-block:: yaml 43 | 44 | data: 45 | taxi: 46 | data_format: json 47 | train_dir: ./data/taxi/train.json 48 | valid_dir: ./data/taxi/dev.json 49 | test_dir: ./data/taxi/test.json 50 | 51 | 52 | 53 | 54 | Setup the configuration file 55 | ============================== 56 | 57 | We provide a preset config file in `Example Config `_. The details of the configuration can be found in `Training Pipeline <../user_guide/run_train_pipeline.html>`_. 58 | 59 | 60 | 61 | 62 | Train the Model 63 | ========================= 64 | 65 | At this stage we need to write a script to run the training pipeline. There is a preset script `train_nhp.py `_ and one can simply copy it. 66 | 67 | Taking the pickle data source for example, after the setup of data, config and running script, the directory structure is as follows: 68 | 69 | .. code-block:: bash 70 | 71 | data 72 | |______taxi 73 | |____ train.pkl 74 | |____ dev.pkl 75 | |____ test.pkl 76 | 77 | configs 78 | |______experiment_config.yaml 79 | 80 | train_nhp.py 81 | 82 | 83 | 84 | The one can simply run the following command. 85 | 86 | 87 | .. code-block:: bash 88 | 89 | python train_nhp.py 90 | 91 | 92 | 93 | Reference 94 | ---------- 95 | 96 | .. [Taxi] 97 | 98 | .. code-block:: bash 99 | 100 | @misc{whong-14-taxi, 101 | title = {F{OIL}ing {NYC}’s Taxi Trip Data}, 102 | author={Whong, Chris}, 103 | year = {2014}, 104 | url = {https://chriswhong.com/open-data/foil_nyc_taxi/} 105 | } 106 | 107 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | ``EasyTPP`` Documentation 3 | =================================== 4 | 5 | 6 | ``EasyTPP`` is an easy-to-use development and application toolkit for `Neural Temporal Point Process `_ (*Neural TPP*), with key features in configurability, compatibility and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of easily customized development and open benchmarking. 7 | 8 | 9 | 10 | .. toctree:: 11 | :hidden: 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | :caption: GETTING STARTED 16 | 17 | Introduction 18 | Installation 19 | Quick Start 20 | 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | :caption: USER GUIDE 25 | 26 | Dataset 27 | Model Training 28 | Model Prediction 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: DEVELOPER GUIDE 33 | 34 | Model Customization 35 | 36 | 37 | .. toctree:: 38 | :maxdepth: 2 39 | :caption: ADVANCED TOPICS 40 | 41 | Thinning Algorithm 42 | Tensorboard 43 | Performance Benchmarks 44 | Implementation Details 45 | 46 | .. toctree:: 47 | :maxdepth: 2 48 | :caption: API REFERENCE 49 | 50 | Config 51 | Preprocess 52 | Model 53 | Runner 54 | Hyper-parameter Optimization 55 | Tf and Torch Wrapper 56 | Utilities -------------------------------------------------------------------------------- /docs/source/ref/config.rst: -------------------------------------------------------------------------------- 1 | .. _api-config: 2 | 3 | EasyTPP Config Modules 4 | ============================ 5 | 6 | 7 | .. automodule:: config_factory 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/ref/hpo.rst: -------------------------------------------------------------------------------- 1 | .. _api-config: 2 | 3 | EasyTPP Config Modules 4 | ============================ 5 | 6 | 7 | .. automodule:: hpo 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/ref/models.rst: -------------------------------------------------------------------------------- 1 | .. _api-model: 2 | 3 | EasyTPP Models 4 | ==================== 5 | 6 | 7 | 8 | .. _api-tf_model: 9 | 10 | model.tf_model module 11 | ------------------------------ 12 | 13 | .. automodule:: easy_tpp.model.tf_model 14 | .. autosummary:: 15 | :toctree: ../generated/ 16 | 17 | tf_baselayer 18 | tf_basemodel 19 | tf_nhp 20 | tf_fullynn 21 | tf_intensity_free 22 | tf_ode_tpp 23 | tf_rmtpp 24 | tf_sahp 25 | tf_thp 26 | tf_attnhp 27 | tf_thinning 28 | 29 | 30 | .. _api-torch_model: 31 | 32 | model.torch_model module 33 | ------------------------------ 34 | 35 | .. automodule:: easy_tpp.model.torch_model 36 | .. autosummary:: 37 | :toctree: ../generated/ 38 | 39 | torch_baselayer 40 | torch_basemodel 41 | torch_nhp 42 | torch_fullynn 43 | torch_intensity_free 44 | torch_ode_tpp 45 | torch_rmtpp 46 | torch_sahp 47 | torch_thp 48 | torch_attnhp 49 | torch_thinning 50 | 51 | -------------------------------------------------------------------------------- /docs/source/ref/preprocess.rst: -------------------------------------------------------------------------------- 1 | .. _api-preprocess: 2 | 3 | EasyTPP Preprocess Modules 4 | ========================== 5 | 6 | 7 | .. automodule:: preprocess 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/ref/runner.rst: -------------------------------------------------------------------------------- 1 | .. _api-modelrunner: 2 | 3 | EasyTPP Model Runner Modules 4 | ============================ 5 | 6 | 7 | .. automodule:: runner 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/ref/utils.rst: -------------------------------------------------------------------------------- 1 | .. _api-util: 2 | 3 | EasyTPP Utilities Modules 4 | ========================== 5 | 6 | 7 | .. automodule:: utils 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/ref/wrapper.rst: -------------------------------------------------------------------------------- 1 | .. _api-wrapper: 2 | 3 | EasyTPP Tf and Torch Wrapper Modules 4 | ==================================== 5 | 6 | 7 | .. automodule:: tf_wrapper 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | 13 | 14 | .. automodule:: torch_wrapper 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/user_guide/run_eval.rst: -------------------------------------------------------------------------------- 1 | ================================ 2 | Evaluate a Model 3 | ================================ 4 | 5 | Step 1: Setup the config file 6 | =============================================== 7 | 8 | Same as in the training pipeline, firstly we need to initialize the task configuration in the config file. 9 | 10 | Similar to the setup in `Training Pipeline <./run_train_pipeline.html>`_, we set the `stage` to `eval` and pass the `pretrained_model_dir` to ``the model_config`` 11 | 12 | Note that the *pretrained_model_dir* can be found in the log of the training process. 13 | 14 | .. code-block:: yaml 15 | 16 | NHP_eval: 17 | base_config: 18 | stage: eval 19 | backend: torch 20 | dataset_id: taxi 21 | runner_id: std_tpp 22 | base_dir: './checkpoints/' 23 | model_id: NHP 24 | trainer_config: 25 | batch_size: 256 26 | max_epoch: 1 27 | model_config: 28 | hidden_size: 64 29 | use_ln: False 30 | seed: 2019 31 | gpu: 0 32 | pretrained_model_dir: ./checkpoints/26507_4380788096_231111-101848/models/saved_model # must provide this dir 33 | thinning: 34 | num_seq: 10 35 | num_sample: 1 36 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 37 | look_ahead_time: 10 38 | patience_counter: 5 # the maximum iteration used in adaptive thinning 39 | over_sample_rate: 5 40 | num_samples_boundary: 5 41 | dtime_max: 5 42 | 43 | 44 | 45 | 46 | A complete example of these files can be seen at `examples/example_config.yaml `_ . 47 | 48 | 49 | Step 2: Run the evaluation script 50 | ================================= 51 | 52 | Same as in the training pipeline, we need to initialize a ``ModelRunner`` object to do the evaluation. 53 | 54 | The following code is an example, which is a copy from `examples/train_nhp.py `_ . 55 | 56 | 57 | .. code-block:: python 58 | 59 | import argparse 60 | 61 | from easy_tpp.config_factory import RunnerConfig 62 | from easy_tpp.runner import Runner 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | 68 | parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml', 69 | help='Dir of configuration yaml to train and evaluate the model.') 70 | 71 | parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_eval', 72 | help='Experiment id in the config file.') 73 | 74 | args = parser.parse_args() 75 | 76 | config = RunnerConfig.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) 77 | 78 | model_runner = Runner.build_from_config(config) 79 | 80 | model_runner.run() 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | 86 | 87 | 88 | 89 | Checkout the output 90 | ==================== 91 | 92 | The evaluation result will be print in the console and saved in the logs whose directory is specified in the 93 | out config file, i.e.: 94 | 95 | .. code-block:: bash 96 | 97 | 'output_config_dir': './checkpoints/NHP_test_conttime_20221002-13:19:23/NHP_test_output.yaml' 98 | -------------------------------------------------------------------------------- /easy_tpp/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' -------------------------------------------------------------------------------- /easy_tpp/config_factory/__init__.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.config_factory.config import Config 2 | from easy_tpp.config_factory.data_config import DataConfig, DataSpecConfig 3 | from easy_tpp.config_factory.hpo_config import HPOConfig, HPORunnerConfig 4 | from easy_tpp.config_factory.runner_config import RunnerConfig, ModelConfig, BaseConfig 5 | 6 | __all__ = ['Config', 7 | 'DataConfig', 8 | 'DataSpecConfig', 9 | 'ModelConfig', 10 | 'BaseConfig', 11 | 'RunnerConfig', 12 | 'HPOConfig', 13 | 'HPORunnerConfig'] 14 | -------------------------------------------------------------------------------- /easy_tpp/config_factory/config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any 3 | from omegaconf import OmegaConf 4 | 5 | from easy_tpp.utils import save_yaml_config, Registrable, logger 6 | 7 | 8 | class Config(Registrable): 9 | 10 | def save_to_yaml_file(self, config_dir): 11 | """Save the config into the yaml file 'config_dir'. 12 | 13 | Args: 14 | config_dir (str): Target filename. 15 | 16 | Returns: 17 | """ 18 | yaml_config = self.get_yaml_config() 19 | OmegaConf.save(yaml_config, config_dir) 20 | 21 | @staticmethod 22 | def build_from_yaml_file(yaml_dir, **kwargs): 23 | """Load yaml config file from disk. 24 | 25 | Args: 26 | yaml_dir (str): Path of the yaml config file. 27 | 28 | Returns: 29 | EasyTPP.Config: Config object corresponding to cls. 30 | """ 31 | config = OmegaConf.load(yaml_dir) 32 | pipeline_config = config.get('pipeline_config_id') 33 | config_cls = Config.by_name(pipeline_config.lower()) 34 | logger.critical(f'Load pipeline config class {config_cls.__name__}') 35 | return config_cls.parse_from_yaml_config(config, **kwargs) 36 | 37 | @abstractmethod 38 | def get_yaml_config(self): 39 | """Get the yaml format config from self. 40 | 41 | Returns: 42 | """ 43 | pass 44 | 45 | @staticmethod 46 | @abstractmethod 47 | def parse_from_yaml_config(yaml_config): 48 | """Parse from the yaml to generate the config object. 49 | 50 | Args: 51 | yaml_config (dict): configs from yaml file. 52 | 53 | Returns: 54 | EasyTPP.Config: Config class for data. 55 | """ 56 | pass 57 | 58 | @abstractmethod 59 | def copy(self): 60 | """Get a same and freely modifiable copy of self. 61 | 62 | Returns: 63 | """ 64 | pass 65 | 66 | def __str__(self): 67 | """Str representation of the config. 68 | 69 | Returns: 70 | str: str representation of the dict format of the config. 71 | """ 72 | return str(self.get_yaml_config()) 73 | 74 | def update(self, config): 75 | """Update the config. 76 | 77 | Args: 78 | config (dict): config dict. 79 | 80 | Returns: 81 | EasyTPP.Config: Config class for data. 82 | """ 83 | logger.critical(f'Update config class {self.__class__.__name__}') 84 | return self.parse_from_yaml_config(config) 85 | 86 | def pop(self, key: str, default_var: Any): 87 | """pop out the key-value item from the config. 88 | 89 | Args: 90 | key (str): key name. 91 | default_var (Any): default value to pop. 92 | 93 | Returns: 94 | Any: value to pop. 95 | """ 96 | return vars(self).pop(key) or default_var 97 | 98 | def get(self, key: str, default_var: Any): 99 | """Retrieve the key-value item from the config. 100 | 101 | Args: 102 | key (str): key name. 103 | default_var (Any): default value to pop. 104 | 105 | Returns: 106 | Any: value to get. 107 | """ 108 | return vars(self)[key] or default_var 109 | 110 | def set(self, key: str, var_to_set: Any): 111 | """Set the key-value item from the config. 112 | 113 | Args: 114 | key (str): key name. 115 | var_to_set (Any): default value to pop. 116 | 117 | Returns: 118 | Any: value to get. 119 | """ 120 | vars(self)[key] = var_to_set 121 | -------------------------------------------------------------------------------- /easy_tpp/config_factory/data_config.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.config_factory.config import Config 2 | 3 | 4 | class DataSpecConfig(Config): 5 | def __init__(self, **kwargs): 6 | """Initialize the Config class. 7 | """ 8 | self.num_event_types = kwargs.get('num_event_types') 9 | self.pad_token_id = kwargs.get('pad_token_id') 10 | self.padding_side = kwargs.get('padding_side') 11 | self.truncation_side = kwargs.get('truncation_side') 12 | self.padding_strategy = kwargs.get('padding_strategy') 13 | self.max_len = kwargs.get('max_len') 14 | self.truncation_strategy = kwargs.get('truncation_strategy') 15 | self.num_event_types_pad = self.num_event_types + 1 16 | self.model_input_names = kwargs.get('model_input_names') 17 | 18 | if self.padding_side is not None and self.padding_side not in ["right", "left"]: 19 | raise ValueError( 20 | f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" 21 | ) 22 | 23 | if self.truncation_side is not None and self.truncation_side not in ["right", "left"]: 24 | raise ValueError( 25 | f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}" 26 | ) 27 | 28 | def get_yaml_config(self): 29 | """Return the config in dict (yaml compatible) format. 30 | 31 | Returns: 32 | dict: config of the data specs in dict format. 33 | """ 34 | return { 35 | 'num_event_types': self.num_event_types, 36 | 'pad_token_id': self.pad_token_id, 37 | 'padding_side': self.padding_side, 38 | 'truncation_side': self.truncation_side, 39 | 'padding_strategy': self.padding_strategy, 40 | 'truncation_strategy': self.truncation_strategy, 41 | 'max_len': self.max_len 42 | } 43 | 44 | @staticmethod 45 | def parse_from_yaml_config(yaml_config): 46 | """Parse from the yaml to generate the config object. 47 | 48 | Args: 49 | yaml_config (dict): configs from yaml file. 50 | 51 | Returns: 52 | DataSpecConfig: Config class for data specs. 53 | """ 54 | return DataSpecConfig(**yaml_config) 55 | 56 | def copy(self): 57 | """Copy the config. 58 | 59 | Returns: 60 | DataSpecConfig: a copy of current config. 61 | """ 62 | return DataSpecConfig(num_event_types_pad=self.num_event_types_pad, 63 | num_event_types=self.num_event_types, 64 | event_pad_index=self.pad_token_id, 65 | padding_side=self.padding_side, 66 | truncation_side=self.truncation_side, 67 | padding_strategy=self.padding_strategy, 68 | truncation_strategy=self.truncation_strategy, 69 | max_len=self.max_len) 70 | 71 | 72 | @Config.register('data_config') 73 | class DataConfig(Config): 74 | def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None): 75 | """Initialize the DataConfig object. 76 | 77 | Args: 78 | train_dir (str): dir of tran set. 79 | valid_dir (str): dir of valid set. 80 | test_dir (str): dir of test set. 81 | specs (dict, optional): specs of dataset. Defaults to None. 82 | """ 83 | self.train_dir = train_dir 84 | self.valid_dir = valid_dir 85 | self.test_dir = test_dir 86 | self.data_specs = specs or DataSpecConfig() 87 | self.data_format = train_dir.split('.')[-1] if data_format is None else data_format 88 | 89 | def get_yaml_config(self): 90 | """Return the config in dict (yaml compatible) format. 91 | 92 | Returns: 93 | dict: config of the data in dict format. 94 | """ 95 | return { 96 | 'train_dir': self.train_dir, 97 | 'valid_dir': self.valid_dir, 98 | 'test_dir': self.test_dir, 99 | 'data_format': self.data_format, 100 | 'data_specs': self.data_specs.get_yaml_config(), 101 | } 102 | 103 | @staticmethod 104 | def parse_from_yaml_config(yaml_config): 105 | """Parse from the yaml to generate the config object. 106 | 107 | Args: 108 | yaml_config (dict): configs from yaml file. 109 | 110 | Returns: 111 | EasyTPP.DataConfig: Config class for data. 112 | """ 113 | return DataConfig( 114 | train_dir=yaml_config.get('train_dir'), 115 | valid_dir=yaml_config.get('valid_dir'), 116 | test_dir=yaml_config.get('test_dir'), 117 | data_format=yaml_config.get('data_format'), 118 | specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs')) 119 | ) 120 | 121 | def copy(self): 122 | """Copy the config. 123 | 124 | Returns: 125 | EasyTPP.DataConfig: a copy of current config. 126 | """ 127 | return DataConfig(train_dir=self.train_dir, 128 | valid_dir=self.valid_dir, 129 | test_dir=self.test_dir, 130 | specs=self.data_specs) 131 | 132 | def get_data_dir(self, split): 133 | """Get the dir of the source raw data. 134 | 135 | Args: 136 | split (str): dataset split notation, 'train', 'dev' or 'valid', 'test'. 137 | 138 | Returns: 139 | str: dir of the source raw data file. 140 | """ 141 | split = split.lower() 142 | if split == 'train': 143 | return self.train_dir 144 | elif split in ['dev', 'valid']: 145 | return self.valid_dir 146 | else: 147 | return self.test_dir 148 | -------------------------------------------------------------------------------- /easy_tpp/config_factory/hpo_config.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.config_factory.config import Config 2 | from easy_tpp.config_factory.runner_config import RunnerConfig 3 | from easy_tpp.utils import parse_uri_to_protocol_and_path, py_assert 4 | 5 | 6 | class HPOConfig(Config): 7 | def __init__(self, framework_id, storage_uri, is_continuous, num_trials, num_jobs): 8 | """Initialize the HPO Config 9 | 10 | Args: 11 | framework_id (str): hpo framework id. 12 | storage_uri (str): result storage dir. 13 | is_continuous (bool): whether to continuously do the optimization. 14 | num_trials (int): num of trails used in optimization. 15 | num_jobs (int): num of the jobs. 16 | """ 17 | self.framework_id = framework_id or 'optuna' 18 | self.is_continuous = is_continuous if is_continuous is not None else True 19 | self.num_trials = num_trials or 50 20 | self.storage_uri = storage_uri 21 | self.num_jobs = num_jobs if num_jobs is not None else 1 22 | 23 | @property 24 | def storage_protocol(self): 25 | """Get the storage protocol 26 | 27 | Returns: 28 | str: the dir of the storage protocol. 29 | """ 30 | storage_protocol, _ = parse_uri_to_protocol_and_path(self.storage_uri) 31 | return storage_protocol 32 | 33 | @property 34 | def storage_path(self): 35 | """Get the storage protocol 36 | 37 | Returns: 38 | str: the dir of the hpo data storage. 39 | """ 40 | _, storage_path = parse_uri_to_protocol_and_path(self.storage_uri) 41 | return storage_path 42 | 43 | def get_yaml_config(self): 44 | """Return the config in dict (yaml compatible) format. 45 | 46 | Returns: 47 | dict: config of the HPO specs in dict format. 48 | """ 49 | return { 50 | 'framework_id': self.framework_id, 51 | 'storage_uri': self.storage_uri, 52 | 'is_continuous': self.is_continuous, 53 | 'num_trials': self.num_trials, 54 | 'num_jobs': self.num_jobs 55 | } 56 | 57 | @staticmethod 58 | def parse_from_yaml_config(yaml_config, **kwargs): 59 | """Parse from the yaml to generate the config object. 60 | 61 | Args: 62 | yaml_config (dict): configs from yaml file. 63 | 64 | Returns: 65 | EasyTPP.HPOConfig: Config class for HPO specs. 66 | """ 67 | if yaml_config is None: 68 | return None 69 | else: 70 | return HPOConfig( 71 | framework_id=yaml_config.get('framework_id'), 72 | storage_uri=yaml_config.get('storage_uri'), 73 | is_continuous=yaml_config.get('is_continuous'), 74 | num_trials=yaml_config.get('num_trials'), 75 | num_jobs=yaml_config.get('num_jobs'), 76 | ) 77 | 78 | def copy(self): 79 | """Copy the config. 80 | 81 | Returns: 82 | EasyTPP.HPOConfig: a copy of current config. 83 | """ 84 | return HPOConfig( 85 | framework_id=self.framework_id, 86 | storage_uri=self.storage_uri, 87 | is_continuous=self.is_continuous, 88 | num_trials=self.num_trials, 89 | num_jobs=self.num_jobs 90 | ) 91 | 92 | 93 | @Config.register('hpo_runner_config') 94 | class HPORunnerConfig(Config): 95 | def __init__(self, hpo_config, runner_config): 96 | """Initialize the config class 97 | 98 | Args: 99 | hpo_config (EasyTPP.HPOConfig): hpo config class. 100 | runner_config (EasyTPP.RunnerConfig): runner config class. 101 | """ 102 | self.hpo_config = hpo_config 103 | self.runner_config = runner_config 104 | 105 | @staticmethod 106 | def parse_from_yaml_config(yaml_config, **kwargs): 107 | """Parse from the yaml to generate the config object. 108 | 109 | Args: 110 | yaml_config (dict): configs from yaml file. 111 | 112 | Returns: 113 | EasyTPP.HPORunnerConfig: Config class for HPO specs. 114 | """ 115 | runner_config = RunnerConfig.parse_from_yaml_config(yaml_config, **kwargs) 116 | hpo_config = HPOConfig.parse_from_yaml_config(yaml_config.get('hpo'), **kwargs) 117 | py_assert(hpo_config is not None, ValueError, 'No hpo configs is provided for HyperTuner') 118 | return HPORunnerConfig( 119 | hpo_config=hpo_config, 120 | runner_config=runner_config 121 | ) 122 | 123 | def copy(self): 124 | """Copy the config. 125 | 126 | Returns: 127 | EasyTPP.HPORunnerConfig: a copy of current config. 128 | """ 129 | return HPORunnerConfig( 130 | hpo_config=self.hpo_config, 131 | runner_config=self.runner_config 132 | ) 133 | -------------------------------------------------------------------------------- /easy_tpp/config_factory/runner_config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | from easy_tpp.config_factory.config import Config 5 | from easy_tpp.config_factory.data_config import DataConfig 6 | from easy_tpp.config_factory.model_config import TrainerConfig, ModelConfig, BaseConfig 7 | from easy_tpp.utils import create_folder, logger, get_unique_id, get_stage, RunnerPhase, \ 8 | MetricsHelper, DefaultRunnerConfig, py_assert, is_torch_available, is_tf_available, is_tf_gpu_available, \ 9 | is_torch_gpu_available 10 | from easy_tpp.utils.const import Backend 11 | 12 | 13 | @Config.register('runner_config') 14 | class RunnerConfig(Config): 15 | def __init__(self, base_config, model_config, data_config, trainer_config): 16 | """Initialize the Config class. 17 | 18 | Args: 19 | base_config (EasyTPP.BaseConfig): BaseConfig object. 20 | model_config (EasyTPP.ModelConfig): ModelConfig object. 21 | data_config (EasyTPP.DataConfig): DataConfig object. 22 | trainer_config (EasyTPP.TrainerConfig): TrainerConfig object 23 | """ 24 | self.data_config = data_config 25 | self.model_config = model_config 26 | self.base_config = base_config 27 | self.trainer_config = trainer_config 28 | 29 | self.update_config() 30 | 31 | # save the complete config 32 | save_dir = self.base_config.specs['output_config_dir'] 33 | self.save_to_yaml_file(save_dir) 34 | 35 | logger.info(f'Save the config to {save_dir}') 36 | 37 | def get_yaml_config(self): 38 | """Return the config in dict (yaml compatible) format. 39 | 40 | Returns: 41 | dict: config of the runner config in dict format. 42 | """ 43 | return {'data_config': self.data_config.get_yaml_config(), 44 | 'base_config': self.base_config.get_yaml_config(), 45 | 'model_config': self.model_config.get_yaml_config(), 46 | 'trainer_config': self.trainer_config.get_yaml_config()} 47 | 48 | @staticmethod 49 | def parse_from_yaml_config(yaml_config, **kwargs): 50 | """Parse from the yaml to generate the config object. 51 | 52 | Args: 53 | yaml_config (dict): configs from yaml file. 54 | 55 | Returns: 56 | RunnerConfig: Config class for trainer specs. 57 | """ 58 | direct_parse = kwargs.get('direct_parse', False) 59 | if not direct_parse: 60 | exp_id = kwargs.get('experiment_id') 61 | yaml_exp_config = yaml_config[exp_id] 62 | dataset_id = yaml_exp_config.get('base_config').get('dataset_id') 63 | if dataset_id is None: 64 | dataset_id = DefaultRunnerConfig.DEFAULT_DATASET_ID 65 | try: 66 | yaml_data_config = yaml_config['data'][dataset_id] 67 | except KeyError: 68 | raise RuntimeError('dataset_id={} is not found in config.'.format(dataset_id)) 69 | 70 | data_config = DataConfig.parse_from_yaml_config(yaml_data_config) 71 | # add exp id to base config 72 | yaml_exp_config.get('base_config').update(exp_id=exp_id) 73 | 74 | else: 75 | yaml_exp_config = yaml_config 76 | data_config = DataConfig.parse_from_yaml_config(yaml_exp_config.get('data_config')) 77 | 78 | base_config = BaseConfig.parse_from_yaml_config(yaml_exp_config.get('base_config')) 79 | model_config = ModelConfig.parse_from_yaml_config(yaml_exp_config.get('model_config')) 80 | trainer_config = TrainerConfig.parse_from_yaml_config(yaml_exp_config.get('trainer_config')) 81 | 82 | return RunnerConfig( 83 | data_config=data_config, 84 | base_config=base_config, 85 | model_config=model_config, 86 | trainer_config=trainer_config 87 | ) 88 | 89 | def update_config(self): 90 | """Updated config dict. 91 | """ 92 | model_folder_name = get_unique_id() 93 | 94 | log_folder = create_folder(self.base_config.base_dir, model_folder_name) 95 | model_folder = create_folder(log_folder, 'models') 96 | 97 | self.base_config.specs['log_folder'] = log_folder 98 | self.base_config.specs['saved_model_dir'] = os.path.join(model_folder, 'saved_model') 99 | self.base_config.specs['saved_log_dir'] = os.path.join(log_folder, 'log') 100 | self.base_config.specs['output_config_dir'] = os.path.join(log_folder, 101 | f'{self.base_config.exp_id}_output.yaml') 102 | 103 | if self.trainer_config.use_tfb: 104 | self.base_config.specs['tfb_train_dir'] = create_folder(log_folder, 'tfb_train') 105 | self.base_config.specs['tfb_valid_dir'] = create_folder(log_folder, 'tfb_valid') 106 | 107 | current_stage = get_stage(self.base_config.stage) 108 | is_training = current_stage == RunnerPhase.TRAIN 109 | self.model_config.is_training = is_training 110 | self.model_config.gpu = self.trainer_config.gpu 111 | 112 | # update the dataset config => model config 113 | self.model_config.num_event_types_pad = self.data_config.data_specs.num_event_types_pad 114 | self.model_config.num_event_types = self.data_config.data_specs.num_event_types 115 | self.model_config.pad_token_id = self.data_config.data_specs.pad_token_id 116 | self.model_config.max_len = self.data_config.data_specs.max_len 117 | 118 | # update base config => model config 119 | model_id = self.base_config.model_id 120 | self.model_config.model_id = model_id 121 | 122 | if self.base_config.model_id == 'ODETPP' and self.base_config.backend == Backend.TF: 123 | py_assert(self.data_config.data_specs.padding_strategy == 'max_length', 124 | ValueError, 125 | 'For ODETPP in TensorFlow, we must pad all sequence to ' 126 | 'the same length (max len of the sequences)!') 127 | 128 | run = current_stage 129 | use_torch = self.base_config.backend == Backend.Torch 130 | device = 'GPU' if self.trainer_config.gpu >= 0 else 'CPU' 131 | 132 | py_assert(is_torch_available() if use_torch else is_tf_available(), ValueError, 133 | f'Backend {self.base_config.backend} is not supported in the current environment yet !') 134 | 135 | if use_torch and device == 'GPU': 136 | py_assert(is_torch_gpu_available(), 137 | ValueError, 138 | f'Torch cuda is not supported in the current environment yet!') 139 | 140 | if not use_torch and device == 'GPU': 141 | py_assert(is_tf_gpu_available(), 142 | ValueError, 143 | f'Tensorflow GPU is not supported in the current environment yet!') 144 | 145 | critical_msg = '{run} model {model_name} using {device} ' \ 146 | 'with {tf_torch} backend'.format(run=run, 147 | model_name=model_id, 148 | device=device, 149 | tf_torch=self.base_config.backend) 150 | 151 | logger.critical(critical_msg) 152 | 153 | return 154 | 155 | def get_metric_functions(self): 156 | return MetricsHelper.get_metrics_callback_from_names(self.trainer_config.metrics) 157 | 158 | def get_metric_direction(self, metric_name='rmse'): 159 | return MetricsHelper.get_metric_direction(metric_name) 160 | 161 | def copy(self): 162 | """Copy the config. 163 | 164 | Returns: 165 | RunnerConfig: a copy of current config. 166 | """ 167 | return RunnerConfig( 168 | base_config=copy.deepcopy(self.base_config), 169 | model_config=copy.deepcopy(self.model_config), 170 | data_config=copy.deepcopy(self.data_config), 171 | trainer_config=copy.deepcopy(self.trainer_config) 172 | ) 173 | -------------------------------------------------------------------------------- /easy_tpp/default_registers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/EasyTemporalPointProcess/7e2b7a001a293c506bd595e8ddb72d83967c2cb2/easy_tpp/default_registers/__init__.py -------------------------------------------------------------------------------- /easy_tpp/default_registers/register_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from easy_tpp.utils.const import PredOutputIndex 4 | from easy_tpp.utils.metrics import MetricsHelper 5 | 6 | 7 | @MetricsHelper.register(name='rmse', direction=MetricsHelper.MINIMIZE, overwrite=False) 8 | def rmse_metric_function(predictions, labels, **kwargs): 9 | """Compute rmse metrics of the time predictions. 10 | 11 | Args: 12 | predictions (np.array): model predictions. 13 | labels (np.array): ground truth. 14 | 15 | Returns: 16 | float: average rmse of the time predictions. 17 | """ 18 | seq_mask = kwargs.get('seq_mask') 19 | if seq_mask is None or len(seq_mask) == 0: 20 | # If mask is empty or None, use all predictions 21 | pred = predictions[PredOutputIndex.TimePredIndex] 22 | label = labels[PredOutputIndex.TimePredIndex] 23 | else: 24 | pred = predictions[PredOutputIndex.TimePredIndex][seq_mask] 25 | label = labels[PredOutputIndex.TimePredIndex][seq_mask] 26 | 27 | pred = np.reshape(pred, [-1]) 28 | label = np.reshape(label, [-1]) 29 | return np.sqrt(np.mean((pred - label) ** 2)) 30 | 31 | 32 | @MetricsHelper.register(name='acc', direction=MetricsHelper.MAXIMIZE, overwrite=False) 33 | def acc_metric_function(predictions, labels, **kwargs): 34 | """Compute accuracy ratio metrics of the type predictions. 35 | 36 | Args: 37 | predictions (np.array): model predictions. 38 | labels (np.array): ground truth. 39 | 40 | Returns: 41 | float: accuracy ratio of the type predictions. 42 | """ 43 | seq_mask = kwargs.get('seq_mask') 44 | if seq_mask is None or len(seq_mask) == 0: 45 | # If mask is empty or None, use all predictions 46 | pred = predictions[PredOutputIndex.TypePredIndex] 47 | label = labels[PredOutputIndex.TypePredIndex] 48 | else: 49 | pred = predictions[PredOutputIndex.TypePredIndex][seq_mask] 50 | label = labels[PredOutputIndex.TypePredIndex][seq_mask] 51 | pred = np.reshape(pred, [-1]) 52 | label = np.reshape(label, [-1]) 53 | return np.mean(pred == label) 54 | -------------------------------------------------------------------------------- /easy_tpp/default_registers/register_optuna_trials.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.hpo.optuna_hpo import OptunaTuner 2 | 3 | 4 | @OptunaTuner.register_trial_func(model_id='default', overwrite=False) 5 | def default_trial(trial, **kwargs): 6 | setting = { 7 | "trainer_config": {"max_epoch": "suggest_int(40, 100, log=True)", 8 | "batch_size": 256, 9 | "optimizer": "adam", 10 | "learning_rate": "suggest_float(5e-4, 1e-2, log=True)"}, 11 | "model_config": {"hidden_size": "suggest_int(16, 32)"} 12 | } 13 | return setting 14 | -------------------------------------------------------------------------------- /easy_tpp/hpo/__init__.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.hpo.base_hpo import HyperTuner 2 | from easy_tpp.hpo.optuna_hpo import OptunaTuner 3 | from easy_tpp.default_registers.register_optuna_trials import * 4 | 5 | __all__ = ['HyperTuner', 6 | 'OptunaTuner'] -------------------------------------------------------------------------------- /easy_tpp/hpo/base_hpo.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections import defaultdict 3 | from typing import List 4 | 5 | from easy_tpp.utils import logger, Registrable 6 | 7 | 8 | class HyperTuner(Registrable): 9 | _trial_register_center = defaultdict(dict) 10 | 11 | def __init__(self, config, trial_end_callbacks: List[callable] = None): 12 | """Initialize the tuner 13 | 14 | Args: 15 | config (EasyTPP.Config): config class 16 | trial_end_callbacks (List[callable]): List of callback functions to be executed after each trial. 17 | """ 18 | self.config = config 19 | self.trial_end_callbacks = trial_end_callbacks or [] 20 | logger.info(f'Storage of hpo framework: {self.config.hpo_config.storage_uri}') 21 | 22 | @abstractmethod 23 | def get_all_best_runner_configs(self): 24 | pass 25 | 26 | @abstractmethod 27 | def get_best_runner_config_by_name(self, runner_id): 28 | """ 29 | 30 | Args: 31 | runner_id (str): 32 | 33 | Returns: 34 | 35 | """ 36 | pass 37 | 38 | @abstractmethod 39 | def get_num_remain_trials_by_name(self, runner_id): 40 | pass 41 | 42 | @staticmethod 43 | def build_from_config(config, trial_end_callbacks: List[callable] = None): 44 | """Load yaml config file from disk. 45 | 46 | Args: 47 | config (EasyTPP.Config): config class 48 | trial_end_callbacks (List[callable]): List of callback functions to be executed after each trial. 49 | 50 | Returns: 51 | EasyTPP.Config: Config object corresponding to cls. 52 | """ 53 | runner_cls = HyperTuner.by_name(config.hpo_config.framework_id) 54 | return runner_cls(config, trial_end_callbacks) 55 | 56 | # ---------------------- Trail Register and Get Functions --------------------- 57 | 58 | @classmethod 59 | def register_trial_func(cls, model_id, overwrite=True): 60 | """Register the trial functions in HPO 61 | 62 | Args: 63 | model_id (str): id of the models. 64 | overwrite (bool, optional): whether to overwrite the trial function. Defaults to True. 65 | 66 | Returns: 67 | dict: the registered trial function 68 | """ 69 | register_center = HyperTuner._trial_register_center 70 | 71 | def _register_trial(func): 72 | if model_id in register_center[cls]: 73 | if overwrite: 74 | register_center[cls][model_id] = func 75 | logger.info(f'The trial for {model_id} is already registered, but overwrite it.') 76 | else: 77 | logger.warn(f'The trial for {model_id} is already registered, and cannot be overwritten!') 78 | else: 79 | register_center[cls][model_id] = func 80 | logger.info(f'Trial register: {cls.get_registered_name()} - {model_id}') 81 | return func 82 | 83 | return _register_trial 84 | 85 | @classmethod 86 | def retrieve_trial_func_by_model_name(cls, name): 87 | """Retrieve the trail function by the model id 88 | 89 | Args: 90 | name (str): model id. 91 | 92 | Raises: 93 | RuntimeError: non registered error for the hpo framework. 94 | 95 | Returns: 96 | dict: registered trial center 97 | """ 98 | cls_trial_rc = HyperTuner._trial_register_center[cls] 99 | if name not in cls_trial_rc: 100 | if 'default' in cls_trial_rc: 101 | logger.warn( 102 | f'Trial for {name} in {cls.get_registered_name()} is not existed, and use default trial!' 103 | ) 104 | name = 'default' 105 | else: 106 | raise RuntimeError(f'This HPO Framework is not registered!') 107 | return cls_trial_rc[name] 108 | 109 | @classmethod 110 | def get_registered_name(cls): 111 | """Get the name of the registered hpo class. 112 | 113 | Returns: 114 | str: the name of the registered hpo class. 115 | """ 116 | hpo_rc = HyperTuner.registry_dict() 117 | for registered_name, hpo_cls in hpo_rc.items(): 118 | if cls in hpo_cls: 119 | return registered_name 120 | 121 | logger.warn(f'The hpo framework is not registered: {cls}') 122 | return None 123 | 124 | @abstractmethod 125 | def run(self): 126 | """Run the process. 127 | 128 | Raises: 129 | NotImplementedError: error raised in base class. 130 | """ 131 | raise NotImplementedError 132 | -------------------------------------------------------------------------------- /easy_tpp/model/__init__.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.model.torch_model.torch_anhn import ANHN as TorchANHN 2 | from easy_tpp.model.torch_model.torch_attnhp import AttNHP as TorchAttNHP 3 | from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel 4 | from easy_tpp.model.torch_model.torch_fullynn import FullyNN as TorchFullyNN 5 | from easy_tpp.model.torch_model.torch_intensity_free import IntensityFree as TorchIntensityFree 6 | from easy_tpp.model.torch_model.torch_nhp import NHP as TorchNHP 7 | from easy_tpp.model.torch_model.torch_ode_tpp import ODETPP as TorchODETPP 8 | from easy_tpp.model.torch_model.torch_rmtpp import RMTPP as TorchRMTPP 9 | from easy_tpp.model.torch_model.torch_sahp import SAHP as TorchSAHP 10 | from easy_tpp.model.torch_model.torch_thp import THP as TorchTHP 11 | 12 | # by default, we use torch and do not install tf, therefore we ignore the import error 13 | try: 14 | from easy_tpp.model.tf_model.tf_basemodel import TfBaseModel 15 | from easy_tpp.model.tf_model.tf_nhp import NHP as TfNHP 16 | from easy_tpp.model.tf_model.tf_ode_tpp import ODETPP as TfODETPP 17 | from easy_tpp.model.tf_model.tf_thp import THP as TfTHP 18 | from easy_tpp.model.tf_model.tf_sahp import SAHP as TfSAHP 19 | from easy_tpp.model.tf_model.tf_rmtpp import RMTPP as TfRMTPP 20 | from easy_tpp.model.tf_model.tf_attnhp import AttNHP as TfAttNHP 21 | from easy_tpp.model.tf_model.tf_anhn import ANHN as TfANHN 22 | from easy_tpp.model.tf_model.tf_fullynn import FullyNN as TfFullyNN 23 | from easy_tpp.model.tf_model.tf_intensity_free import IntensityFree as TfIntensityFree 24 | except ImportError: 25 | pass 26 | 27 | __all__ = ['TorchBaseModel', 28 | 'TorchNHP', 29 | 'TorchAttNHP', 30 | 'TorchTHP', 31 | 'TorchSAHP', 32 | 'TorchFullyNN', 33 | 'TorchIntensityFree', 34 | 'TorchODETPP', 35 | 'TfBaseModel', 36 | 'TfNHP', 37 | 'TfAttNHP', 38 | 'TfTHP', 39 | 'TfSAHP', 40 | 'TfANHN', 41 | 'TfFullyNN', 42 | 'TfIntensityFree', 43 | 'TfODETPP'] 44 | -------------------------------------------------------------------------------- /easy_tpp/model/tf_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/EasyTemporalPointProcess/7e2b7a001a293c506bd595e8ddb72d83967c2cb2/easy_tpp/model/tf_model/__init__.py -------------------------------------------------------------------------------- /easy_tpp/model/tf_model/tf_intensity_free.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | from tensorflow.keras import layers 4 | 5 | from easy_tpp.model.tf_model.tf_basemodel import TfBaseModel 6 | from easy_tpp.utils.tf_utils import get_shape_list 7 | 8 | tfd = tfp.distributions 9 | tfb = tfp.bijectors 10 | 11 | if tf.__version__ >= '2.0': 12 | tf = tf.compat.v1 13 | tf.disable_v2_behavior() 14 | 15 | 16 | class MixtureSameFamily(tfd.MixtureSameFamily): 17 | """Mixture (same-family) distribution, redefined `log_cdf` and `log_survival_function`. 18 | """ 19 | 20 | def log_cdf(self, x): 21 | x = x[..., None] 22 | log_cdf_x = self.components_distribution.log_cdf(x) 23 | mix_logits = self.mixture_distribution.logits 24 | return tf.reduce_logsumexp(log_cdf_x + mix_logits, axis=-1) 25 | 26 | def log_survival_function(self, x): 27 | x = x[..., None] 28 | log_sf_x = self.components_distribution.log_survival_function(x) 29 | mix_logits = self.mixture_distribution.logits 30 | return tf.reduce_logsumexp(log_sf_x + mix_logits, axis=-1) 31 | 32 | 33 | class LogNormalMixtureDistribution: 34 | """ 35 | Mixture of log-normal distributions. 36 | 37 | Args: 38 | locs (tensor): [batch_size, seq_len, num_mix_components]. 39 | log_scales (tensor): [batch_size, seq_len, num_mix_components]. 40 | log_weights (tensor): [batch_size, seq_len, num_mix_components]. 41 | mean_log_inter_time (float): Average log-inter-event-time. 42 | std_log_inter_time (float): Std of log-inter-event-times. 43 | """ 44 | 45 | def __init__(self, locs, log_scales, log_weights, mean_log_inter_time, std_log_inter_time, validate_args=None): 46 | mixture_dist = tfd.Categorical(logits=log_weights) 47 | component_dist = tfd.Normal(loc=locs, scale=tf.exp(log_scales)) 48 | self.GMM = MixtureSameFamily(mixture_dist, component_dist) 49 | self.mean_log_inter_time = mean_log_inter_time 50 | self.std_log_inter_time = std_log_inter_time 51 | 52 | self.transformed_distribution = tfd.TransformedDistribution(self.GMM, 53 | bijector=tfb.Exp(), 54 | validate_args=validate_args) 55 | 56 | def log_prob(self, x): 57 | return self.transformed_distribution.log_prob(x) 58 | 59 | def log_survival_function(self, x): 60 | return self.transformed_distribution.log_survival_function(x) 61 | 62 | 63 | class IntensityFree(TfBaseModel): 64 | """Tensorflow implementation of Intensity-Free Learning of Temporal Point Processes, ICLR 2020. 65 | https://openreview.net/pdf?id=HygOjhEYDH 66 | 67 | reference: https://github.com/shchur/ifl-tpp 68 | """ 69 | 70 | def __init__(self, model_config): 71 | """Initialize the model 72 | 73 | Args: 74 | model_config (EasyTPP.ModelConfig): config of model specs. 75 | """ 76 | super(IntensityFree, self).__init__(model_config) 77 | 78 | self.num_mix_components = model_config.data_specs['num_mix_components'] 79 | self.num_features = 1 + self.hidden_size 80 | 81 | def build_graph(self): 82 | """Build up the network 83 | """ 84 | with tf.variable_scope('IntensityFree'): 85 | self.build_input_graph() 86 | 87 | self.layer_rnn = layers.GRU(self.hidden_size, 88 | return_state=False, 89 | return_sequences=True) 90 | # activation='tanh') 91 | 92 | self.context_init = tf.zeros(self.hidden_size)[None, None, :] 93 | self.mark_linear = layers.Dense(self.num_event_types_pad) 94 | self.linear = layers.Dense(3 * self.num_mix_components) 95 | 96 | self.loss, self.num_event = self.loglike_loss() 97 | 98 | # Make predictions 99 | if self.event_sampler and self.gen_config.num_step_gen == 1: 100 | self.dtime_predict_one_step, self.type_predict_one_step = \ 101 | self.predict_one_step_at_every_event(self.time_seqs, 102 | self.time_delta_seqs, 103 | self.type_seqs) 104 | 105 | if self.event_sampler and self.gen_config.num_step_gen > 1: 106 | # make generations 107 | self.dtime_generation, self.type_generation = \ 108 | self.predict_multi_step_since_last_event(self.time_seqs, 109 | self.time_delta_seqs, 110 | self.type_seqs, 111 | num_step=self.gen_config.num_step_gen) 112 | 113 | def forward(self, time_delta_seqs, type_seqs): 114 | """Call the model. 115 | 116 | Args: 117 | time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. 118 | type_seqs (tensor): [batch_size, seq_len], event type seqs. 119 | 120 | Returns: 121 | tensor: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens. 122 | """ 123 | # [batch_size, seq_len, 1] 124 | temporal_seqs = tf.log(time_delta_seqs + self.eps)[..., None] 125 | 126 | # [batch_size, seq_len, hidden_size] 127 | type_emb = self.layer_type_emb(type_seqs) 128 | 129 | # [batch_size, seq_len, hidden_size + 1] 130 | features = tf.concat([temporal_seqs, type_emb], axis=-1) 131 | 132 | # [batch_size, seq_len, hidden_size] 133 | context = self.layer_rnn(features) 134 | 135 | batch_size, seq_len, hidden_size = get_shape_list(context) 136 | 137 | # (batch_size, 1, hidden_size) 138 | context_init = tf.tile(self.context_init, [batch_size, 1, 1]) 139 | 140 | # (batch_size, seq_len + 1, hidden_size) 141 | context = tf.concat([context_init, context], axis=1) 142 | 143 | return context 144 | 145 | def loglike_loss(self): 146 | """Compute the loglike loss. 147 | 148 | Returns: 149 | tuple: loglikelihood loss and num of events. 150 | 151 | """ 152 | 153 | time_delta_seqs = self.time_delta_seqs 154 | type_seqs = self.type_seqs 155 | batch_non_pad_mask = self.batch_non_pad_mask 156 | 157 | mean_log_inter_time = tf.reduce_mean(tf.log(time_delta_seqs)) 158 | std_log_inter_time = tf.math.reduce_std(tf.log(time_delta_seqs)) 159 | 160 | # [batch_size, seq_len, hidden_size] 161 | # seq_len = time_delta_seqs[:, 1:].size()[1] 162 | context = self.forward(time_delta_seqs[:, 1:], type_seqs[:, :-1]) 163 | 164 | # (batch_size, seq_len, 3 * num_mix_components) 165 | raw_params = self.linear(context) 166 | locs = raw_params[..., :self.num_mix_components] 167 | log_scales = raw_params[..., self.num_mix_components: (2 * self.num_mix_components)] 168 | log_weights = raw_params[..., (2 * self.num_mix_components):] 169 | 170 | log_weights = tf.nn.log_softmax(log_weights, dim=-1) 171 | inter_time_dist = LogNormalMixtureDistribution( 172 | locs=locs, 173 | log_scales=log_scales, 174 | log_weights=log_weights, 175 | mean_log_inter_time=mean_log_inter_time, 176 | std_log_inter_time=std_log_inter_time 177 | ) 178 | 179 | inter_times = tf.clip_by_value(time_delta_seqs, 1e-10, 1e10) 180 | # (batch_size, seq_len) 181 | log_p = inter_time_dist.log_prob(inter_times) 182 | 183 | # (batch_size, 1) 184 | # last_event_idx = tf.cast(tf.reduce_sum(batch_non_pad_mask, axis=-1, keepdims=True), 185 | # tf.int32) - 1 186 | 187 | log_surv_all = inter_time_dist.log_survival_function(inter_times) 188 | 189 | self.inter_times = log_surv_all 190 | 191 | # 192 | # # (batch_size,) 193 | # log_surv_last = tf.gather(log_surv_all, axis=-1, indices=last_event_idx)[..., None] 194 | 195 | # (batch_size, seq_len, num_marks) 196 | mark_logits = tf.nn.log_softmax(self.mark_linear(context), dim=-1) 197 | mark_dist = tfd.Categorical(logits=mark_logits) 198 | log_p += mark_dist.log_prob(type_seqs) 199 | 200 | # (batch_size, seq_len) 201 | log_p = tf.boolean_mask(log_p, batch_non_pad_mask) + self.eps 202 | # (batch_size,) 203 | loss = -tf.reduce_sum(log_p) 204 | 205 | num_events = get_shape_list(log_p)[0] 206 | 207 | return loss, num_events 208 | -------------------------------------------------------------------------------- /easy_tpp/model/torch_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/EasyTemporalPointProcess/7e2b7a001a293c506bd595e8ddb72d83967c2cb2/easy_tpp/model/torch_model/__init__.py -------------------------------------------------------------------------------- /easy_tpp/model/torch_model/torch_rmtpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel 6 | 7 | class RMTPP(TorchBaseModel): 8 | """Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. 9 | https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf 10 | """ 11 | 12 | def __init__(self, model_config): 13 | """Initialize the model 14 | 15 | Args: 16 | model_config (EasyTPP.ModelConfig): config of model specs. 17 | """ 18 | super(RMTPP, self).__init__(model_config) 19 | 20 | self.layer_temporal_emb = nn.Linear(1, self.hidden_size) 21 | self.layer_rnn = nn.RNN(input_size=self.hidden_size, hidden_size=self.hidden_size, 22 | num_layers=1, batch_first=True) 23 | 24 | self.hidden_to_intensity_logits = nn.Linear(self.hidden_size, self.num_event_types) 25 | self.b_t = nn.Parameter(torch.zeros(1, self.num_event_types)) 26 | self.w_t = nn.Parameter(torch.zeros(1, self.num_event_types)) 27 | nn.init.xavier_normal_(self.b_t) 28 | nn.init.xavier_normal_(self.w_t) 29 | 30 | def evolve_and_get_intentsity(self, right_hiddens_BNH, dts_BNG): 31 | """ 32 | Eq.11 that computes intensity. 33 | """ 34 | past_influence_BNGM = self.hidden_to_intensity_logits(right_hiddens_BNH[..., None, :]) 35 | intensity_BNGM = (past_influence_BNGM + self.w_t[None, None, :] * dts_BNG[..., None] 36 | + self.b_t[None, None, :]).clamp(max=math.log(1e5)).exp() 37 | return intensity_BNGM 38 | 39 | def forward(self, batch): 40 | """ 41 | Suppose we have inputs with original sequence length N+1 42 | ts: [t0, t1, ..., t_N] 43 | dts: [0, t1 - t0, t2 - t1, ..., t_N - t_{N-1}] 44 | marks: [k0, k1, ..., k_N] (k0 and kN could be padded marks if t0 and tN correspond to left and right windows) 45 | 46 | Return: 47 | left limits of *intensity* at [t_1, ..., t_N] of shape: (batch_size, seq_len - 1, hidden_dim) 48 | right limits of *hidden states* [t_0, ..., t_{N-1}, t_N] of shape: (batch_size, seq_len, hidden_dim) 49 | We need the right limit of t_N to sample continuation. 50 | """ 51 | 52 | t_BN, dt_BN, marks_BN, _, _ = batch 53 | mark_emb_BNH = self.layer_type_emb(marks_BN) 54 | time_emb_BNH = self.layer_temporal_emb(t_BN[..., None]) 55 | right_hiddens_BNH, _ = self.layer_rnn(mark_emb_BNH + time_emb_BNH) 56 | left_intensity_B_Nm1_M = self.evolve_and_get_intentsity(right_hiddens_BNH[:, :-1, :], dt_BN[:, 1:][...,None]).squeeze(-2) 57 | return left_intensity_B_Nm1_M, right_hiddens_BNH 58 | 59 | 60 | def loglike_loss(self, batch): 61 | """Compute the log-likelihood loss. 62 | 63 | Args: 64 | batch (list): batch input. 65 | 66 | Returns: 67 | tuple: loglikelihood loss and num of events. 68 | """ 69 | ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch 70 | 71 | # compute left intensity and hidden states at event time 72 | # left limits of intensity at [t_1, ..., t_N] 73 | # right limits of hidden states at [t_0, ..., t_{N-1}, t_N] 74 | left_intensity_B_Nm1_M, right_hiddens_BNH = self.forward((ts_BN, dts_BN, marks_BN, None, None)) 75 | right_hiddens_B_Nm1_H = right_hiddens_BNH[..., :-1, :] # discard right limit at t_N for logL 76 | 77 | dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) 78 | intensity_dts_B_Nm1_G_M = self.evolve_and_get_intentsity(right_hiddens_B_Nm1_H, dts_sample_B_Nm1_G) 79 | 80 | event_ll, non_event_ll, num_events = self.compute_loglikelihood( 81 | lambda_at_event=left_intensity_B_Nm1_M, 82 | lambdas_loss_samples=intensity_dts_B_Nm1_G_M, 83 | time_delta_seq=dts_BN[:, 1:], 84 | seq_mask=batch_non_pad_mask[:, 1:], 85 | type_seq=marks_BN[:, 1:] 86 | ) 87 | 88 | # compute loss to minimize 89 | loss = - (event_ll - non_event_ll).sum() 90 | return loss, num_events 91 | 92 | 93 | 94 | def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): 95 | """Compute the intensity at sampled times, not only event times. 96 | 97 | Args: 98 | time_seq (tensor): [batch_size, seq_len], times seqs. 99 | time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. 100 | event_seq (tensor): [batch_size, seq_len], event type seqs. 101 | sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. 102 | 103 | Returns: 104 | tensor: [batch_size, num_times, num_mc_sample, num_event_types], 105 | intensity at each timestamp for each event type. 106 | """ 107 | 108 | compute_last_step_only = kwargs.get('compute_last_step_only', False) 109 | 110 | _input = time_seqs, time_delta_seqs, type_seqs, None, None 111 | _, right_hiddens_BNH = self.forward(_input) 112 | 113 | if compute_last_step_only: 114 | sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH[:, -1:, :], sample_dtimes[:, -1:, :]) 115 | else: 116 | sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH, sample_dtimes) # shape: [B, N, G, M] 117 | return sampled_intensities 118 | -------------------------------------------------------------------------------- /easy_tpp/model/torch_model/torch_thp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, TimePositionalEncoding, ScaledSoftplus 5 | from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel 6 | 7 | 8 | class THP(TorchBaseModel): 9 | """Torch implementation of Transformer Hawkes Process, ICML 2020, https://arxiv.org/abs/2002.09291. 10 | Note: Part of the code is collected from https://github.com/yangalan123/anhp-andtt/tree/master/thp. 11 | """ 12 | 13 | def __init__(self, model_config): 14 | """Initialize the model 15 | 16 | Args: 17 | model_config (EasyTPP.ModelConfig): config of model specs. 18 | """ 19 | super(THP, self).__init__(model_config) 20 | self.d_model = model_config.hidden_size 21 | self.d_time = model_config.time_emb_size 22 | self.use_norm = model_config.use_ln 23 | 24 | self.n_layers = model_config.num_layers 25 | self.n_head = model_config.num_heads 26 | self.dropout = model_config.dropout_rate 27 | 28 | self.layer_temporal_encoding = TimePositionalEncoding(self.d_model, device=self.device) 29 | 30 | self.factor_intensity_base = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) 31 | self.factor_intensity_decay = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) 32 | nn.init.xavier_normal_(self.factor_intensity_base) 33 | nn.init.xavier_normal_(self.factor_intensity_decay) 34 | 35 | # convert hidden vectors into event-type-sized vector 36 | self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) 37 | self.softplus = ScaledSoftplus(self.num_event_types) # learnable mark-specific beta 38 | 39 | # Add MLP layer 40 | # Equation (5) 41 | self.feed_forward = nn.Sequential( 42 | nn.Linear(self.d_model, self.d_model * 2), 43 | nn.ReLU(), 44 | nn.Linear(self.d_model * 2, self.d_model) 45 | ) 46 | 47 | self.stack_layers = nn.ModuleList( 48 | [EncoderLayer( 49 | self.d_model, 50 | MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, 51 | output_linear=False), 52 | use_residual=False, 53 | feed_forward=self.feed_forward, 54 | dropout=self.dropout 55 | ) for _ in range(self.n_layers)]) 56 | 57 | def forward(self, time_seqs, type_seqs, attention_mask): 58 | """Call the model 59 | 60 | Args: 61 | time_seqs (tensor): [batch_size, seq_len], timestamp seqs. 62 | type_seqs (tensor): [batch_size, seq_len], event type seqs. 63 | attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks. 64 | 65 | Returns: 66 | tensor: hidden states at event times. 67 | """ 68 | # [batch_size, seq_len, hidden_size] 69 | tem_enc = self.layer_temporal_encoding(time_seqs) 70 | enc_output = self.layer_type_emb(type_seqs) 71 | 72 | # [batch_size, seq_len, hidden_size] 73 | for enc_layer in self.stack_layers: 74 | enc_output += tem_enc 75 | enc_output = enc_layer( 76 | enc_output, 77 | mask=attention_mask) 78 | 79 | return enc_output 80 | 81 | def loglike_loss(self, batch): 82 | """Compute the loglike loss. 83 | 84 | Args: 85 | batch (tuple, list): batch input. 86 | 87 | Returns: 88 | tuple: loglike loss, num events. 89 | """ 90 | time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch 91 | 92 | # 1. compute event-loglik 93 | # [batch_size, seq_len, hidden_size] 94 | enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1]) 95 | 96 | # [batch_size, seq_len, num_event_types] 97 | # update time decay based on Equation (6) 98 | # [1, 1, num_event_types] 99 | factor_intensity_decay = self.factor_intensity_decay[None, ...] 100 | factor_intensity_base = self.factor_intensity_base[None, ...] 101 | 102 | # update time decay based on Equation (6) 103 | # [batch_size, seq_len, num_event_types] 104 | intensity_states = factor_intensity_decay * time_delta_seqs[:, 1:, None] + self.layer_intensity_hidden( 105 | enc_out) + factor_intensity_base 106 | 107 | lambda_at_event = self.softplus(intensity_states) 108 | 109 | # 2. compute non-event-loglik (using MC sampling to compute integral) 110 | # 2.1 sample dtimes 111 | # [batch_size, seq_len, num_sample] 112 | sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) 113 | 114 | # 2.2 compute intensities at sampled times 115 | # [batch_size, num_times = max_len - 1, num_sample, event_num] 116 | state_t_sample = self.compute_states_at_sample_times(event_states=enc_out, 117 | sample_dtimes=sample_dtimes) 118 | lambda_t_sample = self.softplus(state_t_sample) 119 | 120 | event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, 121 | lambdas_loss_samples=lambda_t_sample, 122 | time_delta_seq=time_delta_seqs[:, 1:], 123 | seq_mask=batch_non_pad_mask[:, 1:], 124 | type_seq=type_seqs[:, 1:]) 125 | 126 | # compute loss to minimize 127 | loss = - (event_ll - non_event_ll).sum() 128 | return loss, num_events 129 | 130 | def compute_states_at_sample_times(self, event_states, sample_dtimes): 131 | """Compute the hidden states at sampled times. 132 | 133 | Args: 134 | event_states (tensor): [batch_size, seq_len, hidden_size]. 135 | sample_dtimes (tensor): [batch_size, seq_len, num_samples]. 136 | 137 | Returns: 138 | tensor: hidden state at each sampled time. 139 | """ 140 | # [batch_size, seq_len, 1, hidden_size] 141 | event_states = event_states[:, :, None, :] 142 | 143 | # [batch_size, seq_len, num_samples, 1] 144 | sample_dtimes = sample_dtimes[..., None] 145 | 146 | # [1, 1, 1, num_event_types] 147 | factor_intensity_decay = self.factor_intensity_decay[None, None, ...] 148 | factor_intensity_base = self.factor_intensity_base[None, None, ...] 149 | 150 | # update time decay based on Equation (6) 151 | # [batch_size, seq_len, num_samples, num_event_types] 152 | intensity_states = factor_intensity_decay * sample_dtimes + self.layer_intensity_hidden( 153 | event_states) + factor_intensity_base 154 | 155 | return intensity_states 156 | 157 | def compute_intensities_at_sample_times(self, 158 | time_seqs, 159 | time_delta_seqs, 160 | type_seqs, 161 | sample_dtimes, 162 | **kwargs): 163 | """Compute hidden states at sampled times. 164 | 165 | Args: 166 | time_seqs (tensor): [batch_size, seq_len], times seqs. 167 | time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. 168 | type_seqs (tensor): [batch_size, seq_len], event type seqs. 169 | sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. 170 | 171 | Returns: 172 | tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. 173 | """ 174 | 175 | attention_mask = kwargs.get('attention_mask', None) 176 | compute_last_step_only = kwargs.get('compute_last_step_only', False) 177 | 178 | if attention_mask is None: 179 | batch_size, seq_len = time_seqs.size() 180 | attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0) 181 | attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) 182 | 183 | # [batch_size, seq_len, num_samples] 184 | enc_out = self.forward(time_seqs, type_seqs, attention_mask) 185 | 186 | # [batch_size, seq_len, num_samples, hidden_size] 187 | encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) 188 | 189 | if compute_last_step_only: 190 | lambdas = self.softplus(encoder_output[:, -1:, :, :]) 191 | else: 192 | # [batch_size, seq_len, num_samples, num_event_types] 193 | lambdas = self.softplus(encoder_output) 194 | return lambdas 195 | -------------------------------------------------------------------------------- /easy_tpp/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.preprocess.data_loader import TPPDataLoader, EventTokenizer, TPPDataset, get_data_loader 2 | 3 | __all__ = ['TPPDataLoader', 4 | 'EventTokenizer', 5 | 'TPPDataset', 6 | 'get_data_loader'] 7 | -------------------------------------------------------------------------------- /easy_tpp/preprocess/data_collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union, Optional 3 | 4 | from easy_tpp.preprocess.event_tokenizer import EventTokenizer 5 | from easy_tpp.utils import PaddingStrategy, TruncationStrategy 6 | 7 | 8 | @dataclass 9 | class TPPDataCollator: 10 | """ 11 | Data collator that will dynamically pad the inputs of event sequences. 12 | 13 | Args: 14 | tokenizer ([`EventTokenizer`]): 15 | The tokenizer used for encoding the data. 16 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 17 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 18 | among: 19 | 20 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single 21 | sequence is provided). 22 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 23 | acceptable input length for the model if that argument is not provided. 24 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). 25 | max_length (`int`, *optional*): 26 | Maximum length of the returned list and optionally padding length (see above). 27 | return_tensors (`str`): 28 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 29 | """ 30 | 31 | tokenizer: EventTokenizer 32 | padding: Union[bool, str, PaddingStrategy] = True 33 | max_length: Optional[int] = None 34 | truncation: Union[bool, str, TruncationStrategy] = False 35 | return_tensors: str = "pt" 36 | 37 | def __call__(self, features, return_tensors=None): 38 | if return_tensors is None: 39 | return_tensors = self.return_tensors 40 | 41 | batch = self.tokenizer.pad( 42 | features, 43 | padding=self.padding, 44 | max_length=self.max_length, 45 | truncation=self.truncation, 46 | return_tensors=return_tensors, 47 | ) 48 | 49 | return batch 50 | -------------------------------------------------------------------------------- /easy_tpp/preprocess/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | from easy_tpp.preprocess.data_collator import TPPDataCollator 8 | from easy_tpp.preprocess.event_tokenizer import EventTokenizer 9 | from easy_tpp.utils import py_assert, is_tf_available 10 | 11 | 12 | class TPPDataset(Dataset): 13 | def __init__(self, data: Dict): 14 | self.data_dict = data 15 | self.time_seqs = self.data_dict['time_seqs'] 16 | self.time_delta_seqs = self.data_dict['time_delta_seqs'] 17 | self.type_seqs = self.data_dict['type_seqs'] 18 | 19 | def __len__(self): 20 | """ 21 | 22 | Returns: length of the dataset 23 | 24 | """ 25 | 26 | py_assert(len(self.time_seqs) == len(self.type_seqs) and len(self.time_delta_seqs) == len(self.type_seqs), 27 | ValueError, 28 | f"Inconsistent lengths for data! time_seq_len:{len(self.time_seqs)}, event_len: " 29 | f"{len(self.type_seqs)}, time_delta_seq_len: {len(self.time_delta_seqs)}") 30 | 31 | return len(self.time_seqs) 32 | 33 | def __getitem__(self, idx): 34 | """ 35 | 36 | Args: 37 | idx: iteration index 38 | 39 | Returns: 40 | dict: a dict of time_seqs, time_delta_seqs and type_seqs element 41 | 42 | """ 43 | return dict({'time_seqs': self.time_seqs[idx], 'time_delta_seqs': self.time_delta_seqs[idx], 44 | 'type_seqs': self.type_seqs[idx]}) 45 | 46 | def to_tf_dataset(self, data_collator: TPPDataCollator, **kwargs): 47 | """Generate a dataset to use in Tensorflow 48 | 49 | Args: 50 | data_collator (TPPDataCollator): collator to tokenize the event data. 51 | 52 | Raises: 53 | ImportError: Tensorflow is not installed. 54 | 55 | Returns: 56 | tf.keras.utils.Sequence: tf Dataset object for TPP data. 57 | """ 58 | if is_tf_available(): 59 | import tensorflow as tf 60 | 61 | if tf.__version__ >= '2.0': 62 | tf = tf.compat.v1 63 | tf.disable_v2_behavior() 64 | else: 65 | raise ImportError("Called a Tensorflow-specific function but Tensorflow is not installed.") 66 | 67 | class TfTPPDataset(tf.keras.utils.Sequence): 68 | def __init__(self, time_seqs, time_delta_seqs, type_seqs, **kwargs): 69 | """Initialize the class. 70 | 71 | Args: 72 | batch_size (int): size of batch. 73 | shuffle (bool): whether to shuffle the data in each batch. 74 | 75 | """ 76 | self.time_seqs = time_seqs 77 | self.time_delta_seqs = time_delta_seqs 78 | self.type_seqs = type_seqs 79 | self.data_len = len(self.time_delta_seqs) 80 | self.batch_size = kwargs.pop('batch_size') 81 | self.shuffle = kwargs.pop('shuffle', False) 82 | self.idx = np.arange(self.data_len) 83 | self.kwargs = kwargs 84 | 85 | def __getitem__(self, index): 86 | # get batch indexes from shuffled indexes 87 | batch_idx = self.idx[index * self.batch_size:(index + 1) * self.batch_size] 88 | batch = dict({'time_seqs': [self.time_seqs[i] for i in batch_idx], 89 | 'time_delta_seqs': [self.time_delta_seqs[i] for i in batch_idx], 90 | 'type_seqs': [self.type_seqs[i] for i in batch_idx]}) 91 | 92 | batch = data_collator(batch, **self.kwargs) 93 | return batch 94 | 95 | def __len__(self): 96 | # Denotes the number of batches per epoch 97 | return math.ceil(self.data_len / self.batch_size) 98 | 99 | def on_epoch_end(self): 100 | # Updates indexes after each epoch 101 | self.idx = np.arange(self.data_len) 102 | if self.shuffle: 103 | np.random.shuffle(self.idx) 104 | 105 | return TfTPPDataset(self.time_seqs, self.time_delta_seqs, self.type_seqs, **kwargs) 106 | 107 | def get_dt_stats(self): 108 | x_bar, s_2_x, n = 0., 0., 0 109 | min_dt, max_dt = np.inf, -np.inf 110 | 111 | for dts, marks in zip(self.time_delta_seqs, self.type_seqs): 112 | dts = np.array(dts[1:-1 if marks[-1] == -1 else None]) 113 | min_dt = min(min_dt, dts.min()) 114 | max_dt = max(max_dt, dts.max()) 115 | y_bar = dts.mean() 116 | s_2_y = dts.var() 117 | m = dts.shape[0] 118 | n += m 119 | # Formula taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches 120 | s_2_x = (((n - 1) * s_2_x + (m - 1) * s_2_y) / (n + m - 1)) + ( 121 | (n * m * ((x_bar - y_bar) ** 2)) / ((n + m) * (n + m - 1))) 122 | x_bar = (n * x_bar + m * y_bar) / (n + m) 123 | 124 | print(x_bar, (s_2_x ** 0.5)) 125 | print(f'min_dt: {min_dt}') 126 | print(f'max_dt: {max_dt}') 127 | return x_bar, (s_2_x ** 0.5), min_dt, max_dt 128 | 129 | 130 | def get_data_loader(dataset: TPPDataset, backend: str, tokenizer: EventTokenizer, **kwargs): 131 | use_torch = backend == 'torch' 132 | 133 | padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy 134 | truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy 135 | 136 | if use_torch: 137 | data_collator = TPPDataCollator(tokenizer=tokenizer, 138 | return_tensors='pt', 139 | max_length=tokenizer.model_max_length, 140 | padding=padding, 141 | truncation=truncation) 142 | 143 | return DataLoader(dataset, 144 | collate_fn=data_collator, 145 | **kwargs) 146 | else: 147 | # we pass to placeholders 148 | data_collator = TPPDataCollator(tokenizer=tokenizer, 149 | return_tensors='np', 150 | max_length=tokenizer.model_max_length, 151 | padding=padding, 152 | truncation=truncation) 153 | 154 | return dataset.to_tf_dataset(data_collator, **kwargs) 155 | -------------------------------------------------------------------------------- /easy_tpp/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.runner.base_runner import Runner 2 | from easy_tpp.runner.tpp_runner import TPPRunner 3 | # for register all necessary contents 4 | from easy_tpp.default_registers.register_metrics import * 5 | 6 | __all__ = ['Runner', 7 | 'TPPRunner'] -------------------------------------------------------------------------------- /easy_tpp/runner/base_runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import abstractmethod 3 | 4 | from easy_tpp.preprocess import TPPDataLoader 5 | from easy_tpp.utils import Registrable, Timer, logger, get_unique_id, LogConst, get_stage, RunnerPhase 6 | 7 | 8 | class Runner(Registrable): 9 | """Registrable Base Runner class. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | runner_config, 15 | unique_model_dir=False, 16 | **kwargs): 17 | """Initialize the base runner. 18 | 19 | Args: 20 | runner_config (RunnerConfig): config for the runner. 21 | unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False. 22 | """ 23 | self.runner_config = runner_config 24 | # re-assign the model_dir 25 | if unique_model_dir: 26 | runner_config.model_dir = runner_config.base_config.specs['saved_model_dir'] + '_' + get_unique_id() 27 | 28 | self.save_log() 29 | 30 | skip_data_loader = kwargs.get('skip_data_loader', False) 31 | if not skip_data_loader: 32 | # build data reader 33 | data_config = self.runner_config.data_config 34 | backend = self.runner_config.base_config.backend 35 | kwargs = self.runner_config.trainer_config.get_yaml_config() 36 | self._data_loader = TPPDataLoader( 37 | data_config=data_config, 38 | backend=backend, 39 | **kwargs 40 | ) 41 | 42 | # Needed for Intensity Free model 43 | mean_log_inter_time, std_log_inter_time, min_dt, max_dt = ( 44 | self._data_loader.train_loader().dataset.get_dt_stats()) 45 | runner_config.model_config.set("mean_log_inter_time", mean_log_inter_time) 46 | runner_config.model_config.set("std_log_inter_time", std_log_inter_time) 47 | self.timer = Timer() 48 | 49 | @staticmethod 50 | def build_from_config(runner_config, unique_model_dir=False, **kwargs): 51 | """Build up the runner from runner config. 52 | 53 | Args: 54 | runner_config (RunnerConfig): config for the runner. 55 | unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False. 56 | 57 | Returns: 58 | Runner: the corresponding runner class. 59 | """ 60 | runner_cls = Runner.by_name(runner_config.base_config.runner_id) 61 | return runner_cls(runner_config, unique_model_dir=unique_model_dir, **kwargs) 62 | 63 | def get_config(self): 64 | return self.runner_config 65 | 66 | def set_model_dir(self, model_dir): 67 | self.runner_config.base_config.specs['saved_model_dir'] = model_dir 68 | 69 | def get_model_dir(self): 70 | return self.runner_config.base_config.specs['saved_model_dir'] 71 | 72 | def train( 73 | self, 74 | train_loader=None, 75 | valid_loader=None, 76 | test_loader=None, 77 | **kwargs 78 | ): 79 | """Train the model. 80 | 81 | Args: 82 | train_loader (EasyTPP.DataLoader, optional): data loader for train set. Defaults to None. 83 | valid_loader (EasyTPP.DataLoader, optional): data loader for valid set. Defaults to None. 84 | test_loader (EasyTPP.DataLoader, optional): data loader for test set. Defaults to None. 85 | 86 | Returns: 87 | model: _description_ 88 | """ 89 | # no train and valid loader from outside 90 | if train_loader is None and valid_loader is None: 91 | train_loader = self._data_loader.train_loader() 92 | valid_loader = self._data_loader.valid_loader() 93 | 94 | # no test loader from outside and there indeed exits test data in config 95 | if test_loader is None and self.runner_config.data_config.test_dir is not None: 96 | test_loader = self._data_loader.test_loader() 97 | 98 | logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') 99 | 100 | timer = self.timer 101 | timer.start() 102 | model_id = self.runner_config.base_config.model_id 103 | logger.info(f'Start {model_id} training...') 104 | model = self._train_model( 105 | train_loader, 106 | valid_loader, 107 | test_loader=test_loader, 108 | **kwargs 109 | ) 110 | logger.info(f'End {model_id} train! Cost time: {timer.end()}') 111 | return model 112 | 113 | def evaluate(self, valid_loader=None, **kwargs): 114 | if valid_loader is None: 115 | valid_loader = self._data_loader.valid_loader() 116 | 117 | logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') 118 | 119 | timer = self.timer 120 | timer.start() 121 | model_id = self.runner_config.base_config.model_id 122 | logger.info(f'Start {model_id} evaluation...') 123 | 124 | metric = self._evaluate_model( 125 | valid_loader, 126 | **kwargs 127 | ) 128 | logger.info(f'End {model_id} evaluation! Cost time: {timer.end()}') 129 | return metric['rmse'] # return a list of scalr for HPO to use 130 | 131 | def gen(self, gen_loader=None, **kwargs): 132 | if gen_loader is None: 133 | gen_loader = self._data_loader.test_loader() 134 | 135 | logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') 136 | 137 | timer = self.timer 138 | timer.start() 139 | model_name = self.runner_config.base_config.model_id 140 | logger.info(f'Start {model_name} evaluation...') 141 | 142 | model = self._gen_model( 143 | gen_loader, 144 | **kwargs 145 | ) 146 | logger.info(f'End {model_name} generation! Cost time: {timer.end()}') 147 | return model 148 | 149 | @abstractmethod 150 | def _train_model(self, train_loader, valid_loader, **kwargs): 151 | pass 152 | 153 | @abstractmethod 154 | def _evaluate_model(self, data_loader, **kwargs): 155 | pass 156 | 157 | @abstractmethod 158 | def _gen_model(self, data_loader, **kwargs): 159 | pass 160 | 161 | @abstractmethod 162 | def _save_model(self, model_dir, **kwargs): 163 | pass 164 | 165 | @abstractmethod 166 | def _load_model(self, model_dir, **kwargs): 167 | pass 168 | 169 | def save_log(self): 170 | """Save log to local files 171 | """ 172 | log_dir = self.runner_config.base_config.specs['saved_log_dir'] 173 | fh = logging.FileHandler(log_dir) 174 | fh.setFormatter(logging.Formatter(LogConst.DEFAULT_FORMAT_LONG)) 175 | logger.addHandler(fh) 176 | logger.info(f'Save the log to {log_dir}') 177 | return 178 | 179 | def save( 180 | self, 181 | model_dir=None, 182 | **kwargs 183 | ): 184 | return self._save_model(model_dir, **kwargs) 185 | 186 | def run(self, **kwargs): 187 | """Start the runner. 188 | 189 | Args: 190 | **kwargs (dict): optional params. 191 | 192 | Returns: 193 | EasyTPP.BaseModel, dict: the results of the process. 194 | """ 195 | current_stage = get_stage(self.runner_config.base_config.stage) 196 | if current_stage == RunnerPhase.TRAIN: 197 | return self.train(**kwargs) 198 | elif current_stage == RunnerPhase.VALIDATE: 199 | return self.evaluate(**kwargs) 200 | else: 201 | return self.gen(**kwargs) 202 | -------------------------------------------------------------------------------- /easy_tpp/tf_wrapper.py: -------------------------------------------------------------------------------- 1 | """ Initialize a Tf model wrapper that feed into Model Runner """ 2 | 3 | import tensorflow as tf 4 | 5 | from easy_tpp.utils import RunnerPhase 6 | from easy_tpp.utils.tf_utils import set_device, set_optimizer 7 | 8 | if tf.__version__ >= '2.0': 9 | tf = tf.compat.v1 10 | tf.disable_v2_behavior() 11 | 12 | 13 | class TfModelWrapper: 14 | def __init__(self, model, base_config, model_config, trainer_config): 15 | """A wrapper class for Tensorflow backends. 16 | 17 | Args: 18 | model (BaseModel): a TPP model. 19 | base_config (EasyTPP.Config): basic configs. 20 | model_config (EasyTPP.Config): model spec configs. 21 | trainer_config (EasyTPP.Config): trainer spec configs. 22 | """ 23 | self.model = model 24 | self.base_config = base_config 25 | self.model_config = model_config 26 | self.trainer_config = trainer_config 27 | set_device(self.trainer_config.gpu) 28 | 29 | # init session and build model 30 | tf.reset_default_graph() 31 | self.sess = tf.Session() 32 | self.model.build_graph() 33 | if self.model_config.is_training: 34 | # set up optimizer 35 | optimizer = self.trainer_config.optimizer 36 | self.learning_rate = self.trainer_config.learning_rate 37 | self.opt = set_optimizer(optimizer, self.learning_rate) 38 | self.train_op = self.opt.minimize(self.model.loss) 39 | 40 | # set up tensorboard 41 | self.use_tfb = self.trainer_config.use_tfb 42 | self.train_summary_writer, self.valid_summary_writer = None, None 43 | if self.use_tfb: 44 | self.train_summary_writer = tf.summary.FileWriter(self.base_config.spec['tfb_train_dir']) 45 | self.valid_summary_writer = tf.summary.FileWriter(self.base_config.spec['tfb_valid_dir']) 46 | 47 | # init variable and saver 48 | self.sess.run(tf.global_variables_initializer()) 49 | self.saver = tf.train.Saver() 50 | 51 | def restore(self, ckpt_dir): 52 | """Load the checkpoint to restore the model. 53 | 54 | Args: 55 | ckpt_dir (str): path for the checkpoint. 56 | """ 57 | self.saver.restore(self.sess, ckpt_dir) 58 | 59 | def save(self, ckpt_dir): 60 | """Save the checkpoint for the model. 61 | 62 | Args: 63 | ckpt_dir (str): path for the checkpoint. 64 | """ 65 | self.saver.save(self.sess, ckpt_dir) 66 | 67 | def write_summary(self, epoch, kv_pairs, phase): 68 | """Write the kv_paris into the tensorboard. 69 | 70 | Args: 71 | epoch (int): epoch index in the training. 72 | kv_pairs (dict): metrics dict. 73 | phase (RunnerPhase): a const that defines the stage of model runner. 74 | """ 75 | if self.use_tfb: 76 | summary_writer = None 77 | if phase == RunnerPhase.TRAIN: 78 | summary_writer = self.train_summary_writer 79 | elif phase == RunnerPhase.VALIDATE: 80 | summary_writer = self.valid_summary_writer 81 | elif phase == RunnerPhase.PREDICT: 82 | pass 83 | 84 | metric_summary = tf.Summary() 85 | if summary_writer is not None: 86 | for k, v in kv_pairs.items(): 87 | if k != 'num_events': 88 | metric_summary.value.add(tag=k, simple_value=v) 89 | summary_writer.add_summary(metric_summary, epoch) 90 | 91 | summary_writer.flush() 92 | return 93 | 94 | def close_summary(self): 95 | """Close the tensorboard summary writer. 96 | """ 97 | if self.train_summary_writer is not None: 98 | self.train_summary_writer.close() 99 | 100 | if self.valid_summary_writer is not None: 101 | self.valid_summary_writer.close() 102 | return 103 | 104 | def run_batch(self, batch, phase): 105 | """Run one batch. 106 | 107 | Args: 108 | batch (EasyTPP.BatchEncoding): preprocessed batch data that go into the model. 109 | phase (RunnerPhase): a const that defines the stage of model runner. 110 | 111 | Returns: 112 | tuple: for training and validation we return loss, prediction and labels; 113 | for prediction we return prediction. 114 | """ 115 | model = self.model 116 | sess = self.sess 117 | 118 | time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch.data.values() 119 | 120 | # set mode to train 121 | is_training = (phase == RunnerPhase.TRAIN) 122 | 123 | fd = { 124 | model.time_seqs: time_seqs, 125 | model.time_delta_seqs: time_delta_seqs, 126 | model.type_seqs: type_seqs, 127 | model.batch_non_pad_mask: batch_non_pad_mask, 128 | model.attention_mask: attention_mask, 129 | model.type_mask: type_mask, 130 | model.is_training: is_training 131 | 132 | } 133 | 134 | # Assume we dont do prediction on train set 135 | pred_dtime, pred_type = None, None 136 | label_dtime, label_type = time_delta_seqs[:, 1:], type_seqs[:, 1:] 137 | 138 | mask = batch_non_pad_mask[:, 1:] 139 | 140 | if phase in (RunnerPhase.TRAIN, RunnerPhase.VALIDATE): 141 | # set mode to train 142 | if is_training: 143 | _, loss, num_event = sess.run([self.train_op, 144 | model.loss, 145 | model.num_event], 146 | feed_dict=fd) 147 | else: 148 | loss, num_event = sess.run([model.loss, 149 | model.num_event], 150 | feed_dict=fd) 151 | 152 | if self.model.event_sampler: 153 | pred_dtime, pred_type = sess.run([model.dtime_predict_one_step, 154 | model.type_predict_one_step], 155 | feed_dict=fd) 156 | return loss, num_event, (pred_dtime, pred_type), (label_dtime, label_type), (mask,) 157 | else: 158 | pred_dtime, pred_type = sess.run([model.dtime_generation, 159 | model.type_generation], 160 | feed_dict=fd) 161 | num_steps = pred_dtime.shape[-1] 162 | label_dtime = time_delta_seqs[:, -num_steps:] 163 | label_type = type_seqs[:, -num_steps:] 164 | return (pred_dtime, pred_type), (label_dtime, label_type) 165 | -------------------------------------------------------------------------------- /easy_tpp/torch_wrapper.py: -------------------------------------------------------------------------------- 1 | """ Initialize a Pytorch model wrapper that feed into Model Runner """ 2 | 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | from easy_tpp.utils import RunnerPhase, set_optimizer, set_device 7 | 8 | 9 | class TorchModelWrapper: 10 | def __init__(self, model, base_config, model_config, trainer_config): 11 | """A wrapper class for Torch backends. 12 | 13 | Args: 14 | model (BaseModel): a TPP model. 15 | base_config (EasyTPP.Config): basic configs. 16 | model_config (EasyTPP.ModelConfig): model spec configs. 17 | trainer_config (EasyTPP.TrainerConfig): trainer spec configs. 18 | """ 19 | self.model = model 20 | self.base_config = base_config 21 | self.model_config = model_config 22 | self.trainer_config = trainer_config 23 | 24 | self.model_id = self.base_config.model_id 25 | self.device = set_device(self.trainer_config.gpu) 26 | 27 | self.model.to(self.device) 28 | 29 | if self.model_config.is_training: 30 | # set up optimizer 31 | optimizer = self.trainer_config.optimizer 32 | self.learning_rate = self.trainer_config.learning_rate 33 | self.opt = set_optimizer(optimizer, self.model.parameters(), self.learning_rate) 34 | 35 | # set up tensorboard 36 | self.train_summary_writer, self.valid_summary_writer = None, None 37 | if self.trainer_config.use_tfb: 38 | self.train_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_train_dir']) 39 | self.valid_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_valid_dir']) 40 | 41 | def restore(self, ckpt_dir): 42 | """Load the checkpoint to restore the model. 43 | 44 | Args: 45 | ckpt_dir (str): path for the checkpoint. 46 | """ 47 | 48 | self.model.load_state_dict(torch.load(ckpt_dir), strict=False) 49 | 50 | def save(self, ckpt_dir): 51 | """Save the checkpoint for the model. 52 | 53 | Args: 54 | ckpt_dir (str): path for the checkpoint. 55 | """ 56 | torch.save(self.model.state_dict(), ckpt_dir) 57 | 58 | def write_summary(self, epoch, kv_pairs, phase): 59 | """Write the kv_paris into the tensorboard 60 | 61 | Args: 62 | epoch (int): epoch index in the training. 63 | kv_pairs (dict): metrics dict. 64 | phase (RunnerPhase): a const that defines the stage of model runner. 65 | """ 66 | if self.trainer_config.use_tfb: 67 | summary_writer = None 68 | if phase == RunnerPhase.TRAIN: 69 | summary_writer = self.train_summary_writer 70 | elif phase == RunnerPhase.VALIDATE: 71 | summary_writer = self.valid_summary_writer 72 | elif phase == RunnerPhase.PREDICT: 73 | pass 74 | 75 | if summary_writer is not None: 76 | for k, v in kv_pairs.items(): 77 | if k != 'num_events': 78 | summary_writer.add_scalar(k, v, epoch) 79 | 80 | summary_writer.flush() 81 | return 82 | 83 | def close_summary(self): 84 | """Close the tensorboard summary writer. 85 | """ 86 | if self.train_summary_writer is not None: 87 | self.train_summary_writer.close() 88 | 89 | if self.valid_summary_writer is not None: 90 | self.valid_summary_writer.close() 91 | return 92 | 93 | def run_batch(self, batch, phase): 94 | """Run one batch. 95 | 96 | Args: 97 | batch (EasyTPP.BatchEncoding): preprocessed batch data that go into the model. 98 | phase (RunnerPhase): a const that defines the stage of model runner. 99 | 100 | Returns: 101 | tuple: for training and validation we return loss, prediction and labels; 102 | for prediction we return prediction. 103 | """ 104 | 105 | batch = batch.to(self.device).values() 106 | if phase in (RunnerPhase.TRAIN, RunnerPhase.VALIDATE): 107 | # set mode to train 108 | is_training = (phase == RunnerPhase.TRAIN) 109 | self.model.train(is_training) 110 | 111 | # FullyRNN needs grad event in validation stage 112 | grad_flag = is_training if not self.model_id == 'FullyNN' else True 113 | # run model 114 | with torch.set_grad_enabled(grad_flag): 115 | loss, num_event = self.model.loglike_loss(batch) 116 | 117 | # Assume we dont do prediction on train set 118 | pred_dtime, pred_type, label_dtime, label_type, mask = None, None, None, None, None 119 | 120 | # update grad 121 | if is_training: 122 | self.opt.zero_grad() 123 | (loss / num_event).backward() 124 | self.opt.step() 125 | else: # by default we do not do evaluation on train set which may take a long time 126 | if self.model.event_sampler: 127 | self.model.eval() 128 | with torch.no_grad(): 129 | if batch[1] is not None and batch[2] is not None: 130 | label_dtime, label_type = batch[1][:, 1:].cpu().numpy(), batch[2][:, 1:].cpu().numpy() 131 | if batch[3] is not None: 132 | mask = batch[3][:, 1:].cpu().numpy() 133 | pred_dtime, pred_type = self.model.predict_one_step_at_every_event(batch=batch) 134 | pred_dtime = pred_dtime.detach().cpu().numpy() 135 | pred_type = pred_type.detach().cpu().numpy() 136 | return loss.item(), num_event, (pred_dtime, pred_type), (label_dtime, label_type), (mask,) 137 | else: 138 | pred_dtime, pred_type, label_dtime, label_type = self.model.predict_multi_step_since_last_event(batch=batch) 139 | pred_dtime = pred_dtime.detach().cpu().numpy() 140 | pred_type = pred_type.detach().cpu().numpy() 141 | label_dtime = label_dtime.detach().cpu().numpy() 142 | label_type = label_type.detach().cpu().numpy() 143 | return (pred_dtime, pred_type), (label_dtime, label_type) 144 | -------------------------------------------------------------------------------- /easy_tpp/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.utils.const import RunnerPhase, LogConst, DefaultRunnerConfig, PaddingStrategy, TensorType, ExplicitEnum, \ 2 | TruncationStrategy 3 | from easy_tpp.utils.import_utils import is_tf_available, is_tensorflow_probability_available, is_torchvision_available, \ 4 | is_torch_cuda_available, is_torch_available, requires_backends, is_tf_gpu_available, is_torch_gpu_available 5 | from easy_tpp.utils.log_utils import default_logger as logger, DEFAULT_FORMATTER 6 | from easy_tpp.utils.metrics import MetricsHelper, MetricsTracker 7 | from easy_tpp.utils.misc import py_assert, make_config_string, create_folder, save_yaml_config, load_yaml_config, \ 8 | load_pickle, has_key, array_pad_cols, save_pickle, concat_element, get_stage, to_dict, \ 9 | dict_deep_update, save_json, load_json 10 | from easy_tpp.utils.multiprocess_utils import get_unique_id, Timer, parse_uri_to_protocol_and_path, is_master_process, \ 11 | is_local_master_process 12 | from easy_tpp.utils.ode_utils import rk4_step_method 13 | from easy_tpp.utils.registrable import Registrable 14 | from easy_tpp.utils.torch_utils import set_device, set_optimizer, set_seed, count_model_params 15 | from easy_tpp.utils.generic import is_torch_device, is_numpy_array 16 | from easy_tpp.utils.gen_utils import generate_and_save_json 17 | 18 | __all__ = ['py_assert', 19 | 'make_config_string', 20 | 'create_folder', 21 | 'save_yaml_config', 22 | 'load_yaml_config', 23 | 'RunnerPhase', 24 | 'LogConst', 25 | 'load_pickle', 26 | 'has_key', 27 | 'array_pad_cols', 28 | 'MetricsHelper', 29 | 'MetricsTracker', 30 | 'set_device', 31 | 'set_optimizer', 32 | 'set_seed', 33 | 'save_pickle', 34 | 'count_model_params', 35 | 'Registrable', 36 | 'logger', 37 | 'get_unique_id', 38 | 'Timer', 39 | 'concat_element', 40 | 'get_stage', 41 | 'to_dict', 42 | 'DEFAULT_FORMATTER', 43 | 'parse_uri_to_protocol_and_path', 44 | 'is_master_process', 45 | 'is_local_master_process', 46 | 'dict_deep_update', 47 | 'DefaultRunnerConfig', 48 | 'rk4_step_method', 49 | 'is_tf_available', 50 | 'is_tensorflow_probability_available', 51 | 'is_torchvision_available', 52 | 'is_torch_cuda_available', 53 | 'is_tf_gpu_available', 54 | 'is_torch_gpu_available', 55 | 'is_torch_available', 56 | 'requires_backends', 57 | 'PaddingStrategy', 58 | 'ExplicitEnum', 59 | 'TruncationStrategy', 60 | 'TensorType', 61 | 'is_torch_device', 62 | 'is_numpy_array', 63 | 'save_json', 64 | 'load_json', 65 | 'generate_and_save_json'] 66 | -------------------------------------------------------------------------------- /easy_tpp/utils/const.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ExplicitEnum(str, Enum): 5 | """ 6 | Enum with more explicit error message for missing values. 7 | """ 8 | 9 | def __str__(self): 10 | return str(self.value) 11 | 12 | @classmethod 13 | def _missing_(cls, value): 14 | raise ValueError( 15 | f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" 16 | ) 17 | 18 | 19 | class PaddingStrategy(ExplicitEnum): 20 | """ 21 | Possible values for the `padding` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in an 22 | IDE. 23 | """ 24 | 25 | LONGEST = "longest" 26 | MAX_LENGTH = "max_length" 27 | DO_NOT_PAD = "do_not_pad" 28 | 29 | 30 | class TensorType(ExplicitEnum): 31 | """ 32 | Possible values for the `return_tensors` argument in [`EventTokenizerBase.__call__`]. Useful for 33 | tab-completion in an IDE. 34 | """ 35 | 36 | PYTORCH = "pt" 37 | TENSORFLOW = "tf" 38 | NUMPY = "np" 39 | 40 | 41 | class RunnerPhase(ExplicitEnum): 42 | """Model runner phase enum. 43 | """ 44 | TRAIN = 'train' 45 | VALIDATE = 'validate' 46 | PREDICT = 'predict' 47 | 48 | 49 | class LossFunction(ExplicitEnum): 50 | """Loss function for neural TPP model. 51 | """ 52 | LOGLIKE = 'loglike' 53 | PARTIAL_TIME_LOSS = 'rmse' 54 | PARTIAL_EVENT_LOSS = 'accuracy' 55 | 56 | 57 | class LogConst: 58 | """Format for log handler. 59 | """ 60 | DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s' 61 | DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s]' \ 62 | ' - %(levelname)s: %(message)s' 63 | 64 | 65 | class PredOutputIndex: 66 | """Positional index for the output tuple in ModelRunner. 67 | """ 68 | TimePredIndex = 0 69 | TypePredIndex = 1 70 | 71 | 72 | class DefaultRunnerConfig: 73 | DEFAULT_DATASET_ID = 'conttime' 74 | 75 | 76 | class TruncationStrategy(ExplicitEnum): 77 | """ 78 | Possible values for the `truncation` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in 79 | an IDE. 80 | """ 81 | 82 | LONGEST_FIRST = "longest_first" 83 | DO_NOT_TRUNCATE = "do_not_truncate" 84 | 85 | 86 | class Backend(ExplicitEnum): 87 | """ 88 | Possible values for the `truncation` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in 89 | an IDE. 90 | """ 91 | 92 | Torch = 'torch' 93 | TF = 'tensorflow' 94 | -------------------------------------------------------------------------------- /easy_tpp/utils/gen_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from easy_tpp.utils.misc import save_json 3 | 4 | def generate_synthetic_data(n_nodes=3, end_time=1000, baseline=0.1, adjacency=0.5, decay=1.0): 5 | """ 6 | Generates synthetic data using a multivariate Hawkes process with exponential kernels. 7 | 8 | Args: 9 | n_nodes (int): Number of nodes (or dimensions) in the Hawkes process. 10 | end_time (float): The time until which the process is simulated. 11 | baseline (float): Baseline intensity for each node. 12 | adjacency (float): Adjacency matrix value for the influence between nodes. 13 | decay (float): Decay parameter for the exponential kernel. 14 | 15 | Returns: 16 | list: A list of lists, where each sublist contains dictionaries representing events for a node. 17 | """ 18 | baseline_vector = np.full(n_nodes, baseline) 19 | adjacency_matrix = np.full((n_nodes, n_nodes), adjacency) 20 | events = [[] for _ in range(n_nodes)] 21 | current_time = 0 22 | 23 | while current_time < end_time: 24 | # Calculate the intensity for each node 25 | intensities = baseline_vector.copy() 26 | for i in range(n_nodes): 27 | for j in range(n_nodes): 28 | if events[j]: 29 | last_event_time = events[j][-1]['time_since_start'] 30 | intensities[i] += adjacency_matrix[i, j] * np.exp(-decay * (current_time - last_event_time)) 31 | 32 | # Determine the next event time 33 | total_intensity = np.sum(intensities) 34 | if total_intensity == 0: 35 | break 36 | time_to_next_event = np.random.exponential(1 / total_intensity) 37 | current_time += time_to_next_event 38 | 39 | if current_time >= end_time: 40 | break 41 | 42 | # Determine which node the event occurs in 43 | probabilities = intensities / total_intensity 44 | node = np.random.choice(n_nodes, p=probabilities) 45 | 46 | # Record the event as a dictionary 47 | if events[node]: 48 | last_event_time = events[node][-1]['time_since_start'] 49 | else: 50 | last_event_time = 0 51 | 52 | event = { 53 | 'time_since_start': current_time, 54 | 'time_since_last_event': current_time - last_event_time, 55 | 'type_event': node 56 | } 57 | events[node].append(event) 58 | 59 | return events 60 | 61 | def format_tick_data_to_hf(events, dim_process, max_seq_len): 62 | """ 63 | Formats the synthetic data from a multivariate Hawkes process to the Hugging Face dataset format. 64 | 65 | Args: 66 | events (list): A list of lists, where each sublist contains dictionaries representing events for a node. 67 | dim_process (int): Number of nodes (or dimensions) in the Hawkes process. 68 | max_seq_len (int): Maximum sequence length. 69 | 70 | Returns: 71 | list: A list of dictionaries, where each dictionary represents a sequence. 72 | """ 73 | # Flatten all events into a single list 74 | all_events = [event for node_events in events for event in node_events] 75 | 76 | # Sort events by time_since_start 77 | all_events.sort(key=lambda x: x['time_since_start']) 78 | 79 | # Split into multiple sequences based on max_seq_len 80 | formatted_data = [] 81 | for seq_idx in range(0, len(all_events), max_seq_len): 82 | seq_events = all_events[seq_idx:seq_idx + max_seq_len] 83 | 84 | # Adjust time_since_start to have zero start timestamps 85 | start_time = seq_events[0]['time_since_start'] 86 | time_since_start = [event['time_since_start'] - start_time for event in seq_events] 87 | time_since_last_event = [event['time_since_last_event'] for event in seq_events] 88 | type_event = [event['type_event'] for event in seq_events] 89 | 90 | temp_dict = { 91 | 'dim_process': dim_process, 92 | 'seq_idx': seq_idx // max_seq_len, 93 | 'seq_len': len(seq_events), 94 | 'time_since_start': time_since_start, 95 | 'time_since_last_event': time_since_last_event, 96 | 'type_event': type_event 97 | } 98 | formatted_data.append(temp_dict) 99 | 100 | return formatted_data 101 | 102 | def generate_and_save_json(n_nodes, end_time, baseline, adjacency, decay, max_seq_len, target_file): 103 | """ 104 | Generates synthetic data, formats it, and saves it to a file in Hugging Face format. 105 | 106 | Args: 107 | n_nodes (int): Number of nodes (or dimensions) in the Hawkes process. 108 | end_time (float): The time until which the process is simulated. 109 | baseline (float): Baseline intensity for each node. 110 | adjacency (float): Adjacency matrix value for the influence between nodes. 111 | decay (float): Decay parameter for the exponential kernel. 112 | max_seq_len (int): Maximum sequence length. 113 | target_file (str): Path to the file where the formatted data will be saved. 114 | 115 | Raises: 116 | IOError: If the file cannot be opened or written to. 117 | """ 118 | events = generate_synthetic_data(n_nodes, end_time, baseline, adjacency, decay) 119 | formatted_data = format_tick_data_to_hf(events, dim_process=n_nodes, max_seq_len=max_seq_len) 120 | save_json(formatted_data, target_file) -------------------------------------------------------------------------------- /easy_tpp/utils/generic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from easy_tpp.utils import is_torch_available, is_tf_available 4 | 5 | 6 | def is_tensor(x): 7 | """ 8 | Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`. 9 | """ 10 | if is_torch_available(): 11 | import torch 12 | 13 | if isinstance(x, torch.Tensor): 14 | return True 15 | if is_tf_available(): 16 | import tensorflow as tf 17 | 18 | if isinstance(x, tf.Tensor): 19 | return True 20 | 21 | return isinstance(x, np.ndarray) 22 | 23 | 24 | def _is_numpy(x): 25 | return isinstance(x, np.ndarray) 26 | 27 | 28 | def is_numpy_array(x): 29 | """ 30 | Tests if `x` is a numpy array or not. 31 | """ 32 | return _is_numpy(x) 33 | 34 | 35 | def _is_torch(x): 36 | import torch 37 | 38 | return isinstance(x, torch.Tensor) 39 | 40 | 41 | def is_torch_tensor(x): 42 | """ 43 | Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. 44 | """ 45 | return False if not is_torch_available() else _is_torch(x) 46 | 47 | 48 | def _is_torch_device(x): 49 | import torch 50 | 51 | return isinstance(x, torch.device) 52 | 53 | 54 | def is_torch_device(x): 55 | """ 56 | Tests if `x` is a torch device or not. Safe to call even if torch is not installed. 57 | """ 58 | return False if not is_torch_available() else _is_torch_device(x) 59 | 60 | 61 | def _is_torch_dtype(x): 62 | import torch 63 | 64 | if isinstance(x, str): 65 | if hasattr(torch, x): 66 | x = getattr(torch, x) 67 | else: 68 | return False 69 | return isinstance(x, torch.dtype) 70 | 71 | 72 | def is_torch_dtype(x): 73 | """ 74 | Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. 75 | """ 76 | return False if not is_torch_available() else _is_torch_dtype(x) 77 | 78 | 79 | def _is_tensorflow(x): 80 | import tensorflow as tf 81 | 82 | return isinstance(x, tf.Tensor) 83 | 84 | 85 | def is_tf_tensor(x): 86 | """ 87 | Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed. 88 | """ 89 | return False if not is_tf_available() else _is_tensorflow(x) 90 | 91 | 92 | def _is_tf_symbolic_tensor(x): 93 | import tensorflow as tf 94 | 95 | # the `is_symbolic_tensor` predicate is only available starting with TF 2.14 96 | if hasattr(tf, "is_symbolic_tensor"): 97 | return tf.is_symbolic_tensor(x) 98 | return type(x) == tf.Tensor 99 | 100 | 101 | def is_tf_symbolic_tensor(x): 102 | """ 103 | Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not 104 | installed. 105 | """ 106 | return False if not is_tf_available() else _is_tf_symbolic_tensor(x) 107 | -------------------------------------------------------------------------------- /easy_tpp/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import sys 3 | from collections import OrderedDict 4 | from typing import Union, Tuple 5 | 6 | from easy_tpp.utils.log_utils import default_logger as logger 7 | 8 | if sys.version_info < (3, 8): 9 | import importlib_metadata 10 | else: 11 | import importlib.metadata as importlib_metadata 12 | 13 | 14 | def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: 15 | # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version 16 | package_exists = importlib.util.find_spec(pkg_name) is not None 17 | package_version = "N/A" 18 | if package_exists: 19 | try: 20 | package_version = importlib_metadata.version(pkg_name) 21 | except importlib_metadata.PackageNotFoundError: 22 | pass 23 | logger.debug(f"Detected {pkg_name} version {package_version}") 24 | if return_version: 25 | return package_exists, package_version 26 | else: 27 | return package_exists 28 | 29 | 30 | _tf_available = _is_package_available("tensorflow") 31 | if _tf_available: 32 | candidates = ( 33 | "tensorflow", 34 | "tensorflow-cpu", 35 | "tensorflow-gpu", 36 | "tf-nightly", 37 | "tf-nightly-cpu", 38 | "tf-nightly-gpu", 39 | "intel-tensorflow", 40 | "intel-tensorflow-avx512", 41 | "tensorflow-rocm", 42 | "tensorflow-macos", 43 | "tensorflow-aarch64", 44 | ) 45 | _tf_version = None 46 | # For the metadata, we have to look for both tensorflow and tensorflow-cpu 47 | for pkg in candidates: 48 | try: 49 | _tf_version = importlib_metadata.version(pkg) 50 | break 51 | except importlib_metadata.PackageNotFoundError: 52 | pass 53 | _tf_available = _tf_version is not None 54 | 55 | _tensorflow_probability_available = _is_package_available("tensorflow_probability") 56 | _torchdistx_available = _is_package_available("torchdistx") 57 | _torchvision_available = _is_package_available("torchvision") 58 | 59 | _torch_available, _torch_version = _is_package_available("torch", return_version=True) 60 | 61 | 62 | def is_torch_available(): 63 | return _torch_available 64 | 65 | 66 | def get_torch_version(): 67 | return _torch_version 68 | 69 | 70 | def is_torchvision_available(): 71 | return _torchvision_available 72 | 73 | 74 | def is_torch_cuda_available(): 75 | if is_torch_available(): 76 | import torch 77 | 78 | return torch.cuda.is_available() 79 | else: 80 | return False 81 | 82 | 83 | def is_tf_available(): 84 | return _tf_available 85 | 86 | 87 | def is_tf_gpu_available(): 88 | if is_tf_available(): 89 | import tensorflow as tf 90 | if tf.__version__ >= '2.0': 91 | return bool(tf.config.list_physical_devices("GPU")) 92 | else: 93 | from tensorflow.python.client import device_lib 94 | local_device_protos = device_lib.list_local_devices() 95 | for device in local_device_protos: 96 | if device.device_type == 'GPU': 97 | return True 98 | else: 99 | return False 100 | 101 | 102 | def is_torch_mps_available(): 103 | if is_torch_available(): 104 | try: 105 | import torch 106 | torch.device('mps') 107 | return True 108 | except RuntimeError: 109 | return False 110 | else: 111 | return False 112 | 113 | 114 | def is_torch_gpu_available(): 115 | is_cuda_available = is_torch_cuda_available() 116 | 117 | is_mps_available = is_torch_mps_available() 118 | 119 | return is_cuda_available | is_mps_available 120 | 121 | 122 | def is_tensorflow_probability_available(): 123 | return _tensorflow_probability_available 124 | 125 | 126 | def torch_only_method(fn): 127 | def wrapper(*args, **kwargs): 128 | if not _torch_available: 129 | raise ImportError( 130 | "You need to install pytorch to use this method or class, " 131 | "or activate it with environment variables USE_TORCH=1 and USE_TF=0." 132 | ) 133 | else: 134 | return fn(*args, **kwargs) 135 | 136 | return wrapper 137 | 138 | 139 | # docstyle-ignore 140 | PYTORCH_IMPORT_ERROR = """ 141 | {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the 142 | installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. 143 | Please note that you may need to restart your runtime after installation. 144 | """ 145 | 146 | # docstyle-ignore 147 | TORCHVISION_IMPORT_ERROR = """ 148 | {0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the 149 | installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. 150 | Please note that you may need to restart your runtime after installation. 151 | """ 152 | 153 | # docstyle-ignore 154 | PYTORCH_IMPORT_ERROR_WITH_TF = """ 155 | {0} requires the PyTorch library but it was not found in your environment. 156 | However, we were able to find a TensorFlow installation. TensorFlow classes begin 157 | with "TF", but are otherwise identically named to our PyTorch classes. This 158 | means that the TF equivalent of the class you tried to import would be "TF{0}". 159 | If you want to use TensorFlow, please use TF classes instead! 160 | 161 | If you really do want to use PyTorch please go to 162 | https://pytorch.org/get-started/locally/ and follow the instructions that 163 | match your environment. 164 | """ 165 | 166 | # docstyle-ignore 167 | TF_IMPORT_ERROR_WITH_PYTORCH = """ 168 | {0} requires the TensorFlow library but it was not found in your environment. 169 | However, we were able to find a PyTorch installation. PyTorch classes do not begin 170 | with "TF", but are otherwise identically named to our TF classes. 171 | If you want to use PyTorch, please use those classes instead! 172 | 173 | If you really do want to use TensorFlow, please follow the instructions on the 174 | installation page https://www.tensorflow.org/install that match your environment. 175 | """ 176 | 177 | # docstyle-ignore 178 | TENSORFLOW_IMPORT_ERROR = """ 179 | {0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the 180 | installation page: https://www.tensorflow.org/install and follow the ones that match your environment. 181 | Please note that you may need to restart your runtime after installation. 182 | """ 183 | 184 | # docstyle-ignore 185 | TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ 186 | {0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as 187 | explained here: https://github.com/tensorflow/probability. 188 | Please note that you may need to restart your runtime after installation. 189 | """ 190 | 191 | BACKENDS_MAPPING = OrderedDict( 192 | [ 193 | ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), 194 | ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), 195 | ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), 196 | ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)) 197 | ] 198 | ) 199 | 200 | 201 | def requires_backends(obj, backends): 202 | if not isinstance(backends, (list, tuple)): 203 | backends = [backends] 204 | 205 | name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ 206 | 207 | # Raise an error for users who might not realize that classes without "TF" are torch-only 208 | if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): 209 | raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) 210 | 211 | # Raise the inverse error for PyTorch users trying to load TF classes 212 | if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): 213 | raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) 214 | 215 | checks = (BACKENDS_MAPPING[backend] for backend in backends) 216 | failed = [msg.format(name) for available, msg in checks if not available()] 217 | if failed: 218 | raise ImportError("".join(failed)) 219 | -------------------------------------------------------------------------------- /easy_tpp/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import typing 4 | 5 | from easy_tpp.utils.const import LogConst 6 | 7 | # -------- log setting --------- 8 | DEFAULT_LOGGER = "easytpp.logger" 9 | 10 | 11 | class CustomFormatter(logging.Formatter): 12 | grey = "\x1b[38;20m" 13 | yellow = "\x1b[33;20m" 14 | red = "\x1b[31;20m" 15 | bold_red = "\x1b[31;1m" 16 | reset = "\x1b[0m" 17 | format = LogConst.DEFAULT_FORMAT_LONG 18 | 19 | FORMATS = { 20 | logging.DEBUG: grey + format + reset, 21 | logging.INFO: grey + format + reset, 22 | logging.WARNING: yellow + format + reset, 23 | logging.ERROR: red + format + reset, 24 | logging.CRITICAL: bold_red + format + reset 25 | } 26 | 27 | def format(self, record): 28 | log_fmt = self.FORMATS.get(record.levelno) 29 | formatter = logging.Formatter(log_fmt) 30 | return formatter.format(record) 31 | 32 | 33 | DEFAULT_FORMATTER = CustomFormatter() 34 | 35 | _ch = logging.StreamHandler(stream=sys.stdout) 36 | _ch.setFormatter(DEFAULT_FORMATTER) 37 | 38 | _DEFAULT_HANDLERS = [_ch] 39 | 40 | _LOGGER_CACHE = {} # type: typing.Dict[str, logging.Logger] 41 | 42 | 43 | def get_logger(name, level="INFO", handlers=None, update=False): 44 | if name in _LOGGER_CACHE and not update: 45 | return _LOGGER_CACHE[name] 46 | logger = logging.getLogger(name) 47 | logger.setLevel(level) 48 | logger.handlers = handlers or _DEFAULT_HANDLERS 49 | logger.propagate = False 50 | return logger 51 | 52 | 53 | # -------------------------- Singleton Object -------------------------- 54 | default_logger = get_logger(DEFAULT_LOGGER) 55 | -------------------------------------------------------------------------------- /easy_tpp/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | from easy_tpp.utils.log_utils import default_logger as logger 6 | 7 | 8 | class MetricsHelper: 9 | MAXIMIZE = 'maximize' 10 | MINIMIZE = 'minimize' 11 | _registry_center = defaultdict(tuple) 12 | 13 | @staticmethod 14 | def get_metric_function(name): 15 | if name in MetricsHelper._registry_center: 16 | return MetricsHelper._registry_center[name][0] 17 | else: 18 | logger.warn(f'Metric is not found: {name}') 19 | return None 20 | 21 | @staticmethod 22 | def get_metric_direction(name): 23 | if name in MetricsHelper._registry_center: 24 | return MetricsHelper._registry_center[name][1] 25 | else: 26 | return None 27 | 28 | @staticmethod 29 | def get_all_registered_metric(): 30 | return MetricsHelper._registry_center.values 31 | 32 | @staticmethod 33 | def register(name, direction, overwrite=True): 34 | registry_center = MetricsHelper._registry_center 35 | 36 | def _add_metric_to_registry(func): 37 | if name in registry_center: 38 | if overwrite: 39 | registry_center[name] = (func, direction) 40 | else: 41 | logger.warn(f'The metric {name} is already registered, and cannot be overwritten!') 42 | else: 43 | registry_center[name] = (func, direction) 44 | return func 45 | 46 | return _add_metric_to_registry 47 | 48 | @staticmethod 49 | def metrics_dict_to_str(metrics_dict): 50 | """ Convert metrics to a string to show in console """ 51 | eval_info = '' 52 | for k, v in metrics_dict.items(): 53 | eval_info += '{0} is {1}, '.format(k, v) 54 | 55 | return eval_info[:-2] 56 | 57 | @staticmethod 58 | def get_metrics_callback_from_names(metric_names): 59 | """ Metrics function callbacks """ 60 | metric_functions = [] 61 | metric_names_ = [] 62 | for name in metric_names: 63 | metric = MetricsHelper.get_metric_function(name) 64 | if metric is not None: 65 | metric_functions.append(metric) 66 | metric_names_.append(name) 67 | 68 | def metrics(preds, labels, **kwargs): 69 | """ call metrics functions """ 70 | res = dict() 71 | for metric_name, metric_func in zip(metric_names_, metric_functions): 72 | res[metric_name.lower()] = metric_func(preds, labels, **kwargs) 73 | return res 74 | 75 | return metrics 76 | 77 | 78 | class MetricsTracker: 79 | """Track and record the metrics. 80 | """ 81 | 82 | def __init__(self): 83 | self.current_best = { 84 | 'loglike': np.finfo(float).min, 85 | 'distance': np.finfo(float).max 86 | } 87 | self.episode_best = 'NeverUpdated' 88 | 89 | def update_best(self, key, value, epoch): 90 | """Update the recorder for the best metrics. 91 | 92 | Args: 93 | key (str): metrics key. 94 | value (float): metrics value. 95 | epoch (int): num of epoch. 96 | 97 | Raises: 98 | NotImplementedError: for keys other than 'loglike'. 99 | 100 | Returns: 101 | bool: whether the recorder has been updated. 102 | """ 103 | updated = False 104 | if key == 'loglike': 105 | if value > self.current_best[key]: 106 | updated = True 107 | self.current_best[key] = value 108 | self.episode_best = epoch 109 | else: 110 | raise NotImplementedError 111 | 112 | return updated 113 | -------------------------------------------------------------------------------- /easy_tpp/utils/multiprocess_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def is_master_process(): 6 | """ Check if the process is the master process in all machines. 7 | 8 | Returns: 9 | bool 10 | """ 11 | rank = 0 if os.getenv('RANK') is None else int(os.getenv('RANK')) 12 | if rank == 0: 13 | return True 14 | else: 15 | return False 16 | 17 | 18 | def is_local_master_process(): 19 | """ Check if the process is the master process in the local machine. 20 | 21 | Returns: 22 | bool 23 | """ 24 | rank = 0 if os.getenv('RANK') is None else int(os.getenv('RANK')) 25 | local_world_size = 1 if os.getenv('LOCAL_WORLD_SIZE') is None else int(os.getenv('LOCAL_WORLD_SIZE')) 26 | if local_world_size == 0 or rank % local_world_size == 0: 27 | return True 28 | else: 29 | return False 30 | 31 | 32 | def get_now_timestamp_id(): 33 | """ Get the current timestamp string. 34 | 35 | Returns: 36 | A string like yyMMdd_hhmmss 37 | """ 38 | import datetime 39 | return datetime.datetime.now().strftime('%y%m%d-%H%M%S') 40 | 41 | 42 | def get_unique_id(): 43 | """ Generate a unique id string based on process id (pid), thread id and timestamp. 44 | 45 | Returns: 46 | Unique id: str 47 | """ 48 | import os 49 | import threading 50 | pid = os.getpid() 51 | tid = threading.currentThread().ident 52 | ts_id = get_now_timestamp_id() 53 | 54 | return '{}_{}_{}'.format(pid, tid, ts_id) 55 | 56 | 57 | def parse_uri_to_protocol_and_path(uri): 58 | """ Parse a uri into two parts, protocol and path. Set 'file' as default protocol when lack protocol. 59 | 60 | Args: 61 | uri: str 62 | The uri to identify a resource, whose format is like 'protocol://uri'. 63 | 64 | Returns: 65 | Protocol: str. The method to access the resource. 66 | URI: str. The location of the resource. 67 | """ 68 | 69 | if uri is None: 70 | return None, None 71 | tokens = uri.split('://') 72 | if len(tokens) == 2: 73 | protocol = tokens[0] 74 | path = tokens[1] 75 | elif len(tokens) == 1: 76 | protocol = 'file' 77 | path = tokens[0] 78 | else: 79 | raise RuntimeError(f'Wrong url format: {uri}') 80 | 81 | return protocol, path 82 | 83 | 84 | class Timer: 85 | """Count the elapsing time between start and end. 86 | """ 87 | 88 | def __init__(self, unit='m'): 89 | unit = unit.lower() 90 | if unit == 's': 91 | self._unit = 1 92 | elif unit == 'm': 93 | self._unit = 60 94 | elif unit == 'h': 95 | self._unit = 1440 96 | else: 97 | raise RuntimeError('Unknown unit:', unit) 98 | 99 | self.unit = unit 100 | # default start time is set to the time the object initialized 101 | self._start_time = time.time() 102 | 103 | def start(self): 104 | self._start_time = time.time() 105 | 106 | def end(self): 107 | end_time = time.time() 108 | cost = (end_time - self._start_time) / self._unit 109 | # reset the start time using the end time 110 | self._start_time = end_time 111 | return '%.3f%s' % (cost, self.unit) 112 | 113 | 114 | # -------------------------- Singleton Object -------------------------- 115 | default_timer = Timer() 116 | -------------------------------------------------------------------------------- /easy_tpp/utils/ode_utils.py: -------------------------------------------------------------------------------- 1 | def ode_update_op(z0, dz, dt): 2 | """ 3 | General update operation for solving ODEs. 4 | 5 | Args: 6 | z0: Tensor or a list for Tensor whose shape is [..., dim] 7 | State at t0. 8 | dz: Tensor or a list for Tensor whose shape is [..., dim] 9 | Differentiation of state. 10 | dt: Tensor with shape [..., 1] 11 | Equal to t1 - t0. 12 | 13 | Returns: 14 | 15 | """ 16 | if isinstance(z0, list) or isinstance(z0, tuple): 17 | return [item_z + dt * item_dz for item_z, item_dz in zip(z0, dz)] 18 | else: 19 | return z0 + dt * dz 20 | 21 | 22 | def euler_step_method(diff_func, dt, z0): 23 | """ 24 | Euler method for solving ODEs. 25 | 26 | Args: 27 | diff_func: function(state) 28 | Differential equation. 29 | dt: Tensor with shape [..., 1] 30 | Equal to t1 - t0. 31 | z0: Tensor or a list for Tensor whose shape is [..., dim] 32 | State at t0. 33 | 34 | Returns: 35 | Tensor or a list for Tensor whose shape is [..., dim], which is updated state. 36 | """ 37 | dz = diff_func(z0) 38 | return ode_update_op(z0, dz, dt) 39 | 40 | 41 | def rk2_step_method(diff_func, dt, z0): 42 | """ 43 | Second order Runge-Kutta method for solving ODEs. 44 | 45 | Args: 46 | diff_func: function(dt, state) 47 | Differential equation. 48 | dt: Tensor with shape [..., 1] 49 | Equal to t1 - t0. 50 | z0: Tensor or a list for Tensor whose shape is [..., dim] 51 | State at t0. 52 | 53 | Returns: 54 | Tensor or a list for Tensor whose shape is [..., dim] 55 | """ 56 | # shape -> [..., dim] 57 | k1 = diff_func(z0) 58 | k2 = diff_func(ode_update_op(z0, k1, dt)) 59 | 60 | if isinstance(z0, list) or isinstance(z0, tuple): 61 | return [item_z + (item_k1 + item_k2) * dt * 0.5 for item_z, item_k1, item_k2 in zip(z0, k1, k2)] 62 | else: 63 | return z0 + dt * (k1 + k2) * 0.5 64 | 65 | 66 | def rk4_step_method(diff_func, dt, z0): 67 | """ 68 | Fourth order Runge-Kutta method for solving ODEs. 69 | 70 | Args: 71 | diff_func: function(dt, state) 72 | Differential equation. 73 | dt: Tensor with shape [..., 1] 74 | Equal to t1 - t0. 75 | z0: Tensor with shape [..., dim] 76 | State at t0. 77 | 78 | Returns: 79 | Tensor with shape [..., dim], which is updated state. 80 | """ 81 | # shape -> [..., dim] 82 | k1 = diff_func(z0) 83 | k2 = diff_func(ode_update_op(z0, k1, dt / 2.0)) 84 | k3 = diff_func(ode_update_op(z0, k2, dt / 2.0)) 85 | k4 = diff_func(ode_update_op(z0, k3, dt)) 86 | 87 | if isinstance(z0, list) or isinstance(z0, tuple): 88 | return [item_z + (item_k1 + 2.0 * item_k2 + 2.0 * item_k3 + item_k4) * dt / 6.0 89 | for item_z, item_k1, item_k2, item_k3, item_k4 in zip(z0, k1, k2, k3, k4)] 90 | else: 91 | return z0 + dt * (k1 + k2 * 2.0 + k3 * 2.0 + k4) / 6.0 92 | -------------------------------------------------------------------------------- /easy_tpp/utils/registrable.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .log_utils import default_logger as logger 4 | 5 | 6 | class Registrable: 7 | """Any class that inherits from ``Registrable`` gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod ``@BaseClass.register(name)``. 8 | 9 | After which you can call ``BaseClass.list_available()`` to get the keys for the registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass. 10 | 11 | Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call ``from_params(params)`` on the returned subclass. 12 | """ 13 | 14 | _registry = defaultdict(dict) 15 | _default_impl = None 16 | 17 | @classmethod 18 | def register(cls, name, constructor=None, overwrite=False): 19 | """Register a class under a particular name. 20 | Args: 21 | name (str): The name to register the class under. 22 | constructor (str): optional (default=None) 23 | The name of the method to use on the class to construct the object. If this is given, 24 | we will use this method (which must be a ``classmethod``) instead of the default 25 | constructor. 26 | overwrite (bool) : optional (default=False) 27 | If True, overwrites any existing models registered under ``name``. Else, 28 | throws an error if a model is already registered under ``name``. 29 | 30 | # Examples 31 | To use this class, you would typically have a base class that inherits from ``Registrable``: 32 | ```python 33 | class Transform(Registrable): 34 | ... 35 | ``` 36 | Then, if you want to register a subclass, you decorate it like this: 37 | ```python 38 | @Transform.register("shift-transform") 39 | class ShiftTransform(Transform): 40 | def __init__(self, param1: int, param2: str): 41 | ... 42 | ``` 43 | Registering a class like this will let you instantiate a class from a config file, where you 44 | give ``"type": "shift-transform"``, and keys corresponding to the parameters of the ``__init__`` 45 | method (note that for this to work, those parameters must have type annotations). 46 | If you want to have the instantiation from a config file call a method other than the 47 | constructor, either because you have several different construction paths that could be 48 | taken for the same object (as we do in ``Transform``) or because you have logic you want to 49 | happen before you get to the constructor, you can register a specific ``@classmethod`` as the constructor to use. 50 | """ 51 | registry = Registrable._registry[cls] 52 | 53 | def add_subclass_to_registry(subclass): 54 | # Add to registry, raise an error if key has already been used. 55 | if name in registry: 56 | if overwrite: 57 | message = ( 58 | f"{name} has already been registered as {registry[name][0].__name__}, but " 59 | f"overwrite=True, so overwriting with {cls.__name__}" 60 | ) 61 | logger.info(message) 62 | else: 63 | message = ( 64 | f"Cannot register {name} as {cls.__name__}; " 65 | f"name already in use for {registry[name][0].__name__}" 66 | ) 67 | raise RuntimeError(message) 68 | registry[name] = (subclass, constructor) 69 | return subclass 70 | 71 | return add_subclass_to_registry 72 | 73 | @classmethod 74 | def by_name(cls, name): 75 | """ 76 | Returns a callable function that constructs an argument of the registered class. Because 77 | you can register particular functions as constructors for specific names, this isn't 78 | necessarily the ``__init__`` method of some class. 79 | """ 80 | logger.debug(f"instantiating registered subclass {name} of {cls}") 81 | subclass, constructor = cls.resolve_class_name(name) 82 | if not constructor: 83 | return subclass 84 | else: 85 | return getattr(subclass, constructor) 86 | 87 | @classmethod 88 | def resolve_class_name(cls, name): 89 | """ 90 | Returns the subclass that corresponds to the given ``name``, along with the name of the 91 | method that was registered as a constructor for that ``name``, if any. 92 | This method also allows ``name`` to be a fully-specified module name, instead of a name that 93 | was already added to the ``Registry``. In that case, you cannot use a separate function as 94 | a constructor (as you need to call ``cls.register()`` in order to tell us what separate 95 | function to use). 96 | """ 97 | if name in Registrable._registry[cls]: 98 | subclass, constructor = Registrable._registry[cls].get(name) 99 | return subclass, constructor 100 | else: 101 | for base_cls, v in Registrable._registry.items(): 102 | if name in v: 103 | subclass, constructor = Registrable._registry[base_cls].get(name) 104 | return subclass, constructor 105 | 106 | if "." in name: 107 | # This might be a fully qualified class name, so we'll try importing its "module" 108 | # and finding it there. 109 | parts = name.split(".") 110 | submodule = ".".join(parts[:-1]) 111 | class_name = parts[-1] 112 | import importlib 113 | try: 114 | module = importlib.import_module(submodule) 115 | except ModuleNotFoundError: 116 | raise RuntimeError( 117 | f"tried to interpret {name} as a path to a class " 118 | f"but unable to import module {submodule}" 119 | ) 120 | 121 | try: 122 | subclass = getattr(module, class_name) 123 | constructor = None 124 | return subclass, constructor 125 | except AttributeError: 126 | raise RuntimeError( 127 | f"tried to interpret {name} as a path to a class " 128 | f"but unable to find class {class_name} in {submodule}" 129 | ) 130 | 131 | else: 132 | # is not a qualified class name 133 | raise RuntimeError( 134 | f"{name} is not a registered name for {cls.__name__}. " 135 | "You probably need to use the --include-package flag " 136 | "to load your custom code. Alternatively, you can specify your choices " 137 | """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ 138 | "in which case they will be automatically imported correctly." 139 | ) 140 | 141 | @classmethod 142 | def list_available(cls): 143 | """List default first if it exists""" 144 | keys = list(Registrable._registry[cls].keys()) 145 | default = cls._default_impl 146 | 147 | if default is None: 148 | return keys 149 | elif default not in keys: 150 | raise RuntimeError(f"Default implementation {default} is not registered") 151 | else: 152 | return [default] + [k for k in keys if k != default] 153 | 154 | @classmethod 155 | def registry_dict(cls): 156 | return Registrable._registry[cls] 157 | -------------------------------------------------------------------------------- /easy_tpp/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from easy_tpp.utils.import_utils import is_tf_gpu_available 8 | 9 | if tf.__version__ >= '2.0': 10 | tf = tf.compat.v1 11 | tf.disable_v2_behavior() 12 | 13 | 14 | def set_seed(seed=1029): 15 | """Setup random seed. 16 | 17 | Args: 18 | seed (int, optional): random seed. Defaults to 1029. 19 | """ 20 | random.seed(seed) 21 | os.environ["PYTHONHASHSEED"] = str(seed) 22 | np.random.seed(seed) 23 | tf.random.set_random_seed(seed) 24 | 25 | 26 | def set_device(gpu=-1): 27 | """Setup the device. 28 | 29 | Args: 30 | gpu (int, optional): Defaults to -1. 31 | """ 32 | if gpu >= 0 and is_tf_gpu_available(): 33 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 34 | else: 35 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 36 | return 37 | 38 | 39 | def set_optimizer(optimizer, lr): 40 | """Setup the optimizer. 41 | 42 | Args: 43 | optimizer (str): name of the optimizer. 44 | lr (float): learning rate. 45 | 46 | Raises: 47 | NotImplementedError: if the optimizer's name is wrong or the optimizer is not supported, 48 | we raise error. 49 | 50 | Returns: 51 | tf.train.optimzer: tf optimizer. 52 | """ 53 | optimizer = optimizer.capitalize() + 'Optimizer' 54 | try: 55 | optimizer = getattr(tf.train, optimizer)(learning_rate=lr) 56 | except Exception: 57 | raise NotImplementedError("optimizer={} is not supported.".format(optimizer)) 58 | 59 | return optimizer 60 | 61 | 62 | def get_shape_list(x): 63 | """Deal with dynamic shape in tensorflow cleanly. 64 | 65 | Args: 66 | x (tensor): input tensor. 67 | 68 | Returns: 69 | list: shape list of the tensor. 70 | """ 71 | static = x.shape.as_list() 72 | dynamic = tf.shape(x) 73 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 74 | 75 | 76 | def tensordot(tensor_a, tensor_b): 77 | """ Tensor dot function. The last dimension of tensor_a and the first dimension of tensor_b must be the same. 78 | 79 | Args: 80 | tensor_a (tensor): input tensor. 81 | tensor_b (tensor): input tensor. 82 | 83 | Returns: 84 | tensor: the result of tensor_a tensor dot tensor_b. 85 | """ 86 | last_idx_a = len(tensor_a.get_shape().as_list()) - 1 87 | return tf.tensordot(tensor_a, tensor_b, [[last_idx_a], [0]]) 88 | 89 | 90 | def swap_axes(tensor, axis1, axis2): 91 | """Interchange two axes of an tensor. 92 | :param tensor: 93 | :param axis1: First axis. 94 | :param axis2: Second axis. 95 | :return: 96 | """ 97 | tensor_perm = list(range(len(tensor.shape.as_list()))) 98 | tensor_perm[axis1] = axis2 99 | tensor_perm[axis2] = axis1 100 | 101 | return tf.transpose(tensor, perm=tensor_perm) 102 | 103 | 104 | def create_tensor(shape, value): 105 | """Creates a tensor with all elements set to be the value. 106 | 107 | Args: 108 | shape (list): the shape of the target tensor to be created. 109 | value (float): value to fill the tensor. 110 | 111 | Returns: 112 | tensor: created tensor with target value filled. 113 | """ 114 | tensor_shape = tf.stack(shape) 115 | return tf.fill(tensor_shape, value) 116 | 117 | 118 | def count_model_params(): 119 | """Count the number of params of the model. 120 | 121 | Args: 122 | model (tf.keras.Model): a torch model. 123 | 124 | Returns: 125 | int: total num of the parameters. 126 | """ 127 | return np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 128 | -------------------------------------------------------------------------------- /easy_tpp/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from easy_tpp.utils.import_utils import is_torch_mps_available 8 | 9 | 10 | def set_seed(seed=1029): 11 | """Setup random seed. 12 | 13 | Args: 14 | seed (int, optional): random seed. Defaults to 1029. 15 | """ 16 | random.seed(seed) 17 | os.environ["PYTHONHASHSEED"] = str(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | def set_device(gpu=-1): 25 | """Setup the device. 26 | 27 | Args: 28 | gpu (int, optional): num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU). 29 | """ 30 | if gpu >= 0: 31 | if torch.cuda.is_available(): 32 | device = torch.device("cuda:" + str(gpu)) 33 | elif is_torch_mps_available(): 34 | device = torch.device("mps") 35 | else: 36 | device = torch.device("cpu") 37 | return device 38 | 39 | 40 | def set_optimizer(optimizer, params, lr): 41 | """Setup the optimizer. 42 | 43 | Args: 44 | optimizer (str): name of the optimizer. 45 | params (dict): dict of params for the optimizer. 46 | lr (float): learning rate. 47 | 48 | Raises: 49 | NotImplementedError: if the optimizer's name is wrong or the optimizer is not supported, 50 | we raise error. 51 | 52 | Returns: 53 | torch.optim: torch optimizer. 54 | """ 55 | if isinstance(optimizer, str): 56 | if optimizer.lower() == "adam": 57 | optimizer = "Adam" 58 | try: 59 | optimizer = getattr(torch.optim, optimizer)(params, lr=lr) 60 | except Exception: 61 | raise NotImplementedError("optimizer={} is not supported.".format(optimizer)) 62 | return optimizer 63 | 64 | 65 | def count_model_params(model): 66 | """Count the number of params of the model. 67 | 68 | Args: 69 | model (torch.nn.Moduel): a torch model. 70 | 71 | Returns: 72 | int: total num of the parameters. 73 | """ 74 | return sum(p.numel() for p in model.parameters()) 75 | -------------------------------------------------------------------------------- /examples/configs/hpo_config.yaml: -------------------------------------------------------------------------------- 1 | pipeline_config_id: hpo_runner_config 2 | 3 | data: 4 | taxi: 5 | data_format: pkl 6 | train_dir: ./data/taxi/train.pkl 7 | valid_dir: ./data/taxi/dev.pkl 8 | test_dir: ./data/taxi/test.pkl 9 | data_specs: 10 | num_event_types: 10 11 | pad_token_id: 10 12 | padding_side: right 13 | truncation_side: right 14 | 15 | hpo: 16 | storage_uri: sqlite://hpo_test.db 17 | is_continuous: False 18 | framework_id: optuna # the framework of hpo 19 | n_trials: 10 20 | 21 | 22 | NHP_train: 23 | base_config: 24 | stage: train 25 | backend: torch 26 | dataset_id: taxi 27 | runner_id: std_tpp 28 | model_id: NHP # model name 29 | base_dir: './checkpoints/' 30 | trainer_config: 31 | batch_size: 256 32 | max_epoch: 200 33 | shuffle: False 34 | optimizer: adam 35 | learning_rate: 1.e-3 36 | valid_freq: 1 37 | use_tfb: False 38 | metrics: [ 'acc', 'rmse' ] 39 | seed: 2019 40 | gpu: -1 41 | model_config: 42 | hidden_size: 64 43 | loss_integral_num_sample_per_step: 20 44 | # pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model 45 | thinning: 46 | num_seq: 10 47 | num_sample: 1 48 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 49 | look_ahead_time: 10 50 | patience_counter: 5 # the maximum iteration used in adaptive thinning 51 | over_sample_rate: 5 52 | num_samples_boundary: 5 53 | dtime_max: 5 54 | num_step_gen: 1 55 | 56 | -------------------------------------------------------------------------------- /examples/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/EasyTemporalPointProcess/7e2b7a001a293c506bd595e8ddb72d83967c2cb2/examples/data/.gitkeep -------------------------------------------------------------------------------- /examples/data_inspection/config.yaml: -------------------------------------------------------------------------------- 1 | pipeline_config_id: data_config 2 | 3 | data_format: json 4 | train_dir: easytpp/taxi # ./data/taxi/train.json 5 | valid_dir: easytpp/taxi # ./data/taxi/dev.json 6 | test_dir: easytpp/taxi # ./data/taxi/test.json 7 | data_specs: 8 | num_event_types: 10 9 | pad_token_id: 10 10 | padding_side: right -------------------------------------------------------------------------------- /examples/data_inspection/data_inspection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # Get the directory of the current file 4 | current_file_path = os.path.abspath(__file__) 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))) 6 | 7 | from easy_tpp.config_factory import Config 8 | from easy_tpp.preprocess.data_loader import TPPDataLoader 9 | 10 | 11 | def main(): 12 | config = Config.build_from_yaml_file('./config.yaml') 13 | tpp_loader = TPPDataLoader(config) 14 | stats = tpp_loader.get_statistics(split='train') 15 | print(stats) 16 | tpp_loader.plot_event_type_distribution() 17 | tpp_loader.plot_event_delta_times_distribution() 18 | 19 | if __name__ == '__main__': 20 | main() -------------------------------------------------------------------------------- /examples/data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from easy_tpp.config_factory import DataSpecConfig 4 | from easy_tpp.preprocess import EventTokenizer 5 | from easy_tpp.preprocess.dataset import TPPDataset, get_data_loader 6 | 7 | 8 | def make_raw_data(): 9 | data = [ 10 | [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 0}], 11 | [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], 12 | [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], 13 | ] 14 | for i, j in enumerate([2, 5, 3]): 15 | start_time = 0 16 | for k in range(j): 17 | delta_t = random.random() 18 | start_time += delta_t 19 | data[i].append({"time_since_last_event": delta_t, 20 | "time_since_start": start_time, 21 | "type_event": random.randint(0, 10) 22 | }) 23 | 24 | return data 25 | 26 | 27 | def main(): 28 | source_data = make_raw_data() 29 | 30 | time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] 31 | type_seqs = [[x["type_event"] for x in seq] for seq in source_data] 32 | time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] 33 | 34 | input_data = {'time_seqs': time_seqs, 35 | 'type_seqs': type_seqs, 36 | 'time_delta_seqs': time_delta_seqs} 37 | 38 | config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 11, 'batch_size': 1, 39 | 'pad_token_id': 11}) 40 | 41 | dataset = TPPDataset(input_data) 42 | 43 | tokenizer = EventTokenizer(config) 44 | 45 | loader = get_data_loader(dataset, 'torch', tokenizer) 46 | 47 | for batch in loader: 48 | print(batch) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /examples/event_tokenizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from easy_tpp.preprocess.event_tokenizer import EventTokenizer 4 | from easy_tpp.config_factory import DataSpecConfig 5 | 6 | def make_raw_data(): 7 | data = [ 8 | [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 0}], 9 | [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], 10 | [{"time_since_last_event": 0, "time_since_start": 0, "type_event": 1}], 11 | ] 12 | for i, j in enumerate([2, 5, 3]): 13 | start_time = 0 14 | for k in range(j): 15 | delta_t = random.random() 16 | start_time += delta_t 17 | data[i].append({"time_since_last_event": delta_t, 18 | "time_since_start": start_time, 19 | "type_event": random.randint(0, 10) 20 | }) 21 | 22 | return data 23 | 24 | 25 | def main(): 26 | source_data = make_raw_data() 27 | 28 | time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] 29 | type_seqs = [[x["type_event"] for x in seq] for seq in source_data] 30 | time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] 31 | 32 | input_data = {'time_seqs': time_seqs, 33 | 'type_seqs': type_seqs, 34 | 'time_delta_seqs': time_delta_seqs} 35 | 36 | config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 11, 'pad_token_id': 11}) 37 | 38 | tokenizer = EventTokenizer(config) 39 | 40 | output = tokenizer.pad(input_data, return_tensors='tf') 41 | 42 | print(output) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /examples/gen_synthetic_data.py: -------------------------------------------------------------------------------- 1 | from easy_tpp.utils.gen_utils import generate_and_save_json 2 | 3 | if __name__ == "__main__": 4 | generate_and_save_json(n_nodes=3, 5 | end_time=100, 6 | baseline=1, 7 | adjacency=0.5, 8 | decay=0.1, 9 | max_seq_len=40, 10 | target_file='synthetic_data.json') 11 | -------------------------------------------------------------------------------- /examples/hf_data_loader.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | def load_data_from_hf(hf_dir=None, local_dir=None): 4 | if hf_dir: 5 | ds = load_dataset(hf_dir) 6 | else: 7 | ds = load_dataset('json', data_files=local_dir) 8 | print(ds) 9 | print('dim process: ' + str(ds['validation'].data['dim_process'][0].as_py())) 10 | print('num seqs: ' + str(ds['validation'].data['num_seqs'][0].as_py())) 11 | print('avg seq len: ' + str(ds['validation'].data['avg_seq_len'][0].as_py())) 12 | print('min seq len: ' + str(ds['validation'].data['min_seq_len'][0].as_py())) 13 | print('max seq len: ' + str(ds['validation'].data['max_seq_len'][0].as_py())) 14 | return 15 | 16 | 17 | if __name__ == '__main__': 18 | # in case one fails to load from hf directly 19 | # one can load the json data file locally 20 | # load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'}) 21 | load_data_from_hf(hf_dir='easytpp/taxi') -------------------------------------------------------------------------------- /examples/script_data_processing/earthquake.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | warnings.filterwarnings('ignore') 9 | 10 | 11 | # data source: https://earthquake.usgs.gov/earthquakes/search/ 12 | 13 | def event_type_map(mag): 14 | if mag < 2.75: 15 | return 0 16 | elif mag < 3.0: 17 | return 1 18 | elif mag < 3.5: 19 | return 2 20 | elif mag < 4.0: 21 | return 3 22 | elif mag < 4.5: 23 | return 4 24 | elif mag < 5.0: 25 | return 5 26 | else: 27 | return 6 28 | 29 | 30 | def clean_csv(source_dir): 31 | df = pd.read_csv(source_dir, header=0) 32 | 33 | df.drop_duplicates(inplace=True) 34 | 35 | df.sort_values(by=['time'], inplace=True) 36 | print(len(df)) 37 | df = df[['time', 'mag']] 38 | df['event_type'] = df['mag'].apply(lambda x: event_type_map(x)) 39 | 40 | df.to_csv('earthquake.csv', index=False, header=True) 41 | return 42 | 43 | 44 | def make_seq(df): 45 | seq = [] 46 | df['time_diff'] = df['event_time'].diff() 47 | df.index = np.arange(len(df)) 48 | for index, row in df.iterrows(): 49 | if index == 0: 50 | event_dict = {"time_since_last_event": 0.0, 51 | "time_since_start": 0.0, 52 | "type_event": row['event_type'] 53 | } 54 | start_event_time = row['event_time'] 55 | else: 56 | event_dict = {"time_since_last_event": row['time_diff'], 57 | "time_since_start": row['event_time'] - start_event_time, 58 | "type_event": row['event_type'] 59 | } 60 | seq.append(event_dict) 61 | 62 | return seq 63 | 64 | 65 | def make_pkl(target_dir, dim_process, split, seqs): 66 | with open(target_dir, "wb") as f_out: 67 | pickle.dump( 68 | { 69 | "dim_process": dim_process, 70 | split: seqs 71 | }, f_out 72 | ) 73 | return 74 | 75 | 76 | def make_dataset(source_dir): 77 | df = pd.read_csv(source_dir, header=0) 78 | df['time'] = pd.to_datetime(df['time']) 79 | 80 | norm_const = 10000 81 | df['event_time'] = df['time'].apply(lambda x: datetime.timestamp(x)) / norm_const 82 | seq_len = np.random.randint(15, 19, 4300) 83 | print(np.sum(seq_len)) 84 | 85 | seq_start_idx = [0] + list(np.cumsum(seq_len)[:-1] - 1) 86 | seq_end_idx = np.cumsum(seq_len) - 1 87 | 88 | total_seq = [make_seq(df.iloc[start_idx:end_idx, :]) for (start_idx, end_idx) in 89 | zip(seq_start_idx, seq_end_idx)] 90 | 91 | print(len(total_seq)) 92 | make_pkl('train.pkl', 7, 'train', total_seq[:3000]) 93 | print(np.sum(seq_len[:3000])) 94 | make_pkl('dev.pkl', 7, 'dev', total_seq[3000:3400]) 95 | print(np.sum(seq_len[3000:3400])) 96 | make_pkl('test.pkl', 7, 'test', total_seq[3400:]) 97 | print(np.sum(seq_len[3400:])) 98 | 99 | # 70794 100 | # 4300 101 | # 49364 102 | # 6612 103 | # 14818 104 | 105 | return 106 | 107 | 108 | if __name__ == '__main__': 109 | # clean_csv() 110 | make_dataset('earthquake.csv') 111 | -------------------------------------------------------------------------------- /examples/script_data_processing/make_hf_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | from easy_tpp.utils import load_pickle 6 | 7 | 8 | def make_json_serializable(input_dict): 9 | for k, v in input_dict.items(): 10 | if isinstance(v, np.float32): 11 | input_dict[k] = float(v) 12 | elif isinstance(v, np.int32): 13 | input_dict[k] = int(v) 14 | 15 | return input_dict 16 | 17 | 18 | def make_hf_dataset(source_dir, target_dir, split='test'): 19 | data_pkl = load_pickle(source_dir) 20 | 21 | dim_process = int(data_pkl['dim_process']) 22 | 23 | data_json = [] 24 | for idx, seq in enumerate(data_pkl[split]): 25 | seq_len = len(seq) 26 | time_since_start, time_since_last_event, type_event = [], [], [] 27 | for idx_event, event in enumerate(data_pkl[split][idx]): 28 | # if idx_event == 0 and event['time_since_start'] > 0: 29 | # start_timestamp = event['time_since_start'] 30 | # else: 31 | # start_timestamp = 0 32 | if idx_event == 0 and event['time_since_last_event'] > 0: 33 | event['time_since_last_event'] = 0 34 | 35 | # event['time_since_start'] -= start_timestamp 36 | 37 | event = make_json_serializable(event) 38 | time_since_start.append(time_since_start) 39 | time_since_last_event.append(event['time_since_last_event']) 40 | type_event.append(event['type_event']) 41 | 42 | # re-calculate the time_since start 43 | from itertools import accumulate 44 | time_since_start = list(accumulate(time_since_last_event)) 45 | 46 | temp_dict = {'dim_process': dim_process, 47 | 'seq_idx': idx, 48 | 'seq_len': seq_len, 49 | 'time_since_start': time_since_start, 50 | 'time_since_last_event': time_since_last_event, 51 | 'type_event': type_event} 52 | data_json.append(temp_dict) 53 | 54 | with open(target_dir, "w") as outfile: 55 | json.dump(data_json, outfile) 56 | 57 | return 58 | 59 | 60 | if __name__ == '__main__': 61 | test_data_dir = ['amazon/test.pkl', 'amazon/test.json'] 62 | dev_data_dir = ['amazon/dev.pkl', 'amazon/dev.json'] 63 | train_data_dir = ['amazon/train.pkl', 'amazon/train.json'] 64 | make_hf_dataset(source_dir=test_data_dir[0], target_dir=test_data_dir[1]) 65 | make_hf_dataset(source_dir=dev_data_dir[0], target_dir=dev_data_dir[1], split='dev') 66 | make_hf_dataset(source_dir=train_data_dir[0], target_dir=train_data_dir[1], split='train') 67 | -------------------------------------------------------------------------------- /examples/script_data_processing/taobao.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | warnings.filterwarnings('ignore') 8 | 9 | 10 | # source data: https://tianchi.aliyun.com/dataset/dataDetail?dataId=649 11 | 12 | def check_dominate_event_type(event_type_seq, threshold=0.7): 13 | event_type = np.unique(event_type_seq) 14 | total_len = len(event_type_seq) 15 | type_ratio = [len(event_type_seq[event_type_seq == event_type_i]) / total_len for event_type_i in event_type] 16 | 17 | return True if max(type_ratio) > threshold else False 18 | 19 | 20 | def cate_map(cate_id, cate_event_map_df): 21 | res = cate_event_map_df[cate_event_map_df['cate'] == cate_id]['event_id'].to_list()[0] 22 | return res 23 | 24 | 25 | def read_data_step_3(source_dir, cate_dir, target_dir): 26 | train_df = pd.read_csv(source_dir, header=0) 27 | 28 | cate_event_map_df = pd.read_csv(cate_dir, header=0) 29 | 30 | train_df['event_type'] = train_df['cate_id'].apply(lambda x: cate_map(x, cate_event_map_df)) 31 | print(train_df['event_type'].value_counts(normalize=True)) 32 | unique_user_id = np.unique(train_df['user_id']) 33 | 34 | for idx, user_id in enumerate(unique_user_id): 35 | user_df = train_df[train_df['user_id'] == user_id] 36 | prev_time = user_df.iloc[0, 4] 37 | event_dtime = user_df['event_dtime'].values 38 | event_time = user_df['event_time'].values 39 | event_dtime[0] = 0.0 40 | 41 | for i in range(1, len(event_time)): 42 | if event_dtime[i] > 50.0: # too large interval 43 | rand_dt = np.random.random() + 0.1 44 | event_time[i] = prev_time + rand_dt 45 | event_dtime[i] = rand_dt 46 | else: 47 | event_time[i] = event_time[i - 1] + event_dtime[i] 48 | prev_time = event_time[i] 49 | 50 | user_df['event_dtime'] = event_dtime 51 | user_df['event_time'] = event_time 52 | 53 | print(min(event_dtime[1:]), max(event_dtime)) 54 | 55 | assert abs(np.mean(user_df['event_time'].diff().values[1:]) - np.mean(event_dtime[1:])) < 0.0001 56 | 57 | train_df.to_csv(target_dir) 58 | return 59 | 60 | 61 | def read_data_step_2(source_dir): 62 | train_df = pd.read_csv(source_dir, header=None) 63 | train_df.columns = ['user_id', 'item_id', 'cate_id', 'event_type_raw', 'event_time'] 64 | count = train_df['cate_id'].value_counts(normalize=True) 65 | pd.DataFrame(count).to_csv('taobao_map.csv', header=True) 66 | 67 | return 68 | 69 | 70 | def read_data_step_1(source_dir, target_dir): 71 | train_df = pd.read_csv(source_dir, header=None) 72 | train_df.columns = ['user_id', 'item_id', 'cate_id', 'event_type_raw', 'event_time'] 73 | train_df['event_time'] /= 10000 74 | unique_user_id = np.unique(train_df['user_id']) 75 | 76 | train_df = train_df[train_df['event_type_raw'] == 'pv'] 77 | 78 | res = pd.DataFrame() 79 | total_seq = 0 80 | 81 | for idx, user_id in enumerate(unique_user_id): 82 | print(f'user {idx}') 83 | user_df = train_df[train_df['user_id'] == user_id] 84 | 85 | # drop consecutive duplicate on pv 86 | user_df = user_df.loc[user_df['cate_id'].shift() != user_df['cate_id']] 87 | user_df.fillna(0.0, inplace=True) 88 | 89 | user_df.sort_values(by=['event_time'], inplace=True) 90 | user_df['event_dtime'] = user_df['event_time'].diff() 91 | 92 | user_df.fillna(0.0, inplace=True) 93 | 94 | # drop dtime < 0.05 95 | user_df = user_df[user_df['event_dtime'] > 0.1] 96 | 97 | if len(user_df) < 40: 98 | print('user seq is too short, skip it') 99 | continue 100 | 101 | total_seq += 1 102 | print(f'{total_seq} users have been recorded') 103 | res = pd.concat([res, user_df]) 104 | if total_seq > 2000: 105 | break 106 | 107 | res.to_csv(target_dir, header=True, index=False) 108 | 109 | return 110 | 111 | 112 | def save_data(source_dir): 113 | df = pd.read_csv(source_dir, header=0) 114 | unique_user_id = np.unique(df['user_id']) 115 | res = [] 116 | print(np.unique(df['event_type'])) 117 | for idx, user_id in enumerate(unique_user_id): 118 | print(f'user {idx}') 119 | user_seq = [] 120 | user_df = df[df['user_id'] == user_id] 121 | length = 0 122 | for idx_row, row in user_df.iterrows(): 123 | event_dtime = 0 if length == 0 else row['event_dtime'] 124 | user_seq.append({"time_since_last_event": event_dtime, 125 | "time_since_start": row['event_time'], 126 | "type_event": row['event_type'] 127 | }) 128 | length += 1 129 | 130 | res.append(user_seq) 131 | 132 | with open('../data/taobao/train.pkl', "wb") as f_out: 133 | pickle.dump( 134 | { 135 | "dim_process": 17, 136 | 'train': res[:1300] 137 | }, f_out 138 | ) 139 | 140 | with open('../data/taobao/dev.pkl', "wb") as f_out: 141 | pickle.dump( 142 | { 143 | "dim_process": 17, 144 | 'dev': res[1300:1500] 145 | }, f_out 146 | ) 147 | 148 | with open('../data/taobao/test.pkl', "wb") as f_out: 149 | pickle.dump( 150 | { 151 | "dim_process": 17, 152 | 'test': res[1500:] 153 | }, f_out 154 | ) 155 | 156 | return 157 | -------------------------------------------------------------------------------- /examples/script_data_processing/taxi.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | 4 | warnings.filterwarnings('ignore') 5 | 6 | def read_data_step_1(): 7 | 8 | def read_pkl(file_dir): 9 | res = [] 10 | taxi = pickle.load(open(file_dir, "rb" )) 11 | count = 0 12 | for seq in taxi['seqs']: 13 | if len(seq) > 34: 14 | count += 1 15 | res.append(seq) 16 | # print(np.max(seq['time_since_last_event'])) 17 | print(count) 18 | return res 19 | 20 | # from Mei et al 's paper on event imputation 21 | train_res = read_pkl('pilottaxi/big/train.pkl') 22 | dev_res = read_pkl('pilottaxi/big/dev.pkl') 23 | test_res = read_pkl('pilottaxi/big/test1.pkl') 24 | 25 | with open('../data/taxi/train.pkl', "wb") as f_out: 26 | pickle.dump( 27 | { 28 | "dim_process": 10, 29 | 'train': train_res[:1500] 30 | }, f_out 31 | ) 32 | 33 | with open('../data/taxi/dev.pkl', "wb") as f_out: 34 | pickle.dump( 35 | { 36 | "dim_process": 10, 37 | 'dev': dev_res[:200] 38 | }, f_out 39 | ) 40 | 41 | with open('../data/taxi/test.pkl', "wb") as f_out: 42 | pickle.dump( 43 | { 44 | "dim_process": 10, 45 | 'test': test_res[:400] 46 | }, f_out 47 | ) 48 | 49 | return 50 | 51 | if __name__ == '__main__': 52 | read_data_step_1() -------------------------------------------------------------------------------- /examples/script_data_processing/volcano.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pickle 3 | import warnings 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | warnings.filterwarnings('ignore') 9 | 10 | 11 | def make_datetime(year, month, day): 12 | try: 13 | date = datetime.datetime(int(year), int(month), int(day)) 14 | except ValueError as e: 15 | if e.args[0] == 'day is out of range for month': 16 | date = datetime.datetime(int(year), int(month), int(day)-1) 17 | return datetime.datetime.timestamp(date) + 61851630000 # make sure the timestamp is positive 18 | 19 | 20 | def clean_csv(): 21 | source_dir = 'events.csv' 22 | 23 | df = pd.read_csv(source_dir, header=0) 24 | 25 | df = df[~df['event_date_year'].isna()] 26 | df = df[df['event_date_year'] > 0] 27 | df['event_date_month'].fillna(1, inplace=True) 28 | df['event_date_day'].fillna(1, inplace=True) 29 | df.drop_duplicates(inplace=True) 30 | norm_const = 1000000 31 | df['event_timestamp'] = df.apply( 32 | lambda x: make_datetime(x['event_date_year'], x['event_date_month'], x['event_date_day']), 33 | axis=1)/norm_const 34 | df.sort_values(by=['event_date_year', 'event_date_month', 'event_date_day'], inplace=True) 35 | df['event_type'] = [0] * len(df) 36 | 37 | df.to_csv('volcano.csv', index=False, header=True) 38 | return 39 | 40 | 41 | def make_seq(df): 42 | seq = [] 43 | df['time_diff'] = df['event_timestamp'].diff() 44 | df.index = np.arange(len(df)) 45 | for index, row in df.iterrows(): 46 | if index == 0: 47 | event_dict = {"time_since_last_event": 0.0, 48 | "time_since_start": 0.0, 49 | "type_event": row['event_type'] 50 | } 51 | start_event_time = row['event_timestamp'] 52 | else: 53 | event_dict = {"time_since_last_event": row['time_diff'], 54 | "time_since_start": row['event_timestamp'] - start_event_time, 55 | "type_event": row['event_type'] 56 | } 57 | seq.append(event_dict) 58 | 59 | return seq 60 | 61 | 62 | def make_pkl(target_dir, dim_process, split, seqs): 63 | with open(target_dir, "wb") as f_out: 64 | pickle.dump( 65 | { 66 | "dim_process": dim_process, 67 | split: seqs 68 | }, f_out 69 | ) 70 | return 71 | 72 | 73 | def make_dataset(source_dir): 74 | df = pd.read_csv(source_dir, header=0) 75 | 76 | vols = np.unique(df['volcano_name']) 77 | total_seq = [] 78 | for vol in vols: 79 | df_ = df[df['volcano_name'] == vol] 80 | df_.sort_values('event_timestamp', inplace=True) 81 | total_seq.append(make_seq(df_)) 82 | 83 | 84 | print(len(total_seq)) 85 | make_pkl('train.pkl', 1, 'train', total_seq[:400]) 86 | count_seq(total_seq[:400]) 87 | make_pkl('dev.pkl', 1, 'dev', total_seq[400:450]) 88 | count_seq(total_seq[400:450]) 89 | make_pkl('test.pkl', 1, 'test', total_seq[450:]) 90 | count_seq(total_seq[450:]) 91 | 92 | 93 | return 94 | 95 | 96 | def count_seq(seqs): 97 | total_len = [len(seq) for seq in seqs] 98 | print(np.mean(total_len)) 99 | print(np.sum(total_len)) 100 | 101 | return 102 | 103 | if __name__ == '__main__': 104 | # clean_csv() 105 | make_dataset('volcano.csv') 106 | -------------------------------------------------------------------------------- /examples/train_experiment/retweet_config.yaml: -------------------------------------------------------------------------------- 1 | pipeline_config_id: runner_config 2 | 3 | data: 4 | retweet: 5 | data_format: json 6 | train_dir: easytpp/retweet 7 | valid_dir: easytpp/retweet 8 | test_dir: easytpp/retweet 9 | data_specs: 10 | num_event_types: 3 11 | pad_token_id: 3 12 | padding_side: right 13 | truncation_side: right 14 | 15 | NHP_train: 16 | base_config: 17 | stage: train 18 | backend: torch 19 | dataset_id: retweet 20 | runner_id: std_tpp 21 | model_id: NHP # model name 22 | base_dir: './checkpoints/' 23 | trainer_config: 24 | batch_size: 256 25 | max_epoch: 20 26 | shuffle: False 27 | optimizer: adam 28 | learning_rate: 1.e-3 29 | valid_freq: 1 30 | use_tfb: False 31 | metrics: [ 'acc', 'rmse' ] 32 | seed: 2019 33 | gpu: -1 34 | model_config: 35 | hidden_size: 64 36 | loss_integral_num_sample_per_step: 20 37 | thinning: 38 | num_seq: 10 39 | num_sample: 1 40 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 41 | look_ahead_time: 10 42 | patience_counter: 5 # the maximum iteration used in adaptive thinning 43 | over_sample_rate: 5 44 | num_samples_boundary: 5 45 | dtime_max: 5 46 | num_step_gen: 1 47 | 48 | 49 | 50 | SAHP_train: 51 | base_config: 52 | stage: train 53 | backend: torch 54 | dataset_id: taxi 55 | runner_id: std_tpp 56 | model_id: SAHP # model name 57 | base_dir: './checkpoints/' 58 | trainer_config: 59 | batch_size: 256 60 | max_epoch: 20 61 | shuffle: False 62 | optimizer: adam 63 | learning_rate: 1.e-3 64 | valid_freq: 1 65 | use_tfb: False 66 | metrics: [ 'acc', 'rmse' ] 67 | seed: 2019 68 | gpu: 0 69 | model_config: 70 | hidden_size: 32 71 | time_emb_size: 16 72 | num_layers: 2 73 | num_heads: 2 74 | loss_integral_num_sample_per_step: 20 75 | use_ln: False 76 | thinning: 77 | num_seq: 10 78 | num_sample: 1 79 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 80 | look_ahead_time: 10 81 | patience_counter: 5 # the maximum iteration used in adaptive thinning 82 | over_sample_rate: 5 83 | num_samples_boundary: 5 84 | dtime_max: 5 85 | num_step_gen: 1 86 | 87 | 88 | 89 | SAHP_gen: 90 | base_config: 91 | stage: gen 92 | backend: torch 93 | dataset_id: retweet 94 | runner_id: std_tpp 95 | model_id: SAHP # model name 96 | base_dir: './checkpoints/' 97 | trainer_config: 98 | batch_size: 256 99 | max_epoch: 1 100 | model_config: 101 | hidden_size: 16 102 | time_emb_size: 4 103 | num_layers: 2 104 | num_heads: 2 105 | loss_integral_num_sample_per_step: 20 106 | use_ln: False 107 | thinning: 108 | num_seq: 10 109 | num_sample: 1 110 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 111 | look_ahead_time: 10 112 | patience_counter: 5 # the maximum iteration used in adaptive thinning 113 | over_sample_rate: 5 114 | num_samples_boundary: 5 115 | dtime_max: 5 116 | num_step_gen: 10 117 | 118 | THP_train: 119 | base_config: 120 | stage: train 121 | backend: torch 122 | dataset_id: taxi 123 | runner_id: std_tpp 124 | model_id: THP # model name 125 | base_dir: './checkpoints/' 126 | trainer_config: 127 | batch_size: 256 128 | max_epoch: 30 129 | shuffle: False 130 | optimizer: adam 131 | learning_rate: 1.e-3 132 | valid_freq: 1 133 | use_tfb: False 134 | metrics: [ 'acc', 'rmse' ] 135 | seed: 2019 136 | gpu: -1 137 | model_config: 138 | hidden_size: 32 139 | time_emb_size: 16 140 | num_layers: 2 141 | num_heads: 2 142 | mc_num_sample_per_step: 20 143 | loss_integral_num_sample_per_step: 20 144 | use_ln: False 145 | thinning: 146 | num_seq: 10 147 | num_sample: 1 148 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 149 | look_ahead_time: 10 150 | patience_counter: 5 # the maximum iteration used in adaptive thinning 151 | over_sample_rate: 5 152 | num_samples_boundary: 5 153 | dtime_max: 5 154 | num_step_gen: 1 155 | 156 | 157 | THP_gen: 158 | base_config: 159 | stage: gen 160 | backend: torch 161 | dataset_id: retweet 162 | runner_id: std_tpp 163 | model_id: THP # model name 164 | base_dir: './checkpoints/' 165 | trainer_config: 166 | batch_size: 256 167 | max_epoch: 1 168 | model_config: 169 | hidden_size: 32 170 | time_emb_size: 16 171 | num_layers: 2 172 | num_heads: 2 173 | mc_num_sample_per_step: 20 174 | loss_integral_num_sample_per_step: 20 175 | use_ln: False 176 | # pretrained_model_dir: ./checkpoints/2694_4384867712_230603-160544/models/saved_model 177 | thinning: 178 | num_seq: 10 179 | num_sample: 1 180 | num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 181 | look_ahead_time: 10 182 | patience_counter: 5 # the maximum iteration used in adaptive thinning 183 | over_sample_rate: 5 184 | num_samples_boundary: 5 185 | dtime_max: 5 186 | num_step_gen: 10 187 | 188 | AttNHP_train: 189 | base_config: 190 | stage: train 191 | backend: torch 192 | dataset_id: taxi 193 | runner_id: std_tpp 194 | model_id: AttNHP # model name 195 | base_dir: './checkpoints/' 196 | trainer_config: 197 | batch_size: 256 198 | max_epoch: 200 199 | shuffle: False 200 | optimizer: adam 201 | learning_rate: 1.e-3 202 | valid_freq: 1 203 | use_tfb: False 204 | metrics: [ 'acc', 'rmse' ] 205 | seed: 2019 206 | gpu: -1 207 | model_config: 208 | hidden_size: 16 209 | time_emb_size: 4 210 | num_layers: 2 211 | num_heads: 2 212 | loss_integral_num_sample_per_step: 10 213 | use_ln: False 214 | thinning: 215 | num_seq: 2 216 | num_sample: 1 217 | num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 218 | look_ahead_time: 10 219 | patience_counter: 5 # the maximum iteration used in adaptive thinning 220 | over_sample_rate: 5 221 | num_samples_boundary: 5 222 | dtime_max: 5 223 | num_step_gen: 1 224 | 225 | 226 | AttNHP_gen: 227 | base_config: 228 | stage: gen 229 | backend: torch 230 | dataset_id: retweet 231 | runner_id: std_tpp 232 | model_id: AttNHP # model name 233 | base_dir: './checkpoints/' 234 | trainer_config: 235 | batch_size: 256 236 | max_epoch: 1 237 | model_config: 238 | hidden_size: 16 239 | time_emb_size: 4 240 | num_layers: 2 241 | num_heads: 2 242 | mc_num_sample_per_step: 20 243 | loss_integral_num_sample_per_step: 20 244 | use_ln: False 245 | # pretrained_model_dir: ./checkpoints/6934_4375315840_230603-222826/models/saved_model 246 | thinning: 247 | num_seq: 10 248 | num_sample: 1 249 | num_exp: 50 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm 250 | look_ahead_time: 10 251 | patience_counter: 5 # the maximum iteration used in adaptive thinning 252 | over_sample_rate: 5 253 | num_samples_boundary: 5 254 | dtime_max: 5 255 | num_step_gen: 10 -------------------------------------------------------------------------------- /examples/train_experiment/run_retweet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from easy_tpp.config_factory import Config 4 | from easy_tpp.runner import Runner 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--config_dir', type=str, required=False, default='retweet_config.yaml', 11 | help='Dir of configuration yaml to train and evaluate the model.') 12 | 13 | parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', 14 | help='Experiment id in the config file.') 15 | 16 | args = parser.parse_args() 17 | 18 | config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) 19 | 20 | model_runner = Runner.build_from_config(config) 21 | 22 | model_runner.run() 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /examples/train_nhp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from easy_tpp.config_factory import Config 4 | from easy_tpp.runner import Runner 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--config_dir', type=str, required=False, default='configs/experiment_config.yaml', 11 | help='Dir of configuration yaml to train and evaluate the model.') 12 | 13 | parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', 14 | help='Experiment id in the config file.') 15 | 16 | args = parser.parse_args() 17 | 18 | config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) 19 | 20 | model_runner = Runner.build_from_config(config) 21 | 22 | model_runner.run() 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /examples/train_nhp_hpo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from easy_tpp.config_factory import Config 4 | from easy_tpp.hpo import HyperTuner 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--config_dir', type=str, required=False, default='configs/hpo_config.yaml', 11 | help='Dir of configuration yaml to train and evaluate the model.') 12 | 13 | parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train', 14 | help='Experiment id in the config file.') 15 | 16 | args = parser.parse_args() 17 | 18 | config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) 19 | 20 | tuner = HyperTuner.build_from_config(config) 21 | 22 | tuner.run() 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /examples/train_nhp_omegaconf.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | from easy_tpp.config_factory import ModelConfig 4 | from easy_tpp.model.torch_model.torch_nhp import NHP 5 | 6 | 7 | def main(): 8 | config_omegaconf = OmegaConf.load('configs/experiment_config.yaml') 9 | 10 | model_config_dict = config_omegaconf.get('NHP_train').get('model_config') 11 | model_config_dict['num_event_types'] = 10 12 | model_config_dict['num_event_types_pad'] = 11 13 | model_config_dict['event_pad_index'] = 10 14 | 15 | model_config = ModelConfig.parse_from_yaml_config(model_config_dict) 16 | 17 | nhp_model = NHP(model_config) 18 | 19 | print(nhp_model.__dict__) 20 | 21 | # config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id) 22 | # 23 | # model_runner = Runner.build_from_config(config) 24 | # 25 | # model_runner.run() 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /requirements-doc.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | myst_parser 4 | nbsphinx 5 | nbsphinx_link 6 | sphinx_fontawesome 7 | sphinx-autobuild 8 | recommonmark 9 | sphinx_markdown_tables -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML>=5.1 2 | numpy 3 | pandas 4 | torch 5 | tensorboard 6 | packaging 7 | datasets 8 | omegaconf 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | from setuptools import find_packages 5 | from setuptools import setup 6 | 7 | 8 | def readme(): 9 | with codecs.open('README.md', encoding='utf-8') as f: 10 | content = f.read() 11 | return content 12 | 13 | 14 | def get_version(): 15 | version_file = os.path.join(os.path.dirname(__file__), "version.py") 16 | version_regex = r"__version__ = ['\"]([^'\"]*)['\"]" 17 | with open(version_file, "r") as f: 18 | version = re.search(version_regex, f.read(), re.M).group(1) 19 | return version 20 | 21 | 22 | def parse_requirements(fname='requirements.txt'): 23 | """Parse the package dependencies listed in a requirements file.""" 24 | 25 | def parse_line(line): 26 | """Parse information from a line in a requirements text file.""" 27 | if line.startswith('-r '): 28 | # Allow specifying requirements in other files 29 | target = line.split(' ')[1] 30 | for line in parse_require_file(target): 31 | yield line 32 | else: 33 | yield line 34 | 35 | def parse_require_file(fpath): 36 | with codecs.open(fpath, 'r') as f: 37 | for line in f.readlines(): 38 | line = line.strip() 39 | if line and not line.startswith('#'): 40 | for ll in parse_line(line): 41 | yield ll 42 | 43 | packages = list(parse_require_file(fname)) 44 | return packages 45 | 46 | 47 | setup( 48 | name='easy_tpp', 49 | version=get_version(), 50 | description='An easy and flexible tool for neural temporal point process', 51 | url = 'https://github.com/ant-research/EasyTemporalPointProcess/', 52 | # long_description = 'Our EasyTPP makes several unique contributions to this area: a unified interface of using existing datasets and adding new datasets; a wide range of evaluation programs that are easy to use and extend as well as facilitate reproducible research; implementations of popular neural TPPs, together with a rich library of modules by composing which one could quickly build complex models. ', 53 | # long_description=open('README.md').read(), 54 | # long_description_content_type='text/markdown', 55 | author='Alipay', 56 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 57 | include_package_data=True, 58 | classifiers=[ 59 | 'Programming Language :: Python :: 3' 60 | ], 61 | install_requires=parse_requirements('requirements.txt')) 62 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/EasyTemporalPointProcess/7e2b7a001a293c506bd595e8ddb72d83967c2cb2/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_data_loader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from easy_tpp.config_factory import DataSpecConfig 4 | from easy_tpp.utils import load_json 5 | from easy_tpp.preprocess.dataset import TPPDataset, EventTokenizer, get_data_loader 6 | 7 | 8 | class TestDataLoader(unittest.TestCase): 9 | def setUp(self): 10 | # Assuming the data is already generated and saved in 'synthetic_hf_data.json' 11 | self.data_file = 'synthetic_data.json' 12 | self.batch_size = 4 13 | self.input_data = self._make_json_2_dict(self.data_file) 14 | self.dataset = TPPDataset(self.input_data) 15 | 16 | config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 3, 17 | 'batch_size': self.batch_size, 18 | 'pad_token_id': 3}) 19 | 20 | self.tokenizer = EventTokenizer(config) 21 | 22 | self.data_loader = get_data_loader(self.dataset, 'torch', self.tokenizer, batch_size=self.batch_size) 23 | 24 | def _make_json_2_dict(self, json_dir): 25 | json_data = load_json(json_dir) 26 | res = dict() 27 | res['time_seqs'] = [x['time_since_start'] for x in json_data] 28 | res['time_delta_seqs'] = [x['time_since_last_event'] for x in json_data] 29 | res['type_seqs'] = [x['type_event'] for x in json_data] 30 | return res 31 | 32 | def test_data_loading(self): 33 | """Test if data is loaded correctly.""" 34 | self.assertIsNotNone(self.input_data) 35 | self.assertIsInstance(self.input_data, dict) 36 | self.assertGreater(len(self.input_data), 0) 37 | 38 | def test_batch_generation(self): 39 | """Test if batches are generated correctly.""" 40 | self.assertGreater(len(self.data_loader), 0) 41 | for batch in self.data_loader: 42 | self.assertLessEqual(batch['time_seqs'].shape[0], self.batch_size) 43 | self.assertIn('time_seqs', batch) 44 | self.assertIn('time_delta_seqs', batch) 45 | self.assertIn('type_seqs', batch) 46 | 47 | def test_batch_content(self): 48 | """Test if batch content is correct.""" 49 | for batch in self.data_loader: 50 | self.assertEqual(len(batch['time_seqs']), len(batch['time_delta_seqs'])) 51 | self.assertEqual(len(batch['time_seqs']), len(batch['type_seqs'])) 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tests/test_nhp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import os 6 | import sys 7 | 8 | # Get the directory of the current file 9 | current_file_path = os.path.abspath(__file__) 10 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) 11 | 12 | from easy_tpp.model import TorchNHP as NHP 13 | from easy_tpp.preprocess.dataset import get_data_loader 14 | from easy_tpp.config_factory import DataSpecConfig, ModelConfig 15 | from easy_tpp.utils import load_json 16 | from easy_tpp.preprocess.dataset import TPPDataset, EventTokenizer 17 | 18 | 19 | class TestNeuralHawkesProcess(unittest.TestCase): 20 | def setUp(self): 21 | """Set up the test environment.""" 22 | # Assuming the data is already generated and saved in 'synthetic_hf_data.json' 23 | self.data_file = 'synthetic_data.json' 24 | self.batch_size = 4 25 | self.input_data = self._make_json_2_dict(self.data_file) 26 | self.dataset = TPPDataset(self.input_data) 27 | 28 | config = DataSpecConfig.parse_from_yaml_config({'num_event_types': 3, 29 | 'batch_size': self.batch_size, 30 | 'pad_token_id': 3}) 31 | 32 | self.tokenizer = EventTokenizer(config) 33 | 34 | self.data_loader = get_data_loader(self.dataset, 'torch', self.tokenizer, batch_size=self.batch_size) 35 | 36 | model_config = ModelConfig.parse_from_yaml_config({'hidden_size': 32, 37 | 'loss_integral_num_sample_per_step': 20, 38 | 'num_event_types': 3, 39 | 'num_event_types_pad': 4, 40 | 'event_pad_index': 3}) 41 | self.model = NHP(model_config) 42 | 43 | def _make_json_2_dict(self, json_dir): 44 | json_data = load_json(json_dir) 45 | res = dict() 46 | res['time_seqs'] = [x['time_since_start'] for x in json_data] 47 | res['time_delta_seqs'] = ([np.array(x['time_since_last_event'], dtype=np.float32) for x in json_data]) 48 | res['type_seqs'] = [x['type_event'] for x in json_data] 49 | return res 50 | 51 | def test_model_initialization(self): 52 | """Test if the model is initialized correctly.""" 53 | self.assertIsInstance(self.model, NHP) 54 | self.assertEqual(self.model.hidden_size, 32) 55 | 56 | def test_forward_pass(self): 57 | """Test the forward pass of the model.""" 58 | batch = next(iter(self.data_loader)).values() 59 | output = self.model(batch) 60 | self.assertIsInstance(output[0], torch.Tensor) 61 | self.assertIsInstance(output[1], torch.Tensor) 62 | 63 | def test_loss_computation(self): 64 | """Test if the model computes loss correctly.""" 65 | batch = next(iter(self.data_loader)).values() 66 | loss = self.model.loglike_loss(batch) 67 | self.assertGreater(loss[0].item(), 0) # Loss should be positive 68 | 69 | def test_backward_pass(self): 70 | """Test if the model can perform a backward pass.""" 71 | batch = next(iter(self.data_loader)).values() 72 | loss = self.model.loglike_loss(batch) 73 | loss[0].backward() 74 | for param in self.model.parameters(): 75 | self.assertIsNotNone(param.grad) # Ensure gradients are computed 76 | 77 | def test_training_step(self): 78 | """Test a single training step.""" 79 | optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) 80 | self.model.train() 81 | for batch in self.data_loader: 82 | optimizer.zero_grad() 83 | loss = self.model.loglike_loss(batch.values()) 84 | loss[0].backward() 85 | optimizer.step() 86 | self.assertIsNotNone(loss[0]) # Ensure loss is computed 87 | break # Only run one step for the test 88 | 89 | 90 | if __name__ == '__main__': 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.2' 2 | --------------------------------------------------------------------------------