├── .github
└── workflows
│ ├── ci.yaml
│ ├── manual-python-publish.yaml
│ └── python_publish.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── .gitignore
├── Makefile
├── api.rst
├── conf.py
├── custom_priors.ipynb
├── faq.md
├── index.rst
└── models.ipynb
├── examples
├── end_to_end_demo_with_multiple_geos.ipynb
└── simple_end_to_end_demo.ipynb
├── images
├── flowchart.png
├── lightweight_mmm_logo_colored_1000.png
├── lightweight_mmm_logo_colored_250.png
├── lightweight_mmm_logo_colored_500.png
└── lightweight_mmm_logo_colored_750.png
├── lightweight_mmm
├── __init__.py
├── conftest.py
├── core
│ ├── __init__.py
│ ├── baseline
│ │ ├── __init__.py
│ │ ├── intercept.py
│ │ └── intercept_test.py
│ ├── core_utils.py
│ ├── core_utils_test.py
│ ├── priors.py
│ ├── time
│ │ ├── __init__.py
│ │ ├── seasonality.py
│ │ ├── seasonality_test.py
│ │ ├── trend.py
│ │ └── trend_test.py
│ └── transformations
│ │ ├── __init__.py
│ │ ├── identity.py
│ │ ├── lagging.py
│ │ ├── lagging_test.py
│ │ ├── saturation.py
│ │ └── saturation_test.py
├── lightweight_mmm.py
├── lightweight_mmm_test.py
├── media_transforms.py
├── media_transforms_test.py
├── models.py
├── models_test.py
├── optimize_media.py
├── optimize_media_test.py
├── plot.py
├── plot_test.py
├── preprocessing.py
├── preprocessing_test.py
├── utils.py
└── utils_test.py
├── readthedocs.yaml
├── requirements
├── requirements.txt
├── requirements_docs.txt
└── requirements_tests.txt
├── setup.py
└── test.sh
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: ci
16 |
17 | on:
18 | push:
19 | branches: ["main"]
20 | pull_request:
21 | branches: ["main"]
22 |
23 | jobs:
24 | build-and-test:
25 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
26 | runs-on: "${{ matrix.os }}"
27 |
28 | strategy:
29 | matrix:
30 | python-version: ["3.8", "3.9", "3.10"]
31 | os: [ubuntu-latest]
32 |
33 | steps:
34 | - uses: "actions/checkout@v2"
35 | - uses: "actions/setup-python@v1"
36 | with:
37 | python-version: "${{ matrix.python-version }}"
38 | - name: Run CI tests
39 | run: bash test.sh
40 | shell: bash
41 |
--------------------------------------------------------------------------------
/.github/workflows/manual-python-publish.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This workflow will upload a Python Package using Twine when a release is created
16 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
17 |
18 | # This workflow uses actions that are not certified by GitHub.
19 | # They are provided by a third-party and are governed by
20 | # separate terms of service, privacy policy, and support
21 | # documentation.
22 |
23 | name: Upload Python Package
24 |
25 | on:
26 | release:
27 | types: [published]
28 | workflow_dispatch:
29 |
30 |
31 | permissions:
32 | contents: read
33 |
34 | jobs:
35 | deploy:
36 |
37 | runs-on: ubuntu-latest
38 |
39 | steps:
40 | - uses: actions/checkout@v3
41 | - name: Set up Python
42 | uses: actions/setup-python@v3
43 | with:
44 | python-version: '3.x'
45 | - name: Install dependencies
46 | run: |
47 | python -m pip install --upgrade pip
48 | pip install setuptools wheel twine
49 | - name: Build and publish
50 | env:
51 | TWINE_USERNAME: __token__
52 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
53 | run: |
54 | python setup.py sdist bdist_wheel
55 | twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/python_publish.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: Upload Python Package
16 |
17 | on:
18 | release:
19 | types: [created]
20 |
21 | jobs:
22 | deploy:
23 | runs-on: ubuntu-latest
24 |
25 | steps:
26 | - uses: actions/checkout@v2
27 | - uses: actions/setup-python@v2
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install setuptools wheel twine
32 | - name: Build and publish
33 | env:
34 | TWINE_USERNAME: __token__
35 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
36 | run: |
37 | python setup.py sdist bdist_wheel
38 | twine upload dist/*
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | _build
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # Minimal makefile for Sphinx documentation
17 | #
18 | # You can set these variables from the command line.
19 | SPHINXOPTS =
20 | SPHINXBUILD = sphinx-build
21 | SOURCEDIR = .
22 | BUILDDIR = _build
23 | # Put it first so that "make" without argument is like "make help".
24 | help:
25 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
26 | .PHONY: help Makefile
27 | # Catch-all target: route all unknown targets to Sphinx using the new
28 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
29 | %: Makefile
30 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | LightweightMMM
2 | ===============
3 | .. currentmodule:: lightweight_mmm
4 |
5 |
6 | LightweightMMM object
7 | ======================
8 | .. currentmodule:: lightweight_mmm.lightweight_mmm
9 | .. autosummary::
10 | LightweightMMM
11 |
12 | .. autoclass:: LightweightMMM
13 | :members:
14 |
15 |
16 | Preprocessing / Scaling
17 | ========================
18 | .. currentmodule:: lightweight_mmm.preprocessing
19 | .. autosummary::
20 | CustomScaler
21 |
22 | .. autoclass:: CustomScaler
23 | :members:
24 |
25 |
26 | Optimize Media
27 | ===============
28 | .. currentmodule:: lightweight_mmm.optimize_media
29 | .. autosummary::
30 | find_optimal_budgets
31 |
32 | .. autofunction:: find_optimal_budgets
33 |
34 | Plot
35 | =====
36 | .. currentmodule:: lightweight_mmm.plot
37 | .. autosummary::
38 | plot_response_curves
39 | plot_cross_correlate
40 | plot_var_cost
41 | plot_model_fit
42 | plot_out_of_sample_model_fit
43 | plot_media_channel_posteriors
44 | plot_prior_and_posterior
45 | plot_bars_media_metrics
46 | plot_pre_post_budget_allocation_comparison
47 | plot_media_baseline_contribution_area_plot
48 | create_media_baseline_contribution_df
49 |
50 | .. autofunction:: plot_response_curves
51 | .. autofunction:: plot_cross_correlate
52 | .. autofunction:: plot_var_cost
53 | .. autofunction:: plot_model_fit
54 | .. autofunction:: plot_out_of_sample_model_fit
55 | .. autofunction:: plot_media_channel_posteriors
56 | .. autofunction:: plot_prior_and_posterior
57 | .. autofunction:: plot_bars_media_metrics
58 | .. autofunction:: plot_pre_post_budget_allocation_comparison
59 | .. autofunction:: plot_media_baseline_contribution_area_plot
60 | .. autofunction:: create_media_baseline_contribution_df
61 |
62 |
63 | Models
64 | =======
65 | .. currentmodule:: lightweight_mmm.models
66 | .. autosummary::
67 | transform_adstock
68 | transform_hill_adstock
69 | transform_carryover
70 | media_mix_model
71 |
72 | .. autofunction:: transform_adstock
73 | .. autofunction:: transform_hill_adstock
74 | .. autofunction:: transform_carryover
75 | .. autofunction:: media_mix_model
76 |
77 | Media Transforms
78 | =================
79 | .. currentmodule:: lightweight_mmm.media_transforms
80 | .. autosummary::
81 | calculate_seasonality
82 | adstock
83 | hill
84 | carryover
85 | apply_exponent_safe
86 |
87 | .. autofunction:: calculate_seasonality
88 | .. autofunction:: adstock
89 | .. autofunction:: hill
90 | .. autofunction:: carryover
91 | .. autofunction:: apply_exponent_safe
92 |
93 | Utils
94 | ======
95 | .. currentmodule:: lightweight_mmm.utils
96 | .. autosummary::
97 | save_model
98 | load_model
99 | simulate_dummy_data
100 | get_halfnormal_mean_from_scale
101 | get_halfnormal_scale_from_mean
102 | get_beta_params_from_mu_sigma
103 | distance_pior_posterior
104 | interpolate_outliers
105 | dataframe_to_jax
106 |
107 | .. autofunction:: save_model
108 | .. autofunction:: load_model
109 | .. autofunction:: simulate_dummy_data
110 | .. autofunction:: get_halfnormal_mean_from_scale
111 | .. autofunction:: get_halfnormal_scale_from_mean
112 | .. autofunction:: get_beta_params_from_mu_sigma
113 | .. autofunction:: distance_pior_posterior
114 | .. autofunction:: interpolate_outliers
115 | .. autofunction:: dataframe_to_jax
116 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Configuration file for the Sphinx documentation builder."""
16 |
17 |
18 | # This file only contains a selection of the most common options. For a full
19 | # list see the documentation:
20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
21 |
22 | # -- Path setup --------------------------------------------------------------
23 |
24 | # If extensions (or modules to document with autodoc) are in another directory,
25 | # add these directories to sys.path here. If the directory is relative to the
26 | # documentation root, use os.path.abspath to make it absolute, like shown here.
27 | #
28 | # import os
29 | # import sys
30 | # sys.path.insert(0, os.path.abspath('.'))
31 |
32 | import os
33 | import sys
34 | sys.path.insert(0, os.path.abspath('..'))
35 |
36 | # -- Project information -----------------------------------------------------
37 |
38 | project = 'LightweightMMM'
39 | copyright = '2022, The LightweightMMM authors'
40 | author = 'The LightweightMMM authors'
41 |
42 |
43 | # -- General configuration ---------------------------------------------------
44 |
45 | # Add any Sphinx extension module names here, as strings. They can be
46 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
47 | # ones.
48 | extensions = [
49 | 'sphinx.ext.autodoc',
50 | 'sphinx.ext.autosummary',
51 | 'sphinx.ext.autosectionlabel',
52 | 'sphinx.ext.doctest',
53 | 'sphinx.ext.intersphinx',
54 | 'sphinx.ext.mathjax',
55 | 'sphinx.ext.napoleon',
56 | 'sphinx.ext.viewcode',
57 | 'nbsphinx',
58 | 'myst_nb', # This is used for the .ipynb notebooks
59 | ]
60 |
61 | # Add any paths that contain templates here, relative to this directory.
62 | templates_path = ['_templates']
63 |
64 | # List of patterns, relative to source directory, that match files and
65 | # directories to ignore when looking for source files.
66 | # This pattern also affects html_static_path and html_extra_path.
67 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
68 |
69 | source_suffix = ['.rst', '.md', '.ipynb']
70 |
71 | autosummary_generate = True
72 |
73 | master_doc = 'index'
74 |
75 | # -- Options for HTML output -------------------------------------------------
76 |
77 | # The theme to use for HTML and HTML Help pages. See the documentation for
78 | # a list of builtin themes.
79 | #
80 | html_theme = 'sphinx_rtd_theme'
81 | html_static_path = []
82 |
83 | # Add any paths that contain custom static files (such as style sheets) here,
84 | # relative to this directory. They are copied after the builtin static files,
85 | # so a file named "default.css" will overwrite the builtin "default.css".
86 | html_static_path = ['_static']
87 |
88 | nbsphinx_codecell_lexer = 'ipython3'
89 | nbsphinx_execute = 'never'
90 | jupyter_execute_notebooks = 'off'
91 |
92 | # -- Extension configuration -------------------------------------------------
93 |
94 | # Tell sphinx-autodoc-typehints to generate stub parameter annotations including
95 | # types, even if the parameters aren't explicitly documented.
96 | always_document_param_types = True
97 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | ## Importing
2 |
3 | *What do I do if I get a numpy ImportError when importing lightweight_mmm?*
4 | If you get an error like:
5 |
6 | ```{python}
7 | ImportError: numpy.core.multiarray failed to import
8 | ```
9 |
10 | Then this probably means you have the wrong version of numpy.
11 | The correct version should be installed by pip when you install lightweight_mmm, but often this error happens when you are using Google Colab or a Jupyter notebook and installing directly from a notebook cell.
12 | To resolve this, make sure you have restarted your Google Colab / Jupyter session after running `!pip install lightwight_mmm`.
13 | This makes sure you have the newly installed version of numpy imported.
14 |
15 | ## Input
16 | *Which media channel metrics can be used as input?*
17 |
18 | You can use impressions, clicks or cost, especially for non digital data. For TV you could e.g. use TV rating points or cost. The model only takes the variation within a channel into account.
19 |
20 | *Can I run MMM at campaign level?*
21 |
22 | We generally don't recommend this. MMM is a macro tool that works well at the channel level. If you use distinct campaigns that have hard starts and stops, you risk losing the memory of the Adstock.
23 | If you are interested in more granular insights, we recommend data-driven multi-touch attribution for your digital channels. You can find an example open-source package [here](https://github.com/google/fractribution/tree/master/py).
24 |
25 | *What is the ideal ratio between %train and %test data for LMMM?*
26 |
27 | Remember we treat LMMM not as a forecasting model, so test data is not always needed. When opting for a train / test split you can use at least 13 weeks for a test. However we recommend refraining from a too log testing period and possibly consider running a separate model that is specifically built for forecasting.
28 |
29 | *What are best practices for lead generating businesses, with long sales cycles?*
30 |
31 | It really depends on your target variable, i.e. what outcome you would like to measure. If generating a lead takes multiple months, you can take more immediate action KPIs like ‘number of conversions’ or ‘number of site visits, form entries’ into account.
32 |
33 |
34 |
35 | ## Modelling
36 |
37 |
38 | *What can I do if the baseline is too low and total media contribution is too high?*
39 |
40 | You can try various things:
41 | 1) You can include non-media variables.
42 | 2) You can lower the prior for the beta (in front of the transformed media).
43 | 3) You can set a bigger prior for the intercept.
44 |
45 |
46 | *What are the different ways we can inform the media priors?*
47 |
48 | By default, the media priors are informed by the costs, channels with more spend get bigger priors.
49 | You can also base media priors on (geo) experiments or use a heuristic like "the percentage of weeks a channel was used". The intuition behind this is that the more a channel is used, the more a marketer believes its contribution should be high.
50 | Outputs from multi-touch attribution (MTA) can also be used as priors for an MMM.
51 | Think with Google has recently published an [article](https://www.thinkwithgoogle.com/_qs/documents/13385/TwGxOP_Unified_Marketing_Measurement.pdf) on combining results from MTA and MMM.
52 |
53 |
54 | *How should I refresh my model and how often?*
55 |
56 | This depends on the data frequency (daily, weekly) but also in what time frame the marketer makes decisions. If decisions are quarterly, we'd recommend to run the model each quarter.
57 | The data window can be expanded each time, so that older data still has an influence on the most recent estimate. Alternatively, old data can also be discarded, for instance if media effectiveness and strategies have changed more drastically over time. Note however that you can always use the posteriors from a previous modeling cycle as priors when you refresh the model.
58 |
59 | *Why is your model additive?*
60 |
61 | We might make the model multiplicative in future versions, but to keep simple and lightweight we have opted for the additive model for now.
62 |
63 | *How does MCMC sampling works in LMMM?*
64 |
65 | LMMM uses the [NUTS algorithm](https://mc-stan.org/docs/2_18/stan-users-guide/sampling-difficulties-with-problematic-priors.html) to solve the budget allocation question. NUTS only cares about priors and posteriors for each parameter and uses all of the data.
66 |
67 | ## Evaluation
68 |
69 | *How important is OOS predictive performance?*
70 |
71 | Remember MMM is not a forecasting model but an contribution and optimisation tool. Test performance should be looked at but contribution is more important.
72 |
73 | *Which metric is recommended for evaluating goodness of fit on test data?*
74 |
75 | We recommend looking at MAPE or median APE instead of the R-squared metric, as those are more interpretable from a business perspective and less influenced by outliers.
76 |
77 | *How is media effectiveness defined?*
78 |
79 | Media effectiveness shows you how much each media channel percentually contributes to the target variable (e.g. y := Sum of sales).
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/google/lightweight_mmm/tree/main/docs
2 |
3 | LightweightMMM Documentation
4 | ===================
5 |
6 | LightweightMMM 🦇 is a lightweight Bayesian Marketing Mix Modeling (MMM)
7 | library that allows users to easily train MMMs and obtain channel attribution
8 | information. The library also includes capabilities for optimizing media
9 | allocation as well as plotting common graphs in the field.
10 |
11 | It is built in python3 and makes use of Numpyro and JAX.
12 |
13 | Installation
14 | ------------
15 |
16 | We have kept JAX as part of the dependencies to install when installing
17 | LightweightMMM, however if you wish to install a different version of JAX or
18 | jaxlib for specific CUDA/CuDNN versions see
19 | https://github.com/jax-ml/jax?tab=readme-ov-file#instructions for instructions on installing
20 | JAX. Otherwise our installation assumes a CPU setup.
21 |
22 | The recommended way of installing lightweight_mmm is through PyPi:
23 |
24 | ``pip install --upgrade pip``
25 | ``pip install lightweight_mmm``
26 |
27 | If you want to use the most recent and slightly less stable version you can install it from github:
28 |
29 | ``pip install --upgrade git+https://github.com/google/lightweight_mmm.git``
30 |
31 |
32 | .. toctree::
33 | :caption: Model Documentation
34 | :maxdepth: 1
35 |
36 | models
37 |
38 | .. toctree::
39 | :caption: Custom priors
40 | :maxdepth: 1
41 |
42 | custom_priors
43 |
44 | .. toctree::
45 | :caption: API Documentation
46 | :maxdepth: 2
47 |
48 | api
49 |
50 | .. toctree::
51 | :caption: FAQ
52 | :maxdepth: 2
53 |
54 | faq
55 |
56 | Contribute
57 | ----------
58 |
59 | - Issue tracker: https://github.com/google/lightweight_mmm/issues
60 | - Source code: https://github.com/google/lightweight_mmm/tree/main
61 |
62 | Support
63 | -------
64 |
65 | If you are having issues, please let us know by filing an issue on our
66 | `issue tracker `_.
67 |
68 | License
69 | -------
70 |
71 | LightweightMMM is licensed under the Apache 2.0 License.
72 |
73 | Indices and tables
74 | ==================
75 |
76 | * :ref:`genindex`
--------------------------------------------------------------------------------
/docs/models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "YTW94aX0fWCQ"
7 | },
8 | "source": [
9 | "# LightweightMMM Models.\n",
10 | "\n",
11 | "The LightweightMMM can either be run using data aggregated at the national level (standard approach) or using data aggregated at a geo level (sub-national hierarchical approach). These models are documented below.\n",
12 | "\n",
13 | "## National level (standard approach)\n",
14 | "\n",
15 | "All the parameters in our Bayesian model have priors which have been set based on simulated studies that produce stable results. We also set out our three different approaches to saturation and media lagging effects: carryover (with exponent), adstock (with exponent) and hill adstock. Please see Jin, Y. et al., (2017) for more details on these models and choice of priors.\n",
16 | "\n",
17 | " $$kpi_{t} = \\alpha + trend_{t} + seasonality_{t} + media\\ channels_{t} + other\\ factors_{t} $$\n",
18 | "\n",
19 | "*Intercept*\u003cbr\u003e\n",
20 | "- $\\alpha \\sim HalfNormal(2)$\n",
21 | "\n",
22 | "*Trend*\u003cbr\u003e\n",
23 | "- $trend_{t} = \\mu t^{\\kappa}$\u003cbr\u003e\n",
24 | "- $\\mu \\sim Normal(0,1)$\u003cbr\u003e\n",
25 | "- $\\kappa \\sim Uniform(0.5,1.5)$\u003cbr\u003e\n",
26 | "- Where $t$ is a linear trend input\n",
27 | "\n",
28 | "*Seasonality (for models using* **weekly observations**)\u003cbr\u003e\n",
29 | "- $seasonality_{t} = \\displaystyle\\sum_{d=1}^{2} (\\gamma_{1,d} cos(\\frac{2 \\pi dt}{52}) + \\gamma_{2,d} sin(\\frac{2 \\pi dt}{52}))$\u003cbr\u003e\n",
30 | "- $\\gamma_{1,d}, \\gamma_{2,d} \\sim Normal(0,1)$\u003cbr\u003e\n",
31 | "\n",
32 | "*Seasonality (for models using* **daily observations**)\u003cbr\u003e\n",
33 | "- $seasonality_{t} = \\displaystyle\\sum_{d=1}^{2} (\\gamma_{1,d} cos(\\frac{2 \\pi dt}{365}) + \\gamma_{2,d} sin(\\frac{2 \\pi dt}{365})) + \\delta_{t \\bmod 7}$\u003cbr\u003e\n",
34 | "- $\\gamma_{1,d}, \\gamma_{2,d} \\sim Normal(0,1)$\u003cbr\u003e\n",
35 | "- $\\delta_{i} \\sim Normal(0,0.5)$\u003cbr\u003e\n",
36 | "\n",
37 | "*Other factors*\u003cbr\u003e\n",
38 | "- $other\\ factors_{t} = \\displaystyle\\sum_{i=1}^{N} \\lambda_{i}Z_{it}$\u003cbr\u003e\n",
39 | "- $\\lambda_{i} \\sim Normal(0,1)$\u003cbr\u003e\n",
40 | "- Where $Z_{i}$ are other factors and $N$ is the number of other factors.\n",
41 | "\n",
42 | "*Media Effect*\u003cbr\u003e\n",
43 | "- $\\beta_{m} \\sim HalfNormal(v_{m})$\u003cbr\u003e\n",
44 | "- Where $v_{m}$ is a scalar equal to the sum of the total cost of media channel $m$.\n",
45 | "\n",
46 | "*Media Channels (for the* **carryover** *model)*\u003cbr\u003e\n",
47 | "- $media\\ channels_{t} = x_{t,m}^{*\\rho_{m}}$\u003cbr\u003e\n",
48 | "- $x_{t,m}^{*} = \\frac{\\displaystyle\\sum_{l=0}^{L} \\tau_{m}^{(l-\\theta_{m})^2}x_{t-l,m}}{\\displaystyle\\sum_{l=0}^{L}\\tau_{m}^{(l-\\theta_{m})^2}}$ where $L=12$\u003cbr\u003e\n",
49 | "- $\\tau_{m} \\sim Beta(1,1)$\u003cbr\u003e\n",
50 | "- $\\theta_{m} \\sim HalfNormal(2)$\u003cbr\u003e\n",
51 | "- $\\rho_{m} \\sim Beta(9,1)$\n",
52 | "- Where $x_{t,m}$ is the media spend or impressions in week $t$ from media channel $m$\u003cbr\u003e\n",
53 | "\n",
54 | "*Media Channels (for the* **adstock** *model)*\u003cbr\u003e\n",
55 | "- $media\\ channels_{t} = x_{t,m,s}^{*\\rho_{m}}$\u003cbr\u003e\n",
56 | "- $x_{t,m,s}^{*} = \\frac{x_{t,m}^{*}}{1/(1-\\lambda_{m})}$\u003cbr\u003e\n",
57 | "- $x_{t,m}^{*} = x_{t,m} + \\lambda_{m} x_{t-1,m}^{*}$ where $t=2,..,N$\u003cbr\u003e\n",
58 | "- $x_{1,m}^{*} = x_{1,m}$\u003cbr\u003e\n",
59 | "- $\\lambda_{m} \\sim Beta(2,1)$\u003cbr\u003e\n",
60 | "- $\\rho_{m} \\sim Beta(9,1)$\n",
61 | "- Where $x_{t,m}$ is the media spend or impressions in week $t$ from media channel $m$\u003cbr\u003e\n",
62 | "\n",
63 | "*Media Channels (for the* **hill_adstock** *model)*\u003cbr\u003e\n",
64 | "- $media\\ channels_{t} = \\frac{1}{1+(x_{t,m}^{*} / K_{m})^{-S_{m}}}$\u003cbr\u003e\n",
65 | "- $x_{t,m}^{*} = x_{t,m} + \\lambda_{m} x_{t-1,m}^{*}$ where $t=2,..,N$\u003cbr\u003e\n",
66 | "- $x_{1,m}^{*} = x_{1,m}$\u003cbr\u003e\n",
67 | "- $K_{m} \\sim Gamma(1,1)$\u003cbr\u003e\n",
68 | "- $S_{m} \\sim Gamma(1,1)$\u003cbr\u003e\n",
69 | "- $\\lambda_{m} \\sim Beta(2,1)$\n",
70 | "- Where $x_{t,m}$ is the media spend or impressions in week $t$ from media channel $m$\u003cbr\u003e\n",
71 | "\n",
72 | "## Geo level (sub-national hierarchical approach)\n",
73 | "\n",
74 | "The hierarchical model is analogous to the standard model except there is an additional dimension of region. In the geo level model seasonality is learned at the sub-national level and at the national level. For more details on this model, please see Sun, Y et al., (2017)."
75 | ]
76 | }
77 | ],
78 | "metadata": {
79 | "colab": {
80 | "collapsed_sections": [],
81 | "name": "models.ipynb",
82 | "private_outputs": true,
83 | "provenance": []
84 | }
85 | },
86 | "nbformat": 4,
87 | "nbformat_minor": 0
88 | }
89 |
--------------------------------------------------------------------------------
/images/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/flowchart.png
--------------------------------------------------------------------------------
/images/lightweight_mmm_logo_colored_1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_1000.png
--------------------------------------------------------------------------------
/images/lightweight_mmm_logo_colored_250.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_250.png
--------------------------------------------------------------------------------
/images/lightweight_mmm_logo_colored_500.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_500.png
--------------------------------------------------------------------------------
/images/lightweight_mmm_logo_colored_750.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_750.png
--------------------------------------------------------------------------------
/lightweight_mmm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """LightweightMMM library.
16 |
17 | Detailed documentation and examples can be found in the
18 | [Github repository](https://github.com/google/lightweight_mmm).
19 | """
20 | __version__ = "0.1.9"
21 |
--------------------------------------------------------------------------------
/lightweight_mmm/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Parse absl flags for when it is run by github actions by pytest."""
16 |
17 | from absl import flags
18 |
19 |
20 | def pytest_configure(config):
21 | flags.FLAGS.mark_as_parsed()
22 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/baseline/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/baseline/intercept.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module for modeling the intercept."""
16 |
17 | from typing import Mapping
18 |
19 | import immutabledict
20 | import jax.numpy as jnp
21 | import numpyro
22 | from numpyro import distributions as dist
23 |
24 | from lightweight_mmm.core import core_utils
25 | from lightweight_mmm.core import priors
26 |
27 |
28 | def simple_intercept(
29 | data: jnp.ndarray,
30 | custom_priors: Mapping[str,
31 | dist.Distribution] = immutabledict.immutabledict(),
32 | ) -> jnp.ndarray:
33 | """Calculates a national or geo incercept.
34 | Note that this intercept is constant over time.
35 |
36 | Args:
37 | data: Media input data. Media data must have either 2 dims for national
38 | model or 3 for geo models.
39 | custom_priors: The custom priors we want the model to take instead of the
40 | default ones. Refer to the full documentation on custom priors for
41 | details.
42 |
43 | Returns:
44 | The values of the intercept.
45 | """
46 | default_priors = priors.get_default_priors()
47 | n_geos = core_utils.get_number_geos(data=data)
48 |
49 | with numpyro.plate(name=f"{priors.INTERCEPT}_plate", size=n_geos):
50 | intercept = numpyro.sample(
51 | name=priors.INTERCEPT,
52 | fn=custom_priors.get(priors.INTERCEPT,
53 | default_priors[priors.INTERCEPT]),
54 | )
55 | return jnp.asarray(intercept)
56 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/baseline/intercept_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for intercept."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpyro
22 | from numpyro import handlers
23 | import numpyro.distributions as dist
24 |
25 | from lightweight_mmm.core import core_utils
26 | from lightweight_mmm.core import priors
27 | from lightweight_mmm.core.baseline import intercept
28 |
29 |
30 | class InterceptTest(parameterized.TestCase):
31 |
32 | @parameterized.named_parameters(
33 | dict(
34 | testcase_name="national",
35 | data_shape=(150, 3),
36 | ),
37 | dict(
38 | testcase_name="geo",
39 | data_shape=(150, 3, 5),
40 | ),
41 | )
42 | def test_simple_intercept_produces_output_correct_shape(self, data_shape):
43 |
44 | def mock_model_function(data):
45 | numpyro.deterministic(
46 | "intercept_values",
47 | intercept.simple_intercept(data=data, custom_priors={}))
48 |
49 | num_samples = 10
50 | data = jnp.ones(data_shape)
51 | n_geos = core_utils.get_number_geos(data=data)
52 | kernel = numpyro.infer.NUTS(model=mock_model_function)
53 | mcmc = numpyro.infer.MCMC(
54 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
55 | rng_key = jax.random.PRNGKey(0)
56 |
57 | mcmc.run(rng_key, data=data)
58 | intercept_values = mcmc.get_samples()["intercept_values"]
59 |
60 | self.assertEqual(intercept_values.shape, (num_samples, n_geos))
61 |
62 | @parameterized.named_parameters(
63 | dict(
64 | testcase_name="national",
65 | data_shape=(150, 3),
66 | ),
67 | dict(
68 | testcase_name="geo",
69 | data_shape=(150, 3, 5),
70 | ),
71 | )
72 | def test_simple_intercept_takes_custom_priors_correctly(self, data_shape):
73 | prior_name = priors.INTERCEPT
74 | expected_value1, expected_value2 = 5.2, 7.56
75 | custom_priors = {
76 | prior_name:
77 | dist.Kumaraswamy(
78 | concentration1=expected_value1, concentration0=expected_value2)
79 | }
80 | media = jnp.ones(data_shape)
81 |
82 | trace_handler = handlers.trace(
83 | handlers.seed(intercept.simple_intercept, rng_seed=0))
84 | trace = trace_handler.get_trace(data=media, custom_priors=custom_priors)
85 | values_and_dists = {
86 | name: site["fn"] for name, site in trace.items() if "fn" in site
87 | }
88 |
89 | used_distribution = values_and_dists[prior_name].base_dist
90 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
91 | self.assertEqual(used_distribution.concentration0, expected_value2)
92 | self.assertEqual(used_distribution.concentration1, expected_value1)
93 |
94 |
95 | if __name__ == "__main__":
96 | absltest.main()
97 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/core_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Sets of utilities used across the core components of LightweightMMM."""
16 |
17 | import sys
18 | from typing import Any, Mapping, Tuple, Union
19 |
20 | import jax.numpy as jnp
21 |
22 | from numpyro import distributions as dist
23 |
24 | # pylint: disable=g-import-not-at-top
25 | if sys.version_info >= (3, 8):
26 | from typing import Protocol
27 | else:
28 | from typing_extensions import Protocol
29 |
30 |
31 | class TransformFunction(Protocol):
32 |
33 | def __call__(
34 | self,
35 | data: jnp.ndarray,
36 | custom_priors: Mapping[str, dist.Distribution],
37 | prefix: str,
38 | **kwargs: Any,
39 | ) -> jnp.ndarray:
40 | ...
41 |
42 |
43 | class Module(Protocol):
44 |
45 | def __call__(
46 | self,
47 | *args: Any,
48 | **kwargs: Any,
49 | ) -> jnp.ndarray:
50 | ...
51 |
52 |
53 | def get_number_geos(data: jnp.ndarray) -> int:
54 | return data.shape[2] if data.ndim == 3 else 1
55 |
56 |
57 | def get_geo_shape(data: jnp.ndarray) -> Union[Tuple[int], Tuple[()]]:
58 | return (data.shape[2],) if data.ndim == 3 else ()
59 |
60 |
61 | def apply_exponent_safe(data: jnp.ndarray,
62 | exponent: jnp.ndarray) -> jnp.ndarray:
63 | """Applies an exponent to given data in a gradient safe way.
64 |
65 | More info on the double jnp.where can be found:
66 | https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
67 |
68 | Args:
69 | data: Input data to use.
70 | exponent: Exponent required for the operations.
71 |
72 | Returns:
73 | The result of the exponent operation with the inputs provided.
74 | """
75 | exponent_safe = jnp.where(data == 0, 1, data) ** exponent
76 | return jnp.where(data == 0, 0, exponent_safe)
77 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/core_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for core_utils."""
16 |
17 | from lightweight_mmm.core import core_utils
18 | from absl.testing import absltest
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 |
24 | class CoreUtilsTest(absltest.TestCase):
25 |
26 | def test_apply_exponent_safe_produces_same_exponent_results(self):
27 | data = jnp.arange(50).reshape((10, 5))
28 | exponent = jnp.full(5, 0.5)
29 |
30 | output = core_utils.apply_exponent_safe(data=data, exponent=exponent)
31 |
32 | np.testing.assert_array_equal(output, data**exponent)
33 |
34 | def test_apply_exponent_safe_produces_correct_shape(self):
35 | data = jnp.ones((10, 5))
36 | exponent = jnp.full(5, 0.5)
37 |
38 | output = core_utils.apply_exponent_safe(data=data, exponent=exponent)
39 |
40 | self.assertEqual(output.shape, data.shape)
41 |
42 | def test_apply_exponent_safe_produces_non_nan_or_inf_grads(self):
43 |
44 | def f_safe(data, exponent):
45 | x = core_utils.apply_exponent_safe(data=data, exponent=exponent)
46 | return x.sum()
47 |
48 | data = jnp.ones((10, 5))
49 | data = data.at[0, 0].set(0.)
50 | exponent = jnp.full(5, 0.5)
51 |
52 | grads = jax.grad(f_safe)(data, exponent)
53 |
54 | self.assertFalse(np.isnan(grads).any())
55 | self.assertFalse(np.isinf(grads).any())
56 |
57 |
58 | if __name__ == '__main__':
59 | absltest.main()
60 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/priors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Sets priors and prior related constants for LMMM."""
16 | from typing import Mapping
17 |
18 | import immutabledict
19 | from numpyro import distributions as dist
20 |
21 | # Core model priors
22 | INTERCEPT = "intercept"
23 | COEF_TREND = "coef_trend"
24 | EXPO_TREND = "expo_trend"
25 | SIGMA = "sigma"
26 | GAMMA_SEASONALITY = "gamma_seasonality"
27 | WEEKDAY = "weekday"
28 | COEF_EXTRA_FEATURES = "coef_extra_features"
29 | COEF_SEASONALITY = "coef_seasonality"
30 |
31 | # Lagging priors
32 | LAG_WEIGHT = "lag_weight"
33 | AD_EFFECT_RETENTION_RATE = "ad_effect_retention_rate"
34 | PEAK_EFFECT_DELAY = "peak_effect_delay"
35 |
36 | # Saturation priors
37 | EXPONENT = "exponent"
38 | HALF_MAX_EFFECTIVE_CONCENTRATION = "half_max_effective_concentration"
39 | SLOPE = "slope"
40 |
41 | # Dynamic trend priors
42 | DYNAMIC_TREND_INITIAL_LEVEL = "dynamic_trend_initial_level"
43 | DYNAMIC_TREND_INITIAL_SLOPE = "dynamic_trend_initial_slope"
44 | DYNAMIC_TREND_LEVEL_VARIANCE = "dynamic_trend_level_variance"
45 | DYNAMIC_TREND_SLOPE_VARIANCE = "dynamic_trend_slope_variance"
46 |
47 | MODEL_PRIORS_NAMES = frozenset((
48 | INTERCEPT,
49 | COEF_TREND,
50 | EXPO_TREND,
51 | SIGMA,
52 | GAMMA_SEASONALITY,
53 | WEEKDAY,
54 | COEF_EXTRA_FEATURES,
55 | COEF_SEASONALITY,
56 | LAG_WEIGHT,
57 | AD_EFFECT_RETENTION_RATE,
58 | PEAK_EFFECT_DELAY,
59 | EXPONENT,
60 | HALF_MAX_EFFECTIVE_CONCENTRATION,
61 | SLOPE,
62 | DYNAMIC_TREND_INITIAL_LEVEL,
63 | DYNAMIC_TREND_INITIAL_SLOPE,
64 | DYNAMIC_TREND_LEVEL_VARIANCE,
65 | DYNAMIC_TREND_SLOPE_VARIANCE,
66 | ))
67 |
68 | GEO_ONLY_PRIORS = frozenset((COEF_SEASONALITY,))
69 |
70 |
71 | def get_default_priors() -> Mapping[str, dist.Distribution]:
72 | # Since JAX cannot be called before absl.app.run in tests we get default
73 | # priors from a function.
74 | return immutabledict.immutabledict({
75 | INTERCEPT: dist.HalfNormal(scale=2.),
76 | COEF_TREND: dist.Normal(loc=0., scale=1.),
77 | EXPO_TREND: dist.Uniform(low=0.5, high=1.5),
78 | SIGMA: dist.Gamma(concentration=1., rate=1.),
79 | GAMMA_SEASONALITY: dist.Normal(loc=0., scale=1.),
80 | WEEKDAY: dist.Normal(loc=0., scale=.5),
81 | COEF_EXTRA_FEATURES: dist.Normal(loc=0., scale=1.),
82 | COEF_SEASONALITY: dist.HalfNormal(scale=.5),
83 | AD_EFFECT_RETENTION_RATE: dist.Beta(concentration1=1., concentration0=1.),
84 | PEAK_EFFECT_DELAY: dist.HalfNormal(scale=2.),
85 | EXPONENT: dist.Beta(concentration1=9., concentration0=1.),
86 | LAG_WEIGHT: dist.Beta(concentration1=2., concentration0=1.),
87 | HALF_MAX_EFFECTIVE_CONCENTRATION: dist.Gamma(concentration=1., rate=1.),
88 | SLOPE: dist.Gamma(concentration=1., rate=1.),
89 | DYNAMIC_TREND_INITIAL_LEVEL: dist.Normal(loc=.5, scale=2.5),
90 | DYNAMIC_TREND_INITIAL_SLOPE: dist.Normal(loc=0., scale=.2),
91 | DYNAMIC_TREND_LEVEL_VARIANCE: dist.Uniform(low=0., high=.1),
92 | DYNAMIC_TREND_SLOPE_VARIANCE: dist.Uniform(low=0., high=.01),
93 | })
94 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/time/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/time/seasonality.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Core and modelling functions for seasonality."""
16 |
17 | from typing import Mapping
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | import numpyro
22 | from numpyro import distributions as dist
23 |
24 | from lightweight_mmm.core import priors
25 | from lightweight_mmm.core import core_utils
26 |
27 |
28 | @jax.jit
29 | def _sinusoidal_seasonality(
30 | seasonality_arange: jnp.ndarray,
31 | degrees_arange: jnp.ndarray,
32 | gamma_seasonality: jnp.ndarray,
33 | frequency: int,
34 | ) -> jnp.ndarray:
35 | """Core calculation of cyclic variation seasonality.
36 |
37 | Args:
38 | seasonality_arange: Array with range [0, N - 1] where N is the size of the
39 | data for which the seasonality is modelled.
40 | degrees_arange: Array with range [0, D - 1] where D is the number of degrees
41 | to use. Must be greater or equal than 1.
42 | gamma_seasonality: Factor to multiply to each degree calculation. Shape must
43 | be aligned with the number of degrees.
44 | frequency: Frecuency of the seasonality be in computed.
45 |
46 | Returns:
47 | An array with the seasonality values.
48 | """
49 | inner_value = seasonality_arange * 2 * jnp.pi * degrees_arange / frequency
50 | season_matrix_sin = jnp.sin(inner_value)
51 | season_matrix_cos = jnp.cos(inner_value)
52 | season_matrix = jnp.concatenate([
53 | jnp.expand_dims(a=season_matrix_sin, axis=-1),
54 | jnp.expand_dims(a=season_matrix_cos, axis=-1)
55 | ],
56 | axis=-1)
57 | return jnp.einsum("tds, ds -> t", season_matrix, gamma_seasonality)
58 |
59 |
60 | def sinusoidal_seasonality(
61 | data: jnp.ndarray,
62 | custom_priors: Mapping[str, dist.Distribution],
63 | *,
64 | degrees_seasonality: int = 2,
65 | frequency: int = 52,
66 | ) -> jnp.ndarray:
67 | """Calculates cyclic variation seasonality.
68 |
69 | For detailed info check:
70 | https://en.wikipedia.org/wiki/Seasonality#Modeling
71 |
72 | Args:
73 | data: Data for which the seasonality will be modelled for. It is used to
74 | obtain the length of the time dimension, axis 0.
75 | custom_priors: The custom priors we want the model to take instead of
76 | default ones.
77 | degrees_seasonality: Number of degrees to use. Must be greater or equal than
78 | 1.
79 | frequency: Frecuency of the seasonality be in computed. By default is 52 for
80 | weekly data (52 weeks in a year).
81 |
82 | Returns:
83 | An array with the seasonality values.
84 | """
85 | number_periods = data.shape[0]
86 | default_priors = priors.get_default_priors()
87 | n_geos = core_utils.get_number_geos(data=data)
88 | with numpyro.plate(name=f"{priors.GAMMA_SEASONALITY}_sin_cos_plate", size=2):
89 | with numpyro.plate(
90 | name=f"{priors.GAMMA_SEASONALITY}_plate", size=degrees_seasonality):
91 | gamma_seasonality = numpyro.sample(
92 | name=priors.GAMMA_SEASONALITY,
93 | fn=custom_priors.get(priors.GAMMA_SEASONALITY,
94 | default_priors[priors.GAMMA_SEASONALITY]))
95 | seasonality_arange = jnp.expand_dims(a=jnp.arange(number_periods), axis=-1)
96 | degrees_arange = jnp.arange(degrees_seasonality)
97 | seasonality_values = _sinusoidal_seasonality(
98 | seasonality_arange=seasonality_arange,
99 | degrees_arange=degrees_arange,
100 | frequency=frequency,
101 | gamma_seasonality=gamma_seasonality,
102 | )
103 | if n_geos > 1:
104 | seasonality_values = jnp.expand_dims(seasonality_values, axis=-1)
105 | return seasonality_values
106 |
107 |
108 | def _intra_week_seasonality(
109 | data: jnp.ndarray,
110 | weekday: jnp.ndarray,
111 | ) -> jnp.ndarray:
112 | data_size = data.shape[0]
113 | return weekday[jnp.arange(data_size) % 7]
114 |
115 |
116 | def intra_week_seasonality(
117 | data: jnp.ndarray,
118 | custom_priors: Mapping[str, dist.Distribution],
119 | ) -> jnp.ndarray:
120 | """Models intra week seasonality.
121 |
122 | Args:
123 | data: Data for which the seasonality will be modelled for. It is used to
124 | obtain the length of the time dimension, axis 0.
125 | custom_priors: The custom priors we want the model to take instead of
126 | default ones.
127 |
128 | Returns:
129 | The contribution of the weekday seasonality.
130 | """
131 | default_priors = priors.get_default_priors()
132 | with numpyro.plate(name=f"{priors.WEEKDAY}_plate", size=7):
133 | weekday = numpyro.sample(
134 | name=priors.WEEKDAY,
135 | fn=custom_priors.get(priors.WEEKDAY, default_priors[priors.WEEKDAY]))
136 |
137 | weekday_series = _intra_week_seasonality(data=data, weekday=weekday)
138 |
139 | if data.ndim == 3: # For geo model's case
140 | weekday_series = jnp.expand_dims(weekday_series, axis=-1)
141 |
142 | return weekday_series
143 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/time/seasonality_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for seasonality."""
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | import numpyro
20 | from numpyro import distributions as dist
21 | from numpyro import handlers
22 |
23 | from absl.testing import absltest
24 | from absl.testing import parameterized
25 | from lightweight_mmm.core import priors
26 | from lightweight_mmm.core.time import seasonality
27 |
28 |
29 | class SeasonalityTest(parameterized.TestCase):
30 |
31 | @parameterized.named_parameters([
32 | dict(
33 | testcase_name="2_degrees",
34 | seasonality_arange_value=150,
35 | degrees_arange_shape=5,
36 | gamma_seasonality_shape=(5, 2),
37 | ),
38 | dict(
39 | testcase_name="10_degree",
40 | seasonality_arange_value=150,
41 | degrees_arange_shape=10,
42 | gamma_seasonality_shape=(10, 2),
43 | ),
44 | dict(
45 | testcase_name="1_degree",
46 | seasonality_arange_value=200,
47 | degrees_arange_shape=1,
48 | gamma_seasonality_shape=(1, 2),
49 | ),
50 | ])
51 | def test_core_sinusoidal_seasonality_produces_correct_shape(
52 | self, seasonality_arange_value, degrees_arange_shape,
53 | gamma_seasonality_shape):
54 | seasonality_arange = jnp.expand_dims(
55 | jnp.arange(seasonality_arange_value), axis=-1)
56 | degrees_arange = jnp.arange(degrees_arange_shape)
57 | gamma_seasonality = jnp.ones(gamma_seasonality_shape)
58 |
59 | seasonality_values = seasonality._sinusoidal_seasonality(
60 | seasonality_arange=seasonality_arange,
61 | degrees_arange=degrees_arange,
62 | gamma_seasonality=gamma_seasonality,
63 | frequency=52,
64 | )
65 | self.assertEqual(seasonality_values.shape, (seasonality_arange_value,))
66 |
67 | @parameterized.named_parameters(
68 | dict(
69 | testcase_name="ten_degrees_national",
70 | data_shape=(500, 5),
71 | degrees_seasonality=10,
72 | expected_shape=(10, 500),
73 | ),
74 | dict(
75 | testcase_name="ten_degrees_geo",
76 | data_shape=(500, 5, 5),
77 | degrees_seasonality=10,
78 | expected_shape=(10, 500, 1),
79 | ),
80 | dict(
81 | testcase_name="one_degrees_national",
82 | data_shape=(500, 5),
83 | degrees_seasonality=1,
84 | expected_shape=(10, 500),
85 | ),
86 | dict(
87 | testcase_name="one_degrees_geo",
88 | data_shape=(500, 5, 5),
89 | degrees_seasonality=1,
90 | expected_shape=(10, 500, 1),
91 | ),
92 | )
93 | def test_model_sinusoidal_seasonality_produces_correct_shape(
94 | self, data_shape, degrees_seasonality, expected_shape):
95 |
96 | def mock_model_function(data, degrees_seasonality, frequency):
97 | numpyro.deterministic(
98 | "seasonality",
99 | seasonality.sinusoidal_seasonality(
100 | data=data,
101 | degrees_seasonality=degrees_seasonality,
102 | custom_priors={},
103 | frequency=frequency))
104 |
105 | num_samples = 10
106 | data = jnp.ones(data_shape)
107 | kernel = numpyro.infer.NUTS(model=mock_model_function)
108 | mcmc = numpyro.infer.MCMC(
109 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
110 | rng_key = jax.random.PRNGKey(0)
111 |
112 | mcmc.run(
113 | rng_key,
114 | data=data,
115 | degrees_seasonality=degrees_seasonality,
116 | frequency=52,
117 | )
118 | seasonality_values = mcmc.get_samples()["seasonality"]
119 |
120 | self.assertEqual(seasonality_values.shape, expected_shape)
121 |
122 | def test_sinusoidal_seasonality_custom_priors_are_taken_correctly(self):
123 | prior_name = priors.GAMMA_SEASONALITY
124 | expected_value1, expected_value2 = 5.2, 7.56
125 | custom_priors = {
126 | prior_name:
127 | dist.Kumaraswamy(
128 | concentration1=expected_value1, concentration0=expected_value2)
129 | }
130 | media = jnp.ones((10, 5, 5))
131 | degrees_seasonality = 3
132 | frequency = 365
133 |
134 | trace_handler = handlers.trace(
135 | handlers.seed(seasonality.sinusoidal_seasonality, rng_seed=0))
136 | trace = trace_handler.get_trace(
137 | data=media,
138 | custom_priors=custom_priors,
139 | degrees_seasonality=degrees_seasonality,
140 | frequency=frequency,
141 | )
142 | values_and_dists = {
143 | name: site["fn"] for name, site in trace.items() if "fn" in site
144 | }
145 |
146 | used_distribution = values_and_dists[prior_name]
147 | if isinstance(used_distribution, dist.ExpandedDistribution):
148 | used_distribution = used_distribution.base_dist
149 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
150 | self.assertEqual(used_distribution.concentration0, expected_value2)
151 | self.assertEqual(used_distribution.concentration1, expected_value1)
152 |
153 | @parameterized.named_parameters(
154 | dict(
155 | testcase_name="ten_degrees",
156 | data_shape=(500, 3),
157 | expected_shape=(10, 500),
158 | ),
159 | dict(
160 | testcase_name="five_degrees",
161 | data_shape=(500, 3, 5),
162 | expected_shape=(10, 500, 1),
163 | ),
164 | )
165 | def test_intra_week_seasonality_produces_correct_shape(
166 | self, data_shape, expected_shape):
167 |
168 | def mock_model_function(data):
169 | numpyro.deterministic(
170 | "intra_week",
171 | seasonality.intra_week_seasonality(
172 | data=data,
173 | custom_priors={},
174 | ))
175 |
176 | num_samples = 10
177 | data = jnp.ones(data_shape)
178 | kernel = numpyro.infer.NUTS(model=mock_model_function)
179 | mcmc = numpyro.infer.MCMC(
180 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
181 | rng_key = jax.random.PRNGKey(0)
182 |
183 | mcmc.run(rng_key, data=data)
184 | seasonality_values = mcmc.get_samples()["intra_week"]
185 |
186 | self.assertEqual(seasonality_values.shape, expected_shape)
187 |
188 | def test_intra_week_seasonality_custom_priors_are_taken_correctly(self):
189 | prior_name = priors.WEEKDAY
190 | expected_value1, expected_value2 = 5.2, 7.56
191 | custom_priors = {
192 | prior_name:
193 | dist.Kumaraswamy(
194 | concentration1=expected_value1, concentration0=expected_value2)
195 | }
196 | media = jnp.ones((10, 5, 5))
197 |
198 | trace_handler = handlers.trace(
199 | handlers.seed(seasonality.intra_week_seasonality, rng_seed=0))
200 | trace = trace_handler.get_trace(
201 | data=media,
202 | custom_priors=custom_priors,
203 | )
204 | values_and_dists = {
205 | name: site["fn"] for name, site in trace.items() if "fn" in site
206 | }
207 |
208 | used_distribution = values_and_dists[prior_name]
209 | if isinstance(used_distribution, dist.ExpandedDistribution):
210 | used_distribution = used_distribution.base_dist
211 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
212 | self.assertEqual(used_distribution.concentration0, expected_value2)
213 | self.assertEqual(used_distribution.concentration1, expected_value1)
214 |
215 |
216 | if __name__ == "__main__":
217 | absltest.main()
218 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/time/trend.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Core and modelling functions for trend."""
16 |
17 | import functools
18 | from typing import Mapping
19 |
20 | import jax
21 | import jax.numpy as jnp
22 | import numpyro
23 | from numpyro import distributions as dist
24 |
25 | from lightweight_mmm.core import core_utils
26 | from lightweight_mmm.core import priors
27 |
28 |
29 | @jax.jit
30 | def _trend_with_exponent(coef_trend: jnp.ndarray, trend: jnp.ndarray,
31 | expo_trend: jnp.ndarray) -> jnp.ndarray:
32 | """Applies the coefficient and exponent to the trend to obtain trend values.
33 |
34 | Args:
35 | coef_trend: Coefficient to be multiplied by the trend.
36 | trend: Initial trend values.
37 | expo_trend: Exponent to be applied to the trend.
38 |
39 | Returns:
40 | The trend values generated.
41 | """
42 | return coef_trend * trend**expo_trend
43 |
44 |
45 | def trend_with_exponent(
46 | data: jnp.ndarray,
47 | custom_priors: Mapping[str, dist.Distribution],
48 | ) -> jnp.ndarray:
49 | """Trend with exponent for curvature.
50 |
51 | Args:
52 | data: Data for which trend will be created.
53 | custom_priors: The custom priors we want the model to take instead of the
54 | default ones. See our custom_priors documentation for details about the
55 | API and possible options.
56 |
57 | Returns:
58 | The values of the trend.
59 | """
60 | default_priors = priors.get_default_priors()
61 | n_geos = core_utils.get_number_geos(data=data)
62 | # TODO(): Force all geos to have the same trend sign.
63 | with numpyro.plate(name=f"{priors.COEF_TREND}_plate", size=n_geos):
64 | coef_trend = numpyro.sample(
65 | name=priors.COEF_TREND,
66 | fn=custom_priors.get(priors.COEF_TREND,
67 | default_priors[priors.COEF_TREND]))
68 |
69 | expo_trend = numpyro.sample(
70 | name=priors.EXPO_TREND,
71 | fn=custom_priors.get(priors.EXPO_TREND,
72 | default_priors[priors.EXPO_TREND]))
73 | linear_trend = jnp.arange(data.shape[0])
74 | if n_geos > 1: # For geo model's case
75 | linear_trend = jnp.expand_dims(linear_trend, axis=-1)
76 | return _trend_with_exponent(
77 | coef_trend=coef_trend, trend=linear_trend, expo_trend=expo_trend)
78 |
79 |
80 | @functools.partial(jax.jit, static_argnames=("number_periods",))
81 | def _dynamic_trend(
82 | number_periods: int,
83 | random_walk_level: jnp.ndarray,
84 | random_walk_slope: jnp.ndarray,
85 | initial_level: jnp.ndarray,
86 | initial_slope: jnp.ndarray,
87 | variance_level: jnp.ndarray,
88 | variance_slope: jnp.ndarray,
89 | ) -> jnp.ndarray:
90 | """Calculates dynamic trend using local linear trend method.
91 |
92 | More details about this function can be found in:
93 | https://storage.googleapis.com/pub-tools-public-publication-data/pdf/41854.pdf
94 |
95 | Args:
96 | number_periods: Number of time periods in the data.
97 | random_walk_level: Random walk of level from sample.
98 | random_walk_slope: Random walk of slope from sample.
99 | initial_level: The initial value for level in local linear trend model.
100 | initial_slope: The initial value for slope in local linear trend model.
101 | variance_level: The variance of the expected increase in level between time.
102 | variance_slope: The variance of the expected increase in slope between time.
103 |
104 | Returns:
105 | The dynamic trend values for the given data with the given parameters.
106 | """
107 | # Simulate gaussian random walk of level with initial level.
108 | random_level = variance_level * random_walk_level
109 | random_level_with_initial_level = jnp.concatenate(
110 | [jnp.array([random_level[0] + initial_level]), random_level[1:]])
111 | level_trend_t = jnp.cumsum(random_level_with_initial_level, axis=0)
112 | # Simulate gaussian random walk of slope with initial slope.
113 | random_slope = variance_slope * random_walk_slope
114 | random_slope_with_initial_slope = jnp.concatenate(
115 | [jnp.array([random_slope[0] + initial_slope]), random_slope[1:]])
116 | slope_trend_t = jnp.cumsum(random_slope_with_initial_slope, axis=0)
117 | # Accumulate sum of slope series to address latent variable slope in function
118 | # level_t = level_t-1 + slope_t-1.
119 | initial_zero_shape = [(1, 0)] if slope_trend_t.ndim == 1 else [(1, 0), (0, 0)]
120 | slope_trend_cumsum = jnp.pad(
121 | jnp.cumsum(slope_trend_t, axis=0)[:number_periods - 1],
122 | initial_zero_shape, mode="constant", constant_values=0)
123 | return level_trend_t + slope_trend_cumsum
124 |
125 |
126 | def dynamic_trend(
127 | geo_size: int,
128 | data_size: int,
129 | is_trend_prediction: bool,
130 | custom_priors: Mapping[str, dist.Distribution],
131 | ) -> jnp.ndarray:
132 | """Generates the dynamic trend to capture the baseline of kpi.
133 |
134 | Args:
135 | geo_size: Number of geos in the model.
136 | data_size: Number of time samples in the model.
137 | is_trend_prediction: Whether it is used for prediction or fitting.
138 | custom_priors: The custom priors we want the model to take instead of the
139 | default ones. See our custom_priors documentation for details about the
140 | API and possible options.
141 |
142 | Returns:
143 | Jax array with trend for each time t.
144 | """
145 | default_priors = priors.get_default_priors()
146 | if not is_trend_prediction:
147 | random_walk_level = numpyro.sample("random_walk_level",
148 | fn=dist.Normal(),
149 | sample_shape=(data_size, 1))
150 | random_walk_slope = numpyro.sample("random_walk_slope",
151 | fn=dist.Normal(),
152 | sample_shape=(data_size, 1))
153 | else:
154 | random_walk_level = numpyro.sample("random_walk_level_prediction",
155 | fn=dist.Normal(),
156 | sample_shape=(data_size, 1))
157 |
158 | random_walk_slope = numpyro.sample("random_walk_slope_prediction",
159 | fn=dist.Normal(),
160 | sample_shape=(data_size, 1))
161 |
162 | with numpyro.plate(
163 | name=f"{priors.DYNAMIC_TREND_INITIAL_LEVEL}_plate", size=geo_size):
164 | trend_initial_level = numpyro.sample(
165 | name=priors.DYNAMIC_TREND_INITIAL_LEVEL,
166 | fn=custom_priors.get(
167 | priors.DYNAMIC_TREND_INITIAL_LEVEL,
168 | default_priors[priors.DYNAMIC_TREND_INITIAL_LEVEL]))
169 |
170 | with numpyro.plate(
171 | name=f"{priors.DYNAMIC_TREND_INITIAL_SLOPE}_plate", size=geo_size):
172 | trend_initial_slope = numpyro.sample(
173 | name=priors.DYNAMIC_TREND_INITIAL_SLOPE,
174 | fn=custom_priors.get(
175 | priors.DYNAMIC_TREND_INITIAL_SLOPE,
176 | default_priors[priors.DYNAMIC_TREND_INITIAL_SLOPE]))
177 |
178 | with numpyro.plate(
179 | name=f"{priors.DYNAMIC_TREND_LEVEL_VARIANCE}_plate", size=geo_size):
180 | trend_level_variance = numpyro.sample(
181 | name=priors.DYNAMIC_TREND_LEVEL_VARIANCE,
182 | fn=custom_priors.get(
183 | priors.DYNAMIC_TREND_LEVEL_VARIANCE,
184 | default_priors[priors.DYNAMIC_TREND_LEVEL_VARIANCE]))
185 |
186 | with numpyro.plate(
187 | name=f"{priors.DYNAMIC_TREND_SLOPE_VARIANCE}_plate", size=geo_size):
188 | trend_slope_variance = numpyro.sample(
189 | name=priors.DYNAMIC_TREND_SLOPE_VARIANCE,
190 | fn=custom_priors.get(
191 | priors.DYNAMIC_TREND_SLOPE_VARIANCE,
192 | default_priors[priors.DYNAMIC_TREND_SLOPE_VARIANCE]))
193 |
194 | if geo_size == 1: # National level model case.
195 | random_walk_level = jnp.squeeze(random_walk_level)
196 | random_walk_slope = jnp.squeeze(random_walk_slope)
197 | trend_initial_level = jnp.squeeze(trend_initial_level)
198 | trend_initial_slope = jnp.squeeze(trend_initial_slope)
199 | trend_level_variance = jnp.squeeze(trend_level_variance)
200 | trend_slope_variance = jnp.squeeze(trend_slope_variance)
201 |
202 | return _dynamic_trend(
203 | number_periods=data_size,
204 | random_walk_level=random_walk_level,
205 | random_walk_slope=random_walk_slope,
206 | initial_level=trend_initial_level,
207 | initial_slope=trend_initial_slope,
208 | variance_level=trend_level_variance,
209 | variance_slope=trend_slope_variance)
210 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/time/trend_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for trend."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | import numpyro
23 | from numpyro import distributions as dist
24 | from numpyro import handlers
25 |
26 | from lightweight_mmm.core import core_utils
27 | from lightweight_mmm.core import priors
28 | from lightweight_mmm.core.time import trend
29 |
30 |
31 | class TrendTest(parameterized.TestCase):
32 |
33 | @parameterized.named_parameters([
34 | dict(
35 | testcase_name="national",
36 | coef_trend_shape=(),
37 | trend_length=150,
38 | expo_trend_shape=(),
39 | ),
40 | dict(
41 | testcase_name="geo",
42 | coef_trend_shape=(5,),
43 | trend_length=150,
44 | expo_trend_shape=(),
45 | ),
46 | ])
47 | def test_core_trend_with_exponent_produces_correct_shape(
48 | self, coef_trend_shape, trend_length, expo_trend_shape):
49 | coef_trend = jnp.ones(coef_trend_shape)
50 | linear_trend = jnp.arange(trend_length)
51 | if coef_trend.ndim == 1: # For geo model's case
52 | linear_trend = jnp.expand_dims(linear_trend, axis=-1)
53 | expo_trend = jnp.ones(expo_trend_shape)
54 |
55 | trend_values = trend._trend_with_exponent(
56 | coef_trend=coef_trend, trend=linear_trend, expo_trend=expo_trend)
57 |
58 | self.assertEqual(trend_values.shape,
59 | (linear_trend.shape[0], *coef_trend_shape))
60 |
61 | @parameterized.named_parameters([
62 | dict(testcase_name="national", data_shape=(150, 3)),
63 | dict(testcase_name="geo", data_shape=(150, 3, 5)),
64 | ])
65 | def test_trend_with_exponent_produces_correct_shape(self, data_shape):
66 |
67 | def mock_model_function(data):
68 | numpyro.deterministic(
69 | "trend", trend.trend_with_exponent(
70 | data=data,
71 | custom_priors={},
72 | ))
73 |
74 | num_samples = 10
75 | data = jnp.ones(data_shape)
76 | kernel = numpyro.infer.NUTS(model=mock_model_function)
77 | mcmc = numpyro.infer.MCMC(
78 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
79 | rng_key = jax.random.PRNGKey(0)
80 | coef_expected_shape = () if data.ndim == 2 else (data.shape[2],)
81 |
82 | mcmc.run(rng_key, data=data)
83 | trend_values = mcmc.get_samples()["trend"]
84 |
85 | self.assertEqual(trend_values.shape,
86 | (num_samples, data.shape[0], *coef_expected_shape))
87 |
88 | @parameterized.named_parameters(
89 | dict(
90 | testcase_name=f"model_{priors.COEF_TREND}",
91 | prior_name=priors.COEF_TREND,
92 | ),
93 | dict(
94 | testcase_name=f"model_{priors.EXPO_TREND}",
95 | prior_name=priors.EXPO_TREND,
96 | ),
97 | )
98 | def test_trend_with_exponent_custom_priors_are_taken_correctly(
99 | self, prior_name):
100 | expected_value1, expected_value2 = 5.2, 7.56
101 | custom_priors = {
102 | prior_name:
103 | dist.Kumaraswamy(
104 | concentration1=expected_value1, concentration0=expected_value2)
105 | }
106 | media = jnp.ones((10, 5, 5))
107 |
108 | trace_handler = handlers.trace(
109 | handlers.seed(trend.trend_with_exponent, rng_seed=0))
110 | trace = trace_handler.get_trace(
111 | data=media,
112 | custom_priors=custom_priors,
113 | )
114 | values_and_dists = {
115 | name: site["fn"] for name, site in trace.items() if "fn" in site
116 | }
117 |
118 | used_distribution = values_and_dists[prior_name]
119 | if isinstance(used_distribution, dist.ExpandedDistribution):
120 | used_distribution = used_distribution.base_dist
121 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
122 | self.assertEqual(used_distribution.concentration0, expected_value2)
123 | self.assertEqual(used_distribution.concentration1, expected_value1)
124 |
125 | @parameterized.named_parameters([
126 | dict(
127 | testcase_name="dynamic_trend_national_shape",
128 | number_periods=100,
129 | initial_level_shape=(),
130 | initial_slope_shape=(),
131 | variance_level_shape=(),
132 | variance_slope_shape=(),
133 | ),
134 | dict(
135 | testcase_name="dynamic_trend_geo_shape",
136 | number_periods=100,
137 | initial_level_shape=(2,),
138 | initial_slope_shape=(2,),
139 | variance_level_shape=(2,),
140 | variance_slope_shape=(2,),
141 | ),
142 | ])
143 | def test_core_dynamic_trend_produces_correct_shape(
144 | self, number_periods, initial_level_shape, initial_slope_shape,
145 | variance_level_shape, variance_slope_shape):
146 | initial_level = jnp.ones(initial_level_shape)
147 | initial_slope = jnp.ones(initial_slope_shape)
148 | variance_level = jnp.ones(variance_level_shape)
149 | variance_slope = jnp.ones(variance_slope_shape)
150 | random_walk_level = jnp.arange(number_periods)
151 | random_walk_slope = jnp.arange(number_periods)
152 | if initial_level.ndim == 1: # For geo model's case
153 | random_walk_level = jnp.expand_dims(random_walk_level, axis=-1)
154 | random_walk_slope = jnp.expand_dims(random_walk_slope, axis=-1)
155 |
156 | dynamic_trend_values = trend._dynamic_trend(
157 | number_periods=number_periods,
158 | random_walk_level=random_walk_level,
159 | random_walk_slope=random_walk_slope,
160 | initial_level=initial_level,
161 | initial_slope=initial_slope,
162 | variance_level=variance_level,
163 | variance_slope=variance_slope,
164 | )
165 |
166 | self.assertEqual(dynamic_trend_values.shape,
167 | (number_periods, *initial_level_shape))
168 |
169 | def test_core_dynamic_trend_produces_correct_value(self):
170 | number_periods = 5
171 | initial_level = jnp.ones(())
172 | initial_slope = jnp.ones(())
173 | variance_level = jnp.ones(())
174 | variance_slope = jnp.ones(())
175 | random_walk_level = jnp.arange(number_periods)
176 | random_walk_slope = jnp.arange(number_periods)
177 | dynamic_trend_expected_value = jnp.array([1, 3, 7, 14, 25])
178 |
179 | dynamic_trend_values = trend._dynamic_trend(
180 | number_periods=number_periods,
181 | random_walk_level=random_walk_level,
182 | random_walk_slope=random_walk_slope,
183 | initial_level=initial_level,
184 | initial_slope=initial_slope,
185 | variance_level=variance_level,
186 | variance_slope=variance_slope,
187 | )
188 |
189 | np.testing.assert_array_equal(dynamic_trend_values,
190 | dynamic_trend_expected_value)
191 |
192 | @parameterized.named_parameters([
193 | dict(
194 | testcase_name="national_with_prediction_is_true",
195 | data_shape=(100, 3),
196 | is_trend_prediction=True),
197 | dict(
198 | testcase_name="geo_with_prediction_is_true",
199 | data_shape=(150, 3, 5),
200 | is_trend_prediction=True),
201 | dict(
202 | testcase_name="national_with_prediction_is_false",
203 | data_shape=(100, 3),
204 | is_trend_prediction=False),
205 | dict(
206 | testcase_name="geo_with_prediction_is_false",
207 | data_shape=(150, 3, 5),
208 | is_trend_prediction=False),
209 | ])
210 | def test_dynamic_trend_produces_correct_shape(
211 | self, data_shape, is_trend_prediction):
212 |
213 | def mock_model_function(geo_size, data_size):
214 | numpyro.deterministic(
215 | "trend", trend.dynamic_trend(
216 | geo_size=geo_size,
217 | data_size=data_size,
218 | is_trend_prediction=is_trend_prediction,
219 | custom_priors={},
220 | ))
221 | num_samples = 10
222 | data = jnp.ones(data_shape)
223 | geo_size = core_utils.get_number_geos(data)
224 | kernel = numpyro.infer.NUTS(model=mock_model_function)
225 | mcmc = numpyro.infer.MCMC(
226 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
227 | rng_key = jax.random.PRNGKey(0)
228 | coef_expected_shape = core_utils.get_geo_shape(data)
229 |
230 | mcmc.run(rng_key, geo_size=geo_size, data_size=data_shape[0])
231 | trend_values = mcmc.get_samples()["trend"]
232 |
233 | self.assertEqual(trend_values.shape,
234 | (num_samples, data.shape[0], *coef_expected_shape))
235 |
236 | @parameterized.named_parameters(
237 | dict(
238 | testcase_name=f"model_{priors.DYNAMIC_TREND_INITIAL_LEVEL}",
239 | prior_name=priors.DYNAMIC_TREND_INITIAL_LEVEL,
240 | ),
241 | dict(
242 | testcase_name=f"model_{priors.DYNAMIC_TREND_INITIAL_SLOPE}",
243 | prior_name=priors.DYNAMIC_TREND_INITIAL_SLOPE,
244 | ),
245 | dict(
246 | testcase_name=f"model_{priors.DYNAMIC_TREND_LEVEL_VARIANCE}",
247 | prior_name=priors.DYNAMIC_TREND_LEVEL_VARIANCE,
248 | ),
249 | dict(
250 | testcase_name=f"model_{priors.DYNAMIC_TREND_SLOPE_VARIANCE}",
251 | prior_name=priors.DYNAMIC_TREND_SLOPE_VARIANCE,
252 | ),
253 | )
254 | def test_core_dynamic_trend_custom_priors_are_taken_correctly(
255 | self, prior_name):
256 | expected_value1, expected_value2 = 5.2, 7.56
257 | custom_priors = {
258 | prior_name:
259 | dist.Kumaraswamy(
260 | concentration1=expected_value1, concentration0=expected_value2)
261 | }
262 | geo_size = 1
263 | data_size = 10
264 | trace_handler = handlers.trace(
265 | handlers.seed(trend.dynamic_trend, rng_seed=0))
266 | trace = trace_handler.get_trace(
267 | geo_size=geo_size,
268 | data_size=data_size,
269 | is_trend_prediction=False,
270 | custom_priors=custom_priors,
271 | )
272 | values_and_dists = {
273 | name: site["fn"] for name, site in trace.items() if "fn" in site
274 | }
275 |
276 | used_distribution = values_and_dists[prior_name]
277 | if isinstance(used_distribution, dist.ExpandedDistribution):
278 | used_distribution = used_distribution.base_dist
279 |
280 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
281 | self.assertEqual(used_distribution.concentration0, expected_value2)
282 | self.assertEqual(used_distribution.concentration1, expected_value1)
283 |
284 | @parameterized.named_parameters([
285 | dict(
286 | testcase_name="trend_prediction_is_true",
287 | is_trend_prediction=True,
288 | expected_trend_parameter=[
289 | "random_walk_level_prediction", "random_walk_slope_prediction"]
290 | ),
291 | dict(
292 | testcase_name="trend_prediction_is_false",
293 | is_trend_prediction=False,
294 | expected_trend_parameter=[
295 | "random_walk_level", "random_walk_slope"]
296 | ),
297 | ])
298 | def test_dynamic_trend_is_trend_prediction_produuce_correct_parameter_names(
299 | self, is_trend_prediction, expected_trend_parameter):
300 |
301 | def mock_model_function(geo_size, data_size):
302 | numpyro.deterministic(
303 | "trend", trend.dynamic_trend(
304 | geo_size=geo_size,
305 | data_size=data_size,
306 | is_trend_prediction=is_trend_prediction,
307 | custom_priors={},
308 | ))
309 | num_samples = 10
310 | data_shape = (10, 3)
311 | data = jnp.ones(data_shape)
312 | geo_size = core_utils.get_number_geos(data)
313 | kernel = numpyro.infer.NUTS(model=mock_model_function)
314 | mcmc = numpyro.infer.MCMC(
315 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
316 | rng_key = jax.random.PRNGKey(0)
317 |
318 | mcmc.run(rng_key, geo_size=geo_size, data_size=data_shape[0])
319 | trend_parameter = [
320 | parameter for parameter, _ in mcmc.get_samples().items()
321 | if parameter.startswith("random_walk")]
322 |
323 | self.assertEqual(trend_parameter, expected_trend_parameter)
324 |
325 | if __name__ == "__main__":
326 | absltest.main()
327 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/transformations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/transformations/identity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module for identity transformations."""
16 |
17 | from typing import Any
18 | import jax.numpy as jnp
19 |
20 |
21 | def identity_transform(
22 | data: jnp.ndarray, # pylint-ignore: unused-argument
23 | *args: Any,
24 | **kwargs: Any,
25 | ) -> jnp.ndarray:
26 | """Identity transform. Returns the main input as is."""
27 | return data
28 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/transformations/lagging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Set of core and modelling lagging functions."""
16 |
17 | import functools
18 | from typing import Mapping, Union
19 |
20 | import jax
21 | import jax.numpy as jnp
22 | import numpyro
23 | import numpyro.distributions as dist
24 | from lightweight_mmm.core import priors
25 |
26 |
27 | @functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1)
28 | def _carryover_convolve(data: jnp.ndarray, weights: jnp.ndarray,
29 | number_lags: int) -> jnp.ndarray:
30 | """Applies the convolution between the data and the weights for the carryover.
31 |
32 | Args:
33 | data: Input data.
34 | weights: Window weights for the carryover.
35 | number_lags: Number of lags the window has.
36 |
37 | Returns:
38 | The result values from convolving the data and the weights with padding.
39 | """
40 | window = jnp.concatenate([jnp.zeros(number_lags - 1), weights])
41 | return jax.scipy.signal.convolve(data, window, mode="same") / weights.sum()
42 |
43 |
44 | @functools.partial(jax.jit, static_argnames=("number_lags",))
45 | def _carryover(
46 | data: jnp.ndarray,
47 | ad_effect_retention_rate: jnp.ndarray,
48 | peak_effect_delay: jnp.ndarray,
49 | number_lags: int,
50 | ) -> jnp.ndarray:
51 | """Calculates media carryover.
52 |
53 | More details about this function can be found in:
54 | https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf
55 |
56 | Args:
57 | data: Input data. It is expected that data has either 2 dimensions for
58 | national models and 3 for geo models.
59 | ad_effect_retention_rate: Retention rate of the advertisement effect.
60 | Default is 0.5.
61 | peak_effect_delay: Delay of the peak effect in the carryover function.
62 | Default is 1.
63 | number_lags: Number of lags to include in the carryover calculation. Default
64 | is 13.
65 |
66 | Returns:
67 | The carryover values for the given data with the given parameters.
68 | """
69 | lags_arange = jnp.expand_dims(
70 | jnp.arange(number_lags, dtype=jnp.float32), axis=-1)
71 | convolve_func = _carryover_convolve
72 | if data.ndim == 3:
73 | # Since _carryover_convolve is already vmaped in the decorator we only need
74 | # to vmap it once here to handle the geo level data. We keep the windows bi
75 | # dimensional also for three dims data and vmap over only the extra data
76 | # dimension.
77 | convolve_func = jax.vmap(
78 | fun=_carryover_convolve, in_axes=(2, None, None), out_axes=2)
79 | weights = ad_effect_retention_rate**((lags_arange - peak_effect_delay)**2)
80 | return convolve_func(data, weights, number_lags)
81 |
82 |
83 | def carryover(
84 | data: jnp.ndarray,
85 | custom_priors: Mapping[str, dist.Distribution],
86 | *,
87 | number_lags: int = 13,
88 | prefix: str = "",
89 | ) -> jnp.ndarray:
90 | """Transforms the input data with the carryover function.
91 |
92 | Args:
93 | data: Media data to be transformed. It is expected to have 2 dims for
94 | national models and 3 for geo models.
95 | custom_priors: The custom priors we want the model to take instead of the
96 | default ones.
97 | number_lags: Number of lags for the carryover function.
98 | prefix: Prefix to use in the variable name for Numpyro.
99 |
100 | Returns:
101 | The transformed media data.
102 | """
103 | default_priors = priors.get_default_priors()
104 | with numpyro.plate(
105 | name=f"{prefix}{priors.AD_EFFECT_RETENTION_RATE}_plate",
106 | size=data.shape[1]):
107 | ad_effect_retention_rate = numpyro.sample(
108 | name=f"{prefix}{priors.AD_EFFECT_RETENTION_RATE}",
109 | fn=custom_priors.get(priors.AD_EFFECT_RETENTION_RATE,
110 | default_priors[priors.AD_EFFECT_RETENTION_RATE]))
111 |
112 | with numpyro.plate(
113 | name=f"{prefix}{priors.PEAK_EFFECT_DELAY}_plate", size=data.shape[1]):
114 | peak_effect_delay = numpyro.sample(
115 | name=f"{prefix}{priors.PEAK_EFFECT_DELAY}",
116 | fn=custom_priors.get(priors.PEAK_EFFECT_DELAY,
117 | default_priors[priors.PEAK_EFFECT_DELAY]))
118 |
119 | return _carryover(
120 | data=data,
121 | ad_effect_retention_rate=ad_effect_retention_rate,
122 | peak_effect_delay=peak_effect_delay,
123 | number_lags=number_lags)
124 |
125 |
126 | @jax.jit
127 | def _adstock(
128 | data: jnp.ndarray,
129 | lag_weight: Union[float, jnp.ndarray] = .9,
130 | normalise: bool = True,
131 | ) -> jnp.ndarray:
132 | """Calculates the adstock value of a given array.
133 |
134 | To learn more about advertising lag:
135 | https://en.wikipedia.org/wiki/Advertising_adstock
136 |
137 | Args:
138 | data: Input array.
139 | lag_weight: lag_weight effect of the adstock function. Default is 0.9.
140 | normalise: Whether to normalise the output value. This normalization will
141 | divide the output values by (1 / (1 - lag_weight)).
142 |
143 | Returns:
144 | The adstock output of the input array.
145 | """
146 |
147 | def adstock_internal(
148 | prev_adstock: jnp.ndarray,
149 | data: jnp.ndarray,
150 | lag_weight: Union[float, jnp.ndarray] = lag_weight,
151 | ) -> jnp.ndarray:
152 | adstock_value = prev_adstock * lag_weight + data
153 | return adstock_value, adstock_value# jax-ndarray
154 |
155 | _, adstock_values = jax.lax.scan(
156 | f=adstock_internal, init=data[0, ...], xs=data[1:, ...])
157 | adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values])
158 | return jax.lax.cond(
159 | normalise,
160 | lambda adstock_values: adstock_values / (1. / (1 - lag_weight)),
161 | lambda adstock_values: adstock_values,
162 | operand=adstock_values)
163 |
164 |
165 | def adstock(
166 | data: jnp.ndarray,
167 | custom_priors: Mapping[str, dist.Distribution],
168 | *,
169 | normalise: bool = True,
170 | prefix: str = "",
171 | ) -> jnp.ndarray:
172 | """Transforms the input data with the adstock function and exponent.
173 |
174 | Args:
175 | data: Media data to be transformed. It is expected to have 2 dims for
176 | national models and 3 for geo models.
177 | custom_priors: The custom priors we want the model to take instead of the
178 | default ones. The possible names of parameters for adstock and exponent
179 | are "lag_weight" and "exponent".
180 | normalise: Whether to normalise the output values.
181 | prefix: Prefix to use in the variable name for Numpyro.
182 |
183 | Returns:
184 | The transformed media data.
185 | """
186 | default_priors = priors.get_default_priors()
187 | with numpyro.plate(
188 | name=f"{prefix}{priors.LAG_WEIGHT}_plate", size=data.shape[1]):
189 | lag_weight = numpyro.sample(
190 | name=f"{prefix}{priors.LAG_WEIGHT}",
191 | fn=custom_priors.get(priors.LAG_WEIGHT,
192 | default_priors[priors.LAG_WEIGHT]))
193 |
194 | if data.ndim == 3:
195 | lag_weight = jnp.expand_dims(lag_weight, axis=-1)
196 |
197 | return _adstock(data=data, lag_weight=lag_weight, normalise=normalise)
198 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/transformations/lagging_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lagging."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | import numpyro
23 | from numpyro import handlers
24 | import numpyro.distributions as dist
25 |
26 | from lightweight_mmm.core import priors
27 | from lightweight_mmm.core.transformations import lagging
28 |
29 |
30 | class LaggingTest(parameterized.TestCase):
31 |
32 | @parameterized.named_parameters(
33 | dict(
34 | testcase_name="national",
35 | data_shape=(150, 3),
36 | ad_effect_retention_rate_shape=(3,),
37 | peak_effect_delay_shape=(3,),
38 | number_lags=13,
39 | ),
40 | dict(
41 | testcase_name="geo",
42 | data_shape=(150, 3, 5),
43 | ad_effect_retention_rate_shape=(3,),
44 | peak_effect_delay_shape=(3,),
45 | number_lags=13,
46 | ),
47 | )
48 | def test_core_carryover_produces_correct_shape(
49 | self,
50 | data_shape,
51 | ad_effect_retention_rate_shape,
52 | peak_effect_delay_shape,
53 | number_lags,
54 | ):
55 | data = jnp.ones(data_shape)
56 | ad_effect_retention_rate = jnp.ones(ad_effect_retention_rate_shape)
57 | peak_effect_delay = jnp.ones(peak_effect_delay_shape)
58 |
59 | output = lagging._carryover(
60 | data=data,
61 | ad_effect_retention_rate=ad_effect_retention_rate,
62 | peak_effect_delay=peak_effect_delay,
63 | number_lags=number_lags,
64 | )
65 |
66 | self.assertEqual(output.shape, data_shape)
67 |
68 | @parameterized.named_parameters(
69 | dict(
70 | testcase_name="national",
71 | data_shape=(150, 3),
72 | ),
73 | dict(
74 | testcase_name="geo",
75 | data_shape=(150, 3, 5),
76 | ),
77 | )
78 | def test_carryover_produces_correct_shape(self, data_shape):
79 |
80 | def mock_model_function(data, number_lags):
81 | numpyro.deterministic(
82 | "carryover",
83 | lagging.carryover(
84 | data=data, custom_priors={}, number_lags=number_lags))
85 |
86 | num_samples = 10
87 | data = jnp.ones(data_shape)
88 | number_lags = 15
89 | kernel = numpyro.infer.NUTS(model=mock_model_function)
90 | mcmc = numpyro.infer.MCMC(
91 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
92 | rng_key = jax.random.PRNGKey(0)
93 |
94 | mcmc.run(rng_key, data=data, number_lags=number_lags)
95 | carryover_values = mcmc.get_samples()["carryover"]
96 |
97 | self.assertEqual(carryover_values.shape, (num_samples, *data.shape))
98 |
99 | @parameterized.named_parameters(
100 | dict(
101 | testcase_name="ad_effect_retention_rate",
102 | prior_name=priors.AD_EFFECT_RETENTION_RATE,
103 | ),
104 | dict(
105 | testcase_name="peak_effect_delay",
106 | prior_name=priors.PEAK_EFFECT_DELAY,
107 | ),
108 | )
109 | def test_carryover_custom_priors_are_taken_correctly(self, prior_name):
110 | expected_value1, expected_value2 = 5.2, 7.56
111 | custom_priors = {
112 | prior_name:
113 | dist.Kumaraswamy(
114 | concentration1=expected_value1, concentration0=expected_value2)
115 | }
116 | media = jnp.ones((10, 5, 5))
117 | number_lags = 13
118 |
119 | trace_handler = handlers.trace(handlers.seed(lagging.carryover, rng_seed=0))
120 | trace = trace_handler.get_trace(
121 | data=media,
122 | custom_priors=custom_priors,
123 | number_lags=number_lags,
124 | )
125 | values_and_dists = {
126 | name: site["fn"] for name, site in trace.items() if "fn" in site
127 | }
128 |
129 | used_distribution = values_and_dists[prior_name]
130 | used_distribution = used_distribution.base_dist
131 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
132 | self.assertEqual(used_distribution.concentration0, expected_value2)
133 | self.assertEqual(used_distribution.concentration1, expected_value1)
134 |
135 | @parameterized.named_parameters(
136 | dict(
137 | testcase_name="national",
138 | data_shape=(150, 3),
139 | lag_weight_shape=(3,),
140 | ),
141 | dict(
142 | testcase_name="geo",
143 | data_shape=(150, 3, 5),
144 | lag_weight_shape=(3, 1),
145 | ),
146 | )
147 | def test_core_adstock_produces_correct_shape(self, data_shape,
148 | lag_weight_shape):
149 | data = jnp.ones(data_shape)
150 | lag_weight = jnp.ones(lag_weight_shape)
151 |
152 | output = lagging._adstock(data=data, lag_weight=lag_weight)
153 |
154 | self.assertEqual(output.shape, data_shape)
155 |
156 | @parameterized.named_parameters(
157 | dict(
158 | testcase_name="national",
159 | data_shape=(150, 3),
160 | ),
161 | dict(
162 | testcase_name="geo",
163 | data_shape=(150, 3, 5),
164 | ),
165 | )
166 | def test_adstock_produces_correct_shape(self, data_shape):
167 |
168 | def mock_model_function(data, normalise):
169 | numpyro.deterministic(
170 | "adstock",
171 | lagging.adstock(data=data, custom_priors={}, normalise=normalise))
172 |
173 | num_samples = 10
174 | data = jnp.ones(data_shape)
175 | normalise = True
176 | kernel = numpyro.infer.NUTS(model=mock_model_function)
177 | mcmc = numpyro.infer.MCMC(
178 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
179 | rng_key = jax.random.PRNGKey(0)
180 |
181 | mcmc.run(rng_key, data=data, normalise=normalise)
182 | adstock_values = mcmc.get_samples()["adstock"]
183 |
184 | self.assertEqual(adstock_values.shape, (num_samples, *data.shape))
185 |
186 | def test_adstock_custom_priors_are_taken_correctly(self):
187 | prior_name = priors.LAG_WEIGHT
188 | expected_value1, expected_value2 = 5.2, 7.56
189 | custom_priors = {
190 | prior_name:
191 | dist.Kumaraswamy(
192 | concentration1=expected_value1, concentration0=expected_value2)
193 | }
194 | data = jnp.ones((10, 5, 5))
195 |
196 | trace_handler = handlers.trace(handlers.seed(lagging.adstock, rng_seed=0))
197 | trace = trace_handler.get_trace(
198 | data=data,
199 | custom_priors=custom_priors,
200 | normalise=True,
201 | )
202 | values_and_dists = {
203 | name: site["fn"] for name, site in trace.items() if "fn" in site
204 | }
205 |
206 | used_distribution = values_and_dists[prior_name]
207 | used_distribution = used_distribution.base_dist
208 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
209 | self.assertEqual(used_distribution.concentration0, expected_value2)
210 | self.assertEqual(used_distribution.concentration1, expected_value1)
211 |
212 | def test_adstock_zeros_stay_zeros(self):
213 | data = jnp.zeros((10, 5))
214 | lag_weight = jnp.full(5, 0.5)
215 |
216 | generated_output = lagging._adstock(data=data, lag_weight=lag_weight)
217 |
218 | np.testing.assert_array_equal(generated_output, data)
219 |
220 | def test_carryover_zeros_stay_zeros(self):
221 | data = jnp.zeros((10, 5))
222 | ad_effect_retention_rate = jnp.full(5, 0.5)
223 | peak_effect_delay = jnp.full(5, 0.5)
224 |
225 | generated_output = lagging._carryover(
226 | data=data,
227 | ad_effect_retention_rate=ad_effect_retention_rate,
228 | peak_effect_delay=peak_effect_delay,
229 | number_lags=7,
230 | )
231 |
232 | np.testing.assert_array_equal(generated_output, data)
233 |
234 |
235 | if __name__ == "__main__":
236 | absltest.main()
237 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/transformations/saturation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Set of core and modelling saturation functions."""
16 |
17 | from typing import Mapping
18 | import jax
19 | import jax.numpy as jnp
20 | import numpyro
21 | from numpyro import distributions as dist
22 |
23 | from lightweight_mmm.core import core_utils
24 | from lightweight_mmm.core import priors
25 |
26 |
27 | @jax.jit
28 | def _hill(
29 | data: jnp.ndarray,
30 | half_max_effective_concentration: jnp.ndarray,
31 | slope: jnp.ndarray,
32 | ) -> jnp.ndarray:
33 | """Calculates the hill function for a given array of values.
34 |
35 | Refer to the following link for detailed information on this equation:
36 | https://en.wikipedia.org/wiki/Hill_equation_(biochemistry)
37 |
38 | Args:
39 | data: Input data.
40 | half_max_effective_concentration: ec50 value for the hill function.
41 | slope: Slope of the hill function.
42 |
43 | Returns:
44 | The hill values for the respective input data.
45 | """
46 | save_transform = core_utils.apply_exponent_safe(
47 | data=data / half_max_effective_concentration, exponent=-slope)
48 | return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform))
49 |
50 |
51 | def hill(
52 | data: jnp.ndarray,
53 | custom_priors: Mapping[str, dist.Distribution],
54 | *,
55 | prefix: str = "",
56 | ) -> jnp.ndarray:
57 | """Transforms the input data with the adstock and hill functions.
58 |
59 | Args:
60 | data: Media data to be transformed. It is expected to have 2 dims for
61 | national models and 3 for geo models.
62 | custom_priors: The custom priors we want the model to take instead of the
63 | default ones. The possible names of parameters for hill_adstock and
64 | exponent are "lag_weight", "half_max_effective_concentration" and "slope".
65 | prefix: Prefix to use in the variable name for Numpyro.
66 |
67 | Returns:
68 | The transformed media data.
69 | """
70 | default_priors = priors.get_default_priors()
71 |
72 | with numpyro.plate(
73 | name=f"{prefix}{priors.HALF_MAX_EFFECTIVE_CONCENTRATION}_plate",
74 | size=data.shape[1]):
75 | half_max_effective_concentration = numpyro.sample(
76 | name=f"{prefix}{priors.HALF_MAX_EFFECTIVE_CONCENTRATION}",
77 | fn=custom_priors.get(
78 | priors.HALF_MAX_EFFECTIVE_CONCENTRATION,
79 | default_priors[priors.HALF_MAX_EFFECTIVE_CONCENTRATION]))
80 |
81 | with numpyro.plate(name=f"{prefix}{priors.SLOPE}_plate", size=data.shape[1]):
82 | slope = numpyro.sample(
83 | name=f"{prefix}{priors.SLOPE}",
84 | fn=custom_priors.get(priors.SLOPE, default_priors[priors.SLOPE]))
85 |
86 | if data.ndim == 3:
87 | half_max_effective_concentration = jnp.expand_dims(
88 | half_max_effective_concentration, axis=-1)
89 | slope = jnp.expand_dims(slope, axis=-1)
90 |
91 | return _hill(
92 | data=data,
93 | half_max_effective_concentration=half_max_effective_concentration,
94 | slope=slope)
95 |
96 |
97 | def _exponent(data: jnp.ndarray, exponent_values: jnp.ndarray) -> jnp.ndarray:
98 | """Applies exponent to the given data."""
99 | return core_utils.apply_exponent_safe(data=data, exponent=exponent_values)
100 |
101 |
102 | def exponent(
103 | data: jnp.ndarray,
104 | custom_priors: Mapping[str, dist.Distribution],
105 | *,
106 | prefix: str = "",
107 | ) -> jnp.ndarray:
108 | """Transforms the input data with the carryover function and exponent.
109 |
110 | Args:
111 | data: Media data to be transformed. It is expected to have 2 dims for
112 | national models and 3 for geo models.
113 | custom_priors: The custom priors we want the model to take instead of the
114 | default ones.
115 | prefix: Prefix to use in the variable name for Numpyro.
116 |
117 | Returns:
118 | The transformed media data.
119 | """
120 | default_priors = priors.get_default_priors()
121 |
122 | with numpyro.plate(
123 | name=f"{prefix}{priors.EXPONENT}_plate", size=data.shape[1]):
124 | exponent_values = numpyro.sample(
125 | name=f"{prefix}{priors.EXPONENT}",
126 | fn=custom_priors.get(priors.EXPONENT, default_priors[priors.EXPONENT]))
127 |
128 | if data.ndim == 3:
129 | exponent_values = jnp.expand_dims(exponent_values, axis=-1)
130 | return _exponent(data=data, exponent_values=exponent_values)
131 |
--------------------------------------------------------------------------------
/lightweight_mmm/core/transformations/saturation_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for saturation."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | import numpyro
23 | from numpyro import handlers
24 | import numpyro.distributions as dist
25 |
26 | from lightweight_mmm.core import priors
27 | from lightweight_mmm.core.transformations import saturation
28 |
29 |
30 | class SaturationTest(parameterized.TestCase):
31 |
32 | @parameterized.named_parameters(
33 | dict(
34 | testcase_name="national",
35 | data_shape=(150, 3),
36 | half_max_effective_concentration_shape=(3,),
37 | slope_shape=(3,),
38 | ),
39 | dict(
40 | testcase_name="geo",
41 | data_shape=(150, 3, 5),
42 | half_max_effective_concentration_shape=(3, 1),
43 | slope_shape=(3, 1),
44 | ),
45 | )
46 | def test_hill_core_produces_correct_shape(
47 | self, data_shape, half_max_effective_concentration_shape, slope_shape):
48 | data = jnp.ones(data_shape)
49 | half_max_effective_concentration = jnp.ones(
50 | half_max_effective_concentration_shape)
51 | slope = jnp.ones(slope_shape)
52 |
53 | output = saturation._hill(
54 | data=data,
55 | half_max_effective_concentration=half_max_effective_concentration,
56 | slope=slope,
57 | )
58 |
59 | self.assertEqual(output.shape, data_shape)
60 |
61 | @parameterized.named_parameters(
62 | dict(
63 | testcase_name="national",
64 | data_shape=(150, 3),
65 | ),
66 | dict(
67 | testcase_name="geo",
68 | data_shape=(150, 3, 5),
69 | ),
70 | )
71 | def test_hill_produces_correct_shape(self, data_shape):
72 |
73 | def mock_model_function(data):
74 | numpyro.deterministic("hill",
75 | saturation.hill(data=data, custom_priors={}))
76 |
77 | num_samples = 10
78 | data = jnp.ones(data_shape)
79 | kernel = numpyro.infer.NUTS(model=mock_model_function)
80 | mcmc = numpyro.infer.MCMC(
81 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
82 | rng_key = jax.random.PRNGKey(0)
83 |
84 | mcmc.run(rng_key, data=data)
85 | output_values = mcmc.get_samples()["hill"]
86 |
87 | self.assertEqual(output_values.shape, (num_samples, *data.shape))
88 |
89 | @parameterized.named_parameters(
90 | dict(
91 | testcase_name="half_max_effective_concentration",
92 | prior_name=priors.HALF_MAX_EFFECTIVE_CONCENTRATION,
93 | ),
94 | dict(
95 | testcase_name="slope",
96 | prior_name=priors.SLOPE,
97 | ),
98 | )
99 | def test_hill_custom_priors_are_taken_correctly(self, prior_name):
100 | expected_value1, expected_value2 = 5.2, 7.56
101 | custom_priors = {
102 | prior_name:
103 | dist.Kumaraswamy(
104 | concentration1=expected_value1, concentration0=expected_value2)
105 | }
106 | media = jnp.ones((10, 5, 5))
107 |
108 | trace_handler = handlers.trace(handlers.seed(saturation.hill, rng_seed=0))
109 | trace = trace_handler.get_trace(
110 | data=media,
111 | custom_priors=custom_priors,
112 | )
113 | values_and_dists = {
114 | name: site["fn"] for name, site in trace.items() if "fn" in site
115 | }
116 |
117 | used_distribution = values_and_dists[prior_name]
118 | used_distribution = used_distribution.base_dist
119 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
120 | self.assertEqual(used_distribution.concentration0, expected_value2)
121 | self.assertEqual(used_distribution.concentration1, expected_value1)
122 |
123 | def test_exponent_core_produces_correct_shape(self):
124 | pass
125 |
126 | @parameterized.named_parameters(
127 | dict(
128 | testcase_name="national",
129 | data_shape=(150, 3),
130 | ),
131 | dict(
132 | testcase_name="geo",
133 | data_shape=(150, 3, 5),
134 | ),
135 | )
136 | def test_exponent_produces_correct_shape(self, data_shape):
137 |
138 | def mock_model_function(data):
139 | numpyro.deterministic("outer_exponent",
140 | saturation.exponent(data=data, custom_priors={}))
141 |
142 | num_samples = 10
143 | data = jnp.ones(data_shape)
144 | kernel = numpyro.infer.NUTS(model=mock_model_function)
145 | mcmc = numpyro.infer.MCMC(
146 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1)
147 | rng_key = jax.random.PRNGKey(0)
148 |
149 | mcmc.run(rng_key, data=data)
150 | output_values = mcmc.get_samples()["outer_exponent"]
151 |
152 | self.assertEqual(output_values.shape, (num_samples, *data.shape))
153 |
154 | def test_exponent_custom_priors_are_taken_correctly(self):
155 | prior_name = priors.EXPONENT
156 | expected_value1, expected_value2 = 5.2, 7.56
157 | custom_priors = {
158 | prior_name:
159 | dist.Kumaraswamy(
160 | concentration1=expected_value1, concentration0=expected_value2)
161 | }
162 | media = jnp.ones((10, 5, 5))
163 |
164 | trace_handler = handlers.trace(
165 | handlers.seed(saturation.exponent, rng_seed=0))
166 | trace = trace_handler.get_trace(
167 | data=media,
168 | custom_priors=custom_priors,
169 | )
170 | values_and_dists = {
171 | name: site["fn"] for name, site in trace.items() if "fn" in site
172 | }
173 |
174 | used_distribution = values_and_dists[prior_name]
175 | used_distribution = used_distribution.base_dist
176 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
177 | self.assertEqual(used_distribution.concentration0, expected_value2)
178 | self.assertEqual(used_distribution.concentration1, expected_value1)
179 |
180 | def test_hill_zeros_stay_zeros(self):
181 | data = jnp.zeros((10, 5))
182 | half_max_effective_concentration = jnp.full(5, 0.5)
183 | slope = jnp.full(5, 0.5)
184 |
185 | generated_output = saturation._hill(
186 | data=data,
187 | half_max_effective_concentration=half_max_effective_concentration,
188 | slope=slope,
189 | )
190 |
191 | np.testing.assert_array_equal(generated_output, data)
192 |
193 | def test_exponent_zeros_stay_zero(self):
194 | data = jnp.zeros((10, 5))
195 | exponent_values = jnp.full(5, 0.5)
196 |
197 | generated_output = saturation._exponent(
198 | data=data,
199 | exponent_values=exponent_values,
200 | )
201 |
202 | np.testing.assert_array_equal(generated_output, data)
203 |
204 |
205 | if __name__ == "__main__":
206 | absltest.main()
207 |
--------------------------------------------------------------------------------
/lightweight_mmm/media_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Media transformations for accounting for lagging or media effects."""
16 |
17 | import functools
18 | from typing import Union
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 |
24 | @functools.partial(jax.jit, static_argnums=[0, 1])
25 | def calculate_seasonality(
26 | number_periods: int,
27 | degrees: int,
28 | gamma_seasonality: Union[int, float, jnp.ndarray],
29 | frequency: int = 52,
30 | ) -> jnp.ndarray:
31 | """Calculates cyclic variation seasonality using Fourier terms.
32 |
33 | For detailed info check:
34 | https://en.wikipedia.org/wiki/Seasonality#Modeling
35 |
36 | Args:
37 | number_periods: Number of seasonal periods in the data. Eg. for 1 year of
38 | seasonal data it will be 52, for 3 years of the same kind 156.
39 | degrees: Number of degrees to use. Must be greater or equal than 1.
40 | gamma_seasonality: Factor to multiply to each degree calculation. Shape must
41 | be aligned with the number of degrees.
42 | frequency: Frequency of the seasonality being computed. By default is 52 for
43 | weekly data (52 weeks in a year).
44 |
45 | Returns:
46 | An array with the seasonality values.
47 | """
48 |
49 | seasonality_range = jnp.expand_dims(a=jnp.arange(number_periods), axis=-1)
50 | degrees_range = jnp.arange(1, degrees+1)
51 | inner_value = seasonality_range * 2 * jnp.pi * degrees_range / frequency
52 | season_matrix_sin = jnp.sin(inner_value)
53 | season_matrix_cos = jnp.cos(inner_value)
54 | season_matrix = jnp.concatenate([
55 | jnp.expand_dims(a=season_matrix_sin, axis=-1),
56 | jnp.expand_dims(a=season_matrix_cos, axis=-1)
57 | ],
58 | axis=-1)
59 | return (season_matrix * gamma_seasonality).sum(axis=2).sum(axis=1)
60 |
61 |
62 | @jax.jit
63 | def adstock(data: jnp.ndarray,
64 | lag_weight: float = .9,
65 | normalise: bool = True) -> jnp.ndarray:
66 | """Calculates the adstock value of a given array.
67 |
68 | To learn more about advertising lag:
69 | https://en.wikipedia.org/wiki/Advertising_adstock
70 |
71 | Args:
72 | data: Input array.
73 | lag_weight: lag_weight effect of the adstock function. Default is 0.9.
74 | normalise: Whether to normalise the output value. This normalization will
75 | divide the output values by (1 / (1 - lag_weight)).
76 |
77 | Returns:
78 | The adstock output of the input array.
79 | """
80 |
81 | def adstock_internal(prev_adstock: jnp.ndarray,
82 | data: jnp.ndarray,
83 | lag_weight: float = lag_weight) -> jnp.ndarray:
84 | adstock_value = prev_adstock * lag_weight + data
85 | return adstock_value, adstock_value# jax-ndarray
86 |
87 | _, adstock_values = jax.lax.scan(
88 | f=adstock_internal, init=data[0, ...], xs=data[1:, ...])
89 | adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values])
90 | return jax.lax.cond(
91 | normalise,
92 | lambda adstock_values: adstock_values / (1. / (1 - lag_weight)),
93 | lambda adstock_values: adstock_values,
94 | operand=adstock_values)
95 |
96 |
97 | @jax.jit
98 | def hill(data: jnp.ndarray, half_max_effective_concentration: jnp.ndarray,
99 | slope: jnp.ndarray) -> jnp.ndarray:
100 | """Calculates the hill function for a given array of values.
101 |
102 | Refer to the following link for detailed information on this equation:
103 | https://en.wikipedia.org/wiki/Hill_equation_(biochemistry)
104 |
105 | Args:
106 | data: Input data.
107 | half_max_effective_concentration: ec50 value for the hill function.
108 | slope: Slope of the hill function.
109 |
110 | Returns:
111 | The hill values for the respective input data.
112 | """
113 | save_transform = apply_exponent_safe(
114 | data=data / half_max_effective_concentration, exponent=-slope)
115 | return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform))
116 |
117 |
118 | @functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1)
119 | def _carryover_convolve(data: jnp.ndarray,
120 | weights: jnp.ndarray,
121 | number_lags: int) -> jnp.ndarray:
122 | """Applies the convolution between the data and the weights for the carryover.
123 |
124 | Args:
125 | data: Input data.
126 | weights: Window weights for the carryover.
127 | number_lags: Number of lags the window has.
128 |
129 | Returns:
130 | The result values from convolving the data and the weights with padding.
131 | """
132 | window = jnp.concatenate([jnp.zeros(number_lags - 1), weights])
133 | return jax.scipy.signal.convolve(data, window, mode="same") / weights.sum()
134 |
135 |
136 | @functools.partial(jax.jit, static_argnames=("number_lags",))
137 | def carryover(data: jnp.ndarray,
138 | ad_effect_retention_rate: jnp.ndarray,
139 | peak_effect_delay: jnp.ndarray,
140 | number_lags: int = 13) -> jnp.ndarray:
141 | """Calculates media carryover.
142 |
143 | More details about this function can be found in:
144 | https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf
145 |
146 | Args:
147 | data: Input data. It is expected that data has either 2 dimensions for
148 | national models and 3 for geo models.
149 | ad_effect_retention_rate: Retention rate of the advertisement effect.
150 | Default is 0.5.
151 | peak_effect_delay: Delay of the peak effect in the carryover function.
152 | Default is 1.
153 | number_lags: Number of lags to include in the carryover calculation. Default
154 | is 13.
155 |
156 | Returns:
157 | The carryover values for the given data with the given parameters.
158 | """
159 | lags_arange = jnp.expand_dims(jnp.arange(number_lags, dtype=jnp.float32),
160 | axis=-1)
161 | convolve_func = _carryover_convolve
162 | if data.ndim == 3:
163 | # Since _carryover_convolve is already vmaped in the decorator we only need
164 | # to vmap it once here to handle the geo level data. We keep the windows bi
165 | # dimensional also for three dims data and vmap over only the extra data
166 | # dimension.
167 | convolve_func = jax.vmap(
168 | fun=_carryover_convolve, in_axes=(2, None, None), out_axes=2)
169 | weights = ad_effect_retention_rate**((lags_arange - peak_effect_delay)**2)
170 | return convolve_func(data, weights, number_lags)
171 |
172 | @jax.jit
173 | def apply_exponent_safe(
174 | data: jnp.ndarray,
175 | exponent: jnp.ndarray,
176 | ) -> jnp.ndarray:
177 | """Applies an exponent to given data in a gradient safe way.
178 |
179 | More info on the double jnp.where can be found:
180 | https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
181 |
182 | Args:
183 | data: Input data to use.
184 | exponent: Exponent required for the operations.
185 |
186 | Returns:
187 | The result of the exponent operation with the inputs provided.
188 | """
189 | exponent_safe = jnp.where(data == 0, 1, data) ** exponent
190 | return jnp.where(data == 0, 0, exponent_safe)
191 |
--------------------------------------------------------------------------------
/lightweight_mmm/media_transforms_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for media_transforms."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 | from lightweight_mmm import media_transforms
24 |
25 |
26 | class MediaTransformsTest(parameterized.TestCase):
27 |
28 | @parameterized.named_parameters([
29 | dict(
30 | testcase_name="2d_four_channels",
31 | data=np.ones((100, 4)),
32 | ad_effect_retention_rate=np.array([0.9, 0.8, 0.7, 1]),
33 | peak_effect_delay=np.array([0.9, 0.8, 0.7, 1]),
34 | number_lags=5),
35 | dict(
36 | testcase_name="2d_one_channel",
37 | data=np.ones((300, 1)),
38 | ad_effect_retention_rate=np.array([0.2]),
39 | peak_effect_delay=np.array([1]),
40 | number_lags=10),
41 | dict(
42 | testcase_name="3d_10channels_10geos",
43 | data=np.ones((100, 10, 10)),
44 | ad_effect_retention_rate=np.ones(10),
45 | peak_effect_delay=np.ones(10),
46 | number_lags=13),
47 | dict(
48 | testcase_name="3d_10channels_8geos",
49 | data=np.ones((100, 10, 8)),
50 | ad_effect_retention_rate=np.ones(10),
51 | peak_effect_delay=np.ones(10),
52 | number_lags=13),
53 | ])
54 | def test_carryover_produces_correct_shape(self, data,
55 | ad_effect_retention_rate,
56 | peak_effect_delay, number_lags):
57 |
58 | generated_output = media_transforms.carryover(data,
59 | ad_effect_retention_rate,
60 | peak_effect_delay,
61 | number_lags)
62 | self.assertEqual(generated_output.shape, data.shape)
63 |
64 | @parameterized.named_parameters([
65 | dict(
66 | testcase_name="2d_three_channels",
67 | data=np.ones((100, 3)),
68 | half_max_effective_concentration=np.array([0.9, 0.8, 0.7]),
69 | slope=np.array([2, 2, 1])),
70 | dict(
71 | testcase_name="2d_one_channels",
72 | data=np.ones((100, 1)),
73 | half_max_effective_concentration=np.array([0.9]),
74 | slope=np.array([5])),
75 | dict(
76 | testcase_name="3d_10channels_5geos",
77 | data=np.ones((100, 10, 5)),
78 | half_max_effective_concentration=np.expand_dims(np.ones(10), axis=-1),
79 | slope=np.expand_dims(np.ones(10), axis=-1)),
80 | dict(
81 | testcase_name="3d_8channels_10geos",
82 | data=np.ones((100, 8, 10)),
83 | half_max_effective_concentration=np.expand_dims(np.ones(8), axis=-1),
84 | slope=np.expand_dims(np.ones(8), axis=-1)),
85 | ])
86 | def test_hill_produces_correct_shape(self, data,
87 | half_max_effective_concentration, slope):
88 | generated_output = media_transforms.hill(
89 | data=data,
90 | half_max_effective_concentration=half_max_effective_concentration,
91 | slope=slope)
92 |
93 | self.assertEqual(generated_output.shape, data.shape)
94 |
95 | @parameterized.named_parameters([
96 | dict(
97 | testcase_name="2d_five_channels",
98 | data=np.ones((100, 5)),
99 | lag_weight=np.array([0.2, 0.3, 0.8, 0.2, 0.1]),
100 | normalise=True),
101 | dict(
102 | testcase_name="2d_one_channels",
103 | data=np.ones((100, 1)),
104 | lag_weight=np.array([0.4]),
105 | normalise=False),
106 | dict(
107 | testcase_name="3d_10channels_5geos",
108 | data=np.ones((100, 10, 5)),
109 | lag_weight=np.expand_dims(np.ones(10), axis=-1),
110 | normalise=True),
111 | dict(
112 | testcase_name="3d_8channels_10geos",
113 | data=np.ones((100, 8, 10)),
114 | lag_weight=np.expand_dims(np.ones(8), axis=-1),
115 | normalise=True),
116 | ])
117 | def test_adstock_produces_correct_shape(self, data, lag_weight, normalise):
118 | generated_output = media_transforms.adstock(
119 | data=data, lag_weight=lag_weight, normalise=normalise)
120 |
121 | self.assertEqual(generated_output.shape, data.shape)
122 |
123 | def test_apply_exponent_safe_produces_correct_shape(self):
124 | data = jnp.arange(50).reshape((10, 5))
125 | exponent = jnp.full(5, 0.5)
126 |
127 | output = media_transforms.apply_exponent_safe(data=data, exponent=exponent)
128 |
129 | np.testing.assert_array_equal(output, data**exponent)
130 |
131 | def test_apply_exponent_safe_produces_same_exponent_results(self):
132 | data = jnp.ones((10, 5))
133 | exponent = jnp.full(5, 0.5)
134 |
135 | output = media_transforms.apply_exponent_safe(data=data, exponent=exponent)
136 |
137 | self.assertEqual(output.shape, data.shape)
138 |
139 | def test_apply_exponent_safe_produces_non_nan_or_inf_grads(self):
140 | def f_safe(data, exponent):
141 | x = media_transforms.apply_exponent_safe(data=data, exponent=exponent)
142 | return x.sum()
143 | data = jnp.ones((10, 5))
144 | data = data.at[0, 0].set(0.)
145 | exponent = jnp.full(5, 0.5)
146 |
147 | grads = jax.grad(f_safe)(data, exponent)
148 |
149 | self.assertFalse(np.isnan(grads).any())
150 | self.assertFalse(np.isinf(grads).any())
151 |
152 | def test_adstock_zeros_stay_zeros(self):
153 | data = jnp.zeros((10, 5))
154 | lag_weight = jnp.full(5, 0.5)
155 |
156 | generated_output = media_transforms.adstock(
157 | data=data, lag_weight=lag_weight)
158 |
159 | np.testing.assert_array_equal(generated_output, data)
160 |
161 | def test_hill_zeros_stay_zeros(self):
162 | data = jnp.zeros((10, 5))
163 | half_max_effective_concentration = jnp.full(5, 0.5)
164 | slope = jnp.full(5, 0.5)
165 |
166 | generated_output = media_transforms.hill(
167 | data=data,
168 | half_max_effective_concentration=half_max_effective_concentration,
169 | slope=slope)
170 |
171 | np.testing.assert_array_equal(generated_output, data)
172 |
173 | def test_carryover_zeros_stay_zeros(self):
174 | data = jnp.zeros((10, 5))
175 | ad_effect_retention_rate = jnp.full(5, 0.5)
176 | peak_effect_delay = jnp.full(5, 0.5)
177 |
178 | generated_output = media_transforms.carryover(
179 | data=data,
180 | ad_effect_retention_rate=ad_effect_retention_rate,
181 | peak_effect_delay=peak_effect_delay)
182 |
183 | np.testing.assert_array_equal(generated_output, data)
184 |
185 |
186 | @parameterized.parameters(range(1, 5))
187 | def test_calculate_seasonality_produces_correct_standard_deviation(
188 | self, degrees):
189 | # It's not very obvious that this is the expected standard deviation, but it
190 | # seems to be true mathematically and this makes a very convenient unit test.
191 | expected_standard_deviation = jnp.sqrt(degrees)
192 |
193 | seasonal_curve = media_transforms.calculate_seasonality(
194 | number_periods=1,
195 | degrees=degrees,
196 | gamma_seasonality=1,
197 | frequency=1200,
198 | )
199 | observed_standard_deviation = jnp.std(seasonal_curve)
200 |
201 | self.assertAlmostEqual(
202 | observed_standard_deviation, expected_standard_deviation, delta=0.01)
203 |
204 |
205 | if __name__ == "__main__":
206 | absltest.main()
207 |
--------------------------------------------------------------------------------
/lightweight_mmm/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module containing the different models available in the lightweightMMM lib.
16 |
17 | Currently this file contains a main model with three possible options for
18 | processing the media data. Which essentially grants the possibility of building
19 | three different models.
20 | - Adstock
21 | - Hill-Adstock
22 | - Carryover
23 | """
24 | import sys
25 | # pylint: disable=g-import-not-at-top
26 | if sys.version_info >= (3, 8):
27 | from typing import Protocol
28 | else:
29 | from typing_extensions import Protocol
30 |
31 | from typing import Any, Dict, Mapping, MutableMapping, Optional, Sequence, Union
32 |
33 | import immutabledict
34 | import jax.numpy as jnp
35 | import numpyro
36 | from numpyro import distributions as dist
37 |
38 | from lightweight_mmm import media_transforms
39 |
40 | Prior = Union[
41 | dist.Distribution,
42 | Dict[str, float],
43 | Sequence[float],
44 | float
45 | ]
46 |
47 |
48 | class TransformFunction(Protocol):
49 |
50 | def __call__(
51 | self,
52 | media_data: jnp.ndarray,
53 | custom_priors: MutableMapping[str, Prior],
54 | **kwargs: Any) -> jnp.ndarray:
55 | ...
56 |
57 |
58 | _INTERCEPT = "intercept"
59 | _COEF_TREND = "coef_trend"
60 | _EXPO_TREND = "expo_trend"
61 | _SIGMA = "sigma"
62 | _GAMMA_SEASONALITY = "gamma_seasonality"
63 | _WEEKDAY = "weekday"
64 | _COEF_EXTRA_FEATURES = "coef_extra_features"
65 | _COEF_SEASONALITY = "coef_seasonality"
66 |
67 | MODEL_PRIORS_NAMES = frozenset((
68 | _INTERCEPT,
69 | _COEF_TREND,
70 | _EXPO_TREND,
71 | _SIGMA,
72 | _GAMMA_SEASONALITY,
73 | _WEEKDAY,
74 | _COEF_EXTRA_FEATURES,
75 | _COEF_SEASONALITY))
76 |
77 | _EXPONENT = "exponent"
78 | _LAG_WEIGHT = "lag_weight"
79 | _HALF_MAX_EFFECTIVE_CONCENTRATION = "half_max_effective_concentration"
80 | _SLOPE = "slope"
81 | _AD_EFFECT_RETENTION_RATE = "ad_effect_retention_rate"
82 | _PEAK_EFFECT_DELAY = "peak_effect_delay"
83 |
84 | TRANSFORM_PRIORS_NAMES = immutabledict.immutabledict({
85 | "carryover":
86 | frozenset((_AD_EFFECT_RETENTION_RATE, _PEAK_EFFECT_DELAY, _EXPONENT)),
87 | "adstock":
88 | frozenset((_EXPONENT, _LAG_WEIGHT)),
89 | "hill_adstock":
90 | frozenset((_LAG_WEIGHT, _HALF_MAX_EFFECTIVE_CONCENTRATION, _SLOPE))
91 | })
92 |
93 | GEO_ONLY_PRIORS = frozenset((_COEF_SEASONALITY,))
94 |
95 |
96 | def _get_default_priors() -> Mapping[str, Prior]:
97 | # Since JAX cannot be called before absl.app.run in tests we get default
98 | # priors from a function.
99 | return immutabledict.immutabledict({
100 | _INTERCEPT: dist.HalfNormal(scale=2.),
101 | _COEF_TREND: dist.Normal(loc=0., scale=1.),
102 | _EXPO_TREND: dist.Uniform(low=0.5, high=1.5),
103 | _SIGMA: dist.Gamma(concentration=1., rate=1.),
104 | _GAMMA_SEASONALITY: dist.Normal(loc=0., scale=1.),
105 | _WEEKDAY: dist.Normal(loc=0., scale=.5),
106 | _COEF_EXTRA_FEATURES: dist.Normal(loc=0., scale=1.),
107 | _COEF_SEASONALITY: dist.HalfNormal(scale=.5)
108 | })
109 |
110 |
111 | def _get_transform_default_priors() -> Mapping[str, Prior]:
112 | # Since JAX cannot be called before absl.app.run in tests we get default
113 | # priors from a function.
114 | return immutabledict.immutabledict({
115 | "carryover":
116 | immutabledict.immutabledict({
117 | _AD_EFFECT_RETENTION_RATE:
118 | dist.Beta(concentration1=1., concentration0=1.),
119 | _PEAK_EFFECT_DELAY:
120 | dist.HalfNormal(scale=2.),
121 | _EXPONENT:
122 | dist.Beta(concentration1=9., concentration0=1.)
123 | }),
124 | "adstock":
125 | immutabledict.immutabledict({
126 | _EXPONENT: dist.Beta(concentration1=9., concentration0=1.),
127 | _LAG_WEIGHT: dist.Beta(concentration1=2., concentration0=1.)
128 | }),
129 | "hill_adstock":
130 | immutabledict.immutabledict({
131 | _LAG_WEIGHT:
132 | dist.Beta(concentration1=2., concentration0=1.),
133 | _HALF_MAX_EFFECTIVE_CONCENTRATION:
134 | dist.Gamma(concentration=1., rate=1.),
135 | _SLOPE:
136 | dist.Gamma(concentration=1., rate=1.)
137 | })
138 | })
139 |
140 |
141 | def transform_adstock(media_data: jnp.ndarray,
142 | custom_priors: MutableMapping[str, Prior],
143 | normalise: bool = True) -> jnp.ndarray:
144 | """Transforms the input data with the adstock function and exponent.
145 |
146 | Args:
147 | media_data: Media data to be transformed. It is expected to have 2 dims for
148 | national models and 3 for geo models.
149 | custom_priors: The custom priors we want the model to take instead of the
150 | default ones. The possible names of parameters for adstock and exponent
151 | are "lag_weight" and "exponent".
152 | normalise: Whether to normalise the output values.
153 |
154 | Returns:
155 | The transformed media data.
156 | """
157 | transform_default_priors = _get_transform_default_priors()["adstock"]
158 | with numpyro.plate(name=f"{_LAG_WEIGHT}_plate",
159 | size=media_data.shape[1]):
160 | lag_weight = numpyro.sample(
161 | name=_LAG_WEIGHT,
162 | fn=custom_priors.get(_LAG_WEIGHT,
163 | transform_default_priors[_LAG_WEIGHT]))
164 |
165 | with numpyro.plate(name=f"{_EXPONENT}_plate",
166 | size=media_data.shape[1]):
167 | exponent = numpyro.sample(
168 | name=_EXPONENT,
169 | fn=custom_priors.get(_EXPONENT,
170 | transform_default_priors[_EXPONENT]))
171 |
172 | if media_data.ndim == 3:
173 | lag_weight = jnp.expand_dims(lag_weight, axis=-1)
174 | exponent = jnp.expand_dims(exponent, axis=-1)
175 |
176 | adstock = media_transforms.adstock(
177 | data=media_data, lag_weight=lag_weight, normalise=normalise)
178 |
179 | return media_transforms.apply_exponent_safe(data=adstock, exponent=exponent)
180 |
181 |
182 | def transform_hill_adstock(media_data: jnp.ndarray,
183 | custom_priors: MutableMapping[str, Prior],
184 | normalise: bool = True) -> jnp.ndarray:
185 | """Transforms the input data with the adstock and hill functions.
186 |
187 | Args:
188 | media_data: Media data to be transformed. It is expected to have 2 dims for
189 | national models and 3 for geo models.
190 | custom_priors: The custom priors we want the model to take instead of the
191 | default ones. The possible names of parameters for hill_adstock and
192 | exponent are "lag_weight", "half_max_effective_concentration" and "slope".
193 | normalise: Whether to normalise the output values.
194 |
195 | Returns:
196 | The transformed media data.
197 | """
198 | transform_default_priors = _get_transform_default_priors()["hill_adstock"]
199 | with numpyro.plate(name=f"{_LAG_WEIGHT}_plate",
200 | size=media_data.shape[1]):
201 | lag_weight = numpyro.sample(
202 | name=_LAG_WEIGHT,
203 | fn=custom_priors.get(_LAG_WEIGHT,
204 | transform_default_priors[_LAG_WEIGHT]))
205 |
206 | with numpyro.plate(name=f"{_HALF_MAX_EFFECTIVE_CONCENTRATION}_plate",
207 | size=media_data.shape[1]):
208 | half_max_effective_concentration = numpyro.sample(
209 | name=_HALF_MAX_EFFECTIVE_CONCENTRATION,
210 | fn=custom_priors.get(
211 | _HALF_MAX_EFFECTIVE_CONCENTRATION,
212 | transform_default_priors[_HALF_MAX_EFFECTIVE_CONCENTRATION]))
213 |
214 | with numpyro.plate(name=f"{_SLOPE}_plate",
215 | size=media_data.shape[1]):
216 | slope = numpyro.sample(
217 | name=_SLOPE,
218 | fn=custom_priors.get(_SLOPE, transform_default_priors[_SLOPE]))
219 |
220 | if media_data.ndim == 3:
221 | lag_weight = jnp.expand_dims(lag_weight, axis=-1)
222 | half_max_effective_concentration = jnp.expand_dims(
223 | half_max_effective_concentration, axis=-1)
224 | slope = jnp.expand_dims(slope, axis=-1)
225 |
226 | return media_transforms.hill(
227 | data=media_transforms.adstock(
228 | data=media_data, lag_weight=lag_weight, normalise=normalise),
229 | half_max_effective_concentration=half_max_effective_concentration,
230 | slope=slope)
231 |
232 |
233 | def transform_carryover(media_data: jnp.ndarray,
234 | custom_priors: MutableMapping[str, Prior],
235 | number_lags: int = 13) -> jnp.ndarray:
236 | """Transforms the input data with the carryover function and exponent.
237 |
238 | Args:
239 | media_data: Media data to be transformed. It is expected to have 2 dims for
240 | national models and 3 for geo models.
241 | custom_priors: The custom priors we want the model to take instead of the
242 | default ones. The possible names of parameters for carryover and exponent
243 | are "ad_effect_retention_rate_plate", "peak_effect_delay_plate" and
244 | "exponent".
245 | number_lags: Number of lags for the carryover function.
246 |
247 | Returns:
248 | The transformed media data.
249 | """
250 | transform_default_priors = _get_transform_default_priors()["carryover"]
251 | with numpyro.plate(name=f"{_AD_EFFECT_RETENTION_RATE}_plate",
252 | size=media_data.shape[1]):
253 | ad_effect_retention_rate = numpyro.sample(
254 | name=_AD_EFFECT_RETENTION_RATE,
255 | fn=custom_priors.get(
256 | _AD_EFFECT_RETENTION_RATE,
257 | transform_default_priors[_AD_EFFECT_RETENTION_RATE]))
258 |
259 | with numpyro.plate(name=f"{_PEAK_EFFECT_DELAY}_plate",
260 | size=media_data.shape[1]):
261 | peak_effect_delay = numpyro.sample(
262 | name=_PEAK_EFFECT_DELAY,
263 | fn=custom_priors.get(
264 | _PEAK_EFFECT_DELAY, transform_default_priors[_PEAK_EFFECT_DELAY]))
265 |
266 | with numpyro.plate(name=f"{_EXPONENT}_plate",
267 | size=media_data.shape[1]):
268 | exponent = numpyro.sample(
269 | name=_EXPONENT,
270 | fn=custom_priors.get(_EXPONENT,
271 | transform_default_priors[_EXPONENT]))
272 | carryover = media_transforms.carryover(
273 | data=media_data,
274 | ad_effect_retention_rate=ad_effect_retention_rate,
275 | peak_effect_delay=peak_effect_delay,
276 | number_lags=number_lags)
277 |
278 | if media_data.ndim == 3:
279 | exponent = jnp.expand_dims(exponent, axis=-1)
280 | return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent)
281 |
282 |
283 | def media_mix_model(
284 | media_data: jnp.ndarray,
285 | target_data: jnp.ndarray,
286 | media_prior: jnp.ndarray,
287 | degrees_seasonality: int,
288 | frequency: int,
289 | transform_function: TransformFunction,
290 | custom_priors: MutableMapping[str, Prior],
291 | transform_kwargs: Optional[MutableMapping[str, Any]] = None,
292 | weekday_seasonality: bool = False,
293 | extra_features: Optional[jnp.ndarray] = None,
294 | ) -> None:
295 | """Media mix model.
296 |
297 | Args:
298 | media_data: Media data to be be used in the model.
299 | target_data: Target data for the model.
300 | media_prior: Cost prior for each of the media channels.
301 | degrees_seasonality: Number of degrees of seasonality to use.
302 | frequency: Frequency of the time span which was used to aggregate the data.
303 | Eg. if weekly data then frequency is 52.
304 | transform_function: Function to use to transform the media data in the
305 | model. Currently the following are supported: 'transform_adstock',
306 | 'transform_carryover' and 'transform_hill_adstock'.
307 | custom_priors: The custom priors we want the model to take instead of the
308 | default ones. See our custom_priors documentation for details about the
309 | API and possible options.
310 | transform_kwargs: Any extra keyword arguments to pass to the transform
311 | function. For example the adstock function can take a boolean to noramlise
312 | output or not.
313 | weekday_seasonality: In case of daily data you can estimate a weekday (7)
314 | parameter.
315 | extra_features: Extra features data to include in the model.
316 | """
317 | default_priors = _get_default_priors()
318 | data_size = media_data.shape[0]
319 | n_channels = media_data.shape[1]
320 | geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else ()
321 | n_geos = media_data.shape[2] if media_data.ndim == 3 else 1
322 |
323 | with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos):
324 | intercept = numpyro.sample(
325 | name=_INTERCEPT,
326 | fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT]))
327 |
328 | with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos):
329 | sigma = numpyro.sample(
330 | name=_SIGMA,
331 | fn=custom_priors.get(_SIGMA, default_priors[_SIGMA]))
332 |
333 | # TODO(): Force all geos to have the same trend sign.
334 | with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos):
335 | coef_trend = numpyro.sample(
336 | name=_COEF_TREND,
337 | fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND]))
338 |
339 | expo_trend = numpyro.sample(
340 | name=_EXPO_TREND,
341 | fn=custom_priors.get(
342 | _EXPO_TREND, default_priors[_EXPO_TREND]))
343 |
344 | with numpyro.plate(
345 | name="channel_media_plate",
346 | size=n_channels,
347 | dim=-2 if media_data.ndim == 3 else -1):
348 | coef_media = numpyro.sample(
349 | name="channel_coef_media" if media_data.ndim == 3 else "coef_media",
350 | fn=dist.HalfNormal(scale=media_prior))
351 | if media_data.ndim == 3:
352 | with numpyro.plate(
353 | name="geo_media_plate",
354 | size=n_geos,
355 | dim=-1):
356 | # Corrects the mean to be the same as in the channel only case.
357 | normalisation_factor = jnp.sqrt(2.0 / jnp.pi)
358 | coef_media = numpyro.sample(
359 | name="coef_media",
360 | fn=dist.HalfNormal(scale=coef_media * normalisation_factor)
361 | )
362 |
363 | with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2):
364 | with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",
365 | size=degrees_seasonality):
366 | gamma_seasonality = numpyro.sample(
367 | name=_GAMMA_SEASONALITY,
368 | fn=custom_priors.get(
369 | _GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY]))
370 |
371 | if weekday_seasonality:
372 | with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7):
373 | weekday = numpyro.sample(
374 | name=_WEEKDAY,
375 | fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY]))
376 | weekday_series = weekday[jnp.arange(data_size) % 7]
377 | # In case of daily data, number of lags should be 13*7.
378 | if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs:
379 | transform_kwargs["number_lags"] = 13 * 7
380 | elif transform_function == "carryover" and not transform_kwargs:
381 | transform_kwargs = {"number_lags": 13 * 7}
382 |
383 | media_transformed = numpyro.deterministic(
384 | name="media_transformed",
385 | value=transform_function(media_data,
386 | custom_priors=custom_priors,
387 | **transform_kwargs if transform_kwargs else {}))
388 | seasonality = media_transforms.calculate_seasonality(
389 | number_periods=data_size,
390 | degrees=degrees_seasonality,
391 | frequency=frequency,
392 | gamma_seasonality=gamma_seasonality)
393 | # For national model's case
394 | trend = jnp.arange(data_size)
395 | media_einsum = "tc, c -> t" # t = time, c = channel
396 | coef_seasonality = 1
397 |
398 | # TODO(): Add conversion of prior for HalfNormal distribution.
399 | if media_data.ndim == 3: # For geo model's case
400 | trend = jnp.expand_dims(trend, axis=-1)
401 | seasonality = jnp.expand_dims(seasonality, axis=-1)
402 | media_einsum = "tcg, cg -> tg" # t = time, c = channel, g = geo
403 | if weekday_seasonality:
404 | weekday_series = jnp.expand_dims(weekday_series, axis=-1)
405 | with numpyro.plate(name="seasonality_plate", size=n_geos):
406 | coef_seasonality = numpyro.sample(
407 | name=_COEF_SEASONALITY,
408 | fn=custom_priors.get(
409 | _COEF_SEASONALITY, default_priors[_COEF_SEASONALITY]))
410 | # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
411 | prediction = (
412 | intercept + coef_trend * trend ** expo_trend +
413 | seasonality * coef_seasonality +
414 | jnp.einsum(media_einsum, media_transformed, coef_media))
415 | if extra_features is not None:
416 | plate_prefixes = ("extra_feature",)
417 | extra_features_einsum = "tf, f -> t" # t = time, f = feature
418 | extra_features_plates_shape = (extra_features.shape[1],)
419 | if extra_features.ndim == 3:
420 | plate_prefixes = ("extra_feature", "geo")
421 | extra_features_einsum = "tfg, fg -> tg" # t = time, f = feature, g = geo
422 | extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
423 | with numpyro.plate_stack(
424 | str(plate_prefixes), sizes=list(extra_features_plates_shape)
425 | ):
426 | coef_extra_features = numpyro.sample(
427 | name=_COEF_EXTRA_FEATURES,
428 | fn=custom_priors.get(
429 | _COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES]))
430 | extra_features_effect = jnp.einsum(extra_features_einsum,
431 | extra_features,
432 | coef_extra_features)
433 | prediction += extra_features_effect
434 |
435 | if weekday_seasonality:
436 | prediction += weekday_series
437 | mu = numpyro.deterministic(name="mu", value=prediction)
438 |
439 | numpyro.sample(
440 | name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)
441 |
--------------------------------------------------------------------------------
/lightweight_mmm/models_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for models."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpyro
22 | from numpyro import distributions as dist
23 | from numpyro import handlers
24 |
25 | from lightweight_mmm import models
26 |
27 |
28 | class ModelsTest(parameterized.TestCase):
29 |
30 | @parameterized.named_parameters(
31 | dict(testcase_name="one_channel", shape=(10, 1)),
32 | dict(testcase_name="five_channel", shape=(10, 5)),
33 | dict(testcase_name="same_channels_as_rows", shape=(10, 10)),
34 | dict(testcase_name="geo_shape_1", shape=(10, 10, 5)),
35 | dict(testcase_name="geo_shape_2", shape=(10, 5, 2)),
36 | dict(testcase_name="one_channel_one_row", shape=(1, 1)))
37 | def test_transform_adstock_produces_correct_output_shape(self, shape):
38 |
39 | def mock_model_function(media_data):
40 | numpyro.deterministic(
41 | "transformed_media",
42 | models.transform_adstock(media_data, custom_priors={}))
43 |
44 | media = jnp.ones(shape)
45 | kernel = numpyro.infer.NUTS(model=mock_model_function)
46 | mcmc = numpyro.infer.MCMC(
47 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1)
48 | rng_key = jax.random.PRNGKey(0)
49 |
50 | mcmc.run(rng_key, media_data=media)
51 | transformed_media = mcmc.get_samples()["transformed_media"].mean(axis=0)
52 |
53 | self.assertEqual(media.shape, transformed_media.shape)
54 |
55 | @parameterized.named_parameters(
56 | dict(testcase_name="one_channel", shape=(10, 1)),
57 | dict(testcase_name="five_channel", shape=(10, 5)),
58 | dict(testcase_name="same_channels_as_rows", shape=(10, 10)),
59 | dict(testcase_name="geo_shape_1", shape=(10, 10, 5)),
60 | dict(testcase_name="geo_shape_2", shape=(10, 5, 2)),
61 | dict(testcase_name="one_channel_one_row", shape=(1, 1)))
62 | def test_transform_hill_adstock_produces_correct_output_shape(self, shape):
63 |
64 | def mock_model_function(media_data):
65 | numpyro.deterministic(
66 | "transformed_media",
67 | models.transform_hill_adstock(media_data, custom_priors={}))
68 |
69 | media = jnp.ones(shape)
70 | kernel = numpyro.infer.NUTS(model=mock_model_function)
71 | mcmc = numpyro.infer.MCMC(
72 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1)
73 | rng_key = jax.random.PRNGKey(0)
74 |
75 | mcmc.run(rng_key, media_data=media)
76 | transformed_media = mcmc.get_samples()["transformed_media"].mean(axis=0)
77 |
78 | self.assertEqual(media.shape, transformed_media.shape)
79 |
80 | @parameterized.named_parameters(
81 | dict(testcase_name="one_channel", shape=(10, 1)),
82 | dict(testcase_name="five_channel", shape=(10, 5)),
83 | dict(testcase_name="same_channels_as_rows", shape=(10, 10)),
84 | dict(testcase_name="geo_shape_1", shape=(10, 10, 5)),
85 | dict(testcase_name="geo_shape_2", shape=(10, 5, 2)),
86 | dict(testcase_name="one_channel_one_row", shape=(1, 1)))
87 | def test_transform_carryover_produces_correct_output_shape(self, shape):
88 |
89 | def mock_model_function(media_data):
90 | numpyro.deterministic(
91 | "transformed_media",
92 | models.transform_carryover(media_data, custom_priors={}))
93 |
94 | media = jnp.ones(shape)
95 | kernel = numpyro.infer.NUTS(model=mock_model_function)
96 | mcmc = numpyro.infer.MCMC(
97 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1)
98 | rng_key = jax.random.PRNGKey(0)
99 |
100 | mcmc.run(rng_key, media_data=media)
101 | transformed_media = mcmc.get_samples()["transformed_media"].mean(axis=0)
102 |
103 | self.assertEqual(media.shape, transformed_media.shape)
104 |
105 | @parameterized.named_parameters(
106 | dict(
107 | testcase_name="national_no_extra",
108 | media_shape=(10, 3),
109 | extra_features_shape=(),
110 | target_shape=(10,),
111 | total_costs_shape=(3,)),
112 | dict(
113 | testcase_name="national_extra",
114 | media_shape=(10, 5),
115 | extra_features_shape=(10, 2),
116 | target_shape=(10,),
117 | total_costs_shape=(5,)),
118 | dict(
119 | testcase_name="geo_extra_3d",
120 | media_shape=(10, 7, 5),
121 | extra_features_shape=(10, 8, 5),
122 | target_shape=(10, 5),
123 | total_costs_shape=(7, 1)),
124 | dict(
125 | testcase_name="geo_no_extra",
126 | media_shape=(10, 7, 5),
127 | extra_features_shape=(),
128 | target_shape=(10, 5),
129 | total_costs_shape=(7, 1)))
130 | def test_media_mix_model_parameters_have_correct_shapes(
131 | self, media_shape, extra_features_shape, target_shape, total_costs_shape):
132 | media = jnp.ones(media_shape)
133 | extra_features = None if not extra_features_shape else jnp.ones(
134 | extra_features_shape)
135 | costs_prior = jnp.ones(total_costs_shape)
136 | degrees = 2
137 | target = jnp.ones(target_shape)
138 | kernel = numpyro.infer.NUTS(model=models.media_mix_model)
139 | mcmc = numpyro.infer.MCMC(
140 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1)
141 | rng_key = jax.random.PRNGKey(0)
142 |
143 | mcmc.run(
144 | rng_key,
145 | media_data=media,
146 | extra_features=extra_features,
147 | target_data=target,
148 | media_prior=costs_prior,
149 | degrees_seasonality=degrees,
150 | custom_priors={},
151 | frequency=52,
152 | transform_function=models.transform_carryover)
153 | trace = mcmc.get_samples()
154 |
155 | self.assertEqual(
156 | jnp.squeeze(trace["intercept"].mean(axis=0)).shape, target_shape[1:])
157 | self.assertEqual(
158 | jnp.squeeze(trace["sigma"].mean(axis=0)).shape, target_shape[1:])
159 | self.assertEqual(
160 | jnp.squeeze(trace["expo_trend"].mean(axis=0)).shape, ())
161 | self.assertEqual(
162 | jnp.squeeze(trace["coef_trend"].mean(axis=0)).shape, target_shape[1:])
163 | self.assertEqual(
164 | jnp.squeeze(trace["coef_media"].mean(axis=0)).shape, media_shape[1:])
165 | if extra_features_shape:
166 | self.assertEqual(trace["coef_extra_features"].mean(axis=0).shape,
167 | extra_features.shape[1:])
168 | self.assertEqual(trace["gamma_seasonality"].mean(axis=0).shape,
169 | (degrees, 2))
170 | self.assertEqual(trace["mu"].mean(axis=0).shape, target_shape)
171 |
172 | @parameterized.named_parameters(
173 | dict(
174 | testcase_name=f"model_{models._INTERCEPT}",
175 | prior_name=models._INTERCEPT,
176 | transform_function=models.transform_carryover),
177 | dict(
178 | testcase_name=f"model_{models._COEF_TREND}",
179 | prior_name=models._COEF_TREND,
180 | transform_function=models.transform_carryover),
181 | dict(
182 | testcase_name=f"model_{models._EXPO_TREND}",
183 | prior_name=models._EXPO_TREND,
184 | transform_function=models.transform_carryover),
185 | dict(
186 | testcase_name=f"model_{models._SIGMA}",
187 | prior_name=models._SIGMA,
188 | transform_function=models.transform_carryover),
189 | dict(
190 | testcase_name=f"model_{models._GAMMA_SEASONALITY}",
191 | prior_name=models._GAMMA_SEASONALITY,
192 | transform_function=models.transform_carryover),
193 | dict(
194 | testcase_name=f"model_{models._WEEKDAY}",
195 | prior_name=models._WEEKDAY,
196 | transform_function=models.transform_carryover),
197 | dict(
198 | testcase_name=f"model_{models._COEF_EXTRA_FEATURES}",
199 | prior_name=models._COEF_EXTRA_FEATURES,
200 | transform_function=models.transform_carryover),
201 | dict(
202 | testcase_name=f"model_{models._COEF_SEASONALITY}",
203 | prior_name=models._COEF_SEASONALITY,
204 | transform_function=models.transform_carryover),
205 | dict(
206 | testcase_name=f"carryover_{models._AD_EFFECT_RETENTION_RATE}",
207 | prior_name=models._AD_EFFECT_RETENTION_RATE,
208 | transform_function=models.transform_carryover),
209 | dict(
210 | testcase_name=f"carryover_{models._PEAK_EFFECT_DELAY}",
211 | prior_name=models._PEAK_EFFECT_DELAY,
212 | transform_function=models.transform_carryover),
213 | dict(
214 | testcase_name=f"carryover_{models._EXPONENT}",
215 | prior_name=models._EXPONENT,
216 | transform_function=models.transform_carryover),
217 | dict(
218 | testcase_name=f"adstock_{models._EXPONENT}",
219 | prior_name=models._EXPONENT,
220 | transform_function=models.transform_adstock),
221 | dict(
222 | testcase_name=f"adstock_{models._LAG_WEIGHT}",
223 | prior_name=models._LAG_WEIGHT,
224 | transform_function=models.transform_adstock),
225 | dict(
226 | testcase_name=f"hilladstock_{models._LAG_WEIGHT}",
227 | prior_name=models._LAG_WEIGHT,
228 | transform_function=models.transform_hill_adstock),
229 | dict(
230 | testcase_name=f"hilladstock_{models._HALF_MAX_EFFECTIVE_CONCENTRATION}",
231 | prior_name=models._HALF_MAX_EFFECTIVE_CONCENTRATION,
232 | transform_function=models.transform_hill_adstock),
233 | dict(
234 | testcase_name=f"hilladstock_{models._SLOPE}",
235 | prior_name=models._SLOPE,
236 | transform_function=models.transform_hill_adstock),
237 | )
238 | def test_media_mix_model_custom_priors_are_taken_correctly(
239 | self, prior_name, transform_function):
240 | expected_value1, expected_value2 = 5.2, 7.56
241 | custom_priors = {
242 | prior_name: dist.Kumaraswamy(
243 | concentration1=expected_value1, concentration0=expected_value2)}
244 | media = jnp.ones((10, 5, 5))
245 | extra_features = jnp.ones((10, 3, 5))
246 | costs_prior = jnp.ones((5, 1))
247 | target = jnp.ones((10, 5))
248 |
249 | trace_handler = handlers.trace(handlers.seed(
250 | models.media_mix_model, rng_seed=0))
251 | trace = trace_handler.get_trace(
252 | media_data=media,
253 | extra_features=extra_features,
254 | target_data=target,
255 | media_prior=costs_prior,
256 | custom_priors=custom_priors,
257 | degrees_seasonality=2,
258 | frequency=52,
259 | transform_function=transform_function,
260 | weekday_seasonality=True
261 | )
262 | values_and_dists = {
263 | name: site["fn"]
264 | for name, site in trace.items() if "fn" in site
265 | }
266 |
267 | used_distribution = values_and_dists[prior_name]
268 | if isinstance(used_distribution, dist.ExpandedDistribution):
269 | used_distribution = used_distribution.base_dist
270 | self.assertIsInstance(used_distribution, dist.Kumaraswamy)
271 | self.assertEqual(used_distribution.concentration0, expected_value2)
272 | self.assertEqual(used_distribution.concentration1, expected_value1)
273 |
274 |
275 | if __name__ == "__main__":
276 | absltest.main()
277 |
--------------------------------------------------------------------------------
/lightweight_mmm/optimize_media.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for optimizing your media based on media mix models."""
16 | import functools
17 | from typing import Optional, Tuple, Union
18 | from absl import logging
19 | import jax
20 | import jax.numpy as jnp
21 | from scipy import optimize
22 |
23 | from lightweight_mmm import lightweight_mmm
24 | from lightweight_mmm import preprocessing
25 |
26 |
27 | @functools.partial(
28 | jax.jit,
29 | static_argnames=("media_mix_model", "media_input_shape", "target_scaler",
30 | "media_scaler"))
31 | def _objective_function(
32 | extra_features: jnp.ndarray,
33 | media_mix_model: lightweight_mmm.LightweightMMM,
34 | media_input_shape: Tuple[int, int],
35 | media_gap: Optional[int],
36 | target_scaler: Optional[preprocessing.CustomScaler],
37 | media_scaler: preprocessing.CustomScaler,
38 | geo_ratio: jnp.ndarray,
39 | seed: Optional[int],
40 | media_values: jnp.ndarray,
41 | ) -> jnp.float64:
42 | """Objective function to calculate the sum of all predictions of the model.
43 |
44 | Args:
45 | extra_features: Extra features the model requires for prediction.
46 | media_mix_model: Media mix model to use. Must have a predict method to be
47 | used.
48 | media_input_shape: Input shape of the data required by the model to get
49 | predictions. This is needed since optimization might flatten some arrays
50 | and they need to be reshaped before running new predictions.
51 | media_gap: Media data gap between the end of training data and the start of
52 | the out of sample media given. Eg. if 100 weeks of data were used for
53 | training and prediction starts 2 months after training data finished we
54 | need to provide the 8 weeks missing between the training data and the
55 | prediction data so data transformations (adstock, carryover, ...) can take
56 | place correctly.
57 | target_scaler: Scaler that was used to scale the target before training.
58 | media_scaler: Scaler that was used to scale the media data before training.
59 | geo_ratio: The ratio to split channel media across geo. Should sum up to 1
60 | for each channel and should have shape (c, g).
61 | seed: Seed to use for PRNGKey during sampling. For replicability run
62 | this function and any other function that gets predictions with the same
63 | seed.
64 | media_values: Media values required by the model to run predictions.
65 |
66 | Returns:
67 | The negative value of the sum of all predictions.
68 | """
69 | if hasattr(media_mix_model, "n_geos") and media_mix_model.n_geos > 1:
70 | media_values = geo_ratio * jnp.expand_dims(media_values, axis=-1)
71 | media_values = jnp.tile(
72 | media_values / media_input_shape[0], reps=media_input_shape[0])
73 | # Distribute budget of each channels across time.
74 | media_values = jnp.reshape(a=media_values, shape=media_input_shape)
75 | media_values = media_scaler.transform(media_values)
76 | return -jnp.sum(
77 | media_mix_model.predict(
78 | media=media_values.reshape(media_input_shape),
79 | extra_features=extra_features,
80 | media_gap=media_gap,
81 | target_scaler=target_scaler,
82 | seed=seed).mean(axis=0))
83 |
84 |
85 | @jax.jit
86 | def _budget_constraint(media: jnp.ndarray,
87 | prices: jnp.ndarray,
88 | budget: jnp.ndarray) -> jnp.float64:
89 | """Calculates optimization constraint to keep spend equal to the budget.
90 |
91 | Args:
92 | media: Array with the values of the media for this iteration.
93 | prices: Prices of each media channel at any given time.
94 | budget: Total budget of the optimization.
95 |
96 | Returns:
97 | The result from substracting the total spending and the budget.
98 | """
99 | media = media.reshape((-1, len(prices)))
100 | return jnp.sum(media * prices) - budget
101 |
102 |
103 | def _get_lower_and_upper_bounds(
104 | media: jnp.ndarray,
105 | n_time_periods: int,
106 | lower_pct: jnp.ndarray,
107 | upper_pct: jnp.ndarray,
108 | media_scaler: Optional[preprocessing.CustomScaler] = None
109 | ) -> optimize.Bounds:
110 | """Gets the lower and upper bounds for optimisation based on historic data.
111 |
112 | It creates an upper bound based on a percentage above the mean value on
113 | each channel and a lower bound based on a relative decrease of the mean
114 | value.
115 |
116 | Args:
117 | media: Media data to get historic mean.
118 | n_time_periods: Number of time periods to optimize for. If model is built on
119 | weekly data, this would be the number of weeks ahead to optimize.
120 | lower_pct: Relative percentage decrease from the mean value to consider as
121 | new lower bound.
122 | upper_pct: Relative percentage increase from the mean value to consider as
123 | new upper bound.
124 | media_scaler: Scaler that was used to scale the media data before training.
125 |
126 | Returns:
127 | A list of tuples with the lower and upper bound for each media channel.
128 | """
129 | if media.ndim == 3:
130 | lower_pct = jnp.expand_dims(lower_pct, axis=-1)
131 | upper_pct = jnp.expand_dims(upper_pct, axis=-1)
132 |
133 | mean_data = media.mean(axis=0)
134 | lower_bounds = jnp.maximum(mean_data * (1 - lower_pct), 0)
135 | upper_bounds = mean_data * (1 + upper_pct)
136 |
137 | if media_scaler:
138 | lower_bounds = media_scaler.inverse_transform(lower_bounds)
139 | upper_bounds = media_scaler.inverse_transform(upper_bounds)
140 |
141 | if media.ndim == 3:
142 | lower_bounds = lower_bounds.sum(axis=-1)
143 | upper_bounds = upper_bounds.sum(axis=-1)
144 |
145 | return optimize.Bounds(lb=lower_bounds * n_time_periods,
146 | ub=upper_bounds * n_time_periods)
147 |
148 |
149 | def _generate_starting_values(
150 | n_time_periods: int, media: jnp.ndarray,
151 | media_scaler: preprocessing.CustomScaler,
152 | budget: Union[float, int],
153 | prices: jnp.ndarray,
154 | ) -> jnp.ndarray:
155 | """Generates starting values based on historic allocation and budget.
156 |
157 | In order to make a comparison we can take the allocation of the last
158 | `n_time_periods` and scale it based on the given budget. Given this, one can
159 | compare how this initial values (based on average historic allocation) compare
160 | to the output of the optimisation in terms of sales/KPI.
161 |
162 | Args:
163 | n_time_periods: Number of time periods the optimization will be done with.
164 | media: Historic media data the model was trained with.
165 | media_scaler: Scaler that was used to scale the media data before training.
166 | budget: Total budget to allocate during the optimization time.
167 | prices: An array with shape (n_media_channels,) for the cost of each media
168 | channel unit.
169 |
170 | Returns:
171 | An array with the starting value for each media channel for the
172 | optimization.
173 | """
174 | previous_allocation = media.mean(axis=0) * n_time_periods
175 | if media_scaler: # Scale before sum as geo scaler has shape (c, g).
176 | previous_allocation = media_scaler.inverse_transform(previous_allocation)
177 |
178 | if media.ndim == 3:
179 | previous_allocation = previous_allocation.sum(axis=-1)
180 |
181 | avg_spend_per_channel = previous_allocation * prices
182 | pct_spend_per_channel = avg_spend_per_channel / avg_spend_per_channel.sum()
183 | budget_per_channel = budget * pct_spend_per_channel
184 | media_unit_per_channel = budget_per_channel / prices
185 | return media_unit_per_channel
186 |
187 |
188 | def find_optimal_budgets(
189 | n_time_periods: int,
190 | media_mix_model: lightweight_mmm.LightweightMMM,
191 | budget: Union[float, int],
192 | prices: jnp.ndarray,
193 | extra_features: Optional[jnp.ndarray] = None,
194 | media_gap: Optional[jnp.ndarray] = None,
195 | target_scaler: Optional[preprocessing.CustomScaler] = None,
196 | media_scaler: Optional[preprocessing.CustomScaler] = None,
197 | bounds_lower_pct: Union[float, jnp.ndarray] = .2,
198 | bounds_upper_pct: Union[float, jnp.ndarray] = .2,
199 | max_iterations: int = 200,
200 | solver_func_tolerance: float = 1e-06,
201 | solver_step_size: float = 1.4901161193847656e-08,
202 | seed: Optional[int] = None) -> optimize.OptimizeResult:
203 | """Finds the best media allocation based on MMM model, prices and a budget.
204 |
205 | Args:
206 | n_time_periods: Number of time periods to optimize for. If model is built on
207 | weekly data, this would be the number of weeks ahead to optimize.
208 | media_mix_model: Media mix model to use for the optimization.
209 | budget: Total budget to allocate during the optimization time.
210 | prices: An array with shape (n_media_channels,) for the cost of each media
211 | channel unit.
212 | extra_features: Extra features needed for the model to predict.
213 | media_gap: Media data gap between the end of training data and the start of
214 | the out of sample media given. Eg. if 100 weeks of data were used for
215 | training and prediction starts 8 weeks after training data finished we
216 | need to provide the 8 weeks missing between the training data and the
217 | prediction data so data transformations (adstock, carryover, ...) can take
218 | place correctly.
219 | target_scaler: Scaler that was used to scale the target before training.
220 | media_scaler: Scaler that was used to scale the media data before training.
221 | bounds_lower_pct: Relative percentage decrease from the mean value to
222 | consider as new lower bound.
223 | bounds_upper_pct: Relative percentage increase from the mean value to
224 | consider as new upper bound.
225 | max_iterations: Number of max iterations to use for the SLSQP scipy
226 | optimizer. Default is 200.
227 | solver_func_tolerance: Precision goal for the value of the prediction in
228 | the stopping criterion. Maps directly to scipy's `ftol`. Intended only
229 | for advanced users. For more details see:
230 | https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.
231 | solver_step_size: Step size used for numerical approximation of the
232 | Jacobian. Maps directly to scipy's `eps`. Intended only for advanced
233 | users. For more details see:
234 | https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.
235 | seed: Seed to use for PRNGKey during sampling. For replicability run
236 | this function and any other function that gets predictions with the same
237 | seed.
238 |
239 | Returns:
240 | solution: OptimizeResult object containing the results of the optimization.
241 | kpi_without_optim: Predicted target based on original allocation proportion
242 | among channels from the historical data.
243 | starting_values: Budget Allocation based on original allocation proportion
244 | and the given total budget.
245 | """
246 | if not hasattr(media_mix_model, "media"):
247 | raise ValueError(
248 | "The passed model has not been trained. Please fit the model before "
249 | "running optimization.")
250 | jax.config.update("jax_enable_x64", True)
251 |
252 | if isinstance(bounds_lower_pct, float):
253 | bounds_lower_pct = jnp.repeat(a=bounds_lower_pct, repeats=len(prices))
254 | if isinstance(bounds_upper_pct, float):
255 | bounds_upper_pct = jnp.repeat(a=bounds_upper_pct, repeats=len(prices))
256 |
257 | bounds = _get_lower_and_upper_bounds(
258 | media=media_mix_model.media,
259 | n_time_periods=n_time_periods,
260 | lower_pct=bounds_lower_pct,
261 | upper_pct=bounds_upper_pct,
262 | media_scaler=media_scaler)
263 | if jnp.sum(bounds.lb * prices) > budget:
264 | logging.warning(
265 | "Budget given is smaller than the lower bounds of the constraints for "
266 | "optimization. This will lead to faulty optimization. Please either "
267 | "increase the budget or change the lower bound by increasing the "
268 | "percentage decrease with the `bounds_lower_pct` parameter.")
269 | if jnp.sum(bounds.ub * prices) < budget:
270 | logging.warning(
271 | "Budget given is larger than the upper bounds of the constraints for "
272 | "optimization. This will lead to faulty optimization. Please either "
273 | "reduce the budget or change the upper bound by increasing the "
274 | "percentage increase with the `bounds_upper_pct` parameter.")
275 |
276 | starting_values = _generate_starting_values(
277 | n_time_periods=n_time_periods,
278 | media=media_mix_model.media,
279 | media_scaler=media_scaler,
280 | budget=budget,
281 | prices=prices,
282 | )
283 | if not media_scaler:
284 | media_scaler = preprocessing.CustomScaler(multiply_by=1, divide_by=1)
285 | if media_mix_model.n_geos == 1:
286 | geo_ratio = 1.0
287 | else:
288 | average_per_time = media_mix_model.media.mean(axis=0)
289 | geo_ratio = average_per_time / jnp.expand_dims(
290 | average_per_time.sum(axis=-1), axis=-1)
291 | media_input_shape = (n_time_periods, *media_mix_model.media.shape[1:])
292 | partial_objective_function = functools.partial(
293 | _objective_function, extra_features, media_mix_model,
294 | media_input_shape, media_gap,
295 | target_scaler, media_scaler, geo_ratio, seed)
296 | solution = optimize.minimize(
297 | fun=partial_objective_function,
298 | x0=starting_values,
299 | bounds=bounds,
300 | method="SLSQP",
301 | jac="3-point",
302 | options={
303 | "maxiter": max_iterations,
304 | "disp": True,
305 | "ftol": solver_func_tolerance,
306 | "eps": solver_step_size,
307 | },
308 | constraints={
309 | "type": "eq",
310 | "fun": _budget_constraint,
311 | "args": (prices, budget)
312 | })
313 |
314 | kpi_without_optim = _objective_function(extra_features=extra_features,
315 | media_mix_model=media_mix_model,
316 | media_input_shape=media_input_shape,
317 | media_gap=media_gap,
318 | target_scaler=target_scaler,
319 | media_scaler=media_scaler,
320 | seed=seed,
321 | geo_ratio=geo_ratio,
322 | media_values=starting_values)
323 | logging.info("KPI without optimization: %r", -1 * kpi_without_optim.item())
324 | logging.info("KPI with optimization: %r", -1 * solution.fun)
325 |
326 | jax.config.update("jax_enable_x64", False)
327 | # TODO(yukaabe): Create an object to contain the results of this function.
328 | return solution, kpi_without_optim, starting_values
329 |
--------------------------------------------------------------------------------
/lightweight_mmm/optimize_media_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for optimize_media."""
16 | from unittest import mock
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 | from lightweight_mmm import lightweight_mmm
25 | from lightweight_mmm import optimize_media
26 | from lightweight_mmm import preprocessing
27 |
28 |
29 | class OptimizeMediaTest(parameterized.TestCase):
30 |
31 | @classmethod
32 | def setUpClass(cls):
33 | super(OptimizeMediaTest, cls).setUpClass()
34 | cls.national_mmm = lightweight_mmm.LightweightMMM()
35 | cls.national_mmm.fit(
36 | media=jnp.ones((50, 5)),
37 | target=jnp.ones(50),
38 | media_prior=jnp.ones(5) * 50,
39 | number_warmup=2,
40 | number_samples=2,
41 | number_chains=1)
42 | cls.geo_mmm = lightweight_mmm.LightweightMMM()
43 | cls.geo_mmm.fit(
44 | media=jnp.ones((50, 5, 3)),
45 | target=jnp.ones((50, 3)),
46 | media_prior=jnp.ones(5) * 50,
47 | number_warmup=2,
48 | number_samples=2,
49 | number_chains=1)
50 |
51 | def setUp(self):
52 | super().setUp()
53 | self.mock_minimize = self.enter_context(
54 | mock.patch.object(optimize_media.optimize, "minimize", autospec=True))
55 |
56 | @parameterized.named_parameters([
57 | dict(
58 | testcase_name="national",
59 | model_name="national_mmm",
60 | geo_ratio=1),
61 | dict(
62 | testcase_name="geo",
63 | model_name="geo_mmm",
64 | geo_ratio=np.tile(0.33, reps=(5, 3)))
65 | ])
66 | def test_objective_function_generates_correct_value_type_and_sign(
67 | self, model_name, geo_ratio):
68 |
69 | mmm = getattr(self, model_name)
70 | extra_features = mmm._extra_features
71 | time_periods = 10
72 |
73 | kpi_predicted = optimize_media._objective_function(
74 | extra_features=extra_features,
75 | media_mix_model=mmm,
76 | media_input_shape=(time_periods, *mmm.media.shape[1:]),
77 | media_gap=None,
78 | target_scaler=None,
79 | media_scaler=preprocessing.CustomScaler(),
80 | media_values=jnp.ones(mmm.n_media_channels) * time_periods,
81 | geo_ratio=geo_ratio,
82 | seed=10)
83 |
84 | self.assertIsInstance(kpi_predicted, jax.Array)
85 | self.assertLessEqual(kpi_predicted, 0)
86 | self.assertEqual(kpi_predicted.shape, ())
87 |
88 | @parameterized.named_parameters([
89 | dict(
90 | testcase_name="zero_output",
91 | media=np.ones(9),
92 | prices=np.array([1, 2, 3]),
93 | budget=18,
94 | expected_value=0),
95 | dict(
96 | testcase_name="negative_output",
97 | media=np.ones(9),
98 | prices=np.array([1, 2, 3]),
99 | budget=20,
100 | expected_value=-2),
101 | dict(
102 | testcase_name="positive_output",
103 | media=np.ones(9),
104 | prices=np.array([1, 2, 3]),
105 | budget=16,
106 | expected_value=2),
107 | dict(
108 | testcase_name="bigger_array",
109 | media=np.ones(18),
110 | prices=np.array([2, 2, 2]),
111 | budget=36,
112 | expected_value=0),
113 | ])
114 | def test_budget_constraint(self, media, prices, budget, expected_value):
115 | generated_value = optimize_media._budget_constraint(
116 | media=media, prices=prices, budget=budget)
117 |
118 | self.assertEqual(generated_value, expected_value)
119 |
120 | @parameterized.named_parameters([
121 | dict(
122 | testcase_name="national_media_scaler",
123 | model_name="national_mmm"),
124 | dict(
125 | testcase_name="geo_media_scaler",
126 | model_name="geo_mmm")
127 | ])
128 | def test_find_optimal_budgets_with_scaler_optimize_called_with_right_params(
129 | self, model_name):
130 |
131 | mmm = getattr(self, model_name)
132 | media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
133 | media_scaler.fit(2 * jnp.ones((10, *mmm.media.shape[1:])))
134 | optimize_media.find_optimal_budgets(
135 | n_time_periods=15,
136 | media_mix_model=mmm,
137 | budget=30,
138 | prices=jnp.ones(mmm.n_media_channels),
139 | target_scaler=None,
140 | media_scaler=media_scaler)
141 |
142 | _, call_kwargs = self.mock_minimize.call_args_list[0]
143 | # 15 weeks at 1.2 gives us 12. and 18. bounds times 2 (scaler) 24. and 36.
144 | np.testing.assert_array_almost_equal(call_kwargs["bounds"].lb,
145 | np.repeat(24., repeats=5) * mmm.n_geos,
146 | decimal=3)
147 | np.testing.assert_array_almost_equal(call_kwargs["bounds"].ub,
148 | np.repeat(36., repeats=5) * mmm.n_geos,
149 | decimal=3)
150 | # We only added scaler with divide operation so we only expectec x2 in
151 | # the divide_by parameter.
152 | np.testing.assert_array_almost_equal(call_kwargs["fun"].args[5].divide_by,
153 | 2 * jnp.ones(mmm.media.shape[1:]),
154 | decimal=3)
155 | np.testing.assert_array_almost_equal(call_kwargs["fun"].args[5].multiply_by,
156 | jnp.ones(mmm.media.shape[1:]),
157 | decimal=3)
158 |
159 | @parameterized.named_parameters([
160 | dict(
161 | testcase_name="national",
162 | model_name="national_mmm"),
163 | dict(
164 | testcase_name="geo",
165 | model_name="geo_mmm")
166 | ])
167 | def test_find_optimal_budgets_without_scaler_optimize_called_with_right_params(
168 | self, model_name):
169 |
170 | mmm = getattr(self, model_name)
171 | optimize_media.find_optimal_budgets(
172 | n_time_periods=15,
173 | media_mix_model=mmm,
174 | budget=30,
175 | prices=jnp.ones(mmm.n_media_channels),
176 | target_scaler=None,
177 | media_scaler=None)
178 |
179 | _, call_kwargs = self.mock_minimize.call_args_list[0]
180 | # 15 weeks at 1.2 gives us 18. bounds
181 | np.testing.assert_array_almost_equal(
182 | call_kwargs["bounds"].lb,
183 | np.repeat(12., repeats=5) * mmm.n_geos,
184 | decimal=3)
185 | np.testing.assert_array_almost_equal(
186 | call_kwargs["bounds"].ub,
187 | np.repeat(18., repeats=5) * mmm.n_geos,
188 | decimal=3)
189 |
190 | np.testing.assert_array_almost_equal(
191 | call_kwargs["fun"].args[5].divide_by,
192 | jnp.ones(mmm.n_media_channels),
193 | decimal=3)
194 | np.testing.assert_array_almost_equal(
195 | call_kwargs["fun"].args[5].multiply_by,
196 | jnp.ones(mmm.n_media_channels),
197 | decimal=3)
198 |
199 | @parameterized.named_parameters([
200 | dict(
201 | testcase_name="national",
202 | model_name="national_mmm"),
203 | dict(
204 | testcase_name="geo",
205 | model_name="geo_mmm")
206 | ])
207 | def test_predict_called_with_right_args(self, model_name):
208 | mmm = getattr(self, model_name)
209 |
210 | optimize_media.find_optimal_budgets(
211 | n_time_periods=15,
212 | media_mix_model=mmm,
213 | budget=30,
214 | prices=jnp.ones(mmm.n_media_channels),
215 | target_scaler=None,
216 | media_scaler=None)
217 |
218 | @parameterized.named_parameters([
219 | dict(
220 | testcase_name="national",
221 | model_name="national_mmm"),
222 | dict(
223 | testcase_name="geo",
224 | model_name="geo_mmm")
225 | ])
226 | def test_budget_lower_than_constraints_warns_user(self, model_name):
227 | mmm = getattr(self, model_name)
228 | expected_warning = (
229 | "Budget given is smaller than the lower bounds of the constraints for "
230 | "optimization. This will lead to faulty optimization. Please either "
231 | "increase the budget or change the lower bound by increasing the "
232 | "percentage decrease with the `bounds_lower_pct` parameter.")
233 |
234 | with self.assertLogs(level="WARNING") as context_manager:
235 | optimize_media.find_optimal_budgets(
236 | n_time_periods=5,
237 | media_mix_model=mmm,
238 | budget=1,
239 | prices=jnp.ones(mmm.n_media_channels))
240 | self.assertEqual(f"WARNING:absl:{expected_warning}",
241 | context_manager.output[0])
242 |
243 | @parameterized.named_parameters([
244 | dict(
245 | testcase_name="national",
246 | model_name="national_mmm"),
247 | dict(
248 | testcase_name="geo",
249 | model_name="geo_mmm")
250 | ])
251 | def test_budget_higher_than_constraints_warns_user(self, model_name):
252 | mmm = getattr(self, model_name)
253 | expected_warning = (
254 | "Budget given is larger than the upper bounds of the constraints for "
255 | "optimization. This will lead to faulty optimization. Please either "
256 | "reduce the budget or change the upper bound by increasing the "
257 | "percentage increase with the `bounds_upper_pct` parameter.")
258 |
259 | with self.assertLogs(level="WARNING") as context_manager:
260 | optimize_media.find_optimal_budgets(
261 | n_time_periods=5,
262 | media_mix_model=mmm,
263 | budget=2000,
264 | prices=jnp.ones(5))
265 | self.assertEqual(f"WARNING:absl:{expected_warning}",
266 | context_manager.output[0])
267 |
268 | @parameterized.named_parameters([
269 | dict(
270 | testcase_name="national",
271 | model_name="national_mmm",
272 | expected_len=3),
273 | dict(
274 | testcase_name="geo",
275 | model_name="geo_mmm",
276 | expected_len=3)
277 | ])
278 | def test_find_optimal_budgets_has_right_output_length_datatype(
279 | self, model_name, expected_len):
280 |
281 | mmm = getattr(self, model_name)
282 | media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
283 | media_scaler.fit(2 * jnp.ones((10, *mmm.media.shape[1:])))
284 | results = optimize_media.find_optimal_budgets(
285 | n_time_periods=15,
286 | media_mix_model=mmm,
287 | budget=30,
288 | prices=jnp.ones(mmm.n_media_channels),
289 | target_scaler=None,
290 | media_scaler=media_scaler)
291 | self.assertLen(results, expected_len)
292 | self.assertIsInstance(results[1], jax.Array)
293 | self.assertIsInstance(results[2], jax.Array)
294 |
295 | @parameterized.named_parameters([
296 | dict(
297 | testcase_name="national_prices",
298 | model_name="national_mmm",
299 | prices=np.array([1., 0.8, 1.2, 1.5, 0.5]),
300 | ),
301 | dict(
302 | testcase_name="national_ones",
303 | model_name="national_mmm",
304 | prices=np.ones(5),
305 | ),
306 | dict(
307 | testcase_name="geo_prices",
308 | model_name="geo_mmm",
309 | prices=np.array([1., 0.8, 1.2, 1.5, 0.5]),
310 | ),
311 | dict(
312 | testcase_name="geo_ones",
313 | model_name="geo_mmm",
314 | prices=np.ones(5),
315 | ),
316 | ])
317 | def test_generate_starting_values_calculates_correct_values(
318 | self, model_name, prices):
319 | mmm = getattr(self, model_name)
320 | n_time_periods = 10
321 | budget = mmm.n_media_channels * n_time_periods
322 | starting_values = optimize_media._generate_starting_values(
323 | n_time_periods=10,
324 | media_scaler=None,
325 | media=mmm.media,
326 | budget=budget,
327 | prices=prices,
328 | )
329 |
330 | # Given that data is all ones, starting values will be equal to prices.
331 | np.testing.assert_array_almost_equal(
332 | starting_values, jnp.repeat(n_time_periods, repeats=5))
333 |
334 | if __name__ == "__main__":
335 | absltest.main()
336 |
--------------------------------------------------------------------------------
/lightweight_mmm/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Set of utilities for LightweighMMM package."""
16 | import pickle
17 | import time
18 | from typing import Any, List, Optional, Tuple
19 |
20 | from absl import logging
21 | from jax import random
22 | import jax.numpy as jnp
23 | import numpy as np
24 | import pandas as pd
25 | from scipy import interpolate
26 | from scipy import optimize
27 | from scipy import spatial
28 | from scipy import stats
29 | from tensorflow.io import gfile
30 |
31 | from lightweight_mmm import media_transforms
32 |
33 |
34 | def save_model(
35 | media_mix_model: Any,
36 | file_path: str
37 | ) -> None:
38 | """Saves the given model in the given path.
39 |
40 | Args:
41 | media_mix_model: Model to save on disk.
42 | file_path: File path where the model should be placed.
43 | """
44 | with gfile.GFile(file_path, "wb") as file:
45 | pickle.dump(obj=media_mix_model, file=file)
46 |
47 |
48 | def load_model(file_path: str) -> Any:
49 | """Loads a model given a string path.
50 |
51 | Args:
52 | file_path: Path of the file containing the model.
53 |
54 | Returns:
55 | The LightweightMMM object that was stored in the given path.
56 | """
57 | with gfile.GFile(file_path, "rb") as file:
58 | media_mix_model = pickle.load(file=file)
59 |
60 | for attr in dir(media_mix_model):
61 | if attr.startswith("__"):
62 | continue
63 | attr_value = getattr(media_mix_model, attr)
64 | if isinstance(attr_value, np.ndarray):
65 | setattr(media_mix_model, attr, jnp.array(attr_value))
66 |
67 | return media_mix_model
68 |
69 |
70 | def get_time_seed() -> int:
71 | """Generates an integer using the last decimals of time.time().
72 |
73 | Returns:
74 | Integer to be used as seed.
75 | """
76 | # time.time() has the following format: 1645174953.0429401
77 | return int(str(time.time()).split(".")[1])
78 |
79 |
80 | def simulate_dummy_data(
81 | data_size: int,
82 | n_media_channels: int,
83 | n_extra_features: int,
84 | geos: int = 1,
85 | seed: int = 5
86 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
87 | """Simulates dummy data needed for media mix modelling.
88 |
89 | This function's goal is to be super simple and not have many parameters,
90 | although it does not generate a fully realistic dataset is only meant to be
91 | used for demos/tutorial purposes. Uses carryover for lagging but has no
92 | saturation and no trend.
93 |
94 | The data simulated includes the media data, extra features, a target/KPI and
95 | costs.
96 |
97 | Args:
98 | data_size: Number of rows to generate.
99 | n_media_channels: Number of media channels to generate.
100 | n_extra_features: Number of extra features to generate.
101 | geos: Number of geos for geo level data (default = 1 for national).
102 | seed: Random seed.
103 |
104 | Returns:
105 | The simulated media, extra features, target and costs.
106 | """
107 | if data_size < 1 or n_media_channels < 1 or n_extra_features < 1:
108 | raise ValueError(
109 | "Data size, n_media_channels and n_extra_features must be greater than"
110 | " 0. Please check the values introduced are greater than zero.")
111 | data_offset = int(data_size * 0.2)
112 | data_size += data_offset
113 | key = random.PRNGKey(seed)
114 | sub_keys = random.split(key=key, num=7)
115 | media_data = random.normal(key=sub_keys[0],
116 | shape=(data_size, n_media_channels)) * 1.5 + 20
117 |
118 | extra_features = random.normal(key=sub_keys[1],
119 | shape=(data_size, n_extra_features)) + 5
120 | # Reduce the costs to make ROI realistic.
121 | costs = media_data[data_offset:].sum(axis=0) * .1
122 |
123 | seasonality = media_transforms.calculate_seasonality(
124 | number_periods=data_size,
125 | degrees=2,
126 | frequency=52,
127 | gamma_seasonality=1)
128 | target_noise = random.normal(key=sub_keys[2], shape=(data_size,)) + 3
129 |
130 | # media_data_transformed = media_transforms.adstock(media_data)
131 | media_data_transformed = media_transforms.carryover(
132 | data=media_data,
133 | ad_effect_retention_rate=jnp.full((n_media_channels,), fill_value=.5),
134 | peak_effect_delay=jnp.full((n_media_channels,), fill_value=1.))
135 | beta_media = random.normal(key=sub_keys[3], shape=(n_media_channels,)) + 1
136 | beta_extra_features = random.normal(key=sub_keys[4],
137 | shape=(n_extra_features,))
138 | # There is no trend to keep this very simple.
139 | target = 15 + seasonality + media_data_transformed.dot(
140 | beta_media) + extra_features.dot(beta_extra_features) + target_noise
141 |
142 | logging.info("Correlation between transformed media and target")
143 | logging.info([
144 | np.corrcoef(target[data_offset:], media_data_transformed[data_offset:,
145 | i])[0, 1]
146 | for i in range(n_media_channels)
147 | ])
148 |
149 | logging.info("True ROI for media channels")
150 | logging.info([
151 | sum(media_data_transformed[data_offset:, i] * beta_media[i]) / costs[i]
152 | for i in range(n_media_channels)
153 | ])
154 |
155 | if geos > 1:
156 | # Distribute national data to geo and add some more noise.
157 | weights = random.uniform(key=sub_keys[5], shape=(1, geos))
158 | weights /= sum(weights)
159 | target_noise = random.normal(key=sub_keys[6], shape=(data_size, geos)) * .5
160 | target = target[:, np.newaxis].dot(weights) + target_noise
161 | media_data = media_data[:, :, np.newaxis].dot(weights)
162 | extra_features = extra_features[:, :, np.newaxis].dot(weights)
163 |
164 | return (media_data[data_offset:], extra_features[data_offset:],
165 | target[data_offset:], costs)
166 |
167 |
168 | def _split_array_into_list(
169 | dataframe: pd.DataFrame,
170 | split_level_feature: str,
171 | features: List[str],
172 | national_model_flag: bool = True) -> List[np.ndarray]:
173 | """Splits data frame into list of jax arrays.
174 |
175 | Args:
176 | dataframe: Dataframe with all the modeling feature.
177 | split_level_feature: Feature that will be used to split.
178 | features: List of feature to export from data frame.
179 | national_model_flag: Whether the data frame is used for national model.
180 |
181 | Returns:
182 | List of jax arrays.
183 | """
184 | split_level = dataframe[split_level_feature].unique()
185 | array_list_by_level = [
186 | dataframe.loc[dataframe[split_level_feature] == level, features].values.T
187 | for level in split_level
188 | ]
189 | feature_array = jnp.stack(array_list_by_level)
190 | if national_model_flag:
191 | feature_array = jnp.squeeze(feature_array, axis=2)
192 | return feature_array# jnp-type
193 |
194 |
195 | def dataframe_to_jax(
196 | dataframe: pd.DataFrame,
197 | media_features: List[str],
198 | extra_features: List[str],
199 | date_feature: str,
200 | target: str,
201 | geo_feature: Optional[str] = None,
202 | cost_features: Optional[List[str]] = None
203 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
204 | """Converts pandas dataframe to right data format for media mix model.
205 |
206 | This function's goal is to convert dataframe which is most familar with data
207 | scientists to jax arrays to help the users who are not familar with array to
208 | use the lightweight MMM library easier.
209 |
210 | Args:
211 | dataframe: Dataframe with geo, KPI, media and non-media features.
212 | media_features: List of media feature names.
213 | extra_features: List of non media feature names.
214 | date_feature: Date feature name.
215 | target: Target variables name.
216 | geo_feature: Geo feature name and it is optional if the data is at national
217 | level.
218 | cost_features: List of media cost variables and it is optional if user
219 | use actual media cost as their media features in the model.
220 |
221 | Returns:
222 | Media, extra features, target and costs arrays.
223 |
224 | Raises:
225 | ValueError: If each geo has unequal number of weeks or there is only one
226 | value in the geo feature.
227 | """
228 | if geo_feature is not None:
229 | if dataframe[geo_feature].nunique() == 1:
230 | raise ValueError(
231 | "Geo feature has at least two geos or keep default for national model"
232 | )
233 | count_by_geo = dataframe.groupby(
234 | geo_feature)[date_feature].count().reset_index()
235 | unique_date_count = count_by_geo[date_feature].nunique()
236 | if unique_date_count != 1:
237 | raise ValueError("Not all the geos have same number of weeks.")
238 | national_model_flag = False
239 | features_to_sort = [date_feature, geo_feature]
240 | else:
241 | national_model_flag = True
242 | features_to_sort = [date_feature]
243 |
244 | df_sorted = dataframe.sort_values(by=features_to_sort)
245 | media_features_data = _split_array_into_list(
246 | dataframe=df_sorted,
247 | split_level_feature=date_feature,
248 | features=media_features,
249 | national_model_flag=national_model_flag)
250 |
251 | extra_features_data = _split_array_into_list(
252 | dataframe=df_sorted,
253 | split_level_feature=date_feature,
254 | features=extra_features,
255 | national_model_flag=national_model_flag)
256 |
257 | target_data = _split_array_into_list(
258 | dataframe=df_sorted,
259 | split_level_feature=date_feature,
260 | features=[target],
261 | national_model_flag=national_model_flag)
262 | target_data = jnp.squeeze(target_data)# jnp-type
263 |
264 | if cost_features:
265 | cost_data = jnp.dot(
266 | jnp.full(len(dataframe), 1), dataframe[cost_features].values)
267 | else:
268 | cost_data = jnp.dot(
269 | jnp.full(len(dataframe), 1), dataframe[media_features].values)
270 | return (media_features_data, extra_features_data, target_data, cost_data)# jax-ndarray
271 |
272 |
273 | def get_halfnormal_mean_from_scale(scale: float) -> float:
274 | """Returns the mean of the half-normal distribition."""
275 | # https://en.wikipedia.org/wiki/Half-normal_distribution
276 | return scale * np.sqrt(2) / np.sqrt(np.pi)
277 |
278 |
279 | def get_halfnormal_scale_from_mean(mean: float) -> float:
280 | """Returns the scale of the half-normal distribution."""
281 | # https://en.wikipedia.org/wiki/Half-normal_distribution
282 | return mean * np.sqrt(np.pi) / np.sqrt(2)
283 |
284 |
285 | def get_beta_params_from_mu_sigma(mu: float,
286 | sigma: float,
287 | bracket: Tuple[float, float] = (.5, 100.)
288 | ) -> Tuple[float, float]:
289 | """Deterministically estimates (a, b) from (mu, sigma) of a beta variable.
290 |
291 | https://en.wikipedia.org/wiki/Beta_distribution
292 |
293 | Args:
294 | mu: The sample mean of the beta distributed variable.
295 | sigma: The sample standard deviation of the beta distributed variable.
296 | bracket: Search bracket for b.
297 |
298 | Returns:
299 | Tuple of the (a, b) parameters.
300 | """
301 | # Assume a = 1 to find b.
302 | def _f(x):
303 | return x ** 2 + 4 * x + 5 + 2 / x - 1 / sigma ** 2
304 | b = optimize.root_scalar(_f, bracket=bracket, method="brentq").root
305 | # Given b, now find a better a.
306 | a = b / (1 / mu - 1)
307 | return a, b
308 |
309 |
310 | def _estimate_pdf(p: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
311 | """Estimates smooth pdf with Gaussian kernel.
312 |
313 | Args:
314 | p: Samples.
315 | x: The continuous x space (sorted).
316 |
317 | Returns:
318 | A density vector.
319 | """
320 | density = sum(stats.norm(xi).pdf(x) for xi in p)
321 | return density / density.sum()
322 |
323 |
324 | def _pmf(p: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
325 | """Estimates discrete pmf.
326 |
327 | Args:
328 | p: Samples.
329 | x: The discrete x space (sorted).
330 |
331 | Returns:
332 | A pmf vector.
333 | """
334 | p_cdf = jnp.array([jnp.sum(p <= x[i]) for i in range(len(x))])
335 | p_pmf = np.concatenate([[p_cdf[0]], jnp.diff(p_cdf)])
336 | return p_pmf / p_pmf.sum()
337 |
338 |
339 | def distance_pior_posterior(p: jnp.ndarray, q: jnp.ndarray, method: str = "KS",
340 | discrete: bool = True) -> float:
341 | """Quantifies the distance between two distributions.
342 |
343 | Note we do not use KL divergence because it's not defined when a probability
344 | is 0.
345 |
346 | https://en.wikipedia.org/wiki/Hellinger_distance
347 |
348 | Args:
349 | p: Samples for distribution 1.
350 | q: Samples for distribution 2.
351 | method: We can have four methods: KS, Hellinger, JS and min.
352 | discrete: Whether input data is discrete or continuous.
353 |
354 | Returns:
355 | The distance metric (between 0 and 1).
356 | """
357 |
358 | if method == "KS":
359 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html
360 | return stats.ks_2samp(p, q).statistic
361 | elif method in ["Hellinger", "JS", "min"]:
362 | if discrete:
363 | x = jnp.unique(jnp.concatenate((p, q)))
364 | p_pdf = _pmf(p, x)
365 | q_pdf = _pmf(q, x)
366 | else:
367 | minx, maxx = min(p.min(), q.min()), max(p.max(), q.max())
368 | x = np.linspace(minx, maxx, 100)
369 | p_pdf = _estimate_pdf(p, x)
370 | q_pdf = _estimate_pdf(q, x)
371 | if method == "Hellinger":
372 | return np.sqrt(jnp.sum((np.sqrt(p_pdf) - np.sqrt(q_pdf)) ** 2)) / np.sqrt(2)
373 | elif method == "JS":
374 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html
375 | return spatial.distance.jensenshannon(p_pdf, q_pdf)
376 | else:
377 | return 1 - np.minimum(p_pdf, q_pdf).sum()
378 |
379 |
380 | def interpolate_outliers(x: jnp.ndarray,
381 | outlier_idx: jnp.ndarray) -> jnp.ndarray:
382 | """Overwrites outliers in x with interpolated values.
383 |
384 | Args:
385 | x: The original univariate variable with outliers.
386 | outlier_idx: Indices of the outliers in x.
387 |
388 | Returns:
389 | A cleaned x with outliers overwritten.
390 |
391 | """
392 | time_idx = jnp.arange(len(x))
393 | inverse_idx = jnp.array([i for i in range(len(x)) if i not in outlier_idx])
394 | interp_func = interpolate.interp1d(
395 | time_idx[inverse_idx], x[inverse_idx], kind="linear")
396 | x = x.at[outlier_idx].set(interp_func(time_idx[outlier_idx]))
397 | return x
398 |
--------------------------------------------------------------------------------
/readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Read the Docs configuration file
16 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
17 | version: 2
18 | build:
19 | os: ubuntu-22.04
20 | tools:
21 | python: "3.10"
22 | sphinx:
23 | builder: html
24 | configuration: docs/conf.py
25 | fail_on_warning: false
26 | python:
27 | install:
28 | - requirements: requirements/requirements_docs.txt
29 | - requirements: requirements/requirements.txt
30 | - method: setuptools
31 | path: .
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | absl-py
16 | arviz>=0.11.2
17 | immutabledict>=2.0.0
18 | jax>=0.3.18
19 | jaxlib>=0.3.18
20 | matplotlib==3.6.1
21 | numpy>=2.3
22 | numpyro>=0.9.2
23 | pandas>=1.1.5
24 | scipy
25 | seaborn==0.11.1
26 | scikit-learn
27 | statsmodels>=0.13.0
28 | tensorflow>=2.7.2
--------------------------------------------------------------------------------
/requirements/requirements_docs.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | sphinx==4.5.0
16 | sphinx_rtd_theme==1.0.0
17 | sphinxcontrib-katex==0.8.6
18 | sphinxcontrib-bibtex==1.0.0
19 | sphinx-autodoc-typehints==1.11.1
20 | IPython==7.16.3
21 | ipykernel==5.3.4
22 | pandoc==1.0.2
23 | myst_nb==0.13.1
24 | docutils==0.16
25 | matplotlib==3.6.1
26 | nbsphinx
27 | sphinx_markdown_tables
28 | ipython_genutils
--------------------------------------------------------------------------------
/requirements/requirements_tests.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | pytest-xdist
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup for lightweight_mmm value package."""
16 |
17 | import os
18 |
19 | from setuptools import find_packages
20 | from setuptools import setup
21 |
22 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
23 |
24 |
25 | def _get_readme():
26 | try:
27 | readme = open(
28 | os.path.join(_CURRENT_DIR, "README.md"), encoding="utf-8").read()
29 | except OSError:
30 | readme = ""
31 | return readme
32 |
33 |
34 | def _get_version():
35 | with open(os.path.join(_CURRENT_DIR, "lightweight_mmm", "__init__.py")) as fp:
36 | for line in fp:
37 | if line.startswith("__version__") and "=" in line:
38 | version = line[line.find("=") + 1:].strip(" '\"\n")
39 | if version:
40 | return version
41 | raise ValueError(
42 | "`__version__` not defined in `lightweight_mmm/__init__.py`")
43 |
44 |
45 | def _parse_requirements(path):
46 |
47 | with open(os.path.join(_CURRENT_DIR, path)) as f:
48 | return [
49 | line.rstrip()
50 | for line in f
51 | if not (line.isspace() or line.startswith("#"))
52 | ]
53 |
54 | _VERSION = _get_version()
55 | _README = _get_readme()
56 | _INSTALL_REQUIREMENTS = _parse_requirements(os.path.join(
57 | _CURRENT_DIR, "requirements", "requirements.txt"))
58 | _TEST_REQUIREMENTS = _parse_requirements(os.path.join(
59 | _CURRENT_DIR, "requirements", "requirements_tests.txt"))
60 |
61 | setup(
62 | name="lightweight_mmm",
63 | version=_VERSION,
64 | description="Package for Media-Mix-Modelling",
65 | long_description="\n".join([_README]),
66 | long_description_content_type="text/markdown",
67 | author="Google LLC",
68 | author_email="no-reply@google.com",
69 | license="Apache 2.0",
70 | packages=find_packages(),
71 | install_requires=_INSTALL_REQUIREMENTS,
72 | tests_require=_TEST_REQUIREMENTS,
73 | url="https://github.com/google/lightweight_mmm",
74 | classifiers=[
75 | "Development Status :: 3 - Alpha",
76 | "Intended Audience :: Developers",
77 | "Intended Audience :: Science/Research",
78 | "License :: OSI Approved :: Apache Software License",
79 | "Topic :: Scientific/Engineering :: Mathematics",
80 | "Programming Language :: Python :: 3.8",
81 | "Programming Language :: Python :: 3.9",
82 | "Programming Language :: Python :: 3.10",
83 |
84 | ],
85 | )
86 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | set -e
18 |
19 | readonly VENV_DIR=/tmp/test_env
20 | echo "Creating virtual environment under ${VENV_DIR}."
21 | echo "You might want to remove this when you no longer need it."
22 |
23 | # Install deps in a virtual env.
24 | python -m venv "${VENV_DIR}"
25 | source "${VENV_DIR}/bin/activate"
26 | python --version
27 |
28 | # Install JAX.
29 | python -m pip install --upgrade pip setuptools
30 | python -m pip install -r requirements/requirements.txt
31 | python -c 'import jax; print(jax.__version__)'
32 |
33 | # Run setup.py, this installs the python dependencies
34 | python -m pip install .
35 |
36 | # Python test dependencies.
37 | python -m pip install -r requirements/requirements_tests.txt
38 |
39 | # Run tests using pytest.
40 | python -m pytest lightweight_mmm
--------------------------------------------------------------------------------