├── .github
└── workflows
│ ├── docs.yml
│ ├── periodic.yml
│ ├── release.yml
│ └── testing.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── cluster_experiments
├── __init__.py
├── cupac.py
├── experiment_analysis.py
├── inference
│ ├── __init__.py
│ ├── analysis_plan.py
│ ├── analysis_plan_config.py
│ ├── analysis_results.py
│ ├── dimension.py
│ ├── hypothesis_test.py
│ ├── metric.py
│ └── variant.py
├── perturbator.py
├── power_analysis.py
├── power_config.py
├── random_splitter.py
├── synthetic_control_utils.py
├── utils.py
└── washover.py
├── docs
├── aa_test.ipynb
├── analysis_with_different_hypotheses.ipynb
├── api
│ ├── analysis_plan.md
│ ├── analysis_results.md
│ ├── cupac_model.md
│ ├── dimension.md
│ ├── experiment_analysis.md
│ ├── hypothesis_test.md
│ ├── metric.md
│ ├── perturbator.md
│ ├── power_analysis.md
│ ├── power_config.md
│ ├── random_splitter.md
│ ├── variant.md
│ └── washover.md
├── create_custom_classes.ipynb
├── cupac_example.ipynb
├── delta_method.ipynb
├── e2e_mde.ipynb
├── experiment_analysis.ipynb
├── multivariate.ipynb
├── normal_power.ipynb
├── normal_power_lines.ipynb
├── paired_ttest.ipynb
├── plot_calendars.ipynb
├── plot_calendars_hours.ipynb
├── switchback.ipynb
├── synthetic_control.ipynb
└── washover_example.ipynb
├── examples
├── cupac_example_gbm.py
├── cupac_example_target_mean.py
├── long_example.py
├── parallel_example.py
├── short_example.py
├── short_example_config.py
├── short_example_dict.py
├── short_example_paired_ttest.py
└── short_example_synthetic_control.py
├── mkdocs.yml
├── pyproject.toml
├── ruff.toml
├── setup.py
├── tests
├── __init__.py
├── analysis
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_analysis.py
│ ├── test_formula.py
│ ├── test_hypothesis.py
│ ├── test_ols_analysis.py
│ └── test_synthetic_analysis.py
├── cupac
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_aggregator.py
│ └── test_cupac_handler.py
├── inference
│ ├── __init__.py
│ ├── test_analysis_plan.py
│ ├── test_analysis_plan_config.py
│ ├── test_analysis_results.py
│ ├── test_dimension.py
│ ├── test_hypothesis_test.py
│ ├── test_metric.py
│ └── test_variant.py
├── perturbator
│ ├── __init__.py
│ ├── conftest.py
│ └── test_perturbator.py
├── power_analysis
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_cupac_power.py
│ ├── test_multivariate.py
│ ├── test_normal_power_analysis.py
│ ├── test_parallel.py
│ ├── test_power_analysis.py
│ ├── test_power_analysis_with_pre_experiment_data.py
│ ├── test_power_raises.py
│ ├── test_seed.py
│ └── test_switchback_power.py
├── power_config
│ ├── __init__.py
│ ├── test_missing_arguments_error.py
│ ├── test_params_flow.py
│ └── test_warnings_superfluous_params.py
├── splitter
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_fixed_size_clusters_splitter.py
│ ├── test_splitter.py
│ ├── test_switchback_splitter.py
│ ├── test_time_col.py
│ └── test_washover.py
├── test_docs.py
├── test_non_clustered.py
├── test_utils.py
└── utils.py
├── theme
├── flow.png
├── icon-cluster.png
└── icon-cluster.svg
├── tox.ini
└── uv.lock
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Docs
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 | - uses: actions/setup-python@v1
14 | with:
15 | python-version: 3.9
16 | - run: cp README.md docs/index.md
17 | - run: cp -r theme docs/theme
18 | - run: |
19 | make install-dev-gh-action
20 | source .venv/bin/activate
21 | pip install lxml_html_clean
22 | mkdocs gh-deploy --force
23 |
--------------------------------------------------------------------------------
/.github/workflows/periodic.yml:
--------------------------------------------------------------------------------
1 | name: Release unit Tests
2 |
3 | on:
4 | schedule:
5 | - cron: "0 0 * * *"
6 |
7 | jobs:
8 | test-release-ubuntu:
9 | runs-on: ${{ matrix.os }}
10 | strategy:
11 | matrix:
12 | python-version: ['3.9', '3.10', '3.11', '3.12']
13 | os: [ubuntu-latest]
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v1
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install cluster-experiments
25 | pip freeze
26 | - name: Test with pytest
27 | run: |
28 | make install-test
29 | source .venv/bin/activate
30 | make test
31 |
32 | test-release-windows:
33 | runs-on: ${{ matrix.os }}
34 | strategy:
35 | matrix:
36 | python-version: ['3.9', '3.10', '3.11', '3.12']
37 | os: [windows-latest]
38 |
39 | steps:
40 | - uses: actions/checkout@v2
41 | - name: Set up Python ${{ matrix.python-version }}
42 | uses: actions/setup-python@v1
43 | with:
44 | python-version: ${{ matrix.python-version }}
45 | - name: Install dependencies
46 | run: |
47 | python -m pip install --upgrade pip
48 | pip install cluster-experiments
49 | pip freeze
50 | - name: Test with pytest
51 | run: |
52 | make install-test
53 | .venv\\Scripts\\activate
54 | make test
55 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release to PyPI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build-and-deploy:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout code
14 | uses: actions/checkout@v2
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v2
18 | with:
19 | python-version: 3.9
20 |
21 | - name: Install dependencies
22 | run: make install-dev-gh-action
23 |
24 | - name: Prepare dist/
25 | run: make prep-dist
26 |
27 | - name: Publish
28 | uses: pypa/gh-action-pypi-publish@release/v1
29 | with:
30 | skip-existing: true
31 | user: __token__
32 | password: ${{ secrets.PYPI_API_TOKEN }}
33 | verify-metadata: false
34 |
--------------------------------------------------------------------------------
/.github/workflows/testing.yml:
--------------------------------------------------------------------------------
1 | name: Code Checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test-coverage:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | matrix:
16 | python-version: ['3.9', '3.10', '3.11', '3.12']
17 |
18 | steps:
19 | - uses: actions/checkout@v2
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v1
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Install Testing Dependencies
25 | run: make install-test
26 | - name: Automated Checking Mechanism
27 | run: |
28 | source .venv/bin/activate
29 | make check
30 | - name: Code coverage
31 | uses: codecov/codecov-action@v4
32 | with:
33 | fail_ci_if_error: true
34 | files: coverage.xml
35 | env:
36 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # VsCode
163 | .vscode/
164 |
165 | .DS_Store
166 |
167 | docs/index.md
168 | docs/theme/
169 | todos.txt
170 |
171 | #
172 | experiments/
173 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.3.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-added-large-files
11 | - repo: https://github.com/psf/black
12 | rev: 25.1.0
13 | hooks:
14 | - id: black
15 | language_version: python3
16 | - repo: https://github.com/charliermarsh/ruff-pre-commit
17 | rev: 'v0.0.261'
18 | hooks:
19 | - id: ruff
20 | args: [--fix, --exit-non-zero-on-fix]
21 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing
2 |
3 | uv is needed as package manager. If you haven't installed it, run the installation command:
4 |
5 | ```bash
6 | curl -LsSf https://astral.sh/uv/install.sh | sh
7 | ```
8 |
9 | ### Project setup
10 |
11 | Clone repo and go to the project directory:
12 |
13 | ```bash
14 | git clone git@github.com:david26694/cluster-experiments.git
15 | cd cluster-experiments
16 | ```
17 |
18 | Create virtual environment and activate it:
19 |
20 | ```bash
21 | uv venv -p 3.10
22 | source .venv/bin/activate
23 | ```
24 |
25 | After creating the virtual environment, install the project dependencies:
26 |
27 | ```bash
28 | make install-dev
29 | ```
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 david26694
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: clean clean-test clean-pyc clean-build
2 |
3 | black:
4 | black cluster_experiments tests setup.py --check
5 |
6 | ruff:
7 | ruff check cluster_experiments tests setup.py
8 |
9 | test:
10 | pytest --cov=./cluster_experiments
11 |
12 | coverage_xml:
13 | coverage xml
14 |
15 | check: black ruff test coverage_xml
16 |
17 | install:
18 | python -m pip install uv
19 | uv sync
20 |
21 | install-dev:
22 | pip install --upgrade pip setuptools wheel
23 | python -m pip install uv
24 | uv sync --extra "dev"
25 | pre-commit install
26 |
27 | install-dev-gh-action:
28 | pip install --upgrade pip setuptools wheel
29 | python -m pip install uv
30 | uv sync --extra "dev"
31 |
32 | install-test:
33 | pip install --upgrade pip setuptools wheel
34 | python -m pip install uv
35 | uv sync --extra "test"
36 |
37 | docs-deploy:
38 | mkdocs gh-deploy
39 |
40 | docs-serve:
41 | cp README.md docs/index.md
42 | rm -rf docs/theme
43 | cp -r theme docs/theme/
44 | mkdocs serve
45 |
46 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
47 |
48 | clean-build: ## remove build artifacts
49 | rm -fr build/
50 | rm -fr dist/
51 | rm -fr .eggs/
52 | find . -name '*.egg-info' -exec rm -fr {} +
53 | find . -name '*.egg' -exec rm -f {} +
54 |
55 | clean-pyc: ## remove Python file artifacts
56 | find . -name '*.pyc' -exec rm -f {} +
57 | find . -name '*.pyo' -exec rm -f {} +
58 | find . -name '*~' -exec rm -f {} +
59 | find . -name '__pycache__' -exec rm -fr {} +
60 |
61 | clean-test: ## remove test and coverage artifacts
62 | rm -fr .tox/
63 | rm -f .coverage
64 | rm -fr htmlcov/
65 | rm -fr .pytest_cache
66 |
67 | prep-dist: clean
68 | uv build
69 |
70 | pypi: prep-dist
71 | twine upload --repository cluster-experiments dist/*
72 |
73 | pypi-gh-actions: prep-dist
74 | # todo: fix this
75 | twine upload --skip-existing dist/*
76 |
77 | # Report log
78 | report-log:
79 | pytest --report-log experiments/reportlog.jsonl
80 |
81 | duration-insights:
82 | pytest-duration-insights explore experiments/reportlog.jsonl
83 |
--------------------------------------------------------------------------------
/cluster_experiments/__init__.py:
--------------------------------------------------------------------------------
1 | from cluster_experiments.cupac import EmptyRegressor, TargetAggregation
2 | from cluster_experiments.experiment_analysis import (
3 | ClusteredOLSAnalysis,
4 | DeltaMethodAnalysis,
5 | ExperimentAnalysis,
6 | GeeExperimentAnalysis,
7 | MLMExperimentAnalysis,
8 | OLSAnalysis,
9 | PairedTTestClusteredAnalysis,
10 | SyntheticControlAnalysis,
11 | TTestClusteredAnalysis,
12 | )
13 | from cluster_experiments.inference.analysis_plan import AnalysisPlan
14 | from cluster_experiments.inference.dimension import Dimension
15 | from cluster_experiments.inference.hypothesis_test import HypothesisTest
16 | from cluster_experiments.inference.metric import Metric, RatioMetric, SimpleMetric
17 | from cluster_experiments.inference.variant import Variant
18 | from cluster_experiments.perturbator import (
19 | BetaRelativePerturbator,
20 | BetaRelativePositivePerturbator,
21 | BinaryPerturbator,
22 | ConstantPerturbator,
23 | NormalPerturbator,
24 | Perturbator,
25 | RelativeMixedPerturbator,
26 | RelativePositivePerturbator,
27 | SegmentedBetaRelativePerturbator,
28 | UniformPerturbator,
29 | )
30 | from cluster_experiments.power_analysis import NormalPowerAnalysis, PowerAnalysis
31 | from cluster_experiments.power_config import PowerConfig
32 | from cluster_experiments.random_splitter import (
33 | BalancedClusteredSplitter,
34 | BalancedSwitchbackSplitter,
35 | ClusteredSplitter,
36 | FixedSizeClusteredSplitter,
37 | NonClusteredSplitter,
38 | RandomSplitter,
39 | RepeatedSampler,
40 | StratifiedClusteredSplitter,
41 | StratifiedSwitchbackSplitter,
42 | SwitchbackSplitter,
43 | )
44 | from cluster_experiments.washover import ConstantWashover, EmptyWashover, Washover
45 |
46 | __all__ = [
47 | "ExperimentAnalysis",
48 | "GeeExperimentAnalysis",
49 | "DeltaMethodAnalysis",
50 | "OLSAnalysis",
51 | "BinaryPerturbator",
52 | "Perturbator",
53 | "ConstantPerturbator",
54 | "UniformPerturbator",
55 | "RelativePositivePerturbator",
56 | "NormalPerturbator",
57 | "BetaRelativePositivePerturbator",
58 | "BetaRelativePerturbator",
59 | "SegmentedBetaRelativePerturbator",
60 | "PowerAnalysis",
61 | "NormalPowerAnalysis",
62 | "PowerConfig",
63 | "EmptyRegressor",
64 | "TargetAggregation",
65 | "BalancedClusteredSplitter",
66 | "ClusteredSplitter",
67 | "RandomSplitter",
68 | "NonClusteredSplitter",
69 | "StratifiedClusteredSplitter",
70 | "SwitchbackSplitter",
71 | "BalancedSwitchbackSplitter",
72 | "StratifiedSwitchbackSplitter",
73 | "RepeatedSampler",
74 | "ClusteredOLSAnalysis",
75 | "TTestClusteredAnalysis",
76 | "PairedTTestClusteredAnalysis",
77 | "EmptyWashover",
78 | "ConstantWashover",
79 | "Washover",
80 | "MLMExperimentAnalysis",
81 | "SyntheticControlAnalysis",
82 | "FixedSizeClusteredSplitter",
83 | "AnalysisPlan",
84 | "Metric",
85 | "SimpleMetric",
86 | "RatioMetric",
87 | "Dimension",
88 | "Variant",
89 | "HypothesisTest",
90 | "RelativeMixedPerturbator",
91 | ]
92 |
--------------------------------------------------------------------------------
/cluster_experiments/cupac.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import pandas as pd
4 | from numpy.typing import ArrayLike
5 | from sklearn.base import BaseEstimator
6 | from sklearn.utils.validation import NotFittedError, check_is_fitted
7 |
8 |
9 | class EmptyRegressor(BaseEstimator):
10 | """
11 | Empty regressor class. It does not do anything, used to glue the code of other estimators and PowerAnalysis
12 |
13 | Each Regressor should have:
14 | - fit method: Uses pre experiment data to fit some kind of model to be used as a covariate and reduce variance.
15 | - predict method: Uses the fitted model to add the covariate on the experiment data.
16 |
17 | It can add aggregates of the target in older data as a covariate, or a model (cupac) to predict the target.
18 | """
19 |
20 | @classmethod
21 | def from_config(cls, config):
22 | return cls()
23 |
24 |
25 | class TargetAggregation(BaseEstimator):
26 | """
27 | Adds average of target using pre-experiment data
28 |
29 | Args:
30 | agg_col: Column to group by to aggregate target
31 | target_col: Column to aggregate
32 | smoothing_factor: Smoothing factor for the smoothed mean
33 | Usage:
34 | ```python
35 | import pandas as pd
36 | from cluster_experiments.cupac import TargetAggregation
37 |
38 | df = pd.DataFrame({"agg_col": ["a", "a", "b", "b", "c", "c"], "target_col": [1, 2, 3, 4, 5, 6]})
39 | new_df = pd.DataFrame({"agg_col": ["a", "a", "b", "b", "c", "c"]})
40 | target_agg = TargetAggregation("agg_col", "target_col")
41 | target_agg.fit(df.drop(columns="target_col"), df["target_col"])
42 | df_with_target_agg = target_agg.predict(new_df)
43 | print(df_with_target_agg)
44 | ```
45 | """
46 |
47 | def __init__(
48 | self,
49 | agg_col: str,
50 | target_col: str = "target",
51 | smoothing_factor: int = 20,
52 | ):
53 | self.agg_col = agg_col
54 | self.target_col = target_col
55 | self.smoothing_factor = smoothing_factor
56 | self.is_empty = False
57 | self.mean_target_col = f"{self.target_col}_mean"
58 | self.smooth_mean_target_col = f"{self.target_col}_smooth_mean"
59 | self.pre_experiment_agg_df = pd.DataFrame()
60 |
61 | def _get_pre_experiment_mean(self, pre_experiment_df: pd.DataFrame) -> float:
62 | return pre_experiment_df[self.target_col].mean()
63 |
64 | def fit(self, X: pd.DataFrame, y: pd.Series) -> "TargetAggregation":
65 | """Fits "target encoder" model to pre-experiment data"""
66 | pre_experiment_df = X.copy()
67 | pre_experiment_df[self.target_col] = y
68 |
69 | self.pre_experiment_mean = self._get_pre_experiment_mean(pre_experiment_df)
70 | self.pre_experiment_agg_df = (
71 | pre_experiment_df.assign(count=1)
72 | .groupby(self.agg_col, as_index=False)
73 | .agg({self.target_col: "sum", "count": "sum"})
74 | .assign(
75 | **{
76 | self.mean_target_col: lambda x: x[self.target_col] / x["count"],
77 | self.smooth_mean_target_col: lambda x: (
78 | x[self.target_col]
79 | + self.smoothing_factor * self.pre_experiment_mean
80 | )
81 | / (x["count"] + self.smoothing_factor),
82 | }
83 | )
84 | .drop(columns=["count", self.target_col])
85 | )
86 | return self
87 |
88 | def predict(self, X: pd.DataFrame) -> ArrayLike:
89 | """Adds average target of pre-experiment data to experiment data"""
90 | return (
91 | X.merge(self.pre_experiment_agg_df, how="left", on=self.agg_col)[
92 | self.smooth_mean_target_col
93 | ]
94 | .fillna(self.pre_experiment_mean)
95 | .values
96 | )
97 |
98 | @classmethod
99 | def from_config(cls, config):
100 | """Creates TargetAggregation from PowerConfig"""
101 | return cls(
102 | agg_col=config.agg_col,
103 | target_col=config.target_col,
104 | smoothing_factor=config.smoothing_factor,
105 | )
106 |
107 |
108 | class CupacHandler:
109 | """
110 | CupacHandler class. It handles operations related to the cupac model.
111 |
112 | Its main goal is to call the add_covariates method, where it will add the ouptut from the cupac model,
113 | and this should be used as covariates in the regression method for the hypothesis test.
114 | """
115 |
116 | def __init__(
117 | self,
118 | cupac_model: Optional[BaseEstimator] = None,
119 | target_col: str = "target",
120 | features_cupac_model: Optional[List[str]] = None,
121 | cache_fit: bool = True,
122 | ):
123 | self.cupac_model: BaseEstimator = cupac_model or EmptyRegressor()
124 | self.target_col = target_col
125 | self.cupac_outcome_name = f"estimate_{target_col}"
126 | self.features_cupac_model: List[str] = features_cupac_model or []
127 | self.is_cupac = not isinstance(self.cupac_model, EmptyRegressor)
128 | self.cache_fit = cache_fit
129 |
130 | def _prep_data_cupac(
131 | self, df: pd.DataFrame, pre_experiment_df: pd.DataFrame
132 | ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
133 | """Prepares data for training and prediction"""
134 | df = df.copy()
135 | pre_experiment_df = pre_experiment_df.copy()
136 | df_predict = df.drop(columns=[self.target_col])
137 | # Split data into X and y
138 | pre_experiment_x = pre_experiment_df.drop(columns=[self.target_col])
139 | pre_experiment_y = pre_experiment_df[self.target_col]
140 |
141 | # Keep only cupac features
142 | if self.features_cupac_model:
143 | pre_experiment_x = pre_experiment_x[self.features_cupac_model]
144 | df_predict = df_predict[self.features_cupac_model]
145 |
146 | return df_predict, pre_experiment_x, pre_experiment_y
147 |
148 | def add_covariates(
149 | self, df: pd.DataFrame, pre_experiment_df: Optional[pd.DataFrame] = None
150 | ) -> pd.DataFrame:
151 | """
152 | Train model to predict outcome variable (based on pre-experiment data)
153 | and add the prediction to the experiment dataframe. Only do this if
154 | we use cupac
155 | Args:
156 | pre_experiment_df: Dataframe with pre-experiment data.
157 | df: Dataframe with outcome and treatment variables.
158 | """
159 | self.check_cupac_inputs(pre_experiment_df)
160 |
161 | # Early return if no need to add covariates
162 | if not self.need_covariates(pre_experiment_df):
163 | return df
164 |
165 | df = df.copy()
166 | pre_experiment_df = pre_experiment_df.copy()
167 | df_predict, pre_experiment_x, pre_experiment_y = self._prep_data_cupac(
168 | df=df, pre_experiment_df=pre_experiment_df
169 | )
170 |
171 | # Fit model if it has not been fitted before
172 | self._fit_cupac_model(pre_experiment_x, pre_experiment_y)
173 |
174 | # Predict
175 | estimated_target = self._predict_cupac_model(df_predict)
176 |
177 | # Add cupac outcome name to df
178 | df[self.cupac_outcome_name] = estimated_target
179 | return df
180 |
181 | def _fit_cupac_model(
182 | self, pre_experiment_x: pd.DataFrame, pre_experiment_y: pd.Series
183 | ):
184 | """Fits the cupac model.
185 | Caches the fitted model in the object, so we only fit it once.
186 | We can disable this by setting cache_fit to False.
187 | """
188 | if not self.cache_fit:
189 | self.cupac_model.fit(pre_experiment_x, pre_experiment_y)
190 | return
191 |
192 | try:
193 | check_is_fitted(self.cupac_model)
194 | except NotFittedError:
195 | self.cupac_model.fit(pre_experiment_x, pre_experiment_y)
196 |
197 | def _predict_cupac_model(self, df_predict: pd.DataFrame) -> ArrayLike:
198 | """Predicts the cupac model"""
199 | if hasattr(self.cupac_model, "predict_proba"):
200 | return self.cupac_model.predict_proba(df_predict)[:, 1]
201 | if hasattr(self.cupac_model, "predict"):
202 | return self.cupac_model.predict(df_predict)
203 | raise ValueError("cupac_model should have predict or predict_proba method.")
204 |
205 | def need_covariates(self, pre_experiment_df: Optional[pd.DataFrame] = None) -> bool:
206 | return pre_experiment_df is not None and self.is_cupac
207 |
208 | def check_cupac_inputs(self, pre_experiment_df: Optional[pd.DataFrame] = None):
209 | if self.is_cupac and pre_experiment_df is None:
210 | raise ValueError("If cupac is used, pre_experiment_df should be provided.")
211 |
212 | if not self.is_cupac and pre_experiment_df is not None:
213 | raise ValueError(
214 | "If cupac is not used, pre_experiment_df should not be provided - "
215 | "remove pre_experiment_df argument or set cupac_model to not None."
216 | )
217 |
--------------------------------------------------------------------------------
/cluster_experiments/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/cluster_experiments/inference/__init__.py
--------------------------------------------------------------------------------
/cluster_experiments/inference/analysis_plan_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Dict, List, Optional, Union
3 |
4 |
5 | @dataclass(eq=True)
6 | class AnalysisPlanMetricsConfig:
7 | metrics: List[Dict[str, str]]
8 | variants: List[Dict[str, Union[str, bool]]]
9 | analysis_type: str
10 | variant_col: str = "experiment_group"
11 | alpha: float = 0.05
12 | dimensions: List[Dict[str, Union[str, List]]] = field(default_factory=lambda: [])
13 | analysis_config: Dict = field(default_factory=lambda: {})
14 | custom_analysis_type_mapper: Optional[Dict] = None
15 |
16 |
17 | @dataclass(eq=True)
18 | class AnalysisPlanConfig:
19 | tests: List[Dict[str, Union[List[Dict], Dict, str]]]
20 | variants: List[Dict[str, Union[str, bool]]]
21 | variant_col: str = "experiment_group"
22 | alpha: float = 0.05
23 |
--------------------------------------------------------------------------------
/cluster_experiments/inference/analysis_results.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass, field
2 | from typing import List
3 |
4 | import pandas as pd
5 |
6 |
7 | @dataclass
8 | class AnalysisPlanResults:
9 | """
10 | A dataclass used to represent the results of the experiment analysis.
11 |
12 | Attributes
13 | ----------
14 | metric_alias : List[str]
15 | The alias of the metric used in the test
16 | control_variant_name : List[str]
17 | The name of the control variant
18 | treatment_variant_name : List[str]
19 | The name of the treatment variant
20 | control_variant_mean : List[float]
21 | The mean value of the control variant
22 | treatment_variant_mean : List[float]
23 | The mean value of the treatment variant
24 | analysis_type : List[str]
25 | The type of analysis performed
26 | ate : List[float]
27 | The average treatment effect
28 | ate_ci_lower : List[float]
29 | The lower bound of the confidence interval for the ATE
30 | ate_ci_upper : List[float]
31 | The upper bound of the confidence interval for the ATE
32 | p_value : List[float]
33 | The p-value of the test
34 | std_error : List[float]
35 | The standard error of the test
36 | dimension_name : List[str]
37 | The name of the dimension
38 | dimension_value : List[str]
39 | The value of the dimension
40 | alpha: List[float]
41 | The significance level of the test
42 | """
43 |
44 | metric_alias: List[str] = field(default_factory=lambda: [])
45 | control_variant_name: List[str] = field(default_factory=lambda: [])
46 | treatment_variant_name: List[str] = field(default_factory=lambda: [])
47 | control_variant_mean: List[float] = field(default_factory=lambda: [])
48 | treatment_variant_mean: List[float] = field(default_factory=lambda: [])
49 | analysis_type: List[str] = field(default_factory=lambda: [])
50 | ate: List[float] = field(default_factory=lambda: [])
51 | ate_ci_lower: List[float] = field(default_factory=lambda: [])
52 | ate_ci_upper: List[float] = field(default_factory=lambda: [])
53 | p_value: List[float] = field(default_factory=lambda: [])
54 | std_error: List[float] = field(default_factory=lambda: [])
55 | dimension_name: List[str] = field(default_factory=lambda: [])
56 | dimension_value: List[str] = field(default_factory=lambda: [])
57 | alpha: List[float] = field(default_factory=lambda: [])
58 |
59 | def __add__(self, other):
60 | if not isinstance(other, AnalysisPlanResults):
61 | return NotImplemented
62 |
63 | return AnalysisPlanResults(
64 | metric_alias=self.metric_alias + other.metric_alias,
65 | control_variant_name=self.control_variant_name + other.control_variant_name,
66 | treatment_variant_name=self.treatment_variant_name
67 | + other.treatment_variant_name,
68 | control_variant_mean=self.control_variant_mean + other.control_variant_mean,
69 | treatment_variant_mean=self.treatment_variant_mean
70 | + other.treatment_variant_mean,
71 | analysis_type=self.analysis_type + other.analysis_type,
72 | ate=self.ate + other.ate,
73 | ate_ci_lower=self.ate_ci_lower + other.ate_ci_lower,
74 | ate_ci_upper=self.ate_ci_upper + other.ate_ci_upper,
75 | p_value=self.p_value + other.p_value,
76 | std_error=self.std_error + other.std_error,
77 | dimension_name=self.dimension_name + other.dimension_name,
78 | dimension_value=self.dimension_value + other.dimension_value,
79 | alpha=self.alpha + other.alpha,
80 | )
81 |
82 | def to_dataframe(self):
83 | return pd.DataFrame(asdict(self))
84 |
--------------------------------------------------------------------------------
/cluster_experiments/inference/dimension.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 |
5 | @dataclass
6 | class Dimension:
7 | """
8 | A class used to represent a Dimension with a name and values.
9 |
10 | Attributes
11 | ----------
12 | name : str
13 | The name of the dimension
14 | values : List[str]
15 | A list of strings representing the possible values of the dimension
16 | """
17 |
18 | name: str
19 | values: List[str]
20 |
21 | def __post_init__(self):
22 | """
23 | Validates the inputs after initialization.
24 | """
25 | self._validate_inputs()
26 |
27 | def _validate_inputs(self):
28 | """
29 | Validates the inputs for the Dimension class.
30 |
31 | Raises
32 | ------
33 | TypeError
34 | If the name is not a string or if values is not a list of strings.
35 | """
36 | if not isinstance(self.name, str):
37 | raise TypeError("Dimension name must be a string")
38 | if not isinstance(self.values, list) or not all(
39 | isinstance(val, str) for val in self.values
40 | ):
41 | raise TypeError("Dimension values must be a list of strings")
42 |
43 | def iterate_dimension_values(self):
44 | """
45 | A generator method to yield name and values from the dimension.
46 |
47 | Yields
48 | ------
49 | Any
50 | A unique value from the dimension.
51 | """
52 | seen = set()
53 | for value in self.values:
54 | if value not in seen:
55 | seen.add(value)
56 | yield value
57 |
58 | @classmethod
59 | def from_metrics_config(cls, config: dict) -> "Dimension":
60 | """
61 | Creates a Dimension object from a configuration dictionary.
62 |
63 | Parameters
64 | ----------
65 | config : dict
66 | A dictionary containing the configuration for the Dimension
67 |
68 | Returns
69 | -------
70 | Dimension
71 | A Dimension object
72 | """
73 | return cls(name=config["name"], values=config["values"])
74 |
75 |
76 | @dataclass
77 | class DefaultDimension(Dimension):
78 | """
79 | A class used to represent a Dimension with a default value representing total, i.e. no slicing.
80 | """
81 |
82 | def __init__(self):
83 | super().__init__(name="__total_dimension", values=["total"])
84 |
--------------------------------------------------------------------------------
/cluster_experiments/inference/metric.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional
3 |
4 | import pandas as pd
5 |
6 |
7 | class Metric(ABC):
8 | """
9 | An abstract base class used to represent a Metric with an alias.
10 |
11 | Attributes
12 | ----------
13 | alias : str
14 | A string representing the alias of the metric
15 | """
16 |
17 | def __init__(self, alias: str):
18 | """
19 | Parameters
20 | ----------
21 | alias : str
22 | The alias of the metric
23 | """
24 | self.alias = alias
25 | self._validate_alias()
26 |
27 | def _validate_alias(self):
28 | """
29 | Validates the alias input for the Metric class.
30 |
31 | Raises
32 | ------
33 | TypeError
34 | If the alias is not a string
35 | """
36 | if not isinstance(self.alias, str):
37 | raise TypeError("Metric alias must be a string")
38 |
39 | @property
40 | @abstractmethod
41 | def target_column(self) -> str:
42 | """
43 | Abstract property to return the target column to feed the experiment analysis class, from the metric definition.
44 |
45 | Returns
46 | -------
47 | str
48 | The target column name
49 | """
50 | pass
51 |
52 | @property
53 | def scale_column(self) -> Optional[str]:
54 | """
55 | Abstract property to return the scale column to feed the experiment analysis class, from the metric definition.
56 |
57 | Returns
58 | -------
59 | str
60 | The scale column name
61 | """
62 | return None
63 |
64 | @abstractmethod
65 | def get_mean(self, df: pd.DataFrame) -> float:
66 | """
67 | Abstract method to return the mean value of the metric, given a dataframe.
68 |
69 | Returns
70 | -------
71 | float
72 | The mean value of the metric
73 | """
74 | pass
75 |
76 | @classmethod
77 | def from_metrics_config(cls, config: dict) -> "Metric":
78 | """
79 | Class method to create a Metric instance from a configuration dictionary.
80 |
81 | Parameters
82 | ----------
83 | config : dict
84 | A dictionary containing the configuration of the metric
85 |
86 | Returns
87 | -------
88 | Metric
89 | A Metric instance
90 | """
91 | if "numerator_name" in config:
92 | return RatioMetric.from_metrics_config(config)
93 | return SimpleMetric.from_metrics_config(config)
94 |
95 |
96 | class SimpleMetric(Metric):
97 | """
98 | A class used to represent a Simple Metric with an alias and a name.
99 | To be used when the metric is defined at the same level of the data used for the analysis.
100 |
101 | Example
102 | ----------
103 | In a clustered experiment the participants were randomised based on their country of residence.
104 | The metric of interest is the salary of each participant. If the dataset fed into the analysis is at participant-level,
105 | then a SimpleMetric must be used. However, if the dataset fed into the analysis is at country-level, then a RatioMetric must be used.
106 |
107 | Attributes
108 | ----------
109 | alias : str
110 | A string representing the alias of the metric
111 | name : str
112 | A string representing the name of the metric
113 | """
114 |
115 | def __init__(self, alias: str, name: str):
116 | """
117 | Parameters
118 | ----------
119 | alias : str
120 | The alias of the metric
121 | name : str
122 | The name of the metric
123 | """
124 | super().__init__(alias)
125 | self.name = name
126 | self._validate_name()
127 |
128 | def _validate_name(self):
129 | """
130 | Validates the name input for the SimpleMetric class.
131 |
132 | Raises
133 | ------
134 | TypeError
135 | If the name is not a string
136 | """
137 | if not isinstance(self.name, str):
138 | raise TypeError("SimpleMetric name must be a string")
139 |
140 | @property
141 | def target_column(self) -> str:
142 | """
143 | Returns the target column for the SimpleMetric.
144 |
145 | Returns
146 | -------
147 | str
148 | The name of the metric
149 | """
150 | return self.name
151 |
152 | def get_mean(self, df: pd.DataFrame) -> float:
153 | """
154 | Returns the mean value of the metric, given a dataframe.
155 |
156 | Returns
157 | -------
158 | float
159 | The mean value of the metric
160 | """
161 | return df[self.name].mean()
162 |
163 | @classmethod
164 | def from_metrics_config(cls, config: dict) -> "Metric":
165 | """
166 | Class method to create a SimpleMetric instance from a configuration dictionary.
167 |
168 | Parameters
169 | ----------
170 | config : dict
171 | A dictionary containing the configuration of the metric
172 |
173 | Returns
174 | -------
175 | SimpleMetric
176 | A SimpleMetric instance
177 | """
178 | return cls(alias=config["alias"], name=config["name"])
179 |
180 |
181 | class RatioMetric(Metric):
182 | """
183 | A class used to represent a Ratio Metric with an alias, a numerator name, and a denominator name.
184 | To be used when the metric is defined at a lower level than the data used for the analysis.
185 |
186 | Example
187 | ----------
188 | In a clustered experiment the participants were randomised based on their country of residence.
189 | The metric of interest is the salary of each participant. If the dataset fed into the analysis is at country-level,
190 | then a RatioMetric must be used: the numerator would be the sum of all salaries in the country,
191 | the denominator would be the number of participants in the country.
192 |
193 | Attributes
194 | ----------
195 | alias : str
196 | A string representing the alias of the metric
197 | numerator_name : str
198 | A string representing the numerator name of the metric
199 | denominator_name : str
200 | A string representing the denominator name of the metric
201 | """
202 |
203 | def __init__(self, alias: str, numerator_name: str, denominator_name: str):
204 | """
205 | Parameters
206 | ----------
207 | alias : str
208 | The alias of the metric
209 | numerator_name : str
210 | The numerator name of the metric
211 | denominator_name : str
212 | The denominator name of the metric
213 | """
214 | super().__init__(alias)
215 | self.numerator_name = numerator_name
216 | self.denominator_name = denominator_name
217 | self._validate_names()
218 |
219 | def _validate_names(self):
220 | """
221 | Validates the numerator and denominator names input for the RatioMetric class.
222 |
223 | Raises
224 | ------
225 | TypeError
226 | If the numerator or denominator names are not strings
227 | """
228 | if not isinstance(self.numerator_name, str) or not isinstance(
229 | self.denominator_name, str
230 | ):
231 | raise TypeError("RatioMetric names must be strings")
232 |
233 | @property
234 | def target_column(self) -> str:
235 | """
236 | Returns the target column for the RatioMetric.
237 |
238 | Returns
239 | -------
240 | str
241 | The numerator name of the metric
242 | """
243 | return self.numerator_name
244 |
245 | @property
246 | def scale_column(self) -> str:
247 | """
248 | Returns the scale column for the RatioMetric.
249 |
250 | Returns
251 | -------
252 | str
253 | The denominator name of the metric
254 | """
255 | return self.denominator_name
256 |
257 | def get_mean(self, df: pd.DataFrame) -> float:
258 | """
259 | Returns the mean value of the metric, given a dataframe.
260 |
261 | Returns
262 | -------
263 | float
264 | The mean value of the metric
265 | """
266 | return df[self.numerator_name].mean() / df[self.denominator_name].mean()
267 |
268 | @classmethod
269 | def from_metrics_config(cls, config: dict) -> "Metric":
270 | """
271 | Class method to create a RatioMetric instance from a configuration dictionary.
272 |
273 | Parameters
274 | ----------
275 | config : dict
276 | A dictionary containing the configuration of the metric
277 |
278 | Returns
279 | -------
280 | RatioMetric
281 | A RatioMetric instance
282 | """
283 | return cls(
284 | alias=config["alias"],
285 | numerator_name=config["numerator_name"],
286 | denominator_name=config["denominator_name"],
287 | )
288 |
--------------------------------------------------------------------------------
/cluster_experiments/inference/variant.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class Variant:
6 | """
7 | A class used to represent a Variant with a name and a control flag.
8 |
9 | Attributes
10 | ----------
11 | name : str
12 | The name of the variant
13 | is_control : bool
14 | A boolean indicating if the variant is a control variant
15 | """
16 |
17 | name: str
18 | is_control: bool
19 |
20 | def __post_init__(self):
21 | """
22 | Validates the inputs after initialization.
23 | """
24 | self._validate_inputs()
25 |
26 | def _validate_inputs(self):
27 | """
28 | Validates the inputs for the Variant class.
29 |
30 | Raises
31 | ------
32 | TypeError
33 | If the name is not a string or if is_control is not a boolean.
34 | """
35 | if not isinstance(self.name, str):
36 | raise TypeError("Variant name must be a string")
37 | if not isinstance(self.is_control, bool):
38 | raise TypeError("Variant is_control must be a boolean")
39 |
40 | @classmethod
41 | def from_metrics_config(cls, config: dict) -> "Variant":
42 | """
43 | Creates a Variant object from a configuration dictionary.
44 |
45 | Parameters
46 | ----------
47 | config : dict
48 | A dictionary containing the configuration for the Variant
49 |
50 | Returns
51 | -------
52 | Variant
53 | A Variant object
54 | """
55 | return cls(name=config["name"], is_control=config["is_control"])
56 |
--------------------------------------------------------------------------------
/cluster_experiments/synthetic_control_utils.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import product
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from scipy.optimize import fmin_slsqp
7 |
8 |
9 | def loss_w(W: np.ndarray, X: np.ndarray, y: np.ndarray) -> float:
10 | """
11 | This function calculates the root mean square error (RMSE) between the actual and predicted values in a linear model.
12 | It is used as an objective function for optimization problems where the goal is to minimize the RMSE.
13 |
14 | Parameters:
15 | W (numpy.ndarray): The weights vector used for predictions.
16 | X (numpy.ndarray): The input data matrix.
17 | y (numpy.ndarray): The actual output vector.
18 |
19 | Returns:
20 | float: The calculated RMSE.
21 | """
22 | return np.sqrt(np.mean((y - X.dot(W)) ** 2))
23 |
24 |
25 | def get_w(X, y, verbose=False) -> np.ndarray:
26 | """
27 | Get weights per unit, constraint in the loss function that sum equals 1; bounds 0 and 1)
28 | """
29 | w_start = np.full(X.shape[1], 1 / X.shape[1])
30 | bounds = [(0.0, 1.0)] * len(w_start)
31 |
32 | weights = fmin_slsqp(
33 | partial(loss_w, X=X, y=y),
34 | w_start,
35 | f_eqcons=lambda x: np.sum(x) - 1,
36 | bounds=bounds,
37 | disp=verbose,
38 | )
39 | return weights
40 |
41 |
42 | def generate_synthetic_control_data(N, start_date, end_date):
43 | """Create df for synthetic control cases, where we need a time variable, a target metric, and some clusters."""
44 | # Generate a list of dates between start_date and end_date
45 | dates = pd.date_range(start_date, end_date, freq="d")
46 |
47 | users = [f"User {i}" for i in range(N)]
48 |
49 | # Create a combination of each date with each user
50 | combinations = list(product(users, dates))
51 |
52 | target_values = np.random.normal(100, 10, size=len(combinations))
53 |
54 | df = pd.DataFrame(combinations, columns=["user", "date"])
55 | df["target"] = target_values
56 |
57 | df["date"] = pd.to_datetime(df["date"])
58 |
59 | return df
60 |
--------------------------------------------------------------------------------
/cluster_experiments/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Dict
3 |
4 |
5 | def _original_time_column(time_col: str) -> str:
6 | """
7 | Usage:
8 | ```python
9 | from cluster_experiments.utils import _original_time_column
10 |
11 | assert _original_time_column("hola") == "original___hola"
12 | ```
13 | """
14 | return f"original___{time_col}"
15 |
16 |
17 | def _get_mapping_key(mapping, key):
18 | try:
19 | return mapping[key]
20 | except KeyError:
21 | raise KeyError(
22 | f"Could not find {key = } in mapping. All options are the following: {list(mapping.keys())}"
23 | )
24 |
25 |
26 | class HypothesisEntries(Enum):
27 | TWO_SIDED = "two-sided"
28 | LESS = "less"
29 | GREATER = "greater"
30 |
31 |
32 | class ModelResults:
33 | def __init__(self, params: Dict, pvalues: Dict):
34 | self.params = params
35 | self.pvalues = pvalues
36 |
--------------------------------------------------------------------------------
/docs/api/analysis_plan.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.inference.analysis_plan import *`
2 |
3 | ::: cluster_experiments.inference.analysis_plan
4 |
--------------------------------------------------------------------------------
/docs/api/analysis_results.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.inference.analysis_results import *`
2 |
3 | ::: cluster_experiments.inference.analysis_results
4 |
--------------------------------------------------------------------------------
/docs/api/cupac_model.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.cupac import *`
2 |
3 | ::: cluster_experiments.cupac
4 |
--------------------------------------------------------------------------------
/docs/api/dimension.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.inference.dimension import *`
2 |
3 | ::: cluster_experiments.inference.dimension
4 |
--------------------------------------------------------------------------------
/docs/api/experiment_analysis.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.experiment_analysis import *`
2 |
3 | ::: cluster_experiments.experiment_analysis
4 |
--------------------------------------------------------------------------------
/docs/api/hypothesis_test.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.inference.hypothesis_test import *`
2 |
3 | ::: cluster_experiments.inference.hypothesis_test
4 |
--------------------------------------------------------------------------------
/docs/api/metric.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.inference.metric import *`
2 |
3 | ::: cluster_experiments.inference.metric
4 |
--------------------------------------------------------------------------------
/docs/api/perturbator.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.perturbator import *`
2 |
3 | ::: cluster_experiments.perturbator
4 |
--------------------------------------------------------------------------------
/docs/api/power_analysis.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.power_analysis import *`
2 |
3 | ::: cluster_experiments.power_analysis
4 |
--------------------------------------------------------------------------------
/docs/api/power_config.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.power_config import *`
2 |
3 | ::: cluster_experiments.power_config
4 |
--------------------------------------------------------------------------------
/docs/api/random_splitter.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.random_splitter import *`
2 |
3 | ::: cluster_experiments.random_splitter
4 |
--------------------------------------------------------------------------------
/docs/api/variant.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.inference.variant import *`
2 |
3 | ::: cluster_experiments.inference.variant
4 |
--------------------------------------------------------------------------------
/docs/api/washover.md:
--------------------------------------------------------------------------------
1 | # `from cluster_experiments.washover import *`
2 |
3 | ::: cluster_experiments.washover
4 |
--------------------------------------------------------------------------------
/docs/create_custom_classes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "Examples on how to create:\n",
9 | "* a custom perturbator\n",
10 | "* a custom splitter\n",
11 | "* a custom hypothesis test\n",
12 | "\n",
13 | "The names of you custom classes don't need to be CustomX, they are completely free. The only requirement is that they inherit from the base class. For example, if you want to create a custom perturbator, you need to inherit from the Perturbator base class. The same applies to the other classes."
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "from cluster_experiments import ExperimentAnalysis\n",
23 | "import pandas as pd\n",
24 | "from scipy.stats import ttest_ind\n",
25 | "\n",
26 | "class CustomExperimentAnalysis(ExperimentAnalysis):\n",
27 | " def analysis_pvalue(self, df: pd.DataFrame, verbose: bool = True) -> float:\n",
28 | " treatment_data = df.query(f\"{self.treatment_col} == 1\")[self.target_col]\n",
29 | " control_data = df.query(f\"{self.treatment_col} == 0\")[self.target_col]\n",
30 | " t_test_results = ttest_ind(treatment_data, control_data, equal_var=False)\n",
31 | " return t_test_results.pvalue"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 2,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "from cluster_experiments import RandomSplitter\n",
41 | "import numpy as np\n",
42 | "\n",
43 | "class CustomRandomSplitter(RandomSplitter):\n",
44 | " def assign_treatment_df(self, df: pd.DataFrame) -> pd.DataFrame:\n",
45 | " df = df.copy()\n",
46 | " # Power users get treatment with 90% probability\n",
47 | " df_power_users = df.query(\"power_user\")\n",
48 | " df_power_users[self.treatment_col] = np.random.choice(\n",
49 | " [\"A\", \"B\"], size=len(df_power_users), p=[0.1, 0.9]\n",
50 | " )\n",
51 | " # Non-power users get treatment with 10% probability\n",
52 | " df_non_power_users = df.query(\"not power_user\")\n",
53 | " df_non_power_users[self.treatment_col] = np.random.choice(\n",
54 | " [\"A\", \"B\"], size=len(df_non_power_users), p=[0.9, 0.1]\n",
55 | " )\n",
56 | " return pd.concat([df_power_users, df_non_power_users])"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 3,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from cluster_experiments import Perturbator\n",
66 | "import pandas as pd\n",
67 | "\n",
68 | "class CustomPerturbator(Perturbator):\n",
69 | " def perturbate(self, df: pd.DataFrame, average_effect: float) -> pd.DataFrame:\n",
70 | " df = df.copy().reset_index(drop=True)\n",
71 | " n = (df[self.treatment_col] == self.treatment).sum()\n",
72 | " df.loc[\n",
73 | " df[self.treatment_col] == self.treatment, self.target_col\n",
74 | " ] += np.random.normal(average_effect, 1, size=n)\n",
75 | " return df"
76 | ]
77 | }
78 | ],
79 | "metadata": {
80 | "kernelspec": {
81 | "display_name": "Python 3.8.6 ('venv': venv)",
82 | "language": "python",
83 | "name": "python3"
84 | },
85 | "language_info": {
86 | "codemirror_mode": {
87 | "name": "ipython",
88 | "version": 3
89 | },
90 | "file_extension": ".py",
91 | "mimetype": "text/x-python",
92 | "name": "python",
93 | "nbconvert_exporter": "python",
94 | "pygments_lexer": "ipython3",
95 | "version": "3.8.6 (default, Jan 17 2022, 12:11:54) \n[Clang 12.0.5 (clang-1205.0.22.11)]"
96 | },
97 | "orig_nbformat": 4,
98 | "vscode": {
99 | "interpreter": {
100 | "hash": "29c447d2129f0d56b23b7ba3abc571cfa9d42454e0e2bba301a881797dc4c0e2"
101 | }
102 | }
103 | },
104 | "nbformat": 4,
105 | "nbformat_minor": 2
106 | }
107 |
--------------------------------------------------------------------------------
/docs/multivariate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "This notebook shows how to use the multivariate module. The idea is to use several treatments in the splitter and only one of them is used to run the hypothesis test."
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import numpy as np\n",
18 | "import pandas as pd\n",
19 | "from cluster_experiments import PowerAnalysis\n"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "# Create fake data\n",
29 | "N = 1_000\n",
30 | "df = pd.DataFrame(\n",
31 | " {\n",
32 | " \"target\": np.random.normal(0, 1, size=N),\n",
33 | " }\n",
34 | ")"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 3,
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "data": {
44 | "text/plain": [
45 | "0.18"
46 | ]
47 | },
48 | "execution_count": 3,
49 | "metadata": {},
50 | "output_type": "execute_result"
51 | }
52 | ],
53 | "source": [
54 | "# Run power analysis using 3 variants\n",
55 | "config_abc = {\n",
56 | " \"analysis\": \"ols_non_clustered\",\n",
57 | " \"perturbator\": \"constant\",\n",
58 | " \"splitter\": \"non_clustered\",\n",
59 | " \"treatments\": [\"A\", \"B\", \"C\"],\n",
60 | " \"control\": \"A\",\n",
61 | " \"treatment\": \"B\",\n",
62 | " \"n_simulations\": 50,\n",
63 | "}\n",
64 | "\n",
65 | "power_abc = PowerAnalysis.from_dict(config_abc)\n",
66 | "power_abc.power_analysis(df, average_effect=0.1)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "data": {
76 | "text/plain": [
77 | "0.28"
78 | ]
79 | },
80 | "execution_count": 4,
81 | "metadata": {},
82 | "output_type": "execute_result"
83 | }
84 | ],
85 | "source": [
86 | "# Run power analysis using 2 variants\n",
87 | "config_ab = {\n",
88 | " \"analysis\": \"ols_non_clustered\",\n",
89 | " \"perturbator\": \"constant\",\n",
90 | " \"splitter\": \"non_clustered\",\n",
91 | " \"treatments\": [\"A\", \"B\"],\n",
92 | " \"control\": \"A\",\n",
93 | " \"treatment\": \"B\",\n",
94 | " \"n_simulations\": 50,\n",
95 | "}\n",
96 | "power_ab = PowerAnalysis.from_dict(config_ab)\n",
97 | "power_ab.power_analysis(df, average_effect=0.1)"
98 | ]
99 | },
100 | {
101 | "attachments": {},
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "The power of the AB test is higher than the ABC test, which makes sense."
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": []
112 | }
113 | ],
114 | "metadata": {
115 | "kernelspec": {
116 | "display_name": "venv",
117 | "language": "python",
118 | "name": "python3"
119 | },
120 | "language_info": {
121 | "codemirror_mode": {
122 | "name": "ipython",
123 | "version": 3
124 | },
125 | "file_extension": ".py",
126 | "mimetype": "text/x-python",
127 | "name": "python",
128 | "nbconvert_exporter": "python",
129 | "pygments_lexer": "ipython3",
130 | "version": "3.8.6"
131 | },
132 | "orig_nbformat": 4,
133 | "vscode": {
134 | "interpreter": {
135 | "hash": "29c447d2129f0d56b23b7ba3abc571cfa9d42454e0e2bba301a881797dc4c0e2"
136 | }
137 | }
138 | },
139 | "nbformat": 4,
140 | "nbformat_minor": 2
141 | }
142 |
--------------------------------------------------------------------------------
/docs/paired_ttest.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "This notebook shows how the PairedTTestClusteredAnalysis class is performing the paired t test. It's important to get a grasp on the difference between cluster and strata columns.\n"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {
14 | "pycharm": {
15 | "name": "#%%\n"
16 | }
17 | },
18 | "outputs": [],
19 | "source": [
20 | "from cluster_experiments.experiment_analysis import PairedTTestClusteredAnalysis\n",
21 | "\n",
22 | "import pandas as pd\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {
29 | "pycharm": {
30 | "name": "#%%\n"
31 | }
32 | },
33 | "outputs": [],
34 | "source": [
35 | "# Let's generate some fake switchback data (the clusters here would be city and date\n",
36 | "df = pd.DataFrame(\n",
37 | " {\n",
38 | " \"country_code\": [\"ES\"] * 4 + [\"IT\"] * 4 + [\"PL\"] * 4 + [\"RO\"] * 4,\n",
39 | " \"date\": [\"2022-01-01\", \"2022-01-02\", \"2022-01-03\", \"2022-01-04\"] * 4,\n",
40 | " \"treatment\": [\"A\", 'B'] * 8,\n",
41 | " \"target\": [0.01] * 15 + [0.1],\n",
42 | " }\n",
43 | " )"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "Let's see what the PairedTTestClusteredAnalysis class is doing under the hood. As I am passing already the treatment column, there's no need for splitter nor perturbator\n"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 3,
56 | "metadata": {
57 | "pycharm": {
58 | "name": "#%%\n"
59 | }
60 | },
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "performing paired t test in this data \n",
67 | " treatment A B\n",
68 | "country_code \n",
69 | "ES 0.01 0.010\n",
70 | "IT 0.01 0.010\n",
71 | "PL 0.01 0.010\n",
72 | "RO 0.01 0.055 \n",
73 | "\n"
74 | ]
75 | },
76 | {
77 | "data": {
78 | "text/plain": "treatment A B\ncountry_code \nES 0.01 0.010\nIT 0.01 0.010\nPL 0.01 0.010\nRO 0.01 0.055",
79 | "text/html": "
\n\n
\n \n \n treatment | \n A | \n B | \n
\n \n country_code | \n | \n | \n
\n \n \n \n ES | \n 0.01 | \n 0.010 | \n
\n \n IT | \n 0.01 | \n 0.010 | \n
\n \n PL | \n 0.01 | \n 0.010 | \n
\n \n RO | \n 0.01 | \n 0.055 | \n
\n \n
\n
"
80 | },
81 | "execution_count": 3,
82 | "metadata": {},
83 | "output_type": "execute_result"
84 | }
85 | ],
86 | "source": [
87 | "analysis = PairedTTestClusteredAnalysis(\n",
88 | " cluster_cols=[\"country_code\", \"date\"], strata_cols = ['country_code']\n",
89 | ")\n",
90 | "\n",
91 | "analysis._preprocessing(df, verbose=True)"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "source": [
97 | "Keep in mind that strata_cols needs to be a subset of cluster_cols and it will be used as the index for pivoting."
98 | ],
99 | "metadata": {
100 | "collapsed": false
101 | }
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 4,
106 | "outputs": [
107 | {
108 | "name": "stdout",
109 | "output_type": "stream",
110 | "text": [
111 | "paired t test results: \n",
112 | " TtestResult(statistic=-1.0, pvalue=0.39100221895577053, df=3) \n",
113 | "\n"
114 | ]
115 | },
116 | {
117 | "data": {
118 | "text/plain": "0.39100221895577053"
119 | },
120 | "execution_count": 4,
121 | "metadata": {},
122 | "output_type": "execute_result"
123 | }
124 | ],
125 | "source": [
126 | "analysis.analysis_pvalue(df, verbose=True)\n"
127 | ],
128 | "metadata": {
129 | "collapsed": false,
130 | "pycharm": {
131 | "name": "#%%\n"
132 | }
133 | }
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 4,
138 | "outputs": [],
139 | "source": [],
140 | "metadata": {
141 | "collapsed": false,
142 | "pycharm": {
143 | "name": "#%%\n"
144 | }
145 | }
146 | }
147 | ],
148 | "metadata": {
149 | "kernelspec": {
150 | "display_name": "Python 3 (ipykernel)",
151 | "language": "python",
152 | "name": "python3"
153 | },
154 | "language_info": {
155 | "codemirror_mode": {
156 | "name": "ipython",
157 | "version": 3
158 | },
159 | "file_extension": ".py",
160 | "mimetype": "text/x-python",
161 | "name": "python",
162 | "nbconvert_exporter": "python",
163 | "pygments_lexer": "ipython3",
164 | "version": "3.8.6"
165 | }
166 | },
167 | "nbformat": 4,
168 | "nbformat_minor": 1
169 | }
170 |
--------------------------------------------------------------------------------
/examples/cupac_example_gbm.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from sklearn.ensemble import HistGradientBoostingRegressor
6 |
7 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis
8 | from cluster_experiments.perturbator import ConstantPerturbator
9 | from cluster_experiments.power_analysis import PowerAnalysis
10 | from cluster_experiments.random_splitter import ClusteredSplitter
11 |
12 |
13 | def generate_random_data(clusters, dates, N):
14 |
15 | # Every cluster has a mean
16 | df_clusters = pd.DataFrame(
17 | {
18 | "cluster": clusters,
19 | "cluster_mean": np.random.normal(0, 0.1, size=len(clusters)),
20 | }
21 | )
22 | # The target is the sum of: user mean, cluster mean and random residual
23 | df = (
24 | pd.DataFrame(
25 | {
26 | "cluster": np.random.choice(clusters, size=N),
27 | "residual": np.random.normal(0, 1, size=N),
28 | "date": np.random.choice(dates, size=N),
29 | "x1": np.random.normal(0, 1, size=N),
30 | "x2": np.random.normal(0, 1, size=N),
31 | "x3": np.random.normal(0, 1, size=N),
32 | "x4": np.random.normal(0, 1, size=N),
33 | }
34 | )
35 | .merge(df_clusters, on="cluster")
36 | .assign(
37 | target=lambda x: x["x1"] * x["x2"]
38 | + x["x3"] ** 2
39 | + x["x4"]
40 | + x["cluster_mean"]
41 | + x["residual"]
42 | )
43 | )
44 |
45 | return df
46 |
47 |
48 | if __name__ == "__main__":
49 | clusters = [f"Cluster {i}" for i in range(100)]
50 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
51 | experiment_dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)]
52 | N = 10_000
53 | df = generate_random_data(clusters, dates, N)
54 | df_analysis = df.query(f"date.isin({experiment_dates})")
55 | df_pre = df.query(f"~date.isin({experiment_dates})")
56 | print(df)
57 |
58 | # Splitter and perturbator
59 | sw = ClusteredSplitter(
60 | cluster_cols=["cluster", "date"],
61 | )
62 |
63 | perturbator = ConstantPerturbator(
64 | average_effect=0.1,
65 | )
66 |
67 | # Vainilla GEE
68 | analysis = GeeExperimentAnalysis(
69 | cluster_cols=["cluster", "date"],
70 | )
71 | pw_vainilla = PowerAnalysis(
72 | perturbator=perturbator,
73 | splitter=sw,
74 | analysis=analysis,
75 | n_simulations=50,
76 | )
77 |
78 | power = pw_vainilla.power_analysis(df_analysis)
79 | print(f"Not using cupac: {power = }")
80 |
81 | # Cupac GEE
82 | analysis = GeeExperimentAnalysis(
83 | cluster_cols=["cluster", "date"], covariates=["estimate_target"]
84 | )
85 |
86 | gbm = HistGradientBoostingRegressor()
87 | pw_cupac = PowerAnalysis(
88 | perturbator=perturbator,
89 | splitter=sw,
90 | analysis=analysis,
91 | n_simulations=50,
92 | cupac_model=gbm,
93 | features_cupac_model=["x1", "x2", "x3", "x4"],
94 | )
95 |
96 | power = pw_cupac.power_analysis(df_analysis, df_pre)
97 | print(f"Using cupac: {power = }")
98 |
--------------------------------------------------------------------------------
/examples/cupac_example_target_mean.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.cupac import TargetAggregation
7 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis
8 | from cluster_experiments.perturbator import ConstantPerturbator
9 | from cluster_experiments.power_analysis import PowerAnalysis
10 | from cluster_experiments.random_splitter import ClusteredSplitter
11 |
12 |
13 | def generate_random_data(clusters, dates, N, n_users=1000):
14 | # Generate random data with clusters and target
15 | users = [f"User {i}" for i in range(n_users)]
16 |
17 | # Every user has a mean
18 | df_users = pd.DataFrame(
19 | {"user": users, "user_mean": np.random.normal(0, 3, size=n_users)}
20 | )
21 |
22 | # Every cluster has a mean
23 | df_clusters = pd.DataFrame(
24 | {
25 | "cluster": clusters,
26 | "cluster_mean": np.random.normal(0, 0.1, size=len(clusters)),
27 | }
28 | )
29 | # The target is the sum of: user mean, cluster mean and random residual
30 | df = (
31 | pd.DataFrame(
32 | {
33 | "cluster": np.random.choice(clusters, size=N),
34 | "residual": np.random.normal(0, 1, size=N),
35 | "user": np.random.choice(users, size=N),
36 | "date": np.random.choice(dates, size=N),
37 | }
38 | )
39 | .merge(df_users, on="user")
40 | .merge(df_clusters, on="cluster")
41 | .assign(target=lambda x: x["residual"] + x["user_mean"] + x["cluster_mean"])
42 | )
43 |
44 | return df
45 |
46 |
47 | if __name__ == "__main__":
48 | clusters = [f"Cluster {i}" for i in range(100)]
49 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
50 | experiment_dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)]
51 | N = 10_000
52 | df = generate_random_data(clusters, dates, N)
53 | df_analysis = df.query(f"date.isin({experiment_dates})")
54 | df_pre = df.query(f"~date.isin({experiment_dates})")
55 | print(df)
56 |
57 | # Splitter and perturbator
58 | sw = ClusteredSplitter(
59 | cluster_cols=["cluster", "date"],
60 | )
61 |
62 | perturbator = ConstantPerturbator(
63 | average_effect=0.1,
64 | )
65 |
66 | # Vainilla GEE
67 | analysis = GeeExperimentAnalysis(
68 | cluster_cols=["cluster", "date"],
69 | )
70 | pw_vainilla = PowerAnalysis(
71 | perturbator=perturbator,
72 | splitter=sw,
73 | analysis=analysis,
74 | n_simulations=50,
75 | )
76 |
77 | power = pw_vainilla.power_analysis(df_analysis)
78 | print(f"Not using cupac: {power = }")
79 |
80 | # Cupac GEE
81 | analysis = GeeExperimentAnalysis(
82 | cluster_cols=["cluster", "date"], covariates=["estimate_target"]
83 | )
84 |
85 | target_agg = TargetAggregation(target_col="target", agg_col="user")
86 | pw_cupac = PowerAnalysis(
87 | perturbator=perturbator,
88 | splitter=sw,
89 | analysis=analysis,
90 | n_simulations=50,
91 | cupac_model=target_agg,
92 | )
93 |
94 | power = pw_cupac.power_analysis(df_analysis, df_pre)
95 | print(f"Using cupac: {power = }")
96 |
--------------------------------------------------------------------------------
/examples/long_example.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis
7 | from cluster_experiments.perturbator import ConstantPerturbator
8 | from cluster_experiments.power_analysis import PowerAnalysis
9 | from cluster_experiments.random_splitter import ClusteredSplitter
10 |
11 |
12 | def generate_random_data(clusters, dates, N):
13 | # Generate random data with clusters and target
14 | users = [f"User {i}" for i in range(1000)]
15 | df = pd.DataFrame(
16 | {
17 | "cluster": np.random.choice(clusters, size=N),
18 | "target": np.random.normal(0, 1, size=N),
19 | "user": np.random.choice(users, size=N),
20 | "date": np.random.choice(dates, size=N),
21 | }
22 | )
23 |
24 | return df
25 |
26 |
27 | if __name__ == "__main__":
28 | clusters = [f"Cluster {i}" for i in range(100)]
29 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
30 | experiment_dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)]
31 | N = 10_000
32 | df = generate_random_data(clusters, dates, N)
33 | sw = ClusteredSplitter(
34 | treatments=["A", "B"],
35 | cluster_cols=["cluster", "date"],
36 | )
37 |
38 | treatment_assignment_df = sw.assign_treatment_df(df)
39 | # NaNs because of data previous to experiment
40 | print(treatment_assignment_df)
41 |
42 | perturbator = ConstantPerturbator(
43 | average_effect=0.01,
44 | target_col="target",
45 | treatment_col="treatment",
46 | )
47 |
48 | perturbated_df = perturbator.perturbate(treatment_assignment_df)
49 | print(perturbated_df.groupby(["treatment"]).mean())
50 |
51 | analysis = GeeExperimentAnalysis(
52 | target_col="target",
53 | treatment_col="treatment",
54 | cluster_cols=["cluster", "date"],
55 | )
56 |
57 | p_val = analysis.get_pvalue(perturbated_df.query("treatment.notnull()"))
58 | print(f"{p_val = }")
59 |
60 | pw = PowerAnalysis(
61 | target_col="target",
62 | treatment_col="treatment",
63 | treatment="B",
64 | perturbator=perturbator,
65 | splitter=sw,
66 | analysis=analysis,
67 | )
68 |
69 | print(df)
70 | power = pw.power_analysis(df)
71 | print(f"{power = }")
72 |
--------------------------------------------------------------------------------
/examples/parallel_example.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import time
3 | from datetime import date
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis
9 | from cluster_experiments.perturbator import ConstantPerturbator
10 | from cluster_experiments.power_analysis import PowerAnalysis
11 | from cluster_experiments.random_splitter import ClusteredSplitter
12 |
13 |
14 | def generate_random_data(clusters, dates, N):
15 | # Generate random data with clusters and target
16 | users = [f"User {i}" for i in range(1000)]
17 | df = pd.DataFrame(
18 | {
19 | "cluster": np.random.choice(clusters, size=N),
20 | "target": np.random.normal(0, 1, size=N),
21 | "user": np.random.choice(users, size=N),
22 | "date": np.random.choice(dates, size=N),
23 | }
24 | )
25 |
26 | return df
27 |
28 |
29 | # %%
30 | clusters = [f"Cluster {i}" for i in range(1000)]
31 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
32 | N = 1_000_000
33 | df = generate_random_data(clusters, dates, N)
34 | sw = ClusteredSplitter(
35 | cluster_cols=["cluster", "date"],
36 | )
37 |
38 | perturbator = ConstantPerturbator(
39 | average_effect=0.1,
40 | )
41 |
42 | analysis = GeeExperimentAnalysis(
43 | cluster_cols=["cluster", "date"],
44 | )
45 |
46 | pw = PowerAnalysis(perturbator=perturbator, splitter=sw, analysis=analysis)
47 |
48 | print(df)
49 | # %%
50 |
51 | if __name__ == "__main__":
52 |
53 | n_simulations = 16
54 | n_jobs = 8
55 | # parallel_start = time.time()
56 | # parallel_sim = pw.power_analysis_parallel(
57 | # df=df, n_simulations=n_simulations, average_effect=-0.01, n_jobs=16
58 | # )
59 | # parallel_end = time.time()
60 | # print("Parallel execution finished")
61 | # parallel_duration = parallel_end - parallel_start
62 | # print(f"{parallel_duration=}")
63 |
64 | non_parallel_start = time.time()
65 | simple_sim = pw.power_analysis(
66 | df=df, n_simulations=n_simulations, average_effect=-0.01
67 | )
68 | non_parallel_end = time.time()
69 | print("Non Parallel execution finished")
70 | non_parallel_duration = non_parallel_end - non_parallel_start
71 | print(f"{non_parallel_duration=}")
72 |
73 | parallel_start = time.time()
74 | parallel_sim = pw.power_analysis(
75 | df=df,
76 | n_simulations=n_simulations,
77 | average_effect=-0.01,
78 | n_jobs=n_jobs,
79 | )
80 | parallel_end = time.time()
81 | print("Parallel mp execution finished")
82 | parallel_duration = parallel_end - parallel_start
83 | print(f"{parallel_duration=}")
84 |
--------------------------------------------------------------------------------
/examples/short_example.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis
7 | from cluster_experiments.perturbator import ConstantPerturbator
8 | from cluster_experiments.power_analysis import PowerAnalysis
9 | from cluster_experiments.random_splitter import ClusteredSplitter
10 |
11 |
12 | def generate_random_data(clusters, dates, N):
13 | # Generate random data with clusters and target
14 | users = [f"User {i}" for i in range(1000)]
15 | df = pd.DataFrame(
16 | {
17 | "cluster": np.random.choice(clusters, size=N),
18 | "target": np.random.normal(0, 1, size=N),
19 | "user": np.random.choice(users, size=N),
20 | "date": np.random.choice(dates, size=N),
21 | }
22 | )
23 |
24 | return df
25 |
26 |
27 | if __name__ == "__main__":
28 | clusters = [f"Cluster {i}" for i in range(100)]
29 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
30 | N = 10_000
31 | df = generate_random_data(clusters, dates, N)
32 | sw = ClusteredSplitter(
33 | cluster_cols=["cluster", "date"],
34 | )
35 |
36 | perturbator = ConstantPerturbator(
37 | average_effect=0.1,
38 | )
39 |
40 | analysis = GeeExperimentAnalysis(
41 | cluster_cols=["cluster", "date"],
42 | )
43 |
44 | pw = PowerAnalysis(
45 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50
46 | )
47 |
48 | print(df)
49 | power = pw.power_analysis(df)
50 | print(f"{power = }")
51 |
--------------------------------------------------------------------------------
/examples/short_example_config.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.power_analysis import PowerAnalysis
7 | from cluster_experiments.power_config import PowerConfig
8 |
9 |
10 | def generate_random_data(clusters, dates, N):
11 | # Generate random data with clusters and target
12 | users = [f"User {i}" for i in range(1000)]
13 | df = pd.DataFrame(
14 | {
15 | "cluster": np.random.choice(clusters, size=N),
16 | "target": np.random.normal(0, 1, size=N),
17 | "user": np.random.choice(users, size=N),
18 | "date": np.random.choice(dates, size=N),
19 | }
20 | )
21 |
22 | return df
23 |
24 |
25 | if __name__ == "__main__":
26 | clusters = [f"Cluster {i}" for i in range(100)]
27 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
28 | N = 10_000
29 | df = generate_random_data(clusters, dates, N)
30 | config = PowerConfig(
31 | cluster_cols=["cluster", "date"],
32 | analysis="gee",
33 | perturbator="constant",
34 | splitter="clustered",
35 | n_simulations=100,
36 | )
37 | pw = PowerAnalysis.from_config(config)
38 |
39 | print(df)
40 | power = pw.power_analysis(df)
41 | print(f"{power = }")
42 |
--------------------------------------------------------------------------------
/examples/short_example_dict.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.power_analysis import PowerAnalysis
7 |
8 |
9 | def generate_random_data(clusters, dates, N):
10 | # Generate random data with clusters and target
11 | users = [f"User {i}" for i in range(1000)]
12 | df = pd.DataFrame(
13 | {
14 | "cluster": np.random.choice(clusters, size=N),
15 | "target": np.random.normal(0, 1, size=N),
16 | "user": np.random.choice(users, size=N),
17 | "date": np.random.choice(dates, size=N),
18 | }
19 | )
20 |
21 | return df
22 |
23 |
24 | if __name__ == "__main__":
25 | clusters = [f"Cluster {i}" for i in range(100)]
26 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
27 | N = 1_000
28 | df = generate_random_data(clusters, dates, N)
29 | config = {
30 | "cluster_cols": ["cluster", "date"],
31 | "analysis": "gee",
32 | "perturbator": "constant",
33 | "splitter": "clustered",
34 | "n_simulations": 50,
35 | }
36 | pw = PowerAnalysis.from_dict(config)
37 |
38 | print(df)
39 | power = pw.power_analysis(df)
40 | print(f"{power = }")
41 |
--------------------------------------------------------------------------------
/examples/short_example_paired_ttest.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.experiment_analysis import PairedTTestClusteredAnalysis
7 | from cluster_experiments.perturbator import ConstantPerturbator
8 | from cluster_experiments.power_analysis import PowerAnalysis
9 | from cluster_experiments.random_splitter import StratifiedSwitchbackSplitter
10 |
11 |
12 | def generate_random_data(clusters, dates, N):
13 | # Generate random data with clusters and target
14 | users = [f"User {i}" for i in range(1000)]
15 | df = pd.DataFrame(
16 | {
17 | "cluster": np.random.choice(clusters, size=N),
18 | "target": np.random.normal(0, 1, size=N),
19 | "user": np.random.choice(users, size=N),
20 | "date": np.random.choice(dates, size=N),
21 | }
22 | )
23 |
24 | df["date"] = pd.to_datetime(df["date"])
25 | df["dow"] = df["date"].dt.day_name()
26 |
27 | return df
28 |
29 |
30 | if __name__ == "__main__":
31 | clusters = [f"Cluster {i}" for i in range(100)]
32 | dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
33 | N = 10_000
34 | df = generate_random_data(clusters, dates, N)
35 | sw = StratifiedSwitchbackSplitter(
36 | cluster_cols=["cluster", "date"], strata_cols=["dow"]
37 | )
38 |
39 | perturbator = ConstantPerturbator(
40 | average_effect=0.1,
41 | )
42 |
43 | analysis = PairedTTestClusteredAnalysis(
44 | cluster_cols=["cluster", "date"], strata_cols=["cluster"]
45 | )
46 |
47 | pw = PowerAnalysis(
48 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50
49 | )
50 |
51 | print(df)
52 | power = pw.power_analysis(df)
53 | print(f"{power = }")
54 |
--------------------------------------------------------------------------------
/examples/short_example_synthetic_control.py:
--------------------------------------------------------------------------------
1 | from cluster_experiments.experiment_analysis import SyntheticControlAnalysis
2 | from cluster_experiments.perturbator import ConstantPerturbator
3 | from cluster_experiments.power_analysis import PowerAnalysisWithPreExperimentData
4 | from cluster_experiments.random_splitter import FixedSizeClusteredSplitter
5 | from cluster_experiments.synthetic_control_utils import generate_synthetic_control_data
6 |
7 | df = generate_synthetic_control_data(10, "2022-01-01", "2022-01-30")
8 |
9 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=2, cluster_cols=["user"])
10 |
11 | perturbator = ConstantPerturbator(
12 | average_effect=0.1,
13 | )
14 |
15 | analysis = SyntheticControlAnalysis(
16 | cluster_cols=["user"], time_col="date", intervention_date="2022-01-15"
17 | )
18 |
19 | pw = PowerAnalysisWithPreExperimentData(
20 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50
21 | )
22 |
23 | power = pw.power_analysis(df)
24 | print(f"{power = }")
25 | print(pw.power_line(df, average_effects=[0.1, 0.2, 0.5, 1, 1.5], n_jobs=-1))
26 | print(pw.simulate_point_estimate(df))
27 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Cluster Experiments Docs
2 | extra_css: [style.css]
3 | repo_url: https://github.com/david26694/cluster-experiments
4 | site_url: https://david26694.github.io/cluster-experiments/
5 | site_description: Functions to design and run clustered experiments
6 | site_author: David Masip
7 | use_directory_urls: false
8 | edit_uri: blob/main/docs/
9 | nav:
10 | - Home:
11 | - Index: index.md
12 | - End-to-end example: e2e_mde.ipynb
13 | - Cupac example: cupac_example.ipynb
14 | - Custom classes: create_custom_classes.ipynb
15 | - Switchback:
16 | - Stratified switchback: switchback.ipynb
17 | - Switchback calendar visualization: plot_calendars.ipynb
18 | - Visualization - 4-hour switches: plot_calendars_hours.ipynb
19 | - Multiple treatments: multivariate.ipynb
20 | - AA test clustered: aa_test.ipynb
21 | - Paired T test: paired_ttest.ipynb
22 | - Different hypotheses tests: analysis_with_different_hypotheses.ipynb
23 | - Washover: washover_example.ipynb
24 | - Normal Power:
25 | - Compare with simulation: normal_power.ipynb
26 | - Time-lines: normal_power_lines.ipynb
27 | - Synthetic control: synthetic_control.ipynb
28 | - Experiment analysis workflow: experiment_analysis.ipynb
29 | - Delta Method Analysis: delta_method.ipynb
30 | - API:
31 | - Experiment analysis methods: api/experiment_analysis.md
32 | - Perturbators: api/perturbator.md
33 | - Splitter: api/random_splitter.md
34 | - Pre experiment outcome model: api/cupac_model.md
35 | - Power config: api/power_config.md
36 | - Power analysis: api/power_analysis.md
37 | - Washover: api/washover.md
38 | - Metric: api/metric.md
39 | - Variant: api/variant.md
40 | - Dimension: api/dimension.md
41 | - Hypothesis Test: api/hypothesis_test.md
42 | - Analysis Plan: api/analysis_plan.md
43 | plugins:
44 | - mkdocstrings:
45 | watch:
46 | - cluster_experiments
47 | - mkdocs-jupyter
48 | - search
49 | copyright: Copyright © 2022 Maintained by David Masip.
50 | theme:
51 | name: material
52 | font:
53 | text: Ubuntu
54 | code: Ubuntu Mono
55 | feature:
56 | tabs: true
57 | palette:
58 | primary: indigo
59 | accent: blue
60 | markdown_extensions:
61 | - codehilite
62 | - pymdownx.inlinehilite
63 | - pymdownx.superfences
64 | - pymdownx.details
65 | - pymdownx.tabbed
66 | - pymdownx.snippets
67 | - pymdownx.highlight:
68 | use_pygments: true
69 | - toc:
70 | permalink: true
71 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | authors = []
3 | requires-python = "<3.13,>=3.9"
4 | dependencies = [
5 | "pip>=22.2.2",
6 | "statsmodels>=0.13.2",
7 | "pandas>=1.2.0",
8 | "scikit-learn>=1.0.0",
9 | "tqdm>=4.0.0",
10 | "numpy>=1.20.0",
11 | ]
12 | name = "cluster-experiments"
13 | version = "0.26.0"
14 | description = ""
15 | readme = "README.md"
16 | classifiers=[
17 | "Development Status :: 4 - Beta",
18 | "Intended Audience :: Developers",
19 | "Programming Language :: Python",
20 | "Programming Language :: Python :: 3.9",
21 | "Programming Language :: Python :: 3.10",
22 | "Programming Language :: Python :: 3.11",
23 | "Programming Language :: Python :: 3.12",
24 | "Programming Language :: Python :: 3 :: Only",
25 | "Operating System :: OS Independent",
26 | "License :: OSI Approved :: MIT License",
27 | ]
28 |
29 | [project.optional-dependencies]
30 | dev = [
31 | "pytest<9.0.0,>=5.4.3",
32 | "black<25.0.0,>=22.12.0",
33 | "ruff<1.0.0,>=0.7.4",
34 | "mktestdocs<1.0.0,>=0.2.2",
35 | "pytest-cov<7.0.0,>=2.10.1",
36 | "pytest-sugar<2.0.0,>=0.9.4",
37 | "pytest-slow-last<1.0.0,>=0.1.3",
38 | "coverage<8.0.0,>=7.6.7",
39 | "pytest-reportlog<1.0.0,>=0.4.0",
40 | "pytest-duration-insights<1.0.0,>=0.1.2",
41 | "pytest-clarity<2.0.0,>=1.0.1",
42 | "pytest-xdist<4.0.0,>=3.6.1",
43 | "pre-commit<5.0.0,>=2.6.0",
44 | "ipykernel<7.0.0,>=6.15.1",
45 | "twine<6.0.0,>=5.1.1",
46 | "build<2.0.0.0,>=1.2.2.post1",
47 | "tox<5.0.0,>=4.23.2",
48 | "mkdocs<2.0.0,>=1.4.0",
49 | "mkdocs-material<10.0.0,>=8.5.0",
50 | "mkdocstrings[python]<1.0.0,>=0.25.0",
51 | "jinja2<4.0.0,>=3.1.0",
52 | "mkdocs-jupyter<1.0.0,>=0.22.0",
53 | "matplotlib<4.0.0,>=3.4.3",
54 | "plotnine<1.0.0,>=0.8.0",
55 | ]
56 |
57 | test = [
58 | "pytest<9.0.0,>=5.4.3",
59 | "black<25.0.0,>=22.12.0",
60 | "ruff<1.0.0,>=0.7.4",
61 | "mktestdocs<1.0.0,>=0.2.2",
62 | "pytest-cov<7.0.0,>=2.10.1",
63 | "pytest-sugar<2.0.0,>=0.9.4",
64 | "pytest-slow-last<1.0.0,>=0.1.3",
65 | "coverage<8.0.0,>=7.6.7",
66 | "pytest-reportlog<1.0.0,>=0.4.0",
67 | "pytest-duration-insights<1.0.0,>=0.1.2",
68 | "pytest-clarity<2.0.0,>=1.0.1",
69 | "pytest-xdist<4.0.0,>=3.6.1",
70 | ]
71 |
72 | docs = [
73 | "mkdocs<2.0.0,>=1.4.0",
74 | "mkdocs-material<10.0.0,>=8.5.0",
75 | "mkdocstrings<1.0.0,>=0.18.0",
76 | "jinja2<4.0.0,>=3.1.0",
77 | "mkdocs-jupyter<1.0.0,>=0.22.0",
78 | "plotnine<1.0.0,>=0.8.0",
79 | "matplotlib<4.0.0,>=3.4.3",
80 | ]
81 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | line-length = 160
2 |
3 | extend-select = ["I001"]
4 |
5 | ignore = ["E501"]
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | packages=find_packages(),
5 | )
6 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/__init__.py
--------------------------------------------------------------------------------
/tests/analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/analysis/__init__.py
--------------------------------------------------------------------------------
/tests/analysis/conftest.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 |
7 | from tests.utils import generate_ratio_metric_data
8 |
9 | N = 50_000
10 |
11 |
12 | @pytest.fixture
13 | def dates():
14 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
15 |
16 |
17 | @pytest.fixture
18 | def experiment_dates():
19 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)]
20 |
21 |
22 | @pytest.fixture
23 | def analysis_df():
24 | return pd.DataFrame(
25 | {
26 | "target": [0, 1, 0, 1],
27 | "treatment": ["A", "B", "B", "A"],
28 | "cluster": ["Cluster 1", "Cluster 1", "Cluster 1", "Cluster 1"],
29 | "date": ["2022-01-01", "2022-01-01", "2022-01-01", "2022-01-01"],
30 | }
31 | )
32 |
33 |
34 | @pytest.fixture
35 | def analysis_ratio_df(dates, experiment_dates):
36 | pre_exp_dates = [d for d in dates if d not in experiment_dates]
37 |
38 | user_sample_mean = 0.3
39 | user_standard_error = 0.15
40 | users = 2000
41 |
42 | user_target_means = np.random.normal(user_sample_mean, user_standard_error, users)
43 |
44 | pre_data = generate_ratio_metric_data(
45 | pre_exp_dates, N, user_target_means, users, treatment_effect=0
46 | )
47 | post_data = generate_ratio_metric_data(
48 | experiment_dates, N, user_target_means, users
49 | )
50 | return pd.concat([pre_data, post_data])
51 |
52 |
53 | @pytest.fixture
54 | def covariate_data():
55 | """generates data via y ~ T + X"""
56 | N = 1000
57 | np.random.seed(123)
58 | X = np.random.normal(size=N)
59 | T = np.random.choice(["A", "B"], size=N)
60 | y = 0.5 * X + 0.1 * (T == "B") + np.random.normal(size=N)
61 | df = pd.DataFrame({"y": y, "T": T, "X": X})
62 | return df
63 |
--------------------------------------------------------------------------------
/tests/analysis/test_formula.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments.experiment_analysis import (
4 | ClusteredOLSAnalysis,
5 | GeeExperimentAnalysis,
6 | MLMExperimentAnalysis,
7 | )
8 |
9 | parametrisation = pytest.mark.parametrize(
10 | "analysis_class",
11 | [
12 | ClusteredOLSAnalysis,
13 | GeeExperimentAnalysis,
14 | MLMExperimentAnalysis,
15 | ],
16 | )
17 |
18 |
19 | @parametrisation
20 | def test_formula_no_covariates(analysis_class):
21 | analysis = analysis_class(
22 | cluster_cols=["cluster"],
23 | treatment_col="treatment",
24 | target_col="y",
25 | )
26 |
27 | assert analysis.formula == "y ~ treatment"
28 |
29 |
30 | @parametrisation
31 | def test_formula_with_covariates(analysis_class):
32 | analysis = analysis_class(
33 | cluster_cols=["cluster"],
34 | treatment_col="treatment",
35 | target_col="y",
36 | covariates=["covariate1", "covariate2"],
37 | )
38 |
39 | assert analysis.formula == "y ~ treatment + covariate1 + covariate2"
40 |
41 |
42 | @parametrisation
43 | def test_formula_with_interaction(analysis_class):
44 | analysis = analysis_class(
45 | cluster_cols=["cluster"],
46 | treatment_col="treatment",
47 | target_col="y",
48 | covariates=["covariate1", "covariate2"],
49 | add_covariate_interaction=True,
50 | )
51 |
52 | assert (
53 | analysis.formula
54 | == "y ~ treatment + covariate1 + covariate2 + __covariate1__interaction + __covariate2__interaction"
55 | )
56 |
--------------------------------------------------------------------------------
/tests/analysis/test_hypothesis.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from cluster_experiments.experiment_analysis import (
5 | ClusteredOLSAnalysis,
6 | GeeExperimentAnalysis,
7 | MLMExperimentAnalysis,
8 | OLSAnalysis,
9 | SyntheticControlAnalysis,
10 | TTestClusteredAnalysis,
11 | )
12 | from cluster_experiments.synthetic_control_utils import generate_synthetic_control_data
13 | from tests.utils import generate_clustered_data
14 |
15 |
16 | @pytest.mark.parametrize("hypothesis", ["less", "greater", "two-sided"])
17 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis])
18 | def test_get_pvalue_hypothesis(analysis_class, hypothesis, analysis_df):
19 | analysis_df_full = pd.concat([analysis_df for _ in range(100)])
20 | analyser = analysis_class(hypothesis=hypothesis)
21 | assert analyser.get_pvalue(analysis_df_full) >= 0
22 |
23 |
24 | @pytest.mark.parametrize("hypothesis", ["less", "greater", "two-sided"])
25 | @pytest.mark.parametrize(
26 | "analysis_class",
27 | [
28 | ClusteredOLSAnalysis,
29 | GeeExperimentAnalysis,
30 | TTestClusteredAnalysis,
31 | MLMExperimentAnalysis,
32 | ],
33 | )
34 | def test_get_pvalue_hypothesis_clustered(analysis_class, hypothesis):
35 |
36 | analysis_df_full = generate_clustered_data()
37 | analyser = analysis_class(hypothesis=hypothesis, cluster_cols=["user_id"])
38 | assert analyser.get_pvalue(analysis_df_full) >= 0
39 |
40 |
41 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis])
42 | def test_get_pvalue_hypothesis_default(analysis_class, analysis_df):
43 | analysis_df_full = pd.concat([analysis_df for _ in range(100)])
44 | analyser = analysis_class()
45 | assert analyser.get_pvalue(analysis_df_full) >= 0
46 |
47 |
48 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis])
49 | def test_get_pvalue_hypothesis_wrong_input(analysis_class, analysis_df):
50 | analysis_df_full = pd.concat([analysis_df for _ in range(100)])
51 |
52 | # Use pytest.raises to check for ValueError
53 | with pytest.raises(ValueError) as excinfo:
54 | analyser = analysis_class(hypothesis="wrong_input")
55 | analyser.get_pvalue(analysis_df_full) >= 0
56 |
57 | # Check if the error message is as expected
58 | assert "'wrong_input' is not a valid HypothesisEntries" in str(excinfo.value)
59 |
60 |
61 | @pytest.mark.parametrize("analysis_class", [OLSAnalysis])
62 | def test_several_hypothesis(analysis_class, analysis_df):
63 | analysis_df_full = pd.concat([analysis_df for _ in range(100)])
64 | analysis_less = analysis_class(hypothesis="less")
65 | analysis_greater = analysis_class(hypothesis="greater")
66 | analysis_two_sided = analysis_class(hypothesis="two-sided")
67 |
68 | assert (
69 | analysis_less.get_pvalue(analysis_df_full)
70 | == analysis_two_sided.get_pvalue(analysis_df_full) / 2
71 | )
72 | assert (
73 | analysis_greater.get_pvalue(analysis_df_full)
74 | == 1 - analysis_two_sided.get_pvalue(analysis_df_full) / 2
75 | )
76 |
77 |
78 | @pytest.mark.parametrize("hypothesis", ["less", "greater", "two-sided"])
79 | def test_hypothesis_synthetic(hypothesis):
80 |
81 | df = generate_synthetic_control_data(
82 | N=10, start_date="2022-01-01", end_date="2022-01-30"
83 | )
84 | # Add treatment column to only 1 user
85 | df["treatment"] = 0
86 | df.loc[(df["user"] == "User 5"), "treatment"] = 1
87 |
88 | analysis = SyntheticControlAnalysis(
89 | hypothesis=hypothesis, cluster_cols=["user"], intervention_date="2022-01-15"
90 | )
91 | assert analysis.analysis_pvalue(df) >= 0
92 |
--------------------------------------------------------------------------------
/tests/analysis/test_ols_analysis.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from cluster_experiments.experiment_analysis import OLSAnalysis
4 |
5 |
6 | def test_binary_treatment(analysis_df):
7 | analyser = OLSAnalysis()
8 | assert (
9 | analyser._create_binary_treatment(analysis_df)["treatment"]
10 | == pd.Series([0, 1, 1, 0])
11 | ).all()
12 |
13 |
14 | def test_get_pvalue(analysis_df):
15 | analysis_df_full = pd.concat([analysis_df for _ in range(100)])
16 | analyser = OLSAnalysis()
17 | assert analyser.get_pvalue(analysis_df_full) >= 0
18 |
19 |
20 | def test_cov_type(analysis_df):
21 | # given
22 | analyser_hc1 = OLSAnalysis(cov_type="HC1")
23 | analyser_hc3 = OLSAnalysis(cov_type="HC3")
24 |
25 | # then: point estimates are the same
26 | assert analyser_hc1.get_point_estimate(
27 | analysis_df
28 | ) == analyser_hc3.get_point_estimate(analysis_df)
29 |
30 | # then: standard errors are different
31 | assert analyser_hc1.get_standard_error(
32 | analysis_df
33 | ) != analyser_hc3.get_standard_error(analysis_df)
34 |
35 |
36 | def test_covariate_interaction(covariate_data):
37 | # given
38 | analysis_interaction = OLSAnalysis(
39 | treatment_col="T",
40 | target_col="y",
41 | covariates=["X"],
42 | add_covariate_interaction=True,
43 | )
44 | analysis_no_interaction = OLSAnalysis(
45 | treatment_col="T",
46 | target_col="y",
47 | covariates=["X"],
48 | add_covariate_interaction=False,
49 | )
50 |
51 | # when: calculating point estimates
52 | point_estimate_interaction = analysis_interaction.get_point_estimate(covariate_data)
53 | point_estimate_no_interaction = analysis_no_interaction.get_point_estimate(
54 | covariate_data
55 | )
56 |
57 | # then: point estimates are different
58 | assert analysis_interaction.formula != analysis_no_interaction.formula
59 | assert point_estimate_interaction != point_estimate_no_interaction
60 |
--------------------------------------------------------------------------------
/tests/analysis/test_synthetic_analysis.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 |
7 | from cluster_experiments.experiment_analysis import SyntheticControlAnalysis
8 | from cluster_experiments.synthetic_control_utils import get_w
9 |
10 |
11 | def generate_2_clusters_data(N, start_date, end_date):
12 | dates = pd.date_range(start_date, end_date, freq="d")
13 | country = ["US", "UK"]
14 | users = [f"User {i}" for i in range(N)]
15 |
16 | # Get the combination of each date with each user
17 | combinations = list(product(users, dates, country))
18 |
19 | df = pd.DataFrame(combinations, columns=["user", "date", "country"])
20 |
21 | df["target"] = np.random.normal(0, 1, size=len(combinations))
22 | # Ensure 'date' column is of datetime type
23 | df["date"] = pd.to_datetime(df["date"])
24 |
25 | return df
26 |
27 |
28 | def test_synthetic_control_analysis():
29 | df = generate_2_clusters_data(10, "2022-01-01", "2022-01-30")
30 |
31 | # Add treatment column to only 1 user
32 | df["treatment"] = "A"
33 | df.loc[(df["user"] == "User 5") & (df["country"] == "US"), "treatment"] = "B"
34 |
35 | analysis = SyntheticControlAnalysis(
36 | cluster_cols=["user", "country"], intervention_date="2022-01-06"
37 | )
38 |
39 | p_value = analysis.get_pvalue(df)
40 | assert 0 <= p_value <= 1
41 |
42 |
43 | @pytest.mark.parametrize(
44 | "X, y",
45 | [
46 | (
47 | np.array([[1, 2], [3, 4]]),
48 | np.array([1, 1]),
49 | ), # Scenario with positive integers
50 | (
51 | np.array([[1, -2], [-3, 4]]),
52 | np.array([1, 1]),
53 | ), # Scenario with negative integers
54 | (
55 | np.array([[1.5, 2.5], [3.5, 4.5]]),
56 | np.array([1, 1]),
57 | ), # Scenario with positive floats
58 | (
59 | np.array([[1.5, -2.5], [-3.5, 4.5]]),
60 | np.array([1, 1]),
61 | ), # Scenario with negative floats
62 | ],
63 | )
64 | def test_get_w_weights(
65 | X, y
66 | ): # this function is not part of the analysis, but it is used in it
67 | expected_sum = 1
68 | expected_bounds = (0, 1)
69 | weights = get_w(X, y)
70 | assert np.isclose(np.sum(weights), expected_sum), "Weights sum should be close to 1"
71 | assert all(
72 | expected_bounds[0] <= w <= expected_bounds[1] for w in weights
73 | ), "Each weight should be between 0 and 1"
74 |
75 |
76 | def test_get_treatment_cluster():
77 | analysis = SyntheticControlAnalysis(
78 | cluster_cols=["cluster"], intervention_date="2022-01-06"
79 | )
80 | df = pd.DataFrame(
81 | {
82 | "target": [1, 2, 3, 4, 5, 6],
83 | "treatment": [0, 0, 1, 1, 1, 0],
84 | "cluster": [
85 | "cluster1",
86 | "cluster2",
87 | "cluster3",
88 | "cluster3",
89 | "cluster3",
90 | "cluster2",
91 | ],
92 | }
93 | )
94 | expected_cluster = "cluster3"
95 | assert analysis._get_treatment_cluster(df) == expected_cluster
96 |
97 |
98 | def test_point_estimate_synthetic_control():
99 | df = generate_2_clusters_data(10, "2022-01-01", "2022-01-30")
100 |
101 | # Add treatment column to only 1 cluster
102 | df["treatment"] = 0
103 | df.loc[(df["user"] == "User 5") & (df["country"] == "US"), "treatment"] = 1
104 |
105 | df.loc[(df["user"] == "User 5") & (df["country"] == "US"), "target"] = 10
106 |
107 | analysis = SyntheticControlAnalysis(
108 | cluster_cols=["user", "country"], intervention_date="2022-01-06"
109 | )
110 |
111 | effect = analysis.analysis_point_estimate(df)
112 | assert 9 <= effect <= 11
113 |
114 |
115 | def test_predict():
116 | analysis = SyntheticControlAnalysis(
117 | cluster_cols=["user", "country"], intervention_date="2022-01-06"
118 | )
119 | df = generate_2_clusters_data(5, "2021-01-01", "2021-01-10")
120 | df["treatment"] = "A"
121 | df.loc[(df["user"] == "User 4") & (df["country"] == "US"), "treatment"] = "B"
122 |
123 | # Same effect to every donor cluster
124 | weights = np.array([0.2] * 9)
125 |
126 | treatment_cluster = "User 4US"
127 |
128 | result = analysis._predict(df, weights, treatment_cluster)
129 |
130 | # Check the results
131 | assert (
132 | "synthetic" in result.columns
133 | ), "The result DataFrame should include a 'synthetic' column"
134 | assert all(
135 | result["treatment"] == "B"
136 | ), "The result DataFrame should only contain the treatment cluster"
137 | assert len(result) > 0, "Should have at least one entry for the treatment cluster"
138 | assert (
139 | not result["synthetic"].isnull().any()
140 | ), "Synthetic column should not have null values"
141 |
--------------------------------------------------------------------------------
/tests/cupac/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/cupac/__init__.py
--------------------------------------------------------------------------------
/tests/cupac/conftest.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 |
5 | @pytest.fixture
6 | def binary_df():
7 | return pd.DataFrame(
8 | {
9 | "target": [0, 1, 0, 1],
10 | "treatment": ["A", "B", "B", "A"],
11 | }
12 | )
13 |
--------------------------------------------------------------------------------
/tests/cupac/test_aggregator.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pandas as pd
4 |
5 | from cluster_experiments.cupac import TargetAggregation
6 |
7 |
8 | def split_x_y(binary_df_agg: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series]:
9 | return binary_df_agg.drop("target", axis=1), binary_df_agg["target"]
10 |
11 |
12 | def test_set_target_aggs(binary_df):
13 | binary_df["user"] = [1, 1, 1, 1]
14 | ta = TargetAggregation(agg_col="user")
15 | X, y = split_x_y(binary_df)
16 | ta.fit(X, y)
17 |
18 | assert len(ta.pre_experiment_agg_df) == 1
19 | assert ta.pre_experiment_mean == 0.5
20 |
21 |
22 | def test_smoothing_0(binary_df):
23 | binary_df["user"] = binary_df["target"]
24 | ta = TargetAggregation(agg_col="user", smoothing_factor=0)
25 | X, y = split_x_y(binary_df)
26 | ta.fit(X, y)
27 | assert (
28 | ta.pre_experiment_agg_df["target_mean"]
29 | == ta.pre_experiment_agg_df["target_smooth_mean"]
30 | ).all()
31 |
32 |
33 | def test_smoothing_non_0(binary_df):
34 | binary_df["user"] = binary_df["target"]
35 | ta = TargetAggregation(agg_col="user", smoothing_factor=2)
36 | X, y = split_x_y(binary_df)
37 | ta.fit(X, y)
38 | assert (
39 | ta.pre_experiment_agg_df["target_mean"]
40 | != ta.pre_experiment_agg_df["target_smooth_mean"]
41 | ).all()
42 | assert (
43 | ta.pre_experiment_agg_df["target_smooth_mean"].loc[[0, 1]] == [0.25, 0.75]
44 | ).all()
45 |
46 |
47 | def test_add_aggs(binary_df):
48 | binary_df["user"] = binary_df["target"]
49 | ta = TargetAggregation(agg_col="user", smoothing_factor=2)
50 | X, y = split_x_y(binary_df)
51 | ta.fit(X, y)
52 | binary_df["target_smooth_mean"] = ta.predict(binary_df)
53 | assert (binary_df.query("user == 0")["target_smooth_mean"] == 0.25).all()
54 |
--------------------------------------------------------------------------------
/tests/cupac/test_cupac_handler.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | import numpy as np
4 | import pytest
5 | from sklearn.ensemble import HistGradientBoostingRegressor
6 |
7 | from cluster_experiments.cupac import CupacHandler, TargetAggregation
8 | from tests.utils import generate_random_data
9 |
10 | N = 1_000
11 |
12 |
13 | @pytest.fixture
14 | def clusters():
15 | return [f"Cluster {i}" for i in range(100)]
16 |
17 |
18 | @pytest.fixture
19 | def dates():
20 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
21 |
22 |
23 | @pytest.fixture
24 | def experiment_dates():
25 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)]
26 |
27 |
28 | @pytest.fixture
29 | def df(clusters, dates):
30 | return generate_random_data(clusters, dates, N)
31 |
32 |
33 | @pytest.fixture
34 | def df_feats(clusters, dates):
35 | df = generate_random_data(clusters, dates, N)
36 | df["x1"] = np.random.normal(0, 1, N)
37 | df["x2"] = np.random.normal(0, 1, N)
38 | return df
39 |
40 |
41 | @pytest.fixture
42 | def cupac_handler_base():
43 | return CupacHandler(
44 | cupac_model=TargetAggregation(
45 | agg_col="user",
46 | ),
47 | target_col="target",
48 | # features_cupac_model=["user"],
49 | )
50 |
51 |
52 | @pytest.fixture
53 | def cupac_handler_model():
54 | return CupacHandler(
55 | cupac_model=HistGradientBoostingRegressor(max_iter=3),
56 | target_col="target",
57 | features_cupac_model=["x1", "x2"],
58 | )
59 |
60 |
61 | @pytest.fixture
62 | def missing_cupac():
63 | return CupacHandler(
64 | None,
65 | )
66 |
67 |
68 | @pytest.mark.parametrize(
69 | "cupac_handler",
70 | [
71 | "cupac_handler_base",
72 | "cupac_handler_model",
73 | ],
74 | )
75 | def test_add_covariates(cupac_handler, df_feats, request):
76 | cupac_handler = request.getfixturevalue(cupac_handler)
77 | df = cupac_handler.add_covariates(df_feats, df_feats.head(10))
78 | assert df["estimate_target"].isna().sum() == 0
79 | assert (df["estimate_target"] <= df["target"].max()).all()
80 | assert (df["estimate_target"] >= df["target"].min()).all()
81 |
82 |
83 | def test_no_target(missing_cupac, df_feats):
84 | """Checks that no target is added when the there is no cupac model"""
85 | df = missing_cupac.add_covariates(df_feats)
86 | assert "estimate_target" not in df.columns
87 |
88 |
89 | def test_no_pre_experiment(cupac_handler_base, df_feats):
90 | """Checks that if there is a cupac model, pre_experiment_df should be provided"""
91 | with pytest.raises(ValueError, match="pre_experiment_df should be provided"):
92 | cupac_handler_base.add_covariates(df_feats)
93 |
94 |
95 | def test_no_cupac(missing_cupac, df_feats):
96 | """Checks that if there is no cupac model, pre_experiment_df should not be provided"""
97 | with pytest.raises(ValueError, match="remove pre_experiment_df argument"):
98 | missing_cupac.add_covariates(df_feats, df_feats.head(10))
99 |
--------------------------------------------------------------------------------
/tests/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/inference/__init__.py
--------------------------------------------------------------------------------
/tests/inference/test_analysis_plan_config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from cluster_experiments.inference.analysis_plan import AnalysisPlan
6 | from cluster_experiments.inference.dimension import Dimension
7 | from cluster_experiments.inference.hypothesis_test import HypothesisTest
8 | from cluster_experiments.inference.metric import SimpleMetric
9 | from cluster_experiments.inference.variant import Variant
10 |
11 |
12 | @pytest.fixture
13 | def experiment_data():
14 | N = 1_000
15 | return pd.DataFrame(
16 | {
17 | "order_value": np.random.normal(100, 10, size=N),
18 | "delivery_time": np.random.normal(10, 1, size=N),
19 | "experiment_group": np.random.choice(["control", "treatment"], size=N),
20 | "city": np.random.choice(["NYC", "LA"], size=N),
21 | "customer_id": np.random.randint(1, 100, size=N),
22 | "customer_age": np.random.randint(20, 60, size=N),
23 | }
24 | )
25 |
26 |
27 | def test_from_metrics_dict():
28 | d = {
29 | "metrics": [{"alias": "AOV", "name": "order_value"}],
30 | "variants": [
31 | {"name": "control", "is_control": True},
32 | {"name": "treatment_1", "is_control": False},
33 | ],
34 | "variant_col": "experiment_group",
35 | "alpha": 0.05,
36 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
37 | "analysis_type": "clustered_ols",
38 | "analysis_config": {"cluster_cols": ["customer_id"]},
39 | }
40 | plan = AnalysisPlan.from_metrics_dict(d)
41 | assert isinstance(plan, AnalysisPlan)
42 | assert len(plan.tests) == 1
43 | assert isinstance(plan.tests[0], HypothesisTest)
44 | assert plan.variant_col == "experiment_group"
45 | assert plan.alpha == 0.05
46 | assert len(plan.variants) == 2
47 |
48 |
49 | def test_analyze_from_metrics_dict(experiment_data):
50 | # given
51 | plan = AnalysisPlan.from_metrics_dict(
52 | {
53 | "metrics": [
54 | {"alias": "AOV", "name": "order_value"},
55 | {"alias": "delivery_time", "name": "delivery_time"},
56 | ],
57 | "variants": [
58 | {"name": "control", "is_control": True},
59 | {"name": "treatment", "is_control": False},
60 | ],
61 | "variant_col": "experiment_group",
62 | "alpha": 0.05,
63 | "dimensions": [
64 | {"name": "city", "values": ["NYC", "LA"]},
65 | ],
66 | "analysis_type": "clustered_ols",
67 | "analysis_config": {"cluster_cols": ["customer_id"]},
68 | }
69 | )
70 |
71 | # when
72 | results = plan.analyze(experiment_data)
73 | results_df = results.to_dataframe()
74 |
75 | # then
76 | assert (
77 | len(results_df) == 6
78 | ), "There should be 6 rows in the results DataFrame, 2 metrics x 3 dimension values"
79 | assert set(results_df["metric_alias"]) == {
80 | "AOV",
81 | "delivery_time",
82 | }, "The metric aliases should be present in the DataFrame"
83 | assert set(results_df["dimension_value"]) == {
84 | "total",
85 | "" "NYC",
86 | "LA",
87 | }, "The dimension values should be present in the DataFrame"
88 |
89 |
90 | def test_from_dict_config():
91 | # ensures that we get the same object when creating from a dict or a config
92 | # given
93 | d = {
94 | "tests": [
95 | {
96 | "metric": {"alias": "AOV", "name": "order_value"},
97 | "analysis_type": "clustered_ols",
98 | "analysis_config": {"cluster_cols": ["customer_id"]},
99 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
100 | },
101 | {
102 | "metric": {"alias": "delivery_time", "name": "delivery_time"},
103 | "analysis_type": "clustered_ols",
104 | "analysis_config": {"cluster_cols": ["customer_id"]},
105 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
106 | },
107 | ],
108 | "variants": [
109 | {"name": "control", "is_control": True},
110 | {"name": "treatment", "is_control": False},
111 | ],
112 | "variant_col": "experiment_group",
113 | "alpha": 0.05,
114 | }
115 | plan = AnalysisPlan(
116 | variants=[
117 | Variant(name="control", is_control=True),
118 | Variant(name="treatment", is_control=False),
119 | ],
120 | variant_col="experiment_group",
121 | alpha=0.05,
122 | tests=[
123 | HypothesisTest(
124 | metric=SimpleMetric(alias="AOV", name="order_value"),
125 | analysis_type="clustered_ols",
126 | analysis_config={"cluster_cols": ["customer_id"]},
127 | dimensions=[Dimension(name="city", values=["NYC", "LA"])],
128 | ),
129 | HypothesisTest(
130 | metric=SimpleMetric(alias="delivery_time", name="delivery_time"),
131 | analysis_type="clustered_ols",
132 | analysis_config={"cluster_cols": ["customer_id"]},
133 | dimensions=[Dimension(name="city", values=["NYC", "LA"])],
134 | ),
135 | ],
136 | )
137 |
138 | # when
139 | plan_from_config = AnalysisPlan.from_dict(d)
140 |
141 | # then
142 | assert plan.variant_col == plan_from_config.variant_col
143 | assert plan.alpha == plan_from_config.alpha
144 | for variant in plan.variants:
145 | assert variant in plan_from_config.variants
146 |
147 |
148 | def test_from_dict():
149 | # given
150 | d = {
151 | "tests": [
152 | {
153 | "metric": {"alias": "AOV", "name": "order_value"},
154 | "analysis_type": "clustered_ols",
155 | "analysis_config": {"cluster_cols": ["customer_id"]},
156 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
157 | },
158 | {
159 | "metric": {"alias": "DT", "name": "delivery_time"},
160 | "analysis_type": "clustered_ols",
161 | "analysis_config": {"cluster_cols": ["customer_id"]},
162 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
163 | },
164 | ],
165 | "variants": [
166 | {"name": "control", "is_control": True},
167 | {"name": "treatment_1", "is_control": False},
168 | ],
169 | "variant_col": "experiment_group",
170 | "alpha": 0.05,
171 | }
172 |
173 | # when
174 | plan = AnalysisPlan.from_dict(d)
175 |
176 | # then
177 | assert isinstance(plan, AnalysisPlan)
178 | assert len(plan.tests) == 2
179 | assert isinstance(plan.tests[0], HypothesisTest)
180 | assert plan.variant_col == "experiment_group"
181 | assert plan.alpha == 0.05
182 | assert len(plan.variants) == 2
183 | assert plan.tests[1].metric.alias == "DT"
184 |
185 |
186 | def test_analyze_from_dict(experiment_data):
187 | # given
188 | d = {
189 | "tests": [
190 | {
191 | "metric": {"alias": "AOV", "name": "order_value"},
192 | "analysis_type": "clustered_ols",
193 | "analysis_config": {"cluster_cols": ["customer_id"]},
194 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
195 | },
196 | {
197 | "metric": {"alias": "DT", "name": "delivery_time"},
198 | "analysis_type": "clustered_ols",
199 | "analysis_config": {"cluster_cols": ["customer_id"]},
200 | "dimensions": [{"name": "city", "values": ["NYC", "LA"]}],
201 | },
202 | ],
203 | "variants": [
204 | {"name": "control", "is_control": True},
205 | {"name": "treatment_1", "is_control": False},
206 | ],
207 | "variant_col": "experiment_group",
208 | "alpha": 0.05,
209 | }
210 | plan = AnalysisPlan.from_dict(d)
211 |
212 | # when
213 | results = plan.analyze(experiment_data)
214 | results_df = results.to_dataframe()
215 |
216 | # then
217 | assert (
218 | len(results_df) == 6
219 | ), "There should be 6 rows in the results DataFrame, 2 metrics x 3 dimension values"
220 | assert set(results_df["metric_alias"]) == {
221 | "AOV",
222 | "DT",
223 | }, "The metric aliases should be present in the DataFrame"
224 | assert set(results_df["dimension_value"]) == {
225 | "total",
226 | "" "NYC",
227 | "LA",
228 | }, "The dimension values should be present in the DataFrame"
229 |
--------------------------------------------------------------------------------
/tests/inference/test_analysis_results.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict
2 |
3 | import pandas as pd
4 | import pytest
5 |
6 | from cluster_experiments.inference.analysis_results import AnalysisPlanResults
7 |
8 |
9 | def test_analysis_plan_results_initialization():
10 | """Test that AnalysisPlanResults initializes with empty lists by default."""
11 | results = AnalysisPlanResults()
12 | assert results.metric_alias == []
13 | assert results.control_variant_name == []
14 | assert results.treatment_variant_name == []
15 | assert results.control_variant_mean == []
16 | assert results.treatment_variant_mean == []
17 | assert results.analysis_type == []
18 | assert results.ate == []
19 | assert results.ate_ci_lower == []
20 | assert results.ate_ci_upper == []
21 | assert results.p_value == []
22 | assert results.std_error == []
23 | assert results.dimension_name == []
24 | assert results.dimension_value == []
25 | assert results.alpha == []
26 |
27 |
28 | def test_analysis_plan_results_custom_initialization():
29 | """Test that AnalysisPlanResults initializes with custom data."""
30 | results = AnalysisPlanResults(
31 | metric_alias=["metric1"],
32 | control_variant_name=["Control"],
33 | treatment_variant_name=["Treatment"],
34 | control_variant_mean=[0.5],
35 | treatment_variant_mean=[0.6],
36 | analysis_type=["AB Test"],
37 | ate=[0.1],
38 | ate_ci_lower=[0.05],
39 | ate_ci_upper=[0.15],
40 | p_value=[0.04],
41 | std_error=[0.02],
42 | dimension_name=["Country"],
43 | dimension_value=["US"],
44 | alpha=[0.05],
45 | )
46 | assert results.metric_alias == ["metric1"]
47 | assert results.control_variant_name == ["Control"]
48 | assert results.treatment_variant_name == ["Treatment"]
49 | assert results.control_variant_mean == [0.5]
50 | assert results.treatment_variant_mean == [0.6]
51 | assert results.analysis_type == ["AB Test"]
52 | assert results.ate == [0.1]
53 | assert results.ate_ci_lower == [0.05]
54 | assert results.ate_ci_upper == [0.15]
55 | assert results.p_value == [0.04]
56 | assert results.std_error == [0.02]
57 | assert results.dimension_name == ["Country"]
58 | assert results.dimension_value == ["US"]
59 | assert results.alpha == [0.05]
60 |
61 |
62 | def test_analysis_plan_results_addition():
63 | """Test that two AnalysisPlanResults instances can be added together."""
64 | results1 = AnalysisPlanResults(
65 | metric_alias=["metric1"],
66 | control_variant_name=["Control"],
67 | treatment_variant_name=["Treatment"],
68 | control_variant_mean=[0.5],
69 | treatment_variant_mean=[0.6],
70 | analysis_type=["AB Test"],
71 | ate=[0.1],
72 | ate_ci_lower=[0.05],
73 | ate_ci_upper=[0.15],
74 | p_value=[0.04],
75 | std_error=[0.02],
76 | dimension_name=["Country"],
77 | dimension_value=["US"],
78 | alpha=[0.05],
79 | )
80 | results2 = AnalysisPlanResults(
81 | metric_alias=["metric2"],
82 | control_variant_name=["Control"],
83 | treatment_variant_name=["Treatment"],
84 | control_variant_mean=[0.55],
85 | treatment_variant_mean=[0.65],
86 | analysis_type=["AB Test"],
87 | ate=[0.1],
88 | ate_ci_lower=[0.05],
89 | ate_ci_upper=[0.15],
90 | p_value=[0.03],
91 | std_error=[0.01],
92 | dimension_name=["Country"],
93 | dimension_value=["CA"],
94 | alpha=[0.05],
95 | )
96 | combined_results = results1 + results2
97 |
98 | assert combined_results.metric_alias == ["metric1", "metric2"]
99 | assert combined_results.control_variant_name == ["Control", "Control"]
100 | assert combined_results.treatment_variant_name == ["Treatment", "Treatment"]
101 | assert combined_results.control_variant_mean == [0.5, 0.55]
102 | assert combined_results.treatment_variant_mean == [0.6, 0.65]
103 | assert combined_results.analysis_type == ["AB Test", "AB Test"]
104 | assert combined_results.ate == [0.1, 0.1]
105 | assert combined_results.ate_ci_lower == [0.05, 0.05]
106 | assert combined_results.ate_ci_upper == [0.15, 0.15]
107 | assert combined_results.p_value == [0.04, 0.03]
108 | assert combined_results.std_error == [0.02, 0.01]
109 | assert combined_results.dimension_name == ["Country", "Country"]
110 | assert combined_results.dimension_value == ["US", "CA"]
111 | assert combined_results.alpha == [0.05, 0.05]
112 |
113 |
114 | def test_analysis_plan_results_addition_type_error():
115 | """Test that adding a non-AnalysisPlanResults object raises a TypeError."""
116 | results = AnalysisPlanResults(metric_alias=["metric1"])
117 | with pytest.raises(TypeError):
118 | results + "not_an_analysis_plan_results" # Should raise TypeError
119 |
120 |
121 | def test_analysis_plan_results_to_dataframe():
122 | """Test that AnalysisPlanResults converts to a DataFrame correctly."""
123 | results = AnalysisPlanResults(
124 | metric_alias=["metric1"],
125 | control_variant_name=["Control"],
126 | treatment_variant_name=["Treatment"],
127 | control_variant_mean=[0.5],
128 | treatment_variant_mean=[0.6],
129 | analysis_type=["AB Test"],
130 | ate=[0.1],
131 | ate_ci_lower=[0.05],
132 | ate_ci_upper=[0.15],
133 | p_value=[0.04],
134 | std_error=[0.02],
135 | dimension_name=["Country"],
136 | dimension_value=["US"],
137 | alpha=[0.05],
138 | )
139 | df = results.to_dataframe()
140 |
141 | assert isinstance(df, pd.DataFrame)
142 | assert df.shape[0] == 1 # Only one entry
143 | assert set(df.columns) == set(asdict(results).keys()) # Columns match attributes
144 | assert df["metric_alias"].iloc[0] == "metric1"
145 | assert df["control_variant_name"].iloc[0] == "Control"
146 | assert df["treatment_variant_name"].iloc[0] == "Treatment"
147 | assert df["control_variant_mean"].iloc[0] == 0.5
148 | assert df["treatment_variant_mean"].iloc[0] == 0.6
149 | assert df["ate"].iloc[0] == 0.1
150 | assert df["ate_ci_lower"].iloc[0] == 0.05
151 | assert df["ate_ci_upper"].iloc[0] == 0.15
152 | assert df["p_value"].iloc[0] == 0.04
153 | assert df["std_error"].iloc[0] == 0.02
154 | assert df["dimension_name"].iloc[0] == "Country"
155 | assert df["dimension_value"].iloc[0] == "US"
156 | assert df["alpha"].iloc[0] == 0.05
157 |
--------------------------------------------------------------------------------
/tests/inference/test_dimension.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments.inference.dimension import DefaultDimension, Dimension
4 |
5 |
6 | def test_dimension_initialization():
7 | """Test Dimension initialization with valid inputs."""
8 | dim = Dimension(name="Country", values=["US", "CA", "UK"])
9 | assert dim.name == "Country"
10 | assert dim.values == ["US", "CA", "UK"]
11 |
12 |
13 | def test_dimension_name_type():
14 | """Test that Dimension raises TypeError if name is not a string."""
15 | with pytest.raises(TypeError, match="Dimension name must be a string"):
16 | Dimension(name=123, values=["US", "CA", "UK"]) # Name should be a string
17 |
18 |
19 | def test_dimension_values_type():
20 | """Test that Dimension raises TypeError if values is not a list of strings."""
21 | # Values should be a list
22 | with pytest.raises(TypeError, match="Dimension values must be a list of strings"):
23 | Dimension(name="Country", values="US, CA, UK") # Should be a list of strings
24 |
25 | # Values should be a list of strings
26 | with pytest.raises(TypeError, match="Dimension values must be a list of strings"):
27 | Dimension(
28 | name="Country", values=["US", 123, "UK"]
29 | ) # All elements should be strings
30 |
31 |
32 | def test_dimension_iterate_dimension_values():
33 | """Test Dimension iterate_dimension_values method to ensure unique values are returned."""
34 | dim = Dimension(name="Country", values=["US", "CA", "US", "UK", "CA"])
35 | unique_values = list(dim.iterate_dimension_values())
36 | assert unique_values == ["US", "CA", "UK"] # Ensures unique, ordered values
37 |
38 |
39 | def test_default_dimension_initialization():
40 | """Test DefaultDimension initialization."""
41 | default_dim = DefaultDimension()
42 | assert default_dim.name == "__total_dimension"
43 | assert default_dim.values == ["total"]
44 |
45 |
46 | def test_default_dimension_iterate_dimension_values():
47 | """Test that DefaultDimension's iterate_dimension_values yields 'total'."""
48 | default_dim = DefaultDimension()
49 | values = list(default_dim.iterate_dimension_values())
50 | assert values == ["total"]
51 |
--------------------------------------------------------------------------------
/tests/inference/test_metric.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from cluster_experiments.inference.metric import Metric, RatioMetric, SimpleMetric
5 |
6 | # Sample DataFrame for testing
7 | sample_data = pd.DataFrame(
8 | {
9 | "salary": [50000, 60000, 70000],
10 | "numerator": [150000, 180000, 210000],
11 | "denominator": [3, 3, 3],
12 | }
13 | )
14 |
15 |
16 | def test_metric_abstract_instantiation():
17 | """Test that Metric cannot be instantiated directly."""
18 | with pytest.raises(TypeError):
19 | Metric("test_metric")
20 |
21 |
22 | def test_metric_alias_type():
23 | """Test that Metric raises TypeError if alias is not a string."""
24 | with pytest.raises(TypeError):
25 | SimpleMetric(123, "salary") # Alias should be a string
26 |
27 |
28 | def test_simple_metric_initialization():
29 | """Test SimpleMetric initialization and target column."""
30 | metric = SimpleMetric("test_metric", "salary")
31 | assert metric.alias == "test_metric"
32 | assert metric.target_column == "salary"
33 |
34 |
35 | def test_simple_metric_name_type():
36 | """Test that SimpleMetric raises TypeError if name is not a string."""
37 | with pytest.raises(TypeError):
38 | SimpleMetric("test_metric", 123) # Name should be a string
39 |
40 |
41 | def test_simple_metric_get_mean():
42 | """Test SimpleMetric get_mean() calculation."""
43 | metric = SimpleMetric("test_metric", "salary")
44 | mean_value = metric.get_mean(sample_data)
45 | assert mean_value == 60000 # Mean of [50000, 60000, 70000]
46 |
47 |
48 | def test_ratio_metric_initialization():
49 | """Test RatioMetric initialization and target column."""
50 | metric = RatioMetric("test_ratio_metric", "numerator", "denominator")
51 | assert metric.alias == "test_ratio_metric"
52 | assert metric.target_column == "numerator"
53 |
54 |
55 | def test_ratio_metric_names_type():
56 | """Test that RatioMetric raises TypeError if numerator or denominator are not strings."""
57 | with pytest.raises(TypeError):
58 | RatioMetric(
59 | "test_ratio_metric", "numerator", 123
60 | ) # Denominator should be a string
61 | with pytest.raises(TypeError):
62 | RatioMetric(
63 | "test_ratio_metric", 123, "denominator"
64 | ) # Numerator should be a string
65 |
66 |
67 | def test_ratio_metric_get_mean():
68 | """Test RatioMetric get_mean() calculation."""
69 | metric = RatioMetric("test_ratio_metric", "numerator", "denominator")
70 | mean_value = metric.get_mean(sample_data)
71 | expected_value = sample_data["numerator"].mean() / sample_data["denominator"].mean()
72 | assert mean_value == expected_value
73 |
--------------------------------------------------------------------------------
/tests/inference/test_variant.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments.inference.variant import Variant
4 |
5 |
6 | def test_variant_initialization():
7 | """Test Variant initialization with valid inputs."""
8 | variant = Variant(name="Test Variant", is_control=True)
9 | assert variant.name == "Test Variant"
10 | assert variant.is_control is True
11 |
12 |
13 | def test_variant_name_type():
14 | """Test that Variant raises TypeError if name is not a string."""
15 | with pytest.raises(TypeError, match="Variant name must be a string"):
16 | Variant(name=123, is_control=True) # Name should be a string
17 |
18 |
19 | def test_variant_is_control_type():
20 | """Test that Variant raises TypeError if is_control is not a boolean."""
21 | with pytest.raises(TypeError, match="Variant is_control must be a boolean"):
22 | Variant(name="Test Variant", is_control="yes") # is_control should be a boolean
23 |
24 |
25 | def test_variant_is_control_default_behavior():
26 | """Test Variant behavior when is_control is set to False."""
27 | variant = Variant(name="Variant B", is_control=False)
28 | assert variant.name == "Variant B"
29 | assert variant.is_control is False
30 |
--------------------------------------------------------------------------------
/tests/perturbator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/perturbator/__init__.py
--------------------------------------------------------------------------------
/tests/perturbator/conftest.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 |
5 | @pytest.fixture
6 | def binary_df():
7 | return pd.DataFrame(
8 | {
9 | "target": [0, 1, 0, 1],
10 | "treatment": ["A", "B", "B", "A"],
11 | }
12 | )
13 |
14 |
15 | @pytest.fixture
16 | def continuous_df():
17 | return pd.DataFrame(
18 | {
19 | "target": [0.5, 0.5, 0.5, 0.5],
20 | "treatment": ["A", "B", "B", "A"],
21 | }
22 | )
23 |
24 |
25 | @pytest.fixture
26 | def generate_clustered_data() -> pd.DataFrame:
27 | analysis_df = pd.DataFrame(
28 | {
29 | "country_code": ["ES"] * 4 + ["IT"] * 4 + ["PL"] * 4 + ["RO"] * 4,
30 | "city_code": ["BCN", "BCN", "MAD", "BCN"]
31 | + ["NAP"] * 4
32 | + ["WAW"] * 4
33 | + ["BUC"] * 4,
34 | "user_id": [1, 1, 2, 1, 3, 4, 5, 6, 7, 8, 8, 8, 9, 9, 9, 10],
35 | "date": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"] * 4,
36 | "treatment": [
37 | "A",
38 | "A",
39 | "B",
40 | "A",
41 | "B",
42 | "B",
43 | "A",
44 | "B",
45 | "B",
46 | "A",
47 | "A",
48 | "A",
49 | "B",
50 | "B",
51 | "B",
52 | "A",
53 | ], # Randomization is done at user level, so same user will always have same treatment
54 | "target": [0.01] * 15 + [0.1],
55 | }
56 | )
57 | return analysis_df
58 |
59 |
60 | @pytest.fixture
61 | def continuous_mixed_df():
62 | return pd.DataFrame(
63 | {
64 | "target": [0.5, -50, 50, 0.5],
65 | "treatment": ["A", "B", "B", "A"],
66 | }
67 | )
68 |
--------------------------------------------------------------------------------
/tests/power_analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/power_analysis/__init__.py
--------------------------------------------------------------------------------
/tests/power_analysis/conftest.py:
--------------------------------------------------------------------------------
1 | from datetime import date, timedelta
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 |
7 | from cluster_experiments.cupac import TargetAggregation
8 | from cluster_experiments.experiment_analysis import (
9 | ClusteredOLSAnalysis,
10 | GeeExperimentAnalysis,
11 | MLMExperimentAnalysis,
12 | )
13 | from cluster_experiments.perturbator import ConstantPerturbator
14 | from cluster_experiments.power_analysis import PowerAnalysis
15 | from cluster_experiments.random_splitter import (
16 | ClusteredSplitter,
17 | StratifiedSwitchbackSplitter,
18 | )
19 | from tests.utils import generate_random_data, generate_ratio_metric_data
20 |
21 | N = 1_000
22 |
23 |
24 | @pytest.fixture
25 | def clusters():
26 | return [f"Cluster {i}" for i in range(100)]
27 |
28 |
29 | @pytest.fixture
30 | def dates():
31 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 32)]
32 |
33 |
34 | @pytest.fixture
35 | def experiment_dates():
36 | return [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(15, 32)]
37 |
38 |
39 | @pytest.fixture
40 | def df(clusters, dates):
41 | return generate_random_data(clusters, dates, N)
42 |
43 |
44 | @pytest.fixture
45 | def correlated_df():
46 | _n_rows = 10_000
47 | _clusters = [f"Cluster {i}" for i in range(10)]
48 | _dates = [f"{date(2022, 1, i):%Y-%m-%d}" for i in range(1, 15)]
49 | df = pd.DataFrame(
50 | {
51 | "cluster": np.random.choice(_clusters, size=_n_rows),
52 | "date": np.random.choice(_dates, size=_n_rows),
53 | }
54 | ).assign(
55 | # Target is a linear combination of cluster and day of week, plus some noise
56 | cluster_id=lambda df: df["cluster"].astype("category").cat.codes,
57 | day_of_week=lambda df: pd.to_datetime(df["date"]).dt.dayofweek,
58 | target=lambda df: df["cluster_id"]
59 | + df["day_of_week"]
60 | + np.random.normal(size=_n_rows),
61 | )
62 | return df
63 |
64 |
65 | @pytest.fixture
66 | def df_feats(clusters, dates):
67 | df = generate_random_data(clusters, dates, N)
68 | df["x1"] = np.random.normal(0, 1, N)
69 | df["x2"] = np.random.normal(0, 1, N)
70 | return df
71 |
72 |
73 | @pytest.fixture
74 | def df_binary(clusters, dates):
75 | return generate_random_data(clusters, dates, N, target="binary")
76 |
77 |
78 | @pytest.fixture
79 | def perturbator():
80 | return ConstantPerturbator(average_effect=0.1)
81 |
82 |
83 | @pytest.fixture
84 | def analysis_gee_vainilla():
85 | return GeeExperimentAnalysis(
86 | cluster_cols=["cluster", "date"],
87 | )
88 |
89 |
90 | @pytest.fixture
91 | def analysis_clusterd_ols():
92 | return ClusteredOLSAnalysis(
93 | cluster_cols=["cluster", "date"],
94 | )
95 |
96 |
97 | @pytest.fixture
98 | def analysis_mlm():
99 | return MLMExperimentAnalysis(
100 | cluster_cols=["cluster", "date"],
101 | )
102 |
103 |
104 | @pytest.fixture
105 | def analysis_gee():
106 | return GeeExperimentAnalysis(
107 | cluster_cols=["cluster", "date"],
108 | covariates=["estimate_target"],
109 | )
110 |
111 |
112 | @pytest.fixture
113 | def cupac_power_analysis(perturbator, analysis_gee):
114 | sw = ClusteredSplitter(
115 | cluster_cols=["cluster", "date"],
116 | )
117 |
118 | target_agg = TargetAggregation(
119 | agg_col="cluster",
120 | )
121 |
122 | return PowerAnalysis(
123 | perturbator=perturbator,
124 | splitter=sw,
125 | analysis=analysis_gee,
126 | cupac_model=target_agg,
127 | n_simulations=3,
128 | )
129 |
130 |
131 | @pytest.fixture
132 | def switchback_power_analysis(perturbator, analysis_gee_vainilla):
133 | sw = StratifiedSwitchbackSplitter(
134 | time_col="date",
135 | switch_frequency="1D",
136 | strata_cols=["cluster"],
137 | cluster_cols=["cluster", "date"],
138 | )
139 |
140 | return PowerAnalysis(
141 | perturbator=perturbator,
142 | splitter=sw,
143 | analysis=analysis_gee_vainilla,
144 | n_simulations=3,
145 | seed=123,
146 | )
147 |
148 |
149 | @pytest.fixture
150 | def switchback_power_analysis_hourly(perturbator, analysis_gee_vainilla):
151 | sw = StratifiedSwitchbackSplitter(
152 | time_col="date",
153 | switch_frequency="1H",
154 | strata_cols=["cluster"],
155 | cluster_cols=["cluster", "date"],
156 | )
157 |
158 | return PowerAnalysis(
159 | perturbator=perturbator,
160 | splitter=sw,
161 | analysis=analysis_gee_vainilla,
162 | n_simulations=3,
163 | )
164 |
165 |
166 | @pytest.fixture
167 | def switchback_washover():
168 | return PowerAnalysis.from_dict(
169 | {
170 | "time_col": "date",
171 | "switch_frequency": "1D",
172 | "perturbator": "constant",
173 | "analysis": "ols_clustered",
174 | "splitter": "switchback_balance",
175 | "cluster_cols": ["cluster", "date"],
176 | "strata_cols": ["cluster"],
177 | "washover": "constant_washover",
178 | "washover_time_delta": timedelta(hours=2),
179 | }
180 | )
181 |
182 |
183 | @pytest.fixture
184 | def delta_df(experiment_dates):
185 |
186 | user_sample_mean = 0.3
187 | user_standard_error = 0.15
188 | users = 2000
189 | N = 50_000
190 |
191 | user_target_means = np.random.normal(user_sample_mean, user_standard_error, users)
192 |
193 | data = generate_ratio_metric_data(
194 | experiment_dates, N, user_target_means, users, treatment_effect=0
195 | )
196 | return data
197 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_cupac_power.py:
--------------------------------------------------------------------------------
1 | from sklearn.ensemble import HistGradientBoostingRegressor
2 |
3 |
4 | def test_power_analyis_aggregate(df, experiment_dates, cupac_power_analysis):
5 | df_analysis = df.query(f"date.isin({experiment_dates})")
6 | df_pre = df.query(f"~date.isin({experiment_dates})")
7 | power = cupac_power_analysis.power_analysis(df_analysis, df_pre)
8 | assert power >= 0
9 | assert power <= 1
10 |
11 |
12 | def test_add_covariates(df, experiment_dates, cupac_power_analysis):
13 | df_analysis = df.query(f"date.isin({experiment_dates})")
14 | df_pre = df.query(f"~date.isin({experiment_dates})")
15 | estimated_target = cupac_power_analysis.cupac_handler.add_covariates(
16 | df_analysis, df_pre
17 | )["estimate_target"]
18 | assert estimated_target.isnull().sum() == 0
19 | assert (estimated_target <= df_pre["target"].max()).all()
20 | assert (estimated_target >= df_pre["target"].min()).all()
21 | assert "estimate_target" in cupac_power_analysis.analysis.covariates
22 |
23 |
24 | def test_prep_data(df_feats, experiment_dates, cupac_power_analysis):
25 | df = df_feats.copy()
26 | df_analysis = df.query(f"date.isin({experiment_dates})")
27 | df_pre = df.query(f"~date.isin({experiment_dates})")
28 | cupac_power_analysis.cupac_handler.features_cupac_model = ["x1", "x2"]
29 | (
30 | df_predict,
31 | pre_experiment_x,
32 | pre_experiment_y,
33 | ) = cupac_power_analysis.cupac_handler._prep_data_cupac(df_analysis, df_pre)
34 | assert list(df_predict.columns) == ["x1", "x2"]
35 | assert list(pre_experiment_x.columns) == ["x1", "x2"]
36 | assert (df_predict["x1"] == df_analysis["x1"]).all()
37 | assert (pre_experiment_x["x1"] == df_pre["x1"]).all()
38 | assert (pre_experiment_y == df_pre["target"]).all()
39 |
40 |
41 | def test_cupac_gbm(df_feats, experiment_dates, cupac_power_analysis):
42 | df = df_feats.copy()
43 | df_analysis = df.query(f"date.isin({experiment_dates})")
44 | df_pre = df.query(f"~date.isin({experiment_dates})")
45 | cupac_power_analysis.features_cupac_model = ["x1", "x2"]
46 | cupac_power_analysis.cupac_model = HistGradientBoostingRegressor()
47 | power = cupac_power_analysis.power_analysis(df_analysis, df_pre)
48 | assert power >= 0
49 | assert power <= 1
50 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_multivariate.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments import (
4 | ConstantPerturbator,
5 | NonClusteredSplitter,
6 | OLSAnalysis,
7 | PowerAnalysis,
8 | )
9 | from tests.utils import generate_non_clustered_data
10 |
11 |
12 | @pytest.fixture
13 | def df():
14 | return generate_non_clustered_data(
15 | N=1000,
16 | n_users=100,
17 | )
18 |
19 |
20 | @pytest.fixture
21 | def perturbator():
22 | return ConstantPerturbator()
23 |
24 |
25 | @pytest.fixture
26 | def ols():
27 | return OLSAnalysis()
28 |
29 |
30 | @pytest.fixture
31 | def binary_hypothesis_power(perturbator, ols):
32 | splitter = NonClusteredSplitter(
33 | treatments=["A", "B"],
34 | treatment_col="treatment",
35 | )
36 | return PowerAnalysis(
37 | splitter=splitter,
38 | perturbator=perturbator,
39 | analysis=ols,
40 | )
41 |
42 |
43 | @pytest.fixture
44 | def multivariate_hypothesis_power(perturbator, ols):
45 | splitter = NonClusteredSplitter(
46 | treatments=["A", "B", "C", "D", "E", "F", "G"],
47 | treatment_col="treatment",
48 | )
49 | return PowerAnalysis(
50 | splitter=splitter,
51 | perturbator=perturbator,
52 | analysis=ols,
53 | )
54 |
55 |
56 | @pytest.fixture
57 | def binary_hypothesis_power_config():
58 | config = {
59 | "analysis": "ols_non_clustered",
60 | "perturbator": "constant",
61 | "splitter": "non_clustered",
62 | "n_simulations": 50,
63 | "seed": 220924,
64 | }
65 | return PowerAnalysis.from_dict(config)
66 |
67 |
68 | @pytest.fixture
69 | def multivariate_hypothesis_power_config():
70 | config = {
71 | "analysis": "ols_non_clustered",
72 | "perturbator": "constant",
73 | "splitter": "non_clustered",
74 | "n_simulations": 50,
75 | "treatments": ["A", "B", "C", "D", "E", "F", "G"],
76 | "seed": 220924,
77 | }
78 | return PowerAnalysis.from_dict(config)
79 |
80 |
81 | def test_higher_power_analysis(
82 | multivariate_hypothesis_power,
83 | binary_hypothesis_power,
84 | df,
85 | ):
86 | power_multi = multivariate_hypothesis_power.power_analysis(df, average_effect=0.1)
87 | power_binary = binary_hypothesis_power.power_analysis(df, average_effect=0.1)
88 | assert power_multi < power_binary, f"{power_multi = } > {power_binary = }"
89 |
90 |
91 | def test_higher_power_analysis_config(
92 | multivariate_hypothesis_power_config,
93 | binary_hypothesis_power_config,
94 | df,
95 | ):
96 | power_multi = multivariate_hypothesis_power_config.power_analysis(
97 | df, average_effect=0.1
98 | )
99 | power_binary = binary_hypothesis_power_config.power_analysis(df, average_effect=0.1)
100 | assert power_multi < power_binary, f"{power_multi = } > {power_binary = }"
101 |
102 |
103 | def test_raise_if_control_not_in_treatments(
104 | perturbator,
105 | ols,
106 | ):
107 | with pytest.raises(AssertionError):
108 | splitter = NonClusteredSplitter(
109 | treatments=["a", "B"],
110 | treatment_col="treatment",
111 | splitter_weights=[0.5, 0.5],
112 | )
113 | PowerAnalysis(
114 | splitter=splitter,
115 | perturbator=perturbator,
116 | analysis=ols,
117 | )
118 | with pytest.raises(AssertionError):
119 | splitter = NonClusteredSplitter(
120 | treatment_col="treatment",
121 | splitter_weights=[0.5, 0.5],
122 | )
123 | PowerAnalysis(
124 | splitter=splitter, perturbator=perturbator, analysis=ols, control="X"
125 | )
126 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_parallel.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments.power_analysis import PowerAnalysis
4 | from cluster_experiments.power_config import PowerConfig
5 |
6 |
7 | def test_raise_n_jobs(df):
8 | config = PowerConfig(
9 | cluster_cols=["cluster", "date"],
10 | analysis="gee",
11 | perturbator="constant",
12 | splitter="clustered",
13 | n_simulations=4,
14 | )
15 | pw = PowerAnalysis.from_config(config)
16 | with pytest.raises(ValueError):
17 | pw.power_analysis(df, average_effect=0.0, n_jobs=0)
18 | with pytest.raises(ValueError):
19 | pw.power_analysis(df, average_effect=0.0, n_jobs=-10)
20 |
21 |
22 | def test_similar_n_jobs(df):
23 | config = PowerConfig(
24 | analysis="ols_non_clustered",
25 | perturbator="constant",
26 | splitter="non_clustered",
27 | n_simulations=100,
28 | seed=123,
29 | )
30 | pw = PowerAnalysis.from_config(config)
31 | power = pw.power_analysis(df, average_effect=0.0, n_jobs=1)
32 | power2 = pw.power_analysis(df, average_effect=0.0, n_jobs=2)
33 | power3 = pw.power_analysis(df, average_effect=0.0, n_jobs=-1)
34 | assert abs(power - power2) <= 0.1
35 | assert abs(power - power3) <= 0.1
36 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_power_analysis.py:
--------------------------------------------------------------------------------
1 | from cluster_experiments.power_analysis import PowerAnalysis
2 | from cluster_experiments.power_config import PowerConfig
3 | from cluster_experiments.random_splitter import ClusteredSplitter
4 |
5 |
6 | def test_power_analysis(df, perturbator, analysis_gee_vainilla):
7 | sw = ClusteredSplitter(
8 | cluster_cols=["cluster", "date"],
9 | )
10 |
11 | pw = PowerAnalysis(
12 | perturbator=perturbator,
13 | splitter=sw,
14 | analysis=analysis_gee_vainilla,
15 | n_simulations=3,
16 | )
17 |
18 | power = pw.power_analysis(df)
19 | assert power >= 0
20 | assert power <= 1
21 |
22 |
23 | def test_power_analysis_config(df):
24 | config = PowerConfig(
25 | cluster_cols=["cluster", "date"],
26 | analysis="gee",
27 | perturbator="constant",
28 | splitter="clustered",
29 | n_simulations=4,
30 | average_effect=0.0,
31 | )
32 | pw = PowerAnalysis.from_config(config)
33 | power = pw.power_analysis(df)
34 | assert power >= 0
35 | assert power <= 1
36 |
37 |
38 | def test_power_analysis_config_avg_effect(df):
39 | config = PowerConfig(
40 | cluster_cols=["cluster", "date"],
41 | analysis="gee",
42 | perturbator="constant",
43 | splitter="clustered",
44 | n_simulations=4,
45 | )
46 | pw = PowerAnalysis.from_config(config)
47 | power = pw.power_analysis(df, average_effect=0.0)
48 | assert power >= 0
49 | assert power <= 1
50 |
51 |
52 | def test_power_analysis_dict(df):
53 | config = dict(
54 | analysis="ols_non_clustered",
55 | perturbator="constant",
56 | splitter="non_clustered",
57 | n_simulations=4,
58 | )
59 | pw = PowerAnalysis.from_dict(config)
60 | power = pw.power_analysis(df, average_effect=0.0)
61 | assert power >= 0
62 | assert power <= 1
63 |
64 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0)
65 | assert power_verbose >= 0
66 | assert power_verbose <= 1
67 |
68 |
69 | def test_different_names(df):
70 | df = df.rename(
71 | columns={
72 | "cluster": "cluster_0",
73 | "target": "target_0",
74 | "date": "date_0",
75 | }
76 | )
77 | config = dict(
78 | cluster_cols=["cluster_0", "date_0"],
79 | analysis="ols_clustered",
80 | perturbator="constant",
81 | splitter="clustered",
82 | n_simulations=4,
83 | treatment_col="treatment_0",
84 | target_col="target_0",
85 | )
86 | pw = PowerAnalysis.from_dict(config)
87 | power = pw.power_analysis(df, average_effect=0.0)
88 | assert power >= 0
89 | assert power <= 1
90 |
91 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0)
92 | assert power_verbose >= 0
93 | assert power_verbose <= 1
94 |
95 |
96 | def test_ttest(df):
97 | config = dict(
98 | cluster_cols=["cluster", "date"],
99 | analysis="ttest_clustered",
100 | perturbator="constant",
101 | splitter="clustered",
102 | n_simulations=4,
103 | )
104 | pw = PowerAnalysis.from_dict(config)
105 | power = pw.power_analysis(df, average_effect=0.0)
106 | assert power >= 0
107 | assert power <= 1
108 |
109 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0)
110 | assert power_verbose >= 0
111 | assert power_verbose <= 1
112 |
113 |
114 | def test_paired_ttest(df):
115 | config = dict(
116 | cluster_cols=["cluster", "date"],
117 | strata_cols=["cluster"],
118 | analysis="paired_ttest_clustered",
119 | perturbator="constant",
120 | splitter="clustered",
121 | n_simulations=4,
122 | )
123 | pw = PowerAnalysis.from_dict(config)
124 |
125 | power = pw.power_analysis(df, average_effect=0.0)
126 | assert power >= 0
127 | assert power <= 1
128 |
129 | power_verbose = pw.power_analysis(df, verbose=True, average_effect=0.0)
130 | assert power_verbose >= 0
131 | assert power_verbose <= 1
132 |
133 |
134 | def test_delta(delta_df):
135 | config = dict(
136 | cluster_cols=["user", "date"],
137 | scale_col="scale",
138 | analysis="delta",
139 | perturbator="constant",
140 | splitter="clustered",
141 | n_simulations=4,
142 | )
143 | pw = PowerAnalysis.from_dict(config)
144 |
145 | delta_df = delta_df.drop(columns=["treatment"])
146 |
147 | power = pw.power_analysis(delta_df, average_effect=0.0)
148 | assert power >= 0
149 | assert power <= 1
150 |
151 |
152 | def test_power_alpha(df):
153 | config = PowerConfig(
154 | analysis="ols_non_clustered",
155 | perturbator="constant",
156 | splitter="non_clustered",
157 | n_simulations=10,
158 | average_effect=0.0,
159 | alpha=0.05,
160 | )
161 | pw = PowerAnalysis.from_config(config)
162 | power_50 = pw.power_analysis(df, alpha=0.5, verbose=True)
163 | power_01 = pw.power_analysis(df, alpha=0.01)
164 |
165 | assert power_50 > power_01
166 |
167 |
168 | def test_length_simulation(df):
169 | config = PowerConfig(
170 | cluster_cols=["cluster", "date"],
171 | analysis="ols_clustered",
172 | perturbator="constant",
173 | splitter="clustered",
174 | n_simulations=10,
175 | average_effect=0.0,
176 | alpha=0.05,
177 | )
178 | pw = PowerAnalysis.from_config(config)
179 | i = 0
180 | for _ in pw.simulate_pvalue(df, n_simulations=5):
181 | i += 1
182 | assert i == 5
183 |
184 |
185 | def test_point_estimate_gee(df):
186 | config = PowerConfig(
187 | cluster_cols=["cluster", "date"],
188 | analysis="gee",
189 | perturbator="constant",
190 | splitter="clustered",
191 | n_simulations=10,
192 | average_effect=5.0,
193 | alpha=0.05,
194 | )
195 | pw = PowerAnalysis.from_config(config)
196 | for point_estimate in pw.simulate_point_estimate(df, n_simulations=1):
197 | assert point_estimate > 0.0
198 |
199 |
200 | def test_point_estimate_clustered_ols(df):
201 | config = PowerConfig(
202 | cluster_cols=["cluster", "date"],
203 | analysis="ols_clustered",
204 | perturbator="constant",
205 | splitter="clustered",
206 | n_simulations=10,
207 | average_effect=5.0,
208 | alpha=0.05,
209 | )
210 | pw = PowerAnalysis.from_config(config)
211 | for point_estimate in pw.simulate_point_estimate(df, n_simulations=1):
212 | assert point_estimate > 0.0
213 |
214 |
215 | def test_point_estimate_ols(df):
216 | config = PowerConfig(
217 | analysis="ols_non_clustered",
218 | perturbator="constant",
219 | splitter="non_clustered",
220 | n_simulations=10,
221 | average_effect=5.0,
222 | alpha=0.05,
223 | )
224 | pw = PowerAnalysis.from_config(config)
225 | for point_estimate in pw.simulate_point_estimate(df, n_simulations=1):
226 | assert point_estimate > 0.0
227 |
228 |
229 | def test_power_line(df):
230 | config = PowerConfig(
231 | analysis="ols_non_clustered",
232 | perturbator="constant",
233 | splitter="non_clustered",
234 | n_simulations=10,
235 | average_effect=0.0,
236 | alpha=0.05,
237 | )
238 | pw = PowerAnalysis.from_config(config)
239 |
240 | power_line = pw.power_line(df, average_effects=[0.0, 1.0], n_simulations=10)
241 | assert len(power_line) == 2
242 | assert power_line[0.0] >= 0
243 | assert power_line[1.0] >= power_line[0.0]
244 |
245 |
246 | def test_running_power(df):
247 | # given
248 | config = PowerConfig(
249 | analysis="ols_non_clustered",
250 | perturbator="constant",
251 | splitter="non_clustered",
252 | n_simulations=50,
253 | average_effect=0.0,
254 | alpha=0.05,
255 | )
256 | pw = PowerAnalysis.from_config(config)
257 | previous_power = 0
258 | for i, power in enumerate(pw.running_power_analysis(df, average_effect=0.1)):
259 | # when: we've run enough simulations
260 | if i > 20:
261 | # then: the power should be stable (no changes bigger than 0.05)
262 | assert abs(power - previous_power) <= 0.05
263 | previous_power = power
264 |
265 |
266 | def test_hypothesis_from_dict(df):
267 | # given
268 | config_less = dict(
269 | analysis="ols_non_clustered",
270 | perturbator="constant",
271 | splitter="non_clustered",
272 | hypothesis="less",
273 | n_simulations=20,
274 | )
275 | pw_less = PowerAnalysis.from_dict(config_less)
276 |
277 | config_greater = config_less.copy()
278 | config_greater["hypothesis"] = "greater"
279 | pw_greater = PowerAnalysis.from_dict(config_greater)
280 |
281 | config_two_sided = config_less.copy()
282 | config_two_sided["hypothesis"] = "two-sided"
283 | pw_two_sided = PowerAnalysis.from_dict(config_two_sided)
284 |
285 | # when
286 | power_less = pw_less.power_analysis(df, average_effect=1.0)
287 | power_greater = pw_greater.power_analysis(df, average_effect=1.0)
288 | power_two_sided = pw_two_sided.power_analysis(df, average_effect=1.0)
289 |
290 | # then
291 | assert power_less < power_greater
292 | assert power_less < power_two_sided
293 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_power_analysis_with_pre_experiment_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from cluster_experiments.experiment_analysis import SyntheticControlAnalysis
4 | from cluster_experiments.perturbator import ConstantPerturbator
5 | from cluster_experiments.power_analysis import PowerAnalysisWithPreExperimentData
6 | from cluster_experiments.random_splitter import FixedSizeClusteredSplitter
7 | from cluster_experiments.synthetic_control_utils import generate_synthetic_control_data
8 |
9 |
10 | def test_power_analysis_with_pre_experiment_data():
11 | df = generate_synthetic_control_data(10, "2022-01-01", "2022-01-30")
12 |
13 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=1, cluster_cols=["user"])
14 |
15 | perturbator = ConstantPerturbator(
16 | average_effect=0.3,
17 | )
18 |
19 | analysis = SyntheticControlAnalysis(
20 | cluster_cols=["user"], time_col="date", intervention_date="2022-01-15"
21 | )
22 |
23 | pw = PowerAnalysisWithPreExperimentData(
24 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50
25 | )
26 |
27 | power = pw.power_analysis(df)
28 | pw.power_line(df, average_effects=[0.3, 0.4])
29 | assert 0 <= power <= 1
30 | values = list(pw.power_line(df, average_effects=[0.3, 0.4]).values())
31 | assert all(0 <= value <= 1 for value in values)
32 |
33 |
34 | def test_simulate_point_estimate():
35 | df = generate_synthetic_control_data(10, "2022-01-01", "2022-01-30")
36 |
37 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=1, cluster_cols=["user"])
38 |
39 | perturbator = ConstantPerturbator(
40 | average_effect=10,
41 | )
42 |
43 | analysis = SyntheticControlAnalysis(
44 | cluster_cols=["user"], time_col="date", intervention_date="2022-01-15"
45 | )
46 |
47 | pw = PowerAnalysisWithPreExperimentData(
48 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=50
49 | )
50 |
51 | point_estimates = list(pw.simulate_point_estimate(df))
52 | assert (
53 | 8 <= pd.Series(point_estimates).mean() <= 11
54 | ), "Point estimate is too far from the real effect."
55 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_power_raises.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments.experiment_analysis import GeeExperimentAnalysis
4 | from cluster_experiments.perturbator import ConstantPerturbator
5 | from cluster_experiments.power_analysis import PowerAnalysis
6 | from cluster_experiments.random_splitter import ClusteredSplitter, NonClusteredSplitter
7 |
8 |
9 | def test_raises_cupac():
10 | config = dict(
11 | cluster_cols=["cluster", "date"],
12 | analysis="gee",
13 | perturbator="constant",
14 | splitter="clustered",
15 | cupac_model="mean_cupac_model",
16 | n_simulations=4,
17 | )
18 | with pytest.raises(AssertionError):
19 | PowerAnalysis.from_dict(config)
20 |
21 |
22 | def test_data_checks(df):
23 | config = dict(
24 | cluster_cols=["cluster", "date"],
25 | analysis="gee",
26 | perturbator="constant",
27 | splitter="clustered",
28 | n_simulations=4,
29 | )
30 | pw = PowerAnalysis.from_dict(config)
31 | df["target"] = df["target"] == 1
32 | with pytest.raises(ValueError):
33 | pw.power_analysis(df, average_effect=0.0)
34 |
35 |
36 | def test_raise_target():
37 | sw = ClusteredSplitter(
38 | cluster_cols=["cluster", "date"],
39 | )
40 |
41 | perturbator = ConstantPerturbator(
42 | average_effect=0.1,
43 | target_col="another_target",
44 | )
45 |
46 | analysis = GeeExperimentAnalysis(
47 | cluster_cols=["cluster", "date"],
48 | )
49 |
50 | with pytest.raises(AssertionError):
51 | PowerAnalysis(
52 | perturbator=perturbator,
53 | splitter=sw,
54 | analysis=analysis,
55 | n_simulations=3,
56 | )
57 |
58 |
59 | def test_raise_treatment():
60 | sw = ClusteredSplitter(
61 | cluster_cols=["cluster", "date"],
62 | )
63 |
64 | perturbator = ConstantPerturbator(average_effect=0.1, treatment="C")
65 |
66 | analysis = GeeExperimentAnalysis(
67 | cluster_cols=["cluster", "date"],
68 | )
69 |
70 | with pytest.raises(AssertionError):
71 | PowerAnalysis(
72 | perturbator=perturbator,
73 | splitter=sw,
74 | analysis=analysis,
75 | n_simulations=3,
76 | )
77 |
78 |
79 | def test_raise_treatment_col():
80 | sw = ClusteredSplitter(
81 | cluster_cols=["cluster", "date"],
82 | )
83 |
84 | perturbator = ConstantPerturbator(
85 | average_effect=0.1,
86 | treatment_col="another_treatment",
87 | )
88 |
89 | analysis = GeeExperimentAnalysis(
90 | cluster_cols=["cluster", "date"],
91 | )
92 |
93 | with pytest.raises(AssertionError):
94 | PowerAnalysis(
95 | perturbator=perturbator,
96 | splitter=sw,
97 | analysis=analysis,
98 | n_simulations=3,
99 | )
100 |
101 |
102 | def test_raise_treatment_col_2():
103 | sw = ClusteredSplitter(
104 | cluster_cols=["cluster", "date"],
105 | )
106 |
107 | perturbator = ConstantPerturbator(
108 | average_effect=0.1,
109 | )
110 |
111 | analysis = GeeExperimentAnalysis(
112 | cluster_cols=["cluster", "date"],
113 | treatment_col="another_treatment",
114 | )
115 |
116 | with pytest.raises(AssertionError):
117 | PowerAnalysis(
118 | perturbator=perturbator,
119 | splitter=sw,
120 | analysis=analysis,
121 | n_simulations=3,
122 | )
123 |
124 |
125 | def test_raise_cluster_cols():
126 | sw = ClusteredSplitter(
127 | cluster_cols=["cluster"],
128 | )
129 |
130 | perturbator = ConstantPerturbator(
131 | average_effect=0.1,
132 | target_col="another_target",
133 | )
134 |
135 | analysis = GeeExperimentAnalysis(
136 | cluster_cols=["cluster", "date"],
137 | )
138 |
139 | with pytest.raises(AssertionError):
140 | PowerAnalysis(
141 | perturbator=perturbator,
142 | splitter=sw,
143 | analysis=analysis,
144 | n_simulations=3,
145 | )
146 |
147 |
148 | def test_raise_clustering_mismatch():
149 | sw = NonClusteredSplitter()
150 |
151 | perturbator = ConstantPerturbator(
152 | average_effect=0.1,
153 | target_col="another_target",
154 | )
155 |
156 | analysis = GeeExperimentAnalysis(
157 | cluster_cols=["cluster", "date"],
158 | )
159 |
160 | with pytest.raises(AssertionError):
161 | PowerAnalysis(
162 | perturbator=perturbator,
163 | splitter=sw,
164 | analysis=analysis,
165 | n_simulations=3,
166 | )
167 |
168 |
169 | def test_raise_treatment_same_control():
170 | sw = ClusteredSplitter(cluster_cols=["cluster", "date"])
171 |
172 | perturbator = ConstantPerturbator(average_effect=0.1)
173 |
174 | analysis = GeeExperimentAnalysis(cluster_cols=["cluster", "date"])
175 |
176 | with pytest.raises(AssertionError):
177 | PowerAnalysis(
178 | perturbator=perturbator,
179 | splitter=sw,
180 | analysis=analysis,
181 | treatment="A",
182 | control="A", # same as treatment
183 | n_simulations=3,
184 | )
185 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_seed.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from cluster_experiments.power_analysis import PowerAnalysis
4 |
5 |
6 | def get_config(perturbator: str) -> dict:
7 | return {
8 | "cluster_cols": ["cluster"],
9 | "analysis": "ols_clustered",
10 | "splitter": "clustered",
11 | "n_simulations": 15,
12 | "seed": 123,
13 | "perturbator": perturbator,
14 | }
15 |
16 |
17 | def test_power_analysis_constant_perturbator_seed(df):
18 | config_dict = get_config("constant")
19 |
20 | powers = []
21 | for _ in range(10):
22 | pw = PowerAnalysis.from_dict(config_dict)
23 | powers.append(pw.power_analysis(df, average_effect=10))
24 |
25 | assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10)
26 |
27 |
28 | def test_power_analysis_binary_perturbator_seed(df_binary):
29 | config_dict = get_config("binary")
30 |
31 | powers = []
32 | for _ in range(10):
33 | pw = PowerAnalysis.from_dict(config_dict)
34 | powers.append(pw.power_analysis(df_binary, average_effect=0.08))
35 |
36 | assert np.isclose(np.var(np.asarray(powers)), 0, atol=1e-10)
37 |
--------------------------------------------------------------------------------
/tests/power_analysis/test_switchback_power.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 |
7 | from cluster_experiments.power_analysis import PowerAnalysis
8 | from cluster_experiments.washover import ConstantWashover
9 |
10 |
11 | def test_switchback(switchback_power_analysis, df):
12 | power = switchback_power_analysis.power_analysis(
13 | df,
14 | average_effect=0.0,
15 | verbose=True,
16 | )
17 | assert power >= 0
18 | assert power <= 1
19 |
20 |
21 | def test_switchback_hour(switchback_power_analysis, df):
22 | # Random dates in 2022-06-26 00:00:00 - 2022-06-26 23:00:00
23 | df["date"] = pd.to_datetime(
24 | np.random.randint(
25 | 1624646400,
26 | 1624732800,
27 | size=len(df),
28 | ),
29 | unit="s",
30 | )
31 | power = switchback_power_analysis.power_analysis(
32 | df,
33 | average_effect=0.0,
34 | verbose=True,
35 | )
36 | assert power >= 0
37 | assert power <= 1
38 |
39 |
40 | def test_switchback_washover(switchback_power_analysis, df):
41 | power_no_washover = switchback_power_analysis.power_analysis(
42 | df,
43 | average_effect=0.1,
44 | n_simulations=10,
45 | )
46 | switchback_power_analysis.splitter.washover = ConstantWashover(
47 | washover_time_delta=datetime.timedelta(hours=23)
48 | )
49 |
50 | power = switchback_power_analysis.power_analysis(
51 | df,
52 | average_effect=0.1,
53 | n_simulations=10,
54 | )
55 | assert power >= 0
56 | assert power <= 1
57 | assert power_no_washover >= power
58 |
59 |
60 | def test_raise_no_delta():
61 | with pytest.raises(ValueError):
62 | PowerAnalysis.from_dict(
63 | {
64 | "time_col": "date",
65 | "switch_frequency": "1D",
66 | "perturbator": "constant",
67 | "analysis": "ols_clustered",
68 | "splitter": "switchback_balance",
69 | "cluster_cols": ["cluster", "date"],
70 | "strata_cols": ["cluster"],
71 | "washover": "constant_washover",
72 | }
73 | )
74 |
75 |
76 | def test_switchback_washover_config(switchback_washover, df):
77 | power = switchback_washover.power_analysis(
78 | df,
79 | average_effect=0.1,
80 | n_simulations=10,
81 | )
82 | assert power >= 0
83 | assert power <= 1
84 |
85 |
86 | def test_switchback_strata():
87 | # Define bihourly switchback splitter
88 | config = {
89 | "time_col": "time",
90 | "switch_frequency": "30min",
91 | "perturbator": "constant",
92 | "analysis": "ols_clustered",
93 | "splitter": "switchback_stratified",
94 | "cluster_cols": ["time", "city"],
95 | "strata_cols": ["day_of_week", "hour_of_day", "city"],
96 | "target_col": "y",
97 | "n_simulations": 3,
98 | }
99 |
100 | power = PowerAnalysis.from_dict(config)
101 | np.random.seed(42)
102 | df_raw = pd.DataFrame(
103 | {
104 | "time": pd.date_range("2021-01-01", "2021-01-10 23:59", freq="1T"),
105 | "y": np.random.randn(10 * 24 * 60),
106 | }
107 | ).assign(
108 | day_of_week=lambda df: df.time.dt.dayofweek,
109 | hour_of_day=lambda df: df.time.dt.hour,
110 | )
111 | df = pd.concat([df_raw.assign(city=city) for city in ("TGN", "NYC", "LON")])
112 | pw = power.power_analysis(df, average_effect=0.1)
113 | assert pw >= 0
114 | assert pw <= 1
115 |
--------------------------------------------------------------------------------
/tests/power_config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/power_config/__init__.py
--------------------------------------------------------------------------------
/tests/power_config/test_missing_arguments_error.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from cluster_experiments.power_config import MissingArgumentError, PowerConfig
4 |
5 |
6 | def test_missing_argument_segment_cols(caplog):
7 | msg = "segment_cols is required when using perturbator = segmented_beta_relative."
8 | with pytest.raises(MissingArgumentError, match=msg):
9 | PowerConfig(
10 | cluster_cols=["cluster"],
11 | analysis="mlm",
12 | perturbator="segmented_beta_relative",
13 | splitter="non_clustered",
14 | n_simulations=4,
15 | average_effect=1.5,
16 | )
17 |
--------------------------------------------------------------------------------
/tests/power_config/test_params_flow.py:
--------------------------------------------------------------------------------
1 | from cluster_experiments.power_analysis import NormalPowerAnalysis
2 |
3 |
4 | def test_cov_type_flows():
5 | # given
6 | config = {
7 | "analysis": "ols_non_clustered",
8 | "perturbator": "constant",
9 | "splitter": "non_clustered",
10 | "cov_type": "HC1",
11 | }
12 |
13 | # when
14 | power_analysis = NormalPowerAnalysis.from_dict(config)
15 |
16 | # then
17 | assert power_analysis.analysis.cov_type == "HC1"
18 |
19 |
20 | def test_cov_type_default():
21 | # given
22 | config = {
23 | "analysis": "ols_non_clustered",
24 | "perturbator": "constant",
25 | "splitter": "non_clustered",
26 | }
27 |
28 | # when
29 | power_analysis = NormalPowerAnalysis.from_dict(config)
30 |
31 | # then
32 | assert power_analysis.analysis.cov_type == "HC3"
33 |
34 |
35 | def test_covariate_interaction_flows():
36 | # given
37 | config = {
38 | "analysis": "ols_non_clustered",
39 | "perturbator": "constant",
40 | "splitter": "non_clustered",
41 | "add_covariate_interaction": True,
42 | }
43 |
44 | # when
45 | power_analysis = NormalPowerAnalysis.from_dict(config)
46 |
47 | # then
48 | assert power_analysis.analysis.add_covariate_interaction is True
49 |
--------------------------------------------------------------------------------
/tests/power_config/test_warnings_superfluous_params.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from cluster_experiments.power_config import PowerConfig
4 |
5 |
6 | def test_config_warning_superfluous_param_switch_frequency(caplog):
7 | msg = "switch_frequency = 1H has no effect with splitter = non_clustered. Overriding switch_frequency to None."
8 | with caplog.at_level(logging.WARNING):
9 | PowerConfig(
10 | cluster_cols=["cluster", "date"],
11 | analysis="ols_non_clustered",
12 | perturbator="constant",
13 | splitter="non_clustered",
14 | n_simulations=4,
15 | average_effect=1.5,
16 | switch_frequency="1H",
17 | )
18 | assert msg in caplog.text
19 |
20 |
21 | def test_config_warning_superfluous_param_washover_time_delta(caplog):
22 | msg = "washover_time_delta = 30 has no effect with splitter = non_clustered. Overriding washover_time_delta to None."
23 | with caplog.at_level(logging.WARNING):
24 | PowerConfig(
25 | cluster_cols=["cluster", "date"],
26 | analysis="ols_non_clustered",
27 | perturbator="constant",
28 | splitter="non_clustered",
29 | n_simulations=4,
30 | average_effect=1.5,
31 | washover_time_delta=30,
32 | )
33 | assert msg in caplog.text
34 |
35 |
36 | def test_config_warning_superfluous_param_washover(caplog):
37 | msg = "washover = constant_washover has no effect with splitter = non_clustered. Overriding washover to ."
38 | with caplog.at_level(logging.WARNING):
39 | PowerConfig(
40 | cluster_cols=["cluster", "date"],
41 | analysis="ols_non_clustered",
42 | perturbator="constant",
43 | splitter="non_clustered",
44 | n_simulations=4,
45 | average_effect=1.5,
46 | washover="constant_washover",
47 | )
48 | assert msg in caplog.text
49 |
50 |
51 | def test_config_warning_superfluous_param_time_col(caplog):
52 | msg = "time_col = datetime has no effect with splitter = non_clustered. Overriding time_col to None."
53 | with caplog.at_level(logging.WARNING):
54 | PowerConfig(
55 | cluster_cols=["cluster", "date"],
56 | analysis="ols_non_clustered",
57 | perturbator="constant",
58 | splitter="non_clustered",
59 | n_simulations=4,
60 | average_effect=1.5,
61 | time_col="datetime",
62 | )
63 | assert msg in caplog.text
64 |
65 |
66 | def test_config_warning_superfluous_param_perturbator(caplog):
67 | msg = "scale = 0.5 has no effect with perturbator = constant. Overriding scale to None."
68 | with caplog.at_level(logging.WARNING):
69 | PowerConfig(
70 | cluster_cols=["cluster", "date"],
71 | analysis="ols_non_clustered",
72 | perturbator="constant",
73 | splitter="non_clustered",
74 | n_simulations=4,
75 | average_effect=1.5,
76 | scale=0.5,
77 | )
78 | assert msg in caplog.text
79 |
80 |
81 | def test_config_warning_superfluous_param_strata_cols(caplog):
82 | msg = "strata_cols = ['group'] has no effect with splitter = non_clustered. Overriding strata_cols to None."
83 | with caplog.at_level(logging.WARNING):
84 | PowerConfig(
85 | cluster_cols=["cluster", "date"],
86 | analysis="ols_non_clustered",
87 | perturbator="constant",
88 | splitter="non_clustered",
89 | n_simulations=4,
90 | average_effect=1.5,
91 | strata_cols=["group"],
92 | )
93 | assert msg in caplog.text
94 |
95 |
96 | def test_config_warning_superfluous_param_splitter_weights(caplog):
97 | msg = "splitter_weights = [0.5, 0.5] has no effect with splitter = clustered_stratified. Overriding splitter_weights to None."
98 | with caplog.at_level(logging.WARNING):
99 | PowerConfig(
100 | cluster_cols=["cluster", "date"],
101 | analysis="ols_non_clustered",
102 | perturbator="constant",
103 | splitter="clustered_stratified",
104 | splitter_weights=[0.5, 0.5],
105 | n_simulations=4,
106 | average_effect=1.5,
107 | )
108 | assert msg in caplog.text
109 |
110 |
111 | def test_config_warning_superfluous_param_agg_col(caplog):
112 | msg = "agg_col = agg_col has no effect with cupac_model = . Overriding agg_col to ."
113 | with caplog.at_level(logging.WARNING):
114 | PowerConfig(
115 | cluster_cols=["cluster", "date"],
116 | analysis="ols_non_clustered",
117 | perturbator="constant",
118 | splitter="non_clustered",
119 | n_simulations=4,
120 | average_effect=1.5,
121 | agg_col="agg_col",
122 | )
123 | assert msg in caplog.text
124 |
125 |
126 | def test_config_warning_superfluous_param_smoothing_factor(caplog):
127 | msg = "smoothing_factor = 0.5 has no effect with cupac_model = . Overriding smoothing_factor to 20."
128 | with caplog.at_level(logging.WARNING):
129 | PowerConfig(
130 | cluster_cols=["cluster", "date"],
131 | analysis="ols_non_clustered",
132 | perturbator="constant",
133 | splitter="non_clustered",
134 | n_simulations=4,
135 | average_effect=1.5,
136 | smoothing_factor=0.5,
137 | )
138 | assert msg in caplog.text
139 |
140 |
141 | def test_config_warning_superfluous_param_features_cupac_model(caplog):
142 | msg = "features_cupac_model = ['feature1'] has no effect with cupac_model = . Overriding features_cupac_model to None."
143 | with caplog.at_level(logging.WARNING):
144 | PowerConfig(
145 | cluster_cols=["cluster", "date"],
146 | analysis="ols_clustered",
147 | perturbator="constant",
148 | splitter="non_clustered",
149 | n_simulations=4,
150 | average_effect=1.5,
151 | features_cupac_model=["feature1"],
152 | )
153 | assert msg in caplog.text
154 |
155 |
156 | def test_config_warning_superfluous_param_covariates(caplog):
157 | msg = "covariates = ['covariate1'] has no effect with analysis = ttest_clustered. Overriding covariates to None."
158 | with caplog.at_level(logging.WARNING):
159 | PowerConfig(
160 | cluster_cols=["cluster", "date"],
161 | analysis="ttest_clustered",
162 | perturbator="constant",
163 | splitter="non_clustered",
164 | n_simulations=4,
165 | average_effect=1.5,
166 | covariates=["covariate1"],
167 | )
168 | assert msg in caplog.text
169 |
--------------------------------------------------------------------------------
/tests/splitter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/tests/splitter/__init__.py
--------------------------------------------------------------------------------
/tests/splitter/conftest.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 |
7 | from cluster_experiments.power_analysis import PowerAnalysis
8 | from cluster_experiments.random_splitter import (
9 | BalancedSwitchbackSplitter,
10 | StratifiedSwitchbackSplitter,
11 | SwitchbackSplitter,
12 | )
13 |
14 |
15 | @pytest.fixture
16 | def switchback_splitter():
17 | return SwitchbackSplitter(
18 | time_col="time", switch_frequency="1D", cluster_cols=["time"]
19 | )
20 |
21 |
22 | @pytest.fixture
23 | def switchback_splitter_config():
24 | config = {
25 | "time_col": "time",
26 | "switch_frequency": "1D",
27 | "perturbator": "constant",
28 | "analysis": "ols_clustered",
29 | "splitter": "switchback",
30 | "cluster_cols": ["time"],
31 | }
32 | return PowerAnalysis.from_dict(config).splitter
33 |
34 |
35 | switchback_splitter_parametrize = pytest.mark.parametrize(
36 | "splitter",
37 | [
38 | "switchback_splitter",
39 | "switchback_splitter_config",
40 | ],
41 | )
42 |
43 |
44 | @pytest.fixture
45 | def balanced_splitter():
46 | return BalancedSwitchbackSplitter(
47 | time_col="time", switch_frequency="1D", cluster_cols=["time"]
48 | )
49 |
50 |
51 | @pytest.fixture
52 | def balanced_splitter_config():
53 | config = {
54 | "time_col": "time",
55 | "switch_frequency": "1D",
56 | "perturbator": "constant",
57 | "analysis": "ols_clustered",
58 | "splitter": "switchback_balance",
59 | "cluster_cols": ["time"],
60 | }
61 | return PowerAnalysis.from_dict(config).splitter
62 |
63 |
64 | balanced_splitter_parametrize = pytest.mark.parametrize(
65 | "splitter",
66 | [
67 | "balanced_splitter",
68 | "balanced_splitter_config",
69 | ],
70 | )
71 |
72 |
73 | @pytest.fixture
74 | def stratified_switchback_splitter():
75 | return StratifiedSwitchbackSplitter(
76 | time_col="time",
77 | switch_frequency="1D",
78 | strata_cols=["day_of_week"],
79 | cluster_cols=["time"],
80 | )
81 |
82 |
83 | @pytest.fixture
84 | def stratified_switchback_splitter_config():
85 | config = {
86 | "time_col": "time",
87 | "switch_frequency": "1D",
88 | "perturbator": "constant",
89 | "analysis": "ols_clustered",
90 | "splitter": "switchback_stratified",
91 | "cluster_cols": ["time"],
92 | "strata_cols": ["day_of_week"],
93 | }
94 | return PowerAnalysis.from_dict(config).splitter
95 |
96 |
97 | stratified_splitter_parametrize = pytest.mark.parametrize(
98 | "splitter",
99 | [
100 | "stratified_switchback_splitter",
101 | "stratified_switchback_splitter_config",
102 | ],
103 | )
104 |
105 |
106 | @pytest.fixture
107 | def date_df():
108 | return pd.DataFrame(
109 | {
110 | "time": pd.date_range("2020-01-01", "2020-01-10", freq="1D"),
111 | "y": np.random.randn(10),
112 | }
113 | )
114 |
115 |
116 | @pytest.fixture
117 | def biweekly_df():
118 | df = pd.DataFrame(
119 | {
120 | "time": pd.date_range("2020-01-01", "2020-01-14", freq="1D"),
121 | "y": np.random.randn(14),
122 | }
123 | ).assign(
124 | day_of_week=lambda df: df["time"].dt.day_name(),
125 | )
126 | return pd.concat([df.assign(cluster=f"Cluster {i}") for i in range(10)])
127 |
128 |
129 | @pytest.fixture
130 | def washover_df():
131 | # Define data with random dates
132 | df_raw = pd.DataFrame(
133 | {
134 | "time": pd.date_range("2021-01-01", "2021-01-02", freq="1min")[
135 | np.random.randint(24 * 60, size=7 * 24 * 60)
136 | ],
137 | "y": np.random.randn(7 * 24 * 60),
138 | }
139 | ).assign(
140 | day_of_week=lambda df: df.time.dt.dayofweek,
141 | hour_of_day=lambda df: df.time.dt.hour,
142 | )
143 | df = pd.concat([df_raw.assign(city=city) for city in ("TGN", "NYC", "LON", "REU")])
144 | return df
145 |
146 |
147 | @pytest.fixture
148 | def washover_base_df():
149 | df = pd.DataFrame(
150 | {
151 | "original___time": [
152 | pd.to_datetime("2022-01-01 00:20:00"),
153 | pd.to_datetime("2022-01-01 00:31:00"),
154 | pd.to_datetime("2022-01-01 01:14:00"),
155 | pd.to_datetime("2022-01-01 01:31:00"),
156 | ],
157 | "treatment": ["A", "A", "B", "B"],
158 | "city": ["TGN"] * 4,
159 | }
160 | ).assign(time=lambda x: x["original___time"].dt.floor("1h"))
161 | return df
162 |
163 |
164 | @pytest.fixture
165 | def washover_df_no_switch():
166 | df = pd.DataFrame(
167 | {
168 | "original___time": [
169 | pd.to_datetime("2022-01-01 00:20:00"),
170 | pd.to_datetime("2022-01-01 00:31:00"),
171 | pd.to_datetime("2022-01-01 01:14:00"),
172 | pd.to_datetime("2022-01-01 01:31:00"),
173 | pd.to_datetime("2022-01-01 02:01:00"),
174 | pd.to_datetime("2022-01-01 02:31:00"),
175 | ],
176 | "treatment": ["A", "A", "B", "B", "B", "B"],
177 | "city": ["TGN"] * 6,
178 | }
179 | ).assign(time=lambda x: x["original___time"].dt.floor("1h"))
180 | return df
181 |
182 |
183 | @pytest.fixture
184 | def washover_df_multi_city():
185 | df = pd.DataFrame(
186 | {
187 | "original___time": [
188 | pd.to_datetime("2022-01-01 01:14:00"),
189 | pd.to_datetime("2022-01-01 00:20:00"),
190 | pd.to_datetime("2022-01-01 00:31:00"),
191 | pd.to_datetime("2022-01-01 02:01:00"),
192 | pd.to_datetime("2022-01-01 01:31:00"),
193 | pd.to_datetime("2022-01-01 02:31:00"),
194 | ]
195 | * 2,
196 | "treatment": ["B", "A", "A", "B", "B", "B"]
197 | + ["A", "A", "A", "B", "A", "B"],
198 | "city": ["TGN"] * 6 + ["BCN"] * 6,
199 | }
200 | ).assign(time=lambda x: x["original___time"].dt.floor("1h"))
201 | return df
202 |
203 |
204 | @pytest.fixture
205 | def washover_split_df(n):
206 | # Return
207 | return pd.DataFrame(
208 | {
209 | # Random time each minute in 2022-01-01, length 1000
210 | "time": pd.date_range("2022-01-01", "2022-01-02", freq="1min")[
211 | np.random.randint(24 * 60, size=n)
212 | ],
213 | "city": random.choices(["TGN", "NYC", "LON", "REU"], k=n),
214 | }
215 | )
216 |
217 |
218 | @pytest.fixture
219 | def washover_split_no_city_df(n):
220 | # Return
221 | return pd.DataFrame(
222 | {
223 | # Random time each minute in 2022-01-01, length 1000
224 | "time": pd.date_range("2022-01-01", "2022-01-02", freq="1min")[
225 | np.random.randint(24 * 60, size=n)
226 | ],
227 | }
228 | )
229 |
--------------------------------------------------------------------------------
/tests/splitter/test_fixed_size_clusters_splitter.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments import ConstantPerturbator, PowerAnalysis
7 | from cluster_experiments.experiment_analysis import ClusteredOLSAnalysis
8 | from cluster_experiments.random_splitter import (
9 | FixedSizeClusteredSplitter,
10 | )
11 |
12 |
13 | def test_predefined_treatment_clusters_splitter():
14 | # Create a DataFrame with mock data
15 | df = pd.DataFrame({"cluster": ["A", "A", "B", "B", "C", "C", "D", "D", "E", "E"]})
16 |
17 | split = FixedSizeClusteredSplitter(cluster_cols=["cluster"], n_treatment_clusters=1)
18 |
19 | df = split.assign_treatment_df(df)
20 |
21 | # Verify that the treatments were assigned correctly
22 | assert df[split.treatment_col].value_counts()[split.treatments[0]] == 8
23 | assert df[split.treatment_col].value_counts()[split.treatments[1]] == 2
24 |
25 |
26 | def test_sample_treatment_with_balanced_clusters():
27 | splitter = FixedSizeClusteredSplitter(cluster_cols=["city"], n_treatment_clusters=2)
28 | df = pd.DataFrame({"city": ["A", "B", "C", "D"]})
29 | treatments = splitter.sample_treatment(df)
30 | assert len(treatments) == len(df)
31 | assert treatments.count("A") == 2
32 | assert treatments.count("B") == 2
33 |
34 |
35 | def generate_data(N, start_date, end_date):
36 | dates = pd.date_range(start_date, end_date, freq="d")
37 |
38 | users = [f"User {i}" for i in range(N)]
39 |
40 | combinations = list(product(users, dates))
41 |
42 | target_values = np.random.normal(100, 10, size=len(combinations))
43 |
44 | df = pd.DataFrame(combinations, columns=["user", "date"])
45 | df["target"] = target_values
46 |
47 | return df
48 |
49 |
50 | def test_ols_fixed_size_treatment():
51 | df = generate_data(100, "2021-01-01", "2021-01-15")
52 |
53 | analysis = ClusteredOLSAnalysis(cluster_cols=["user"])
54 |
55 | sw = FixedSizeClusteredSplitter(n_treatment_clusters=1, cluster_cols=["user"])
56 |
57 | perturbator = ConstantPerturbator(
58 | average_effect=0,
59 | )
60 |
61 | pw = PowerAnalysis(
62 | perturbator=perturbator, splitter=sw, analysis=analysis, n_simulations=200
63 | )
64 | pw.power_analysis(df, average_effect=0)
65 | # todo finish this test, the power shouldn't be too high
66 |
--------------------------------------------------------------------------------
/tests/splitter/test_switchback_splitter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from cluster_experiments.experiment_analysis import OLSAnalysis
5 | from cluster_experiments.perturbator import ConstantPerturbator
6 | from cluster_experiments.power_analysis import PowerAnalysis
7 | from cluster_experiments.random_splitter import SwitchbackSplitter
8 | from tests.splitter.conftest import (
9 | balanced_splitter_parametrize,
10 | stratified_splitter_parametrize,
11 | switchback_splitter_parametrize,
12 | )
13 |
14 | add_cluster_cols_parametrize = pytest.mark.parametrize(
15 | "add_cluster_cols", [True, False]
16 | )
17 |
18 |
19 | @switchback_splitter_parametrize
20 | def test_switchback_splitter(splitter, date_df, request):
21 | switchback_splitter = request.getfixturevalue(splitter)
22 |
23 | treatment_assignment = switchback_splitter.assign_treatment_df(date_df)
24 | assert "time" in switchback_splitter.cluster_cols
25 |
26 | # Only 1 treatment per date
27 | assert (treatment_assignment.groupby("time")["treatment"].nunique() == 1).all()
28 |
29 |
30 | @switchback_splitter_parametrize
31 | def test_clustered_switchback_splitter(splitter, biweekly_df, request):
32 | switchback_splitter = request.getfixturevalue(splitter)
33 | biweekly_df_long = pd.concat([biweekly_df for _ in range(3)])
34 | switchback_splitter.cluster_cols = ["cluster", "time"]
35 | treatment_assignment = switchback_splitter.assign_treatment_df(biweekly_df_long)
36 | assert "time" in switchback_splitter.cluster_cols
37 |
38 | # Only 1 treatment per cluster
39 | assert (
40 | treatment_assignment.groupby(["cluster", "time"])["treatment"].nunique() == 1
41 | ).all()
42 |
43 |
44 | @balanced_splitter_parametrize
45 | @add_cluster_cols_parametrize
46 | def test_clustered_switchback_splitter_balance(
47 | splitter, add_cluster_cols, biweekly_df, request
48 | ):
49 | balanced_splitter = request.getfixturevalue(splitter)
50 |
51 | if add_cluster_cols:
52 | balanced_splitter.cluster_cols += ["cluster"]
53 | treatment_assignment = balanced_splitter.assign_treatment_df(biweekly_df)
54 | assert "time" in balanced_splitter.cluster_cols
55 | # Assert that the treatment assignment is balanced
56 | assert (treatment_assignment.treatment.value_counts() == 70).all()
57 |
58 |
59 | @stratified_splitter_parametrize
60 | @add_cluster_cols_parametrize
61 | def test_stratified_splitter(splitter, add_cluster_cols, biweekly_df, request):
62 | stratified_switchback_splitter = request.getfixturevalue(splitter)
63 |
64 | if add_cluster_cols:
65 | stratified_switchback_splitter.cluster_cols += ["cluster"]
66 | stratified_switchback_splitter.strata_cols += ["cluster"]
67 |
68 | treatment_assignment = stratified_switchback_splitter.assign_treatment_df(
69 | biweekly_df
70 | )
71 | assert "time" in stratified_switchback_splitter.cluster_cols
72 | assert (treatment_assignment.treatment.value_counts() == 70).all()
73 | # Per cluster, there are 2 treatments. Per day of week too
74 | for col in ["cluster", "day_of_week"]:
75 | assert (treatment_assignment.groupby([col])["treatment"].nunique() == 2).all()
76 |
77 | # Check stratification. Count day_of_week and treatment, we should always
78 | # have the same number of observations. Same for cluster
79 | for col in ["cluster", "day_of_week"]:
80 | assert treatment_assignment.groupby([col, "treatment"]).size().nunique() == 1
81 |
82 |
83 | def test_raise_time_col_not_in_df():
84 | with pytest.raises(
85 | AssertionError,
86 | match="in switchback splitters, time_col must be in cluster_cols",
87 | ):
88 | sw = SwitchbackSplitter(time_col="time")
89 | perturbator = ConstantPerturbator()
90 | analysis = OLSAnalysis()
91 | _ = PowerAnalysis(
92 | splitter=sw,
93 | perturbator=perturbator,
94 | analysis=analysis,
95 | )
96 |
97 |
98 | def test_raise_time_col_not_in_df_splitter():
99 | with pytest.raises(
100 | AssertionError,
101 | match="in switchback splitters, time_col must be in cluster_cols",
102 | ):
103 | data = pd.DataFrame(
104 | {
105 | "activation_time": pd.date_range(
106 | start="2021-01-01", periods=10, freq="D"
107 | ),
108 | "city": ["A" for _ in range(10)],
109 | }
110 | )
111 | time_col = "activation_time"
112 | switch_frequency = "6h"
113 | cluster_cols = ["city"]
114 |
115 | splitter = SwitchbackSplitter(
116 | time_col=time_col,
117 | cluster_cols=cluster_cols,
118 | switch_frequency=switch_frequency,
119 | )
120 | _ = splitter.assign_treatment_df(data)
121 |
--------------------------------------------------------------------------------
/tests/splitter/test_time_col.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from tests.splitter.conftest import switchback_splitter_parametrize
6 |
7 |
8 | @pytest.fixture
9 | def yearly_df():
10 | return pd.DataFrame(
11 | {
12 | "time": pd.date_range("2021-01-01", "2021-12-31", freq="1D"),
13 | "y": np.random.randn(365),
14 | }
15 | )
16 |
17 |
18 | @pytest.fixture
19 | def hourly_df():
20 | return pd.DataFrame(
21 | {
22 | "time": pd.date_range("2021-01-01", "2021-01-10 23:00", freq="1H"),
23 | "y": np.random.randn(10 * 24), # 10 days, 24 hours per day, 1 hour per row
24 | }
25 | )
26 |
27 |
28 | @pytest.fixture
29 | def minute_df():
30 | return pd.DataFrame(
31 | {
32 | "time": pd.date_range("2021-01-01", "2021-01-10 23:59", freq="1T"),
33 | "y": np.random.randn(
34 | 10 * 24 * 60
35 | ), # 10 days, 24 hours per day, 1 hour per row
36 | }
37 | )
38 |
39 |
40 | @pytest.mark.parametrize(
41 | "df,switchback_freq,n_splits",
42 | [
43 | ("date_df", "1D", 10),
44 | ("date_df", "2D", 5),
45 | ("date_df", "4D", 3),
46 | ("hourly_df", "H", 240),
47 | ("hourly_df", "2H", 120),
48 | ("hourly_df", "4H", 60),
49 | ("hourly_df", "6H", 40),
50 | ("hourly_df", "12H", 20),
51 | ("minute_df", "min", 14400),
52 | ("minute_df", "2min", 7200),
53 | ("minute_df", "4min", 3600),
54 | ("minute_df", "30min", 480),
55 | ("minute_df", "60min", 240),
56 | ("yearly_df", "W", 53),
57 | ("yearly_df", "M", 12),
58 | ],
59 | )
60 | @switchback_splitter_parametrize
61 | def test_date_col(splitter, df, switchback_freq, n_splits, request):
62 | date_df = request.getfixturevalue(df)
63 | splitter = request.getfixturevalue(splitter)
64 | splitter.switch_frequency = switchback_freq
65 | time_col = splitter._get_time_col_cluster(date_df)
66 | assert time_col.dtype == "datetime64[ns]"
67 | assert time_col.nunique() == n_splits
68 |
69 | if "W" not in switchback_freq and "M" not in switchback_freq:
70 | pd.testing.assert_series_equal(
71 | time_col, date_df["time"].dt.floor(switchback_freq)
72 | )
73 |
74 |
75 | @pytest.mark.parametrize(
76 | "switchback_freq,day_of_week",
77 | [
78 | ("W-MON", "Tuesday"),
79 | ("W-TUE", "Wednesday"),
80 | ("W-SUN", "Monday"),
81 | ("W", "Monday"),
82 | ],
83 | )
84 | @switchback_splitter_parametrize
85 | def test_week_col_date(splitter, date_df, switchback_freq, day_of_week, request):
86 | splitter = request.getfixturevalue(splitter)
87 | splitter.switch_frequency = switchback_freq
88 | time_col = splitter._get_time_col_cluster(date_df)
89 | assert time_col.dtype == "datetime64[ns]"
90 | pd.testing.assert_series_equal(
91 | time_col, date_df["time"].dt.to_period(switchback_freq).dt.start_time
92 | )
93 | assert time_col.nunique() == 2
94 | # Assert that the first day of the week is correct
95 | assert (
96 | time_col.dt.day_name().iloc[0] == day_of_week
97 | ), f"{switchback_freq} failed, day_of_week is {time_col.dt.day_name().iloc[0]}"
98 |
--------------------------------------------------------------------------------
/tests/splitter/test_washover.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from datetime import timedelta
3 |
4 | import pandas as pd
5 | import pytest
6 |
7 | from cluster_experiments import SwitchbackSplitter
8 | from cluster_experiments.washover import ConstantWashover, EmptyWashover
9 |
10 |
11 | @pytest.mark.parametrize("minutes, n_rows", [(30, 2), (10, 4), (15, 3)])
12 | def test_constant_washover_base(minutes, n_rows, washover_base_df):
13 | out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover(
14 | df=washover_base_df,
15 | truncated_time_col="time",
16 | cluster_cols=["city", "time"],
17 | treatment_col="treatment",
18 | original_time_col="original___time",
19 | )
20 |
21 | assert len(out_df) == n_rows
22 | assert (out_df["original___time"].dt.minute > minutes).all()
23 |
24 |
25 | @pytest.mark.parametrize(
26 | "minutes, n_rows, df",
27 | [
28 | (30, 4, "washover_df_no_switch"),
29 | (30, 4 + 4, "washover_df_multi_city"),
30 | ],
31 | )
32 | def test_constant_washover_no_switch(minutes, n_rows, df, request):
33 | washover_df = request.getfixturevalue(df)
34 |
35 | out_df = ConstantWashover(washover_time_delta=timedelta(minutes=minutes)).washover(
36 | df=washover_df,
37 | truncated_time_col="time",
38 | cluster_cols=["city", "time"],
39 | treatment_col="treatment",
40 | )
41 | assert len(out_df) == n_rows
42 | if df == "washover_df_no_switch":
43 | # Check that, after 2022-01-01 02:00:00, we keep all the rows of the original
44 | # dataframe
45 | assert washover_df.query("time >= '2022-01-01 02:00:00'").equals(
46 | out_df.query("time >= '2022-01-01 02:00:00'")
47 | )
48 | # Check that, after 2022-01-01 01:00:00, we don't have the same rows as the
49 | # original dataframe
50 | assert not washover_df.query("time >= '2022-01-01 01:00:00'").equals(
51 | out_df.query("time >= '2022-01-01 01:00:00'")
52 | )
53 |
54 |
55 | @pytest.mark.parametrize(
56 | "minutes, n,",
57 | [
58 | (15, 10000),
59 | ],
60 | )
61 | def test_constant_washover_split(minutes, n, washover_split_df):
62 | washover = ConstantWashover(washover_time_delta=timedelta(minutes=minutes))
63 |
64 | splitter = SwitchbackSplitter(
65 | washover=washover,
66 | time_col="time",
67 | cluster_cols=["city", "time"],
68 | treatment_col="treatment",
69 | switch_frequency="30T",
70 | )
71 |
72 | out_df = splitter.assign_treatment_df(df=washover_split_df)
73 |
74 | # Assert A and B in out_df
75 | assert set(out_df["treatment"].unique()) == {"A", "B"}
76 |
77 | # We need to have less than 10000 rows
78 | assert len(out_df) < n
79 |
80 | # We need to have more than 5000 rows (this is because ABB doesn't do washover on the second split)
81 | assert len(out_df) > n / 2
82 |
83 |
84 | @pytest.mark.parametrize(
85 | "minutes, n,",
86 | [
87 | (15, 1000),
88 | ],
89 | )
90 | def test_constant_washover_split_no_city(minutes, n, washover_split_no_city_df):
91 | washover = ConstantWashover(washover_time_delta=timedelta(minutes=minutes))
92 |
93 | splitter = SwitchbackSplitter(
94 | washover=washover,
95 | time_col="time",
96 | cluster_cols=["time"],
97 | treatment_col="treatment",
98 | switch_frequency="30T",
99 | )
100 |
101 | out_df = splitter.assign_treatment_df(df=washover_split_no_city_df)
102 |
103 | # Assert A and B in out_df
104 | assert set(out_df["treatment"].unique()) == {"A", "B"}
105 |
106 | # We need to have less than 10000 rows
107 | assert len(out_df) < n
108 |
109 | # We need to have more than 5000 rows (this is because ABB doesn't do washover on the second split)
110 | assert len(out_df) > n / 2
111 |
112 |
113 | @pytest.mark.parametrize(
114 | "minutes, n",
115 | [
116 | (15, 1000),
117 | ],
118 | )
119 | def test_no_washover_split(minutes, n, washover_split_df):
120 | washover = EmptyWashover()
121 |
122 | splitter = SwitchbackSplitter(
123 | washover=washover,
124 | time_col="time",
125 | cluster_cols=["city", "time"],
126 | treatment_col="treatment",
127 | switch_frequency="30T",
128 | )
129 |
130 | out_df = splitter.assign_treatment_df(df=washover_split_df)
131 |
132 | # Assert A and B in out_df
133 | assert set(out_df["treatment"].unique()) == {"A", "B"}
134 |
135 | # We need to have exactly 1000 rows
136 | assert len(out_df) == n
137 |
138 |
139 | @pytest.mark.parametrize(
140 | "minutes, n_rows, df",
141 | [
142 | (30, 4, "washover_df_no_switch"),
143 | (30, 4 + 4, "washover_df_multi_city"),
144 | ],
145 | )
146 | def test_constant_washover_no_switch_instantiated_int(minutes, n_rows, df, request):
147 | washover_df = request.getfixturevalue(df)
148 |
149 | @dataclass
150 | class Cfg:
151 | washover_time_delta: int
152 |
153 | cw = ConstantWashover.from_config(Cfg(minutes))
154 | out_df = cw.washover(
155 | df=washover_df,
156 | truncated_time_col="time",
157 | cluster_cols=["city", "time"],
158 | treatment_col="treatment",
159 | )
160 | assert len(out_df) == n_rows
161 | if df == "washover_df_no_switch":
162 | # Check that, after 2022-01-01 02:00:00, we keep all the rows of the original
163 | # dataframe
164 | assert washover_df.query("time >= '2022-01-01 02:00:00'").equals(
165 | out_df.query("time >= '2022-01-01 02:00:00'")
166 | )
167 | # Check that, after 2022-01-01 01:00:00, we don't have the same rows as the
168 | # original dataframe
169 | assert not washover_df.query("time >= '2022-01-01 01:00:00'").equals(
170 | out_df.query("time >= '2022-01-01 01:00:00'")
171 | )
172 |
173 |
174 | def test_truncated_time_not_in_cluster_cols():
175 | msg = "is not in the cluster columns."
176 | df = pd.DataFrame(columns=["time_bin", "city", "time", "treatment"])
177 |
178 | # Check that the truncated_time_col is also included in the cluster_cols,
179 | # An error is raised because "time_bin" is not in the cluster_cols
180 | with pytest.raises(ValueError, match=msg):
181 |
182 | ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover(
183 | df=df,
184 | truncated_time_col="time_bin",
185 | cluster_cols=["city"],
186 | original_time_col="time",
187 | treatment_col="treatment",
188 | )
189 |
190 |
191 | def test_missing_original_time_col():
192 | msg = "columns and/or not specified as an input."
193 | df = pd.DataFrame(columns=["time_bin", "city", "treatment"])
194 |
195 | # Check that the original_time_col is specifed as an input and in the dataframe columns
196 | # An error is raised because "time" is not specified as an input for the washover
197 | with pytest.raises(ValueError, match=msg):
198 |
199 | ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover(
200 | df=df,
201 | truncated_time_col="time_bin",
202 | cluster_cols=["city", "time_bin"],
203 | treatment_col="treatment",
204 | )
205 |
206 |
207 | def test_cluster_cols_missing_in_df():
208 | msg = "cluster is not in the dataframe columns."
209 | df = pd.DataFrame(columns=["time_bin", "time", "treatment"])
210 |
211 | # Check that all the cluster_cols are in the dataframe columns
212 | # An error is raised because "city" is not in the dataframe columns
213 | with pytest.raises(ValueError, match=msg):
214 |
215 | ConstantWashover(washover_time_delta=timedelta(minutes=30)).washover(
216 | df=df,
217 | truncated_time_col="time_bin",
218 | cluster_cols=["city", "time_bin"],
219 | original_time_col="time",
220 | treatment_col="treatment",
221 | )
222 |
--------------------------------------------------------------------------------
/tests/test_docs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mktestdocs import check_docstring, check_md_file, get_codeblock_members
3 |
4 | from cluster_experiments import (
5 | AnalysisPlan,
6 | BalancedClusteredSplitter,
7 | BalancedSwitchbackSplitter,
8 | BetaRelativePerturbator,
9 | BetaRelativePositivePerturbator,
10 | BinaryPerturbator,
11 | ClusteredOLSAnalysis,
12 | ClusteredSplitter,
13 | ConstantPerturbator,
14 | ConstantWashover,
15 | DeltaMethodAnalysis,
16 | Dimension,
17 | EmptyRegressor,
18 | EmptyWashover,
19 | ExperimentAnalysis,
20 | FixedSizeClusteredSplitter,
21 | GeeExperimentAnalysis,
22 | HypothesisTest,
23 | Metric,
24 | MLMExperimentAnalysis,
25 | NonClusteredSplitter,
26 | NormalPerturbator,
27 | NormalPowerAnalysis,
28 | OLSAnalysis,
29 | PairedTTestClusteredAnalysis,
30 | Perturbator,
31 | PowerAnalysis,
32 | PowerConfig,
33 | RandomSplitter,
34 | RatioMetric,
35 | RelativeMixedPerturbator,
36 | RelativePositivePerturbator,
37 | RepeatedSampler,
38 | SegmentedBetaRelativePerturbator,
39 | SimpleMetric,
40 | StratifiedClusteredSplitter,
41 | StratifiedSwitchbackSplitter,
42 | SwitchbackSplitter,
43 | SyntheticControlAnalysis,
44 | TargetAggregation,
45 | TTestClusteredAnalysis,
46 | UniformPerturbator,
47 | Variant,
48 | )
49 | from cluster_experiments.utils import _original_time_column
50 |
51 | all_objects = [
52 | BalancedClusteredSplitter,
53 | BinaryPerturbator,
54 | ClusteredOLSAnalysis,
55 | ClusteredSplitter,
56 | DeltaMethodAnalysis,
57 | EmptyRegressor,
58 | ExperimentAnalysis,
59 | FixedSizeClusteredSplitter,
60 | GeeExperimentAnalysis,
61 | NonClusteredSplitter,
62 | OLSAnalysis,
63 | Perturbator,
64 | PowerAnalysis,
65 | NormalPowerAnalysis,
66 | PowerConfig,
67 | RandomSplitter,
68 | StratifiedClusteredSplitter,
69 | SyntheticControlAnalysis,
70 | TargetAggregation,
71 | TTestClusteredAnalysis,
72 | PairedTTestClusteredAnalysis,
73 | ConstantPerturbator,
74 | UniformPerturbator,
75 | _original_time_column,
76 | ConstantWashover,
77 | EmptyWashover,
78 | BalancedSwitchbackSplitter,
79 | StratifiedSwitchbackSplitter,
80 | SwitchbackSplitter,
81 | RepeatedSampler,
82 | MLMExperimentAnalysis,
83 | RelativePositivePerturbator,
84 | NormalPerturbator,
85 | BetaRelativePositivePerturbator,
86 | BetaRelativePerturbator,
87 | SegmentedBetaRelativePerturbator,
88 | AnalysisPlan,
89 | Metric,
90 | SimpleMetric,
91 | RatioMetric,
92 | Dimension,
93 | Variant,
94 | HypothesisTest,
95 | RelativeMixedPerturbator,
96 | ]
97 |
98 |
99 | def flatten(items):
100 | """Flattens a list"""
101 | return [item for sublist in items for item in sublist]
102 |
103 |
104 | # This way we ensure that each item in `all_members` points to a method
105 | # that could have a docstring.
106 | all_members = flatten([get_codeblock_members(o) for o in all_objects])
107 |
108 |
109 | @pytest.mark.parametrize("func", all_members, ids=lambda d: d.__qualname__)
110 | def test_function_docstrings(func):
111 | """Test the python example in each method in each object."""
112 | check_docstring(obj=func)
113 |
114 |
115 | @pytest.mark.parametrize(
116 | "fpath",
117 | [
118 | "README.md",
119 | ],
120 | )
121 | def test_quickstart_docs_file(fpath):
122 | """Test the quickstart files."""
123 | check_md_file(fpath, memory=True)
124 |
--------------------------------------------------------------------------------
/tests/test_non_clustered.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import pytest
5 | from sklearn.model_selection import train_test_split
6 |
7 | from cluster_experiments.cupac import TargetAggregation
8 | from cluster_experiments.experiment_analysis import OLSAnalysis
9 | from cluster_experiments.perturbator import ConstantPerturbator
10 | from cluster_experiments.power_analysis import PowerAnalysis
11 | from cluster_experiments.random_splitter import NonClusteredSplitter
12 | from tests.utils import generate_non_clustered_data
13 |
14 | N = 10_000
15 | n_users = 1000
16 | random.seed(41)
17 |
18 |
19 | @pytest.fixture
20 | def df():
21 | return generate_non_clustered_data(N, n_users)
22 |
23 |
24 | @pytest.fixture
25 | def df_feats():
26 | df = generate_non_clustered_data(N, n_users)
27 | df["x1"] = np.random.normal(0, 1, N)
28 | df["x2"] = np.random.normal(0, 1, N)
29 | return df
30 |
31 |
32 | @pytest.fixture
33 | def cupac_power_analysis():
34 | sw = NonClusteredSplitter()
35 |
36 | perturbator = ConstantPerturbator(
37 | average_effect=0.1,
38 | )
39 |
40 | analysis = OLSAnalysis(
41 | covariates=["estimate_target"],
42 | )
43 |
44 | target_agg = TargetAggregation(
45 | agg_col="user",
46 | )
47 |
48 | return PowerAnalysis(
49 | perturbator=perturbator,
50 | splitter=sw,
51 | analysis=analysis,
52 | cupac_model=target_agg,
53 | n_simulations=3,
54 | )
55 |
56 |
57 | @pytest.fixture
58 | def cupac_from_config():
59 | return PowerAnalysis.from_dict(
60 | dict(
61 | analysis="ols_non_clustered",
62 | perturbator="constant",
63 | splitter="non_clustered",
64 | cupac_model="mean_cupac_model",
65 | average_effect=0.1,
66 | n_simulations=4,
67 | covariates=["estimate_target"],
68 | agg_col="user",
69 | )
70 | )
71 |
72 |
73 | def test_power_analysis(cupac_power_analysis, df):
74 | pre_df, df = train_test_split(df)
75 | power = cupac_power_analysis.power_analysis(df, pre_df)
76 | assert power >= 0
77 | assert power <= 1
78 |
79 |
80 | def test_power_analysis_config(cupac_from_config, df):
81 | pre_df, df = train_test_split(df)
82 | power = cupac_from_config.power_analysis(df, pre_df)
83 | assert power >= 0
84 | assert power <= 1
85 |
86 |
87 | def test_splitter(df):
88 | splitter = NonClusteredSplitter()
89 | # Check counts A and B are 50/50
90 | treatment_assignment = splitter.assign_treatment_df(df)
91 | n_a = treatment_assignment.treatment.value_counts()["A"]
92 | assert n_a >= -200 + len(treatment_assignment) / 2
93 | assert n_a <= 200 + len(treatment_assignment) / 2
94 |
95 |
96 | def test_splitter_weighted(df):
97 | splitter = NonClusteredSplitter(splitter_weights=[0.1, 0.9])
98 | # Check counts A and B are 10/90
99 | treatment_assignment = splitter.assign_treatment_df(df)
100 | n_a = treatment_assignment.treatment.value_counts()["A"]
101 | assert n_a >= -100 + len(treatment_assignment) * 0.1
102 | assert n_a <= 100 + len(treatment_assignment) * 0.1
103 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from cluster_experiments.synthetic_control_utils import loss_w
5 | from cluster_experiments.utils import _get_mapping_key
6 |
7 |
8 | def test_get_mapping_key():
9 | with pytest.raises(KeyError):
10 | mapping = {"a": 1, "b": 2}
11 | _get_mapping_key(mapping, "c")
12 |
13 |
14 | def test_loss_w():
15 | W = np.array([2, 1]) # Weights vector
16 | X = np.array([[1, 2], [3, 4], [5, 6]]) # Input matrix
17 | y = np.array([6, 14, 22]) # Actual outputs
18 |
19 | # Calculate expected result
20 | # Predictions are calculated as follows:
21 | # [1*2 + 2*1, 3*2 + 4*1, 5*2 + 6*1] = [4, 10, 16]
22 | # RMSE is sqrt(mean([(6-4)^2, (14-10)^2, (22-16)^2]))
23 | expected_rmse = np.sqrt(
24 | np.mean((np.array([6, 14, 22]) - np.array([4, 10, 16])) ** 2)
25 | )
26 |
27 | # Call the function
28 | calculated_rmse = loss_w(W, X, y)
29 |
30 | # Assert if the calculated RMSE matches the expected RMSE
31 | assert np.isclose(
32 | calculated_rmse, expected_rmse
33 | ), f"Expected RMSE: {expected_rmse}, but got: {calculated_rmse}"
34 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from cluster_experiments.random_splitter import RandomSplitter
7 |
8 | TARGETS = {
9 | "binary": lambda x: np.random.choice([0, 1], size=x),
10 | "continuous": lambda x: np.random.normal(0, 1, x),
11 | }
12 |
13 |
14 | def combine_columns(df, cols_list):
15 | # combines columns for testing the stratified splitter in case of multiple strata or clusters
16 | if len(cols_list) > 1:
17 | return df[cols_list].agg("-".join, axis=1)
18 | else:
19 | return df[cols_list]
20 |
21 |
22 | def assert_balanced_strata(
23 | splitter: RandomSplitter,
24 | df: pd.DataFrame,
25 | strata_cols: List[str],
26 | cluster_cols: List[str],
27 | treatments: List[str],
28 | ):
29 | # asserts the balance of the stratified splitter for a given input data frame
30 | treatment_df = splitter.assign_treatment_df(df)
31 |
32 | treatment_df_unique = treatment_df[
33 | strata_cols + cluster_cols + ["treatment"]
34 | ].drop_duplicates()
35 |
36 | treatment_df_unique["clusters_test"] = combine_columns(
37 | treatment_df_unique, cluster_cols
38 | )
39 | treatment_df_unique["strata_test"] = combine_columns(
40 | treatment_df_unique, strata_cols
41 | )
42 |
43 | for treatment in treatments:
44 | for stratus in treatment_df_unique["strata_test"].unique():
45 | assert (
46 | treatment_df_unique.query(f"strata_test == '{stratus}'")["treatment"]
47 | .value_counts(normalize=True)[treatment]
48 | .squeeze()
49 | ) == 1 / len(treatments)
50 |
51 |
52 | def generate_random_data(clusters, dates, N, target="continuous"):
53 | # Generate random data with clusters and target
54 | users = [f"User {i}" for i in range(1000)]
55 |
56 | target_values = TARGETS[target](N)
57 | df = pd.DataFrame(
58 | {
59 | "cluster": np.random.choice(clusters, size=N),
60 | "target": target_values,
61 | "user": np.random.choice(users, size=N),
62 | "date": np.random.choice(dates, size=N),
63 | }
64 | )
65 |
66 | return df
67 |
68 |
69 | def generate_non_clustered_data(N, n_users):
70 | users = [f"User {i}" for i in range(n_users)]
71 | df = pd.DataFrame(
72 | {
73 | "target": np.random.normal(0, 1, size=N),
74 | "user": np.random.choice(users, size=N),
75 | }
76 | )
77 | return df
78 |
79 |
80 | def generate_clustered_data() -> pd.DataFrame:
81 | analysis_df = pd.DataFrame(
82 | {
83 | "country_code": ["ES"] * 4 + ["IT"] * 4 + ["PL"] * 4 + ["RO"] * 4,
84 | "city_code": ["BCN", "BCN", "MAD", "BCN"]
85 | + ["NAP"] * 4
86 | + ["WAW"] * 4
87 | + ["BUC"] * 4,
88 | "user_id": [1, 1, 2, 1, 3, 4, 5, 6, 7, 8, 8, 8, 9, 9, 9, 10],
89 | "date": ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-04"] * 4,
90 | "treatment": [
91 | "A",
92 | "A",
93 | "B",
94 | "A",
95 | "B",
96 | "B",
97 | "A",
98 | "B",
99 | "B",
100 | "A",
101 | "A",
102 | "A",
103 | "B",
104 | "B",
105 | "B",
106 | "A",
107 | ], # Randomization is done at user level, so same user will always have same treatment
108 | "target": [0.01] * 15 + [0.1],
109 | }
110 | )
111 | return analysis_df
112 |
113 |
114 | def generate_ratio_metric_data(
115 | dates,
116 | N,
117 | user_target_means: Optional[np.ndarray] = None,
118 | num_users=2000,
119 | treatment_effect=0.25,
120 | ) -> pd.DataFrame:
121 |
122 | if user_target_means is None:
123 | user_target_means = np.random.normal(0.3, 0.15, num_users)
124 |
125 | user_sessions = np.random.choice(num_users, N)
126 | user_dates = np.random.choice(dates, N)
127 |
128 | # assign treatment groups
129 | treatment = np.random.choice([0, 1], num_users)
130 |
131 | x1 = np.random.normal(0, 0.01, N)
132 | x2 = np.random.normal(0, 0.01, N)
133 | noise = np.random.normal(0, 0.01, N)
134 |
135 | # create target rate per session level
136 | target_percent_per_session = (
137 | treatment_effect * treatment[user_sessions]
138 | + user_target_means[user_sessions]
139 | + x1
140 | + x2**2
141 | + noise
142 | )
143 |
144 | # Remove <0 or >1
145 | target_percent_per_session[target_percent_per_session > 1] = 1
146 | target_percent_per_session[target_percent_per_session < 0] = 0
147 |
148 | targets_observed = np.random.binomial(1, target_percent_per_session)
149 |
150 | # rename treatment array 0-->A, 1-->B
151 | mapped_treatment = np.where(treatment == 0, "A", "B")
152 |
153 | return pd.DataFrame(
154 | {
155 | "user": user_sessions,
156 | "date": user_dates,
157 | "treatment": mapped_treatment[user_sessions],
158 | "target": targets_observed,
159 | "scale": np.ones_like(user_sessions),
160 | "x1": x1,
161 | "x2": x2,
162 | "user_target_means": user_target_means[user_sessions],
163 | }
164 | )
165 |
--------------------------------------------------------------------------------
/theme/flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/theme/flow.png
--------------------------------------------------------------------------------
/theme/icon-cluster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david26694/cluster-experiments/b5c39ed993ff68a5acf5df59f54ff6920a60e99f/theme/icon-cluster.png
--------------------------------------------------------------------------------
/theme/icon-cluster.svg:
--------------------------------------------------------------------------------
1 |
2 |
5 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | isolated_build = True
3 | envlist = py38,py39,py310
4 |
5 | [testenv]
6 | deps = .[dev]
7 |
8 | commands =
9 | black --check cluster_experiments
10 | coverage run --source=cluster_experiments --branch -m pytest .
11 | coverage report -m
12 |
--------------------------------------------------------------------------------