├── .coveragerc ├── .github ├── changelog_template.md ├── get-changelog-diff.sh ├── has-functional-changes.sh ├── is-version-number-acceptable.sh ├── publish-git-tag.sh ├── request-model-versions.sh └── workflows │ ├── pr.yml │ └── push.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── changelog.yaml ├── changelog_entry.yaml ├── dashboard ├── app.py └── experiments │ ├── gpt4_api_user.py │ └── requirements.txt ├── docker └── Dockerfile ├── gcp ├── Dockerfile ├── README.md ├── bump_country_package.py ├── dispatch.yaml ├── export.py └── policyengine_api │ ├── Dockerfile │ ├── app.yaml │ └── start.sh ├── policyengine_api ├── ai_prompts │ ├── __init__.py │ └── simulation_analysis_prompt.py ├── ai_templates │ └── simulation_analysis_template.py ├── api.py ├── constants.py ├── country.py ├── data │ ├── README.md │ ├── __init__.py │ ├── data.py │ ├── initialise.sql │ └── initialise_local.sql ├── endpoints │ ├── __init__.py │ ├── economy │ │ ├── compare.py │ │ └── reform_impact.py │ ├── home.py │ ├── household.py │ ├── policy.py │ └── simulation.py ├── gcp_logging.py ├── jobs │ ├── __init__.py │ ├── base_job.py │ ├── calculate_economy_simulation_job.py │ └── tasks │ │ ├── __init__.py │ │ └── compute_general_economy.py ├── openapi_spec.yaml ├── routes │ ├── ai_prompt_routes.py │ ├── economy_routes.py │ ├── error_routes.py │ ├── household_routes.py │ ├── metadata_routes.py │ ├── policy_routes.py │ ├── simulation_analysis_routes.py │ ├── tracer_analysis_routes.py │ └── user_profile_routes.py ├── services │ ├── __init__.py │ ├── ai_analysis_service.py │ ├── ai_prompt_service.py │ ├── economy_service.py │ ├── household_service.py │ ├── job_service.py │ ├── metadata_service.py │ ├── policy_service.py │ ├── reform_impacts_service.py │ ├── simulation_analysis_service.py │ ├── tracer_analysis_service.py │ └── user_service.py ├── setup_data.py ├── utils │ ├── __init__.py │ ├── cache_utils.py │ ├── get_current_law.py │ ├── hugging_face.py │ ├── json.py │ ├── payload_validators │ │ ├── __init__.py │ │ ├── ai │ │ │ ├── __init__.py │ │ │ └── validate_sim_analysis_payload.py │ │ ├── validate_country.py │ │ ├── validate_household_payload.py │ │ ├── validate_set_policy_payload.py │ │ └── validate_tracer_analysis_payload.py │ ├── singleton.py │ └── v2_v1_comparison.py └── worker.py ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── data ├── calculate_us_1_data.json ├── calculate_us_2_data.json ├── or_rebate_measure_118.json ├── test_economy_1_policy_1.json ├── uk_household.json ├── us_household.json └── utah_reform.json ├── env_variables └── test_environment_variables.py ├── fixtures ├── __init__.py ├── jobs │ ├── __init__.py │ └── calculate_economy_simulation_job.py ├── services │ ├── ai_analysis_service.py │ ├── household_fixtures.py │ ├── policy_service.py │ ├── tracer_analysis_service.py │ ├── tracer_fixture_service.py │ └── user_service.py ├── simulation_analysis_prompt_fixtures.py └── utils │ └── v2_v1_comparison.py ├── snapshots ├── simulation_analysis_prompt_dataset_enhanced_cps.txt ├── simulation_analysis_prompt_region_enhanced_us.txt ├── simulation_analysis_prompt_uk.txt └── simulation_analysis_prompt_us.txt ├── to_refactor ├── api │ ├── test_api.py │ ├── test_hello_world.yaml │ ├── test_liveness.yaml │ ├── test_readiness.yaml │ ├── test_uk_baseline_policy.yaml │ ├── test_uk_metadata.yaml │ └── test_us_create_empty_household.yaml ├── fixtures │ ├── simulation_analysis_fixtures.py │ └── to_refactor_household_fixtures.py └── python │ ├── test_ai_analysis_service_old.py │ ├── test_calculate_us_1.py │ ├── test_data.py │ ├── test_economy_1.py │ ├── test_error_routes.py │ ├── test_household_routes.py │ ├── test_policy.py │ ├── test_policy_service_old.py │ ├── test_simulation_analysis_routes.py │ ├── test_tracer_analysis_routes.py │ ├── test_us_policy_macro.py │ ├── test_user_profile_routes.py │ ├── test_validate_country.py │ ├── test_validate_household_payload.py │ ├── test_validate_sim_analysis_payload.py │ ├── test_varying_your_earnings.py │ └── test_yearly_var_removal.py └── unit ├── __init__.py ├── ai_prompts └── test_simulation_analysis_prompt.py ├── conftest.py ├── jobs ├── __init__.py └── test_calculate_economy_simulation_job.py ├── services ├── test_ai_analysis_service.py ├── test_execute_analysis.py ├── test_household_service.py ├── test_metadata_service.py ├── test_policy_service.py ├── test_tracer_analysis_service.py ├── test_tracer_service.py ├── test_update_profile_service.py └── test_user_service.py └── utils └── test_v2_v1_comparison.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = policyengine_api 3 | 4 | [report] 5 | omit = 6 | tests/* 7 | */tests/* 8 | tests/conftest.py 9 | 10 | -------------------------------------------------------------------------------- /.github/changelog_template.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | {{changelog}} -------------------------------------------------------------------------------- /.github/get-changelog-diff.sh: -------------------------------------------------------------------------------- 1 | last_tagged_commit=`git describe --tags --abbrev=0 --first-parent` 2 | git --no-pager diff $last_tagged_commit -- CHANGELOG.md -------------------------------------------------------------------------------- /.github/has-functional-changes.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | IGNORE_DIFF_ON="README.md CONTRIBUTING.md Makefile docs/* .gitignore LICENSE* .github/* data/*" 4 | 5 | last_tagged_commit=`git describe --tags --abbrev=0 --first-parent` # --first-parent ensures we don't follow tags not published in master through an unlikely intermediary merge commit 6 | 7 | if git diff-index --name-only --exit-code $last_tagged_commit -- . `echo " $IGNORE_DIFF_ON" | sed 's/ / :(exclude)/g'` # Check if any file that has not be listed in IGNORE_DIFF_ON has changed since the last tag was published. 8 | then 9 | echo "No functional changes detected." 10 | exit 1 11 | else echo "The functional files above were changed." 12 | fi 13 | -------------------------------------------------------------------------------- /.github/is-version-number-acceptable.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | if [[ ${GITHUB_REF#refs/heads/} == master ]] 4 | then 5 | echo "No need for a version check on master." 6 | exit 0 7 | fi 8 | 9 | if ! $(dirname "$BASH_SOURCE")/has-functional-changes.sh 10 | then 11 | echo "No need for a version update." 12 | exit 0 13 | fi 14 | 15 | current_version=`python setup.py --version` 16 | 17 | if git rev-parse --verify --quiet $current_version 18 | then 19 | echo "Version $current_version already exists in commit:" 20 | git --no-pager log -1 $current_version 21 | echo 22 | echo "Update the version number in setup.py before merging this branch into master." 23 | echo "Look at the CONTRIBUTING.md file to learn how the version number should be updated." 24 | exit 1 25 | fi 26 | 27 | if ! $(dirname "$BASH_SOURCE")/has-functional-changes.sh | grep --quiet CHANGELOG.md 28 | then 29 | echo "CHANGELOG.md has not been modified, while functional changes were made." 30 | echo "Explain what you changed before merging this branch into master." 31 | echo "Look at the CONTRIBUTING.md file to learn how to write the changelog." 32 | exit 2 33 | fi 34 | -------------------------------------------------------------------------------- /.github/publish-git-tag.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | git tag `python setup.py --version` 4 | git push --tags || true # update the repository version 5 | -------------------------------------------------------------------------------- /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: Pull request 2 | 3 | on: pull_request 4 | 5 | env: 6 | ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true 7 | 8 | jobs: 9 | lint: 10 | name: Lint 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout repo 14 | uses: actions/checkout@v4 15 | - name: Setup Python 16 | uses: actions/setup-python@v5 17 | - name: Format with Black 18 | uses: psf/black@stable 19 | with: 20 | options: ". -l 79 --check" 21 | check-version: 22 | name: Check version 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Checkout repo 26 | uses: actions/checkout@v4 27 | with: 28 | fetch-depth: 0 29 | repository: ${{ github.event.pull_request.head.repo.full_name }} 30 | ref: ${{ github.event.pull_request.head.ref }} 31 | - name: Set up Python 32 | uses: actions/setup-python@v5 33 | with: 34 | python-version: "3.11" 35 | - name: Build changelog 36 | run: pip install yaml-changelog && make changelog 37 | - name: Preview changelog update 38 | run: ".github/get-changelog-diff.sh" 39 | - name: Check version number has been properly updated 40 | run: ".github/is-version-number-acceptable.sh" 41 | test_container_builds: 42 | name: Docker 43 | runs-on: ubuntu-latest 44 | steps: 45 | - name: Checkout repo 46 | uses: actions/checkout@v4 47 | - name: Log in to the Container registry 48 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 49 | with: 50 | registry: ghcr.io 51 | username: ${{ github.actor }} 52 | password: ${{ secrets.POLICYENGINE_DOCKER }} 53 | - name: Build container 54 | run: docker build -t ghcr.io/policyengine/policyengine docker 55 | test_env_vars: 56 | name: Test environment variables 57 | runs-on: ubuntu-latest 58 | steps: 59 | - name: Checkout repo 60 | uses: actions/checkout@v4 61 | - name: Set up Python 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: "3.11" 65 | - name: Auth 66 | uses: google-github-actions/auth@v2 67 | with: 68 | credentials_json: ${{ secrets.GCP_SA_KEY }} 69 | - name: Install dependencies 70 | run: make install 71 | - name: Run environment variable tests 72 | run: pytest tests/env_variables/test_environment_variables.py 73 | env: 74 | POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: ${{ secrets.POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN }} 75 | HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} 76 | POLICYENGINE_DB_PASSWORD: ${{ secrets.POLICYENGINE_DB_PASSWORD }} 77 | test: 78 | name: Test 79 | runs-on: ubuntu-latest 80 | needs: test_env_vars 81 | container: 82 | image: policyengine/policyengine-api 83 | steps: 84 | - name: Checkout repo 85 | uses: actions/checkout@v4 86 | - name: Auth 87 | uses: google-github-actions/auth@v2 88 | with: 89 | credentials_json: ${{ secrets.GCP_SA_KEY }} 90 | - name: Set up Cloud SDK 91 | uses: google-github-actions/setup-gcloud@v2 92 | with: 93 | project_id: policyengine-api 94 | - name: Install dependencies 95 | run: make install 96 | - name: Test the API 97 | run: make test 98 | env: 99 | POLICYENGINE_DB_PASSWORD: ${{ secrets.POLICYENGINE_DB_PASSWORD }} 100 | POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: ${{ secrets.POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN }} 101 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 102 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 103 | HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} 104 | - name: Upload coverage to Codecov 105 | uses: codecov/codecov-action@v5 106 | with: 107 | token: ${{ secrets.CODECOV_TOKEN }} 108 | slug: PolicyEngine/policyengine-api -------------------------------------------------------------------------------- /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | name: Push 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | env: 9 | ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true 10 | 11 | concurrency: 12 | group: deploy 13 | 14 | jobs: 15 | Lint: 16 | runs-on: ubuntu-latest 17 | if: | 18 | (github.repository == 'PolicyEngine/policyengine-uk') 19 | && (github.event.head_commit.message == 'Update PolicyEngine API') 20 | steps: 21 | - name: Checkout repo 22 | uses: actions/checkout@v4 23 | - name: Check formatting 24 | uses: "lgeiger/black-action@master" 25 | with: 26 | args: ". -l 79 --check" 27 | versioning: 28 | name: Update versioning 29 | if: | 30 | (github.repository == 'PolicyEngine/policyengine-api') 31 | && !(github.event.head_commit.message == 'Update PolicyEngine API') 32 | runs-on: ubuntu-latest 33 | steps: 34 | - name: Checkout repo 35 | uses: actions/checkout@v4 36 | with: 37 | repository: ${{ github.event.pull_request.head.repo.full_name }} 38 | ref: ${{ github.event.pull_request.head.ref }} 39 | token: ${{ secrets.POLICYENGINE_GITHUB }} 40 | - name: Setup Python 41 | uses: actions/setup-python@v5 42 | with: 43 | python-version: "3.11" 44 | - name: Build changelog 45 | run: pip install yaml-changelog && make changelog 46 | - name: Preview changelog update 47 | run: ".github/get-changelog-diff.sh" 48 | - name: Update changelog 49 | uses: EndBug/add-and-commit@v9 50 | with: 51 | add: "." 52 | committer_name: Github Actions[bot] 53 | author_name: Github Actions[bot] 54 | message: Update PolicyEngine API 55 | deploy: 56 | name: Deploy API 57 | runs-on: ubuntu-latest 58 | if: | 59 | (github.repository == 'PolicyEngine/policyengine-api') 60 | && (github.event.head_commit.message == 'Update PolicyEngine API') 61 | steps: 62 | - name: Checkout repo 63 | uses: actions/checkout@v4 64 | - name: Setup Python 65 | uses: actions/setup-python@v5 66 | with: 67 | python-version: "3.11" 68 | - name: GCP authentication 69 | uses: "google-github-actions/auth@v2" 70 | with: 71 | credentials_json: "${{ secrets.GCP_SA_KEY }}" 72 | - name: Set up GCloud 73 | uses: "google-github-actions/setup-gcloud@v2" 74 | - name: Deploy 75 | run: make deploy 76 | env: 77 | POLICYENGINE_DB_PASSWORD: ${{ secrets.POLICYENGINE_DB_PASSWORD }} 78 | GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GCP_SA_KEY }} 79 | POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: ${{ secrets.POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN }} 80 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 81 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 82 | HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} 83 | docker: 84 | name: Docker 85 | runs-on: ubuntu-latest 86 | steps: 87 | - name: Checkout repo 88 | uses: actions/checkout@v4 89 | - name: Log in to the Container registry 90 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 91 | with: 92 | registry: ghcr.io 93 | username: ${{ github.actor }} 94 | password: ${{ secrets.POLICYENGINE_DOCKER }} 95 | - name: Build container 96 | run: docker build -t ghcr.io/policyengine/policyengine docker 97 | - name: Push container 98 | run: docker push ghcr.io/policyengine/policyengine 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | **/__pycache__ 3 | *.egg-info 4 | .pytest_cache 5 | **/*.db 6 | **/*.db-journal 7 | .gac.json 8 | .dbpw 9 | .github_microdata_token 10 | /app.yaml 11 | /Dockerfile 12 | dist/* 13 | **/*.rdb 14 | **/*.h5 15 | **/*.csv.gz 16 | .env 17 | 18 | # Ignore generated credentials from google-github-actions/auth 19 | gha-creds-*.json 20 | 21 | ## vscode settings 22 | /.vscode 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | pip install -e ".[dev]" --config-settings editable_mode=compat 3 | 4 | debug: 5 | FLASK_APP=policyengine_api.api FLASK_DEBUG=1 flask run --without-threads 6 | 7 | test-env-vars: 8 | pytest tests/env_variables 9 | 10 | test: 11 | MAX_HOUSEHOLDS=1000 coverage run -a --branch -m pytest tests/to_refactor tests/unit --disable-pytest-warnings 12 | coverage xml -i 13 | 14 | debug-test: 15 | MAX_HOUSEHOLDS=1000 FLASK_DEBUG=1 pytest -vv --durations=0 tests 16 | 17 | format: 18 | black . -l 79 19 | 20 | deploy: 21 | python gcp/export.py 22 | gcloud config set app/cloud_build_timeout 2400 23 | cp gcp/policyengine_api/* . 24 | y | gcloud app deploy --service-account=github-deployment@policyengine-api.iam.gserviceaccount.com 25 | rm app.yaml 26 | rm Dockerfile 27 | rm .gac.json 28 | rm .dbpw 29 | 30 | changelog: 31 | build-changelog changelog.yaml --output changelog.yaml --update-last-date --start-from 0.1.0 --append-file changelog_entry.yaml 32 | build-changelog changelog.yaml --org PolicyEngine --repo policyengine-api --output CHANGELOG.md --template .github/changelog_template.md 33 | bump-version changelog.yaml setup.py policyengine_api/constants.py 34 | rm changelog_entry.yaml || true 35 | touch changelog_entry.yaml -------------------------------------------------------------------------------- /changelog_entry.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolicyEngine/policyengine-api/5a3e2ce40857c086299cfe0b53956c1e28a8e3ad/changelog_entry.yaml -------------------------------------------------------------------------------- /dashboard/app.py: -------------------------------------------------------------------------------- 1 | from policyengine_api.data.data import database 2 | import streamlit as st 3 | 4 | st.title("PolicyEngine API dashboard") 5 | 6 | # Add a text box that the user can enter a SQL query into, with a submit button and a table showing the results. 7 | 8 | st.subheader("Run a SQL query") 9 | 10 | query = st.text_area("Enter a SQL query", "SELECT * FROM policy LIMIT 10;") 11 | if st.button("Submit"): 12 | try: 13 | results = database.query(query) 14 | st.table(results.fetchall()) 15 | except Exception as e: 16 | st.error(e) 17 | 18 | # Enable the user to look up a policy by ID. 19 | 20 | st.subheader("Look up a policy") 21 | 22 | policy_id = int( 23 | st.text_input("Enter a policy ID", "1", key="policy_lookup_text") 24 | ) 25 | country_id = st.text_input( 26 | "Enter a country ID", "uk", key="policy_lookup_country" 27 | ) 28 | if st.button("Look up policy", key="policy_lookup"): 29 | try: 30 | results = database.query( 31 | f"SELECT * FROM policy WHERE id IS '{policy_id}' AND country_id IS '{country_id}' LIMIT 10;" 32 | ) 33 | st.table(results.fetchall()) 34 | except Exception as e: 35 | st.error(e) 36 | 37 | # Enable the user to set the label of a policy. 38 | 39 | st.subheader("Set a policy's label") 40 | 41 | policy_id = int(st.text_input("Enter a policy ID", "1")) 42 | country_id = st.text_input("Enter a country ID", "uk") 43 | new_label = st.text_input( 44 | "Enter a new label", "New label", key="policy_label_text" 45 | ) 46 | if st.button("Set policy label", key="policy_label"): 47 | try: 48 | database.set_policy_label(policy_id, country_id, new_label) 49 | st.success("Success!") 50 | except Exception as e: 51 | st.error(e) 52 | 53 | # Enable the user to delete a policy. 54 | 55 | st.subheader("Delete a policy") 56 | 57 | policy_id = int( 58 | st.text_input("Enter a policy ID", "1", key="policy_delete_text") 59 | ) 60 | country_id = st.text_input( 61 | "Enter a country ID", "uk", key="policy_delete_country" 62 | ) 63 | if st.button("Delete policy", key="policy_delete"): 64 | try: 65 | database.delete_policy(policy_id, country_id) 66 | st.success("Success!") 67 | except Exception as e: 68 | st.error(e) 69 | -------------------------------------------------------------------------------- /dashboard/experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | pip>=23 2 | streamlit 3 | openai 4 | pandas 5 | requests 6 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 2 | RUN pip install policyengine-core policyengine-uk policyengine-us ipython 3 | -------------------------------------------------------------------------------- /gcp/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 2 | ENV VIRTUAL_ENV /env 3 | ENV PATH /env/bin:$PATH 4 | RUN apt-get update && apt-get install -y build-essential checkinstall 5 | RUN python3.11 -m pip install --upgrade pip --trusted-host pypi.python.org --trusted-host pypi.org --trusted-host files.pythonhosted.orgpip 6 | RUN apt-get update && apt-get install -y redis-server 7 | RUN pip install git+https://github.com/policyengine/policyengine-api 8 | -------------------------------------------------------------------------------- /gcp/README.md: -------------------------------------------------------------------------------- 1 | # Docker guidance 2 | 3 | The deployment actions build Docker images and deploy them to Google App Engine. The docker images themselves are based off a starter image (to save each API docker image having to spend 5 minutes installing the same dependencies). The starter image is the `Dockerfile` in this directory. 4 | 5 | To update the starter image: 6 | * `python setup.py sdist` to build the python package 7 | * `twine upload dist/*` to upload the package to pypi as `policyengine-api` 8 | * `cd gcp` 9 | * `docker build .` 10 | * `docker images` to get the image id (the most recent one should be the one you just built) 11 | * `docker tag policyengine/policyengine-api` 12 | * `docker push policyengine/policyengine-api` 13 | -------------------------------------------------------------------------------- /gcp/bump_country_package.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import time 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | # Usage: python bump_country_package.py --country=policyengine-uk --version=1.0.1 10 | # Country can be: policyengine-uk, policyengine-us 11 | # Version should be in the format: 1.0.1 12 | parser.add_argument( 13 | "--country", 14 | type=str, 15 | required=True, 16 | help="Country package to bump", 17 | choices=["policyengine-uk", "policyengine-us", "policyengine-canada"], 18 | ) 19 | parser.add_argument( 20 | "--version", type=str, required=True, help="Version to bump to" 21 | ) 22 | args = parser.parse_args() 23 | country = args.country 24 | version = args.version 25 | bump_country_package(country, version) 26 | 27 | 28 | def bump_country_package(country, version): 29 | time.sleep(60 * 5) 30 | # Update the version in the country package's setup.py 31 | setup_py_path = f"setup.py" 32 | with open(setup_py_path, "r") as f: 33 | setup_py = f.read() 34 | # Find where it says {country}=={old version} and replace it with {country}=={new version} 35 | country_package_name = country.replace("-", "_") 36 | package_version_regex = rf"{country_package_name}==(\d+\.\d+\.\d+)" 37 | match = re.search(package_version_regex, setup_py) 38 | 39 | # If the line was found, replace it with the new package version 40 | if match: 41 | new_line = f"{country_package_name}=={version}" 42 | setup_py = setup_py.replace(match.group(0), new_line) 43 | # Write setup_py to setup.py 44 | with open(setup_py_path, "w") as f: 45 | f.write(setup_py) 46 | 47 | country_package_full_name = country.replace( 48 | "policyengine", "PolicyEngine" 49 | ).replace("-", " ") 50 | country_id = country.replace("policyengine-", "") 51 | country_package_full_name = country_package_full_name.replace( 52 | country_id, country_id.upper() 53 | ) 54 | 55 | changelog_yaml = f"""- bump: patch\n changes:\n changed:\n - Update {country_package_full_name} to {version}\n""" 56 | # Write changelog_yaml to changelog.yaml 57 | with open("changelog_entry.yaml", "w") as f: 58 | f.write(changelog_yaml) 59 | 60 | # Commit the change and push to a branch 61 | branch_name = f"bump-{country}-to-{version}" 62 | # Checkout a new branch locally, add all the files, commit, and push using the GitHub CLI only 63 | 64 | # First, create a new branch off master 65 | os.system(f"git checkout -b {branch_name}") 66 | # Add all the files 67 | os.system("git add -A") 68 | # Commit the change 69 | os.system(f"git config --global user.name 'PolicyEngine[bot]'") 70 | os.system(f"git config --global user.email 'hello@policyengine.org'") 71 | os.system(f'git commit -m "Bump {country_package_full_name} to {version}"') 72 | # Push the branch to GitHub, using the personal access token stored in GITHUB_TOKEN 73 | os.system(f"git push -u origin {branch_name} -f") 74 | # Create a pull request using the GitHub CLI 75 | os.system( 76 | f"gh pr create --title 'Update {country_package_full_name} to {version}' --body 'Update {country_package_full_name} to {version}' --base master --head {branch_name}" 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /gcp/dispatch.yaml: -------------------------------------------------------------------------------- 1 | dispatch: 2 | - url: api.policyengine.org/* 3 | service: default 4 | -------------------------------------------------------------------------------- /gcp/export.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | GAE = os.environ["GOOGLE_APPLICATION_CREDENTIALS"] 5 | # If it's a filepath, read the file. Otherwise, it'll be JSON 6 | try: 7 | Path(GAE).resolve(strict=True) 8 | with open(GAE, "r") as f: 9 | GAE = f.read() 10 | except Exception as e: 11 | pass 12 | DB_PD = os.environ["POLICYENGINE_DB_PASSWORD"] 13 | GITHUB_MICRODATA_TOKEN = os.environ["POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN"] 14 | ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"] 15 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] 16 | HUGGING_FACE_TOKEN = os.environ["HUGGING_FACE_TOKEN"] 17 | 18 | # Export GAE to to .gac.json and DB_PD to .dbpw in the current directory 19 | 20 | with open(".gac.json", "w") as f: 21 | f.write(GAE) 22 | 23 | with open(".dbpw", "w") as f: 24 | f.write(DB_PD) 25 | 26 | # in gcp/compute_api/Dockerfile, replace .github_microdata_token with the contents of the file 27 | for dockerfile_location in [ 28 | "gcp/policyengine_api/Dockerfile", 29 | ]: 30 | with open(dockerfile_location, "r") as f: 31 | dockerfile = f.read() 32 | dockerfile = dockerfile.replace( 33 | ".github_microdata_token", GITHUB_MICRODATA_TOKEN 34 | ) 35 | dockerfile = dockerfile.replace( 36 | ".anthropic_api_key", ANTHROPIC_API_KEY 37 | ) 38 | dockerfile = dockerfile.replace(".openai_api_key", OPENAI_API_KEY) 39 | dockerfile = dockerfile.replace( 40 | ".hugging_face_token", HUGGING_FACE_TOKEN 41 | ) 42 | 43 | with open(dockerfile_location, "w") as f: 44 | f.write(dockerfile) 45 | -------------------------------------------------------------------------------- /gcp/policyengine_api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM policyengine/policyengine-api:latest 2 | 3 | ENV GOOGLE_APPLICATION_CREDENTIALS .gac.json 4 | ENV POLICYENGINE_DB_PASSWORD .dbpw 5 | ENV POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN .github_microdata_token 6 | ENV ANTHROPIC_API_KEY .anthropic_api_key 7 | ENV OPENAI_API_KEY .openai_api_key 8 | ENV HUGGING_FACE_TOKEN .hugging_face_token 9 | ENV CREDENTIALS_JSON_API_V2 .credentials_json_api_v2 10 | 11 | WORKDIR /app 12 | 13 | # Copy application 14 | ADD . /app 15 | 16 | # Make start.sh executable 17 | RUN chmod +x /app/start.sh 18 | 19 | RUN cd /app && make install && make test 20 | 21 | # Use full path to start.sh 22 | CMD ["/bin/sh", "/app/start.sh"] 23 | -------------------------------------------------------------------------------- /gcp/policyengine_api/app.yaml: -------------------------------------------------------------------------------- 1 | runtime: custom 2 | env: flex 3 | resources: 4 | cpu: 24 5 | memory_gb: 128 6 | disk_size_gb: 128 7 | automatic_scaling: 8 | min_num_instances: 1 9 | max_num_instances: 1 10 | cool_down_period_sec: 180 11 | cpu_utilization: 12 | target_utilization: 0.8 13 | liveness_check: 14 | path: "/liveness-check" 15 | check_interval_sec: 30 16 | timeout_sec: 30 17 | failure_threshold: 5 18 | success_threshold: 2 19 | runtime_config: 20 | operating_system: "ubuntu22" 21 | runtime_version: "22" 22 | readiness_check: 23 | path: "/readiness-check" 24 | app_start_timeout_sec: 400 25 | -------------------------------------------------------------------------------- /gcp/policyengine_api/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Environment variables 3 | PORT="${PORT:-8080}" 4 | WORKER_COUNT="${WORKER_COUNT:-3}" 5 | REDIS_PORT="${REDIS_PORT:-6379}" 6 | 7 | # Start the API 8 | gunicorn -b :"$PORT" policyengine_api.api --timeout 300 --workers 5 & 9 | 10 | # Start Redis with configuration for multiple clients 11 | redis-server --protected-mode no \ 12 | --maxclients 10000 \ 13 | --timeout 0 & 14 | 15 | # Wait for Redis to be ready 16 | sleep 2 17 | 18 | # Start multiple workers using POSIX-compliant loop 19 | i=1 20 | while [ $i -le "$WORKER_COUNT" ] 21 | do 22 | echo "Starting worker $i..." 23 | python3 policyengine_api/worker.py & 24 | i=$((i + 1)) 25 | done 26 | 27 | # Keep the script running and handle shutdown gracefully 28 | trap "pkill -P $$; exit 1" INT TERM 29 | 30 | wait -------------------------------------------------------------------------------- /policyengine_api/ai_prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .simulation_analysis_prompt import generate_simulation_analysis_prompt 2 | -------------------------------------------------------------------------------- /policyengine_api/ai_templates/simulation_analysis_template.py: -------------------------------------------------------------------------------- 1 | simulation_analysis_template = """ 2 | I'm using PolicyEngine, a free, open source tool to compute the impact of 3 | public policy. I'm writing up an economic analysis of a hypothetical tax-benefit 4 | policy reform. Please write the analysis for me using the details below, in 5 | their order. You should: 6 | 7 | - First explain each provision of the reform, noting that it's hypothetical and 8 | won't represents policy reforms for {time_period} and {region}. Explain how 9 | the parameters are changing from the baseline to the reform values using the given data. 10 | 11 | {enhanced_cps_template} 12 | 13 | - Round large numbers like: {currency}3.1 billion, {currency}300 million, 14 | {currency}106,000, {currency}1.50 (never {currency}1.5). 15 | 16 | - Round percentages to one decimal place. 17 | 18 | - Avoid normative language like 'requires', 'should', 'must', and use quantitative statements 19 | over general adjectives and adverbs. If you don't know what something is, don't make it up. 20 | 21 | - Avoid speculating about the intent of the policy or inferring any motives; only describe the 22 | observable effects and impacts of the policy. Refrain from using subjective language or making 23 | assumptions about the recipients and their needs. 24 | 25 | - Use the active voice where possible; for example, write phrases where the reform is the subject, 26 | such as "the reform [or a description of the reform] reduces poverty by x%". 27 | 28 | - Use {dialect} English spelling and grammar. 29 | 30 | - Cite PolicyEngine {country_id_uppercase} v{selected_version} and the {data_source} microdata 31 | when describing policy impacts. 32 | 33 | - When describing poverty impacts, note that the poverty measure reported is {poverty_measure} 34 | 35 | - Don't use headers, but do use Markdown formatting. Use - for bullets, and include a newline after each bullet. 36 | 37 | - Include the following embeds inline, without a header so it flows. 38 | 39 | - Immediately after you describe the changes by decile, include the text: '{{{{distributionalImpact.incomeDecile.relative}}}}' 40 | 41 | - And after the poverty rate changes, include the text: '{{{{povertyImpact.regular.byAge}}}}' 42 | 43 | {poverty_rate_change_text} 44 | 45 | - And after the inequality changes, include the text: "{{{{inequalityImpact}}}}" 46 | 47 | - Make sure to accurately represent the changes observed in the data. 48 | 49 | - This JSON snippet describes the default parameter values: {relevant_parameter_baseline_values} 50 | 51 | - This JSON snippet describes the baseline and reform policies being compared: {policy} 52 | 53 | - {policy_label} has the following impacts from the PolicyEngine microsimulation model: 54 | 55 | - This JSON snippet describes the relevant parameters with more details: {relevant_parameters} 56 | 57 | - This JSON describes the total budgetary impact, the change to tax revenues and benefit 58 | spending (ignore 'households' and 'baseline_net_income': {impact_budget}) 59 | 60 | - This JSON describes how common different outcomes were at each income decile: {impact_intra_decile} 61 | 62 | - This JSON describes the average and relative changes to income by each income decile: {impact_decile} 63 | 64 | - This JSON describes the baseline and reform poverty rates by age group 65 | (describe the relative changes): {impact_poverty} 66 | 67 | - This JSON describes the baseline and reform deep poverty rates by age group 68 | (describe the relative changes): {impact_deep_poverty} 69 | 70 | - This JSON describes the baseline and reform poverty and deep poverty rates 71 | by gender (briefly describe the relative changes): {impact_poverty_by_gender} 72 | 73 | {poverty_by_race_text} 74 | 75 | - This JSON describes three inequality metrics in the baseline and reform, the Gini 76 | coefficient of income inequality, the share of income held by the top 10% of households 77 | and the share held by the top 1% (describe the relative changes): {impact_inequality} 78 | 79 | {audience_description} 80 | """ 81 | 82 | audience_descriptions = { 83 | "ELI5": "Write this for a layperson who doesn't know much about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", 84 | "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", 85 | "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", 86 | } 87 | -------------------------------------------------------------------------------- /policyengine_api/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the main Flask app for the PolicyEngine API. 3 | """ 4 | 5 | from pathlib import Path 6 | from flask_cors import CORS 7 | import flask 8 | import yaml 9 | from flask_caching import Cache 10 | from policyengine_api.utils import make_cache_key 11 | from .constants import VERSION 12 | 13 | # from werkzeug.middleware.profiler import ProfilerMiddleware 14 | 15 | # Endpoints 16 | from policyengine_api.routes.error_routes import error_bp 17 | from policyengine_api.routes.economy_routes import economy_bp 18 | from policyengine_api.routes.household_routes import household_bp 19 | from policyengine_api.routes.simulation_analysis_routes import ( 20 | simulation_analysis_bp, 21 | ) 22 | from policyengine_api.routes.policy_routes import policy_bp 23 | from policyengine_api.routes.tracer_analysis_routes import tracer_analysis_bp 24 | from policyengine_api.routes.metadata_routes import metadata_bp 25 | from policyengine_api.routes.user_profile_routes import user_profile_bp 26 | from policyengine_api.routes.ai_prompt_routes import ai_prompt_bp 27 | 28 | from .endpoints import ( 29 | get_home, 30 | get_policy_search, 31 | get_household_under_policy, 32 | get_calculate, 33 | set_user_policy, 34 | get_user_policy, 35 | update_user_policy, 36 | get_simulations, 37 | ) 38 | 39 | print("Initialising API...") 40 | 41 | app = application = flask.Flask(__name__) 42 | 43 | app.config.from_mapping( 44 | { 45 | "CACHE_TYPE": "RedisCache", 46 | "CACHE_KEY_PREFIX": "policyengine", 47 | "CACHE_REDIS_HOST": "127.0.0.1", 48 | "CACHE_REDIS_PORT": 6379, 49 | "CACHE_DEFAULT_TIMEOUT": 300, 50 | } 51 | ) 52 | cache = Cache(app) 53 | 54 | CORS(app) 55 | 56 | app.register_blueprint(error_bp) 57 | 58 | app.route("/", methods=["GET"])(get_home) 59 | 60 | app.register_blueprint(metadata_bp) 61 | 62 | app.register_blueprint(household_bp) 63 | 64 | # Routes for getting and setting a "policy" record 65 | app.register_blueprint(policy_bp) 66 | 67 | app.route("//policies", methods=["GET"])(get_policy_search) 68 | 69 | app.route( 70 | "//household//policy/", 71 | methods=["GET"], 72 | )(get_household_under_policy) 73 | 74 | app.route("//calculate", methods=["POST"])( 75 | cache.cached(make_cache_key=make_cache_key)(get_calculate) 76 | ) 77 | 78 | app.route("//calculate-full", methods=["POST"])( 79 | cache.cached(make_cache_key=make_cache_key)( 80 | lambda *args, **kwargs: get_calculate( 81 | *args, **kwargs, add_missing=True 82 | ) 83 | ) 84 | ) 85 | 86 | # Routes for economy microsimulation 87 | app.register_blueprint(economy_bp) 88 | 89 | # Routes for AI analysis of economy microsim runs 90 | app.register_blueprint(simulation_analysis_bp) 91 | 92 | app.route("//user-policy", methods=["POST"])(set_user_policy) 93 | 94 | app.route("//user-policy", methods=["PUT"])(update_user_policy) 95 | 96 | app.route("//user-policy/", methods=["GET"])( 97 | get_user_policy 98 | ) 99 | 100 | app.register_blueprint(user_profile_bp) 101 | 102 | app.route("/simulations", methods=["GET"])(get_simulations) 103 | 104 | app.register_blueprint(tracer_analysis_bp) 105 | 106 | app.register_blueprint(ai_prompt_bp) 107 | 108 | 109 | @app.route("/liveness-check", methods=["GET"]) 110 | def liveness_check(): 111 | return flask.Response( 112 | "OK", status=200, headers={"Content-Type": "text/plain"} 113 | ) 114 | 115 | 116 | @app.route("/readiness-check", methods=["GET"]) 117 | def readiness_check(): 118 | return flask.Response( 119 | "OK", status=200, headers={"Content-Type": "text/plain"} 120 | ) 121 | 122 | 123 | # Add OpenAPI spec (__file__.parent / openapi_spec.yaml) 124 | 125 | with open(Path(__file__).parent / "openapi_spec.yaml", encoding="utf-8") as f: 126 | openapi_spec = yaml.safe_load(f) 127 | openapi_spec["info"]["version"] = VERSION 128 | 129 | 130 | @app.route("/specification", methods=["GET"]) 131 | def get_specification(): 132 | return flask.jsonify(openapi_spec) 133 | 134 | 135 | print("API initialised.") 136 | -------------------------------------------------------------------------------- /policyengine_api/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pkg_resources 3 | 4 | REPO = Path(__file__).parents[1] 5 | GET = "GET" 6 | POST = "POST" 7 | UPDATE = "UPDATE" 8 | LIST = "LIST" 9 | VERSION = "3.23.2" 10 | COUNTRIES = ("uk", "us", "ca", "ng", "il") 11 | COUNTRY_PACKAGE_NAMES = ( 12 | "policyengine_uk", 13 | "policyengine_us", 14 | "policyengine_canada", 15 | "policyengine_ng", 16 | "policyengine_il", 17 | ) 18 | try: 19 | COUNTRY_PACKAGE_VERSIONS = { 20 | country: pkg_resources.get_distribution(package_name).version 21 | for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES) 22 | } 23 | except: 24 | COUNTRY_PACKAGE_VERSIONS = {country: "0.0.0" for country in COUNTRIES} 25 | __version__ = VERSION 26 | -------------------------------------------------------------------------------- /policyengine_api/data/README.md: -------------------------------------------------------------------------------- 1 | # `policyengine_api.data` 2 | 3 | This module contains the code managing the database used by the main and compute-intensive APIs. 4 | 5 | ## Tables 6 | 7 | ### `household` 8 | 9 | Stores an individual household's data. 10 | 11 | | Column | Type | Description | 12 | | --- | --- | --- | 13 | | `id` | `int` | Unique identifier for the household | 14 | | `label` | `string` | Display name for the household | 15 | | `version` | `string` | Version of the PolicyEngine API when the household was last updated | 16 | | `household_json` | `json` | JSON representation of the household | 17 | 18 | Initialisation SQL: 19 | 20 | ```sql 21 | CREATE TABLE IF NOT EXISTS household ( 22 | id SERIAL PRIMARY KEY, 23 | label VARCHAR(255) NOT NULL, 24 | version VARCHAR(255) NOT NULL, 25 | household_json JSONB NOT NULL 26 | ); 27 | ``` 28 | 29 | ### `policy` 30 | 31 | Stores a policy - a set of time-period-dated parameter overrides from current law. 32 | 33 | | Column | Type | Description | 34 | | --- | --- | --- | 35 | | `id` | `int` | Unique identifier for the policy | 36 | | `label` | `string` | Display name for the policy | 37 | | `version` | `string` | Version of the PolicyEngine API when the policy was last updated | 38 | | `policy_json` | `json` | JSON representation of the policy | 39 | 40 | Initialisation SQL: 41 | 42 | ```sql 43 | CREATE TABLE IF NOT EXISTS policy ( 44 | id SERIAL PRIMARY KEY, 45 | label VARCHAR(255) NOT NULL, 46 | version VARCHAR(255) NOT NULL, 47 | policy_json JSONB NOT NULL 48 | ); 49 | ``` 50 | 51 | ### `economy` 52 | 53 | Stores the high-level outputs of a microsimulation run, under a given policy. 54 | 55 | | Column | Type | Description | 56 | | --- | --- | --- | 57 | | `household_id` | `int` | Unique identifier for the household | 58 | | `policy_id` | `int` | Unique identifier for the policy | 59 | | `version` | `string` | Version of the PolicyEngine API when the economy was last updated | 60 | | `economy_json` | `json` | JSON representation of the economy | 61 | 62 | Initialisation SQL: 63 | 64 | ```sql 65 | CREATE TABLE IF NOT EXISTS economy ( 66 | household_id INT NOT NULL, 67 | policy_id INT NOT NULL, 68 | version VARCHAR(255) NOT NULL, 69 | economy_json JSONB NOT NULL, 70 | PRIMARY KEY (household_id, policy_id) 71 | ); 72 | ``` -------------------------------------------------------------------------------- /policyengine_api/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import PolicyEngineDatabase, database, local_database 2 | -------------------------------------------------------------------------------- /policyengine_api/data/initialise.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS household ( 2 | id INTEGER PRIMARY KEY AUTO_INCREMENT, 3 | country_id VARCHAR(3) NOT NULL, 4 | label VARCHAR(255), 5 | api_version VARCHAR(255) NOT NULL, 6 | household_json JSON NOT NULL, 7 | household_hash VARCHAR(255) NOT NULL 8 | ); 9 | 10 | CREATE TABLE IF NOT EXISTS computed_household ( 11 | household_id INT NOT NULL, 12 | policy_id INT NOT NULL, 13 | country_id VARCHAR(3) NOT NULL, 14 | api_version VARCHAR(10) NOT NULL, 15 | computed_household_json JSON NOT NULL, 16 | status VARCHAR(32), 17 | PRIMARY KEY (household_id, policy_id, country_id) 18 | ); 19 | 20 | CREATE TABLE IF NOT EXISTS policy ( 21 | id INTEGER AUTO_INCREMENT, 22 | country_id VARCHAR(3) NOT NULL, 23 | label VARCHAR(255), 24 | api_version VARCHAR(10) NOT NULL, 25 | policy_json JSON NOT NULL, 26 | policy_hash VARCHAR(255) NOT NULL, 27 | PRIMARY KEY (id, country_id, policy_hash) 28 | ); 29 | 30 | CREATE TABLE IF NOT EXISTS economy ( 31 | economy_id INTEGER PRIMARY KEY AUTO_INCREMENT, 32 | policy_id INT NOT NULL, 33 | country_id VARCHAR(3) NOT NULL, 34 | region VARCHAR(32), 35 | time_period VARCHAR(32), 36 | options_json JSON NOT NULL, 37 | options_hash VARCHAR(255) NOT NULL, 38 | api_version VARCHAR(10) NOT NULL, 39 | economy_json JSON, 40 | status VARCHAR(32) NOT NULL, 41 | message VARCHAR(255) 42 | ); 43 | 44 | CREATE TABLE IF NOT EXISTS reform_impact ( 45 | reform_impact_id INTEGER PRIMARY KEY AUTO_INCREMENT, 46 | baseline_policy_id INT NOT NULL, 47 | reform_policy_id INT NOT NULL, 48 | country_id VARCHAR(3) NOT NULL, 49 | region VARCHAR(32) NOT NULL, 50 | dataset VARCHAR(255) NOT NULL, 51 | time_period VARCHAR(32) NOT NULL, 52 | options_json JSON, 53 | options_hash VARCHAR(255), 54 | api_version VARCHAR(10) NOT NULL, 55 | reform_impact_json JSON NOT NULL, 56 | status VARCHAR(32) NOT NULL, 57 | message VARCHAR(255), 58 | start_time DATETIME 59 | ); 60 | 61 | CREATE TABLE IF NOT EXISTS analysis ( 62 | prompt_id INTEGER PRIMARY KEY AUTO_INCREMENT, 63 | prompt LONGTEXT NOT NULL, 64 | analysis LONGTEXT, 65 | status VARCHAR(32) NOT NULL 66 | ) 67 | 68 | -- The dataset row below was added while the table is in prod; 69 | -- we must allow NULL values for this column 70 | CREATE TABLE IF NOT EXISTS user_policies ( 71 | id INTEGER PRIMARY KEY AUTO_INCREMENT, 72 | country_id VARCHAR(3) NOT NULL, 73 | reform_id INTEGER NOT NULL, 74 | reform_label VARCHAR(255), 75 | baseline_id INTEGER NOT NULL, 76 | baseline_label VARCHAR(255), 77 | user_id VARCHAR(255) NOT NULL, 78 | year VARCHAR(32) NOT NULL, 79 | geography VARCHAR(255) NOT NULL, 80 | dataset VARCHAR(255), 81 | number_of_provisions INTEGER NOT NULL, 82 | api_version VARCHAR(32) NOT NULL, 83 | added_date BIGINT NOT NULL, 84 | updated_date BIGINT NOT NULL, 85 | budgetary_impact VARCHAR(255), 86 | type VARCHAR(255) 87 | ); 88 | 89 | CREATE TABLE IF NOT EXISTS user_profiles ( 90 | user_id INTEGER PRIMARY KEY AUTO_INCREMENT, 91 | auth0_id VARCHAR(255) NOT NULL UNIQUE, 92 | username VARCHAR(255) UNIQUE, 93 | primary_country VARCHAR(3) NOT NULL, 94 | user_since BIGINT NOT NULL 95 | ); 96 | 97 | CREATE TABLE IF NOT EXISTS tracers ( 98 | id INTEGER PRIMARY KEY AUTO_INCREMENT, 99 | household_id INT NOT NULL, 100 | policy_id INT NOT NULL, 101 | country_id VARCHAR(3) NOT NULL, 102 | api_version VARCHAR(10) NOT NULL, 103 | tracer_output JSON NOT NULL 104 | ); -------------------------------------------------------------------------------- /policyengine_api/data/initialise_local.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS household; 2 | DROP TABLE IF EXISTS computed_household; 3 | DROP TABLE IF EXISTS policy; 4 | DROP TABLE IF EXISTS economy; 5 | DROP TABLE IF EXISTS reform_impact; 6 | DROP TABLE IF EXISTS analysis; 7 | DROP TABLE IF EXISTS user_policies; 8 | DROP TABLE IF EXISTS tracers; 9 | 10 | CREATE TABLE IF NOT EXISTS household ( 11 | id INTEGER PRIMARY KEY, 12 | country_id VARCHAR(3) NOT NULL, 13 | label VARCHAR(255), 14 | api_version VARCHAR(255) NOT NULL, 15 | household_json JSONB NOT NULL, 16 | household_hash VARCHAR(255) NOT NULL 17 | ); 18 | 19 | CREATE TABLE IF NOT EXISTS computed_household ( 20 | household_id INT NOT NULL, 21 | policy_id INT NOT NULL, 22 | country_id VARCHAR(3) NOT NULL, 23 | api_version VARCHAR(10) NOT NULL, 24 | computed_household_json JSONB NOT NULL, 25 | status VARCHAR(32), 26 | PRIMARY KEY (household_id, policy_id, country_id) 27 | ); 28 | 29 | CREATE TABLE IF NOT EXISTS policy ( 30 | id INTEGER PRIMARY KEY, 31 | country_id VARCHAR(3) NOT NULL, 32 | label VARCHAR(255), 33 | api_version VARCHAR(10) NOT NULL, 34 | policy_json JSONB NOT NULL, 35 | policy_hash VARCHAR(255) NOT NULL 36 | ); 37 | 38 | CREATE TABLE IF NOT EXISTS economy ( 39 | economy_id INTEGER PRIMARY KEY, 40 | policy_id INT NOT NULL, 41 | country_id VARCHAR(3) NOT NULL, 42 | region VARCHAR(32), 43 | time_period VARCHAR(32), 44 | options_json JSON NOT NULL, 45 | options_hash VARCHAR(255) NOT NULL, 46 | api_version VARCHAR(10) NOT NULL, 47 | economy_json JSON, 48 | status VARCHAR(32) NOT NULL, 49 | message VARCHAR(255) 50 | ); 51 | 52 | CREATE TABLE IF NOT EXISTS reform_impact ( 53 | reform_impact_id INTEGER PRIMARY KEY, 54 | baseline_policy_id INT NOT NULL, 55 | reform_policy_id INT NOT NULL, 56 | country_id VARCHAR(3) NOT NULL, 57 | region VARCHAR(32) NOT NULL, 58 | dataset VARCHAR(255) NOT NULL, 59 | time_period VARCHAR(32) NOT NULL, 60 | options_json JSON NOT NULL, 61 | options_hash VARCHAR(255) NOT NULL, 62 | api_version VARCHAR(10) NOT NULL, 63 | reform_impact_json JSON NOT NULL, 64 | status VARCHAR(32) NOT NULL, 65 | message VARCHAR(255), 66 | start_time DATETIME NOT NULL, 67 | end_time DATETIME 68 | ); 69 | 70 | CREATE TABLE IF NOT EXISTS analysis ( 71 | prompt_id INTEGER PRIMARY KEY, 72 | prompt LONGTEXT NOT NULL, 73 | analysis LONGTEXT, 74 | status VARCHAR(32) NOT NULL 75 | ); 76 | 77 | -- The dataset row below was added while the table is in prod; 78 | -- we must allow NULL values for this column 79 | CREATE TABLE IF NOT EXISTS user_policies ( 80 | id INTEGER PRIMARY KEY, 81 | country_id VARCHAR(3) NOT NULL, 82 | reform_id INTEGER NOT NULL, 83 | reform_label VARCHAR(255), 84 | baseline_id INTEGER NOT NULL, 85 | baseline_label VARCHAR(255), 86 | user_id VARCHAR(255) NOT NULL, 87 | year VARCHAR(32) NOT NULL, 88 | geography VARCHAR(255) NOT NULL, 89 | dataset VARCHAR(255), 90 | number_of_provisions INTEGER NOT NULL, 91 | api_version VARCHAR(32) NOT NULL, 92 | added_date BIGINT NOT NULL, 93 | updated_date BIGINT NOT NULL, 94 | budgetary_impact VARCHAR(255), 95 | type VARCHAR(255) 96 | ); 97 | 98 | CREATE TABLE IF NOT EXISTS user_profiles ( 99 | user_id INTEGER PRIMARY KEY, 100 | auth0_id VARCHAR(255) NOT NULL UNIQUE, 101 | username VARCHAR(255) UNIQUE, 102 | primary_country VARCHAR(3) NOT NULL, 103 | user_since BIGINT NOT NULL 104 | ); 105 | 106 | CREATE TABLE IF NOT EXISTS tracers ( 107 | id INTEGER PRIMARY KEY, 108 | household_id INT NOT NULL, 109 | policy_id INT NOT NULL, 110 | country_id VARCHAR(3) NOT NULL, 111 | api_version VARCHAR(10) NOT NULL, 112 | tracer_output JSON NOT NULL 113 | ); -------------------------------------------------------------------------------- /policyengine_api/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | from .home import get_home 2 | from .household import ( 3 | get_household_under_policy, 4 | get_calculate, 5 | ) 6 | from .policy import ( 7 | get_policy_search, 8 | set_user_policy, 9 | get_user_policy, 10 | update_user_policy, 11 | ) 12 | 13 | from .simulation import get_simulations 14 | -------------------------------------------------------------------------------- /policyengine_api/endpoints/economy/reform_impact.py: -------------------------------------------------------------------------------- 1 | from policyengine_api.data import local_database 2 | 3 | 4 | def set_comment_on_job( 5 | comment: str, 6 | country_id, 7 | policy_id, 8 | baseline_policy_id, 9 | region, 10 | dataset, 11 | time_period, 12 | options_hash, 13 | ): 14 | query = ( 15 | "UPDATE reform_impact SET message = ? WHERE country_id = ? AND " 16 | "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " 17 | "time_period = ? AND options_hash = ? AND dataset = ?" 18 | ) 19 | 20 | local_database.query( 21 | query, 22 | ( 23 | comment, 24 | country_id, 25 | policy_id, 26 | baseline_policy_id, 27 | region, 28 | time_period, 29 | options_hash, 30 | dataset, 31 | ), 32 | ) 33 | -------------------------------------------------------------------------------- /policyengine_api/endpoints/home.py: -------------------------------------------------------------------------------- 1 | def get_home() -> str: 2 | """Get the home page of the PolicyEngine API. 3 | 4 | Returns: 5 | str: The home page. 6 | """ 7 | return f"

PolicyEngine households API

Use this API to compute the impact of public policy on individual households.

" 8 | -------------------------------------------------------------------------------- /policyengine_api/endpoints/simulation.py: -------------------------------------------------------------------------------- 1 | from policyengine_api.data import local_database 2 | 3 | """ 4 | 5 | CREATE TABLE IF NOT EXISTS reform_impact ( 6 | reform_impact_id INTEGER PRIMARY KEY AUTO_INCREMENT, 7 | baseline_policy_id INT NOT NULL, 8 | reform_policy_id INT NOT NULL, 9 | country_id VARCHAR(3) NOT NULL, 10 | region VARCHAR(32) NOT NULL, 11 | time_period VARCHAR(32) NOT NULL, 12 | options_json JSON, 13 | options_hash VARCHAR(255), 14 | api_version VARCHAR(10) NOT NULL, 15 | reform_impact_json JSON NOT NULL, 16 | status VARCHAR(32) NOT NULL, 17 | message VARCHAR(255), 18 | start_time DATETIME 19 | ); 20 | 21 | """ 22 | 23 | 24 | def get_simulations( 25 | max_results: int = 100, 26 | ): 27 | # Get the last N simulations ordered by start time 28 | 29 | desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" 30 | 31 | result = local_database.query( 32 | f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", 33 | ).fetchall() 34 | 35 | # Format into [{}] 36 | 37 | return {"result": [dict(r) for r in result]} 38 | -------------------------------------------------------------------------------- /policyengine_api/gcp_logging.py: -------------------------------------------------------------------------------- 1 | from google.cloud.logging import Client 2 | 3 | logger = Client().logger("policyengine-api") 4 | -------------------------------------------------------------------------------- /policyengine_api/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_job import BaseJob 2 | from .calculate_economy_simulation_job import CalculateEconomySimulationJob 3 | -------------------------------------------------------------------------------- /policyengine_api/jobs/base_job.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import datetime 3 | from enum import Enum 4 | 5 | 6 | class JobStatus(Enum): 7 | PENDING = "pending" 8 | RUNNING = "running" 9 | COMPLETED = "completed" 10 | FAILED = "failed" 11 | 12 | 13 | class BaseJob(ABC): 14 | def __init__(self): 15 | self.started_at = None 16 | self.completed_at = None 17 | self.status = JobStatus.PENDING 18 | self.result = None 19 | self.error = None 20 | 21 | # Individual jobs should implement this method 22 | @abstractmethod 23 | def run(self, *args, **kwargs): 24 | pass 25 | 26 | def execute(self, *args, **kwargs): 27 | try: 28 | self.start_time = datetime.datetime.now(datetime.timezone.utc) 29 | self.status = JobStatus.RUNNING 30 | self.result = self.run(*args, **kwargs) 31 | self.status = JobStatus.COMPLETED 32 | except Exception as e: 33 | self.status = JobStatus.FAILED 34 | self.error = str(e) 35 | raise 36 | finally: 37 | self.end_time = datetime.datetime.now(datetime.timezone.utc) 38 | -------------------------------------------------------------------------------- /policyengine_api/jobs/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .compute_general_economy import compute_general_economy 2 | -------------------------------------------------------------------------------- /policyengine_api/routes/ai_prompt_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, Response, request 2 | from copy import deepcopy 3 | from policyengine_api.services.ai_prompt_service import AIPromptService 4 | from policyengine_api.utils.payload_validators import validate_country 5 | from policyengine_api.utils.payload_validators.ai import ( 6 | validate_sim_analysis_payload, 7 | ) 8 | from werkzeug.exceptions import NotFound, BadRequest 9 | import json 10 | 11 | 12 | ai_prompt_bp = Blueprint("ai_prompt", __name__) 13 | ai_prompt_service = AIPromptService() 14 | 15 | 16 | @validate_country 17 | @ai_prompt_bp.route( 18 | "//ai-prompts/", 19 | methods=["POST"], 20 | ) 21 | def generate_ai_prompt(country_id, prompt_name: str) -> Response: 22 | """ 23 | Get an AI prompt with a given name, filled with the given data. 24 | """ 25 | print(f"Got request for AI prompt {prompt_name}") 26 | 27 | payload = request.json 28 | 29 | is_payload_valid, message = validate_sim_analysis_payload(payload) 30 | if not is_payload_valid: 31 | raise BadRequest(f"Invalid JSON data; details: {message}") 32 | 33 | input_data = { 34 | **deepcopy(payload), 35 | "country_id": country_id, 36 | } 37 | 38 | prompt: str | None = ai_prompt_service.get_prompt( 39 | name=prompt_name, input_data=input_data 40 | ) 41 | if prompt is None: 42 | raise NotFound(f"Prompt with name {prompt_name} not found.") 43 | 44 | return Response( 45 | json.dumps( 46 | { 47 | "status": "ok", 48 | "message": None, 49 | "result": prompt, 50 | } 51 | ), 52 | status=200, 53 | mimetype="application/json", 54 | ) 55 | -------------------------------------------------------------------------------- /policyengine_api/routes/economy_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint 2 | from policyengine_api.services.economy_service import EconomyService 3 | from policyengine_api.utils import get_current_law_policy_id 4 | from policyengine_api.utils.payload_validators import validate_country 5 | from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS 6 | from flask import request 7 | import json 8 | 9 | economy_bp = Blueprint("economy", __name__) 10 | economy_service = EconomyService() 11 | 12 | 13 | @validate_country 14 | @economy_bp.route( 15 | "//economy//over/", 16 | methods=["GET"], 17 | ) 18 | def get_economic_impact( 19 | country_id: str, policy_id: int, baseline_policy_id: int 20 | ): 21 | 22 | policy_id = int(policy_id or get_current_law_policy_id(country_id)) 23 | baseline_policy_id = int( 24 | baseline_policy_id or get_current_law_policy_id(country_id) 25 | ) 26 | 27 | # Pop items from query params 28 | query_parameters = request.args 29 | options = dict(query_parameters) 30 | options = json.loads(json.dumps(options)) 31 | region = options.pop("region") 32 | dataset = options.pop("dataset", "default") 33 | time_period = options.pop("time_period") 34 | api_version = options.pop( 35 | "version", COUNTRY_PACKAGE_VERSIONS.get(country_id) 36 | ) 37 | 38 | result = economy_service.get_economic_impact( 39 | country_id, 40 | policy_id, 41 | baseline_policy_id, 42 | region, 43 | dataset, 44 | time_period, 45 | options, 46 | api_version, 47 | ) 48 | return result 49 | -------------------------------------------------------------------------------- /policyengine_api/routes/error_routes.py: -------------------------------------------------------------------------------- 1 | import json 2 | from flask import Response, Blueprint 3 | from werkzeug.exceptions import ( 4 | HTTPException, 5 | ) 6 | 7 | error_bp = Blueprint("error", __name__) 8 | 9 | 10 | @error_bp.app_errorhandler(404) 11 | def response_404(error) -> Response: 12 | """Specific handler for 404 Not Found errors""" 13 | return make_error_response(error, 404) 14 | 15 | 16 | @error_bp.app_errorhandler(400) 17 | def response_400(error) -> Response: 18 | """Specific handler for 400 Bad Request errors""" 19 | return make_error_response(error, 400) 20 | 21 | 22 | @error_bp.app_errorhandler(401) 23 | def response_401(error) -> Response: 24 | """Specific handler for 401 Unauthorized errors""" 25 | return make_error_response(error, 401) 26 | 27 | 28 | @error_bp.app_errorhandler(403) 29 | def response_403(error) -> Response: 30 | """Specific handler for 403 Forbidden errors""" 31 | return make_error_response(error, 403) 32 | 33 | 34 | @error_bp.app_errorhandler(500) 35 | def response_500(error) -> Response: 36 | """Specific handler for 500 Internal Server errors""" 37 | return make_error_response(error, 500) 38 | 39 | 40 | @error_bp.app_errorhandler(HTTPException) 41 | def response_http_exception(error: HTTPException) -> Response: 42 | """Generic handler for HTTPException; should be raised if no specific handler is found""" 43 | return make_error_response(str(error), error.code) 44 | 45 | 46 | @error_bp.app_errorhandler(Exception) 47 | def response_generic_error(error: Exception) -> Response: 48 | """Handler for any unhandled exceptions""" 49 | return make_error_response(str(error), 500) 50 | 51 | 52 | def make_error_response( 53 | error, 54 | status_code: int, 55 | ) -> Response: 56 | """Create a generic error response""" 57 | return Response( 58 | json.dumps( 59 | { 60 | "status": "error", 61 | "message": str(error), 62 | "result": None, 63 | } 64 | ), 65 | status_code, 66 | mimetype="application/json", 67 | ) 68 | -------------------------------------------------------------------------------- /policyengine_api/routes/household_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, Response, request 2 | from werkzeug.exceptions import NotFound, BadRequest 3 | import json 4 | 5 | from policyengine_api.services.household_service import HouseholdService 6 | from policyengine_api.utils.payload_validators import ( 7 | validate_household_payload, 8 | validate_country, 9 | ) 10 | 11 | 12 | household_bp = Blueprint("household", __name__) 13 | household_service = HouseholdService() 14 | 15 | 16 | @household_bp.route( 17 | "//household/", methods=["GET"] 18 | ) 19 | @validate_country 20 | def get_household(country_id: str, household_id: int) -> Response: 21 | """ 22 | Get a household's input data with a given ID. 23 | 24 | Args: 25 | country_id (str): The country ID. 26 | household_id (int): The household ID. 27 | """ 28 | print(f"Got request for household {household_id} in country {country_id}") 29 | 30 | household: dict | None = household_service.get_household( 31 | country_id, household_id 32 | ) 33 | if household is None: 34 | raise NotFound(f"Household #{household_id} not found.") 35 | else: 36 | return Response( 37 | json.dumps( 38 | { 39 | "status": "ok", 40 | "message": None, 41 | "result": household, 42 | } 43 | ), 44 | status=200, 45 | mimetype="application/json", 46 | ) 47 | 48 | 49 | @household_bp.route("//household", methods=["POST"]) 50 | @validate_country 51 | def post_household(country_id: str) -> Response: 52 | """ 53 | Set a household's input data. 54 | 55 | Args: 56 | country_id (str): The country ID. 57 | """ 58 | 59 | # Validate payload 60 | payload = request.json 61 | is_payload_valid, message = validate_household_payload(payload) 62 | if not is_payload_valid: 63 | raise BadRequest(f"Unable to create new household; details: {message}") 64 | 65 | # The household label appears to be unimplemented at this time, 66 | # thus it should always be 'None' 67 | label: str | None = payload.get("label") 68 | household_json: dict = payload.get("data") 69 | 70 | household_id = household_service.create_household( 71 | country_id, household_json, label 72 | ) 73 | 74 | return Response( 75 | json.dumps( 76 | { 77 | "status": "ok", 78 | "message": None, 79 | "result": { 80 | "household_id": household_id, 81 | }, 82 | } 83 | ), 84 | status=201, 85 | mimetype="application/json", 86 | ) 87 | 88 | 89 | @household_bp.route( 90 | "//household/", methods=["PUT"] 91 | ) 92 | @validate_country 93 | def update_household(country_id: str, household_id: int) -> Response: 94 | """ 95 | Update a household's input data. 96 | 97 | Args: 98 | country_id (str): The country ID. 99 | household_id (int): The household ID. 100 | """ 101 | 102 | # Validate payload 103 | payload = request.json 104 | is_payload_valid, message = validate_household_payload(payload) 105 | if not is_payload_valid: 106 | raise BadRequest( 107 | f"Unable to update household #{household_id}; details: {message}" 108 | ) 109 | 110 | # First, attempt to fetch the existing household 111 | label: str | None = payload.get("label") 112 | household_json: dict = payload.get("data") 113 | 114 | household: dict | None = household_service.get_household( 115 | country_id, household_id 116 | ) 117 | if household is None: 118 | raise NotFound(f"Household #{household_id} not found.") 119 | 120 | # Next, update the household 121 | updated_household: dict = household_service.update_household( 122 | country_id, household_id, household_json, label 123 | ) 124 | return Response( 125 | json.dumps( 126 | { 127 | "status": "ok", 128 | "message": None, 129 | "result": { 130 | "household_id": household_id, 131 | "household_json": updated_household["household_json"], 132 | }, 133 | } 134 | ), 135 | status=200, 136 | mimetype="application/json", 137 | ) 138 | -------------------------------------------------------------------------------- /policyengine_api/routes/metadata_routes.py: -------------------------------------------------------------------------------- 1 | import json 2 | from flask import Blueprint, Response 3 | 4 | from policyengine_api.utils.payload_validators import validate_country 5 | from policyengine_api.services.metadata_service import MetadataService 6 | 7 | metadata_bp = Blueprint("metadata", __name__) 8 | metadata_service = MetadataService() 9 | 10 | 11 | @metadata_bp.route("//metadata", methods=["GET"]) 12 | @validate_country 13 | def get_metadata(country_id: str) -> Response: 14 | """Get metadata for a country. 15 | 16 | Args: 17 | country_id (str): The country ID. 18 | """ 19 | 20 | # Retrieve country metadata and add status and message to the response 21 | country_metadata = metadata_service.get_metadata(country_id) 22 | return Response( 23 | json.dumps( 24 | {"status": "ok", "message": None, "result": country_metadata} 25 | ), 26 | status=200, 27 | mimetype="application/json", 28 | ) 29 | -------------------------------------------------------------------------------- /policyengine_api/routes/policy_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, Response, request 2 | import json 3 | 4 | from policyengine_api.services.policy_service import PolicyService 5 | from werkzeug.exceptions import NotFound, BadRequest 6 | from policyengine_api.utils.payload_validators import ( 7 | validate_country, 8 | validate_set_policy_payload, 9 | ) 10 | 11 | policy_bp = Blueprint("policy", __name__) 12 | policy_service = PolicyService() 13 | 14 | 15 | @policy_bp.route("//policy/", methods=["GET"]) 16 | @validate_country 17 | def get_policy(country_id: str, policy_id: int | str) -> Response: 18 | """ 19 | Get policy data for a given country and policy ID. 20 | 21 | Args: 22 | country_id (str) 23 | policy_id (int | str) 24 | 25 | Returns: 26 | Response: A Flask response object containing the 27 | policy data in JSON format 28 | """ 29 | 30 | # Specifically cast policy_id to an integer 31 | policy_id = int(policy_id) 32 | 33 | policy: dict | None = policy_service.get_policy(country_id, policy_id) 34 | 35 | if policy is None: 36 | raise NotFound(f"Policy #{policy_id} not found.") 37 | 38 | return Response( 39 | json.dumps({"status": "ok", "message": None, "result": policy}), 40 | status=200, 41 | ) 42 | 43 | 44 | @policy_bp.route("//policy", methods=["POST"]) 45 | @validate_country 46 | def set_policy(country_id: str) -> Response: 47 | """ 48 | Set policy data for given country and policy. If policy already exists, 49 | return existing policy and 200. 50 | 51 | Args: 52 | country_id (str) 53 | """ 54 | 55 | payload = request.json 56 | 57 | is_payload_valid, message = validate_set_policy_payload(payload) 58 | if not is_payload_valid: 59 | raise BadRequest(f"Invalid JSON data; details: {message}") 60 | 61 | label = payload.pop("label", None) 62 | policy_json = payload.pop("data", None) 63 | 64 | policy_id, message, is_existing_policy = policy_service.set_policy( 65 | country_id, 66 | label, 67 | policy_json, 68 | ) 69 | 70 | response_body = dict( 71 | status="ok", 72 | message=message, 73 | result=dict( 74 | policy_id=policy_id, 75 | ), 76 | ) 77 | 78 | code = 200 if is_existing_policy else 201 79 | return Response( 80 | json.dumps(response_body), status=code, mimetype="application/json" 81 | ) 82 | -------------------------------------------------------------------------------- /policyengine_api/routes/simulation_analysis_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, request, Response, stream_with_context 2 | from werkzeug.exceptions import BadRequest 3 | from policyengine_api.utils.payload_validators import validate_country 4 | from policyengine_api.services.simulation_analysis_service import ( 5 | SimulationAnalysisService, 6 | ) 7 | from policyengine_api.utils.payload_validators import ( 8 | validate_country, 9 | ) 10 | from policyengine_api.utils.payload_validators.ai import ( 11 | validate_sim_analysis_payload, 12 | ) 13 | import json 14 | 15 | simulation_analysis_bp = Blueprint("simulation_analysis", __name__) 16 | simulation_analysis_service = SimulationAnalysisService() 17 | 18 | 19 | @simulation_analysis_bp.route( 20 | "//simulation-analysis", methods=["POST"] 21 | ) 22 | @validate_country 23 | def execute_simulation_analysis(country_id): 24 | print("Got POST request for simulation analysis") 25 | 26 | # Pop items from request payload and validate 27 | # where necessary 28 | payload = request.json 29 | 30 | is_payload_valid, message = validate_sim_analysis_payload(payload) 31 | if not is_payload_valid: 32 | raise BadRequest(f"Invalid JSON data; details: {message}") 33 | 34 | currency: str = payload.get("currency") 35 | selected_version: str = payload.get("selected_version") 36 | dataset: str | None = payload.get("dataset") 37 | time_period: str = payload.get("time_period") 38 | impact: dict = payload.get("impact") 39 | policy_label: str = payload.get("policy_label") 40 | policy: dict = payload.get("policy") 41 | region: str = payload.get("region") 42 | relevant_parameters: list[dict] = payload.get("relevant_parameters") 43 | relevant_parameter_baseline_values: list[dict] = payload.get( 44 | "relevant_parameter_baseline_values" 45 | ) 46 | audience = payload.get("audience", "") 47 | 48 | analysis, analysis_type = simulation_analysis_service.execute_analysis( 49 | country_id, 50 | currency, 51 | dataset, 52 | selected_version, 53 | time_period, 54 | impact, 55 | policy_label, 56 | policy, 57 | region, 58 | relevant_parameters, 59 | relevant_parameter_baseline_values, 60 | audience, 61 | ) 62 | 63 | if analysis_type == "static": 64 | return Response( 65 | json.dumps({"status": "ok", "result": analysis, "message": None}), 66 | mimetype="application/json", 67 | ) 68 | 69 | # Create streaming response 70 | response = Response( 71 | stream_with_context(analysis), 72 | status=200, 73 | mimetype="application/x-ndjson", 74 | ) 75 | 76 | # Set header to prevent buffering on Google App Engine deployment 77 | # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) 78 | response.headers["X-Accel-Buffering"] = "no" 79 | 80 | return response 81 | -------------------------------------------------------------------------------- /policyengine_api/routes/tracer_analysis_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, request, Response, stream_with_context 2 | from werkzeug.exceptions import BadRequest 3 | from policyengine_api.utils.payload_validators import ( 4 | validate_country, 5 | validate_tracer_analysis_payload, 6 | ) 7 | from policyengine_api.services.tracer_analysis_service import ( 8 | TracerAnalysisService, 9 | ) 10 | import json 11 | 12 | tracer_analysis_bp = Blueprint("tracer_analysis", __name__) 13 | tracer_analysis_service = TracerAnalysisService() 14 | 15 | 16 | @tracer_analysis_bp.route("//tracer-analysis", methods=["POST"]) 17 | @validate_country 18 | def execute_tracer_analysis(country_id): 19 | 20 | payload = request.json 21 | 22 | is_payload_valid, message = validate_tracer_analysis_payload(payload) 23 | if not is_payload_valid: 24 | raise BadRequest(f"Invalid JSON data; details: {message}") 25 | 26 | household_id = payload.get("household_id") 27 | policy_id = payload.get("policy_id") 28 | variable = payload.get("variable") 29 | 30 | analysis, analysis_type = tracer_analysis_service.execute_analysis( 31 | country_id, 32 | household_id, 33 | policy_id, 34 | variable, 35 | ) 36 | 37 | if analysis_type == "static": 38 | return Response( 39 | json.dumps({"status": "ok", "result": analysis, "message": None}), 40 | status=200, 41 | mimetype="application/json", 42 | ) 43 | 44 | # Create streaming response 45 | response = Response( 46 | stream_with_context(analysis), 47 | status=200, 48 | mimetype="application/x-ndjson", 49 | ) 50 | 51 | # Set header to prevent buffering on Google App Engine deployment 52 | # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) 53 | response.headers["X-Accel-Buffering"] = "no" 54 | 55 | return response 56 | -------------------------------------------------------------------------------- /policyengine_api/routes/user_profile_routes.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, Response, request 2 | from policyengine_api.utils.payload_validators import validate_country 3 | from policyengine_api.data import database 4 | import json 5 | from policyengine_api.services.user_service import UserService 6 | from werkzeug.exceptions import BadRequest, NotFound 7 | 8 | user_profile_bp = Blueprint("user_profile", __name__) 9 | user_service = UserService() 10 | 11 | 12 | @user_profile_bp.route("//user-profile", methods=["POST"]) 13 | @validate_country 14 | def set_user_profile(country_id: str) -> Response: 15 | """ 16 | Creates a new user_profile 17 | """ 18 | 19 | payload = request.json 20 | if payload is None: 21 | raise BadRequest("Payload missing from request") 22 | 23 | auth0_id = payload.pop("auth0_id") 24 | username = payload.pop("username", None) 25 | user_since = payload.pop("user_since") 26 | 27 | created, row = user_service.create_profile( 28 | primary_country=country_id, 29 | auth0_id=auth0_id, 30 | username=username, 31 | user_since=user_since, 32 | ) 33 | 34 | response = dict( 35 | status="ok", 36 | message="Record created successfully" if created else "Record exists", 37 | result=dict( 38 | user_id=row["user_id"], 39 | primary_country=row["primary_country"], 40 | username=row["username"], 41 | user_since=row["user_since"], 42 | ), 43 | ) 44 | return Response( 45 | json.dumps(response), 46 | status=201 if created else 200, 47 | mimetype="application/json", 48 | ) 49 | 50 | 51 | @user_profile_bp.route("//user-profile", methods=["GET"]) 52 | @validate_country 53 | def get_user_profile(country_id: str) -> Response: 54 | auth0_id = request.args.get("auth0_id") 55 | user_id = request.args.get("user_id") 56 | 57 | if (auth0_id is None) and (user_id is None): 58 | raise BadRequest("auth0_id or user_id must be provided") 59 | 60 | row = ( 61 | user_service.get_profile(user_id=user_id) 62 | if auth0_id is None 63 | else user_service.get_profile(auth0_id=auth0_id) 64 | ) 65 | 66 | if row is None: 67 | raise NotFound("No such user") 68 | 69 | readable_row = dict(row) 70 | # Delete auth0_id value if querying from user_id, as that value 71 | # is a more private attribute than all others 72 | if auth0_id is None: 73 | del readable_row["auth0_id"] 74 | 75 | response_body = dict( 76 | status="ok", 77 | message=f"User #{readable_row['user_id']} found successfully", 78 | result=readable_row, 79 | ) 80 | 81 | return Response( 82 | json.dumps(response_body), 83 | status=200, 84 | mimetype="application/json", 85 | ) 86 | 87 | 88 | @user_profile_bp.route("//user-profile", methods=["PUT"]) 89 | @validate_country 90 | def update_user_profile(country_id: str) -> Response: 91 | """ 92 | Update any part of a user_profile, given a user_id, 93 | except the auth0_id value; any attempt to edit this 94 | will assume malicious intent and 403 95 | """ 96 | 97 | # Construct the relevant UPDATE request 98 | setter_array = [] 99 | args = [] 100 | payload = request.json 101 | 102 | if payload is None: 103 | raise BadRequest("No user data provided in request") 104 | 105 | # TODO: we should validate the payload 106 | # to ensure type safety https://github.com/PolicyEngine/policyengine-api/issues/2054 107 | user_id = payload.pop("user_id") 108 | username = payload.pop("username", None) 109 | primary_country = payload.pop("primary_country", None) 110 | user_since = payload.pop("user_since", None) 111 | 112 | if user_id is None: 113 | raise BadRequest("Payload must include user_id") 114 | 115 | updated = user_service.update_profile( 116 | user_id=user_id, 117 | primary_country=primary_country, 118 | username=username, 119 | user_since=user_since, 120 | ) 121 | 122 | if not updated: 123 | raise NotFound("No such user id") 124 | 125 | response_body = dict( 126 | status="ok", 127 | message=f"User profile #{user_id} updated successfully", 128 | result=dict(user_id=user_id), 129 | ) 130 | 131 | return Response( 132 | json.dumps(response_body), 133 | status=200, 134 | mimetype="application/json", 135 | ) 136 | -------------------------------------------------------------------------------- /policyengine_api/services/__init__.py: -------------------------------------------------------------------------------- 1 | # from .economy_service import EconomyService, JobService, PolicyService 2 | -------------------------------------------------------------------------------- /policyengine_api/services/ai_analysis_service.py: -------------------------------------------------------------------------------- 1 | import anthropic 2 | import os 3 | import json 4 | from typing import Generator, Optional, Literal 5 | from policyengine_api.data import local_database 6 | from pydantic import BaseModel 7 | 8 | 9 | class StreamEvent(BaseModel): 10 | type: str 11 | 12 | 13 | class TextEvent(StreamEvent): 14 | type: str = "text" 15 | stream: str 16 | 17 | 18 | class ErrorEvent(StreamEvent): 19 | type: str = "error" 20 | error: str 21 | 22 | 23 | class AIAnalysisService: 24 | """ 25 | Base class for various AI analysis-based services, 26 | including SimulationAnalysisService, that connects with the analysis 27 | local database table 28 | """ 29 | 30 | def get_existing_analysis(self, prompt: str) -> Optional[str]: 31 | """ 32 | Get existing analysis from the local database 33 | """ 34 | 35 | analysis = local_database.query( 36 | f"SELECT analysis FROM analysis WHERE prompt = ?", 37 | (prompt,), 38 | ).fetchone() 39 | 40 | if analysis is None: 41 | return None 42 | 43 | return json.dumps(analysis["analysis"]) 44 | 45 | def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: 46 | 47 | # Configure a Claude client 48 | claude_client = anthropic.Anthropic( 49 | api_key=os.getenv("ANTHROPIC_API_KEY") 50 | ) 51 | 52 | def generate(): 53 | response_text = "" 54 | 55 | # Temporarily downgrading Claude to 3.5 Sonnet to prevent unwanted 56 | # quotes in Explain with AI feature responses. 57 | # If Claude is still at 3.5 Sonnet on July 1, 2025, file an issue. 58 | # See https://github.com/PolicyEngine/policyengine-app/issues/2584 59 | with claude_client.messages.stream( 60 | model="claude-3-5-sonnet-20240620", 61 | max_tokens=1500, 62 | temperature=0.0, 63 | system="Respond with a historical quote", 64 | messages=[{"role": "user", "content": prompt}], 65 | ) as stream: 66 | for event in stream: 67 | # Docs on structure of Anthropic error events at https://docs.anthropic.com/en/api/messages-streaming#error-events 68 | if event.type == "error": 69 | error: dict[str, str] = event.error 70 | error_type: str = error["type"] 71 | return_event = ErrorEvent(error=error_type) 72 | yield json.dumps(return_event.model_dump()) + "\n" 73 | return 74 | if event.type == "text": 75 | response_text += event.text 76 | return_event = TextEvent(stream=event.text) 77 | yield json.dumps(return_event.model_dump()) + "\n" 78 | 79 | # Update the analysis record and return if no error occurred 80 | local_database.query( 81 | f"INSERT INTO analysis (prompt, analysis, status) VALUES (?, ?, ?)", 82 | (prompt, response_text, "ok"), 83 | ) 84 | 85 | return generate() 86 | -------------------------------------------------------------------------------- /policyengine_api/services/ai_prompt_service.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any 2 | from policyengine_api.ai_prompts.simulation_analysis_prompt import ( 3 | generate_simulation_analysis_prompt, 4 | ) 5 | 6 | AIPrompt = Callable[[dict[str, Any]], str] 7 | 8 | ALL_AI_PROMPTS: dict[str, AIPrompt] = { 9 | "simulation_analysis": generate_simulation_analysis_prompt, 10 | } 11 | 12 | 13 | class AIPromptService: 14 | 15 | def get_prompt(self, name: str, input_data: dict) -> str | None: 16 | """ 17 | Get an AI prompt with a given name, filled with the given data. 18 | """ 19 | 20 | if name in ALL_AI_PROMPTS: 21 | return ALL_AI_PROMPTS[name](input_data) 22 | 23 | return None 24 | -------------------------------------------------------------------------------- /policyengine_api/services/household_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sqlalchemy.engine.row import LegacyRow 3 | 4 | from policyengine_api.data import database 5 | from policyengine_api.utils import hash_object 6 | from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS 7 | 8 | 9 | class HouseholdService: 10 | 11 | def get_household(self, country_id: str, household_id: int) -> dict | None: 12 | """ 13 | Get a household's input data with a given ID. 14 | 15 | Args: 16 | country_id (str): The country ID. 17 | household_id (int): The household ID. 18 | """ 19 | print("Getting household data") 20 | 21 | try: 22 | if type(household_id) is not int or household_id < 0: 23 | raise Exception( 24 | f"Invalid household ID: {household_id}. Must be a positive integer." 25 | ) 26 | 27 | row: LegacyRow | None = database.query( 28 | f"SELECT * FROM household WHERE id = ? AND country_id = ?", 29 | (household_id, country_id), 30 | ).fetchone() 31 | 32 | # If row is present, we must JSON.loads the household_json 33 | household = None 34 | if row is not None: 35 | household = dict(row) 36 | if household["household_json"]: 37 | household["household_json"] = json.loads( 38 | household["household_json"] 39 | ) 40 | return household 41 | 42 | except Exception as e: 43 | print( 44 | f"Error fetching household #{household_id}. Details: {str(e)}" 45 | ) 46 | raise e 47 | 48 | def create_household( 49 | self, 50 | country_id: str, 51 | household_json: dict, 52 | label: str | None, 53 | ) -> int: 54 | """ 55 | Create a new household with the given data. 56 | 57 | Args: 58 | country_id (str): The country ID. 59 | household_json (dict): The household data. 60 | household_hash (int): The hash of the household data. 61 | label (str): The label for the household. 62 | api_version (str): The API version. 63 | """ 64 | 65 | print("Creating new household") 66 | 67 | try: 68 | household_hash: str = hash_object(household_json) 69 | api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) 70 | 71 | database.query( 72 | f"INSERT INTO household (country_id, household_json, household_hash, label, api_version) VALUES (?, ?, ?, ?, ?)", 73 | ( 74 | country_id, 75 | json.dumps(household_json), 76 | household_hash, 77 | label, 78 | api_version, 79 | ), 80 | ) 81 | 82 | household_id = database.query( 83 | f"SELECT id FROM household WHERE country_id = ? AND household_hash = ?", 84 | (country_id, household_hash), 85 | ).fetchone()["id"] 86 | 87 | return household_id 88 | except Exception as e: 89 | print(f"Error creating household. Details: {str(e)}") 90 | raise e 91 | 92 | def update_household( 93 | self, 94 | country_id: str, 95 | household_id: int, 96 | household_json: dict, 97 | label: str, 98 | ) -> dict: 99 | """ 100 | Update a household with the given data. 101 | 102 | Args: 103 | country_id (str): The country ID. 104 | household_id (int): The household ID. 105 | payload (dict): The data to update the household with. 106 | """ 107 | print("Updating household") 108 | 109 | try: 110 | 111 | household_hash: str = hash_object(household_json) 112 | api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) 113 | 114 | database.query( 115 | f"UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", 116 | ( 117 | json.dumps(household_json), 118 | household_hash, 119 | label, 120 | api_version, 121 | household_id, 122 | ), 123 | ) 124 | 125 | # Fetch the updated JSON back from the table 126 | updated_household: dict = self.get_household( 127 | country_id, household_id 128 | ) 129 | return updated_household 130 | except Exception as e: 131 | print( 132 | f"Error updating household #{household_id}. Details: {str(e)}" 133 | ) 134 | raise e 135 | -------------------------------------------------------------------------------- /policyengine_api/services/job_service.py: -------------------------------------------------------------------------------- 1 | from redis import Redis 2 | from rq import Queue 3 | from rq.job import Job 4 | from policyengine_api.utils import Singleton 5 | from policyengine_api.jobs import CalculateEconomySimulationJob 6 | from datetime import datetime 7 | from enum import Enum 8 | 9 | calc_ec_sim_job = CalculateEconomySimulationJob() 10 | 11 | queue = Queue(connection=Redis()) 12 | 13 | 14 | class JobStatus(Enum): 15 | PENDING = "pending" 16 | RUNNING = "running" 17 | COMPLETED = "completed" 18 | FAILED = "failed" 19 | 20 | 21 | class JobService(metaclass=Singleton): 22 | """ 23 | Hybrid service used to manage backend economy-wide simulation 24 | jobs. This is not connected to any routes or tables, but interfaces 25 | with the Redis queue to enqueue jobs and track their status. 26 | """ 27 | 28 | def __init__(self): 29 | self.recent_jobs = {} 30 | 31 | def execute_job(self, job_id, job_timeout, type, *args, **kwargs): 32 | try: 33 | # Prevent duplicate jobs 34 | try: 35 | existing_job = Job.fetch(job_id, connection=queue.connection) 36 | if existing_job and existing_job.get_status() not in [ 37 | "finished", 38 | "failed", 39 | ]: 40 | print( 41 | f"Job {job_id} already exists and is {existing_job.get_status()}" 42 | ) 43 | return 44 | except Exception as e: 45 | # Job doesn't exist, continue with creation 46 | pass 47 | 48 | match type: 49 | case "calculate_economy_simulation": 50 | queue.enqueue( 51 | f=calc_ec_sim_job.run, 52 | *args, 53 | **kwargs, 54 | job_id=job_id, 55 | job_timeout=job_timeout, 56 | ) 57 | case _: 58 | raise ValueError(f"Invalid job type: {type}") 59 | 60 | self._prune_recent_jobs() 61 | except Exception as e: 62 | print(f"Error executing job: {str(e)}") 63 | raise e 64 | 65 | def fetch_job_queue_pos(self, job_id): 66 | try: 67 | job = Job.fetch(job_id, connection=queue.connection) 68 | pos = ( 69 | job.get_position() 70 | if type(job.get_position()) == (int or float) 71 | else 0 72 | ) 73 | return pos 74 | except Exception as e: 75 | print(f"Error fetching job queue position: {str(e)}") 76 | raise e 77 | 78 | def get_recent_jobs(self): 79 | return self.recent_jobs 80 | 81 | def update_recent_job(self, job_id, key, value): 82 | self.recent_jobs[job_id][key] = value 83 | 84 | def add_recent_job(self, type, job_id, start_time, end_time): 85 | self.recent_jobs[job_id] = dict( 86 | type=type, start_time=start_time, end_time=end_time 87 | ) 88 | 89 | def _prune_recent_jobs(self): 90 | if len(self.recent_jobs) > 100: 91 | oldest_job_id = min( 92 | self.recent_jobs, 93 | key=lambda k: self.recent_jobs[k]["start_time"], 94 | ) 95 | del self.recent_jobs[oldest_job_id] 96 | 97 | def get_average_time(self): 98 | """Get the average time for the last 10 jobs. Jobs might not have an end time (None).""" 99 | recent_jobs = [ 100 | job for job in self.recent_jobs.values() if job["end_time"] 101 | ] 102 | # Get 10 most recently finishing jobs 103 | recent_jobs = sorted( 104 | recent_jobs, key=lambda x: x["end_time"], reverse=True 105 | )[:10] 106 | print(recent_jobs, self.recent_jobs) 107 | if not recent_jobs: 108 | return 100 109 | total_time = sum( 110 | [ 111 | (job["end_time"] - job["start_time"]).total_seconds() 112 | for job in recent_jobs 113 | ] 114 | ) 115 | return total_time / len(recent_jobs) 116 | -------------------------------------------------------------------------------- /policyengine_api/services/metadata_service.py: -------------------------------------------------------------------------------- 1 | from policyengine_api.country import COUNTRIES 2 | 3 | 4 | class MetadataService: 5 | def get_metadata(self, country_id: str) -> dict: 6 | country = COUNTRIES.get(country_id) 7 | if country == None: 8 | raise RuntimeError( 9 | f"Attempted to get metadata for a nonexistant country: '{country_id}'" 10 | ) 11 | 12 | return country.metadata 13 | -------------------------------------------------------------------------------- /policyengine_api/services/simulation_analysis_service.py: -------------------------------------------------------------------------------- 1 | from policyengine_api.services.ai_analysis_service import AIAnalysisService 2 | from policyengine_api.services.ai_prompt_service import AIPromptService 3 | from typing import Generator, Literal 4 | 5 | ai_prompt_service = AIPromptService() 6 | 7 | 8 | class SimulationAnalysisService(AIAnalysisService): 9 | """ 10 | Service for generating AI analysis of economy-wide simulation 11 | runs; this is connected with the simulation_analysis route and 12 | analysis database table 13 | """ 14 | 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def execute_analysis( 19 | self, 20 | country_id: str, 21 | currency: str, 22 | dataset: str | None, 23 | selected_version: str, 24 | time_period: str, 25 | impact: dict, 26 | policy_label: str, 27 | policy: dict, 28 | region: str, 29 | relevant_parameters: list[dict], 30 | relevant_parameter_baseline_values: list[dict], 31 | audience: str | None, 32 | ) -> tuple[ 33 | Generator[str, None, None] | str, Literal["streaming", "static"] 34 | ]: 35 | """ 36 | Execute AI analysis for economy-wide simulation 37 | 38 | Returns a tuple of: 39 | - The AI analysis as either a streaming output (if new) or 40 | a string (if existing in database) 41 | - The return type (either "streaming" or "static") 42 | 43 | """ 44 | 45 | print("Generating prompt for economy-wide simulation analysis") 46 | 47 | # Create prompt based on data 48 | prompt = self._generate_simulation_analysis_prompt( 49 | time_period, 50 | region, 51 | currency, 52 | policy, 53 | impact, 54 | relevant_parameters, 55 | relevant_parameter_baseline_values, 56 | selected_version, 57 | country_id, 58 | policy_label, 59 | audience, 60 | dataset=dataset, 61 | ) 62 | 63 | print("Checking if AI analysis already exists for this prompt") 64 | # If a calculated record exists for this prompt, return it as a 65 | # streaming response 66 | existing_analysis = self.get_existing_analysis(prompt) 67 | if existing_analysis is not None: 68 | return existing_analysis, "static" 69 | 70 | print( 71 | "Found no existing AI analysis; triggering new analysis with Claude" 72 | ) 73 | # Otherwise, pass prompt to Claude, then return streaming function 74 | try: 75 | analysis = self.trigger_ai_analysis(prompt) 76 | return analysis, "streaming" 77 | except Exception as e: 78 | raise e 79 | 80 | def _generate_simulation_analysis_prompt( 81 | self, 82 | time_period, 83 | region, 84 | currency, 85 | policy, 86 | impact, 87 | relevant_parameters, 88 | relevant_parameter_baseline_values, 89 | selected_version, 90 | country_id, 91 | policy_label, 92 | audience, 93 | dataset, 94 | ): 95 | 96 | prompt_data: dict = { 97 | "time_period": time_period, 98 | "region": region, 99 | "currency": currency, 100 | "policy": policy, 101 | "impact": impact, 102 | "relevant_parameters": relevant_parameters, 103 | "relevant_parameter_baseline_values": relevant_parameter_baseline_values, 104 | "selected_version": selected_version, 105 | "country_id": country_id, 106 | "policy_label": policy_label, 107 | "audience": audience, 108 | "dataset": dataset, 109 | } 110 | 111 | try: 112 | prompt = ai_prompt_service.get_prompt( 113 | "simulation_analysis", prompt_data 114 | ) 115 | return prompt 116 | 117 | except Exception as e: 118 | raise e 119 | -------------------------------------------------------------------------------- /policyengine_api/services/user_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | from policyengine_api.data import database 4 | 5 | 6 | class UserService: 7 | def create_profile( 8 | self, 9 | primary_country: str, 10 | auth0_id: str, 11 | username: str | None, 12 | user_since: str, 13 | ) -> tuple[bool, Any]: 14 | """ 15 | returns true if a new record was created and false otherwise. 16 | """ 17 | # TODO: this is not written as an atomic operation. This will cause intermittent errors 18 | # in some cases 19 | # https://github.com/PolicyEngine/policyengine-api/issues/2058 to resolve after 20 | # this refactor. 21 | row = self.get_profile(auth0_id=auth0_id) 22 | if row is not None: 23 | return False, row 24 | # Unfortunately, it's not possible to use RETURNING 25 | # with SQLite3 without rewriting the PolicyEngineDatabase 26 | # object or implementing a true ORM, thus the double query 27 | database.query( 28 | f"INSERT INTO user_profiles (primary_country, auth0_id, username, user_since) VALUES (?, ?, ?, ?)", 29 | (primary_country, auth0_id, username, user_since), 30 | ) 31 | 32 | row = self.get_profile(auth0_id=auth0_id) 33 | 34 | return (True, row) 35 | 36 | def get_profile( 37 | self, auth0_id: str | None = None, user_id: str | None = None 38 | ) -> Any | None: 39 | key = "user_id" if auth0_id is None else "auth0_id" 40 | value = user_id if auth0_id is None else auth0_id 41 | if value is None: 42 | raise ValueError("you must specify either auth0_id or user_id") 43 | row = database.query( 44 | f"SELECT * FROM user_profiles WHERE {key} = ?", 45 | (value,), 46 | ).fetchone() 47 | 48 | return row 49 | 50 | def update_profile( 51 | self, 52 | user_id: str, 53 | primary_country: str | None, 54 | username: str | None, 55 | user_since: str, 56 | ) -> bool: 57 | fields = dict( 58 | primary_country=primary_country, 59 | username=username, 60 | user_since=user_since, 61 | ) 62 | if self.get_profile(user_id=user_id) is None: 63 | return False 64 | 65 | with_values = [key for key in fields if fields[key] is not None] 66 | fields_update = ",".join([f"{key} = ?" for key in with_values]) 67 | query = f"UPDATE user_profiles SET {fields_update} WHERE user_id = ?" 68 | values = [fields[key] for key in with_values] + [user_id] 69 | 70 | print(f"Updating record {user_id}") 71 | try: 72 | database.query(query, (tuple(values))) 73 | except Exception as ex: 74 | print(f"ERROR: unable to update user record: {ex}") 75 | raise 76 | return True 77 | -------------------------------------------------------------------------------- /policyengine_api/setup_data.py: -------------------------------------------------------------------------------- 1 | def setup_data(): 2 | import policyengine_api.endpoints.search 3 | -------------------------------------------------------------------------------- /policyengine_api/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .json import * 2 | from .cache_utils import * 3 | from .singleton import Singleton 4 | from .get_current_law import get_current_law_policy_id 5 | -------------------------------------------------------------------------------- /policyengine_api/utils/cache_utils.py: -------------------------------------------------------------------------------- 1 | """Tools for caching API responses.""" 2 | 3 | import json 4 | import logging 5 | import flask 6 | 7 | 8 | def make_cache_key(*args, **kwargs): 9 | # pylint: disable=unused-argument 10 | """make a hash to uniquely identify a cache entry. 11 | keep it fast, adding overhead to try to add some minor chance of a 12 | cache hit is not worth it. 13 | """ 14 | data = "" 15 | if flask.request.content_type == "application/json": 16 | data = flask.request.get_json() 17 | elif flask.request.content_type in [ 18 | "application/x-www-form-urlencoded", 19 | "multipart/form-data", 20 | ]: 21 | data = flask.request.form.to_dict() 22 | if data != "": 23 | data = json.dumps(data, separators=("", "")) 24 | 25 | cache_key = str(hash(flask.request.full_path + data)) 26 | logging.basicConfig(level=logging.DEBUG) 27 | logging.getLogger().debug( 28 | "PATH: %s, CACHE_KEY: %s", flask.request.full_path, cache_key 29 | ) 30 | return cache_key 31 | -------------------------------------------------------------------------------- /policyengine_api/utils/get_current_law.py: -------------------------------------------------------------------------------- 1 | def get_current_law_policy_id(country_id: str) -> int: 2 | return { 3 | "uk": 1, 4 | "us": 2, 5 | "ca": 3, 6 | "ng": 4, 7 | }[country_id] 8 | -------------------------------------------------------------------------------- /policyengine_api/utils/hugging_face.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import ( 2 | hf_hub_download, 3 | model_info, 4 | ModelInfo, 5 | HfApi, 6 | ) 7 | from huggingface_hub.errors import RepositoryNotFoundError 8 | from getpass import getpass 9 | import os 10 | import warnings 11 | import traceback 12 | 13 | with warnings.catch_warnings(): 14 | warnings.simplefilter("ignore") 15 | 16 | 17 | def get_latest_commit_tag(repo_id, repo_type="model"): 18 | """ 19 | Get the tag associated with the latest commit in a HF repo. 20 | Returns the tag name or None if no tag is associated. 21 | """ 22 | api = HfApi() 23 | 24 | is_repo_private = check_is_repo_private(repo_id) 25 | 26 | authentication_token: str = None 27 | if is_repo_private: 28 | authentication_token: str = get_or_prompt_hf_token() 29 | 30 | # Get list of commits 31 | commits = api.list_repo_commits( 32 | repo_id=repo_id, repo_type=repo_type, token=authentication_token 33 | ) 34 | 35 | if not commits: 36 | return None 37 | 38 | latest_commit = commits[0] # Most recent commit is first 39 | 40 | # Get all tags in the repository 41 | tags = api.list_repo_refs( 42 | repo_id=repo_id, repo_type=repo_type, token=authentication_token 43 | ).tags 44 | 45 | # Find tag that points to the latest commit 46 | for tag in tags: 47 | if tag.target_commit == latest_commit.commit_id: 48 | return tag.ref.replace("refs/tags/", "") 49 | 50 | return None 51 | 52 | 53 | def check_is_repo_private(repo: str) -> bool: 54 | """ 55 | Check if a Hugging Face repository is private. 56 | 57 | Args: 58 | repo (str): The Hugging Face repo name, in format "{org}/{repo}". 59 | 60 | Returns: 61 | bool: True if the repo is private, False otherwise. 62 | """ 63 | try: 64 | fetched_model_info: ModelInfo = model_info(repo) 65 | return fetched_model_info.private 66 | except RepositoryNotFoundError: 67 | return True # If repo not found, assume it's private 68 | except Exception as e: 69 | raise Exception( 70 | f"Unable to check if repo {repo} is private. The full error is {traceback.format_exc()}" 71 | ) 72 | 73 | 74 | def download_huggingface_dataset( 75 | repo: str, 76 | repo_filename: str, 77 | version: str = None, 78 | local_dir: str | None = None, 79 | ): 80 | """ 81 | Download a dataset from the Hugging Face Hub. 82 | 83 | Args: 84 | repo (str): The Hugging Face repo name, in format "{org}/{repo}". 85 | repo_filename (str): The filename of the dataset. 86 | version (str, optional): The version of the dataset. Defaults to None. 87 | local_dir (str, optional): The local directory to save the dataset to. Defaults to None. 88 | """ 89 | is_repo_private = check_is_repo_private(repo) 90 | 91 | authentication_token: str = None 92 | if is_repo_private: 93 | authentication_token: str = get_or_prompt_hf_token() 94 | 95 | return hf_hub_download( 96 | repo_id=repo, 97 | repo_type="model", 98 | filename=repo_filename, 99 | revision=version, 100 | token=authentication_token, 101 | local_dir=local_dir, 102 | ) 103 | 104 | 105 | def get_or_prompt_hf_token() -> str: 106 | """ 107 | Either get the Hugging Face token from the environment, 108 | or prompt the user for it and store it in the environment. 109 | 110 | Returns: 111 | str: The Hugging Face token. 112 | """ 113 | 114 | token = os.environ.get("HUGGING_FACE_TOKEN") 115 | if token is None: 116 | token = getpass( 117 | "Enter your Hugging Face token (or set HUGGING_FACE_TOKEN environment variable): " 118 | ) 119 | # Optionally store in env for subsequent calls in same session 120 | os.environ["HUGGING_FACE_TOKEN"] = token 121 | 122 | return token 123 | -------------------------------------------------------------------------------- /policyengine_api/utils/json.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import base64 3 | import numpy as np 4 | 5 | 6 | def make_hashable(o): 7 | if isinstance(o, (tuple, list)): 8 | return tuple((make_hashable(e) for e in o)) 9 | 10 | if isinstance(o, dict): 11 | return tuple(sorted((k, make_hashable(v)) for k, v in o.items())) 12 | 13 | if isinstance(o, (set, frozenset)): 14 | return tuple(sorted(make_hashable(e) for e in o)) 15 | 16 | return o 17 | 18 | 19 | def hash_object(o): 20 | hasher = hashlib.sha256() 21 | hasher.update(repr(make_hashable(o)).encode()) 22 | return base64.b64encode(hasher.digest()).decode() 23 | 24 | 25 | def get_safe_json(value): 26 | # This function is used when constructing metadata, 27 | # which will be consumed by JS, hence Infinity, not .inf 28 | if isinstance(value, (int, float)): 29 | if value == np.inf: 30 | return "Infinity" 31 | elif value == -np.inf: 32 | return "-Infinity" 33 | return value 34 | if isinstance(value, str): 35 | return value 36 | if isinstance(value, dict): 37 | return {k: get_safe_json(v) for k, v in value.items()} 38 | if isinstance(value, list): 39 | return [get_safe_json(v) for v in value] 40 | return None 41 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/__init__.py: -------------------------------------------------------------------------------- 1 | from .validate_tracer_analysis_payload import validate_tracer_analysis_payload 2 | from .validate_country import validate_country 3 | from .validate_set_policy_payload import validate_set_policy_payload 4 | from .validate_household_payload import validate_household_payload 5 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/ai/__init__.py: -------------------------------------------------------------------------------- 1 | from .validate_sim_analysis_payload import validate_sim_analysis_payload 2 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/ai/validate_sim_analysis_payload.py: -------------------------------------------------------------------------------- 1 | def validate_sim_analysis_payload(payload: dict) -> tuple[bool, str | None]: 2 | # Check if all required keys are present; note 3 | # that the audience key is optional 4 | required_keys = [ 5 | "currency", 6 | "selected_version", 7 | "time_period", 8 | "impact", 9 | "policy_label", 10 | "policy", 11 | "region", 12 | "relevant_parameters", 13 | "relevant_parameter_baseline_values", 14 | ] 15 | str_keys = [ 16 | "currency", 17 | "selected_version", 18 | "time_period", 19 | "policy_label", 20 | "region", 21 | ] 22 | dict_keys = [ 23 | "policy", 24 | "impact", 25 | ] 26 | list_keys = ["relevant_parameters", "relevant_parameter_baseline_values"] 27 | missing_keys = [key for key in required_keys if key not in payload] 28 | if missing_keys: 29 | return False, f"Missing required keys: {missing_keys}" 30 | 31 | # Check if all keys are of the right type 32 | for key, value in payload.items(): 33 | if key in str_keys and not isinstance(value, str): 34 | return False, f"Key '{key}' must be a string" 35 | elif key in dict_keys and not isinstance(value, dict): 36 | return False, f"Key '{key}' must be a dictionary" 37 | elif key in list_keys and not isinstance(value, list): 38 | return False, f"Key '{key}' must be a list" 39 | 40 | return True, None 41 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/validate_country.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Union 3 | from flask import Response 4 | import json 5 | from policyengine_api.constants import COUNTRIES 6 | 7 | 8 | def validate_country(func): 9 | """Validate that a country ID is valid. If not, return a 400 response. 10 | 11 | Args: 12 | country_id (str): The country ID to validate. 13 | 14 | Returns: 15 | Response(400) if country is not valid, else continues 16 | """ 17 | 18 | @wraps(func) 19 | def validate_country_wrapper( 20 | country_id: str, *args, **kwargs 21 | ) -> Union[None, Response]: 22 | print("Validating country") 23 | if country_id not in COUNTRIES: 24 | body = dict( 25 | status="error", 26 | message=f"Country {country_id} not found. Available countries are: {', '.join(COUNTRIES)}", 27 | ) 28 | return Response(json.dumps(body), status=400) 29 | return func(country_id, *args, **kwargs) 30 | 31 | return validate_country_wrapper 32 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/validate_household_payload.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def validate_household_payload(payload): 5 | """ 6 | Validate the payload for a POST request to set a household's input data. 7 | 8 | Args: 9 | payload (dict): The payload to validate. 10 | 11 | Returns: 12 | tuple[bool, str]: A tuple containing a boolean indicating whether the payload is valid and a message. 13 | """ 14 | # Check that all required keys are present 15 | required_keys = ["data"] 16 | missing_keys = [key for key in required_keys if key not in payload] 17 | if missing_keys: 18 | return False, f"Missing required keys: {missing_keys}" 19 | 20 | # Check that label is either string or None, if present 21 | if "label" in payload: 22 | if payload["label"] is not None and not isinstance( 23 | payload["label"], str 24 | ): 25 | return False, "Label must be a string or None" 26 | 27 | # Check that data is a dictionary 28 | if not isinstance(payload["data"], dict): 29 | return False, "Unable to parse household JSON data" 30 | 31 | return True, None 32 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/validate_set_policy_payload.py: -------------------------------------------------------------------------------- 1 | def validate_set_policy_payload(payload: dict) -> tuple[bool, str | None]: 2 | 3 | # Check that all required keys are present 4 | required_keys = ["data"] 5 | missing_keys = [key for key in required_keys if key not in payload] 6 | if missing_keys: 7 | return False, f"Missing required keys: {missing_keys}" 8 | 9 | # Check that label is either string or None 10 | if "label" in payload: 11 | if payload["label"] is not None and not isinstance( 12 | payload["label"], str 13 | ): 14 | return False, "Label must be a string or None" 15 | 16 | # Check that data is a dictionary 17 | if not isinstance(payload["data"], dict): 18 | return False, "Data must be a dictionary" 19 | 20 | return True, None 21 | -------------------------------------------------------------------------------- /policyengine_api/utils/payload_validators/validate_tracer_analysis_payload.py: -------------------------------------------------------------------------------- 1 | def validate_tracer_analysis_payload(payload: dict): 2 | # Validate payload 3 | if not payload: 4 | return False, "No payload provided" 5 | 6 | required_keys = ["household_id", "policy_id", "variable"] 7 | for key in required_keys: 8 | if key not in payload: 9 | return False, f"Missing required key: {key}" 10 | 11 | return True, None 12 | -------------------------------------------------------------------------------- /policyengine_api/utils/singleton.py: -------------------------------------------------------------------------------- 1 | class Singleton(type): 2 | _instances = {} 3 | 4 | def __call__(cls, *args, **kwargs): 5 | if cls not in cls._instances: 6 | cls._instances[cls] = super(Singleton, cls).__call__( 7 | *args, **kwargs 8 | ) 9 | return cls._instances[cls] 10 | -------------------------------------------------------------------------------- /policyengine_api/worker.py: -------------------------------------------------------------------------------- 1 | from redis import Redis 2 | from rq import Worker 3 | 4 | # Preload libraries 5 | import policyengine_uk 6 | import policyengine_us 7 | import policyengine_canada 8 | import policyengine_ng 9 | 10 | # Provide the worker with the list of queues (str) to listen to. 11 | w = Worker(["default"], connection=Redis()) 12 | w.work() 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from policyengine_api.constants import __version__ 3 | 4 | setup( 5 | name="policyengine-api", 6 | version=__version__, 7 | author="PolicyEngine", 8 | author_email="hello@policyengine.org", 9 | description="PolicyEngine API", 10 | packages=find_packages(), 11 | install_requires=[ 12 | "anthropic", 13 | "assertpy", 14 | "click>=8,<9", 15 | "cloud-sql-python-connector", 16 | "google-cloud-workflows", 17 | "faiss-cpu<1.8.0", 18 | "flask>=3,<4", 19 | "flask-cors>=5,<6", 20 | "google-cloud-logging", 21 | "gunicorn", 22 | "markupsafe>=3,<4", 23 | "openai", 24 | "policyengine_canada==0.96.2", 25 | "policyengine-ng==0.5.1", 26 | "policyengine-il==0.1.0", 27 | "policyengine_uk==2.28.2", 28 | "policyengine_us==1.307.0", 29 | "policyengine_core>=3.16.6", 30 | "policyengine>=0.3.0", 31 | "pydantic", 32 | "pymysql", 33 | "python-dotenv", 34 | "redis", 35 | "rq", 36 | "sqlalchemy>=1.4,<2", 37 | "streamlit", 38 | "werkzeug", 39 | "Flask-Caching>=2,<3", 40 | "google-cloud-logging>=3,<4", 41 | ], 42 | extras_require={ 43 | "dev": ["pytest-timeout", "coverage", "pytest-snapshot"], 44 | }, 45 | # script policyengine-api-setup -> policyengine_api.setup_data:setup_data 46 | entry_points={ 47 | "console_scripts": [ 48 | "policyengine-api-setup=policyengine_api.setup_data:setup_data", 49 | ], 50 | }, 51 | ) 52 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolicyEngine/policyengine-api/5a3e2ce40857c086299cfe0b53956c1e28a8e3ad/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import time 4 | from contextlib import contextmanager 5 | from subprocess import Popen, TimeoutExpired 6 | import sys 7 | import redis 8 | import pytest 9 | from policyengine_api.api import app 10 | 11 | 12 | # Add the project root directory to PYTHONPATH 13 | root_dir = Path(__file__).parent 14 | sys.path.append(str(root_dir)) 15 | """Shared fixtures""" 16 | 17 | 18 | @contextmanager 19 | def running(process_arguments, seconds_to_wait_after_launch=0): 20 | """run a process and kill it after""" 21 | process = Popen(process_arguments) 22 | time.sleep(seconds_to_wait_after_launch) 23 | try: 24 | yield process 25 | finally: 26 | process.kill() 27 | try: 28 | process.wait(10) 29 | except TimeoutExpired: 30 | process.terminate() 31 | 32 | 33 | @pytest.fixture(name="rest_client", scope="session") 34 | def client(): 35 | """run the app for the tests to run against""" 36 | app.config["TESTING"] = True 37 | with running(["redis-server"], 3): 38 | redis_client = redis.Redis() 39 | redis_client.ping() 40 | with running([sys.executable, "policyengine_api/worker.py"], 3): 41 | with app.test_client() as test_client: 42 | yield test_client 43 | -------------------------------------------------------------------------------- /tests/data/calculate_us_1_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "household": { 3 | "people": { 4 | "you": { 5 | "age": { 6 | "2023": 40 7 | }, 8 | "people": { 9 | "2023": 1 10 | } 11 | } 12 | }, 13 | "axes": [ 14 | [ 15 | { 16 | "name": "employment_income", 17 | "period": "2023", 18 | "min": 0, 19 | "max": 200000, 20 | "count": 401 21 | } 22 | ] 23 | ] 24 | }, 25 | "policy": { 26 | "gov.irs.income.exemption.amount": { 27 | "2023-01-01.2028-12-31": "101" 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /tests/data/calculate_us_2_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "household": { 3 | "people": { 4 | "you": { 5 | "age": { 6 | "2023": 40 7 | }, 8 | "people": { 9 | "2023": 1 10 | } 11 | } 12 | }, 13 | "axes": [ 14 | [ 15 | { 16 | "name": "employment_income", 17 | "period": "2023", 18 | "min": 0, 19 | "max": 200000, 20 | "count": 401 21 | } 22 | ] 23 | ] 24 | }, 25 | "policy": { 26 | "gov.irs.income.exemption.amount": { 27 | "2023-01-01.2028-12-31": "100" 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /tests/data/or_rebate_measure_118.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "gov.contrib.ubi_center.basic_income.amount.person.flat": { 4 | "2024-01-01.2100-12-31": 1160 5 | } 6 | } 7 | } -------------------------------------------------------------------------------- /tests/data/test_economy_1_policy_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "gov.abolitions.ccdf_income": { 4 | "2023-01-01.2028-12-31": true 5 | }, 6 | "gov.irs.income.exemption.amount": { 7 | "2023-01-01.2028-12-31": "100" 8 | } 9 | } 10 | } -------------------------------------------------------------------------------- /tests/data/uk_household.json: -------------------------------------------------------------------------------- 1 | { 2 | "benunits": { 3 | "your immediate family": { 4 | "is_married": { 5 | "2023": true 6 | }, 7 | "members": [ 8 | "you", 9 | "your partner", 10 | "your first child", 11 | "your second child", 12 | "your third child", 13 | "your fourth child" 14 | ] 15 | } 16 | }, 17 | "households": { 18 | "your household": { 19 | "BRMA": { 20 | "2023": "MAIDSTONE" 21 | }, 22 | "local_authority": { 23 | "2023": "MAIDSTONE" 24 | }, 25 | "members": [ 26 | "you", 27 | "your partner", 28 | "your first child", 29 | "your second child", 30 | "your third child", 31 | "your fourth child" 32 | ], 33 | "region": { 34 | "2023": "LONDON" 35 | } 36 | } 37 | }, 38 | "people": { 39 | "you": { 40 | "age": { 41 | "2023": 40 42 | }, 43 | "employment_income": { 44 | "2023": 40000 45 | } 46 | }, 47 | "your first child": { 48 | "age": { 49 | "2023": 10 50 | }, 51 | "employment_income": { 52 | "2023": 0 53 | } 54 | }, 55 | "your fourth child": { 56 | "age": { 57 | "2023": 10 58 | }, 59 | "employment_income": { 60 | "2023": 0 61 | } 62 | }, 63 | "your partner": { 64 | "age": { 65 | "2023": 40 66 | }, 67 | "employment_income": { 68 | "2023": 0 69 | } 70 | }, 71 | "your second child": { 72 | "age": { 73 | "2023": 10 74 | }, 75 | "employment_income": { 76 | "2023": 0 77 | } 78 | }, 79 | "your third child": { 80 | "age": { 81 | "2023": 10 82 | }, 83 | "employment_income": { 84 | "2023": 0 85 | } 86 | } 87 | } 88 | } -------------------------------------------------------------------------------- /tests/data/us_household.json: -------------------------------------------------------------------------------- 1 | { 2 | "families": { 3 | "your family": { 4 | "members": [ 5 | "you", 6 | "your partner", 7 | "your first dependent", 8 | "your second dependent", 9 | "your third dependent", 10 | "your fourth dependent" 11 | ] 12 | } 13 | }, 14 | "households": { 15 | "your household": { 16 | "members": [ 17 | "you", 18 | "your partner", 19 | "your first dependent", 20 | "your second dependent", 21 | "your third dependent", 22 | "your fourth dependent" 23 | ], 24 | "state_name": { 25 | "2023": "AL" 26 | } 27 | } 28 | }, 29 | "marital_units": { 30 | "your marital unit": { 31 | "members": [ 32 | "you", 33 | "your partner" 34 | ] 35 | }, 36 | "your first dependent's marital unit": { 37 | "marital_unit_id": { 38 | "2023": 1 39 | }, 40 | "members": [ 41 | "your first dependent" 42 | ] 43 | }, 44 | "your second dependent's marital unit": { 45 | "marital_unit_id": { 46 | "2023": 2 47 | }, 48 | "members": [ 49 | "your second dependent" 50 | ] 51 | }, 52 | "your third dependent's marital unit": { 53 | "marital_unit_id": { 54 | "2023": 3 55 | }, 56 | "members": [ 57 | "your third dependent" 58 | ] 59 | }, 60 | "your fourth dependent's marital unit": { 61 | "marital_unit_id": { 62 | "2023": 4 63 | }, 64 | "members": [ 65 | "your fourth dependent" 66 | ] 67 | } 68 | }, 69 | "people": { 70 | "you": { 71 | "age": { 72 | "2023": 40 73 | }, 74 | "employment_income": { 75 | "2023": 40000 76 | } 77 | }, 78 | "your first dependent": { 79 | "age": { 80 | "2023": 10 81 | }, 82 | "employment_income": { 83 | "2023": 0 84 | }, 85 | "is_tax_unit_dependent": { 86 | "2023": true 87 | } 88 | }, 89 | "your partner": { 90 | "age": { 91 | "2023": 40 92 | }, 93 | "employment_income": { 94 | "2023": 0 95 | } 96 | }, 97 | "your second dependent": { 98 | "age": { 99 | "2023": 10 100 | }, 101 | "employment_income": { 102 | "2023": 0 103 | }, 104 | "is_tax_unit_dependent": { 105 | "2023": true 106 | } 107 | }, 108 | "your third dependent": { 109 | "age": { 110 | "2023": 10 111 | }, 112 | "employment_income": { 113 | "2023": 0 114 | }, 115 | "is_tax_unit_dependent": { 116 | "2023": true 117 | } 118 | }, 119 | "your fourth dependent": { 120 | "age": { 121 | "2023": 10 122 | }, 123 | "employment_income": { 124 | "2023": 0 125 | }, 126 | "is_tax_unit_dependent": { 127 | "2023": true 128 | } 129 | } 130 | }, 131 | "spm_units": { 132 | "your household": { 133 | "members": [ 134 | "you", 135 | "your partner", 136 | "your first dependent", 137 | "your second dependent", 138 | "your third dependent", 139 | "your fourth dependent" 140 | ] 141 | } 142 | }, 143 | "tax_units": { 144 | "your tax unit": { 145 | "members": [ 146 | "you", 147 | "your partner", 148 | "your first dependent", 149 | "your second dependent", 150 | "your third dependent", 151 | "your fourth dependent" 152 | ] 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /tests/data/utah_reform.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "gov.states.ut.tax.income.rate": { 4 | "2025-01-01.2100-12-31": 0.0455 5 | }, 6 | "gov.states.ut.tax.income.credits.ctc.child_age_threshold[0].amount": { 7 | "2025-01-01.2100-12-31": false 8 | }, 9 | "gov.states.ut.tax.income.credits.ctc.child_age_threshold[2].threshold": { 10 | "2025-01-01.2100-12-31": 5 11 | }, 12 | "gov.states.ut.tax.income.credits.ss_benefits.phase_out.threshold.JOINT": { 13 | "2025-01-01.2100-12-31": 75000 14 | }, 15 | "gov.states.ut.tax.income.credits.ss_benefits.phase_out.threshold.SINGLE": { 16 | "2025-01-01.2100-12-31": 45000 17 | }, 18 | "gov.states.ut.tax.income.credits.ss_benefits.phase_out.threshold.SEPARATE": { 19 | "2025-01-01.2100-12-31": 37500 20 | }, 21 | "gov.states.ut.tax.income.credits.ss_benefits.phase_out.threshold.SURVIVING_SPOUSE": { 22 | "2025-01-01.2100-12-31": 75000 23 | }, 24 | "gov.states.ut.tax.income.credits.ss_benefits.phase_out.threshold.HEAD_OF_HOUSEHOLD": { 25 | "2025-01-01.2100-12-31": 75000 26 | } 27 | } 28 | } -------------------------------------------------------------------------------- /tests/env_variables/test_environment_variables.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import requests 4 | 5 | HUGGING_FACE_API_URL = "https://huggingface.co/api/whoami-v2" 6 | GITHUB_API_URL = "https://api.github.com/user" 7 | 8 | do_not_run_in_debug = lambda: os.getenv("FLASK_DEBUG") == "1" 9 | 10 | 11 | class TestEnvironmentVariables: 12 | """Tests for expiring environment variables.""" 13 | 14 | @pytest.mark.skipif( 15 | do_not_run_in_debug(), 16 | reason="Skipping in debug mode", 17 | ) 18 | def test_hugging_face_token(self): 19 | """Test if HUGGING_FACE_TOKEN is valid by querying Hugging Face API.""" 20 | 21 | token = os.getenv("HUGGING_FACE_TOKEN") 22 | assert token is not None, "HUGGING_FACE_TOKEN is not set" 23 | 24 | token_validation_response = requests.get( 25 | HUGGING_FACE_API_URL, 26 | headers={"Authorization": f"Bearer {token}"}, 27 | timeout=5, 28 | ) 29 | 30 | assert ( 31 | token_validation_response.status_code == 200 32 | ), f"Invalid HUGGING_FACE_TOKEN: {token_validation_response.text}" 33 | 34 | @pytest.mark.skipif( 35 | do_not_run_in_debug(), 36 | reason="Skipping in debug mode", 37 | ) 38 | def test_github_microdata_auth_token(self): 39 | """Test if POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is valid by querying GitHub user API.""" 40 | 41 | token = os.getenv("POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN") 42 | assert ( 43 | token is not None 44 | ), "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" 45 | 46 | headers = { 47 | "Authorization": f"Bearer {token}", 48 | "Accept": "application/vnd.github+json", 49 | "X-GitHub-Api-Version": "2022-11-28", 50 | } 51 | 52 | token_validation_response = requests.get( 53 | GITHUB_API_URL, 54 | headers=headers, 55 | timeout=5, 56 | ) 57 | 58 | assert ( 59 | token_validation_response.status_code == 200 60 | ), f"Invalid POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: {token_validation_response.text}" 61 | 62 | token_user_details = token_validation_response.json() 63 | assert ( 64 | "login" in token_user_details 65 | ), "Token is valid but did not return expected user details" 66 | -------------------------------------------------------------------------------- /tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolicyEngine/policyengine-api/5a3e2ce40857c086299cfe0b53956c1e28a8e3ad/tests/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolicyEngine/policyengine-api/5a3e2ce40857c086299cfe0b53956c1e28a8e3ad/tests/fixtures/jobs/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/jobs/calculate_economy_simulation_job.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import unittest.mock as mock 3 | import numpy as np 4 | import pandas as pd 5 | import h5py 6 | 7 | 8 | @pytest.fixture 9 | def mock_huggingface_downloads(monkeypatch): 10 | """Mock the huggingface dataset downloads.""" 11 | 12 | def mock_download(repo, repo_filename): 13 | # Return mock file paths for constituency data 14 | if "constituency_weights" in repo_filename: 15 | return "mock_weights.h5" 16 | elif "constituencies_2024.csv" in repo_filename: 17 | return "mock_constituencies.csv" 18 | return repo_filename 19 | 20 | monkeypatch.setattr( 21 | "policyengine_api.jobs.calculate_economy_simulation_job.download_huggingface_dataset", 22 | mock_download, 23 | ) 24 | 25 | 26 | @pytest.fixture 27 | def mock_country(): 28 | """Create a mock UK country object.""" 29 | mock_country = mock.MagicMock() 30 | mock_country.name = "uk" 31 | return mock_country 32 | 33 | 34 | @pytest.fixture 35 | def mock_h5py_weights(monkeypatch): 36 | """Mock reading h5py weights.""" 37 | # Create a weight matrix with 650 constituencies and 100 households 38 | mock_weights = np.ones((650, 100)) 39 | 40 | # Create a mock dataset that works with [...] syntax 41 | mock_dataset = mock.MagicMock() 42 | mock_dataset.__getitem__.return_value = mock_weights 43 | 44 | # Create a mock group with the dataset 45 | mock_group = mock.MagicMock() 46 | mock_group.__getitem__.return_value = mock_dataset 47 | 48 | # Create a mock file 49 | mock_file = mock.MagicMock() 50 | mock_file.__enter__.return_value = mock_group 51 | mock_file.__exit__.return_value = None 52 | 53 | monkeypatch.setattr(h5py, "File", lambda path, mode: mock_file) 54 | return mock_weights 55 | 56 | 57 | @pytest.fixture 58 | def mock_constituency_names(monkeypatch): 59 | """Mock constituency names dataframe.""" 60 | # Create mock constituency data with English (E), Scottish (S), Welsh (W) and Northern Irish (N) constituencies 61 | # Need 650 constituencies to match the weights array shape 62 | codes = [] 63 | names = [] 64 | 65 | # Create 400 English constituencies 66 | for i in range(400): 67 | codes.append(f"E{i:07d}") 68 | names.append(f"English Constituency {i}") 69 | 70 | # Create 150 Scottish constituencies 71 | for i in range(150): 72 | codes.append(f"S{i:07d}") 73 | names.append(f"Scottish Constituency {i}") 74 | 75 | # Create 50 Welsh constituencies 76 | for i in range(50): 77 | codes.append(f"W{i:07d}") 78 | names.append(f"Welsh Constituency {i}") 79 | 80 | # Create 50 Northern Irish constituencies 81 | for i in range(50): 82 | codes.append(f"N{i:07d}") 83 | names.append(f"Northern Irish Constituency {i}") 84 | 85 | data = {"code": codes, "name": names} 86 | mock_df = pd.DataFrame(data) 87 | 88 | monkeypatch.setattr(pd, "read_csv", lambda path: mock_df) 89 | return mock_df 90 | 91 | 92 | @pytest.fixture 93 | def mock_simulation(): 94 | """Create a mock simulation object.""" 95 | simulation = mock.MagicMock() 96 | simulation.calculate.return_value = None 97 | simulation.set_input.return_value = None 98 | 99 | # Mock the holder objects 100 | person_holder = mock.MagicMock() 101 | benunit_holder = mock.MagicMock() 102 | simulation.get_holder.side_effect = lambda name: { 103 | "person_weight": person_holder, 104 | "benunit_weight": benunit_holder, 105 | }.get(name) 106 | 107 | return simulation 108 | 109 | 110 | @pytest.fixture 111 | def mock_reform_impacts_service(): 112 | with mock.patch( 113 | "policyengine_api.jobs.tasks.economy_simulation.reform_impacts_service" 114 | ) as mock_service: 115 | yield mock_service 116 | 117 | 118 | @pytest.fixture 119 | def mock_logging(): 120 | with mock.patch( 121 | "policyengine_api.jobs.tasks.economy_simulation.logging" 122 | ) as mock_logging: 123 | yield mock_logging 124 | -------------------------------------------------------------------------------- /tests/fixtures/services/ai_analysis_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import Optional 3 | from unittest.mock import MagicMock, patch 4 | 5 | 6 | # Event class to mimic Anthropic's streaming response events 7 | class MockEvent: 8 | def __init__( 9 | self, 10 | event_type: str, 11 | text: Optional[str] = None, 12 | error: Optional[dict[str, str]] = None, 13 | ): 14 | self.type = event_type 15 | self.text = text 16 | self.error = error 17 | 18 | 19 | @pytest.fixture() 20 | def patch_anthropic(): 21 | """ 22 | Fixture that patches the anthropic module at the root level. 23 | This ensures all imports of anthropic.Anthropic use our mock. 24 | """ 25 | with patch("anthropic.Anthropic") as mock: 26 | yield mock 27 | 28 | 29 | @pytest.fixture 30 | def mock_stream_text_events(patch_anthropic): 31 | """ 32 | Fixture that configures the mock Anthropic client to stream text events. 33 | """ 34 | 35 | def _configure(text_chunks: list[str]): 36 | # Set up mock client 37 | mock_client = MagicMock() 38 | patch_anthropic.return_value = mock_client 39 | 40 | # Set up mock stream 41 | mock_stream = MagicMock() 42 | mock_client.messages.stream.return_value.__enter__.return_value = ( 43 | mock_stream 44 | ) 45 | 46 | # Configure stream to yield text events 47 | events = [ 48 | MockEvent(event_type="text", text=chunk) for chunk in text_chunks 49 | ] 50 | mock_stream.__iter__.return_value = events 51 | 52 | return mock_client 53 | 54 | return _configure 55 | 56 | 57 | @pytest.fixture 58 | def mock_stream_error_event(patch_anthropic): 59 | """ 60 | Fixture that configures the mock Anthropic client to stream an error event. 61 | """ 62 | 63 | def _configure(error_type: str): 64 | # Set up mock client 65 | mock_client = MagicMock() 66 | patch_anthropic.return_value = mock_client 67 | 68 | # Set up mock stream 69 | mock_stream = MagicMock() 70 | mock_client.messages.stream.return_value.__enter__.return_value = ( 71 | mock_stream 72 | ) 73 | 74 | # Configure stream to yield an error event 75 | error_event = MockEvent(event_type="error", error={"type": error_type}) 76 | mock_stream.__iter__.return_value = [error_event] 77 | 78 | return mock_client 79 | 80 | return _configure 81 | 82 | 83 | def parse_to_chunks(input: str) -> list[str]: 84 | """ 85 | The AI analysis service returns streaming output in chunks of 5 characters. 86 | Parse any string to that format. 87 | """ 88 | CHAR_LEN = 5 89 | 90 | return [input[i : i + CHAR_LEN] for i in range(0, len(input), CHAR_LEN)] 91 | -------------------------------------------------------------------------------- /tests/fixtures/services/household_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from unittest.mock import patch 4 | 5 | 6 | valid_request_body = { 7 | "data": {"people": {"person1": {"age": 30, "income": 50000}}}, 8 | "label": "Test Household", 9 | } 10 | 11 | valid_db_row = { 12 | "id": 10, 13 | "country_id": "us", 14 | "household_json": json.dumps(valid_request_body["data"]), 15 | "household_hash": "some-hash", 16 | "label": "Test Household", 17 | "api_version": "3.0.0", 18 | } 19 | 20 | valid_hash_value = "some-hash" 21 | 22 | 23 | @pytest.fixture 24 | def mock_hash_object(): 25 | """Mock the hash_object function.""" 26 | with patch( 27 | "policyengine_api.services.household_service.hash_object" 28 | ) as mock: 29 | mock.return_value = valid_hash_value 30 | yield mock 31 | 32 | 33 | @pytest.fixture 34 | def existing_household_record(test_db): 35 | """Insert an existing household record into the database.""" 36 | test_db.query( 37 | "INSERT INTO household (id, country_id, household_json, household_hash, label, api_version) VALUES (?, ?, ?, ?, ?, ?)", 38 | ( 39 | valid_db_row["id"], 40 | valid_db_row["country_id"], 41 | valid_db_row["household_json"], 42 | valid_db_row["household_hash"], 43 | valid_db_row["label"], 44 | valid_db_row["api_version"], 45 | ), 46 | ) 47 | 48 | inserted_row = test_db.query( 49 | "SELECT * FROM household WHERE id = ?", (valid_db_row["id"],) 50 | ).fetchone() 51 | 52 | return inserted_row 53 | -------------------------------------------------------------------------------- /tests/fixtures/services/policy_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from unittest.mock import patch 4 | 5 | valid_policy_json = { 6 | "data": { 7 | "gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433} 8 | }, 9 | } 10 | 11 | valid_hash_value = "NgJhpeuRVnIAwgYWuJsd2fI/N88rIE6Kcj8q4TPD/i4=" 12 | 13 | # Sample valid policy data 14 | valid_policy_data = { 15 | "id": 11, 16 | "country_id": "us", 17 | "label": None, 18 | "api_version": "1.180.1", 19 | "policy_json": json.dumps(valid_policy_json["data"]), 20 | "policy_hash": valid_hash_value, 21 | } 22 | 23 | 24 | @pytest.fixture 25 | def mock_database(): 26 | """Mock the database module.""" 27 | with patch("policyengine_api.services.policy_service.database") as mock_db: 28 | yield mock_db 29 | 30 | 31 | @pytest.fixture 32 | def mock_hash_object(): 33 | """Mock the hash_object function.""" 34 | with patch("policyengine_api.services.policy_service.hash_object") as mock: 35 | mock.return_value = valid_hash_value 36 | yield mock 37 | 38 | 39 | @pytest.fixture 40 | def existing_policy_record(test_db): 41 | """Insert an existing policy record into the database.""" 42 | test_db.query( 43 | "INSERT INTO policy (id, country_id, policy_json, policy_hash, label, api_version) VALUES (?, ?, ?, ?, ?, ?)", 44 | ( 45 | valid_policy_data["id"], 46 | valid_policy_data["country_id"], 47 | valid_policy_data["policy_json"], 48 | valid_policy_data["policy_hash"], 49 | valid_policy_data["label"], 50 | valid_policy_data["api_version"], 51 | ), 52 | ) 53 | 54 | inserted_row = test_db.query( 55 | "SELECT * FROM policy WHERE id = ?", (valid_policy_data["id"],) 56 | ).fetchone() 57 | 58 | return inserted_row 59 | -------------------------------------------------------------------------------- /tests/fixtures/services/tracer_analysis_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from policyengine_api.services.tracer_analysis_service import ( 4 | TracerAnalysisService, 5 | ) 6 | from unittest.mock import patch 7 | 8 | 9 | valid_tracer_output = [ 10 | " snap<2027, (default)> = [6769.799]", 11 | " snap<2027-01, (default)> = [561.117]", 12 | " takes_up_snap_if_eligible<2027-01, (default)> = [ True]", 13 | " snap_normal_allotment<2027-01, (default)> = [561.117]", 14 | " is_snap_eligible<2027-01, (default)> = [ True]", 15 | " meets_snap_net_income_test<2027-01, (default)> = [ True]", 16 | " snap_net_income_fpg_ratio<2027-01, (default)> = [0.]", 17 | " snap_net_income<2027-01, (default)> = [0.]", 18 | " snap_fpg<2027-01, (default)> = [1806.4779]", 19 | ] 20 | 21 | invalid_tracer_output = { 22 | "variable": "only_government_benefit <1500>", 23 | "variable": " market_income <1000>", 24 | } 25 | 26 | spliced_valid_tracer_output_root_variable = valid_tracer_output[0:] 27 | 28 | spliced_valid_tracer_output_nested_variable = valid_tracer_output[2:3] 29 | 30 | spliced_valid_tracer_output_leaf_variable = valid_tracer_output[8:] 31 | 32 | spliced_valid_tracer_output_for_variable_that_is_substring_of_another = ( 33 | valid_tracer_output[7:8] 34 | ) 35 | 36 | empty_tracer = [] 37 | 38 | 39 | @pytest.fixture 40 | def sample_tracer_data(): 41 | return valid_tracer_output 42 | 43 | 44 | @pytest.fixture 45 | def sample_expected_segment(): 46 | return spliced_valid_tracer_output_nested_variable 47 | 48 | 49 | @pytest.fixture 50 | def mock_get_tracer(sample_tracer_data): 51 | with patch.object( 52 | TracerAnalysisService, "get_tracer", return_value=sample_tracer_data 53 | ) as mock: 54 | yield mock 55 | 56 | 57 | @pytest.fixture 58 | def mock_parse_tracer_output(sample_expected_segment): 59 | with patch.object( 60 | TracerAnalysisService, 61 | "_parse_tracer_output", 62 | return_value=sample_expected_segment, 63 | ) as mock: 64 | yield mock 65 | 66 | 67 | @pytest.fixture 68 | def mock_get_existing_analysis(): 69 | with patch.object( 70 | TracerAnalysisService, 71 | "get_existing_analysis", 72 | return_value="Existing static analysis", 73 | ) as mock: 74 | yield mock 75 | 76 | 77 | @pytest.fixture 78 | def mock_trigger_ai_analysis(): 79 | def dummy_generator(): 80 | yield "stream chunk 1" 81 | yield "stream chunk 2" 82 | 83 | with patch.object( 84 | TracerAnalysisService, 85 | "trigger_ai_analysis", 86 | return_value=dummy_generator(), 87 | ) as mock: 88 | yield mock 89 | -------------------------------------------------------------------------------- /tests/fixtures/services/tracer_fixture_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from policyengine_api.services.tracer_analysis_service import ( 4 | TracerAnalysisService, 5 | ) 6 | 7 | valid_tracer = { 8 | "tracer_output": [ 9 | "only_government_benefit <1500>", 10 | " market_income <1000>", 11 | " employment_income <1000>", 12 | " main_employment_income <1000>", 13 | " non_market_income <500>", 14 | " pension_income <500>", 15 | ] 16 | } 17 | 18 | valid_tracer_row = { 19 | "household_id": "71424", 20 | "policy_id": "2", 21 | "country_id": "us", 22 | "api_version": "1.150.0", 23 | "tracer_output": json.dumps(valid_tracer["tracer_output"]), 24 | } 25 | 26 | 27 | @pytest.fixture 28 | def test_tracer_data(test_db): 29 | 30 | # Insert data using query() 31 | test_db.query( 32 | """ 33 | INSERT INTO tracers (household_id, policy_id, country_id, api_version, tracer_output) 34 | VALUES (?, ?, ?, ?, ?) 35 | """, 36 | ( 37 | valid_tracer_row["household_id"], 38 | valid_tracer_row["policy_id"], 39 | valid_tracer_row["country_id"], 40 | valid_tracer_row["api_version"], 41 | valid_tracer_row["tracer_output"], 42 | ), 43 | ) 44 | 45 | # Verify that the data has been inserted 46 | inserted_row = test_db.query( 47 | "SELECT * FROM tracers WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ?", 48 | ( 49 | valid_tracer_row["household_id"], 50 | valid_tracer_row["policy_id"], 51 | valid_tracer_row["country_id"], 52 | valid_tracer_row["api_version"], 53 | ), 54 | ).fetchone() 55 | 56 | return inserted_row 57 | -------------------------------------------------------------------------------- /tests/fixtures/services/user_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | valid_user_record = { 5 | "user_id": 1, 6 | "auth0_id": "123", 7 | "username": "person1", 8 | "primary_country": "US", 9 | "user_since": 1678658906, 10 | } 11 | 12 | 13 | @pytest.fixture 14 | def existing_user_profile(test_db): 15 | """Insert an existing user record into the database.""" 16 | test_db.query( 17 | "INSERT INTO user_profiles (user_id, auth0_id, username, primary_country, user_since) VALUES (?, ?, ?, ?, ?)", 18 | ( 19 | valid_user_record["user_id"], 20 | valid_user_record["auth0_id"], 21 | valid_user_record["username"], 22 | valid_user_record["primary_country"], 23 | valid_user_record["user_since"], 24 | ), 25 | ) 26 | inserted_row = test_db.query( 27 | "SELECT * FROM user_profiles WHERE auth0_id = ?", 28 | (valid_user_record["auth0_id"],), 29 | ).fetchone() 30 | 31 | return inserted_row 32 | -------------------------------------------------------------------------------- /tests/fixtures/utils/v2_v1_comparison.py: -------------------------------------------------------------------------------- 1 | VALID_JOB_ID = "valid_job_id" 2 | 3 | valid_v2_v1_comparison = { 4 | "country_id": "us", 5 | "region": "ca", 6 | "reform_policy": '{"tax_rate": 0.2}', 7 | "baseline_policy": '{"tax_rate": 0.15}', 8 | "reform_policy_id": 1, 9 | "baseline_policy_id": 7, 10 | "time_period": "2023-01-01.2028-12-31", 11 | "dataset": "test_dataset", 12 | "v2_id": "v2_workflow_id", 13 | "v2_error": None, 14 | "v1_country_package_version": "1.0.0", 15 | "v2_country_package_version": "2.0.0", 16 | "v1_impact": {"impact_value": 100}, 17 | "v2_impact": {"impact_value": 120}, 18 | "v1_v2_diff": {"impact_value": 20}, 19 | "message": None, 20 | "job_id": VALID_JOB_ID, 21 | } 22 | 23 | invalid_v2_v1_comparison = { 24 | **valid_v2_v1_comparison, 25 | "reform_policy_id": "invalid_id", # Invalid type 26 | } 27 | -------------------------------------------------------------------------------- /tests/snapshots/simulation_analysis_prompt_dataset_enhanced_cps.txt: -------------------------------------------------------------------------------- 1 | 2 | I'm using PolicyEngine, a free, open source tool to compute the impact of 3 | public policy. I'm writing up an economic analysis of a hypothetical tax-benefit 4 | policy reform. Please write the analysis for me using the details below, in 5 | their order. You should: 6 | 7 | - First explain each provision of the reform, noting that it's hypothetical and 8 | won't represents policy reforms for 2022 and us. Explain how 9 | the parameters are changing from the baseline to the reform values using the given data. 10 | 11 | - Explicitly mention that this analysis uses PolicyEngine Enhanced CPS, constructed 12 | from the 2023 Current Population Survey and the 2015 IRS Public Use File, and calibrated 13 | to tax, benefit, income, and demographic aggregates. 14 | 15 | - Round large numbers like: $3.1 billion, $300 million, 16 | $106,000, $1.50 (never $1.5). 17 | 18 | - Round percentages to one decimal place. 19 | 20 | - Avoid normative language like 'requires', 'should', 'must', and use quantitative statements 21 | over general adjectives and adverbs. If you don't know what something is, don't make it up. 22 | 23 | - Avoid speculating about the intent of the policy or inferring any motives; only describe the 24 | observable effects and impacts of the policy. Refrain from using subjective language or making 25 | assumptions about the recipients and their needs. 26 | 27 | - Use the active voice where possible; for example, write phrases where the reform is the subject, 28 | such as "the reform [or a description of the reform] reduces poverty by x%". 29 | 30 | - Use American English spelling and grammar. 31 | 32 | - Cite PolicyEngine US v1.2.3 and the 2022 Current Population Survey March Supplement microdata 33 | when describing policy impacts. 34 | 35 | - When describing poverty impacts, note that the poverty measure reported is the Supplemental Poverty Measure 36 | 37 | - Don't use headers, but do use Markdown formatting. Use - for bullets, and include a newline after each bullet. 38 | 39 | - Include the following embeds inline, without a header so it flows. 40 | 41 | - Immediately after you describe the changes by decile, include the text: '{{distributionalImpact.incomeDecile.relative}}' 42 | 43 | - And after the poverty rate changes, include the text: '{{povertyImpact.regular.byAge}}' 44 | 45 | - After the racial breakdown of poverty rate changes, include the text: '{{povertyImpact.regular.byRace}}' 46 | 47 | - And after the inequality changes, include the text: "{{inequalityImpact}}" 48 | 49 | - Make sure to accurately represent the changes observed in the data. 50 | 51 | - This JSON snippet describes the default parameter values: [{'parameter1': 100, 'parameter2': 200}] 52 | 53 | - This JSON snippet describes the baseline and reform policies being compared: {'gov.test.parameter': 0.1} 54 | 55 | - policy_label has the following impacts from the PolicyEngine microsimulation model: 56 | 57 | - This JSON snippet describes the relevant parameters with more details: [{'parameter1': 100, 'parameter2': 200}] 58 | 59 | - This JSON describes the total budgetary impact, the change to tax revenues and benefit 60 | spending (ignore 'households' and 'baseline_net_income': {"baseline": 0.0, "reform": 0.1}) 61 | 62 | - This JSON describes how common different outcomes were at each income decile: {"baseline": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}, "reform": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}} 63 | 64 | - This JSON describes the average and relative changes to income by each income decile: {"baseline": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}, "reform": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}} 65 | 66 | - This JSON describes the baseline and reform poverty rates by age group 67 | (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 68 | 69 | - This JSON describes the baseline and reform deep poverty rates by age group 70 | (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 71 | 72 | - This JSON describes the baseline and reform poverty and deep poverty rates 73 | by gender (briefly describe the relative changes): {"baseline": 0.1, "reform": 0.2} 74 | 75 | - This JSON describes the baseline and reform poverty impacts by racial group (briefly describe the relative changes): {"HISPANIC": 0.1, "WHITE": 0.2} 76 | 77 | - This JSON describes three inequality metrics in the baseline and reform, the Gini 78 | coefficient of income inequality, the share of income held by the top 10% of households 79 | and the share held by the top 1% (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 80 | 81 | Write this for a policy analyst who knows a bit about economics and policy. 82 | -------------------------------------------------------------------------------- /tests/snapshots/simulation_analysis_prompt_uk.txt: -------------------------------------------------------------------------------- 1 | 2 | I'm using PolicyEngine, a free, open source tool to compute the impact of 3 | public policy. I'm writing up an economic analysis of a hypothetical tax-benefit 4 | policy reform. Please write the analysis for me using the details below, in 5 | their order. You should: 6 | 7 | - First explain each provision of the reform, noting that it's hypothetical and 8 | won't represents policy reforms for 2022 and uk. Explain how 9 | the parameters are changing from the baseline to the reform values using the given data. 10 | 11 | 12 | 13 | - Round large numbers like: £3.1 billion, £300 million, 14 | £106,000, £1.50 (never £1.5). 15 | 16 | - Round percentages to one decimal place. 17 | 18 | - Avoid normative language like 'requires', 'should', 'must', and use quantitative statements 19 | over general adjectives and adverbs. If you don't know what something is, don't make it up. 20 | 21 | - Avoid speculating about the intent of the policy or inferring any motives; only describe the 22 | observable effects and impacts of the policy. Refrain from using subjective language or making 23 | assumptions about the recipients and their needs. 24 | 25 | - Use the active voice where possible; for example, write phrases where the reform is the subject, 26 | such as "the reform [or a description of the reform] reduces poverty by x%". 27 | 28 | - Use British English spelling and grammar. 29 | 30 | - Cite PolicyEngine UK v1.2.3 and the PolicyEngine-enhanced 2019 Family Resources Survey microdata 31 | when describing policy impacts. 32 | 33 | - When describing poverty impacts, note that the poverty measure reported is absolute poverty before housing costs 34 | 35 | - Don't use headers, but do use Markdown formatting. Use - for bullets, and include a newline after each bullet. 36 | 37 | - Include the following embeds inline, without a header so it flows. 38 | 39 | - Immediately after you describe the changes by decile, include the text: '{{distributionalImpact.incomeDecile.relative}}' 40 | 41 | - And after the poverty rate changes, include the text: '{{povertyImpact.regular.byAge}}' 42 | 43 | 44 | 45 | - And after the inequality changes, include the text: "{{inequalityImpact}}" 46 | 47 | - Make sure to accurately represent the changes observed in the data. 48 | 49 | - This JSON snippet describes the default parameter values: [{'parameter1': 100, 'parameter2': 200}] 50 | 51 | - This JSON snippet describes the baseline and reform policies being compared: {'gov.test.parameter': 0.1} 52 | 53 | - policy_label has the following impacts from the PolicyEngine microsimulation model: 54 | 55 | - This JSON snippet describes the relevant parameters with more details: [{'parameter1': 100, 'parameter2': 200}] 56 | 57 | - This JSON describes the total budgetary impact, the change to tax revenues and benefit 58 | spending (ignore 'households' and 'baseline_net_income': {"baseline": 0.0, "reform": 0.1}) 59 | 60 | - This JSON describes how common different outcomes were at each income decile: {"baseline": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}, "reform": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}} 61 | 62 | - This JSON describes the average and relative changes to income by each income decile: {"baseline": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}, "reform": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}} 63 | 64 | - This JSON describes the baseline and reform poverty rates by age group 65 | (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 66 | 67 | - This JSON describes the baseline and reform deep poverty rates by age group 68 | (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 69 | 70 | - This JSON describes the baseline and reform poverty and deep poverty rates 71 | by gender (briefly describe the relative changes): {"baseline": 0.1, "reform": 0.2} 72 | 73 | 74 | 75 | - This JSON describes three inequality metrics in the baseline and reform, the Gini 76 | coefficient of income inequality, the share of income held by the top 10% of households 77 | and the share held by the top 1% (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 78 | 79 | Write this for a policy analyst who knows a bit about economics and policy. 80 | -------------------------------------------------------------------------------- /tests/snapshots/simulation_analysis_prompt_us.txt: -------------------------------------------------------------------------------- 1 | 2 | I'm using PolicyEngine, a free, open source tool to compute the impact of 3 | public policy. I'm writing up an economic analysis of a hypothetical tax-benefit 4 | policy reform. Please write the analysis for me using the details below, in 5 | their order. You should: 6 | 7 | - First explain each provision of the reform, noting that it's hypothetical and 8 | won't represents policy reforms for 2022 and us. Explain how 9 | the parameters are changing from the baseline to the reform values using the given data. 10 | 11 | 12 | 13 | - Round large numbers like: $3.1 billion, $300 million, 14 | $106,000, $1.50 (never $1.5). 15 | 16 | - Round percentages to one decimal place. 17 | 18 | - Avoid normative language like 'requires', 'should', 'must', and use quantitative statements 19 | over general adjectives and adverbs. If you don't know what something is, don't make it up. 20 | 21 | - Avoid speculating about the intent of the policy or inferring any motives; only describe the 22 | observable effects and impacts of the policy. Refrain from using subjective language or making 23 | assumptions about the recipients and their needs. 24 | 25 | - Use the active voice where possible; for example, write phrases where the reform is the subject, 26 | such as "the reform [or a description of the reform] reduces poverty by x%". 27 | 28 | - Use American English spelling and grammar. 29 | 30 | - Cite PolicyEngine US v1.2.3 and the 2022 Current Population Survey March Supplement microdata 31 | when describing policy impacts. 32 | 33 | - When describing poverty impacts, note that the poverty measure reported is the Supplemental Poverty Measure 34 | 35 | - Don't use headers, but do use Markdown formatting. Use - for bullets, and include a newline after each bullet. 36 | 37 | - Include the following embeds inline, without a header so it flows. 38 | 39 | - Immediately after you describe the changes by decile, include the text: '{{distributionalImpact.incomeDecile.relative}}' 40 | 41 | - And after the poverty rate changes, include the text: '{{povertyImpact.regular.byAge}}' 42 | 43 | - After the racial breakdown of poverty rate changes, include the text: '{{povertyImpact.regular.byRace}}' 44 | 45 | - And after the inequality changes, include the text: "{{inequalityImpact}}" 46 | 47 | - Make sure to accurately represent the changes observed in the data. 48 | 49 | - This JSON snippet describes the default parameter values: [{'parameter1': 100, 'parameter2': 200}] 50 | 51 | - This JSON snippet describes the baseline and reform policies being compared: {'gov.test.parameter': 0.1} 52 | 53 | - policy_label has the following impacts from the PolicyEngine microsimulation model: 54 | 55 | - This JSON snippet describes the relevant parameters with more details: [{'parameter1': 100, 'parameter2': 200}] 56 | 57 | - This JSON describes the total budgetary impact, the change to tax revenues and benefit 58 | spending (ignore 'households' and 'baseline_net_income': {"baseline": 0.0, "reform": 0.1}) 59 | 60 | - This JSON describes how common different outcomes were at each income decile: {"baseline": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}, "reform": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}} 61 | 62 | - This JSON describes the average and relative changes to income by each income decile: {"baseline": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}, "reform": {"1": 0.1, "2": 0.2, "3": 0.3, "4": 0.4, "5": 0.5, "6": 0.6, "7": 0.7, "8": 0.8, "9": 0.9, "10": 1.0}} 63 | 64 | - This JSON describes the baseline and reform poverty rates by age group 65 | (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 66 | 67 | - This JSON describes the baseline and reform deep poverty rates by age group 68 | (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 69 | 70 | - This JSON describes the baseline and reform poverty and deep poverty rates 71 | by gender (briefly describe the relative changes): {"baseline": 0.1, "reform": 0.2} 72 | 73 | - This JSON describes the baseline and reform poverty impacts by racial group (briefly describe the relative changes): {"HISPANIC": 0.1, "WHITE": 0.2} 74 | 75 | - This JSON describes three inequality metrics in the baseline and reform, the Gini 76 | coefficient of income inequality, the share of income held by the top 10% of households 77 | and the share held by the top 1% (describe the relative changes): {"baseline": 0.1, "reform": 0.2} 78 | 79 | Write this for a policy analyst who knows a bit about economics and policy. 80 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_api.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import pytest 4 | from pathlib import Path 5 | from policyengine_api.api import app 6 | 7 | 8 | @pytest.fixture 9 | def client(): 10 | app.config["TESTING"] = True 11 | with app.test_client() as client: 12 | yield client 13 | 14 | 15 | # Each YAML file in the tests directory is a test case. 16 | # The test case is a dictionary with the following keys: 17 | 18 | # - name: the name of the test case 19 | # - endpoint: the endpoint to test 20 | # - method: the HTTP method to use 21 | # - data: the data to send to the endpoint if the method is POST 22 | # - expected_status: the expected HTTP status code 23 | # - expected_result: the expected result of the endpoint 24 | 25 | test_paths = [ 26 | path 27 | for path in (Path(__file__).parent).rglob("*") 28 | if path.suffix == ".yaml" 29 | ] 30 | test_data = [yaml.safe_load(path.read_text()) for path in test_paths] 31 | test_names = [test["name"] for test in test_data] 32 | 33 | 34 | def assert_response_data_matches_expected(data: dict, expected: dict): 35 | # For every key in expected, check that the corresponding key in data 36 | # has the same value. 37 | for key, value in expected.items(): 38 | if key not in data: 39 | raise ValueError(f"Key {key} not found in response data.\n {data}") 40 | if isinstance(value, dict): 41 | assert_response_data_matches_expected(data[key], value) 42 | elif data[key] != value: 43 | raise ValueError( 44 | f"Value {data[key]} for key {key} does not match expected value {value}.\n {data}" 45 | ) 46 | 47 | 48 | @pytest.mark.parametrize("test", test_data, ids=test_names) 49 | def test_response(client, test: dict): 50 | if test.get("method", "GET") == "GET": 51 | response = client.get(test["endpoint"]) 52 | elif test.get("method") == "POST": 53 | response = client.post( 54 | test["endpoint"], 55 | data=json.dumps(test["data"]), 56 | content_type="application/json", 57 | ) 58 | elif test.get("method") == "PUT": 59 | response = client.put( 60 | test["endpoint"], 61 | data=json.dumps(test["data"]), 62 | content_type="application/json", 63 | ) 64 | else: 65 | raise ValueError(f"Unknown HTTP method: {test['method']}") 66 | 67 | assert response.status_code == test.get("response", {}).get("status", 200) 68 | if "data" in test.get("response", {}): 69 | assert_response_data_matches_expected( 70 | json.loads(response.data), test.get("response", {}).get("data", {}) 71 | ) 72 | elif "html" in test.get("response", {}): 73 | assert response.data.decode("utf-8") == test.get("response", {}).get( 74 | "html", "" 75 | ) 76 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_hello_world.yaml: -------------------------------------------------------------------------------- 1 | name: Hello World 2 | endpoint: / 3 | response: 4 | status: 200 5 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_liveness.yaml: -------------------------------------------------------------------------------- 1 | name: Liveness 2 | endpoint: /liveness-check 3 | response: 4 | status: 200 5 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_readiness.yaml: -------------------------------------------------------------------------------- 1 | name: Readiness 2 | endpoint: /readiness-check 3 | response: 4 | status: 200 5 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_uk_baseline_policy.yaml: -------------------------------------------------------------------------------- 1 | name: UK baseline policy 2 | endpoint: /uk/policy/1 3 | method: GET 4 | data: {} # only for POST 5 | response: 6 | data: 7 | result: 8 | label: Current law 9 | status: 200 10 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_uk_metadata.yaml: -------------------------------------------------------------------------------- 1 | name: UK metadata 2 | endpoint: /uk/metadata 3 | response: 4 | status: 200 5 | -------------------------------------------------------------------------------- /tests/to_refactor/api/test_us_create_empty_household.yaml: -------------------------------------------------------------------------------- 1 | name: create US empty household 2 | endpoint: /us/household 3 | method: POST 4 | data: 5 | label: Empty Household 6 | data: {} 7 | response: 8 | data: 9 | status: ok 10 | status: 201 11 | -------------------------------------------------------------------------------- /tests/to_refactor/fixtures/simulation_analysis_fixtures.py: -------------------------------------------------------------------------------- 1 | test_impact = { 2 | "budget": { 3 | "baseline": 0.0, 4 | "reform": 0.1, 5 | "change": 0.2, 6 | }, 7 | "intra_decile": { 8 | "baseline": { 9 | "1": 0.1, 10 | "2": 0.2, 11 | "3": 0.3, 12 | "4": 0.4, 13 | "5": 0.5, 14 | "6": 0.6, 15 | "7": 0.7, 16 | "8": 0.8, 17 | "9": 0.9, 18 | "10": 1.0, 19 | }, 20 | "reform": { 21 | "1": 0.1, 22 | "2": 0.2, 23 | "3": 0.3, 24 | "4": 0.4, 25 | "5": 0.5, 26 | "6": 0.6, 27 | "7": 0.7, 28 | "8": 0.8, 29 | "9": 0.9, 30 | "10": 1.0, 31 | }, 32 | }, 33 | "decile": { 34 | "1": 0.1, 35 | "2": 0.2, 36 | "3": 0.3, 37 | "4": 0.4, 38 | "5": 0.5, 39 | "6": 0.6, 40 | "7": 0.7, 41 | "8": 0.8, 42 | "9": 0.9, 43 | "10": 1.0, 44 | }, 45 | "poverty": { 46 | "poverty": 0.3, 47 | "deep_poverty": 0.4, 48 | }, 49 | "poverty_by_gender": { 50 | "baseline": { 51 | "male": 0.5, 52 | "female": 0.6, 53 | }, 54 | "reform": { 55 | "male": 0.7, 56 | "female": 0.8, 57 | }, 58 | }, 59 | "poverty_by_race": {"poverty": 0.6}, 60 | "inequality": { 61 | "baseline": 0.7, 62 | "reform": 0.8, 63 | "change": 0.9, 64 | }, 65 | } 66 | 67 | test_json = { 68 | "currency": "USD", 69 | "selected_version": "2023", 70 | "time_period": "2023", 71 | "dataset": None, 72 | "impact": test_impact, 73 | "policy_label": "Test Policy", 74 | "policy": dict(policy_json="policy details"), 75 | "region": "US", 76 | "relevant_parameters": [ 77 | {"param1": 100}, 78 | {"param2": 200}, 79 | ], 80 | "relevant_parameter_baseline_values": [ 81 | {"param1": 100}, 82 | {"param2": 200}, 83 | ], 84 | "audience": "Normal", 85 | } 86 | -------------------------------------------------------------------------------- /tests/to_refactor/fixtures/to_refactor_household_fixtures.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pytest 3 | from unittest.mock import patch 4 | 5 | valid_request_body = { 6 | "data": {"people": {"person1": {"age": 30, "income": 50000}}}, 7 | "label": "Test Household", 8 | } 9 | 10 | valid_db_row = { 11 | "id": 10, 12 | "country_id": "us", 13 | "household_json": json.dumps(valid_request_body["data"]), 14 | "household_hash": "some-hash", 15 | "label": "Test Household", 16 | "api_version": "3.0.0", 17 | } 18 | 19 | valid_hash_value = "some-hash" 20 | 21 | 22 | @pytest.fixture 23 | def mock_hash_object(): 24 | """Mock the hash_object function.""" 25 | with patch( 26 | "policyengine_api.services.household_service.hash_object" 27 | ) as mock: 28 | mock.return_value = valid_hash_value 29 | yield mock 30 | 31 | 32 | @pytest.fixture 33 | def mock_database(): 34 | """Mock the database module.""" 35 | with patch( 36 | "policyengine_api.services.household_service.database" 37 | ) as mock_db: 38 | yield mock_db 39 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_ai_analysis_service_old.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch, MagicMock 3 | import json 4 | import os 5 | from policyengine_api.services.ai_analysis_service import AIAnalysisService 6 | 7 | test_ai_service = AIAnalysisService() 8 | 9 | 10 | @patch("policyengine_api.services.ai_analysis_service.local_database") 11 | def test_get_existing_analysis_found(mock_db): 12 | mock_db.query.return_value.fetchone.return_value = { 13 | "analysis": "Existing analysis" 14 | } 15 | 16 | prompt = "Test prompt" 17 | output = test_ai_service.get_existing_analysis(prompt) 18 | 19 | assert output == json.dumps("Existing analysis") 20 | 21 | # Check database query 22 | mock_db.query.assert_called_once_with( 23 | f"SELECT analysis FROM analysis WHERE prompt = ?", 24 | (prompt,), 25 | ) 26 | 27 | 28 | @patch("policyengine_api.services.ai_analysis_service.local_database") 29 | def test_get_existing_analysis_not_found(mock_db): 30 | mock_db.query.return_value.fetchone.return_value = None 31 | 32 | prompt = "Test prompt" 33 | result = test_ai_service.get_existing_analysis(prompt) 34 | 35 | assert result is None 36 | mock_db.query.assert_called_once_with( 37 | f"SELECT analysis FROM analysis WHERE prompt = ?", 38 | (prompt,), 39 | ) 40 | 41 | 42 | # Additional test to check environment variable 43 | def test_anthropic_api_key(): 44 | with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_key"}): 45 | assert os.getenv("ANTHROPIC_API_KEY") == "test_key" 46 | 47 | 48 | # Test error handling in trigger_ai_analysis 49 | @patch("policyengine_api.services.ai_analysis_service.anthropic.Anthropic") 50 | def test_trigger_ai_analysis_error(mock_anthropic): 51 | mock_client = MagicMock() 52 | mock_anthropic.return_value = mock_client 53 | mock_client.messages.stream.side_effect = Exception("API Error") 54 | 55 | prompt = "Test prompt" 56 | generator = test_ai_service.trigger_ai_analysis(prompt) 57 | 58 | # The generator should stop after the initial yield due to the error 59 | with pytest.raises(Exception, match="API Error"): 60 | list(generator) 61 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_calculate_us_1.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Test: run a few calls to /calculate, running them with --durations=0 should 3 | # show that chaching is working (the ones suffixed by _repeat should be hits 4 | # and run much faster than their equivalent without the _repeat suffix). 5 | """ 6 | 7 | 8 | def test_calculate_us_1(rest_client): 9 | """This should be a cache miss as no other requests have been made yet.""" 10 | response = rest_client.post( 11 | "/us/calculate", 12 | headers={"Content-Type": "application/json"}, 13 | data=open( 14 | "./tests/data/calculate_us_1_data.json", 15 | "r", 16 | encoding="utf-8", 17 | ), 18 | ) 19 | assert response.status_code == 200, response.text 20 | 21 | 22 | def test_calculate_us_2(rest_client): 23 | """This should be a miss as the data is different to test_calculate_us_1""" 24 | response = rest_client.post( 25 | "/us/calculate", 26 | headers={"Content-Type": "application/json"}, 27 | data=open( 28 | "./tests/data/calculate_us_2_data.json", 29 | "r", 30 | encoding="utf-8", 31 | ), 32 | ) 33 | assert response.status_code == 200, response.text 34 | 35 | 36 | def test_calculate_us_1_repeat_1(rest_client): 37 | """This should be a hit as the data is the same as test_calculate_us_1""" 38 | response = rest_client.post( 39 | "/us/calculate", 40 | headers={"Content-Type": "application/json"}, 41 | data=open( 42 | "./tests/data/calculate_us_1_data.json", 43 | "r", 44 | encoding="utf-8", 45 | ), 46 | ) 47 | assert response.status_code == 200, response.text 48 | 49 | 50 | def test_calculate_us_2_repeat_1(rest_client): 51 | """This should be a cache hit as the data is the same as 52 | test_calculate_us_2 53 | """ 54 | response = rest_client.post( 55 | "/us/calculate", 56 | headers={"Content-Type": "application/json"}, 57 | data=open( 58 | "./tests/data/calculate_us_2_data.json", 59 | "r", 60 | encoding="utf-8", 61 | ), 62 | ) 63 | assert response.status_code == 200, response.text 64 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_economy_1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test basics for /economy endpoints. 3 | """ 4 | 5 | # Temporarily disabling until this test can be refactored 6 | 7 | # import json 8 | # import time 9 | # import datetime 10 | # from policyengine_api.data import local_database 11 | # 12 | # 13 | # def test_economy_1(rest_client): 14 | # """Add a simple policy and get /economy for that over 2.""" 15 | # 16 | # with open( 17 | # "./tests/python/data/test_economy_1_policy_1.json", 18 | # "r", 19 | # encoding="utf-8", 20 | # ) as f: 21 | # data_object = json.load(f) 22 | # 23 | # local_database.query("DELETE FROM reform_impact WHERE country_id = 'us'") 24 | # policy_create = rest_client.post( 25 | # "/us/policy", 26 | # headers={"Content-Type": "application/json"}, 27 | # json=data_object, 28 | # ) 29 | # assert policy_create.status_code in [200, 201] 30 | # assert policy_create.json["result"] is not None 31 | # policy_id = policy_create.json["result"]["policy_id"] 32 | # assert policy_id is not None 33 | # policy_response = rest_client.get(f"/us/policy/{policy_id}") 34 | # assert policy_response.status_code == 200 35 | # assert policy_response.json["status"] == "ok" 36 | # assert policy_response.json["result"]["id"] == int(policy_id) 37 | # policy = policy_response.json["result"]["policy_json"] 38 | # assert policy is not None 39 | # assert policy["gov.abolitions.ccdf_income"] is not None 40 | # assert policy["gov.irs.income.exemption.amount"] is not None 41 | # assert ( 42 | # policy["gov.abolitions.ccdf_income"]["2023-01-01.2028-12-31"] is True 43 | # ) 44 | # assert ( 45 | # policy["gov.irs.income.exemption.amount"]["2023-01-01.2028-12-31"] 46 | # == "100" 47 | # ) 48 | # query = f"/us/economy/{policy_id}/over/2?region=us&time_period=2023" 49 | # economy_response = rest_client.get(query) 50 | # assert economy_response.status_code == 200 51 | # assert economy_response.json["status"] == "computing", ( 52 | # f'Expected first answer status to be "computing" but it is ' 53 | # f'{str(economy_response.json["status"])}' 54 | # ) 55 | # while economy_response.json["status"] == "computing": 56 | # print("Before sleep:", datetime.datetime.now()) 57 | # time.sleep(3) 58 | # print("After sleep:", datetime.datetime.now()) 59 | # economy_response = rest_client.get(query) 60 | # print(json.dumps(economy_response.json)) 61 | # assert ( 62 | # economy_response.json["status"] == "ok" 63 | # ), f'Expected status "ok", got {economy_response.json["status"]}' 64 | # 65 | # local_database.query( 66 | # f"DELETE FROM policy WHERE id = ? ", 67 | # (policy_id,), 68 | # ) 69 | # 70 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_error_routes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from flask import Flask 3 | from policyengine_api.routes.error_routes import error_bp 4 | from werkzeug.exceptions import ( 5 | NotFound, 6 | BadRequest, 7 | Unauthorized, 8 | Forbidden, 9 | InternalServerError, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def app(): 15 | """Create and configure a new app instance for each test.""" 16 | app = Flask(__name__) 17 | app.register_blueprint(error_bp) 18 | return app 19 | 20 | 21 | @pytest.fixture 22 | def client(app): 23 | """Create a test client for the app.""" 24 | return app.test_client() 25 | 26 | 27 | def test_404_handler(app, client): 28 | """Test 404 Not Found error handling""" 29 | 30 | @app.route("/nonexistent") 31 | def nonexistent(): 32 | raise NotFound("Custom not found message") 33 | 34 | response = client.get("/nonexistent") 35 | data = response.get_json() 36 | 37 | assert response.status_code == 404 38 | assert data["status"] == "error" 39 | assert "Custom not found message" in data["message"] 40 | assert data["result"] is None 41 | 42 | 43 | def test_400_handler(app, client): 44 | """Test 400 Bad Request error handling""" 45 | 46 | @app.route("/bad-request") 47 | def bad_request(): 48 | raise BadRequest("Invalid parameters") 49 | 50 | response = client.get("/bad-request") 51 | data = response.get_json() 52 | 53 | assert response.status_code == 400 54 | assert data["status"] == "error" 55 | assert "Invalid parameters" in data["message"] 56 | assert data["result"] is None 57 | 58 | 59 | def test_401_handler(app, client): 60 | """Test 401 Unauthorized error handling""" 61 | 62 | @app.route("/unauthorized") 63 | def unauthorized(): 64 | raise Unauthorized("Invalid credentials") 65 | 66 | response = client.get("/unauthorized") 67 | data = response.get_json() 68 | 69 | assert response.status_code == 401 70 | assert data["status"] == "error" 71 | assert "Invalid credentials" in data["message"] 72 | assert data["result"] is None 73 | 74 | 75 | def test_403_handler(app, client): 76 | """Test 403 Forbidden error handling""" 77 | 78 | @app.route("/forbidden") 79 | def forbidden(): 80 | raise Forbidden("Access denied") 81 | 82 | response = client.get("/forbidden") 83 | data = response.get_json() 84 | 85 | assert response.status_code == 403 86 | assert data["status"] == "error" 87 | assert "Access denied" in data["message"] 88 | assert data["result"] is None 89 | 90 | 91 | def test_500_handler(app, client): 92 | """Test 500 Internal Server Error handling""" 93 | 94 | @app.route("/server-error") 95 | def server_error(): 96 | raise InternalServerError("Database connection failed") 97 | 98 | response = client.get("/server-error") 99 | data = response.get_json() 100 | 101 | assert response.status_code == 500 102 | assert data["status"] == "error" 103 | assert "Database connection failed" in data["message"] 104 | assert data["result"] is None 105 | 106 | 107 | def test_generic_exception_handler(app, client): 108 | """Test handling of generic exceptions""" 109 | 110 | @app.route("/generic-error") 111 | def generic_error(): 112 | raise ValueError("Something went wrong") 113 | 114 | response = client.get("/generic-error") 115 | data = response.get_json() 116 | 117 | assert response.status_code == 500 118 | assert data["status"] == "error" 119 | assert "Something went wrong" in data["message"] 120 | assert data["result"] is None 121 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_policy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | import time 4 | import sqlite3 5 | from policyengine_api.data import database 6 | from policyengine_api.utils import hash_object 7 | from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS 8 | 9 | 10 | class TestPolicyCreation: 11 | # Define the policy to test against 12 | country_id = "us" 13 | policy_json = {"sample_parameter": {"2024-01-01.2025-12-31": True}} 14 | label = "dworkin" 15 | test_policy = {"data": policy_json, "label": label} 16 | policy_hash = hash_object(policy_json) 17 | 18 | """ 19 | Test creating a policy, then ensure that duplicating 20 | that policy generates the correct response within the 21 | app; this requires sequential processing, hence the 22 | need for a separate Python-based test 23 | """ 24 | 25 | def test_create_unique_policy(self, rest_client): 26 | database.query( 27 | f"DELETE FROM policy WHERE policy_hash = ? AND label = ? AND country_id = ?", 28 | (self.policy_hash, self.label, self.country_id), 29 | ) 30 | 31 | res = rest_client.post("/us/policy", json=self.test_policy) 32 | return_object = json.loads(res.text) 33 | 34 | assert return_object["status"] == "ok" 35 | assert res.status_code == 201 36 | 37 | def test_create_nonunique_policy(self, rest_client): 38 | res = rest_client.post("/us/policy", json=self.test_policy) 39 | return_object = json.loads(res.text) 40 | 41 | assert return_object["status"] == "ok" 42 | assert res.status_code == 200 43 | 44 | database.query( 45 | f"DELETE FROM policy WHERE policy_hash = ? AND label = ? AND country_id = ?", 46 | (self.policy_hash, self.label, self.country_id), 47 | ) 48 | 49 | def test_create_policy_invalid_country(self, rest_client): 50 | res = rest_client.post("/au/policy", json=self.test_policy) 51 | assert res.status_code == 400 52 | 53 | 54 | # The below test is very prone to race conditions due to emission against live db; 55 | # should be refactored to use mock/test db and reinstated 56 | 57 | # class TestPolicySearch: 58 | # country_id = "us" 59 | # policy_json = {"sample_input": {"2023-01-01.2024-12-31": True}} 60 | # label = "maxwell" 61 | # policy_hash = hash_object(policy_json) 62 | # api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) 63 | # 64 | # # Pre-seed database with duplicate policies 65 | # for i in range(2): 66 | # database.query( 67 | # f"INSERT INTO policy (country_id, label, policy_json, policy_hash, api_version) VALUES (?, ?, ?, ?, ?)", 68 | # ( 69 | # country_id, 70 | # label, 71 | # json.dumps(policy_json), 72 | # policy_hash, 73 | # api_version, 74 | # ), 75 | # ) 76 | # 77 | # db_output = database.query( 78 | # f"SELECT * FROM policy WHERE label = ?", 79 | # (label,), 80 | # ).fetchall() 81 | # 82 | # def test_search_all_policies(self, rest_client): 83 | # res = rest_client.get("/us/policies") 84 | # return_object = json.loads(res.text) 85 | # 86 | # filtered_return = list( 87 | # filter(lambda x: x["label"] == self.label, return_object["result"]) 88 | # ) 89 | # 90 | # assert return_object["status"] == "ok" 91 | # assert len(filtered_return) == len(self.db_output) 92 | # 93 | # def test_search_unique_policies(self, rest_client): 94 | # res = rest_client.get("/us/policies?unique_only=true") 95 | # return_object = json.loads(res.text) 96 | # 97 | # filtered_return = list( 98 | # filter(lambda x: x["label"] == self.label, return_object["result"]) 99 | # ) 100 | # 101 | # assert return_object["status"] == "ok" 102 | # assert len(filtered_return) == 1 103 | # 104 | # # Clean up duplicate policies created 105 | # database.query( 106 | # f"DELETE FROM policy WHERE policy_hash = ? AND label = ? AND country_id = ?", 107 | # (self.policy_hash, self.label, self.country_id), 108 | # ) 109 | # 110 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_simulation_analysis_routes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch 3 | from flask import Flask 4 | 5 | from policyengine_api.services.simulation_analysis_service import ( 6 | SimulationAnalysisService, 7 | ) 8 | from policyengine_api.routes.simulation_analysis_routes import ( 9 | execute_simulation_analysis, 10 | ) 11 | 12 | from tests.to_refactor.fixtures.simulation_analysis_fixtures import ( 13 | test_json, 14 | test_impact, 15 | ) 16 | 17 | test_service = SimulationAnalysisService() 18 | 19 | 20 | def test_execute_simulation_analysis_existing_analysis(rest_client): 21 | 22 | with patch( 23 | "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" 24 | ) as mock_get_existing: 25 | mock_get_existing.return_value = "Existing analysis" 26 | 27 | response = rest_client.post("/us/simulation-analysis", json=test_json) 28 | 29 | assert response.status_code == 200 30 | assert "Existing analysis" in response.json["result"] 31 | 32 | 33 | def test_execute_simulation_analysis_new_analysis(rest_client): 34 | with patch( 35 | "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" 36 | ) as mock_get_existing: 37 | mock_get_existing.return_value = None 38 | with patch( 39 | "policyengine_api.services.simulation_analysis_service.AIAnalysisService.trigger_ai_analysis" 40 | ) as mock_trigger: 41 | mock_trigger.return_value = (s for s in ["New analysis"]) 42 | 43 | response = rest_client.post( 44 | "/us/simulation-analysis", json=test_json 45 | ) 46 | 47 | assert response.status_code == 200 48 | assert b"New analysis" in response.data 49 | 50 | 51 | def test_execute_simulation_analysis_error(rest_client): 52 | with patch( 53 | "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" 54 | ) as mock_get_existing: 55 | mock_get_existing.return_value = None 56 | with patch( 57 | "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" 58 | ) as mock_trigger: 59 | mock_trigger.side_effect = Exception("Test error") 60 | 61 | response = rest_client.post( 62 | "/us/simulation-analysis", json=test_json 63 | ) 64 | 65 | assert response.status_code == 500 66 | assert "Test error" in response.json.get("message") 67 | 68 | 69 | def test_execute_simulation_analysis_enhanced_cps(rest_client): 70 | policy_details = dict(policy_json="policy details") 71 | 72 | test_json_enhanced_cps = { 73 | "currency": "USD", 74 | "selected_version": "2023", 75 | "time_period": "2023", 76 | "impact": test_impact, 77 | "policy_label": "Test Policy", 78 | "policy": policy_details, 79 | "region": "us", 80 | "dataset": "enhanced_cps", 81 | "relevant_parameters": ["param1", "param2"], 82 | "relevant_parameter_baseline_values": [ 83 | {"param1": 100}, 84 | {"param2": 200}, 85 | ], 86 | "audience": "Normal", 87 | } 88 | with patch( 89 | "policyengine_api.services.simulation_analysis_service.SimulationAnalysisService._generate_simulation_analysis_prompt" 90 | ) as mock_generate_prompt: 91 | with patch( 92 | "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" 93 | ) as mock_get_existing: 94 | mock_get_existing.return_value = None 95 | with patch( 96 | "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" 97 | ) as mock_trigger: 98 | mock_trigger.return_value = ( 99 | s for s in ["Enhanced CPS analysis"] 100 | ) 101 | 102 | response = rest_client.post( 103 | "/us/simulation-analysis", json=test_json_enhanced_cps 104 | ) 105 | 106 | assert response.status_code == 200 107 | assert b"Enhanced CPS analysis" in response.data 108 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_tracer_analysis_routes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from flask import json 3 | from unittest.mock import patch 4 | 5 | 6 | @patch("policyengine_api.services.tracer_analysis_service.local_database") 7 | @patch( 8 | "policyengine_api.services.tracer_analysis_service.TracerAnalysisService.trigger_ai_analysis" 9 | ) 10 | def test_execute_tracer_analysis_success( 11 | mock_trigger_ai_analysis, mock_db, rest_client 12 | ): 13 | mock_db.query.return_value.fetchone.return_value = { 14 | "tracer_output": json.dumps( 15 | ["disposable_income <1000>", " market_income <1000>"] 16 | ) 17 | } 18 | mock_trigger_ai_analysis.return_value = "AI analysis result" 19 | test_household_id = 1500 20 | 21 | # Set this to US current law 22 | test_policy_id = 2 23 | 24 | response = rest_client.post( 25 | "/us/tracer-analysis", 26 | json={ 27 | "household_id": test_household_id, 28 | "policy_id": test_policy_id, 29 | "variable": "disposable_income", 30 | }, 31 | ) 32 | 33 | assert response.status_code == 200 34 | assert b"AI analysis result" in response.data 35 | 36 | 37 | @patch("policyengine_api.services.tracer_analysis_service.local_database") 38 | def test_execute_tracer_analysis_no_tracer(mock_db, rest_client): 39 | mock_db.query.return_value.fetchone.return_value = None 40 | 41 | response = rest_client.post( 42 | "/us/tracer-analysis", 43 | json={ 44 | "household_id": "test_household", 45 | "policy_id": "test_policy", 46 | "variable": "disposable_income", 47 | }, 48 | ) 49 | 50 | assert response.status_code == 404 51 | assert ( 52 | "No household simulation tracer found" 53 | in json.loads(response.data)["message"] 54 | ) 55 | 56 | 57 | @patch("policyengine_api.services.tracer_analysis_service.local_database") 58 | @patch( 59 | "policyengine_api.services.tracer_analysis_service.TracerAnalysisService.trigger_ai_analysis" 60 | ) 61 | def test_execute_tracer_analysis_ai_error( 62 | mock_trigger_ai_analysis, mock_db, rest_client 63 | ): 64 | mock_db.query.return_value.fetchone.return_value = { 65 | "tracer_output": json.dumps( 66 | ["disposable_income <1000>", " market_income <1000>"] 67 | ) 68 | } 69 | mock_trigger_ai_analysis.side_effect = Exception(KeyError) 70 | 71 | test_household_id = 1500 72 | test_policy_id = 2 73 | 74 | # Use the test client to make the request instead of calling the function directly 75 | response = rest_client.post( 76 | "/us/tracer-analysis", 77 | json={ 78 | "household_id": test_household_id, 79 | "policy_id": test_policy_id, 80 | "variable": "disposable_income", 81 | }, 82 | ) 83 | 84 | assert response.status_code == 500 85 | assert json.loads(response.data)["status"] == "error" 86 | 87 | 88 | # Test invalid country 89 | def test_invalid_country(rest_client): 90 | response = rest_client.post( 91 | "/invalid_country/tracer-analysis", 92 | json={ 93 | "household_id": "test_household", 94 | "policy_id": "test_policy", 95 | "variable": "disposable_income", 96 | }, 97 | ) 98 | assert response.status_code == 400 99 | assert b"Country invalid_country not found" in response.data 100 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_us_policy_macro.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import datetime 4 | from policyengine_api.data import local_database 5 | import sys 6 | 7 | # This test is a temporary test to ensure state impacts work properly. 8 | # It should be replaced with more comprehensive integration tests, then removed. 9 | 10 | 11 | def test_utah(rest_client): 12 | """ 13 | Test that the a given Utah policy is calculated 14 | and provides logical outputs for a sim 15 | """ 16 | 17 | return utah_reform_runner(rest_client, "ut") 18 | 19 | 20 | def utah_reform_runner(rest_client, region: str = "us"): 21 | """ 22 | Run the given Utah policy test, depending on provided 23 | region (defaults to "us") 24 | """ 25 | 26 | test_year = 2025 27 | default_policy = 2 28 | 29 | with open( 30 | "./tests/data/utah_reform.json", 31 | "r", 32 | encoding="utf-8", 33 | ) as f: 34 | data_object = json.load(f) 35 | 36 | local_database.query("DELETE FROM reform_impact WHERE country_id = 'us'") 37 | 38 | policy_create = rest_client.post( 39 | "/us/policy", 40 | headers={"Content-Type": "application/json"}, 41 | json=data_object, 42 | ) 43 | 44 | assert policy_create.status_code in [200, 201] 45 | assert policy_create.json["result"] is not None 46 | 47 | policy_id = policy_create.json["result"]["policy_id"] 48 | assert policy_id is not None 49 | 50 | query = f"/us/economy/{policy_id}/over/{default_policy}?region={region}&time_period={test_year}" 51 | economy_response = rest_client.get(query) 52 | assert economy_response.status_code == 200 53 | assert economy_response.json["status"] == "computing", ( 54 | f'Expected first answer status to be "computing" but it is ' 55 | f'{str(economy_response.json["status"])}' 56 | ) 57 | while economy_response.json["status"] == "computing": 58 | print("Before sleep:", datetime.datetime.now()) 59 | time.sleep(3) 60 | print("After sleep:", datetime.datetime.now()) 61 | economy_response = rest_client.get(query) 62 | print(json.dumps(economy_response.json)) 63 | assert ( 64 | economy_response.json["status"] == "ok" 65 | ), f'Expected status "ok", got {economy_response.json["status"]} with message "{economy_response.json}"' 66 | 67 | result = economy_response.json["result"] 68 | 69 | assert result is not None 70 | 71 | # Ensure that there is some budgetary impact 72 | cost = round(result["budget"]["budgetary_impact"] / 1e6, 1) 73 | assert ( 74 | cost / 95.4 - 1 75 | ) < 0.01, ( 76 | f"Expected budgetary impact to be 95.4 million, got {cost} million" 77 | ) 78 | 79 | assert ( 80 | result["intra_decile"]["all"]["Lose less than 5%"] / 0.637 - 1 81 | ) < 0.01, ( 82 | f"Expected 63.7% of people to lose less than 5%, got " 83 | f"{result['intra_decile']['all']['Lose less than 5%']}" 84 | ) 85 | 86 | local_database.query( 87 | f"DELETE FROM policy WHERE id = ? ", 88 | (policy_id,), 89 | ) 90 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_user_profile_routes.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | from policyengine_api.data import database 4 | import time 5 | 6 | 7 | class TestUserProfiles: 8 | # Define the profile to test against 9 | auth0_id = "dworkin" 10 | primary_country = "us" 11 | # Simulate JS's Date.now() 12 | user_since = int(time.time()) 13 | 14 | test_profile = { 15 | "auth0_id": auth0_id, 16 | "primary_country": primary_country, 17 | "user_since": user_since, 18 | } 19 | 20 | """ 21 | Test adding a record to user_profiles 22 | """ 23 | 24 | def test_set_and_get_record(self, rest_client): 25 | database.query( 26 | f"DELETE FROM user_profiles WHERE auth0_id = ? AND primary_country = ?", 27 | ( 28 | self.auth0_id, 29 | self.primary_country, 30 | ), 31 | ) 32 | 33 | res = rest_client.post("/us/user-profile", json=self.test_profile) 34 | return_object = json.loads(res.text) 35 | 36 | assert return_object["status"] == "ok" 37 | assert res.status_code == 201 38 | 39 | res = rest_client.get(f"/us/user-profile?auth0_id={self.auth0_id}") 40 | return_object = json.loads(res.text) 41 | 42 | assert res.status_code == 200 43 | assert return_object["status"] == "ok" 44 | assert return_object["result"]["auth0_id"] == self.auth0_id 45 | assert ( 46 | return_object["result"]["primary_country"] == self.primary_country 47 | ) 48 | assert return_object["result"]["username"] == None 49 | 50 | user_id = return_object["result"]["user_id"] 51 | 52 | res = rest_client.get(f"/us/user-profile?user_id={user_id}") 53 | return_object = json.loads(res.text) 54 | 55 | assert res.status_code == 200 56 | assert return_object["status"] == "ok" 57 | assert ( 58 | return_object["result"]["primary_country"] == self.primary_country 59 | ) 60 | assert return_object["result"].get("auth0_id") is None 61 | assert return_object["result"]["username"] == None 62 | 63 | test_username = "maxwell" 64 | updated_profile = {"user_id": user_id, "username": test_username} 65 | 66 | res = rest_client.put("/us/user-profile", json=updated_profile) 67 | return_object = json.loads(res.text) 68 | 69 | assert return_object["status"] == "ok" 70 | assert res.status_code == 200 71 | 72 | row = database.query( 73 | f"SELECT * FROM user_profiles WHERE user_id = ? AND username = ?", 74 | (user_id, test_username), 75 | ).fetchone() 76 | assert row is not None 77 | 78 | malicious_updated_profile = {**updated_profile, "auth0_id": "BOGUS"} 79 | 80 | res = rest_client.put( 81 | "/us/user-profile", json=malicious_updated_profile 82 | ) 83 | return_object = json.loads(res.text) 84 | 85 | assert res.status_code == 200 86 | 87 | row = database.query( 88 | f"SELECT * FROM user_profiles WHERE username = ?", 89 | (test_username,), 90 | ).fetchone() 91 | 92 | assert row["auth0_id"] == self.auth0_id 93 | 94 | database.query( 95 | f"DELETE FROM user_profiles WHERE user_id = ? AND auth0_id = ? AND primary_country = ?", 96 | (user_id, self.auth0_id, self.primary_country), 97 | ) 98 | 99 | def test_non_existent_record(self, rest_client): 100 | non_existent_auth0_id = 15303 101 | 102 | res = rest_client.get( 103 | f"/us/user-profile?auth0_id={non_existent_auth0_id}" 104 | ) 105 | return_object = json.loads(res.text) 106 | 107 | assert res.status_code == 404 108 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_validate_country.py: -------------------------------------------------------------------------------- 1 | from flask import Response 2 | from policyengine_api.utils.payload_validators import validate_country 3 | 4 | 5 | @validate_country 6 | def foo(country_id, other): 7 | """ 8 | A simple dummy test method for validation testing. Must be defined outside of the class (or within 9 | the test functions themselves) due to complications with the `self` parameter for class methods. 10 | """ 11 | return "bar" 12 | 13 | 14 | class TestValidateCountry: 15 | """ 16 | Test that the @validate_country decorator returns 404 if the country does not exist, otherwise 17 | continues execution of the function. 18 | """ 19 | 20 | def test_valid_country(self): 21 | result = foo("us", "extra_arg") 22 | assert result == "bar" 23 | 24 | def test_invalid_country(self): 25 | result = foo("baz", "extra_arg") 26 | assert isinstance(result, Response) 27 | assert result.status_code == 400 28 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_validate_household_payload.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | class TestHouseholdRouteValidation: 5 | """Test validation and error handling in household routes.""" 6 | 7 | @pytest.mark.parametrize( 8 | "invalid_payload", 9 | [ 10 | {}, # Empty payload 11 | {"label": "Test"}, # Missing data field 12 | {"data": None}, # None data 13 | {"data": "not_a_dict"}, # Non-dict data 14 | {"data": {}, "label": 123}, # Invalid label type 15 | ], 16 | ) 17 | def test_post_household_invalid_payload( 18 | self, rest_client, invalid_payload 19 | ): 20 | """Test POST endpoint with various invalid payloads.""" 21 | response = rest_client.post( 22 | "/us/household", 23 | json=invalid_payload, 24 | content_type="application/json", 25 | ) 26 | 27 | assert response.status_code == 400 28 | assert b"Unable to create new household" in response.data 29 | 30 | @pytest.mark.parametrize( 31 | "invalid_id", 32 | [ 33 | "abc", # Non-numeric 34 | "1.5", # Float 35 | ], 36 | ) 37 | def test_get_household_invalid_id(self, rest_client, invalid_id): 38 | """Test GET endpoint with invalid household IDs.""" 39 | response = rest_client.get(f"/us/household/{invalid_id}") 40 | 41 | # Default Werkzeug validation returns 404, not 400 42 | assert response.status_code == 404 43 | assert ( 44 | b"The requested URL was not found on the server" in response.data 45 | ) 46 | 47 | @pytest.mark.parametrize( 48 | "country_id", 49 | [ 50 | "123", # Numeric 51 | "us!!", # Special characters 52 | "zz", # Non-ISO 53 | "a" * 100, # Too long 54 | ], 55 | ) 56 | def test_invalid_country_id(self, rest_client, country_id): 57 | """Test endpoints with invalid country IDs.""" 58 | # Test GET 59 | get_response = rest_client.get(f"/{country_id}/household/1") 60 | assert get_response.status_code == 400 61 | 62 | # Test POST 63 | post_response = rest_client.post( 64 | f"/{country_id}/household", 65 | json={"data": {}}, 66 | content_type="application/json", 67 | ) 68 | assert post_response.status_code == 400 69 | 70 | # Test PUT 71 | put_response = rest_client.put( 72 | f"/{country_id}/household/1", 73 | json={"data": {}}, 74 | content_type="application/json", 75 | ) 76 | assert put_response.status_code == 400 77 | -------------------------------------------------------------------------------- /tests/to_refactor/python/test_validate_sim_analysis_payload.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import Dict, Any, Tuple 3 | 4 | from policyengine_api.utils.payload_validators.ai import ( 5 | validate_sim_analysis_payload, 6 | ) 7 | 8 | 9 | @pytest.fixture 10 | def valid_payload() -> Dict[str, Any]: 11 | return { 12 | "currency": "USD", 13 | "selected_version": "v1.0", 14 | "time_period": "2024", 15 | "impact": {"value": 100}, 16 | "policy_label": "Test Policy", 17 | "policy": {"type": "tax", "rate": 0.1}, 18 | "region": "NA", 19 | "relevant_parameters": ["param1", "param2"], 20 | "relevant_parameter_baseline_values": [1.0, 2.0], 21 | } 22 | 23 | 24 | def test_valid_payload(valid_payload): 25 | """Test that a valid payload passes validation""" 26 | is_valid, error = validate_sim_analysis_payload(valid_payload) 27 | assert is_valid is True 28 | assert error is None 29 | 30 | 31 | def test_missing_required_key(valid_payload): 32 | """Test that missing required keys are detected""" 33 | del valid_payload["currency"] 34 | is_valid, error = validate_sim_analysis_payload(valid_payload) 35 | assert is_valid is False 36 | assert "Missing required keys: ['currency']" in error 37 | 38 | 39 | def test_invalid_string_type(valid_payload): 40 | """Test that wrong type for string fields is detected""" 41 | valid_payload["currency"] = 123 # Should be string 42 | is_valid, error = validate_sim_analysis_payload(valid_payload) 43 | assert is_valid is False 44 | assert "Key 'currency' must be a string" in error 45 | 46 | 47 | def test_invalid_dict_type(valid_payload): 48 | """Test that wrong type for dictionary fields is detected""" 49 | valid_payload["impact"] = ["not", "a", "dict"] # Should be dict 50 | is_valid, error = validate_sim_analysis_payload(valid_payload) 51 | assert is_valid is False 52 | assert "Key 'impact' must be a dictionary" in error 53 | 54 | 55 | def test_invalid_list_type(valid_payload): 56 | """Test that wrong type for list fields is detected""" 57 | valid_payload["relevant_parameters"] = "not a list" # Should be list 58 | is_valid, error = validate_sim_analysis_payload(valid_payload) 59 | assert is_valid is False 60 | assert "Key 'relevant_parameters' must be a list" in error 61 | 62 | 63 | def test_extra_keys_allowed(valid_payload): 64 | """Test that extra keys don't cause validation to fail""" 65 | valid_payload["extra_key"] = "some value" 66 | is_valid, error = validate_sim_analysis_payload(valid_payload) 67 | assert is_valid is True 68 | assert error is None 69 | 70 | 71 | @pytest.mark.parametrize( 72 | "key", 73 | [ 74 | "currency", 75 | "selected_version", 76 | "time_period", 77 | "impact", 78 | "policy_label", 79 | "policy", 80 | "region", 81 | "relevant_parameters", 82 | "relevant_parameter_baseline_values", 83 | ], 84 | ) 85 | def test_individual_required_keys(valid_payload, key): 86 | """Test that each required key is properly checked""" 87 | del valid_payload[key] 88 | is_valid, error = validate_sim_analysis_payload(valid_payload) 89 | assert is_valid is False 90 | assert f"Missing required keys: ['{key}']" in error 91 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Unit tests directory 2 | -------------------------------------------------------------------------------- /tests/unit/ai_prompts/test_simulation_analysis_prompt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from policyengine_api.ai_prompts.simulation_analysis_prompt import ( 3 | generate_simulation_analysis_prompt, 4 | ) 5 | from tests.fixtures.simulation_analysis_prompt_fixtures import ( 6 | valid_input_us, 7 | valid_input_uk, 8 | invalid_data_missing_input_field, 9 | given_valid_data_and_dataset_is_enhanced_cps, 10 | ) 11 | 12 | 13 | class TestGenerateSimulationAnalysisPrompt: 14 | 15 | def test_given_valid_us_input(self, snapshot): 16 | 17 | snapshot.snapshot_dir = "tests/snapshots" 18 | 19 | prompt = generate_simulation_analysis_prompt(valid_input_us) 20 | snapshot.assert_match(prompt, "simulation_analysis_prompt_us.txt") 21 | 22 | def test_given_valid_uk_input(self, snapshot): 23 | 24 | snapshot.snapshot_dir = "tests/snapshots" 25 | 26 | prompt = generate_simulation_analysis_prompt(valid_input_uk) 27 | snapshot.assert_match(prompt, "simulation_analysis_prompt_uk.txt") 28 | 29 | def test_given_dataset_is_enhanced_cps(self, snapshot): 30 | 31 | snapshot.snapshot_dir = "tests/snapshots" 32 | valid_enhanced_cps_input_data = ( 33 | given_valid_data_and_dataset_is_enhanced_cps(valid_input_us) 34 | ) 35 | 36 | prompt = generate_simulation_analysis_prompt( 37 | valid_enhanced_cps_input_data 38 | ) 39 | snapshot.assert_match( 40 | prompt, "simulation_analysis_prompt_dataset_enhanced_cps.txt" 41 | ) 42 | 43 | def test_given_missing_input_field(self): 44 | 45 | with pytest.raises( 46 | Exception, 47 | match="1 validation error for InboundParameters\ntime_period\n Field required", 48 | ): 49 | generate_simulation_analysis_prompt( 50 | invalid_data_missing_input_field 51 | ) 52 | -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import sqlite3 3 | from policyengine_api.data import PolicyEngineDatabase 4 | from policyengine_api.constants import REPO 5 | 6 | 7 | class TestPolicyEngineDatabase(PolicyEngineDatabase): 8 | """Test version of PolicyEngineDatabase that uses in-memory SQLite""" 9 | 10 | def __init__(self, initialize: bool = True): 11 | self.local = True # Always use SQLite for tests 12 | if initialize: 13 | self._setup_connection() 14 | self.initialize() 15 | 16 | def _setup_connection(self): 17 | """Setup the in-memory connection""" 18 | if not hasattr(self, "_connection"): 19 | self._connection = sqlite3.connect(":memory:") 20 | 21 | def dict_factory(cursor, row): 22 | d = {} 23 | for idx, col in enumerate(cursor.description): 24 | d[col[0]] = row[idx] 25 | return d 26 | 27 | self._connection.row_factory = dict_factory 28 | 29 | def initialize(self): 30 | """ 31 | Override initialize to avoid file operations from parent class 32 | """ 33 | self._setup_connection() 34 | 35 | # Read the SQL initialization file 36 | init_file = ( 37 | REPO 38 | / "policyengine_api" 39 | / "data" 40 | / f"initialise{'_local' if self.local else ''}.sql" 41 | ) 42 | with open(init_file, "r") as f: 43 | full_query = f.read() 44 | 45 | # Split and execute the queries 46 | queries = full_query.split(";") 47 | for query in queries: 48 | if query.strip(): # Skip empty queries 49 | self.query(query) 50 | 51 | def query(self, *query): 52 | """Override query method to use in-memory connection""" 53 | if not hasattr(self, "_connection"): 54 | # Create a persistent connection for the in-memory database 55 | self._connection = sqlite3.connect(self.db_url) 56 | 57 | def dict_factory(cursor, row): 58 | d = {} 59 | for idx, col in enumerate(cursor.description): 60 | d[col[0]] = row[idx] 61 | return d 62 | 63 | self._connection.row_factory = dict_factory 64 | 65 | cursor = self._connection.cursor() 66 | result = cursor.execute(*query) 67 | self._connection.commit() 68 | return result 69 | 70 | def clean(self): 71 | """Clear all data from tables while preserving the schema""" 72 | if hasattr(self, "_connection"): 73 | cursor = self._connection.cursor() 74 | 75 | # Get all table names 76 | tables = cursor.execute( 77 | "SELECT name FROM sqlite_master WHERE type='table'" 78 | ).fetchall() 79 | 80 | # Disable foreign key checks temporarily 81 | cursor.execute("PRAGMA foreign_keys=OFF") 82 | 83 | # Delete all data from each table 84 | for table in tables: 85 | table_name = table["name"] 86 | cursor.execute(f"DELETE FROM {table_name}") 87 | 88 | # Re-enable foreign key checks 89 | cursor.execute("PRAGMA foreign_keys=ON") 90 | 91 | self._connection.commit() 92 | 93 | 94 | @pytest.fixture(scope="session") 95 | def test_db(): 96 | """Create a test database instance that persists for the whole test session""" 97 | db = TestPolicyEngineDatabase(initialize=True) 98 | yield db 99 | # Clean up the connection when done 100 | if hasattr(db, "_connection"): 101 | db._connection.close() 102 | 103 | 104 | @pytest.fixture(autouse=True) 105 | def override_database(test_db, monkeypatch): 106 | """ 107 | Global database override that affects all imports of the database. 108 | This fixture automatically applies to all tests. 109 | """ 110 | test_db.clean() 111 | 112 | # Patch at the root module level where database is defined 113 | import policyengine_api.data 114 | 115 | monkeypatch.setattr(policyengine_api.data, "database", test_db) 116 | 117 | # Also patch the module-level variable for any existing imports 118 | import sys 119 | 120 | for module_name, module in list(sys.modules.items()): 121 | if module_name.startswith("policyengine_api."): 122 | if hasattr(module, "database"): 123 | monkeypatch.setattr(module, "database", test_db) 124 | if hasattr(module, "local_database"): 125 | monkeypatch.setattr(module, "local_database", test_db) 126 | 127 | yield test_db 128 | -------------------------------------------------------------------------------- /tests/unit/jobs/__init__.py: -------------------------------------------------------------------------------- 1 | # Jobs tests directory 2 | -------------------------------------------------------------------------------- /tests/unit/services/test_ai_analysis_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | from policyengine_api.services.ai_analysis_service import AIAnalysisService 3 | from tests.fixtures.services.ai_analysis_service import ( 4 | mock_stream_text_events, 5 | mock_stream_error_event, 6 | patch_anthropic, 7 | parse_to_chunks, 8 | ) 9 | import pytest 10 | 11 | # Initialize the service 12 | service = AIAnalysisService() 13 | 14 | 15 | class TestTriggerAIAnalysis: 16 | 17 | def test_trigger_ai_analysis_given_successful_streaming( 18 | self, mock_stream_text_events, test_db 19 | ): 20 | # GIVEN a series of successful text messages from the Claude API 21 | expected_response = "This is a historical quote." 22 | text_chunks = parse_to_chunks(expected_response) 23 | mock_client = mock_stream_text_events(text_chunks=text_chunks) 24 | 25 | # WHEN we call trigger_ai_analysis 26 | prompt = "Tell me a historical quote" 27 | generator = service.trigger_ai_analysis(prompt) 28 | 29 | # THEN it should yield the expected chunks 30 | results = list(generator) 31 | 32 | # Verify each yielded chunk 33 | for i, chunk in enumerate(results): 34 | if i < len(text_chunks): 35 | expected_chunk = ( 36 | json.dumps({"type": "text", "stream": text_chunks[i][:5]}) 37 | + "\n" 38 | ) 39 | assert chunk == expected_chunk 40 | 41 | # Verify the database was updated with the complete response 42 | analysis_record = test_db.query( 43 | "SELECT * FROM analysis WHERE prompt = ?", (prompt,) 44 | ).fetchone() 45 | 46 | assert analysis_record is not None 47 | assert analysis_record["analysis"] == expected_response 48 | assert analysis_record["status"] == "ok" 49 | 50 | @pytest.mark.parametrize( 51 | "error_type", 52 | [ 53 | "overloaded_error", 54 | "api_error", 55 | "unknown_error", 56 | ], 57 | ) 58 | def test_trigger_ai_analysis_given_error( 59 | self, mock_stream_error_event, test_db, error_type 60 | ): 61 | # GIVEN an overloaded_error event from the Claude API 62 | mock_client = mock_stream_error_event(error_type) 63 | 64 | # WHEN we call trigger_ai_analysis 65 | prompt = "Tell me a historical quote about erroneous systems" 66 | generator = service.trigger_ai_analysis(prompt) 67 | 68 | # THEN it should yield the expected error message 69 | results = list(generator) 70 | 71 | # Verify the error message 72 | expected_error = ( 73 | json.dumps( 74 | { 75 | "type": "error", 76 | "error": error_type, 77 | } 78 | ) 79 | + "\n" 80 | ) 81 | assert results[0] == expected_error 82 | 83 | # Verify the database was not updated 84 | analysis_record = test_db.query( 85 | "SELECT * FROM analysis WHERE prompt = ?", (prompt,) 86 | ).fetchone() 87 | 88 | assert analysis_record is None 89 | -------------------------------------------------------------------------------- /tests/unit/services/test_execute_analysis.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from policyengine_api.services.tracer_analysis_service import ( 4 | TracerAnalysisService, 5 | ) 6 | from werkzeug.exceptions import NotFound 7 | 8 | from tests.fixtures.services.tracer_analysis_service import ( 9 | sample_tracer_data, 10 | sample_expected_segment, 11 | mock_get_tracer, 12 | mock_get_existing_analysis, 13 | mock_parse_tracer_output, 14 | mock_trigger_ai_analysis, 15 | ) 16 | 17 | service = TracerAnalysisService() 18 | country_id = "us" 19 | household_id = "71424" 20 | policy_id = "2" 21 | target_variable = "takes_up_snap_if_eligible" 22 | 23 | 24 | class TestExecuteAnalysis: 25 | def test_execute_analysis_static( 26 | self, 27 | mock_get_tracer, 28 | mock_parse_tracer_output, 29 | mock_get_existing_analysis, 30 | ): 31 | """ 32 | GIVEN a valid tracer data and an expected parsed segment (included as fixture), 33 | AND get_existing_analysis returns a static analysis (included as fixture), 34 | WHEN execute_analysis is called, 35 | THEN then a static analysis with the "static" flag should be returned. 36 | """ 37 | 38 | analysis, analysis_type = service.execute_analysis( 39 | country_id, household_id, policy_id, target_variable 40 | ) 41 | 42 | assert analysis == "Existing static analysis" 43 | assert analysis_type == "static" 44 | 45 | def test_execute_analysis_streaming( 46 | self, 47 | mock_get_tracer, 48 | mock_parse_tracer_output, 49 | mock_get_existing_analysis, 50 | mock_trigger_ai_analysis, 51 | ): 52 | """ 53 | GIVEN a valid tracer data and an expected parsed segment, 54 | AND get_existing_analysis returns None, 55 | WHEN execute_analysis is called, 56 | THEN trigger_ai_analysis is called and returns a generator with the "streaming" flag. 57 | """ 58 | 59 | # When existing analysis value is None 60 | mock_get_existing_analysis.return_value = None 61 | 62 | analysis, analysis_type = service.execute_analysis( 63 | country_id, household_id, policy_id, target_variable 64 | ) 65 | 66 | expected_streaming_output = ["stream chunk 1", "stream chunk 2"] 67 | streaming_output = list(analysis) 68 | assert streaming_output == expected_streaming_output 69 | assert analysis_type == "streaming" 70 | -------------------------------------------------------------------------------- /tests/unit/services/test_metadata_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from policyengine_api.services.metadata_service import MetadataService 3 | from policyengine_api.country import COUNTRIES 4 | 5 | 6 | class TestMetadataService: 7 | 8 | def test_get_metadata_nonexistent_country(self): 9 | service = MetadataService() 10 | # GIVEN a non-existent country ID 11 | invalid_country_id = "invalid_country" 12 | 13 | # WHEN we call get_metadata with an invalid country 14 | # THEN it should raise an Exception 15 | with pytest.raises( 16 | Exception, 17 | match=f"Attempted to get metadata for a nonexistant country: '{invalid_country_id}'", 18 | ): 19 | service.get_metadata(invalid_country_id) 20 | 21 | def test_get_metadata_empty_country_id(self): 22 | service = MetadataService() 23 | 24 | # GIVEN an empty country ID 25 | empty_country_id = "" 26 | 27 | # WHEN we call get_metadata with an empty country ID 28 | # THEN it should raise an Exception 29 | with pytest.raises( 30 | Exception, 31 | match=f"Attempted to get metadata for a nonexistant country: '{empty_country_id}'", 32 | ): 33 | service.get_metadata(empty_country_id) 34 | 35 | @pytest.mark.parametrize( 36 | "country_id, current_law_id, test_regions", 37 | [ 38 | ( 39 | "uk", 40 | 1, 41 | [ 42 | "uk", 43 | "country/england", 44 | "country/scotland", 45 | "country/wales", 46 | "country/ni", 47 | ], 48 | ), 49 | ("us", 2, ["us", "ca", "ny", "tx", "fl"]), 50 | ("ca", 3, ["ca"]), 51 | ("ng", 4, ["ng"]), 52 | ("il", 5, ["il"]), 53 | ], 54 | ) 55 | def test_verify_metadata_for_given_country( 56 | self, country_id, current_law_id, test_regions 57 | ): 58 | """ 59 | Verifies metadata for a specific country contains expected values. 60 | 61 | Args: 62 | service: The MetadataService fixture 63 | country_id: Country identifier string 64 | current_law_id: Expected current law ID for this country 65 | test_regions: List of region codes that should be present for this country 66 | """ 67 | # Get metadata for the country 68 | service = MetadataService() 69 | metadata = service.get_metadata(country_id) 70 | 71 | # Verify basic structure 72 | assert metadata is not None 73 | assert isinstance(metadata, dict) 74 | assert "variables" in metadata 75 | assert "parameters" in metadata 76 | assert "entities" in metadata 77 | assert "variableModules" in metadata 78 | assert "current_law_id" in metadata 79 | assert "economy_options" in metadata 80 | assert "basicInputs" in metadata 81 | assert "modelled_policies" in metadata 82 | assert "version" in metadata 83 | 84 | # Verify country-specific data 85 | assert metadata["current_law_id"] == current_law_id 86 | 87 | # Verify region data exists 88 | assert "region" in metadata["economy_options"] 89 | regions = metadata["economy_options"]["region"] 90 | for region in test_regions: 91 | assert any( 92 | r["name"] == region for r in regions 93 | ), f"Expected region '{region}' not found" 94 | 95 | # Verify time periods exist and have correct structure 96 | assert "time_period" in metadata["economy_options"] 97 | time_periods = metadata["economy_options"]["time_period"] 98 | assert isinstance(time_periods, list) 99 | assert len(time_periods) > 0 100 | 101 | # Check time period structure instead of specific values 102 | for period in time_periods: 103 | assert "name" in period 104 | assert "label" in period 105 | assert isinstance(period["name"], int) 106 | assert isinstance(period["label"], str) 107 | 108 | # Verify datasets exist and are of correct type 109 | assert "datasets" in metadata["economy_options"] 110 | assert isinstance(metadata["economy_options"]["datasets"], list) 111 | -------------------------------------------------------------------------------- /tests/unit/services/test_tracer_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from policyengine_api.services.tracer_analysis_service import ( 4 | TracerAnalysisService, 5 | ) 6 | from werkzeug.exceptions import NotFound 7 | 8 | from tests.fixtures.services.tracer_fixture_service import ( 9 | test_tracer_data, 10 | valid_tracer_row, 11 | valid_tracer, 12 | ) 13 | 14 | tracer_service = TracerAnalysisService() 15 | 16 | 17 | def test_get_tracer_valid(test_tracer_data): 18 | # Test get_tracer successfully retrieves valid data from the database. 19 | 20 | result = tracer_service.get_tracer( 21 | test_tracer_data["country_id"], 22 | test_tracer_data["household_id"], 23 | test_tracer_data["policy_id"], 24 | test_tracer_data["api_version"], 25 | ) 26 | 27 | # match the valid output as collected from fixture 28 | valid_output = valid_tracer["tracer_output"] 29 | assert result == valid_output 30 | 31 | 32 | def test_get_tracer_not_found(): 33 | # Test get_tracer raises NotFound when no matching record exists. 34 | valid_country_val_in_db = "us" 35 | invalid_household_not_in_db = "9999999" 36 | invalid_policyID_not_in_db = "999" 37 | invalid_api_version = "9.999.0" 38 | data_not_in_db = [ 39 | valid_country_val_in_db, 40 | invalid_household_not_in_db, 41 | invalid_policyID_not_in_db, 42 | invalid_api_version, 43 | ] 44 | with pytest.raises(NotFound): 45 | tracer_service.get_tracer(*data_not_in_db) 46 | 47 | 48 | def test_get_tracer_database_error(test_db): 49 | # Test get_tracer handles database errors properly. 50 | missing_country_id = "" 51 | valid_householdID = "71424" 52 | valid_policyID = "2" 53 | valid_api_version = "1.150.0" 54 | missing_parameter_causing_database_exception = [ 55 | missing_country_id, 56 | valid_householdID, 57 | valid_policyID, 58 | valid_api_version, 59 | ] 60 | with pytest.raises(Exception): 61 | tracer_service.get_tracer( 62 | *missing_parameter_causing_database_exception 63 | ) 64 | -------------------------------------------------------------------------------- /tests/unit/services/test_update_profile_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from policyengine_api.services.user_service import UserService 3 | 4 | from tests.fixtures.services.user_service import ( 5 | valid_user_record, 6 | existing_user_profile, 7 | ) 8 | 9 | service = UserService() 10 | 11 | 12 | class TestUpdateProfile: 13 | 14 | def test_update_profile_given_existing_record( 15 | self, test_db, existing_user_profile 16 | ): 17 | # GIVEN an existing profile record (from fixture) 18 | 19 | # WHEN we call update_profile with new data 20 | updated_username = "updated_username" 21 | updated_country = "uk" 22 | 23 | result = service.update_profile( 24 | user_id=existing_user_profile["user_id"], 25 | primary_country=updated_country, 26 | username=updated_username, 27 | user_since=existing_user_profile["user_since"], 28 | ) 29 | 30 | # THEN the method should return True for successful update 31 | assert result is True 32 | 33 | # AND the database should be updated with new values 34 | updated_record = test_db.query( 35 | "SELECT * FROM user_profiles WHERE user_id = ?", 36 | (existing_user_profile["user_id"],), 37 | ).fetchone() 38 | 39 | assert updated_record["username"] == updated_username 40 | assert updated_record["primary_country"] == updated_country 41 | 42 | def test_update_profile_given_nonexistent_record(self, test_db): 43 | # GIVEN a nonexistent profile record id 44 | NONEXISTENT_ID = 999 45 | 46 | # WHEN we call update_profile for this nonexistent record 47 | result = service.update_profile( 48 | user_id=NONEXISTENT_ID, 49 | primary_country="uk", 50 | username="newuser", 51 | user_since="2024-01-01", 52 | ) 53 | 54 | # THEN the result should be False 55 | assert result is False 56 | 57 | def test_update_profile_with_partial_fields( 58 | self, test_db, existing_user_profile 59 | ): 60 | # GIVEN an existing profile record (from fixture) 61 | 62 | # WHEN we call update_profile with only some fields provided 63 | updated_country = "CA" 64 | original_username = existing_user_profile["username"] 65 | 66 | result = service.update_profile( 67 | user_id=existing_user_profile["user_id"], 68 | primary_country=updated_country, 69 | username=None, 70 | user_since=existing_user_profile["user_since"], 71 | ) 72 | 73 | # THEN the method should return True for successful update 74 | assert result is True 75 | 76 | # AND only the provided fields should be updated 77 | updated_record = test_db.query( 78 | "SELECT * FROM user_profiles WHERE user_id = ?", 79 | (existing_user_profile["user_id"],), 80 | ).fetchone() 81 | 82 | assert updated_record["primary_country"] == updated_country 83 | assert ( 84 | updated_record["username"] == original_username 85 | ) # Username should remain unchanged 86 | 87 | def test_update_profile_with_database_error( 88 | self, monkeypatch, existing_user_profile 89 | ): 90 | # GIVEN an existing profile record (from fixture) 91 | 92 | # AND a database that raises an exception 93 | def mock_db_query_error(*args, **kwargs): 94 | raise Exception("Database error") 95 | 96 | monkeypatch.setattr( 97 | "policyengine_api.data.database.query", mock_db_query_error 98 | ) 99 | 100 | # WHEN we call update_profile 101 | # THEN an exception should be raised 102 | with pytest.raises(Exception, match="Database error"): 103 | service.update_profile( 104 | user_id=existing_user_profile["user_id"], 105 | primary_country="US", 106 | username="testuser", 107 | user_since="2023-01-01", 108 | ) 109 | 110 | def test_update_profile_id_not_specified(self): 111 | # GIVEN no user_id specified 112 | 113 | # WHEN we call update_profile with None as user_id 114 | # THEN a ValueError should be raised 115 | with pytest.raises( 116 | ValueError, match="you must specify either auth0_id or user_id" 117 | ): 118 | service.update_profile( 119 | user_id=None, 120 | primary_country="US", 121 | username="testuser", 122 | user_since="2023-01-01", 123 | ) 124 | -------------------------------------------------------------------------------- /tests/unit/services/test_user_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from policyengine_api.services.user_service import UserService 3 | 4 | from tests.fixtures.services.user_service import ( 5 | valid_user_record, 6 | existing_user_profile, 7 | ) 8 | 9 | service = UserService() 10 | 11 | 12 | class TestGetProfile: 13 | 14 | def test_get_profile_id_not_specified(self): 15 | # GIVEN no ID 16 | # WHEN we call get_profile with no auth0_id or user_id 17 | 18 | # Then a ValueError should be raised 19 | with pytest.raises( 20 | ValueError, match="you must specify either auth0_id or user_id" 21 | ): 22 | service.get_profile() 23 | 24 | def test_get_profile_nonexistent_record(self): 25 | # GIVEN nonexistent record 26 | INVALID_RECORD_ID = "invalid" 27 | 28 | # WHEN we call get_profile with nonexistent user 29 | result = service.get_profile(auth0_id=INVALID_RECORD_ID) 30 | 31 | # THEN result is None 32 | assert result is None 33 | 34 | def test_get_profile_auth0_id(self, existing_user_profile): 35 | # WHEN we call get_profile with auth0_id 36 | result = service.get_profile( 37 | auth0_id=existing_user_profile["auth0_id"] 38 | ) 39 | 40 | # THEN returns record 41 | assert result == existing_user_profile 42 | 43 | def test_get_profile_user_id(self, existing_user_profile): 44 | # WHEN we call get_profile with user_id 45 | result = service.get_profile(user_id=existing_user_profile["user_id"]) 46 | 47 | # THEN returns record 48 | assert result == existing_user_profile 49 | 50 | def test_get_profile_id_priority(self, test_db, existing_user_profile): 51 | 52 | # WHEN we call get_profile with auth0_id and user_id 53 | result = service.get_profile( 54 | auth0_id=existing_user_profile["auth0_id"], 55 | user_id=existing_user_profile["user_id"], 56 | ) 57 | 58 | # THEN returns record using auth0_id 59 | record = test_db.query( 60 | "SELECT * FROM user_profiles WHERE auth0_id = ?", 61 | (valid_user_record["auth0_id"],), 62 | ).fetchone() 63 | 64 | assert result == record 65 | --------------------------------------------------------------------------------