├── .github └── workflows │ ├── ci.yaml │ ├── manual-python-publish.yaml │ └── python_publish.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── .gitignore ├── Makefile ├── api.rst ├── conf.py ├── custom_priors.ipynb ├── faq.md ├── index.rst └── models.ipynb ├── examples ├── end_to_end_demo_with_multiple_geos.ipynb └── simple_end_to_end_demo.ipynb ├── images ├── flowchart.png ├── lightweight_mmm_logo_colored_1000.png ├── lightweight_mmm_logo_colored_250.png ├── lightweight_mmm_logo_colored_500.png └── lightweight_mmm_logo_colored_750.png ├── lightweight_mmm ├── __init__.py ├── conftest.py ├── core │ ├── __init__.py │ ├── baseline │ │ ├── __init__.py │ │ ├── intercept.py │ │ └── intercept_test.py │ ├── core_utils.py │ ├── core_utils_test.py │ ├── priors.py │ ├── time │ │ ├── __init__.py │ │ ├── seasonality.py │ │ ├── seasonality_test.py │ │ ├── trend.py │ │ └── trend_test.py │ └── transformations │ │ ├── __init__.py │ │ ├── identity.py │ │ ├── lagging.py │ │ ├── lagging_test.py │ │ ├── saturation.py │ │ └── saturation_test.py ├── lightweight_mmm.py ├── lightweight_mmm_test.py ├── media_transforms.py ├── media_transforms_test.py ├── models.py ├── models_test.py ├── optimize_media.py ├── optimize_media_test.py ├── plot.py ├── plot_test.py ├── preprocessing.py ├── preprocessing_test.py ├── utils.py └── utils_test.py ├── readthedocs.yaml ├── requirements ├── requirements.txt ├── requirements_docs.txt └── requirements_tests.txt ├── setup.py └── test.sh /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: ci 16 | 17 | on: 18 | push: 19 | branches: ["main"] 20 | pull_request: 21 | branches: ["main"] 22 | 23 | jobs: 24 | build-and-test: 25 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 26 | runs-on: "${{ matrix.os }}" 27 | 28 | strategy: 29 | matrix: 30 | python-version: ["3.8", "3.9", "3.10"] 31 | os: [ubuntu-latest] 32 | 33 | steps: 34 | - uses: "actions/checkout@v2" 35 | - uses: "actions/setup-python@v1" 36 | with: 37 | python-version: "${{ matrix.python-version }}" 38 | - name: Run CI tests 39 | run: bash test.sh 40 | shell: bash 41 | -------------------------------------------------------------------------------- /.github/workflows/manual-python-publish.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This workflow will upload a Python Package using Twine when a release is created 16 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 17 | 18 | # This workflow uses actions that are not certified by GitHub. 19 | # They are provided by a third-party and are governed by 20 | # separate terms of service, privacy policy, and support 21 | # documentation. 22 | 23 | name: Upload Python Package 24 | 25 | on: 26 | release: 27 | types: [published] 28 | workflow_dispatch: 29 | 30 | 31 | permissions: 32 | contents: read 33 | 34 | jobs: 35 | deploy: 36 | 37 | runs-on: ubuntu-latest 38 | 39 | steps: 40 | - uses: actions/checkout@v3 41 | - name: Set up Python 42 | uses: actions/setup-python@v3 43 | with: 44 | python-version: '3.x' 45 | - name: Install dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install setuptools wheel twine 49 | - name: Build and publish 50 | env: 51 | TWINE_USERNAME: __token__ 52 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 53 | run: | 54 | python setup.py sdist bdist_wheel 55 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/python_publish.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Upload Python Package 16 | 17 | on: 18 | release: 19 | types: [created] 20 | 21 | jobs: 22 | deploy: 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - uses: actions/checkout@v2 27 | - uses: actions/setup-python@v2 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install setuptools wheel twine 32 | - name: Build and publish 33 | env: 34 | TWINE_USERNAME: __token__ 35 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 36 | run: | 37 | python setup.py sdist bdist_wheel 38 | twine upload dist/* -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | _build -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Minimal makefile for Sphinx documentation 17 | # 18 | # You can set these variables from the command line. 19 | SPHINXOPTS = 20 | SPHINXBUILD = sphinx-build 21 | SOURCEDIR = . 22 | BUILDDIR = _build 23 | # Put it first so that "make" without argument is like "make help". 24 | help: 25 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 26 | .PHONY: help Makefile 27 | # Catch-all target: route all unknown targets to Sphinx using the new 28 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 29 | %: Makefile 30 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | LightweightMMM 2 | =============== 3 | .. currentmodule:: lightweight_mmm 4 | 5 | 6 | LightweightMMM object 7 | ====================== 8 | .. currentmodule:: lightweight_mmm.lightweight_mmm 9 | .. autosummary:: 10 | LightweightMMM 11 | 12 | .. autoclass:: LightweightMMM 13 | :members: 14 | 15 | 16 | Preprocessing / Scaling 17 | ======================== 18 | .. currentmodule:: lightweight_mmm.preprocessing 19 | .. autosummary:: 20 | CustomScaler 21 | 22 | .. autoclass:: CustomScaler 23 | :members: 24 | 25 | 26 | Optimize Media 27 | =============== 28 | .. currentmodule:: lightweight_mmm.optimize_media 29 | .. autosummary:: 30 | find_optimal_budgets 31 | 32 | .. autofunction:: find_optimal_budgets 33 | 34 | Plot 35 | ===== 36 | .. currentmodule:: lightweight_mmm.plot 37 | .. autosummary:: 38 | plot_response_curves 39 | plot_cross_correlate 40 | plot_var_cost 41 | plot_model_fit 42 | plot_out_of_sample_model_fit 43 | plot_media_channel_posteriors 44 | plot_prior_and_posterior 45 | plot_bars_media_metrics 46 | plot_pre_post_budget_allocation_comparison 47 | plot_media_baseline_contribution_area_plot 48 | create_media_baseline_contribution_df 49 | 50 | .. autofunction:: plot_response_curves 51 | .. autofunction:: plot_cross_correlate 52 | .. autofunction:: plot_var_cost 53 | .. autofunction:: plot_model_fit 54 | .. autofunction:: plot_out_of_sample_model_fit 55 | .. autofunction:: plot_media_channel_posteriors 56 | .. autofunction:: plot_prior_and_posterior 57 | .. autofunction:: plot_bars_media_metrics 58 | .. autofunction:: plot_pre_post_budget_allocation_comparison 59 | .. autofunction:: plot_media_baseline_contribution_area_plot 60 | .. autofunction:: create_media_baseline_contribution_df 61 | 62 | 63 | Models 64 | ======= 65 | .. currentmodule:: lightweight_mmm.models 66 | .. autosummary:: 67 | transform_adstock 68 | transform_hill_adstock 69 | transform_carryover 70 | media_mix_model 71 | 72 | .. autofunction:: transform_adstock 73 | .. autofunction:: transform_hill_adstock 74 | .. autofunction:: transform_carryover 75 | .. autofunction:: media_mix_model 76 | 77 | Media Transforms 78 | ================= 79 | .. currentmodule:: lightweight_mmm.media_transforms 80 | .. autosummary:: 81 | calculate_seasonality 82 | adstock 83 | hill 84 | carryover 85 | apply_exponent_safe 86 | 87 | .. autofunction:: calculate_seasonality 88 | .. autofunction:: adstock 89 | .. autofunction:: hill 90 | .. autofunction:: carryover 91 | .. autofunction:: apply_exponent_safe 92 | 93 | Utils 94 | ====== 95 | .. currentmodule:: lightweight_mmm.utils 96 | .. autosummary:: 97 | save_model 98 | load_model 99 | simulate_dummy_data 100 | get_halfnormal_mean_from_scale 101 | get_halfnormal_scale_from_mean 102 | get_beta_params_from_mu_sigma 103 | distance_pior_posterior 104 | interpolate_outliers 105 | dataframe_to_jax 106 | 107 | .. autofunction:: save_model 108 | .. autofunction:: load_model 109 | .. autofunction:: simulate_dummy_data 110 | .. autofunction:: get_halfnormal_mean_from_scale 111 | .. autofunction:: get_halfnormal_scale_from_mean 112 | .. autofunction:: get_beta_params_from_mu_sigma 113 | .. autofunction:: distance_pior_posterior 114 | .. autofunction:: interpolate_outliers 115 | .. autofunction:: dataframe_to_jax 116 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration file for the Sphinx documentation builder.""" 16 | 17 | 18 | # This file only contains a selection of the most common options. For a full 19 | # list see the documentation: 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 21 | 22 | # -- Path setup -------------------------------------------------------------- 23 | 24 | # If extensions (or modules to document with autodoc) are in another directory, 25 | # add these directories to sys.path here. If the directory is relative to the 26 | # documentation root, use os.path.abspath to make it absolute, like shown here. 27 | # 28 | # import os 29 | # import sys 30 | # sys.path.insert(0, os.path.abspath('.')) 31 | 32 | import os 33 | import sys 34 | sys.path.insert(0, os.path.abspath('..')) 35 | 36 | # -- Project information ----------------------------------------------------- 37 | 38 | project = 'LightweightMMM' 39 | copyright = '2022, The LightweightMMM authors' 40 | author = 'The LightweightMMM authors' 41 | 42 | 43 | # -- General configuration --------------------------------------------------- 44 | 45 | # Add any Sphinx extension module names here, as strings. They can be 46 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 47 | # ones. 48 | extensions = [ 49 | 'sphinx.ext.autodoc', 50 | 'sphinx.ext.autosummary', 51 | 'sphinx.ext.autosectionlabel', 52 | 'sphinx.ext.doctest', 53 | 'sphinx.ext.intersphinx', 54 | 'sphinx.ext.mathjax', 55 | 'sphinx.ext.napoleon', 56 | 'sphinx.ext.viewcode', 57 | 'nbsphinx', 58 | 'myst_nb', # This is used for the .ipynb notebooks 59 | ] 60 | 61 | # Add any paths that contain templates here, relative to this directory. 62 | templates_path = ['_templates'] 63 | 64 | # List of patterns, relative to source directory, that match files and 65 | # directories to ignore when looking for source files. 66 | # This pattern also affects html_static_path and html_extra_path. 67 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 68 | 69 | source_suffix = ['.rst', '.md', '.ipynb'] 70 | 71 | autosummary_generate = True 72 | 73 | master_doc = 'index' 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | html_theme = 'sphinx_rtd_theme' 81 | html_static_path = [] 82 | 83 | # Add any paths that contain custom static files (such as style sheets) here, 84 | # relative to this directory. They are copied after the builtin static files, 85 | # so a file named "default.css" will overwrite the builtin "default.css". 86 | html_static_path = ['_static'] 87 | 88 | nbsphinx_codecell_lexer = 'ipython3' 89 | nbsphinx_execute = 'never' 90 | jupyter_execute_notebooks = 'off' 91 | 92 | # -- Extension configuration ------------------------------------------------- 93 | 94 | # Tell sphinx-autodoc-typehints to generate stub parameter annotations including 95 | # types, even if the parameters aren't explicitly documented. 96 | always_document_param_types = True 97 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | ## Importing 2 | 3 | *What do I do if I get a numpy ImportError when importing lightweight_mmm?* 4 | If you get an error like: 5 | 6 | ```{python} 7 | ImportError: numpy.core.multiarray failed to import 8 | ``` 9 | 10 | Then this probably means you have the wrong version of numpy. 11 | The correct version should be installed by pip when you install lightweight_mmm, but often this error happens when you are using Google Colab or a Jupyter notebook and installing directly from a notebook cell. 12 | To resolve this, make sure you have restarted your Google Colab / Jupyter session after running `!pip install lightwight_mmm`. 13 | This makes sure you have the newly installed version of numpy imported. 14 | 15 | ## Input 16 | *Which media channel metrics can be used as input?* 17 | 18 | You can use impressions, clicks or cost, especially for non digital data. For TV you could e.g. use TV rating points or cost. The model only takes the variation within a channel into account. 19 | 20 | *Can I run MMM at campaign level?* 21 | 22 | We generally don't recommend this. MMM is a macro tool that works well at the channel level. If you use distinct campaigns that have hard starts and stops, you risk losing the memory of the Adstock. 23 | If you are interested in more granular insights, we recommend data-driven multi-touch attribution for your digital channels. You can find an example open-source package [here](https://github.com/google/fractribution/tree/master/py). 24 | 25 | *What is the ideal ratio between %train and %test data for LMMM?* 26 | 27 | Remember we treat LMMM not as a forecasting model, so test data is not always needed. When opting for a train / test split you can use at least 13 weeks for a test. However we recommend refraining from a too log testing period and possibly consider running a separate model that is specifically built for forecasting. 28 | 29 | *What are best practices for lead generating businesses, with long sales cycles?* 30 | 31 | It really depends on your target variable, i.e. what outcome you would like to measure. If generating a lead takes multiple months, you can take more immediate action KPIs like ‘number of conversions’ or ‘number of site visits, form entries’ into account. 32 | 33 | 34 | 35 | ## Modelling 36 | 37 | 38 | *What can I do if the baseline is too low and total media contribution is too high?* 39 | 40 | You can try various things: 41 | 1) You can include non-media variables. 42 | 2) You can lower the prior for the beta (in front of the transformed media). 43 | 3) You can set a bigger prior for the intercept. 44 | 45 | 46 | *What are the different ways we can inform the media priors?* 47 | 48 | By default, the media priors are informed by the costs, channels with more spend get bigger priors. 49 | You can also base media priors on (geo) experiments or use a heuristic like "the percentage of weeks a channel was used". The intuition behind this is that the more a channel is used, the more a marketer believes its contribution should be high. 50 | Outputs from multi-touch attribution (MTA) can also be used as priors for an MMM. 51 | Think with Google has recently published an [article](https://www.thinkwithgoogle.com/_qs/documents/13385/TwGxOP_Unified_Marketing_Measurement.pdf) on combining results from MTA and MMM. 52 | 53 | 54 | *How should I refresh my model and how often?* 55 | 56 | This depends on the data frequency (daily, weekly) but also in what time frame the marketer makes decisions. If decisions are quarterly, we'd recommend to run the model each quarter. 57 | The data window can be expanded each time, so that older data still has an influence on the most recent estimate. Alternatively, old data can also be discarded, for instance if media effectiveness and strategies have changed more drastically over time. Note however that you can always use the posteriors from a previous modeling cycle as priors when you refresh the model. 58 | 59 | *Why is your model additive?* 60 | 61 | We might make the model multiplicative in future versions, but to keep simple and lightweight we have opted for the additive model for now. 62 | 63 | *How does MCMC sampling works in LMMM?* 64 | 65 | LMMM uses the [NUTS algorithm](https://mc-stan.org/docs/2_18/stan-users-guide/sampling-difficulties-with-problematic-priors.html) to solve the budget allocation question. NUTS only cares about priors and posteriors for each parameter and uses all of the data. 66 | 67 | ## Evaluation 68 | 69 | *How important is OOS predictive performance?* 70 | 71 | Remember MMM is not a forecasting model but an contribution and optimisation tool. Test performance should be looked at but contribution is more important. 72 | 73 | *Which metric is recommended for evaluating goodness of fit on test data?* 74 | 75 | We recommend looking at MAPE or median APE instead of the R-squared metric, as those are more interpretable from a business perspective and less influenced by outliers. 76 | 77 | *How is media effectiveness defined?* 78 | 79 | Media effectiveness shows you how much each media channel percentually contributes to the target variable (e.g. y := Sum of sales). -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/google/lightweight_mmm/tree/main/docs 2 | 3 | LightweightMMM Documentation 4 | =================== 5 | 6 | LightweightMMM 🦇 is a lightweight Bayesian Marketing Mix Modeling (MMM) 7 | library that allows users to easily train MMMs and obtain channel attribution 8 | information. The library also includes capabilities for optimizing media 9 | allocation as well as plotting common graphs in the field. 10 | 11 | It is built in python3 and makes use of Numpyro and JAX. 12 | 13 | Installation 14 | ------------ 15 | 16 | We have kept JAX as part of the dependencies to install when installing 17 | LightweightMMM, however if you wish to install a different version of JAX or 18 | jaxlib for specific CUDA/CuDNN versions see 19 | https://github.com/jax-ml/jax?tab=readme-ov-file#instructions for instructions on installing 20 | JAX. Otherwise our installation assumes a CPU setup. 21 | 22 | The recommended way of installing lightweight_mmm is through PyPi: 23 | 24 | ``pip install --upgrade pip`` 25 | ``pip install lightweight_mmm`` 26 | 27 | If you want to use the most recent and slightly less stable version you can install it from github: 28 | 29 | ``pip install --upgrade git+https://github.com/google/lightweight_mmm.git`` 30 | 31 | 32 | .. toctree:: 33 | :caption: Model Documentation 34 | :maxdepth: 1 35 | 36 | models 37 | 38 | .. toctree:: 39 | :caption: Custom priors 40 | :maxdepth: 1 41 | 42 | custom_priors 43 | 44 | .. toctree:: 45 | :caption: API Documentation 46 | :maxdepth: 2 47 | 48 | api 49 | 50 | .. toctree:: 51 | :caption: FAQ 52 | :maxdepth: 2 53 | 54 | faq 55 | 56 | Contribute 57 | ---------- 58 | 59 | - Issue tracker: https://github.com/google/lightweight_mmm/issues 60 | - Source code: https://github.com/google/lightweight_mmm/tree/main 61 | 62 | Support 63 | ------- 64 | 65 | If you are having issues, please let us know by filing an issue on our 66 | `issue tracker `_. 67 | 68 | License 69 | ------- 70 | 71 | LightweightMMM is licensed under the Apache 2.0 License. 72 | 73 | Indices and tables 74 | ================== 75 | 76 | * :ref:`genindex` -------------------------------------------------------------------------------- /docs/models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YTW94aX0fWCQ" 7 | }, 8 | "source": [ 9 | "# LightweightMMM Models.\n", 10 | "\n", 11 | "The LightweightMMM can either be run using data aggregated at the national level (standard approach) or using data aggregated at a geo level (sub-national hierarchical approach). These models are documented below.\n", 12 | "\n", 13 | "## National level (standard approach)\n", 14 | "\n", 15 | "All the parameters in our Bayesian model have priors which have been set based on simulated studies that produce stable results. We also set out our three different approaches to saturation and media lagging effects: carryover (with exponent), adstock (with exponent) and hill adstock. Please see Jin, Y. et al., (2017) for more details on these models and choice of priors.\n", 16 | "\n", 17 | " $$kpi_{t} = \\alpha + trend_{t} + seasonality_{t} + media\\ channels_{t} + other\\ factors_{t} $$\n", 18 | "\n", 19 | "*Intercept*\u003cbr\u003e\n", 20 | "- $\\alpha \\sim HalfNormal(2)$\n", 21 | "\n", 22 | "*Trend*\u003cbr\u003e\n", 23 | "- $trend_{t} = \\mu t^{\\kappa}$\u003cbr\u003e\n", 24 | "- $\\mu \\sim Normal(0,1)$\u003cbr\u003e\n", 25 | "- $\\kappa \\sim Uniform(0.5,1.5)$\u003cbr\u003e\n", 26 | "- Where $t$ is a linear trend input\n", 27 | "\n", 28 | "*Seasonality (for models using* **weekly observations**)\u003cbr\u003e\n", 29 | "- $seasonality_{t} = \\displaystyle\\sum_{d=1}^{2} (\\gamma_{1,d} cos(\\frac{2 \\pi dt}{52}) + \\gamma_{2,d} sin(\\frac{2 \\pi dt}{52}))$\u003cbr\u003e\n", 30 | "- $\\gamma_{1,d}, \\gamma_{2,d} \\sim Normal(0,1)$\u003cbr\u003e\n", 31 | "\n", 32 | "*Seasonality (for models using* **daily observations**)\u003cbr\u003e\n", 33 | "- $seasonality_{t} = \\displaystyle\\sum_{d=1}^{2} (\\gamma_{1,d} cos(\\frac{2 \\pi dt}{365}) + \\gamma_{2,d} sin(\\frac{2 \\pi dt}{365})) + \\delta_{t \\bmod 7}$\u003cbr\u003e\n", 34 | "- $\\gamma_{1,d}, \\gamma_{2,d} \\sim Normal(0,1)$\u003cbr\u003e\n", 35 | "- $\\delta_{i} \\sim Normal(0,0.5)$\u003cbr\u003e\n", 36 | "\n", 37 | "*Other factors*\u003cbr\u003e\n", 38 | "- $other\\ factors_{t} = \\displaystyle\\sum_{i=1}^{N} \\lambda_{i}Z_{it}$\u003cbr\u003e\n", 39 | "- $\\lambda_{i} \\sim Normal(0,1)$\u003cbr\u003e\n", 40 | "- Where $Z_{i}$ are other factors and $N$ is the number of other factors.\n", 41 | "\n", 42 | "*Media Effect*\u003cbr\u003e\n", 43 | "- $\\beta_{m} \\sim HalfNormal(v_{m})$\u003cbr\u003e\n", 44 | "- Where $v_{m}$ is a scalar equal to the sum of the total cost of media channel $m$.\n", 45 | "\n", 46 | "*Media Channels (for the* **carryover** *model)*\u003cbr\u003e\n", 47 | "- $media\\ channels_{t} = x_{t,m}^{*\\rho_{m}}$\u003cbr\u003e\n", 48 | "- $x_{t,m}^{*} = \\frac{\\displaystyle\\sum_{l=0}^{L} \\tau_{m}^{(l-\\theta_{m})^2}x_{t-l,m}}{\\displaystyle\\sum_{l=0}^{L}\\tau_{m}^{(l-\\theta_{m})^2}}$ where $L=12$\u003cbr\u003e\n", 49 | "- $\\tau_{m} \\sim Beta(1,1)$\u003cbr\u003e\n", 50 | "- $\\theta_{m} \\sim HalfNormal(2)$\u003cbr\u003e\n", 51 | "- $\\rho_{m} \\sim Beta(9,1)$\n", 52 | "- Where $x_{t,m}$ is the media spend or impressions in week $t$ from media channel $m$\u003cbr\u003e\n", 53 | "\n", 54 | "*Media Channels (for the* **adstock** *model)*\u003cbr\u003e\n", 55 | "- $media\\ channels_{t} = x_{t,m,s}^{*\\rho_{m}}$\u003cbr\u003e\n", 56 | "- $x_{t,m,s}^{*} = \\frac{x_{t,m}^{*}}{1/(1-\\lambda_{m})}$\u003cbr\u003e\n", 57 | "- $x_{t,m}^{*} = x_{t,m} + \\lambda_{m} x_{t-1,m}^{*}$ where $t=2,..,N$\u003cbr\u003e\n", 58 | "- $x_{1,m}^{*} = x_{1,m}$\u003cbr\u003e\n", 59 | "- $\\lambda_{m} \\sim Beta(2,1)$\u003cbr\u003e\n", 60 | "- $\\rho_{m} \\sim Beta(9,1)$\n", 61 | "- Where $x_{t,m}$ is the media spend or impressions in week $t$ from media channel $m$\u003cbr\u003e\n", 62 | "\n", 63 | "*Media Channels (for the* **hill_adstock** *model)*\u003cbr\u003e\n", 64 | "- $media\\ channels_{t} = \\frac{1}{1+(x_{t,m}^{*} / K_{m})^{-S_{m}}}$\u003cbr\u003e\n", 65 | "- $x_{t,m}^{*} = x_{t,m} + \\lambda_{m} x_{t-1,m}^{*}$ where $t=2,..,N$\u003cbr\u003e\n", 66 | "- $x_{1,m}^{*} = x_{1,m}$\u003cbr\u003e\n", 67 | "- $K_{m} \\sim Gamma(1,1)$\u003cbr\u003e\n", 68 | "- $S_{m} \\sim Gamma(1,1)$\u003cbr\u003e\n", 69 | "- $\\lambda_{m} \\sim Beta(2,1)$\n", 70 | "- Where $x_{t,m}$ is the media spend or impressions in week $t$ from media channel $m$\u003cbr\u003e\n", 71 | "\n", 72 | "## Geo level (sub-national hierarchical approach)\n", 73 | "\n", 74 | "The hierarchical model is analogous to the standard model except there is an additional dimension of region. In the geo level model seasonality is learned at the sub-national level and at the national level. For more details on this model, please see Sun, Y et al., (2017)." 75 | ] 76 | } 77 | ], 78 | "metadata": { 79 | "colab": { 80 | "collapsed_sections": [], 81 | "name": "models.ipynb", 82 | "private_outputs": true, 83 | "provenance": [] 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 0 88 | } 89 | -------------------------------------------------------------------------------- /images/flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/flowchart.png -------------------------------------------------------------------------------- /images/lightweight_mmm_logo_colored_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_1000.png -------------------------------------------------------------------------------- /images/lightweight_mmm_logo_colored_250.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_250.png -------------------------------------------------------------------------------- /images/lightweight_mmm_logo_colored_500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_500.png -------------------------------------------------------------------------------- /images/lightweight_mmm_logo_colored_750.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/lightweight_mmm/9b4256513f144172dbb698100fbe68600e62d2ad/images/lightweight_mmm_logo_colored_750.png -------------------------------------------------------------------------------- /lightweight_mmm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """LightweightMMM library. 16 | 17 | Detailed documentation and examples can be found in the 18 | [Github repository](https://github.com/google/lightweight_mmm). 19 | """ 20 | __version__ = "0.1.9" 21 | -------------------------------------------------------------------------------- /lightweight_mmm/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Parse absl flags for when it is run by github actions by pytest.""" 16 | 17 | from absl import flags 18 | 19 | 20 | def pytest_configure(config): 21 | flags.FLAGS.mark_as_parsed() 22 | -------------------------------------------------------------------------------- /lightweight_mmm/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /lightweight_mmm/core/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /lightweight_mmm/core/baseline/intercept.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for modeling the intercept.""" 16 | 17 | from typing import Mapping 18 | 19 | import immutabledict 20 | import jax.numpy as jnp 21 | import numpyro 22 | from numpyro import distributions as dist 23 | 24 | from lightweight_mmm.core import core_utils 25 | from lightweight_mmm.core import priors 26 | 27 | 28 | def simple_intercept( 29 | data: jnp.ndarray, 30 | custom_priors: Mapping[str, 31 | dist.Distribution] = immutabledict.immutabledict(), 32 | ) -> jnp.ndarray: 33 | """Calculates a national or geo incercept. 34 | Note that this intercept is constant over time. 35 | 36 | Args: 37 | data: Media input data. Media data must have either 2 dims for national 38 | model or 3 for geo models. 39 | custom_priors: The custom priors we want the model to take instead of the 40 | default ones. Refer to the full documentation on custom priors for 41 | details. 42 | 43 | Returns: 44 | The values of the intercept. 45 | """ 46 | default_priors = priors.get_default_priors() 47 | n_geos = core_utils.get_number_geos(data=data) 48 | 49 | with numpyro.plate(name=f"{priors.INTERCEPT}_plate", size=n_geos): 50 | intercept = numpyro.sample( 51 | name=priors.INTERCEPT, 52 | fn=custom_priors.get(priors.INTERCEPT, 53 | default_priors[priors.INTERCEPT]), 54 | ) 55 | return jnp.asarray(intercept) 56 | -------------------------------------------------------------------------------- /lightweight_mmm/core/baseline/intercept_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for intercept.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpyro 22 | from numpyro import handlers 23 | import numpyro.distributions as dist 24 | 25 | from lightweight_mmm.core import core_utils 26 | from lightweight_mmm.core import priors 27 | from lightweight_mmm.core.baseline import intercept 28 | 29 | 30 | class InterceptTest(parameterized.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | dict( 34 | testcase_name="national", 35 | data_shape=(150, 3), 36 | ), 37 | dict( 38 | testcase_name="geo", 39 | data_shape=(150, 3, 5), 40 | ), 41 | ) 42 | def test_simple_intercept_produces_output_correct_shape(self, data_shape): 43 | 44 | def mock_model_function(data): 45 | numpyro.deterministic( 46 | "intercept_values", 47 | intercept.simple_intercept(data=data, custom_priors={})) 48 | 49 | num_samples = 10 50 | data = jnp.ones(data_shape) 51 | n_geos = core_utils.get_number_geos(data=data) 52 | kernel = numpyro.infer.NUTS(model=mock_model_function) 53 | mcmc = numpyro.infer.MCMC( 54 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 55 | rng_key = jax.random.PRNGKey(0) 56 | 57 | mcmc.run(rng_key, data=data) 58 | intercept_values = mcmc.get_samples()["intercept_values"] 59 | 60 | self.assertEqual(intercept_values.shape, (num_samples, n_geos)) 61 | 62 | @parameterized.named_parameters( 63 | dict( 64 | testcase_name="national", 65 | data_shape=(150, 3), 66 | ), 67 | dict( 68 | testcase_name="geo", 69 | data_shape=(150, 3, 5), 70 | ), 71 | ) 72 | def test_simple_intercept_takes_custom_priors_correctly(self, data_shape): 73 | prior_name = priors.INTERCEPT 74 | expected_value1, expected_value2 = 5.2, 7.56 75 | custom_priors = { 76 | prior_name: 77 | dist.Kumaraswamy( 78 | concentration1=expected_value1, concentration0=expected_value2) 79 | } 80 | media = jnp.ones(data_shape) 81 | 82 | trace_handler = handlers.trace( 83 | handlers.seed(intercept.simple_intercept, rng_seed=0)) 84 | trace = trace_handler.get_trace(data=media, custom_priors=custom_priors) 85 | values_and_dists = { 86 | name: site["fn"] for name, site in trace.items() if "fn" in site 87 | } 88 | 89 | used_distribution = values_and_dists[prior_name].base_dist 90 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 91 | self.assertEqual(used_distribution.concentration0, expected_value2) 92 | self.assertEqual(used_distribution.concentration1, expected_value1) 93 | 94 | 95 | if __name__ == "__main__": 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /lightweight_mmm/core/core_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sets of utilities used across the core components of LightweightMMM.""" 16 | 17 | import sys 18 | from typing import Any, Mapping, Tuple, Union 19 | 20 | import jax.numpy as jnp 21 | 22 | from numpyro import distributions as dist 23 | 24 | # pylint: disable=g-import-not-at-top 25 | if sys.version_info >= (3, 8): 26 | from typing import Protocol 27 | else: 28 | from typing_extensions import Protocol 29 | 30 | 31 | class TransformFunction(Protocol): 32 | 33 | def __call__( 34 | self, 35 | data: jnp.ndarray, 36 | custom_priors: Mapping[str, dist.Distribution], 37 | prefix: str, 38 | **kwargs: Any, 39 | ) -> jnp.ndarray: 40 | ... 41 | 42 | 43 | class Module(Protocol): 44 | 45 | def __call__( 46 | self, 47 | *args: Any, 48 | **kwargs: Any, 49 | ) -> jnp.ndarray: 50 | ... 51 | 52 | 53 | def get_number_geos(data: jnp.ndarray) -> int: 54 | return data.shape[2] if data.ndim == 3 else 1 55 | 56 | 57 | def get_geo_shape(data: jnp.ndarray) -> Union[Tuple[int], Tuple[()]]: 58 | return (data.shape[2],) if data.ndim == 3 else () 59 | 60 | 61 | def apply_exponent_safe(data: jnp.ndarray, 62 | exponent: jnp.ndarray) -> jnp.ndarray: 63 | """Applies an exponent to given data in a gradient safe way. 64 | 65 | More info on the double jnp.where can be found: 66 | https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf 67 | 68 | Args: 69 | data: Input data to use. 70 | exponent: Exponent required for the operations. 71 | 72 | Returns: 73 | The result of the exponent operation with the inputs provided. 74 | """ 75 | exponent_safe = jnp.where(data == 0, 1, data) ** exponent 76 | return jnp.where(data == 0, 0, exponent_safe) 77 | -------------------------------------------------------------------------------- /lightweight_mmm/core/core_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for core_utils.""" 16 | 17 | from lightweight_mmm.core import core_utils 18 | from absl.testing import absltest 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | 24 | class CoreUtilsTest(absltest.TestCase): 25 | 26 | def test_apply_exponent_safe_produces_same_exponent_results(self): 27 | data = jnp.arange(50).reshape((10, 5)) 28 | exponent = jnp.full(5, 0.5) 29 | 30 | output = core_utils.apply_exponent_safe(data=data, exponent=exponent) 31 | 32 | np.testing.assert_array_equal(output, data**exponent) 33 | 34 | def test_apply_exponent_safe_produces_correct_shape(self): 35 | data = jnp.ones((10, 5)) 36 | exponent = jnp.full(5, 0.5) 37 | 38 | output = core_utils.apply_exponent_safe(data=data, exponent=exponent) 39 | 40 | self.assertEqual(output.shape, data.shape) 41 | 42 | def test_apply_exponent_safe_produces_non_nan_or_inf_grads(self): 43 | 44 | def f_safe(data, exponent): 45 | x = core_utils.apply_exponent_safe(data=data, exponent=exponent) 46 | return x.sum() 47 | 48 | data = jnp.ones((10, 5)) 49 | data = data.at[0, 0].set(0.) 50 | exponent = jnp.full(5, 0.5) 51 | 52 | grads = jax.grad(f_safe)(data, exponent) 53 | 54 | self.assertFalse(np.isnan(grads).any()) 55 | self.assertFalse(np.isinf(grads).any()) 56 | 57 | 58 | if __name__ == '__main__': 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /lightweight_mmm/core/priors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sets priors and prior related constants for LMMM.""" 16 | from typing import Mapping 17 | 18 | import immutabledict 19 | from numpyro import distributions as dist 20 | 21 | # Core model priors 22 | INTERCEPT = "intercept" 23 | COEF_TREND = "coef_trend" 24 | EXPO_TREND = "expo_trend" 25 | SIGMA = "sigma" 26 | GAMMA_SEASONALITY = "gamma_seasonality" 27 | WEEKDAY = "weekday" 28 | COEF_EXTRA_FEATURES = "coef_extra_features" 29 | COEF_SEASONALITY = "coef_seasonality" 30 | 31 | # Lagging priors 32 | LAG_WEIGHT = "lag_weight" 33 | AD_EFFECT_RETENTION_RATE = "ad_effect_retention_rate" 34 | PEAK_EFFECT_DELAY = "peak_effect_delay" 35 | 36 | # Saturation priors 37 | EXPONENT = "exponent" 38 | HALF_MAX_EFFECTIVE_CONCENTRATION = "half_max_effective_concentration" 39 | SLOPE = "slope" 40 | 41 | # Dynamic trend priors 42 | DYNAMIC_TREND_INITIAL_LEVEL = "dynamic_trend_initial_level" 43 | DYNAMIC_TREND_INITIAL_SLOPE = "dynamic_trend_initial_slope" 44 | DYNAMIC_TREND_LEVEL_VARIANCE = "dynamic_trend_level_variance" 45 | DYNAMIC_TREND_SLOPE_VARIANCE = "dynamic_trend_slope_variance" 46 | 47 | MODEL_PRIORS_NAMES = frozenset(( 48 | INTERCEPT, 49 | COEF_TREND, 50 | EXPO_TREND, 51 | SIGMA, 52 | GAMMA_SEASONALITY, 53 | WEEKDAY, 54 | COEF_EXTRA_FEATURES, 55 | COEF_SEASONALITY, 56 | LAG_WEIGHT, 57 | AD_EFFECT_RETENTION_RATE, 58 | PEAK_EFFECT_DELAY, 59 | EXPONENT, 60 | HALF_MAX_EFFECTIVE_CONCENTRATION, 61 | SLOPE, 62 | DYNAMIC_TREND_INITIAL_LEVEL, 63 | DYNAMIC_TREND_INITIAL_SLOPE, 64 | DYNAMIC_TREND_LEVEL_VARIANCE, 65 | DYNAMIC_TREND_SLOPE_VARIANCE, 66 | )) 67 | 68 | GEO_ONLY_PRIORS = frozenset((COEF_SEASONALITY,)) 69 | 70 | 71 | def get_default_priors() -> Mapping[str, dist.Distribution]: 72 | # Since JAX cannot be called before absl.app.run in tests we get default 73 | # priors from a function. 74 | return immutabledict.immutabledict({ 75 | INTERCEPT: dist.HalfNormal(scale=2.), 76 | COEF_TREND: dist.Normal(loc=0., scale=1.), 77 | EXPO_TREND: dist.Uniform(low=0.5, high=1.5), 78 | SIGMA: dist.Gamma(concentration=1., rate=1.), 79 | GAMMA_SEASONALITY: dist.Normal(loc=0., scale=1.), 80 | WEEKDAY: dist.Normal(loc=0., scale=.5), 81 | COEF_EXTRA_FEATURES: dist.Normal(loc=0., scale=1.), 82 | COEF_SEASONALITY: dist.HalfNormal(scale=.5), 83 | AD_EFFECT_RETENTION_RATE: dist.Beta(concentration1=1., concentration0=1.), 84 | PEAK_EFFECT_DELAY: dist.HalfNormal(scale=2.), 85 | EXPONENT: dist.Beta(concentration1=9., concentration0=1.), 86 | LAG_WEIGHT: dist.Beta(concentration1=2., concentration0=1.), 87 | HALF_MAX_EFFECTIVE_CONCENTRATION: dist.Gamma(concentration=1., rate=1.), 88 | SLOPE: dist.Gamma(concentration=1., rate=1.), 89 | DYNAMIC_TREND_INITIAL_LEVEL: dist.Normal(loc=.5, scale=2.5), 90 | DYNAMIC_TREND_INITIAL_SLOPE: dist.Normal(loc=0., scale=.2), 91 | DYNAMIC_TREND_LEVEL_VARIANCE: dist.Uniform(low=0., high=.1), 92 | DYNAMIC_TREND_SLOPE_VARIANCE: dist.Uniform(low=0., high=.01), 93 | }) 94 | -------------------------------------------------------------------------------- /lightweight_mmm/core/time/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | -------------------------------------------------------------------------------- /lightweight_mmm/core/time/seasonality.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Core and modelling functions for seasonality.""" 16 | 17 | from typing import Mapping 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | import numpyro 22 | from numpyro import distributions as dist 23 | 24 | from lightweight_mmm.core import priors 25 | from lightweight_mmm.core import core_utils 26 | 27 | 28 | @jax.jit 29 | def _sinusoidal_seasonality( 30 | seasonality_arange: jnp.ndarray, 31 | degrees_arange: jnp.ndarray, 32 | gamma_seasonality: jnp.ndarray, 33 | frequency: int, 34 | ) -> jnp.ndarray: 35 | """Core calculation of cyclic variation seasonality. 36 | 37 | Args: 38 | seasonality_arange: Array with range [0, N - 1] where N is the size of the 39 | data for which the seasonality is modelled. 40 | degrees_arange: Array with range [0, D - 1] where D is the number of degrees 41 | to use. Must be greater or equal than 1. 42 | gamma_seasonality: Factor to multiply to each degree calculation. Shape must 43 | be aligned with the number of degrees. 44 | frequency: Frecuency of the seasonality be in computed. 45 | 46 | Returns: 47 | An array with the seasonality values. 48 | """ 49 | inner_value = seasonality_arange * 2 * jnp.pi * degrees_arange / frequency 50 | season_matrix_sin = jnp.sin(inner_value) 51 | season_matrix_cos = jnp.cos(inner_value) 52 | season_matrix = jnp.concatenate([ 53 | jnp.expand_dims(a=season_matrix_sin, axis=-1), 54 | jnp.expand_dims(a=season_matrix_cos, axis=-1) 55 | ], 56 | axis=-1) 57 | return jnp.einsum("tds, ds -> t", season_matrix, gamma_seasonality) 58 | 59 | 60 | def sinusoidal_seasonality( 61 | data: jnp.ndarray, 62 | custom_priors: Mapping[str, dist.Distribution], 63 | *, 64 | degrees_seasonality: int = 2, 65 | frequency: int = 52, 66 | ) -> jnp.ndarray: 67 | """Calculates cyclic variation seasonality. 68 | 69 | For detailed info check: 70 | https://en.wikipedia.org/wiki/Seasonality#Modeling 71 | 72 | Args: 73 | data: Data for which the seasonality will be modelled for. It is used to 74 | obtain the length of the time dimension, axis 0. 75 | custom_priors: The custom priors we want the model to take instead of 76 | default ones. 77 | degrees_seasonality: Number of degrees to use. Must be greater or equal than 78 | 1. 79 | frequency: Frecuency of the seasonality be in computed. By default is 52 for 80 | weekly data (52 weeks in a year). 81 | 82 | Returns: 83 | An array with the seasonality values. 84 | """ 85 | number_periods = data.shape[0] 86 | default_priors = priors.get_default_priors() 87 | n_geos = core_utils.get_number_geos(data=data) 88 | with numpyro.plate(name=f"{priors.GAMMA_SEASONALITY}_sin_cos_plate", size=2): 89 | with numpyro.plate( 90 | name=f"{priors.GAMMA_SEASONALITY}_plate", size=degrees_seasonality): 91 | gamma_seasonality = numpyro.sample( 92 | name=priors.GAMMA_SEASONALITY, 93 | fn=custom_priors.get(priors.GAMMA_SEASONALITY, 94 | default_priors[priors.GAMMA_SEASONALITY])) 95 | seasonality_arange = jnp.expand_dims(a=jnp.arange(number_periods), axis=-1) 96 | degrees_arange = jnp.arange(degrees_seasonality) 97 | seasonality_values = _sinusoidal_seasonality( 98 | seasonality_arange=seasonality_arange, 99 | degrees_arange=degrees_arange, 100 | frequency=frequency, 101 | gamma_seasonality=gamma_seasonality, 102 | ) 103 | if n_geos > 1: 104 | seasonality_values = jnp.expand_dims(seasonality_values, axis=-1) 105 | return seasonality_values 106 | 107 | 108 | def _intra_week_seasonality( 109 | data: jnp.ndarray, 110 | weekday: jnp.ndarray, 111 | ) -> jnp.ndarray: 112 | data_size = data.shape[0] 113 | return weekday[jnp.arange(data_size) % 7] 114 | 115 | 116 | def intra_week_seasonality( 117 | data: jnp.ndarray, 118 | custom_priors: Mapping[str, dist.Distribution], 119 | ) -> jnp.ndarray: 120 | """Models intra week seasonality. 121 | 122 | Args: 123 | data: Data for which the seasonality will be modelled for. It is used to 124 | obtain the length of the time dimension, axis 0. 125 | custom_priors: The custom priors we want the model to take instead of 126 | default ones. 127 | 128 | Returns: 129 | The contribution of the weekday seasonality. 130 | """ 131 | default_priors = priors.get_default_priors() 132 | with numpyro.plate(name=f"{priors.WEEKDAY}_plate", size=7): 133 | weekday = numpyro.sample( 134 | name=priors.WEEKDAY, 135 | fn=custom_priors.get(priors.WEEKDAY, default_priors[priors.WEEKDAY])) 136 | 137 | weekday_series = _intra_week_seasonality(data=data, weekday=weekday) 138 | 139 | if data.ndim == 3: # For geo model's case 140 | weekday_series = jnp.expand_dims(weekday_series, axis=-1) 141 | 142 | return weekday_series 143 | -------------------------------------------------------------------------------- /lightweight_mmm/core/time/seasonality_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for seasonality.""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpyro 20 | from numpyro import distributions as dist 21 | from numpyro import handlers 22 | 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | from lightweight_mmm.core import priors 26 | from lightweight_mmm.core.time import seasonality 27 | 28 | 29 | class SeasonalityTest(parameterized.TestCase): 30 | 31 | @parameterized.named_parameters([ 32 | dict( 33 | testcase_name="2_degrees", 34 | seasonality_arange_value=150, 35 | degrees_arange_shape=5, 36 | gamma_seasonality_shape=(5, 2), 37 | ), 38 | dict( 39 | testcase_name="10_degree", 40 | seasonality_arange_value=150, 41 | degrees_arange_shape=10, 42 | gamma_seasonality_shape=(10, 2), 43 | ), 44 | dict( 45 | testcase_name="1_degree", 46 | seasonality_arange_value=200, 47 | degrees_arange_shape=1, 48 | gamma_seasonality_shape=(1, 2), 49 | ), 50 | ]) 51 | def test_core_sinusoidal_seasonality_produces_correct_shape( 52 | self, seasonality_arange_value, degrees_arange_shape, 53 | gamma_seasonality_shape): 54 | seasonality_arange = jnp.expand_dims( 55 | jnp.arange(seasonality_arange_value), axis=-1) 56 | degrees_arange = jnp.arange(degrees_arange_shape) 57 | gamma_seasonality = jnp.ones(gamma_seasonality_shape) 58 | 59 | seasonality_values = seasonality._sinusoidal_seasonality( 60 | seasonality_arange=seasonality_arange, 61 | degrees_arange=degrees_arange, 62 | gamma_seasonality=gamma_seasonality, 63 | frequency=52, 64 | ) 65 | self.assertEqual(seasonality_values.shape, (seasonality_arange_value,)) 66 | 67 | @parameterized.named_parameters( 68 | dict( 69 | testcase_name="ten_degrees_national", 70 | data_shape=(500, 5), 71 | degrees_seasonality=10, 72 | expected_shape=(10, 500), 73 | ), 74 | dict( 75 | testcase_name="ten_degrees_geo", 76 | data_shape=(500, 5, 5), 77 | degrees_seasonality=10, 78 | expected_shape=(10, 500, 1), 79 | ), 80 | dict( 81 | testcase_name="one_degrees_national", 82 | data_shape=(500, 5), 83 | degrees_seasonality=1, 84 | expected_shape=(10, 500), 85 | ), 86 | dict( 87 | testcase_name="one_degrees_geo", 88 | data_shape=(500, 5, 5), 89 | degrees_seasonality=1, 90 | expected_shape=(10, 500, 1), 91 | ), 92 | ) 93 | def test_model_sinusoidal_seasonality_produces_correct_shape( 94 | self, data_shape, degrees_seasonality, expected_shape): 95 | 96 | def mock_model_function(data, degrees_seasonality, frequency): 97 | numpyro.deterministic( 98 | "seasonality", 99 | seasonality.sinusoidal_seasonality( 100 | data=data, 101 | degrees_seasonality=degrees_seasonality, 102 | custom_priors={}, 103 | frequency=frequency)) 104 | 105 | num_samples = 10 106 | data = jnp.ones(data_shape) 107 | kernel = numpyro.infer.NUTS(model=mock_model_function) 108 | mcmc = numpyro.infer.MCMC( 109 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 110 | rng_key = jax.random.PRNGKey(0) 111 | 112 | mcmc.run( 113 | rng_key, 114 | data=data, 115 | degrees_seasonality=degrees_seasonality, 116 | frequency=52, 117 | ) 118 | seasonality_values = mcmc.get_samples()["seasonality"] 119 | 120 | self.assertEqual(seasonality_values.shape, expected_shape) 121 | 122 | def test_sinusoidal_seasonality_custom_priors_are_taken_correctly(self): 123 | prior_name = priors.GAMMA_SEASONALITY 124 | expected_value1, expected_value2 = 5.2, 7.56 125 | custom_priors = { 126 | prior_name: 127 | dist.Kumaraswamy( 128 | concentration1=expected_value1, concentration0=expected_value2) 129 | } 130 | media = jnp.ones((10, 5, 5)) 131 | degrees_seasonality = 3 132 | frequency = 365 133 | 134 | trace_handler = handlers.trace( 135 | handlers.seed(seasonality.sinusoidal_seasonality, rng_seed=0)) 136 | trace = trace_handler.get_trace( 137 | data=media, 138 | custom_priors=custom_priors, 139 | degrees_seasonality=degrees_seasonality, 140 | frequency=frequency, 141 | ) 142 | values_and_dists = { 143 | name: site["fn"] for name, site in trace.items() if "fn" in site 144 | } 145 | 146 | used_distribution = values_and_dists[prior_name] 147 | if isinstance(used_distribution, dist.ExpandedDistribution): 148 | used_distribution = used_distribution.base_dist 149 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 150 | self.assertEqual(used_distribution.concentration0, expected_value2) 151 | self.assertEqual(used_distribution.concentration1, expected_value1) 152 | 153 | @parameterized.named_parameters( 154 | dict( 155 | testcase_name="ten_degrees", 156 | data_shape=(500, 3), 157 | expected_shape=(10, 500), 158 | ), 159 | dict( 160 | testcase_name="five_degrees", 161 | data_shape=(500, 3, 5), 162 | expected_shape=(10, 500, 1), 163 | ), 164 | ) 165 | def test_intra_week_seasonality_produces_correct_shape( 166 | self, data_shape, expected_shape): 167 | 168 | def mock_model_function(data): 169 | numpyro.deterministic( 170 | "intra_week", 171 | seasonality.intra_week_seasonality( 172 | data=data, 173 | custom_priors={}, 174 | )) 175 | 176 | num_samples = 10 177 | data = jnp.ones(data_shape) 178 | kernel = numpyro.infer.NUTS(model=mock_model_function) 179 | mcmc = numpyro.infer.MCMC( 180 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 181 | rng_key = jax.random.PRNGKey(0) 182 | 183 | mcmc.run(rng_key, data=data) 184 | seasonality_values = mcmc.get_samples()["intra_week"] 185 | 186 | self.assertEqual(seasonality_values.shape, expected_shape) 187 | 188 | def test_intra_week_seasonality_custom_priors_are_taken_correctly(self): 189 | prior_name = priors.WEEKDAY 190 | expected_value1, expected_value2 = 5.2, 7.56 191 | custom_priors = { 192 | prior_name: 193 | dist.Kumaraswamy( 194 | concentration1=expected_value1, concentration0=expected_value2) 195 | } 196 | media = jnp.ones((10, 5, 5)) 197 | 198 | trace_handler = handlers.trace( 199 | handlers.seed(seasonality.intra_week_seasonality, rng_seed=0)) 200 | trace = trace_handler.get_trace( 201 | data=media, 202 | custom_priors=custom_priors, 203 | ) 204 | values_and_dists = { 205 | name: site["fn"] for name, site in trace.items() if "fn" in site 206 | } 207 | 208 | used_distribution = values_and_dists[prior_name] 209 | if isinstance(used_distribution, dist.ExpandedDistribution): 210 | used_distribution = used_distribution.base_dist 211 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 212 | self.assertEqual(used_distribution.concentration0, expected_value2) 213 | self.assertEqual(used_distribution.concentration1, expected_value1) 214 | 215 | 216 | if __name__ == "__main__": 217 | absltest.main() 218 | -------------------------------------------------------------------------------- /lightweight_mmm/core/time/trend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Core and modelling functions for trend.""" 16 | 17 | import functools 18 | from typing import Mapping 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpyro 23 | from numpyro import distributions as dist 24 | 25 | from lightweight_mmm.core import core_utils 26 | from lightweight_mmm.core import priors 27 | 28 | 29 | @jax.jit 30 | def _trend_with_exponent(coef_trend: jnp.ndarray, trend: jnp.ndarray, 31 | expo_trend: jnp.ndarray) -> jnp.ndarray: 32 | """Applies the coefficient and exponent to the trend to obtain trend values. 33 | 34 | Args: 35 | coef_trend: Coefficient to be multiplied by the trend. 36 | trend: Initial trend values. 37 | expo_trend: Exponent to be applied to the trend. 38 | 39 | Returns: 40 | The trend values generated. 41 | """ 42 | return coef_trend * trend**expo_trend 43 | 44 | 45 | def trend_with_exponent( 46 | data: jnp.ndarray, 47 | custom_priors: Mapping[str, dist.Distribution], 48 | ) -> jnp.ndarray: 49 | """Trend with exponent for curvature. 50 | 51 | Args: 52 | data: Data for which trend will be created. 53 | custom_priors: The custom priors we want the model to take instead of the 54 | default ones. See our custom_priors documentation for details about the 55 | API and possible options. 56 | 57 | Returns: 58 | The values of the trend. 59 | """ 60 | default_priors = priors.get_default_priors() 61 | n_geos = core_utils.get_number_geos(data=data) 62 | # TODO(): Force all geos to have the same trend sign. 63 | with numpyro.plate(name=f"{priors.COEF_TREND}_plate", size=n_geos): 64 | coef_trend = numpyro.sample( 65 | name=priors.COEF_TREND, 66 | fn=custom_priors.get(priors.COEF_TREND, 67 | default_priors[priors.COEF_TREND])) 68 | 69 | expo_trend = numpyro.sample( 70 | name=priors.EXPO_TREND, 71 | fn=custom_priors.get(priors.EXPO_TREND, 72 | default_priors[priors.EXPO_TREND])) 73 | linear_trend = jnp.arange(data.shape[0]) 74 | if n_geos > 1: # For geo model's case 75 | linear_trend = jnp.expand_dims(linear_trend, axis=-1) 76 | return _trend_with_exponent( 77 | coef_trend=coef_trend, trend=linear_trend, expo_trend=expo_trend) 78 | 79 | 80 | @functools.partial(jax.jit, static_argnames=("number_periods",)) 81 | def _dynamic_trend( 82 | number_periods: int, 83 | random_walk_level: jnp.ndarray, 84 | random_walk_slope: jnp.ndarray, 85 | initial_level: jnp.ndarray, 86 | initial_slope: jnp.ndarray, 87 | variance_level: jnp.ndarray, 88 | variance_slope: jnp.ndarray, 89 | ) -> jnp.ndarray: 90 | """Calculates dynamic trend using local linear trend method. 91 | 92 | More details about this function can be found in: 93 | https://storage.googleapis.com/pub-tools-public-publication-data/pdf/41854.pdf 94 | 95 | Args: 96 | number_periods: Number of time periods in the data. 97 | random_walk_level: Random walk of level from sample. 98 | random_walk_slope: Random walk of slope from sample. 99 | initial_level: The initial value for level in local linear trend model. 100 | initial_slope: The initial value for slope in local linear trend model. 101 | variance_level: The variance of the expected increase in level between time. 102 | variance_slope: The variance of the expected increase in slope between time. 103 | 104 | Returns: 105 | The dynamic trend values for the given data with the given parameters. 106 | """ 107 | # Simulate gaussian random walk of level with initial level. 108 | random_level = variance_level * random_walk_level 109 | random_level_with_initial_level = jnp.concatenate( 110 | [jnp.array([random_level[0] + initial_level]), random_level[1:]]) 111 | level_trend_t = jnp.cumsum(random_level_with_initial_level, axis=0) 112 | # Simulate gaussian random walk of slope with initial slope. 113 | random_slope = variance_slope * random_walk_slope 114 | random_slope_with_initial_slope = jnp.concatenate( 115 | [jnp.array([random_slope[0] + initial_slope]), random_slope[1:]]) 116 | slope_trend_t = jnp.cumsum(random_slope_with_initial_slope, axis=0) 117 | # Accumulate sum of slope series to address latent variable slope in function 118 | # level_t = level_t-1 + slope_t-1. 119 | initial_zero_shape = [(1, 0)] if slope_trend_t.ndim == 1 else [(1, 0), (0, 0)] 120 | slope_trend_cumsum = jnp.pad( 121 | jnp.cumsum(slope_trend_t, axis=0)[:number_periods - 1], 122 | initial_zero_shape, mode="constant", constant_values=0) 123 | return level_trend_t + slope_trend_cumsum 124 | 125 | 126 | def dynamic_trend( 127 | geo_size: int, 128 | data_size: int, 129 | is_trend_prediction: bool, 130 | custom_priors: Mapping[str, dist.Distribution], 131 | ) -> jnp.ndarray: 132 | """Generates the dynamic trend to capture the baseline of kpi. 133 | 134 | Args: 135 | geo_size: Number of geos in the model. 136 | data_size: Number of time samples in the model. 137 | is_trend_prediction: Whether it is used for prediction or fitting. 138 | custom_priors: The custom priors we want the model to take instead of the 139 | default ones. See our custom_priors documentation for details about the 140 | API and possible options. 141 | 142 | Returns: 143 | Jax array with trend for each time t. 144 | """ 145 | default_priors = priors.get_default_priors() 146 | if not is_trend_prediction: 147 | random_walk_level = numpyro.sample("random_walk_level", 148 | fn=dist.Normal(), 149 | sample_shape=(data_size, 1)) 150 | random_walk_slope = numpyro.sample("random_walk_slope", 151 | fn=dist.Normal(), 152 | sample_shape=(data_size, 1)) 153 | else: 154 | random_walk_level = numpyro.sample("random_walk_level_prediction", 155 | fn=dist.Normal(), 156 | sample_shape=(data_size, 1)) 157 | 158 | random_walk_slope = numpyro.sample("random_walk_slope_prediction", 159 | fn=dist.Normal(), 160 | sample_shape=(data_size, 1)) 161 | 162 | with numpyro.plate( 163 | name=f"{priors.DYNAMIC_TREND_INITIAL_LEVEL}_plate", size=geo_size): 164 | trend_initial_level = numpyro.sample( 165 | name=priors.DYNAMIC_TREND_INITIAL_LEVEL, 166 | fn=custom_priors.get( 167 | priors.DYNAMIC_TREND_INITIAL_LEVEL, 168 | default_priors[priors.DYNAMIC_TREND_INITIAL_LEVEL])) 169 | 170 | with numpyro.plate( 171 | name=f"{priors.DYNAMIC_TREND_INITIAL_SLOPE}_plate", size=geo_size): 172 | trend_initial_slope = numpyro.sample( 173 | name=priors.DYNAMIC_TREND_INITIAL_SLOPE, 174 | fn=custom_priors.get( 175 | priors.DYNAMIC_TREND_INITIAL_SLOPE, 176 | default_priors[priors.DYNAMIC_TREND_INITIAL_SLOPE])) 177 | 178 | with numpyro.plate( 179 | name=f"{priors.DYNAMIC_TREND_LEVEL_VARIANCE}_plate", size=geo_size): 180 | trend_level_variance = numpyro.sample( 181 | name=priors.DYNAMIC_TREND_LEVEL_VARIANCE, 182 | fn=custom_priors.get( 183 | priors.DYNAMIC_TREND_LEVEL_VARIANCE, 184 | default_priors[priors.DYNAMIC_TREND_LEVEL_VARIANCE])) 185 | 186 | with numpyro.plate( 187 | name=f"{priors.DYNAMIC_TREND_SLOPE_VARIANCE}_plate", size=geo_size): 188 | trend_slope_variance = numpyro.sample( 189 | name=priors.DYNAMIC_TREND_SLOPE_VARIANCE, 190 | fn=custom_priors.get( 191 | priors.DYNAMIC_TREND_SLOPE_VARIANCE, 192 | default_priors[priors.DYNAMIC_TREND_SLOPE_VARIANCE])) 193 | 194 | if geo_size == 1: # National level model case. 195 | random_walk_level = jnp.squeeze(random_walk_level) 196 | random_walk_slope = jnp.squeeze(random_walk_slope) 197 | trend_initial_level = jnp.squeeze(trend_initial_level) 198 | trend_initial_slope = jnp.squeeze(trend_initial_slope) 199 | trend_level_variance = jnp.squeeze(trend_level_variance) 200 | trend_slope_variance = jnp.squeeze(trend_slope_variance) 201 | 202 | return _dynamic_trend( 203 | number_periods=data_size, 204 | random_walk_level=random_walk_level, 205 | random_walk_slope=random_walk_slope, 206 | initial_level=trend_initial_level, 207 | initial_slope=trend_initial_slope, 208 | variance_level=trend_level_variance, 209 | variance_slope=trend_slope_variance) 210 | -------------------------------------------------------------------------------- /lightweight_mmm/core/time/trend_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for trend.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import numpyro 23 | from numpyro import distributions as dist 24 | from numpyro import handlers 25 | 26 | from lightweight_mmm.core import core_utils 27 | from lightweight_mmm.core import priors 28 | from lightweight_mmm.core.time import trend 29 | 30 | 31 | class TrendTest(parameterized.TestCase): 32 | 33 | @parameterized.named_parameters([ 34 | dict( 35 | testcase_name="national", 36 | coef_trend_shape=(), 37 | trend_length=150, 38 | expo_trend_shape=(), 39 | ), 40 | dict( 41 | testcase_name="geo", 42 | coef_trend_shape=(5,), 43 | trend_length=150, 44 | expo_trend_shape=(), 45 | ), 46 | ]) 47 | def test_core_trend_with_exponent_produces_correct_shape( 48 | self, coef_trend_shape, trend_length, expo_trend_shape): 49 | coef_trend = jnp.ones(coef_trend_shape) 50 | linear_trend = jnp.arange(trend_length) 51 | if coef_trend.ndim == 1: # For geo model's case 52 | linear_trend = jnp.expand_dims(linear_trend, axis=-1) 53 | expo_trend = jnp.ones(expo_trend_shape) 54 | 55 | trend_values = trend._trend_with_exponent( 56 | coef_trend=coef_trend, trend=linear_trend, expo_trend=expo_trend) 57 | 58 | self.assertEqual(trend_values.shape, 59 | (linear_trend.shape[0], *coef_trend_shape)) 60 | 61 | @parameterized.named_parameters([ 62 | dict(testcase_name="national", data_shape=(150, 3)), 63 | dict(testcase_name="geo", data_shape=(150, 3, 5)), 64 | ]) 65 | def test_trend_with_exponent_produces_correct_shape(self, data_shape): 66 | 67 | def mock_model_function(data): 68 | numpyro.deterministic( 69 | "trend", trend.trend_with_exponent( 70 | data=data, 71 | custom_priors={}, 72 | )) 73 | 74 | num_samples = 10 75 | data = jnp.ones(data_shape) 76 | kernel = numpyro.infer.NUTS(model=mock_model_function) 77 | mcmc = numpyro.infer.MCMC( 78 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 79 | rng_key = jax.random.PRNGKey(0) 80 | coef_expected_shape = () if data.ndim == 2 else (data.shape[2],) 81 | 82 | mcmc.run(rng_key, data=data) 83 | trend_values = mcmc.get_samples()["trend"] 84 | 85 | self.assertEqual(trend_values.shape, 86 | (num_samples, data.shape[0], *coef_expected_shape)) 87 | 88 | @parameterized.named_parameters( 89 | dict( 90 | testcase_name=f"model_{priors.COEF_TREND}", 91 | prior_name=priors.COEF_TREND, 92 | ), 93 | dict( 94 | testcase_name=f"model_{priors.EXPO_TREND}", 95 | prior_name=priors.EXPO_TREND, 96 | ), 97 | ) 98 | def test_trend_with_exponent_custom_priors_are_taken_correctly( 99 | self, prior_name): 100 | expected_value1, expected_value2 = 5.2, 7.56 101 | custom_priors = { 102 | prior_name: 103 | dist.Kumaraswamy( 104 | concentration1=expected_value1, concentration0=expected_value2) 105 | } 106 | media = jnp.ones((10, 5, 5)) 107 | 108 | trace_handler = handlers.trace( 109 | handlers.seed(trend.trend_with_exponent, rng_seed=0)) 110 | trace = trace_handler.get_trace( 111 | data=media, 112 | custom_priors=custom_priors, 113 | ) 114 | values_and_dists = { 115 | name: site["fn"] for name, site in trace.items() if "fn" in site 116 | } 117 | 118 | used_distribution = values_and_dists[prior_name] 119 | if isinstance(used_distribution, dist.ExpandedDistribution): 120 | used_distribution = used_distribution.base_dist 121 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 122 | self.assertEqual(used_distribution.concentration0, expected_value2) 123 | self.assertEqual(used_distribution.concentration1, expected_value1) 124 | 125 | @parameterized.named_parameters([ 126 | dict( 127 | testcase_name="dynamic_trend_national_shape", 128 | number_periods=100, 129 | initial_level_shape=(), 130 | initial_slope_shape=(), 131 | variance_level_shape=(), 132 | variance_slope_shape=(), 133 | ), 134 | dict( 135 | testcase_name="dynamic_trend_geo_shape", 136 | number_periods=100, 137 | initial_level_shape=(2,), 138 | initial_slope_shape=(2,), 139 | variance_level_shape=(2,), 140 | variance_slope_shape=(2,), 141 | ), 142 | ]) 143 | def test_core_dynamic_trend_produces_correct_shape( 144 | self, number_periods, initial_level_shape, initial_slope_shape, 145 | variance_level_shape, variance_slope_shape): 146 | initial_level = jnp.ones(initial_level_shape) 147 | initial_slope = jnp.ones(initial_slope_shape) 148 | variance_level = jnp.ones(variance_level_shape) 149 | variance_slope = jnp.ones(variance_slope_shape) 150 | random_walk_level = jnp.arange(number_periods) 151 | random_walk_slope = jnp.arange(number_periods) 152 | if initial_level.ndim == 1: # For geo model's case 153 | random_walk_level = jnp.expand_dims(random_walk_level, axis=-1) 154 | random_walk_slope = jnp.expand_dims(random_walk_slope, axis=-1) 155 | 156 | dynamic_trend_values = trend._dynamic_trend( 157 | number_periods=number_periods, 158 | random_walk_level=random_walk_level, 159 | random_walk_slope=random_walk_slope, 160 | initial_level=initial_level, 161 | initial_slope=initial_slope, 162 | variance_level=variance_level, 163 | variance_slope=variance_slope, 164 | ) 165 | 166 | self.assertEqual(dynamic_trend_values.shape, 167 | (number_periods, *initial_level_shape)) 168 | 169 | def test_core_dynamic_trend_produces_correct_value(self): 170 | number_periods = 5 171 | initial_level = jnp.ones(()) 172 | initial_slope = jnp.ones(()) 173 | variance_level = jnp.ones(()) 174 | variance_slope = jnp.ones(()) 175 | random_walk_level = jnp.arange(number_periods) 176 | random_walk_slope = jnp.arange(number_periods) 177 | dynamic_trend_expected_value = jnp.array([1, 3, 7, 14, 25]) 178 | 179 | dynamic_trend_values = trend._dynamic_trend( 180 | number_periods=number_periods, 181 | random_walk_level=random_walk_level, 182 | random_walk_slope=random_walk_slope, 183 | initial_level=initial_level, 184 | initial_slope=initial_slope, 185 | variance_level=variance_level, 186 | variance_slope=variance_slope, 187 | ) 188 | 189 | np.testing.assert_array_equal(dynamic_trend_values, 190 | dynamic_trend_expected_value) 191 | 192 | @parameterized.named_parameters([ 193 | dict( 194 | testcase_name="national_with_prediction_is_true", 195 | data_shape=(100, 3), 196 | is_trend_prediction=True), 197 | dict( 198 | testcase_name="geo_with_prediction_is_true", 199 | data_shape=(150, 3, 5), 200 | is_trend_prediction=True), 201 | dict( 202 | testcase_name="national_with_prediction_is_false", 203 | data_shape=(100, 3), 204 | is_trend_prediction=False), 205 | dict( 206 | testcase_name="geo_with_prediction_is_false", 207 | data_shape=(150, 3, 5), 208 | is_trend_prediction=False), 209 | ]) 210 | def test_dynamic_trend_produces_correct_shape( 211 | self, data_shape, is_trend_prediction): 212 | 213 | def mock_model_function(geo_size, data_size): 214 | numpyro.deterministic( 215 | "trend", trend.dynamic_trend( 216 | geo_size=geo_size, 217 | data_size=data_size, 218 | is_trend_prediction=is_trend_prediction, 219 | custom_priors={}, 220 | )) 221 | num_samples = 10 222 | data = jnp.ones(data_shape) 223 | geo_size = core_utils.get_number_geos(data) 224 | kernel = numpyro.infer.NUTS(model=mock_model_function) 225 | mcmc = numpyro.infer.MCMC( 226 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 227 | rng_key = jax.random.PRNGKey(0) 228 | coef_expected_shape = core_utils.get_geo_shape(data) 229 | 230 | mcmc.run(rng_key, geo_size=geo_size, data_size=data_shape[0]) 231 | trend_values = mcmc.get_samples()["trend"] 232 | 233 | self.assertEqual(trend_values.shape, 234 | (num_samples, data.shape[0], *coef_expected_shape)) 235 | 236 | @parameterized.named_parameters( 237 | dict( 238 | testcase_name=f"model_{priors.DYNAMIC_TREND_INITIAL_LEVEL}", 239 | prior_name=priors.DYNAMIC_TREND_INITIAL_LEVEL, 240 | ), 241 | dict( 242 | testcase_name=f"model_{priors.DYNAMIC_TREND_INITIAL_SLOPE}", 243 | prior_name=priors.DYNAMIC_TREND_INITIAL_SLOPE, 244 | ), 245 | dict( 246 | testcase_name=f"model_{priors.DYNAMIC_TREND_LEVEL_VARIANCE}", 247 | prior_name=priors.DYNAMIC_TREND_LEVEL_VARIANCE, 248 | ), 249 | dict( 250 | testcase_name=f"model_{priors.DYNAMIC_TREND_SLOPE_VARIANCE}", 251 | prior_name=priors.DYNAMIC_TREND_SLOPE_VARIANCE, 252 | ), 253 | ) 254 | def test_core_dynamic_trend_custom_priors_are_taken_correctly( 255 | self, prior_name): 256 | expected_value1, expected_value2 = 5.2, 7.56 257 | custom_priors = { 258 | prior_name: 259 | dist.Kumaraswamy( 260 | concentration1=expected_value1, concentration0=expected_value2) 261 | } 262 | geo_size = 1 263 | data_size = 10 264 | trace_handler = handlers.trace( 265 | handlers.seed(trend.dynamic_trend, rng_seed=0)) 266 | trace = trace_handler.get_trace( 267 | geo_size=geo_size, 268 | data_size=data_size, 269 | is_trend_prediction=False, 270 | custom_priors=custom_priors, 271 | ) 272 | values_and_dists = { 273 | name: site["fn"] for name, site in trace.items() if "fn" in site 274 | } 275 | 276 | used_distribution = values_and_dists[prior_name] 277 | if isinstance(used_distribution, dist.ExpandedDistribution): 278 | used_distribution = used_distribution.base_dist 279 | 280 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 281 | self.assertEqual(used_distribution.concentration0, expected_value2) 282 | self.assertEqual(used_distribution.concentration1, expected_value1) 283 | 284 | @parameterized.named_parameters([ 285 | dict( 286 | testcase_name="trend_prediction_is_true", 287 | is_trend_prediction=True, 288 | expected_trend_parameter=[ 289 | "random_walk_level_prediction", "random_walk_slope_prediction"] 290 | ), 291 | dict( 292 | testcase_name="trend_prediction_is_false", 293 | is_trend_prediction=False, 294 | expected_trend_parameter=[ 295 | "random_walk_level", "random_walk_slope"] 296 | ), 297 | ]) 298 | def test_dynamic_trend_is_trend_prediction_produuce_correct_parameter_names( 299 | self, is_trend_prediction, expected_trend_parameter): 300 | 301 | def mock_model_function(geo_size, data_size): 302 | numpyro.deterministic( 303 | "trend", trend.dynamic_trend( 304 | geo_size=geo_size, 305 | data_size=data_size, 306 | is_trend_prediction=is_trend_prediction, 307 | custom_priors={}, 308 | )) 309 | num_samples = 10 310 | data_shape = (10, 3) 311 | data = jnp.ones(data_shape) 312 | geo_size = core_utils.get_number_geos(data) 313 | kernel = numpyro.infer.NUTS(model=mock_model_function) 314 | mcmc = numpyro.infer.MCMC( 315 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 316 | rng_key = jax.random.PRNGKey(0) 317 | 318 | mcmc.run(rng_key, geo_size=geo_size, data_size=data_shape[0]) 319 | trend_parameter = [ 320 | parameter for parameter, _ in mcmc.get_samples().items() 321 | if parameter.startswith("random_walk")] 322 | 323 | self.assertEqual(trend_parameter, expected_trend_parameter) 324 | 325 | if __name__ == "__main__": 326 | absltest.main() 327 | -------------------------------------------------------------------------------- /lightweight_mmm/core/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /lightweight_mmm/core/transformations/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for identity transformations.""" 16 | 17 | from typing import Any 18 | import jax.numpy as jnp 19 | 20 | 21 | def identity_transform( 22 | data: jnp.ndarray, # pylint-ignore: unused-argument 23 | *args: Any, 24 | **kwargs: Any, 25 | ) -> jnp.ndarray: 26 | """Identity transform. Returns the main input as is.""" 27 | return data 28 | -------------------------------------------------------------------------------- /lightweight_mmm/core/transformations/lagging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Set of core and modelling lagging functions.""" 16 | 17 | import functools 18 | from typing import Mapping, Union 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpyro 23 | import numpyro.distributions as dist 24 | from lightweight_mmm.core import priors 25 | 26 | 27 | @functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1) 28 | def _carryover_convolve(data: jnp.ndarray, weights: jnp.ndarray, 29 | number_lags: int) -> jnp.ndarray: 30 | """Applies the convolution between the data and the weights for the carryover. 31 | 32 | Args: 33 | data: Input data. 34 | weights: Window weights for the carryover. 35 | number_lags: Number of lags the window has. 36 | 37 | Returns: 38 | The result values from convolving the data and the weights with padding. 39 | """ 40 | window = jnp.concatenate([jnp.zeros(number_lags - 1), weights]) 41 | return jax.scipy.signal.convolve(data, window, mode="same") / weights.sum() 42 | 43 | 44 | @functools.partial(jax.jit, static_argnames=("number_lags",)) 45 | def _carryover( 46 | data: jnp.ndarray, 47 | ad_effect_retention_rate: jnp.ndarray, 48 | peak_effect_delay: jnp.ndarray, 49 | number_lags: int, 50 | ) -> jnp.ndarray: 51 | """Calculates media carryover. 52 | 53 | More details about this function can be found in: 54 | https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf 55 | 56 | Args: 57 | data: Input data. It is expected that data has either 2 dimensions for 58 | national models and 3 for geo models. 59 | ad_effect_retention_rate: Retention rate of the advertisement effect. 60 | Default is 0.5. 61 | peak_effect_delay: Delay of the peak effect in the carryover function. 62 | Default is 1. 63 | number_lags: Number of lags to include in the carryover calculation. Default 64 | is 13. 65 | 66 | Returns: 67 | The carryover values for the given data with the given parameters. 68 | """ 69 | lags_arange = jnp.expand_dims( 70 | jnp.arange(number_lags, dtype=jnp.float32), axis=-1) 71 | convolve_func = _carryover_convolve 72 | if data.ndim == 3: 73 | # Since _carryover_convolve is already vmaped in the decorator we only need 74 | # to vmap it once here to handle the geo level data. We keep the windows bi 75 | # dimensional also for three dims data and vmap over only the extra data 76 | # dimension. 77 | convolve_func = jax.vmap( 78 | fun=_carryover_convolve, in_axes=(2, None, None), out_axes=2) 79 | weights = ad_effect_retention_rate**((lags_arange - peak_effect_delay)**2) 80 | return convolve_func(data, weights, number_lags) 81 | 82 | 83 | def carryover( 84 | data: jnp.ndarray, 85 | custom_priors: Mapping[str, dist.Distribution], 86 | *, 87 | number_lags: int = 13, 88 | prefix: str = "", 89 | ) -> jnp.ndarray: 90 | """Transforms the input data with the carryover function. 91 | 92 | Args: 93 | data: Media data to be transformed. It is expected to have 2 dims for 94 | national models and 3 for geo models. 95 | custom_priors: The custom priors we want the model to take instead of the 96 | default ones. 97 | number_lags: Number of lags for the carryover function. 98 | prefix: Prefix to use in the variable name for Numpyro. 99 | 100 | Returns: 101 | The transformed media data. 102 | """ 103 | default_priors = priors.get_default_priors() 104 | with numpyro.plate( 105 | name=f"{prefix}{priors.AD_EFFECT_RETENTION_RATE}_plate", 106 | size=data.shape[1]): 107 | ad_effect_retention_rate = numpyro.sample( 108 | name=f"{prefix}{priors.AD_EFFECT_RETENTION_RATE}", 109 | fn=custom_priors.get(priors.AD_EFFECT_RETENTION_RATE, 110 | default_priors[priors.AD_EFFECT_RETENTION_RATE])) 111 | 112 | with numpyro.plate( 113 | name=f"{prefix}{priors.PEAK_EFFECT_DELAY}_plate", size=data.shape[1]): 114 | peak_effect_delay = numpyro.sample( 115 | name=f"{prefix}{priors.PEAK_EFFECT_DELAY}", 116 | fn=custom_priors.get(priors.PEAK_EFFECT_DELAY, 117 | default_priors[priors.PEAK_EFFECT_DELAY])) 118 | 119 | return _carryover( 120 | data=data, 121 | ad_effect_retention_rate=ad_effect_retention_rate, 122 | peak_effect_delay=peak_effect_delay, 123 | number_lags=number_lags) 124 | 125 | 126 | @jax.jit 127 | def _adstock( 128 | data: jnp.ndarray, 129 | lag_weight: Union[float, jnp.ndarray] = .9, 130 | normalise: bool = True, 131 | ) -> jnp.ndarray: 132 | """Calculates the adstock value of a given array. 133 | 134 | To learn more about advertising lag: 135 | https://en.wikipedia.org/wiki/Advertising_adstock 136 | 137 | Args: 138 | data: Input array. 139 | lag_weight: lag_weight effect of the adstock function. Default is 0.9. 140 | normalise: Whether to normalise the output value. This normalization will 141 | divide the output values by (1 / (1 - lag_weight)). 142 | 143 | Returns: 144 | The adstock output of the input array. 145 | """ 146 | 147 | def adstock_internal( 148 | prev_adstock: jnp.ndarray, 149 | data: jnp.ndarray, 150 | lag_weight: Union[float, jnp.ndarray] = lag_weight, 151 | ) -> jnp.ndarray: 152 | adstock_value = prev_adstock * lag_weight + data 153 | return adstock_value, adstock_value# jax-ndarray 154 | 155 | _, adstock_values = jax.lax.scan( 156 | f=adstock_internal, init=data[0, ...], xs=data[1:, ...]) 157 | adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values]) 158 | return jax.lax.cond( 159 | normalise, 160 | lambda adstock_values: adstock_values / (1. / (1 - lag_weight)), 161 | lambda adstock_values: adstock_values, 162 | operand=adstock_values) 163 | 164 | 165 | def adstock( 166 | data: jnp.ndarray, 167 | custom_priors: Mapping[str, dist.Distribution], 168 | *, 169 | normalise: bool = True, 170 | prefix: str = "", 171 | ) -> jnp.ndarray: 172 | """Transforms the input data with the adstock function and exponent. 173 | 174 | Args: 175 | data: Media data to be transformed. It is expected to have 2 dims for 176 | national models and 3 for geo models. 177 | custom_priors: The custom priors we want the model to take instead of the 178 | default ones. The possible names of parameters for adstock and exponent 179 | are "lag_weight" and "exponent". 180 | normalise: Whether to normalise the output values. 181 | prefix: Prefix to use in the variable name for Numpyro. 182 | 183 | Returns: 184 | The transformed media data. 185 | """ 186 | default_priors = priors.get_default_priors() 187 | with numpyro.plate( 188 | name=f"{prefix}{priors.LAG_WEIGHT}_plate", size=data.shape[1]): 189 | lag_weight = numpyro.sample( 190 | name=f"{prefix}{priors.LAG_WEIGHT}", 191 | fn=custom_priors.get(priors.LAG_WEIGHT, 192 | default_priors[priors.LAG_WEIGHT])) 193 | 194 | if data.ndim == 3: 195 | lag_weight = jnp.expand_dims(lag_weight, axis=-1) 196 | 197 | return _adstock(data=data, lag_weight=lag_weight, normalise=normalise) 198 | -------------------------------------------------------------------------------- /lightweight_mmm/core/transformations/lagging_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lagging.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import numpyro 23 | from numpyro import handlers 24 | import numpyro.distributions as dist 25 | 26 | from lightweight_mmm.core import priors 27 | from lightweight_mmm.core.transformations import lagging 28 | 29 | 30 | class LaggingTest(parameterized.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | dict( 34 | testcase_name="national", 35 | data_shape=(150, 3), 36 | ad_effect_retention_rate_shape=(3,), 37 | peak_effect_delay_shape=(3,), 38 | number_lags=13, 39 | ), 40 | dict( 41 | testcase_name="geo", 42 | data_shape=(150, 3, 5), 43 | ad_effect_retention_rate_shape=(3,), 44 | peak_effect_delay_shape=(3,), 45 | number_lags=13, 46 | ), 47 | ) 48 | def test_core_carryover_produces_correct_shape( 49 | self, 50 | data_shape, 51 | ad_effect_retention_rate_shape, 52 | peak_effect_delay_shape, 53 | number_lags, 54 | ): 55 | data = jnp.ones(data_shape) 56 | ad_effect_retention_rate = jnp.ones(ad_effect_retention_rate_shape) 57 | peak_effect_delay = jnp.ones(peak_effect_delay_shape) 58 | 59 | output = lagging._carryover( 60 | data=data, 61 | ad_effect_retention_rate=ad_effect_retention_rate, 62 | peak_effect_delay=peak_effect_delay, 63 | number_lags=number_lags, 64 | ) 65 | 66 | self.assertEqual(output.shape, data_shape) 67 | 68 | @parameterized.named_parameters( 69 | dict( 70 | testcase_name="national", 71 | data_shape=(150, 3), 72 | ), 73 | dict( 74 | testcase_name="geo", 75 | data_shape=(150, 3, 5), 76 | ), 77 | ) 78 | def test_carryover_produces_correct_shape(self, data_shape): 79 | 80 | def mock_model_function(data, number_lags): 81 | numpyro.deterministic( 82 | "carryover", 83 | lagging.carryover( 84 | data=data, custom_priors={}, number_lags=number_lags)) 85 | 86 | num_samples = 10 87 | data = jnp.ones(data_shape) 88 | number_lags = 15 89 | kernel = numpyro.infer.NUTS(model=mock_model_function) 90 | mcmc = numpyro.infer.MCMC( 91 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 92 | rng_key = jax.random.PRNGKey(0) 93 | 94 | mcmc.run(rng_key, data=data, number_lags=number_lags) 95 | carryover_values = mcmc.get_samples()["carryover"] 96 | 97 | self.assertEqual(carryover_values.shape, (num_samples, *data.shape)) 98 | 99 | @parameterized.named_parameters( 100 | dict( 101 | testcase_name="ad_effect_retention_rate", 102 | prior_name=priors.AD_EFFECT_RETENTION_RATE, 103 | ), 104 | dict( 105 | testcase_name="peak_effect_delay", 106 | prior_name=priors.PEAK_EFFECT_DELAY, 107 | ), 108 | ) 109 | def test_carryover_custom_priors_are_taken_correctly(self, prior_name): 110 | expected_value1, expected_value2 = 5.2, 7.56 111 | custom_priors = { 112 | prior_name: 113 | dist.Kumaraswamy( 114 | concentration1=expected_value1, concentration0=expected_value2) 115 | } 116 | media = jnp.ones((10, 5, 5)) 117 | number_lags = 13 118 | 119 | trace_handler = handlers.trace(handlers.seed(lagging.carryover, rng_seed=0)) 120 | trace = trace_handler.get_trace( 121 | data=media, 122 | custom_priors=custom_priors, 123 | number_lags=number_lags, 124 | ) 125 | values_and_dists = { 126 | name: site["fn"] for name, site in trace.items() if "fn" in site 127 | } 128 | 129 | used_distribution = values_and_dists[prior_name] 130 | used_distribution = used_distribution.base_dist 131 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 132 | self.assertEqual(used_distribution.concentration0, expected_value2) 133 | self.assertEqual(used_distribution.concentration1, expected_value1) 134 | 135 | @parameterized.named_parameters( 136 | dict( 137 | testcase_name="national", 138 | data_shape=(150, 3), 139 | lag_weight_shape=(3,), 140 | ), 141 | dict( 142 | testcase_name="geo", 143 | data_shape=(150, 3, 5), 144 | lag_weight_shape=(3, 1), 145 | ), 146 | ) 147 | def test_core_adstock_produces_correct_shape(self, data_shape, 148 | lag_weight_shape): 149 | data = jnp.ones(data_shape) 150 | lag_weight = jnp.ones(lag_weight_shape) 151 | 152 | output = lagging._adstock(data=data, lag_weight=lag_weight) 153 | 154 | self.assertEqual(output.shape, data_shape) 155 | 156 | @parameterized.named_parameters( 157 | dict( 158 | testcase_name="national", 159 | data_shape=(150, 3), 160 | ), 161 | dict( 162 | testcase_name="geo", 163 | data_shape=(150, 3, 5), 164 | ), 165 | ) 166 | def test_adstock_produces_correct_shape(self, data_shape): 167 | 168 | def mock_model_function(data, normalise): 169 | numpyro.deterministic( 170 | "adstock", 171 | lagging.adstock(data=data, custom_priors={}, normalise=normalise)) 172 | 173 | num_samples = 10 174 | data = jnp.ones(data_shape) 175 | normalise = True 176 | kernel = numpyro.infer.NUTS(model=mock_model_function) 177 | mcmc = numpyro.infer.MCMC( 178 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 179 | rng_key = jax.random.PRNGKey(0) 180 | 181 | mcmc.run(rng_key, data=data, normalise=normalise) 182 | adstock_values = mcmc.get_samples()["adstock"] 183 | 184 | self.assertEqual(adstock_values.shape, (num_samples, *data.shape)) 185 | 186 | def test_adstock_custom_priors_are_taken_correctly(self): 187 | prior_name = priors.LAG_WEIGHT 188 | expected_value1, expected_value2 = 5.2, 7.56 189 | custom_priors = { 190 | prior_name: 191 | dist.Kumaraswamy( 192 | concentration1=expected_value1, concentration0=expected_value2) 193 | } 194 | data = jnp.ones((10, 5, 5)) 195 | 196 | trace_handler = handlers.trace(handlers.seed(lagging.adstock, rng_seed=0)) 197 | trace = trace_handler.get_trace( 198 | data=data, 199 | custom_priors=custom_priors, 200 | normalise=True, 201 | ) 202 | values_and_dists = { 203 | name: site["fn"] for name, site in trace.items() if "fn" in site 204 | } 205 | 206 | used_distribution = values_and_dists[prior_name] 207 | used_distribution = used_distribution.base_dist 208 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 209 | self.assertEqual(used_distribution.concentration0, expected_value2) 210 | self.assertEqual(used_distribution.concentration1, expected_value1) 211 | 212 | def test_adstock_zeros_stay_zeros(self): 213 | data = jnp.zeros((10, 5)) 214 | lag_weight = jnp.full(5, 0.5) 215 | 216 | generated_output = lagging._adstock(data=data, lag_weight=lag_weight) 217 | 218 | np.testing.assert_array_equal(generated_output, data) 219 | 220 | def test_carryover_zeros_stay_zeros(self): 221 | data = jnp.zeros((10, 5)) 222 | ad_effect_retention_rate = jnp.full(5, 0.5) 223 | peak_effect_delay = jnp.full(5, 0.5) 224 | 225 | generated_output = lagging._carryover( 226 | data=data, 227 | ad_effect_retention_rate=ad_effect_retention_rate, 228 | peak_effect_delay=peak_effect_delay, 229 | number_lags=7, 230 | ) 231 | 232 | np.testing.assert_array_equal(generated_output, data) 233 | 234 | 235 | if __name__ == "__main__": 236 | absltest.main() 237 | -------------------------------------------------------------------------------- /lightweight_mmm/core/transformations/saturation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Set of core and modelling saturation functions.""" 16 | 17 | from typing import Mapping 18 | import jax 19 | import jax.numpy as jnp 20 | import numpyro 21 | from numpyro import distributions as dist 22 | 23 | from lightweight_mmm.core import core_utils 24 | from lightweight_mmm.core import priors 25 | 26 | 27 | @jax.jit 28 | def _hill( 29 | data: jnp.ndarray, 30 | half_max_effective_concentration: jnp.ndarray, 31 | slope: jnp.ndarray, 32 | ) -> jnp.ndarray: 33 | """Calculates the hill function for a given array of values. 34 | 35 | Refer to the following link for detailed information on this equation: 36 | https://en.wikipedia.org/wiki/Hill_equation_(biochemistry) 37 | 38 | Args: 39 | data: Input data. 40 | half_max_effective_concentration: ec50 value for the hill function. 41 | slope: Slope of the hill function. 42 | 43 | Returns: 44 | The hill values for the respective input data. 45 | """ 46 | save_transform = core_utils.apply_exponent_safe( 47 | data=data / half_max_effective_concentration, exponent=-slope) 48 | return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform)) 49 | 50 | 51 | def hill( 52 | data: jnp.ndarray, 53 | custom_priors: Mapping[str, dist.Distribution], 54 | *, 55 | prefix: str = "", 56 | ) -> jnp.ndarray: 57 | """Transforms the input data with the adstock and hill functions. 58 | 59 | Args: 60 | data: Media data to be transformed. It is expected to have 2 dims for 61 | national models and 3 for geo models. 62 | custom_priors: The custom priors we want the model to take instead of the 63 | default ones. The possible names of parameters for hill_adstock and 64 | exponent are "lag_weight", "half_max_effective_concentration" and "slope". 65 | prefix: Prefix to use in the variable name for Numpyro. 66 | 67 | Returns: 68 | The transformed media data. 69 | """ 70 | default_priors = priors.get_default_priors() 71 | 72 | with numpyro.plate( 73 | name=f"{prefix}{priors.HALF_MAX_EFFECTIVE_CONCENTRATION}_plate", 74 | size=data.shape[1]): 75 | half_max_effective_concentration = numpyro.sample( 76 | name=f"{prefix}{priors.HALF_MAX_EFFECTIVE_CONCENTRATION}", 77 | fn=custom_priors.get( 78 | priors.HALF_MAX_EFFECTIVE_CONCENTRATION, 79 | default_priors[priors.HALF_MAX_EFFECTIVE_CONCENTRATION])) 80 | 81 | with numpyro.plate(name=f"{prefix}{priors.SLOPE}_plate", size=data.shape[1]): 82 | slope = numpyro.sample( 83 | name=f"{prefix}{priors.SLOPE}", 84 | fn=custom_priors.get(priors.SLOPE, default_priors[priors.SLOPE])) 85 | 86 | if data.ndim == 3: 87 | half_max_effective_concentration = jnp.expand_dims( 88 | half_max_effective_concentration, axis=-1) 89 | slope = jnp.expand_dims(slope, axis=-1) 90 | 91 | return _hill( 92 | data=data, 93 | half_max_effective_concentration=half_max_effective_concentration, 94 | slope=slope) 95 | 96 | 97 | def _exponent(data: jnp.ndarray, exponent_values: jnp.ndarray) -> jnp.ndarray: 98 | """Applies exponent to the given data.""" 99 | return core_utils.apply_exponent_safe(data=data, exponent=exponent_values) 100 | 101 | 102 | def exponent( 103 | data: jnp.ndarray, 104 | custom_priors: Mapping[str, dist.Distribution], 105 | *, 106 | prefix: str = "", 107 | ) -> jnp.ndarray: 108 | """Transforms the input data with the carryover function and exponent. 109 | 110 | Args: 111 | data: Media data to be transformed. It is expected to have 2 dims for 112 | national models and 3 for geo models. 113 | custom_priors: The custom priors we want the model to take instead of the 114 | default ones. 115 | prefix: Prefix to use in the variable name for Numpyro. 116 | 117 | Returns: 118 | The transformed media data. 119 | """ 120 | default_priors = priors.get_default_priors() 121 | 122 | with numpyro.plate( 123 | name=f"{prefix}{priors.EXPONENT}_plate", size=data.shape[1]): 124 | exponent_values = numpyro.sample( 125 | name=f"{prefix}{priors.EXPONENT}", 126 | fn=custom_priors.get(priors.EXPONENT, default_priors[priors.EXPONENT])) 127 | 128 | if data.ndim == 3: 129 | exponent_values = jnp.expand_dims(exponent_values, axis=-1) 130 | return _exponent(data=data, exponent_values=exponent_values) 131 | -------------------------------------------------------------------------------- /lightweight_mmm/core/transformations/saturation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for saturation.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import numpyro 23 | from numpyro import handlers 24 | import numpyro.distributions as dist 25 | 26 | from lightweight_mmm.core import priors 27 | from lightweight_mmm.core.transformations import saturation 28 | 29 | 30 | class SaturationTest(parameterized.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | dict( 34 | testcase_name="national", 35 | data_shape=(150, 3), 36 | half_max_effective_concentration_shape=(3,), 37 | slope_shape=(3,), 38 | ), 39 | dict( 40 | testcase_name="geo", 41 | data_shape=(150, 3, 5), 42 | half_max_effective_concentration_shape=(3, 1), 43 | slope_shape=(3, 1), 44 | ), 45 | ) 46 | def test_hill_core_produces_correct_shape( 47 | self, data_shape, half_max_effective_concentration_shape, slope_shape): 48 | data = jnp.ones(data_shape) 49 | half_max_effective_concentration = jnp.ones( 50 | half_max_effective_concentration_shape) 51 | slope = jnp.ones(slope_shape) 52 | 53 | output = saturation._hill( 54 | data=data, 55 | half_max_effective_concentration=half_max_effective_concentration, 56 | slope=slope, 57 | ) 58 | 59 | self.assertEqual(output.shape, data_shape) 60 | 61 | @parameterized.named_parameters( 62 | dict( 63 | testcase_name="national", 64 | data_shape=(150, 3), 65 | ), 66 | dict( 67 | testcase_name="geo", 68 | data_shape=(150, 3, 5), 69 | ), 70 | ) 71 | def test_hill_produces_correct_shape(self, data_shape): 72 | 73 | def mock_model_function(data): 74 | numpyro.deterministic("hill", 75 | saturation.hill(data=data, custom_priors={})) 76 | 77 | num_samples = 10 78 | data = jnp.ones(data_shape) 79 | kernel = numpyro.infer.NUTS(model=mock_model_function) 80 | mcmc = numpyro.infer.MCMC( 81 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 82 | rng_key = jax.random.PRNGKey(0) 83 | 84 | mcmc.run(rng_key, data=data) 85 | output_values = mcmc.get_samples()["hill"] 86 | 87 | self.assertEqual(output_values.shape, (num_samples, *data.shape)) 88 | 89 | @parameterized.named_parameters( 90 | dict( 91 | testcase_name="half_max_effective_concentration", 92 | prior_name=priors.HALF_MAX_EFFECTIVE_CONCENTRATION, 93 | ), 94 | dict( 95 | testcase_name="slope", 96 | prior_name=priors.SLOPE, 97 | ), 98 | ) 99 | def test_hill_custom_priors_are_taken_correctly(self, prior_name): 100 | expected_value1, expected_value2 = 5.2, 7.56 101 | custom_priors = { 102 | prior_name: 103 | dist.Kumaraswamy( 104 | concentration1=expected_value1, concentration0=expected_value2) 105 | } 106 | media = jnp.ones((10, 5, 5)) 107 | 108 | trace_handler = handlers.trace(handlers.seed(saturation.hill, rng_seed=0)) 109 | trace = trace_handler.get_trace( 110 | data=media, 111 | custom_priors=custom_priors, 112 | ) 113 | values_and_dists = { 114 | name: site["fn"] for name, site in trace.items() if "fn" in site 115 | } 116 | 117 | used_distribution = values_and_dists[prior_name] 118 | used_distribution = used_distribution.base_dist 119 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 120 | self.assertEqual(used_distribution.concentration0, expected_value2) 121 | self.assertEqual(used_distribution.concentration1, expected_value1) 122 | 123 | def test_exponent_core_produces_correct_shape(self): 124 | pass 125 | 126 | @parameterized.named_parameters( 127 | dict( 128 | testcase_name="national", 129 | data_shape=(150, 3), 130 | ), 131 | dict( 132 | testcase_name="geo", 133 | data_shape=(150, 3, 5), 134 | ), 135 | ) 136 | def test_exponent_produces_correct_shape(self, data_shape): 137 | 138 | def mock_model_function(data): 139 | numpyro.deterministic("outer_exponent", 140 | saturation.exponent(data=data, custom_priors={})) 141 | 142 | num_samples = 10 143 | data = jnp.ones(data_shape) 144 | kernel = numpyro.infer.NUTS(model=mock_model_function) 145 | mcmc = numpyro.infer.MCMC( 146 | sampler=kernel, num_warmup=10, num_samples=num_samples, num_chains=1) 147 | rng_key = jax.random.PRNGKey(0) 148 | 149 | mcmc.run(rng_key, data=data) 150 | output_values = mcmc.get_samples()["outer_exponent"] 151 | 152 | self.assertEqual(output_values.shape, (num_samples, *data.shape)) 153 | 154 | def test_exponent_custom_priors_are_taken_correctly(self): 155 | prior_name = priors.EXPONENT 156 | expected_value1, expected_value2 = 5.2, 7.56 157 | custom_priors = { 158 | prior_name: 159 | dist.Kumaraswamy( 160 | concentration1=expected_value1, concentration0=expected_value2) 161 | } 162 | media = jnp.ones((10, 5, 5)) 163 | 164 | trace_handler = handlers.trace( 165 | handlers.seed(saturation.exponent, rng_seed=0)) 166 | trace = trace_handler.get_trace( 167 | data=media, 168 | custom_priors=custom_priors, 169 | ) 170 | values_and_dists = { 171 | name: site["fn"] for name, site in trace.items() if "fn" in site 172 | } 173 | 174 | used_distribution = values_and_dists[prior_name] 175 | used_distribution = used_distribution.base_dist 176 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 177 | self.assertEqual(used_distribution.concentration0, expected_value2) 178 | self.assertEqual(used_distribution.concentration1, expected_value1) 179 | 180 | def test_hill_zeros_stay_zeros(self): 181 | data = jnp.zeros((10, 5)) 182 | half_max_effective_concentration = jnp.full(5, 0.5) 183 | slope = jnp.full(5, 0.5) 184 | 185 | generated_output = saturation._hill( 186 | data=data, 187 | half_max_effective_concentration=half_max_effective_concentration, 188 | slope=slope, 189 | ) 190 | 191 | np.testing.assert_array_equal(generated_output, data) 192 | 193 | def test_exponent_zeros_stay_zero(self): 194 | data = jnp.zeros((10, 5)) 195 | exponent_values = jnp.full(5, 0.5) 196 | 197 | generated_output = saturation._exponent( 198 | data=data, 199 | exponent_values=exponent_values, 200 | ) 201 | 202 | np.testing.assert_array_equal(generated_output, data) 203 | 204 | 205 | if __name__ == "__main__": 206 | absltest.main() 207 | -------------------------------------------------------------------------------- /lightweight_mmm/media_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Media transformations for accounting for lagging or media effects.""" 16 | 17 | import functools 18 | from typing import Union 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | @functools.partial(jax.jit, static_argnums=[0, 1]) 25 | def calculate_seasonality( 26 | number_periods: int, 27 | degrees: int, 28 | gamma_seasonality: Union[int, float, jnp.ndarray], 29 | frequency: int = 52, 30 | ) -> jnp.ndarray: 31 | """Calculates cyclic variation seasonality using Fourier terms. 32 | 33 | For detailed info check: 34 | https://en.wikipedia.org/wiki/Seasonality#Modeling 35 | 36 | Args: 37 | number_periods: Number of seasonal periods in the data. Eg. for 1 year of 38 | seasonal data it will be 52, for 3 years of the same kind 156. 39 | degrees: Number of degrees to use. Must be greater or equal than 1. 40 | gamma_seasonality: Factor to multiply to each degree calculation. Shape must 41 | be aligned with the number of degrees. 42 | frequency: Frequency of the seasonality being computed. By default is 52 for 43 | weekly data (52 weeks in a year). 44 | 45 | Returns: 46 | An array with the seasonality values. 47 | """ 48 | 49 | seasonality_range = jnp.expand_dims(a=jnp.arange(number_periods), axis=-1) 50 | degrees_range = jnp.arange(1, degrees+1) 51 | inner_value = seasonality_range * 2 * jnp.pi * degrees_range / frequency 52 | season_matrix_sin = jnp.sin(inner_value) 53 | season_matrix_cos = jnp.cos(inner_value) 54 | season_matrix = jnp.concatenate([ 55 | jnp.expand_dims(a=season_matrix_sin, axis=-1), 56 | jnp.expand_dims(a=season_matrix_cos, axis=-1) 57 | ], 58 | axis=-1) 59 | return (season_matrix * gamma_seasonality).sum(axis=2).sum(axis=1) 60 | 61 | 62 | @jax.jit 63 | def adstock(data: jnp.ndarray, 64 | lag_weight: float = .9, 65 | normalise: bool = True) -> jnp.ndarray: 66 | """Calculates the adstock value of a given array. 67 | 68 | To learn more about advertising lag: 69 | https://en.wikipedia.org/wiki/Advertising_adstock 70 | 71 | Args: 72 | data: Input array. 73 | lag_weight: lag_weight effect of the adstock function. Default is 0.9. 74 | normalise: Whether to normalise the output value. This normalization will 75 | divide the output values by (1 / (1 - lag_weight)). 76 | 77 | Returns: 78 | The adstock output of the input array. 79 | """ 80 | 81 | def adstock_internal(prev_adstock: jnp.ndarray, 82 | data: jnp.ndarray, 83 | lag_weight: float = lag_weight) -> jnp.ndarray: 84 | adstock_value = prev_adstock * lag_weight + data 85 | return adstock_value, adstock_value# jax-ndarray 86 | 87 | _, adstock_values = jax.lax.scan( 88 | f=adstock_internal, init=data[0, ...], xs=data[1:, ...]) 89 | adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values]) 90 | return jax.lax.cond( 91 | normalise, 92 | lambda adstock_values: adstock_values / (1. / (1 - lag_weight)), 93 | lambda adstock_values: adstock_values, 94 | operand=adstock_values) 95 | 96 | 97 | @jax.jit 98 | def hill(data: jnp.ndarray, half_max_effective_concentration: jnp.ndarray, 99 | slope: jnp.ndarray) -> jnp.ndarray: 100 | """Calculates the hill function for a given array of values. 101 | 102 | Refer to the following link for detailed information on this equation: 103 | https://en.wikipedia.org/wiki/Hill_equation_(biochemistry) 104 | 105 | Args: 106 | data: Input data. 107 | half_max_effective_concentration: ec50 value for the hill function. 108 | slope: Slope of the hill function. 109 | 110 | Returns: 111 | The hill values for the respective input data. 112 | """ 113 | save_transform = apply_exponent_safe( 114 | data=data / half_max_effective_concentration, exponent=-slope) 115 | return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform)) 116 | 117 | 118 | @functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1) 119 | def _carryover_convolve(data: jnp.ndarray, 120 | weights: jnp.ndarray, 121 | number_lags: int) -> jnp.ndarray: 122 | """Applies the convolution between the data and the weights for the carryover. 123 | 124 | Args: 125 | data: Input data. 126 | weights: Window weights for the carryover. 127 | number_lags: Number of lags the window has. 128 | 129 | Returns: 130 | The result values from convolving the data and the weights with padding. 131 | """ 132 | window = jnp.concatenate([jnp.zeros(number_lags - 1), weights]) 133 | return jax.scipy.signal.convolve(data, window, mode="same") / weights.sum() 134 | 135 | 136 | @functools.partial(jax.jit, static_argnames=("number_lags",)) 137 | def carryover(data: jnp.ndarray, 138 | ad_effect_retention_rate: jnp.ndarray, 139 | peak_effect_delay: jnp.ndarray, 140 | number_lags: int = 13) -> jnp.ndarray: 141 | """Calculates media carryover. 142 | 143 | More details about this function can be found in: 144 | https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf 145 | 146 | Args: 147 | data: Input data. It is expected that data has either 2 dimensions for 148 | national models and 3 for geo models. 149 | ad_effect_retention_rate: Retention rate of the advertisement effect. 150 | Default is 0.5. 151 | peak_effect_delay: Delay of the peak effect in the carryover function. 152 | Default is 1. 153 | number_lags: Number of lags to include in the carryover calculation. Default 154 | is 13. 155 | 156 | Returns: 157 | The carryover values for the given data with the given parameters. 158 | """ 159 | lags_arange = jnp.expand_dims(jnp.arange(number_lags, dtype=jnp.float32), 160 | axis=-1) 161 | convolve_func = _carryover_convolve 162 | if data.ndim == 3: 163 | # Since _carryover_convolve is already vmaped in the decorator we only need 164 | # to vmap it once here to handle the geo level data. We keep the windows bi 165 | # dimensional also for three dims data and vmap over only the extra data 166 | # dimension. 167 | convolve_func = jax.vmap( 168 | fun=_carryover_convolve, in_axes=(2, None, None), out_axes=2) 169 | weights = ad_effect_retention_rate**((lags_arange - peak_effect_delay)**2) 170 | return convolve_func(data, weights, number_lags) 171 | 172 | @jax.jit 173 | def apply_exponent_safe( 174 | data: jnp.ndarray, 175 | exponent: jnp.ndarray, 176 | ) -> jnp.ndarray: 177 | """Applies an exponent to given data in a gradient safe way. 178 | 179 | More info on the double jnp.where can be found: 180 | https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf 181 | 182 | Args: 183 | data: Input data to use. 184 | exponent: Exponent required for the operations. 185 | 186 | Returns: 187 | The result of the exponent operation with the inputs provided. 188 | """ 189 | exponent_safe = jnp.where(data == 0, 1, data) ** exponent 190 | return jnp.where(data == 0, 0, exponent_safe) 191 | -------------------------------------------------------------------------------- /lightweight_mmm/media_transforms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for media_transforms.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | from lightweight_mmm import media_transforms 24 | 25 | 26 | class MediaTransformsTest(parameterized.TestCase): 27 | 28 | @parameterized.named_parameters([ 29 | dict( 30 | testcase_name="2d_four_channels", 31 | data=np.ones((100, 4)), 32 | ad_effect_retention_rate=np.array([0.9, 0.8, 0.7, 1]), 33 | peak_effect_delay=np.array([0.9, 0.8, 0.7, 1]), 34 | number_lags=5), 35 | dict( 36 | testcase_name="2d_one_channel", 37 | data=np.ones((300, 1)), 38 | ad_effect_retention_rate=np.array([0.2]), 39 | peak_effect_delay=np.array([1]), 40 | number_lags=10), 41 | dict( 42 | testcase_name="3d_10channels_10geos", 43 | data=np.ones((100, 10, 10)), 44 | ad_effect_retention_rate=np.ones(10), 45 | peak_effect_delay=np.ones(10), 46 | number_lags=13), 47 | dict( 48 | testcase_name="3d_10channels_8geos", 49 | data=np.ones((100, 10, 8)), 50 | ad_effect_retention_rate=np.ones(10), 51 | peak_effect_delay=np.ones(10), 52 | number_lags=13), 53 | ]) 54 | def test_carryover_produces_correct_shape(self, data, 55 | ad_effect_retention_rate, 56 | peak_effect_delay, number_lags): 57 | 58 | generated_output = media_transforms.carryover(data, 59 | ad_effect_retention_rate, 60 | peak_effect_delay, 61 | number_lags) 62 | self.assertEqual(generated_output.shape, data.shape) 63 | 64 | @parameterized.named_parameters([ 65 | dict( 66 | testcase_name="2d_three_channels", 67 | data=np.ones((100, 3)), 68 | half_max_effective_concentration=np.array([0.9, 0.8, 0.7]), 69 | slope=np.array([2, 2, 1])), 70 | dict( 71 | testcase_name="2d_one_channels", 72 | data=np.ones((100, 1)), 73 | half_max_effective_concentration=np.array([0.9]), 74 | slope=np.array([5])), 75 | dict( 76 | testcase_name="3d_10channels_5geos", 77 | data=np.ones((100, 10, 5)), 78 | half_max_effective_concentration=np.expand_dims(np.ones(10), axis=-1), 79 | slope=np.expand_dims(np.ones(10), axis=-1)), 80 | dict( 81 | testcase_name="3d_8channels_10geos", 82 | data=np.ones((100, 8, 10)), 83 | half_max_effective_concentration=np.expand_dims(np.ones(8), axis=-1), 84 | slope=np.expand_dims(np.ones(8), axis=-1)), 85 | ]) 86 | def test_hill_produces_correct_shape(self, data, 87 | half_max_effective_concentration, slope): 88 | generated_output = media_transforms.hill( 89 | data=data, 90 | half_max_effective_concentration=half_max_effective_concentration, 91 | slope=slope) 92 | 93 | self.assertEqual(generated_output.shape, data.shape) 94 | 95 | @parameterized.named_parameters([ 96 | dict( 97 | testcase_name="2d_five_channels", 98 | data=np.ones((100, 5)), 99 | lag_weight=np.array([0.2, 0.3, 0.8, 0.2, 0.1]), 100 | normalise=True), 101 | dict( 102 | testcase_name="2d_one_channels", 103 | data=np.ones((100, 1)), 104 | lag_weight=np.array([0.4]), 105 | normalise=False), 106 | dict( 107 | testcase_name="3d_10channels_5geos", 108 | data=np.ones((100, 10, 5)), 109 | lag_weight=np.expand_dims(np.ones(10), axis=-1), 110 | normalise=True), 111 | dict( 112 | testcase_name="3d_8channels_10geos", 113 | data=np.ones((100, 8, 10)), 114 | lag_weight=np.expand_dims(np.ones(8), axis=-1), 115 | normalise=True), 116 | ]) 117 | def test_adstock_produces_correct_shape(self, data, lag_weight, normalise): 118 | generated_output = media_transforms.adstock( 119 | data=data, lag_weight=lag_weight, normalise=normalise) 120 | 121 | self.assertEqual(generated_output.shape, data.shape) 122 | 123 | def test_apply_exponent_safe_produces_correct_shape(self): 124 | data = jnp.arange(50).reshape((10, 5)) 125 | exponent = jnp.full(5, 0.5) 126 | 127 | output = media_transforms.apply_exponent_safe(data=data, exponent=exponent) 128 | 129 | np.testing.assert_array_equal(output, data**exponent) 130 | 131 | def test_apply_exponent_safe_produces_same_exponent_results(self): 132 | data = jnp.ones((10, 5)) 133 | exponent = jnp.full(5, 0.5) 134 | 135 | output = media_transforms.apply_exponent_safe(data=data, exponent=exponent) 136 | 137 | self.assertEqual(output.shape, data.shape) 138 | 139 | def test_apply_exponent_safe_produces_non_nan_or_inf_grads(self): 140 | def f_safe(data, exponent): 141 | x = media_transforms.apply_exponent_safe(data=data, exponent=exponent) 142 | return x.sum() 143 | data = jnp.ones((10, 5)) 144 | data = data.at[0, 0].set(0.) 145 | exponent = jnp.full(5, 0.5) 146 | 147 | grads = jax.grad(f_safe)(data, exponent) 148 | 149 | self.assertFalse(np.isnan(grads).any()) 150 | self.assertFalse(np.isinf(grads).any()) 151 | 152 | def test_adstock_zeros_stay_zeros(self): 153 | data = jnp.zeros((10, 5)) 154 | lag_weight = jnp.full(5, 0.5) 155 | 156 | generated_output = media_transforms.adstock( 157 | data=data, lag_weight=lag_weight) 158 | 159 | np.testing.assert_array_equal(generated_output, data) 160 | 161 | def test_hill_zeros_stay_zeros(self): 162 | data = jnp.zeros((10, 5)) 163 | half_max_effective_concentration = jnp.full(5, 0.5) 164 | slope = jnp.full(5, 0.5) 165 | 166 | generated_output = media_transforms.hill( 167 | data=data, 168 | half_max_effective_concentration=half_max_effective_concentration, 169 | slope=slope) 170 | 171 | np.testing.assert_array_equal(generated_output, data) 172 | 173 | def test_carryover_zeros_stay_zeros(self): 174 | data = jnp.zeros((10, 5)) 175 | ad_effect_retention_rate = jnp.full(5, 0.5) 176 | peak_effect_delay = jnp.full(5, 0.5) 177 | 178 | generated_output = media_transforms.carryover( 179 | data=data, 180 | ad_effect_retention_rate=ad_effect_retention_rate, 181 | peak_effect_delay=peak_effect_delay) 182 | 183 | np.testing.assert_array_equal(generated_output, data) 184 | 185 | 186 | @parameterized.parameters(range(1, 5)) 187 | def test_calculate_seasonality_produces_correct_standard_deviation( 188 | self, degrees): 189 | # It's not very obvious that this is the expected standard deviation, but it 190 | # seems to be true mathematically and this makes a very convenient unit test. 191 | expected_standard_deviation = jnp.sqrt(degrees) 192 | 193 | seasonal_curve = media_transforms.calculate_seasonality( 194 | number_periods=1, 195 | degrees=degrees, 196 | gamma_seasonality=1, 197 | frequency=1200, 198 | ) 199 | observed_standard_deviation = jnp.std(seasonal_curve) 200 | 201 | self.assertAlmostEqual( 202 | observed_standard_deviation, expected_standard_deviation, delta=0.01) 203 | 204 | 205 | if __name__ == "__main__": 206 | absltest.main() 207 | -------------------------------------------------------------------------------- /lightweight_mmm/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module containing the different models available in the lightweightMMM lib. 16 | 17 | Currently this file contains a main model with three possible options for 18 | processing the media data. Which essentially grants the possibility of building 19 | three different models. 20 | - Adstock 21 | - Hill-Adstock 22 | - Carryover 23 | """ 24 | import sys 25 | # pylint: disable=g-import-not-at-top 26 | if sys.version_info >= (3, 8): 27 | from typing import Protocol 28 | else: 29 | from typing_extensions import Protocol 30 | 31 | from typing import Any, Dict, Mapping, MutableMapping, Optional, Sequence, Union 32 | 33 | import immutabledict 34 | import jax.numpy as jnp 35 | import numpyro 36 | from numpyro import distributions as dist 37 | 38 | from lightweight_mmm import media_transforms 39 | 40 | Prior = Union[ 41 | dist.Distribution, 42 | Dict[str, float], 43 | Sequence[float], 44 | float 45 | ] 46 | 47 | 48 | class TransformFunction(Protocol): 49 | 50 | def __call__( 51 | self, 52 | media_data: jnp.ndarray, 53 | custom_priors: MutableMapping[str, Prior], 54 | **kwargs: Any) -> jnp.ndarray: 55 | ... 56 | 57 | 58 | _INTERCEPT = "intercept" 59 | _COEF_TREND = "coef_trend" 60 | _EXPO_TREND = "expo_trend" 61 | _SIGMA = "sigma" 62 | _GAMMA_SEASONALITY = "gamma_seasonality" 63 | _WEEKDAY = "weekday" 64 | _COEF_EXTRA_FEATURES = "coef_extra_features" 65 | _COEF_SEASONALITY = "coef_seasonality" 66 | 67 | MODEL_PRIORS_NAMES = frozenset(( 68 | _INTERCEPT, 69 | _COEF_TREND, 70 | _EXPO_TREND, 71 | _SIGMA, 72 | _GAMMA_SEASONALITY, 73 | _WEEKDAY, 74 | _COEF_EXTRA_FEATURES, 75 | _COEF_SEASONALITY)) 76 | 77 | _EXPONENT = "exponent" 78 | _LAG_WEIGHT = "lag_weight" 79 | _HALF_MAX_EFFECTIVE_CONCENTRATION = "half_max_effective_concentration" 80 | _SLOPE = "slope" 81 | _AD_EFFECT_RETENTION_RATE = "ad_effect_retention_rate" 82 | _PEAK_EFFECT_DELAY = "peak_effect_delay" 83 | 84 | TRANSFORM_PRIORS_NAMES = immutabledict.immutabledict({ 85 | "carryover": 86 | frozenset((_AD_EFFECT_RETENTION_RATE, _PEAK_EFFECT_DELAY, _EXPONENT)), 87 | "adstock": 88 | frozenset((_EXPONENT, _LAG_WEIGHT)), 89 | "hill_adstock": 90 | frozenset((_LAG_WEIGHT, _HALF_MAX_EFFECTIVE_CONCENTRATION, _SLOPE)) 91 | }) 92 | 93 | GEO_ONLY_PRIORS = frozenset((_COEF_SEASONALITY,)) 94 | 95 | 96 | def _get_default_priors() -> Mapping[str, Prior]: 97 | # Since JAX cannot be called before absl.app.run in tests we get default 98 | # priors from a function. 99 | return immutabledict.immutabledict({ 100 | _INTERCEPT: dist.HalfNormal(scale=2.), 101 | _COEF_TREND: dist.Normal(loc=0., scale=1.), 102 | _EXPO_TREND: dist.Uniform(low=0.5, high=1.5), 103 | _SIGMA: dist.Gamma(concentration=1., rate=1.), 104 | _GAMMA_SEASONALITY: dist.Normal(loc=0., scale=1.), 105 | _WEEKDAY: dist.Normal(loc=0., scale=.5), 106 | _COEF_EXTRA_FEATURES: dist.Normal(loc=0., scale=1.), 107 | _COEF_SEASONALITY: dist.HalfNormal(scale=.5) 108 | }) 109 | 110 | 111 | def _get_transform_default_priors() -> Mapping[str, Prior]: 112 | # Since JAX cannot be called before absl.app.run in tests we get default 113 | # priors from a function. 114 | return immutabledict.immutabledict({ 115 | "carryover": 116 | immutabledict.immutabledict({ 117 | _AD_EFFECT_RETENTION_RATE: 118 | dist.Beta(concentration1=1., concentration0=1.), 119 | _PEAK_EFFECT_DELAY: 120 | dist.HalfNormal(scale=2.), 121 | _EXPONENT: 122 | dist.Beta(concentration1=9., concentration0=1.) 123 | }), 124 | "adstock": 125 | immutabledict.immutabledict({ 126 | _EXPONENT: dist.Beta(concentration1=9., concentration0=1.), 127 | _LAG_WEIGHT: dist.Beta(concentration1=2., concentration0=1.) 128 | }), 129 | "hill_adstock": 130 | immutabledict.immutabledict({ 131 | _LAG_WEIGHT: 132 | dist.Beta(concentration1=2., concentration0=1.), 133 | _HALF_MAX_EFFECTIVE_CONCENTRATION: 134 | dist.Gamma(concentration=1., rate=1.), 135 | _SLOPE: 136 | dist.Gamma(concentration=1., rate=1.) 137 | }) 138 | }) 139 | 140 | 141 | def transform_adstock(media_data: jnp.ndarray, 142 | custom_priors: MutableMapping[str, Prior], 143 | normalise: bool = True) -> jnp.ndarray: 144 | """Transforms the input data with the adstock function and exponent. 145 | 146 | Args: 147 | media_data: Media data to be transformed. It is expected to have 2 dims for 148 | national models and 3 for geo models. 149 | custom_priors: The custom priors we want the model to take instead of the 150 | default ones. The possible names of parameters for adstock and exponent 151 | are "lag_weight" and "exponent". 152 | normalise: Whether to normalise the output values. 153 | 154 | Returns: 155 | The transformed media data. 156 | """ 157 | transform_default_priors = _get_transform_default_priors()["adstock"] 158 | with numpyro.plate(name=f"{_LAG_WEIGHT}_plate", 159 | size=media_data.shape[1]): 160 | lag_weight = numpyro.sample( 161 | name=_LAG_WEIGHT, 162 | fn=custom_priors.get(_LAG_WEIGHT, 163 | transform_default_priors[_LAG_WEIGHT])) 164 | 165 | with numpyro.plate(name=f"{_EXPONENT}_plate", 166 | size=media_data.shape[1]): 167 | exponent = numpyro.sample( 168 | name=_EXPONENT, 169 | fn=custom_priors.get(_EXPONENT, 170 | transform_default_priors[_EXPONENT])) 171 | 172 | if media_data.ndim == 3: 173 | lag_weight = jnp.expand_dims(lag_weight, axis=-1) 174 | exponent = jnp.expand_dims(exponent, axis=-1) 175 | 176 | adstock = media_transforms.adstock( 177 | data=media_data, lag_weight=lag_weight, normalise=normalise) 178 | 179 | return media_transforms.apply_exponent_safe(data=adstock, exponent=exponent) 180 | 181 | 182 | def transform_hill_adstock(media_data: jnp.ndarray, 183 | custom_priors: MutableMapping[str, Prior], 184 | normalise: bool = True) -> jnp.ndarray: 185 | """Transforms the input data with the adstock and hill functions. 186 | 187 | Args: 188 | media_data: Media data to be transformed. It is expected to have 2 dims for 189 | national models and 3 for geo models. 190 | custom_priors: The custom priors we want the model to take instead of the 191 | default ones. The possible names of parameters for hill_adstock and 192 | exponent are "lag_weight", "half_max_effective_concentration" and "slope". 193 | normalise: Whether to normalise the output values. 194 | 195 | Returns: 196 | The transformed media data. 197 | """ 198 | transform_default_priors = _get_transform_default_priors()["hill_adstock"] 199 | with numpyro.plate(name=f"{_LAG_WEIGHT}_plate", 200 | size=media_data.shape[1]): 201 | lag_weight = numpyro.sample( 202 | name=_LAG_WEIGHT, 203 | fn=custom_priors.get(_LAG_WEIGHT, 204 | transform_default_priors[_LAG_WEIGHT])) 205 | 206 | with numpyro.plate(name=f"{_HALF_MAX_EFFECTIVE_CONCENTRATION}_plate", 207 | size=media_data.shape[1]): 208 | half_max_effective_concentration = numpyro.sample( 209 | name=_HALF_MAX_EFFECTIVE_CONCENTRATION, 210 | fn=custom_priors.get( 211 | _HALF_MAX_EFFECTIVE_CONCENTRATION, 212 | transform_default_priors[_HALF_MAX_EFFECTIVE_CONCENTRATION])) 213 | 214 | with numpyro.plate(name=f"{_SLOPE}_plate", 215 | size=media_data.shape[1]): 216 | slope = numpyro.sample( 217 | name=_SLOPE, 218 | fn=custom_priors.get(_SLOPE, transform_default_priors[_SLOPE])) 219 | 220 | if media_data.ndim == 3: 221 | lag_weight = jnp.expand_dims(lag_weight, axis=-1) 222 | half_max_effective_concentration = jnp.expand_dims( 223 | half_max_effective_concentration, axis=-1) 224 | slope = jnp.expand_dims(slope, axis=-1) 225 | 226 | return media_transforms.hill( 227 | data=media_transforms.adstock( 228 | data=media_data, lag_weight=lag_weight, normalise=normalise), 229 | half_max_effective_concentration=half_max_effective_concentration, 230 | slope=slope) 231 | 232 | 233 | def transform_carryover(media_data: jnp.ndarray, 234 | custom_priors: MutableMapping[str, Prior], 235 | number_lags: int = 13) -> jnp.ndarray: 236 | """Transforms the input data with the carryover function and exponent. 237 | 238 | Args: 239 | media_data: Media data to be transformed. It is expected to have 2 dims for 240 | national models and 3 for geo models. 241 | custom_priors: The custom priors we want the model to take instead of the 242 | default ones. The possible names of parameters for carryover and exponent 243 | are "ad_effect_retention_rate_plate", "peak_effect_delay_plate" and 244 | "exponent". 245 | number_lags: Number of lags for the carryover function. 246 | 247 | Returns: 248 | The transformed media data. 249 | """ 250 | transform_default_priors = _get_transform_default_priors()["carryover"] 251 | with numpyro.plate(name=f"{_AD_EFFECT_RETENTION_RATE}_plate", 252 | size=media_data.shape[1]): 253 | ad_effect_retention_rate = numpyro.sample( 254 | name=_AD_EFFECT_RETENTION_RATE, 255 | fn=custom_priors.get( 256 | _AD_EFFECT_RETENTION_RATE, 257 | transform_default_priors[_AD_EFFECT_RETENTION_RATE])) 258 | 259 | with numpyro.plate(name=f"{_PEAK_EFFECT_DELAY}_plate", 260 | size=media_data.shape[1]): 261 | peak_effect_delay = numpyro.sample( 262 | name=_PEAK_EFFECT_DELAY, 263 | fn=custom_priors.get( 264 | _PEAK_EFFECT_DELAY, transform_default_priors[_PEAK_EFFECT_DELAY])) 265 | 266 | with numpyro.plate(name=f"{_EXPONENT}_plate", 267 | size=media_data.shape[1]): 268 | exponent = numpyro.sample( 269 | name=_EXPONENT, 270 | fn=custom_priors.get(_EXPONENT, 271 | transform_default_priors[_EXPONENT])) 272 | carryover = media_transforms.carryover( 273 | data=media_data, 274 | ad_effect_retention_rate=ad_effect_retention_rate, 275 | peak_effect_delay=peak_effect_delay, 276 | number_lags=number_lags) 277 | 278 | if media_data.ndim == 3: 279 | exponent = jnp.expand_dims(exponent, axis=-1) 280 | return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent) 281 | 282 | 283 | def media_mix_model( 284 | media_data: jnp.ndarray, 285 | target_data: jnp.ndarray, 286 | media_prior: jnp.ndarray, 287 | degrees_seasonality: int, 288 | frequency: int, 289 | transform_function: TransformFunction, 290 | custom_priors: MutableMapping[str, Prior], 291 | transform_kwargs: Optional[MutableMapping[str, Any]] = None, 292 | weekday_seasonality: bool = False, 293 | extra_features: Optional[jnp.ndarray] = None, 294 | ) -> None: 295 | """Media mix model. 296 | 297 | Args: 298 | media_data: Media data to be be used in the model. 299 | target_data: Target data for the model. 300 | media_prior: Cost prior for each of the media channels. 301 | degrees_seasonality: Number of degrees of seasonality to use. 302 | frequency: Frequency of the time span which was used to aggregate the data. 303 | Eg. if weekly data then frequency is 52. 304 | transform_function: Function to use to transform the media data in the 305 | model. Currently the following are supported: 'transform_adstock', 306 | 'transform_carryover' and 'transform_hill_adstock'. 307 | custom_priors: The custom priors we want the model to take instead of the 308 | default ones. See our custom_priors documentation for details about the 309 | API and possible options. 310 | transform_kwargs: Any extra keyword arguments to pass to the transform 311 | function. For example the adstock function can take a boolean to noramlise 312 | output or not. 313 | weekday_seasonality: In case of daily data you can estimate a weekday (7) 314 | parameter. 315 | extra_features: Extra features data to include in the model. 316 | """ 317 | default_priors = _get_default_priors() 318 | data_size = media_data.shape[0] 319 | n_channels = media_data.shape[1] 320 | geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else () 321 | n_geos = media_data.shape[2] if media_data.ndim == 3 else 1 322 | 323 | with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos): 324 | intercept = numpyro.sample( 325 | name=_INTERCEPT, 326 | fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT])) 327 | 328 | with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos): 329 | sigma = numpyro.sample( 330 | name=_SIGMA, 331 | fn=custom_priors.get(_SIGMA, default_priors[_SIGMA])) 332 | 333 | # TODO(): Force all geos to have the same trend sign. 334 | with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos): 335 | coef_trend = numpyro.sample( 336 | name=_COEF_TREND, 337 | fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND])) 338 | 339 | expo_trend = numpyro.sample( 340 | name=_EXPO_TREND, 341 | fn=custom_priors.get( 342 | _EXPO_TREND, default_priors[_EXPO_TREND])) 343 | 344 | with numpyro.plate( 345 | name="channel_media_plate", 346 | size=n_channels, 347 | dim=-2 if media_data.ndim == 3 else -1): 348 | coef_media = numpyro.sample( 349 | name="channel_coef_media" if media_data.ndim == 3 else "coef_media", 350 | fn=dist.HalfNormal(scale=media_prior)) 351 | if media_data.ndim == 3: 352 | with numpyro.plate( 353 | name="geo_media_plate", 354 | size=n_geos, 355 | dim=-1): 356 | # Corrects the mean to be the same as in the channel only case. 357 | normalisation_factor = jnp.sqrt(2.0 / jnp.pi) 358 | coef_media = numpyro.sample( 359 | name="coef_media", 360 | fn=dist.HalfNormal(scale=coef_media * normalisation_factor) 361 | ) 362 | 363 | with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2): 364 | with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate", 365 | size=degrees_seasonality): 366 | gamma_seasonality = numpyro.sample( 367 | name=_GAMMA_SEASONALITY, 368 | fn=custom_priors.get( 369 | _GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY])) 370 | 371 | if weekday_seasonality: 372 | with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7): 373 | weekday = numpyro.sample( 374 | name=_WEEKDAY, 375 | fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY])) 376 | weekday_series = weekday[jnp.arange(data_size) % 7] 377 | # In case of daily data, number of lags should be 13*7. 378 | if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs: 379 | transform_kwargs["number_lags"] = 13 * 7 380 | elif transform_function == "carryover" and not transform_kwargs: 381 | transform_kwargs = {"number_lags": 13 * 7} 382 | 383 | media_transformed = numpyro.deterministic( 384 | name="media_transformed", 385 | value=transform_function(media_data, 386 | custom_priors=custom_priors, 387 | **transform_kwargs if transform_kwargs else {})) 388 | seasonality = media_transforms.calculate_seasonality( 389 | number_periods=data_size, 390 | degrees=degrees_seasonality, 391 | frequency=frequency, 392 | gamma_seasonality=gamma_seasonality) 393 | # For national model's case 394 | trend = jnp.arange(data_size) 395 | media_einsum = "tc, c -> t" # t = time, c = channel 396 | coef_seasonality = 1 397 | 398 | # TODO(): Add conversion of prior for HalfNormal distribution. 399 | if media_data.ndim == 3: # For geo model's case 400 | trend = jnp.expand_dims(trend, axis=-1) 401 | seasonality = jnp.expand_dims(seasonality, axis=-1) 402 | media_einsum = "tcg, cg -> tg" # t = time, c = channel, g = geo 403 | if weekday_seasonality: 404 | weekday_series = jnp.expand_dims(weekday_series, axis=-1) 405 | with numpyro.plate(name="seasonality_plate", size=n_geos): 406 | coef_seasonality = numpyro.sample( 407 | name=_COEF_SEASONALITY, 408 | fn=custom_priors.get( 409 | _COEF_SEASONALITY, default_priors[_COEF_SEASONALITY])) 410 | # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5]. 411 | prediction = ( 412 | intercept + coef_trend * trend ** expo_trend + 413 | seasonality * coef_seasonality + 414 | jnp.einsum(media_einsum, media_transformed, coef_media)) 415 | if extra_features is not None: 416 | plate_prefixes = ("extra_feature",) 417 | extra_features_einsum = "tf, f -> t" # t = time, f = feature 418 | extra_features_plates_shape = (extra_features.shape[1],) 419 | if extra_features.ndim == 3: 420 | plate_prefixes = ("extra_feature", "geo") 421 | extra_features_einsum = "tfg, fg -> tg" # t = time, f = feature, g = geo 422 | extra_features_plates_shape = (extra_features.shape[1], *geo_shape) 423 | with numpyro.plate_stack( 424 | str(plate_prefixes), sizes=list(extra_features_plates_shape) 425 | ): 426 | coef_extra_features = numpyro.sample( 427 | name=_COEF_EXTRA_FEATURES, 428 | fn=custom_priors.get( 429 | _COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES])) 430 | extra_features_effect = jnp.einsum(extra_features_einsum, 431 | extra_features, 432 | coef_extra_features) 433 | prediction += extra_features_effect 434 | 435 | if weekday_seasonality: 436 | prediction += weekday_series 437 | mu = numpyro.deterministic(name="mu", value=prediction) 438 | 439 | numpyro.sample( 440 | name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data) 441 | -------------------------------------------------------------------------------- /lightweight_mmm/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for models.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpyro 22 | from numpyro import distributions as dist 23 | from numpyro import handlers 24 | 25 | from lightweight_mmm import models 26 | 27 | 28 | class ModelsTest(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters( 31 | dict(testcase_name="one_channel", shape=(10, 1)), 32 | dict(testcase_name="five_channel", shape=(10, 5)), 33 | dict(testcase_name="same_channels_as_rows", shape=(10, 10)), 34 | dict(testcase_name="geo_shape_1", shape=(10, 10, 5)), 35 | dict(testcase_name="geo_shape_2", shape=(10, 5, 2)), 36 | dict(testcase_name="one_channel_one_row", shape=(1, 1))) 37 | def test_transform_adstock_produces_correct_output_shape(self, shape): 38 | 39 | def mock_model_function(media_data): 40 | numpyro.deterministic( 41 | "transformed_media", 42 | models.transform_adstock(media_data, custom_priors={})) 43 | 44 | media = jnp.ones(shape) 45 | kernel = numpyro.infer.NUTS(model=mock_model_function) 46 | mcmc = numpyro.infer.MCMC( 47 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1) 48 | rng_key = jax.random.PRNGKey(0) 49 | 50 | mcmc.run(rng_key, media_data=media) 51 | transformed_media = mcmc.get_samples()["transformed_media"].mean(axis=0) 52 | 53 | self.assertEqual(media.shape, transformed_media.shape) 54 | 55 | @parameterized.named_parameters( 56 | dict(testcase_name="one_channel", shape=(10, 1)), 57 | dict(testcase_name="five_channel", shape=(10, 5)), 58 | dict(testcase_name="same_channels_as_rows", shape=(10, 10)), 59 | dict(testcase_name="geo_shape_1", shape=(10, 10, 5)), 60 | dict(testcase_name="geo_shape_2", shape=(10, 5, 2)), 61 | dict(testcase_name="one_channel_one_row", shape=(1, 1))) 62 | def test_transform_hill_adstock_produces_correct_output_shape(self, shape): 63 | 64 | def mock_model_function(media_data): 65 | numpyro.deterministic( 66 | "transformed_media", 67 | models.transform_hill_adstock(media_data, custom_priors={})) 68 | 69 | media = jnp.ones(shape) 70 | kernel = numpyro.infer.NUTS(model=mock_model_function) 71 | mcmc = numpyro.infer.MCMC( 72 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1) 73 | rng_key = jax.random.PRNGKey(0) 74 | 75 | mcmc.run(rng_key, media_data=media) 76 | transformed_media = mcmc.get_samples()["transformed_media"].mean(axis=0) 77 | 78 | self.assertEqual(media.shape, transformed_media.shape) 79 | 80 | @parameterized.named_parameters( 81 | dict(testcase_name="one_channel", shape=(10, 1)), 82 | dict(testcase_name="five_channel", shape=(10, 5)), 83 | dict(testcase_name="same_channels_as_rows", shape=(10, 10)), 84 | dict(testcase_name="geo_shape_1", shape=(10, 10, 5)), 85 | dict(testcase_name="geo_shape_2", shape=(10, 5, 2)), 86 | dict(testcase_name="one_channel_one_row", shape=(1, 1))) 87 | def test_transform_carryover_produces_correct_output_shape(self, shape): 88 | 89 | def mock_model_function(media_data): 90 | numpyro.deterministic( 91 | "transformed_media", 92 | models.transform_carryover(media_data, custom_priors={})) 93 | 94 | media = jnp.ones(shape) 95 | kernel = numpyro.infer.NUTS(model=mock_model_function) 96 | mcmc = numpyro.infer.MCMC( 97 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1) 98 | rng_key = jax.random.PRNGKey(0) 99 | 100 | mcmc.run(rng_key, media_data=media) 101 | transformed_media = mcmc.get_samples()["transformed_media"].mean(axis=0) 102 | 103 | self.assertEqual(media.shape, transformed_media.shape) 104 | 105 | @parameterized.named_parameters( 106 | dict( 107 | testcase_name="national_no_extra", 108 | media_shape=(10, 3), 109 | extra_features_shape=(), 110 | target_shape=(10,), 111 | total_costs_shape=(3,)), 112 | dict( 113 | testcase_name="national_extra", 114 | media_shape=(10, 5), 115 | extra_features_shape=(10, 2), 116 | target_shape=(10,), 117 | total_costs_shape=(5,)), 118 | dict( 119 | testcase_name="geo_extra_3d", 120 | media_shape=(10, 7, 5), 121 | extra_features_shape=(10, 8, 5), 122 | target_shape=(10, 5), 123 | total_costs_shape=(7, 1)), 124 | dict( 125 | testcase_name="geo_no_extra", 126 | media_shape=(10, 7, 5), 127 | extra_features_shape=(), 128 | target_shape=(10, 5), 129 | total_costs_shape=(7, 1))) 130 | def test_media_mix_model_parameters_have_correct_shapes( 131 | self, media_shape, extra_features_shape, target_shape, total_costs_shape): 132 | media = jnp.ones(media_shape) 133 | extra_features = None if not extra_features_shape else jnp.ones( 134 | extra_features_shape) 135 | costs_prior = jnp.ones(total_costs_shape) 136 | degrees = 2 137 | target = jnp.ones(target_shape) 138 | kernel = numpyro.infer.NUTS(model=models.media_mix_model) 139 | mcmc = numpyro.infer.MCMC( 140 | sampler=kernel, num_warmup=10, num_samples=10, num_chains=1) 141 | rng_key = jax.random.PRNGKey(0) 142 | 143 | mcmc.run( 144 | rng_key, 145 | media_data=media, 146 | extra_features=extra_features, 147 | target_data=target, 148 | media_prior=costs_prior, 149 | degrees_seasonality=degrees, 150 | custom_priors={}, 151 | frequency=52, 152 | transform_function=models.transform_carryover) 153 | trace = mcmc.get_samples() 154 | 155 | self.assertEqual( 156 | jnp.squeeze(trace["intercept"].mean(axis=0)).shape, target_shape[1:]) 157 | self.assertEqual( 158 | jnp.squeeze(trace["sigma"].mean(axis=0)).shape, target_shape[1:]) 159 | self.assertEqual( 160 | jnp.squeeze(trace["expo_trend"].mean(axis=0)).shape, ()) 161 | self.assertEqual( 162 | jnp.squeeze(trace["coef_trend"].mean(axis=0)).shape, target_shape[1:]) 163 | self.assertEqual( 164 | jnp.squeeze(trace["coef_media"].mean(axis=0)).shape, media_shape[1:]) 165 | if extra_features_shape: 166 | self.assertEqual(trace["coef_extra_features"].mean(axis=0).shape, 167 | extra_features.shape[1:]) 168 | self.assertEqual(trace["gamma_seasonality"].mean(axis=0).shape, 169 | (degrees, 2)) 170 | self.assertEqual(trace["mu"].mean(axis=0).shape, target_shape) 171 | 172 | @parameterized.named_parameters( 173 | dict( 174 | testcase_name=f"model_{models._INTERCEPT}", 175 | prior_name=models._INTERCEPT, 176 | transform_function=models.transform_carryover), 177 | dict( 178 | testcase_name=f"model_{models._COEF_TREND}", 179 | prior_name=models._COEF_TREND, 180 | transform_function=models.transform_carryover), 181 | dict( 182 | testcase_name=f"model_{models._EXPO_TREND}", 183 | prior_name=models._EXPO_TREND, 184 | transform_function=models.transform_carryover), 185 | dict( 186 | testcase_name=f"model_{models._SIGMA}", 187 | prior_name=models._SIGMA, 188 | transform_function=models.transform_carryover), 189 | dict( 190 | testcase_name=f"model_{models._GAMMA_SEASONALITY}", 191 | prior_name=models._GAMMA_SEASONALITY, 192 | transform_function=models.transform_carryover), 193 | dict( 194 | testcase_name=f"model_{models._WEEKDAY}", 195 | prior_name=models._WEEKDAY, 196 | transform_function=models.transform_carryover), 197 | dict( 198 | testcase_name=f"model_{models._COEF_EXTRA_FEATURES}", 199 | prior_name=models._COEF_EXTRA_FEATURES, 200 | transform_function=models.transform_carryover), 201 | dict( 202 | testcase_name=f"model_{models._COEF_SEASONALITY}", 203 | prior_name=models._COEF_SEASONALITY, 204 | transform_function=models.transform_carryover), 205 | dict( 206 | testcase_name=f"carryover_{models._AD_EFFECT_RETENTION_RATE}", 207 | prior_name=models._AD_EFFECT_RETENTION_RATE, 208 | transform_function=models.transform_carryover), 209 | dict( 210 | testcase_name=f"carryover_{models._PEAK_EFFECT_DELAY}", 211 | prior_name=models._PEAK_EFFECT_DELAY, 212 | transform_function=models.transform_carryover), 213 | dict( 214 | testcase_name=f"carryover_{models._EXPONENT}", 215 | prior_name=models._EXPONENT, 216 | transform_function=models.transform_carryover), 217 | dict( 218 | testcase_name=f"adstock_{models._EXPONENT}", 219 | prior_name=models._EXPONENT, 220 | transform_function=models.transform_adstock), 221 | dict( 222 | testcase_name=f"adstock_{models._LAG_WEIGHT}", 223 | prior_name=models._LAG_WEIGHT, 224 | transform_function=models.transform_adstock), 225 | dict( 226 | testcase_name=f"hilladstock_{models._LAG_WEIGHT}", 227 | prior_name=models._LAG_WEIGHT, 228 | transform_function=models.transform_hill_adstock), 229 | dict( 230 | testcase_name=f"hilladstock_{models._HALF_MAX_EFFECTIVE_CONCENTRATION}", 231 | prior_name=models._HALF_MAX_EFFECTIVE_CONCENTRATION, 232 | transform_function=models.transform_hill_adstock), 233 | dict( 234 | testcase_name=f"hilladstock_{models._SLOPE}", 235 | prior_name=models._SLOPE, 236 | transform_function=models.transform_hill_adstock), 237 | ) 238 | def test_media_mix_model_custom_priors_are_taken_correctly( 239 | self, prior_name, transform_function): 240 | expected_value1, expected_value2 = 5.2, 7.56 241 | custom_priors = { 242 | prior_name: dist.Kumaraswamy( 243 | concentration1=expected_value1, concentration0=expected_value2)} 244 | media = jnp.ones((10, 5, 5)) 245 | extra_features = jnp.ones((10, 3, 5)) 246 | costs_prior = jnp.ones((5, 1)) 247 | target = jnp.ones((10, 5)) 248 | 249 | trace_handler = handlers.trace(handlers.seed( 250 | models.media_mix_model, rng_seed=0)) 251 | trace = trace_handler.get_trace( 252 | media_data=media, 253 | extra_features=extra_features, 254 | target_data=target, 255 | media_prior=costs_prior, 256 | custom_priors=custom_priors, 257 | degrees_seasonality=2, 258 | frequency=52, 259 | transform_function=transform_function, 260 | weekday_seasonality=True 261 | ) 262 | values_and_dists = { 263 | name: site["fn"] 264 | for name, site in trace.items() if "fn" in site 265 | } 266 | 267 | used_distribution = values_and_dists[prior_name] 268 | if isinstance(used_distribution, dist.ExpandedDistribution): 269 | used_distribution = used_distribution.base_dist 270 | self.assertIsInstance(used_distribution, dist.Kumaraswamy) 271 | self.assertEqual(used_distribution.concentration0, expected_value2) 272 | self.assertEqual(used_distribution.concentration1, expected_value1) 273 | 274 | 275 | if __name__ == "__main__": 276 | absltest.main() 277 | -------------------------------------------------------------------------------- /lightweight_mmm/optimize_media.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for optimizing your media based on media mix models.""" 16 | import functools 17 | from typing import Optional, Tuple, Union 18 | from absl import logging 19 | import jax 20 | import jax.numpy as jnp 21 | from scipy import optimize 22 | 23 | from lightweight_mmm import lightweight_mmm 24 | from lightweight_mmm import preprocessing 25 | 26 | 27 | @functools.partial( 28 | jax.jit, 29 | static_argnames=("media_mix_model", "media_input_shape", "target_scaler", 30 | "media_scaler")) 31 | def _objective_function( 32 | extra_features: jnp.ndarray, 33 | media_mix_model: lightweight_mmm.LightweightMMM, 34 | media_input_shape: Tuple[int, int], 35 | media_gap: Optional[int], 36 | target_scaler: Optional[preprocessing.CustomScaler], 37 | media_scaler: preprocessing.CustomScaler, 38 | geo_ratio: jnp.ndarray, 39 | seed: Optional[int], 40 | media_values: jnp.ndarray, 41 | ) -> jnp.float64: 42 | """Objective function to calculate the sum of all predictions of the model. 43 | 44 | Args: 45 | extra_features: Extra features the model requires for prediction. 46 | media_mix_model: Media mix model to use. Must have a predict method to be 47 | used. 48 | media_input_shape: Input shape of the data required by the model to get 49 | predictions. This is needed since optimization might flatten some arrays 50 | and they need to be reshaped before running new predictions. 51 | media_gap: Media data gap between the end of training data and the start of 52 | the out of sample media given. Eg. if 100 weeks of data were used for 53 | training and prediction starts 2 months after training data finished we 54 | need to provide the 8 weeks missing between the training data and the 55 | prediction data so data transformations (adstock, carryover, ...) can take 56 | place correctly. 57 | target_scaler: Scaler that was used to scale the target before training. 58 | media_scaler: Scaler that was used to scale the media data before training. 59 | geo_ratio: The ratio to split channel media across geo. Should sum up to 1 60 | for each channel and should have shape (c, g). 61 | seed: Seed to use for PRNGKey during sampling. For replicability run 62 | this function and any other function that gets predictions with the same 63 | seed. 64 | media_values: Media values required by the model to run predictions. 65 | 66 | Returns: 67 | The negative value of the sum of all predictions. 68 | """ 69 | if hasattr(media_mix_model, "n_geos") and media_mix_model.n_geos > 1: 70 | media_values = geo_ratio * jnp.expand_dims(media_values, axis=-1) 71 | media_values = jnp.tile( 72 | media_values / media_input_shape[0], reps=media_input_shape[0]) 73 | # Distribute budget of each channels across time. 74 | media_values = jnp.reshape(a=media_values, shape=media_input_shape) 75 | media_values = media_scaler.transform(media_values) 76 | return -jnp.sum( 77 | media_mix_model.predict( 78 | media=media_values.reshape(media_input_shape), 79 | extra_features=extra_features, 80 | media_gap=media_gap, 81 | target_scaler=target_scaler, 82 | seed=seed).mean(axis=0)) 83 | 84 | 85 | @jax.jit 86 | def _budget_constraint(media: jnp.ndarray, 87 | prices: jnp.ndarray, 88 | budget: jnp.ndarray) -> jnp.float64: 89 | """Calculates optimization constraint to keep spend equal to the budget. 90 | 91 | Args: 92 | media: Array with the values of the media for this iteration. 93 | prices: Prices of each media channel at any given time. 94 | budget: Total budget of the optimization. 95 | 96 | Returns: 97 | The result from substracting the total spending and the budget. 98 | """ 99 | media = media.reshape((-1, len(prices))) 100 | return jnp.sum(media * prices) - budget 101 | 102 | 103 | def _get_lower_and_upper_bounds( 104 | media: jnp.ndarray, 105 | n_time_periods: int, 106 | lower_pct: jnp.ndarray, 107 | upper_pct: jnp.ndarray, 108 | media_scaler: Optional[preprocessing.CustomScaler] = None 109 | ) -> optimize.Bounds: 110 | """Gets the lower and upper bounds for optimisation based on historic data. 111 | 112 | It creates an upper bound based on a percentage above the mean value on 113 | each channel and a lower bound based on a relative decrease of the mean 114 | value. 115 | 116 | Args: 117 | media: Media data to get historic mean. 118 | n_time_periods: Number of time periods to optimize for. If model is built on 119 | weekly data, this would be the number of weeks ahead to optimize. 120 | lower_pct: Relative percentage decrease from the mean value to consider as 121 | new lower bound. 122 | upper_pct: Relative percentage increase from the mean value to consider as 123 | new upper bound. 124 | media_scaler: Scaler that was used to scale the media data before training. 125 | 126 | Returns: 127 | A list of tuples with the lower and upper bound for each media channel. 128 | """ 129 | if media.ndim == 3: 130 | lower_pct = jnp.expand_dims(lower_pct, axis=-1) 131 | upper_pct = jnp.expand_dims(upper_pct, axis=-1) 132 | 133 | mean_data = media.mean(axis=0) 134 | lower_bounds = jnp.maximum(mean_data * (1 - lower_pct), 0) 135 | upper_bounds = mean_data * (1 + upper_pct) 136 | 137 | if media_scaler: 138 | lower_bounds = media_scaler.inverse_transform(lower_bounds) 139 | upper_bounds = media_scaler.inverse_transform(upper_bounds) 140 | 141 | if media.ndim == 3: 142 | lower_bounds = lower_bounds.sum(axis=-1) 143 | upper_bounds = upper_bounds.sum(axis=-1) 144 | 145 | return optimize.Bounds(lb=lower_bounds * n_time_periods, 146 | ub=upper_bounds * n_time_periods) 147 | 148 | 149 | def _generate_starting_values( 150 | n_time_periods: int, media: jnp.ndarray, 151 | media_scaler: preprocessing.CustomScaler, 152 | budget: Union[float, int], 153 | prices: jnp.ndarray, 154 | ) -> jnp.ndarray: 155 | """Generates starting values based on historic allocation and budget. 156 | 157 | In order to make a comparison we can take the allocation of the last 158 | `n_time_periods` and scale it based on the given budget. Given this, one can 159 | compare how this initial values (based on average historic allocation) compare 160 | to the output of the optimisation in terms of sales/KPI. 161 | 162 | Args: 163 | n_time_periods: Number of time periods the optimization will be done with. 164 | media: Historic media data the model was trained with. 165 | media_scaler: Scaler that was used to scale the media data before training. 166 | budget: Total budget to allocate during the optimization time. 167 | prices: An array with shape (n_media_channels,) for the cost of each media 168 | channel unit. 169 | 170 | Returns: 171 | An array with the starting value for each media channel for the 172 | optimization. 173 | """ 174 | previous_allocation = media.mean(axis=0) * n_time_periods 175 | if media_scaler: # Scale before sum as geo scaler has shape (c, g). 176 | previous_allocation = media_scaler.inverse_transform(previous_allocation) 177 | 178 | if media.ndim == 3: 179 | previous_allocation = previous_allocation.sum(axis=-1) 180 | 181 | avg_spend_per_channel = previous_allocation * prices 182 | pct_spend_per_channel = avg_spend_per_channel / avg_spend_per_channel.sum() 183 | budget_per_channel = budget * pct_spend_per_channel 184 | media_unit_per_channel = budget_per_channel / prices 185 | return media_unit_per_channel 186 | 187 | 188 | def find_optimal_budgets( 189 | n_time_periods: int, 190 | media_mix_model: lightweight_mmm.LightweightMMM, 191 | budget: Union[float, int], 192 | prices: jnp.ndarray, 193 | extra_features: Optional[jnp.ndarray] = None, 194 | media_gap: Optional[jnp.ndarray] = None, 195 | target_scaler: Optional[preprocessing.CustomScaler] = None, 196 | media_scaler: Optional[preprocessing.CustomScaler] = None, 197 | bounds_lower_pct: Union[float, jnp.ndarray] = .2, 198 | bounds_upper_pct: Union[float, jnp.ndarray] = .2, 199 | max_iterations: int = 200, 200 | solver_func_tolerance: float = 1e-06, 201 | solver_step_size: float = 1.4901161193847656e-08, 202 | seed: Optional[int] = None) -> optimize.OptimizeResult: 203 | """Finds the best media allocation based on MMM model, prices and a budget. 204 | 205 | Args: 206 | n_time_periods: Number of time periods to optimize for. If model is built on 207 | weekly data, this would be the number of weeks ahead to optimize. 208 | media_mix_model: Media mix model to use for the optimization. 209 | budget: Total budget to allocate during the optimization time. 210 | prices: An array with shape (n_media_channels,) for the cost of each media 211 | channel unit. 212 | extra_features: Extra features needed for the model to predict. 213 | media_gap: Media data gap between the end of training data and the start of 214 | the out of sample media given. Eg. if 100 weeks of data were used for 215 | training and prediction starts 8 weeks after training data finished we 216 | need to provide the 8 weeks missing between the training data and the 217 | prediction data so data transformations (adstock, carryover, ...) can take 218 | place correctly. 219 | target_scaler: Scaler that was used to scale the target before training. 220 | media_scaler: Scaler that was used to scale the media data before training. 221 | bounds_lower_pct: Relative percentage decrease from the mean value to 222 | consider as new lower bound. 223 | bounds_upper_pct: Relative percentage increase from the mean value to 224 | consider as new upper bound. 225 | max_iterations: Number of max iterations to use for the SLSQP scipy 226 | optimizer. Default is 200. 227 | solver_func_tolerance: Precision goal for the value of the prediction in 228 | the stopping criterion. Maps directly to scipy's `ftol`. Intended only 229 | for advanced users. For more details see: 230 | https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp. 231 | solver_step_size: Step size used for numerical approximation of the 232 | Jacobian. Maps directly to scipy's `eps`. Intended only for advanced 233 | users. For more details see: 234 | https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp. 235 | seed: Seed to use for PRNGKey during sampling. For replicability run 236 | this function and any other function that gets predictions with the same 237 | seed. 238 | 239 | Returns: 240 | solution: OptimizeResult object containing the results of the optimization. 241 | kpi_without_optim: Predicted target based on original allocation proportion 242 | among channels from the historical data. 243 | starting_values: Budget Allocation based on original allocation proportion 244 | and the given total budget. 245 | """ 246 | if not hasattr(media_mix_model, "media"): 247 | raise ValueError( 248 | "The passed model has not been trained. Please fit the model before " 249 | "running optimization.") 250 | jax.config.update("jax_enable_x64", True) 251 | 252 | if isinstance(bounds_lower_pct, float): 253 | bounds_lower_pct = jnp.repeat(a=bounds_lower_pct, repeats=len(prices)) 254 | if isinstance(bounds_upper_pct, float): 255 | bounds_upper_pct = jnp.repeat(a=bounds_upper_pct, repeats=len(prices)) 256 | 257 | bounds = _get_lower_and_upper_bounds( 258 | media=media_mix_model.media, 259 | n_time_periods=n_time_periods, 260 | lower_pct=bounds_lower_pct, 261 | upper_pct=bounds_upper_pct, 262 | media_scaler=media_scaler) 263 | if jnp.sum(bounds.lb * prices) > budget: 264 | logging.warning( 265 | "Budget given is smaller than the lower bounds of the constraints for " 266 | "optimization. This will lead to faulty optimization. Please either " 267 | "increase the budget or change the lower bound by increasing the " 268 | "percentage decrease with the `bounds_lower_pct` parameter.") 269 | if jnp.sum(bounds.ub * prices) < budget: 270 | logging.warning( 271 | "Budget given is larger than the upper bounds of the constraints for " 272 | "optimization. This will lead to faulty optimization. Please either " 273 | "reduce the budget or change the upper bound by increasing the " 274 | "percentage increase with the `bounds_upper_pct` parameter.") 275 | 276 | starting_values = _generate_starting_values( 277 | n_time_periods=n_time_periods, 278 | media=media_mix_model.media, 279 | media_scaler=media_scaler, 280 | budget=budget, 281 | prices=prices, 282 | ) 283 | if not media_scaler: 284 | media_scaler = preprocessing.CustomScaler(multiply_by=1, divide_by=1) 285 | if media_mix_model.n_geos == 1: 286 | geo_ratio = 1.0 287 | else: 288 | average_per_time = media_mix_model.media.mean(axis=0) 289 | geo_ratio = average_per_time / jnp.expand_dims( 290 | average_per_time.sum(axis=-1), axis=-1) 291 | media_input_shape = (n_time_periods, *media_mix_model.media.shape[1:]) 292 | partial_objective_function = functools.partial( 293 | _objective_function, extra_features, media_mix_model, 294 | media_input_shape, media_gap, 295 | target_scaler, media_scaler, geo_ratio, seed) 296 | solution = optimize.minimize( 297 | fun=partial_objective_function, 298 | x0=starting_values, 299 | bounds=bounds, 300 | method="SLSQP", 301 | jac="3-point", 302 | options={ 303 | "maxiter": max_iterations, 304 | "disp": True, 305 | "ftol": solver_func_tolerance, 306 | "eps": solver_step_size, 307 | }, 308 | constraints={ 309 | "type": "eq", 310 | "fun": _budget_constraint, 311 | "args": (prices, budget) 312 | }) 313 | 314 | kpi_without_optim = _objective_function(extra_features=extra_features, 315 | media_mix_model=media_mix_model, 316 | media_input_shape=media_input_shape, 317 | media_gap=media_gap, 318 | target_scaler=target_scaler, 319 | media_scaler=media_scaler, 320 | seed=seed, 321 | geo_ratio=geo_ratio, 322 | media_values=starting_values) 323 | logging.info("KPI without optimization: %r", -1 * kpi_without_optim.item()) 324 | logging.info("KPI with optimization: %r", -1 * solution.fun) 325 | 326 | jax.config.update("jax_enable_x64", False) 327 | # TODO(yukaabe): Create an object to contain the results of this function. 328 | return solution, kpi_without_optim, starting_values 329 | -------------------------------------------------------------------------------- /lightweight_mmm/optimize_media_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for optimize_media.""" 16 | from unittest import mock 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | from lightweight_mmm import lightweight_mmm 25 | from lightweight_mmm import optimize_media 26 | from lightweight_mmm import preprocessing 27 | 28 | 29 | class OptimizeMediaTest(parameterized.TestCase): 30 | 31 | @classmethod 32 | def setUpClass(cls): 33 | super(OptimizeMediaTest, cls).setUpClass() 34 | cls.national_mmm = lightweight_mmm.LightweightMMM() 35 | cls.national_mmm.fit( 36 | media=jnp.ones((50, 5)), 37 | target=jnp.ones(50), 38 | media_prior=jnp.ones(5) * 50, 39 | number_warmup=2, 40 | number_samples=2, 41 | number_chains=1) 42 | cls.geo_mmm = lightweight_mmm.LightweightMMM() 43 | cls.geo_mmm.fit( 44 | media=jnp.ones((50, 5, 3)), 45 | target=jnp.ones((50, 3)), 46 | media_prior=jnp.ones(5) * 50, 47 | number_warmup=2, 48 | number_samples=2, 49 | number_chains=1) 50 | 51 | def setUp(self): 52 | super().setUp() 53 | self.mock_minimize = self.enter_context( 54 | mock.patch.object(optimize_media.optimize, "minimize", autospec=True)) 55 | 56 | @parameterized.named_parameters([ 57 | dict( 58 | testcase_name="national", 59 | model_name="national_mmm", 60 | geo_ratio=1), 61 | dict( 62 | testcase_name="geo", 63 | model_name="geo_mmm", 64 | geo_ratio=np.tile(0.33, reps=(5, 3))) 65 | ]) 66 | def test_objective_function_generates_correct_value_type_and_sign( 67 | self, model_name, geo_ratio): 68 | 69 | mmm = getattr(self, model_name) 70 | extra_features = mmm._extra_features 71 | time_periods = 10 72 | 73 | kpi_predicted = optimize_media._objective_function( 74 | extra_features=extra_features, 75 | media_mix_model=mmm, 76 | media_input_shape=(time_periods, *mmm.media.shape[1:]), 77 | media_gap=None, 78 | target_scaler=None, 79 | media_scaler=preprocessing.CustomScaler(), 80 | media_values=jnp.ones(mmm.n_media_channels) * time_periods, 81 | geo_ratio=geo_ratio, 82 | seed=10) 83 | 84 | self.assertIsInstance(kpi_predicted, jax.Array) 85 | self.assertLessEqual(kpi_predicted, 0) 86 | self.assertEqual(kpi_predicted.shape, ()) 87 | 88 | @parameterized.named_parameters([ 89 | dict( 90 | testcase_name="zero_output", 91 | media=np.ones(9), 92 | prices=np.array([1, 2, 3]), 93 | budget=18, 94 | expected_value=0), 95 | dict( 96 | testcase_name="negative_output", 97 | media=np.ones(9), 98 | prices=np.array([1, 2, 3]), 99 | budget=20, 100 | expected_value=-2), 101 | dict( 102 | testcase_name="positive_output", 103 | media=np.ones(9), 104 | prices=np.array([1, 2, 3]), 105 | budget=16, 106 | expected_value=2), 107 | dict( 108 | testcase_name="bigger_array", 109 | media=np.ones(18), 110 | prices=np.array([2, 2, 2]), 111 | budget=36, 112 | expected_value=0), 113 | ]) 114 | def test_budget_constraint(self, media, prices, budget, expected_value): 115 | generated_value = optimize_media._budget_constraint( 116 | media=media, prices=prices, budget=budget) 117 | 118 | self.assertEqual(generated_value, expected_value) 119 | 120 | @parameterized.named_parameters([ 121 | dict( 122 | testcase_name="national_media_scaler", 123 | model_name="national_mmm"), 124 | dict( 125 | testcase_name="geo_media_scaler", 126 | model_name="geo_mmm") 127 | ]) 128 | def test_find_optimal_budgets_with_scaler_optimize_called_with_right_params( 129 | self, model_name): 130 | 131 | mmm = getattr(self, model_name) 132 | media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) 133 | media_scaler.fit(2 * jnp.ones((10, *mmm.media.shape[1:]))) 134 | optimize_media.find_optimal_budgets( 135 | n_time_periods=15, 136 | media_mix_model=mmm, 137 | budget=30, 138 | prices=jnp.ones(mmm.n_media_channels), 139 | target_scaler=None, 140 | media_scaler=media_scaler) 141 | 142 | _, call_kwargs = self.mock_minimize.call_args_list[0] 143 | # 15 weeks at 1.2 gives us 12. and 18. bounds times 2 (scaler) 24. and 36. 144 | np.testing.assert_array_almost_equal(call_kwargs["bounds"].lb, 145 | np.repeat(24., repeats=5) * mmm.n_geos, 146 | decimal=3) 147 | np.testing.assert_array_almost_equal(call_kwargs["bounds"].ub, 148 | np.repeat(36., repeats=5) * mmm.n_geos, 149 | decimal=3) 150 | # We only added scaler with divide operation so we only expectec x2 in 151 | # the divide_by parameter. 152 | np.testing.assert_array_almost_equal(call_kwargs["fun"].args[5].divide_by, 153 | 2 * jnp.ones(mmm.media.shape[1:]), 154 | decimal=3) 155 | np.testing.assert_array_almost_equal(call_kwargs["fun"].args[5].multiply_by, 156 | jnp.ones(mmm.media.shape[1:]), 157 | decimal=3) 158 | 159 | @parameterized.named_parameters([ 160 | dict( 161 | testcase_name="national", 162 | model_name="national_mmm"), 163 | dict( 164 | testcase_name="geo", 165 | model_name="geo_mmm") 166 | ]) 167 | def test_find_optimal_budgets_without_scaler_optimize_called_with_right_params( 168 | self, model_name): 169 | 170 | mmm = getattr(self, model_name) 171 | optimize_media.find_optimal_budgets( 172 | n_time_periods=15, 173 | media_mix_model=mmm, 174 | budget=30, 175 | prices=jnp.ones(mmm.n_media_channels), 176 | target_scaler=None, 177 | media_scaler=None) 178 | 179 | _, call_kwargs = self.mock_minimize.call_args_list[0] 180 | # 15 weeks at 1.2 gives us 18. bounds 181 | np.testing.assert_array_almost_equal( 182 | call_kwargs["bounds"].lb, 183 | np.repeat(12., repeats=5) * mmm.n_geos, 184 | decimal=3) 185 | np.testing.assert_array_almost_equal( 186 | call_kwargs["bounds"].ub, 187 | np.repeat(18., repeats=5) * mmm.n_geos, 188 | decimal=3) 189 | 190 | np.testing.assert_array_almost_equal( 191 | call_kwargs["fun"].args[5].divide_by, 192 | jnp.ones(mmm.n_media_channels), 193 | decimal=3) 194 | np.testing.assert_array_almost_equal( 195 | call_kwargs["fun"].args[5].multiply_by, 196 | jnp.ones(mmm.n_media_channels), 197 | decimal=3) 198 | 199 | @parameterized.named_parameters([ 200 | dict( 201 | testcase_name="national", 202 | model_name="national_mmm"), 203 | dict( 204 | testcase_name="geo", 205 | model_name="geo_mmm") 206 | ]) 207 | def test_predict_called_with_right_args(self, model_name): 208 | mmm = getattr(self, model_name) 209 | 210 | optimize_media.find_optimal_budgets( 211 | n_time_periods=15, 212 | media_mix_model=mmm, 213 | budget=30, 214 | prices=jnp.ones(mmm.n_media_channels), 215 | target_scaler=None, 216 | media_scaler=None) 217 | 218 | @parameterized.named_parameters([ 219 | dict( 220 | testcase_name="national", 221 | model_name="national_mmm"), 222 | dict( 223 | testcase_name="geo", 224 | model_name="geo_mmm") 225 | ]) 226 | def test_budget_lower_than_constraints_warns_user(self, model_name): 227 | mmm = getattr(self, model_name) 228 | expected_warning = ( 229 | "Budget given is smaller than the lower bounds of the constraints for " 230 | "optimization. This will lead to faulty optimization. Please either " 231 | "increase the budget or change the lower bound by increasing the " 232 | "percentage decrease with the `bounds_lower_pct` parameter.") 233 | 234 | with self.assertLogs(level="WARNING") as context_manager: 235 | optimize_media.find_optimal_budgets( 236 | n_time_periods=5, 237 | media_mix_model=mmm, 238 | budget=1, 239 | prices=jnp.ones(mmm.n_media_channels)) 240 | self.assertEqual(f"WARNING:absl:{expected_warning}", 241 | context_manager.output[0]) 242 | 243 | @parameterized.named_parameters([ 244 | dict( 245 | testcase_name="national", 246 | model_name="national_mmm"), 247 | dict( 248 | testcase_name="geo", 249 | model_name="geo_mmm") 250 | ]) 251 | def test_budget_higher_than_constraints_warns_user(self, model_name): 252 | mmm = getattr(self, model_name) 253 | expected_warning = ( 254 | "Budget given is larger than the upper bounds of the constraints for " 255 | "optimization. This will lead to faulty optimization. Please either " 256 | "reduce the budget or change the upper bound by increasing the " 257 | "percentage increase with the `bounds_upper_pct` parameter.") 258 | 259 | with self.assertLogs(level="WARNING") as context_manager: 260 | optimize_media.find_optimal_budgets( 261 | n_time_periods=5, 262 | media_mix_model=mmm, 263 | budget=2000, 264 | prices=jnp.ones(5)) 265 | self.assertEqual(f"WARNING:absl:{expected_warning}", 266 | context_manager.output[0]) 267 | 268 | @parameterized.named_parameters([ 269 | dict( 270 | testcase_name="national", 271 | model_name="national_mmm", 272 | expected_len=3), 273 | dict( 274 | testcase_name="geo", 275 | model_name="geo_mmm", 276 | expected_len=3) 277 | ]) 278 | def test_find_optimal_budgets_has_right_output_length_datatype( 279 | self, model_name, expected_len): 280 | 281 | mmm = getattr(self, model_name) 282 | media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) 283 | media_scaler.fit(2 * jnp.ones((10, *mmm.media.shape[1:]))) 284 | results = optimize_media.find_optimal_budgets( 285 | n_time_periods=15, 286 | media_mix_model=mmm, 287 | budget=30, 288 | prices=jnp.ones(mmm.n_media_channels), 289 | target_scaler=None, 290 | media_scaler=media_scaler) 291 | self.assertLen(results, expected_len) 292 | self.assertIsInstance(results[1], jax.Array) 293 | self.assertIsInstance(results[2], jax.Array) 294 | 295 | @parameterized.named_parameters([ 296 | dict( 297 | testcase_name="national_prices", 298 | model_name="national_mmm", 299 | prices=np.array([1., 0.8, 1.2, 1.5, 0.5]), 300 | ), 301 | dict( 302 | testcase_name="national_ones", 303 | model_name="national_mmm", 304 | prices=np.ones(5), 305 | ), 306 | dict( 307 | testcase_name="geo_prices", 308 | model_name="geo_mmm", 309 | prices=np.array([1., 0.8, 1.2, 1.5, 0.5]), 310 | ), 311 | dict( 312 | testcase_name="geo_ones", 313 | model_name="geo_mmm", 314 | prices=np.ones(5), 315 | ), 316 | ]) 317 | def test_generate_starting_values_calculates_correct_values( 318 | self, model_name, prices): 319 | mmm = getattr(self, model_name) 320 | n_time_periods = 10 321 | budget = mmm.n_media_channels * n_time_periods 322 | starting_values = optimize_media._generate_starting_values( 323 | n_time_periods=10, 324 | media_scaler=None, 325 | media=mmm.media, 326 | budget=budget, 327 | prices=prices, 328 | ) 329 | 330 | # Given that data is all ones, starting values will be equal to prices. 331 | np.testing.assert_array_almost_equal( 332 | starting_values, jnp.repeat(n_time_periods, repeats=5)) 333 | 334 | if __name__ == "__main__": 335 | absltest.main() 336 | -------------------------------------------------------------------------------- /lightweight_mmm/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Set of utilities for LightweighMMM package.""" 16 | import pickle 17 | import time 18 | from typing import Any, List, Optional, Tuple 19 | 20 | from absl import logging 21 | from jax import random 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import pandas as pd 25 | from scipy import interpolate 26 | from scipy import optimize 27 | from scipy import spatial 28 | from scipy import stats 29 | from tensorflow.io import gfile 30 | 31 | from lightweight_mmm import media_transforms 32 | 33 | 34 | def save_model( 35 | media_mix_model: Any, 36 | file_path: str 37 | ) -> None: 38 | """Saves the given model in the given path. 39 | 40 | Args: 41 | media_mix_model: Model to save on disk. 42 | file_path: File path where the model should be placed. 43 | """ 44 | with gfile.GFile(file_path, "wb") as file: 45 | pickle.dump(obj=media_mix_model, file=file) 46 | 47 | 48 | def load_model(file_path: str) -> Any: 49 | """Loads a model given a string path. 50 | 51 | Args: 52 | file_path: Path of the file containing the model. 53 | 54 | Returns: 55 | The LightweightMMM object that was stored in the given path. 56 | """ 57 | with gfile.GFile(file_path, "rb") as file: 58 | media_mix_model = pickle.load(file=file) 59 | 60 | for attr in dir(media_mix_model): 61 | if attr.startswith("__"): 62 | continue 63 | attr_value = getattr(media_mix_model, attr) 64 | if isinstance(attr_value, np.ndarray): 65 | setattr(media_mix_model, attr, jnp.array(attr_value)) 66 | 67 | return media_mix_model 68 | 69 | 70 | def get_time_seed() -> int: 71 | """Generates an integer using the last decimals of time.time(). 72 | 73 | Returns: 74 | Integer to be used as seed. 75 | """ 76 | # time.time() has the following format: 1645174953.0429401 77 | return int(str(time.time()).split(".")[1]) 78 | 79 | 80 | def simulate_dummy_data( 81 | data_size: int, 82 | n_media_channels: int, 83 | n_extra_features: int, 84 | geos: int = 1, 85 | seed: int = 5 86 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: 87 | """Simulates dummy data needed for media mix modelling. 88 | 89 | This function's goal is to be super simple and not have many parameters, 90 | although it does not generate a fully realistic dataset is only meant to be 91 | used for demos/tutorial purposes. Uses carryover for lagging but has no 92 | saturation and no trend. 93 | 94 | The data simulated includes the media data, extra features, a target/KPI and 95 | costs. 96 | 97 | Args: 98 | data_size: Number of rows to generate. 99 | n_media_channels: Number of media channels to generate. 100 | n_extra_features: Number of extra features to generate. 101 | geos: Number of geos for geo level data (default = 1 for national). 102 | seed: Random seed. 103 | 104 | Returns: 105 | The simulated media, extra features, target and costs. 106 | """ 107 | if data_size < 1 or n_media_channels < 1 or n_extra_features < 1: 108 | raise ValueError( 109 | "Data size, n_media_channels and n_extra_features must be greater than" 110 | " 0. Please check the values introduced are greater than zero.") 111 | data_offset = int(data_size * 0.2) 112 | data_size += data_offset 113 | key = random.PRNGKey(seed) 114 | sub_keys = random.split(key=key, num=7) 115 | media_data = random.normal(key=sub_keys[0], 116 | shape=(data_size, n_media_channels)) * 1.5 + 20 117 | 118 | extra_features = random.normal(key=sub_keys[1], 119 | shape=(data_size, n_extra_features)) + 5 120 | # Reduce the costs to make ROI realistic. 121 | costs = media_data[data_offset:].sum(axis=0) * .1 122 | 123 | seasonality = media_transforms.calculate_seasonality( 124 | number_periods=data_size, 125 | degrees=2, 126 | frequency=52, 127 | gamma_seasonality=1) 128 | target_noise = random.normal(key=sub_keys[2], shape=(data_size,)) + 3 129 | 130 | # media_data_transformed = media_transforms.adstock(media_data) 131 | media_data_transformed = media_transforms.carryover( 132 | data=media_data, 133 | ad_effect_retention_rate=jnp.full((n_media_channels,), fill_value=.5), 134 | peak_effect_delay=jnp.full((n_media_channels,), fill_value=1.)) 135 | beta_media = random.normal(key=sub_keys[3], shape=(n_media_channels,)) + 1 136 | beta_extra_features = random.normal(key=sub_keys[4], 137 | shape=(n_extra_features,)) 138 | # There is no trend to keep this very simple. 139 | target = 15 + seasonality + media_data_transformed.dot( 140 | beta_media) + extra_features.dot(beta_extra_features) + target_noise 141 | 142 | logging.info("Correlation between transformed media and target") 143 | logging.info([ 144 | np.corrcoef(target[data_offset:], media_data_transformed[data_offset:, 145 | i])[0, 1] 146 | for i in range(n_media_channels) 147 | ]) 148 | 149 | logging.info("True ROI for media channels") 150 | logging.info([ 151 | sum(media_data_transformed[data_offset:, i] * beta_media[i]) / costs[i] 152 | for i in range(n_media_channels) 153 | ]) 154 | 155 | if geos > 1: 156 | # Distribute national data to geo and add some more noise. 157 | weights = random.uniform(key=sub_keys[5], shape=(1, geos)) 158 | weights /= sum(weights) 159 | target_noise = random.normal(key=sub_keys[6], shape=(data_size, geos)) * .5 160 | target = target[:, np.newaxis].dot(weights) + target_noise 161 | media_data = media_data[:, :, np.newaxis].dot(weights) 162 | extra_features = extra_features[:, :, np.newaxis].dot(weights) 163 | 164 | return (media_data[data_offset:], extra_features[data_offset:], 165 | target[data_offset:], costs) 166 | 167 | 168 | def _split_array_into_list( 169 | dataframe: pd.DataFrame, 170 | split_level_feature: str, 171 | features: List[str], 172 | national_model_flag: bool = True) -> List[np.ndarray]: 173 | """Splits data frame into list of jax arrays. 174 | 175 | Args: 176 | dataframe: Dataframe with all the modeling feature. 177 | split_level_feature: Feature that will be used to split. 178 | features: List of feature to export from data frame. 179 | national_model_flag: Whether the data frame is used for national model. 180 | 181 | Returns: 182 | List of jax arrays. 183 | """ 184 | split_level = dataframe[split_level_feature].unique() 185 | array_list_by_level = [ 186 | dataframe.loc[dataframe[split_level_feature] == level, features].values.T 187 | for level in split_level 188 | ] 189 | feature_array = jnp.stack(array_list_by_level) 190 | if national_model_flag: 191 | feature_array = jnp.squeeze(feature_array, axis=2) 192 | return feature_array# jnp-type 193 | 194 | 195 | def dataframe_to_jax( 196 | dataframe: pd.DataFrame, 197 | media_features: List[str], 198 | extra_features: List[str], 199 | date_feature: str, 200 | target: str, 201 | geo_feature: Optional[str] = None, 202 | cost_features: Optional[List[str]] = None 203 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: 204 | """Converts pandas dataframe to right data format for media mix model. 205 | 206 | This function's goal is to convert dataframe which is most familar with data 207 | scientists to jax arrays to help the users who are not familar with array to 208 | use the lightweight MMM library easier. 209 | 210 | Args: 211 | dataframe: Dataframe with geo, KPI, media and non-media features. 212 | media_features: List of media feature names. 213 | extra_features: List of non media feature names. 214 | date_feature: Date feature name. 215 | target: Target variables name. 216 | geo_feature: Geo feature name and it is optional if the data is at national 217 | level. 218 | cost_features: List of media cost variables and it is optional if user 219 | use actual media cost as their media features in the model. 220 | 221 | Returns: 222 | Media, extra features, target and costs arrays. 223 | 224 | Raises: 225 | ValueError: If each geo has unequal number of weeks or there is only one 226 | value in the geo feature. 227 | """ 228 | if geo_feature is not None: 229 | if dataframe[geo_feature].nunique() == 1: 230 | raise ValueError( 231 | "Geo feature has at least two geos or keep default for national model" 232 | ) 233 | count_by_geo = dataframe.groupby( 234 | geo_feature)[date_feature].count().reset_index() 235 | unique_date_count = count_by_geo[date_feature].nunique() 236 | if unique_date_count != 1: 237 | raise ValueError("Not all the geos have same number of weeks.") 238 | national_model_flag = False 239 | features_to_sort = [date_feature, geo_feature] 240 | else: 241 | national_model_flag = True 242 | features_to_sort = [date_feature] 243 | 244 | df_sorted = dataframe.sort_values(by=features_to_sort) 245 | media_features_data = _split_array_into_list( 246 | dataframe=df_sorted, 247 | split_level_feature=date_feature, 248 | features=media_features, 249 | national_model_flag=national_model_flag) 250 | 251 | extra_features_data = _split_array_into_list( 252 | dataframe=df_sorted, 253 | split_level_feature=date_feature, 254 | features=extra_features, 255 | national_model_flag=national_model_flag) 256 | 257 | target_data = _split_array_into_list( 258 | dataframe=df_sorted, 259 | split_level_feature=date_feature, 260 | features=[target], 261 | national_model_flag=national_model_flag) 262 | target_data = jnp.squeeze(target_data)# jnp-type 263 | 264 | if cost_features: 265 | cost_data = jnp.dot( 266 | jnp.full(len(dataframe), 1), dataframe[cost_features].values) 267 | else: 268 | cost_data = jnp.dot( 269 | jnp.full(len(dataframe), 1), dataframe[media_features].values) 270 | return (media_features_data, extra_features_data, target_data, cost_data)# jax-ndarray 271 | 272 | 273 | def get_halfnormal_mean_from_scale(scale: float) -> float: 274 | """Returns the mean of the half-normal distribition.""" 275 | # https://en.wikipedia.org/wiki/Half-normal_distribution 276 | return scale * np.sqrt(2) / np.sqrt(np.pi) 277 | 278 | 279 | def get_halfnormal_scale_from_mean(mean: float) -> float: 280 | """Returns the scale of the half-normal distribution.""" 281 | # https://en.wikipedia.org/wiki/Half-normal_distribution 282 | return mean * np.sqrt(np.pi) / np.sqrt(2) 283 | 284 | 285 | def get_beta_params_from_mu_sigma(mu: float, 286 | sigma: float, 287 | bracket: Tuple[float, float] = (.5, 100.) 288 | ) -> Tuple[float, float]: 289 | """Deterministically estimates (a, b) from (mu, sigma) of a beta variable. 290 | 291 | https://en.wikipedia.org/wiki/Beta_distribution 292 | 293 | Args: 294 | mu: The sample mean of the beta distributed variable. 295 | sigma: The sample standard deviation of the beta distributed variable. 296 | bracket: Search bracket for b. 297 | 298 | Returns: 299 | Tuple of the (a, b) parameters. 300 | """ 301 | # Assume a = 1 to find b. 302 | def _f(x): 303 | return x ** 2 + 4 * x + 5 + 2 / x - 1 / sigma ** 2 304 | b = optimize.root_scalar(_f, bracket=bracket, method="brentq").root 305 | # Given b, now find a better a. 306 | a = b / (1 / mu - 1) 307 | return a, b 308 | 309 | 310 | def _estimate_pdf(p: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: 311 | """Estimates smooth pdf with Gaussian kernel. 312 | 313 | Args: 314 | p: Samples. 315 | x: The continuous x space (sorted). 316 | 317 | Returns: 318 | A density vector. 319 | """ 320 | density = sum(stats.norm(xi).pdf(x) for xi in p) 321 | return density / density.sum() 322 | 323 | 324 | def _pmf(p: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: 325 | """Estimates discrete pmf. 326 | 327 | Args: 328 | p: Samples. 329 | x: The discrete x space (sorted). 330 | 331 | Returns: 332 | A pmf vector. 333 | """ 334 | p_cdf = jnp.array([jnp.sum(p <= x[i]) for i in range(len(x))]) 335 | p_pmf = np.concatenate([[p_cdf[0]], jnp.diff(p_cdf)]) 336 | return p_pmf / p_pmf.sum() 337 | 338 | 339 | def distance_pior_posterior(p: jnp.ndarray, q: jnp.ndarray, method: str = "KS", 340 | discrete: bool = True) -> float: 341 | """Quantifies the distance between two distributions. 342 | 343 | Note we do not use KL divergence because it's not defined when a probability 344 | is 0. 345 | 346 | https://en.wikipedia.org/wiki/Hellinger_distance 347 | 348 | Args: 349 | p: Samples for distribution 1. 350 | q: Samples for distribution 2. 351 | method: We can have four methods: KS, Hellinger, JS and min. 352 | discrete: Whether input data is discrete or continuous. 353 | 354 | Returns: 355 | The distance metric (between 0 and 1). 356 | """ 357 | 358 | if method == "KS": 359 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html 360 | return stats.ks_2samp(p, q).statistic 361 | elif method in ["Hellinger", "JS", "min"]: 362 | if discrete: 363 | x = jnp.unique(jnp.concatenate((p, q))) 364 | p_pdf = _pmf(p, x) 365 | q_pdf = _pmf(q, x) 366 | else: 367 | minx, maxx = min(p.min(), q.min()), max(p.max(), q.max()) 368 | x = np.linspace(minx, maxx, 100) 369 | p_pdf = _estimate_pdf(p, x) 370 | q_pdf = _estimate_pdf(q, x) 371 | if method == "Hellinger": 372 | return np.sqrt(jnp.sum((np.sqrt(p_pdf) - np.sqrt(q_pdf)) ** 2)) / np.sqrt(2) 373 | elif method == "JS": 374 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html 375 | return spatial.distance.jensenshannon(p_pdf, q_pdf) 376 | else: 377 | return 1 - np.minimum(p_pdf, q_pdf).sum() 378 | 379 | 380 | def interpolate_outliers(x: jnp.ndarray, 381 | outlier_idx: jnp.ndarray) -> jnp.ndarray: 382 | """Overwrites outliers in x with interpolated values. 383 | 384 | Args: 385 | x: The original univariate variable with outliers. 386 | outlier_idx: Indices of the outliers in x. 387 | 388 | Returns: 389 | A cleaned x with outliers overwritten. 390 | 391 | """ 392 | time_idx = jnp.arange(len(x)) 393 | inverse_idx = jnp.array([i for i in range(len(x)) if i not in outlier_idx]) 394 | interp_func = interpolate.interp1d( 395 | time_idx[inverse_idx], x[inverse_idx], kind="linear") 396 | x = x.at[outlier_idx].set(interp_func(time_idx[outlier_idx])) 397 | return x 398 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Read the Docs configuration file 16 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 17 | version: 2 18 | build: 19 | os: ubuntu-22.04 20 | tools: 21 | python: "3.10" 22 | sphinx: 23 | builder: html 24 | configuration: docs/conf.py 25 | fail_on_warning: false 26 | python: 27 | install: 28 | - requirements: requirements/requirements_docs.txt 29 | - requirements: requirements/requirements.txt 30 | - method: setuptools 31 | path: . -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | absl-py 16 | arviz>=0.11.2 17 | immutabledict>=2.0.0 18 | jax>=0.3.18 19 | jaxlib>=0.3.18 20 | matplotlib==3.6.1 21 | numpy>=2.3 22 | numpyro>=0.9.2 23 | pandas>=1.1.5 24 | scipy 25 | seaborn==0.11.1 26 | scikit-learn 27 | statsmodels>=0.13.0 28 | tensorflow>=2.7.2 -------------------------------------------------------------------------------- /requirements/requirements_docs.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | sphinx==4.5.0 16 | sphinx_rtd_theme==1.0.0 17 | sphinxcontrib-katex==0.8.6 18 | sphinxcontrib-bibtex==1.0.0 19 | sphinx-autodoc-typehints==1.11.1 20 | IPython==7.16.3 21 | ipykernel==5.3.4 22 | pandoc==1.0.2 23 | myst_nb==0.13.1 24 | docutils==0.16 25 | matplotlib==3.6.1 26 | nbsphinx 27 | sphinx_markdown_tables 28 | ipython_genutils -------------------------------------------------------------------------------- /requirements/requirements_tests.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | pytest-xdist 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup for lightweight_mmm value package.""" 16 | 17 | import os 18 | 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | 24 | 25 | def _get_readme(): 26 | try: 27 | readme = open( 28 | os.path.join(_CURRENT_DIR, "README.md"), encoding="utf-8").read() 29 | except OSError: 30 | readme = "" 31 | return readme 32 | 33 | 34 | def _get_version(): 35 | with open(os.path.join(_CURRENT_DIR, "lightweight_mmm", "__init__.py")) as fp: 36 | for line in fp: 37 | if line.startswith("__version__") and "=" in line: 38 | version = line[line.find("=") + 1:].strip(" '\"\n") 39 | if version: 40 | return version 41 | raise ValueError( 42 | "`__version__` not defined in `lightweight_mmm/__init__.py`") 43 | 44 | 45 | def _parse_requirements(path): 46 | 47 | with open(os.path.join(_CURRENT_DIR, path)) as f: 48 | return [ 49 | line.rstrip() 50 | for line in f 51 | if not (line.isspace() or line.startswith("#")) 52 | ] 53 | 54 | _VERSION = _get_version() 55 | _README = _get_readme() 56 | _INSTALL_REQUIREMENTS = _parse_requirements(os.path.join( 57 | _CURRENT_DIR, "requirements", "requirements.txt")) 58 | _TEST_REQUIREMENTS = _parse_requirements(os.path.join( 59 | _CURRENT_DIR, "requirements", "requirements_tests.txt")) 60 | 61 | setup( 62 | name="lightweight_mmm", 63 | version=_VERSION, 64 | description="Package for Media-Mix-Modelling", 65 | long_description="\n".join([_README]), 66 | long_description_content_type="text/markdown", 67 | author="Google LLC", 68 | author_email="no-reply@google.com", 69 | license="Apache 2.0", 70 | packages=find_packages(), 71 | install_requires=_INSTALL_REQUIREMENTS, 72 | tests_require=_TEST_REQUIREMENTS, 73 | url="https://github.com/google/lightweight_mmm", 74 | classifiers=[ 75 | "Development Status :: 3 - Alpha", 76 | "Intended Audience :: Developers", 77 | "Intended Audience :: Science/Research", 78 | "License :: OSI Approved :: Apache Software License", 79 | "Topic :: Scientific/Engineering :: Mathematics", 80 | "Programming Language :: Python :: 3.8", 81 | "Programming Language :: Python :: 3.9", 82 | "Programming Language :: Python :: 3.10", 83 | 84 | ], 85 | ) 86 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | set -e 18 | 19 | readonly VENV_DIR=/tmp/test_env 20 | echo "Creating virtual environment under ${VENV_DIR}." 21 | echo "You might want to remove this when you no longer need it." 22 | 23 | # Install deps in a virtual env. 24 | python -m venv "${VENV_DIR}" 25 | source "${VENV_DIR}/bin/activate" 26 | python --version 27 | 28 | # Install JAX. 29 | python -m pip install --upgrade pip setuptools 30 | python -m pip install -r requirements/requirements.txt 31 | python -c 'import jax; print(jax.__version__)' 32 | 33 | # Run setup.py, this installs the python dependencies 34 | python -m pip install . 35 | 36 | # Python test dependencies. 37 | python -m pip install -r requirements/requirements_tests.txt 38 | 39 | # Run tests using pytest. 40 | python -m pytest lightweight_mmm --------------------------------------------------------------------------------