├── .github └── workflows │ ├── docs.yml │ ├── periodic.yml │ ├── release.yml │ └── testing.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── cluster_experiments ├── __init__.py ├── cupac.py ├── experiment_analysis.py ├── inference │ ├── __init__.py │ ├── analysis_plan.py │ ├── analysis_plan_config.py │ ├── analysis_results.py │ ├── dimension.py │ ├── hypothesis_test.py │ ├── metric.py │ └── variant.py ├── perturbator.py ├── power_analysis.py ├── power_config.py ├── random_splitter.py ├── synthetic_control_utils.py ├── utils.py └── washover.py ├── docs ├── aa_test.ipynb ├── analysis_with_different_hypotheses.ipynb ├── api │ ├── analysis_plan.md │ ├── analysis_results.md │ ├── cupac_model.md │ ├── dimension.md │ ├── experiment_analysis.md │ ├── hypothesis_test.md │ ├── metric.md │ ├── perturbator.md │ ├── power_analysis.md │ ├── power_config.md │ ├── random_splitter.md │ ├── variant.md │ └── washover.md ├── create_custom_classes.ipynb ├── cupac_example.ipynb ├── delta_method.ipynb ├── e2e_mde.ipynb ├── experiment_analysis.ipynb ├── multivariate.ipynb ├── normal_power.ipynb ├── normal_power_lines.ipynb ├── paired_ttest.ipynb ├── plot_calendars.ipynb ├── plot_calendars_hours.ipynb ├── switchback.ipynb ├── synthetic_control.ipynb └── washover_example.ipynb ├── examples ├── cupac_example_gbm.py ├── cupac_example_target_mean.py ├── long_example.py ├── parallel_example.py ├── short_example.py ├── short_example_config.py ├── short_example_dict.py ├── short_example_paired_ttest.py └── short_example_synthetic_control.py ├── mkdocs.yml ├── pyproject.toml ├── ruff.toml ├── setup.py ├── tests ├── __init__.py ├── analysis │ ├── __init__.py │ ├── conftest.py │ ├── test_analysis.py │ ├── test_formula.py │ ├── test_hypothesis.py │ ├── test_ols_analysis.py │ └── test_synthetic_analysis.py ├── cupac │ ├── __init__.py │ ├── conftest.py │ ├── test_aggregator.py │ └── test_cupac_handler.py ├── inference │ ├── __init__.py │ ├── test_analysis_plan.py │ ├── test_analysis_plan_config.py │ ├── test_analysis_results.py │ ├── test_dimension.py │ ├── test_hypothesis_test.py │ ├── test_metric.py │ └── test_variant.py ├── perturbator │ ├── __init__.py │ ├── conftest.py │ └── test_perturbator.py ├── power_analysis │ ├── __init__.py │ ├── conftest.py │ ├── test_cupac_power.py │ ├── test_multivariate.py │ ├── test_normal_power_analysis.py │ ├── test_parallel.py │ ├── test_power_analysis.py │ ├── test_power_analysis_with_pre_experiment_data.py │ ├── test_power_raises.py │ ├── test_seed.py │ └── test_switchback_power.py ├── power_config │ ├── __init__.py │ ├── test_missing_arguments_error.py │ ├── test_params_flow.py │ └── test_warnings_superfluous_params.py ├── splitter │ ├── __init__.py │ ├── conftest.py │ ├── test_fixed_size_clusters_splitter.py │ ├── test_splitter.py │ ├── test_switchback_splitter.py │ ├── test_time_col.py │ └── test_washover.py ├── test_docs.py ├── test_non_clustered.py ├── test_utils.py └── utils.py ├── theme ├── flow.png ├── icon-cluster.png └── icon-cluster.svg ├── tox.ini └── uv.lock /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.9 16 | - run: cp README.md docs/index.md 17 | - run: cp -r theme docs/theme 18 | - run: | 19 | make install-dev-gh-action 20 | source .venv/bin/activate 21 | pip install lxml_html_clean 22 | mkdocs gh-deploy --force 23 | -------------------------------------------------------------------------------- /.github/workflows/periodic.yml: -------------------------------------------------------------------------------- 1 | name: Release unit Tests 2 | 3 | on: 4 | schedule: 5 | - cron: "0 0 * * *" 6 | 7 | jobs: 8 | test-release-ubuntu: 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | python-version: ['3.9', '3.10', '3.11', '3.12'] 13 | os: [ubuntu-latest] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install cluster-experiments 25 | pip freeze 26 | - name: Test with pytest 27 | run: | 28 | make install-test 29 | source .venv/bin/activate 30 | make test 31 | 32 | test-release-windows: 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | matrix: 36 | python-version: ['3.9', '3.10', '3.11', '3.12'] 37 | os: [windows-latest] 38 | 39 | steps: 40 | - uses: actions/checkout@v2 41 | - name: Set up Python ${{ matrix.python-version }} 42 | uses: actions/setup-python@v1 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | - name: Install dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install cluster-experiments 49 | pip freeze 50 | - name: Test with pytest 51 | run: | 52 | make install-test 53 | .venv\\Scripts\\activate 54 | make test 55 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release to PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build-and-deploy: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v2 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.9 20 | 21 | - name: Install dependencies 22 | run: make install-dev-gh-action 23 | 24 | - name: Prepare dist/ 25 | run: make prep-dist 26 | 27 | - name: Publish 28 | uses: pypa/gh-action-pypi-publish@release/v1 29 | with: 30 | skip-existing: true 31 | user: __token__ 32 | password: ${{ secrets.PYPI_API_TOKEN }} 33 | verify-metadata: false 34 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Code Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test-coverage: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ['3.9', '3.10', '3.11', '3.12'] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v1 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install Testing Dependencies 25 | run: make install-test 26 | - name: Automated Checking Mechanism 27 | run: | 28 | source .venv/bin/activate 29 | make check 30 | - name: Code coverage 31 | uses: codecov/codecov-action@v4 32 | with: 33 | fail_ci_if_error: true 34 | files: coverage.xml 35 | env: 36 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # VsCode 163 | .vscode/ 164 | 165 | .DS_Store 166 | 167 | docs/index.md 168 | docs/theme/ 169 | todos.txt 170 | 171 | # 172 | experiments/ 173 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.3.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - repo: https://github.com/psf/black 12 | rev: 25.1.0 13 | hooks: 14 | - id: black 15 | language_version: python3 16 | - repo: https://github.com/charliermarsh/ruff-pre-commit 17 | rev: 'v0.0.261' 18 | hooks: 19 | - id: ruff 20 | args: [--fix, --exit-non-zero-on-fix] 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | uv is needed as package manager. If you haven't installed it, run the installation command: 4 | 5 | ```bash 6 | curl -LsSf https://astral.sh/uv/install.sh | sh 7 | ``` 8 | 9 | ### Project setup 10 | 11 | Clone repo and go to the project directory: 12 | 13 | ```bash 14 | git clone git@github.com:david26694/cluster-experiments.git 15 | cd cluster-experiments 16 | ``` 17 | 18 | Create virtual environment and activate it: 19 | 20 | ```bash 21 | uv venv -p 3.10 22 | source .venv/bin/activate 23 | ``` 24 | 25 | After creating the virtual environment, install the project dependencies: 26 | 27 | ```bash 28 | make install-dev 29 | ``` 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 david26694 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build 2 | 3 | black: 4 | black cluster_experiments tests setup.py --check 5 | 6 | ruff: 7 | ruff check cluster_experiments tests setup.py 8 | 9 | test: 10 | pytest --cov=./cluster_experiments 11 | 12 | coverage_xml: 13 | coverage xml 14 | 15 | check: black ruff test coverage_xml 16 | 17 | install: 18 | python -m pip install uv 19 | uv sync 20 | 21 | install-dev: 22 | pip install --upgrade pip setuptools wheel 23 | python -m pip install uv 24 | uv sync --extra "dev" 25 | pre-commit install 26 | 27 | install-dev-gh-action: 28 | pip install --upgrade pip setuptools wheel 29 | python -m pip install uv 30 | uv sync --extra "dev" 31 | 32 | install-test: 33 | pip install --upgrade pip setuptools wheel 34 | python -m pip install uv 35 | uv sync --extra "test" 36 | 37 | docs-deploy: 38 | mkdocs gh-deploy 39 | 40 | docs-serve: 41 | cp README.md docs/index.md 42 | rm -rf docs/theme 43 | cp -r theme docs/theme/ 44 | mkdocs serve 45 | 46 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 47 | 48 | clean-build: ## remove build artifacts 49 | rm -fr build/ 50 | rm -fr dist/ 51 | rm -fr .eggs/ 52 | find . -name '*.egg-info' -exec rm -fr {} + 53 | find . -name '*.egg' -exec rm -f {} + 54 | 55 | clean-pyc: ## remove Python file artifacts 56 | find . -name '*.pyc' -exec rm -f {} + 57 | find . -name '*.pyo' -exec rm -f {} + 58 | find . -name '*~' -exec rm -f {} + 59 | find . -name '__pycache__' -exec rm -fr {} + 60 | 61 | clean-test: ## remove test and coverage artifacts 62 | rm -fr .tox/ 63 | rm -f .coverage 64 | rm -fr htmlcov/ 65 | rm -fr .pytest_cache 66 | 67 | prep-dist: clean 68 | uv build 69 | 70 | pypi: prep-dist 71 | twine upload --repository cluster-experiments dist/* 72 | 73 | pypi-gh-actions: prep-dist 74 | # todo: fix this 75 | twine upload --skip-existing dist/* 76 | 77 | # Report log 78 | report-log: 79 | pytest --report-log experiments/reportlog.jsonl 80 | 81 | duration-insights: 82 | pytest-duration-insights explore experiments/reportlog.jsonl 83 | -------------------------------------------------------------------------------- /cluster_experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from cluster_experiments.cupac import EmptyRegressor, TargetAggregation 2 | from cluster_experiments.experiment_analysis import ( 3 | ClusteredOLSAnalysis, 4 | DeltaMethodAnalysis, 5 | ExperimentAnalysis, 6 | GeeExperimentAnalysis, 7 | MLMExperimentAnalysis, 8 | OLSAnalysis, 9 | PairedTTestClusteredAnalysis, 10 | SyntheticControlAnalysis, 11 | TTestClusteredAnalysis, 12 | ) 13 | from cluster_experiments.inference.analysis_plan import AnalysisPlan 14 | from cluster_experiments.inference.dimension import Dimension 15 | from cluster_experiments.inference.hypothesis_test import HypothesisTest 16 | from cluster_experiments.inference.metric import Metric, RatioMetric, SimpleMetric 17 | from cluster_experiments.inference.variant import Variant 18 | from cluster_experiments.perturbator import ( 19 | BetaRelativePerturbator, 20 | BetaRelativePositivePerturbator, 21 | BinaryPerturbator, 22 | ConstantPerturbator, 23 | NormalPerturbator, 24 | Perturbator, 25 | RelativeMixedPerturbator, 26 | RelativePositivePerturbator, 27 | SegmentedBetaRelativePerturbator, 28 | UniformPerturbator, 29 | ) 30 | from cluster_experiments.power_analysis import NormalPowerAnalysis, PowerAnalysis 31 | from cluster_experiments.power_config import PowerConfig 32 | from cluster_experiments.random_splitter import ( 33 | BalancedClusteredSplitter, 34 | BalancedSwitchbackSplitter, 35 | ClusteredSplitter, 36 | FixedSizeClusteredSplitter, 37 | NonClusteredSplitter, 38 | RandomSplitter, 39 | RepeatedSampler, 40 | StratifiedClusteredSplitter, 41 | StratifiedSwitchbackSplitter, 42 | SwitchbackSplitter, 43 | ) 44 | from cluster_experiments.washover import ConstantWashover, EmptyWashover, Washover 45 | 46 | __all__ = [ 47 | "ExperimentAnalysis", 48 | "GeeExperimentAnalysis", 49 | "DeltaMethodAnalysis", 50 | "OLSAnalysis", 51 | "BinaryPerturbator", 52 | "Perturbator", 53 | "ConstantPerturbator", 54 | "UniformPerturbator", 55 | "RelativePositivePerturbator", 56 | "NormalPerturbator", 57 | "BetaRelativePositivePerturbator", 58 | "BetaRelativePerturbator", 59 | "SegmentedBetaRelativePerturbator", 60 | "PowerAnalysis", 61 | "NormalPowerAnalysis", 62 | "PowerConfig", 63 | "EmptyRegressor", 64 | "TargetAggregation", 65 | "BalancedClusteredSplitter", 66 | "ClusteredSplitter", 67 | "RandomSplitter", 68 | "NonClusteredSplitter", 69 | "StratifiedClusteredSplitter", 70 | "SwitchbackSplitter", 71 | "BalancedSwitchbackSplitter", 72 | "StratifiedSwitchbackSplitter", 73 | "RepeatedSampler", 74 | "ClusteredOLSAnalysis", 75 | "TTestClusteredAnalysis", 76 | "PairedTTestClusteredAnalysis", 77 | "EmptyWashover", 78 | "ConstantWashover", 79 | "Washover", 80 | "MLMExperimentAnalysis", 81 | "SyntheticControlAnalysis", 82 | "FixedSizeClusteredSplitter", 83 | "AnalysisPlan", 84 | "Metric", 85 | "SimpleMetric", 86 | "RatioMetric", 87 | "Dimension", 88 | "Variant", 89 | "HypothesisTest", 90 | "RelativeMixedPerturbator", 91 | ] 92 | -------------------------------------------------------------------------------- /cluster_experiments/cupac.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import pandas as pd 4 | from numpy.typing import ArrayLike 5 | from sklearn.base import BaseEstimator 6 | from sklearn.utils.validation import NotFittedError, check_is_fitted 7 | 8 | 9 | class EmptyRegressor(BaseEstimator): 10 | """ 11 | Empty regressor class. It does not do anything, used to glue the code of other estimators and PowerAnalysis 12 | 13 | Each Regressor should have: 14 | - fit method: Uses pre experiment data to fit some kind of model to be used as a covariate and reduce variance. 15 | - predict method: Uses the fitted model to add the covariate on the experiment data. 16 | 17 | It can add aggregates of the target in older data as a covariate, or a model (cupac) to predict the target. 18 | """ 19 | 20 | @classmethod 21 | def from_config(cls, config): 22 | return cls() 23 | 24 | 25 | class TargetAggregation(BaseEstimator): 26 | """ 27 | Adds average of target using pre-experiment data 28 | 29 | Args: 30 | agg_col: Column to group by to aggregate target 31 | target_col: Column to aggregate 32 | smoothing_factor: Smoothing factor for the smoothed mean 33 | Usage: 34 | ```python 35 | import pandas as pd 36 | from cluster_experiments.cupac import TargetAggregation 37 | 38 | df = pd.DataFrame({"agg_col": ["a", "a", "b", "b", "c", "c"], "target_col": [1, 2, 3, 4, 5, 6]}) 39 | new_df = pd.DataFrame({"agg_col": ["a", "a", "b", "b", "c", "c"]}) 40 | target_agg = TargetAggregation("agg_col", "target_col") 41 | target_agg.fit(df.drop(columns="target_col"), df["target_col"]) 42 | df_with_target_agg = target_agg.predict(new_df) 43 | print(df_with_target_agg) 44 | ``` 45 | """ 46 | 47 | def __init__( 48 | self, 49 | agg_col: str, 50 | target_col: str = "target", 51 | smoothing_factor: int = 20, 52 | ): 53 | self.agg_col = agg_col 54 | self.target_col = target_col 55 | self.smoothing_factor = smoothing_factor 56 | self.is_empty = False 57 | self.mean_target_col = f"{self.target_col}_mean" 58 | self.smooth_mean_target_col = f"{self.target_col}_smooth_mean" 59 | self.pre_experiment_agg_df = pd.DataFrame() 60 | 61 | def _get_pre_experiment_mean(self, pre_experiment_df: pd.DataFrame) -> float: 62 | return pre_experiment_df[self.target_col].mean() 63 | 64 | def fit(self, X: pd.DataFrame, y: pd.Series) -> "TargetAggregation": 65 | """Fits "target encoder" model to pre-experiment data""" 66 | pre_experiment_df = X.copy() 67 | pre_experiment_df[self.target_col] = y 68 | 69 | self.pre_experiment_mean = self._get_pre_experiment_mean(pre_experiment_df) 70 | self.pre_experiment_agg_df = ( 71 | pre_experiment_df.assign(count=1) 72 | .groupby(self.agg_col, as_index=False) 73 | .agg({self.target_col: "sum", "count": "sum"}) 74 | .assign( 75 | **{ 76 | self.mean_target_col: lambda x: x[self.target_col] / x["count"], 77 | self.smooth_mean_target_col: lambda x: ( 78 | x[self.target_col] 79 | + self.smoothing_factor * self.pre_experiment_mean 80 | ) 81 | / (x["count"] + self.smoothing_factor), 82 | } 83 | ) 84 | .drop(columns=["count", self.target_col]) 85 | ) 86 | return self 87 | 88 | def predict(self, X: pd.DataFrame) -> ArrayLike: 89 | """Adds average target of pre-experiment data to experiment data""" 90 | return ( 91 | X.merge(self.pre_experiment_agg_df, how="left", on=self.agg_col)[ 92 | self.smooth_mean_target_col 93 | ] 94 | .fillna(self.pre_experiment_mean) 95 | .values 96 | ) 97 | 98 | @classmethod 99 | def from_config(cls, config): 100 | """Creates TargetAggregation from PowerConfig""" 101 | return cls( 102 | agg_col=config.agg_col, 103 | target_col=config.target_col, 104 | smoothing_factor=config.smoothing_factor, 105 | ) 106 | 107 | 108 | class CupacHandler: 109 | """ 110 | CupacHandler class. It handles operations related to the cupac model. 111 | 112 | Its main goal is to call the add_covariates method, where it will add the ouptut from the cupac model, 113 | and this should be used as covariates in the regression method for the hypothesis test. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | cupac_model: Optional[BaseEstimator] = None, 119 | target_col: str = "target", 120 | features_cupac_model: Optional[List[str]] = None, 121 | cache_fit: bool = True, 122 | ): 123 | self.cupac_model: BaseEstimator = cupac_model or EmptyRegressor() 124 | self.target_col = target_col 125 | self.cupac_outcome_name = f"estimate_{target_col}" 126 | self.features_cupac_model: List[str] = features_cupac_model or [] 127 | self.is_cupac = not isinstance(self.cupac_model, EmptyRegressor) 128 | self.cache_fit = cache_fit 129 | 130 | def _prep_data_cupac( 131 | self, df: pd.DataFrame, pre_experiment_df: pd.DataFrame 132 | ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]: 133 | """Prepares data for training and prediction""" 134 | df = df.copy() 135 | pre_experiment_df = pre_experiment_df.copy() 136 | df_predict = df.drop(columns=[self.target_col]) 137 | # Split data into X and y 138 | pre_experiment_x = pre_experiment_df.drop(columns=[self.target_col]) 139 | pre_experiment_y = pre_experiment_df[self.target_col] 140 | 141 | # Keep only cupac features 142 | if self.features_cupac_model: 143 | pre_experiment_x = pre_experiment_x[self.features_cupac_model] 144 | df_predict = df_predict[self.features_cupac_model] 145 | 146 | return df_predict, pre_experiment_x, pre_experiment_y 147 | 148 | def add_covariates( 149 | self, df: pd.DataFrame, pre_experiment_df: Optional[pd.DataFrame] = None 150 | ) -> pd.DataFrame: 151 | """ 152 | Train model to predict outcome variable (based on pre-experiment data) 153 | and add the prediction to the experiment dataframe. Only do this if 154 | we use cupac 155 | Args: 156 | pre_experiment_df: Dataframe with pre-experiment data. 157 | df: Dataframe with outcome and treatment variables. 158 | """ 159 | self.check_cupac_inputs(pre_experiment_df) 160 | 161 | # Early return if no need to add covariates 162 | if not self.need_covariates(pre_experiment_df): 163 | return df 164 | 165 | df = df.copy() 166 | pre_experiment_df = pre_experiment_df.copy() 167 | df_predict, pre_experiment_x, pre_experiment_y = self._prep_data_cupac( 168 | df=df, pre_experiment_df=pre_experiment_df 169 | ) 170 | 171 | # Fit model if it has not been fitted before 172 | self._fit_cupac_model(pre_experiment_x, pre_experiment_y) 173 | 174 | # Predict 175 | estimated_target = self._predict_cupac_model(df_predict) 176 | 177 | # Add cupac outcome name to df 178 | df[self.cupac_outcome_name] = estimated_target 179 | return df 180 | 181 | def _fit_cupac_model( 182 | self, pre_experiment_x: pd.DataFrame, pre_experiment_y: pd.Series 183 | ): 184 | """Fits the cupac model. 185 | Caches the fitted model in the object, so we only fit it once. 186 | We can disable this by setting cache_fit to False. 187 | """ 188 | if not self.cache_fit: 189 | self.cupac_model.fit(pre_experiment_x, pre_experiment_y) 190 | return 191 | 192 | try: 193 | check_is_fitted(self.cupac_model) 194 | except NotFittedError: 195 | self.cupac_model.fit(pre_experiment_x, pre_experiment_y) 196 | 197 | def _predict_cupac_model(self, df_predict: pd.DataFrame) -> ArrayLike: 198 | """Predicts the cupac model""" 199 | if hasattr(self.cupac_model, "predict_proba"): 200 | return self.cupac_model.predict_proba(df_predict)[:, 1] 201 | if hasattr(self.cupac_model, "predict"): 202 | return self.cupac_model.predict(df_predict) 203 | raise ValueError("cupac_model should have predict or predict_proba method.") 204 | 205 | def need_covariates(self, pre_experiment_df: Optional[pd.DataFrame] = None) -> bool: 206 | return pre_experiment_df is not None and self.is_cupac 207 | 208 | def check_cupac_inputs(self, pre_experiment_df: Optional[pd.DataFrame] = None): 209 | if self.is_cupac and pre_experiment_df is None: 210 | raise ValueError("If cupac is used, pre_experiment_df should be provided.") 211 | 212 | if not self.is_cupac and pre_experiment_df is not None: 213 | raise ValueError( 214 | "If cupac is not used, pre_experiment_df should not be provided - " 215 | "remove pre_experiment_df argument or set cupac_model to not None." 216 | ) 217 | -------------------------------------------------------------------------------- /cluster_experiments/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/cluster_experiments/inference/__init__.py -------------------------------------------------------------------------------- /cluster_experiments/inference/analysis_plan_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, List, Optional, Union 3 | 4 | 5 | @dataclass(eq=True) 6 | class AnalysisPlanMetricsConfig: 7 | metrics: List[Dict[str, str]] 8 | variants: List[Dict[str, Union[str, bool]]] 9 | analysis_type: str 10 | variant_col: str = "experiment_group" 11 | alpha: float = 0.05 12 | dimensions: List[Dict[str, Union[str, List]]] = field(default_factory=lambda: []) 13 | analysis_config: Dict = field(default_factory=lambda: {}) 14 | custom_analysis_type_mapper: Optional[Dict] = None 15 | 16 | 17 | @dataclass(eq=True) 18 | class AnalysisPlanConfig: 19 | tests: List[Dict[str, Union[List[Dict], Dict, str]]] 20 | variants: List[Dict[str, Union[str, bool]]] 21 | variant_col: str = "experiment_group" 22 | alpha: float = 0.05 23 | -------------------------------------------------------------------------------- /cluster_experiments/inference/analysis_results.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from typing import List 3 | 4 | import pandas as pd 5 | 6 | 7 | @dataclass 8 | class AnalysisPlanResults: 9 | """ 10 | A dataclass used to represent the results of the experiment analysis. 11 | 12 | Attributes 13 | ---------- 14 | metric_alias : List[str] 15 | The alias of the metric used in the test 16 | control_variant_name : List[str] 17 | The name of the control variant 18 | treatment_variant_name : List[str] 19 | The name of the treatment variant 20 | control_variant_mean : List[float] 21 | The mean value of the control variant 22 | treatment_variant_mean : List[float] 23 | The mean value of the treatment variant 24 | analysis_type : List[str] 25 | The type of analysis performed 26 | ate : List[float] 27 | The average treatment effect 28 | ate_ci_lower : List[float] 29 | The lower bound of the confidence interval for the ATE 30 | ate_ci_upper : List[float] 31 | The upper bound of the confidence interval for the ATE 32 | p_value : List[float] 33 | The p-value of the test 34 | std_error : List[float] 35 | The standard error of the test 36 | dimension_name : List[str] 37 | The name of the dimension 38 | dimension_value : List[str] 39 | The value of the dimension 40 | alpha: List[float] 41 | The significance level of the test 42 | """ 43 | 44 | metric_alias: List[str] = field(default_factory=lambda: []) 45 | control_variant_name: List[str] = field(default_factory=lambda: []) 46 | treatment_variant_name: List[str] = field(default_factory=lambda: []) 47 | control_variant_mean: List[float] = field(default_factory=lambda: []) 48 | treatment_variant_mean: List[float] = field(default_factory=lambda: []) 49 | analysis_type: List[str] = field(default_factory=lambda: []) 50 | ate: List[float] = field(default_factory=lambda: []) 51 | ate_ci_lower: List[float] = field(default_factory=lambda: []) 52 | ate_ci_upper: List[float] = field(default_factory=lambda: []) 53 | p_value: List[float] = field(default_factory=lambda: []) 54 | std_error: List[float] = field(default_factory=lambda: []) 55 | dimension_name: List[str] = field(default_factory=lambda: []) 56 | dimension_value: List[str] = field(default_factory=lambda: []) 57 | alpha: List[float] = field(default_factory=lambda: []) 58 | 59 | def __add__(self, other): 60 | if not isinstance(other, AnalysisPlanResults): 61 | return NotImplemented 62 | 63 | return AnalysisPlanResults( 64 | metric_alias=self.metric_alias + other.metric_alias, 65 | control_variant_name=self.control_variant_name + other.control_variant_name, 66 | treatment_variant_name=self.treatment_variant_name 67 | + other.treatment_variant_name, 68 | control_variant_mean=self.control_variant_mean + other.control_variant_mean, 69 | treatment_variant_mean=self.treatment_variant_mean 70 | + other.treatment_variant_mean, 71 | analysis_type=self.analysis_type + other.analysis_type, 72 | ate=self.ate + other.ate, 73 | ate_ci_lower=self.ate_ci_lower + other.ate_ci_lower, 74 | ate_ci_upper=self.ate_ci_upper + other.ate_ci_upper, 75 | p_value=self.p_value + other.p_value, 76 | std_error=self.std_error + other.std_error, 77 | dimension_name=self.dimension_name + other.dimension_name, 78 | dimension_value=self.dimension_value + other.dimension_value, 79 | alpha=self.alpha + other.alpha, 80 | ) 81 | 82 | def to_dataframe(self): 83 | return pd.DataFrame(asdict(self)) 84 | -------------------------------------------------------------------------------- /cluster_experiments/inference/dimension.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | 5 | @dataclass 6 | class Dimension: 7 | """ 8 | A class used to represent a Dimension with a name and values. 9 | 10 | Attributes 11 | ---------- 12 | name : str 13 | The name of the dimension 14 | values : List[str] 15 | A list of strings representing the possible values of the dimension 16 | """ 17 | 18 | name: str 19 | values: List[str] 20 | 21 | def __post_init__(self): 22 | """ 23 | Validates the inputs after initialization. 24 | """ 25 | self._validate_inputs() 26 | 27 | def _validate_inputs(self): 28 | """ 29 | Validates the inputs for the Dimension class. 30 | 31 | Raises 32 | ------ 33 | TypeError 34 | If the name is not a string or if values is not a list of strings. 35 | """ 36 | if not isinstance(self.name, str): 37 | raise TypeError("Dimension name must be a string") 38 | if not isinstance(self.values, list) or not all( 39 | isinstance(val, str) for val in self.values 40 | ): 41 | raise TypeError("Dimension values must be a list of strings") 42 | 43 | def iterate_dimension_values(self): 44 | """ 45 | A generator method to yield name and values from the dimension. 46 | 47 | Yields 48 | ------ 49 | Any 50 | A unique value from the dimension. 51 | """ 52 | seen = set() 53 | for value in self.values: 54 | if value not in seen: 55 | seen.add(value) 56 | yield value 57 | 58 | @classmethod 59 | def from_metrics_config(cls, config: dict) -> "Dimension": 60 | """ 61 | Creates a Dimension object from a configuration dictionary. 62 | 63 | Parameters 64 | ---------- 65 | config : dict 66 | A dictionary containing the configuration for the Dimension 67 | 68 | Returns 69 | ------- 70 | Dimension 71 | A Dimension object 72 | """ 73 | return cls(name=config["name"], values=config["values"]) 74 | 75 | 76 | @dataclass 77 | class DefaultDimension(Dimension): 78 | """ 79 | A class used to represent a Dimension with a default value representing total, i.e. no slicing. 80 | """ 81 | 82 | def __init__(self): 83 | super().__init__(name="__total_dimension", values=["total"]) 84 | -------------------------------------------------------------------------------- /cluster_experiments/inference/metric.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | import pandas as pd 5 | 6 | 7 | class Metric(ABC): 8 | """ 9 | An abstract base class used to represent a Metric with an alias. 10 | 11 | Attributes 12 | ---------- 13 | alias : str 14 | A string representing the alias of the metric 15 | """ 16 | 17 | def __init__(self, alias: str): 18 | """ 19 | Parameters 20 | ---------- 21 | alias : str 22 | The alias of the metric 23 | """ 24 | self.alias = alias 25 | self._validate_alias() 26 | 27 | def _validate_alias(self): 28 | """ 29 | Validates the alias input for the Metric class. 30 | 31 | Raises 32 | ------ 33 | TypeError 34 | If the alias is not a string 35 | """ 36 | if not isinstance(self.alias, str): 37 | raise TypeError("Metric alias must be a string") 38 | 39 | @property 40 | @abstractmethod 41 | def target_column(self) -> str: 42 | """ 43 | Abstract property to return the target column to feed the experiment analysis class, from the metric definition. 44 | 45 | Returns 46 | ------- 47 | str 48 | The target column name 49 | """ 50 | pass 51 | 52 | @property 53 | def scale_column(self) -> Optional[str]: 54 | """ 55 | Abstract property to return the scale column to feed the experiment analysis class, from the metric definition. 56 | 57 | Returns 58 | ------- 59 | str 60 | The scale column name 61 | """ 62 | return None 63 | 64 | @abstractmethod 65 | def get_mean(self, df: pd.DataFrame) -> float: 66 | """ 67 | Abstract method to return the mean value of the metric, given a dataframe. 68 | 69 | Returns 70 | ------- 71 | float 72 | The mean value of the metric 73 | """ 74 | pass 75 | 76 | @classmethod 77 | def from_metrics_config(cls, config: dict) -> "Metric": 78 | """ 79 | Class method to create a Metric instance from a configuration dictionary. 80 | 81 | Parameters 82 | ---------- 83 | config : dict 84 | A dictionary containing the configuration of the metric 85 | 86 | Returns 87 | ------- 88 | Metric 89 | A Metric instance 90 | """ 91 | if "numerator_name" in config: 92 | return RatioMetric.from_metrics_config(config) 93 | return SimpleMetric.from_metrics_config(config) 94 | 95 | 96 | class SimpleMetric(Metric): 97 | """ 98 | A class used to represent a Simple Metric with an alias and a name. 99 | To be used when the metric is defined at the same level of the data used for the analysis. 100 | 101 | Example 102 | ---------- 103 | In a clustered experiment the participants were randomised based on their country of residence. 104 | The metric of interest is the salary of each participant. If the dataset fed into the analysis is at participant-level, 105 | then a SimpleMetric must be used. However, if the dataset fed into the analysis is at country-level, then a RatioMetric must be used. 106 | 107 | Attributes 108 | ---------- 109 | alias : str 110 | A string representing the alias of the metric 111 | name : str 112 | A string representing the name of the metric 113 | """ 114 | 115 | def __init__(self, alias: str, name: str): 116 | """ 117 | Parameters 118 | ---------- 119 | alias : str 120 | The alias of the metric 121 | name : str 122 | The name of the metric 123 | """ 124 | super().__init__(alias) 125 | self.name = name 126 | self._validate_name() 127 | 128 | def _validate_name(self): 129 | """ 130 | Validates the name input for the SimpleMetric class. 131 | 132 | Raises 133 | ------ 134 | TypeError 135 | If the name is not a string 136 | """ 137 | if not isinstance(self.name, str): 138 | raise TypeError("SimpleMetric name must be a string") 139 | 140 | @property 141 | def target_column(self) -> str: 142 | """ 143 | Returns the target column for the SimpleMetric. 144 | 145 | Returns 146 | ------- 147 | str 148 | The name of the metric 149 | """ 150 | return self.name 151 | 152 | def get_mean(self, df: pd.DataFrame) -> float: 153 | """ 154 | Returns the mean value of the metric, given a dataframe. 155 | 156 | Returns 157 | ------- 158 | float 159 | The mean value of the metric 160 | """ 161 | return df[self.name].mean() 162 | 163 | @classmethod 164 | def from_metrics_config(cls, config: dict) -> "Metric": 165 | """ 166 | Class method to create a SimpleMetric instance from a configuration dictionary. 167 | 168 | Parameters 169 | ---------- 170 | config : dict 171 | A dictionary containing the configuration of the metric 172 | 173 | Returns 174 | ------- 175 | SimpleMetric 176 | A SimpleMetric instance 177 | """ 178 | return cls(alias=config["alias"], name=config["name"]) 179 | 180 | 181 | class RatioMetric(Metric): 182 | """ 183 | A class used to represent a Ratio Metric with an alias, a numerator name, and a denominator name. 184 | To be used when the metric is defined at a lower level than the data used for the analysis. 185 | 186 | Example 187 | ---------- 188 | In a clustered experiment the participants were randomised based on their country of residence. 189 | The metric of interest is the salary of each participant. If the dataset fed into the analysis is at country-level, 190 | then a RatioMetric must be used: the numerator would be the sum of all salaries in the country, 191 | the denominator would be the number of participants in the country. 192 | 193 | Attributes 194 | ---------- 195 | alias : str 196 | A string representing the alias of the metric 197 | numerator_name : str 198 | A string representing the numerator name of the metric 199 | denominator_name : str 200 | A string representing the denominator name of the metric 201 | """ 202 | 203 | def __init__(self, alias: str, numerator_name: str, denominator_name: str): 204 | """ 205 | Parameters 206 | ---------- 207 | alias : str 208 | The alias of the metric 209 | numerator_name : str 210 | The numerator name of the metric 211 | denominator_name : str 212 | The denominator name of the metric 213 | """ 214 | super().__init__(alias) 215 | self.numerator_name = numerator_name 216 | self.denominator_name = denominator_name 217 | self._validate_names() 218 | 219 | def _validate_names(self): 220 | """ 221 | Validates the numerator and denominator names input for the RatioMetric class. 222 | 223 | Raises 224 | ------ 225 | TypeError 226 | If the numerator or denominator names are not strings 227 | """ 228 | if not isinstance(self.numerator_name, str) or not isinstance( 229 | self.denominator_name, str 230 | ): 231 | raise TypeError("RatioMetric names must be strings") 232 | 233 | @property 234 | def target_column(self) -> str: 235 | """ 236 | Returns the target column for the RatioMetric. 237 | 238 | Returns 239 | ------- 240 | str 241 | The numerator name of the metric 242 | """ 243 | return self.numerator_name 244 | 245 | @property 246 | def scale_column(self) -> str: 247 | """ 248 | Returns the scale column for the RatioMetric. 249 | 250 | Returns 251 | ------- 252 | str 253 | The denominator name of the metric 254 | """ 255 | return self.denominator_name 256 | 257 | def get_mean(self, df: pd.DataFrame) -> float: 258 | """ 259 | Returns the mean value of the metric, given a dataframe. 260 | 261 | Returns 262 | ------- 263 | float 264 | The mean value of the metric 265 | """ 266 | return df[self.numerator_name].mean() / df[self.denominator_name].mean() 267 | 268 | @classmethod 269 | def from_metrics_config(cls, config: dict) -> "Metric": 270 | """ 271 | Class method to create a RatioMetric instance from a configuration dictionary. 272 | 273 | Parameters 274 | ---------- 275 | config : dict 276 | A dictionary containing the configuration of the metric 277 | 278 | Returns 279 | ------- 280 | RatioMetric 281 | A RatioMetric instance 282 | """ 283 | return cls( 284 | alias=config["alias"], 285 | numerator_name=config["numerator_name"], 286 | denominator_name=config["denominator_name"], 287 | ) 288 | -------------------------------------------------------------------------------- /cluster_experiments/inference/variant.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Variant: 6 | """ 7 | A class used to represent a Variant with a name and a control flag. 8 | 9 | Attributes 10 | ---------- 11 | name : str 12 | The name of the variant 13 | is_control : bool 14 | A boolean indicating if the variant is a control variant 15 | """ 16 | 17 | name: str 18 | is_control: bool 19 | 20 | def __post_init__(self): 21 | """ 22 | Validates the inputs after initialization. 23 | """ 24 | self._validate_inputs() 25 | 26 | def _validate_inputs(self): 27 | """ 28 | Validates the inputs for the Variant class. 29 | 30 | Raises 31 | ------ 32 | TypeError 33 | If the name is not a string or if is_control is not a boolean. 34 | """ 35 | if not isinstance(self.name, str): 36 | raise TypeError("Variant name must be a string") 37 | if not isinstance(self.is_control, bool): 38 | raise TypeError("Variant is_control must be a boolean") 39 | 40 | @classmethod 41 | def from_metrics_config(cls, config: dict) -> "Variant": 42 | """ 43 | Creates a Variant object from a configuration dictionary. 44 | 45 | Parameters 46 | ---------- 47 | config : dict 48 | A dictionary containing the configuration for the Variant 49 | 50 | Returns 51 | ------- 52 | Variant 53 | A Variant object 54 | """ 55 | return cls(name=config["name"], is_control=config["is_control"]) 56 | -------------------------------------------------------------------------------- /cluster_experiments/synthetic_control_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import product 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from scipy.optimize import fmin_slsqp 7 | 8 | 9 | def loss_w(W: np.ndarray, X: np.ndarray, y: np.ndarray) -> float: 10 | """ 11 | This function calculates the root mean square error (RMSE) between the actual and predicted values in a linear model. 12 | It is used as an objective function for optimization problems where the goal is to minimize the RMSE. 13 | 14 | Parameters: 15 | W (numpy.ndarray): The weights vector used for predictions. 16 | X (numpy.ndarray): The input data matrix. 17 | y (numpy.ndarray): The actual output vector. 18 | 19 | Returns: 20 | float: The calculated RMSE. 21 | """ 22 | return np.sqrt(np.mean((y - X.dot(W)) ** 2)) 23 | 24 | 25 | def get_w(X, y, verbose=False) -> np.ndarray: 26 | """ 27 | Get weights per unit, constraint in the loss function that sum equals 1; bounds 0 and 1) 28 | """ 29 | w_start = np.full(X.shape[1], 1 / X.shape[1]) 30 | bounds = [(0.0, 1.0)] * len(w_start) 31 | 32 | weights = fmin_slsqp( 33 | partial(loss_w, X=X, y=y), 34 | w_start, 35 | f_eqcons=lambda x: np.sum(x) - 1, 36 | bounds=bounds, 37 | disp=verbose, 38 | ) 39 | return weights 40 | 41 | 42 | def generate_synthetic_control_data(N, start_date, end_date): 43 | """Create df for synthetic control cases, where we need a time variable, a target metric, and some clusters.""" 44 | # Generate a list of dates between start_date and end_date 45 | dates = pd.date_range(start_date, end_date, freq="d") 46 | 47 | users = [f"User {i}" for i in range(N)] 48 | 49 | # Create a combination of each date with each user 50 | combinations = list(product(users, dates)) 51 | 52 | target_values = np.random.normal(100, 10, size=len(combinations)) 53 | 54 | df = pd.DataFrame(combinations, columns=["user", "date"]) 55 | df["target"] = target_values 56 | 57 | df["date"] = pd.to_datetime(df["date"]) 58 | 59 | return df 60 | -------------------------------------------------------------------------------- /cluster_experiments/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict 3 | 4 | 5 | def _original_time_column(time_col: str) -> str: 6 | """ 7 | Usage: 8 | ```python 9 | from cluster_experiments.utils import _original_time_column 10 | 11 | assert _original_time_column("hola") == "original___hola" 12 | ``` 13 | """ 14 | return f"original___{time_col}" 15 | 16 | 17 | def _get_mapping_key(mapping, key): 18 | try: 19 | return mapping[key] 20 | except KeyError: 21 | raise KeyError( 22 | f"Could not find {key = } in mapping. All options are the following: {list(mapping.keys())}" 23 | ) 24 | 25 | 26 | class HypothesisEntries(Enum): 27 | TWO_SIDED = "two-sided" 28 | LESS = "less" 29 | GREATER = "greater" 30 | 31 | 32 | class ModelResults: 33 | def __init__(self, params: Dict, pvalues: Dict): 34 | self.params = params 35 | self.pvalues = pvalues 36 | -------------------------------------------------------------------------------- /docs/api/analysis_plan.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.inference.analysis_plan import *` 2 | 3 | ::: cluster_experiments.inference.analysis_plan 4 | -------------------------------------------------------------------------------- /docs/api/analysis_results.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.inference.analysis_results import *` 2 | 3 | ::: cluster_experiments.inference.analysis_results 4 | -------------------------------------------------------------------------------- /docs/api/cupac_model.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.cupac import *` 2 | 3 | ::: cluster_experiments.cupac 4 | -------------------------------------------------------------------------------- /docs/api/dimension.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.inference.dimension import *` 2 | 3 | ::: cluster_experiments.inference.dimension 4 | -------------------------------------------------------------------------------- /docs/api/experiment_analysis.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.experiment_analysis import *` 2 | 3 | ::: cluster_experiments.experiment_analysis 4 | -------------------------------------------------------------------------------- /docs/api/hypothesis_test.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.inference.hypothesis_test import *` 2 | 3 | ::: cluster_experiments.inference.hypothesis_test 4 | -------------------------------------------------------------------------------- /docs/api/metric.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.inference.metric import *` 2 | 3 | ::: cluster_experiments.inference.metric 4 | -------------------------------------------------------------------------------- /docs/api/perturbator.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.perturbator import *` 2 | 3 | ::: cluster_experiments.perturbator 4 | -------------------------------------------------------------------------------- /docs/api/power_analysis.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.power_analysis import *` 2 | 3 | ::: cluster_experiments.power_analysis 4 | -------------------------------------------------------------------------------- /docs/api/power_config.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.power_config import *` 2 | 3 | ::: cluster_experiments.power_config 4 | -------------------------------------------------------------------------------- /docs/api/random_splitter.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.random_splitter import *` 2 | 3 | ::: cluster_experiments.random_splitter 4 | -------------------------------------------------------------------------------- /docs/api/variant.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.inference.variant import *` 2 | 3 | ::: cluster_experiments.inference.variant 4 | -------------------------------------------------------------------------------- /docs/api/washover.md: -------------------------------------------------------------------------------- 1 | # `from cluster_experiments.washover import *` 2 | 3 | ::: cluster_experiments.washover 4 | -------------------------------------------------------------------------------- /docs/create_custom_classes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "Examples on how to create:\n", 9 | "* a custom perturbator\n", 10 | "* a custom splitter\n", 11 | "* a custom hypothesis test\n", 12 | "\n", 13 | "The names of you custom classes don't need to be CustomX, they are completely free. The only requirement is that they inherit from the base class. For example, if you want to create a custom perturbator, you need to inherit from the Perturbator base class. The same applies to the other classes." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from cluster_experiments import ExperimentAnalysis\n", 23 | "import pandas as pd\n", 24 | "from scipy.stats import ttest_ind\n", 25 | "\n", 26 | "class CustomExperimentAnalysis(ExperimentAnalysis):\n", 27 | " def analysis_pvalue(self, df: pd.DataFrame, verbose: bool = True) -> float:\n", 28 | " treatment_data = df.query(f\"{self.treatment_col} == 1\")[self.target_col]\n", 29 | " control_data = df.query(f\"{self.treatment_col} == 0\")[self.target_col]\n", 30 | " t_test_results = ttest_ind(treatment_data, control_data, equal_var=False)\n", 31 | " return t_test_results.pvalue" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from cluster_experiments import RandomSplitter\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "class CustomRandomSplitter(RandomSplitter):\n", 44 | " def assign_treatment_df(self, df: pd.DataFrame) -> pd.DataFrame:\n", 45 | " df = df.copy()\n", 46 | " # Power users get treatment with 90% probability\n", 47 | " df_power_users = df.query(\"power_user\")\n", 48 | " df_power_users[self.treatment_col] = np.random.choice(\n", 49 | " [\"A\", \"B\"], size=len(df_power_users), p=[0.1, 0.9]\n", 50 | " )\n", 51 | " # Non-power users get treatment with 10% probability\n", 52 | " df_non_power_users = df.query(\"not power_user\")\n", 53 | " df_non_power_users[self.treatment_col] = np.random.choice(\n", 54 | " [\"A\", \"B\"], size=len(df_non_power_users), p=[0.9, 0.1]\n", 55 | " )\n", 56 | " return pd.concat([df_power_users, df_non_power_users])" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from cluster_experiments import Perturbator\n", 66 | "import pandas as pd\n", 67 | "\n", 68 | "class CustomPerturbator(Perturbator):\n", 69 | " def perturbate(self, df: pd.DataFrame, average_effect: float) -> pd.DataFrame:\n", 70 | " df = df.copy().reset_index(drop=True)\n", 71 | " n = (df[self.treatment_col] == self.treatment).sum()\n", 72 | " df.loc[\n", 73 | " df[self.treatment_col] == self.treatment, self.target_col\n", 74 | " ] += np.random.normal(average_effect, 1, size=n)\n", 75 | " return df" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3.8.6 ('venv': venv)", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.8.6 (default, Jan 17 2022, 12:11:54) \n[Clang 12.0.5 (clang-1205.0.22.11)]" 96 | }, 97 | "orig_nbformat": 4, 98 | "vscode": { 99 | "interpreter": { 100 | "hash": "29c447d2129f0d56b23b7ba3abc571cfa9d42454e0e2bba301a881797dc4c0e2" 101 | } 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 2 106 | } 107 | -------------------------------------------------------------------------------- /docs/multivariate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "This notebook shows how to use the multivariate module. The idea is to use several treatments in the splitter and only one of them is used to run the hypothesis test." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import numpy as np\n", 18 | "import pandas as pd\n", 19 | "from cluster_experiments import PowerAnalysis\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# Create fake data\n", 29 | "N = 1_000\n", 30 | "df = pd.DataFrame(\n", 31 | " {\n", 32 | " \"target\": np.random.normal(0, 1, size=N),\n", 33 | " }\n", 34 | ")" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "0.18" 46 | ] 47 | }, 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "# Run power analysis using 3 variants\n", 55 | "config_abc = {\n", 56 | " \"analysis\": \"ols_non_clustered\",\n", 57 | " \"perturbator\": \"constant\",\n", 58 | " \"splitter\": \"non_clustered\",\n", 59 | " \"treatments\": [\"A\", \"B\", \"C\"],\n", 60 | " \"control\": \"A\",\n", 61 | " \"treatment\": \"B\",\n", 62 | " \"n_simulations\": 50,\n", 63 | "}\n", 64 | "\n", 65 | "power_abc = PowerAnalysis.from_dict(config_abc)\n", 66 | "power_abc.power_analysis(df, average_effect=0.1)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "0.28" 78 | ] 79 | }, 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "# Run power analysis using 2 variants\n", 87 | "config_ab = {\n", 88 | " \"analysis\": \"ols_non_clustered\",\n", 89 | " \"perturbator\": \"constant\",\n", 90 | " \"splitter\": \"non_clustered\",\n", 91 | " \"treatments\": [\"A\", \"B\"],\n", 92 | " \"control\": \"A\",\n", 93 | " \"treatment\": \"B\",\n", 94 | " \"n_simulations\": 50,\n", 95 | "}\n", 96 | "power_ab = PowerAnalysis.from_dict(config_ab)\n", 97 | "power_ab.power_analysis(df, average_effect=0.1)" 98 | ] 99 | }, 100 | { 101 | "attachments": {}, 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "The power of the AB test is higher than the ABC test, which makes sense." 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": "venv", 117 | "language": "python", 118 | "name": "python3" 119 | }, 120 | "language_info": { 121 | "codemirror_mode": { 122 | "name": "ipython", 123 | "version": 3 124 | }, 125 | "file_extension": ".py", 126 | "mimetype": "text/x-python", 127 | "name": "python", 128 | "nbconvert_exporter": "python", 129 | "pygments_lexer": "ipython3", 130 | "version": "3.8.6" 131 | }, 132 | "orig_nbformat": 4, 133 | "vscode": { 134 | "interpreter": { 135 | "hash": "29c447d2129f0d56b23b7ba3abc571cfa9d42454e0e2bba301a881797dc4c0e2" 136 | } 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 2 141 | } 142 | -------------------------------------------------------------------------------- /docs/paired_ttest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This notebook shows how the PairedTTestClusteredAnalysis class is performing the paired t test. It's important to get a grasp on the difference between cluster and strata columns.\n" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "pycharm": { 15 | "name": "#%%\n" 16 | } 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "from cluster_experiments.experiment_analysis import PairedTTestClusteredAnalysis\n", 21 | "\n", 22 | "import pandas as pd\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": { 29 | "pycharm": { 30 | "name": "#%%\n" 31 | } 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# Let's generate some fake switchback data (the clusters here would be city and date\n", 36 | "df = pd.DataFrame(\n", 37 | " {\n", 38 | " \"country_code\": [\"ES\"] * 4 + [\"IT\"] * 4 + [\"PL\"] * 4 + [\"RO\"] * 4,\n", 39 | " \"date\": [\"2022-01-01\", \"2022-01-02\", \"2022-01-03\", \"2022-01-04\"] * 4,\n", 40 | " \"treatment\": [\"A\", 'B'] * 8,\n", 41 | " \"target\": [0.01] * 15 + [0.1],\n", 42 | " }\n", 43 | " )" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "Let's see what the PairedTTestClusteredAnalysis class is doing under the hood. As I am passing already the treatment column, there's no need for splitter nor perturbator\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": { 57 | "pycharm": { 58 | "name": "#%%\n" 59 | } 60 | }, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "performing paired t test in this data \n", 67 | " treatment A B\n", 68 | "country_code \n", 69 | "ES 0.01 0.010\n", 70 | "IT 0.01 0.010\n", 71 | "PL 0.01 0.010\n", 72 | "RO 0.01 0.055 \n", 73 | "\n" 74 | ] 75 | }, 76 | { 77 | "data": { 78 | "text/plain": "treatment A B\ncountry_code \nES 0.01 0.010\nIT 0.01 0.010\nPL 0.01 0.010\nRO 0.01 0.055", 79 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
treatmentAB
country_code
ES0.010.010
IT0.010.010
PL0.010.010
RO0.010.055
\n
" 80 | }, 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "analysis = PairedTTestClusteredAnalysis(\n", 88 | " cluster_cols=[\"country_code\", \"date\"], strata_cols = ['country_code']\n", 89 | ")\n", 90 | "\n", 91 | "analysis._preprocessing(df, verbose=True)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "source": [ 97 | "Keep in mind that strata_cols needs to be a subset of cluster_cols and it will be used as the index for pivoting." 98 | ], 99 | "metadata": { 100 | "collapsed": false 101 | } 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "paired t test results: \n", 112 | " TtestResult(statistic=-1.0, pvalue=0.39100221895577053, df=3) \n", 113 | "\n" 114 | ] 115 | }, 116 | { 117 | "data": { 118 | "text/plain": "0.39100221895577053" 119 | }, 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "analysis.analysis_pvalue(df, verbose=True)\n" 127 | ], 128 | "metadata": { 129 | "collapsed": false, 130 | "pycharm": { 131 | "name": "#%%\n" 132 | } 133 | } 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "outputs": [], 139 | "source": [], 140 | "metadata": { 141 | "collapsed": false, 142 | "pycharm": { 143 | "name": "#%%\n" 144 | } 145 | } 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3 (ipykernel)", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.8.6" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 1 169 | } 170 | -------------------------------------------------------------------------------- /examples/cupac_example_gbm.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.ensemble import HistGradientBoostingRegressor 6 | 7 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis 8 | from cluster_experiments.perturbator import ConstantPerturbator 9 | from cluster_experiments.power_analysis import PowerAnalysis 10 | from cluster_experiments.random_splitter import ClusteredSplitter 11 | 12 | 13 | def generate_random_data(clusters, dates, N): 14 | 15 | # Every cluster has a mean 16 | df_clusters = pd.DataFrame( 17 | { 18 | "cluster": clusters, 19 | "cluster_mean": np.random.normal(0, 0.1, size=len(clusters)), 20 | } 21 | ) 22 | # The target is the sum of: user mean, cluster mean and random residual 23 | df = ( 24 | pd.DataFrame( 25 | { 26 | "cluster": np.random.choice(clusters, size=N), 27 | "residual": np.random.normal(0, 1, size=N), 28 | "date": np.random.choice(dates, size=N), 29 | "x1": np.random.normal(0, 1, size=N), 30 | "x2": np.random.normal(0, 1, size=N), 31 | "x3": np.random.normal(0, 1, size=N), 32 | "x4": np.random.normal(0, 1, size=N), 33 | } 34 | ) 35 | .merge(df_clusters, on="cluster") 36 | .assign( 37 | target=lambda x: x["x1"] * x["x2"] 38 | + x["x3"] ** 2 39 | + x["x4"] 40 | + x["cluster_mean"] 41 | + x["residual"] 42 | ) 43 | ) 44 | 45 | return df 46 | 47 | 48 | if __name__ == "__main__": 49 | clusters = [f"Cluster {i}" for i in range(100)] 50 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 51 | experiment_dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)] 52 | N = 10_000 53 | df = generate_random_data(clusters, dates, N) 54 | df_analysis = df.query(f"date.isin({experiment_dates})") 55 | df_pre = df.query(f"~date.isin({experiment_dates})") 56 | print(df) 57 | 58 | # Splitter and perturbator 59 | sw = ClusteredSplitter( 60 | cluster_cols=["cluster", "date"], 61 | ) 62 | 63 | perturbator = ConstantPerturbator( 64 | average_effect=0.1, 65 | ) 66 | 67 | # Vainilla GEE 68 | analysis = GeeExperimentAnalysis( 69 | cluster_cols=["cluster", "date"], 70 | ) 71 | pw_vainilla = PowerAnalysis( 72 | perturbator=perturbator, 73 | splitter=sw, 74 | analysis=analysis, 75 | n_simulations=50, 76 | ) 77 | 78 | power = pw_vainilla.power_analysis(df_analysis) 79 | print(f"Not using cupac: {power = }") 80 | 81 | # Cupac GEE 82 | analysis = GeeExperimentAnalysis( 83 | cluster_cols=["cluster", "date"], covariates=["estimate_target"] 84 | ) 85 | 86 | gbm = HistGradientBoostingRegressor() 87 | pw_cupac = PowerAnalysis( 88 | perturbator=perturbator, 89 | splitter=sw, 90 | analysis=analysis, 91 | n_simulations=50, 92 | cupac_model=gbm, 93 | features_cupac_model=["x1", "x2", "x3", "x4"], 94 | ) 95 | 96 | power = pw_cupac.power_analysis(df_analysis, df_pre) 97 | print(f"Using cupac: {power = }") 98 | -------------------------------------------------------------------------------- /examples/cupac_example_target_mean.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.cupac import TargetAggregation 7 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis 8 | from cluster_experiments.perturbator import ConstantPerturbator 9 | from cluster_experiments.power_analysis import PowerAnalysis 10 | from cluster_experiments.random_splitter import ClusteredSplitter 11 | 12 | 13 | def generate_random_data(clusters, dates, N, n_users=1000): 14 | # Generate random data with clusters and target 15 | users = [f"User {i}" for i in range(n_users)] 16 | 17 | # Every user has a mean 18 | df_users = pd.DataFrame( 19 | {"user": users, "user_mean": np.random.normal(0, 3, size=n_users)} 20 | ) 21 | 22 | # Every cluster has a mean 23 | df_clusters = pd.DataFrame( 24 | { 25 | "cluster": clusters, 26 | "cluster_mean": np.random.normal(0, 0.1, size=len(clusters)), 27 | } 28 | ) 29 | # The target is the sum of: user mean, cluster mean and random residual 30 | df = ( 31 | pd.DataFrame( 32 | { 33 | "cluster": np.random.choice(clusters, size=N), 34 | "residual": np.random.normal(0, 1, size=N), 35 | "user": np.random.choice(users, size=N), 36 | "date": np.random.choice(dates, size=N), 37 | } 38 | ) 39 | .merge(df_users, on="user") 40 | .merge(df_clusters, on="cluster") 41 | .assign(target=lambda x: x["residual"] + x["user_mean"] + x["cluster_mean"]) 42 | ) 43 | 44 | return df 45 | 46 | 47 | if __name__ == "__main__": 48 | clusters = [f"Cluster {i}" for i in range(100)] 49 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 50 | experiment_dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)] 51 | N = 10_000 52 | df = generate_random_data(clusters, dates, N) 53 | df_analysis = df.query(f"date.isin({experiment_dates})") 54 | df_pre = df.query(f"~date.isin({experiment_dates})") 55 | print(df) 56 | 57 | # Splitter and perturbator 58 | sw = ClusteredSplitter( 59 | cluster_cols=["cluster", "date"], 60 | ) 61 | 62 | perturbator = ConstantPerturbator( 63 | average_effect=0.1, 64 | ) 65 | 66 | # Vainilla GEE 67 | analysis = GeeExperimentAnalysis( 68 | cluster_cols=["cluster", "date"], 69 | ) 70 | pw_vainilla = PowerAnalysis( 71 | perturbator=perturbator, 72 | splitter=sw, 73 | analysis=analysis, 74 | n_simulations=50, 75 | ) 76 | 77 | power = pw_vainilla.power_analysis(df_analysis) 78 | print(f"Not using cupac: {power = }") 79 | 80 | # Cupac GEE 81 | analysis = GeeExperimentAnalysis( 82 | cluster_cols=["cluster", "date"], covariates=["estimate_target"] 83 | ) 84 | 85 | target_agg = TargetAggregation(target_col="target", agg_col="user") 86 | pw_cupac = PowerAnalysis( 87 | perturbator=perturbator, 88 | splitter=sw, 89 | analysis=analysis, 90 | n_simulations=50, 91 | cupac_model=target_agg, 92 | ) 93 | 94 | power = pw_cupac.power_analysis(df_analysis, df_pre) 95 | print(f"Using cupac: {power = }") 96 | -------------------------------------------------------------------------------- /examples/long_example.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis 7 | from cluster_experiments.perturbator import ConstantPerturbator 8 | from cluster_experiments.power_analysis import PowerAnalysis 9 | from cluster_experiments.random_splitter import ClusteredSplitter 10 | 11 | 12 | def generate_random_data(clusters, dates, N): 13 | # Generate random data with clusters and target 14 | users = [f"User {i}" for i in range(1000)] 15 | df = pd.DataFrame( 16 | { 17 | "cluster": np.random.choice(clusters, size=N), 18 | "target": np.random.normal(0, 1, size=N), 19 | "user": np.random.choice(users, size=N), 20 | "date": np.random.choice(dates, size=N), 21 | } 22 | ) 23 | 24 | return df 25 | 26 | 27 | if __name__ == "__main__": 28 | clusters = [f"Cluster {i}" for i in range(100)] 29 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 30 | experiment_dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)] 31 | N = 10_000 32 | df = generate_random_data(clusters, dates, N) 33 | sw = ClusteredSplitter( 34 | treatments=["A", "B"], 35 | cluster_cols=["cluster", "date"], 36 | ) 37 | 38 | treatment_assignment_df = sw.assign_treatment_df(df) 39 | # NaNs because of data previous to experiment 40 | print(treatment_assignment_df) 41 | 42 | perturbator = ConstantPerturbator( 43 | average_effect=0.01, 44 | target_col="target", 45 | treatment_col="treatment", 46 | ) 47 | 48 | perturbated_df = perturbator.perturbate(treatment_assignment_df) 49 | print(perturbated_df.groupby(["treatment"]).mean()) 50 | 51 | analysis = GeeExperimentAnalysis( 52 | target_col="target", 53 | treatment_col="treatment", 54 | cluster_cols=["cluster", "date"], 55 | ) 56 | 57 | p_val = analysis.get_pvalue(perturbated_df.query("treatment.notnull()")) 58 | print(f"{p_val = }") 59 | 60 | pw = PowerAnalysis( 61 | target_col="target", 62 | treatment_col="treatment", 63 | treatment="B", 64 | perturbator=perturbator, 65 | splitter=sw, 66 | analysis=analysis, 67 | ) 68 | 69 | print(df) 70 | power = pw.power_analysis(df) 71 | print(f"{power = }") 72 | -------------------------------------------------------------------------------- /examples/parallel_example.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import time 3 | from datetime import date 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis 9 | from cluster_experiments.perturbator import ConstantPerturbator 10 | from cluster_experiments.power_analysis import PowerAnalysis 11 | from cluster_experiments.random_splitter import ClusteredSplitter 12 | 13 | 14 | def generate_random_data(clusters, dates, N): 15 | # Generate random data with clusters and target 16 | users = [f"User {i}" for i in range(1000)] 17 | df = pd.DataFrame( 18 | { 19 | "cluster": np.random.choice(clusters, size=N), 20 | "target": np.random.normal(0, 1, size=N), 21 | "user": np.random.choice(users, size=N), 22 | "date": np.random.choice(dates, size=N), 23 | } 24 | ) 25 | 26 | return df 27 | 28 | 29 | # %% 30 | clusters = [f"Cluster {i}" for i in range(1000)] 31 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 32 | N = 1_000_000 33 | df = generate_random_data(clusters, dates, N) 34 | sw = ClusteredSplitter( 35 | cluster_cols=["cluster", "date"], 36 | ) 37 | 38 | perturbator = ConstantPerturbator( 39 | average_effect=0.1, 40 | ) 41 | 42 | analysis = GeeExperimentAnalysis( 43 | cluster_cols=["cluster", "date"], 44 | ) 45 | 46 | pw = PowerAnalysis(perturbator=perturbator, splitter=sw, analysis=analysis) 47 | 48 | print(df) 49 | # %% 50 | 51 | if __name__ == "__main__": 52 | 53 | n_simulations = 16 54 | n_jobs = 8 55 | # parallel_start = time.time() 56 | # parallel_sim = pw.power_analysis_parallel( 57 | # df=df, n_simulations=n_simulations, average_effect=-0.01, n_jobs=16 58 | # ) 59 | # parallel_end = time.time() 60 | # print("Parallel execution finished") 61 | # parallel_duration = parallel_end - parallel_start 62 | # print(f"{parallel_duration=}") 63 | 64 | non_parallel_start = time.time() 65 | simple_sim = pw.power_analysis( 66 | df=df, n_simulations=n_simulations, average_effect=-0.01 67 | ) 68 | non_parallel_end = time.time() 69 | print("Non Parallel execution finished") 70 | non_parallel_duration = non_parallel_end - non_parallel_start 71 | print(f"{non_parallel_duration=}") 72 | 73 | parallel_start = time.time() 74 | parallel_sim = pw.power_analysis( 75 | df=df, 76 | n_simulations=n_simulations, 77 | average_effect=-0.01, 78 | n_jobs=n_jobs, 79 | ) 80 | parallel_end = time.time() 81 | print("Parallel mp execution finished") 82 | parallel_duration = parallel_end - parallel_start 83 | print(f"{parallel_duration=}") 84 | -------------------------------------------------------------------------------- /examples/short_example.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis 7 | from cluster_experiments.perturbator import ConstantPerturbator 8 | from cluster_experiments.power_analysis import PowerAnalysis 9 | from cluster_experiments.random_splitter import ClusteredSplitter 10 | 11 | 12 | def generate_random_data(clusters, dates, N): 13 | # Generate random data with clusters and target 14 | users = [f"User {i}" for i in range(1000)] 15 | df = pd.DataFrame( 16 | { 17 | "cluster": np.random.choice(clusters, size=N), 18 | "target": np.random.normal(0, 1, size=N), 19 | "user": np.random.choice(users, size=N), 20 | "date": np.random.choice(dates, size=N), 21 | } 22 | ) 23 | 24 | return df 25 | 26 | 27 | if __name__ == "__main__": 28 | clusters = [f"Cluster {i}" for i in range(100)] 29 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 30 | N = 10_000 31 | df = generate_random_data(clusters, dates, N) 32 | sw = ClusteredSplitter( 33 | cluster_cols=["cluster", "date"], 34 | ) 35 | 36 | perturbator = ConstantPerturbator( 37 | average_effect=0.1, 38 | ) 39 | 40 | analysis = GeeExperimentAnalysis( 41 | cluster_cols=["cluster", "date"], 42 | ) 43 | 44 | pw = PowerAnalysis( 45 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50 46 | ) 47 | 48 | print(df) 49 | power = pw.power_analysis(df) 50 | print(f"{power = }") 51 | -------------------------------------------------------------------------------- /examples/short_example_config.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.power_analysis import PowerAnalysis 7 | from cluster_experiments.power_config import PowerConfig 8 | 9 | 10 | def generate_random_data(clusters, dates, N): 11 | # Generate random data with clusters and target 12 | users = [f"User {i}" for i in range(1000)] 13 | df = pd.DataFrame( 14 | { 15 | "cluster": np.random.choice(clusters, size=N), 16 | "target": np.random.normal(0, 1, size=N), 17 | "user": np.random.choice(users, size=N), 18 | "date": np.random.choice(dates, size=N), 19 | } 20 | ) 21 | 22 | return df 23 | 24 | 25 | if __name__ == "__main__": 26 | clusters = [f"Cluster {i}" for i in range(100)] 27 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 28 | N = 10_000 29 | df = generate_random_data(clusters, dates, N) 30 | config = PowerConfig( 31 | cluster_cols=["cluster", "date"], 32 | analysis="gee", 33 | perturbator="constant", 34 | splitter="clustered", 35 | n_simulations=100, 36 | ) 37 | pw = PowerAnalysis.from_config(config) 38 | 39 | print(df) 40 | power = pw.power_analysis(df) 41 | print(f"{power = }") 42 | -------------------------------------------------------------------------------- /examples/short_example_dict.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.power_analysis import PowerAnalysis 7 | 8 | 9 | def generate_random_data(clusters, dates, N): 10 | # Generate random data with clusters and target 11 | users = [f"User {i}" for i in range(1000)] 12 | df = pd.DataFrame( 13 | { 14 | "cluster": np.random.choice(clusters, size=N), 15 | "target": np.random.normal(0, 1, size=N), 16 | "user": np.random.choice(users, size=N), 17 | "date": np.random.choice(dates, size=N), 18 | } 19 | ) 20 | 21 | return df 22 | 23 | 24 | if __name__ == "__main__": 25 | clusters = [f"Cluster {i}" for i in range(100)] 26 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 27 | N = 1_000 28 | df = generate_random_data(clusters, dates, N) 29 | config = { 30 | "cluster_cols": ["cluster", "date"], 31 | "analysis": "gee", 32 | "perturbator": "constant", 33 | "splitter": "clustered", 34 | "n_simulations": 50, 35 | } 36 | pw = PowerAnalysis.from_dict(config) 37 | 38 | print(df) 39 | power = pw.power_analysis(df) 40 | print(f"{power = }") 41 | -------------------------------------------------------------------------------- /examples/short_example_paired_ttest.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.experiment_analysis import PairedTTestClusteredAnalysis 7 | from cluster_experiments.perturbator import ConstantPerturbator 8 | from cluster_experiments.power_analysis import PowerAnalysis 9 | from cluster_experiments.random_splitter import StratifiedSwitchbackSplitter 10 | 11 | 12 | def generate_random_data(clusters, dates, N): 13 | # Generate random data with clusters and target 14 | users = [f"User {i}" for i in range(1000)] 15 | df = pd.DataFrame( 16 | { 17 | "cluster": np.random.choice(clusters, size=N), 18 | "target": np.random.normal(0, 1, size=N), 19 | "user": np.random.choice(users, size=N), 20 | "date": np.random.choice(dates, size=N), 21 | } 22 | ) 23 | 24 | df["date"] = pd.to_datetime(df["date"]) 25 | df["dow"] = df["date"].dt.day_name() 26 | 27 | return df 28 | 29 | 30 | if __name__ == "__main__": 31 | clusters = [f"Cluster {i}" for i in range(100)] 32 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 33 | N = 10_000 34 | df = generate_random_data(clusters, dates, N) 35 | sw = StratifiedSwitchbackSplitter( 36 | cluster_cols=["cluster", "date"], strata_cols=["dow"] 37 | ) 38 | 39 | perturbator = ConstantPerturbator( 40 | average_effect=0.1, 41 | ) 42 | 43 | analysis = PairedTTestClusteredAnalysis( 44 | cluster_cols=["cluster", "date"], strata_cols=["cluster"] 45 | ) 46 | 47 | pw = PowerAnalysis( 48 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50 49 | ) 50 | 51 | print(df) 52 | power = pw.power_analysis(df) 53 | print(f"{power = }") 54 | -------------------------------------------------------------------------------- /examples/short_example_synthetic_control.py: -------------------------------------------------------------------------------- 1 | from cluster_experiments.experiment_analysis import SyntheticControlAnalysis 2 | from cluster_experiments.perturbator import ConstantPerturbator 3 | from cluster_experiments.power_analysis import PowerAnalysisWithPreExperimentData 4 | from cluster_experiments.random_splitter import FixedSizeClusteredSplitter 5 | from cluster_experiments.synthetic_control_utils import generate_synthetic_control_data 6 | 7 | df = generate_synthetic_control_data(10, "2022-01-01", "2022-01-30") 8 | 9 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=2, cluster_cols=["user"]) 10 | 11 | perturbator = ConstantPerturbator( 12 | average_effect=0.1, 13 | ) 14 | 15 | analysis = SyntheticControlAnalysis( 16 | cluster_cols=["user"], time_col="date", intervention_date="2022-01-15" 17 | ) 18 | 19 | pw = PowerAnalysisWithPreExperimentData( 20 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50 21 | ) 22 | 23 | power = pw.power_analysis(df) 24 | print(f"{power = }") 25 | print(pw.power_line(df, average_effects=[0.1, 0.2, 0.5, 1, 1.5], n_jobs=-1)) 26 | print(pw.simulate_point_estimate(df)) 27 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Cluster Experiments Docs 2 | extra_css: [style.css] 3 | repo_url: https://github.com/david26694/cluster-experiments 4 | site_url: https://david26694.github.io/cluster-experiments/ 5 | site_description: Functions to design and run clustered experiments 6 | site_author: David Masip 7 | use_directory_urls: false 8 | edit_uri: blob/main/docs/ 9 | nav: 10 | - Home: 11 | - Index: index.md 12 | - End-to-end example: e2e_mde.ipynb 13 | - Cupac example: cupac_example.ipynb 14 | - Custom classes: create_custom_classes.ipynb 15 | - Switchback: 16 | - Stratified switchback: switchback.ipynb 17 | - Switchback calendar visualization: plot_calendars.ipynb 18 | - Visualization - 4-hour switches: plot_calendars_hours.ipynb 19 | - Multiple treatments: multivariate.ipynb 20 | - AA test clustered: aa_test.ipynb 21 | - Paired T test: paired_ttest.ipynb 22 | - Different hypotheses tests: analysis_with_different_hypotheses.ipynb 23 | - Washover: washover_example.ipynb 24 | - Normal Power: 25 | - Compare with simulation: normal_power.ipynb 26 | - Time-lines: normal_power_lines.ipynb 27 | - Synthetic control: synthetic_control.ipynb 28 | - Experiment analysis workflow: experiment_analysis.ipynb 29 | - Delta Method Analysis: delta_method.ipynb 30 | - API: 31 | - Experiment analysis methods: api/experiment_analysis.md 32 | - Perturbators: api/perturbator.md 33 | - Splitter: api/random_splitter.md 34 | - Pre experiment outcome model: api/cupac_model.md 35 | - Power config: api/power_config.md 36 | - Power analysis: api/power_analysis.md 37 | - Washover: api/washover.md 38 | - Metric: api/metric.md 39 | - Variant: api/variant.md 40 | - Dimension: api/dimension.md 41 | - Hypothesis Test: api/hypothesis_test.md 42 | - Analysis Plan: api/analysis_plan.md 43 | plugins: 44 | - mkdocstrings: 45 | watch: 46 | - cluster_experiments 47 | - mkdocs-jupyter 48 | - search 49 | copyright: Copyright © 2022 Maintained by David Masip. 50 | theme: 51 | name: material 52 | font: 53 | text: Ubuntu 54 | code: Ubuntu Mono 55 | feature: 56 | tabs: true 57 | palette: 58 | primary: indigo 59 | accent: blue 60 | markdown_extensions: 61 | - codehilite 62 | - pymdownx.inlinehilite 63 | - pymdownx.superfences 64 | - pymdownx.details 65 | - pymdownx.tabbed 66 | - pymdownx.snippets 67 | - pymdownx.highlight: 68 | use_pygments: true 69 | - toc: 70 | permalink: true 71 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [] 3 | requires-python = "<3.13,>=3.9" 4 | dependencies = [ 5 | "pip>=22.2.2", 6 | "statsmodels>=0.13.2", 7 | "pandas>=1.2.0", 8 | "scikit-learn>=1.0.0", 9 | "tqdm>=4.0.0", 10 | "numpy>=1.20.0", 11 | ] 12 | name = "cluster-experiments" 13 | version = "0.26.0" 14 | description = "" 15 | readme = "README.md" 16 | classifiers=[ 17 | "Development Status :: 4 - Beta", 18 | "Intended Audience :: Developers", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Programming Language :: Python :: 3 :: Only", 25 | "Operating System :: OS Independent", 26 | "License :: OSI Approved :: MIT License", 27 | ] 28 | 29 | [project.optional-dependencies] 30 | dev = [ 31 | "pytest<9.0.0,>=5.4.3", 32 | "black<25.0.0,>=22.12.0", 33 | "ruff<1.0.0,>=0.7.4", 34 | "mktestdocs<1.0.0,>=0.2.2", 35 | "pytest-cov<7.0.0,>=2.10.1", 36 | "pytest-sugar<2.0.0,>=0.9.4", 37 | "pytest-slow-last<1.0.0,>=0.1.3", 38 | "coverage<8.0.0,>=7.6.7", 39 | "pytest-reportlog<1.0.0,>=0.4.0", 40 | "pytest-duration-insights<1.0.0,>=0.1.2", 41 | "pytest-clarity<2.0.0,>=1.0.1", 42 | "pytest-xdist<4.0.0,>=3.6.1", 43 | "pre-commit<5.0.0,>=2.6.0", 44 | "ipykernel<7.0.0,>=6.15.1", 45 | "twine<6.0.0,>=5.1.1", 46 | "build<2.0.0.0,>=1.2.2.post1", 47 | "tox<5.0.0,>=4.23.2", 48 | "mkdocs<2.0.0,>=1.4.0", 49 | "mkdocs-material<10.0.0,>=8.5.0", 50 | "mkdocstrings[python]<1.0.0,>=0.25.0", 51 | "jinja2<4.0.0,>=3.1.0", 52 | "mkdocs-jupyter<1.0.0,>=0.22.0", 53 | "matplotlib<4.0.0,>=3.4.3", 54 | "plotnine<1.0.0,>=0.8.0", 55 | ] 56 | 57 | test = [ 58 | "pytest<9.0.0,>=5.4.3", 59 | "black<25.0.0,>=22.12.0", 60 | "ruff<1.0.0,>=0.7.4", 61 | "mktestdocs<1.0.0,>=0.2.2", 62 | "pytest-cov<7.0.0,>=2.10.1", 63 | "pytest-sugar<2.0.0,>=0.9.4", 64 | "pytest-slow-last<1.0.0,>=0.1.3", 65 | "coverage<8.0.0,>=7.6.7", 66 | "pytest-reportlog<1.0.0,>=0.4.0", 67 | "pytest-duration-insights<1.0.0,>=0.1.2", 68 | "pytest-clarity<2.0.0,>=1.0.1", 69 | "pytest-xdist<4.0.0,>=3.6.1", 70 | ] 71 | 72 | docs = [ 73 | "mkdocs<2.0.0,>=1.4.0", 74 | "mkdocs-material<10.0.0,>=8.5.0", 75 | "mkdocstrings<1.0.0,>=0.18.0", 76 | "jinja2<4.0.0,>=3.1.0", 77 | "mkdocs-jupyter<1.0.0,>=0.22.0", 78 | "plotnine<1.0.0,>=0.8.0", 79 | "matplotlib<4.0.0,>=3.4.3", 80 | ] 81 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 160 2 | 3 | extend-select = ["I001"] 4 | 5 | ignore = ["E501"] 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | packages=find_packages(), 5 | ) 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/__init__.py -------------------------------------------------------------------------------- /tests/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/analysis/__init__.py -------------------------------------------------------------------------------- /tests/analysis/conftest.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from tests.utils import generate_ratio_metric_data 8 | 9 | N = 50_000 10 | 11 | 12 | @pytest.fixture 13 | def dates(): 14 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 15 | 16 | 17 | @pytest.fixture 18 | def experiment_dates(): 19 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)] 20 | 21 | 22 | @pytest.fixture 23 | def analysis_df(): 24 | return pd.DataFrame( 25 | { 26 | "target": [0, 1, 0, 1], 27 | "treatment": ["A", "B", "B", "A"], 28 | "cluster": ["Cluster 1", "Cluster 1", "Cluster 1", "Cluster 1"], 29 | "date": ["2022-01-01", "2022-01-01", "2022-01-01", "2022-01-01"], 30 | } 31 | ) 32 | 33 | 34 | @pytest.fixture 35 | def analysis_ratio_df(dates, experiment_dates): 36 | pre_exp_dates = [d for d in dates if d not in experiment_dates] 37 | 38 | user_sample_mean = 0.3 39 | user_standard_error = 0.15 40 | users = 2000 41 | 42 | user_target_means = np.random.normal(user_sample_mean, user_standard_error, users) 43 | 44 | pre_data = generate_ratio_metric_data( 45 | pre_exp_dates, N, user_target_means, users, treatment_effect=0 46 | ) 47 | post_data = generate_ratio_metric_data( 48 | experiment_dates, N, user_target_means, users 49 | ) 50 | return pd.concat([pre_data, post_data]) 51 | 52 | 53 | @pytest.fixture 54 | def covariate_data(): 55 | """generates data via y ~ T + X""" 56 | N = 1000 57 | np.random.seed(123) 58 | X = np.random.normal(size=N) 59 | T = np.random.choice(["A", "B"], size=N) 60 | y = 0.5 * X + 0.1 * (T == "B") + np.random.normal(size=N) 61 | df = pd.DataFrame({"y": y, "T": T, "X": X}) 62 | return df 63 | -------------------------------------------------------------------------------- /tests/analysis/test_formula.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments.experiment_analysis import ( 4 | ClusteredOLSAnalysis, 5 | GeeExperimentAnalysis, 6 | MLMExperimentAnalysis, 7 | ) 8 | 9 | parametrisation = pytest.mark.parametrize( 10 | "analysis_class", 11 | [ 12 | ClusteredOLSAnalysis, 13 | GeeExperimentAnalysis, 14 | MLMExperimentAnalysis, 15 | ], 16 | ) 17 | 18 | 19 | @parametrisation 20 | def test_formula_no_covariates(analysis_class): 21 | analysis = analysis_class( 22 | cluster_cols=["cluster"], 23 | treatment_col="treatment", 24 | target_col="y", 25 | ) 26 | 27 | assert analysis.formula == "y ~ treatment" 28 | 29 | 30 | @parametrisation 31 | def test_formula_with_covariates(analysis_class): 32 | analysis = analysis_class( 33 | cluster_cols=["cluster"], 34 | treatment_col="treatment", 35 | target_col="y", 36 | covariates=["covariate1", "covariate2"], 37 | ) 38 | 39 | assert analysis.formula == "y ~ treatment + covariate1 + covariate2" 40 | 41 | 42 | @parametrisation 43 | def test_formula_with_interaction(analysis_class): 44 | analysis = analysis_class( 45 | cluster_cols=["cluster"], 46 | treatment_col="treatment", 47 | target_col="y", 48 | covariates=["covariate1", "covariate2"], 49 | add_covariate_interaction=True, 50 | ) 51 | 52 | assert ( 53 | analysis.formula 54 | == "y ~ treatment + covariate1 + covariate2 + __covariate1__interaction + __covariate2__interaction" 55 | ) 56 | -------------------------------------------------------------------------------- /tests/analysis/test_hypothesis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from cluster_experiments.experiment_analysis import ( 5 | ClusteredOLSAnalysis, 6 | GeeExperimentAnalysis, 7 | MLMExperimentAnalysis, 8 | OLSAnalysis, 9 | SyntheticControlAnalysis, 10 | TTestClusteredAnalysis, 11 | ) 12 | from cluster_experiments.synthetic_control_utils import generate_synthetic_control_data 13 | from tests.utils import generate_clustered_data 14 | 15 | 16 | @pytest.mark.parametrize("hypothesis", ["less", "greater", "two-sided"]) 17 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis]) 18 | def test_get_pvalue_hypothesis(analysis_class, hypothesis, analysis_df): 19 | analysis_df_full = pd.concat([analysis_df for _ in range(100)]) 20 | analyser = analysis_class(hypothesis=hypothesis) 21 | assert analyser.get_pvalue(analysis_df_full) >= 0 22 | 23 | 24 | @pytest.mark.parametrize("hypothesis", ["less", "greater", "two-sided"]) 25 | @pytest.mark.parametrize( 26 | "analysis_class", 27 | [ 28 | ClusteredOLSAnalysis, 29 | GeeExperimentAnalysis, 30 | TTestClusteredAnalysis, 31 | MLMExperimentAnalysis, 32 | ], 33 | ) 34 | def test_get_pvalue_hypothesis_clustered(analysis_class, hypothesis): 35 | 36 | analysis_df_full = generate_clustered_data() 37 | analyser = analysis_class(hypothesis=hypothesis, cluster_cols=["user_id"]) 38 | assert analyser.get_pvalue(analysis_df_full) >= 0 39 | 40 | 41 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis]) 42 | def test_get_pvalue_hypothesis_default(analysis_class, analysis_df): 43 | analysis_df_full = pd.concat([analysis_df for _ in range(100)]) 44 | analyser = analysis_class() 45 | assert analyser.get_pvalue(analysis_df_full) >= 0 46 | 47 | 48 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis]) 49 | def test_get_pvalue_hypothesis_wrong_input(analysis_class, analysis_df): 50 | analysis_df_full = pd.concat([analysis_df for _ in range(100)]) 51 | 52 | # Use pytest.raises to check for ValueError 53 | with pytest.raises(ValueError) as excinfo: 54 | analyser = analysis_class(hypothesis="wrong_input") 55 | analyser.get_pvalue(analysis_df_full) >= 0 56 | 57 | # Check if the error message is as expected 58 | assert "'wrong_input' is not a valid HypothesisEntries" in str(excinfo.value) 59 | 60 | 61 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis]) 62 | def test_several_hypothesis(analysis_class, analysis_df): 63 | analysis_df_full = pd.concat([analysis_df for _ in range(100)]) 64 | analysis_less = analysis_class(hypothesis="less") 65 | analysis_greater = analysis_class(hypothesis="greater") 66 | analysis_two_sided = analysis_class(hypothesis="two-sided") 67 | 68 | assert ( 69 | analysis_less.get_pvalue(analysis_df_full) 70 | == analysis_two_sided.get_pvalue(analysis_df_full) / 2 71 | ) 72 | assert ( 73 | analysis_greater.get_pvalue(analysis_df_full) 74 | == 1 - analysis_two_sided.get_pvalue(analysis_df_full) / 2 75 | ) 76 | 77 | 78 | @pytest.mark.parametrize("hypothesis", ["less", "greater", "two-sided"]) 79 | def test_hypothesis_synthetic(hypothesis): 80 | 81 | df = generate_synthetic_control_data( 82 | N=10, start_date="2022-01-01", end_date="2022-01-30" 83 | ) 84 | # Add treatment column to only 1 user 85 | df["treatment"] = 0 86 | df.loc[(df["user"] == "User 5"), "treatment"] = 1 87 | 88 | analysis = SyntheticControlAnalysis( 89 | hypothesis=hypothesis, cluster_cols=["user"], intervention_date="2022-01-15" 90 | ) 91 | assert analysis.analysis_pvalue(df) >= 0 92 | -------------------------------------------------------------------------------- /tests/analysis/test_ols_analysis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from cluster_experiments.experiment_analysis import OLSAnalysis 4 | 5 | 6 | def test_binary_treatment(analysis_df): 7 | analyser = OLSAnalysis() 8 | assert ( 9 | analyser._create_binary_treatment(analysis_df)["treatment"] 10 | == pd.Series([0, 1, 1, 0]) 11 | ).all() 12 | 13 | 14 | def test_get_pvalue(analysis_df): 15 | analysis_df_full = pd.concat([analysis_df for _ in range(100)]) 16 | analyser = OLSAnalysis() 17 | assert analyser.get_pvalue(analysis_df_full) >= 0 18 | 19 | 20 | def test_cov_type(analysis_df): 21 | # given 22 | analyser_hc1 = OLSAnalysis(cov_type="HC1") 23 | analyser_hc3 = OLSAnalysis(cov_type="HC3") 24 | 25 | # then: point estimates are the same 26 | assert analyser_hc1.get_point_estimate( 27 | analysis_df 28 | ) == analyser_hc3.get_point_estimate(analysis_df) 29 | 30 | # then: standard errors are different 31 | assert analyser_hc1.get_standard_error( 32 | analysis_df 33 | ) != analyser_hc3.get_standard_error(analysis_df) 34 | 35 | 36 | def test_covariate_interaction(covariate_data): 37 | # given 38 | analysis_interaction = OLSAnalysis( 39 | treatment_col="T", 40 | target_col="y", 41 | covariates=["X"], 42 | add_covariate_interaction=True, 43 | ) 44 | analysis_no_interaction = OLSAnalysis( 45 | treatment_col="T", 46 | target_col="y", 47 | covariates=["X"], 48 | add_covariate_interaction=False, 49 | ) 50 | 51 | # when: calculating point estimates 52 | point_estimate_interaction = analysis_interaction.get_point_estimate(covariate_data) 53 | point_estimate_no_interaction = analysis_no_interaction.get_point_estimate( 54 | covariate_data 55 | ) 56 | 57 | # then: point estimates are different 58 | assert analysis_interaction.formula != analysis_no_interaction.formula 59 | assert point_estimate_interaction != point_estimate_no_interaction 60 | -------------------------------------------------------------------------------- /tests/analysis/test_synthetic_analysis.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from cluster_experiments.experiment_analysis import SyntheticControlAnalysis 8 | from cluster_experiments.synthetic_control_utils import get_w 9 | 10 | 11 | def generate_2_clusters_data(N, start_date, end_date): 12 | dates = pd.date_range(start_date, end_date, freq="d") 13 | country = ["US", "UK"] 14 | users = [f"User {i}" for i in range(N)] 15 | 16 | # Get the combination of each date with each user 17 | combinations = list(product(users, dates, country)) 18 | 19 | df = pd.DataFrame(combinations, columns=["user", "date", "country"]) 20 | 21 | df["target"] = np.random.normal(0, 1, size=len(combinations)) 22 | # Ensure 'date' column is of datetime type 23 | df["date"] = pd.to_datetime(df["date"]) 24 | 25 | return df 26 | 27 | 28 | def test_synthetic_control_analysis(): 29 | df = generate_2_clusters_data(10, "2022-01-01", "2022-01-30") 30 | 31 | # Add treatment column to only 1 user 32 | df["treatment"] = "A" 33 | df.loc[(df["user"] == "User 5") & (df["country"] == "US"), "treatment"] = "B" 34 | 35 | analysis = SyntheticControlAnalysis( 36 | cluster_cols=["user", "country"], intervention_date="2022-01-06" 37 | ) 38 | 39 | p_value = analysis.get_pvalue(df) 40 | assert 0 <= p_value <= 1 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "X, y", 45 | [ 46 | ( 47 | np.array([[1, 2], [3, 4]]), 48 | np.array([1, 1]), 49 | ), # Scenario with positive integers 50 | ( 51 | np.array([[1, -2], [-3, 4]]), 52 | np.array([1, 1]), 53 | ), # Scenario with negative integers 54 | ( 55 | np.array([[1.5, 2.5], [3.5, 4.5]]), 56 | np.array([1, 1]), 57 | ), # Scenario with positive floats 58 | ( 59 | np.array([[1.5, -2.5], [-3.5, 4.5]]), 60 | np.array([1, 1]), 61 | ), # Scenario with negative floats 62 | ], 63 | ) 64 | def test_get_w_weights( 65 | X, y 66 | ): # this function is not part of the analysis, but it is used in it 67 | expected_sum = 1 68 | expected_bounds = (0, 1) 69 | weights = get_w(X, y) 70 | assert np.isclose(np.sum(weights), expected_sum), "Weights sum should be close to 1" 71 | assert all( 72 | expected_bounds[0] <= w <= expected_bounds[1] for w in weights 73 | ), "Each weight should be between 0 and 1" 74 | 75 | 76 | def test_get_treatment_cluster(): 77 | analysis = SyntheticControlAnalysis( 78 | cluster_cols=["cluster"], intervention_date="2022-01-06" 79 | ) 80 | df = pd.DataFrame( 81 | { 82 | "target": [1, 2, 3, 4, 5, 6], 83 | "treatment": [0, 0, 1, 1, 1, 0], 84 | "cluster": [ 85 | "cluster1", 86 | "cluster2", 87 | "cluster3", 88 | "cluster3", 89 | "cluster3", 90 | "cluster2", 91 | ], 92 | } 93 | ) 94 | expected_cluster = "cluster3" 95 | assert analysis._get_treatment_cluster(df) == expected_cluster 96 | 97 | 98 | def test_point_estimate_synthetic_control(): 99 | df = generate_2_clusters_data(10, "2022-01-01", "2022-01-30") 100 | 101 | # Add treatment column to only 1 cluster 102 | df["treatment"] = 0 103 | df.loc[(df["user"] == "User 5") & (df["country"] == "US"), "treatment"] = 1 104 | 105 | df.loc[(df["user"] == "User 5") & (df["country"] == "US"), "target"] = 10 106 | 107 | analysis = SyntheticControlAnalysis( 108 | cluster_cols=["user", "country"], intervention_date="2022-01-06" 109 | ) 110 | 111 | effect = analysis.analysis_point_estimate(df) 112 | assert 9 <= effect <= 11 113 | 114 | 115 | def test_predict(): 116 | analysis = SyntheticControlAnalysis( 117 | cluster_cols=["user", "country"], intervention_date="2022-01-06" 118 | ) 119 | df = generate_2_clusters_data(5, "2021-01-01", "2021-01-10") 120 | df["treatment"] = "A" 121 | df.loc[(df["user"] == "User 4") & (df["country"] == "US"), "treatment"] = "B" 122 | 123 | # Same effect to every donor cluster 124 | weights = np.array([0.2] * 9) 125 | 126 | treatment_cluster = "User 4US" 127 | 128 | result = analysis._predict(df, weights, treatment_cluster) 129 | 130 | # Check the results 131 | assert ( 132 | "synthetic" in result.columns 133 | ), "The result DataFrame should include a 'synthetic' column" 134 | assert all( 135 | result["treatment"] == "B" 136 | ), "The result DataFrame should only contain the treatment cluster" 137 | assert len(result) > 0, "Should have at least one entry for the treatment cluster" 138 | assert ( 139 | not result["synthetic"].isnull().any() 140 | ), "Synthetic column should not have null values" 141 | -------------------------------------------------------------------------------- /tests/cupac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/cupac/__init__.py -------------------------------------------------------------------------------- /tests/cupac/conftest.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def binary_df(): 7 | return pd.DataFrame( 8 | { 9 | "target": [0, 1, 0, 1], 10 | "treatment": ["A", "B", "B", "A"], 11 | } 12 | ) 13 | -------------------------------------------------------------------------------- /tests/cupac/test_aggregator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pandas as pd 4 | 5 | from cluster_experiments.cupac import TargetAggregation 6 | 7 | 8 | def split_x_y(binary_df_agg: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series]: 9 | return binary_df_agg.drop("target", axis=1), binary_df_agg["target"] 10 | 11 | 12 | def test_set_target_aggs(binary_df): 13 | binary_df["user"] = [1, 1, 1, 1] 14 | ta = TargetAggregation(agg_col="user") 15 | X, y = split_x_y(binary_df) 16 | ta.fit(X, y) 17 | 18 | assert len(ta.pre_experiment_agg_df) == 1 19 | assert ta.pre_experiment_mean == 0.5 20 | 21 | 22 | def test_smoothing_0(binary_df): 23 | binary_df["user"] = binary_df["target"] 24 | ta = TargetAggregation(agg_col="user", smoothing_factor=0) 25 | X, y = split_x_y(binary_df) 26 | ta.fit(X, y) 27 | assert ( 28 | ta.pre_experiment_agg_df["target_mean"] 29 | == ta.pre_experiment_agg_df["target_smooth_mean"] 30 | ).all() 31 | 32 | 33 | def test_smoothing_non_0(binary_df): 34 | binary_df["user"] = binary_df["target"] 35 | ta = TargetAggregation(agg_col="user", smoothing_factor=2) 36 | X, y = split_x_y(binary_df) 37 | ta.fit(X, y) 38 | assert ( 39 | ta.pre_experiment_agg_df["target_mean"] 40 | != ta.pre_experiment_agg_df["target_smooth_mean"] 41 | ).all() 42 | assert ( 43 | ta.pre_experiment_agg_df["target_smooth_mean"].loc[[0, 1]] == [0.25, 0.75] 44 | ).all() 45 | 46 | 47 | def test_add_aggs(binary_df): 48 | binary_df["user"] = binary_df["target"] 49 | ta = TargetAggregation(agg_col="user", smoothing_factor=2) 50 | X, y = split_x_y(binary_df) 51 | ta.fit(X, y) 52 | binary_df["target_smooth_mean"] = ta.predict(binary_df) 53 | assert (binary_df.query("user == 0")["target_smooth_mean"] == 0.25).all() 54 | -------------------------------------------------------------------------------- /tests/cupac/test_cupac_handler.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.ensemble import HistGradientBoostingRegressor 6 | 7 | from cluster_experiments.cupac import CupacHandler, TargetAggregation 8 | from tests.utils import generate_random_data 9 | 10 | N = 1_000 11 | 12 | 13 | @pytest.fixture 14 | def clusters(): 15 | return [f"Cluster {i}" for i in range(100)] 16 | 17 | 18 | @pytest.fixture 19 | def dates(): 20 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 21 | 22 | 23 | @pytest.fixture 24 | def experiment_dates(): 25 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)] 26 | 27 | 28 | @pytest.fixture 29 | def df(clusters, dates): 30 | return generate_random_data(clusters, dates, N) 31 | 32 | 33 | @pytest.fixture 34 | def df_feats(clusters, dates): 35 | df = generate_random_data(clusters, dates, N) 36 | df["x1"] = np.random.normal(0, 1, N) 37 | df["x2"] = np.random.normal(0, 1, N) 38 | return df 39 | 40 | 41 | @pytest.fixture 42 | def cupac_handler_base(): 43 | return CupacHandler( 44 | cupac_model=TargetAggregation( 45 | agg_col="user", 46 | ), 47 | target_col="target", 48 | # features_cupac_model=["user"], 49 | ) 50 | 51 | 52 | @pytest.fixture 53 | def cupac_handler_model(): 54 | return CupacHandler( 55 | cupac_model=HistGradientBoostingRegressor(max_iter=3), 56 | target_col="target", 57 | features_cupac_model=["x1", "x2"], 58 | ) 59 | 60 | 61 | @pytest.fixture 62 | def missing_cupac(): 63 | return CupacHandler( 64 | None, 65 | ) 66 | 67 | 68 | @pytest.mark.parametrize( 69 | "cupac_handler", 70 | [ 71 | "cupac_handler_base", 72 | "cupac_handler_model", 73 | ], 74 | ) 75 | def test_add_covariates(cupac_handler, df_feats, request): 76 | cupac_handler = request.getfixturevalue(cupac_handler) 77 | df = cupac_handler.add_covariates(df_feats, df_feats.head(10)) 78 | assert df["estimate_target"].isna().sum() == 0 79 | assert (df["estimate_target"] <= df["target"].max()).all() 80 | assert (df["estimate_target"] >= df["target"].min()).all() 81 | 82 | 83 | def test_no_target(missing_cupac, df_feats): 84 | """Checks that no target is added when the there is no cupac model""" 85 | df = missing_cupac.add_covariates(df_feats) 86 | assert "estimate_target" not in df.columns 87 | 88 | 89 | def test_no_pre_experiment(cupac_handler_base, df_feats): 90 | """Checks that if there is a cupac model, pre_experiment_df should be provided""" 91 | with pytest.raises(ValueError, match="pre_experiment_df should be provided"): 92 | cupac_handler_base.add_covariates(df_feats) 93 | 94 | 95 | def test_no_cupac(missing_cupac, df_feats): 96 | """Checks that if there is no cupac model, pre_experiment_df should not be provided""" 97 | with pytest.raises(ValueError, match="remove pre_experiment_df argument"): 98 | missing_cupac.add_covariates(df_feats, df_feats.head(10)) 99 | -------------------------------------------------------------------------------- /tests/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/inference/__init__.py -------------------------------------------------------------------------------- /tests/inference/test_analysis_plan_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from cluster_experiments.inference.analysis_plan import AnalysisPlan 6 | from cluster_experiments.inference.dimension import Dimension 7 | from cluster_experiments.inference.hypothesis_test import HypothesisTest 8 | from cluster_experiments.inference.metric import SimpleMetric 9 | from cluster_experiments.inference.variant import Variant 10 | 11 | 12 | @pytest.fixture 13 | def experiment_data(): 14 | N = 1_000 15 | return pd.DataFrame( 16 | { 17 | "order_value": np.random.normal(100, 10, size=N), 18 | "delivery_time": np.random.normal(10, 1, size=N), 19 | "experiment_group": np.random.choice(["control", "treatment"], size=N), 20 | "city": np.random.choice(["NYC", "LA"], size=N), 21 | "customer_id": np.random.randint(1, 100, size=N), 22 | "customer_age": np.random.randint(20, 60, size=N), 23 | } 24 | ) 25 | 26 | 27 | def test_from_metrics_dict(): 28 | d = { 29 | "metrics": [{"alias": "AOV", "name": "order_value"}], 30 | "variants": [ 31 | {"name": "control", "is_control": True}, 32 | {"name": "treatment_1", "is_control": False}, 33 | ], 34 | "variant_col": "experiment_group", 35 | "alpha": 0.05, 36 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 37 | "analysis_type": "clustered_ols", 38 | "analysis_config": {"cluster_cols": ["customer_id"]}, 39 | } 40 | plan = AnalysisPlan.from_metrics_dict(d) 41 | assert isinstance(plan, AnalysisPlan) 42 | assert len(plan.tests) == 1 43 | assert isinstance(plan.tests[0], HypothesisTest) 44 | assert plan.variant_col == "experiment_group" 45 | assert plan.alpha == 0.05 46 | assert len(plan.variants) == 2 47 | 48 | 49 | def test_analyze_from_metrics_dict(experiment_data): 50 | # given 51 | plan = AnalysisPlan.from_metrics_dict( 52 | { 53 | "metrics": [ 54 | {"alias": "AOV", "name": "order_value"}, 55 | {"alias": "delivery_time", "name": "delivery_time"}, 56 | ], 57 | "variants": [ 58 | {"name": "control", "is_control": True}, 59 | {"name": "treatment", "is_control": False}, 60 | ], 61 | "variant_col": "experiment_group", 62 | "alpha": 0.05, 63 | "dimensions": [ 64 | {"name": "city", "values": ["NYC", "LA"]}, 65 | ], 66 | "analysis_type": "clustered_ols", 67 | "analysis_config": {"cluster_cols": ["customer_id"]}, 68 | } 69 | ) 70 | 71 | # when 72 | results = plan.analyze(experiment_data) 73 | results_df = results.to_dataframe() 74 | 75 | # then 76 | assert ( 77 | len(results_df) == 6 78 | ), "There should be 6 rows in the results DataFrame, 2 metrics x 3 dimension values" 79 | assert set(results_df["metric_alias"]) == { 80 | "AOV", 81 | "delivery_time", 82 | }, "The metric aliases should be present in the DataFrame" 83 | assert set(results_df["dimension_value"]) == { 84 | "total", 85 | "" "NYC", 86 | "LA", 87 | }, "The dimension values should be present in the DataFrame" 88 | 89 | 90 | def test_from_dict_config(): 91 | # ensures that we get the same object when creating from a dict or a config 92 | # given 93 | d = { 94 | "tests": [ 95 | { 96 | "metric": {"alias": "AOV", "name": "order_value"}, 97 | "analysis_type": "clustered_ols", 98 | "analysis_config": {"cluster_cols": ["customer_id"]}, 99 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 100 | }, 101 | { 102 | "metric": {"alias": "delivery_time", "name": "delivery_time"}, 103 | "analysis_type": "clustered_ols", 104 | "analysis_config": {"cluster_cols": ["customer_id"]}, 105 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 106 | }, 107 | ], 108 | "variants": [ 109 | {"name": "control", "is_control": True}, 110 | {"name": "treatment", "is_control": False}, 111 | ], 112 | "variant_col": "experiment_group", 113 | "alpha": 0.05, 114 | } 115 | plan = AnalysisPlan( 116 | variants=[ 117 | Variant(name="control", is_control=True), 118 | Variant(name="treatment", is_control=False), 119 | ], 120 | variant_col="experiment_group", 121 | alpha=0.05, 122 | tests=[ 123 | HypothesisTest( 124 | metric=SimpleMetric(alias="AOV", name="order_value"), 125 | analysis_type="clustered_ols", 126 | analysis_config={"cluster_cols": ["customer_id"]}, 127 | dimensions=[Dimension(name="city", values=["NYC", "LA"])], 128 | ), 129 | HypothesisTest( 130 | metric=SimpleMetric(alias="delivery_time", name="delivery_time"), 131 | analysis_type="clustered_ols", 132 | analysis_config={"cluster_cols": ["customer_id"]}, 133 | dimensions=[Dimension(name="city", values=["NYC", "LA"])], 134 | ), 135 | ], 136 | ) 137 | 138 | # when 139 | plan_from_config = AnalysisPlan.from_dict(d) 140 | 141 | # then 142 | assert plan.variant_col == plan_from_config.variant_col 143 | assert plan.alpha == plan_from_config.alpha 144 | for variant in plan.variants: 145 | assert variant in plan_from_config.variants 146 | 147 | 148 | def test_from_dict(): 149 | # given 150 | d = { 151 | "tests": [ 152 | { 153 | "metric": {"alias": "AOV", "name": "order_value"}, 154 | "analysis_type": "clustered_ols", 155 | "analysis_config": {"cluster_cols": ["customer_id"]}, 156 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 157 | }, 158 | { 159 | "metric": {"alias": "DT", "name": "delivery_time"}, 160 | "analysis_type": "clustered_ols", 161 | "analysis_config": {"cluster_cols": ["customer_id"]}, 162 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 163 | }, 164 | ], 165 | "variants": [ 166 | {"name": "control", "is_control": True}, 167 | {"name": "treatment_1", "is_control": False}, 168 | ], 169 | "variant_col": "experiment_group", 170 | "alpha": 0.05, 171 | } 172 | 173 | # when 174 | plan = AnalysisPlan.from_dict(d) 175 | 176 | # then 177 | assert isinstance(plan, AnalysisPlan) 178 | assert len(plan.tests) == 2 179 | assert isinstance(plan.tests[0], HypothesisTest) 180 | assert plan.variant_col == "experiment_group" 181 | assert plan.alpha == 0.05 182 | assert len(plan.variants) == 2 183 | assert plan.tests[1].metric.alias == "DT" 184 | 185 | 186 | def test_analyze_from_dict(experiment_data): 187 | # given 188 | d = { 189 | "tests": [ 190 | { 191 | "metric": {"alias": "AOV", "name": "order_value"}, 192 | "analysis_type": "clustered_ols", 193 | "analysis_config": {"cluster_cols": ["customer_id"]}, 194 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 195 | }, 196 | { 197 | "metric": {"alias": "DT", "name": "delivery_time"}, 198 | "analysis_type": "clustered_ols", 199 | "analysis_config": {"cluster_cols": ["customer_id"]}, 200 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}], 201 | }, 202 | ], 203 | "variants": [ 204 | {"name": "control", "is_control": True}, 205 | {"name": "treatment_1", "is_control": False}, 206 | ], 207 | "variant_col": "experiment_group", 208 | "alpha": 0.05, 209 | } 210 | plan = AnalysisPlan.from_dict(d) 211 | 212 | # when 213 | results = plan.analyze(experiment_data) 214 | results_df = results.to_dataframe() 215 | 216 | # then 217 | assert ( 218 | len(results_df) == 6 219 | ), "There should be 6 rows in the results DataFrame, 2 metrics x 3 dimension values" 220 | assert set(results_df["metric_alias"]) == { 221 | "AOV", 222 | "DT", 223 | }, "The metric aliases should be present in the DataFrame" 224 | assert set(results_df["dimension_value"]) == { 225 | "total", 226 | "" "NYC", 227 | "LA", 228 | }, "The dimension values should be present in the DataFrame" 229 | -------------------------------------------------------------------------------- /tests/inference/test_analysis_results.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | from cluster_experiments.inference.analysis_results import AnalysisPlanResults 7 | 8 | 9 | def test_analysis_plan_results_initialization(): 10 | """Test that AnalysisPlanResults initializes with empty lists by default.""" 11 | results = AnalysisPlanResults() 12 | assert results.metric_alias == [] 13 | assert results.control_variant_name == [] 14 | assert results.treatment_variant_name == [] 15 | assert results.control_variant_mean == [] 16 | assert results.treatment_variant_mean == [] 17 | assert results.analysis_type == [] 18 | assert results.ate == [] 19 | assert results.ate_ci_lower == [] 20 | assert results.ate_ci_upper == [] 21 | assert results.p_value == [] 22 | assert results.std_error == [] 23 | assert results.dimension_name == [] 24 | assert results.dimension_value == [] 25 | assert results.alpha == [] 26 | 27 | 28 | def test_analysis_plan_results_custom_initialization(): 29 | """Test that AnalysisPlanResults initializes with custom data.""" 30 | results = AnalysisPlanResults( 31 | metric_alias=["metric1"], 32 | control_variant_name=["Control"], 33 | treatment_variant_name=["Treatment"], 34 | control_variant_mean=[0.5], 35 | treatment_variant_mean=[0.6], 36 | analysis_type=["AB Test"], 37 | ate=[0.1], 38 | ate_ci_lower=[0.05], 39 | ate_ci_upper=[0.15], 40 | p_value=[0.04], 41 | std_error=[0.02], 42 | dimension_name=["Country"], 43 | dimension_value=["US"], 44 | alpha=[0.05], 45 | ) 46 | assert results.metric_alias == ["metric1"] 47 | assert results.control_variant_name == ["Control"] 48 | assert results.treatment_variant_name == ["Treatment"] 49 | assert results.control_variant_mean == [0.5] 50 | assert results.treatment_variant_mean == [0.6] 51 | assert results.analysis_type == ["AB Test"] 52 | assert results.ate == [0.1] 53 | assert results.ate_ci_lower == [0.05] 54 | assert results.ate_ci_upper == [0.15] 55 | assert results.p_value == [0.04] 56 | assert results.std_error == [0.02] 57 | assert results.dimension_name == ["Country"] 58 | assert results.dimension_value == ["US"] 59 | assert results.alpha == [0.05] 60 | 61 | 62 | def test_analysis_plan_results_addition(): 63 | """Test that two AnalysisPlanResults instances can be added together.""" 64 | results1 = AnalysisPlanResults( 65 | metric_alias=["metric1"], 66 | control_variant_name=["Control"], 67 | treatment_variant_name=["Treatment"], 68 | control_variant_mean=[0.5], 69 | treatment_variant_mean=[0.6], 70 | analysis_type=["AB Test"], 71 | ate=[0.1], 72 | ate_ci_lower=[0.05], 73 | ate_ci_upper=[0.15], 74 | p_value=[0.04], 75 | std_error=[0.02], 76 | dimension_name=["Country"], 77 | dimension_value=["US"], 78 | alpha=[0.05], 79 | ) 80 | results2 = AnalysisPlanResults( 81 | metric_alias=["metric2"], 82 | control_variant_name=["Control"], 83 | treatment_variant_name=["Treatment"], 84 | control_variant_mean=[0.55], 85 | treatment_variant_mean=[0.65], 86 | analysis_type=["AB Test"], 87 | ate=[0.1], 88 | ate_ci_lower=[0.05], 89 | ate_ci_upper=[0.15], 90 | p_value=[0.03], 91 | std_error=[0.01], 92 | dimension_name=["Country"], 93 | dimension_value=["CA"], 94 | alpha=[0.05], 95 | ) 96 | combined_results = results1 + results2 97 | 98 | assert combined_results.metric_alias == ["metric1", "metric2"] 99 | assert combined_results.control_variant_name == ["Control", "Control"] 100 | assert combined_results.treatment_variant_name == ["Treatment", "Treatment"] 101 | assert combined_results.control_variant_mean == [0.5, 0.55] 102 | assert combined_results.treatment_variant_mean == [0.6, 0.65] 103 | assert combined_results.analysis_type == ["AB Test", "AB Test"] 104 | assert combined_results.ate == [0.1, 0.1] 105 | assert combined_results.ate_ci_lower == [0.05, 0.05] 106 | assert combined_results.ate_ci_upper == [0.15, 0.15] 107 | assert combined_results.p_value == [0.04, 0.03] 108 | assert combined_results.std_error == [0.02, 0.01] 109 | assert combined_results.dimension_name == ["Country", "Country"] 110 | assert combined_results.dimension_value == ["US", "CA"] 111 | assert combined_results.alpha == [0.05, 0.05] 112 | 113 | 114 | def test_analysis_plan_results_addition_type_error(): 115 | """Test that adding a non-AnalysisPlanResults object raises a TypeError.""" 116 | results = AnalysisPlanResults(metric_alias=["metric1"]) 117 | with pytest.raises(TypeError): 118 | results + "not_an_analysis_plan_results" # Should raise TypeError 119 | 120 | 121 | def test_analysis_plan_results_to_dataframe(): 122 | """Test that AnalysisPlanResults converts to a DataFrame correctly.""" 123 | results = AnalysisPlanResults( 124 | metric_alias=["metric1"], 125 | control_variant_name=["Control"], 126 | treatment_variant_name=["Treatment"], 127 | control_variant_mean=[0.5], 128 | treatment_variant_mean=[0.6], 129 | analysis_type=["AB Test"], 130 | ate=[0.1], 131 | ate_ci_lower=[0.05], 132 | ate_ci_upper=[0.15], 133 | p_value=[0.04], 134 | std_error=[0.02], 135 | dimension_name=["Country"], 136 | dimension_value=["US"], 137 | alpha=[0.05], 138 | ) 139 | df = results.to_dataframe() 140 | 141 | assert isinstance(df, pd.DataFrame) 142 | assert df.shape[0] == 1 # Only one entry 143 | assert set(df.columns) == set(asdict(results).keys()) # Columns match attributes 144 | assert df["metric_alias"].iloc[0] == "metric1" 145 | assert df["control_variant_name"].iloc[0] == "Control" 146 | assert df["treatment_variant_name"].iloc[0] == "Treatment" 147 | assert df["control_variant_mean"].iloc[0] == 0.5 148 | assert df["treatment_variant_mean"].iloc[0] == 0.6 149 | assert df["ate"].iloc[0] == 0.1 150 | assert df["ate_ci_lower"].iloc[0] == 0.05 151 | assert df["ate_ci_upper"].iloc[0] == 0.15 152 | assert df["p_value"].iloc[0] == 0.04 153 | assert df["std_error"].iloc[0] == 0.02 154 | assert df["dimension_name"].iloc[0] == "Country" 155 | assert df["dimension_value"].iloc[0] == "US" 156 | assert df["alpha"].iloc[0] == 0.05 157 | -------------------------------------------------------------------------------- /tests/inference/test_dimension.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments.inference.dimension import DefaultDimension, Dimension 4 | 5 | 6 | def test_dimension_initialization(): 7 | """Test Dimension initialization with valid inputs.""" 8 | dim = Dimension(name="Country", values=["US", "CA", "UK"]) 9 | assert dim.name == "Country" 10 | assert dim.values == ["US", "CA", "UK"] 11 | 12 | 13 | def test_dimension_name_type(): 14 | """Test that Dimension raises TypeError if name is not a string.""" 15 | with pytest.raises(TypeError, match="Dimension name must be a string"): 16 | Dimension(name=123, values=["US", "CA", "UK"]) # Name should be a string 17 | 18 | 19 | def test_dimension_values_type(): 20 | """Test that Dimension raises TypeError if values is not a list of strings.""" 21 | # Values should be a list 22 | with pytest.raises(TypeError, match="Dimension values must be a list of strings"): 23 | Dimension(name="Country", values="US, CA, UK") # Should be a list of strings 24 | 25 | # Values should be a list of strings 26 | with pytest.raises(TypeError, match="Dimension values must be a list of strings"): 27 | Dimension( 28 | name="Country", values=["US", 123, "UK"] 29 | ) # All elements should be strings 30 | 31 | 32 | def test_dimension_iterate_dimension_values(): 33 | """Test Dimension iterate_dimension_values method to ensure unique values are returned.""" 34 | dim = Dimension(name="Country", values=["US", "CA", "US", "UK", "CA"]) 35 | unique_values = list(dim.iterate_dimension_values()) 36 | assert unique_values == ["US", "CA", "UK"] # Ensures unique, ordered values 37 | 38 | 39 | def test_default_dimension_initialization(): 40 | """Test DefaultDimension initialization.""" 41 | default_dim = DefaultDimension() 42 | assert default_dim.name == "__total_dimension" 43 | assert default_dim.values == ["total"] 44 | 45 | 46 | def test_default_dimension_iterate_dimension_values(): 47 | """Test that DefaultDimension's iterate_dimension_values yields 'total'.""" 48 | default_dim = DefaultDimension() 49 | values = list(default_dim.iterate_dimension_values()) 50 | assert values == ["total"] 51 | -------------------------------------------------------------------------------- /tests/inference/test_metric.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from cluster_experiments.inference.metric import Metric, RatioMetric, SimpleMetric 5 | 6 | # Sample DataFrame for testing 7 | sample_data = pd.DataFrame( 8 | { 9 | "salary": [50000, 60000, 70000], 10 | "numerator": [150000, 180000, 210000], 11 | "denominator": [3, 3, 3], 12 | } 13 | ) 14 | 15 | 16 | def test_metric_abstract_instantiation(): 17 | """Test that Metric cannot be instantiated directly.""" 18 | with pytest.raises(TypeError): 19 | Metric("test_metric") 20 | 21 | 22 | def test_metric_alias_type(): 23 | """Test that Metric raises TypeError if alias is not a string.""" 24 | with pytest.raises(TypeError): 25 | SimpleMetric(123, "salary") # Alias should be a string 26 | 27 | 28 | def test_simple_metric_initialization(): 29 | """Test SimpleMetric initialization and target column.""" 30 | metric = SimpleMetric("test_metric", "salary") 31 | assert metric.alias == "test_metric" 32 | assert metric.target_column == "salary" 33 | 34 | 35 | def test_simple_metric_name_type(): 36 | """Test that SimpleMetric raises TypeError if name is not a string.""" 37 | with pytest.raises(TypeError): 38 | SimpleMetric("test_metric", 123) # Name should be a string 39 | 40 | 41 | def test_simple_metric_get_mean(): 42 | """Test SimpleMetric get_mean() calculation.""" 43 | metric = SimpleMetric("test_metric", "salary") 44 | mean_value = metric.get_mean(sample_data) 45 | assert mean_value == 60000 # Mean of [50000, 60000, 70000] 46 | 47 | 48 | def test_ratio_metric_initialization(): 49 | """Test RatioMetric initialization and target column.""" 50 | metric = RatioMetric("test_ratio_metric", "numerator", "denominator") 51 | assert metric.alias == "test_ratio_metric" 52 | assert metric.target_column == "numerator" 53 | 54 | 55 | def test_ratio_metric_names_type(): 56 | """Test that RatioMetric raises TypeError if numerator or denominator are not strings.""" 57 | with pytest.raises(TypeError): 58 | RatioMetric( 59 | "test_ratio_metric", "numerator", 123 60 | ) # Denominator should be a string 61 | with pytest.raises(TypeError): 62 | RatioMetric( 63 | "test_ratio_metric", 123, "denominator" 64 | ) # Numerator should be a string 65 | 66 | 67 | def test_ratio_metric_get_mean(): 68 | """Test RatioMetric get_mean() calculation.""" 69 | metric = RatioMetric("test_ratio_metric", "numerator", "denominator") 70 | mean_value = metric.get_mean(sample_data) 71 | expected_value = sample_data["numerator"].mean() / sample_data["denominator"].mean() 72 | assert mean_value == expected_value 73 | -------------------------------------------------------------------------------- /tests/inference/test_variant.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments.inference.variant import Variant 4 | 5 | 6 | def test_variant_initialization(): 7 | """Test Variant initialization with valid inputs.""" 8 | variant = Variant(name="Test Variant", is_control=True) 9 | assert variant.name == "Test Variant" 10 | assert variant.is_control is True 11 | 12 | 13 | def test_variant_name_type(): 14 | """Test that Variant raises TypeError if name is not a string.""" 15 | with pytest.raises(TypeError, match="Variant name must be a string"): 16 | Variant(name=123, is_control=True) # Name should be a string 17 | 18 | 19 | def test_variant_is_control_type(): 20 | """Test that Variant raises TypeError if is_control is not a boolean.""" 21 | with pytest.raises(TypeError, match="Variant is_control must be a boolean"): 22 | Variant(name="Test Variant", is_control="yes") # is_control should be a boolean 23 | 24 | 25 | def test_variant_is_control_default_behavior(): 26 | """Test Variant behavior when is_control is set to False.""" 27 | variant = Variant(name="Variant B", is_control=False) 28 | assert variant.name == "Variant B" 29 | assert variant.is_control is False 30 | -------------------------------------------------------------------------------- /tests/perturbator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/perturbator/__init__.py -------------------------------------------------------------------------------- /tests/perturbator/conftest.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def binary_df(): 7 | return pd.DataFrame( 8 | { 9 | "target": [0, 1, 0, 1], 10 | "treatment": ["A", "B", "B", "A"], 11 | } 12 | ) 13 | 14 | 15 | @pytest.fixture 16 | def continuous_df(): 17 | return pd.DataFrame( 18 | { 19 | "target": [0.5, 0.5, 0.5, 0.5], 20 | "treatment": ["A", "B", "B", "A"], 21 | } 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def generate_clustered_data() -> pd.DataFrame: 27 | analysis_df = pd.DataFrame( 28 | { 29 | "country_code": ["ES"] * 4 + ["IT"] * 4 + ["PL"] * 4 + ["RO"] * 4, 30 | "city_code": ["BCN", "BCN", "MAD", "BCN"] 31 | + ["NAP"] * 4 32 | + ["WAW"] * 4 33 | + ["BUC"] * 4, 34 | "user_id": [1, 1, 2, 1, 3, 4, 5, 6, 7, 8, 8, 8, 9, 9, 9, 10], 35 | "date": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"] * 4, 36 | "treatment": [ 37 | "A", 38 | "A", 39 | "B", 40 | "A", 41 | "B", 42 | "B", 43 | "A", 44 | "B", 45 | "B", 46 | "A", 47 | "A", 48 | "A", 49 | "B", 50 | "B", 51 | "B", 52 | "A", 53 | ], # Randomization is done at user level, so same user will always have same treatment 54 | "target": [0.01] * 15 + [0.1], 55 | } 56 | ) 57 | return analysis_df 58 | 59 | 60 | @pytest.fixture 61 | def continuous_mixed_df(): 62 | return pd.DataFrame( 63 | { 64 | "target": [0.5, -50, 50, 0.5], 65 | "treatment": ["A", "B", "B", "A"], 66 | } 67 | ) 68 | -------------------------------------------------------------------------------- /tests/power_analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/power_analysis/__init__.py -------------------------------------------------------------------------------- /tests/power_analysis/conftest.py: -------------------------------------------------------------------------------- 1 | from datetime import date, timedelta 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from cluster_experiments.cupac import TargetAggregation 8 | from cluster_experiments.experiment_analysis import ( 9 | ClusteredOLSAnalysis, 10 | GeeExperimentAnalysis, 11 | MLMExperimentAnalysis, 12 | ) 13 | from cluster_experiments.perturbator import ConstantPerturbator 14 | from cluster_experiments.power_analysis import PowerAnalysis 15 | from cluster_experiments.random_splitter import ( 16 | ClusteredSplitter, 17 | StratifiedSwitchbackSplitter, 18 | ) 19 | from tests.utils import generate_random_data, generate_ratio_metric_data 20 | 21 | N = 1_000 22 | 23 | 24 | @pytest.fixture 25 | def clusters(): 26 | return [f"Cluster {i}" for i in range(100)] 27 | 28 | 29 | @pytest.fixture 30 | def dates(): 31 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)] 32 | 33 | 34 | @pytest.fixture 35 | def experiment_dates(): 36 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)] 37 | 38 | 39 | @pytest.fixture 40 | def df(clusters, dates): 41 | return generate_random_data(clusters, dates, N) 42 | 43 | 44 | @pytest.fixture 45 | def correlated_df(): 46 | _n_rows = 10_000 47 | _clusters = [f"Cluster {i}" for i in range(10)] 48 | _dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 15)] 49 | df = pd.DataFrame( 50 | { 51 | "cluster": np.random.choice(_clusters, size=_n_rows), 52 | "date": np.random.choice(_dates, size=_n_rows), 53 | } 54 | ).assign( 55 | # Target is a linear combination of cluster and day of week, plus some noise 56 | cluster_id=lambda df: df["cluster"].astype("category").cat.codes, 57 | day_of_week=lambda df: pd.to_datetime(df["date"]).dt.dayofweek, 58 | target=lambda df: df["cluster_id"] 59 | + df["day_of_week"] 60 | + np.random.normal(size=_n_rows), 61 | ) 62 | return df 63 | 64 | 65 | @pytest.fixture 66 | def df_feats(clusters, dates): 67 | df = generate_random_data(clusters, dates, N) 68 | df["x1"] = np.random.normal(0, 1, N) 69 | df["x2"] = np.random.normal(0, 1, N) 70 | return df 71 | 72 | 73 | @pytest.fixture 74 | def df_binary(clusters, dates): 75 | return generate_random_data(clusters, dates, N, target="binary") 76 | 77 | 78 | @pytest.fixture 79 | def perturbator(): 80 | return ConstantPerturbator(average_effect=0.1) 81 | 82 | 83 | @pytest.fixture 84 | def analysis_gee_vainilla(): 85 | return GeeExperimentAnalysis( 86 | cluster_cols=["cluster", "date"], 87 | ) 88 | 89 | 90 | @pytest.fixture 91 | def analysis_clusterd_ols(): 92 | return ClusteredOLSAnalysis( 93 | cluster_cols=["cluster", "date"], 94 | ) 95 | 96 | 97 | @pytest.fixture 98 | def analysis_mlm(): 99 | return MLMExperimentAnalysis( 100 | cluster_cols=["cluster", "date"], 101 | ) 102 | 103 | 104 | @pytest.fixture 105 | def analysis_gee(): 106 | return GeeExperimentAnalysis( 107 | cluster_cols=["cluster", "date"], 108 | covariates=["estimate_target"], 109 | ) 110 | 111 | 112 | @pytest.fixture 113 | def cupac_power_analysis(perturbator, analysis_gee): 114 | sw = ClusteredSplitter( 115 | cluster_cols=["cluster", "date"], 116 | ) 117 | 118 | target_agg = TargetAggregation( 119 | agg_col="cluster", 120 | ) 121 | 122 | return PowerAnalysis( 123 | perturbator=perturbator, 124 | splitter=sw, 125 | analysis=analysis_gee, 126 | cupac_model=target_agg, 127 | n_simulations=3, 128 | ) 129 | 130 | 131 | @pytest.fixture 132 | def switchback_power_analysis(perturbator, analysis_gee_vainilla): 133 | sw = StratifiedSwitchbackSplitter( 134 | time_col="date", 135 | switch_frequency="1D", 136 | strata_cols=["cluster"], 137 | cluster_cols=["cluster", "date"], 138 | ) 139 | 140 | return PowerAnalysis( 141 | perturbator=perturbator, 142 | splitter=sw, 143 | analysis=analysis_gee_vainilla, 144 | n_simulations=3, 145 | seed=123, 146 | ) 147 | 148 | 149 | @pytest.fixture 150 | def switchback_power_analysis_hourly(perturbator, analysis_gee_vainilla): 151 | sw = StratifiedSwitchbackSplitter( 152 | time_col="date", 153 | switch_frequency="1H", 154 | strata_cols=["cluster"], 155 | cluster_cols=["cluster", "date"], 156 | ) 157 | 158 | return PowerAnalysis( 159 | perturbator=perturbator, 160 | splitter=sw, 161 | analysis=analysis_gee_vainilla, 162 | n_simulations=3, 163 | ) 164 | 165 | 166 | @pytest.fixture 167 | def switchback_washover(): 168 | return PowerAnalysis.from_dict( 169 | { 170 | "time_col": "date", 171 | "switch_frequency": "1D", 172 | "perturbator": "constant", 173 | "analysis": "ols_clustered", 174 | "splitter": "switchback_balance", 175 | "cluster_cols": ["cluster", "date"], 176 | "strata_cols": ["cluster"], 177 | "washover": "constant_washover", 178 | "washover_time_delta": timedelta(hours=2), 179 | } 180 | ) 181 | 182 | 183 | @pytest.fixture 184 | def delta_df(experiment_dates): 185 | 186 | user_sample_mean = 0.3 187 | user_standard_error = 0.15 188 | users = 2000 189 | N = 50_000 190 | 191 | user_target_means = np.random.normal(user_sample_mean, user_standard_error, users) 192 | 193 | data = generate_ratio_metric_data( 194 | experiment_dates, N, user_target_means, users, treatment_effect=0 195 | ) 196 | return data 197 | -------------------------------------------------------------------------------- /tests/power_analysis/test_cupac_power.py: -------------------------------------------------------------------------------- 1 | from sklearn.ensemble import HistGradientBoostingRegressor 2 | 3 | 4 | def test_power_analyis_aggregate(df, experiment_dates, cupac_power_analysis): 5 | df_analysis = df.query(f"date.isin({experiment_dates})") 6 | df_pre = df.query(f"~date.isin({experiment_dates})") 7 | power = cupac_power_analysis.power_analysis(df_analysis, df_pre) 8 | assert power >= 0 9 | assert power <= 1 10 | 11 | 12 | def test_add_covariates(df, experiment_dates, cupac_power_analysis): 13 | df_analysis = df.query(f"date.isin({experiment_dates})") 14 | df_pre = df.query(f"~date.isin({experiment_dates})") 15 | estimated_target = cupac_power_analysis.cupac_handler.add_covariates( 16 | df_analysis, df_pre 17 | )["estimate_target"] 18 | assert estimated_target.isnull().sum() == 0 19 | assert (estimated_target <= df_pre["target"].max()).all() 20 | assert (estimated_target >= df_pre["target"].min()).all() 21 | assert "estimate_target" in cupac_power_analysis.analysis.covariates 22 | 23 | 24 | def test_prep_data(df_feats, experiment_dates, cupac_power_analysis): 25 | df = df_feats.copy() 26 | df_analysis = df.query(f"date.isin({experiment_dates})") 27 | df_pre = df.query(f"~date.isin({experiment_dates})") 28 | cupac_power_analysis.cupac_handler.features_cupac_model = ["x1", "x2"] 29 | ( 30 | df_predict, 31 | pre_experiment_x, 32 | pre_experiment_y, 33 | ) = cupac_power_analysis.cupac_handler._prep_data_cupac(df_analysis, df_pre) 34 | assert list(df_predict.columns) == ["x1", "x2"] 35 | assert list(pre_experiment_x.columns) == ["x1", "x2"] 36 | assert (df_predict["x1"] == df_analysis["x1"]).all() 37 | assert (pre_experiment_x["x1"] == df_pre["x1"]).all() 38 | assert (pre_experiment_y == df_pre["target"]).all() 39 | 40 | 41 | def test_cupac_gbm(df_feats, experiment_dates, cupac_power_analysis): 42 | df = df_feats.copy() 43 | df_analysis = df.query(f"date.isin({experiment_dates})") 44 | df_pre = df.query(f"~date.isin({experiment_dates})") 45 | cupac_power_analysis.features_cupac_model = ["x1", "x2"] 46 | cupac_power_analysis.cupac_model = HistGradientBoostingRegressor() 47 | power = cupac_power_analysis.power_analysis(df_analysis, df_pre) 48 | assert power >= 0 49 | assert power <= 1 50 | -------------------------------------------------------------------------------- /tests/power_analysis/test_multivariate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments import ( 4 | ConstantPerturbator, 5 | NonClusteredSplitter, 6 | OLSAnalysis, 7 | PowerAnalysis, 8 | ) 9 | from tests.utils import generate_non_clustered_data 10 | 11 | 12 | @pytest.fixture 13 | def df(): 14 | return generate_non_clustered_data( 15 | N=1000, 16 | n_users=100, 17 | ) 18 | 19 | 20 | @pytest.fixture 21 | def perturbator(): 22 | return ConstantPerturbator() 23 | 24 | 25 | @pytest.fixture 26 | def ols(): 27 | return OLSAnalysis() 28 | 29 | 30 | @pytest.fixture 31 | def binary_hypothesis_power(perturbator, ols): 32 | splitter = NonClusteredSplitter( 33 | treatments=["A", "B"], 34 | treatment_col="treatment", 35 | ) 36 | return PowerAnalysis( 37 | splitter=splitter, 38 | perturbator=perturbator, 39 | analysis=ols, 40 | ) 41 | 42 | 43 | @pytest.fixture 44 | def multivariate_hypothesis_power(perturbator, ols): 45 | splitter = NonClusteredSplitter( 46 | treatments=["A", "B", "C", "D", "E", "F", "G"], 47 | treatment_col="treatment", 48 | ) 49 | return PowerAnalysis( 50 | splitter=splitter, 51 | perturbator=perturbator, 52 | analysis=ols, 53 | ) 54 | 55 | 56 | @pytest.fixture 57 | def binary_hypothesis_power_config(): 58 | config = { 59 | "analysis": "ols_non_clustered", 60 | "perturbator": "constant", 61 | "splitter": "non_clustered", 62 | "n_simulations": 50, 63 | "seed": 220924, 64 | } 65 | return PowerAnalysis.from_dict(config) 66 | 67 | 68 | @pytest.fixture 69 | def multivariate_hypothesis_power_config(): 70 | config = { 71 | "analysis": "ols_non_clustered", 72 | "perturbator": "constant", 73 | "splitter": "non_clustered", 74 | "n_simulations": 50, 75 | "treatments": ["A", "B", "C", "D", "E", "F", "G"], 76 | "seed": 220924, 77 | } 78 | return PowerAnalysis.from_dict(config) 79 | 80 | 81 | def test_higher_power_analysis( 82 | multivariate_hypothesis_power, 83 | binary_hypothesis_power, 84 | df, 85 | ): 86 | power_multi = multivariate_hypothesis_power.power_analysis(df, average_effect=0.1) 87 | power_binary = binary_hypothesis_power.power_analysis(df, average_effect=0.1) 88 | assert power_multi < power_binary, f"{power_multi = } > {power_binary = }" 89 | 90 | 91 | def test_higher_power_analysis_config( 92 | multivariate_hypothesis_power_config, 93 | binary_hypothesis_power_config, 94 | df, 95 | ): 96 | power_multi = multivariate_hypothesis_power_config.power_analysis( 97 | df, average_effect=0.1 98 | ) 99 | power_binary = binary_hypothesis_power_config.power_analysis(df, average_effect=0.1) 100 | assert power_multi < power_binary, f"{power_multi = } > {power_binary = }" 101 | 102 | 103 | def test_raise_if_control_not_in_treatments( 104 | perturbator, 105 | ols, 106 | ): 107 | with pytest.raises(AssertionError): 108 | splitter = NonClusteredSplitter( 109 | treatments=["a", "B"], 110 | treatment_col="treatment", 111 | splitter_weights=[0.5, 0.5], 112 | ) 113 | PowerAnalysis( 114 | splitter=splitter, 115 | perturbator=perturbator, 116 | analysis=ols, 117 | ) 118 | with pytest.raises(AssertionError): 119 | splitter = NonClusteredSplitter( 120 | treatment_col="treatment", 121 | splitter_weights=[0.5, 0.5], 122 | ) 123 | PowerAnalysis( 124 | splitter=splitter, perturbator=perturbator, analysis=ols, control="X" 125 | ) 126 | -------------------------------------------------------------------------------- /tests/power_analysis/test_parallel.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments.power_analysis import PowerAnalysis 4 | from cluster_experiments.power_config import PowerConfig 5 | 6 | 7 | def test_raise_n_jobs(df): 8 | config = PowerConfig( 9 | cluster_cols=["cluster", "date"], 10 | analysis="gee", 11 | perturbator="constant", 12 | splitter="clustered", 13 | n_simulations=4, 14 | ) 15 | pw = PowerAnalysis.from_config(config) 16 | with pytest.raises(ValueError): 17 | pw.power_analysis(df, average_effect=0.0, n_jobs=0) 18 | with pytest.raises(ValueError): 19 | pw.power_analysis(df, average_effect=0.0, n_jobs=-10) 20 | 21 | 22 | def test_similar_n_jobs(df): 23 | config = PowerConfig( 24 | analysis="ols_non_clustered", 25 | perturbator="constant", 26 | splitter="non_clustered", 27 | n_simulations=100, 28 | seed=123, 29 | ) 30 | pw = PowerAnalysis.from_config(config) 31 | power = pw.power_analysis(df, average_effect=0.0, n_jobs=1) 32 | power2 = pw.power_analysis(df, average_effect=0.0, n_jobs=2) 33 | power3 = pw.power_analysis(df, average_effect=0.0, n_jobs=-1) 34 | assert abs(power - power2) <= 0.1 35 | assert abs(power - power3) <= 0.1 36 | -------------------------------------------------------------------------------- /tests/power_analysis/test_power_analysis.py: -------------------------------------------------------------------------------- 1 | from cluster_experiments.power_analysis import PowerAnalysis 2 | from cluster_experiments.power_config import PowerConfig 3 | from cluster_experiments.random_splitter import ClusteredSplitter 4 | 5 | 6 | def test_power_analysis(df, perturbator, analysis_gee_vainilla): 7 | sw = ClusteredSplitter( 8 | cluster_cols=["cluster", "date"], 9 | ) 10 | 11 | pw = PowerAnalysis( 12 | perturbator=perturbator, 13 | splitter=sw, 14 | analysis=analysis_gee_vainilla, 15 | n_simulations=3, 16 | ) 17 | 18 | power = pw.power_analysis(df) 19 | assert power >= 0 20 | assert power <= 1 21 | 22 | 23 | def test_power_analysis_config(df): 24 | config = PowerConfig( 25 | cluster_cols=["cluster", "date"], 26 | analysis="gee", 27 | perturbator="constant", 28 | splitter="clustered", 29 | n_simulations=4, 30 | average_effect=0.0, 31 | ) 32 | pw = PowerAnalysis.from_config(config) 33 | power = pw.power_analysis(df) 34 | assert power >= 0 35 | assert power <= 1 36 | 37 | 38 | def test_power_analysis_config_avg_effect(df): 39 | config = PowerConfig( 40 | cluster_cols=["cluster", "date"], 41 | analysis="gee", 42 | perturbator="constant", 43 | splitter="clustered", 44 | n_simulations=4, 45 | ) 46 | pw = PowerAnalysis.from_config(config) 47 | power = pw.power_analysis(df, average_effect=0.0) 48 | assert power >= 0 49 | assert power <= 1 50 | 51 | 52 | def test_power_analysis_dict(df): 53 | config = dict( 54 | analysis="ols_non_clustered", 55 | perturbator="constant", 56 | splitter="non_clustered", 57 | n_simulations=4, 58 | ) 59 | pw = PowerAnalysis.from_dict(config) 60 | power = pw.power_analysis(df, average_effect=0.0) 61 | assert power >= 0 62 | assert power <= 1 63 | 64 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0) 65 | assert power_verbose >= 0 66 | assert power_verbose <= 1 67 | 68 | 69 | def test_different_names(df): 70 | df = df.rename( 71 | columns={ 72 | "cluster": "cluster_0", 73 | "target": "target_0", 74 | "date": "date_0", 75 | } 76 | ) 77 | config = dict( 78 | cluster_cols=["cluster_0", "date_0"], 79 | analysis="ols_clustered", 80 | perturbator="constant", 81 | splitter="clustered", 82 | n_simulations=4, 83 | treatment_col="treatment_0", 84 | target_col="target_0", 85 | ) 86 | pw = PowerAnalysis.from_dict(config) 87 | power = pw.power_analysis(df, average_effect=0.0) 88 | assert power >= 0 89 | assert power <= 1 90 | 91 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0) 92 | assert power_verbose >= 0 93 | assert power_verbose <= 1 94 | 95 | 96 | def test_ttest(df): 97 | config = dict( 98 | cluster_cols=["cluster", "date"], 99 | analysis="ttest_clustered", 100 | perturbator="constant", 101 | splitter="clustered", 102 | n_simulations=4, 103 | ) 104 | pw = PowerAnalysis.from_dict(config) 105 | power = pw.power_analysis(df, average_effect=0.0) 106 | assert power >= 0 107 | assert power <= 1 108 | 109 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0) 110 | assert power_verbose >= 0 111 | assert power_verbose <= 1 112 | 113 | 114 | def test_paired_ttest(df): 115 | config = dict( 116 | cluster_cols=["cluster", "date"], 117 | strata_cols=["cluster"], 118 | analysis="paired_ttest_clustered", 119 | perturbator="constant", 120 | splitter="clustered", 121 | n_simulations=4, 122 | ) 123 | pw = PowerAnalysis.from_dict(config) 124 | 125 | power = pw.power_analysis(df, average_effect=0.0) 126 | assert power >= 0 127 | assert power <= 1 128 | 129 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0) 130 | assert power_verbose >= 0 131 | assert power_verbose <= 1 132 | 133 | 134 | def test_delta(delta_df): 135 | config = dict( 136 | cluster_cols=["user", "date"], 137 | scale_col="scale", 138 | analysis="delta", 139 | perturbator="constant", 140 | splitter="clustered", 141 | n_simulations=4, 142 | ) 143 | pw = PowerAnalysis.from_dict(config) 144 | 145 | delta_df = delta_df.drop(columns=["treatment"]) 146 | 147 | power = pw.power_analysis(delta_df, average_effect=0.0) 148 | assert power >= 0 149 | assert power <= 1 150 | 151 | 152 | def test_power_alpha(df): 153 | config = PowerConfig( 154 | analysis="ols_non_clustered", 155 | perturbator="constant", 156 | splitter="non_clustered", 157 | n_simulations=10, 158 | average_effect=0.0, 159 | alpha=0.05, 160 | ) 161 | pw = PowerAnalysis.from_config(config) 162 | power_50 = pw.power_analysis(df, alpha=0.5, verbose=True) 163 | power_01 = pw.power_analysis(df, alpha=0.01) 164 | 165 | assert power_50 > power_01 166 | 167 | 168 | def test_length_simulation(df): 169 | config = PowerConfig( 170 | cluster_cols=["cluster", "date"], 171 | analysis="ols_clustered", 172 | perturbator="constant", 173 | splitter="clustered", 174 | n_simulations=10, 175 | average_effect=0.0, 176 | alpha=0.05, 177 | ) 178 | pw = PowerAnalysis.from_config(config) 179 | i = 0 180 | for _ in pw.simulate_pvalue(df, n_simulations=5): 181 | i += 1 182 | assert i == 5 183 | 184 | 185 | def test_point_estimate_gee(df): 186 | config = PowerConfig( 187 | cluster_cols=["cluster", "date"], 188 | analysis="gee", 189 | perturbator="constant", 190 | splitter="clustered", 191 | n_simulations=10, 192 | average_effect=5.0, 193 | alpha=0.05, 194 | ) 195 | pw = PowerAnalysis.from_config(config) 196 | for point_estimate in pw.simulate_point_estimate(df, n_simulations=1): 197 | assert point_estimate > 0.0 198 | 199 | 200 | def test_point_estimate_clustered_ols(df): 201 | config = PowerConfig( 202 | cluster_cols=["cluster", "date"], 203 | analysis="ols_clustered", 204 | perturbator="constant", 205 | splitter="clustered", 206 | n_simulations=10, 207 | average_effect=5.0, 208 | alpha=0.05, 209 | ) 210 | pw = PowerAnalysis.from_config(config) 211 | for point_estimate in pw.simulate_point_estimate(df, n_simulations=1): 212 | assert point_estimate > 0.0 213 | 214 | 215 | def test_point_estimate_ols(df): 216 | config = PowerConfig( 217 | analysis="ols_non_clustered", 218 | perturbator="constant", 219 | splitter="non_clustered", 220 | n_simulations=10, 221 | average_effect=5.0, 222 | alpha=0.05, 223 | ) 224 | pw = PowerAnalysis.from_config(config) 225 | for point_estimate in pw.simulate_point_estimate(df, n_simulations=1): 226 | assert point_estimate > 0.0 227 | 228 | 229 | def test_power_line(df): 230 | config = PowerConfig( 231 | analysis="ols_non_clustered", 232 | perturbator="constant", 233 | splitter="non_clustered", 234 | n_simulations=10, 235 | average_effect=0.0, 236 | alpha=0.05, 237 | ) 238 | pw = PowerAnalysis.from_config(config) 239 | 240 | power_line = pw.power_line(df, average_effects=[0.0, 1.0], n_simulations=10) 241 | assert len(power_line) == 2 242 | assert power_line[0.0] >= 0 243 | assert power_line[1.0] >= power_line[0.0] 244 | 245 | 246 | def test_running_power(df): 247 | # given 248 | config = PowerConfig( 249 | analysis="ols_non_clustered", 250 | perturbator="constant", 251 | splitter="non_clustered", 252 | n_simulations=50, 253 | average_effect=0.0, 254 | alpha=0.05, 255 | ) 256 | pw = PowerAnalysis.from_config(config) 257 | previous_power = 0 258 | for i, power in enumerate(pw.running_power_analysis(df, average_effect=0.1)): 259 | # when: we've run enough simulations 260 | if i > 20: 261 | # then: the power should be stable (no changes bigger than 0.05) 262 | assert abs(power - previous_power) <= 0.05 263 | previous_power = power 264 | 265 | 266 | def test_hypothesis_from_dict(df): 267 | # given 268 | config_less = dict( 269 | analysis="ols_non_clustered", 270 | perturbator="constant", 271 | splitter="non_clustered", 272 | hypothesis="less", 273 | n_simulations=20, 274 | ) 275 | pw_less = PowerAnalysis.from_dict(config_less) 276 | 277 | config_greater = config_less.copy() 278 | config_greater["hypothesis"] = "greater" 279 | pw_greater = PowerAnalysis.from_dict(config_greater) 280 | 281 | config_two_sided = config_less.copy() 282 | config_two_sided["hypothesis"] = "two-sided" 283 | pw_two_sided = PowerAnalysis.from_dict(config_two_sided) 284 | 285 | # when 286 | power_less = pw_less.power_analysis(df, average_effect=1.0) 287 | power_greater = pw_greater.power_analysis(df, average_effect=1.0) 288 | power_two_sided = pw_two_sided.power_analysis(df, average_effect=1.0) 289 | 290 | # then 291 | assert power_less < power_greater 292 | assert power_less < power_two_sided 293 | -------------------------------------------------------------------------------- /tests/power_analysis/test_power_analysis_with_pre_experiment_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from cluster_experiments.experiment_analysis import SyntheticControlAnalysis 4 | from cluster_experiments.perturbator import ConstantPerturbator 5 | from cluster_experiments.power_analysis import PowerAnalysisWithPreExperimentData 6 | from cluster_experiments.random_splitter import FixedSizeClusteredSplitter 7 | from cluster_experiments.synthetic_control_utils import generate_synthetic_control_data 8 | 9 | 10 | def test_power_analysis_with_pre_experiment_data(): 11 | df = generate_synthetic_control_data(10, "2022-01-01", "2022-01-30") 12 | 13 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=1, cluster_cols=["user"]) 14 | 15 | perturbator = ConstantPerturbator( 16 | average_effect=0.3, 17 | ) 18 | 19 | analysis = SyntheticControlAnalysis( 20 | cluster_cols=["user"], time_col="date", intervention_date="2022-01-15" 21 | ) 22 | 23 | pw = PowerAnalysisWithPreExperimentData( 24 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50 25 | ) 26 | 27 | power = pw.power_analysis(df) 28 | pw.power_line(df, average_effects=[0.3, 0.4]) 29 | assert 0 <= power <= 1 30 | values = list(pw.power_line(df, average_effects=[0.3, 0.4]).values()) 31 | assert all(0 <= value <= 1 for value in values) 32 | 33 | 34 | def test_simulate_point_estimate(): 35 | df = generate_synthetic_control_data(10, "2022-01-01", "2022-01-30") 36 | 37 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=1, cluster_cols=["user"]) 38 | 39 | perturbator = ConstantPerturbator( 40 | average_effect=10, 41 | ) 42 | 43 | analysis = SyntheticControlAnalysis( 44 | cluster_cols=["user"], time_col="date", intervention_date="2022-01-15" 45 | ) 46 | 47 | pw = PowerAnalysisWithPreExperimentData( 48 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50 49 | ) 50 | 51 | point_estimates = list(pw.simulate_point_estimate(df)) 52 | assert ( 53 | 8 <= pd.Series(point_estimates).mean() <= 11 54 | ), "Point estimate is too far from the real effect." 55 | -------------------------------------------------------------------------------- /tests/power_analysis/test_power_raises.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis 4 | from cluster_experiments.perturbator import ConstantPerturbator 5 | from cluster_experiments.power_analysis import PowerAnalysis 6 | from cluster_experiments.random_splitter import ClusteredSplitter, NonClusteredSplitter 7 | 8 | 9 | def test_raises_cupac(): 10 | config = dict( 11 | cluster_cols=["cluster", "date"], 12 | analysis="gee", 13 | perturbator="constant", 14 | splitter="clustered", 15 | cupac_model="mean_cupac_model", 16 | n_simulations=4, 17 | ) 18 | with pytest.raises(AssertionError): 19 | PowerAnalysis.from_dict(config) 20 | 21 | 22 | def test_data_checks(df): 23 | config = dict( 24 | cluster_cols=["cluster", "date"], 25 | analysis="gee", 26 | perturbator="constant", 27 | splitter="clustered", 28 | n_simulations=4, 29 | ) 30 | pw = PowerAnalysis.from_dict(config) 31 | df["target"] = df["target"] == 1 32 | with pytest.raises(ValueError): 33 | pw.power_analysis(df, average_effect=0.0) 34 | 35 | 36 | def test_raise_target(): 37 | sw = ClusteredSplitter( 38 | cluster_cols=["cluster", "date"], 39 | ) 40 | 41 | perturbator = ConstantPerturbator( 42 | average_effect=0.1, 43 | target_col="another_target", 44 | ) 45 | 46 | analysis = GeeExperimentAnalysis( 47 | cluster_cols=["cluster", "date"], 48 | ) 49 | 50 | with pytest.raises(AssertionError): 51 | PowerAnalysis( 52 | perturbator=perturbator, 53 | splitter=sw, 54 | analysis=analysis, 55 | n_simulations=3, 56 | ) 57 | 58 | 59 | def test_raise_treatment(): 60 | sw = ClusteredSplitter( 61 | cluster_cols=["cluster", "date"], 62 | ) 63 | 64 | perturbator = ConstantPerturbator(average_effect=0.1, treatment="C") 65 | 66 | analysis = GeeExperimentAnalysis( 67 | cluster_cols=["cluster", "date"], 68 | ) 69 | 70 | with pytest.raises(AssertionError): 71 | PowerAnalysis( 72 | perturbator=perturbator, 73 | splitter=sw, 74 | analysis=analysis, 75 | n_simulations=3, 76 | ) 77 | 78 | 79 | def test_raise_treatment_col(): 80 | sw = ClusteredSplitter( 81 | cluster_cols=["cluster", "date"], 82 | ) 83 | 84 | perturbator = ConstantPerturbator( 85 | average_effect=0.1, 86 | treatment_col="another_treatment", 87 | ) 88 | 89 | analysis = GeeExperimentAnalysis( 90 | cluster_cols=["cluster", "date"], 91 | ) 92 | 93 | with pytest.raises(AssertionError): 94 | PowerAnalysis( 95 | perturbator=perturbator, 96 | splitter=sw, 97 | analysis=analysis, 98 | n_simulations=3, 99 | ) 100 | 101 | 102 | def test_raise_treatment_col_2(): 103 | sw = ClusteredSplitter( 104 | cluster_cols=["cluster", "date"], 105 | ) 106 | 107 | perturbator = ConstantPerturbator( 108 | average_effect=0.1, 109 | ) 110 | 111 | analysis = GeeExperimentAnalysis( 112 | cluster_cols=["cluster", "date"], 113 | treatment_col="another_treatment", 114 | ) 115 | 116 | with pytest.raises(AssertionError): 117 | PowerAnalysis( 118 | perturbator=perturbator, 119 | splitter=sw, 120 | analysis=analysis, 121 | n_simulations=3, 122 | ) 123 | 124 | 125 | def test_raise_cluster_cols(): 126 | sw = ClusteredSplitter( 127 | cluster_cols=["cluster"], 128 | ) 129 | 130 | perturbator = ConstantPerturbator( 131 | average_effect=0.1, 132 | target_col="another_target", 133 | ) 134 | 135 | analysis = GeeExperimentAnalysis( 136 | cluster_cols=["cluster", "date"], 137 | ) 138 | 139 | with pytest.raises(AssertionError): 140 | PowerAnalysis( 141 | perturbator=perturbator, 142 | splitter=sw, 143 | analysis=analysis, 144 | n_simulations=3, 145 | ) 146 | 147 | 148 | def test_raise_clustering_mismatch(): 149 | sw = NonClusteredSplitter() 150 | 151 | perturbator = ConstantPerturbator( 152 | average_effect=0.1, 153 | target_col="another_target", 154 | ) 155 | 156 | analysis = GeeExperimentAnalysis( 157 | cluster_cols=["cluster", "date"], 158 | ) 159 | 160 | with pytest.raises(AssertionError): 161 | PowerAnalysis( 162 | perturbator=perturbator, 163 | splitter=sw, 164 | analysis=analysis, 165 | n_simulations=3, 166 | ) 167 | 168 | 169 | def test_raise_treatment_same_control(): 170 | sw = ClusteredSplitter(cluster_cols=["cluster", "date"]) 171 | 172 | perturbator = ConstantPerturbator(average_effect=0.1) 173 | 174 | analysis = GeeExperimentAnalysis(cluster_cols=["cluster", "date"]) 175 | 176 | with pytest.raises(AssertionError): 177 | PowerAnalysis( 178 | perturbator=perturbator, 179 | splitter=sw, 180 | analysis=analysis, 181 | treatment="A", 182 | control="A", # same as treatment 183 | n_simulations=3, 184 | ) 185 | -------------------------------------------------------------------------------- /tests/power_analysis/test_seed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cluster_experiments.power_analysis import PowerAnalysis 4 | 5 | 6 | def get_config(perturbator: str) -> dict: 7 | return { 8 | "cluster_cols": ["cluster"], 9 | "analysis": "ols_clustered", 10 | "splitter": "clustered", 11 | "n_simulations": 15, 12 | "seed": 123, 13 | "perturbator": perturbator, 14 | } 15 | 16 | 17 | def test_power_analysis_constant_perturbator_seed(df): 18 | config_dict = get_config("constant") 19 | 20 | powers = [] 21 | for _ in range(10): 22 | pw = PowerAnalysis.from_dict(config_dict) 23 | powers.append(pw.power_analysis(df, average_effect=10)) 24 | 25 | assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10) 26 | 27 | 28 | def test_power_analysis_binary_perturbator_seed(df_binary): 29 | config_dict = get_config("binary") 30 | 31 | powers = [] 32 | for _ in range(10): 33 | pw = PowerAnalysis.from_dict(config_dict) 34 | powers.append(pw.power_analysis(df_binary, average_effect=0.08)) 35 | 36 | assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10) 37 | -------------------------------------------------------------------------------- /tests/power_analysis/test_switchback_power.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from cluster_experiments.power_analysis import PowerAnalysis 8 | from cluster_experiments.washover import ConstantWashover 9 | 10 | 11 | def test_switchback(switchback_power_analysis, df): 12 | power = switchback_power_analysis.power_analysis( 13 | df, 14 | average_effect=0.0, 15 | verbose=True, 16 | ) 17 | assert power >= 0 18 | assert power <= 1 19 | 20 | 21 | def test_switchback_hour(switchback_power_analysis, df): 22 | # Random dates in 2022-06-26 00:00:00 - 2022-06-26 23:00:00 23 | df["date"] = pd.to_datetime( 24 | np.random.randint( 25 | 1624646400, 26 | 1624732800, 27 | size=len(df), 28 | ), 29 | unit="s", 30 | ) 31 | power = switchback_power_analysis.power_analysis( 32 | df, 33 | average_effect=0.0, 34 | verbose=True, 35 | ) 36 | assert power >= 0 37 | assert power <= 1 38 | 39 | 40 | def test_switchback_washover(switchback_power_analysis, df): 41 | power_no_washover = switchback_power_analysis.power_analysis( 42 | df, 43 | average_effect=0.1, 44 | n_simulations=10, 45 | ) 46 | switchback_power_analysis.splitter.washover = ConstantWashover( 47 | washover_time_delta=datetime.timedelta(hours=23) 48 | ) 49 | 50 | power = switchback_power_analysis.power_analysis( 51 | df, 52 | average_effect=0.1, 53 | n_simulations=10, 54 | ) 55 | assert power >= 0 56 | assert power <= 1 57 | assert power_no_washover >= power 58 | 59 | 60 | def test_raise_no_delta(): 61 | with pytest.raises(ValueError): 62 | PowerAnalysis.from_dict( 63 | { 64 | "time_col": "date", 65 | "switch_frequency": "1D", 66 | "perturbator": "constant", 67 | "analysis": "ols_clustered", 68 | "splitter": "switchback_balance", 69 | "cluster_cols": ["cluster", "date"], 70 | "strata_cols": ["cluster"], 71 | "washover": "constant_washover", 72 | } 73 | ) 74 | 75 | 76 | def test_switchback_washover_config(switchback_washover, df): 77 | power = switchback_washover.power_analysis( 78 | df, 79 | average_effect=0.1, 80 | n_simulations=10, 81 | ) 82 | assert power >= 0 83 | assert power <= 1 84 | 85 | 86 | def test_switchback_strata(): 87 | # Define bihourly switchback splitter 88 | config = { 89 | "time_col": "time", 90 | "switch_frequency": "30min", 91 | "perturbator": "constant", 92 | "analysis": "ols_clustered", 93 | "splitter": "switchback_stratified", 94 | "cluster_cols": ["time", "city"], 95 | "strata_cols": ["day_of_week", "hour_of_day", "city"], 96 | "target_col": "y", 97 | "n_simulations": 3, 98 | } 99 | 100 | power = PowerAnalysis.from_dict(config) 101 | np.random.seed(42) 102 | df_raw = pd.DataFrame( 103 | { 104 | "time": pd.date_range("2021-01-01", "2021-01-10 23:59", freq="1T"), 105 | "y": np.random.randn(10 * 24 * 60), 106 | } 107 | ).assign( 108 | day_of_week=lambda df: df.time.dt.dayofweek, 109 | hour_of_day=lambda df: df.time.dt.hour, 110 | ) 111 | df = pd.concat([df_raw.assign(city=city) for city in ("TGN", "NYC", "LON")]) 112 | pw = power.power_analysis(df, average_effect=0.1) 113 | assert pw >= 0 114 | assert pw <= 1 115 | -------------------------------------------------------------------------------- /tests/power_config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/power_config/__init__.py -------------------------------------------------------------------------------- /tests/power_config/test_missing_arguments_error.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from cluster_experiments.power_config import MissingArgumentError, PowerConfig 4 | 5 | 6 | def test_missing_argument_segment_cols(caplog): 7 | msg = "segment_cols is required when using perturbator = segmented_beta_relative." 8 | with pytest.raises(MissingArgumentError, match=msg): 9 | PowerConfig( 10 | cluster_cols=["cluster"], 11 | analysis="mlm", 12 | perturbator="segmented_beta_relative", 13 | splitter="non_clustered", 14 | n_simulations=4, 15 | average_effect=1.5, 16 | ) 17 | -------------------------------------------------------------------------------- /tests/power_config/test_params_flow.py: -------------------------------------------------------------------------------- 1 | from cluster_experiments.power_analysis import NormalPowerAnalysis 2 | 3 | 4 | def test_cov_type_flows(): 5 | # given 6 | config = { 7 | "analysis": "ols_non_clustered", 8 | "perturbator": "constant", 9 | "splitter": "non_clustered", 10 | "cov_type": "HC1", 11 | } 12 | 13 | # when 14 | power_analysis = NormalPowerAnalysis.from_dict(config) 15 | 16 | # then 17 | assert power_analysis.analysis.cov_type == "HC1" 18 | 19 | 20 | def test_cov_type_default(): 21 | # given 22 | config = { 23 | "analysis": "ols_non_clustered", 24 | "perturbator": "constant", 25 | "splitter": "non_clustered", 26 | } 27 | 28 | # when 29 | power_analysis = NormalPowerAnalysis.from_dict(config) 30 | 31 | # then 32 | assert power_analysis.analysis.cov_type == "HC3" 33 | 34 | 35 | def test_covariate_interaction_flows(): 36 | # given 37 | config = { 38 | "analysis": "ols_non_clustered", 39 | "perturbator": "constant", 40 | "splitter": "non_clustered", 41 | "add_covariate_interaction": True, 42 | } 43 | 44 | # when 45 | power_analysis = NormalPowerAnalysis.from_dict(config) 46 | 47 | # then 48 | assert power_analysis.analysis.add_covariate_interaction is True 49 | -------------------------------------------------------------------------------- /tests/power_config/test_warnings_superfluous_params.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from cluster_experiments.power_config import PowerConfig 4 | 5 | 6 | def test_config_warning_superfluous_param_switch_frequency(caplog): 7 | msg = "switch_frequency = 1H has no effect with splitter = non_clustered. Overriding switch_frequency to None." 8 | with caplog.at_level(logging.WARNING): 9 | PowerConfig( 10 | cluster_cols=["cluster", "date"], 11 | analysis="ols_non_clustered", 12 | perturbator="constant", 13 | splitter="non_clustered", 14 | n_simulations=4, 15 | average_effect=1.5, 16 | switch_frequency="1H", 17 | ) 18 | assert msg in caplog.text 19 | 20 | 21 | def test_config_warning_superfluous_param_washover_time_delta(caplog): 22 | msg = "washover_time_delta = 30 has no effect with splitter = non_clustered. Overriding washover_time_delta to None." 23 | with caplog.at_level(logging.WARNING): 24 | PowerConfig( 25 | cluster_cols=["cluster", "date"], 26 | analysis="ols_non_clustered", 27 | perturbator="constant", 28 | splitter="non_clustered", 29 | n_simulations=4, 30 | average_effect=1.5, 31 | washover_time_delta=30, 32 | ) 33 | assert msg in caplog.text 34 | 35 | 36 | def test_config_warning_superfluous_param_washover(caplog): 37 | msg = "washover = constant_washover has no effect with splitter = non_clustered. Overriding washover to ." 38 | with caplog.at_level(logging.WARNING): 39 | PowerConfig( 40 | cluster_cols=["cluster", "date"], 41 | analysis="ols_non_clustered", 42 | perturbator="constant", 43 | splitter="non_clustered", 44 | n_simulations=4, 45 | average_effect=1.5, 46 | washover="constant_washover", 47 | ) 48 | assert msg in caplog.text 49 | 50 | 51 | def test_config_warning_superfluous_param_time_col(caplog): 52 | msg = "time_col = datetime has no effect with splitter = non_clustered. Overriding time_col to None." 53 | with caplog.at_level(logging.WARNING): 54 | PowerConfig( 55 | cluster_cols=["cluster", "date"], 56 | analysis="ols_non_clustered", 57 | perturbator="constant", 58 | splitter="non_clustered", 59 | n_simulations=4, 60 | average_effect=1.5, 61 | time_col="datetime", 62 | ) 63 | assert msg in caplog.text 64 | 65 | 66 | def test_config_warning_superfluous_param_perturbator(caplog): 67 | msg = "scale = 0.5 has no effect with perturbator = constant. Overriding scale to None." 68 | with caplog.at_level(logging.WARNING): 69 | PowerConfig( 70 | cluster_cols=["cluster", "date"], 71 | analysis="ols_non_clustered", 72 | perturbator="constant", 73 | splitter="non_clustered", 74 | n_simulations=4, 75 | average_effect=1.5, 76 | scale=0.5, 77 | ) 78 | assert msg in caplog.text 79 | 80 | 81 | def test_config_warning_superfluous_param_strata_cols(caplog): 82 | msg = "strata_cols = ['group'] has no effect with splitter = non_clustered. Overriding strata_cols to None." 83 | with caplog.at_level(logging.WARNING): 84 | PowerConfig( 85 | cluster_cols=["cluster", "date"], 86 | analysis="ols_non_clustered", 87 | perturbator="constant", 88 | splitter="non_clustered", 89 | n_simulations=4, 90 | average_effect=1.5, 91 | strata_cols=["group"], 92 | ) 93 | assert msg in caplog.text 94 | 95 | 96 | def test_config_warning_superfluous_param_splitter_weights(caplog): 97 | msg = "splitter_weights = [0.5, 0.5] has no effect with splitter = clustered_stratified. Overriding splitter_weights to None." 98 | with caplog.at_level(logging.WARNING): 99 | PowerConfig( 100 | cluster_cols=["cluster", "date"], 101 | analysis="ols_non_clustered", 102 | perturbator="constant", 103 | splitter="clustered_stratified", 104 | splitter_weights=[0.5, 0.5], 105 | n_simulations=4, 106 | average_effect=1.5, 107 | ) 108 | assert msg in caplog.text 109 | 110 | 111 | def test_config_warning_superfluous_param_agg_col(caplog): 112 | msg = "agg_col = agg_col has no effect with cupac_model = . Overriding agg_col to ." 113 | with caplog.at_level(logging.WARNING): 114 | PowerConfig( 115 | cluster_cols=["cluster", "date"], 116 | analysis="ols_non_clustered", 117 | perturbator="constant", 118 | splitter="non_clustered", 119 | n_simulations=4, 120 | average_effect=1.5, 121 | agg_col="agg_col", 122 | ) 123 | assert msg in caplog.text 124 | 125 | 126 | def test_config_warning_superfluous_param_smoothing_factor(caplog): 127 | msg = "smoothing_factor = 0.5 has no effect with cupac_model = . Overriding smoothing_factor to 20." 128 | with caplog.at_level(logging.WARNING): 129 | PowerConfig( 130 | cluster_cols=["cluster", "date"], 131 | analysis="ols_non_clustered", 132 | perturbator="constant", 133 | splitter="non_clustered", 134 | n_simulations=4, 135 | average_effect=1.5, 136 | smoothing_factor=0.5, 137 | ) 138 | assert msg in caplog.text 139 | 140 | 141 | def test_config_warning_superfluous_param_features_cupac_model(caplog): 142 | msg = "features_cupac_model = ['feature1'] has no effect with cupac_model = . Overriding features_cupac_model to None." 143 | with caplog.at_level(logging.WARNING): 144 | PowerConfig( 145 | cluster_cols=["cluster", "date"], 146 | analysis="ols_clustered", 147 | perturbator="constant", 148 | splitter="non_clustered", 149 | n_simulations=4, 150 | average_effect=1.5, 151 | features_cupac_model=["feature1"], 152 | ) 153 | assert msg in caplog.text 154 | 155 | 156 | def test_config_warning_superfluous_param_covariates(caplog): 157 | msg = "covariates = ['covariate1'] has no effect with analysis = ttest_clustered. Overriding covariates to None." 158 | with caplog.at_level(logging.WARNING): 159 | PowerConfig( 160 | cluster_cols=["cluster", "date"], 161 | analysis="ttest_clustered", 162 | perturbator="constant", 163 | splitter="non_clustered", 164 | n_simulations=4, 165 | average_effect=1.5, 166 | covariates=["covariate1"], 167 | ) 168 | assert msg in caplog.text 169 | -------------------------------------------------------------------------------- /tests/splitter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/splitter/__init__.py -------------------------------------------------------------------------------- /tests/splitter/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from cluster_experiments.power_analysis import PowerAnalysis 8 | from cluster_experiments.random_splitter import ( 9 | BalancedSwitchbackSplitter, 10 | StratifiedSwitchbackSplitter, 11 | SwitchbackSplitter, 12 | ) 13 | 14 | 15 | @pytest.fixture 16 | def switchback_splitter(): 17 | return SwitchbackSplitter( 18 | time_col="time", switch_frequency="1D", cluster_cols=["time"] 19 | ) 20 | 21 | 22 | @pytest.fixture 23 | def switchback_splitter_config(): 24 | config = { 25 | "time_col": "time", 26 | "switch_frequency": "1D", 27 | "perturbator": "constant", 28 | "analysis": "ols_clustered", 29 | "splitter": "switchback", 30 | "cluster_cols": ["time"], 31 | } 32 | return PowerAnalysis.from_dict(config).splitter 33 | 34 | 35 | switchback_splitter_parametrize = pytest.mark.parametrize( 36 | "splitter", 37 | [ 38 | "switchback_splitter", 39 | "switchback_splitter_config", 40 | ], 41 | ) 42 | 43 | 44 | @pytest.fixture 45 | def balanced_splitter(): 46 | return BalancedSwitchbackSplitter( 47 | time_col="time", switch_frequency="1D", cluster_cols=["time"] 48 | ) 49 | 50 | 51 | @pytest.fixture 52 | def balanced_splitter_config(): 53 | config = { 54 | "time_col": "time", 55 | "switch_frequency": "1D", 56 | "perturbator": "constant", 57 | "analysis": "ols_clustered", 58 | "splitter": "switchback_balance", 59 | "cluster_cols": ["time"], 60 | } 61 | return PowerAnalysis.from_dict(config).splitter 62 | 63 | 64 | balanced_splitter_parametrize = pytest.mark.parametrize( 65 | "splitter", 66 | [ 67 | "balanced_splitter", 68 | "balanced_splitter_config", 69 | ], 70 | ) 71 | 72 | 73 | @pytest.fixture 74 | def stratified_switchback_splitter(): 75 | return StratifiedSwitchbackSplitter( 76 | time_col="time", 77 | switch_frequency="1D", 78 | strata_cols=["day_of_week"], 79 | cluster_cols=["time"], 80 | ) 81 | 82 | 83 | @pytest.fixture 84 | def stratified_switchback_splitter_config(): 85 | config = { 86 | "time_col": "time", 87 | "switch_frequency": "1D", 88 | "perturbator": "constant", 89 | "analysis": "ols_clustered", 90 | "splitter": "switchback_stratified", 91 | "cluster_cols": ["time"], 92 | "strata_cols": ["day_of_week"], 93 | } 94 | return PowerAnalysis.from_dict(config).splitter 95 | 96 | 97 | stratified_splitter_parametrize = pytest.mark.parametrize( 98 | "splitter", 99 | [ 100 | "stratified_switchback_splitter", 101 | "stratified_switchback_splitter_config", 102 | ], 103 | ) 104 | 105 | 106 | @pytest.fixture 107 | def date_df(): 108 | return pd.DataFrame( 109 | { 110 | "time": pd.date_range("2020-01-01", "2020-01-10", freq="1D"), 111 | "y": np.random.randn(10), 112 | } 113 | ) 114 | 115 | 116 | @pytest.fixture 117 | def biweekly_df(): 118 | df = pd.DataFrame( 119 | { 120 | "time": pd.date_range("2020-01-01", "2020-01-14", freq="1D"), 121 | "y": np.random.randn(14), 122 | } 123 | ).assign( 124 | day_of_week=lambda df: df["time"].dt.day_name(), 125 | ) 126 | return pd.concat([df.assign(cluster=f"Cluster {i}") for i in range(10)]) 127 | 128 | 129 | @pytest.fixture 130 | def washover_df(): 131 | # Define data with random dates 132 | df_raw = pd.DataFrame( 133 | { 134 | "time": pd.date_range("2021-01-01", "2021-01-02", freq="1min")[ 135 | np.random.randint(24 * 60, size=7 * 24 * 60) 136 | ], 137 | "y": np.random.randn(7 * 24 * 60), 138 | } 139 | ).assign( 140 | day_of_week=lambda df: df.time.dt.dayofweek, 141 | hour_of_day=lambda df: df.time.dt.hour, 142 | ) 143 | df = pd.concat([df_raw.assign(city=city) for city in ("TGN", "NYC", "LON", "REU")]) 144 | return df 145 | 146 | 147 | @pytest.fixture 148 | def washover_base_df(): 149 | df = pd.DataFrame( 150 | { 151 | "original___time": [ 152 | pd.to_datetime("2022-01-01 00:20:00"), 153 | pd.to_datetime("2022-01-01 00:31:00"), 154 | pd.to_datetime("2022-01-01 01:14:00"), 155 | pd.to_datetime("2022-01-01 01:31:00"), 156 | ], 157 | "treatment": ["A", "A", "B", "B"], 158 | "city": ["TGN"] * 4, 159 | } 160 | ).assign(time=lambda x: x["original___time"].dt.floor("1h")) 161 | return df 162 | 163 | 164 | @pytest.fixture 165 | def washover_df_no_switch(): 166 | df = pd.DataFrame( 167 | { 168 | "original___time": [ 169 | pd.to_datetime("2022-01-01 00:20:00"), 170 | pd.to_datetime("2022-01-01 00:31:00"), 171 | pd.to_datetime("2022-01-01 01:14:00"), 172 | pd.to_datetime("2022-01-01 01:31:00"), 173 | pd.to_datetime("2022-01-01 02:01:00"), 174 | pd.to_datetime("2022-01-01 02:31:00"), 175 | ], 176 | "treatment": ["A", "A", "B", "B", "B", "B"], 177 | "city": ["TGN"] * 6, 178 | } 179 | ).assign(time=lambda x: x["original___time"].dt.floor("1h")) 180 | return df 181 | 182 | 183 | @pytest.fixture 184 | def washover_df_multi_city(): 185 | df = pd.DataFrame( 186 | { 187 | "original___time": [ 188 | pd.to_datetime("2022-01-01 01:14:00"), 189 | pd.to_datetime("2022-01-01 00:20:00"), 190 | pd.to_datetime("2022-01-01 00:31:00"), 191 | pd.to_datetime("2022-01-01 02:01:00"), 192 | pd.to_datetime("2022-01-01 01:31:00"), 193 | pd.to_datetime("2022-01-01 02:31:00"), 194 | ] 195 | * 2, 196 | "treatment": ["B", "A", "A", "B", "B", "B"] 197 | + ["A", "A", "A", "B", "A", "B"], 198 | "city": ["TGN"] * 6 + ["BCN"] * 6, 199 | } 200 | ).assign(time=lambda x: x["original___time"].dt.floor("1h")) 201 | return df 202 | 203 | 204 | @pytest.fixture 205 | def washover_split_df(n): 206 | # Return 207 | return pd.DataFrame( 208 | { 209 | # Random time each minute in 2022-01-01, length 1000 210 | "time": pd.date_range("2022-01-01", "2022-01-02", freq="1min")[ 211 | np.random.randint(24 * 60, size=n) 212 | ], 213 | "city": random.choices(["TGN", "NYC", "LON", "REU"], k=n), 214 | } 215 | ) 216 | 217 | 218 | @pytest.fixture 219 | def washover_split_no_city_df(n): 220 | # Return 221 | return pd.DataFrame( 222 | { 223 | # Random time each minute in 2022-01-01, length 1000 224 | "time": pd.date_range("2022-01-01", "2022-01-02", freq="1min")[ 225 | np.random.randint(24 * 60, size=n) 226 | ], 227 | } 228 | ) 229 | -------------------------------------------------------------------------------- /tests/splitter/test_fixed_size_clusters_splitter.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments import ConstantPerturbator, PowerAnalysis 7 | from cluster_experiments.experiment_analysis import ClusteredOLSAnalysis 8 | from cluster_experiments.random_splitter import ( 9 | FixedSizeClusteredSplitter, 10 | ) 11 | 12 | 13 | def test_predefined_treatment_clusters_splitter(): 14 | # Create a DataFrame with mock data 15 | df = pd.DataFrame({"cluster": ["A", "A", "B", "B", "C", "C", "D", "D", "E", "E"]}) 16 | 17 | split = FixedSizeClusteredSplitter(cluster_cols=["cluster"], n_treatment_clusters=1) 18 | 19 | df = split.assign_treatment_df(df) 20 | 21 | # Verify that the treatments were assigned correctly 22 | assert df[split.treatment_col].value_counts()[split.treatments[0]] == 8 23 | assert df[split.treatment_col].value_counts()[split.treatments[1]] == 2 24 | 25 | 26 | def test_sample_treatment_with_balanced_clusters(): 27 | splitter = FixedSizeClusteredSplitter(cluster_cols=["city"], n_treatment_clusters=2) 28 | df = pd.DataFrame({"city": ["A", "B", "C", "D"]}) 29 | treatments = splitter.sample_treatment(df) 30 | assert len(treatments) == len(df) 31 | assert treatments.count("A") == 2 32 | assert treatments.count("B") == 2 33 | 34 | 35 | def generate_data(N, start_date, end_date): 36 | dates = pd.date_range(start_date, end_date, freq="d") 37 | 38 | users = [f"User {i}" for i in range(N)] 39 | 40 | combinations = list(product(users, dates)) 41 | 42 | target_values = np.random.normal(100, 10, size=len(combinations)) 43 | 44 | df = pd.DataFrame(combinations, columns=["user", "date"]) 45 | df["target"] = target_values 46 | 47 | return df 48 | 49 | 50 | def test_ols_fixed_size_treatment(): 51 | df = generate_data(100, "2021-01-01", "2021-01-15") 52 | 53 | analysis = ClusteredOLSAnalysis(cluster_cols=["user"]) 54 | 55 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=1, cluster_cols=["user"]) 56 | 57 | perturbator = ConstantPerturbator( 58 | average_effect=0, 59 | ) 60 | 61 | pw = PowerAnalysis( 62 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=200 63 | ) 64 | pw.power_analysis(df, average_effect=0) 65 | # todo finish this test, the power shouldn't be too high 66 | -------------------------------------------------------------------------------- /tests/splitter/test_switchback_splitter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from cluster_experiments.experiment_analysis import OLSAnalysis 5 | from cluster_experiments.perturbator import ConstantPerturbator 6 | from cluster_experiments.power_analysis import PowerAnalysis 7 | from cluster_experiments.random_splitter import SwitchbackSplitter 8 | from tests.splitter.conftest import ( 9 | balanced_splitter_parametrize, 10 | stratified_splitter_parametrize, 11 | switchback_splitter_parametrize, 12 | ) 13 | 14 | add_cluster_cols_parametrize = pytest.mark.parametrize( 15 | "add_cluster_cols", [True, False] 16 | ) 17 | 18 | 19 | @switchback_splitter_parametrize 20 | def test_switchback_splitter(splitter, date_df, request): 21 | switchback_splitter = request.getfixturevalue(splitter) 22 | 23 | treatment_assignment = switchback_splitter.assign_treatment_df(date_df) 24 | assert "time" in switchback_splitter.cluster_cols 25 | 26 | # Only 1 treatment per date 27 | assert (treatment_assignment.groupby("time")["treatment"].nunique() == 1).all() 28 | 29 | 30 | @switchback_splitter_parametrize 31 | def test_clustered_switchback_splitter(splitter, biweekly_df, request): 32 | switchback_splitter = request.getfixturevalue(splitter) 33 | biweekly_df_long = pd.concat([biweekly_df for _ in range(3)]) 34 | switchback_splitter.cluster_cols = ["cluster", "time"] 35 | treatment_assignment = switchback_splitter.assign_treatment_df(biweekly_df_long) 36 | assert "time" in switchback_splitter.cluster_cols 37 | 38 | # Only 1 treatment per cluster 39 | assert ( 40 | treatment_assignment.groupby(["cluster", "time"])["treatment"].nunique() == 1 41 | ).all() 42 | 43 | 44 | @balanced_splitter_parametrize 45 | @add_cluster_cols_parametrize 46 | def test_clustered_switchback_splitter_balance( 47 | splitter, add_cluster_cols, biweekly_df, request 48 | ): 49 | balanced_splitter = request.getfixturevalue(splitter) 50 | 51 | if add_cluster_cols: 52 | balanced_splitter.cluster_cols += ["cluster"] 53 | treatment_assignment = balanced_splitter.assign_treatment_df(biweekly_df) 54 | assert "time" in balanced_splitter.cluster_cols 55 | # Assert that the treatment assignment is balanced 56 | assert (treatment_assignment.treatment.value_counts() == 70).all() 57 | 58 | 59 | @stratified_splitter_parametrize 60 | @add_cluster_cols_parametrize 61 | def test_stratified_splitter(splitter, add_cluster_cols, biweekly_df, request): 62 | stratified_switchback_splitter = request.getfixturevalue(splitter) 63 | 64 | if add_cluster_cols: 65 | stratified_switchback_splitter.cluster_cols += ["cluster"] 66 | stratified_switchback_splitter.strata_cols += ["cluster"] 67 | 68 | treatment_assignment = stratified_switchback_splitter.assign_treatment_df( 69 | biweekly_df 70 | ) 71 | assert "time" in stratified_switchback_splitter.cluster_cols 72 | assert (treatment_assignment.treatment.value_counts() == 70).all() 73 | # Per cluster, there are 2 treatments. Per day of week too 74 | for col in ["cluster", "day_of_week"]: 75 | assert (treatment_assignment.groupby([col])["treatment"].nunique() == 2).all() 76 | 77 | # Check stratification. Count day_of_week and treatment, we should always 78 | # have the same number of observations. Same for cluster 79 | for col in ["cluster", "day_of_week"]: 80 | assert treatment_assignment.groupby([col, "treatment"]).size().nunique() == 1 81 | 82 | 83 | def test_raise_time_col_not_in_df(): 84 | with pytest.raises( 85 | AssertionError, 86 | match="in switchback splitters, time_col must be in cluster_cols", 87 | ): 88 | sw = SwitchbackSplitter(time_col="time") 89 | perturbator = ConstantPerturbator() 90 | analysis = OLSAnalysis() 91 | _ = PowerAnalysis( 92 | splitter=sw, 93 | perturbator=perturbator, 94 | analysis=analysis, 95 | ) 96 | 97 | 98 | def test_raise_time_col_not_in_df_splitter(): 99 | with pytest.raises( 100 | AssertionError, 101 | match="in switchback splitters, time_col must be in cluster_cols", 102 | ): 103 | data = pd.DataFrame( 104 | { 105 | "activation_time": pd.date_range( 106 | start="2021-01-01", periods=10, freq="D" 107 | ), 108 | "city": ["A" for _ in range(10)], 109 | } 110 | ) 111 | time_col = "activation_time" 112 | switch_frequency = "6h" 113 | cluster_cols = ["city"] 114 | 115 | splitter = SwitchbackSplitter( 116 | time_col=time_col, 117 | cluster_cols=cluster_cols, 118 | switch_frequency=switch_frequency, 119 | ) 120 | _ = splitter.assign_treatment_df(data) 121 | -------------------------------------------------------------------------------- /tests/splitter/test_time_col.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from tests.splitter.conftest import switchback_splitter_parametrize 6 | 7 | 8 | @pytest.fixture 9 | def yearly_df(): 10 | return pd.DataFrame( 11 | { 12 | "time": pd.date_range("2021-01-01", "2021-12-31", freq="1D"), 13 | "y": np.random.randn(365), 14 | } 15 | ) 16 | 17 | 18 | @pytest.fixture 19 | def hourly_df(): 20 | return pd.DataFrame( 21 | { 22 | "time": pd.date_range("2021-01-01", "2021-01-10 23:00", freq="1H"), 23 | "y": np.random.randn(10 * 24), # 10 days, 24 hours per day, 1 hour per row 24 | } 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def minute_df(): 30 | return pd.DataFrame( 31 | { 32 | "time": pd.date_range("2021-01-01", "2021-01-10 23:59", freq="1T"), 33 | "y": np.random.randn( 34 | 10 * 24 * 60 35 | ), # 10 days, 24 hours per day, 1 hour per row 36 | } 37 | ) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "df,switchback_freq,n_splits", 42 | [ 43 | ("date_df", "1D", 10), 44 | ("date_df", "2D", 5), 45 | ("date_df", "4D", 3), 46 | ("hourly_df", "H", 240), 47 | ("hourly_df", "2H", 120), 48 | ("hourly_df", "4H", 60), 49 | ("hourly_df", "6H", 40), 50 | ("hourly_df", "12H", 20), 51 | ("minute_df", "min", 14400), 52 | ("minute_df", "2min", 7200), 53 | ("minute_df", "4min", 3600), 54 | ("minute_df", "30min", 480), 55 | ("minute_df", "60min", 240), 56 | ("yearly_df", "W", 53), 57 | ("yearly_df", "M", 12), 58 | ], 59 | ) 60 | @switchback_splitter_parametrize 61 | def test_date_col(splitter, df, switchback_freq, n_splits, request): 62 | date_df = request.getfixturevalue(df) 63 | splitter = request.getfixturevalue(splitter) 64 | splitter.switch_frequency = switchback_freq 65 | time_col = splitter._get_time_col_cluster(date_df) 66 | assert time_col.dtype == "datetime64[ns]" 67 | assert time_col.nunique() == n_splits 68 | 69 | if "W" not in switchback_freq and "M" not in switchback_freq: 70 | pd.testing.assert_series_equal( 71 | time_col, date_df["time"].dt.floor(switchback_freq) 72 | ) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "switchback_freq,day_of_week", 77 | [ 78 | ("W-MON", "Tuesday"), 79 | ("W-TUE", "Wednesday"), 80 | ("W-SUN", "Monday"), 81 | ("W", "Monday"), 82 | ], 83 | ) 84 | @switchback_splitter_parametrize 85 | def test_week_col_date(splitter, date_df, switchback_freq, day_of_week, request): 86 | splitter = request.getfixturevalue(splitter) 87 | splitter.switch_frequency = switchback_freq 88 | time_col = splitter._get_time_col_cluster(date_df) 89 | assert time_col.dtype == "datetime64[ns]" 90 | pd.testing.assert_series_equal( 91 | time_col, date_df["time"].dt.to_period(switchback_freq).dt.start_time 92 | ) 93 | assert time_col.nunique() == 2 94 | # Assert that the first day of the week is correct 95 | assert ( 96 | time_col.dt.day_name().iloc[0] == day_of_week 97 | ), f"{switchback_freq} failed, day_of_week is {time_col.dt.day_name().iloc[0]}" 98 | -------------------------------------------------------------------------------- /tests/splitter/test_washover.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import timedelta 3 | 4 | import pandas as pd 5 | import pytest 6 | 7 | from cluster_experiments import SwitchbackSplitter 8 | from cluster_experiments.washover import ConstantWashover, EmptyWashover 9 | 10 | 11 | @pytest.mark.parametrize("minutes, n_rows", [(30, 2), (10, 4), (15, 3)]) 12 | def test_constant_washover_base(minutes, n_rows, washover_base_df): 13 | out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover( 14 | df=washover_base_df, 15 | truncated_time_col="time", 16 | cluster_cols=["city", "time"], 17 | treatment_col="treatment", 18 | original_time_col="original___time", 19 | ) 20 | 21 | assert len(out_df) == n_rows 22 | assert (out_df["original___time"].dt.minute > minutes).all() 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "minutes, n_rows, df", 27 | [ 28 | (30, 4, "washover_df_no_switch"), 29 | (30, 4 + 4, "washover_df_multi_city"), 30 | ], 31 | ) 32 | def test_constant_washover_no_switch(minutes, n_rows, df, request): 33 | washover_df = request.getfixturevalue(df) 34 | 35 | out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover( 36 | df=washover_df, 37 | truncated_time_col="time", 38 | cluster_cols=["city", "time"], 39 | treatment_col="treatment", 40 | ) 41 | assert len(out_df) == n_rows 42 | if df == "washover_df_no_switch": 43 | # Check that, after 2022-01-01 02:00:00, we keep all the rows of the original 44 | # dataframe 45 | assert washover_df.query("time >= '2022-01-01 02:00:00'").equals( 46 | out_df.query("time >= '2022-01-01 02:00:00'") 47 | ) 48 | # Check that, after 2022-01-01 01:00:00, we don't have the same rows as the 49 | # original dataframe 50 | assert not washover_df.query("time >= '2022-01-01 01:00:00'").equals( 51 | out_df.query("time >= '2022-01-01 01:00:00'") 52 | ) 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "minutes, n,", 57 | [ 58 | (15, 10000), 59 | ], 60 | ) 61 | def test_constant_washover_split(minutes, n, washover_split_df): 62 | washover = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)) 63 | 64 | splitter = SwitchbackSplitter( 65 | washover=washover, 66 | time_col="time", 67 | cluster_cols=["city", "time"], 68 | treatment_col="treatment", 69 | switch_frequency="30T", 70 | ) 71 | 72 | out_df = splitter.assign_treatment_df(df=washover_split_df) 73 | 74 | # Assert A and B in out_df 75 | assert set(out_df["treatment"].unique()) == {"A", "B"} 76 | 77 | # We need to have less than 10000 rows 78 | assert len(out_df) < n 79 | 80 | # We need to have more than 5000 rows (this is because ABB doesn't do washover on the second split) 81 | assert len(out_df) > n / 2 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "minutes, n,", 86 | [ 87 | (15, 1000), 88 | ], 89 | ) 90 | def test_constant_washover_split_no_city(minutes, n, washover_split_no_city_df): 91 | washover = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)) 92 | 93 | splitter = SwitchbackSplitter( 94 | washover=washover, 95 | time_col="time", 96 | cluster_cols=["time"], 97 | treatment_col="treatment", 98 | switch_frequency="30T", 99 | ) 100 | 101 | out_df = splitter.assign_treatment_df(df=washover_split_no_city_df) 102 | 103 | # Assert A and B in out_df 104 | assert set(out_df["treatment"].unique()) == {"A", "B"} 105 | 106 | # We need to have less than 10000 rows 107 | assert len(out_df) < n 108 | 109 | # We need to have more than 5000 rows (this is because ABB doesn't do washover on the second split) 110 | assert len(out_df) > n / 2 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "minutes, n", 115 | [ 116 | (15, 1000), 117 | ], 118 | ) 119 | def test_no_washover_split(minutes, n, washover_split_df): 120 | washover = EmptyWashover() 121 | 122 | splitter = SwitchbackSplitter( 123 | washover=washover, 124 | time_col="time", 125 | cluster_cols=["city", "time"], 126 | treatment_col="treatment", 127 | switch_frequency="30T", 128 | ) 129 | 130 | out_df = splitter.assign_treatment_df(df=washover_split_df) 131 | 132 | # Assert A and B in out_df 133 | assert set(out_df["treatment"].unique()) == {"A", "B"} 134 | 135 | # We need to have exactly 1000 rows 136 | assert len(out_df) == n 137 | 138 | 139 | @pytest.mark.parametrize( 140 | "minutes, n_rows, df", 141 | [ 142 | (30, 4, "washover_df_no_switch"), 143 | (30, 4 + 4, "washover_df_multi_city"), 144 | ], 145 | ) 146 | def test_constant_washover_no_switch_instantiated_int(minutes, n_rows, df, request): 147 | washover_df = request.getfixturevalue(df) 148 | 149 | @dataclass 150 | class Cfg: 151 | washover_time_delta: int 152 | 153 | cw = ConstantWashover.from_config(Cfg(minutes)) 154 | out_df = cw.washover( 155 | df=washover_df, 156 | truncated_time_col="time", 157 | cluster_cols=["city", "time"], 158 | treatment_col="treatment", 159 | ) 160 | assert len(out_df) == n_rows 161 | if df == "washover_df_no_switch": 162 | # Check that, after 2022-01-01 02:00:00, we keep all the rows of the original 163 | # dataframe 164 | assert washover_df.query("time >= '2022-01-01 02:00:00'").equals( 165 | out_df.query("time >= '2022-01-01 02:00:00'") 166 | ) 167 | # Check that, after 2022-01-01 01:00:00, we don't have the same rows as the 168 | # original dataframe 169 | assert not washover_df.query("time >= '2022-01-01 01:00:00'").equals( 170 | out_df.query("time >= '2022-01-01 01:00:00'") 171 | ) 172 | 173 | 174 | def test_truncated_time_not_in_cluster_cols(): 175 | msg = "is not in the cluster columns." 176 | df = pd.DataFrame(columns=["time_bin", "city", "time", "treatment"]) 177 | 178 | # Check that the truncated_time_col is also included in the cluster_cols, 179 | # An error is raised because "time_bin" is not in the cluster_cols 180 | with pytest.raises(ValueError, match=msg): 181 | 182 | ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover( 183 | df=df, 184 | truncated_time_col="time_bin", 185 | cluster_cols=["city"], 186 | original_time_col="time", 187 | treatment_col="treatment", 188 | ) 189 | 190 | 191 | def test_missing_original_time_col(): 192 | msg = "columns and/or not specified as an input." 193 | df = pd.DataFrame(columns=["time_bin", "city", "treatment"]) 194 | 195 | # Check that the original_time_col is specifed as an input and in the dataframe columns 196 | # An error is raised because "time" is not specified as an input for the washover 197 | with pytest.raises(ValueError, match=msg): 198 | 199 | ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover( 200 | df=df, 201 | truncated_time_col="time_bin", 202 | cluster_cols=["city", "time_bin"], 203 | treatment_col="treatment", 204 | ) 205 | 206 | 207 | def test_cluster_cols_missing_in_df(): 208 | msg = "cluster is not in the dataframe columns." 209 | df = pd.DataFrame(columns=["time_bin", "time", "treatment"]) 210 | 211 | # Check that all the cluster_cols are in the dataframe columns 212 | # An error is raised because "city" is not in the dataframe columns 213 | with pytest.raises(ValueError, match=msg): 214 | 215 | ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover( 216 | df=df, 217 | truncated_time_col="time_bin", 218 | cluster_cols=["city", "time_bin"], 219 | original_time_col="time", 220 | treatment_col="treatment", 221 | ) 222 | -------------------------------------------------------------------------------- /tests/test_docs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mktestdocs import check_docstring, check_md_file, get_codeblock_members 3 | 4 | from cluster_experiments import ( 5 | AnalysisPlan, 6 | BalancedClusteredSplitter, 7 | BalancedSwitchbackSplitter, 8 | BetaRelativePerturbator, 9 | BetaRelativePositivePerturbator, 10 | BinaryPerturbator, 11 | ClusteredOLSAnalysis, 12 | ClusteredSplitter, 13 | ConstantPerturbator, 14 | ConstantWashover, 15 | DeltaMethodAnalysis, 16 | Dimension, 17 | EmptyRegressor, 18 | EmptyWashover, 19 | ExperimentAnalysis, 20 | FixedSizeClusteredSplitter, 21 | GeeExperimentAnalysis, 22 | HypothesisTest, 23 | Metric, 24 | MLMExperimentAnalysis, 25 | NonClusteredSplitter, 26 | NormalPerturbator, 27 | NormalPowerAnalysis, 28 | OLSAnalysis, 29 | PairedTTestClusteredAnalysis, 30 | Perturbator, 31 | PowerAnalysis, 32 | PowerConfig, 33 | RandomSplitter, 34 | RatioMetric, 35 | RelativeMixedPerturbator, 36 | RelativePositivePerturbator, 37 | RepeatedSampler, 38 | SegmentedBetaRelativePerturbator, 39 | SimpleMetric, 40 | StratifiedClusteredSplitter, 41 | StratifiedSwitchbackSplitter, 42 | SwitchbackSplitter, 43 | SyntheticControlAnalysis, 44 | TargetAggregation, 45 | TTestClusteredAnalysis, 46 | UniformPerturbator, 47 | Variant, 48 | ) 49 | from cluster_experiments.utils import _original_time_column 50 | 51 | all_objects = [ 52 | BalancedClusteredSplitter, 53 | BinaryPerturbator, 54 | ClusteredOLSAnalysis, 55 | ClusteredSplitter, 56 | DeltaMethodAnalysis, 57 | EmptyRegressor, 58 | ExperimentAnalysis, 59 | FixedSizeClusteredSplitter, 60 | GeeExperimentAnalysis, 61 | NonClusteredSplitter, 62 | OLSAnalysis, 63 | Perturbator, 64 | PowerAnalysis, 65 | NormalPowerAnalysis, 66 | PowerConfig, 67 | RandomSplitter, 68 | StratifiedClusteredSplitter, 69 | SyntheticControlAnalysis, 70 | TargetAggregation, 71 | TTestClusteredAnalysis, 72 | PairedTTestClusteredAnalysis, 73 | ConstantPerturbator, 74 | UniformPerturbator, 75 | _original_time_column, 76 | ConstantWashover, 77 | EmptyWashover, 78 | BalancedSwitchbackSplitter, 79 | StratifiedSwitchbackSplitter, 80 | SwitchbackSplitter, 81 | RepeatedSampler, 82 | MLMExperimentAnalysis, 83 | RelativePositivePerturbator, 84 | NormalPerturbator, 85 | BetaRelativePositivePerturbator, 86 | BetaRelativePerturbator, 87 | SegmentedBetaRelativePerturbator, 88 | AnalysisPlan, 89 | Metric, 90 | SimpleMetric, 91 | RatioMetric, 92 | Dimension, 93 | Variant, 94 | HypothesisTest, 95 | RelativeMixedPerturbator, 96 | ] 97 | 98 | 99 | def flatten(items): 100 | """Flattens a list""" 101 | return [item for sublist in items for item in sublist] 102 | 103 | 104 | # This way we ensure that each item in `all_members` points to a method 105 | # that could have a docstring. 106 | all_members = flatten([get_codeblock_members(o) for o in all_objects]) 107 | 108 | 109 | @pytest.mark.parametrize("func", all_members, ids=lambda d: d.__qualname__) 110 | def test_function_docstrings(func): 111 | """Test the python example in each method in each object.""" 112 | check_docstring(obj=func) 113 | 114 | 115 | @pytest.mark.parametrize( 116 | "fpath", 117 | [ 118 | "README.md", 119 | ], 120 | ) 121 | def test_quickstart_docs_file(fpath): 122 | """Test the quickstart files.""" 123 | check_md_file(fpath, memory=True) 124 | -------------------------------------------------------------------------------- /tests/test_non_clustered.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.model_selection import train_test_split 6 | 7 | from cluster_experiments.cupac import TargetAggregation 8 | from cluster_experiments.experiment_analysis import OLSAnalysis 9 | from cluster_experiments.perturbator import ConstantPerturbator 10 | from cluster_experiments.power_analysis import PowerAnalysis 11 | from cluster_experiments.random_splitter import NonClusteredSplitter 12 | from tests.utils import generate_non_clustered_data 13 | 14 | N = 10_000 15 | n_users = 1000 16 | random.seed(41) 17 | 18 | 19 | @pytest.fixture 20 | def df(): 21 | return generate_non_clustered_data(N, n_users) 22 | 23 | 24 | @pytest.fixture 25 | def df_feats(): 26 | df = generate_non_clustered_data(N, n_users) 27 | df["x1"] = np.random.normal(0, 1, N) 28 | df["x2"] = np.random.normal(0, 1, N) 29 | return df 30 | 31 | 32 | @pytest.fixture 33 | def cupac_power_analysis(): 34 | sw = NonClusteredSplitter() 35 | 36 | perturbator = ConstantPerturbator( 37 | average_effect=0.1, 38 | ) 39 | 40 | analysis = OLSAnalysis( 41 | covariates=["estimate_target"], 42 | ) 43 | 44 | target_agg = TargetAggregation( 45 | agg_col="user", 46 | ) 47 | 48 | return PowerAnalysis( 49 | perturbator=perturbator, 50 | splitter=sw, 51 | analysis=analysis, 52 | cupac_model=target_agg, 53 | n_simulations=3, 54 | ) 55 | 56 | 57 | @pytest.fixture 58 | def cupac_from_config(): 59 | return PowerAnalysis.from_dict( 60 | dict( 61 | analysis="ols_non_clustered", 62 | perturbator="constant", 63 | splitter="non_clustered", 64 | cupac_model="mean_cupac_model", 65 | average_effect=0.1, 66 | n_simulations=4, 67 | covariates=["estimate_target"], 68 | agg_col="user", 69 | ) 70 | ) 71 | 72 | 73 | def test_power_analysis(cupac_power_analysis, df): 74 | pre_df, df = train_test_split(df) 75 | power = cupac_power_analysis.power_analysis(df, pre_df) 76 | assert power >= 0 77 | assert power <= 1 78 | 79 | 80 | def test_power_analysis_config(cupac_from_config, df): 81 | pre_df, df = train_test_split(df) 82 | power = cupac_from_config.power_analysis(df, pre_df) 83 | assert power >= 0 84 | assert power <= 1 85 | 86 | 87 | def test_splitter(df): 88 | splitter = NonClusteredSplitter() 89 | # Check counts A and B are 50/50 90 | treatment_assignment = splitter.assign_treatment_df(df) 91 | n_a = treatment_assignment.treatment.value_counts()["A"] 92 | assert n_a >= -200 + len(treatment_assignment) / 2 93 | assert n_a <= 200 + len(treatment_assignment) / 2 94 | 95 | 96 | def test_splitter_weighted(df): 97 | splitter = NonClusteredSplitter(splitter_weights=[0.1, 0.9]) 98 | # Check counts A and B are 10/90 99 | treatment_assignment = splitter.assign_treatment_df(df) 100 | n_a = treatment_assignment.treatment.value_counts()["A"] 101 | assert n_a >= -100 + len(treatment_assignment) * 0.1 102 | assert n_a <= 100 + len(treatment_assignment) * 0.1 103 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from cluster_experiments.synthetic_control_utils import loss_w 5 | from cluster_experiments.utils import _get_mapping_key 6 | 7 | 8 | def test_get_mapping_key(): 9 | with pytest.raises(KeyError): 10 | mapping = {"a": 1, "b": 2} 11 | _get_mapping_key(mapping, "c") 12 | 13 | 14 | def test_loss_w(): 15 | W = np.array([2, 1]) # Weights vector 16 | X = np.array([[1, 2], [3, 4], [5, 6]]) # Input matrix 17 | y = np.array([6, 14, 22]) # Actual outputs 18 | 19 | # Calculate expected result 20 | # Predictions are calculated as follows: 21 | # [1*2 + 2*1, 3*2 + 4*1, 5*2 + 6*1] = [4, 10, 16] 22 | # RMSE is sqrt(mean([(6-4)^2, (14-10)^2, (22-16)^2])) 23 | expected_rmse = np.sqrt( 24 | np.mean((np.array([6, 14, 22]) - np.array([4, 10, 16])) ** 2) 25 | ) 26 | 27 | # Call the function 28 | calculated_rmse = loss_w(W, X, y) 29 | 30 | # Assert if the calculated RMSE matches the expected RMSE 31 | assert np.isclose( 32 | calculated_rmse, expected_rmse 33 | ), f"Expected RMSE: {expected_rmse}, but got: {calculated_rmse}" 34 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cluster_experiments.random_splitter import RandomSplitter 7 | 8 | TARGETS = { 9 | "binary": lambda x: np.random.choice([0, 1], size=x), 10 | "continuous": lambda x: np.random.normal(0, 1, x), 11 | } 12 | 13 | 14 | def combine_columns(df, cols_list): 15 | # combines columns for testing the stratified splitter in case of multiple strata or clusters 16 | if len(cols_list) > 1: 17 | return df[cols_list].agg("-".join, axis=1) 18 | else: 19 | return df[cols_list] 20 | 21 | 22 | def assert_balanced_strata( 23 | splitter: RandomSplitter, 24 | df: pd.DataFrame, 25 | strata_cols: List[str], 26 | cluster_cols: List[str], 27 | treatments: List[str], 28 | ): 29 | # asserts the balance of the stratified splitter for a given input data frame 30 | treatment_df = splitter.assign_treatment_df(df) 31 | 32 | treatment_df_unique = treatment_df[ 33 | strata_cols + cluster_cols + ["treatment"] 34 | ].drop_duplicates() 35 | 36 | treatment_df_unique["clusters_test"] = combine_columns( 37 | treatment_df_unique, cluster_cols 38 | ) 39 | treatment_df_unique["strata_test"] = combine_columns( 40 | treatment_df_unique, strata_cols 41 | ) 42 | 43 | for treatment in treatments: 44 | for stratus in treatment_df_unique["strata_test"].unique(): 45 | assert ( 46 | treatment_df_unique.query(f"strata_test == '{stratus}'")["treatment"] 47 | .value_counts(normalize=True)[treatment] 48 | .squeeze() 49 | ) == 1 / len(treatments) 50 | 51 | 52 | def generate_random_data(clusters, dates, N, target="continuous"): 53 | # Generate random data with clusters and target 54 | users = [f"User {i}" for i in range(1000)] 55 | 56 | target_values = TARGETS[target](N) 57 | df = pd.DataFrame( 58 | { 59 | "cluster": np.random.choice(clusters, size=N), 60 | "target": target_values, 61 | "user": np.random.choice(users, size=N), 62 | "date": np.random.choice(dates, size=N), 63 | } 64 | ) 65 | 66 | return df 67 | 68 | 69 | def generate_non_clustered_data(N, n_users): 70 | users = [f"User {i}" for i in range(n_users)] 71 | df = pd.DataFrame( 72 | { 73 | "target": np.random.normal(0, 1, size=N), 74 | "user": np.random.choice(users, size=N), 75 | } 76 | ) 77 | return df 78 | 79 | 80 | def generate_clustered_data() -> pd.DataFrame: 81 | analysis_df = pd.DataFrame( 82 | { 83 | "country_code": ["ES"] * 4 + ["IT"] * 4 + ["PL"] * 4 + ["RO"] * 4, 84 | "city_code": ["BCN", "BCN", "MAD", "BCN"] 85 | + ["NAP"] * 4 86 | + ["WAW"] * 4 87 | + ["BUC"] * 4, 88 | "user_id": [1, 1, 2, 1, 3, 4, 5, 6, 7, 8, 8, 8, 9, 9, 9, 10], 89 | "date": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"] * 4, 90 | "treatment": [ 91 | "A", 92 | "A", 93 | "B", 94 | "A", 95 | "B", 96 | "B", 97 | "A", 98 | "B", 99 | "B", 100 | "A", 101 | "A", 102 | "A", 103 | "B", 104 | "B", 105 | "B", 106 | "A", 107 | ], # Randomization is done at user level, so same user will always have same treatment 108 | "target": [0.01] * 15 + [0.1], 109 | } 110 | ) 111 | return analysis_df 112 | 113 | 114 | def generate_ratio_metric_data( 115 | dates, 116 | N, 117 | user_target_means: Optional[np.ndarray] = None, 118 | num_users=2000, 119 | treatment_effect=0.25, 120 | ) -> pd.DataFrame: 121 | 122 | if user_target_means is None: 123 | user_target_means = np.random.normal(0.3, 0.15, num_users) 124 | 125 | user_sessions = np.random.choice(num_users, N) 126 | user_dates = np.random.choice(dates, N) 127 | 128 | # assign treatment groups 129 | treatment = np.random.choice([0, 1], num_users) 130 | 131 | x1 = np.random.normal(0, 0.01, N) 132 | x2 = np.random.normal(0, 0.01, N) 133 | noise = np.random.normal(0, 0.01, N) 134 | 135 | # create target rate per session level 136 | target_percent_per_session = ( 137 | treatment_effect * treatment[user_sessions] 138 | + user_target_means[user_sessions] 139 | + x1 140 | + x2**2 141 | + noise 142 | ) 143 | 144 | # Remove <0 or >1 145 | target_percent_per_session[target_percent_per_session > 1] = 1 146 | target_percent_per_session[target_percent_per_session < 0] = 0 147 | 148 | targets_observed = np.random.binomial(1, target_percent_per_session) 149 | 150 | # rename treatment array 0-->A, 1-->B 151 | mapped_treatment = np.where(treatment == 0, "A", "B") 152 | 153 | return pd.DataFrame( 154 | { 155 | "user": user_sessions, 156 | "date": user_dates, 157 | "treatment": mapped_treatment[user_sessions], 158 | "target": targets_observed, 159 | "scale": np.ones_like(user_sessions), 160 | "x1": x1, 161 | "x2": x2, 162 | "user_target_means": user_target_means[user_sessions], 163 | } 164 | ) 165 | -------------------------------------------------------------------------------- /theme/flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/theme/flow.png -------------------------------------------------------------------------------- /theme/icon-cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/theme/icon-cluster.png -------------------------------------------------------------------------------- /theme/icon-cluster.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | isolated_build = True 3 | envlist = py38,py39,py310 4 | 5 | [testenv] 6 | deps = .[dev] 7 | 8 | commands = 9 | black --check cluster_experiments 10 | coverage run --source=cluster_experiments --branch -m pytest . 11 | coverage report -m 12 | --------------------------------------------------------------------------------