├── tests ├── __init__.py ├── test_datasets │ ├── __init__.py │ ├── test_sequence_dataset.py │ ├── test_clickhouse_sequence.py │ ├── datamodels.py │ ├── test_pandas_dataset.py │ └── test_sliceable_dataset.py ├── test_interface │ ├── __init__.py │ └── test_run_model.py ├── test_modules │ ├── __init__.py │ ├── test_configure_optimizer.py │ ├── test_combined_embedder.py │ ├── test_variational_event_model.py │ └── test_ib_event_model.py ├── conftest.py ├── test_utils │ ├── test_data │ │ ├── test_target_creator.py │ │ ├── test_tokenizer.py │ │ └── test_encoder.py │ └── test_date_arithmetic.py └── test_losses │ ├── test_mutual_information.py │ └── test_weight_scheduler.py ├── examples ├── __init__.py ├── Images │ ├── p_dists.PNG │ ├── card_users.PNG │ ├── BTYD_timeline.PNG │ ├── churn_val_loss.PNG │ ├── total_val_loss.PNG │ ├── clustering_both.png │ ├── low_p_last_trans.PNG │ ├── product_val_loss.PNG │ ├── balanced_last_trans.PNG │ ├── bimodal_cont_dist.PNG │ ├── high_p_last_trans.PNG │ ├── transaction_lengths.PNG │ └── long_intervals_last_trans.PNG ├── discrete_values.pkl ├── eventsprofiles_datamodel.py ├── train_btyd.py └── train_btyd_information_bottleneck.py ├── neural_lifetimes ├── data │ ├── __init__.py │ ├── datamodules │ │ └── __init__.py │ ├── dataloaders │ │ └── __init__.py │ ├── datasets │ │ └── __init__.py │ └── utils.py ├── models │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── event_model.py │ │ └── configure_optimizers.py │ └── nets │ │ ├── __init__.py │ │ ├── mlp.py │ │ ├── event_model.py │ │ ├── encoder_decoder.py │ │ ├── embedder.py │ │ └── heads.py ├── __init__.py ├── utils │ ├── aws │ │ ├── __init__.py │ │ ├── s3_utils.py │ │ └── utils.py │ ├── __init__.py │ ├── clickhouse │ │ ├── __init__.py │ │ ├── clickhouse_ingest.py │ │ └── schema.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── get_tensorboard_logger.py │ │ ├── projection_monitor.py │ │ ├── churn_monitor.py │ │ └── git.py │ ├── data │ │ ├── __init__.py │ │ ├── target_creator.py │ │ ├── tokenizer.py │ │ ├── encoder_with_unknown.py │ │ └── feature_encoder.py │ ├── plots │ │ └── __init__.py │ ├── date_arithmetic.py │ ├── scheduler.py │ └── score_estimators.py ├── metrics │ ├── __init__.py │ ├── sinkhorn.py │ └── kullback_leibler.py ├── losses │ ├── __init__.py │ ├── elbo.py │ └── mutual_information.py └── run_model.py ├── pages ├── _config.yml └── index.md ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── QUESTION.md │ ├── FEATURE_REQUEST.md │ └── BUG_REPORT.md ├── dependabot.yml ├── ISSUE_TEMPLATE.md ├── workflows │ ├── lint.yml │ ├── format.yml │ ├── test.yml │ ├── publish.yml │ └── docs.yml └── PULL_REQUEST_TEMPLATE.md ├── docs ├── _notebook_use_own_dataset.nblink ├── _notebook_BTYD_visualisation.nblink ├── Images │ ├── p_dists.PNG │ ├── card_users.PNG │ ├── BTYD_timeline.PNG │ ├── churn_val_loss.PNG │ ├── clustering_both.png │ ├── total_val_loss.PNG │ ├── bimodal_cont_dist.PNG │ ├── high_p_last_trans.PNG │ ├── low_p_last_trans.PNG │ ├── product_val_loss.PNG │ ├── balanced_last_trans.PNG │ ├── transaction_lengths.PNG │ └── long_intervals_last_trans.PNG ├── _static │ ├── model.png │ └── workflow.png ├── modules.rst ├── neural_lifetimes.utils.plots.rst ├── neural_lifetimes.losses.rst ├── neural_lifetimes.models.modules.rst ├── neural_lifetimes.data.dataloaders.rst ├── neural_lifetimes.data.datamodules.rst ├── neural_lifetimes.models.rst ├── index.rst ├── Makefile ├── neural_lifetimes.data.rst ├── neural_lifetimes.utils.clickhouse.rst ├── neural_lifetimes.rst ├── make.bat ├── neural_lifetimes.utils.aws.rst ├── neural_lifetimes.utils.rst ├── high_level_overview.rst ├── neural_lifetimes.utils.callbacks.rst ├── README.md ├── neural_lifetimes.data.datasets.rst ├── neural_lifetimes.models.nets.rst ├── intro.rst └── conf.py ├── .flake8 ├── requirements-dev.txt ├── pyproject.toml ├── .editorconfig ├── requirements.txt ├── LICENSE ├── setup.py ├── Makefile ├── .gitignore ├── CONTRIBUTING.md └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neural_lifetimes/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neural_lifetimes/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pages/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false -------------------------------------------------------------------------------- /tests/test_datasets/test_sequence_dataset.py: -------------------------------------------------------------------------------- 1 | # TODO implement tests for sequence dataset 2 | -------------------------------------------------------------------------------- /docs/_notebook_use_own_dataset.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../examples/use_own_dataset.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /neural_lifetimes/__init__.py: -------------------------------------------------------------------------------- 1 | from .run_model import run_model 2 | 3 | __all__ = ['run_model'] 4 | -------------------------------------------------------------------------------- /docs/_notebook_BTYD_visualisation.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../examples/BTYD_visualisation.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /docs/Images/p_dists.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/p_dists.PNG -------------------------------------------------------------------------------- /docs/_static/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/_static/model.png -------------------------------------------------------------------------------- /neural_lifetimes/utils/aws/__init__.py: -------------------------------------------------------------------------------- 1 | from .query import caching_query 2 | 3 | __all__ = [caching_query] 4 | -------------------------------------------------------------------------------- /docs/Images/card_users.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/card_users.PNG -------------------------------------------------------------------------------- /docs/_static/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/_static/workflow.png -------------------------------------------------------------------------------- /docs/Images/BTYD_timeline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/BTYD_timeline.PNG -------------------------------------------------------------------------------- /examples/Images/p_dists.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/p_dists.PNG -------------------------------------------------------------------------------- /examples/discrete_values.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/discrete_values.pkl -------------------------------------------------------------------------------- /docs/Images/churn_val_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/churn_val_loss.PNG -------------------------------------------------------------------------------- /docs/Images/clustering_both.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/clustering_both.png -------------------------------------------------------------------------------- /docs/Images/total_val_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/total_val_loss.PNG -------------------------------------------------------------------------------- /examples/Images/card_users.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/card_users.PNG -------------------------------------------------------------------------------- /docs/Images/bimodal_cont_dist.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/bimodal_cont_dist.PNG -------------------------------------------------------------------------------- /docs/Images/high_p_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/high_p_last_trans.PNG -------------------------------------------------------------------------------- /docs/Images/low_p_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/low_p_last_trans.PNG -------------------------------------------------------------------------------- /docs/Images/product_val_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/product_val_loss.PNG -------------------------------------------------------------------------------- /examples/Images/BTYD_timeline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/BTYD_timeline.PNG -------------------------------------------------------------------------------- /examples/Images/churn_val_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/churn_val_loss.PNG -------------------------------------------------------------------------------- /examples/Images/total_val_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/total_val_loss.PNG -------------------------------------------------------------------------------- /docs/Images/balanced_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/balanced_last_trans.PNG -------------------------------------------------------------------------------- /docs/Images/transaction_lengths.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/transaction_lengths.PNG -------------------------------------------------------------------------------- /examples/Images/clustering_both.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/clustering_both.png -------------------------------------------------------------------------------- /examples/Images/low_p_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/low_p_last_trans.PNG -------------------------------------------------------------------------------- /examples/Images/product_val_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/product_val_loss.PNG -------------------------------------------------------------------------------- /examples/Images/balanced_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/balanced_last_trans.PNG -------------------------------------------------------------------------------- /examples/Images/bimodal_cont_dist.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/bimodal_cont_dist.PNG -------------------------------------------------------------------------------- /examples/Images/high_p_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/high_p_last_trans.PNG -------------------------------------------------------------------------------- /examples/Images/transaction_lengths.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/transaction_lengths.PNG -------------------------------------------------------------------------------- /neural_lifetimes/data/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence_datamodule import SequenceDataModule 2 | 3 | __all__ = [SequenceDataModule] 4 | -------------------------------------------------------------------------------- /docs/Images/long_intervals_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/docs/Images/long_intervals_last_trans.PNG -------------------------------------------------------------------------------- /neural_lifetimes/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .date_arithmetic import datetime2float, float2datetime 2 | 3 | __all__ = [datetime2float, float2datetime] 4 | -------------------------------------------------------------------------------- /examples/Images/long_intervals_last_trans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transferwise/neural-lifetimes/HEAD/examples/Images/long_intervals_last_trans.PNG -------------------------------------------------------------------------------- /neural_lifetimes/data/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence_loader import SequenceLoader, get_last, trim_last 2 | 3 | __all__ = [SequenceLoader, get_last, trim_last] 4 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/clickhouse/__init__.py: -------------------------------------------------------------------------------- 1 | from .clickhouse_ingest import clickhouse_ingest, clickhouse_ranges 2 | 3 | __all__ = [clickhouse_ingest, clickhouse_ranges] 4 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | neural_lifetimes package 2 | ============================================================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | neural_lifetimes 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/QUESTION.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask anything about this project 4 | title: '' 5 | labels: help wanted 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Your question** -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = D100,D104,D105,D200,D401,E203,W503,D101,D102,D103,D107 3 | max-line-length=120 4 | exclude = 5 | .git 6 | __pycache__ 7 | docs 8 | build 9 | dist 10 | venv 11 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==22.3.0 2 | click==8.0.4 3 | flake8==4.0.1 4 | flake8-docstrings==1.6.0 5 | nbsphinx==0.8.8 6 | nbsphinx-link==1.3.0 7 | pytest==6.2.5 8 | pytest-cov==3.0.0 9 | python-snappy==0.6.1 10 | sphinx==4.3.2 11 | sphinx-rtd-theme==1.0.0 12 | -------------------------------------------------------------------------------- /neural_lifetimes/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .sinkhorn import WassersteinMetric 2 | from .kullback_leibler import KullbackLeiblerDivergence, ParametricKullbackLeiblerDivergence 3 | 4 | __all__ = [KullbackLeiblerDivergence, WassersteinMetric, ParametricKullbackLeiblerDivergence] 5 | -------------------------------------------------------------------------------- /neural_lifetimes/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .event_model import EventModel 2 | from .variational_event_model import VariationalEventModel 3 | from .information_bottleneck_event_model import InformationBottleneckEventModel 4 | 5 | __all__ = [EventModel, InformationBottleneckEventModel, VariationalEventModel] 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | 4 | [tool.pytest.ini_options] 5 | addopts = "--strict-markers -vv" 6 | markers = [ 7 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 8 | ] 9 | testpaths = [ 10 | "tests", 11 | ] 12 | 13 | [tool.black] 14 | line-length = 120 15 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .churn_monitor import MonitorChurn 2 | from .distribution_monitor import DistributionMonitor 3 | 4 | from .projection_monitor import MonitorProjection 5 | from .git import GitInformationLogger 6 | 7 | __all__ = [DistributionMonitor, GitInformationLogger, MonitorChurn, MonitorProjection] 8 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | charset = utf-8 9 | end_of_line = lf 10 | 11 | [Makefile] 12 | indent_style = tab 13 | 14 | [LICENSE] 15 | insert_final_newline = false 16 | 17 | [*.{diff,patch}] 18 | trim_trailing_whitespace = false 19 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # This is an automatically generated base configuration 2 | # For further configuration options and tuning: 3 | # https://docs.github.com/en/free-pro-team@latest/github/administering-a-repository/configuration-options-for-dependency-updates 4 | 5 | version: 2 6 | updates: 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3~=1.17 2 | botocore~=1.20 3 | clickhouse_driver~=0.2.4 4 | cloudpickle~=2.0 5 | ctgan~=0.5.1 6 | deepecho~=0.3.0.post1 7 | matplotlib~=3.5 8 | numpy~=1.21 9 | pandas~=1.1 10 | pyarrow~=7.0 11 | torch~=1.11 12 | pytorch_lightning~=1.6 13 | rdt~=0.6.3 14 | scikit_learn~=1.0 15 | scipy~=1.6 16 | sdv~=0.13.1 17 | SQLAlchemy~=1.3.24 18 | torchmetrics~=0.8 19 | tqdm~=4.64 20 | vaex~=4.7 21 | plotly~=5.7 -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytorch_lightning.utilities import seed 3 | 4 | 5 | @pytest.fixture(autouse=True) 6 | def random_seed(): 7 | """ 8 | For reproducibility, set the random seed everywhere. 9 | 10 | PyTorch Lightning's seed.seed_everything sets the random seen in torch, numpy and 11 | python.random. Pandas uses the random state from numpy. 12 | """ 13 | seed.seed_everything(0) 14 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder_with_unknown import OrdinalEncoder, OrdinalEncoderWithUnknown 2 | from .feature_encoder import FeatureDictionaryEncoder 3 | from .target_creator import TargetCreator, DummyTransform 4 | from .tokenizer import Tokenizer 5 | 6 | __all__ = [ 7 | OrdinalEncoder, 8 | OrdinalEncoderWithUnknown, 9 | FeatureDictionaryEncoder, 10 | TargetCreator, 11 | DummyTransform, 12 | Tokenizer, 13 | ] 14 | -------------------------------------------------------------------------------- /neural_lifetimes/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .btyd import BTYD 2 | from .clickhouse_sequence import ClickhouseSequenceDataset 3 | from .pandas_dataset import PandasSequenceDataset 4 | from .sequence_dataset import SequenceDataset, SequenceSubset, SliceableDataset, SliceableSubset 5 | 6 | __all__ = [ 7 | BTYD, 8 | ClickhouseSequenceDataset, 9 | SequenceDataset, 10 | SequenceSubset, 11 | SliceableDataset, 12 | SliceableSubset, 13 | PandasSequenceDataset, 14 | ] 15 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/plots/__init__.py: -------------------------------------------------------------------------------- 1 | from .plots import ( 2 | plot_cont_feature, 3 | plot_cont_features_pd, 4 | plot_discr_feature, 5 | plot_freq_transactions, 6 | plot_num_transactions, 7 | plot_timeline, 8 | plot_transactions, 9 | ) 10 | 11 | __all__ = [ 12 | plot_cont_feature, 13 | plot_cont_features_pd, 14 | plot_discr_feature, 15 | plot_freq_transactions, 16 | plot_num_transactions, 17 | plot_timeline, 18 | plot_transactions, 19 | ] 20 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.utils.plots.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.utils.plots package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.utils.plots.plots module 8 | ------------------------------------------ 9 | 10 | .. automodule:: neural_lifetimes.utils.plots.plots 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: neural_lifetimes.utils.plots 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Subject of the issue 2 | Describe your issue here. 3 | 4 | ### Your environment 5 | * Version of neural-lifetimes, e.g branch/commit #/release/tag 6 | * Source type and setup 7 | * Target type and setup 8 | 9 | ### Steps to reproduce 10 | Tell us how to reproduce this issue. 11 | 12 | ### Expected behaviour 13 | Tell us what should happen 14 | 15 | ### Actual behaviour 16 | Tell us what happens instead 17 | 18 | ### Further notes 19 | Tell us why you think this might be happening, or stuff you tried that didn't work 20 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.losses.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.losses package 2 | ================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.losses.distributional\_losses module 8 | ------------------------------------------------------ 9 | 10 | .. automodule:: neural_lifetimes.losses.distributional_losses 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: neural_lifetimes.losses 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /neural_lifetimes/models/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom_par import DeepEchoModel, PAREdit, PARModel, PARNet 2 | from .custom_tvae import DataTransformer, Decoder, TVAESynthesizer 3 | from .embedder import CombinedEmbedder 4 | from .encoder_decoder import VariationalEncoderDecoder 5 | from .event_model import EventEncoder 6 | 7 | __all__ = [ 8 | DeepEchoModel, 9 | PAREdit, 10 | PARModel, 11 | PARNet, 12 | DataTransformer, 13 | Decoder, 14 | TVAESynthesizer, 15 | CombinedEmbedder, 16 | VariationalEncoderDecoder, 17 | EventEncoder, 18 | ] 19 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | check: 13 | runs-on: ubuntu-20.04 14 | 15 | steps: 16 | - name: Checking out repo 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.9 23 | 24 | - name: Install dependencies 25 | run: make venv 26 | 27 | - name: Linting 28 | run: make lint 29 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.models.modules.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.models.modules package 2 | ======================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.models.modules.classic\_model module 8 | ------------------------------------------------------ 9 | 10 | .. automodule:: neural_lifetimes.models.modules.classic_model 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: neural_lifetimes.models.modules 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /.github/workflows/format.yml: -------------------------------------------------------------------------------- 1 | name: Format 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | check: 13 | runs-on: ubuntu-20.04 14 | 15 | steps: 16 | - name: Checking out repo 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.9 23 | 24 | - name: Install dependencies 25 | run: make venv 26 | 27 | - name: Formatting 28 | run: make checkformat 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 TransferWise Ltd. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.data.dataloaders.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.data.dataloaders package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.data.dataloaders.sequence\_loader module 8 | ---------------------------------------------------------- 9 | 10 | .. automodule:: neural_lifetimes.data.dataloaders.sequence_loader 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: neural_lifetimes.data.dataloaders 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.data.datamodules.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.data.datamodules package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.data.datamodules.sequence\_datamodule module 8 | -------------------------------------------------------------- 9 | 10 | .. automodule:: neural_lifetimes.data.datamodules.sequence_datamodule 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: neural_lifetimes.data.datamodules 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /neural_lifetimes/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributional_losses import ( 2 | CategoricalLoss, 3 | ChurnLoss, 4 | CompositeLoss, 5 | ExponentialLoss, 6 | LogNormalLoss, 7 | NormalLoss, 8 | SumLoss, 9 | TauLoss, 10 | ) 11 | 12 | from .mutual_information import InformationBottleneckLoss 13 | from .elbo import VariationalEncoderDecoderLoss 14 | 15 | __all__ = [ 16 | CategoricalLoss, 17 | ChurnLoss, 18 | CompositeLoss, 19 | ExponentialLoss, 20 | InformationBottleneckLoss, 21 | LogNormalLoss, 22 | NormalLoss, 23 | SumLoss, 24 | TauLoss, 25 | VariationalEncoderDecoderLoss, 26 | ] 27 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.models.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.models package 2 | ================================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | neural_lifetimes.models.modules 11 | neural_lifetimes.models.nets 12 | 13 | Submodules 14 | ---------- 15 | 16 | neural\_lifetimes.models.targets module 17 | --------------------------------------- 18 | 19 | .. automodule:: neural_lifetimes.models.targets 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: neural_lifetimes.models 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Wise Model documentation master file, created by 2 | sphinx-quickstart on Sun Nov 7 16:35:07 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Neural Lifetimes' documentation! 7 | ====================================== 8 | 9 | .. include:: intro.rst 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Contents: 14 | 15 | high_level_overview 16 | _notebook_BTYD_visualisation 17 | _notebook_use_own_dataset 18 | modules 19 | 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/callbacks/get_tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 3 | 4 | 5 | def _get_tensorboard_logger(trainer: pl.Trainer) -> TensorBoardLogger: 6 | for logger in trainer.loggers: 7 | if isinstance(logger, TensorBoardLogger): 8 | break 9 | else: 10 | logger_types = [type(logger).__name__ for logger in trainer.loggers] 11 | raise ValueError( 12 | f"No Tensorboard logger found in the lightning Trainer's loggers. Got {logger_types} instead." 13 | + "This callback requires a Tensorboard logger." 14 | ) 15 | 16 | return logger 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.data.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.data package 2 | ============================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | neural_lifetimes.data.dataloaders 11 | neural_lifetimes.data.datamodules 12 | neural_lifetimes.data.datasets 13 | 14 | Submodules 15 | ---------- 16 | 17 | neural\_lifetimes.data.data\_processing module 18 | ---------------------------------------------- 19 | 20 | .. automodule:: neural_lifetimes.data.data_processing 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | Module contents 26 | --------------- 27 | 28 | .. automodule:: neural_lifetimes.data 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | check: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: [ubuntu-latest, macos-latest, windows-latest] 17 | python-version: [3.7, 3.8, 3.9] 18 | 19 | steps: 20 | - name: Checking out repo 21 | uses: actions/checkout@v2 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install dependencies 29 | run: make venv 30 | 31 | - name: Testing and coverage 32 | run: make cov 33 | -------------------------------------------------------------------------------- /neural_lifetimes/models/modules/event_model.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from typing import Any, Dict, Tuple 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from abc import ABC, abstractmethod 7 | 8 | 9 | # TODO implement base _EventModel 10 | class EventModel(pl.LightningModule, ABC): # TODO Add better docstring 11 | @abstractmethod 12 | def configure_criterion(self): 13 | pass 14 | 15 | @abstractmethod 16 | def build_parameter_dict(self) -> Dict[str, Any]: 17 | pass 18 | 19 | @abstractmethod 20 | def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 21 | pass 22 | 23 | @abstractmethod 24 | def encode(self, x: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]: 25 | pass 26 | 27 | 28 | __all__ = [EventModel] 29 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.utils.clickhouse.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.utils.clickhouse package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.utils.clickhouse.clickhouse\_ingest module 8 | ------------------------------------------------------------ 9 | 10 | .. automodule:: neural_lifetimes.utils.clickhouse.clickhouse_ingest 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | neural\_lifetimes.utils.clickhouse.schema module 16 | ------------------------------------------------ 17 | 18 | .. automodule:: neural_lifetimes.utils.clickhouse.schema 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: neural_lifetimes.utils.clickhouse 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-20.04 11 | 12 | steps: 13 | - name: Checkout repo 14 | uses: actions/checkout@v2.3.4 15 | 16 | - name: Install dependencies 17 | run: python -m pip install --upgrade pip setuptools wheel twine 18 | 19 | - name: Build distributable package 20 | run: python setup.py bdist_wheel sdist 21 | 22 | - name: Check distributable 23 | run: twine check --strict dist/*.whl 24 | 25 | - name: Publish package 26 | if: ${{ github.event_name == 'release' && github.event.action == 'published' }} 27 | uses: pypa/gh-action-pypi-publish@release/v1 28 | with: 29 | user: __token__ 30 | password: ${{ secrets.PYPI_API_TOKEN }} 31 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | neural_lifetimes.data 11 | neural_lifetimes.losses 12 | neural_lifetimes.models 13 | neural_lifetimes.utils 14 | 15 | Submodules 16 | ---------- 17 | 18 | neural\_lifetimes.inference module 19 | ---------------------------------- 20 | 21 | .. automodule:: neural_lifetimes.inference 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | neural\_lifetimes.run\_model module 27 | ----------------------------------- 28 | 29 | .. automodule:: neural_lifetimes.run_model 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | Module contents 35 | --------------- 36 | 37 | .. automodule:: neural_lifetimes 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | -------------------------------------------------------------------------------- /examples/eventsprofiles_datamodel.py: -------------------------------------------------------------------------------- 1 | cont_feat_events = [ 2 | "INVOICE_VALUE_GBP_log", 3 | "FEE_INVOICE_ratio", 4 | "MARGIN_GBP", 5 | ] 6 | cont_feat_profiles = ["AGE_YEARS"] 7 | 8 | discr_feat_events = [ 9 | "ACTION_STATE", 10 | "ACTION_TYPE", 11 | "BALANCE_STEP_TYPE", 12 | "BALANCE_TRANSACTION_TYPE", 13 | "PRODUCT_TYPE", 14 | "SENDER_TYPE", 15 | "SOURCE_CURRENCY", 16 | "SUCCESSFUL_ACTION", 17 | "TARGET_CURRENCY", 18 | ] 19 | 20 | discr_feat_profiles = [ 21 | "ADDRESS_MARKET", 22 | "ADDR_COUNTRY", 23 | "BEST_GUESS_COUNTRY", 24 | "CREATION_PLATFORM", 25 | "CUSTOMER_TYPE", 26 | "FIRST_CCY_SEND", 27 | "FIRST_CCY_TARGET", 28 | "FIRST_VISIT_IP_COUNTRY", 29 | "LANGUAGE", 30 | ] 31 | 32 | target_cols = ["dt"] 33 | cont_feat = cont_feat_events + cont_feat_profiles 34 | discr_feat = discr_feat_events + discr_feat_profiles 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/BUG_REPORT.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | --- 11 | name: Bug report 12 | about: Create a report to help us improve 13 | title: '' 14 | labels: '' 15 | assignees: '' 16 | 17 | --- 18 | 19 | **Describe the bug** 20 | A clear and concise description of what the bug is. 21 | 22 | **To Reproduce** 23 | Steps to reproduce the behavior: 24 | 1. Prepare the data as '...' 25 | 2. Run '....' 26 | 4. See error 27 | 28 | **Expected behavior** 29 | A clear and concise description of what you expected to happen. 30 | 31 | **Screenshots** 32 | If applicable, add screenshots to help explain your problem. 33 | 34 | **Your environment** 35 | - Version of neural-lifetimes, e.g branch/commit #/release/tag 36 | - Source type and setup 37 | - Target type and setup 38 | 39 | **Additional context** 40 | Add any other context about the problem here. 41 | -------------------------------------------------------------------------------- /neural_lifetimes/models/nets/mlp.py: -------------------------------------------------------------------------------- 1 | # # dont think this module is used at all 2 | # from typing import List 3 | 4 | # from torch import nn as nn 5 | 6 | 7 | # class MLP(nn.Module): 8 | # def __init__(self, layers: List[int], drop_rate=0.0): 9 | # super().__init__() 10 | # self.layers = nn.ModuleList() 11 | # for in_dim, out_dim in zip(layers[:-2], layers[1:-1]): 12 | # self.layers.append(nn.Linear(in_dim, out_dim)) 13 | # self.layers.append(nn.Dropout(drop_rate)) 14 | # self.layers.append(nn.ReLU()) 15 | # self.layers.append(nn.Linear(layers[-2], layers[-1])) 16 | 17 | # def forward(self, x): 18 | # batch_size = x["target"].shape[0] 19 | # x_in = x["data"].reshape(batch_size, -1) 20 | # for i, layer in enumerate(self.layers): 21 | # x_in = layer(x_in) 22 | 23 | # x["output"] = x_in 24 | # return x 25 | 26 | # TODO delete this file or uncomment it 27 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.utils.aws.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.utils.aws package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.utils.aws.query module 8 | ---------------------------------------- 9 | 10 | .. automodule:: neural_lifetimes.utils.aws.query 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | neural\_lifetimes.utils.aws.s3\_utils module 16 | -------------------------------------------- 17 | 18 | .. automodule:: neural_lifetimes.utils.aws.s3_utils 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | neural\_lifetimes.utils.aws.utils module 24 | ---------------------------------------- 25 | 26 | .. automodule:: neural_lifetimes.utils.aws.utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: neural_lifetimes.utils.aws 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /tests/test_utils/test_data/test_target_creator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from neural_lifetimes.utils.data import TargetCreator 4 | 5 | 6 | def test_constructor(): 7 | TargetCreator(cols=["CF1", "DF1"]) 8 | 9 | 10 | def test_parameter_dict(): 11 | target_creator = TargetCreator(cols=["CF1", "DF1"]) 12 | assert target_creator.build_parameter_dict() == {"columns": "['CF1', 'DF1']"} 13 | 14 | 15 | def test_call(): 16 | target_creator = TargetCreator(cols=["CF1", "DF1"]) 17 | data = { 18 | "t": np.array([1, 3, 4]), 19 | "dt": np.array([0, 2, 1]), 20 | "CF1": np.array([1, 2, 3]), 21 | "DF1": np.array([1, 2, 3]), 22 | } 23 | expected = { 24 | "next_dt": np.array([2, 1]), 25 | "next_CF1": np.array([2, 3]), 26 | "next_DF1": np.array([2, 3]), 27 | **data, 28 | } 29 | 30 | transformed_data = target_creator(data) 31 | 32 | for k, v in transformed_data.items(): 33 | assert np.all(expected[k] == v) 34 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | pull_request: 8 | branches: [main] 9 | push: 10 | branches: [main] 11 | 12 | jobs: 13 | build-docs: 14 | runs-on: ubuntu-20.04 15 | steps: 16 | 17 | - name: Checkout 18 | uses: actions/checkout@master 19 | with: 20 | fetch-depth: 0 21 | 22 | - uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.9 25 | 26 | - name: Install dependencies 27 | run: | 28 | sudo apt update -y 29 | sudo apt install -y pandoc 30 | make venv 31 | 32 | - name: Build documentation 33 | run: | 34 | make docs 35 | 36 | - name: Publish documentation 37 | if: ${{ github.ref == 'refs/heads/main' || contains('refs/tags/', github.ref) }} 38 | uses: peaceiris/actions-gh-pages@v3 39 | with: 40 | github_token: ${{ secrets.GITHUB_TOKEN }} 41 | publish_dir: ./docs/build 42 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.utils.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.utils package 2 | =============================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | neural_lifetimes.utils.aws 11 | neural_lifetimes.utils.callbacks 12 | neural_lifetimes.utils.clickhouse 13 | neural_lifetimes.utils.plots 14 | 15 | Submodules 16 | ---------- 17 | 18 | neural\_lifetimes.utils.date\_arithmetic module 19 | ----------------------------------------------- 20 | 21 | .. automodule:: neural_lifetimes.utils.date_arithmetic 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | neural\_lifetimes.utils.encoder\_with\_unknown module 27 | ----------------------------------------------------- 28 | 29 | .. automodule:: neural_lifetimes.utils.encoder_with_unknown 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | Module contents 35 | --------------- 36 | 37 | .. automodule:: neural_lifetimes.utils 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | -------------------------------------------------------------------------------- /docs/high_level_overview.rst: -------------------------------------------------------------------------------- 1 | High Level Overview 2 | ~~~~~~~~~~~~~~~~~~~~~~ 3 | 4 | This is intended to be a high-level documentation of how the model is structured. UML diagrams are used when necessary. 5 | 6 | The overall workflow of using Neural Lifetimes is represented by the following diagram: 7 | 8 | .. image :: _static/workflow.png 9 | :align: center 10 | :width: 600px 11 | 12 | Library Functionalities 13 | ------------------------ 14 | 15 | This library contains several parts, each with different functionalities: 16 | 17 | - neural_lifetimes: This package contains the source code, including the model, model settings, trainers, and data handlers. 18 | - ml_utils: This package contains the utilities for the event model, including loss functions, encoders, and decoders. 19 | - clickhouse_utils: This package contains the utilities for the clickhouse database. 20 | 21 | Details of the utilisation of the user interface can be found in Quickstart. 22 | 23 | Model 24 | ------ 25 | 26 | The model is structured as follows: 27 | 28 | .. image :: _static/model.png 29 | :align: center 30 | :width: 600px 31 | 32 | 33 | -------------------------------------------------------------------------------- /tests/test_datasets/test_clickhouse_sequence.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import pytest 3 | 4 | from neural_lifetimes.data.datasets import ClickhouseSequenceDataset 5 | 6 | TIME_COL = "transaction_time" 7 | UID_COL = "uid" 8 | 9 | 10 | @pytest.fixture 11 | def start_date(): 12 | return datetime(2020, 1, 1) 13 | 14 | 15 | @pytest.fixture 16 | def uids(): 17 | return list(range(1, 11)) 18 | 19 | 20 | def _construct_dataset(): 21 | return ClickhouseSequenceDataset() 22 | 23 | 24 | class TestConstruction: 25 | @pytest.mark.xfail 26 | def test_sets_static_data(self): 27 | assert False 28 | 29 | @pytest.mark.xfail 30 | def test_filters_by_min_items_per_uid(self): 31 | assert False 32 | 33 | @pytest.mark.xfail 34 | def test_filters_by_as_of_time(self): 35 | assert False 36 | 37 | @pytest.mark.xfail 38 | def test_len(self): 39 | assert False 40 | 41 | 42 | class TestGetItem: 43 | @pytest.mark.xfail 44 | def test_item_getter_int(self): 45 | assert False 46 | 47 | @pytest.mark.xfail 48 | def test_item_getter_list(self): 49 | assert False 50 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.utils.callbacks.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.utils.callbacks package 2 | ========================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.utils.callbacks.churn\_monitor module 8 | ------------------------------------------------------- 9 | 10 | .. automodule:: neural_lifetimes.utils.callbacks.churn_monitor 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | neural\_lifetimes.utils.callbacks.distribution\_monitor module 16 | -------------------------------------------------------------- 17 | 18 | .. automodule:: neural_lifetimes.utils.callbacks.distribution_monitor 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | neural\_lifetimes.utils.callbacks.projection\_monitor module 24 | ------------------------------------------------------------ 25 | 26 | .. automodule:: neural_lifetimes.utils.callbacks.projection_monitor 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: neural_lifetimes.utils.callbacks 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def _read_requirements_file(path: str): 7 | with open(path) as f: 8 | return list( 9 | map( 10 | lambda req: req.strip(), 11 | f.readlines(), 12 | ) 13 | ) 14 | 15 | 16 | with open("README.md") as f: 17 | long_description = f.read() 18 | 19 | # TODO Mark: Update before release 20 | setup( 21 | name="neural-lifetimes", 22 | version="0.1.0", 23 | description="User behavior prediction from event data.", 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | author="Wise", 27 | url="https://github.com/transferwise/neural_lifetimes", 28 | classifiers=[ 29 | "License :: OSI Approved :: Apache Software License", 30 | "Programming Language :: Python :: 3 :: Only", 31 | "Programming Language :: Python :: 3.8", 32 | "Programming Language :: Python :: 3.9", 33 | ], 34 | py_modules=["neural_lifetimes"], 35 | install_requires=_read_requirements_file("requirements.txt"), 36 | extras_require={ 37 | "test": _read_requirements_file("requirements-dev.txt"), 38 | }, 39 | packages=find_packages(exclude=["tests*"]), 40 | package_data={"neural_lifetimes.datasets": ["data/*.pkl"]}, 41 | ) 42 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Problem 2 | 3 | _Describe the problem your PR is trying to solve_ 4 | 5 | ## Proposed changes 6 | 7 | _Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. 8 | If it fixes a bug or resolves a feature request, be sure to link to that issue._ 9 | 10 | 11 | ## Types of changes 12 | 13 | What types of changes does your code introduce to neural-lifetimes? 14 | _Put an `x` in the boxes that apply_ 15 | 16 | - [ ] Bugfix (non-breaking change which fixes an issue) 17 | - [ ] New feature (non-breaking change which adds functionality) 18 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 19 | - [ ] Documentation Update (if none of the other choices apply) 20 | 21 | 22 | ## Checklist 23 | 24 | - [ ] I have read the [CONTRIBUTING](../CONTRIBUTING.md) doc 25 | - [ ] Description above provides context of the change 26 | - [ ] I have added tests that prove my fix is effective or that my feature works 27 | - [ ] Unit tests for changes (not needed for documentation changes) 28 | - [ ] CI checks pass with my changes 29 | - [ ] Bumping version in `setup.py` is an individual PR and not mixed with feature or bugfix PRs 30 | - [ ] Commits follow "[How to write a good git commit message](http://chris.beams.io/posts/git-commit/)" 31 | - [ ] Relevant documentation is updated including usage instructions 32 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/aws/s3_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple 3 | 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | 7 | 8 | def decompose_s3_url(save_loc: str) -> Tuple[str, str]: 9 | assert save_loc[:5] == "s3://" 10 | save_loc = save_loc[5:] 11 | bucket = save_loc.split("/")[0] 12 | s3_file_name = save_loc[(len(bucket) + 1) :] 13 | return bucket, s3_file_name 14 | 15 | 16 | def file_exists_in_s3(s3_url): 17 | bucket, s3_file_name = decompose_s3_url(s3_url) 18 | s3_resource = boto3.resource("s3") 19 | 20 | bucket_con = s3_resource.Bucket(bucket) 21 | obj = list(bucket_con.objects.filter(Prefix=s3_file_name)) 22 | return len(obj) > 0 23 | 24 | 25 | def save_file_to_s3(local_file_name: str, s3_url: str): 26 | bucket, s3_file_name = decompose_s3_url(s3_url) 27 | s3_client = boto3.client("s3") 28 | try: 29 | return s3_client.upload_file(local_file_name, bucket, s3_file_name) 30 | except ClientError as e: 31 | logging.error(e) 32 | raise 33 | 34 | 35 | def get_file_from_s3(s3_url: str, local_file_name: str): 36 | bucket, s3_file_name = decompose_s3_url(s3_url) 37 | s3 = boto3.resource("s3") 38 | try: 39 | return s3.Bucket(bucket).download_file(s3_file_name, local_file_name) 40 | except ClientError as e: 41 | if e.response["Error"]["Code"] == "404": 42 | print("The object does not exist.") 43 | else: 44 | raise 45 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Generating documentation 2 | 3 | Neural Lifetimes uses Sphinx to auto generate its own documentation. This document provides a guide to update our documentation with your modifications. 4 | 5 | You can also check out [Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/quickstart.html) to see how to [write the rst files])(https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html), and [document code](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#module-sphinx.ext.autodoc). 6 | 7 | 8 | ## Incorporating notebooks 9 | 10 | Sphinx on its own doesn't allow us to incorporate notebooks into the documentation easily. Here is how to do so with the [nbsphinx](https://github.com/spatialaudio/nbsphinx) and [nbsphinx-link](https://github.com/vidartf/nbsphinx-link) libraries (which are included in this package): 11 | 12 | 1. Edit `index.rst` and add the names of your jupyter notebooks (without ".ipynb") to the toctree. 13 | 2. For notebooks which are not in the docs folder, instead create a `.nblink` file that links to it like so 14 | ``` 15 | { 16 | "path": "/relative/path/to/notebook.ipynb", 17 | "extra-media": [ 18 | "/relative/path/to/images/folder", 19 | "/relative/path/to/specific/image.png" 20 | ] 21 | } 22 | ``` 23 | where the path is the relative path to the `docs` folder. 24 | 3. Run `make html` to build the html files. 25 | 26 | To link to the documentation of other functions, use `:mod:|bt|DisplayName |bt|`. 27 | -------------------------------------------------------------------------------- /pages/index.md: -------------------------------------------------------------------------------- 1 | ## Welcome to GitHub Pages 2 | 3 | You can use the [editor on GitHub](https://github.com/transferwise/neural-lifetimes/edit/gh-pages/index.md) to maintain and preview the content for your website in Markdown files. 4 | 5 | Whenever you commit to this repository, GitHub Pages will run [Jekyll](https://jekyllrb.com/) to rebuild the pages in your site, from the content in your Markdown files. 6 | 7 | ### Markdown 8 | 9 | Markdown is a lightweight and easy-to-use syntax for styling your writing. It includes conventions for 10 | 11 | ```markdown 12 | Syntax highlighted code block 13 | 14 | # Header 1 15 | ## Header 2 16 | ### Header 3 17 | 18 | - Bulleted 19 | - List 20 | 21 | 1. Numbered 22 | 2. List 23 | 24 | **Bold** and _Italic_ and `Code` text 25 | 26 | [Link](url) and ![Image](src) 27 | ``` 28 | 29 | For more details see [Basic writing and formatting syntax](https://docs.github.com/en/github/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax). 30 | 31 | ### Jekyll Themes 32 | 33 | Your Pages site will use the layout and styles from the Jekyll theme you have selected in your [repository settings](https://github.com/transferwise/neural-lifetimes/settings/pages). The name of this theme is saved in the Jekyll `_config.yml` configuration file. 34 | 35 | ### Support or Contact 36 | 37 | Having trouble with Pages? Check out our [documentation](https://docs.github.com/categories/github-pages-basics/) or [contact support](https://support.github.com/contact) and we’ll help you sort it out. 38 | -------------------------------------------------------------------------------- /neural_lifetimes/metrics/sinkhorn.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Any 2 | 3 | import torch 4 | 5 | from scipy.stats import wasserstein_distance 6 | import torchmetrics 7 | 8 | # This implementation is based on scipy.stats.wasserstein_metric. 9 | # There are solvers for differentiable approximations with GPU acceleration 10 | # the geomloss package: https://www.kernel-operations.io/geomloss/api/pytorch-api.html#geomloss.SamplesLoss 11 | 12 | 13 | class WassersteinMetric(torchmetrics.Metric): 14 | def __init__(self, compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any]) -> None: 15 | super().__init__() 16 | # intiialise states 17 | self.add_state("preds", default=[], dist_reduce_fx="cat") 18 | self.add_state("target", default=[], dist_reduce_fx="cat") 19 | 20 | def update(self, preds: torch.Tensor, target: torch.Tensor): 21 | assert preds.shape == target.shape, ( 22 | "``preds`` and ``target`` need to have the same shape. " 23 | + f"Got ``{preds.shape}`` and ``{target.shape}`` instead." 24 | ) 25 | preds, target = preds.flatten(), target.flatten() 26 | assert preds.dim() == 1, f"The input tensors need to be one-dimensional. Got {preds.dim()} dimensions." 27 | self.preds.append(preds) 28 | self.target.append(target) 29 | 30 | def compute(self): 31 | preds = torch.cat(self.preds) 32 | target = torch.cat(self.target) 33 | metric = wasserstein_distance(preds.cpu().numpy(), target.cpu().numpy()) 34 | return metric 35 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.data.datasets.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.data.datasets package 2 | ======================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.data.datasets.btyd module 8 | ------------------------------------------- 9 | 10 | .. automodule:: neural_lifetimes.data.datasets.btyd 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | neural\_lifetimes.data.datasets.clickhouse\_sequence module 16 | ----------------------------------------------------------- 17 | 18 | .. automodule:: neural_lifetimes.data.datasets.clickhouse_sequence 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | neural\_lifetimes.data.datasets.pandas\_dataset module 24 | ------------------------------------------------------ 25 | 26 | .. automodule:: neural_lifetimes.data.datasets.pandas_dataset 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | neural\_lifetimes.data.datasets.sequence\_dataset module 32 | -------------------------------------------------------- 33 | 34 | .. automodule:: neural_lifetimes.data.datasets.sequence_dataset 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | neural\_lifetimes.data.datasets.sequence\_sampling\_dataset module 40 | ------------------------------------------------------------------ 41 | 42 | .. automodule:: neural_lifetimes.data.datasets.sequence_sampling_dataset 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: neural_lifetimes.data.datasets 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/data/target_creator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Sequence, Union 2 | 3 | import numpy as np 4 | 5 | 6 | class DummyTransform: 7 | def __call__(self, x, *args, **kwargs): 8 | return x 9 | 10 | def output_len(self, input_len: int): 11 | return input_len 12 | 13 | 14 | class TargetCreator: 15 | """ 16 | A class to create targets for a sequence of events. 17 | 18 | Args: 19 | cols (List[str]): The list of columns to use as features. 20 | 21 | Attributes: 22 | cols (List[str]): The list of columns to use as features. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | cols: List[str], 28 | ): 29 | super().__init__() 30 | self.cols = cols 31 | 32 | def build_parameter_dict(self) -> Dict[str, str]: 33 | """Return a dictionary of parameters. 34 | 35 | Returns: 36 | Dict[str, Any]: A dictionary of the target transform parameters 37 | """ 38 | return { 39 | "columns": str(self.cols), 40 | } 41 | 42 | def __call__(self, x: Dict[str, np.ndarray]) -> Dict[str, Union[np.ndarray, Sequence[str]]]: 43 | """Appends the data dict ``x`` with right-shifted copies for all keys specified in ``cols`` and ``dt``. 44 | 45 | Args: 46 | x (Dict[str, np.ndarray]): data dictionary. 47 | 48 | Returns: 49 | Dict[str, Union[np.ndarray, Sequence[str]]]: The appended data dictionary. 50 | """ 51 | for c in self.cols + ["dt"]: 52 | x[f"next_{c}"] = x[c][1:] 53 | assert len(x[f"next_{c}"]) == len(x["t"]) - 1 54 | 55 | return x 56 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/aws/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from typing import Callable 3 | 4 | import pandas as pd 5 | 6 | from .query import caching_query 7 | 8 | 9 | def jsons_files_to_iterable(pattern: str): 10 | """ 11 | Iterates over all files in a directory and yields a json / dictionary for each file. 12 | 13 | Args: 14 | pattern (str): The pattern to match files against. 15 | 16 | Yields: 17 | dict: A dictionary representing the json file. 18 | """ 19 | return files_to_df_iterable(pattern, lambda x: pd.read_json(x, lines=True)) 20 | 21 | 22 | def csv_files_to_iterable(pattern: str): 23 | """ 24 | Iterates over all files in a directory and yields a dataframe for each file. 25 | 26 | Args: 27 | pattern (str): The pattern to match files against. 28 | 29 | Yields: 30 | pd.DataFrame: A dataframe representing the csv file. 31 | """ 32 | return files_to_df_iterable(pattern, lambda x: pd.read_csv(x)) 33 | 34 | 35 | def files_to_df_iterable(pattern: str, loader: Callable = lambda x: caching_query(x, None), sort_files=False): 36 | """ 37 | Iterates over all files in a directory and yields the loaded data for each file. 38 | 39 | Args: 40 | pattern (str): The pattern to match files against. 41 | loader (Callable, optional): The function to load the data. Defaults to lambda x: caching_query(x, None). 42 | sort_files (bool, optional): Whether to sort the files. Defaults to False. 43 | 44 | Yields: 45 | The loaded data for each file. 46 | """ 47 | files = glob.glob(pattern) 48 | if sort_files: 49 | files = sorted(files) 50 | for file in files: 51 | # print(file) 52 | yield (loader(file)) 53 | -------------------------------------------------------------------------------- /tests/test_losses/test_mutual_information.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from neural_lifetimes.losses.mutual_information import MutualInformationGradientEstimator 3 | from neural_lifetimes.losses import InformationBottleneckLoss 4 | 5 | import pytest 6 | 7 | from neural_lifetimes.utils.scheduler import WeightScheduler 8 | 9 | 10 | def identity(x): 11 | return x 12 | 13 | 14 | class TestMutualInformation: 15 | @staticmethod 16 | def get_mi_obj(n_eigen=None, n_eigen_threshold=None): 17 | return MutualInformationGradientEstimator(n_eigen, n_eigen_threshold) 18 | 19 | @pytest.mark.parametrize(("n_eigen", "n_eigen_threshold"), ((None, None), (10, None), (None, 1e-4))) 20 | def test_constructor(self, n_eigen, n_eigen_threshold): 21 | self.get_mi_obj(n_eigen, n_eigen_threshold) 22 | 23 | @pytest.mark.parametrize(("n_eigen", "n_eigen_threshold"), ((None, None), (10, None), (None, 1e-4))) 24 | def test_forward_isnotnan(self, n_eigen, n_eigen_threshold): 25 | fn = self.get_mi_obj(n_eigen, n_eigen_threshold) 26 | assert torch.isfinite(fn(torch.randn((10, 5)), torch.randn((10, 2)))) 27 | 28 | 29 | class TestInformationBottleneckLoss: 30 | @staticmethod 31 | def get_loss_fn(n_eigen: int = None, n_eigen_threshold: float = None): 32 | fit_loss = identity 33 | weight_scheduler = WeightScheduler() 34 | return InformationBottleneckLoss(fit_loss, weight_scheduler, n_eigen, n_eigen_threshold) 35 | 36 | @pytest.mark.xfail 37 | def test_constructor(self): 38 | pass 39 | 40 | @pytest.mark.xfail 41 | def test_reg_weight(self): 42 | assert self.get_loss_fn().reg_weight == 1 43 | 44 | @pytest.mark.xfail 45 | def test_forward(self): 46 | pass 47 | -------------------------------------------------------------------------------- /tests/test_datasets/datamodels.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class DataModel: 6 | pass 7 | 8 | 9 | @dataclass 10 | class EventprofilesDataModel(DataModel): 11 | cont_feat_events = [ 12 | "INVOICE_VALUE_GBP_log", 13 | "FEE_INVOICE_ratio", 14 | "MARGIN_GBP", 15 | ] 16 | cont_feat_profiles = ["AGE_YEARS"] 17 | 18 | discr_feat_events = [ 19 | "ACTION_STATE", 20 | "ACTION_TYPE", 21 | "BALANCE_STEP_TYPE", 22 | "BALANCE_TRANSACTION_TYPE", 23 | "PRODUCT_TYPE", 24 | "SENDER_TYPE", 25 | "SOURCE_CURRENCY", 26 | "SUCCESSFUL_ACTION", 27 | "TARGET_CURRENCY", 28 | ] 29 | 30 | discr_feat_profiles = [ 31 | "ADDRESS_MARKET", 32 | "ADDR_COUNTRY", 33 | "BEST_GUESS_COUNTRY", 34 | "CREATION_PLATFORM", 35 | "CUSTOMER_TYPE", 36 | "FIRST_CCY_SEND", 37 | "FIRST_CCY_TARGET", 38 | "FIRST_VISIT_IP_COUNTRY", 39 | "LANGUAGE", 40 | ] 41 | 42 | target_cols = ["dt"] 43 | cont_feat = cont_feat_events + cont_feat_profiles 44 | discr_feat = discr_feat_events + discr_feat_profiles 45 | 46 | 47 | @dataclass 48 | class EventsOnlyDataModel(DataModel): 49 | cont_feat = [ 50 | "INVOICE_VALUE_GBP", 51 | "FEE_VALUE_GBP", 52 | "MARGIN_GBP", 53 | ] 54 | cont_feat_profiles = [] 55 | discr_feat = [ 56 | "ACTION_STATE", 57 | "ACTION_TYPE", 58 | "BALANCE_STEP_TYPE", 59 | "BALANCE_TRANSACTION_TYPE", 60 | "PRODUCT_TYPE", 61 | "SENDER_TYPE", 62 | "SOURCE_CURRENCY", 63 | "SUCCESSFUL_ACTION", 64 | "TARGET_CURRENCY", 65 | ] 66 | discr_feat_profiles = [] 67 | target_cols = ["dt"] 68 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | venv_name = venv 2 | venv_activate_path := ./$(venv_name)/bin/activate 3 | package_name = neural_lifetimes 4 | cov_args := --cov $(package_name) --cov-report=term-missing 5 | not_slow = -m "not slow" 6 | 7 | .PHONY: clean venv lint test slowtest cov slowcov docs 8 | 9 | clean: 10 | rm -rf ./$(venv_name) 11 | 12 | venv: 13 | python3 -m venv $(venv_name) ;\ 14 | . $(venv_activate_path) ;\ 15 | pip install --upgrade pip setuptools wheel ;\ 16 | pip install --upgrade -r requirements-dev.txt ;\ 17 | pip install --upgrade -r requirements.txt 18 | 19 | update: 20 | . $(venv_activate_path) ;\ 21 | pip install --upgrade pip setuptools wheel ;\ 22 | pip install --upgrade -r requirements-dev.txt ;\ 23 | pip install --upgrade -r requirements.txt 24 | 25 | lint: 26 | . $(venv_activate_path) ;\ 27 | flake8 $(package_name)/ ;\ 28 | flake8 tests/ 29 | 30 | test: 31 | . $(venv_activate_path) ;\ 32 | py.test $(not_slow) --disable-warnings 33 | 34 | slowtest: 35 | . $(venv_activate_path) ;\ 36 | py.test 37 | 38 | cov: 39 | . $(venv_activate_path) ;\ 40 | py.test $(cov_args) $(not_slow) 41 | 42 | slowcov: 43 | . $(venv_activate_path) ;\ 44 | py.test $(cov_args) 45 | 46 | format: 47 | . $(venv_activate_path) ;\ 48 | isort -rc . ;\ 49 | autoflake -r --in-place --remove-unused-variables $(package_name)/ ;\ 50 | autoflake -r --in-place --remove-unused-variables tests/ ;\ 51 | black $(package_name)/ --skip-string-normalization ;\ 52 | black tests/ --skip-string-normalization 53 | 54 | checkformat: 55 | . $(venv_activate_path) ;\ 56 | black $(package_name)/ --skip-string-normalization --check ;\ 57 | black tests/ --skip-string-normalization --check 58 | 59 | docs: 60 | . $(venv_activate_path) ;\ 61 | cd docs/ ;\ 62 | sphinx-apidoc -o . ../neural_lifetimes ;\ 63 | sphinx-build -b html . build 64 | -------------------------------------------------------------------------------- /tests/test_losses/test_weight_scheduler.py: -------------------------------------------------------------------------------- 1 | from neural_lifetimes.utils.scheduler import LinearWarmupScheduler, ExponentialWarmupScheduler 2 | 3 | 4 | class Test_LinearWarmupScheduler: 5 | @staticmethod 6 | def scheduler(n_cold_steps: int = 2, n_warmup_steps: int = 5, target_weight: float = 2): 7 | return LinearWarmupScheduler( 8 | n_cold_steps=n_cold_steps, 9 | n_warmup_steps=n_warmup_steps, 10 | target_weight=target_weight, 11 | ) 12 | 13 | def test_constructor(self): 14 | self.scheduler() 15 | 16 | def test_scheduling(self): 17 | scheduler = self.scheduler() 18 | 19 | scheduler_weights = [] 20 | for _ in range(10): 21 | scheduler_weights.append(scheduler.weight) 22 | scheduler.step() 23 | 24 | expected_weights = [0, 0, 0.4, 0.8, 1.2, 1.6, 2, 2, 2, 2] 25 | assert all([s == e for s, e in zip(scheduler_weights, expected_weights)]) 26 | 27 | 28 | class Test_ExponentialWarmupScheduler: 29 | @staticmethod 30 | def scheduler(n_cold_steps: int = 2, n_warmup_steps: int = 5, target_weight: float = 2, gamma: float = 4): 31 | return ExponentialWarmupScheduler( 32 | n_cold_steps=n_cold_steps, n_warmup_steps=n_warmup_steps, target_weight=target_weight, gamma=gamma 33 | ) 34 | 35 | def test_constructor(self): 36 | self.scheduler() 37 | 38 | def test_scheduling(self): 39 | scheduler = self.scheduler() 40 | 41 | scheduler_weights = [] 42 | for _ in range(10): 43 | scheduler_weights.append(scheduler.weight) 44 | scheduler.step() 45 | 46 | expected_weights = [0, 0, 0.0078125, 0.03125, 0.125, 0.5, 2, 2, 2, 2] 47 | assert all([s == e for s, e in zip(scheduler_weights, expected_weights)]) 48 | -------------------------------------------------------------------------------- /neural_lifetimes/models/nets/event_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from .embedder import CombinedEmbedder 8 | 9 | 10 | # TODO rename 11 | class EventEncoder(nn.Module): 12 | def __init__(self, emb: CombinedEmbedder, rnn_dim: int, drop_rate: float = 0.0, num_layers: int = 1): 13 | super().__init__() 14 | self.emb = emb 15 | 16 | if num_layers == 1: 17 | drop_rate = 0.0 18 | print("Dropout for RNN was set to 0, because num_layers=1.") 19 | 20 | self.rnn = nn.GRU( 21 | input_size=emb.output_shape[1], 22 | hidden_size=rnn_dim, 23 | num_layers=1, 24 | dropout=drop_rate, 25 | batch_first=True, 26 | ) 27 | self.linear = nn.Linear(rnn_dim, rnn_dim) 28 | self.output_shape = [None, rnn_dim] 29 | 30 | def forward(self, x: Dict[str, torch.Tensor]): 31 | # TODO: Eventually, to include initial state features, 32 | # to be fed into the initial RNN state 33 | 34 | # stacked_seq x emb_dim 35 | x_emb = self.emb(x) 36 | 37 | # seq_inds = zip(x["offsets"][:-1], x["offsets"][1:]) 38 | x_stacked = nn.utils.rnn.pack_sequence( 39 | [x_emb[s:e] for s, e in zip(x["offsets"][:-1], x["offsets"][1:])], 40 | enforce_sorted=False, 41 | ) 42 | 43 | # stacked_seq x rnn_dim 44 | x_proc, _ = self.rnn(x_stacked) 45 | padded, lens = nn.utils.rnn.pad_packed_sequence(x_proc) 46 | seq = torch.cat([padded[:seqlen, i] for i, seqlen in enumerate(lens)]) 47 | assert not torch.isnan(seq.data.mean()), "NaN value in rnn output" 48 | 49 | x_out = self.linear(F.relu(F.dropout(seq))) 50 | 51 | return x_out 52 | -------------------------------------------------------------------------------- /docs/neural_lifetimes.models.nets.rst: -------------------------------------------------------------------------------- 1 | neural\_lifetimes.models.nets package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | neural\_lifetimes.models.nets.custom\_par module 8 | ------------------------------------------------ 9 | 10 | .. automodule:: neural_lifetimes.models.nets.custom_par 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | neural\_lifetimes.models.nets.custom\_tvae module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: neural_lifetimes.models.nets.custom_tvae 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | neural\_lifetimes.models.nets.embedder module 24 | --------------------------------------------- 25 | 26 | .. automodule:: neural_lifetimes.models.nets.embedder 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | neural\_lifetimes.models.nets.encoder\_decoder module 32 | ----------------------------------------------------- 33 | 34 | .. automodule:: neural_lifetimes.models.nets.encoder_decoder 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | neural\_lifetimes.models.nets.event\_model module 40 | ------------------------------------------------- 41 | 42 | .. automodule:: neural_lifetimes.models.nets.event_model 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | neural\_lifetimes.models.nets.heads module 48 | ------------------------------------------ 49 | 50 | .. automodule:: neural_lifetimes.models.nets.heads 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | neural\_lifetimes.models.nets.mlp module 56 | ---------------------------------------- 57 | 58 | .. automodule:: neural_lifetimes.models.nets.mlp 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: neural_lifetimes.models.nets 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /tests/test_utils/test_data/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Dict 3 | import numpy as np 4 | 5 | import pytest 6 | 7 | from neural_lifetimes.utils.data import Tokenizer 8 | 9 | 10 | @pytest.fixture 11 | def data(): 12 | return { 13 | "CF1": np.array([1, 2, 3, 4]), 14 | "CF2": np.array([4, 3, 2, 1]), 15 | "Other_CF": np.array([0, 1, 0, 1]), 16 | "DF1": np.array(["level_1", "level_1", "level_2", "level_3"]), 17 | "Other_DF": np.array(["level_1", "level_1", "level_2", "level_3"]), 18 | "dates": np.array( 19 | [datetime(2022, 1, 1), datetime(2022, 2, 1), datetime(2022, 3, 1), datetime(2022, 1, 1)], 20 | dtype=np.datetime64, 21 | ), 22 | } 23 | 24 | 25 | @pytest.fixture 26 | def tokenizer(): 27 | return Tokenizer( 28 | continuous_features=["CF1", "CF2"], 29 | discrete_features=["DF1"], 30 | max_item_len=2, 31 | start_token_continuous=float("inf"), 32 | start_token_discrete="BigBang", 33 | start_token_other=-1, 34 | start_token_timestamp=datetime(1970, 1, 1, 0, 0, 0, 0), 35 | ) 36 | 37 | 38 | def test_constructor(tokenizer): 39 | assert tokenizer 40 | 41 | 42 | def test_call(data: Dict[str, np.ndarray], tokenizer: Tokenizer): 43 | tokenized = tokenizer(data) 44 | 45 | expected = { 46 | "CF1": np.array([np.inf, 3, 4]), 47 | "CF2": np.array([np.inf, 2, 1]), 48 | "Other_CF": np.array([-1, 0, 1]), 49 | "DF1": np.array(["BigBang", "level_2", "level_3"]), 50 | "Other_DF": np.array([-1, "level_2", "level_3"]), 51 | "dates": np.array( 52 | [datetime(1970, 1, 1, 0, 0, 0, 0), datetime(2022, 3, 1), datetime(2022, 1, 1)], 53 | dtype=np.datetime64, 54 | ), 55 | } 56 | 57 | for k in tokenized.keys(): 58 | assert np.all(tokenized[k] == expected[k]) 59 | 60 | 61 | def test_features(tokenizer: Tokenizer): 62 | feat = tokenizer.features 63 | assert feat == ["CF1", "CF2", "DF1"] 64 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/callbacks/projection_monitor.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | 4 | from .get_tensorboard_logger import _get_tensorboard_logger 5 | 6 | 7 | class MonitorProjection(pl.Callback): 8 | def __init__(self, max_batches: int = 100, mod: int = 5): 9 | super().__init__() 10 | self.max_batches = max_batches 11 | self.mod = mod 12 | 13 | def on_train_epoch_end(self, trainer, pl_module): 14 | if trainer.current_epoch % self.mod == 0: 15 | loader = trainer.datamodule.val_dataloader() 16 | logger = _get_tensorboard_logger(trainer) 17 | 18 | encoded_data = [] 19 | labels = [] 20 | for i, mini_batch in enumerate(loader): 21 | for k, v in mini_batch.items(): 22 | if isinstance(v, torch.Tensor): 23 | mini_batch[k] = v.to(pl_module.device) 24 | batch_offsets = mini_batch["offsets"][1:] - 1 25 | embedded, _ = pl_module.encode(mini_batch) 26 | # embedded = pl_module.net.fc_mu(pl_module.net.encoder(mini_batch)) 27 | emb_lastevent = embedded[batch_offsets] 28 | encoded_data.append(emb_lastevent) 29 | if "btyd_mode" in mini_batch.keys(): 30 | mode = mini_batch["btyd_mode"][batch_offsets].to(torch.uint8) 31 | labels.append(mode) 32 | 33 | if len(encoded_data) > self.max_batches: 34 | break 35 | 36 | data = torch.cat(encoded_data) 37 | if labels != []: 38 | labels = torch.cat(labels) 39 | logger.experiment.add_embedding( 40 | data, 41 | metadata=labels, 42 | global_step=trainer.global_step, 43 | tag=f"Embedding Epoch {trainer.current_epoch}", 44 | ) 45 | return 46 | 47 | logger.experiment.add_embedding( 48 | data, 49 | global_step=trainer.global_step, 50 | tag=f"Embedding Epoch {trainer.current_epoch}", 51 | ) 52 | -------------------------------------------------------------------------------- /tests/test_modules/test_configure_optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import torch.nn as nn 5 | 6 | from pytorch_lightning import LightningModule 7 | 8 | from neural_lifetimes.models.modules.configure_optimizers import configure_optimizers 9 | 10 | 11 | @pytest.fixture 12 | def model(): 13 | class TestModel(LightningModule): 14 | def __init__(self, *args: Any, **kwargs: Any) -> None: 15 | super().__init__(*args, **kwargs) 16 | self.net = nn.Linear(10, 10) 17 | 18 | def forward(self, x): 19 | return self.net(x) 20 | 21 | return TestModel() 22 | 23 | 24 | class Test_Configure_Optimizer: 25 | def test_default_args(self, model: LightningModule): 26 | config_dict = configure_optimizers(model.parameters(), lr=0.1) 27 | assert set(config_dict.keys()) == {"optimizer", "lr_scheduler"} 28 | 29 | def test_custom_args(self, model: LightningModule): 30 | config_dict = configure_optimizers( 31 | model.parameters(), 32 | lr=0.1, 33 | optimizer="SGD", 34 | optimizer_kwargs={"momentum": 0.5}, 35 | scheduler="MultiStepLR", 36 | scheduler_kwargs={"milestones": [1, 2, 3]}, 37 | lightning_scheduler_config={"frequency": 5}, 38 | ) 39 | assert config_dict["optimizer"].param_groups[0]["momentum"] == 0.5 40 | assert config_dict["lr_scheduler"]["frequency"] == 5 41 | assert list(config_dict["lr_scheduler"]["scheduler"].milestones) == [1, 2, 3] 42 | assert set(config_dict.keys()) == {"optimizer", "lr_scheduler"} 43 | 44 | def test_raises_optimizer_error(self, model: LightningModule): 45 | with pytest.raises(NotImplementedError) as excinfo: 46 | configure_optimizers(model.parameters(), lr=0.1, optimizer="Newton-Raphson") 47 | (msg,) = excinfo.value.args 48 | assert msg == 'Optimizer "Newton-Raphson" not implemented.' 49 | 50 | def test_raises_multistep_error(self, model: LightningModule): 51 | with pytest.raises(AssertionError) as excinfo: 52 | configure_optimizers(model.parameters(), lr=0.1, scheduler="MultiStepLR") 53 | (msg,) = excinfo.value.args 54 | assert msg == "MultiStepLR requires you to set `milestones` manually." 55 | -------------------------------------------------------------------------------- /docs/intro.rst: -------------------------------------------------------------------------------- 1 | Neural Lifetimes 2 | ^^^^^^^^^^^^^^^^ 3 | 4 | Introduction 5 | ------------ 6 | One of the most important problems a firm faces is the question of consumer 7 | value. Over the years, there have been many attempts to address this issue, with 8 | one of the most successful being the "Buy-Till-You-Die" class of RFM models 9 | that the `lifetimes `_ package is based on. A major pitfall of these models is the 10 | rigid assumptions about distributions of hyperparameters, as well as the lack of 11 | granularity of analysis. Wise_model aims to address these issues with a novel 12 | implementation of recursive neural networks. 13 | 14 | The Neural Lifetimes is a way to easily train a neural network on your data (for more 15 | information on the neural net architecture, see `an overview of the model `_). Once you 16 | have trained a model, you can use the `Inference package `_ to predict customer 17 | actions, extend existing customer sequences, or simulate entirely new sequences. 18 | 19 | The Neural Lifetimes is based on a few assumptions: 20 | 21 | 1. Customers interact with the firm when they are “alive” between each timestep. 22 | 2. At each timestep, there is a probability of the customer "dying". This probability is sampled from the latent space. 23 | 24 | Applications 25 | ~~~~~~~~~~~~ 26 | 27 | Common applications include: 28 | 29 | - Predicting customers transactions (alive = actively purchasing, 30 | dead = not buying anymore). 31 | - Clustering your customers based on their demographic information and 32 | their behaviour. 33 | - Predicting the churn probability your customers given their purchasing behaviour 34 | 35 | 36 | Specific Examples 37 | ~~~~~~~~~~~~~~~~~ 38 | 39 | For some examples of what this library can do, see :doc:`Quickstart <_notebook_BTYD_visualisation>`. 40 | 41 | Installation 42 | ------------ 43 | 44 | :: 45 | 46 | pip install neural-lifetimes 47 | 48 | Documentation and tutorials 49 | --------------------------- 50 | 51 | :doc:`Official documentation ` 52 | 53 | Questions? Comments? Requests? 54 | ------------------------------ 55 | 56 | Please create an issue in the `this 57 | repository `__. 58 | 59 | .. Use the actual url 60 | -------------------------------------------------------------------------------- /neural_lifetimes/losses/elbo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class VariationalEncoderDecoderLoss(nn.Module): 8 | # Adds a variational term to any kind of final loss 9 | # Assumes final loss is a log likelihood 10 | 11 | def __init__(self, fit_loss: nn.Module, reg_weight=1): 12 | super().__init__() 13 | self.final_loss = fit_loss 14 | self.reg_weight = reg_weight 15 | 16 | def forward(self, model_out, target_x) -> torch.Tensor: # changed order to follow pytorch convention 17 | 18 | fit_loss, losses_dict = self.final_loss(model_out, target_x) 19 | mu = model_out["mu"] 20 | std = model_out["std"] 21 | 22 | if self.reg_weight is not None: 23 | # see Appendix B from VAE paper: 24 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 25 | # https://arxiv.org/abs/1312.6114 26 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 27 | KLD_element = 1 + torch.log(std * std) - mu * mu - std * std 28 | KLD = -0.5 * torch.mean(KLD_element) 29 | my_loss = fit_loss + self.reg_weight * KLD 30 | else: # the case when we do no sampling 31 | my_loss = fit_loss 32 | KLD = 0.0 33 | 34 | losses_dict["kl_div"] = KLD 35 | losses_dict["model_fit"] = fit_loss 36 | losses_dict["total"] = my_loss 37 | losses_dict = {f"loss/{name}": loss for name, loss in losses_dict.items()} 38 | 39 | assert my_loss not in [-np.inf, np.inf], "Loss not finite!" 40 | assert not torch.isnan(my_loss), "Got a NaN loss" 41 | 42 | assert sum(losses_dict.values()) not in [-np.inf, np.inf], "Loss not finite!" 43 | assert not torch.isnan(sum(losses_dict.values())), "Got a NaN loss" 44 | 45 | return my_loss, losses_dict 46 | 47 | 48 | # class ELBOLoss(nn.Module): 49 | # def __init__(self, fit_loss: nn.Module, reg_weight) -> None: 50 | # super().__init__() 51 | # self.fit_loss = fit_loss 52 | # self.reg_weight = reg_weight 53 | 54 | # def forward(self, pred: torch.Tensor, target: torch.Tensor, latent: torch.Tensor) -> torch.Tensor: 55 | # fit_loss, losses_dict = self.final_loss(model_out, target_x) 56 | # mu = model_out["mu"] 57 | # std = model_out["std"] 58 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/clickhouse/clickhouse_ingest.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Callable, Dict, Sequence 3 | 4 | import numpy as np 5 | from clickhouse_driver import Client 6 | 7 | from neural_lifetimes.data.utils import normalize_types 8 | from neural_lifetimes.utils.aws import caching_query 9 | 10 | from .schema import make_clickhouse_schema 11 | 12 | 13 | def clickhouse_ingest( 14 | db_io, 15 | client: Client, 16 | insert_fn: Callable, 17 | daily_fn: Callable, 18 | start_date: datetime.date, 19 | end_date: datetime.date, 20 | data_dir: str, 21 | table_name: str, 22 | uid_name: str, 23 | time_col: str, 24 | high_granularity: Sequence = (), 25 | flush_table: bool = False, 26 | verbose: bool = False, 27 | ) -> None: 28 | 29 | # iterates over a range of dates and dumps the results into clickhouse 30 | 31 | assert start_date <= end_date, "Start date can not be after end date!" 32 | initialized = False 33 | this_date = start_date 34 | while this_date < end_date: 35 | fn = data_dir + f"events_{daily_fn.__name__}_{this_date}.h5" 36 | if verbose: 37 | print(fn) 38 | this_df = caching_query(fn, lambda: daily_fn(db_io, this_date)) 39 | if this_df is not None: 40 | 41 | if not initialized: 42 | client.execute("CREATE DATABASE IF NOT EXISTS events") 43 | if flush_table: 44 | client.execute("DROP TABLE IF EXISTS events.ras_slice") 45 | dtypes = this_df.dtypes 46 | schema = make_clickhouse_schema( 47 | dtypes, 48 | table_name, 49 | (uid_name, time_col), 50 | high_granularity=high_granularity, 51 | ) 52 | client.execute(schema) 53 | initialized = True 54 | 55 | this_df = normalize_types(this_df) 56 | this_df[time_col] = this_df[time_col].dt.strftime("%Y-%m-%d %H:%M:%S") 57 | insert_fn(this_df) 58 | 59 | this_date += datetime.timedelta(days=1) 60 | 61 | return dtypes 62 | 63 | 64 | def clickhouse_ranges(client: Client, discr_feat: Sequence[str], table_name: str) -> Dict[str, np.ndarray]: 65 | out = {f: np.array(client.execute(f"SELECT DISTINCT {f} from {table_name}")) for f in discr_feat} 66 | return out 67 | -------------------------------------------------------------------------------- /neural_lifetimes/models/nets/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions.normal import Normal 4 | 5 | 6 | class VariationalEncoderDecoder(nn.Module): 7 | """ 8 | An implementation of variational encoder and decoder. 9 | 10 | Args: 11 | encoder(nn.Module): A model mapping batches of source domain to batches of vectors (batch x input_size) 12 | decoder(nn.Module): Model mapping latent z (batch x z_size) to target domain 13 | sample_z(bool): Whether to sample z = N(mu, std) or just take z=mu. Defaults to ``True``. 14 | epsilon_std(float): Scaling factor for sampling, low values help convergence. Defaults to ``1.0``. 15 | 16 | Note: 17 | See https://github.com/mkusner/grammarVAE/issues/7 18 | """ 19 | 20 | def __init__( 21 | self, 22 | encoder: nn.Module, 23 | decoder: nn.Module, 24 | sample_z: bool = True, 25 | epsilon_std: float = 1.0, 26 | ): 27 | super().__init__() 28 | self.encoder = encoder 29 | self.decoder = decoder 30 | self.z_size = self.decoder.input_shape[1] 31 | self.sample_z = sample_z 32 | self.epsilon_std = epsilon_std 33 | 34 | self.fc_mu = nn.Linear(self.encoder.output_shape[-1], self.z_size) 35 | self.fc_log_var = nn.Linear(self.encoder.output_shape[-1], self.z_size) 36 | 37 | def forward(self, x): 38 | """ 39 | Encoder-decoder forward pass. 40 | 41 | Args: 42 | x: A batch of source domain data (batch x input_size) 43 | 44 | Returns: 45 | The data after being passed through the encoder and decoder. 46 | """ 47 | enc_out = self.encoder(x) 48 | mu = self.fc_mu(enc_out) 49 | log_var = self.fc_log_var(enc_out) 50 | 51 | eps = 0.0001 52 | std = torch.square(log_var) + eps # no longer log, why? 53 | 54 | if self.sample_z: 55 | z = Normal(mu, self.epsilon_std * std).rsample() 56 | else: 57 | z = mu 58 | 59 | output = self.decoder(z) 60 | output["mu"] = mu 61 | output["std"] = std 62 | output["sampled_z"] = z 63 | 64 | return output 65 | 66 | def load(self, weights_file): 67 | print("Trying to load model parameters from ", weights_file) 68 | self.load_state_dict(torch.load(weights_file)) 69 | self.eval() 70 | print("Success!") 71 | -------------------------------------------------------------------------------- /neural_lifetimes/metrics/kullback_leibler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Any, Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.distributions.kl import kl_divergence 6 | 7 | import torchmetrics 8 | 9 | 10 | class KullbackLeiblerDivergence(torchmetrics.Metric): 11 | def __init__(self, compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any]) -> None: 12 | super().__init__() 13 | self.add_state("preds", default=[], dist_reduce_fx="cat") 14 | self.add_state("target", default=[], dist_reduce_fx="cat") 15 | 16 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 17 | # expects both preds ant target to be log scale 18 | self.preds.append(preds) 19 | self.target.append(target) 20 | 21 | def compute(self) -> torch.Tensor: 22 | preds = torch.cat(self.preds) 23 | target = torch.cat(self.target) 24 | 25 | preds_logs_prob, target_log_prob = _histogram(preds, target) 26 | return F.kl_div(preds_logs_prob, target_log_prob, log_target=True, reduction="none").mean() 27 | 28 | 29 | class ParametricKullbackLeiblerDivergence(torchmetrics.Metric): 30 | def __init__( 31 | self, 32 | distribution: torch.distributions.Distribution, 33 | compute_on_step: Optional[bool] = None, 34 | **kwargs: Dict[str, Any] 35 | ) -> None: 36 | super().__init__() 37 | self.add_state("preds", default=[], dist_reduce_fx="cat") 38 | self.add_state("target", default=[], dist_reduce_fx="cat") 39 | self.distribution = distribution 40 | 41 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 42 | preds, target = preds.flatten(), target.flatten() 43 | 44 | assert preds.shape == target.shape 45 | assert preds.dim() == 1 46 | 47 | self.preds.append(preds) 48 | self.target.append(target) 49 | 50 | def compute(self) -> torch.Tensor: 51 | preds = torch.cat(self.preds) 52 | target = torch.cat(self.target) 53 | 54 | p = self.distribution(target) 55 | q = self.distribution(preds) 56 | kl_div = kl_divergence(p, q) 57 | 58 | return kl_div.mean() 59 | 60 | 61 | def _histogram(preds: torch.Tensor, target: torch.Tensor, nbins: int = 50, eps: float = 1.0e-6) -> Tuple[torch.Tensor]: 62 | assert len(preds) > 0 and len(target) > 0 63 | 64 | min_ = min(preds.min(), target.min()) 65 | max_ = max(preds.max(), target.max()) 66 | pred_log_prob = torch.log((torch.histc(preds, min=min_, max=max_) / len(preds)) + eps) 67 | target_log_prob = torch.log((torch.histc(target, min=min_, max=max_) / len(target)) + eps) 68 | return pred_log_prob, target_log_prob 69 | -------------------------------------------------------------------------------- /tests/test_modules/test_combined_embedder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from neural_lifetimes.models.nets import CombinedEmbedder 5 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder 6 | 7 | 8 | @pytest.fixture 9 | def embed_dim(): 10 | return 64 11 | 12 | 13 | @pytest.fixture 14 | def drop_rate(): 15 | return 0.1 16 | 17 | 18 | class TestEmbedderConstruction: 19 | @staticmethod 20 | def _get_embedder(continuous_features, discrete_features, *args, **kwargs) -> CombinedEmbedder: 21 | encoder = FeatureDictionaryEncoder(continuous_features, discrete_features) 22 | return CombinedEmbedder(encoder, *args, **kwargs) 23 | 24 | # TODO The embedder should be able to handle only continuous or only discrete features. 25 | # One currently leads to a warning. 26 | @pytest.mark.parametrize("continuous_features", [[], ["FEAT_1", "FEAT_2"]]) 27 | @pytest.mark.parametrize("category_dict", [{}, {"CAT_1": [1, 4], "CAT_2": ["A", "B"]}]) 28 | def test_constructor(self, continuous_features, category_dict, embed_dim, drop_rate): 29 | self._get_embedder(continuous_features, category_dict, embed_dim, drop_rate) 30 | 31 | # TODO: empty category_dicts should not be allowed. e.g. {'CAT_1': []} 32 | @pytest.mark.xfail 33 | @pytest.mark.parametrize("category_dict", [{"CAT_1": []}]) 34 | def test_constructor_empty_set_cat_dict( 35 | self, continuous_features, category_dict, embed_dim, drop_rate, pre_encoded=True 36 | ): 37 | with pytest.raises(Exception): 38 | self._get_embedder(continuous_features, category_dict, embed_dim, drop_rate, pre_encoded) 39 | 40 | def test_parameter_dict(self, embed_dim, drop_rate): 41 | emb = self._get_embedder([], {}, embed_dim, drop_rate) 42 | pars = emb.build_parameter_dict() 43 | expected = { 44 | "embed_dim": embed_dim, 45 | "embedder_drop_rate": drop_rate, 46 | } 47 | assert pars == expected 48 | 49 | # TODO this should work. Somehow it doesnt. 50 | @pytest.mark.xfail 51 | @pytest.mark.parametrize("continuous_features", [[], ["FEAT_1", "FEAT_2"]]) 52 | @pytest.mark.parametrize("category_dict", [[], {"CAT_1": [1, 4], "CAT_2": ["A", "B"]}]) 53 | @pytest.mark.parametrize("pre_encoded", [True, False]) 54 | def test_forward(self, continuous_features, category_dict, embed_dim, drop_rate, pre_encoded): 55 | emb = self._get_embedder(continuous_features, category_dict, embed_dim, drop_rate, pre_encoded) 56 | x = { 57 | "FEAT_1": torch.tensor([0.5, 0.1]), 58 | "FEAT_2": torch.tensor([3.5, 1]), 59 | "CAT_1": [1, 4], 60 | "CAT_2": ["a", "whatever"], 61 | } 62 | emb(x) 63 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "Neural Lifetimes" 22 | copyright = "2021, Wise" 23 | author = "Wise" 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.doctest", 34 | "sphinx.ext.napoleon", 35 | "nbsphinx", 36 | "sphinx.ext.mathjax", 37 | "nbsphinx_link", 38 | "sphinx_rtd_theme", 39 | ] 40 | 41 | # Have autodoc import torch but not actually run it? 42 | autodoc_mock_imports = ["torch", "pytorch_lightning"] 43 | nbsphinx_execute = 'never' 44 | nbsphinx_allow_errors = True 45 | 46 | # Napoleon settings 47 | napoleon_google_docstring = True 48 | napoleon_numpy_docstring = False 49 | napoleon_use_ivar = True 50 | napoleon_use_param = True 51 | napoleon_use_rtype = False 52 | napoleon_use_keyword = True 53 | napoleon_include_private_with_doc = False 54 | napoleon_include_special_with_doc = True 55 | 56 | # Add any paths that contain templates here, relative to this directory. 57 | templates_path = ["_templates"] 58 | 59 | # List of patterns, relative to source directory, that match files and 60 | # directories to ignore when looking for source files. 61 | # This pattern also affects html_static_path and html_extra_path. 62 | exclude_patterns = ["_build", "**.ipynb_checkpoints", "Thumbs.db", ".DS_Store", "venv"] 63 | 64 | 65 | # -- Options for HTML output ------------------------------------------------- 66 | 67 | # The theme to use for HTML and HTML Help pages. See the documentation for 68 | # a list of builtin themes. 69 | # 70 | html_theme = "sphinx_rtd_theme" 71 | 72 | # Add any paths that contain custom static files (such as style sheets) here, 73 | # relative to this directory. They are copied after the builtin static files, 74 | # so a file named "default.css" will overwrite the builtin "default.css". 75 | html_static_path = ["_static"] 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # example data 2 | examples/data/ 3 | 4 | # logs 5 | logs/ 6 | examples/logs/ 7 | 8 | # Editor configs 9 | .vscode/ 10 | .vs/* 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | .vscode/ 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | _build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | # docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | /storage/ 143 | 144 | 145 | *.DS_Store 146 | 147 | # saved models 148 | logs/* 149 | *.DS_Store 150 | 151 | .idea 152 | .envrc 153 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/date_arithmetic.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Union 3 | import numpy as np 4 | 5 | 6 | def date2time(date: datetime.date) -> datetime.datetime: 7 | """ 8 | Convert a date to a datetime by adding the minimum time. 9 | 10 | Args: 11 | date (datetime.date): The date to convert. 12 | 13 | Returns: 14 | datetime.datetime: The converted date. 15 | """ 16 | return datetime.datetime.combine(date, datetime.datetime.min.time()) 17 | 18 | 19 | DatetimeLike = Union[np.array, datetime.datetime] 20 | TimestampLike = Union[np.array, float, int] 21 | 22 | _conversion_factors = { 23 | "d": 1e6 * 60 * 60 * 24, 24 | "h": 1e6 * 60 * 60, 25 | "min": 1e6 * 60, 26 | "s": 1e6, 27 | "ms": 1000, 28 | } 29 | 30 | 31 | def datetime2float(t: DatetimeLike, unit: str = "h") -> TimestampLike: 32 | """Converts datetime-like objects to timestamps given a certain unit. 33 | 34 | Args: 35 | t (DatetimeLike): The datetime-like object, e.g. datetime.datetime or np.array["datetime64[us]"] 36 | unit (str, optional): Units for the time scale. supported are "d", "h", "min", "s", "ms". Defaults to "h". 37 | 38 | Raises: 39 | ValueError: When "unit" not supported. 40 | TypeError: when type of "t" not supported. 41 | 42 | Returns: 43 | TimestampLike: float or np.array[np.float32] with timestamps. 44 | """ 45 | if unit not in _conversion_factors: 46 | raise ValueError(f"Unit parameter '{unit}' not supported.") 47 | if isinstance(t, datetime.datetime): 48 | t = t.timestamp() * 1e6 49 | elif isinstance(t, np.ndarray): 50 | t = t.astype("datetime64[us]").astype(np.int64).astype(np.float32) 51 | else: 52 | raise TypeError(f"Datetime must be of type 'DatetimeLike'. Got '{type(t).__name__}'") 53 | # time is in microseconds 54 | t = t / _conversion_factors[unit] 55 | # time is in hours 56 | return t 57 | 58 | 59 | def float2datetime(t: TimestampLike, unit: str = "h") -> DatetimeLike: 60 | """Converts float values to datetime objects. 61 | 62 | Args: 63 | t (TimestampLike): numeric value with timestamp. 64 | unit (str, optional): Units for the time scale. supported are "d", "h", "min", "s", "ms". Defaults to "h". 65 | 66 | Raises: 67 | ValueError: When "unit" not supported. 68 | TypeError: when type of "t" not supported. 69 | 70 | Returns: 71 | DatetimeLike: The datetimes. 72 | """ 73 | if unit not in _conversion_factors: 74 | raise ValueError(f"Unit parameter '{unit}' not supported.") 75 | t = t * _conversion_factors[unit] 76 | if isinstance(t, (float, int)): 77 | t = datetime.datetime.fromtimestamp(t / 1e6) 78 | elif isinstance(t, np.ndarray): 79 | t = t.astype("datetime64[us]") 80 | else: 81 | raise TypeError(f"Datetime must be of type 'TimestampLike'. Got '{type(t).__name__}'") 82 | 83 | return t 84 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to neural-lifetimes! 2 | 3 | When contributing to this repository, please first discuss the change you wish to make via issue, or any other method with the owners of this repository before making a change. 4 | 5 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's: 6 | 7 | - Reporting a bug 8 | - Discussing the current state of the code 9 | - Submitting a fix 10 | - Proposing new features 11 | - Becoming a maintainer 12 | 13 | Please keep all your communication respectful. 14 | 15 | 16 | ## We Develop with GitHub 17 | We use GitHub to host code, to track public issues and feature requests from the community, as well as accept pull requests. 18 | 19 | 20 | ## We Use [GitHub Flow](https://guides.github.com/introduction/flow/index.html), So All Code Changes Happen Through Pull Requests 21 | Pull requests are the best way to propose changes to the codebase (we use [GitHub Flow](https://guides.github.com/introduction/flow/index.html)). We actively welcome your pull requests: 22 | 23 | 1. Fork the repo and create your branch from `main`. 24 | 2. If you've added code that should be tested, add tests: unit and End-2-End if possible. 25 | 3. If you've added a new feature or changed the behavior of existing one, update the tests and the relevant documentation in [README.md](./README.md) and [online documentation code](./docs). 26 | 4. Ensure the test suite passes. 27 | 5. Make sure your code lints. 28 | 6. Issue that pull request! 29 | 30 | 31 | ## Any contributions you make will be under the Apache License Version 2.0. 32 | In short, when you submit code changes, your submissions are understood to be under the same [Apache License Version 2.0](./LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. 33 | 34 | ## Report bugs using GitHub's [issues](https://github.com/transferwise/neural-lifetimes/issues) 35 | We use GitHub issues to track public bugs. Report a bug by [opening a new issue](https://github.com/transferwise/neural-lifetimes/issues/new); it's that easy! 36 | 37 | ## Write bug reports with detail, background and setup 38 | Here's [a great example from Craig Hockenberry](http://www.openradar.me/11905408) 39 | 40 | **Great Bug Reports** tend to have: 41 | 42 | - A quick summary and/or background 43 | - Steps to reproduce 44 | - Be specific! 45 | - Describe your source/target setup. 46 | - What you expected would happen 47 | - What actually happens 48 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) 49 | 50 | People *love* thorough bug reports. I'm not even kidding. 51 | 52 | ## Use a Consistent Coding Style 53 | 54 | * Tabs for indentation 55 | * Google docstring format for Python documentation. 56 | * Single quotes for string literals 57 | * We've started using SonarLint PyCharm plugin to detect code complexity among other issues to improve code quality. 58 | * You can try running `black` for autoformatting. You can do this with `make format` 59 | 60 | ## Versioning 61 | We use [Semantic versioning](https://semver.org/). 62 | 63 | ## License 64 | By contributing, you agree that your contributions will be licensed under its Apache License Version 2.0 65 | -------------------------------------------------------------------------------- /tests/test_modules/test_variational_event_model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | 5 | import numpy as np 6 | 7 | import pytest 8 | 9 | from neural_lifetimes import run_model 10 | from neural_lifetimes.data.datamodules.sequence_datamodule import SequenceDataModule 11 | from neural_lifetimes.data.datasets.btyd import BTYD, GenMode 12 | from neural_lifetimes.models.modules import VariationalEventModel 13 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder, TargetCreator, Tokenizer 14 | 15 | from ..test_datasets.datamodels import EventprofilesDataModel 16 | 17 | DISCRETE_START_TOKEN = "StartToken" 18 | 19 | 20 | @pytest.fixture 21 | def log_dir() -> str: 22 | with TemporaryDirectory() as f: 23 | yield f 24 | 25 | 26 | @pytest.fixture 27 | def data_dir() -> str: 28 | return str(Path(__file__).parents[2] / "examples") 29 | 30 | 31 | @pytest.mark.slow 32 | class TestVariationalEventModel: 33 | @pytest.mark.parametrize("vae_sample_z", ("True", "False")) 34 | def test_btyd(self, vae_sample_z: bool, log_dir: str, data_dir: str) -> None: 35 | # create btyd data and dependent modules 36 | data_model = EventprofilesDataModel() 37 | 38 | dataset = BTYD.from_modes( 39 | modes=[GenMode(a=1, b=3, r=1, alpha=15)], 40 | num_customers=100, 41 | seq_gen_dynamic=True, 42 | start_date=datetime.datetime(2019, 1, 1, 0, 0, 0), 43 | start_limit_date=datetime.datetime(2019, 6, 15, 0, 0, 0), 44 | end_date=datetime.datetime(2021, 1, 1, 0, 0, 0), 45 | data_dir=data_dir, 46 | continuous_features=data_model.cont_feat, 47 | discrete_features=data_model.discr_feat, 48 | ) 49 | 50 | discr_values = dataset.get_discrete_feature_values("") 51 | 52 | tokenizer = Tokenizer( 53 | data_model.cont_feat, 54 | discr_values, 55 | 100, 56 | np.nan, 57 | "", 58 | datetime.datetime(1970, 1, 1, 0, 0, 0), 59 | np.nan, 60 | ) 61 | transform = FeatureDictionaryEncoder(data_model.cont_feat, discr_values) 62 | 63 | target_transform = TargetCreator(cols=(data_model.target_cols + data_model.cont_feat + data_model.discr_feat)) 64 | datamodule = SequenceDataModule( 65 | dataset=dataset, 66 | tokenizer=tokenizer, 67 | transform=transform, 68 | target_transform=target_transform, 69 | test_size=0.2, 70 | batch_points=256, 71 | min_points=1, 72 | ) 73 | 74 | # create model 75 | model = VariationalEventModel( 76 | feature_encoder_config=transform.config_dict(), 77 | rnn_dim=256, 78 | emb_dim=256, 79 | drop_rate=0.5, 80 | bottleneck_dim=32, 81 | lr=0.001, 82 | target_cols=target_transform.cols, 83 | vae_sample_z=vae_sample_z, 84 | vae_sampling_scaler=1.0, 85 | vae_KL_weight=0.01, 86 | ) 87 | 88 | run_model( 89 | datamodule, 90 | model, 91 | log_dir=log_dir, 92 | num_epochs=2, 93 | val_check_interval=2, 94 | limit_val_batches=2, 95 | ) 96 | 97 | # TODO add tests for: different rnn_dims, drop_rates, bottleneck_dims. 98 | # TODO add failure tests for negative learning rate, negative sacler and negative KL weights 99 | -------------------------------------------------------------------------------- /neural_lifetimes/data/utils.py: -------------------------------------------------------------------------------- 1 | """A series of functions to convert dataframes to pytorch tensors.""" 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | 8 | from pandas.api.types import is_float_dtype, is_integer_dtype, is_numeric_dtype 9 | 10 | 11 | def torchify(x: Dict[str, np.ndarray]) -> Dict[str, Union[np.ndarray, torch.Tensor]]: 12 | out = {} 13 | for key, val in x.items(): 14 | try: 15 | out[key] = torch.from_numpy(val) 16 | except TypeError: 17 | out[key] = val 18 | return out 19 | 20 | 21 | # TODO: remove if no longer needed 22 | # torchify: copies dtype and references data. this function: forces dtype and copies data. 23 | def torchify_old(x: Dict[str, np.ndarray]) -> Dict[str, Union[np.ndarray, torch.Tensor]]: 24 | """ 25 | Cast all numerical elements to tensors, forcing float64 to float32. 26 | 27 | Args: 28 | x (Dict[str, np.ndarray]): the dictionary to cast 29 | device (torch.device): the device to cast to 30 | 31 | Returns: 32 | Dict[str, Union[np.ndarray, torch.Tensor]]: the dictionary with all numerical elements cast to tensors 33 | """ 34 | out = {} 35 | for k, v in x.items(): 36 | if is_numeric_dtype(v.dtype): 37 | nice_type = normalize_types(v.dtype) 38 | out[k] = torch.tensor(v.astype(nice_type)) 39 | else: 40 | out[k] = v 41 | return out 42 | 43 | 44 | # TODO: remove if no longer needed 45 | def normalize_types(x: np.dtype): 46 | """ 47 | Find the type to be cast to for a given dtype: np.float32 for floats, np.int64 for ints. 48 | 49 | Note the conversion of ints is required for pytorch loss functions. 50 | 51 | Args: 52 | x (np.dtype): the dtype for which to determine cast type. 53 | """ 54 | if is_float_dtype(x): 55 | return np.float32 56 | elif is_integer_dtype(x): 57 | return np.int64 # torch loss functions require that 58 | else: 59 | return x 60 | 61 | 62 | # TODO: remove if no longer needed 63 | def remove_isolated_transactions(df: pd.DataFrame, uid_name: str): 64 | """ 65 | Delete users which only have 1 transaction. 66 | 67 | Args: 68 | df (pd.DataFrame): The dataframe to clean. 69 | uid_name (str): The name of the user id column. 70 | 71 | Returns: 72 | pd.DataFrame: The cleaned dataframe. 73 | """ 74 | counts = df[uid_name].value_counts() 75 | 76 | out = df[df[uid_name].isin(set(counts[counts > 1].index))] 77 | return out 78 | 79 | 80 | def detorch(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray: 81 | """ 82 | Convert a torch tensor to a numpy array. 83 | 84 | Args: 85 | x (Union[torch.Tensor, np.ndarray]): The tensor to convert. 86 | 87 | Returns: 88 | np.ndarray: The numpy array. 89 | """ 90 | if isinstance(x, torch.Tensor): 91 | return x.cpu().detach().numpy() 92 | else: 93 | return x 94 | 95 | 96 | # TODO: remove if no longer needed 97 | def batch_to_dataframe(x: Dict[str, Union[torch.Tensor, np.ndarray]]) -> pd.DataFrame: 98 | """ 99 | Convert a dictionary of tensors or numpy arrays to a pandas dataframe. Drops the "offsets" column. 100 | 101 | Args: 102 | x (Dict[str, Union[torch.Tensor, np.ndarray]]): The dictionary to convert. 103 | 104 | Returns: 105 | pd.DataFrame: The dataframe. 106 | """ 107 | return pd.DataFrame.from_dict({k: detorch(v) for k, v in x.items() if k != "offsets"}) 108 | -------------------------------------------------------------------------------- /examples/train_btyd.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | import pytorch_lightning as pl 7 | 8 | from neural_lifetimes import run_model 9 | from neural_lifetimes.data.datamodules import SequenceDataModule 10 | from neural_lifetimes.data.datasets.btyd import BTYD, GenMode 11 | from neural_lifetimes.models.modules import VariationalEventModel 12 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder, Tokenizer, TargetCreator 13 | from examples import eventsprofiles_datamodel 14 | 15 | 16 | LOG_DIR = str(Path(__file__).parent / 'logs') 17 | data_dir = str(Path(__file__).parent.absolute()) 18 | 19 | START_TOKEN_DISCR = "" 20 | COLS = eventsprofiles_datamodel.target_cols + eventsprofiles_datamodel.cont_feat + eventsprofiles_datamodel.discr_feat 21 | 22 | if __name__ == "__main__": 23 | pl.seed_everything(9473) 24 | 25 | btyd_dataset = BTYD.from_modes( 26 | modes=[ 27 | GenMode(a=2, b=10, r=5, alpha=10), 28 | GenMode(a=2, b=10, r=10, alpha=600), 29 | ], 30 | num_customers=1000, 31 | mode_ratios=[2.5, 1], # generate equal number of transactions from each mode 32 | seq_gen_dynamic=False, 33 | start_date=datetime.datetime(2019, 1, 1, 0, 0, 0), 34 | start_limit_date=datetime.datetime(2019, 6, 15, 0, 0, 0), 35 | end_date=datetime.datetime(2021, 1, 1, 0, 0, 0), 36 | data_dir=data_dir, 37 | continuous_features=eventsprofiles_datamodel.cont_feat, 38 | discrete_features=eventsprofiles_datamodel.discr_feat, 39 | track_statistics=True, 40 | ) 41 | 42 | btyd_dataset[:] 43 | print(f"Expected Num Transactions per mode: {btyd_dataset.expected_num_transactions_from_priors()}") 44 | print(f"Expected p churn per mode: {btyd_dataset.expected_p_churn_from_priors()}") 45 | print(f"Expected time interval per mode: {btyd_dataset.expected_time_interval_from_priors()}") 46 | print(f"Truncated sequences: {btyd_dataset.truncated_sequences}") 47 | 48 | btyd_dataset.plot_tracked_statistics().show() 49 | 50 | discrete_values = btyd_dataset.get_discrete_feature_values( 51 | start_token=START_TOKEN_DISCR, 52 | ) 53 | 54 | encoder = FeatureDictionaryEncoder(eventsprofiles_datamodel.cont_feat, discrete_values) 55 | 56 | tokenizer = Tokenizer( 57 | continuous_features=eventsprofiles_datamodel.cont_feat, 58 | discrete_features=discrete_values, 59 | start_token_continuous=np.nan, 60 | start_token_discrete=START_TOKEN_DISCR, 61 | start_token_other=np.nan, 62 | max_item_len=100, 63 | start_token_timestamp=datetime.datetime(1970, 1, 1, 1, 0), 64 | ) 65 | 66 | target_transform = TargetCreator(cols=COLS) 67 | 68 | datamodule = SequenceDataModule( 69 | dataset=btyd_dataset, 70 | tokenizer=tokenizer, 71 | transform=encoder, 72 | target_transform=target_transform, 73 | test_size=0.2, 74 | batch_points=1024, 75 | min_points=1, 76 | ) 77 | 78 | net = VariationalEventModel( 79 | feature_encoder_config=encoder.config_dict(), 80 | rnn_dim=256, 81 | emb_dim=256, 82 | drop_rate=0.5, 83 | bottleneck_dim=32, 84 | lr=0.001, 85 | target_cols=COLS, 86 | vae_sample_z=True, 87 | vae_sampling_scaler=1.0, 88 | vae_KL_weight=0.01, 89 | ) 90 | 91 | run_model( 92 | datamodule, 93 | net, 94 | log_dir=LOG_DIR, 95 | num_epochs=50, 96 | val_check_interval=10, 97 | limit_val_batches=20, 98 | gradient_clipping=0.0000001, 99 | ) 100 | -------------------------------------------------------------------------------- /tests/test_modules/test_ib_event_model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | 5 | import numpy as np 6 | 7 | import pytest 8 | 9 | from neural_lifetimes import run_model 10 | from neural_lifetimes.data.datamodules.sequence_datamodule import SequenceDataModule 11 | from neural_lifetimes.data.datasets.btyd import BTYD, GenMode 12 | from neural_lifetimes.models.modules import InformationBottleneckEventModel 13 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder, TargetCreator, Tokenizer 14 | 15 | from ..test_datasets.datamodels import EventprofilesDataModel 16 | 17 | DISCRETE_START_TOKEN = "StartToken" 18 | 19 | 20 | @pytest.fixture 21 | def log_dir() -> str: 22 | with TemporaryDirectory() as f: 23 | yield f 24 | 25 | 26 | @pytest.fixture 27 | def data_dir() -> str: 28 | return str(Path(__file__).parents[2] / "examples") 29 | 30 | 31 | @pytest.mark.slow 32 | class TestInformationBottleneckEventModel: 33 | @pytest.mark.parametrize("vae_sample_z", ("True", "False")) 34 | def test_btyd(self, vae_sample_z: bool, log_dir: str, data_dir: str) -> None: 35 | # create btyd data and dependent modules 36 | data_model = EventprofilesDataModel() 37 | 38 | dataset = BTYD.from_modes( 39 | modes=[GenMode(a=1, b=3, r=1, alpha=15)], 40 | num_customers=100, 41 | seq_gen_dynamic=True, 42 | start_date=datetime.datetime(2019, 1, 1, 0, 0, 0), 43 | start_limit_date=datetime.datetime(2019, 6, 15, 0, 0, 0), 44 | end_date=datetime.datetime(2021, 1, 1, 0, 0, 0), 45 | data_dir=data_dir, 46 | continuous_features=data_model.cont_feat, 47 | discrete_features=data_model.discr_feat, 48 | ) 49 | 50 | discr_values = dataset.get_discrete_feature_values("") 51 | 52 | tokenizer = Tokenizer( 53 | data_model.cont_feat, 54 | discr_values, 55 | 100, 56 | np.nan, 57 | "", 58 | datetime.datetime(1970, 1, 1, 0, 0, 0), 59 | np.nan, 60 | ) 61 | transform = FeatureDictionaryEncoder(data_model.cont_feat, discr_values) 62 | 63 | target_transform = TargetCreator(cols=(data_model.target_cols + data_model.cont_feat + data_model.discr_feat)) 64 | datamodule = SequenceDataModule( 65 | dataset=dataset, 66 | tokenizer=tokenizer, 67 | transform=transform, 68 | target_transform=target_transform, 69 | test_size=0.2, 70 | batch_points=256, 71 | min_points=1, 72 | ) 73 | 74 | loss_cfg = { 75 | "n_cold_steps": 100, 76 | "n_warmup_steps": 100, 77 | "n_target_weight": 0.001, 78 | "n_eigen": None, 79 | "n_eigen_threshold": None, 80 | } 81 | 82 | # create model 83 | model = InformationBottleneckEventModel( 84 | feature_encoder_config=transform.config_dict(), 85 | rnn_dim=256, 86 | emb_dim=256, 87 | drop_rate=0.5, 88 | bottleneck_dim=32, 89 | lr=0.001, 90 | target_cols=target_transform.cols, 91 | loss_cfg=loss_cfg, 92 | ) 93 | 94 | run_model( 95 | datamodule, 96 | model, 97 | log_dir=log_dir, 98 | num_epochs=2, 99 | val_check_interval=2, 100 | limit_val_batches=2, 101 | ) 102 | 103 | # TODO add tests for: different rnn_dims, drop_rates, bottleneck_dims. 104 | # TODO add failure tests for negative learning rate, negative sacler and negative KL weights 105 | -------------------------------------------------------------------------------- /neural_lifetimes/models/nets/embedder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder 8 | 9 | 10 | class CombinedEmbedder(nn.Module): 11 | """ 12 | An embedder for continous and discrete features. Is a nn.Module. 13 | 14 | Args: 15 | continuous_features (List[str]): list of continuous features 16 | category_dict (Dict[str, List]): dictionary of discrete features 17 | embed_dim (int): embedding dimension 18 | drop_rate (float): dropout rate. Defaults to ``0.0``. 19 | pre_encoded (bool): whether to use the input data as is. Defaults to ``False``. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | feature_encoder: FeatureDictionaryEncoder, 25 | embed_dim: int, 26 | drop_rate: float = 0.0, 27 | ): 28 | super().__init__() 29 | self.embed_dim = embed_dim 30 | self.drop_rate = drop_rate 31 | self.encoder = feature_encoder 32 | 33 | # create the continuous feature encoding, with one hidden layer for good measure 34 | num_cf = len(self.continuous_features) 35 | self.c1 = nn.Linear(num_cf, 2 * num_cf) 36 | self.c2 = nn.Linear(2 * num_cf, embed_dim) 37 | self.combine = nn.Linear(len(self.discrete_features) + 1, 1) 38 | self.layer_norm = nn.LayerNorm(normalized_shape=embed_dim) 39 | 40 | # create the discrete feature encoding 41 | self.enc = {} 42 | self.emb = nn.ModuleDict() 43 | for name in self.discrete_features: 44 | self.emb[name] = nn.Embedding(self.encoder.feature_size(name), embed_dim) 45 | 46 | self.output_shape = [None, embed_dim] 47 | 48 | @property 49 | def continuous_features(self): 50 | return self.encoder.continuous_features 51 | 52 | @property 53 | def discrete_features(self): 54 | return self.encoder.discrete_features 55 | 56 | def build_parameter_dict(self) -> Dict[str, Any]: 57 | """Return a dict of parameters. 58 | 59 | Returns: 60 | Dict[str, Any]: Parameters of the embedder 61 | """ 62 | return { 63 | "embed_dim": self.embed_dim, 64 | "embedder_drop_rate": self.drop_rate, 65 | } 66 | 67 | def forward(self, x: Dict[str, torch.Tensor]): 68 | combined_emb = [] 69 | 70 | # batch x num_cont_features 71 | cf = torch.stack([x[f] for f in self.continuous_features], dim=1) 72 | cf[cf.isnan()] = 0 # TODO do not do if nan is start token 73 | 74 | cf_emb = F.dropout(F.relu(self.c1(cf)), self.drop_rate, self.training) 75 | # cf_emb = torch.clip(cf_emb, -65000, 65000) 76 | assert not cf_emb.isnan().any(), "First Linear Layer for continuous features contains `NaN` values." 77 | 78 | # batch x embed_dim 79 | cf_emb = F.dropout(F.relu(self.c2(cf_emb)), self.drop_rate, self.training) 80 | assert not cf_emb.isnan().any(), "Second Linear Layer for continuous features contains `NaN` values." 81 | combined_emb.append(cf_emb) 82 | 83 | # out = torch.clip(out, -65000, 65000) 84 | for name in self.discrete_features: 85 | disc_emb = F.dropout(self.emb[name](x[name])) 86 | assert not disc_emb.isnan().any(), f"Embedding for discrete feature '{name}' contains `NaN` values." 87 | combined_emb.append(disc_emb) 88 | 89 | combined_emb = torch.stack(combined_emb, dim=-1) 90 | out = self.combine(combined_emb).squeeze() 91 | assert not out.isnan().any(), "Combined Embeddings for all features contain `NaN` values." 92 | 93 | # out = self.layer_norm(out) # TODO try removing this again once all features are properly normalized 94 | # assert not out.isnan().any(), "Normalized Embeddings contain `NaN` values." 95 | return out 96 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Dict 3 | 4 | from dataclasses import dataclass 5 | 6 | import numpy as np 7 | 8 | 9 | @dataclass 10 | class Tokenizer: 11 | """A callable class to tokenize your dictionary-based batches. 12 | 13 | The Tokenizer left-truncates sequences and appends start tokens. 14 | 15 | Args: 16 | continuous_features (List[str]): A list containing the names of the continuous features. 17 | discrete_features (List[str]): A list containing the names of the discrete features. 18 | max_item_len (int): The maximum length to which the sequence should be truncated. 19 | The tokenizer performs left-truncation. The length of returned sequences will be ``max_item_len + 1``. 20 | start_token_continuous (np.float32): The start token for variables specified in ``continuous_features``. 21 | start_token_discrete (str): The start token for variables specified in ``discrete_features``. 22 | start_token_timestamp (datetime.datetime): The start token for variables with data type ``np.datetime64``. 23 | start_token_other (np.float32): The start token for variables not specified in ``continuous_features``, 24 | `discrete_features`` or of type ``np.datetime64``. 25 | 26 | Attributes: 27 | continuous_features (List[str]): A list containing the names of the continuous features. 28 | discrete_features (List[str]): A list containing the names of the discrete features. 29 | features (List[str]): A list containing the names of both continuous and discrete features. 30 | max_item_len (int): The maximum length to which the sequence should be truncated. 31 | The tokenizer performs left-truncation. The length of returned sequences will be ``max_item_len + 1``. 32 | start_token_continuous (np.float32): The start token for variables specified in ``continuous_features``. 33 | start_token_discrete (str): The start token for variables specified in ``discrete_features``. 34 | start_token_timestamp (datetime.datetime): The start token for variables with numpy dtype of kind ``datetime``, 35 | i.e. ``dtype.kind == 'M'``. 36 | start_token_other (np.float32): The start token for variables not specified in ``continuous_features``, 37 | `discrete_features`` or of type ``np.datetime64``. 38 | """ 39 | 40 | continuous_features: List[str] 41 | discrete_features: List[str] 42 | max_item_len: int 43 | start_token_continuous: np.float32 44 | start_token_discrete: str 45 | start_token_timestamp: datetime 46 | start_token_other: np.float32 47 | 48 | def __call__(self, x: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 49 | """Tokenizes and truncates the data ``x``. 50 | 51 | Args: 52 | x (Dict[str, np.ndarray]): The raw data dictionary. 53 | 54 | Returns: 55 | Dict[str, np.ndarray]: The transformed data. 56 | """ 57 | # truncate sequence 58 | x = {k: v[-(self.max_item_len) :] for k, v in x.items()} 59 | 60 | # add start tokens 61 | for k, v in x.items(): 62 | if k in self.features or k in ["t", "dt"]: 63 | if k in self.discrete_features: 64 | x[k] = np.append([self.start_token_discrete], v) 65 | else: 66 | x[k] = np.append([self.start_token_continuous], v) 67 | else: 68 | # numpy dtype kind "M" is any datetime object 69 | if v.dtype.kind == "M": 70 | x[k] = np.append(np.array([self.start_token_timestamp], dtype=np.datetime64), v) 71 | else: 72 | x[k] = np.append([self.start_token_other], v) 73 | return x 74 | 75 | @property 76 | def features(self) -> List[str]: 77 | return self.continuous_features + list(self.discrete_features) 78 | -------------------------------------------------------------------------------- /examples/train_btyd_information_bottleneck.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | import pytorch_lightning as pl 7 | 8 | from neural_lifetimes import run_model 9 | from neural_lifetimes.data.datamodules import SequenceDataModule 10 | from neural_lifetimes.data.datasets.btyd import BTYD, GenMode 11 | from neural_lifetimes.models.modules import InformationBottleneckEventModel 12 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder, Tokenizer, TargetCreator 13 | from examples import eventsprofiles_datamodel 14 | 15 | 16 | LOG_DIR = str(Path(__file__).parent / "logs") 17 | data_dir = str(Path(__file__).parent.absolute()) 18 | 19 | START_TOKEN_DISCR = "" 20 | COLS = eventsprofiles_datamodel.target_cols + eventsprofiles_datamodel.cont_feat + eventsprofiles_datamodel.discr_feat 21 | 22 | if __name__ == "__main__": 23 | pl.seed_everything(9473) 24 | 25 | btyd_dataset = BTYD.from_modes( 26 | modes=[ 27 | GenMode(a=2, b=10, r=5, alpha=10), 28 | GenMode(a=2, b=10, r=10, alpha=600), 29 | ], 30 | num_customers=1000, 31 | mode_ratios=[2.5, 1], # generate equal number of transactions from each mode 32 | seq_gen_dynamic=False, 33 | start_date=datetime.datetime(2019, 1, 1, 0, 0, 0), 34 | start_limit_date=datetime.datetime(2019, 6, 15, 0, 0, 0), 35 | end_date=datetime.datetime(2021, 1, 1, 0, 0, 0), 36 | data_dir=data_dir, 37 | continuous_features=eventsprofiles_datamodel.cont_feat, 38 | discrete_features=eventsprofiles_datamodel.discr_feat, 39 | track_statistics=True, 40 | ) 41 | 42 | btyd_dataset[:] 43 | print(f"Expected Num Transactions per mode: {btyd_dataset.expected_num_transactions_from_priors()}") 44 | print(f"Expected p churn per mode: {btyd_dataset.expected_p_churn_from_priors()}") 45 | print(f"Expected time interval per mode: {btyd_dataset.expected_time_interval_from_priors()}") 46 | print(f"Truncated sequences: {btyd_dataset.truncated_sequences}") 47 | 48 | btyd_dataset.plot_tracked_statistics().show() 49 | 50 | discrete_values = btyd_dataset.get_discrete_feature_values( 51 | start_token=START_TOKEN_DISCR, 52 | ) 53 | 54 | encoder = FeatureDictionaryEncoder(eventsprofiles_datamodel.cont_feat, discrete_values) 55 | 56 | tokenizer = Tokenizer( 57 | continuous_features=eventsprofiles_datamodel.cont_feat, 58 | discrete_features=discrete_values, 59 | start_token_continuous=np.nan, 60 | start_token_discrete=START_TOKEN_DISCR, 61 | start_token_other=np.nan, 62 | max_item_len=100, 63 | start_token_timestamp=datetime.datetime(1970, 1, 1, 1, 0), 64 | ) 65 | 66 | target_transform = TargetCreator(cols=COLS) 67 | 68 | datamodule = SequenceDataModule( 69 | dataset=btyd_dataset, 70 | tokenizer=tokenizer, 71 | transform=encoder, 72 | target_transform=target_transform, 73 | test_size=0.2, 74 | batch_points=1024, 75 | min_points=1, 76 | ) 77 | 78 | loss_cfg = { 79 | "n_cold_steps": 100, 80 | "n_warmup_steps": 100, 81 | "target_weight": 0.001, 82 | "n_eigen": None, 83 | "n_eigen_threshold": None, 84 | } 85 | 86 | net = InformationBottleneckEventModel( 87 | feature_encoder_config=encoder.config_dict(), 88 | rnn_dim=256, 89 | emb_dim=256, 90 | drop_rate=0.5, 91 | bottleneck_dim=32, 92 | lr=0.001, 93 | target_cols=COLS, 94 | encoder_noise=1e-6, 95 | loss_cfg=loss_cfg, 96 | ) 97 | 98 | run_model( 99 | datamodule, 100 | net, 101 | log_dir=LOG_DIR, 102 | num_epochs=50, 103 | val_check_interval=10, 104 | limit_val_batches=20, 105 | gradient_clipping=0.0000001, 106 | ) 107 | -------------------------------------------------------------------------------- /neural_lifetimes/models/modules/configure_optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from torch import nn 4 | from torch.optim import Adam, SGD 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR 6 | 7 | 8 | def configure_optimizers( 9 | parameters: nn.Module, 10 | lr: float, 11 | optimizer: str = "Adam", 12 | optimizer_kwargs: Optional[Dict[str, Any]] = None, 13 | scheduler: str = "ReduceLROnPlateau", 14 | scheduler_kwargs: Optional[Dict[str, Any]] = None, 15 | lightning_scheduler_config: Optional[Dict[str, Any]] = None, 16 | ) -> Dict[str, Any]: 17 | """Configures optimizers and LR schedulers for any ``LightningModule``. 18 | 19 | Args: 20 | parameters (nn.Module): The models hyperparameters, e.g. ``nn.Module.parameters()`` 21 | lr (float): The intial learning rate at which to start training. 22 | optimizer (str, optional): The optimizer to use. Defaults to "Adam". 23 | optimizer_kwargs (Optional[Dict[str, Any]], optional): Additional arguments to initialise optimizer. 24 | Defaults to None. 25 | scheduler (str, optional): The scheduler to use. Defaults to "ReduceLROnPlateau". 26 | scheduler_kwargs (Optional[Dict[str, Any]], optional): Additional arguments to initialise scheduler. 27 | Defaults to None. 28 | lightning_scheduler_config (Optional[Dict[str, Any]], optional): Arguments to overwrite the default 29 | ``lightning_scheduler_config``. Defaults to None. 30 | 31 | Raises: 32 | NotImplementedError: The specified ``optimizer`` is not implemented. 33 | AssertionError: When using ``MultiStepLR``, a list of milestones is required. 34 | 35 | Returns: 36 | Dict[str, Any]: The PyTorch-Lightning dictionary specifying the optimizer configuration. It has two elements 37 | ``optimizer`` sets the optimizer and ``lr_scheduler`` contains a lightning scheduler configuration dictionary. 38 | """ 39 | # check inputs 40 | optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs 41 | scheduler_kwargs = {} if scheduler_kwargs is None else scheduler_kwargs 42 | lightning_scheduler_config = {} if lightning_scheduler_config is None else lightning_scheduler_config 43 | 44 | # configure optimizer 45 | if optimizer == "Adam": 46 | opt = Adam(parameters, lr=lr, **optimizer_kwargs) 47 | elif optimizer == "SGD": 48 | opt = SGD(parameters, lr=lr, **optimizer_kwargs) 49 | else: 50 | raise NotImplementedError(f'Optimizer "{optimizer}" not implemented.') 51 | 52 | # configure scheduler 53 | lightning_scheduler_config = dict( 54 | { 55 | "interval": "epoch", 56 | "frequency": 200, 57 | "monitor": "val_loss/total", 58 | "strict": True, 59 | "name": f"lr-{optimizer}-{scheduler}", 60 | }, 61 | **lightning_scheduler_config, 62 | ) 63 | 64 | if scheduler == "ReduceLROnPlateau": 65 | scheduler_kwargs = dict( 66 | { 67 | "optimizer": opt, 68 | "mode": "min", 69 | "patience": 2, 70 | "verbose": True, 71 | }, 72 | **scheduler_kwargs, 73 | ) 74 | scheduler = ReduceLROnPlateau(**scheduler_kwargs) 75 | elif scheduler == "MultiStepLR": 76 | assert "milestones" in scheduler_kwargs, "MultiStepLR requires you to set `milestones` manually." 77 | scheduler_kwargs = dict( 78 | { 79 | "optimizer": opt, 80 | "verbose": True, 81 | }, 82 | **scheduler_kwargs, 83 | ) 84 | scheduler = MultiStepLR(**scheduler_kwargs) 85 | elif scheduler == "None": 86 | scheduler = None 87 | else: 88 | raise NotImplementedError(f'Scheduler "{scheduler}" not implemented.') 89 | 90 | lightning_scheduler_config["scheduler"] = scheduler 91 | return {"optimizer": opt, "lr_scheduler": lightning_scheduler_config} 92 | -------------------------------------------------------------------------------- /tests/test_utils/test_date_arithmetic.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | import numpy as np 4 | 5 | import pytest 6 | from neural_lifetimes.utils import datetime2float, float2datetime 7 | from neural_lifetimes.utils.date_arithmetic import _conversion_factors 8 | 9 | 10 | @pytest.fixture 11 | def datetimes(): 12 | return [datetime(2000 + i, i, 2 * i) for i in range(1, 13)] 13 | 14 | 15 | @pytest.fixture 16 | def timestamps(): 17 | return [ 18 | 271776.0, 19 | 281328.0, 20 | 290808.0, 21 | 300383.0, 22 | 309911.0, 23 | 319463.0, 24 | 328991.0, 25 | 338567.0, 26 | 348119.0, 27 | 357647.0, 28 | 367200.0, 29 | 376752.0, 30 | ] 31 | 32 | 33 | # tolerance in hours 34 | @pytest.fixture 35 | def error_tolerance(): 36 | return 2 37 | 38 | 39 | class Test_Base: 40 | @staticmethod 41 | def test_datetime2float(datetimes, timestamps, error_tolerance): 42 | res = [datetime2float(dt) for dt in datetimes] 43 | assert all([(r - t) < error_tolerance for r, t in zip(res, timestamps)]) 44 | 45 | @staticmethod 46 | def test_float2datetime(datetimes, timestamps, error_tolerance): 47 | res = [float2datetime(ts) for ts in timestamps] 48 | assert all([(r - t) < timedelta(hours=error_tolerance) for r, t in zip(res, datetimes)]) 49 | 50 | @staticmethod 51 | def test_both(datetimes, error_tolerance): 52 | res = [float2datetime(datetime2float(dt)) for dt in datetimes] 53 | assert all([(r - t) < timedelta(hours=error_tolerance) for r, t in zip(res, datetimes)]) 54 | 55 | 56 | class Test_Numpy: 57 | @staticmethod 58 | def test_datetime2float(datetimes, timestamps, error_tolerance): 59 | datetimes = np.array(datetimes, dtype="datetime64[us]") 60 | timestamps = np.array(timestamps) 61 | res = datetime2float(datetimes) 62 | assert np.all(np.isclose(res, timestamps, rtol=0, atol=error_tolerance)) 63 | 64 | @staticmethod 65 | def test_float2datetime(datetimes, timestamps, error_tolerance): 66 | datetimes = np.array(datetimes, dtype="datetime64[us]") 67 | timestamps = np.array(timestamps) 68 | res = float2datetime(timestamps) 69 | assert np.all((res - datetimes) < timedelta(hours=error_tolerance)) 70 | 71 | @staticmethod 72 | def test_both(datetimes, timestamps, error_tolerance): 73 | datetimes = np.array(datetimes, dtype="datetime64[us]") 74 | timestamps = np.array(timestamps) 75 | res = float2datetime(datetime2float(datetimes)) 76 | assert np.all((res - datetimes) < timedelta(hours=error_tolerance)) 77 | 78 | 79 | class Test_AcrossTypes: 80 | @staticmethod 81 | def test_numpy_base(datetimes, error_tolerance): 82 | res_base = np.array([datetime2float(dt) for dt in datetimes]) 83 | np_datetimes = np.array(datetimes, dtype="datetime64[us]") 84 | res_np = datetime2float(np_datetimes) 85 | assert np.all(np.isclose(res_base, res_np, rtol=0, atol=error_tolerance)) 86 | 87 | 88 | class Test_Units: 89 | @staticmethod 90 | def test_tofloat(datetimes, error_tolerance): 91 | cf = _conversion_factors 92 | base = np.array([datetime2float(dt) for dt in datetimes]) 93 | res = {unit: np.array([datetime2float(dt, unit=unit) for dt in datetimes]) for unit in cf.keys()} 94 | for unit, arr in res.items(): 95 | assert np.all(np.isclose(arr * cf[unit] / cf["h"], base, rtol=0, atol=error_tolerance)) 96 | 97 | @staticmethod 98 | def test_todatetime(datetimes, timestamps, error_tolerance): 99 | cf = _conversion_factors 100 | timestamps = {unit: [ts / cf[unit] * cf["h"] for ts in timestamps] for unit in cf.keys()} 101 | res = {unit: [float2datetime(ts, unit=unit) for ts in tstamp] for unit, tstamp in timestamps.items()} 102 | for arr in res.values(): 103 | assert all([(r - t) < timedelta(hours=error_tolerance) for r, t in zip(arr, datetimes)]) 104 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class WeightScheduler(ABC): 5 | """Abstract baseclass for the weight schedulers. 6 | 7 | Classes of this template require a weight attribute returning the current weight and a step method. 8 | """ 9 | 10 | @property 11 | @abstractmethod 12 | def weight(self) -> float: 13 | pass 14 | 15 | @abstractmethod 16 | def step(self) -> float: 17 | pass 18 | 19 | 20 | class LinearWarmupScheduler(WeightScheduler): 21 | def __init__(self, n_cold_steps: int, n_warmup_steps: int, target_weight: float) -> None: 22 | """The LinearWarmupScheduler implements a linear increase in the weight to start up training. 23 | 24 | Often model training benefit from starting training with no penalty and slowing increasing it after a few 25 | first steps. This is called warmup. In particular, the information bottleneck requires this. The linear 26 | scheduler returns penalty weight 0 until step ``n_cold_steps`` is reached and then linearly interpolates for 27 | ``n_warmup_steps`` until the ``target_weight`` is reached. 28 | 29 | Args: 30 | n_cold_steps (int): Number of steps to conduct with 0 weight. 31 | n_warmup_steps (int): Number of steps to warmup weight. 32 | target_weight (float): Target weight. 33 | """ 34 | super().__init__() 35 | self.n_cold_steps = n_cold_steps 36 | self.n_warmup_steps = n_warmup_steps 37 | self.target_weight = target_weight 38 | self._step = 1 39 | 40 | @property 41 | def weight(self) -> float: 42 | """Get the current weight. 43 | 44 | Returns: 45 | float: Weight. 46 | """ 47 | if self._step <= self.n_cold_steps: 48 | return 0.0 49 | elif self._step < self.n_warmup_steps + self.n_cold_steps: 50 | return (self._step - self.n_cold_steps) / self.n_warmup_steps * self.target_weight 51 | else: 52 | return self.target_weight 53 | 54 | def step(self) -> float: 55 | """Increases the step.""" 56 | self._step += 1 57 | return self.weight 58 | 59 | 60 | class ExponentialWarmupScheduler(WeightScheduler): 61 | def __init__(self, n_cold_steps: int, n_warmup_steps: int, target_weight: float, gamma: float) -> None: 62 | """The ExpoentialWarmupScheduler implements an exponential increase in the weight to start up training. 63 | 64 | Often model training benefit from starting training with no penalty and slowing increasing it after a few 65 | first steps. This is called warmup. In particular, the information bottleneck requires this. The exponential 66 | scheduler returns penalty weight 0 until step ``n_cold_steps`` is reached and then increases the weight every 67 | step by factor ``gamma`` for ``n_warmup_steps`` steps, when ``target_weight`` is reached. 68 | 69 | Args: 70 | n_cold_steps (int): Number of cold start steps with 0 weight. 71 | n_warmup_steps (int): Number of steps to warmup weight. 72 | target_weight (float): Target weight. 73 | gamma (float): Factor to increase learning rate. 74 | """ 75 | super().__init__() 76 | self.n_cold_steps = n_cold_steps 77 | self.n_warmup_steps = n_warmup_steps 78 | self.target_weight = target_weight 79 | self.gamma = gamma 80 | self._step = 1 81 | 82 | @property 83 | def weight(self) -> float: 84 | """Get the current weight. 85 | 86 | Returns: 87 | float: Weight. 88 | """ 89 | if self._step <= self.n_cold_steps: 90 | return 0.0 91 | elif self._step < self.n_warmup_steps + self.n_cold_steps: 92 | return self.target_weight / self.gamma ** (self.n_warmup_steps + self.n_cold_steps - self._step) 93 | else: 94 | return self.target_weight 95 | 96 | def step(self) -> float: 97 | """Increases the step.""" 98 | self._step += 1 99 | return self.weight 100 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/data/encoder_with_unknown.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | import numpy as np 3 | from sklearn.preprocessing import OrdinalEncoder 4 | 5 | import torch 6 | 7 | 8 | # TODO What does this function actually do? it doesn't normalize the data 9 | def normalize(x): 10 | """ 11 | Normalize the data. Only 1 encoding is handled at a time. 12 | 13 | Args: 14 | x: The data to be normalized. 15 | 16 | Returns: 17 | np.array: The normalized data. 18 | 19 | Note: 20 | Since we are using np.array, it may lead to errors with GPUs. 21 | """ 22 | try: 23 | if isinstance(x, torch.Tensor): 24 | x = x.detach().cpu().numpy() 25 | except Exception: 26 | pass 27 | 28 | x = np.array(x) # TODO Why copy the data? 29 | if len(x.shape) == 1: 30 | x = x[:, None] # TODO is this the same as np.expand_dims() ? 31 | 32 | assert x.shape[1] == 1 # only handle one encoding at a time 33 | return x 34 | 35 | 36 | # TODO The encoder truncates the "" token when the original dtype is shorted. This could be better. 37 | class OrdinalEncoderWithUnknown(OrdinalEncoder): 38 | """An ordinal encoder that encodes unknown values as 0. 39 | 40 | The OrdinalEncoderWithUnknown works with unknown values. If an unknown value is passed into ``transform()``, 41 | it will be encoded as ``0``. The ``inverse_transform`` maps ``0`` to ````. 42 | The encoder acts as a surjective mapping. 43 | 44 | Attributes: 45 | levels (np.ndarray): The raw levels that can be decoded to. Includes the ```` token. 46 | 47 | Basis: 48 | ``sklearn.preprocessing.OrdinalEncoder`` 49 | """ 50 | 51 | # uses 0 to encode unknown values 52 | def transform(self, x: np.ndarray) -> np.ndarray: 53 | """Transforms the data into encoded format. 54 | 55 | Args: 56 | x (np.ndarray): The raw data. 57 | 58 | Returns: 59 | np.ndarray: The encoded data with dtype ``int64``. 60 | """ 61 | x = normalize(x) 62 | out = np.zeros(x.shape).astype(int) 63 | # The below was the old implementation 64 | # known = np.array([xx[0] in self.categories_[0] for xx in x]) 65 | # this should give identical results but faster 66 | known = np.isin(x, self.categories_[0]).reshape(-1) 67 | if any(known): 68 | out[known] = super(OrdinalEncoderWithUnknown, self).transform(np.array(x)[known]) + 1 69 | return out 70 | 71 | def fit(self, x: np.ndarray) -> None: 72 | """Fits the encoder. 73 | 74 | Args: 75 | x (np.ndarray): The raw data array. 76 | 77 | Returns: 78 | _type_: The encoded data array. 79 | """ 80 | x = normalize(x) 81 | return super().fit(x) 82 | 83 | def __len__(self) -> int: 84 | return len(self.categories_[0]) + 1 85 | 86 | def inverse_transform(self, x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: 87 | """Transforms the data into the decoded format. 88 | 89 | Unknown values will be decoded as "". 90 | 91 | Args: 92 | x (Union[np.ndarray, torch.Tensor]): The encoded data. 93 | 94 | Returns: 95 | np.ndarray: The decoded data. The dtype will match the dtype of the array past into the ``fit`` method. 96 | 97 | Note: 98 | If the string dtype passed into ``fit`` too short for "", this token will be truncated. 99 | """ 100 | if isinstance(x, torch.Tensor): 101 | x = x.detach().cpu().numpy() 102 | out = np.full_like(x, "", dtype=self.categories_[0].dtype) 103 | known = x > 0 104 | if any(known): 105 | out[known] = ( 106 | super(OrdinalEncoderWithUnknown, self) 107 | .inverse_transform(np.expand_dims(x[known], axis=-1) - 1) 108 | .reshape(-1) 109 | ) 110 | return out 111 | 112 | @property 113 | def levels(self): 114 | return np.concatenate((np.array([""]).astype(self.categories_[0].dtype), self.categories_[0])) 115 | 116 | def to_dict(self) -> Dict[str, int]: 117 | """Converts the encoder into a dictionary structure mapping raw to encoded values. Includes unknown token. 118 | 119 | Returns: 120 | Dict[str, int]: Dictionary of form ``raw: encoded``. 121 | """ 122 | return {level: self.transform(np.array([level])).item() for level in self.levels} 123 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/score_estimators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | 4 | # code taken from https://github.com/zhouyiji/MIGE/blob/master/toy/mige_toy.ipynb 5 | 6 | 7 | class ScoreEstimator(ABC): 8 | @abstractmethod 9 | def __init__(self): 10 | pass 11 | 12 | def rbf_kernel(self, x1, x2, kernel_width): 13 | return torch.exp( 14 | -torch.sum(torch.mul((x1 - x2), (x1 - x2)), dim=-1) / (2 * torch.mul(kernel_width, kernel_width)) 15 | ) 16 | 17 | def gram(self, x1, x2, kernel_width): 18 | x_row = torch.unsqueeze(x1, -2) 19 | x_col = torch.unsqueeze(x2, -3) 20 | kernel_width = kernel_width[..., None, None] 21 | return self.rbf_kernel(x_row, x_col, kernel_width) 22 | 23 | def grad_gram(self, x1, x2, kernel_width): 24 | x_row = torch.unsqueeze(x1, -2) 25 | x_col = torch.unsqueeze(x2, -3) 26 | kernel_width = kernel_width[..., None, None] 27 | G = self.rbf_kernel(x_row, x_col, kernel_width) 28 | diff = (x_row - x_col) / (kernel_width[..., None] ** 2) 29 | G_expand = torch.unsqueeze(G, -1) 30 | grad_x2 = G_expand * diff 31 | grad_x1 = G_expand * (-diff) 32 | return G, grad_x1, grad_x2 33 | 34 | def heuristic_kernel_width(self, x_samples, x_basis): 35 | n_samples = x_samples.size()[-2] 36 | n_basis = x_basis.size()[-2] 37 | x_samples_expand = torch.unsqueeze(x_samples, -2) 38 | x_basis_expand = torch.unsqueeze(x_basis, -3) 39 | pairwise_dist = torch.sqrt( 40 | torch.sum(torch.mul(x_samples_expand - x_basis_expand, x_samples_expand - x_basis_expand), dim=-1) 41 | ) 42 | k = n_samples * n_basis // 2 43 | top_k_values = torch.topk(torch.reshape(pairwise_dist, [-1, n_samples * n_basis]), k=k)[0] 44 | kernel_width = torch.reshape(top_k_values[:, -1], x_samples.size()[:-2]) 45 | return kernel_width.detach() 46 | 47 | @abstractmethod 48 | def compute_gradients(self, samples, x=None): 49 | pass 50 | 51 | 52 | class SpectralScoreEstimator(ScoreEstimator): 53 | def __init__(self, n_eigen=None, eta=None, n_eigen_threshold=None): 54 | self._n_eigen = n_eigen 55 | self._eta = eta 56 | self._n_eigen_threshold = n_eigen_threshold 57 | super().__init__() 58 | 59 | def nystrom_ext(self, samples, x, eigen_vectors, eigen_values, kernel_width): 60 | M = torch.tensor(samples.size()[-2]).to(samples.device) 61 | Kxq = self.gram(x, samples, kernel_width) 62 | ret = torch.sqrt(M.float()) * torch.matmul(Kxq, eigen_vectors) 63 | ret *= 1.0 / torch.unsqueeze(eigen_values, dim=-2) 64 | return ret 65 | 66 | def compute_gradients(self, samples, x=None): 67 | if x is None: 68 | kernel_width = self.heuristic_kernel_width(samples, samples) 69 | x = samples 70 | else: 71 | _samples = torch.cat([samples, x], dim=-2) 72 | kernel_width = self.heuristic_kernel_width(_samples, _samples) 73 | 74 | M = samples.size()[-2] 75 | Kq, grad_K1, grad_K2 = self.grad_gram(samples, samples, kernel_width) 76 | if self._eta is not None: 77 | Kq += self._eta * torch.eye(M) 78 | 79 | # eigen_values, eigen_vectors = torch.symeig(Kq, eigenvectors=True, upper=True) torch==1.3 80 | eigen_values, eigen_vectors = torch.linalg.eigh(Kq, UPLO="U") 81 | if (self._n_eigen is None) and (self._n_eigen_threshold is not None): 82 | eigen_arr = torch.mean(torch.reshape(eigen_values, [-1, M]), dim=0) 83 | 84 | eigen_arr = torch.flip(eigen_arr, [-1]) 85 | eigen_arr /= torch.sum(eigen_arr) 86 | eigen_cum = torch.cumsum(eigen_arr, dim=-1) 87 | eigen_lt = torch.lt(eigen_cum, self._n_eigen_threshold) 88 | self._n_eigen = torch.sum(eigen_lt) 89 | if self._n_eigen is not None: 90 | eigen_values = eigen_values[..., -self._n_eigen :] 91 | eigen_vectors = eigen_vectors[..., -self._n_eigen :] 92 | eigen_ext = self.nystrom_ext(samples, x, eigen_vectors, eigen_values, kernel_width) 93 | grad_K1_avg = torch.mean(grad_K1, dim=-3) 94 | M = torch.tensor(M).to(samples.device) 95 | beta = ( 96 | -torch.sqrt(M.float()) 97 | * torch.matmul(torch.transpose(eigen_vectors, -1, -2), grad_K1_avg) 98 | / torch.unsqueeze(eigen_values, -1) 99 | ) 100 | grads = torch.matmul(eigen_ext, beta) 101 | self._n_eigen = None 102 | return grads 103 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/clickhouse/schema.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import Sequence, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from clickhouse_driver import Client 7 | from sqlalchemy import create_engine 8 | 9 | type_dict = { 10 | np.dtype("int64"): "Int64", 11 | np.dtype("int32"): "Int32", 12 | np.dtype("O"): "String", 13 | np.dtype("float64"): "Float64", 14 | np.dtype("float32"): "Float32", 15 | np.dtype(" Dict[str, torch.Tensor]: 93 | 94 | final = [i.item() - 1 for i in batch["offsets"][1:]] 95 | out = {k: v[final] for k, v in model_out.items()} 96 | 97 | return batch, out 98 | 99 | 100 | def remove_initial_event(batch: Dict[str, torch.Tensor], model_out: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 101 | 102 | first = [i.item() for i in batch["offsets"][:-1]] # first elements #batch["offsets"][:-1].tolist() 103 | inds = [k for k in range(batch["offsets"][-1]) if k not in first] 104 | out = {k: v[inds] for k, v in model_out.items()} 105 | 106 | return batch, out 107 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/callbacks/git.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from typing import Any, Dict 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.loggers import TensorBoardLogger 7 | 8 | 9 | class GitInformationLogger(pl.Callback): 10 | def __init__(self, prefix="git/", verbose: bool = True) -> None: 11 | """Collects information on the state of the git repository for reproducibility. 12 | 13 | Args: 14 | prefix (str, optional): The prefix used for logging the contained information. Defaults to "git/". 15 | verbose (bool, optional): When true, warns when running from a dirty repository. Defaults to True. 16 | """ 17 | super().__init__() 18 | self.prefix = prefix 19 | self.verbose = verbose 20 | 21 | # check whether run started in git repo: 22 | out = subprocess.run(["git", "status", "-s"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 23 | if out.stderr.startswith(b"fatal: not a git repository"): 24 | raise SystemError( 25 | "The entrypoint for the script is not inside a git repository. Consider running `git init` " 26 | + f"in you shell or remove the {self.__class__.__name__} callback." 27 | ) 28 | 29 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 30 | self._log(trainer, pl_module) 31 | return super().on_fit_start(trainer, pl_module) 32 | 33 | def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 34 | self._log(trainer, pl_module) 35 | return super().on_test_start(trainer, pl_module) 36 | 37 | def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 38 | self._log(trainer, pl_module) 39 | return super().on_predict_start(trainer, pl_module) 40 | 41 | def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 42 | self._log(trainer, pl_module) 43 | return super().on_validation_start(trainer, pl_module) 44 | 45 | def _log(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 46 | git_information = self.data_dict() 47 | if isinstance(pl_module.logger, TensorBoardLogger): 48 | raise AssertionError( 49 | "This callback does not support the ``TensorBoardLogger``, due to its inability to process multiple " 50 | + "calls to log hyperparameters." 51 | ) 52 | else: 53 | pl_module.logger.log_hyperparams(git_information) 54 | 55 | def data_dict(self) -> Dict[str, Any]: 56 | """Collects git information including git hash, commit date, short hash, status, branch and hostname. 57 | 58 | Returns: 59 | Dict[str, Any]: Git repository information. 60 | """ 61 | git_information = { 62 | "short_hash": get_git_short_hash(), 63 | "hash": get_git_hash(), 64 | "commit_date": get_git_commit_date(), 65 | "status": get_git_repository_status(), 66 | "branch": get_git_branch(), 67 | "host": get_hostname(), 68 | } 69 | git_information = {f"{self.prefix}{key}": _decode(value) for key, value in git_information.items()} 70 | self._print_git_status_warning(git_information[f"{self.prefix}status"]) 71 | return git_information 72 | 73 | def _print_git_status_warning(self, git_status: str): 74 | if not self.verbose: 75 | return 76 | if git_status == "dirty": 77 | print("\nWARNING: The git repository contains untracked files or uncommited changes.\n") 78 | if git_status == "unknown": 79 | print( 80 | "\nWARNING: The git repository status could not be determined. The git repository may contain " 81 | + "untracked files or uncommited changes.\n" 82 | ) 83 | 84 | 85 | def get_git_hash(): 86 | return subprocess.check_output(["git", "log", "-n", "1", "--pretty=tformat:%H"]).strip() 87 | 88 | 89 | def get_git_short_hash(): 90 | return subprocess.check_output(["git", "log", "-n", "1", "--pretty=tformat:%h"]).strip() 91 | 92 | 93 | def get_git_commit_date(): 94 | return subprocess.check_output(["git", "log", "-n", "1", "--pretty=tformat:%ci", "--date=short"]).strip() 95 | 96 | 97 | def get_git_repository_status(): 98 | out = subprocess.run(["git", "status", "-s"], stdout=subprocess.PIPE) 99 | return "dirty" if out.stdout else "clean" 100 | 101 | 102 | def get_git_branch(): 103 | return subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).strip() 104 | 105 | 106 | def get_hostname(): 107 | return os.uname()[1] 108 | 109 | 110 | def _decode(x): 111 | try: 112 | x = x.decode("utf-8") 113 | except AttributeError: 114 | pass 115 | return x 116 | -------------------------------------------------------------------------------- /tests/test_datasets/test_pandas_dataset.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from enum import Enum 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | import vaex 9 | 10 | from neural_lifetimes.data.datasets import PandasSequenceDataset 11 | 12 | TIME_COL = "transaction_time" 13 | UID_COL = "uid" 14 | 15 | 16 | # TODO xfail until PandasSequenceDataset fixes 17 | pytestmark = pytest.mark.xfail 18 | 19 | 20 | class DataframeType(Enum): 21 | pandas = "pandas" 22 | vaex = "vaex" 23 | 24 | 25 | @pytest.fixture 26 | def start_date(): 27 | return datetime(2020, 1, 1) 28 | 29 | 30 | @pytest.fixture 31 | def uids(): 32 | return list(range(1, 11)) 33 | 34 | 35 | @pytest.fixture 36 | def df(start_date, uids): 37 | rows = [] 38 | for uid in uids: 39 | # Each user will get the same number of transactions as their uid int 40 | for i in range(uid): 41 | rows.append( 42 | { 43 | TIME_COL: start_date + timedelta(days=i), 44 | UID_COL: uid, 45 | } 46 | ) 47 | return pd.DataFrame(rows).sort_values(by=[TIME_COL]) 48 | 49 | 50 | def _construct_dataset( 51 | df: pd.DataFrame, 52 | asof_time: Optional[datetime] = None, 53 | min_items_per_uid: int = 1, 54 | df_type: DataframeType = DataframeType.pandas, 55 | ) -> PandasSequenceDataset: 56 | if df_type == DataframeType.vaex: 57 | df = vaex.from_pandas(df) 58 | return PandasSequenceDataset( 59 | df, 60 | UID_COL, 61 | TIME_COL, 62 | asof_time=asof_time, 63 | min_items_per_uid=min_items_per_uid, 64 | ) 65 | 66 | 67 | class TestConstruction: 68 | @pytest.mark.parametrize("df_type", DataframeType) 69 | def test_sets_static_data(self, df, df_type): 70 | dataset = _construct_dataset(df, df_type=df_type) 71 | 72 | assert dataset.id_column == UID_COL 73 | assert dataset.time_col == TIME_COL 74 | 75 | @pytest.mark.parametrize("df_type", DataframeType) 76 | def test_filters_by_min_items_per_uid(self, df, uids, df_type): 77 | # We should lose uids 1 and 2 78 | dataset = _construct_dataset(df, min_items_per_uid=3, df_type=df_type) 79 | 80 | unique_uids = dataset.df[UID_COL].unique() 81 | expected = set(uids) - {1, 2} 82 | 83 | assert set(unique_uids) == expected 84 | assert len(dataset.ids) == len(expected) 85 | assert set(dataset.ids) == expected 86 | 87 | @pytest.mark.parametrize("df_type", DataframeType) 88 | def test_filters_by_as_of_time(self, df, start_date, uids, df_type): 89 | dataset = _construct_dataset(df, asof_time=start_date, df_type=df_type) 90 | 91 | unique_dates = dataset.df[TIME_COL].unique() 92 | unique_uids = dataset.df[UID_COL].unique() 93 | 94 | assert len(dataset.df) == len(uids) 95 | 96 | assert len(unique_dates) == 1 97 | assert unique_dates[0] == np.datetime64(start_date) 98 | 99 | # Should now have one transaction per uid 100 | assert len(unique_uids) == len(uids) 101 | assert set(unique_uids) == set(uids) 102 | 103 | @pytest.mark.parametrize("df_type", DataframeType) 104 | def test_len(self, df, uids, df_type): 105 | dataset = _construct_dataset(df, df_type=df_type) 106 | result = len(dataset) 107 | 108 | assert result == len(set(uids)) 109 | 110 | 111 | class TestGetItem: 112 | @pytest.mark.parametrize("df_type", DataframeType) 113 | def test_returns_length_of_sequence(self, df, df_type): 114 | uid = 5 115 | sequence_length = uid # As defined in the fixtures 116 | dataset = _construct_dataset(df, df_type=df_type) 117 | 118 | result = dataset.get_seq_len(uid - 1) 119 | 120 | assert result == sequence_length 121 | 122 | 123 | class TestGetBulk: 124 | @pytest.mark.parametrize("df_type", DataframeType) 125 | def test_gets_transactions_for_uid_sequence(self, df, df_type): 126 | uids = np.array([4, 8]) 127 | indices = uids - np.ones_like(uids) 128 | dataset = _construct_dataset(df, df_type=df_type) 129 | 130 | result = dataset[list(indices)] 131 | 132 | assert len(result) == len(uids) 133 | 134 | for uid, sequence in zip(uids, result): 135 | assert len(sequence["t"]) == uid # Sequence length == uids from the fixtures 136 | 137 | @pytest.mark.parametrize("df_type", DataframeType) 138 | def test_works_for_a_single_transaction(self, df, df_type): 139 | uid = 1 140 | index = uid - 1 141 | dataset = _construct_dataset(df, df_type=df_type) 142 | 143 | result = dataset[index] 144 | 145 | assert len(result) == 1 146 | 147 | sequence = result[0] 148 | assert len(sequence["t"]) == 1 # Sequence length == uids from the fixtures 149 | -------------------------------------------------------------------------------- /tests/test_datasets/test_sliceable_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Sequence 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from neural_lifetimes.data.datasets.sequence_dataset import SliceableDataset 7 | 8 | 9 | class ExampleSliceableData(SliceableDataset): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | # length = 10 13 | # num_var = 3 14 | # self.data = np.arange(0, length*num_var).reshape((length, num_var)) 15 | 16 | def __len__(self): 17 | # return self.data.shape[0] 18 | return 10 19 | 20 | def _load_batch(self, s: Sequence[int]) -> Sequence[Dict[str, np.ndarray]]: 21 | return [{"s": np.array(s)}] # odd construct to maintain type safety 22 | # data = self.data[s,:] 23 | # return {str(i): data[:,i] for i in range(data.shape[1])} 24 | 25 | 26 | class IncompleteSliceableData(SliceableDataset): 27 | def __init__(self) -> None: 28 | super().__init__() 29 | length = 10 30 | num_var = 3 31 | self.data = np.arange(0, length * num_var).reshape((length, num_var)) 32 | 33 | 34 | class TestSliceableData: 35 | @staticmethod 36 | def _get_expected_slice_result(s: slice) -> List[int]: 37 | return list(range(10))[s] 38 | 39 | @staticmethod 40 | def _output_to_list(o: Sequence[Dict[str, np.ndarray]]) -> List[int]: 41 | return list(o[0]["s"]) 42 | 43 | @staticmethod 44 | def get_dataset(well_defined: bool) -> SliceableDataset: 45 | if well_defined: 46 | return ExampleSliceableData() 47 | else: 48 | return IncompleteSliceableData() 49 | 50 | @staticmethod 51 | def get_test_slices(name): 52 | # slices to test negative slicing 53 | # if name == 'closed': 54 | # return [slice(0,1), slice(-2,-1), slice(0,4), slice(-9,-3)] 55 | # if name == 'left-open': 56 | # return [slice(None,4), slice(None,-3)] 57 | # if name == 'right-open': 58 | # return [slice(4, None), slice(-3, None)] 59 | # if name == 'open': 60 | # return [slice(None, None), slice(None, None, -1)] 61 | # if name == 'inverse': 62 | # return [slice(4, 0, -1), slice(4, 0, 1), slice(5,2,2), slice(5,2,-2)] 63 | # if name == 'empty': 64 | # return [slice(0,0), slice(-2, -2, 1), slice(0,2,-2)] 65 | 66 | if name == "closed": 67 | return [slice(0, 1), slice(0, 4)] 68 | if name == "left-open": 69 | return [slice(None, 4)] 70 | if name == "right-open": 71 | return [slice(4, None)] 72 | if name == "open": 73 | return [slice(None, None)] 74 | if name == "empty": 75 | return [slice(0, 0), slice(11, None, None)] 76 | 77 | def test_integer_getter(self): 78 | dataset = self.get_dataset(True) 79 | items = [0, 9, -1] 80 | for item in items: 81 | assert dataset[item] == [{"s": np.array([item])}] 82 | 83 | @pytest.mark.parametrize( 84 | "intervals", ("closed", "left-open", "right-open", "open", "empty") 85 | ) # TODO add 'inverse' here 86 | def test_slice_getter(self, intervals): 87 | # get correct dataset for test 88 | dataset = self.get_dataset(True) 89 | slices = self.get_test_slices(intervals) 90 | 91 | for s in slices: 92 | assert self._output_to_list(dataset[s]) == self._get_expected_slice_result(s) 93 | 94 | def test_seq_getter(self): 95 | dataset = self.get_dataset(True) 96 | items = [[], [0], [0, 2], [-2, -4]] 97 | for item in items: 98 | assert self._output_to_list(dataset[item]) == item 99 | 100 | @pytest.mark.parametrize( 101 | ("slice", "expected"), 102 | ( 103 | ( 104 | "A", 105 | "ExampleSliceableData indices must be integers, slices or iterable over integers, not str.", 106 | ), 107 | ( 108 | {}, 109 | "ExampleSliceableData indices must be integers, slices or iterable over integers, not dict.", 110 | ), 111 | ( 112 | 0.3, 113 | "ExampleSliceableData indices must be integers, slices or iterable over integers, not float.", 114 | ), 115 | ), 116 | ) 117 | def test_getter_type_error(self, slice, expected): 118 | dataset = self.get_dataset(True) 119 | with pytest.raises(TypeError) as excinfo: 120 | dataset[slice] 121 | (msg,) = excinfo.value.args 122 | assert msg == expected 123 | 124 | def test_len_not_implemented_error(self): 125 | with pytest.raises(TypeError) as excinfo: 126 | self.get_dataset(False) 127 | (msg,) = excinfo.value.args 128 | assert msg.startswith("Can't instantiate abstract class IncompleteSliceableData") 129 | 130 | def test_get_bulk_not_implemented_error(self): 131 | with pytest.raises(TypeError) as excinfo: 132 | self.get_dataset(False) 133 | (msg,) = excinfo.value.args 134 | assert msg.startswith("Can't instantiate abstract class IncompleteSliceableData") 135 | 136 | 137 | # d = IncompleteSliceableData() 138 | # d._get_bulk([5]) 139 | 140 | # t = TestSliceableData() 141 | # t.test_slice_getter('open') 142 | 143 | # d = ExampleSliceableData() 144 | # d[5] 145 | -------------------------------------------------------------------------------- /neural_lifetimes/models/nets/heads.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Dict 3 | 4 | import torch 5 | import torch.distributions as d 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from neural_lifetimes.losses import CategoricalLoss, CompositeLoss, ExponentialLoss, NormalLoss 10 | 11 | 12 | class BasicHead(nn.Module): 13 | def __init__( 14 | self, 15 | input_dim: int, 16 | drop_rate: float, 17 | output_dim: int, 18 | transform: Callable, 19 | init_norm=False, 20 | ): 21 | super().__init__() 22 | self.drop_rate = drop_rate 23 | hidden_dim = int(math.sqrt(input_dim * output_dim)) 24 | self.fc1 = nn.Linear(input_dim, hidden_dim) 25 | self.fc2 = nn.Linear(hidden_dim, output_dim) 26 | self.transform = transform 27 | self.input_shape = [None, input_dim] 28 | self.output_shape = [None, output_dim] 29 | 30 | self.init_norm = init_norm 31 | self.init_bias = None 32 | self.init_std = None 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | eps = 0.00001 36 | out = F.dropout(F.relu(self.fc1(x)), self.drop_rate, self.training) 37 | tmp = self.fc2(out) 38 | 39 | if self.init_norm: 40 | if self.init_bias is None: 41 | self.init_bias = -tmp.mean(dim=0, keepdims=True).detach() 42 | self.init_std = tmp.std(dim=0, keepdims=True).detach() + eps 43 | tmp = (tmp + self.init_bias) / self.init_std 44 | out = self.transform(tmp) 45 | return out 46 | 47 | 48 | class ExponentialHead(BasicHead): 49 | """ 50 | Output is the scale for a normalized distribution, typically a time delta. 51 | """ 52 | 53 | def __init__(self, input_dim: int, drop_rate: float): 54 | eps = 0.00001 55 | 56 | def normalize(x: torch.Tensor): 57 | x = torch.exp(torch.clamp(x, min=eps, max=60)) 58 | return x 59 | 60 | super().__init__(input_dim, drop_rate, 1, normalize) 61 | 62 | def distribution(self, v: torch.Tensor) -> d.Distribution: 63 | return d.exponential.Exponential(1 / v) 64 | 65 | def loss_function(self): 66 | return ExponentialLoss() 67 | 68 | 69 | class ExponentialHeadNoLoss(ExponentialHead): 70 | def loss_function(self): 71 | return None 72 | 73 | 74 | class ProbabilityHead(BasicHead): 75 | """ 76 | Output is the scale for a normalized distribution, typically a time delta. 77 | """ 78 | 79 | def __init__(self, input_dim: int, drop_rate: float): 80 | super().__init__( 81 | input_dim, 82 | drop_rate, 83 | 1, 84 | lambda x: torch.sigmoid(x), # was 0.5*x 85 | init_norm=False, 86 | ) 87 | 88 | def distribution(self, v: torch.Tensor) -> d.Distribution: 89 | return d.bernoulli.Bernoulli(v) 90 | 91 | def loss_function(self): 92 | # TODO: insert the right thing here 93 | return NotImplementedError 94 | 95 | 96 | class ChurnProbabilityHead(ProbabilityHead): 97 | def distribution(self, v: torch.Tensor) -> d.Distribution: 98 | return d.bernoulli.Bernoulli(v) 99 | 100 | def loss_function( 101 | self, 102 | ): # this one needs a special treatment because we don't need to clip last value in each seq 103 | return None 104 | 105 | 106 | class NormalHead(BasicHead): 107 | """ 108 | Output is the mean and std for a normal distribution. 109 | 110 | Could also be used for a LogNormal distribution. 111 | """ 112 | 113 | def __init__(self, input_dim: int, drop_rate: float): 114 | def normalize(x: torch.Tensor): 115 | # x[:, 1] = torch.exp(x[:, 1]/10) 116 | # x[:, 1] = x[:, 1]**2 117 | return x 118 | 119 | super().__init__(input_dim, drop_rate, 2, normalize) 120 | 121 | def distribution(self, v: torch.Tensor) -> d.Distribution: 122 | return d.normal.Normal(v[:, 0], torch.exp(torch.clamp(v[:, 1], min=-30, max=50))) 123 | 124 | def loss_function(self): 125 | return NormalLoss() 126 | 127 | 128 | class CategoricalHead(BasicHead): 129 | """ 130 | Output is the scale for a normalized distribution, typically a time delta. 131 | """ 132 | 133 | def __init__(self, input_dim: int, num_categories: int, drop_rate: float): 134 | super().__init__(input_dim, drop_rate, num_categories, lambda x: F.log_softmax(x, dim=-1)) 135 | 136 | def distribution(self, v: torch.Tensor) -> d.Distribution: 137 | return d.categorical.Categorical(v) 138 | 139 | def loss_function(self): 140 | return CategoricalLoss() 141 | 142 | 143 | class CompositeDistribution: 144 | def __init__(self, distrs: Dict[str, d.Distribution]): 145 | self.distrs = distrs 146 | 147 | def sample(self, shape) -> Dict[str, torch.Tensor]: 148 | return {k: d.sample(shape) for k, d in self.distrs} 149 | 150 | 151 | class CompositeHead(nn.Module): 152 | def __init__(self, d: Dict[str, nn.Module]): 153 | super().__init__() 154 | self.heads = nn.ModuleDict(d) 155 | self.input_shape = next(iter(d.values())).input_shape 156 | 157 | def forward(self, x): 158 | return {key: h(x) for key, h in self.heads.items()} 159 | 160 | def loss_function(self, preprocess: Callable) -> CompositeLoss: 161 | return CompositeLoss( 162 | {k: h.loss_function() for k, h in self.heads.items() if h.loss_function() is not None}, 163 | preprocess, 164 | ) 165 | 166 | def distribution(self, params: Dict[str, torch.Tensor]) -> CompositeDistribution: 167 | gens = {k: h.distribution(params[k]) for k, h in self.heads.items()} 168 | return CompositeDistribution(gens) 169 | -------------------------------------------------------------------------------- /neural_lifetimes/utils/data/feature_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .encoder_with_unknown import OrdinalEncoderWithUnknown 7 | 8 | 9 | class FeatureDictionaryEncoder: 10 | def __init__( 11 | self, 12 | continuous_features: List[str], 13 | discrete_features: Dict[str, np.ndarray], 14 | pre_encoded: bool = False, 15 | start_token_discrete: Optional[str] = None, 16 | ) -> None: 17 | """A Combined encoder for dictionary-based batches. 18 | 19 | This Encoder applies an ``utils.data.OrdinalEncoderWithUnknown`` discrete features and 20 | the identity function to continuous features in a dictionary. It further, converts values to normalised types 21 | for Pytorch training, i.e. `int64` for discrete and `float32` for continuous features. 22 | 23 | Args: 24 | continuous_features (List[str]): The names of items to be treated as continuous features. 25 | discrete_features (Dict[str, np.ndarray]): A dictionary of discrete features with their names as keys and 26 | levels as value. 27 | pre_encoded (bool, optional): If features are loaded pre-encoded, set this to ``True``. 28 | This will skip encoding while still enabling decoding. Defaults to False. 29 | start_token_discrete (Optional[str], optional): This encoder assumes that the start token of discrete 30 | features is part of the levels as passed into ``discrete_values``. If it isn't, specify the token here 31 | to append it manually. Defaults to None. 32 | """ 33 | self.continuous_features = continuous_features 34 | self.discrete_features = discrete_features 35 | self.pre_encoded = pre_encoded 36 | self.start_token_discrete = start_token_discrete 37 | 38 | self.enc = {} 39 | for name, values in discrete_features.items(): 40 | if start_token_discrete is not None: 41 | values = np.concatenate(([self.start_token_discrete], values)).astype(values.dtype) 42 | self.enc[name] = OrdinalEncoderWithUnknown() 43 | self.enc[name].fit(values) 44 | 45 | def feature_size(self, name: str) -> int: 46 | if name in self.discrete_features: 47 | return len(self.enc[name]) 48 | elif name in self.continuous_features: 49 | return 1 50 | else: 51 | raise KeyError(f"'{name}' unknown.") 52 | 53 | # encode single item 54 | def encode(self, name: str, x: np.ndarray) -> np.ndarray: 55 | """Encode a single feature. 56 | 57 | Args: 58 | name (str): feature name. 59 | x (np.ndarray): raw feature values. 60 | 61 | Returns: 62 | np.ndarray: encoded feature values. 63 | """ 64 | # if discrete features are not pre-encoded, encode them 65 | if not self.pre_encoded and name in self.discrete_features: 66 | x = self.enc[name].transform(x).reshape(-1) 67 | 68 | # change types for continuous and discrete features for smooth torch conversion 69 | if name in self.continuous_features: 70 | encoded = x.astype(np.float32) 71 | elif name in self.discrete_features: 72 | encoded = x.astype(np.int64) 73 | else: 74 | encoded = x # for non-numeric data. e.g. user profile ID 75 | 76 | return encoded 77 | 78 | # decode single item 79 | def decode(self, name: str, x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: 80 | """Decode a single feature. 81 | 82 | Args: 83 | name (str): feature name. 84 | x (Union[np.ndarray, torch.Tensor]): encoded feature values. 85 | 86 | Returns: 87 | np.ndarray: decoded feature values. 88 | """ 89 | if self.pre_encoded or name not in self.discrete_features: 90 | decoded = x 91 | else: 92 | decoded = self.enc[name].inverse_transform(x).reshape(-1) 93 | return decoded 94 | 95 | def __call__(self, x: Dict[str, Any], mode: str = "encode") -> Dict[str, Any]: 96 | """Transform a dictionary of features. This call can encode and decode. 97 | 98 | Args: 99 | x (Dict[str, Any]): The input features. 100 | mode (str, optional): The transform mode can be ``enode`` or ``decode``. Defaults to "encode". 101 | 102 | Returns: 103 | Dict[str, Any]: The transformed features. 104 | """ 105 | assert mode in ["encode", "decode"] 106 | if mode == "encode": 107 | return {k: self.encode(k, v) for k, v in x.items()} 108 | else: 109 | return {k: self.decode(k, v) for k, v in x.items()} 110 | 111 | @property 112 | def features(self) -> List[str]: 113 | return self.continuous_features + list(self.discrete_features.keys()) 114 | 115 | def config_dict(self) -> Dict[str, Any]: 116 | """Dump the encoder into a dictionary. It can be used to re-initialise the object using ``.from_dict()``. 117 | 118 | Returns: 119 | Dict[str, Any]: The Encoder config dictionary. A dictionary containing all arguments required to initialise 120 | the object. 121 | """ 122 | return { 123 | "continuous_features": self.continuous_features, 124 | "discrete_features": self.discrete_features, 125 | "pre_encoded": self.pre_encoded, 126 | "start_token_discrete": self.start_token_discrete, 127 | } 128 | 129 | @classmethod 130 | def from_dict(cls, dictionary: Dict[str, Any]): 131 | """Initialise encoder from dictionary. 132 | 133 | Args: 134 | dictionary (Dict[str, Any]): A configuration dict as dumped by ``.config_dict()``. 135 | """ 136 | assert "continuous_features" in dictionary 137 | assert "discrete_features" in dictionary 138 | assert "pre_encoded" in dictionary 139 | 140 | return cls( 141 | continuous_features=dictionary["continuous_features"], 142 | discrete_features=dictionary["discrete_features"], 143 | pre_encoded=dictionary["pre_encoded"], 144 | start_token_discrete=dictionary["start_token_discrete"], 145 | ) 146 | -------------------------------------------------------------------------------- /tests/test_interface/test_run_model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | from typing import Tuple 5 | import numpy as np 6 | 7 | import pytest 8 | from pytorch_lightning import LightningDataModule, LightningModule 9 | from pytorch_lightning.loggers import CSVLogger 10 | 11 | from neural_lifetimes import run_model 12 | from neural_lifetimes.data.datamodules.sequence_datamodule import SequenceDataModule 13 | from neural_lifetimes.data.datasets.btyd import BTYD, GenMode 14 | from neural_lifetimes.models.modules import VariationalEventModel, EventModel 15 | from neural_lifetimes.utils.data import Tokenizer, FeatureDictionaryEncoder, TargetCreator 16 | 17 | from ..test_datasets.datamodels import EventprofilesDataModel 18 | 19 | DISCRETE_START_TOKEN = "StartToken" 20 | 21 | 22 | @pytest.fixture 23 | def log_dir() -> str: 24 | with TemporaryDirectory() as f: 25 | yield f 26 | 27 | 28 | # @pytest.fixture 29 | # def data_dir() -> str: 30 | # return str(Path(__file__).parents[2] / "examples") 31 | 32 | 33 | @pytest.fixture 34 | def data_and_model() -> Tuple[SequenceDataModule, EventModel]: 35 | # create btyd data and dependent modules 36 | data_model = EventprofilesDataModel() 37 | 38 | dataset = BTYD.from_modes( 39 | modes=[GenMode(a=1, b=3, r=1, alpha=15)], 40 | num_customers=100, 41 | seq_gen_dynamic=True, 42 | start_date=datetime.datetime(2019, 1, 1, 0, 0, 0), 43 | start_limit_date=datetime.datetime(2019, 6, 15, 0, 0, 0), 44 | end_date=datetime.datetime(2021, 1, 1, 0, 0, 0), 45 | data_dir=str(Path(__file__).parents[2] / "examples"), 46 | continuous_features=data_model.cont_feat, 47 | discrete_features=data_model.discr_feat, 48 | ) 49 | 50 | discr_values = dataset.get_discrete_feature_values("") 51 | 52 | transform = FeatureDictionaryEncoder(data_model.cont_feat, discr_values) 53 | target_transform = TargetCreator(cols=data_model.target_cols + data_model.cont_feat + data_model.discr_feat) 54 | tokenizer = Tokenizer( 55 | data_model.cont_feat, discr_values, 100, np.nan, "", datetime.datetime(1970, 1, 1), np.nan 56 | ) 57 | 58 | datamodule = SequenceDataModule( 59 | dataset=dataset, 60 | transform=transform, 61 | target_transform=target_transform, 62 | tokenizer=tokenizer, 63 | test_size=0.2, 64 | batch_points=256, 65 | min_points=1, 66 | ) 67 | 68 | # create model 69 | model = VariationalEventModel( 70 | feature_encoder_config=transform.config_dict(), 71 | rnn_dim=256, 72 | emb_dim=256, 73 | drop_rate=0.5, 74 | bottleneck_dim=32, 75 | lr=0.001, 76 | target_cols=target_transform.cols, 77 | vae_sample_z=True, 78 | vae_sampling_scaler=1.0, 79 | vae_KL_weight=0.01, 80 | ) 81 | return datamodule, model 82 | 83 | 84 | @pytest.mark.slow 85 | class TestRunModel: 86 | @pytest.mark.parametrize("run_mode", ("train", "test", "none")) 87 | def test_defaults_with_run( 88 | self, 89 | log_dir: str, 90 | data_and_model: Tuple[LightningDataModule, LightningModule], 91 | run_mode: str, 92 | ) -> None: 93 | datamodule, model = data_and_model 94 | run_model( 95 | datamodule, 96 | model, 97 | log_dir=log_dir, 98 | num_epochs=2, 99 | run_mode=run_mode, 100 | ) 101 | 102 | def test_custom_trainer_kwargs( 103 | self, 104 | log_dir: str, 105 | data_and_model: Tuple[LightningDataModule, LightningModule], 106 | ) -> None: 107 | datamodule, model = data_and_model 108 | trainer = run_model( 109 | datamodule, 110 | model, 111 | log_dir=log_dir, 112 | checkpoint_path=None, 113 | run_mode="none", 114 | trainer_kwargs={"accumulate_grad_batches": 2, "callbacks": []}, 115 | ) 116 | assert len(trainer.callbacks) <= 4 117 | # progressbar, model summary and model checkpoints are turned on by default. 118 | # Gradient Accummulation also enabled => max 4 callbacks expected 119 | assert trainer.accumulate_grad_batches == 2 120 | 121 | def test_custom_logger_class( 122 | self, 123 | log_dir: str, 124 | data_and_model: Tuple[LightningDataModule, LightningModule], 125 | ) -> None: 126 | datamodule, model = data_and_model 127 | loggers = [CSVLogger(save_dir=log_dir)] 128 | trainer = run_model( 129 | datamodule, 130 | model, 131 | log_dir=log_dir, 132 | checkpoint_path=None, 133 | run_mode="none", 134 | loggers=loggers, 135 | ) 136 | assert isinstance(trainer.logger, CSVLogger) 137 | 138 | def test_custom_logger_kwargs( 139 | self, 140 | log_dir: str, 141 | data_and_model: Tuple[LightningDataModule, LightningModule], 142 | ) -> None: 143 | datamodule, model = data_and_model 144 | logger_kwargs = {"save_dir": "directory123", "version": 42} 145 | trainer = run_model( 146 | datamodule, 147 | model, 148 | log_dir=log_dir, 149 | checkpoint_path=None, 150 | run_mode="none", 151 | logger_kwargs=logger_kwargs, 152 | ) 153 | assert trainer.logger.version == 42 154 | assert trainer.logger.save_dir == "directory123" 155 | 156 | def test_custom_logger_raises_error( 157 | self, 158 | log_dir: str, 159 | data_and_model: Tuple[LightningDataModule, LightningModule], 160 | ) -> None: 161 | datamodule, model = data_and_model 162 | loggers = [CSVLogger(save_dir=log_dir)] 163 | logger_kwargs = {"version": 42} 164 | with pytest.raises(AssertionError) as excinfo: 165 | _ = run_model( 166 | datamodule, 167 | model, 168 | log_dir=log_dir, 169 | checkpoint_path=None, 170 | run_mode="none", 171 | loggers=loggers, 172 | logger_kwargs=logger_kwargs, 173 | ) 174 | (msg,) = excinfo.value.args 175 | assert msg == "If custom logger is supplied, 'logger_kwargs' argument must be 'None'" 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Lifetimes 2 | 3 | #TODO Insert Logo 4 | 5 | ![test](https://github.com/transferwise/neural-lifetimes/actions/workflows/test.yml/badge.svg) 6 | ![lint](https://github.com/transferwise/neural-lifetimes/actions/workflows/lint.yml/badge.svg) 7 | ![format](https://github.com/transferwise/neural-lifetimes/actions/workflows/format.yml/badge.svg) 8 | ![docs](https://github.com/transferwise/neural-lifetimes/actions/workflows/docs.yml/badge.svg) 9 | ![pypi](https://img.shields.io/pypi/v/neural-lifetimes) 10 | 11 | # Introduction 12 | 13 | The Neural Lifetimes package is an open-source lightweight framework based on [PyTorch](https://pytorch.org/) and [PyTorch-Lightning](https://www.pytorchlightning.ai/) to conduct modern lifetimes analysis based on neural network models. This package provides both flexibility and simplicity: 14 | 15 | - Users can use the simple interface to load their own data and train good models _out-of-the-box_ with very few lines of code. 16 | - The modular design of this package enables users to selectively pick individual tools. 17 | 18 | Possible usage of Neural Lifetimes is 19 | 20 | - Predicting customer transactions 21 | - Calculating Expected Customer Lifetime Values 22 | - Obtain Customer Embeddings 23 | - TODO add more 24 | 25 | # Features 26 | 27 | ## Simple Interface 28 | 29 | You can run your own dataset with a few lines of code: 30 | 31 | ## Data 32 | 33 | We introduce a set of tools to 34 | 35 | - Load data in batches from database 36 | - Handle sequential data 37 | - Load data from interfaces such as Pandas, Clickhouse, Postgres, VAEX and more 38 | 39 | We further provide a simulated dataset based on the `BTYD` model for exploring this package and we provide tutorials to understand the mechanics of this model. 40 | 41 | ## Models 42 | 43 | We provide a simple `GRU`-based model that embeds any data and predicts sequences of transactions. 44 | 45 | ## Model Inference 46 | 47 | The class `inference.ModelInference` allows to simulate sequences from scratch or extend sequences from a model artifact. 48 | A sequence is simulated/extended iteratively by adding one event at the end of the sequence each time. 49 | To simulate an event, the current sequence is used as the model input and the distributions outputted by the model are 50 | used to sample the next event. The sampled event is added to the sequence and the resulting sequence is used as an input 51 | in the following iteration. The process ends if a sequence reaches the `end_date` or if the 52 | customer churns. 53 | 54 | To initialize the `ModelInference` class needs, you need to give the filepath of a trained model artifact: 55 | 56 | ``` 57 | inference = ModelInference( 58 | model_filename = "/logs/artifacts/version_1/epoch=0-step=1-val_loss_total=1.0.ckpt" 59 | ) 60 | ``` 61 | 62 | `ModelInference` has two main methods: 63 | 64 | - `simulate_sequences`: simulates `n` sequences from scratch. The sequences start with an event randomly sampled between 65 | `start_date` and `start_date_limit`. The sequences of events are build by sampling 66 | from the model distribution ouputs. The sequence is initialized with a Starting Token event. 67 | A sequence will end when if either the user churns or if an event happens after the 68 | `end_date`. 69 | 70 | ``` 71 | simulate_sequences = inference.simulate_sequences( 72 | n = 10, 73 | start_date = datetime.datetime(2021, 1, 1, 0, 0, 0), 74 | start_date_limit = datetime.datetime(2021, 2, 1, 0, 0, 0), 75 | end_date = datetime.datetime(2021, 4, 1, 0, 0, 0), 76 | start_token_discr = 'StartToken', 77 | start_token_cont = 0 78 | ) 79 | ``` 80 | 81 | - `extend_sequence`: takes a `ml_utils.torch.sequence_loader.SequenceLoader` loader and the start and end date of the 82 | simulation. The method processes the loader in batches. The `start_date` must be after any event in any sequence. Customers might have already churned after their last event 83 | so we first need to infer the churn status of the customers. To infer the churn status, we input a sequence into the model 84 | and sample from the output distributions. If the churn status after the last event is True or the next event would have 85 | happened before `start_date` we infer that that customer has churned. 86 | For all the customer sequence that haven't churned we extend the sequences as in `simulate_sequences`. 87 | 88 | ``` 89 | raw_data, extended_seq = inference.extend_sequence( 90 | loader, 91 | start_date = datetime.datetime(2021, 1, 1, 0, 0, 0), 92 | end_date = datetime.datetime(2021, 4, 1, 0, 0, 0), 93 | return_input = True 94 | ) 95 | ``` 96 | 97 | The `extend_sequence` method can return also the original sequences if `return_input = True`. 98 | `extended_seq` contains list of dicts where each dict is a processed batch. Each dict has two keys: 'extended_sequences' and 'inferred_churn'. 99 | 'extended_sequences' contains the extended sequences that were inferred NOT to have churned. 100 | 'inferred_churn' contains the sequences that were inferred to have churned. 101 | 102 | # Documentation 103 | 104 | The documentation for this repository is available at 105 | 106 | [TODO Add Link]() 107 | 108 | # Install 109 | 110 | You may install the package from [PyPI](https://pypi.org/project/neural-lifetimes/): 111 | 112 | ```bash 113 | pip install neural-lifetimes 114 | ``` 115 | 116 | Alternatively, you may install from git to get access to the latest commits: 117 | 118 | ```bash 119 | pip install git+https://github.com/transferwise/neural-lifetimes 120 | ``` 121 | 122 | # Getting started 123 | 124 | In the documentation there is a tutorial on getting started. 125 | 126 | [TODO add link]() 127 | 128 | #TODO add google colab notebook to start 129 | 130 | # Useful Resources 131 | 132 | - Github: [Lifetimes Package](https://github.com/CamDavidsonPilon/lifetimes) 133 | - Documentation: [PyTorch](https://pytorch.org/docs/stable/index.html/) 134 | - Documentation: [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) 135 | - Paper: [Fader et al. (2005), "Counting Your Customers" the Easy Way: An Alternative to the Pareto/NBD Model](http://brucehardie.com/papers/018/fader_et_al_mksc_05.pdf) 136 | 137 | # Contribute 138 | 139 | We welcome all contributions to this repository. Please read the [Contributing Guide](https://github.com/transferwise/neural_lifetimes/blob/update-readme/CONTRIBUTING.md). 140 | 141 | If you have any questions or comments please raise a Github issue. 142 | -------------------------------------------------------------------------------- /neural_lifetimes/run_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 6 | from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger 7 | 8 | from .data.datamodules.sequence_datamodule import SequenceDataModule 9 | from .utils.callbacks import MonitorChurn, DistributionMonitor, MonitorProjection 10 | 11 | 12 | def run_model( 13 | datamodule: SequenceDataModule, 14 | model: pl.LightningModule, 15 | log_dir: str, 16 | num_epochs: int = 100, 17 | checkpoint_path: Optional[str] = None, 18 | run_mode: str = "train", 19 | val_check_interval: Union[int, float] = 1.0, 20 | limit_val_batches: Union[int, float] = 1.0, 21 | gradient_clipping: float = 0.1, 22 | trainer_kwargs: Optional[Dict[str, Any]] = None, 23 | loggers: Optional[List[LightningLoggerBase]] = None, 24 | logger_kwargs: Optional[Dict[str, Any]] = None, 25 | ) -> pl.Trainer: 26 | """ 27 | Run the model for training or testing. 28 | 29 | Note: 30 | This function function will run a model using the 31 | `Trainer `. Some key arguments for you to set are 32 | given explicitly. You can overwrite any other argument of the Trainer or Logger manually. 33 | 34 | Args: 35 | datamodule (PLDataModule): The data on which the model is run. 36 | model (pytorch_lightning.LightningModule): The LightningModule containing the Embedder, TargetCreator, network 37 | and instructions for forward passes. 38 | log_dir (str): Path into which model checkpoints and tensorboard logs will be written. 39 | num_epochs (int): Number of epochs to train for. Default is 100. 40 | checkpoint_path (Optional[str]): Path to a model to load. If is ``None`` (default), we will train a new model. 41 | run_mode (str): 'train': Train the model 'test': run an inference pass. 'none': Do neither. 42 | Default is ``train``. 43 | val_check_interval (int | float): sets the frequency of running the model on the validation set. 44 | See: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#val-check-interval. 45 | Default is ``1.0``. 46 | limit_val_batches (int | float): Limits the number of batches of the validation set that is run by the Trainer. 47 | See: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#limit-val-batches. 48 | Default is ``1.0``. 49 | gradient_clipping (float): sets the threshold L2 norm for clipping the gradients. 50 | See: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#gradient-clip-val. 51 | Default is ``0.1``. 52 | trainer_kwargs (Dict(str, Any)): Forward any keyword argument to the `Trainer `. 53 | Any argument passed here, will be set in the Trainer constructor. Default is ``None``. 54 | loggers (Optional[List[LightningLoggerBase]]): This function uses the TensorboardLogger by default. If you wish 55 | to use another logger, you can pass a list of any other `Logger ` in 56 | here. Default is ``None``. 57 | logger_kwargs (Dict(str, Any)): Forward any keword argument to the 58 | `Logger `. Any argument passed here, will be set in the Logger 59 | constructor. Default is ``None``. 60 | 61 | Returns: 62 | pl.Trainer: The trainer object. 63 | """ 64 | if checkpoint_path: 65 | model.load_state_from_checkpoint(checkpoint_path) 66 | 67 | # configure loggers to be used by pytorch. 68 | # https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html 69 | if loggers is None: 70 | # process user arguments 71 | if logger_kwargs is None: 72 | logger_kwargs = {} 73 | # ensure that the user can overwrite the default Tb arguments 74 | logger_kwargs = dict( 75 | { 76 | "save_dir": log_dir, 77 | "default_hp_metric": True, 78 | "log_graph": False, 79 | "name": "logs", 80 | }, 81 | **logger_kwargs, 82 | ) 83 | loggers = [TensorBoardLogger(**logger_kwargs)] 84 | else: 85 | assert isinstance(loggers, List) 86 | assert [isinstance(logger, LightningLoggerBase) for logger in loggers] 87 | assert logger_kwargs is None, "If custom logger is supplied, 'logger_kwargs' argument must be 'None'" 88 | 89 | # configure callbacks. That is actions taken on certain events, e.g. epoch end. 90 | callbacks = [ 91 | ModelCheckpoint( 92 | dirpath=os.path.join(log_dir, f"version_{loggers[0].version}"), 93 | filename="{epoch}-{step}-{val_loss/total:.2f}", 94 | monitor="val_loss/total", 95 | mode="min", 96 | save_last=True, 97 | save_top_k=3, 98 | verbose=True, 99 | ), 100 | DistributionMonitor(), 101 | MonitorProjection(), 102 | MonitorChurn(), 103 | ] 104 | 105 | # process user arguments for Trainer 106 | if trainer_kwargs is None: 107 | trainer_kwargs = {} 108 | # ensure the user can overwrite anything they want 109 | trainer_kwargs = dict( 110 | { 111 | "callbacks": callbacks, 112 | "logger": loggers, 113 | "max_epochs": num_epochs, 114 | "val_check_interval": val_check_interval, 115 | "limit_val_batches": limit_val_batches, 116 | "amp_backend": "native", 117 | "precision": 16, 118 | "track_grad_norm": 2, 119 | "gradient_clip_val": gradient_clipping, 120 | }, 121 | **trainer_kwargs, 122 | ) 123 | 124 | trainer = pl.Trainer( 125 | **trainer_kwargs, 126 | ) 127 | 128 | # if model weights are supplied we do inference. Otherwise we train. 129 | if run_mode == "test": 130 | trainer.test(model, datamodule=datamodule) 131 | elif run_mode == "train": 132 | # start training 133 | trainer.fit(model, datamodule=datamodule) 134 | elif run_mode == "none": 135 | pass 136 | else: 137 | raise ValueError(f"`run_mode` must be`train`, `test` or `none`. Not {run_mode}.") 138 | return trainer 139 | 140 | 141 | # TODO add train, test and inference functions for users to find in the docs and being redirected to run_model 142 | -------------------------------------------------------------------------------- /neural_lifetimes/losses/mutual_information.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | from neural_lifetimes.utils.score_estimators import SpectralScoreEstimator 6 | from neural_lifetimes.utils.scheduler import WeightScheduler 7 | 8 | # code heavily relies on https://github.com/zhouyiji/MIGE 9 | 10 | 11 | class MutualInformationGradientEstimator(nn.Module): 12 | def __init__(self, n_eigen: int = None, n_eigen_threshold: float = None) -> None: 13 | """The `MutualInformationGradientEstimator` provides a Monte Carlo estimator of Mutual Information. 14 | 15 | This class allows to train a model using mutual information based on empirical data distributions as an 16 | objective function. The Mutual Information Gradients are estimated using the MIGE method, Wen et al, 2020. The 17 | method is fully compatable with stochastic optimization techniques. This class can be used like any loss in 18 | PyTorch. Beware, that this is the mutual information, so gradient descent would have to be used with the 19 | negative mutual information. 20 | 21 | See: https://openreview.net/forum?id=ByxaUgrFvH¬eId=rkePFjKYor 22 | 23 | Args: 24 | n_eigen (int, optional): Sets the number of eigenvalues to be used for the Nyström approximation. 25 | If ``None``, all values will be used. Defaults to None. 26 | n_eigen_threshold (float, optional): Sets the threshold for eigenvalues to be used for the Nystöm 27 | approximation. If ``None``, all values will be used. Defaults to None. 28 | """ 29 | super().__init__() 30 | self.spectral_j = SpectralScoreEstimator(n_eigen=n_eigen, n_eigen_threshold=n_eigen_threshold) 31 | self.spectral_m = SpectralScoreEstimator(n_eigen=n_eigen, n_eigen_threshold=n_eigen_threshold) 32 | 33 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 34 | """Stores the mutual information gradient in the autograd system. 35 | 36 | Args: 37 | x (torch.Tensor): The input data. 38 | y (torch.Tensor): The output data. 39 | 40 | Returns: 41 | torch.Tensor: A scalar leading associated with the correct gradient of the MI. 42 | """ 43 | loss = self._entropy_surrogate(self.spectral_j, torch.cat([x, y], dim=-1)) - self._entropy_surrogate( 44 | self.spectral_m, y 45 | ) 46 | return loss 47 | 48 | @staticmethod 49 | def _entropy_surrogate(estimator, samples): 50 | dlog_q = estimator.compute_gradients(samples.detach(), None) 51 | surrogate_cost = torch.mean(torch.sum(dlog_q.detach() * samples, -1)) 52 | return surrogate_cost 53 | 54 | 55 | class InformationBottleneckLoss(nn.Module): 56 | def __init__( 57 | self, 58 | fit_loss: nn.Module, 59 | weight_scheduler: WeightScheduler, 60 | n_eigen: int = None, 61 | n_eigen_threshold: float = None, 62 | ) -> None: 63 | """Implements an information bottleneck loss. 64 | 65 | The information bottleneck is a penalty on the any other loss. The information bottlneck between A and B 66 | minimises the information about A contained in B, i.e. the mutual information I(A,B). The objective function is: 67 | ``L = fit_loss * reg_weight * I(A,B)`` 68 | 69 | The NN predicts: X -> A -> B -> Y. Where, X is the data, A is the RNN output, B is the bottleneck, and Y is the 70 | prediction. The information bottleneck is enforced between A and B. If ``reg_weight`` is small enough, the fit 71 | loss forces information to flow from A -> B that is relevant to Y, however, the I(A,B) penalty will penalise all 72 | information flowing through, meaning that information in X irrelevant to Y will be supressed. 73 | 74 | Args: 75 | fit_loss (nn.Module): The loss function for the model fit. 76 | weight_scheduler (WeightScheduler): A weight scheduler for the weight of the mutual information in the total 77 | loss. 78 | n_eigen (int, optional): Sets the number of eigenvalues to be used for the Nyström approximation. 79 | If ``None``, all values will be used. Defaults to None. 80 | n_eigen_threshold (float, optional): Sets the threshold for eigenvalues to be used for the Nystöm 81 | approximation. If ``None``, all values will be used. Defaults to None. 82 | """ 83 | super().__init__() 84 | self.fit_loss = fit_loss 85 | self.weight_scheduler = weight_scheduler 86 | self.mi = MutualInformationGradientEstimator(n_eigen=n_eigen, n_eigen_threshold=n_eigen_threshold) 87 | 88 | def forward( 89 | self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] 90 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 91 | """Calculated the information bottleneck loss. 92 | 93 | Args: 94 | pred (Dict[str, torch.Tensor]): the NN prediction dictionary. Beyond keys required for ``fit_loss``, 95 | this should contain "event_encoding" and "bottleneck" between which the penalty is imposed. 96 | target (Dict[str, torch.Tensor]): The target data. 97 | 98 | Returns: 99 | Tuple[torch.Tensor, Dict[str, torch.Tensor]]: The loss. 100 | """ 101 | bottleneck = pred.pop("bottleneck") 102 | event_encoding = pred.pop("event_encoding") 103 | fit_loss, losses_dict = self.fit_loss(pred, target) 104 | latent_loss = self.mi(event_encoding, bottleneck) 105 | loss = fit_loss + self.reg_weight * latent_loss 106 | 107 | losses_dict["model_fit"] = fit_loss 108 | losses_dict["total"] = loss 109 | losses_dict["mutual_information"] = latent_loss 110 | 111 | losses_dict = {f"loss/{name}": loss for name, loss in losses_dict.items()} 112 | 113 | assert loss not in [-torch.inf, torch.inf], "Loss not finite!" 114 | assert not torch.isnan(loss), "Got a NaN loss" 115 | 116 | # if the sum is finite, this should be redundant?! 117 | assert sum(losses_dict.values()) not in [-torch.inf, torch.inf], "Loss not finite!" 118 | assert not torch.isnan(sum(losses_dict.values())), "Got a NaN loss" 119 | 120 | return loss, {k: v.detach() for k, v in losses_dict.items()} 121 | 122 | @property 123 | def reg_weight(self) -> float: 124 | """The weight currently used for the mutual information. 125 | 126 | This is an alias and determined through the 127 | ``WeightScheduler``. 128 | 129 | Returns: 130 | float: The current weight 131 | """ 132 | return self.weight_scheduler.weight 133 | -------------------------------------------------------------------------------- /tests/test_utils/test_data/test_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import pytest 5 | 6 | from neural_lifetimes.utils.data import FeatureDictionaryEncoder, OrdinalEncoderWithUnknown 7 | 8 | 9 | @pytest.fixture 10 | def data(): 11 | return { 12 | "CF1": np.array([1, 2, 3, 4], dtype=np.float32), 13 | "CF2": np.array([4, 3, 2, 1], dtype=np.float32), 14 | "DF1": np.array(["level_1", "level_1", "level_2", "level_3"]), 15 | "DF2": np.array(["l1", "l1", "l2", "l3"]), 16 | } 17 | 18 | 19 | @pytest.fixture 20 | def discrete_values(): 21 | return { 22 | "DF1": np.array(["level_1", "level_2", "level_3"]), 23 | "DF2": np.array(["l1", "l2", "l3"]), 24 | } 25 | 26 | 27 | class Test_OrdinalEncoderWithUnkown: 28 | def _construct_and_fit(self): 29 | encoder = OrdinalEncoderWithUnknown() 30 | encoder.fit(np.array(["level_1", "level_2", "level_3"])) 31 | return encoder 32 | 33 | def test_fit(self): 34 | self._construct_and_fit() 35 | 36 | def test_levels(self): 37 | encoder = self._construct_and_fit() 38 | assert np.all(encoder.levels == np.array(["