├── .bumpversion.cfg ├── .editorconfig ├── .github └── workflows │ ├── main.yml │ ├── release.yml │ └── update-sdk-versions.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── growthbook ├── __init__.py ├── common_types.py ├── core.py ├── growthbook.py ├── growthbook_client.py └── py.typed ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── setup.cfg ├── setup.py └── tests ├── cases.json ├── conftest.py ├── test_growthbook.py └── test_growthbook_client.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.2.1 3 | commit = True 4 | tag = True 5 | tag_name = v{new_version} 6 | parse = (?P\d+)\.(?P\d+)\.(?P\d+) 7 | serialize = {major}.{minor}.{patch} 8 | 9 | [bumpversion:file:growthbook/__init__.py] 10 | search = __version__ = "{current_version}" 11 | replace = __version__ = "{new_version}" -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.json] 14 | indent_size = 2 15 | 16 | [*.md] 17 | indent_size = 2 18 | 19 | [*.yml] 20 | indent_size = 2 21 | 22 | [*.bat] 23 | indent_style = tab 24 | end_of_line = crlf 25 | 26 | [LICENSE] 27 | insert_final_newline = false 28 | 29 | [Makefile] 30 | indent_style = tab 31 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: [push] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-22.04 6 | strategy: 7 | matrix: 8 | python-version: [3.7, 3.8, 3.9, "3.10", "3.11", "3.12"] 9 | 10 | steps: 11 | - name: "Begin CI..." 12 | uses: actions/checkout@v2 13 | 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements_dev.txt 23 | pip install -r requirements.txt 24 | python -m pip install . 25 | 26 | - name: Lint with flake8 27 | run: | 28 | # stop the build if there are Python syntax errors or undefined names 29 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 30 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 31 | flake8 . --count --exit-zero --max-complexity=25 --max-line-length=127 --statistics 32 | 33 | - name: Static type checking with MyPy 34 | run: | 35 | mypy growthbook/growthbook*.py --implicit-optional 36 | 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | workflow_dispatch: 4 | inputs: 5 | release_type: 6 | description: 'Release type (major, minor, patch)' 7 | required: true 8 | default: 'patch' 9 | type: choice 10 | options: 11 | - patch 12 | - minor 13 | - major 14 | 15 | jobs: 16 | release: 17 | runs-on: ubuntu-latest 18 | permissions: 19 | contents: write 20 | id-token: write 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 # Get all history and tags 26 | token: ${{ secrets.GITHUB_TOKEN }} 27 | submodules: true # Get submodules if any 28 | persist-credentials: false 29 | fetch-tags: true 30 | 31 | - name: Set up Python 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: '3.11' 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip setuptools 39 | pip install -r requirements_dev.txt 40 | pip install -r requirements.txt 41 | 42 | - name: Check Git Files 43 | run: ls -la .git 44 | 45 | - name: Run tests 46 | run: pytest 47 | 48 | - name: Configure Git 49 | run: | 50 | git config --global user.name "github-actions" 51 | git config --global user.email "github-actions@github.com" 52 | 53 | - name: Create Release & Publish to PyPI 54 | env: 55 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 56 | TWINE_USERNAME: __token__ 57 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 58 | run: | 59 | # Install tools 60 | pip install bump2version build twine 61 | 62 | # 1. Bump version 63 | case "${{ github.event.inputs.release_type }}" in 64 | "major") 65 | bump2version major 66 | ;; 67 | "minor") 68 | bump2version minor 69 | ;; 70 | *) 71 | bump2version patch 72 | ;; 73 | esac 74 | 75 | # 2. Build package 76 | python -m build 77 | 78 | # 3. Create GitHub Release 79 | VERSION=$(grep "__version__" growthbook/__init__.py | cut -d'"' -f2) 80 | git push && git push --tags 81 | gh release create "v${VERSION}" \ 82 | --title "v${VERSION}" \ 83 | --notes "Release v${VERSION}" \ 84 | dist/* 85 | 86 | # 4. Publish to PyPI 87 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/update-sdk-versions.yml: -------------------------------------------------------------------------------- 1 | name: Update SDK Version 2 | 3 | on: 4 | release: 5 | types: [published] # Triggers when a release is published 6 | workflow_dispatch: {} # Manual trigger without inputs 7 | 8 | jobs: 9 | update-versions: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | pull-requests: write 14 | 15 | steps: 16 | - name: Get Latest Release 17 | id: latest_release 18 | run: | 19 | LATEST_TAG=$(gh api repos/growthbook/growthbook-python/releases/latest --jq .tag_name) 20 | echo "version=${LATEST_TAG#v}" >> $GITHUB_OUTPUT 21 | env: 22 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 23 | 24 | - name: Update SDK Versions Repository 25 | env: 26 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | run: | 28 | # Get version from latest release 29 | VERSION=${{ steps.latest_release.outputs.version }} 30 | 31 | # Validate version format (semver) 32 | if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$'; then 33 | echo "Error: Version '$VERSION' is not in semantic versioning format (x.y.z)" 34 | exit 1 35 | fi 36 | 37 | # Remove directory if it exists and clone fresh 38 | rm -rf growthbook || true 39 | git clone https://github.com/growthbook/growthbook.git || { 40 | echo "Failed to clone repository" 41 | exit 1 42 | } 43 | cd growthbook/packages/shared/src/sdk-versioning/sdk-versions 44 | 45 | # Check if version already exists 46 | if jq -e --arg v "$VERSION" '.versions[] | select(.version == $v)' python.json > /dev/null; then 47 | echo "Version $VERSION already exists in python.json" 48 | exit 0 # Exit successfully since this is not an error condition 49 | fi 50 | 51 | # Create a new branch 52 | git checkout -b update-python-sdk-${VERSION} 53 | 54 | # Update the JSON file 55 | jq --arg v "$VERSION" \ 56 | '.versions = ([{"version": $v}] + .versions)' \ 57 | python.json > python.json.tmp && mv python.json.tmp python.json 58 | 59 | # Commit and push changes 60 | git config user.name "github-actions" 61 | git config user.email "github-actions@github.com" 62 | git add python.json 63 | git commit -m "chore: update Python SDK to ${VERSION}" 64 | git push origin update-python-sdk-${VERSION} 65 | 66 | # Create Pull Request 67 | gh pr create \ 68 | --title "Update Python SDK to ${VERSION}" \ 69 | --body "Automated PR to update Python SDK version to ${VERSION} in capabilities matrix" \ 70 | --repo growthbook/growthbook \ 71 | --base main -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | .pytest_cache/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # IDE settings 104 | .vscode/ 105 | 106 | # Generated version file 107 | growthbook/_version.py -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## **1.1.0** - Apr 11, 2024 4 | 5 | - Support for prerequisite feature flags 6 | - Optional Sticky Bucketing for experiment variation assignments 7 | - SemVer targeting support 8 | - Fixed multiple bugs and edge cases when comparing different data types 9 | - Fixed bugs with the $in and $nin operators 10 | - Now, we ignore unknown fields in feature definitions instead of throwing Exceptions 11 | - Support for feature rule ids (for easier debugging) 12 | 13 | ## **1.0.0** - Apr 23, 2023 14 | 15 | - Update to the official 0.4.1 GrowthBook SDK spec version 16 | - Built-in fetching and caching of feature flags from the GrowthBook API 17 | - Added detailed logging for easier debugging 18 | - Support for new feature/experiment properties that enable holdout groups, meta info, better hashing algorithms, and more 19 | 20 | ## **0.3.1** - Aug 1, 2022 21 | 22 | - Bug fix - skip experiment when the hashAttribute's value is `None` 23 | 24 | ## **0.3.0** - May 24, 2022 25 | 26 | - Bug fix - don't skip feature rules when experiment variation is forced 27 | 28 | ## **0.2.0** - Feb 13, 2022 29 | 30 | - Support for Feature Flags 31 | 32 | ## **0.1.1** - Jun 15, 2021 33 | 34 | - Initial release (inline experiments only) 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | We welcome all contributions! 4 | 5 | ## Type checking 6 | 7 | We use MyPy for type checking. 8 | 9 | ```bash 10 | make type-check 11 | ``` 12 | 13 | ## Tests 14 | 15 | Run the test suite with 16 | 17 | ```bash 18 | make test 19 | ``` 20 | 21 | ## Linting 22 | 23 | Lint the code with flake8 24 | 25 | ```bash 26 | make lint 27 | ``` 28 | 29 | ## Releasing 30 | 31 | 1. Bump the version in `setup.cfg` and `setup.py` 32 | 2. Merge code to `main` 33 | 3. Create a new release on GitHub with your version as the tag (e.g. `v0.2.0`) 34 | 4. Run `make dist` to create a distribution file 35 | 5. Run `make release` to upload to pypi 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021, GrowthBook 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include growthbook/*.py 2 | include setup.py 3 | include README.md 4 | include LICENSE 5 | include tests/*.py 6 | include growthbook/py.typed -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -f .coverage 46 | rm -fr htmlcov/ 47 | rm -fr .pytest_cache 48 | 49 | lint: ## check style with flake8 50 | flake8 growthbook/growthbook.py --max-line-length=150 51 | 52 | type-check: 53 | mypy growthbook/growthbook.py --implicit-optional 54 | 55 | test: ## run tests quickly with the default Python 56 | pytest 57 | 58 | coverage: ## check code coverage quickly with the default Python 59 | coverage run --source growthbook -m pytest 60 | coverage report -m 61 | coverage html 62 | $(BROWSER) htmlcov/index.html 63 | 64 | release: dist ## package and upload a release 65 | twine upload dist/* 66 | 67 | dist: clean ## builds source and wheel package 68 | python setup.py sdist 69 | python setup.py bdist_wheel 70 | ls -l dist 71 | 72 | install: clean ## install the package to the active Python's site-packages 73 | python setup.py install 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GrowthBook Python SDK 2 | 3 | Powerful Feature flagging and A/B testing for Python apps. 4 | 5 | ![Build Status](https://github.com/growthbook/growthbook-python/workflows/Build/badge.svg) 6 | 7 | - **Lightweight and fast** 8 | - **Local evaluation**, no network requests required 9 | - Python 3.6+ 10 | - 100% test coverage 11 | - Flexible **targeting** 12 | - **Use your existing event tracking** (GA, Segment, Mixpanel, custom) 13 | - **Remote configuration** to change feature flags without deploying new code 14 | - **Async support** with real-time feature updates 15 | 16 | ## Installation 17 | 18 | `pip install growthbook` (recommended) or copy `growthbook.py` into your project 19 | 20 | ## Quick Usage 21 | 22 | ```python 23 | from growthbook import GrowthBook 24 | 25 | # User attributes for targeting and experimentation 26 | attributes = { 27 | "id": "123", 28 | "customUserAttribute": "foo" 29 | } 30 | 31 | def on_experiment_viewed(experiment, result): 32 | # Use whatever event tracking system you want 33 | analytics.track(attributes["id"], "Experiment Viewed", { 34 | 'experimentId': experiment.key, 35 | 'variationId': result.variationId 36 | }) 37 | 38 | # Create a GrowthBook instance 39 | gb = GrowthBook( 40 | attributes = attributes, 41 | on_experiment_viewed = on_experiment_viewed, 42 | api_host = "https://cdn.growthbook.io", 43 | client_key = "sdk-abc123" 44 | ) 45 | 46 | # Load features from the GrowthBook API with caching 47 | gb.load_features() 48 | 49 | # Simple on/off feature gating 50 | if gb.is_on("my-feature"): 51 | print("My feature is on!") 52 | 53 | # Get the value of a feature with a fallback 54 | color = gb.get_feature_value("button-color-feature", "blue") 55 | ``` 56 | 57 | ### Web Frameworks (Django, Flask, etc.) 58 | 59 | For web frameworks, you should create a new `GrowthBook` instance for every incoming request and call `destroy()` at the end of the request to clean up resources. 60 | 61 | In Django, for example, this is best done with a simple middleware: 62 | 63 | ```python 64 | from growthbook import GrowthBook 65 | 66 | def growthbook_middleware(get_response): 67 | def middleware(request): 68 | request.gb = GrowthBook( 69 | # ... 70 | ) 71 | request.gb.load_features() 72 | 73 | response = get_response(request) 74 | 75 | request.gb.destroy() # Cleanup 76 | 77 | return response 78 | return middleware 79 | ``` 80 | 81 | Then, you can easily use GrowthBook in any of your views: 82 | 83 | ```python 84 | def index(request): 85 | feature_enabled = request.gb.is_on("my-feature") 86 | # ... 87 | ``` 88 | 89 | ## Quick Usage - Async Client 90 | 91 | ```python 92 | from growthbook import GrowthBookClient, Options, UserContext, FeatureRefreshStrategy 93 | import asyncio 94 | 95 | async def main(): 96 | # Create client options 97 | options = Options( 98 | api_host="https://cdn.growthbook.io", 99 | client_key="sdk-abc123", 100 | # Optional: Enable real-time feature updates 101 | refresh_strategy=FeatureRefreshStrategy.SERVER_SENT_EVENTS 102 | ) 103 | 104 | # Create and initialize client 105 | client = GrowthBookClient(options) 106 | try: 107 | # Initialize the client before using it 108 | success = await client.initialize() 109 | if not success: 110 | print("Failed to initialize GrowthBook client") 111 | return 112 | 113 | # Create user context for targeting 114 | user = UserContext( 115 | attributes={ 116 | "id": "123", 117 | "country": "US", 118 | "premium": True 119 | } 120 | ) 121 | 122 | # Simple feature evaluation 123 | if await client.is_on("new-homepage", user): 124 | print("New homepage is enabled!") 125 | 126 | # Get feature value with fallback 127 | color = await client.get_feature_value("button-color", "blue", user) 128 | print(f"Button color is {color}") 129 | 130 | # Run an experiment 131 | result = await client.run( 132 | Experiment( 133 | key="my-test", 134 | variations=["A", "B"] 135 | ), 136 | user 137 | ) 138 | print(f"User got variation: {result.value}") 139 | finally: 140 | # Always close the client when done 141 | await client.close() 142 | 143 | # Run the async code 144 | asyncio.run(main()) 145 | ``` 146 | 147 | ### Async Web Framework Integration 148 | 149 | The async client works great with async web frameworks like FastAPI: 150 | 151 | ```python 152 | from fastapi import FastAPI, Depends 153 | from growthbook import GrowthBookClient, Options, UserContext 154 | 155 | app = FastAPI() 156 | 157 | # Create a single client instance 158 | gb_client = GrowthBookClient( 159 | Options( 160 | api_host="https://cdn.growthbook.io", 161 | client_key="sdk-abc123" 162 | ) 163 | ) 164 | 165 | @app.on_event("startup") 166 | async def startup(): 167 | # Initialize the client when the app starts 168 | await gb_client.initialize() 169 | 170 | @app.on_event("shutdown") 171 | async def shutdown(): 172 | # Clean up when the app shuts down 173 | await gb_client.close() 174 | 175 | @app.get("/") 176 | async def root(user_id: str): 177 | # Create user context for the request 178 | user = UserContext(attributes={"id": user_id}) 179 | 180 | # Use features 181 | show_new_ui = await gb_client.is_on("new-ui", user) 182 | return {"new_ui": show_new_ui} 183 | ``` 184 | 185 | ### Real-time Feature Updates 186 | 187 | The async client supports real-time feature updates using Server-Sent Events: 188 | 189 | ```python 190 | from growthbook import GrowthBookClient, Options, FeatureRefreshStrategy 191 | 192 | client = GrowthBookClient( 193 | Options( 194 | api_host="https://cdn.growthbook.io", 195 | client_key="sdk-abc123", 196 | # Enable SSE for real-time updates 197 | refresh_strategy=FeatureRefreshStrategy.SERVER_SENT_EVENTS 198 | ) 199 | ) 200 | ``` 201 | 202 | ### Concurrency and Thread Safety 203 | 204 | The async client is designed to be thread-safe and handle concurrent requests efficiently. You can safely use a single client instance across multiple coroutines. For web applications, you can create a single client instance at startup and share it across requests. Here's an example: 205 | 206 | ```python 207 | from fastapi import FastAPI 208 | from growthbook import GrowthBookClient, Options, UserContext 209 | import asyncio 210 | 211 | app = FastAPI() 212 | 213 | # Single client instance shared across all requests 214 | gb_client = GrowthBookClient(Options( 215 | api_host="https://cdn.growthbook.io", 216 | client_key="sdk-abc123" 217 | )) 218 | 219 | @app.on_event("startup") 220 | async def startup(): 221 | await gb_client.initialize() 222 | 223 | @app.on_event("shutdown") 224 | async def shutdown(): 225 | await gb_client.close() 226 | 227 | @app.get("/batch") 228 | async def batch_process(user_ids: list[str]): 229 | # Safely process multiple users concurrently 230 | tasks = [] 231 | for user_id in user_ids: 232 | user = UserContext(attributes={"id": user_id}) 233 | tasks.append(gb_client.eval_feature("new-feature", user)) 234 | 235 | results = await asyncio.gather(*tasks) 236 | return {"results": results} 237 | ``` 238 | 239 | Note: While the client is thread-safe, you should not share a single `UserContext` instance across different requests. Create a new `UserContext` for each request to maintain proper isolation. 240 | 241 | ## Loading Features 242 | 243 | There are two ways to load feature flags into the GrowthBook SDK. You can either use the built-in fetching/caching logic or implement your own custom solution. 244 | 245 | ### Built-in Fetching and Caching 246 | 247 | To use the built-in fetching and caching logic, in the `GrowthBook` constructor, pass in your GrowthBook `api_host` and `client_key`. If you have encryption enabled for your GrowthBook endpoint, you also need to pass the `decryption_key` into the constructor. 248 | 249 | Then, call the `load_features()` method to initiate the HTTP request with a cache layer. 250 | 251 | Here's a full example: 252 | 253 | ```python 254 | gb = GrowthBook( 255 | api_host = "https://cdn.growthbook.io", 256 | client_key = "sdk-abc123", 257 | # How long to cache features in seconds (Optional, default 60s) 258 | cache_ttl = 60, 259 | ) 260 | gb.load_features() 261 | ``` 262 | 263 | #### Caching 264 | 265 | GrowthBook comes with a custom in-memory cache. If you run Python in a multi-process mode, the different processes cannot share memory, so you likely want to switch to a distributed cache system like Redis instead. 266 | 267 | Here is an example of using Redis: 268 | 269 | ```python 270 | from redis import Redis 271 | import json 272 | from growthbook import GrowthBook, AbstractFeatureCache, feature_repo 273 | 274 | class RedisFeatureCache(AbstractFeatureCache): 275 | def __init__(self): 276 | self.r = Redis(host='localhost', port=6379) 277 | self.prefix = "gb:" 278 | 279 | def get(self, key: str): 280 | data = self.r.get(self.prefix + key) 281 | # Data stored as a JSON string, parse into dict before returning 282 | return None if data is None else json.loads(data) 283 | 284 | def set(self, key: str, value: dict, ttl: int) -> None: 285 | self.r.set(self.prefix + key, json.dumps(value)) 286 | self.r.expire(self.prefix + key, ttl) 287 | 288 | # Configure GrowthBook to use your custom cache class 289 | feature_repo.set_cache(RedisFeatureCache()) 290 | ``` 291 | 292 | ### Custom Implementation 293 | 294 | If you prefer to handle the entire fetching/caching logic yourself, you can just pass in a `dict` of features from the GrowthBook API directly into the constructor: 295 | 296 | ```python 297 | # From the GrowthBook API 298 | features = {'my-feature':{'defaultValue':False}} 299 | 300 | gb = GrowthBook( 301 | features = features 302 | ) 303 | ``` 304 | 305 | Note: When doing this, you do not need to specify your `api_host` or `client_key` and you don't need to call `gb.load_features()`. 306 | 307 | ## GrowthBook class 308 | 309 | The GrowthBook constructor has the following parameters: 310 | 311 | - **enabled** (`bool`) - Flag to globally disable all experiments. Default true. 312 | - **attributes** (`dict`) - Dictionary of user attributes that are used for targeting and to assign variations 313 | - **url** (`str`) - The URL of the current request (if applicable) 314 | - **qa_mode** (`boolean`) - If true, random assignment is disabled and only explicitly forced variations are used. 315 | - **on_experiment_viewed** (`callable`) - A function that takes `experiment` and `result` as arguments. 316 | - **api_host** (`str`) - The GrowthBook API host to fetch feature flags from. Defaults to `https://cdn.growthbook.io` 317 | - **client_key** (`str`) - The client key that will be passed to the API Host to fetch feature flags 318 | - **decryption_key** (`str`) - If the GrowthBook API endpoint has encryption enabled, specify the decryption key here 319 | - **cache_ttl** (`int`) - How long to cache features in-memory from the GrowthBook API (seconds, default `60`) 320 | - **features** (`dict`) - Feature definitions from the GrowthBook API (only required if `client_key` is not specified) 321 | - **forced_variations** (`dict`) - Dictionary of forced experiment variations (used for QA) 322 | 323 | There are also getter and setter methods for features and attributes if you need to update them later in the request: 324 | 325 | ```python 326 | gb.set_features(gb.get_features()) 327 | gb.set_attributes(gb.get_attributes()) 328 | ``` 329 | 330 | ### Attributes 331 | 332 | You can specify attributes about the current user and request. These are used for two things: 333 | 334 | 1. Feature targeting (e.g. paid users get one value, free users get another) 335 | 2. Assigning persistent variations in A/B tests (e.g. user id "123" always gets variation B) 336 | 337 | Attributes can be any JSON data type - boolean, integer, float, string, list, or dict. 338 | 339 | ```python 340 | attributes = { 341 | 'id': "123", 342 | 'loggedIn': True, 343 | 'age': 21.5, 344 | 'tags': ["tag1", "tag2"], 345 | 'account': { 346 | 'age': 90 347 | } 348 | } 349 | 350 | # Pass into constructor 351 | gb = GrowthBook(attributes = attributes) 352 | 353 | # Or set later 354 | gb.set_attributes(attributes) 355 | ``` 356 | 357 | ### Tracking Experiments 358 | 359 | Any time an experiment is run to determine the value of a feature, you want to track that event in your analytics system. 360 | 361 | You can use the `on_experiment_viewed` option to do this: 362 | 363 | ```python 364 | from growthbook import GrowthBook, Experiment, Result 365 | 366 | def on_experiment_viewed(experiment: Experiment, result: Result): 367 | # Use whatever event tracking system you want 368 | analytics.track(attributes["id"], "Experiment Viewed", { 369 | 'experimentId': experiment.key, 370 | 'variationId': result.variationId 371 | }) 372 | 373 | # Pass into constructor 374 | gb = GrowthBook( 375 | on_experiment_viewed = on_experiment_viewed 376 | ) 377 | ``` 378 | 379 | ## Using Features 380 | 381 | There are 3 main methods for interacting with features. 382 | 383 | - `gb.is_on("feature-key")` returns true if the feature is on 384 | - `gb.is_off("feature-key")` returns false if the feature is on 385 | - `gb.get_feature_value("feature-key", "default")` returns the value of the feature with a fallback 386 | 387 | In addition, you can use `gb.evalFeature("feature-key")` to get back a `FeatureResult` object with the following properties: 388 | 389 | - **value** - The JSON-decoded value of the feature (or `None` if not defined) 390 | - **on** and **off** - The JSON-decoded value cast to booleans 391 | - **source** - Why the value was assigned to the user. One of `unknownFeature`, `defaultValue`, `force`, or `experiment` 392 | - **experiment** - Information about the experiment (if any) which was used to assign the value to the user 393 | - **experimentResult** - The result of the experiment (if any) which was used to assign the value to the user 394 | 395 | ## Sticky Bucketing 396 | 397 | By default GrowthBook does not persist assigned experiment variations for a user. We rely on deterministic hashing to ensure that the same user attributes always map to the same experiment variation. However, there are cases where this isn't good enough. For example, if you change targeting conditions in the middle of an experiment, users may stop being shown a variation even if they were previously bucketed into it. 398 | 399 | Sticky Bucketing is a solution to these issues. You can provide a Sticky Bucket Service to the GrowthBook instance to persist previously seen variations and ensure that the user experience remains consistent for your users. 400 | 401 | A sample `InMemoryStickyBucketService` implementation is provided for reference, but in production you will definitely want to implement your own version using a database, cookies, or similar for persistence. 402 | 403 | Sticky Bucket documents contain three fields 404 | 405 | - `attributeName` - The name of the attribute used to identify the user (e.g. `id`, `cookie_id`, etc.) 406 | - `attributeValue` - The value of the attribute (e.g. `123`) 407 | - `assignments` - A dictionary of persisted experiment assignments. For example: `{"exp1__0":"control"}` 408 | 409 | The attributeName/attributeValue combo is the primary key. 410 | 411 | Here's an example implementation using a theoretical `db` object: 412 | 413 | ```python 414 | from growthbook import AbstractStickyBucketService, GrowthBook 415 | 416 | class MyStickyBucketService(AbstractStickyBucketService): 417 | # Lookup a sticky bucket document 418 | def get_assignments(self, attributeName: str, attributeValue: str) -> Optional[Dict]: 419 | return db.find({ 420 | "attributeName": attributeName, 421 | "attributeValue": attributeValue 422 | }) 423 | 424 | def save_assignments(self, doc: Dict) -> None: 425 | # Insert new record if not exists, otherwise update 426 | db.upsert({ 427 | "attributeName": doc["attributeName"], 428 | "attributeValue": doc["attributeValue"] 429 | }, { 430 | "$set": { 431 | "assignments": doc["assignments"] 432 | } 433 | }) 434 | 435 | # Pass in an instance of this service to your GrowthBook constructor 436 | 437 | gb = GrowthBook( 438 | sticky_bucket_service = MyStickyBucketService() 439 | ) 440 | ``` 441 | 442 | ## Inline Experiments 443 | 444 | Instead of declaring all features up-front and referencing them by ids in your code, you can also just run an experiment directly. This is done with the `run` method: 445 | 446 | ```python 447 | from growthbook import Experiment 448 | 449 | exp = Experiment( 450 | key = "my-experiment", 451 | variations = ["red", "blue", "green"] 452 | ) 453 | 454 | # Either "red", "blue", or "green" 455 | print(gb.run(exp).value) 456 | ``` 457 | 458 | As you can see, there are 2 required parameters for experiments, a string key, and an array of variations. Variations can be any data type, not just strings. 459 | 460 | There are a number of additional settings to control the experiment behavior: 461 | 462 | - **key** (`str`) - The globally unique tracking key for the experiment 463 | - **variations** (`any[]`) - The different variations to choose between 464 | - **seed** (`str`) - Added to the user id when hashing to determine a variation. Defaults to the experiment `key` 465 | - **weights** (`float[]`) - How to weight traffic between variations. Must add to 1. 466 | - **coverage** (`float`) - What percent of users should be included in the experiment (between 0 and 1, inclusive) 467 | - **condition** (`dict`) - Targeting conditions 468 | - **force** (`int`) - All users included in the experiment will be forced into the specified variation index 469 | - **hashAttribute** (`string`) - What user attribute should be used to assign variations (defaults to "id") 470 | - **hashVersion** (`int`) - What version of our hashing algorithm to use. We recommend using the latest version `2`. 471 | - **namespace** (`tuple[str,float,float]`) - Used to run mutually exclusive experiments. 472 | 473 | Here's an example that uses all of them: 474 | 475 | ```python 476 | exp = Experiment( 477 | key="my-test", 478 | # Variations can be a list of any data type 479 | variations=[0, 1], 480 | # If this changes, it will re-randomize all users in the experiment 481 | seed="abcdef123456", 482 | # Run a 40/60 experiment instead of the default even split (50/50) 483 | weights=[0.4, 0.6], 484 | # Only include 20% of users in the experiment 485 | coverage=0.2, 486 | # Targeting condition using a MongoDB-like syntax 487 | condition={ 488 | 'country': 'US', 489 | 'browser': { 490 | '$in': ['chrome', 'firefox'] 491 | } 492 | }, 493 | # Use an alternate attribute for assigning variations (default is 'id') 494 | hashAttribute="sessionId", 495 | # Use the latest hashing algorithm 496 | hashVersion=2, 497 | # Includes the first 50% of users in the "pricing" namespace 498 | # Another experiment with a non-overlapping range will be mutually exclusive (e.g. [0.5, 1]) 499 | namespace=("pricing", 0, 0.5), 500 | ) 501 | ``` 502 | 503 | ### Inline Experiment Return Value 504 | 505 | A call to `run` returns a `Result` object with a few useful properties: 506 | 507 | ```python 508 | result = gb.run(exp) 509 | 510 | # If user is part of the experiment 511 | print(result.inExperiment) # True or False 512 | 513 | # The index of the assigned variation 514 | print(result.variationId) # e.g. 0 or 1 515 | 516 | # The value of the assigned variation 517 | print(result.value) # e.g. "A" or "B" 518 | 519 | # If the variation was randomly assigned by hashing user attributes 520 | print(result.hashUsed) # True or False 521 | 522 | # The user attribute used to assign a variation 523 | print(result.hashAttribute) # "id" 524 | 525 | # The value of that attribute 526 | print(result.hashValue) # e.g. "123" 527 | ``` 528 | 529 | The `inExperiment` flag will be false if the user was excluded from being part of the experiment for any reason (e.g. failed targeting conditions). 530 | 531 | The `hashUsed` flag will only be true if the user was randomly assigned a variation. If the user was forced into a specific variation instead, this flag will be false. 532 | 533 | ### Example Experiments 534 | 535 | 3-way experiment with uneven variation weights: 536 | 537 | ```python 538 | gb.run(Experiment( 539 | key = "3-way-uneven", 540 | variations = ["A","B","C"], 541 | weights = [0.5, 0.25, 0.25] 542 | )) 543 | ``` 544 | 545 | Slow rollout (10% of users who match the targeting condition): 546 | 547 | ```python 548 | # User is marked as being in "qa" and "beta" 549 | gb = GrowthBook( 550 | attributes = { 551 | "id": "123", 552 | "beta": True, 553 | "qa": True, 554 | }, 555 | ) 556 | 557 | gb.run(Experiment( 558 | key = "slow-rollout", 559 | variations = ["A", "B"], 560 | coverage = 0.1, 561 | condition = { 562 | 'beta': True 563 | } 564 | )) 565 | ``` 566 | 567 | Complex variations 568 | 569 | ```python 570 | result = gb.run(Experiment( 571 | key = "complex-variations", 572 | variations = [ 573 | ("blue", "large"), 574 | ("green", "small") 575 | ], 576 | )) 577 | 578 | # Either "blue,large" OR "green,small" 579 | print(result.value[0] + "," + result.value[1]) 580 | ``` 581 | 582 | Assign variations based on something other than user id 583 | 584 | ```python 585 | gb = GrowthBook( 586 | attributes = { 587 | "id": "123", 588 | "company": "growthbook" 589 | } 590 | ) 591 | 592 | # Users in the same company will always get the same variation 593 | gb.run(Experiment( 594 | key = "by-company-id", 595 | variations = ["A", "B"], 596 | hashAttribute = "company" 597 | )) 598 | ``` 599 | 600 | ## Logging 601 | 602 | The GrowthBook SDK uses a Python logger with the name `growthbook` and includes helpful info for debugging as well as warnings/errors if something is misconfigured. 603 | 604 | Here's an example of logging to the console 605 | 606 | ```python 607 | import logging 608 | 609 | logger = logging.getLogger('growthbook') 610 | logger.setLevel(logging.DEBUG) 611 | 612 | handler = logging.StreamHandler() 613 | formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s') 614 | handler.setFormatter(formatter) 615 | logger.addHandler(handler) 616 | ``` 617 | -------------------------------------------------------------------------------- /growthbook/__init__.py: -------------------------------------------------------------------------------- 1 | from .growthbook import * 2 | 3 | from .growthbook_client import ( 4 | GrowthBookClient, 5 | EnhancedFeatureRepository, 6 | FeatureCache, 7 | BackoffStrategy 8 | ) 9 | 10 | __version__ = "1.2.1" 11 | -------------------------------------------------------------------------------- /growthbook/common_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | # Only require typing_extensions if using Python 3.7 or earlier 5 | if sys.version_info >= (3, 8): 6 | from typing import TypedDict 7 | else: 8 | from typing_extensions import TypedDict 9 | 10 | from dataclasses import dataclass, field 11 | from typing import Any, Dict, List, Optional, Union, Set, Tuple 12 | from enum import Enum 13 | from abc import ABC, abstractmethod 14 | 15 | class VariationMeta(TypedDict): 16 | key: str 17 | name: str 18 | passthrough: bool 19 | 20 | 21 | class Filter(TypedDict): 22 | seed: str 23 | ranges: List[Tuple[float, float]] 24 | hashVersion: int 25 | attribute: str 26 | 27 | class Experiment(object): 28 | def __init__( 29 | self, 30 | key: str, 31 | variations: list, 32 | weights: List[float] = None, 33 | active: bool = True, 34 | status: str = "running", 35 | coverage: int = None, 36 | condition: dict = None, 37 | namespace: Tuple[str, float, float] = None, 38 | url: str = "", 39 | include=None, 40 | groups: list = None, 41 | force: int = None, 42 | hashAttribute: str = "id", 43 | fallbackAttribute: str = None, 44 | hashVersion: int = None, 45 | ranges: List[Tuple[float, float]] = None, 46 | meta: List[VariationMeta] = None, 47 | filters: List[Filter] = None, 48 | seed: str = None, 49 | name: str = None, 50 | phase: str = None, 51 | disableStickyBucketing: bool = False, 52 | bucketVersion: int = None, 53 | minBucketVersion: int = None, 54 | parentConditions: List[dict] = None, 55 | ) -> None: 56 | self.key = key 57 | self.variations = variations 58 | self.weights = weights 59 | self.active = active 60 | self.coverage = coverage 61 | self.condition = condition 62 | self.namespace = namespace 63 | self.force = force 64 | self.hashAttribute = hashAttribute 65 | self.hashVersion = hashVersion or 1 66 | self.ranges = ranges 67 | self.meta = meta 68 | self.filters = filters 69 | self.seed = seed 70 | self.name = name 71 | self.phase = phase 72 | self.disableStickyBucketing = disableStickyBucketing 73 | self.bucketVersion = bucketVersion or 0 74 | self.minBucketVersion = minBucketVersion or 0 75 | self.parentConditions = parentConditions 76 | 77 | self.fallbackAttribute = None 78 | if not self.disableStickyBucketing: 79 | self.fallbackAttribute = fallbackAttribute 80 | 81 | # Deprecated properties 82 | self.status = status 83 | self.url = url 84 | self.include = include 85 | self.groups = groups 86 | 87 | def to_dict(self): 88 | obj = { 89 | "key": self.key, 90 | "variations": self.variations, 91 | "weights": self.weights, 92 | "active": self.active, 93 | "coverage": self.coverage or 1, 94 | "condition": self.condition, 95 | "namespace": self.namespace, 96 | "force": self.force, 97 | "hashAttribute": self.hashAttribute, 98 | "hashVersion": self.hashVersion, 99 | "ranges": self.ranges, 100 | "meta": self.meta, 101 | "filters": self.filters, 102 | "seed": self.seed, 103 | "name": self.name, 104 | "phase": self.phase, 105 | } 106 | 107 | if self.fallbackAttribute: 108 | obj["fallbackAttribute"] = self.fallbackAttribute 109 | if self.disableStickyBucketing: 110 | obj["disableStickyBucketing"] = True 111 | if self.bucketVersion: 112 | obj["bucketVersion"] = self.bucketVersion 113 | if self.minBucketVersion: 114 | obj["minBucketVersion"] = self.minBucketVersion 115 | if self.parentConditions: 116 | obj["parentConditions"] = self.parentConditions 117 | 118 | return obj 119 | 120 | def update(self, data: dict) -> None: 121 | weights = data.get("weights", None) 122 | status = data.get("status", None) 123 | coverage = data.get("coverage", None) 124 | url = data.get("url", None) 125 | groups = data.get("groups", None) 126 | force = data.get("force", None) 127 | 128 | if weights is not None: 129 | self.weights = weights 130 | if status is not None: 131 | self.status = status 132 | if coverage is not None: 133 | self.coverage = coverage 134 | if url is not None: 135 | self.url = url 136 | if groups is not None: 137 | self.groups = groups 138 | if force is not None: 139 | self.force = force 140 | 141 | 142 | class Result(object): 143 | def __init__( 144 | self, 145 | variationId: int, 146 | inExperiment: bool, 147 | value, 148 | hashUsed: bool, 149 | hashAttribute: str, 150 | hashValue: str, 151 | featureId: Optional[str], 152 | meta: VariationMeta = None, 153 | bucket: float = None, 154 | stickyBucketUsed: bool = False, 155 | ) -> None: 156 | self.variationId = variationId 157 | self.inExperiment = inExperiment 158 | self.value = value 159 | self.hashUsed = hashUsed 160 | self.hashAttribute = hashAttribute 161 | self.hashValue = hashValue 162 | self.featureId = featureId or None 163 | self.bucket = bucket 164 | self.stickyBucketUsed = stickyBucketUsed 165 | 166 | self.key = str(variationId) 167 | self.name = "" 168 | self.passthrough = False 169 | 170 | if meta: 171 | if "name" in meta: 172 | self.name = meta["name"] 173 | if "key" in meta: 174 | self.key = meta["key"] 175 | if "passthrough" in meta: 176 | self.passthrough = meta["passthrough"] 177 | 178 | def to_dict(self) -> dict: 179 | obj = { 180 | "featureId": self.featureId, 181 | "variationId": self.variationId, 182 | "inExperiment": self.inExperiment, 183 | "value": self.value, 184 | "hashUsed": self.hashUsed, 185 | "hashAttribute": self.hashAttribute, 186 | "hashValue": self.hashValue, 187 | "key": self.key, 188 | "stickyBucketUsed": self.stickyBucketUsed, 189 | } 190 | 191 | if self.bucket is not None: 192 | obj["bucket"] = self.bucket 193 | if self.name: 194 | obj["name"] = self.name 195 | if self.passthrough: 196 | obj["passthrough"] = True 197 | 198 | return obj 199 | 200 | class FeatureResult(object): 201 | def __init__( 202 | self, 203 | value, 204 | source: str, 205 | experiment: Experiment = None, 206 | experimentResult: Result = None, 207 | ruleId: str = None, 208 | ) -> None: 209 | self.value = value 210 | self.source = source 211 | self.ruleId = ruleId 212 | self.experiment = experiment 213 | self.experimentResult = experimentResult 214 | self.on = bool(value) 215 | self.off = not bool(value) 216 | 217 | def to_dict(self) -> dict: 218 | data = { 219 | "value": self.value, 220 | "source": self.source, 221 | "on": self.on, 222 | "off": self.off, 223 | "ruleId": self.ruleId or "", 224 | } 225 | if self.experiment: 226 | data["experiment"] = self.experiment.to_dict() 227 | if self.experimentResult: 228 | data["experimentResult"] = self.experimentResult.to_dict() 229 | 230 | return data 231 | 232 | class Feature(object): 233 | def __init__(self, defaultValue=None, rules: list = []) -> None: 234 | self.defaultValue = defaultValue 235 | self.rules: List[FeatureRule] = [] 236 | for rule in rules: 237 | if isinstance(rule, FeatureRule): 238 | self.rules.append(rule) 239 | else: 240 | self.rules.append(FeatureRule( 241 | id=rule.get("id", None), 242 | key=rule.get("key", ""), 243 | variations=rule.get("variations", None), 244 | weights=rule.get("weights", None), 245 | coverage=rule.get("coverage", None), 246 | condition=rule.get("condition", None), 247 | namespace=rule.get("namespace", None), 248 | force=rule.get("force", None), 249 | hashAttribute=rule.get("hashAttribute", "id"), 250 | fallbackAttribute=rule.get("fallbackAttribute", None), 251 | hashVersion=rule.get("hashVersion", None), 252 | range=rule.get("range", None), 253 | ranges=rule.get("ranges", None), 254 | meta=rule.get("meta", None), 255 | filters=rule.get("filters", None), 256 | seed=rule.get("seed", None), 257 | name=rule.get("name", None), 258 | phase=rule.get("phase", None), 259 | disableStickyBucketing=rule.get("disableStickyBucketing", False), 260 | bucketVersion=rule.get("bucketVersion", None), 261 | minBucketVersion=rule.get("minBucketVersion", None), 262 | parentConditions=rule.get("parentConditions", None), 263 | )) 264 | 265 | def to_dict(self) -> dict: 266 | return { 267 | "defaultValue": self.defaultValue, 268 | "rules": [rule.to_dict() for rule in self.rules], 269 | } 270 | 271 | class FeatureRule(object): 272 | def __init__( 273 | self, 274 | id: str = None, 275 | key: str = "", 276 | variations: list = None, 277 | weights: List[float] = None, 278 | coverage: int = None, 279 | condition: dict = None, 280 | namespace: Tuple[str, float, float] = None, 281 | force=None, 282 | hashAttribute: str = "id", 283 | fallbackAttribute: str = None, 284 | hashVersion: int = None, 285 | range: Tuple[float, float] = None, 286 | ranges: List[Tuple[float, float]] = None, 287 | meta: List[VariationMeta] = None, 288 | filters: List[Filter] = None, 289 | seed: str = None, 290 | name: str = None, 291 | phase: str = None, 292 | disableStickyBucketing: bool = False, 293 | bucketVersion: int = None, 294 | minBucketVersion: int = None, 295 | parentConditions: List[dict] = None, 296 | ) -> None: 297 | 298 | if disableStickyBucketing: 299 | fallbackAttribute = None 300 | 301 | self.id = id 302 | self.key = key 303 | self.variations = variations 304 | self.weights = weights 305 | self.coverage = coverage 306 | self.condition = condition 307 | self.namespace = namespace 308 | self.force = force 309 | self.hashAttribute = hashAttribute 310 | self.fallbackAttribute = fallbackAttribute 311 | self.hashVersion = hashVersion or 1 312 | self.range = range 313 | self.ranges = ranges 314 | self.meta = meta 315 | self.filters = filters 316 | self.seed = seed 317 | self.name = name 318 | self.phase = phase 319 | self.disableStickyBucketing = disableStickyBucketing 320 | self.bucketVersion = bucketVersion or 0 321 | self.minBucketVersion = minBucketVersion or 0 322 | self.parentConditions = parentConditions 323 | 324 | def to_dict(self) -> dict: 325 | data: Dict[str, Any] = {} 326 | if self.id: 327 | data["id"] = self.id 328 | if self.key: 329 | data["key"] = self.key 330 | if self.variations is not None: 331 | data["variations"] = self.variations 332 | if self.weights is not None: 333 | data["weights"] = self.weights 334 | if self.coverage and self.coverage != 1: 335 | data["coverage"] = self.coverage 336 | if self.condition is not None: 337 | data["condition"] = self.condition 338 | if self.namespace is not None: 339 | data["namespace"] = self.namespace 340 | if self.force is not None: 341 | data["force"] = self.force 342 | if self.hashAttribute != "id": 343 | data["hashAttribute"] = self.hashAttribute 344 | if self.hashVersion: 345 | data["hashVersion"] = self.hashVersion 346 | if self.range is not None: 347 | data["range"] = self.range 348 | if self.ranges is not None: 349 | data["ranges"] = self.ranges 350 | if self.meta is not None: 351 | data["meta"] = self.meta 352 | if self.filters is not None: 353 | data["filters"] = self.filters 354 | if self.seed is not None: 355 | data["seed"] = self.seed 356 | if self.name is not None: 357 | data["name"] = self.name 358 | if self.phase is not None: 359 | data["phase"] = self.phase 360 | if self.fallbackAttribute: 361 | data["fallbackAttribute"] = self.fallbackAttribute 362 | if self.disableStickyBucketing: 363 | data["disableStickyBucketing"] = True 364 | if self.bucketVersion: 365 | data["bucketVersion"] = self.bucketVersion 366 | if self.minBucketVersion: 367 | data["minBucketVersion"] = self.minBucketVersion 368 | if self.parentConditions: 369 | data["parentConditions"] = self.parentConditions 370 | 371 | return data 372 | 373 | class AbstractStickyBucketService(ABC): 374 | @abstractmethod 375 | def get_assignments(self, attributeName: str, attributeValue: str) -> Optional[Dict]: 376 | pass 377 | 378 | @abstractmethod 379 | def save_assignments(self, doc: Dict) -> None: 380 | pass 381 | 382 | def get_key(self, attributeName: str, attributeValue: str) -> str: 383 | return f"{attributeName}||{attributeValue}" 384 | 385 | # By default, just loop through all attributes and call get_assignments 386 | # Override this method in subclasses to perform a multi-query instead 387 | def get_all_assignments(self, attributes: Dict[str, str]) -> Dict[str, Dict]: 388 | docs = {} 389 | for attributeName, attributeValue in attributes.items(): 390 | doc = self.get_assignments(attributeName, attributeValue) 391 | if doc: 392 | docs[self.get_key(attributeName, attributeValue)] = doc 393 | return docs 394 | 395 | @dataclass 396 | class StackContext: 397 | id: Optional[str] = None 398 | evaluated_features: Set[str] = field(default_factory=set) 399 | 400 | class FeatureRefreshStrategy(Enum): 401 | STALE_WHILE_REVALIDATE = 'HTTP_REFRESH' 402 | SERVER_SENT_EVENTS = 'SSE' 403 | 404 | @dataclass 405 | class Options: 406 | url: Optional[str] = None 407 | api_host: Optional[str] = "https://cdn.growthbook.io" 408 | client_key: Optional[str] = None 409 | decryption_key: Optional[str] = None 410 | cache_ttl: int = 60 411 | enabled: bool = True 412 | qa_mode: bool = False 413 | enable_dev_mode: bool = False 414 | # forced_variations: Dict[str, Any] = field(default_factory=dict) 415 | refresh_strategy: Optional[FeatureRefreshStrategy] = FeatureRefreshStrategy.STALE_WHILE_REVALIDATE 416 | sticky_bucket_service: Optional[AbstractStickyBucketService] = None 417 | sticky_bucket_identifier_attributes: Optional[List[str]] = None 418 | on_experiment_viewed=None 419 | 420 | @dataclass 421 | class UserContext: 422 | # user_id: Optional[str] = None 423 | url: str = "" 424 | attributes: Dict[str, Any] = field(default_factory=dict) 425 | groups: Dict[str, str] = field(default_factory=dict) 426 | forced_variations: Dict[str, Any] = field(default_factory=dict) 427 | overrides: Dict[str, Any] = field(default_factory=dict) 428 | sticky_bucket_assignment_docs: Dict[str, Any] = field(default_factory=dict) 429 | 430 | @dataclass 431 | class GlobalContext: 432 | options: Options 433 | features: Dict[str, Any] = field(default_factory=dict) 434 | saved_groups: Dict[str, Any] = field(default_factory=dict) 435 | 436 | @dataclass 437 | class EvaluationContext: 438 | user: UserContext 439 | global_ctx: GlobalContext 440 | stack: StackContext 441 | -------------------------------------------------------------------------------- /growthbook/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import json 4 | 5 | from urllib.parse import urlparse, parse_qs 6 | from typing import Callable, Optional, Any, Set, Tuple, List, Dict 7 | from .common_types import EvaluationContext, FeatureResult, Experiment, Filter, Result, VariationMeta 8 | 9 | 10 | logger = logging.getLogger("growthbook.core") 11 | 12 | def evalCondition(attributes: dict, condition: dict, savedGroups: dict = None) -> bool: 13 | for key, value in condition.items(): 14 | if key == "$or": 15 | if not evalOr(attributes, value, savedGroups): 16 | return False 17 | elif key == "$nor": 18 | if evalOr(attributes, value, savedGroups): 19 | return False 20 | elif key == "$and": 21 | if not evalAnd(attributes, value, savedGroups): 22 | return False 23 | elif key == "$not": 24 | if evalCondition(attributes, value, savedGroups): 25 | return False 26 | elif not evalConditionValue(value, getPath(attributes, key), savedGroups): 27 | return False 28 | 29 | return True 30 | 31 | def evalOr(attributes, conditions, savedGroups) -> bool: 32 | if len(conditions) == 0: 33 | return True 34 | 35 | for condition in conditions: 36 | if evalCondition(attributes, condition, savedGroups): 37 | return True 38 | return False 39 | 40 | 41 | def evalAnd(attributes, conditions, savedGroups) -> bool: 42 | for condition in conditions: 43 | if not evalCondition(attributes, condition, savedGroups): 44 | return False 45 | return True 46 | 47 | def isOperatorObject(obj) -> bool: 48 | for key in obj.keys(): 49 | if key[0] != "$": 50 | return False 51 | return True 52 | 53 | def getType(attributeValue) -> str: 54 | t = type(attributeValue) 55 | 56 | if attributeValue is None: 57 | return "null" 58 | if t is int or t is float: 59 | return "number" 60 | if t is str: 61 | return "string" 62 | if t is list or t is set: 63 | return "array" 64 | if t is dict: 65 | return "object" 66 | if t is bool: 67 | return "boolean" 68 | return "unknown" 69 | 70 | def getPath(attributes, path): 71 | current = attributes 72 | for segment in path.split("."): 73 | if type(current) is dict and segment in current: 74 | current = current[segment] 75 | else: 76 | return None 77 | return current 78 | 79 | def evalConditionValue(conditionValue, attributeValue, savedGroups) -> bool: 80 | if type(conditionValue) is dict and isOperatorObject(conditionValue): 81 | for key, value in conditionValue.items(): 82 | if not evalOperatorCondition(key, attributeValue, value, savedGroups): 83 | return False 84 | return True 85 | return conditionValue == attributeValue 86 | 87 | def elemMatch(condition, attributeValue, savedGroups) -> bool: 88 | if not type(attributeValue) is list: 89 | return False 90 | 91 | for item in attributeValue: 92 | if isOperatorObject(condition): 93 | if evalConditionValue(condition, item, savedGroups): 94 | return True 95 | else: 96 | if evalCondition(item, condition, savedGroups): 97 | return True 98 | 99 | return False 100 | 101 | def compare(val1, val2) -> int: 102 | if (type(val1) is int or type(val1) is float) and not (type(val2) is int or type(val2) is float): 103 | if (val2 is None): 104 | val2 = 0 105 | else: 106 | val2 = float(val2) 107 | 108 | if (type(val2) is int or type(val2) is float) and not (type(val1) is int or type(val1) is float): 109 | if (val1 is None): 110 | val1 = 0 111 | else: 112 | val1 = float(val1) 113 | 114 | if val1 > val2: 115 | return 1 116 | if val1 < val2: 117 | return -1 118 | return 0 119 | 120 | def evalOperatorCondition(operator, attributeValue, conditionValue, savedGroups) -> bool: 121 | if operator == "$eq": 122 | try: 123 | return compare(attributeValue, conditionValue) == 0 124 | except Exception: 125 | return False 126 | elif operator == "$ne": 127 | try: 128 | return compare(attributeValue, conditionValue) != 0 129 | except Exception: 130 | return False 131 | elif operator == "$lt": 132 | try: 133 | return compare(attributeValue, conditionValue) < 0 134 | except Exception: 135 | return False 136 | elif operator == "$lte": 137 | try: 138 | return compare(attributeValue, conditionValue) <= 0 139 | except Exception: 140 | return False 141 | elif operator == "$gt": 142 | try: 143 | return compare(attributeValue, conditionValue) > 0 144 | except Exception: 145 | return False 146 | elif operator == "$gte": 147 | try: 148 | return compare(attributeValue, conditionValue) >= 0 149 | except Exception: 150 | return False 151 | elif operator == "$veq": 152 | return paddedVersionString(attributeValue) == paddedVersionString(conditionValue) 153 | elif operator == "$vne": 154 | return paddedVersionString(attributeValue) != paddedVersionString(conditionValue) 155 | elif operator == "$vlt": 156 | return paddedVersionString(attributeValue) < paddedVersionString(conditionValue) 157 | elif operator == "$vlte": 158 | return paddedVersionString(attributeValue) <= paddedVersionString(conditionValue) 159 | elif operator == "$vgt": 160 | return paddedVersionString(attributeValue) > paddedVersionString(conditionValue) 161 | elif operator == "$vgte": 162 | return paddedVersionString(attributeValue) >= paddedVersionString(conditionValue) 163 | elif operator == "$inGroup": 164 | if not type(conditionValue) is str: 165 | return False 166 | if not conditionValue in savedGroups: 167 | return False 168 | return isIn(savedGroups[conditionValue] or [], attributeValue) 169 | elif operator == "$notInGroup": 170 | if not type(conditionValue) is str: 171 | return False 172 | if not conditionValue in savedGroups: 173 | return True 174 | return not isIn(savedGroups[conditionValue] or [], attributeValue) 175 | elif operator == "$regex": 176 | try: 177 | r = re.compile(conditionValue) 178 | return bool(r.search(attributeValue)) 179 | except Exception: 180 | return False 181 | elif operator == "$in": 182 | if not type(conditionValue) is list: 183 | return False 184 | return isIn(conditionValue, attributeValue) 185 | elif operator == "$nin": 186 | if not type(conditionValue) is list: 187 | return False 188 | return not isIn(conditionValue, attributeValue) 189 | elif operator == "$elemMatch": 190 | return elemMatch(conditionValue, attributeValue, savedGroups) 191 | elif operator == "$size": 192 | if not (type(attributeValue) is list): 193 | return False 194 | return evalConditionValue(conditionValue, len(attributeValue), savedGroups) 195 | elif operator == "$all": 196 | if not (type(attributeValue) is list): 197 | return False 198 | for cond in conditionValue: 199 | passing = False 200 | for attr in attributeValue: 201 | if evalConditionValue(cond, attr, savedGroups): 202 | passing = True 203 | if not passing: 204 | return False 205 | return True 206 | elif operator == "$exists": 207 | if not conditionValue: 208 | return attributeValue is None 209 | return attributeValue is not None 210 | elif operator == "$type": 211 | return getType(attributeValue) == conditionValue 212 | elif operator == "$not": 213 | return not evalConditionValue(conditionValue, attributeValue, savedGroups) 214 | return False 215 | 216 | def paddedVersionString(input) -> str: 217 | # If input is a number, convert to a string 218 | if type(input) is int or type(input) is float: 219 | input = str(input) 220 | 221 | if not input or type(input) is not str: 222 | input = "0" 223 | 224 | # Remove build info and leading `v` if any 225 | input = re.sub(r"(^v|\+.*$)", "", input) 226 | # Split version into parts (both core version numbers and pre-release tags) 227 | # "v1.2.3-rc.1+build123" -> ["1","2","3","rc","1"] 228 | parts = re.split(r"[-.]", input) 229 | # If it's SemVer without a pre-release, add `~` to the end 230 | # ["1","0","0"] -> ["1","0","0","~"] 231 | # "~" is the largest ASCII character, so this will make "1.0.0" greater than "1.0.0-beta" for example 232 | if len(parts) == 3: 233 | parts.append("~") 234 | # Left pad each numeric part with spaces so string comparisons will work ("9">"10", but " 9"<"10") 235 | # Then, join back together into a single string 236 | return "-".join([v.rjust(5, " ") if re.match(r"^[0-9]+$", v) else v for v in parts]) 237 | 238 | 239 | def isIn(conditionValue, attributeValue) -> bool: 240 | if type(attributeValue) is list: 241 | return bool(set(conditionValue) & set(attributeValue)) 242 | return attributeValue in conditionValue 243 | 244 | def _getOrigHashValue( 245 | eval_context: EvaluationContext, 246 | attr: Optional[str] = "id", 247 | fallbackAttr: Optional[str] = None 248 | ) -> Tuple[str, str]: 249 | # attr = attr or "id" -- Fix for the flaky behavior of sticky bucket assignment 250 | actual_attr: str = attr if attr is not None else "" 251 | val = "" 252 | 253 | if actual_attr in eval_context.user.attributes: 254 | val = "" if eval_context.user.attributes[actual_attr] is None else eval_context.user.attributes[actual_attr] 255 | 256 | # If no match, try fallback 257 | if (not val or val == "") and fallbackAttr and eval_context.global_ctx.options.sticky_bucket_service: 258 | if fallbackAttr in eval_context.user.attributes: 259 | val = "" if eval_context.user.attributes[fallbackAttr] is None else eval_context.user.attributes[fallbackAttr] 260 | 261 | if not val or val != "": 262 | actual_attr = fallbackAttr 263 | 264 | return (actual_attr, val) 265 | 266 | def _getHashValue(eval_context: EvaluationContext, attr: str = None, fallbackAttr: str = None) -> Tuple[str, str]: 267 | (attr, val) = _getOrigHashValue(attr=attr, fallbackAttr=fallbackAttr, eval_context=eval_context) 268 | return (attr, str(val)) 269 | 270 | def _isIncludedInRollout( 271 | seed: str, 272 | eval_context: EvaluationContext, 273 | hashAttribute: str = None, 274 | fallbackAttribute: str = None, 275 | range: Tuple[float, float] = None, 276 | coverage: float = None, 277 | hashVersion: int = None 278 | ) -> bool: 279 | if coverage is None and range is None: 280 | return True 281 | 282 | if coverage == 0 and range is None: 283 | return False 284 | 285 | (_, hash_value) = _getHashValue(attr=hashAttribute, fallbackAttr=fallbackAttribute, eval_context=eval_context) 286 | if hash_value == "": 287 | return False 288 | 289 | n = gbhash(seed, hash_value, hashVersion or 1) 290 | if n is None: 291 | return False 292 | 293 | if range: 294 | return inRange(n, range) 295 | elif coverage is not None: 296 | return n <= coverage 297 | 298 | return True 299 | 300 | def _isFilteredOut(filters: List[Filter], eval_context: EvaluationContext) -> bool: 301 | for filter in filters: 302 | (_, hash_value) = _getHashValue(attr=filter.get("attribute", "id"), eval_context=eval_context) 303 | if hash_value == "": 304 | return False 305 | 306 | n = gbhash(filter.get("seed", ""), hash_value, filter.get("hashVersion", 2)) 307 | if n is None: 308 | return False 309 | 310 | filtered = False 311 | for range in filter["ranges"]: 312 | if inRange(n, range): 313 | filtered = True 314 | break 315 | if not filtered: 316 | return True 317 | return False 318 | 319 | 320 | def fnv1a32(str: str) -> int: 321 | hval = 0x811C9DC5 322 | prime = 0x01000193 323 | uint32_max = 2 ** 32 324 | for s in str: 325 | hval = hval ^ ord(s) 326 | hval = (hval * prime) % uint32_max 327 | return hval 328 | 329 | def inNamespace(userId: str, namespace: Tuple[str, float, float]) -> bool: 330 | n = gbhash("__" + namespace[0], userId, 1) 331 | if n is None: 332 | return False 333 | return namespace[1] <= n < namespace[2] 334 | 335 | def gbhash(seed: str, value: str, version: int) -> Optional[float]: 336 | if version == 2: 337 | n = fnv1a32(str(fnv1a32(seed + value))) 338 | return (n % 10000) / 10000 339 | if version == 1: 340 | n = fnv1a32(value + seed) 341 | return (n % 1000) / 1000 342 | return None 343 | 344 | def inRange(n: float, range: Tuple[float, float]) -> bool: 345 | return range[0] <= n < range[1] 346 | 347 | def chooseVariation(n: float, ranges: List[Tuple[float, float]]) -> int: 348 | for i, r in enumerate(ranges): 349 | if inRange(n, r): 350 | return i 351 | return -1 352 | 353 | def getQueryStringOverride(id: str, url: str, numVariations: int) -> Optional[int]: 354 | res = urlparse(url) 355 | if not res.query: 356 | return None 357 | qs = parse_qs(res.query) 358 | if id not in qs: 359 | return None 360 | variation = qs[id][0] 361 | if variation is None or not variation.isdigit(): 362 | return None 363 | varId = int(variation) 364 | if varId < 0 or varId >= numVariations: 365 | return None 366 | return varId 367 | 368 | def _urlIsValid(url: Optional[str], pattern: str) -> bool: 369 | if not url: # it was self._url! Ignored the param passed in. 370 | return False 371 | 372 | try: 373 | r = re.compile(pattern) 374 | if r.search(url): 375 | return True 376 | 377 | pathOnly = re.sub(r"^[^/]*/", "/", re.sub(r"^https?:\/\/", "", url)) 378 | if r.search(pathOnly): 379 | return True 380 | return False 381 | except Exception: 382 | return True 383 | 384 | def getEqualWeights(numVariations: int) -> List[float]: 385 | if numVariations < 1: 386 | return [] 387 | return [1 / numVariations for _ in range(numVariations)] 388 | 389 | 390 | def getBucketRanges( 391 | numVariations: int, coverage: float = 1, weights: List[float] = None 392 | ) -> List[Tuple[float, float]]: 393 | if coverage < 0: 394 | coverage = 0 395 | if coverage > 1: 396 | coverage = 1 397 | if weights is None: 398 | weights = getEqualWeights(numVariations) 399 | if len(weights) != numVariations: 400 | weights = getEqualWeights(numVariations) 401 | if sum(weights) < 0.99 or sum(weights) > 1.01: 402 | weights = getEqualWeights(numVariations) 403 | 404 | cumulative: float = 0 405 | ranges = [] 406 | for w in weights: 407 | start = cumulative 408 | cumulative += w 409 | ranges.append((start, start + coverage * w)) 410 | 411 | return ranges 412 | 413 | def eval_feature( 414 | key: str, 415 | evalContext: EvaluationContext = None, 416 | callback_subscription: Callable[[Experiment, Result], None] = None 417 | ) -> FeatureResult: 418 | """Core feature evaluation logic as a standalone function""" 419 | 420 | if evalContext is None: 421 | raise ValueError("evalContext is required - eval_feature") 422 | 423 | if key not in evalContext.global_ctx.features: 424 | logger.warning("Unknown feature %s", key) 425 | return FeatureResult(None, "unknownFeature") 426 | 427 | if key in evalContext.stack.evaluated_features: 428 | logger.warning("Cyclic prerequisite detected, stack: %s", evalContext.stack.evaluated_features) 429 | return FeatureResult(None, "cyclicPrerequisite") 430 | 431 | evalContext.stack.evaluated_features.add(key) 432 | 433 | feature = evalContext.global_ctx.features[key] 434 | 435 | evaluated_features = evalContext.stack.evaluated_features.copy() 436 | 437 | for rule in feature.rules: 438 | # Reset the stack for each rule 439 | evalContext.stack.evaluated_features = evaluated_features.copy() 440 | 441 | if (rule.parentConditions): 442 | prereq_res = eval_prereqs(parentConditions=rule.parentConditions, evalContext=evalContext) 443 | if prereq_res == "gate": 444 | logger.debug("Top-level prerequisite failed, return None, feature %s", key) 445 | return FeatureResult(None, "prerequisite") 446 | if prereq_res == "cyclic": 447 | # Warning already logged in this case 448 | return FeatureResult(None, "cyclicPrerequisite") 449 | if prereq_res == "fail": 450 | logger.debug("Skip rule because of failing prerequisite, feature %s", key) 451 | continue 452 | 453 | if rule.condition: 454 | if not evalCondition(evalContext.user.attributes, rule.condition, evalContext.global_ctx.saved_groups): 455 | logger.debug( 456 | "Skip rule because of failed condition, feature %s", key 457 | ) 458 | continue 459 | if rule.filters: 460 | if _isFilteredOut(rule.filters, evalContext): 461 | logger.debug( 462 | "Skip rule because of filters/namespaces, feature %s", key 463 | ) 464 | continue 465 | if rule.force is not None: 466 | if not _isIncludedInRollout( 467 | seed=rule.seed or key, 468 | hashAttribute=rule.hashAttribute, 469 | fallbackAttribute=rule.fallbackAttribute, 470 | range=rule.range, 471 | coverage=rule.coverage, 472 | hashVersion=rule.hashVersion, 473 | eval_context=evalContext 474 | ): 475 | logger.debug( 476 | "Skip rule because user not included in percentage rollout, feature %s", 477 | key, 478 | ) 479 | continue 480 | 481 | logger.debug("Force value from rule, feature %s", key) 482 | return FeatureResult(rule.force, "force", ruleId=rule.id) 483 | 484 | if rule.variations is None: 485 | logger.warning("Skip invalid rule, feature %s", key) 486 | continue 487 | 488 | exp = Experiment( 489 | key=rule.key or key, 490 | variations=rule.variations, 491 | coverage=rule.coverage, 492 | weights=rule.weights, 493 | hashAttribute=rule.hashAttribute, 494 | fallbackAttribute=rule.fallbackAttribute, 495 | namespace=rule.namespace, 496 | hashVersion=rule.hashVersion, 497 | meta=rule.meta, 498 | ranges=rule.ranges, 499 | name=rule.name, 500 | phase=rule.phase, 501 | seed=rule.seed, 502 | filters=rule.filters, 503 | condition=rule.condition, 504 | disableStickyBucketing=rule.disableStickyBucketing, 505 | bucketVersion=rule.bucketVersion, 506 | minBucketVersion=rule.minBucketVersion, 507 | ) 508 | 509 | result = run_experiment(experiment=exp, featureId=key, evalContext=evalContext) 510 | 511 | if callback_subscription: 512 | callback_subscription(exp, result) 513 | 514 | if not result.inExperiment: 515 | logger.debug( 516 | "Skip rule because user not included in experiment, feature %s", key 517 | ) 518 | continue 519 | 520 | if result.passthrough: 521 | logger.debug("Continue to next rule, feature %s", key) 522 | continue 523 | 524 | logger.debug("Assign value from experiment, feature %s", key) 525 | return FeatureResult( 526 | result.value, "experiment", exp, result, ruleId=rule.id 527 | ) 528 | 529 | logger.debug("Use default value for feature %s", key) 530 | return FeatureResult(feature.defaultValue, "defaultValue") 531 | 532 | def eval_prereqs(parentConditions: List[dict], evalContext: EvaluationContext) -> str: 533 | evaluated_features = evalContext.stack.evaluated_features.copy() 534 | 535 | for parentCondition in parentConditions: 536 | # Reset the stack in each iteration 537 | evalContext.stack.evaluated_features = evaluated_features.copy() 538 | 539 | parentRes = eval_feature(key=parentCondition.get("id", None), evalContext=evalContext) 540 | 541 | if parentRes.source == "cyclicPrerequisite": 542 | return "cyclic" 543 | 544 | if not evalCondition({'value': parentRes.value}, parentCondition.get("condition", None), evalContext.global_ctx.saved_groups): 545 | if parentCondition.get("gate", False): 546 | return "gate" 547 | return "fail" 548 | return "pass" 549 | 550 | def _get_sticky_bucket_experiment_key(experiment_key: str, bucket_version: int = 0) -> str: 551 | return experiment_key + "__" + str(bucket_version) 552 | 553 | def _get_sticky_bucket_assignments(evalContext: EvaluationContext, 554 | attr: str = None, 555 | fallback: str = None) -> Dict[str, str]: 556 | merged: Dict[str, str] = {} 557 | 558 | # Search for docs stored for attribute(id) 559 | _, hashValue = _getHashValue(attr=attr, eval_context=evalContext) 560 | key = f"{attr}||{hashValue}" 561 | if key in evalContext.user.sticky_bucket_assignment_docs: 562 | merged = evalContext.user.sticky_bucket_assignment_docs[key].get("assignments", {}) 563 | 564 | # Search for docs stored for fallback attribute 565 | if fallback: 566 | _, hashValue = _getHashValue(fallbackAttr=fallback, eval_context=evalContext) 567 | key = f"{fallback}||{hashValue}" 568 | if key in evalContext.user.sticky_bucket_assignment_docs: 569 | # Merge the fallback assignments, but don't overwrite existing ones 570 | for k, v in evalContext.user.sticky_bucket_assignment_docs[key].get("assignments", {}).items(): 571 | if k not in merged: 572 | merged[k] = v 573 | 574 | return merged 575 | 576 | def _is_blocked( 577 | assignments: Dict[str, str], 578 | experiment_key: str, 579 | min_bucket_version: int 580 | ) -> bool: 581 | if min_bucket_version > 0: 582 | for i in range(min_bucket_version): 583 | blocked_key = _get_sticky_bucket_experiment_key(experiment_key, i) 584 | if blocked_key in assignments: 585 | return True 586 | return False 587 | 588 | def _get_sticky_bucket_variation( 589 | experiment_key: str, 590 | evalContext: EvaluationContext, 591 | bucket_version: int = None, 592 | min_bucket_version: int = None, 593 | meta: List[VariationMeta] = None, 594 | hash_attribute: str = None, 595 | fallback_attribute: str = None, 596 | ) -> dict: 597 | bucket_version = bucket_version or 0 598 | min_bucket_version = min_bucket_version or 0 599 | meta = meta or [] 600 | 601 | id = _get_sticky_bucket_experiment_key(experiment_key, bucket_version) 602 | 603 | assignments = _get_sticky_bucket_assignments(attr=hash_attribute, fallback=fallback_attribute, evalContext=evalContext) 604 | if _is_blocked(assignments, experiment_key, min_bucket_version): 605 | return { 606 | 'variation': -1, 607 | 'versionIsBlocked': True 608 | } 609 | 610 | variation_key = assignments.get(id, None) 611 | if not variation_key: 612 | return { 613 | 'variation': -1 614 | } 615 | 616 | # Find the key in meta 617 | variation = next((i for i, v in enumerate(meta) if v.get("key") == variation_key), -1) 618 | if variation < 0: 619 | return { 620 | 'variation': -1 621 | } 622 | 623 | return {'variation': variation} 624 | 625 | def run_experiment(experiment: Experiment, 626 | featureId: Optional[str] = None, 627 | evalContext: EvaluationContext = None, 628 | tracking_cb: Callable[[Experiment, Result], None] = None 629 | ) -> Result: 630 | if evalContext is None: 631 | raise ValueError("evalContext is required - run_experiment") 632 | # 1. If experiment has less than 2 variations, return immediately 633 | if len(experiment.variations) < 2: 634 | logger.warning( 635 | "Experiment %s has less than 2 variations, skip", experiment.key 636 | ) 637 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 638 | # 2. If growthbook is disabled, return immediately 639 | if not evalContext.global_ctx.options.enabled: 640 | logger.debug( 641 | "Skip experiment %s because GrowthBook is disabled", experiment.key 642 | ) 643 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 644 | # 2.5. If the experiment props have been overridden, merge them in 645 | if evalContext.user.overrides.get(experiment.key, None): 646 | experiment.update(evalContext.user.overrides[experiment.key]) 647 | # 3. If experiment is forced via a querystring in the url 648 | qs = getQueryStringOverride( 649 | experiment.key, evalContext.user.url, len(experiment.variations) 650 | ) 651 | if qs is not None: 652 | logger.debug( 653 | "Force variation %d from URL querystring, experiment %s", 654 | qs, 655 | experiment.key, 656 | ) 657 | return _getExperimentResult(experiment=experiment, variationId=qs, featureId=featureId, evalContext=evalContext) 658 | # 4. If variation is forced in the context 659 | if evalContext.user.forced_variations.get(experiment.key, None) is not None: 660 | logger.debug( 661 | "Force variation %d from GrowthBook context, experiment %s", 662 | evalContext.user.forced_variations[experiment.key], 663 | experiment.key, 664 | ) 665 | return _getExperimentResult( 666 | experiment=experiment, variationId=evalContext.user.forced_variations[experiment.key], featureId=featureId, evalContext=evalContext 667 | ) 668 | # 5. If experiment is a draft or not active, return immediately 669 | if experiment.status == "draft" or not experiment.active: 670 | logger.debug("Experiment %s is not active, skip", experiment.key) 671 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 672 | 673 | # 6. Get the user hash attribute and value 674 | (hashAttribute, hashValue) = _getHashValue(attr=experiment.hashAttribute, fallbackAttr=experiment.fallbackAttribute, eval_context=evalContext) 675 | if not hashValue: 676 | logger.debug( 677 | "Skip experiment %s because user's hashAttribute value is empty", 678 | experiment.key, 679 | ) 680 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 681 | 682 | assigned = -1 683 | 684 | found_sticky_bucket = False 685 | sticky_bucket_version_is_blocked = False 686 | if evalContext.global_ctx.options.sticky_bucket_service and not experiment.disableStickyBucketing: 687 | sticky_bucket = _get_sticky_bucket_variation( 688 | experiment_key=experiment.key, 689 | bucket_version=experiment.bucketVersion, 690 | min_bucket_version=experiment.minBucketVersion, 691 | meta=experiment.meta, 692 | hash_attribute=experiment.hashAttribute, 693 | fallback_attribute=experiment.fallbackAttribute, 694 | evalContext=evalContext 695 | ) 696 | found_sticky_bucket = sticky_bucket.get('variation', 0) >= 0 697 | assigned = sticky_bucket.get('variation', 0) 698 | sticky_bucket_version_is_blocked = sticky_bucket.get('versionIsBlocked', False) 699 | 700 | if found_sticky_bucket: 701 | logger.debug("Found sticky bucket for experiment %s, assigning sticky variation %s", experiment.key, assigned) 702 | 703 | # Some checks are not needed if we already have a sticky bucket 704 | if not found_sticky_bucket: 705 | # 7. Filtered out / not in namespace 706 | if experiment.filters: 707 | if _isFilteredOut(experiment.filters, evalContext): 708 | logger.debug( 709 | "Skip experiment %s because of filters/namespaces", experiment.key 710 | ) 711 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 712 | elif experiment.namespace and not inNamespace(hashValue, experiment.namespace): 713 | logger.debug("Skip experiment %s because of namespace", experiment.key) 714 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 715 | 716 | # 7.5. If experiment has an include property 717 | if experiment.include: 718 | try: 719 | if not experiment.include(): 720 | logger.debug( 721 | "Skip experiment %s because include() returned false", 722 | experiment.key, 723 | ) 724 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 725 | except Exception: 726 | logger.warning( 727 | "Skip experiment %s because include() raised an Exception", 728 | experiment.key, 729 | ) 730 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 731 | 732 | # 8. Exclude if condition is false 733 | if experiment.condition and not evalCondition( 734 | evalContext.user.attributes, experiment.condition, evalContext.global_ctx.saved_groups 735 | ): 736 | logger.debug( 737 | "Skip experiment %s because user failed the condition", experiment.key 738 | ) 739 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 740 | 741 | # 8.05 Exclude if parent conditions are not met 742 | if (experiment.parentConditions): 743 | prereq_res = eval_prereqs(parentConditions=experiment.parentConditions, evalContext=evalContext) 744 | if prereq_res == "gate" or prereq_res == "fail": 745 | logger.debug("Skip experiment %s because of failing prerequisite", experiment.key) 746 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 747 | if prereq_res == "cyclic": 748 | logger.debug("Skip experiment %s because of cyclic prerequisite", experiment.key) 749 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 750 | 751 | # 8.1. Make sure user is in a matching group 752 | if experiment.groups and len(experiment.groups): 753 | expGroups = evalContext.user.groups or {} 754 | matched = False 755 | for group in experiment.groups: 756 | if expGroups[group]: 757 | matched = True 758 | if not matched: 759 | logger.debug( 760 | "Skip experiment %s because user not in required group", 761 | experiment.key, 762 | ) 763 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 764 | 765 | # The following apply even when in a sticky bucket 766 | 767 | # 8.2. If experiment.url is set, see if it's valid 768 | if experiment.url: 769 | if not _urlIsValid(url=evalContext.global_ctx.options.url, pattern=experiment.url): 770 | logger.debug( 771 | "Skip experiment %s because current URL is not targeted", 772 | experiment.key, 773 | ) 774 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 775 | 776 | # 9. Get bucket ranges and choose variation 777 | n = gbhash( 778 | experiment.seed or experiment.key, hashValue, experiment.hashVersion or 1 779 | ) 780 | if n is None: 781 | logger.warning( 782 | "Skip experiment %s because of invalid hashVersion", experiment.key 783 | ) 784 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 785 | 786 | if not found_sticky_bucket: 787 | c = experiment.coverage 788 | ranges = experiment.ranges or getBucketRanges( 789 | len(experiment.variations), c if c is not None else 1, experiment.weights 790 | ) 791 | assigned = chooseVariation(n, ranges) 792 | 793 | # Unenroll if any prior sticky buckets are blocked by version 794 | if sticky_bucket_version_is_blocked: 795 | logger.debug("Skip experiment %s because sticky bucket version is blocked", experiment.key) 796 | return _getExperimentResult(experiment=experiment, featureId=featureId, stickyBucketUsed=True, evalContext=evalContext) 797 | 798 | # 10. Return if not in experiment 799 | if assigned < 0: 800 | logger.debug( 801 | "Skip experiment %s because user is not included in the rollout", 802 | experiment.key, 803 | ) 804 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 805 | 806 | # 11. If experiment is forced, return immediately 807 | if experiment.force is not None: 808 | logger.debug( 809 | "Force variation %d in experiment %s", experiment.force, experiment.key 810 | ) 811 | return _getExperimentResult( 812 | experiment=experiment, variationId=experiment.force, featureId=featureId, evalContext=evalContext 813 | ) 814 | 815 | # 12. Exclude if in QA mode 816 | if evalContext.global_ctx.options.qa_mode: 817 | logger.debug("Skip experiment %s because of QA Mode", experiment.key) 818 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 819 | 820 | # 12.5. If experiment is stopped, return immediately 821 | if experiment.status == "stopped": 822 | logger.debug("Skip experiment %s because it is stopped", experiment.key) 823 | return _getExperimentResult(experiment=experiment, featureId=featureId, evalContext=evalContext) 824 | 825 | # 13. Build the result object 826 | result = _getExperimentResult( 827 | experiment=experiment, variationId=assigned, hashUsed=True, featureId=featureId, bucket=n, stickyBucketUsed=found_sticky_bucket, evalContext=evalContext 828 | ) 829 | 830 | # 13.5 Persist sticky bucket 831 | if evalContext.global_ctx.options.sticky_bucket_service and not experiment.disableStickyBucketing: 832 | assignment = {} 833 | assignment[_get_sticky_bucket_experiment_key( 834 | experiment.key, 835 | experiment.bucketVersion 836 | )] = result.key 837 | 838 | data = _generate_sticky_bucket_assignment_doc( 839 | attribute_name=hashAttribute, 840 | attribute_value=hashValue, 841 | assignments=assignment, 842 | evalContext=evalContext 843 | ) 844 | doc = data.get("doc", None) 845 | if doc and data.get('changed', False): 846 | if not evalContext.user.sticky_bucket_assignment_docs: 847 | evalContext.user.sticky_bucket_assignment_docs = {} 848 | evalContext.user.sticky_bucket_assignment_docs[data.get('key')] = doc 849 | evalContext.global_ctx.options.sticky_bucket_service.save_assignments(doc) 850 | 851 | # 14. Fire the tracking callback if set 852 | if tracking_cb: 853 | tracking_cb(experiment, result) 854 | 855 | # 15. Return the result 856 | logger.debug("Assigned variation %d in experiment %s", assigned, experiment.key) 857 | return result 858 | 859 | def _generate_sticky_bucket_assignment_doc(attribute_name: str, attribute_value: str, assignments: dict, evalContext: EvaluationContext): 860 | key = attribute_name + "||" + attribute_value 861 | existing_assignments = evalContext.user.sticky_bucket_assignment_docs.get(key, {}).get("assignments", {}) 862 | 863 | new_assignments = {**existing_assignments, **assignments} 864 | 865 | # Compare JSON strings to see if they have changed 866 | existing_json = json.dumps(existing_assignments, sort_keys=True) 867 | new_json = json.dumps(new_assignments, sort_keys=True) 868 | changed = existing_json != new_json 869 | 870 | return { 871 | 'key': key, 872 | 'doc': { 873 | 'attributeName': attribute_name, 874 | 'attributeValue': attribute_value, 875 | 'assignments': new_assignments 876 | }, 877 | 'changed': changed 878 | } 879 | 880 | def _getExperimentResult( 881 | experiment: Experiment, 882 | evalContext: EvaluationContext, 883 | variationId: int = -1, 884 | hashUsed: bool = False, 885 | featureId: str = None, 886 | bucket: float = None, 887 | stickyBucketUsed: bool = False 888 | ) -> Result: 889 | inExperiment = True 890 | if variationId < 0 or variationId > len(experiment.variations) - 1: 891 | variationId = 0 892 | inExperiment = False 893 | 894 | meta = None 895 | if experiment.meta: 896 | meta = experiment.meta[variationId] 897 | 898 | (hashAttribute, hashValue) = _getOrigHashValue(attr=experiment.hashAttribute, 899 | fallbackAttr=experiment.fallbackAttribute, 900 | eval_context=evalContext) 901 | 902 | return Result( 903 | featureId=featureId, 904 | inExperiment=inExperiment, 905 | variationId=variationId, 906 | value=experiment.variations[variationId], 907 | hashUsed=hashUsed, 908 | hashAttribute=hashAttribute, 909 | hashValue=hashValue, 910 | meta=meta, 911 | bucket=bucket, 912 | stickyBucketUsed=stickyBucketUsed 913 | ) -------------------------------------------------------------------------------- /growthbook/growthbook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | This is the Python client library for GrowthBook, the open-source 4 | feature flagging and A/B testing platform. 5 | More info at https://www.growthbook.io 6 | """ 7 | 8 | import sys 9 | import json 10 | import threading 11 | import logging 12 | 13 | from abc import ABC, abstractmethod 14 | from typing import Optional, Any, Set, Tuple, List, Dict, Callable 15 | 16 | from .common_types import ( EvaluationContext, 17 | Experiment, 18 | FeatureResult, 19 | Feature, 20 | GlobalContext, 21 | Options, 22 | Result, StackContext, 23 | UserContext, 24 | AbstractStickyBucketService, 25 | FeatureRule 26 | ) 27 | 28 | # Only require typing_extensions if using Python 3.7 or earlier 29 | if sys.version_info >= (3, 8): 30 | from typing import TypedDict 31 | else: 32 | from typing_extensions import TypedDict 33 | 34 | from base64 import b64decode 35 | from time import time 36 | import aiohttp 37 | import asyncio 38 | 39 | from aiohttp.client_exceptions import ClientConnectorError, ClientResponseError, ClientPayloadError 40 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 41 | from cryptography.hazmat.primitives import padding 42 | from urllib3 import PoolManager 43 | 44 | from .core import _getHashValue, eval_feature as core_eval_feature, run_experiment 45 | 46 | logger = logging.getLogger("growthbook") 47 | 48 | def decrypt(encrypted_str: str, key_str: str) -> str: 49 | iv_str, ct_str = encrypted_str.split(".", 2) 50 | 51 | key = b64decode(key_str) 52 | iv = b64decode(iv_str) 53 | ct = b64decode(ct_str) 54 | 55 | cipher = Cipher(algorithms.AES128(key), modes.CBC(iv)) 56 | decryptor = cipher.decryptor() 57 | 58 | decrypted = decryptor.update(ct) + decryptor.finalize() 59 | 60 | unpadder = padding.PKCS7(128).unpadder() 61 | bytestring = unpadder.update(decrypted) + unpadder.finalize() 62 | 63 | return bytestring.decode("utf-8") 64 | 65 | class AbstractFeatureCache(ABC): 66 | @abstractmethod 67 | def get(self, key: str) -> Optional[Dict]: 68 | pass 69 | 70 | @abstractmethod 71 | def set(self, key: str, value: Dict, ttl: int) -> None: 72 | pass 73 | 74 | def clear(self) -> None: 75 | pass 76 | 77 | 78 | class CacheEntry(object): 79 | def __init__(self, value: Dict, ttl: int) -> None: 80 | self.value = value 81 | self.ttl = ttl 82 | self.expires = time() + ttl 83 | 84 | def update(self, value: Dict): 85 | self.value = value 86 | self.expires = time() + self.ttl 87 | 88 | 89 | class InMemoryFeatureCache(AbstractFeatureCache): 90 | def __init__(self) -> None: 91 | self.cache: Dict[str, CacheEntry] = {} 92 | 93 | def get(self, key: str) -> Optional[Dict]: 94 | if key in self.cache: 95 | entry = self.cache[key] 96 | if entry.expires >= time(): 97 | return entry.value 98 | return None 99 | 100 | def set(self, key: str, value: Dict, ttl: int) -> None: 101 | if key in self.cache: 102 | self.cache[key].update(value) 103 | self.cache[key] = CacheEntry(value, ttl) 104 | 105 | def clear(self) -> None: 106 | self.cache.clear() 107 | 108 | class InMemoryStickyBucketService(AbstractStickyBucketService): 109 | def __init__(self) -> None: 110 | self.docs: Dict[str, Dict] = {} 111 | 112 | def get_assignments(self, attributeName: str, attributeValue: str) -> Optional[Dict]: 113 | return self.docs.get(self.get_key(attributeName, attributeValue), None) 114 | 115 | def save_assignments(self, doc: Dict) -> None: 116 | self.docs[self.get_key(doc["attributeName"], doc["attributeValue"])] = doc 117 | 118 | def destroy(self) -> None: 119 | self.docs.clear() 120 | 121 | 122 | class SSEClient: 123 | def __init__(self, api_host, client_key, on_event, reconnect_delay=5, headers=None): 124 | self.api_host = api_host 125 | self.client_key = client_key 126 | 127 | self.on_event = on_event 128 | self.reconnect_delay = reconnect_delay 129 | 130 | self._sse_session = None 131 | self._sse_thread = None 132 | self._loop = None 133 | 134 | self.is_running = False 135 | 136 | self.headers = { 137 | "Accept": "application/json; q=0.5, text/event-stream", 138 | "Cache-Control": "no-cache", 139 | } 140 | 141 | if headers: 142 | self.headers.update(headers) 143 | 144 | def connect(self): 145 | if self.is_running: 146 | logger.debug("Streaming session is already running.") 147 | return 148 | 149 | self.is_running = True 150 | self._sse_thread = threading.Thread(target=self._run_sse_channel) 151 | self._sse_thread.start() 152 | 153 | def disconnect(self): 154 | self.is_running = False 155 | if self._loop and self._loop.is_running(): 156 | future = asyncio.run_coroutine_threadsafe(self._stop_session(), self._loop) 157 | try: 158 | future.result() 159 | except Exception as e: 160 | logger.error(f"Streaming disconnect error: {e}") 161 | 162 | if self._sse_thread: 163 | self._sse_thread.join(timeout=5) 164 | 165 | logger.debug("Streaming session disconnected") 166 | 167 | def _get_sse_url(self, api_host: str, client_key: str) -> str: 168 | api_host = (api_host or "https://cdn.growthbook.io").rstrip("/") 169 | return f"{api_host}/sub/{client_key}" 170 | 171 | async def _init_session(self): 172 | url = self._get_sse_url(self.api_host, self.client_key) 173 | 174 | while self.is_running: 175 | try: 176 | async with aiohttp.ClientSession(headers=self.headers) as session: 177 | self._sse_session = session 178 | 179 | async with session.get(url) as response: 180 | response.raise_for_status() 181 | await self._process_response(response) 182 | except ClientResponseError as e: 183 | logger.error(f"Streaming error, closing connection: {e.status} {e.message}") 184 | self.is_running = False 185 | break 186 | except (ClientConnectorError, ClientPayloadError) as e: 187 | logger.error(f"Streaming error: {e}") 188 | if not self.is_running: 189 | break 190 | await self._wait_for_reconnect() 191 | except TimeoutError: 192 | logger.warning(f"Streaming connection timed out after {self.timeout} seconds.") 193 | await self._wait_for_reconnect() 194 | except asyncio.CancelledError: 195 | logger.debug("Streaming was cancelled.") 196 | break 197 | finally: 198 | await self._close_session() 199 | 200 | async def _process_response(self, response): 201 | event_data = {} 202 | async for line in response.content: 203 | decoded_line = line.decode('utf-8').strip() 204 | if decoded_line.startswith("event:"): 205 | event_data['type'] = decoded_line[len("event:"):].strip() 206 | elif decoded_line.startswith("data:"): 207 | event_data['data'] = event_data.get('data', '') + f"\n{decoded_line[len('data:'):].strip()}" 208 | elif not decoded_line: 209 | if 'type' in event_data and 'data' in event_data: 210 | self.on_event(event_data) 211 | event_data = {} 212 | 213 | if 'type' in event_data and 'data' in event_data: 214 | self.on_event(event_data) 215 | 216 | async def _wait_for_reconnect(self): 217 | logger.debug(f"Attempting to reconnect streaming in {self.reconnect_delay}") 218 | await asyncio.sleep(self.reconnect_delay) 219 | 220 | async def _close_session(self): 221 | if self._sse_session: 222 | await self._sse_session.close() 223 | logger.debug("Streaming session closed.") 224 | 225 | def _run_sse_channel(self): 226 | self._loop = asyncio.new_event_loop() 227 | 228 | try: 229 | self._loop.run_until_complete(self._init_session()) 230 | except asyncio.CancelledError: 231 | pass 232 | finally: 233 | self._loop.run_until_complete(self._loop.shutdown_asyncgens()) 234 | self._loop.close() 235 | 236 | async def _stop_session(self): 237 | if self._sse_session: 238 | await self._sse_session.close() 239 | 240 | if self._loop and self._loop.is_running(): 241 | tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()] 242 | for task in tasks: 243 | task.cancel() 244 | try: 245 | await task 246 | except asyncio.CancelledError: 247 | pass 248 | 249 | class FeatureRepository(object): 250 | def __init__(self) -> None: 251 | self.cache: AbstractFeatureCache = InMemoryFeatureCache() 252 | self.http: Optional[PoolManager] = None 253 | self.sse_client: Optional[SSEClient] = None 254 | self._feature_update_callbacks: List[Callable[[Dict], None]] = [] 255 | 256 | def set_cache(self, cache: AbstractFeatureCache) -> None: 257 | self.cache = cache 258 | 259 | def clear_cache(self): 260 | self.cache.clear() 261 | 262 | def save_in_cache(self, key: str, res, ttl: int = 600): 263 | self.cache.set(key, res, ttl) 264 | 265 | def add_feature_update_callback(self, callback: Callable[[Dict], None]) -> None: 266 | """Add a callback to be notified when features are updated due to cache expiry""" 267 | if callback not in self._feature_update_callbacks: 268 | self._feature_update_callbacks.append(callback) 269 | 270 | def remove_feature_update_callback(self, callback: Callable[[Dict], None]) -> None: 271 | """Remove a feature update callback""" 272 | if callback in self._feature_update_callbacks: 273 | self._feature_update_callbacks.remove(callback) 274 | 275 | def _notify_feature_update_callbacks(self, features_data: Dict) -> None: 276 | """Notify all registered callbacks about feature updates""" 277 | for callback in self._feature_update_callbacks: 278 | try: 279 | callback(features_data) 280 | except Exception as e: 281 | logger.warning(f"Error in feature update callback: {e}") 282 | 283 | # Loads features with an in-memory cache in front using stale-while-revalidate approach 284 | def load_features( 285 | self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 600 286 | ) -> Optional[Dict]: 287 | if not client_key: 288 | raise ValueError("Must specify `client_key` to refresh features") 289 | 290 | key = api_host + "::" + client_key 291 | 292 | cached = self.cache.get(key) 293 | if not cached: 294 | res = self._fetch_features(api_host, client_key, decryption_key) 295 | if res is not None: 296 | self.cache.set(key, res, ttl) 297 | logger.debug("Fetched features from API, stored in cache") 298 | # Notify callbacks about fresh features 299 | self._notify_feature_update_callbacks(res) 300 | return res 301 | return cached 302 | 303 | async def load_features_async( 304 | self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 600 305 | ) -> Optional[Dict]: 306 | key = api_host + "::" + client_key 307 | 308 | cached = self.cache.get(key) 309 | if not cached: 310 | res = await self._fetch_features_async(api_host, client_key, decryption_key) 311 | if res is not None: 312 | self.cache.set(key, res, ttl) 313 | logger.debug("Fetched features from API, stored in cache") 314 | # Notify callbacks about fresh features 315 | self._notify_feature_update_callbacks(res) 316 | return res 317 | return cached 318 | 319 | # Perform the GET request (separate method for easy mocking) 320 | def _get(self, url: str): 321 | self.http = self.http or PoolManager() 322 | return self.http.request("GET", url) 323 | 324 | def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: 325 | try: 326 | r = self._get(self._get_features_url(api_host, client_key)) 327 | if r.status >= 400: 328 | logger.warning( 329 | "Failed to fetch features, received status code %d", r.status 330 | ) 331 | return None 332 | decoded = json.loads(r.data.decode("utf-8")) 333 | return decoded 334 | except Exception: 335 | logger.warning("Failed to decode feature JSON from GrowthBook API") 336 | return None 337 | 338 | async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optional[Dict]: 339 | try: 340 | url = self._get_features_url(api_host, client_key) 341 | async with aiohttp.ClientSession() as session: 342 | async with session.get(url) as response: 343 | if response.status >= 400: 344 | logger.warning("Failed to fetch features, received status code %d", response.status) 345 | return None 346 | decoded = await response.json() 347 | return decoded 348 | except aiohttp.ClientError as e: 349 | logger.warning(f"HTTP request failed: {e}") 350 | return None 351 | except Exception as e: 352 | logger.warning("Failed to decode feature JSON from GrowthBook API: %s", e) 353 | return None 354 | 355 | def decrypt_response(self, data, decryption_key: str): 356 | if "encryptedFeatures" in data: 357 | if not decryption_key: 358 | raise ValueError("Must specify decryption_key") 359 | try: 360 | decryptedFeatures = decrypt(data["encryptedFeatures"], decryption_key) 361 | data['features'] = json.loads(decryptedFeatures) 362 | del data['encryptedFeatures'] 363 | except Exception: 364 | logger.warning( 365 | "Failed to decrypt features from GrowthBook API response" 366 | ) 367 | return None 368 | elif "features" not in data: 369 | logger.warning("GrowthBook API response missing features") 370 | 371 | if "encryptedSavedGroups" in data: 372 | if not decryption_key: 373 | raise ValueError("Must specify decryption_key") 374 | try: 375 | decryptedFeatures = decrypt(data["encryptedSavedGroups"], decryption_key) 376 | data['savedGroups'] = json.loads(decryptedFeatures) 377 | del data['encryptedSavedGroups'] 378 | return data 379 | except Exception: 380 | logger.warning( 381 | "Failed to decrypt saved groups from GrowthBook API response" 382 | ) 383 | 384 | return data 385 | 386 | # Fetch features from the GrowthBook API 387 | def _fetch_features( 388 | self, api_host: str, client_key: str, decryption_key: str = "" 389 | ) -> Optional[Dict]: 390 | decoded = self._fetch_and_decode(api_host, client_key) 391 | if not decoded: 392 | return None 393 | 394 | data = self.decrypt_response(decoded, decryption_key) 395 | 396 | return data 397 | 398 | async def _fetch_features_async( 399 | self, api_host: str, client_key: str, decryption_key: str = "" 400 | ) -> Optional[Dict]: 401 | decoded = await self._fetch_and_decode_async(api_host, client_key) 402 | if not decoded: 403 | return None 404 | 405 | data = self.decrypt_response(decoded, decryption_key) 406 | 407 | return data 408 | 409 | 410 | def startAutoRefresh(self, api_host, client_key, cb): 411 | if not client_key: 412 | raise ValueError("Must specify `client_key` to start features streaming") 413 | self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb) 414 | self.sse_client.connect() 415 | 416 | def stopAutoRefresh(self): 417 | self.sse_client.disconnect() 418 | 419 | @staticmethod 420 | def _get_features_url(api_host: str, client_key: str) -> str: 421 | api_host = (api_host or "https://cdn.growthbook.io").rstrip("/") 422 | return api_host + "/api/features/" + client_key 423 | 424 | 425 | # Singleton instance 426 | feature_repo = FeatureRepository() 427 | 428 | class GrowthBook(object): 429 | def __init__( 430 | self, 431 | enabled: bool = True, 432 | attributes: dict = {}, 433 | url: str = "", 434 | features: dict = {}, 435 | qa_mode: bool = False, 436 | on_experiment_viewed=None, 437 | api_host: str = "", 438 | client_key: str = "", 439 | decryption_key: str = "", 440 | cache_ttl: int = 600, 441 | forced_variations: dict = {}, 442 | sticky_bucket_service: AbstractStickyBucketService = None, 443 | sticky_bucket_identifier_attributes: List[str] = None, 444 | savedGroups: dict = {}, 445 | streaming: bool = False, 446 | # Deprecated args 447 | trackingCallback=None, 448 | qaMode: bool = False, 449 | user: dict = {}, 450 | groups: dict = {}, 451 | overrides: dict = {}, 452 | forcedVariations: dict = {}, 453 | ): 454 | self._enabled = enabled 455 | self._attributes = attributes 456 | self._url = url 457 | self._features: Dict[str, Feature] = {} 458 | self._saved_groups = savedGroups 459 | self._api_host = api_host 460 | self._client_key = client_key 461 | self._decryption_key = decryption_key 462 | self._cache_ttl = cache_ttl 463 | self.sticky_bucket_identifier_attributes = sticky_bucket_identifier_attributes 464 | self.sticky_bucket_service = sticky_bucket_service 465 | self._sticky_bucket_assignment_docs: dict = {} 466 | self._using_derived_sticky_bucket_attributes = not sticky_bucket_identifier_attributes 467 | self._sticky_bucket_attributes: Optional[dict] = None 468 | 469 | self._qaMode = qa_mode or qaMode 470 | self._trackingCallback = on_experiment_viewed or trackingCallback 471 | 472 | self._streaming = streaming 473 | 474 | # Deprecated args 475 | self._user = user 476 | self._groups = groups 477 | self._overrides = overrides 478 | self._forcedVariations = forced_variations or forcedVariations 479 | 480 | self._tracked: Dict[str, Any] = {} 481 | self._assigned: Dict[str, Any] = {} 482 | self._subscriptions: Set[Any] = set() 483 | 484 | self._global_ctx = GlobalContext( 485 | options=Options( 486 | url=self._url, 487 | api_host=self._api_host, 488 | client_key=self._client_key, 489 | decryption_key=self._decryption_key, 490 | cache_ttl=self._cache_ttl, 491 | sticky_bucket_service=self.sticky_bucket_service, 492 | sticky_bucket_identifier_attributes=self.sticky_bucket_identifier_attributes, 493 | enabled=self._enabled, 494 | qa_mode=self._qaMode 495 | ), 496 | features={}, 497 | saved_groups=self._saved_groups 498 | ) 499 | # Create a user context for the current user 500 | self._user_ctx: UserContext = UserContext( 501 | url=self._url, 502 | attributes=self._attributes, 503 | groups=self._groups, 504 | forced_variations=self._forcedVariations, 505 | overrides=self._overrides, 506 | sticky_bucket_assignment_docs=self._sticky_bucket_assignment_docs 507 | ) 508 | 509 | if features: 510 | self.setFeatures(features) 511 | 512 | # Register for automatic feature updates when cache expires 513 | if self._client_key: 514 | feature_repo.add_feature_update_callback(self._on_feature_update) 515 | 516 | if self._streaming: 517 | self.load_features() 518 | self.startAutoRefresh() 519 | 520 | def _on_feature_update(self, features_data: Dict) -> None: 521 | """Callback to handle automatic feature updates from FeatureRepository""" 522 | if features_data and "features" in features_data: 523 | self.set_features(features_data["features"]) 524 | if features_data and "savedGroups" in features_data: 525 | self._saved_groups = features_data["savedGroups"] 526 | 527 | def load_features(self) -> None: 528 | 529 | response = feature_repo.load_features( 530 | self._api_host, self._client_key, self._decryption_key, self._cache_ttl 531 | ) 532 | if response is not None and "features" in response.keys(): 533 | self.setFeatures(response["features"]) 534 | 535 | if response is not None and "savedGroups" in response: 536 | self._saved_groups = response["savedGroups"] 537 | 538 | async def load_features_async(self) -> None: 539 | if not self._client_key: 540 | raise ValueError("Must specify `client_key` to refresh features") 541 | 542 | features = await feature_repo.load_features_async( 543 | self._api_host, self._client_key, self._decryption_key, self._cache_ttl 544 | ) 545 | 546 | if features is not None: 547 | if "features" in features: 548 | self.setFeatures(features["features"]) 549 | if "savedGroups" in features: 550 | self._saved_groups = features["savedGroups"] 551 | feature_repo.save_in_cache(self._client_key, features, self._cache_ttl) 552 | 553 | def _features_event_handler(self, features): 554 | decoded = json.loads(features) 555 | if not decoded: 556 | return None 557 | 558 | data = feature_repo.decrypt_response(decoded, self._decryption_key) 559 | 560 | if data is not None: 561 | if "features" in data: 562 | self.setFeatures(data["features"]) 563 | if "savedGroups" in data: 564 | self._saved_groups = data["savedGroups"] 565 | feature_repo.save_in_cache(self._client_key, features, self._cache_ttl) 566 | 567 | def _dispatch_sse_event(self, event_data): 568 | event_type = event_data['type'] 569 | data = event_data['data'] 570 | if event_type == 'features-updated': 571 | self.load_features() 572 | elif event_type == 'features': 573 | self._features_event_handler(data) 574 | 575 | 576 | def startAutoRefresh(self): 577 | if not self._client_key: 578 | raise ValueError("Must specify `client_key` to start features streaming") 579 | 580 | feature_repo.startAutoRefresh( 581 | api_host=self._api_host, 582 | client_key=self._client_key, 583 | cb=self._dispatch_sse_event 584 | ) 585 | 586 | def stopAutoRefresh(self): 587 | feature_repo.stopAutoRefresh() 588 | 589 | # @deprecated, use set_features 590 | def setFeatures(self, features: dict) -> None: 591 | return self.set_features(features) 592 | 593 | def set_features(self, features: dict) -> None: 594 | self._features = {} 595 | for key, feature in features.items(): 596 | if isinstance(feature, Feature): 597 | self._features[key] = feature 598 | else: 599 | self._features[key] = Feature( 600 | rules=feature.get("rules", []), 601 | defaultValue=feature.get("defaultValue", None), 602 | ) 603 | # Update the global context with the new features and saved groups 604 | self._global_ctx.features = self._features 605 | self._global_ctx.saved_groups = self._saved_groups 606 | self.refresh_sticky_buckets() 607 | 608 | # @deprecated, use get_features 609 | def getFeatures(self) -> Dict[str, Feature]: 610 | return self.get_features() 611 | 612 | def get_features(self) -> Dict[str, Feature]: 613 | return self._features 614 | 615 | # @deprecated, use set_attributes 616 | def setAttributes(self, attributes: dict) -> None: 617 | return self.set_attributes(attributes) 618 | 619 | def set_attributes(self, attributes: dict) -> None: 620 | self._attributes = attributes 621 | self.refresh_sticky_buckets() 622 | 623 | # @deprecated, use get_attributes 624 | def getAttributes(self) -> dict: 625 | return self.get_attributes() 626 | 627 | def get_attributes(self) -> dict: 628 | return self._attributes 629 | 630 | def destroy(self) -> None: 631 | # Clean up feature update callback 632 | if self._client_key: 633 | feature_repo.remove_feature_update_callback(self._on_feature_update) 634 | 635 | self._subscriptions.clear() 636 | self._tracked.clear() 637 | self._assigned.clear() 638 | self._trackingCallback = None 639 | self._forcedVariations.clear() 640 | self._overrides.clear() 641 | self._groups.clear() 642 | self._attributes.clear() 643 | self._features.clear() 644 | 645 | # @deprecated, use is_on 646 | def isOn(self, key: str) -> bool: 647 | return self.is_on(key) 648 | 649 | def is_on(self, key: str) -> bool: 650 | return self.evalFeature(key).on 651 | 652 | # @deprecated, use is_off 653 | def isOff(self, key: str) -> bool: 654 | return self.is_off(key) 655 | 656 | def is_off(self, key: str) -> bool: 657 | return self.evalFeature(key).off 658 | 659 | # @deprecated, use get_feature_value 660 | def getFeatureValue(self, key: str, fallback): 661 | return self.get_feature_value(key, fallback) 662 | 663 | def get_feature_value(self, key: str, fallback): 664 | res = self.evalFeature(key) 665 | return res.value if res.value is not None else fallback 666 | 667 | # @deprecated, use eval_feature 668 | def evalFeature(self, key: str) -> FeatureResult: 669 | return self.eval_feature(key) 670 | 671 | def _ensure_fresh_features(self) -> None: 672 | """Lazy refresh: Check cache expiry and refresh if needed, but only if client_key is provided""" 673 | 674 | if self._streaming or not self._client_key: 675 | return # Skip cache checks - SSE handles freshness for streaming users 676 | 677 | try: 678 | self.load_features() 679 | except Exception as e: 680 | logger.warning(f"Failed to refresh features: {e}") 681 | 682 | def _get_eval_context(self) -> EvaluationContext: 683 | # Lazy refresh: ensure features are fresh before evaluation 684 | self._ensure_fresh_features() 685 | 686 | # use the latest attributes for every evaluation. 687 | self._user_ctx.attributes = self._attributes 688 | self._user_ctx.url = self._url 689 | self._user_ctx.overrides = self._overrides 690 | # set the url for every evaluation. (unlikely to change) 691 | self._global_ctx.options.url = self._url 692 | return EvaluationContext( 693 | global_ctx = self._global_ctx, 694 | user = self._user_ctx, 695 | stack = StackContext(evaluated_features=set()) 696 | ) 697 | 698 | def eval_feature(self, key: str) -> FeatureResult: 699 | return core_eval_feature(key=key, 700 | evalContext=self._get_eval_context(), 701 | callback_subscription=self._fireSubscriptions 702 | ) 703 | 704 | # @deprecated, use get_all_results 705 | def getAllResults(self): 706 | return self.get_all_results() 707 | 708 | def get_all_results(self): 709 | return self._assigned.copy() 710 | 711 | def _fireSubscriptions(self, experiment: Experiment, result: Result): 712 | if experiment is None: 713 | return 714 | 715 | prev = self._assigned.get(experiment.key, None) 716 | if ( 717 | not prev 718 | or prev["result"].inExperiment != result.inExperiment 719 | or prev["result"].variationId != result.variationId 720 | ): 721 | self._assigned[experiment.key] = { 722 | "experiment": experiment, 723 | "result": result, 724 | } 725 | for cb in self._subscriptions: 726 | try: 727 | cb(experiment, result) 728 | except Exception: 729 | pass 730 | 731 | def run(self, experiment: Experiment) -> Result: 732 | # result = self._run(experiment) 733 | result = run_experiment(experiment=experiment, 734 | evalContext=self._get_eval_context(), 735 | tracking_cb=self._track 736 | ) 737 | 738 | self._fireSubscriptions(experiment, result) 739 | return result 740 | 741 | def subscribe(self, callback): 742 | self._subscriptions.add(callback) 743 | return lambda: self._subscriptions.remove(callback) 744 | 745 | def _track(self, experiment: Experiment, result: Result) -> None: 746 | if not self._trackingCallback: 747 | return None 748 | key = ( 749 | result.hashAttribute 750 | + str(result.hashValue) 751 | + experiment.key 752 | + str(result.variationId) 753 | ) 754 | if not self._tracked.get(key): 755 | try: 756 | self._trackingCallback(experiment=experiment, result=result) 757 | self._tracked[key] = True 758 | except Exception: 759 | pass 760 | 761 | def _derive_sticky_bucket_identifier_attributes(self) -> List[str]: 762 | attributes = set() 763 | for key, feature in self._features.items(): 764 | for rule in feature.rules: 765 | if rule.variations: 766 | attributes.add(rule.hashAttribute or "id") 767 | if rule.fallbackAttribute: 768 | attributes.add(rule.fallbackAttribute) 769 | return list(attributes) 770 | 771 | def _get_sticky_bucket_attributes(self) -> dict: 772 | attributes: Dict[str, str] = {} 773 | if self._using_derived_sticky_bucket_attributes: 774 | self.sticky_bucket_identifier_attributes = self._derive_sticky_bucket_identifier_attributes() 775 | 776 | if not self.sticky_bucket_identifier_attributes: 777 | return attributes 778 | 779 | for attr in self.sticky_bucket_identifier_attributes: 780 | _, hash_value = _getHashValue(attr=attr, eval_context=self._get_eval_context()) 781 | if hash_value: 782 | attributes[attr] = hash_value 783 | return attributes 784 | 785 | def refresh_sticky_buckets(self, force: bool = False) -> None: 786 | if not self.sticky_bucket_service: 787 | return 788 | 789 | attributes = self._get_sticky_bucket_attributes() 790 | if not force and attributes == self._sticky_bucket_attributes: 791 | logger.debug("Skipping refresh of sticky bucket assignments, no changes") 792 | return 793 | 794 | self._sticky_bucket_attributes = attributes 795 | self._sticky_bucket_assignment_docs = self.sticky_bucket_service.get_all_assignments(attributes) 796 | # Update the user context with the new sticky bucket assignment docs 797 | self._user_ctx.sticky_bucket_assignment_docs = self._sticky_bucket_assignment_docs 798 | -------------------------------------------------------------------------------- /growthbook/growthbook_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from dataclasses import dataclass, field 4 | import random 5 | import logging 6 | from typing import Any, Dict, List, Optional, Union, Callable, Awaitable 7 | from typing import Set 8 | import asyncio 9 | import threading 10 | import traceback 11 | from datetime import datetime 12 | from growthbook import FeatureRepository 13 | from contextlib import asynccontextmanager 14 | 15 | from .core import eval_feature as core_eval_feature, run_experiment 16 | from .common_types import ( 17 | Feature, 18 | GlobalContext, 19 | Options, 20 | Result, 21 | UserContext, 22 | EvaluationContext, 23 | StackContext, 24 | FeatureResult, 25 | FeatureRefreshStrategy, 26 | Experiment 27 | ) 28 | 29 | logger = logging.getLogger("growthbook.growthbook_client") 30 | 31 | class SingletonMeta(type): 32 | """Thread-safe implementation of Singleton pattern""" 33 | _instances: Dict[type, Any] = {} 34 | _lock = threading.Lock() 35 | 36 | def __call__(cls, *args, **kwargs): 37 | with cls._lock: 38 | if cls not in cls._instances: 39 | instance = super().__call__(*args, **kwargs) 40 | cls._instances[cls] = instance 41 | return cls._instances[cls] 42 | 43 | class BackoffStrategy: 44 | """Exponential backoff with jitter for failed requests""" 45 | def __init__( 46 | self, 47 | initial_delay: float = 1.0, 48 | max_delay: float = 60.0, 49 | multiplier: float = 2.0, 50 | jitter: float = 0.1 51 | ): 52 | self.initial_delay = initial_delay 53 | self.max_delay = max_delay 54 | self.multiplier = multiplier 55 | self.jitter = jitter 56 | self.current_delay = initial_delay 57 | self.attempt = 0 58 | 59 | def next_delay(self) -> float: 60 | """Calculate next delay with jitter""" 61 | delay = min( 62 | self.current_delay * (self.multiplier ** self.attempt), 63 | self.max_delay 64 | ) 65 | # Add random jitter 66 | jitter_amount = delay * self.jitter 67 | delay = delay + (random.random() * 2 - 1) * jitter_amount 68 | self.attempt += 1 69 | return max(delay, self.initial_delay) 70 | 71 | def reset(self) -> None: 72 | """Reset backoff state""" 73 | self.current_delay = self.initial_delay 74 | self.attempt = 0 75 | 76 | class WeakRefWrapper: 77 | """A wrapper class to allow weak references for otherwise non-weak-referenceable objects.""" 78 | def __init__(self, obj): 79 | self.obj = obj 80 | 81 | class FeatureCache: 82 | """Thread-safe feature cache""" 83 | def __init__(self): 84 | self._cache = { 85 | 'features': {}, 86 | 'savedGroups': {} 87 | } 88 | self._lock = threading.Lock() 89 | 90 | def update(self, features: Dict[str, Any], saved_groups: Dict[str, Any]) -> None: 91 | """Simple thread-safe update of cache with new API data""" 92 | with self._lock: 93 | self._cache['features'].update(features) 94 | self._cache['savedGroups'].update(saved_groups) 95 | 96 | def get_current_state(self) -> Dict[str, Any]: 97 | """Get current cache state""" 98 | with self._lock: 99 | return { 100 | "features": dict(self._cache['features']), 101 | "savedGroups": self._cache['savedGroups'] 102 | } 103 | 104 | class EnhancedFeatureRepository(FeatureRepository, metaclass=SingletonMeta): 105 | def __init__(self, api_host: str, client_key: str, decryption_key: str = "", cache_ttl: int = 60): 106 | FeatureRepository.__init__(self) 107 | self._api_host = api_host 108 | self._client_key = client_key 109 | self._decryption_key = decryption_key 110 | self._cache_ttl = cache_ttl 111 | self._refresh_lock = threading.Lock() 112 | self._refresh_task: Optional[asyncio.Task] = None 113 | self._stop_event = asyncio.Event() 114 | self._backoff = BackoffStrategy() 115 | self._feature_cache = FeatureCache() 116 | self._callbacks: List[Callable[[Dict[str, Any]], Awaitable[None]]] = [] 117 | self._last_successful_refresh = None 118 | self._refresh_in_progress = asyncio.Lock() 119 | 120 | @asynccontextmanager 121 | async def refresh_operation(self): 122 | """Context manager for feature refresh with proper cleanup""" 123 | if self._refresh_in_progress.locked(): 124 | yield False 125 | return 126 | 127 | # async with self._refresh_in_progress: 128 | try: 129 | await self._refresh_in_progress.acquire() 130 | yield True 131 | self._backoff.reset() 132 | self._last_successful_refresh = datetime.now() 133 | except Exception as e: 134 | delay = self._backoff.next_delay() 135 | logger.error(f"Refresh failed, next attempt in {delay:.2f}s: {str(e)}") 136 | traceback.print_exc() 137 | raise 138 | finally: 139 | if self._refresh_in_progress.locked(): 140 | self._refresh_in_progress.release() 141 | 142 | async def _handle_feature_update(self, data: Dict[str, Any]) -> None: 143 | """Update features with memory optimization""" 144 | # Directly update with new features 145 | self._feature_cache.update( 146 | data.get("features", {}), 147 | data.get("savedGroups", {}) 148 | ) 149 | 150 | # Create a copy of callbacks to avoid modification during iteration 151 | with self._refresh_lock: 152 | callbacks = self._callbacks.copy() 153 | 154 | for callback in callbacks: 155 | try: 156 | await callback(dict(self._feature_cache.get_current_state())) 157 | except Exception: 158 | traceback.print_exc() 159 | 160 | def add_callback(self, callback: Callable[[Dict[str, Any]], Awaitable[None]]) -> None: 161 | """Add callback to the list""" 162 | with self._refresh_lock: 163 | if callback not in self._callbacks: 164 | self._callbacks.append(callback) 165 | 166 | def remove_callback(self, callback: Callable[[Dict[str, Any]], Awaitable[None]]) -> None: 167 | """Remove callback from the list""" 168 | with self._refresh_lock: 169 | if callback in self._callbacks: 170 | self._callbacks.remove(callback) 171 | 172 | async def _start_sse_refresh(self) -> None: 173 | """Start SSE-based feature refresh""" 174 | with self._refresh_lock: 175 | if self._refresh_task is not None: # Already running 176 | return 177 | 178 | async def sse_handler(event_data: Dict[str, Any]) -> None: 179 | try: 180 | if event_data['type'] == 'features-updated': 181 | response = await self.load_features_async( 182 | self._api_host, self._client_key, self._decryption_key, self._cache_ttl 183 | ) 184 | if response is not None: 185 | await self._handle_feature_update(response) 186 | elif event_data['type'] == 'features': 187 | await self._handle_feature_update(event_data['data']) 188 | except Exception: 189 | traceback.print_exc() 190 | 191 | # Start the SSE connection task 192 | self._refresh_task = asyncio.create_task( 193 | self._maintain_sse_connection(sse_handler) 194 | ) 195 | 196 | async def _maintain_sse_connection(self, handler: Callable) -> None: 197 | """Maintain SSE connection with automatic reconnection""" 198 | while not self._stop_event.is_set(): 199 | try: 200 | await self.startAutoRefresh(self._api_host, self._client_key, handler) 201 | except Exception as e: 202 | if not self._stop_event.is_set(): 203 | delay = self._backoff.next_delay() 204 | logger.error(f"SSE connection lost, reconnecting in {delay:.2f}s: {str(e)}") 205 | await asyncio.sleep(delay) 206 | 207 | async def _start_http_refresh(self, interval: int = 60) -> None: 208 | """Enhanced HTTP polling with backoff""" 209 | if self._refresh_task: 210 | return 211 | 212 | async def refresh_loop() -> None: 213 | try: 214 | while not self._stop_event.is_set(): 215 | async with self.refresh_operation() as should_refresh: 216 | if should_refresh: 217 | try: 218 | response = await self.load_features_async( 219 | api_host=self._api_host, 220 | client_key=self._client_key, 221 | decryption_key=self._decryption_key, 222 | ttl=self._cache_ttl 223 | ) 224 | if response is not None: 225 | await self._handle_feature_update(response) 226 | # On success, reset backoff and use normal interval 227 | self._backoff.reset() 228 | try: 229 | await asyncio.sleep(interval) 230 | except asyncio.CancelledError: 231 | # Allow cancellation during sleep 232 | raise 233 | except Exception as e: 234 | # On failure, use backoff delay 235 | delay = self._backoff.next_delay() 236 | logger.error(f"Refresh failed, next attempt in {delay:.2f}s: {str(e)}") 237 | traceback.print_exc() 238 | try: 239 | await asyncio.sleep(delay) 240 | except asyncio.CancelledError: 241 | # Allow cancellation during sleep 242 | raise 243 | except asyncio.CancelledError: 244 | # Clean exit on cancellation 245 | raise 246 | finally: 247 | # Ensure we're marked as stopped 248 | self._stop_event.set() 249 | 250 | self._refresh_task = asyncio.create_task(refresh_loop()) 251 | 252 | async def start_feature_refresh(self, strategy: FeatureRefreshStrategy, callback=None): 253 | """Initialize feature refresh based on strategy""" 254 | self._refresh_callback = callback 255 | 256 | if strategy == FeatureRefreshStrategy.SERVER_SENT_EVENTS: 257 | await self._start_sse_refresh() 258 | else: 259 | await self._start_http_refresh() 260 | 261 | async def stop_refresh(self) -> None: 262 | """Clean shutdown of refresh tasks""" 263 | self._stop_event.set() 264 | if self._refresh_task: 265 | # Cancel the task 266 | self._refresh_task.cancel() 267 | try: 268 | # Wait for it to actually finish 269 | await self._refresh_task 270 | except asyncio.CancelledError: 271 | pass 272 | except Exception as e: 273 | logger.error(f"Error during refresh task cleanup: {e}") 274 | finally: 275 | self._refresh_task = None 276 | self._backoff.reset() 277 | self._stop_event.clear() 278 | 279 | async def __aenter__(self): 280 | return self 281 | 282 | async def __aexit__(self, exc_type, exc_val, exc_tb): 283 | await self.stop_refresh() 284 | 285 | async def load_features_async( 286 | self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 60 287 | ) -> Optional[Dict]: 288 | # Use stored values when called internally 289 | if api_host == self._api_host and client_key == self._client_key: 290 | decryption_key = self._decryption_key 291 | ttl = self._cache_ttl 292 | return await super().load_features_async(api_host, client_key, decryption_key, ttl) 293 | 294 | class GrowthBookClient: 295 | def __init__( 296 | self, 297 | options: Optional[Union[Dict[str, Any], Options]] = None 298 | ): 299 | self.options = ( 300 | options if isinstance(options, Options) 301 | else Options(**options) if options 302 | else Options() 303 | ) 304 | 305 | # Thread-safe tracking state 306 | self._tracked: Dict[str, bool] = {} # Access only within async context 307 | self._tracked_lock = threading.Lock() 308 | 309 | # Thread-safe subscription management 310 | self._subscriptions: Set[Callable[[Experiment, Result], None]] = set() 311 | self._subscriptions_lock = threading.Lock() 312 | 313 | # Add sticky bucket cache 314 | self._sticky_bucket_cache: Dict[str, Dict[str, Any]] = { 315 | 'attributes': {}, 316 | 'assignments': {} 317 | } 318 | self._sticky_bucket_cache_lock = False 319 | 320 | self._features_repository = ( 321 | EnhancedFeatureRepository( 322 | self.options.api_host or "https://cdn.growthbook.io", 323 | self.options.client_key or "", 324 | self.options.decryption_key or "", 325 | self.options.cache_ttl 326 | ) 327 | if self.options.client_key 328 | else None 329 | ) 330 | 331 | self._global_context: Optional[GlobalContext] = None 332 | self._context_lock = asyncio.Lock() 333 | 334 | def _track(self, experiment: Experiment, result: Result) -> None: 335 | """Thread-safe tracking implementation""" 336 | if not self.options.on_experiment_viewed: 337 | return 338 | 339 | # Create unique key for this tracking event 340 | key = ( 341 | result.hashAttribute 342 | + str(result.hashValue) 343 | + experiment.key 344 | + str(result.variationId) 345 | ) 346 | 347 | with self._tracked_lock: 348 | if not self._tracked.get(key): 349 | try: 350 | self.options.on_experiment_viewed(experiment=experiment, result=result) 351 | self._tracked[key] = True 352 | except Exception: 353 | logger.exception("Error in tracking callback") 354 | 355 | def subscribe(self, callback: Callable[[Experiment, Result], None]) -> Callable[[], None]: 356 | """Thread-safe subscription management""" 357 | with self._subscriptions_lock: 358 | self._subscriptions.add(callback) 359 | def unsubscribe(): 360 | with self._subscriptions_lock: 361 | self._subscriptions.discard(callback) 362 | return unsubscribe 363 | 364 | def _fire_subscriptions(self, experiment: Experiment, result: Result) -> None: 365 | """Thread-safe subscription notifications""" 366 | with self._subscriptions_lock: 367 | subscriptions = self._subscriptions.copy() 368 | 369 | for callback in subscriptions: 370 | try: 371 | callback(experiment, result) 372 | except Exception: 373 | logger.exception("Error in subscription callback") 374 | 375 | async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, Any]: 376 | """Refresh sticky bucket assignments only if attributes have changed""" 377 | if not self.options.sticky_bucket_service: 378 | return {} 379 | 380 | # Use compare-and-swap pattern 381 | while not self._sticky_bucket_cache_lock: 382 | if attributes == self._sticky_bucket_cache['attributes']: 383 | return self._sticky_bucket_cache['assignments'] 384 | 385 | self._sticky_bucket_cache_lock = True 386 | try: 387 | assignments = self.options.sticky_bucket_service.get_all_assignments(attributes) 388 | self._sticky_bucket_cache['attributes'] = attributes.copy() 389 | self._sticky_bucket_cache['assignments'] = assignments 390 | return assignments 391 | finally: 392 | self._sticky_bucket_cache_lock = False 393 | 394 | # Fallback return for edge case where loop condition is never satisfied 395 | return {} 396 | 397 | async def initialize(self) -> bool: 398 | """Initialize client with features and start refresh""" 399 | if not self._features_repository: 400 | logger.error("No features repository available") 401 | return False 402 | 403 | try: 404 | # Initial feature load 405 | initial_features = await self._features_repository.load_features_async( 406 | self.options.api_host or "https://cdn.growthbook.io", 407 | self.options.client_key or "", 408 | self.options.decryption_key or "", 409 | self.options.cache_ttl 410 | ) 411 | if not initial_features: 412 | logger.error("Failed to load initial features") 413 | return False 414 | 415 | # Create global context with initial features 416 | await self._feature_update_callback(initial_features) 417 | 418 | # Set up callback for future updates 419 | self._features_repository.add_callback(self._feature_update_callback) 420 | 421 | # Start feature refresh 422 | refresh_strategy = self.options.refresh_strategy or FeatureRefreshStrategy.STALE_WHILE_REVALIDATE 423 | await self._features_repository.start_feature_refresh(refresh_strategy) 424 | return True 425 | 426 | except Exception as e: 427 | logger.error(f"Initialization failed: {str(e)}", exc_info=True) 428 | traceback.print_exc() 429 | return False 430 | 431 | async def _feature_update_callback(self, features_data: Dict[str, Any]) -> None: 432 | """Handle feature updates and manage global context""" 433 | if not features_data: 434 | logger.warning("Warning: Received empty features data") 435 | return 436 | 437 | async with self._context_lock: 438 | features = {} 439 | 440 | for key, feature in features_data.get("features", {}).items(): 441 | if isinstance(feature, Feature): 442 | features[key] = feature 443 | else: 444 | features[key] = Feature( 445 | rules=feature.get("rules", []), 446 | defaultValue=feature.get("defaultValue", None), 447 | ) 448 | 449 | if self._global_context is None: 450 | # Initial creation of global context 451 | self._global_context = GlobalContext( 452 | options=self.options, 453 | features=features, 454 | saved_groups=features_data.get("savedGroups", {}) 455 | ) 456 | else: 457 | # Update existing global context 458 | self._global_context.features = features 459 | self._global_context.saved_groups = features_data.get("savedGroups", {}) 460 | 461 | async def __aenter__(self): 462 | await self.initialize() 463 | return self 464 | 465 | async def __aexit__(self, exc_type, exc_val, exc_tb): 466 | await self.close() 467 | 468 | async def create_evaluation_context(self, user_context: UserContext) -> EvaluationContext: 469 | """Create evaluation context for feature evaluation""" 470 | if self._global_context is None: 471 | raise RuntimeError("GrowthBook client not properly initialized") 472 | 473 | # Get sticky bucket assignments if needed 474 | sticky_assignments = await self._refresh_sticky_buckets(user_context.attributes) 475 | 476 | # update user context with sticky bucket assignments 477 | user_context.sticky_bucket_assignment_docs = sticky_assignments 478 | 479 | return EvaluationContext( 480 | user=user_context, 481 | global_ctx=self._global_context, 482 | stack=StackContext(evaluated_features=set()) 483 | ) 484 | 485 | async def eval_feature(self, key: str, user_context: UserContext) -> FeatureResult: 486 | """Evaluate a feature with proper async context management""" 487 | async with self._context_lock: 488 | context = await self.create_evaluation_context(user_context) 489 | result = core_eval_feature(key=key, evalContext=context) 490 | return result 491 | 492 | async def is_on(self, key: str, user_context: UserContext) -> bool: 493 | """Check if a feature is enabled with proper async context management""" 494 | async with self._context_lock: 495 | context = await self.create_evaluation_context(user_context) 496 | return core_eval_feature(key=key, evalContext=context).on 497 | 498 | async def is_off(self, key: str, user_context: UserContext) -> bool: 499 | """Check if a feature is set to off with proper async context management""" 500 | async with self._context_lock: 501 | context = await self.create_evaluation_context(user_context) 502 | return core_eval_feature(key=key, evalContext=context).off 503 | 504 | async def get_feature_value(self, key: str, fallback: Any, user_context: UserContext) -> Any: 505 | async with self._context_lock: 506 | context = await self.create_evaluation_context(user_context) 507 | result = core_eval_feature(key=key, evalContext=context) 508 | return result.value if result.value is not None else fallback 509 | 510 | async def run(self, experiment: Experiment, user_context: UserContext) -> Result: 511 | """Run experiment with tracking""" 512 | async with self._context_lock: 513 | context = await self.create_evaluation_context(user_context) 514 | result = run_experiment( 515 | experiment=experiment, 516 | evalContext=context, 517 | tracking_cb=self._track 518 | ) 519 | # Fire subscriptions synchronously 520 | self._fire_subscriptions(experiment, result) 521 | return result 522 | 523 | async def close(self) -> None: 524 | """Clean shutdown with proper cleanup""" 525 | if self._features_repository: 526 | await self._features_repository.stop_refresh() 527 | 528 | # Clear tracking and subscription state 529 | with self._tracked_lock: 530 | self._tracked.clear() 531 | with self._subscriptions_lock: 532 | self._subscriptions.clear() 533 | 534 | # Clear context 535 | async with self._context_lock: 536 | self._global_context = None -------------------------------------------------------------------------------- /growthbook/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/growthbook/growthbook-python/2d83c4156f845b609402a14ad70afa02da9cef44/growthbook/py.typed -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Stick to old setup and keep only pytest config 2 | [tool.pytest.ini_options] 3 | asyncio_mode = "strict" 4 | asyncio_default_fixture_loop_scope = "function" 5 | testpaths = ["tests"] 6 | pythonpath = "." 7 | 8 | [project] 9 | dynamic = ["version"] 10 | name = "growthbook" 11 | description = "Powerful Feature flagging and A/B testing for Python apps" 12 | readme = "README.md" 13 | requires-python = ">=3.6" 14 | license = {text = "MIT"} 15 | authors = [ 16 | {name = "GrowthBook", email = "hello@growthbook.io"} 17 | ] 18 | keywords = ["growthbook"] 19 | dependencies = [ 20 | "cryptography", 21 | "typing_extensions", 22 | "urllib3", 23 | 'dataclasses;python_version<"3.7"', 24 | 'async-generator;python_version<"3.7"', 25 | "aiohttp>=3.6.0", 26 | 'importlib-metadata;python_version<"3.8"' 27 | ] 28 | classifiers = [ 29 | "Development Status :: 4 - Beta", 30 | "Intended Audience :: Developers", 31 | "License :: OSI Approved :: MIT License", 32 | "Programming Language :: Python :: 3", 33 | "Programming Language :: Python :: 3.6", 34 | "Programming Language :: Python :: 3.7", 35 | "Programming Language :: Python :: 3.8", 36 | "Programming Language :: Python :: 3.9", 37 | "Programming Language :: Python :: 3.10", 38 | "Programming Language :: Python :: 3.11", 39 | "Programming Language :: Python :: 3.12", 40 | ] 41 | 42 | [build-system] 43 | requires = ["setuptools>=45"] 44 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cryptography>=35.0.0 2 | typing_extensions>=3.8.0 3 | urllib3>=1.26.0 4 | aiohttp>=3.8.6 5 | asyncio>=3.4.3 6 | pytest-mock>=3.0.0 7 | dataclasses;python_version<"3.7" 8 | async-generator;python_version<"3.7" 9 | importlib-metadata;python_version<"3.8" -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip 2 | bump2version>=1.0.0 3 | wheel 4 | watchdog 5 | flake8 6 | coverage 7 | twine 8 | pytest 9 | pytest-asyncio>=0.10.0 10 | pytest-mock 11 | mypy 12 | types-dataclasses>=0.6.6;python_version<"3.7" 13 | types-urllib3 14 | aiohttp-sse-client -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.2.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bdist_wheel] 11 | universal = 1 12 | 13 | [flake8] 14 | exclude = docs 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | from os import path 5 | import re 6 | 7 | this_directory = path.abspath(path.dirname(__file__)) 8 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 9 | long_description = f.read() 10 | 11 | requirements = [ 12 | 'cryptography', 13 | 'typing_extensions', 14 | 'urllib3', 15 | 'aiohttp>=3.6.0', # For async HTTP support 16 | ] 17 | 18 | test_requirements = [ 19 | 'pytest>=3', 20 | 'pytest-asyncio>=0.10.0', 21 | ] 22 | 23 | def get_version(): 24 | with open('growthbook/__init__.py', 'r') as f: 25 | content = f.read() 26 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) 27 | if version_match: 28 | return version_match.group(1) 29 | raise RuntimeError("Unable to find version string.") 30 | 31 | setup( 32 | name='growthbook', 33 | author="GrowthBook", 34 | author_email='hello@growthbook.io', 35 | python_requires='>=3.7', 36 | version=get_version(), # Read version from __init__.py 37 | classifiers=[ 38 | 'Development Status :: 4 - Beta', 39 | 'Intended Audience :: Developers', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Natural Language :: English', 42 | 'Programming Language :: Python :: 3', 43 | 'Programming Language :: Python :: 3.7', 44 | 'Programming Language :: Python :: 3.8', 45 | 'Programming Language :: Python :: 3.9', 46 | 'Programming Language :: Python :: 3.10', 47 | 'Programming Language :: Python :: 3.11', 48 | 'Programming Language :: Python :: 3.12', 49 | ], 50 | description="Powerful Feature flagging and A/B testing for Python apps", 51 | long_description=long_description, 52 | long_description_content_type='text/markdown', 53 | install_requires=requirements, 54 | license="MIT", 55 | include_package_data=True, 56 | packages=find_packages(include=['growthbook', 'growthbook.*']), 57 | package_data={"growthbook": ["py.typed"]}, 58 | keywords='growthbook', 59 | test_suite='tests', 60 | tests_require=test_requirements, 61 | url='https://github.com/growthbook/growthbook-python', 62 | ) 63 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | import os 4 | import sys 5 | from growthbook.growthbook_client import EnhancedFeatureRepository, SingletonMeta 6 | 7 | # Only define the event_loop_policy fixture, and let pytest-asyncio handle event_loop 8 | @pytest.fixture(scope="session") 9 | def event_loop_policy(): 10 | return asyncio.get_event_loop_policy() 11 | 12 | @pytest.fixture(autouse=True) 13 | def reset_singleton(): 14 | """Reset the EnhancedFeatureRepository singleton between tests""" 15 | # Let the test run first 16 | yield 17 | # Only clear after test is completely done 18 | if hasattr(SingletonMeta, '_instances'): 19 | # Ensure any async operations are complete 20 | for instance in SingletonMeta._instances.values(): 21 | if hasattr(instance, '_stop_event'): 22 | instance._stop_event.set() 23 | SingletonMeta._instances.clear() 24 | 25 | @pytest.fixture(autouse=True) 26 | async def cleanup_tasks(): 27 | """Cleanup any pending tasks after each test.""" 28 | yield 29 | loop = asyncio.get_event_loop() 30 | # Let any pending callbacks complete 31 | await asyncio.sleep(0) 32 | # Ensure singleton instances are cleaned up first 33 | if hasattr(SingletonMeta, '_instances'): 34 | for instance in SingletonMeta._instances.values(): 35 | if hasattr(instance, 'stop_refresh'): 36 | await instance.stop_refresh() 37 | # Clear singleton instances 38 | SingletonMeta._instances.clear() 39 | 40 | # Add project root to Python path 41 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -------------------------------------------------------------------------------- /tests/test_growthbook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import os 5 | from growthbook import ( 6 | FeatureRule, 7 | GrowthBook, 8 | Experiment, 9 | Feature, 10 | InMemoryStickyBucketService, 11 | decrypt, 12 | feature_repo, 13 | logger, 14 | ) 15 | 16 | from growthbook.core import ( 17 | getBucketRanges, 18 | gbhash, 19 | chooseVariation, 20 | paddedVersionString, 21 | getQueryStringOverride, 22 | inNamespace, 23 | getEqualWeights, 24 | evalCondition, 25 | ) 26 | 27 | from time import time 28 | import pytest 29 | 30 | logger.setLevel("DEBUG") 31 | 32 | 33 | def pytest_generate_tests(metafunc): 34 | folder = os.path.abspath(os.path.dirname(__file__)) 35 | jsonfile = os.path.join(folder, "cases.json") 36 | with open(jsonfile) as file: 37 | data = json.load(file) 38 | 39 | for func, cases in data.items(): 40 | key = func + "_data" 41 | 42 | if (func == "versionCompare"): 43 | for method, cases in cases.items(): 44 | key = func + "_" + method + "_data" 45 | if (key in metafunc.fixturenames): 46 | metafunc.parametrize(key, cases) 47 | elif key in metafunc.fixturenames: 48 | metafunc.parametrize(key, cases) 49 | 50 | 51 | def test_hash(hash_data): 52 | seed, value, version, expected = hash_data 53 | assert gbhash(seed, value, version) == expected 54 | 55 | 56 | def round_list(item): 57 | is_tuple = type(item) is tuple 58 | 59 | if is_tuple: 60 | item = list(item) 61 | 62 | for i, value in enumerate(item): 63 | item[i] = round(value, 6) 64 | 65 | return item 66 | 67 | 68 | def round_list_of_lists(item): 69 | for i, value in enumerate(item): 70 | item[i] = round_list(value) 71 | return item 72 | 73 | 74 | def test_get_bucket_range(getBucketRange_data): 75 | _, args, expected = getBucketRange_data 76 | numVariations, coverage, weights = args 77 | 78 | actual = getBucketRanges(numVariations, coverage, weights) 79 | 80 | assert round_list_of_lists(actual) == round_list_of_lists(expected) 81 | 82 | 83 | def test_choose_variation(chooseVariation_data): 84 | _, n, ranges, expected = chooseVariation_data 85 | assert chooseVariation(n, ranges) == expected 86 | 87 | 88 | def test_get_qs_override(getQueryStringOverride_data): 89 | _, id, url, numVariations, expected = getQueryStringOverride_data 90 | assert getQueryStringOverride(id, url, numVariations) == expected 91 | 92 | 93 | def test_namespace(inNamespace_data): 94 | _, id, namespace, expected = inNamespace_data 95 | assert inNamespace(id, namespace) == expected 96 | 97 | 98 | def test_equal_weights(getEqualWeights_data): 99 | numVariations, expected = getEqualWeights_data 100 | weights = getEqualWeights(numVariations) 101 | assert round_list(weights) == round_list(expected) 102 | 103 | 104 | def test_conditions(evalCondition_data): 105 | _, condition, attributes, expected, savedGroups = (evalCondition_data + [None]*5)[:5] 106 | assert evalCondition(attributes, condition, savedGroups) == expected 107 | 108 | 109 | def test_decrypt(decrypt_data): 110 | _, encrypted, key, expected = decrypt_data 111 | try: 112 | assert (decrypt(encrypted, key)) == expected 113 | except Exception: 114 | assert (expected) is None 115 | 116 | 117 | def test_feature(feature_data): 118 | _, ctx, key, expected = feature_data 119 | gb = GrowthBook(**ctx) 120 | res = gb.evalFeature(key) 121 | 122 | if "experiment" in expected: 123 | expected["experiment"] = Experiment(**expected["experiment"]).to_dict() 124 | 125 | actual = res.to_dict() 126 | 127 | assert actual == expected 128 | gb.destroy() 129 | 130 | 131 | def test_run(run_data): 132 | _, ctx, exp, value, inExperiment, hashUsed = run_data 133 | gb = GrowthBook(**ctx) 134 | 135 | res = gb.run(Experiment(**exp)) 136 | assert res.value == value 137 | assert res.inExperiment == inExperiment 138 | assert res.hashUsed == hashUsed 139 | 140 | gb.destroy() 141 | 142 | 143 | def test_stickyBucket(stickyBucket_data): 144 | _, ctx, initial_docs, key, expected_result, expected_docs = stickyBucket_data 145 | # Just use the interface directly, which passes and doesn't persist anywhere 146 | service = InMemoryStickyBucketService() 147 | 148 | for doc in initial_docs: 149 | service.save_assignments(doc) 150 | 151 | ctx['sticky_bucket_service'] = service 152 | 153 | if 'stickyBucketIdentifierAttributes' in ctx: 154 | ctx['sticky_bucket_identifier_attributes'] = ctx['stickyBucketIdentifierAttributes'] 155 | ctx.pop('stickyBucketIdentifierAttributes') 156 | 157 | if 'stickyBucketAssignmentDocs' in ctx: 158 | service.docs = ctx['stickyBucketAssignmentDocs'] 159 | ctx.pop('stickyBucketAssignmentDocs') 160 | 161 | gb = GrowthBook(**ctx) 162 | res = gb.eval_feature(key) 163 | 164 | if not res.experimentResult: 165 | assert None == expected_result 166 | else: 167 | assert res.experimentResult.to_dict() == expected_result 168 | 169 | # Ignore extra docs in service, just make sure each expected one matches 170 | for key, value in expected_docs.items(): 171 | assert service.docs[key] == value 172 | 173 | service.destroy() 174 | gb.destroy() 175 | 176 | 177 | def getTrackingMock(gb: GrowthBook): 178 | calls = [] 179 | 180 | def track(experiment, result): 181 | return calls.append([experiment, result]) 182 | 183 | gb._trackingCallback = track 184 | return lambda: calls 185 | 186 | 187 | def test_tracking(): 188 | gb = GrowthBook(attributes={"id": "1"}) 189 | 190 | getMockedCalls = getTrackingMock(gb) 191 | 192 | exp1 = Experiment( 193 | key="my-tracked-test", 194 | variations=[0, 1], 195 | ) 196 | exp2 = Experiment( 197 | key="my-other-tracked-test", 198 | variations=[0, 1], 199 | ) 200 | 201 | res1 = gb.run(exp1) 202 | gb.run(exp1) 203 | gb.run(exp1) 204 | res4 = gb.run(exp2) 205 | gb._attributes = {"id": "2"} 206 | res5 = gb.run(exp2) 207 | 208 | calls = getMockedCalls() 209 | assert len(calls) == 3 210 | assert calls[0] == [exp1, res1] 211 | assert calls[1] == [exp2, res4] 212 | assert calls[2] == [exp2, res5] 213 | 214 | gb.destroy() 215 | 216 | 217 | def test_handles_weird_experiment_values(): 218 | gb = GrowthBook(attributes={"id": "1"}) 219 | 220 | assert ( 221 | gb.run( 222 | Experiment( 223 | key="my-test", 224 | variations=[0, 1], 225 | include=lambda: 1 / 0, 226 | ) 227 | ).inExperiment 228 | is False 229 | ) 230 | 231 | # Should fail gracefully 232 | gb._trackingCallback = lambda experiment, result: 1 / 0 233 | assert gb.run(Experiment(key="my-test", variations=[0, 1])).value == 1 234 | 235 | gb.subscribe(lambda: 1 / 0) 236 | assert gb.run(Experiment(key="my-new-test", variations=[0, 1])).value == 0 237 | 238 | gb.destroy() 239 | 240 | 241 | def test_force_variation(): 242 | gb = GrowthBook(attributes={"id": "6"}) 243 | exp = Experiment(key="forced-test", variations=[0, 1]) 244 | assert gb.run(exp).value == 0 245 | 246 | getMockedCalls = getTrackingMock(gb) 247 | 248 | gb._overrides = { 249 | "forced-test": { 250 | "force": 1, 251 | }, 252 | } 253 | assert gb.run(exp).value == 1 254 | 255 | calls = getMockedCalls() 256 | assert len(calls) == 0 257 | 258 | gb.destroy() 259 | 260 | 261 | def test_uses_overrides(): 262 | gb = GrowthBook( 263 | attributes={"id": "1"}, 264 | overrides={ 265 | "my-test": { 266 | "coverage": 0.01, 267 | }, 268 | }, 269 | ) 270 | 271 | assert ( 272 | gb.run( 273 | Experiment( 274 | key="my-test", 275 | variations=[0, 1], 276 | ) 277 | ).inExperiment 278 | is False 279 | ) 280 | 281 | gb._overrides = { 282 | "my-test": { 283 | "url": r"^\\/path", 284 | }, 285 | } 286 | 287 | assert ( 288 | gb.run( 289 | Experiment( 290 | key="my-test", 291 | variations=[0, 1], 292 | ) 293 | ).inExperiment 294 | is False 295 | ) 296 | 297 | gb.destroy() 298 | 299 | 300 | def test_filters_user_groups(): 301 | gb = GrowthBook( 302 | attributes={"id": "123"}, 303 | groups={ 304 | "alpha": True, 305 | "beta": True, 306 | "internal": False, 307 | "qa": False, 308 | }, 309 | ) 310 | 311 | assert ( 312 | gb.run( 313 | Experiment( 314 | key="my-test", 315 | variations=[0, 1], 316 | groups=["internal", "qa"], 317 | ) 318 | ).inExperiment 319 | is False 320 | ) 321 | 322 | assert ( 323 | gb.run( 324 | Experiment( 325 | key="my-test", 326 | variations=[0, 1], 327 | groups=["internal", "qa", "beta"], 328 | ) 329 | ).inExperiment 330 | is True 331 | ) 332 | 333 | assert ( 334 | gb.run( 335 | Experiment( 336 | key="my-test", 337 | variations=[0, 1], 338 | ) 339 | ).inExperiment 340 | is True 341 | ) 342 | 343 | gb.destroy() 344 | 345 | 346 | def test_runs_custom_include_callback(): 347 | gb = GrowthBook(user={"id": "1"}) 348 | assert ( 349 | gb.run( 350 | Experiment(key="my-test", variations=[0, 1], include=lambda: False) 351 | ).inExperiment 352 | is False 353 | ) 354 | 355 | gb.destroy() 356 | 357 | 358 | def test_supports_custom_user_hash_keys(): 359 | gb = GrowthBook(attributes={"id": "1", "company": "abc"}) 360 | 361 | exp = Experiment(key="my-test", variations=[0, 1], hashAttribute="company") 362 | 363 | res = gb.run(exp) 364 | 365 | assert res.hashAttribute == "company" 366 | assert res.hashValue == "abc" 367 | 368 | gb.destroy() 369 | 370 | 371 | def test_querystring_force_disabled_tracking(): 372 | gb = GrowthBook( 373 | attributes={"id": "1"}, 374 | url="http://example.com?forced-test-qs=1", 375 | ) 376 | getMockedCalls = getTrackingMock(gb) 377 | 378 | exp = Experiment( 379 | key="forced-test-qs", 380 | variations=[0, 1], 381 | ) 382 | gb.run(exp) 383 | 384 | calls = getMockedCalls() 385 | assert len(calls) == 0 386 | 387 | 388 | def test_url_targeting(): 389 | gb = GrowthBook( 390 | attributes={"id": "1"}, 391 | url="http://example.com", 392 | ) 393 | 394 | exp = Experiment( 395 | key="my-test", 396 | variations=[0, 1], 397 | url="^\\/post\\/[0-9]+", 398 | ) 399 | 400 | res = gb.run(exp) 401 | assert res.inExperiment is False 402 | assert res.value == 0 403 | 404 | gb._url = "http://example.com/post/123" 405 | res = gb.run(exp) 406 | assert res.inExperiment is True 407 | assert res.value == 1 408 | 409 | exp.url = "http:\\/\\/example.com\\/post\\/[0-9]+" 410 | res = gb.run(exp) 411 | assert res.inExperiment is True 412 | assert res.value == 1 413 | 414 | gb.destroy() 415 | 416 | 417 | def test_invalid_url_regex(): 418 | gb = GrowthBook( 419 | attributes={"id": "1"}, 420 | overrides={ 421 | "my-test": { 422 | "url": "???***[)", 423 | }, 424 | }, 425 | url="http://example.com", 426 | ) 427 | 428 | assert ( 429 | gb.run( 430 | Experiment( 431 | key="my-test", 432 | variations=[0, 1], 433 | ) 434 | ).value 435 | == 1 436 | ) 437 | 438 | gb.destroy() 439 | 440 | 441 | def test_ignores_draft_experiments(): 442 | gb = GrowthBook(attributes={"id": "1"}) 443 | exp = Experiment( 444 | key="my-test", 445 | status="draft", 446 | variations=[0, 1], 447 | ) 448 | 449 | res1 = gb.run(exp) 450 | gb._url = "http://example.com/?my-test=1" 451 | res2 = gb.run(exp) 452 | 453 | assert res1.inExperiment is False 454 | assert res1.hashUsed is False 455 | assert res1.value == 0 456 | assert res2.inExperiment is True 457 | assert res2.hashUsed is False 458 | assert res2.value == 1 459 | 460 | gb.destroy() 461 | 462 | 463 | def test_ignores_stopped_experiments_unless_forced(): 464 | gb = GrowthBook(attributes={"id": "1"}) 465 | expLose = Experiment( 466 | key="my-test", 467 | status="stopped", 468 | variations=[0, 1, 2], 469 | ) 470 | expWin = Experiment( 471 | key="my-test", 472 | status="stopped", 473 | variations=[0, 1, 2], 474 | force=2, 475 | ) 476 | 477 | res1 = gb.run(expLose) 478 | res2 = gb.run(expWin) 479 | 480 | assert res1.value == 0 481 | assert res1.inExperiment is False 482 | assert res2.value == 2 483 | assert res2.inExperiment is True 484 | 485 | gb.destroy() 486 | 487 | 488 | fired = {} 489 | 490 | 491 | def flagSubscription(experiment, result): 492 | fired["value"] = True 493 | 494 | 495 | def hasFired(): 496 | return fired.get("value", False) 497 | 498 | 499 | def resetFiredFlag(): 500 | fired["value"] = False 501 | 502 | 503 | def test_destroy_removes_subscriptions(): 504 | gb = GrowthBook(user={"id": "1"}) 505 | 506 | resetFiredFlag() 507 | gb.subscribe(flagSubscription) 508 | 509 | gb.run( 510 | Experiment( 511 | key="my-test", 512 | variations=[0, 1], 513 | ) 514 | ) 515 | 516 | assert hasFired() is True 517 | 518 | resetFiredFlag() 519 | gb.destroy() 520 | 521 | gb.run( 522 | Experiment( 523 | key="my-other-test", 524 | variations=[0, 1], 525 | ) 526 | ) 527 | 528 | assert hasFired() is False 529 | 530 | gb.destroy() 531 | 532 | 533 | def test_fires_subscriptions_correctly(): 534 | gb = GrowthBook( 535 | user={ 536 | "id": "1", 537 | }, 538 | ) 539 | 540 | resetFiredFlag() 541 | unsubscriber = gb.subscribe(flagSubscription) 542 | 543 | assert hasFired() is False 544 | 545 | exp = Experiment( 546 | key="my-test", 547 | variations=[0, 1], 548 | ) 549 | 550 | # Should fire when user is put in an experiment 551 | gb.run(exp) 552 | assert hasFired() is True 553 | 554 | # Does not fire if nothing has changed 555 | resetFiredFlag() 556 | gb.run(exp) 557 | assert hasFired() is False 558 | 559 | # Does not fire after unsubscribed 560 | unsubscriber() 561 | gb.run( 562 | Experiment( 563 | key="other-test", 564 | variations=[0, 1], 565 | ) 566 | ) 567 | 568 | assert hasFired() is False 569 | 570 | gb.destroy() 571 | 572 | 573 | def test_stores_assigned_variations_in_the_user(): 574 | gb = GrowthBook( 575 | attributes={ 576 | "id": "1", 577 | }, 578 | ) 579 | 580 | gb.run(Experiment(key="my-test", variations=[0, 1])) 581 | gb.run(Experiment(key="my-test-3", variations=[0, 1])) 582 | 583 | assigned = gb.getAllResults() 584 | assignedArr = [] 585 | 586 | for e in assigned: 587 | assignedArr.append({"key": e, "variation": assigned[e]["result"].variationId}) 588 | 589 | assert len(assignedArr) == 2 590 | assert assignedArr[0]["key"] == "my-test" 591 | assert assignedArr[0]["variation"] == 1 592 | assert assignedArr[1]["key"] == "my-test-3" 593 | assert assignedArr[1]["variation"] == 0 594 | 595 | gb.destroy() 596 | 597 | 598 | def test_getters_setters(): 599 | gb = GrowthBook() 600 | 601 | feat = Feature(defaultValue="yes", rules=[FeatureRule(force="no")]) 602 | featuresInput = {"feature-1": feat.to_dict()} 603 | attributes = {"id": "123", "url": "/"} 604 | 605 | gb.setFeatures(featuresInput) 606 | gb.setAttributes(attributes) 607 | 608 | featuresOutput = {k: v.to_dict() for (k, v) in gb.getFeatures().items()} 609 | 610 | assert featuresOutput == featuresInput 611 | assert attributes == gb.getAttributes() 612 | 613 | newAttrs = {"url": "/hello"} 614 | gb.setAttributes(newAttrs) 615 | assert newAttrs == gb.getAttributes() 616 | 617 | gb.destroy() 618 | 619 | 620 | def test_return_ruleid_when_evaluating_a_feature(): 621 | gb = GrowthBook( 622 | features={"feature": {"defaultValue": 0, "rules": [{"force": 1, "id": "foo"}]}} 623 | ) 624 | assert gb.eval_feature("feature").ruleId == "foo" 625 | gb.destroy() 626 | 627 | 628 | def test_feature_methods(): 629 | gb = GrowthBook( 630 | features={ 631 | "featureOn": {"defaultValue": 12}, 632 | "featureNone": {"defaultValue": None}, 633 | "featureOff": {"defaultValue": 0}, 634 | } 635 | ) 636 | 637 | assert gb.isOn("featureOn") is True 638 | assert gb.isOff("featureOn") is False 639 | assert gb.getFeatureValue("featureOn", 15) == 12 640 | 641 | assert gb.isOn("featureOff") is False 642 | assert gb.isOff("featureOff") is True 643 | assert gb.getFeatureValue("featureOff", 10) == 0 644 | 645 | assert gb.isOn("featureNone") is False 646 | assert gb.isOff("featureNone") is True 647 | assert gb.getFeatureValue("featureNone", 10) == 10 648 | 649 | gb.destroy() 650 | 651 | 652 | class MockHttpResp: 653 | def __init__(self, status: int, data: str) -> None: 654 | self.status = status 655 | self.data = data.encode("utf-8") 656 | 657 | 658 | def test_feature_repository(mocker): 659 | m = mocker.patch.object(feature_repo, "_get") 660 | expected = {"features": {"feature": {"defaultValue": 5}}} 661 | m.return_value = MockHttpResp(200, json.dumps(expected)) 662 | features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 663 | 664 | m.assert_called_once_with("https://cdn.growthbook.io/api/features/sdk-abc123") 665 | assert features == expected 666 | 667 | # Uses in-memory cache for the 2nd call 668 | features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 669 | assert m.call_count == 1 670 | assert features == expected 671 | 672 | # Does a new request if cache entry is expired 673 | feature_repo.cache.cache["https://cdn.growthbook.io::sdk-abc123"].expires = ( 674 | time() - 10 675 | ) 676 | features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 677 | assert m.call_count == 2 678 | assert features == expected 679 | 680 | feature_repo.clear_cache() 681 | 682 | 683 | def test_feature_repository_error(mocker): 684 | m = mocker.patch.object(feature_repo, "_get") 685 | m.return_value = MockHttpResp(400, "400 Error") 686 | features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 687 | 688 | m.assert_called_once_with("https://cdn.growthbook.io/api/features/sdk-abc123") 689 | assert features is None 690 | 691 | # Does not cache errors 692 | features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 693 | assert m.call_count == 2 694 | assert features is None 695 | 696 | # Handles broken JSON response 697 | m.return_value = MockHttpResp(200, "{'corrupted':6('4") 698 | features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 699 | assert m.call_count == 3 700 | assert features is None 701 | 702 | feature_repo.clear_cache() 703 | 704 | 705 | def test_feature_repository_encrypted(mocker): 706 | m = mocker.patch.object(feature_repo, "_get") 707 | m.return_value = MockHttpResp( 708 | 200, 709 | json.dumps( 710 | { 711 | "features": {}, 712 | "encryptedFeatures": "m5ylFM6ndyOJA2OPadubkw==.Uu7ViqgKEt/dWvCyhI46q088PkAEJbnXKf3KPZjf9IEQQ+A8fojNoxw4wIbPX3aj", 713 | } 714 | ), 715 | ) 716 | features = feature_repo.load_features( 717 | "https://cdn.growthbook.io", "sdk-abc123", "Zvwv/+uhpFDznZ6SX28Yjg==" 718 | ) 719 | 720 | m.assert_called_once_with("https://cdn.growthbook.io/api/features/sdk-abc123") 721 | assert features == {"features": {"feature": {"defaultValue": True}}} 722 | 723 | feature_repo.clear_cache() 724 | 725 | # Raises exception if missing decryption key 726 | with pytest.raises(Exception): 727 | feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") 728 | 729 | 730 | def test_load_features(mocker): 731 | m = mocker.patch.object(feature_repo, "_get") 732 | m.return_value = MockHttpResp( 733 | 200, json.dumps({"features": {"feature": {"defaultValue": 5}}}) 734 | ) 735 | 736 | gb = GrowthBook(api_host="https://cdn.growthbook.io", client_key="sdk-abc123") 737 | 738 | assert m.call_count == 0 739 | 740 | gb.load_features() 741 | m.assert_called_once_with("https://cdn.growthbook.io/api/features/sdk-abc123") 742 | 743 | assert gb.get_features()["feature"].to_dict() == {"defaultValue": 5, "rules": []} 744 | 745 | feature_repo.clear_cache() 746 | gb.destroy() 747 | 748 | 749 | def test_loose_unmarshalling(mocker): 750 | m = mocker.patch.object(feature_repo, "_get") 751 | m.return_value = MockHttpResp(200, json.dumps({ 752 | "features": { 753 | "feature": { 754 | "defaultValue": 5, 755 | "rules": [ 756 | { 757 | "condition": {"country": "US"}, 758 | "force": 3, 759 | "hashVersion": 1, 760 | "unknown": "foo" 761 | }, 762 | { 763 | "key": "my-exp", 764 | "hashVersion": 2, 765 | "variations": [0, 1], 766 | "meta": [ 767 | { 768 | "key": "control", 769 | "unknown": "foo" 770 | }, 771 | { 772 | "key": "variation1", 773 | "unknown": "foo" 774 | } 775 | ], 776 | "filters": [ 777 | { 778 | "seed": "abc123", 779 | "ranges": [[0, 0.0001]], 780 | "hashVersion": 2, 781 | "attribute": "id", 782 | "unknown": "foo" 783 | } 784 | ] 785 | }, 786 | { 787 | "unknownRuleType": "foo" 788 | } 789 | ], 790 | "unknown": "foo" 791 | } 792 | }, 793 | "unknown": "foo" 794 | })) 795 | 796 | gb = GrowthBook(api_host="https://cdn.growthbook.io", client_key="sdk-abc123") 797 | 798 | assert m.call_count == 0 799 | 800 | gb.load_features() 801 | m.assert_called_once_with("https://cdn.growthbook.io/api/features/sdk-abc123") 802 | 803 | assert gb.get_features()["feature"].to_dict() == { 804 | "defaultValue": 5, 805 | "rules": [ 806 | { 807 | "condition": {"country": "US"}, 808 | "force": 3, 809 | "hashVersion": 1 810 | }, 811 | { 812 | "key": "my-exp", 813 | "hashVersion": 2, 814 | "variations": [0, 1], 815 | "meta": [ 816 | { 817 | "key": "control", 818 | "unknown": "foo" 819 | }, 820 | { 821 | "key": "variation1", 822 | "unknown": "foo" 823 | } 824 | ], 825 | "filters": [ 826 | { 827 | "seed": "abc123", 828 | "ranges": [[0, 0.0001]], 829 | "hashVersion": 2, 830 | "attribute": "id", 831 | "unknown": "foo" 832 | } 833 | ] 834 | }, 835 | { 836 | "hashVersion": 1 837 | } 838 | ] 839 | } 840 | 841 | value = gb.get_feature_value("feature", -1) 842 | assert value == 5 843 | 844 | feature_repo.clear_cache() 845 | gb.destroy() 846 | 847 | 848 | def test_sticky_bucket_service(mocker): 849 | # Start forcing everyone to variation1 850 | features = { 851 | "feature": { 852 | "defaultValue": 5, 853 | "rules": [{ 854 | "key": "exp", 855 | "variations": [0, 1], 856 | "weights": [0, 1], 857 | "meta": [ 858 | {"key": "control"}, 859 | {"key": "variation1"} 860 | ] 861 | }] 862 | }, 863 | } 864 | 865 | service = InMemoryStickyBucketService() 866 | gb = GrowthBook( 867 | sticky_bucket_service=service, 868 | attributes={ 869 | "id": "1" 870 | }, 871 | features=features 872 | ) 873 | 874 | assert gb.get_feature_value("feature", -1) == 1 875 | assert service.get_assignments("id", "1") == { 876 | "attributeName": "id", 877 | "attributeValue": "1", 878 | "assignments": { 879 | "exp__0": "variation1" 880 | } 881 | } 882 | 883 | logger.debug("Change weights and ensure old user still gets variation") 884 | features["feature"]["rules"][0]["weights"] = [1, 0] 885 | gb.set_features(features) 886 | assert gb.get_feature_value("feature", -1) == 1 887 | 888 | logger.debug("New GrowthBook instance should also get variation") 889 | gb2 = GrowthBook( 890 | sticky_bucket_service=service, 891 | attributes={ 892 | "id": "1" 893 | }, 894 | features=features 895 | ) 896 | assert gb2.get_feature_value("feature", -1) == 1 897 | gb2.destroy() 898 | 899 | logger.debug("New users should get control") 900 | gb.set_attributes({"id": "2"}) 901 | assert gb.get_feature_value("feature", -1) == 0 902 | 903 | logger.debug("Bumping bucketVersion, should reset sticky buckets") 904 | gb.set_attributes({"id": "1"}) 905 | features["feature"]["rules"][0]["bucketVersion"] = 1 906 | gb.set_features(features) 907 | assert gb.get_feature_value("feature", -1) == 0 908 | 909 | assert service.get_assignments("id", "1") == { 910 | "attributeName": "id", 911 | "attributeValue": "1", 912 | "assignments": { 913 | "exp__0": "variation1", 914 | "exp__1": "control" 915 | } 916 | } 917 | gb.destroy() 918 | service.destroy() 919 | 920 | 921 | def test_ttl_automatic_feature_refresh(mocker): 922 | """Test that GrowthBook instances automatically get updated features when cache expires during evaluation""" 923 | # Mock responses to simulate feature flag changes 924 | mock_responses = [ 925 | {"features": {"test_feature": {"defaultValue": False}}, "savedGroups": {}}, 926 | {"features": {"test_feature": {"defaultValue": True}}, "savedGroups": {}} 927 | ] 928 | 929 | call_count = 0 930 | def mock_fetch_features(api_host, client_key, decryption_key=""): 931 | nonlocal call_count 932 | response = mock_responses[min(call_count, len(mock_responses) - 1)] 933 | call_count += 1 934 | return response 935 | 936 | # Clear cache and mock the fetch method 937 | feature_repo.clear_cache() 938 | m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) 939 | 940 | # Create GrowthBook instance with short TTL 941 | gb = GrowthBook( 942 | api_host="https://cdn.growthbook.io", 943 | client_key="test-key", 944 | cache_ttl=1 # 1 second TTL for testing 945 | ) 946 | 947 | try: 948 | # Initial evaluation - should trigger first load 949 | assert gb.is_on('test_feature') == False 950 | assert call_count == 1 951 | 952 | # Manually expire the cache by setting expiry time to past 953 | cache_key = "https://cdn.growthbook.io::test-key" 954 | if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: 955 | feature_repo.cache.cache[cache_key].expires = time() - 10 956 | 957 | # Next evaluation should automatically refresh cache and update features 958 | assert gb.is_on('test_feature') == True 959 | assert call_count == 2 960 | 961 | finally: 962 | gb.destroy() 963 | feature_repo.clear_cache() 964 | 965 | 966 | def test_multiple_instances_get_updated_on_cache_expiry(mocker): 967 | """Test that multiple GrowthBook instances all get updated when cache expires during evaluation""" 968 | mock_responses = [ 969 | {"features": {"test_feature": {"defaultValue": "v1"}}, "savedGroups": {}}, 970 | {"features": {"test_feature": {"defaultValue": "v2"}}, "savedGroups": {}} 971 | ] 972 | 973 | call_count = 0 974 | def mock_fetch_features(api_host, client_key, decryption_key=""): 975 | nonlocal call_count 976 | response = mock_responses[min(call_count, len(mock_responses) - 1)] 977 | call_count += 1 978 | return response 979 | 980 | feature_repo.clear_cache() 981 | m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) 982 | 983 | # Create multiple GrowthBook instances 984 | gb1 = GrowthBook(api_host="https://cdn.growthbook.io", client_key="test-key") 985 | gb2 = GrowthBook(api_host="https://cdn.growthbook.io", client_key="test-key") 986 | 987 | try: 988 | # Initial evaluation from first instance - should trigger first load 989 | assert gb1.get_feature_value('test_feature', 'default') == "v1" 990 | assert call_count == 1 991 | 992 | # Second instance should use cached value (no additional API call) 993 | assert gb2.get_feature_value('test_feature', 'default') == "v1" 994 | assert call_count == 1 # Still 1, used cache 995 | 996 | # Manually expire the cache 997 | cache_key = "https://cdn.growthbook.io::test-key" 998 | if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: 999 | feature_repo.cache.cache[cache_key].expires = time() - 10 1000 | 1001 | # Next evaluation should automatically refresh and notify both instances via callbacks 1002 | assert gb1.get_feature_value('test_feature', 'default') == "v2" 1003 | assert call_count == 2 1004 | 1005 | # Second instance should also have the updated value due to callbacks 1006 | assert gb2.get_feature_value('test_feature', 'default') == "v2" 1007 | 1008 | finally: 1009 | gb1.destroy() 1010 | gb2.destroy() 1011 | feature_repo.clear_cache() -------------------------------------------------------------------------------- /tests/test_growthbook_client.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from unittest.mock import patch 3 | 4 | import pytest_asyncio 5 | 6 | try: 7 | from unittest.mock import AsyncMock 8 | except ImportError: 9 | # For Python 3.7 compatibility 10 | from unittest.mock import MagicMock 11 | class AsyncMock(MagicMock): 12 | async def __call__(self, *args, **kwargs): 13 | return super(AsyncMock, self).__call__(*args, **kwargs) 14 | 15 | from growthbook import InMemoryStickyBucketService 16 | import pytest 17 | import asyncio 18 | import os 19 | import json 20 | 21 | from growthbook.common_types import Experiment, Options 22 | from growthbook.growthbook_client import ( 23 | GrowthBookClient, 24 | UserContext, 25 | FeatureRefreshStrategy, 26 | EnhancedFeatureRepository 27 | ) 28 | 29 | @pytest.fixture 30 | def mock_features_response(): 31 | return { 32 | "features": { 33 | "test-feature": { 34 | "defaultValue": True, 35 | "rules": [] 36 | } 37 | }, 38 | "savedGroups": {} 39 | } 40 | 41 | @pytest.fixture 42 | def mock_options(): 43 | return Options( 44 | api_host="https://test.growthbook.io", 45 | client_key="test_key", 46 | decryption_key="test_decrypt", 47 | cache_ttl=60, 48 | enabled=True, 49 | refresh_strategy=FeatureRefreshStrategy.STALE_WHILE_REVALIDATE 50 | ) 51 | 52 | 53 | @pytest.fixture 54 | def mock_sse_data(): 55 | return { 56 | 'type': 'features', 57 | 'data': { 58 | 'features': { 59 | 'feature-1': {'defaultValue': True}, 60 | 'feature-2': {'defaultValue': False} 61 | } 62 | } 63 | } 64 | 65 | @pytest_asyncio.fixture(autouse=True) 66 | async def cleanup_singleton(): 67 | """Clean up singleton instance between tests""" 68 | yield 69 | # Clear singleton instances after each test 70 | EnhancedFeatureRepository._instances = {} 71 | await asyncio.sleep(0.1) # Allow tasks to clean up 72 | 73 | @pytest.mark.asyncio 74 | async def test_initialization_for_failure(mock_options): 75 | with patch('growthbook.growthbook_client.EnhancedFeatureRepository.load_features_async') as mock_load: 76 | mock_load.side_effect = Exception("Network error") 77 | client = GrowthBookClient(mock_options) 78 | success = await client.initialize() 79 | assert success == False 80 | assert mock_load.call_count == 1 81 | 82 | @pytest.mark.asyncio 83 | async def test_sse_connection_lifecycle(mock_options, mock_features_response): 84 | with patch('growthbook.growthbook_client.EnhancedFeatureRepository.load_features_async', 85 | new_callable=AsyncMock, return_value=mock_features_response) as mock_load: 86 | 87 | client = GrowthBookClient( 88 | Options(**{**mock_options.__dict__, 89 | "refresh_strategy": FeatureRefreshStrategy.SERVER_SENT_EVENTS}) 90 | ) 91 | 92 | with patch('growthbook.growthbook_client.EnhancedFeatureRepository._maintain_sse_connection') as mock_sse: 93 | await client.initialize() 94 | assert mock_sse.called 95 | await client.close() 96 | 97 | @pytest.mark.asyncio 98 | async def test_feature_repository_load(): 99 | repo = EnhancedFeatureRepository( 100 | api_host="https://test.growthbook.io", 101 | client_key="test_key" 102 | ) 103 | features_response = { 104 | "features": {"test-feature": {"defaultValue": True}}, 105 | "savedGroups": {} 106 | } 107 | 108 | with patch('growthbook.FeatureRepository.load_features_async', 109 | new_callable=AsyncMock, return_value=features_response) as mock_load: 110 | result = await repo.load_features_async(api_host="", client_key="") 111 | assert result == features_response 112 | 113 | @pytest.mark.asyncio 114 | async def test_initialize_success(mock_options, mock_features_response): 115 | with patch('growthbook.growthbook_client.EnhancedFeatureRepository.load_features_async', 116 | new_callable=AsyncMock, return_value=mock_features_response) as mock_load, \ 117 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 118 | new_callable=AsyncMock, return_value=None): 119 | 120 | client = GrowthBookClient(mock_options) 121 | success = await client.initialize() 122 | 123 | # result = client.eval_feature('test-feature') 124 | # print(f'result= {result}') 125 | assert success == True 126 | 127 | @pytest.mark.asyncio 128 | async def test_refresh_operation_lock(): 129 | """Verify refresh_operation lock prevents concurrent refreshes""" 130 | repo = EnhancedFeatureRepository( 131 | api_host="https://test.growthbook.io", 132 | client_key="test_key" 133 | ) 134 | 135 | results = [] 136 | async def refresh_task(): 137 | async with repo.refresh_operation() as should_refresh: 138 | results.append(should_refresh) 139 | await asyncio.sleep(0.1) # Simulate work 140 | return should_refresh 141 | 142 | await asyncio.gather(*[refresh_task() for _ in range(5)]) 143 | assert sum(1 for r in results if r) == 1 # Only one task should get True 144 | assert sum(1 for r in results if not r) == 4 # Rest should get False 145 | 146 | 147 | @pytest.mark.asyncio 148 | async def test_concurrent_feature_updates(): 149 | """Verify FeatureCache thread safety during concurrent updates""" 150 | repo = EnhancedFeatureRepository( 151 | api_host="https://test.growthbook.io", 152 | client_key="test_key" 153 | ) 154 | features = {f"feature-{i}": {"defaultValue": i} for i in range(10)} 155 | 156 | async def update_features(feature_subset): 157 | await repo._handle_feature_update({"features": feature_subset, "savedGroups": {}}) 158 | 159 | await asyncio.gather(*[ 160 | update_features({k: features[k]}) 161 | for k in features 162 | ]) 163 | 164 | cache_state = repo._feature_cache.get_current_state() 165 | # Verify all features were properly stored 166 | assert cache_state["features"] == features 167 | assert cache_state["savedGroups"] == {} 168 | 169 | @pytest.mark.asyncio 170 | async def test_callback_thread_safety(): 171 | """Verify callback invocations are thread-safe""" 172 | repo = EnhancedFeatureRepository( 173 | api_host="https://test.growthbook.io", 174 | client_key="test_key" 175 | ) 176 | 177 | received_callbacks = [] 178 | async def test_callback(features): 179 | received_callbacks.append(features) 180 | 181 | repo.add_callback(test_callback) 182 | test_features = [{"features": {f"f{i}": {"value": i}}, "savedGroups": {}} for i in range(5)] 183 | 184 | await asyncio.gather(*[ 185 | repo._handle_feature_update(update) 186 | for update in test_features 187 | ]) 188 | 189 | assert len(received_callbacks) == 5 190 | 191 | @pytest.mark.asyncio 192 | async def test_http_refresh(): 193 | """Verify HTTP refresh mechanism works correctly""" 194 | repo = EnhancedFeatureRepository( 195 | api_host="https://test.growthbook.io", 196 | client_key="test_key" 197 | ) 198 | 199 | # Mock responses for load_features_async 200 | feature_updates = [ 201 | {"features": {"feature1": {"defaultValue": 1}}, "savedGroups": {}}, 202 | {"features": {"feature1": {"defaultValue": 2}}, "savedGroups": {}} 203 | ] 204 | 205 | mock_load = AsyncMock() 206 | mock_load.side_effect = [feature_updates[0], feature_updates[1], *[feature_updates[1]] * 10] 207 | 208 | try: 209 | with patch('growthbook.FeatureRepository.load_features_async', mock_load): 210 | # Start HTTP refresh with a short interval for testing 211 | refresh_task = asyncio.create_task(repo._start_http_refresh(interval=0.1)) 212 | 213 | # Wait for two refresh cycles 214 | await asyncio.sleep(0.3) 215 | 216 | # Verify load_features_async was called at least twice 217 | assert mock_load.call_count == 3 218 | 219 | # Verify the latest feature state 220 | cache_state = repo._feature_cache.get_current_state() 221 | assert cache_state["features"]["feature1"] == {"defaultValue": 2} 222 | finally: 223 | # Ensure cleanup happens even if test fails 224 | await repo.stop_refresh() 225 | # Wait a bit to ensure task is fully cleaned up 226 | await asyncio.sleep(0.1) 227 | 228 | @pytest.mark.asyncio 229 | async def test_initialization_state_verification(mock_options, mock_features_response): 230 | """Verify feature state and callback registration after initialization""" 231 | callback_called = False 232 | features_received = None 233 | 234 | async def test_callback(features): 235 | nonlocal callback_called, features_received 236 | callback_called = True 237 | features_received = features 238 | 239 | with patch('growthbook.FeatureRepository.load_features_async', 240 | new_callable=AsyncMock, return_value=mock_features_response) as mock_load: 241 | 242 | client = GrowthBookClient(mock_options) 243 | client._features_repository.add_callback(test_callback) 244 | 245 | success = await client.initialize() 246 | await asyncio.sleep(0) 247 | 248 | assert success == True 249 | assert callback_called == True 250 | assert features_received == mock_features_response 251 | # Convert Feature objects to dict for comparison 252 | features_dict = { 253 | key: {"defaultValue": feature.defaultValue, "rules": feature.rules} 254 | for key, feature in client._global_context.features.items() 255 | } 256 | assert features_dict == mock_features_response["features"] 257 | 258 | @pytest.mark.asyncio 259 | async def test_sse_event_handling(mock_options): 260 | """Test SSE event handling and reconnection logic""" 261 | events = [ 262 | {'type': 'features', 'data': {'features': {'feature1': {'defaultValue': 1}}}}, 263 | {'type': 'ping', 'data': {}}, # Should be ignored 264 | {'type': 'features', 'data': {'features': {'feature1': {'defaultValue': 2}}}} 265 | ] 266 | 267 | async def mock_sse_handler(event_data): 268 | """Mock the SSE event handler to directly update feature cache""" 269 | if event_data['type'] == 'features': 270 | await client._features_repository._handle_feature_update(event_data['data']) 271 | 272 | with patch('growthbook.FeatureRepository.load_features_async', 273 | new_callable=AsyncMock, return_value={"features": {}, "savedGroups": {}}) as mock_load: 274 | 275 | # Create options with SSE strategy 276 | sse_options = Options( 277 | api_host=mock_options.api_host, 278 | client_key=mock_options.client_key, 279 | refresh_strategy=FeatureRefreshStrategy.SERVER_SENT_EVENTS 280 | ) 281 | 282 | client = GrowthBookClient(sse_options) 283 | 284 | try: 285 | await client.initialize() 286 | 287 | # Simulate SSE events directly 288 | for event in events: 289 | if event['type'] == 'features': 290 | await client._features_repository._handle_feature_update(event['data']) 291 | 292 | # print(f"AFTER TEST: Current cache state: {client._features_repository._feature_cache.get_current_state()}") 293 | # Verify feature update happened 294 | assert client._features_repository._feature_cache.get_current_state()["features"]["feature1"]["defaultValue"] == 2 295 | finally: 296 | # Ensure we clean up the SSE connection 297 | await client.close() 298 | 299 | @pytest.mark.asyncio 300 | async def test_http_refresh_backoff(): 301 | """Test HTTP refresh backoff strategy""" 302 | repo = EnhancedFeatureRepository( 303 | api_host="https://test.growthbook.io", 304 | client_key="test_key" 305 | ) 306 | 307 | call_times = [] 308 | failure_count = 0 309 | success_time = None 310 | done = asyncio.Event() 311 | 312 | async def mock_load(*args, **kwargs): 313 | nonlocal failure_count 314 | current_time = asyncio.get_event_loop().time() 315 | call_times.append(current_time) 316 | 317 | if failure_count < 3: 318 | failure_count += 1 319 | raise ConnectionError("Network error") 320 | 321 | nonlocal success_time 322 | if not success_time: 323 | success_time = current_time 324 | # Wait for at least one more call after success to verify normal interval 325 | if len(call_times) >= 5: 326 | done.set() 327 | return {"features": {}, "savedGroups": {}} 328 | 329 | try: 330 | with patch('growthbook.FeatureRepository.load_features_async', side_effect=mock_load): 331 | refresh_task = asyncio.create_task(repo._start_http_refresh(interval=0.1)) 332 | try: 333 | await asyncio.wait_for(done.wait(), timeout=5.0) 334 | except asyncio.TimeoutError: 335 | pass 336 | 337 | # Verify we had failures followed by success 338 | assert failure_count == 3, f"Expected 3 failures, got {failure_count}" 339 | assert len(call_times) >= 4, f"Expected at least 4 calls, got {len(call_times)}" 340 | 341 | # Verify backoff behavior - delays should generally increase during failures 342 | if len(call_times) >= 3: 343 | first_delay = call_times[1] - call_times[0] 344 | second_delay = call_times[2] - call_times[1] 345 | # Allow some flexibility in CI environments 346 | assert second_delay >= first_delay * 0.8, f"Second delay ({second_delay:.3f}) should be >= 80% of first delay ({first_delay:.3f})" 347 | 348 | # After success, verify we have reasonable timing for normal operation 349 | if len(call_times) >= 5: 350 | post_success_delay = call_times[4] - call_times[3] 351 | assert 0.05 <= post_success_delay <= 0.2, f"Post-success delay should be near 0.1s, got {post_success_delay:.3f}" 352 | 353 | finally: 354 | # Ensure cleanup happens even if test fails 355 | await repo.stop_refresh() 356 | # Wait a bit to ensure task is fully cleaned up 357 | await asyncio.sleep(0.1) 358 | 359 | @pytest.mark.asyncio 360 | async def test_concurrent_initialization(): 361 | """Test concurrent initialization attempts""" 362 | shared_response = { 363 | "features": { 364 | "test-feature": {"defaultValue": 0} 365 | }, 366 | "savedGroups": {} 367 | } 368 | loading_started = asyncio.Event() 369 | loading_wait = asyncio.Event() 370 | load_count = 0 371 | 372 | async def mock_load(*args, **kwargs): 373 | nonlocal load_count 374 | load_count += 1 375 | loading_started.set() 376 | await loading_wait.wait() 377 | shared_response["features"]["test-feature"]["defaultValue"] += 1 378 | return shared_response 379 | 380 | with patch('growthbook.FeatureRepository.load_features_async', side_effect=mock_load): 381 | client = GrowthBookClient(Options( 382 | api_host="https://test.growthbook.io", 383 | client_key="test_key" 384 | )) 385 | 386 | try: 387 | # Start concurrent initializations 388 | init_tasks = [asyncio.create_task(client.initialize()) for _ in range(5)] 389 | 390 | # Wait for the first load attempt to start 391 | await loading_started.wait() 392 | await asyncio.sleep(0.1) 393 | loading_wait.set() 394 | 395 | results = await asyncio.gather(*init_tasks, return_exceptions=True) 396 | 397 | # Verify results 398 | assert all(r == True for r in results) 399 | assert load_count > 1 400 | final_cache = client._features_repository._feature_cache.get_current_state() 401 | assert final_cache["features"]["test-feature"]["defaultValue"] == 6 402 | finally: 403 | # Ensure proper cleanup 404 | await client.close() 405 | # Wait for any pending tasks to complete 406 | await asyncio.sleep(0.1) 407 | # Get all tasks and cancel any remaining ones 408 | for task in asyncio.all_tasks(): 409 | if not task.done() and task != asyncio.current_task(): 410 | task.cancel() 411 | try: 412 | await task 413 | except asyncio.CancelledError: 414 | pass 415 | 416 | def pytest_generate_tests(metafunc): 417 | """Generate test cases from cases.json""" 418 | # Skip if the test doesn't need case data 419 | if not any(x.endswith('_data') for x in metafunc.fixturenames): 420 | return 421 | 422 | folder = os.path.abspath(os.path.dirname(__file__)) 423 | jsonfile = os.path.join(folder, "cases.json") 424 | with open(jsonfile) as file: 425 | data = json.load(file) 426 | 427 | # Map test functions to their data 428 | test_data_map = { 429 | 'test_eval_feature': 'feature', 430 | 'test_experiment_run': 'run', 431 | 'test_sticky_bucket': 'stickyBucket' 432 | } 433 | 434 | for func, data_key in test_data_map.items(): 435 | fixture_name = f"{func}_data" 436 | if fixture_name in metafunc.fixturenames: 437 | metafunc.parametrize(fixture_name, data.get(data_key, [])) 438 | 439 | @pytest.mark.asyncio 440 | async def test_eval_feature(test_eval_feature_data, base_client_setup): 441 | """Test feature evaluation similar to test_feature in test_growthbook.py""" 442 | _, ctx, key, expected = test_eval_feature_data 443 | 444 | # Get base setup 445 | user_attrs, client_opts, features_data = base_client_setup(ctx) 446 | 447 | # Clear any existing singleton instances 448 | EnhancedFeatureRepository._instances = {} 449 | 450 | try: 451 | # Set up mocks for both FeatureRepository and EnhancedFeatureRepository 452 | with patch('growthbook.FeatureRepository.load_features_async', 453 | new_callable=AsyncMock, return_value=features_data), \ 454 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 455 | new_callable=AsyncMock), \ 456 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.stop_refresh', 457 | new_callable=AsyncMock): 458 | 459 | # Create and initialize client 460 | async with GrowthBookClient(Options(**client_opts)) as client: 461 | result = await client.eval_feature(key, UserContext(**user_attrs)) 462 | 463 | if "experiment" in expected: 464 | expected["experiment"] = Experiment(**expected["experiment"]).to_dict() 465 | 466 | assert result.to_dict() == expected 467 | except Exception as e: 468 | print(f"Error during test execution: {str(e)}") 469 | raise 470 | finally: 471 | await client.close() 472 | await asyncio.sleep(0.1) 473 | 474 | @pytest.mark.asyncio 475 | async def test_experiment_run(test_experiment_run_data, base_client_setup): 476 | """Test experiment running similar to test_run in test_growthbook.py""" 477 | _, ctx, exp, value, inExperiment, hashUsed = test_experiment_run_data 478 | 479 | # Get base setup 480 | user_attrs, client_opts, features_data = base_client_setup(ctx) 481 | 482 | # Clear any existing singleton instances 483 | EnhancedFeatureRepository._instances = {} 484 | 485 | try: 486 | # Set up mocks for both FeatureRepository and EnhancedFeatureRepository 487 | with patch('growthbook.FeatureRepository.load_features_async', 488 | new_callable=AsyncMock, return_value=features_data), \ 489 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 490 | new_callable=AsyncMock), \ 491 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.stop_refresh', 492 | new_callable=AsyncMock): 493 | 494 | # Create and initialize client 495 | async with GrowthBookClient(Options(**client_opts)) as client: 496 | result = await client.run(Experiment(**exp), UserContext(**user_attrs)) 497 | 498 | # Verify experiment results 499 | assert result.value == value 500 | assert result.inExperiment == inExperiment 501 | assert result.hashUsed == hashUsed 502 | except Exception as e: 503 | print(f"Error during test execution: {str(e)}") 504 | raise 505 | finally: 506 | await client.close() 507 | await asyncio.sleep(0.1) 508 | 509 | @pytest.mark.asyncio 510 | async def test_feature_methods(): 511 | """Test feature helper methods (isOn, isOff, getFeatureValue)""" 512 | features_data = { 513 | "features": { 514 | "featureOn": {"defaultValue": 12}, 515 | "featureNone": {"defaultValue": None}, 516 | "featureOff": {"defaultValue": 0} 517 | }, 518 | "savedGroups": {} 519 | } 520 | 521 | # Simple client options 522 | client_opts = { 523 | 'api_host': "https://localhost.growthbook.io", 524 | 'client_key': "test-key", 525 | 'enabled': True 526 | } 527 | 528 | # Clear any existing singleton instances 529 | EnhancedFeatureRepository._instances = {} 530 | user_context = UserContext(attributes={"id": "user-1"}) 531 | 532 | try: 533 | # Set up mocks for both FeatureRepository and EnhancedFeatureRepository 534 | with patch('growthbook.FeatureRepository.load_features_async', 535 | new_callable=AsyncMock, return_value=features_data), \ 536 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 537 | new_callable=AsyncMock), \ 538 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.stop_refresh', 539 | new_callable=AsyncMock): 540 | 541 | # Create and initialize client 542 | async with GrowthBookClient(Options(**client_opts)) as client: 543 | # Test isOn 544 | assert await client.is_on("featureOn", user_context) is True 545 | assert await client.is_on("featureOff", user_context) is False 546 | assert await client.is_on("featureNone", user_context) is False 547 | 548 | # Test isOff 549 | assert await client.is_off("featureOn", user_context) is False 550 | assert await client.is_off("featureOff", user_context) is True 551 | assert await client.is_off("featureNone", user_context) is True 552 | 553 | # Test getFeatureValue 554 | assert await client.get_feature_value("featureOn", 15, user_context) == 12 555 | assert await client.get_feature_value("featureOff", 10, user_context) == 0 556 | assert await client.get_feature_value("featureNone", 10, user_context) == 10 557 | assert await client.get_feature_value("nonexistent", "default", user_context) == "default" 558 | except Exception as e: 559 | print(f"Error during test execution: {str(e)}") 560 | raise 561 | finally: 562 | await client.close() 563 | await asyncio.sleep(0.1) 564 | 565 | @pytest.fixture 566 | def base_client_setup(): 567 | """Common setup for client tests""" 568 | def _setup(ctx): 569 | # Separate client options from user context 570 | user_attrs = { 571 | "attributes": ctx.get("attributes", {}), 572 | "url": ctx.get("url", ""), 573 | "groups": ctx.get("groups", {}), 574 | "forced_variations": ctx.get("forcedVariations", {}) 575 | } 576 | 577 | # Base client options 578 | client_opts = { 579 | 'api_host': "https://localhost.growthbook.io", 580 | 'client_key': "test-key", 581 | 'enabled': ctx.get("enabled", True), 582 | 'qa_mode': ctx.get("qaMode", False) 583 | } 584 | 585 | # Features data structure 586 | features_data = { 587 | "features": ctx.get("features", {}), 588 | "savedGroups": ctx.get("savedGroups", {}) 589 | } 590 | 591 | return user_attrs, client_opts, features_data 592 | return _setup 593 | 594 | @pytest.mark.asyncio 595 | async def test_sticky_bucket(test_sticky_bucket_data, base_client_setup): 596 | """Test sticky bucket functionality in GrowthBookClient""" 597 | _, ctx, initial_docs, key, expected_result, expected_docs = test_sticky_bucket_data 598 | 599 | # Initialize sticky bucket service with test data 600 | service = InMemoryStickyBucketService() 601 | 602 | # Add initial documents to the service 603 | for doc in initial_docs: 604 | service.save_assignments(doc) 605 | 606 | # Handle sticky bucket identifier attributes mapping 607 | if 'stickyBucketIdentifierAttributes' in ctx: 608 | ctx['sticky_bucket_identifier_attributes'] = ctx['stickyBucketIdentifierAttributes'] 609 | ctx.pop('stickyBucketIdentifierAttributes') 610 | 611 | # Handle sticky bucket assignment docs 612 | if 'stickyBucketAssignmentDocs' in ctx: 613 | service.docs = ctx['stickyBucketAssignmentDocs'] 614 | ctx.pop('stickyBucketAssignmentDocs') 615 | 616 | # Get base setup 617 | user_attrs, client_opts, features_data = base_client_setup(ctx) 618 | 619 | # Add sticky bucket service to client options 620 | client_opts['sticky_bucket_service'] = service 621 | 622 | # Clear any existing singleton instances 623 | EnhancedFeatureRepository._instances = {} 624 | 625 | try: 626 | # Set up mocks 627 | with patch('growthbook.FeatureRepository.load_features_async', 628 | new_callable=AsyncMock, return_value=features_data), \ 629 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 630 | new_callable=AsyncMock), \ 631 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.stop_refresh', 632 | new_callable=AsyncMock): 633 | 634 | # Create and initialize client 635 | async with GrowthBookClient(Options(**client_opts)) as client: 636 | # Evaluate feature 637 | result = await client.eval_feature(key, UserContext(**user_attrs)) 638 | 639 | # Verify experiment result 640 | if not result.experimentResult: 641 | assert None == expected_result 642 | else: 643 | assert result.experimentResult.to_dict() == expected_result 644 | 645 | # Verify sticky bucket assignments - check each expected doc individually 646 | for doc_key, expected_doc in expected_docs.items(): 647 | assert service.docs[doc_key] == expected_doc 648 | except Exception as e: 649 | print(f"Error during test execution: {str(e)}") 650 | raise 651 | finally: 652 | await client.close() 653 | service.destroy() 654 | await asyncio.sleep(0.1) 655 | 656 | async def getTrackingMock(client: GrowthBookClient): 657 | """Helper function to mock tracking for tests""" 658 | calls = [] 659 | 660 | def track(experiment, result): 661 | calls.append([experiment, result]) 662 | 663 | client.options.on_experiment_viewed = track 664 | return lambda: calls 665 | 666 | @pytest.mark.asyncio 667 | async def test_tracking(): 668 | """Test experiment tracking behavior""" 669 | # Create client with minimal options 670 | client = GrowthBookClient(Options( 671 | api_host="https://localhost.growthbook.io", 672 | client_key="test-key", 673 | enabled=True 674 | )) 675 | 676 | getMockedCalls = await getTrackingMock(client) 677 | 678 | # Create test experiments 679 | exp1 = Experiment( 680 | key="my-tracked-test", 681 | variations=[0, 1], 682 | ) 683 | exp2 = Experiment( 684 | key="my-other-tracked-test", 685 | variations=[0, 1], 686 | ) 687 | 688 | # Create user context 689 | user_context = UserContext(attributes={"id": "1"}) 690 | 691 | try: 692 | # Set up mocks for feature repository 693 | with patch('growthbook.FeatureRepository.load_features_async', 694 | new_callable=AsyncMock, return_value={"features": {}, "savedGroups": {}}), \ 695 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 696 | new_callable=AsyncMock), \ 697 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.stop_refresh', 698 | new_callable=AsyncMock): 699 | 700 | # Initialize client 701 | await client.initialize() 702 | 703 | # Run experiments 704 | res1 = await client.run(exp1, user_context) 705 | await client.run(exp1, user_context) # Should not track duplicate 706 | await client.run(exp1, user_context) # Should not track duplicate 707 | res4 = await client.run(exp2, user_context) 708 | 709 | # Change user attributes 710 | user_context.attributes = {"id": "2"} 711 | res5 = await client.run(exp2, user_context) 712 | 713 | # Verify tracking calls 714 | calls = getMockedCalls() 715 | assert len(calls) == 3, "Expected exactly 3 tracking calls" 716 | assert calls[0] == [exp1, res1], "First tracking call mismatch" 717 | assert calls[1] == [exp2, res4], "Second tracking call mismatch" 718 | assert calls[2] == [exp2, res5], "Third tracking call mismatch" 719 | 720 | finally: 721 | await client.close() 722 | 723 | @pytest.mark.asyncio 724 | async def test_handles_tracking_errors(): 725 | """Test graceful handling of tracking callback errors""" 726 | client = GrowthBookClient(Options( 727 | api_host="https://localhost.growthbook.io", 728 | client_key="test-key", 729 | enabled=True 730 | )) 731 | 732 | # Set up tracking callback that raises an error 733 | def failing_track(experiment, result): 734 | raise Exception("Tracking failed") 735 | 736 | client.options.on_experiment_viewed = failing_track 737 | 738 | # Create test experiment 739 | exp = Experiment( 740 | key="error-test", 741 | variations=[0, 1], 742 | ) 743 | user_context = UserContext(attributes={"id": "1"}) 744 | 745 | try: 746 | # Set up mocks 747 | with patch('growthbook.FeatureRepository.load_features_async', 748 | new_callable=AsyncMock, return_value={"features": {}, "savedGroups": {}}), \ 749 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.start_feature_refresh', 750 | new_callable=AsyncMock), \ 751 | patch('growthbook.growthbook_client.EnhancedFeatureRepository.stop_refresh', 752 | new_callable=AsyncMock): 753 | 754 | await client.initialize() 755 | 756 | # Should not raise exception despite tracking error 757 | result = await client.run(exp, user_context) 758 | assert result is not None, "Experiment should run despite tracking error" 759 | 760 | finally: 761 | await client.close() --------------------------------------------------------------------------------