├── .coverage ├── .coveragerc ├── .env.dev ├── .env.test ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── scripts │ └── bump_chart_version.py └── workflows │ ├── black.yml │ ├── docs_publish.yml │ ├── helm_charts_release.yml │ ├── main.yml │ ├── pull_request.yml │ ├── release_pypi_package.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .streamlit └── config.toml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENCE ├── README.md ├── SECURITY.md ├── assets ├── logo.svg ├── ui-prototype-demo.gif └── ui-prototype-demo.mp4 ├── docs └── mkdocs │ ├── docs │ ├── css │ │ ├── custom.css │ │ ├── extra.css │ │ └── termynal.css │ ├── features.md │ ├── img │ │ ├── favicon.png │ │ ├── logo-white.png │ │ └── logo.svg │ ├── index.md │ ├── js │ │ ├── custom.js │ │ ├── mathjax-config.js │ │ └── termynal.js │ ├── metric-definitions.md │ ├── ml-monitoring.md │ ├── sdk-docs.md │ └── tutorial │ │ ├── installation.md │ │ ├── metrics.md │ │ ├── monitors_alerts.md │ │ └── sdk.md │ └── mkdocs.yml ├── examples ├── docker-compose │ └── docker-compose.yml └── notebooks │ └── sdk-example.ipynb ├── helm_charts └── whitebox │ ├── Chart.yaml │ ├── artifacthub-repo.yml │ ├── templates │ ├── _helpers.tpl │ ├── deployment.yaml │ ├── ingress.yaml │ └── service.yaml │ └── values.yaml ├── pytest.ini ├── requirements.txt ├── scripts └── decrypt_api_key.py ├── setup.py └── whitebox ├── .streamlit └── config.toml ├── __init__.py ├── analytics ├── __init__.py ├── data │ └── testing │ │ ├── classification_test_data.csv │ │ ├── metrics_test_data.csv │ │ ├── regression_test_data.csv │ │ └── udemy_fin_adj.csv ├── drift │ └── pipelines.py ├── metrics │ ├── functions.py │ └── pipelines.py ├── models │ └── pipelines.py ├── tests │ ├── __init__.py │ └── test_pipelines.py └── xai_models │ └── pipelines.py ├── api ├── __init__.py └── v1 │ ├── __init__.py │ ├── alerts.py │ ├── cron_tasks.py │ ├── dataset_rows.py │ ├── docs.py │ ├── drifting_metrics.py │ ├── health.py │ ├── inference_rows.py │ ├── model_integrity_metrics.py │ ├── model_monitors.py │ ├── models.py │ └── performance_metrics.py ├── core ├── __init__.py ├── db.py ├── manager.py └── settings.py ├── cron.py ├── cron_tasks ├── monitoring_alerts.py ├── monitoring_metrics.py ├── shared.py └── tasks.py ├── crud ├── __init__.py ├── alerts.py ├── base.py ├── dataset_rows.py ├── drifting_metrics.py ├── inference_rows.py ├── model_integrity_metrics.py ├── model_monitors.py ├── models.py ├── performance_metrics.py └── users.py ├── entities ├── Alert.py ├── Base.py ├── DatasetRow.py ├── DriftingMetric.py ├── Inference.py ├── Model.py ├── ModelIntegrityMetric.py ├── ModelMonitor.py ├── PerformanceMetric.py ├── User.py └── __init__.py ├── main.py ├── middleware ├── __initi__.py └── auth.py ├── schemas ├── __init__.py ├── alert.py ├── base.py ├── datasetRow.py ├── driftingMetric.py ├── inferenceRow.py ├── model.py ├── modelIntegrityMetric.py ├── modelMonitor.py ├── performanceMetric.py ├── task.py ├── user.py └── utils.py ├── sdk ├── __init__.py └── whitebox.py ├── streamlit ├── app.py ├── cards.py ├── classification_test_data copy.csv ├── classification_test_data.csv ├── config │ └── config_readme.toml ├── mock │ ├── alerts.json │ ├── drift.json │ ├── inferences.json │ ├── monitors.json │ ├── performance.json │ └── t.ipynb ├── mock_app.py ├── references │ ├── logo.png │ └── whitebox2.png ├── tabs │ ├── alerts.py │ ├── drifting.py │ ├── inferences.py │ ├── monitors.py │ ├── overview.py │ ├── performance.py │ └── sidebar.py └── utils │ ├── export.py │ ├── graphs.py │ ├── load.py │ ├── style.css │ └── transformation.py ├── tests ├── __init__.py ├── unit_tests │ └── test_unit.py ├── utils │ └── maps.py └── v1 │ ├── __init__.py │ ├── conftest.py │ ├── mock_data.py │ ├── test_alerts.py │ ├── test_cron_tasks.py │ ├── test_dataset_rows.py │ ├── test_drifting_metrics.py │ ├── test_errors.py │ ├── test_health.py │ ├── test_inference_rows.py │ ├── test_model_integrity_metrics.py │ ├── test_model_monitors.py │ ├── test_models.py │ ├── test_performance_metrics.py │ └── test_sdk.py └── utils ├── __init__.py ├── errors.py ├── exceptions.py ├── id_gen.py ├── logger.py └── passwords.py /.coverage: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/.coverage -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | *__init__* 4 | */usr/local/lib* 5 | *test* 6 | 7 | [report] 8 | omit = 9 | *test* 10 | */usr/local/lib* 11 | */__init__.py 12 | -------------------------------------------------------------------------------- /.env.dev: -------------------------------------------------------------------------------- 1 | ENV=dev 2 | APP_NAME=Whitebox | Development 3 | APP_NAME_CRON=Whitebox | Development 4 | 5 | DATABASE_URL=postgresql://postgres:postgres@localhost:5432/postgres 6 | 7 | VERSION=0.1.0 8 | METRICS_CRON=*/15 * * * * 9 | 10 | MODEL_PATH=models 11 | -------------------------------------------------------------------------------- /.env.test: -------------------------------------------------------------------------------- 1 | ENV=test 2 | APP_NAME=Whitebox | Test 3 | APP_NAME_CRON=Whitebox | Test 4 | SECRET_KEY=3beae33e30bcdaf6b172e17dc8f26341 5 | 6 | DATABASE_URL=postgresql://postgres:postgres@localhost:5432/test 7 | 8 | VERSION=0.1.0 9 | METRICS_CRON=*/15 * * * * 10 | 11 | MODEL_PATH=models/test_model -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Ignore Jupyter notebooks from Git stats 2 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | - [ ] closes #xxxx (Replace xxxx with the GitHub issue number) 2 | - [ ] `Tests added and passed` if fixing a bug or adding a new feature. 3 | - [ ] All `code checks passed`. 4 | - [ ] Added `type annotations` to new arguments/methods/functions. 5 | - [ ] Updated documentation. -------------------------------------------------------------------------------- /.github/scripts/bump_chart_version.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import sys 3 | 4 | with open("helm_charts/whitebox/Chart.yaml", "r") as f: 5 | chart = yaml.safe_load(f) 6 | 7 | if len(sys.argv) < 2: 8 | raise Exception("Version number not provided") 9 | 10 | # If it starts with a v, remove it 11 | if sys.argv[1].startswith("v"): 12 | sys.argv[1] = sys.argv[1][1:] 13 | 14 | chart["version"] = sys.argv[1] 15 | 16 | with open("helm_charts/whitebox/Chart.yaml", "w") as f: 17 | yaml.dump(chart, f) 18 | 19 | print(f"Updated chart version to {sys.argv[1]}") 20 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Black Formatter 2 | 3 | on: pull_request 4 | jobs: 5 | lint: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v2 9 | - uses: psf/black@23.1.0 -------------------------------------------------------------------------------- /.github/workflows/docs_publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs 2 | on: 3 | release: 4 | types: [published] 5 | 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.10" 16 | - run: pip install mkdocs-material 17 | - run: mkdocs gh-deploy --config-file docs/mkdocs/mkdocs.yml --force 18 | -------------------------------------------------------------------------------- /.github/workflows/helm_charts_release.yml: -------------------------------------------------------------------------------- 1 | name: Release Charts 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | release_charts: 9 | permissions: 10 | contents: write 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v2 15 | with: 16 | fetch-depth: 0 17 | 18 | - name: Install Helm 19 | uses: azure/setup-helm@v3 20 | with: 21 | version: v3.10.0 22 | 23 | - name: Set up Python 3.9 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: 3.9 27 | 28 | - name: Extract chart version and remove v prefix 29 | id: extract_version 30 | run: | 31 | CHART_VERSION=$(echo ${{ github.ref }} | cut -d'/' -f3) 32 | echo "::set-output name=version::${CHART_VERSION#v}" 33 | 34 | - name: Package whitebox chart 35 | run: | 36 | helm package -u --version ${{ steps.extract_version.outputs.version }} helm_charts/whitebox 37 | 38 | - name: Publish whitebox chart 39 | run: | 40 | curl --data-binary "@whitebox-${{ steps.extract_version.outputs.version }}.tgz" chartmuseum.squaredev.io/api/charts -u ${{ secrets.CHARTMUSEUM_USERNAME }}:${{ secrets.CHARTMUSEUM_PASSWORD }} 41 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Test & Publish 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | 7 | workflow_dispatch: 8 | 9 | jobs: 10 | test: 11 | uses: ./.github/workflows/test.yml 12 | 13 | whitebox: 14 | needs: 15 | - test 16 | uses: squaredev-io/gh-workflows/.github/workflows/base_build_publish.yml@main 17 | with: 18 | image: whitebox 19 | dockerfile: Dockerfile 20 | runs-on: ubuntu-latest 21 | secrets: 22 | docker_username: ${{ secrets.DOCKER_USERNAME }} 23 | docker_access_token: ${{ secrets.DOCKER_ACCESS_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | paths-ignore: 7 | - 'docs/**' 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | test: 13 | uses: ./.github/workflows/test.yml 14 | -------------------------------------------------------------------------------- /.github/workflows/release_pypi_package.yml: -------------------------------------------------------------------------------- 1 | name: Release PyPI Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | release_package: 9 | permissions: 10 | contents: write 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v2 15 | with: 16 | fetch-depth: 0 17 | 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.9 22 | 23 | - name: Build a binary wheel and a source tarball. 24 | run: pip install wheel && python setup.py sdist bdist_wheel 25 | 26 | - name: Publish distribution 📦 to PyPI 27 | if: startsWith(github.ref, 'refs/tags') 28 | uses: pypa/gh-action-pypi-publish@release/v1 29 | with: 30 | password: ${{ secrets.PYPI_API_TOKEN }} 31 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # On workflow call 2 | 3 | name: test 4 | 5 | on: workflow_call 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | services: 11 | postgres: 12 | image: postgres 13 | env: 14 | POSTGRES_DB: test 15 | POSTGRES_PASSWORD: postgres 16 | ports: 17 | - '5432:5432' 18 | options: >- 19 | --health-cmd pg_isready 20 | --health-interval 10s 21 | --health-timeout 5s 22 | --health-retries 5 23 | steps: 24 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 25 | - uses: actions/checkout@v3 26 | - uses: actions/setup-python@v4 27 | with: 28 | python-version: '3.10' 29 | cache: 'pip' # caching pip dependencies 30 | - run: sudo apt-get update && sudo apt-get install libpq-dev postgresql-client -y 31 | - run: pip install -r requirements.txt 32 | - name: Test 33 | run: ENV="test" pytest 34 | timeout-minutes: 4 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__/ 3 | *__pycache__ 4 | *.vscode 5 | *.py[cod] 6 | *$py.class 7 | .pytest_cache 8 | *.DS_Store 9 | check_data/ 10 | /models/ 11 | *.ipynb_checkpoints 12 | *.db 13 | site/ 14 | test.ipynb 15 | dist/ 16 | build/ 17 | whitebox_sdk.egg-info 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.1.0 4 | hooks: 5 | - id: black 6 | entry: bash -c 'black "$@"; git add -u' -- 7 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="dark" 3 | primaryColor="#21babe" 4 | backgroundColor="#1e2025" 5 | secondaryBackgroundColor="#252a33" 6 | 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Whitebox 2 | We welcome contributions to the Whitebox project! If you're interested in helping out, please join our Discord server: https://discord.gg/TPw5vXqn 3 | 4 | ## How to contribute 5 | 1. Fork the Whitebox repository 6 | 2. Create a new branch for your changes 7 | 3. Make your changes 8 | 4. Submit a pull request 9 | Please make sure to follow the project's coding style and to add test cases for any new or changed functionality. Also update the docs where needed. 10 | 11 | ## Bug reports and feature requests 12 | If you find a bug or have a feature request, please open an issue on the GitHub repository. 13 | 14 | Thank you for your interest in contributing to Whitebox! 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | RUN apt-get update && apt-get install libpq-dev python-dev -y 4 | 5 | WORKDIR /whitebox 6 | 7 | COPY requirements.txt requirements.txt 8 | 9 | RUN pip install -r requirements.txt 10 | 11 | COPY . . 12 | 13 | EXPOSE 8000 14 | 15 | ENTRYPOINT ENV=dev uvicorn whitebox.main:app --reload --host 0.0.0.0 --port 8000 16 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2022] [Squaredev BV] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Reporting a Vulnerability 4 | 5 | Open an issue with the security label. Use the bug template. 6 | -------------------------------------------------------------------------------- /assets/ui-prototype-demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/assets/ui-prototype-demo.gif -------------------------------------------------------------------------------- /assets/ui-prototype-demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/assets/ui-prototype-demo.mp4 -------------------------------------------------------------------------------- /docs/mkdocs/docs/css/custom.css: -------------------------------------------------------------------------------- 1 | .termynal-comment { 2 | color: #4a968f; 3 | font-style: italic; 4 | display: block; 5 | } 6 | 7 | .termy { 8 | /* For right to left languages */ 9 | direction: ltr; 10 | } 11 | 12 | .termy [data-termynal] { 13 | white-space: pre-wrap; 14 | } 15 | 16 | a.external-link { 17 | /* For right to left languages */ 18 | direction: ltr; 19 | display: inline-block; 20 | } 21 | 22 | a.external-link::after { 23 | /* \00A0 is a non-breaking space 24 | to make the mark be on the same line as the link 25 | */ 26 | content: "\00A0[↪]"; 27 | } 28 | 29 | a.internal-link::after { 30 | /* \00A0 is a non-breaking space 31 | to make the mark be on the same line as the link 32 | */ 33 | content: "\00A0↪"; 34 | } 35 | 36 | .shadow { 37 | box-shadow: 5px 5px 10px #999; 38 | } 39 | 40 | /* Give space to lower icons so Gitter chat doesn't get on top of them */ 41 | .md-footer-meta { 42 | padding-bottom: 2em; 43 | } 44 | 45 | .user-list { 46 | display: flex; 47 | flex-wrap: wrap; 48 | margin-bottom: 2rem; 49 | } 50 | 51 | .user-list-center { 52 | justify-content: space-evenly; 53 | } 54 | 55 | .user { 56 | margin: 1em; 57 | min-width: 7em; 58 | } 59 | 60 | .user .avatar-wrapper { 61 | width: 80px; 62 | height: 80px; 63 | margin: 10px auto; 64 | overflow: hidden; 65 | border-radius: 50%; 66 | position: relative; 67 | } 68 | 69 | .user .avatar-wrapper img { 70 | position: absolute; 71 | top: 50%; 72 | left: 50%; 73 | transform: translate(-50%, -50%); 74 | } 75 | 76 | .user .title { 77 | text-align: center; 78 | } 79 | 80 | .user .count { 81 | font-size: 80%; 82 | text-align: center; 83 | } 84 | 85 | a.announce-link:link, 86 | a.announce-link:visited { 87 | color: #fff; 88 | } 89 | 90 | a.announce-link:hover { 91 | color: var(--md-accent-fg-color); 92 | } 93 | 94 | .announce-wrapper { 95 | display: flex; 96 | justify-content: space-between; 97 | flex-wrap: wrap; 98 | align-items: center; 99 | } 100 | 101 | .announce-wrapper div.item { 102 | display: none; 103 | } 104 | 105 | .announce-wrapper .sponsor-badge { 106 | display: block; 107 | position: absolute; 108 | top: -10px; 109 | right: 0; 110 | font-size: 0.5rem; 111 | color: #999; 112 | background-color: #666; 113 | border-radius: 10px; 114 | padding: 0 10px; 115 | z-index: 10; 116 | } 117 | 118 | .announce-wrapper .sponsor-image { 119 | display: block; 120 | border-radius: 20px; 121 | } 122 | 123 | .announce-wrapper>div { 124 | min-height: 40px; 125 | display: flex; 126 | align-items: center; 127 | } 128 | 129 | .twitter { 130 | color: #00acee; 131 | } 132 | 133 | /* Right to left languages */ 134 | code { 135 | direction: ltr; 136 | display: inline-block; 137 | } 138 | 139 | .md-content__inner h1 { 140 | direction: ltr !important; 141 | } 142 | 143 | .illustration { 144 | margin-top: 2em; 145 | margin-bottom: 2em; 146 | } 147 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/css/extra.css: -------------------------------------------------------------------------------- 1 | :root { 2 | 3 | /* Primary color shades */ 4 | --md-primary-fg-color: #21babe; 5 | --md-primary-fg-color--light: #21babe; 6 | --md-primary-fg-color--dark: #86e6e9; 7 | --md-primary-bg-color: hsla(0, 0%, 100%, 1); 8 | --md-primary-bg-color--light: hsla(0, 0%, 100%, 0.7); 9 | --md-typeset-a-color: #21babe; 10 | 11 | /* Accent color shades */ 12 | --md-accent-fg-color: #006493; 13 | --md-accent-fg-color--transparent: hsla(189, 100%, 37%, 0.1); 14 | --md-accent-bg-color: hsla(0, 0%, 100%, 1); 15 | --md-accent-bg-color--light: hsla(0, 0%, 100%, 0.7); 16 | } 17 | 18 | :root > * { 19 | 20 | /* Code block color shades */ 21 | --md-code-bg-color: hsla(0, 0%, 96%, 1); 22 | --md-code-fg-color: hsla(200, 18%, 26%, 1); 23 | 24 | /* Footer */ 25 | --md-footer-bg-color: #21babe; 26 | --md-footer-bg-color--dark: hsla(0, 0%, 0%, 0.32); 27 | --md-footer-fg-color: hsla(0, 0%, 100%, 1); 28 | --md-footer-fg-color--light: hsla(0, 0%, 100%, 0.7); 29 | --md-footer-fg-color--lighter: hsla(0, 0%, 100%, 0.3); 30 | } -------------------------------------------------------------------------------- /docs/mkdocs/docs/css/termynal.css: -------------------------------------------------------------------------------- 1 | /** 2 | * termynal.js 3 | * 4 | * @author Ines Montani 5 | * @version 0.0.1 6 | * @license MIT 7 | */ 8 | 9 | :root { 10 | --color-bg: #252a33; 11 | --color-text: #eee; 12 | --color-text-subtle: #a2a2a2; 13 | } 14 | 15 | [data-termynal] { 16 | width: 750px; 17 | max-width: 100%; 18 | background: var(--color-bg); 19 | color: var(--color-text); 20 | /* font-size: 18px; */ 21 | font-size: 15px; 22 | /* font-family: 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; */ 23 | font-family: 'Roboto Mono', 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; 24 | border-radius: 4px; 25 | padding: 75px 45px 35px; 26 | position: relative; 27 | -webkit-box-sizing: border-box; 28 | box-sizing: border-box; 29 | } 30 | 31 | [data-termynal]:before { 32 | content: ''; 33 | position: absolute; 34 | top: 15px; 35 | left: 15px; 36 | display: inline-block; 37 | width: 15px; 38 | height: 15px; 39 | border-radius: 50%; 40 | /* A little hack to display the window buttons in one pseudo element. */ 41 | background: #d9515d; 42 | -webkit-box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 43 | box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 44 | } 45 | 46 | [data-termynal]:after { 47 | content: 'bash'; 48 | position: absolute; 49 | color: var(--color-text-subtle); 50 | top: 5px; 51 | left: 0; 52 | width: 100%; 53 | text-align: center; 54 | } 55 | 56 | a[data-terminal-control] { 57 | text-align: right; 58 | display: block; 59 | color: #aebbff; 60 | } 61 | 62 | [data-ty] { 63 | display: block; 64 | line-height: 2; 65 | } 66 | 67 | [data-ty]:before { 68 | /* Set up defaults and ensure empty lines are displayed. */ 69 | content: ''; 70 | display: inline-block; 71 | vertical-align: middle; 72 | } 73 | 74 | [data-ty="input"]:before, 75 | [data-ty-prompt]:before { 76 | margin-right: 0.75em; 77 | color: var(--color-text-subtle); 78 | } 79 | 80 | [data-ty="input"]:before { 81 | content: '$'; 82 | } 83 | 84 | [data-ty][data-ty-prompt]:before { 85 | content: attr(data-ty-prompt); 86 | } 87 | 88 | [data-ty-cursor]:after { 89 | content: attr(data-ty-cursor); 90 | font-family: monospace; 91 | margin-left: 0.5em; 92 | -webkit-animation: blink 1s infinite; 93 | animation: blink 1s infinite; 94 | } 95 | 96 | 97 | /* Cursor animation */ 98 | 99 | @-webkit-keyframes blink { 100 | 50% { 101 | opacity: 0; 102 | } 103 | } 104 | 105 | @keyframes blink { 106 | 50% { 107 | opacity: 0; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/features.md: -------------------------------------------------------------------------------- 1 | # Features 2 | 3 | ## Design decisions 4 | 5 | - **Easy**: Very easy to set up and get started with. 6 | - **Intuitive**: Designed to be intuitive and easy to use. 7 | - **Pythonic SDK**: Pythonic SDK for building your own monitoring infrastructure. 8 | - **Robust**: Get production-ready MLOps system. 9 | - **Kubernetes**: Get production-ready code. With automatic interactive documentation. 10 | 11 | ## Descriptive Statistics 12 | 13 | Whitebox provides a nice [list of descriptive statistics](../metric-definitions/#descriptive-statistics) of input dataset, making the overview of data easy. 14 | 15 | ## Models Metrics 16 | 17 | ### Classification Models 18 | 19 | Whitebox includes comprehensive [metrics](../metric-definitions/#evaluation-metrics) tracking for classification models. This allows users to easily evaluate the performance of their classification models and identify areas for improvement. Additionally, users can set custom thresholds for each metric to receive alerts when performance deviates from expected results. 20 | 21 | ### Regression Models 22 | 23 | Whitebox includes comprehensive [metrics](../metric-definitions/#evaluation-metrics) tracking for regression models. This allows users to easily evaluate the performance of their regression models and identify areas for improvement. Additionally, users can set custom thresholds for each metric to receive alerts when performance deviates from expected results. 24 | 25 | ## Data / Concept Drift Monitoring 26 | 27 | Whitebox includes monitoring for [data and concept drift](../metric-definitions/#statistical-tests-and-techniques). This feature tracks changes in the distribution of the data used to train models and alerts users when significant changes occur. Additionally, it detects changes in the performance of deployed models and alerts users when significant drift is detected. This allows users to identify and address data and model drift early, reducing the risk of poor model performance. 28 | 29 | ## Explainable AI 30 | 31 | Whitebox includes [model explaination](../metric-definitions/#explainable-ai-models) also. The explainability performed through the explainability report which allows user to know anytime which feature had the most impact on model's prediction. 32 | 33 | ## Alerts 34 | 35 | Whitebox includes an [alerting system](../tutorial/monitors_alerts/#monitors-and-alerts) that allows users to set custom thresholds for metrics and receive notifications when performance deviates from expected results. These alerts can be delivered via email, SMS, or webhook, and can be customized to fit the needs of any organization. This allows users to quickly respond to changes in model performance and take action to improve results. 36 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/img/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/docs/mkdocs/docs/img/favicon.png -------------------------------------------------------------------------------- /docs/mkdocs/docs/img/logo-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/docs/mkdocs/docs/img/logo-white.png -------------------------------------------------------------------------------- /docs/mkdocs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Whitebox 2 | 3 |

4 | 5 | Whitebox 6 | 7 |

8 |

9 | Whitebox is an open source E2E ML monitoring platform with edge capabilities that plays nicely with kubernetes 10 | 11 |

12 | 13 | --- 14 | 15 | **Documentation**: https://whitebox-ai.github.io/whitebox/ 16 | 17 | **Source Code**: https://github.com/whitebox-ai/whitebox 18 | 19 | **Roadmap**: https://github.com/whitebox-ai/whitebox/milestones 20 | 21 | **Issue tracking** https://github.com/orgs/whitebox-ai/projects/1/views/3 22 | 23 | **Discord**: https://discord.gg/bkAcsx4V 24 | 25 | --- 26 | 27 | Whitebox is a dynamic open source E2E ML monitoring platform with edge capabilities that plays nicely with kubernetes. 28 | 29 | The current key features are: 30 | 31 | - **Descriptive statistics** 32 | - **Classification models evaluation metrics** 33 | - **Data / Concept drift monitoring** 34 | - **Explainable AI** 35 | - **Alerts** 36 | 37 | Design guidelines: 38 | 39 | - **Easy**: Very easy to set up and get started with. 40 | - **Intuitive**: Designed to be intuitive and easy to use. 41 | - **Pythonic SDK**: Pythonic SDK for building your own monitoring infrastructure. 42 | - **Robust**: Get production-ready MLOps system. 43 | - **Kubernetes**: Get production-ready code. With automatic interactive documentation. 44 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/js/mathjax-config.js: -------------------------------------------------------------------------------- 1 | /* mathjax-loader.js file */ 2 | /* ref: http://facelessuser.github.io/pymdown-extensions/extensions/arithmatex/ */ 3 | (function (win, doc) { 4 | win.MathJax = { 5 | config: ["MMLorHTML.js"], 6 | extensions: ["tex2jax.js"], 7 | jax: ["input/TeX"], 8 | tex2jax: { 9 | inlineMath: [ ["\\(","\\)"] ], 10 | displayMath: [ ["\\[","\\]"] ] 11 | }, 12 | TeX: { 13 | TagSide: "right", 14 | TagIndent: ".8em", 15 | MultLineWidth: "85%", 16 | equationNumbers: { 17 | autoNumber: "AMS", 18 | }, 19 | unicode: { 20 | fonts: "STIXGeneral,'Arial Unicode MS'" 21 | } 22 | }, 23 | displayAlign: 'center', 24 | showProcessingMessages: false, 25 | messageStyle: 'none' 26 | }; 27 | })(window, document); -------------------------------------------------------------------------------- /docs/mkdocs/docs/ml-monitoring.md: -------------------------------------------------------------------------------- 1 | # Understanding machine learning monitoring 2 | 3 | Machine learning (ML) monitoring is a crucial process for ensuring the performance and reliability of ML models and systems. It involves tracking metrics, identifying issues, and improving overall performance. In this article, we will explore the key aspects of ML monitoring and its importance in today's data-driven world. 4 | 5 | ## What is machine learning monitoring? 6 | 7 | ML monitoring is the process of tracking the performance and behavior of ML models and systems. This includes monitoring metrics such as accuracy, precision, recall, and model performance over time. By monitoring these metrics, organizations can identify when a model is performing poorly or behaving unexpectedly and take action to correct it. Additionally, ML monitoring can be used to track the performance of different models and compare them to identify which model is performing best. 8 | 9 | ## Detecting poor performance 10 | 11 | One of the key aspects of ML monitoring is being able to detect when a model is performing poorly or behaving unexpectedly. This can be done by setting up alerts on certain metrics or by monitoring for unexpected changes in the model's behavior. For example, if a model's accuracy suddenly drops, an alert can be triggered to notify the team responsible for maintaining the model. This allows them to quickly investigate and address the issue, minimizing the impact on the organization's operations. 12 | 13 | ## Monitoring data and concept drift 14 | 15 | Another important aspect of ML monitoring is the ability to detect and address data and concept drifting. Data drifting refers to the gradual change in the distribution or characteristics of the input data over time. As the data changes, the model's performance may decrease, which is why it is important to detect data drifting and retrain the model with new data. Concept drifting, on the other hand, happens when the statistical properties of the target variable, which the model is trying to predict, change over time. This can happen due to various reasons such as changes in the underlying data distribution, overfitting, or degradation of the model's parameters. To address these issues, beside drift detection, organizations can use techniques such as monitoring of the performance of the model over time, retraining the model regularly etc. 16 | 17 | ## Monitoring input data 18 | 19 | Another important aspect of ML monitoring is being able to track the input data that is being fed into the model. This can be used to identify any issues with the data, such as missing values, and to ensure that the data is being processed correctly. This helps to ensure that the model is making accurate predictions and that the data is being used effectively. 20 | 21 | ## Monitoring infrastructure and resources 22 | 23 | Finally, it is important to monitor the infrastructure and resources that the ML models are running on. This includes monitoring things like CPU and memory usage, disk space, and network traffic. This helps to ensure that the models have the resources they need to perform well and to identify any potential bottlenecks that may be impacting performance. By monitoring the infrastructure, organizations can ensure that their models are running smoothly and can make adjustments as needed. 24 | 25 | ## Conclusion 26 | 27 | In conclusion, ML monitoring is an essential process that helps organizations ensure the performance and reliability of their ML models and systems. By tracking metrics, identifying issues, and monitoring input data and infrastructure, organizations can ensure that their models are delivering accurate and reliable results and that their systems are running smoothly. With the increasing reliance on data and ML, the importance of ML monitoring will only continue to grow. 28 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/tutorial/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Docker 4 | 5 | Install whitebox server and all of its dependencies using `docker-compose` 6 | 7 | Copy the following code in a file named `docker-compose.yml`: 8 | 9 | ```yaml 10 | version: "3.10" 11 | name: Whitebox 12 | services: 13 | postgres: 14 | image: postgres:15 15 | restart: unless-stopped 16 | environment: 17 | - POSTGRES_USER=postgres 18 | - POSTGRES_PASSWORD=postgres 19 | - POSTGRES_MULTIPLE_DATABASES=test # postgres db is created by default 20 | logging: 21 | options: 22 | max-size: 10m 23 | max-file: "3" 24 | ports: 25 | - "5432:5432" 26 | volumes: 27 | - wb_data:/var/lib/postgresql/data 28 | networks: 29 | - whitebox 30 | 31 | whitebox: 32 | image: sqdhub/whitebox:main 33 | restart: unless-stopped 34 | environment: 35 | - APP_NAME=Whitebox | Docker 36 | - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/postgres 37 | - SECRET_KEY= # Optional, if not set the API key won't be encrypted 38 | ports: 39 | - "8000:8000" 40 | depends_on: 41 | - postgres 42 | networks: 43 | - whitebox 44 | 45 | volumes: 46 | wb_data: 47 | 48 | networks: 49 | whitebox: 50 | name: whitebox 51 | ``` 52 | 53 | With your terminal navigate to `docker-compose.yml`'s location and then run the following command: 54 | 55 |
56 | 57 | ```console 58 | $ docker-compose up 59 | 60 | ``` 61 | 62 |
63 | 64 | ## Kubernetes 65 | 66 | You can also install Whitebox server and all of its dependencies in your k8s cluster using `helm`: 67 | 68 | ```bash 69 | helm repo add squaredev https://chartmuseum.squaredev.io/ 70 | helm repo update 71 | helm install whitebox squaredev/whitebox 72 | ``` 73 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/tutorial/monitors_alerts.md: -------------------------------------------------------------------------------- 1 | # Monitors and Alerts 2 | 3 | ## Monitors 4 | 5 | You can create a monitor in whitebox so that alert are created automaticaly when some value is out of bounds. Here is an example: 6 | 7 | ```Python 8 | from whitebox import Whitebox, MonitorStatus, MonitorMetrics, AlertSeverity 9 | 10 | wb = Whitebox(host="127.0.0.1:8000", api_key="some_api_key") 11 | 12 | model_monitor = wb.create_model_monitor( 13 | model_id="mock_model_id", 14 | name="test", 15 | status=MonitorStatus.active, 16 | metric=MonitorMetrics.accuracy, 17 | severity=AlertSeverity.high, 18 | email="jackie.chan@somemail.io", 19 | lower_threshold=0.7 20 | ) 21 | ``` 22 | 23 | ## Alerts 24 | 25 | Once the metrics reports have been produced, the monitoring alert pipeline is triggered. This means that if you have created any model monitors for a specific metric, alerts will be created if certain criteria are met, based on the thresholds and the monitor types you have specified. 26 | -------------------------------------------------------------------------------- /docs/mkdocs/docs/tutorial/sdk.md: -------------------------------------------------------------------------------- 1 | # Using Whitebox SDK 2 | 3 | ## Installation 4 | 5 | Installing Whitebox is a pretty easy job. Just install it like any other python package. 6 | 7 | Install the SDK with `pip`: 8 | 9 |
10 | 11 | ```console 12 | $ pip install whitebox-sdk 13 | ``` 14 | 15 |
16 | 17 | All the required packages will be automatically installed! 18 | 19 | Now you're good to go! 20 | 21 | ## Initial Setup 22 | 23 | In order to run Whitebox, you will need the application's API key. 24 | This key will be produced for you during the initial run of the Uvicorn live server. 25 | Assuming you run the server with docker compose (you can find more in the install page of the tutorial), you will see the following output: 26 | 27 |
28 | 29 | ```console 30 | $ docker compose up 31 | 32 | ... 33 | INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) 34 | INFO: Started reloader process [4450] using StatReload 35 | INFO: Started server process [4452] 36 | INFO: Waiting for application startup. 37 | INFO: Created username: admin, API key: some_api_key 38 | INFO: Application startup complete. 39 | ... 40 | ``` 41 | 42 |
43 | 44 | !!! info 45 | 46 | Keep this API key somewhere safe! 47 | 48 | If you lose it, you will need to delete the admin user in your database and re-run the live serve to produce a new key! 49 | 50 | After you get the API key, all you have to do is create an instance of the Whitebox class adding your host and API key as parameters: 51 | 52 | ```Python 53 | from whitebox import Whitebox 54 | 55 | wb = Whitebox(host="http://127.0.0.1:8000", api_key="some_api_key") 56 | ``` 57 | 58 | Now you're ready to start using Whitebox! 59 | 60 | ## Models 61 | 62 | ### Creating a Model 63 | 64 | In order to start adding training datasets and inferences, you first need to create a model. 65 | 66 | Let's create a sample model: 67 | 68 | ```Python 69 | wb.create_model( 70 | name="Model 1", 71 | type="binary", 72 | labels={ 73 | 'additionalProp1': 0, 74 | 'additionalProp2': 1 75 | }, 76 | target_column="target", 77 | granularity="1D" 78 | ) 79 | ``` 80 | 81 | For more details about the schema accepted property types visit the [Models section](../../sdk-docs/#models) in the SDK documentation. 82 | 83 | ### Fetching a Model 84 | 85 | Getting a model from the database is as easy as it sounds. You'll just need the `model_id`: 86 | 87 | ```Python 88 | wb.get_model("some_model_id") 89 | ``` 90 | 91 | ### Deleting a model 92 | 93 | Deleting a model is as easy as fetching a model. Just use the `model_id`: 94 | 95 | ```Python 96 | wb.delete_model("some_model_id") 97 | ``` 98 | 99 | !!! warning 100 | 101 | You will have to be extra careful when deleting a model because all datasets, inferences, monitors and literally everything will be deleted from the database along with the model itself! 102 | 103 | ## Loading Training Datasets 104 | 105 | Once you have created a model you can start loading your data. Let's start with the training dataset! 106 | 107 | In our example we will create a `pd.DataFrame` from a `.csv` file. Of course you can use any method you like to create your `pd.DataFrame` as long as your non-processed and processed datasets have **the same amount of rows** (a.k.a. the same length) and there are **more than one rows**! 108 | 109 | ```Python 110 | import pandas as pd 111 | non_processed_df = pd.read_csv("path/to/file/non_processed_data.csv") 112 | processed_df = pd.read_csv("path/to/file/processed_data.csv") 113 | 114 | wb.log_training_dataset( 115 | model_id="some_model_id", 116 | non_processed=non_processed_df, 117 | processed=processed_df 118 | ) 119 | ``` 120 | 121 | !!! note 122 | 123 | When your training dataset is saved in the database, the model training process will begin excecuting, based on this dataset and the model it's associated with. That's why you need to load all the rows of your training dataset in the same batch. 124 | 125 | ## Loading Inferences 126 | 127 | To load your inferences you have to follow the exact same procedure as with the training datasets. The only difference is that you need to provide a `pd.Series` with the timestamps and (optionally) a `pd.Series` with the actuals, whose indices should match the ones in the non-processed and processed `pd.DataFrames`. 128 | 129 | In our example let's assume that both the non-processed and processed `pd.DataFrames` have 10 rows each: 130 | 131 | ```Python 132 | import pandas as pd 133 | non_processed_df = pd.read_csv("path/to/file/non_processed_data.csv") 134 | processed_df = pd.read_csv("path/to/file/processed_data.csv") 135 | 136 | # Timestamps and actuals should have a length of 10 137 | timestamps = pd.Series(["2022-12-22T12:13:27.879738"] * 10) 138 | actuals = pd.Series([0, 1, 1, 1, 0, 0, 1, 1, 0, 0]) 139 | 140 | wb.log_inferences( 141 | model_id="some_model_id", 142 | non_processed=non_processed_df, 143 | processed=processed_df, 144 | timestamps=timestamps, 145 | actuals=actuals 146 | ) 147 | ``` 148 | 149 | !!! warning 150 | 151 | Make sure you add the actuals if you already know them, because as of now, there's no ability to add them at a later time by updating the inference rows. 152 | -------------------------------------------------------------------------------- /docs/mkdocs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Whitebox 2 | theme: 3 | name: material 4 | logo: img/logo-white.png 5 | favicon: img/favicon.png 6 | palette: 7 | - media: "(prefers-color-scheme: light)" 8 | scheme: default 9 | toggle: 10 | icon: material/weather-sunny 11 | name: Switch to dark mode 12 | - media: "(prefers-color-scheme: dark)" 13 | scheme: slate 14 | toggle: 15 | icon: material/weather-night 16 | name: Switch to light mode 17 | icon: 18 | repo: fontawesome/brands/git-alt 19 | features: 20 | - navigation.footer 21 | repo_url: https://github.com/whitebox-ai/whitebox 22 | repo_name: squaredev/whitebox 23 | edit_uri: "" 24 | 25 | nav: 26 | - Overview: index.md 27 | - Features: features.md 28 | - Tutorial: 29 | - tutorial/installation.md 30 | - tutorial/sdk.md 31 | - tutorial/metrics.md 32 | - tutorial/monitors_alerts.md 33 | 34 | - MLOps Monitoring Intro: 35 | - ml-monitoring.md 36 | - metric-definitions.md 37 | 38 | - sdk-docs.md 39 | 40 | markdown_extensions: 41 | - pymdownx.arithmatex 42 | - pymdownx.highlight: 43 | anchor_linenums: true 44 | - admonition 45 | - def_list 46 | - pymdownx.details 47 | - pymdownx.highlight 48 | - pymdownx.inlinehilite 49 | - pymdownx.snippets 50 | - pymdownx.superfences 51 | - toc: 52 | permalink: true 53 | extra_css: 54 | - css/termynal.css 55 | - css/custom.css 56 | - css/extra.css 57 | extra_javascript: 58 | - js/termynal.js 59 | - js/custom.js 60 | - js/mathjax-config.js 61 | - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML 62 | extra: 63 | social: 64 | - icon: fontawesome/brands/github 65 | link: https://github.com/whitebnox-ai 66 | - icon: fontawesome/brands/linkedin 67 | link: https://www.linkedin.com/company/squaredev 68 | - icon: fontawesome/brands/python 69 | link: https://pypi.org/project/whitebox-sdk/ 70 | -------------------------------------------------------------------------------- /examples/docker-compose/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.10" 2 | name: Whitebox 3 | services: 4 | postgres: 5 | image: postgres:15 6 | restart: unless-stopped 7 | environment: 8 | - POSTGRES_USER=postgres 9 | - POSTGRES_PASSWORD=postgres 10 | - POSTGRES_MULTIPLE_DATABASES=test # postgres db is created by default 11 | logging: 12 | options: 13 | max-size: 10m 14 | max-file: "3" 15 | ports: 16 | - "5432:5432" 17 | volumes: 18 | - wb_data:/var/lib/postgresql/data 19 | networks: 20 | - whitebox 21 | 22 | whitebox: 23 | image: sqdhub/whitebox:main 24 | platform: linux/amd64 25 | restart: unless-stopped 26 | environment: 27 | - APP_NAME=Whitebox | Docker 28 | - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/postgres 29 | - SECRET_KEY= 30 | ports: 31 | - "8000:8000" 32 | depends_on: 33 | - postgres 34 | networks: 35 | - whitebox 36 | 37 | volumes: 38 | wb_data: 39 | 40 | networks: 41 | whitebox: 42 | name: whitebox 43 | -------------------------------------------------------------------------------- /helm_charts/whitebox/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: whitebox 3 | description: A Machine learning monitoring platform 4 | 5 | type: application 6 | version: 0.1.0 7 | appVersion: "0.1.0" 8 | 9 | dependencies: 10 | - name: postgresql 11 | version: 12.1.5 12 | repository: https://charts.bitnami.com/bitnami 13 | condition: postgresql.enabled 14 | -------------------------------------------------------------------------------- /helm_charts/whitebox/artifacthub-repo.yml: -------------------------------------------------------------------------------- 1 | repositoryID: 5276bf98-7f0a-407d-ba4d-5b3083801cd6 2 | owners: 3 | - name: Squaredev 4 | email: hello@squaredev.io 5 | -------------------------------------------------------------------------------- /helm_charts/whitebox/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "whitebox.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "whitebox.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "whitebox.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "whitebox.labels" -}} 37 | helm.sh/chart: {{ include "whitebox.chart" . }} 38 | {{ include "whitebox.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "whitebox.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "whitebox.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "whitebox.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "whitebox.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} 63 | -------------------------------------------------------------------------------- /helm_charts/whitebox/templates/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: {{ include "whitebox.fullname" . }} 5 | labels: 6 | {{- include "whitebox.labels" . | nindent 4 }} 7 | spec: 8 | replicas: {{ .Values.replicaCount }} 9 | selector: 10 | matchLabels: 11 | {{- include "whitebox.selectorLabels" . | nindent 6 }} 12 | template: 13 | metadata: 14 | {{- with .Values.podAnnotations }} 15 | annotations: 16 | {{- toYaml . | nindent 8 }} 17 | {{- end }} 18 | labels: 19 | {{- include "whitebox.selectorLabels" . | nindent 8 }} 20 | spec: 21 | securityContext: 22 | {{- toYaml .Values.podSecurityContext | nindent 8 }} 23 | containers: 24 | - name: whitebox 25 | securityContext: 26 | {{- toYaml .Values.securityContext | nindent 12 }} 27 | image: sqdhub/whitebox:{{ .Values.image.tag }} 28 | ports: 29 | - name: http 30 | containerPort: 8000 31 | protocol: TCP 32 | env: 33 | - name: DATABASE_URL 34 | value: postgresql://{{ .Values.postgresql.auth.username | default "postgres" }}:{{ .Values.postgresql.auth.password | default "postgres" }}@{{ .Release.Name }}-postgresql/postgres 35 | resources: 36 | {{- toYaml .Values.resources | nindent 12 }} 37 | {{- with .Values.nodeSelector }} 38 | nodeSelector: 39 | {{- toYaml . | nindent 8 }} 40 | {{- end }} 41 | {{- with .Values.affinity }} 42 | affinity: 43 | {{- toYaml . | nindent 8 }} 44 | {{- end }} 45 | {{- with .Values.tolerations }} 46 | tolerations: 47 | {{- toYaml . | nindent 8 }} 48 | {{- end }} 49 | -------------------------------------------------------------------------------- /helm_charts/whitebox/templates/ingress.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.ingress.enabled -}} 2 | {{- $fullName := include "whitebox.fullname" . -}} 3 | {{- $svcPort := .Values.service.port -}} 4 | {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} 5 | {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} 6 | {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} 7 | {{- end }} 8 | {{- end }} 9 | {{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} 10 | apiVersion: networking.k8s.io/v1 11 | {{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} 12 | apiVersion: networking.k8s.io/v1beta1 13 | {{- else -}} 14 | apiVersion: extensions/v1beta1 15 | {{- end }} 16 | kind: Ingress 17 | metadata: 18 | name: {{ $fullName }} 19 | labels: 20 | {{- include "whitebox.labels" . | nindent 4 }} 21 | {{- with .Values.ingress.annotations }} 22 | annotations: 23 | {{- toYaml . | nindent 4 }} 24 | {{- end }} 25 | spec: 26 | {{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }} 27 | ingressClassName: {{ .Values.ingress.className }} 28 | {{- end }} 29 | {{- if .Values.ingress.tls }} 30 | tls: 31 | {{- range .Values.ingress.tls }} 32 | - hosts: 33 | {{- range .hosts }} 34 | - {{ . | quote }} 35 | {{- end }} 36 | secretName: {{ .secretName }} 37 | {{- end }} 38 | {{- end }} 39 | rules: 40 | {{- range .Values.ingress.hosts }} 41 | - host: {{ .host | quote }} 42 | http: 43 | paths: 44 | {{- range .paths }} 45 | - path: {{ .path }} 46 | {{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }} 47 | pathType: {{ .pathType }} 48 | {{- end }} 49 | backend: 50 | {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} 51 | service: 52 | name: {{ $fullName }} 53 | port: 54 | number: {{ $svcPort }} 55 | {{- else }} 56 | serviceName: {{ $fullName }} 57 | servicePort: {{ $svcPort }} 58 | {{- end }} 59 | {{- end }} 60 | {{- end }} 61 | {{- end }} 62 | -------------------------------------------------------------------------------- /helm_charts/whitebox/templates/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: {{ include "whitebox.fullname" . }} 5 | labels: 6 | {{- include "whitebox.labels" . | nindent 4 }} 7 | spec: 8 | type: {{ .Values.service.type }} 9 | ports: 10 | - port: {{ .Values.service.port }} 11 | targetPort: http 12 | protocol: TCP 13 | name: http 14 | selector: 15 | {{- include "whitebox.selectorLabels" . | nindent 4 }} 16 | -------------------------------------------------------------------------------- /helm_charts/whitebox/values.yaml: -------------------------------------------------------------------------------- 1 | replicaCount: 1 2 | 3 | image: 4 | tag: main 5 | 6 | nameOverride: '' 7 | fullnameOverride: '' 8 | 9 | podAnnotations: {} 10 | 11 | podSecurityContext: {} 12 | 13 | securityContext: {} 14 | 15 | service: 16 | type: ClusterIP 17 | port: 80 18 | 19 | 20 | ingress: 21 | enabled: false 22 | # className: '' 23 | # annotations: 24 | # kubernetes.io/ingress.class: nginx 25 | # kubernetes.io/tls-acme: "true" 26 | # cert-manager.io/cluster-issuer: letsencrypt-prod 27 | # hosts: 28 | # - host: whitebox.example.com 29 | # paths: 30 | # - path: / 31 | # pathType: Prefix 32 | # tls: 33 | # - secretName: whitebox-tls 34 | # hosts: 35 | # - whitebox.example.com 36 | 37 | resources: {} 38 | 39 | nodeSelector: {} 40 | 41 | tolerations: [] 42 | 43 | affinity: {} 44 | 45 | postgresql: 46 | enabled: true 47 | storageClass: standard 48 | 49 | auth: 50 | username: whitebox 51 | password: whitebox 52 | 53 | primary: 54 | persistence: 55 | size: 100Mi 56 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | pythonpath = . 3 | log_cli=true 4 | ; log_cli_level = "INFO" 5 | ; addopts = --ignore=whitebox/tests/v1/**.py 6 | markers = order -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.6.2 2 | appdirs==1.4.4 3 | asyncpg==0.27.0 4 | attrs==22.2.0 5 | bcrypt==4.0.1 6 | black==23.1.0 7 | brotlipy==0.7.0 8 | certifi==2022.12.7 9 | cffi==1.15.1 10 | cfgv==3.3.1 11 | charset-normalizer==3.0.1 12 | click==8.1.3 13 | colorama==0.4.6 14 | contourpy==1.0.7 15 | coverage==7.1.0 16 | crontab==1.0.0 17 | cryptography==39.0.1 18 | cycler==0.11.0 19 | databases==0.7.0 20 | distlib==0.3.6 21 | docopt==0.6.2 22 | evidently==0.2.4 23 | exceptiongroup==1.1.0 24 | fastapi==0.90.1 25 | filelock==3.9.0 26 | fonttools==4.38.0 27 | ghp-import==2.1.0 28 | h11==0.14.0 29 | httpcore==0.16.3 30 | httpx==0.23.3 31 | identify==2.5.17 32 | idna==3.4 33 | imageio==2.25.0 34 | iniconfig==2.0.0 35 | Jinja2==3.1.2 36 | joblib==1.2.0 37 | kiwisolver==1.4.4 38 | lightgbm==3.3.5 39 | lime==0.2.0.1 40 | Markdown==3.3.7 41 | MarkupSafe==2.1.2 42 | matplotlib==3.6.3 43 | mergedeep==1.3.4 44 | mkdocs==1.4.2 45 | mkdocs-material==9.0.12 46 | mkdocs-material-extensions==1.1.1 47 | mypy-extensions==1.0.0 48 | networkx==3.0 49 | nltk==3.8.1 50 | nodeenv==1.7.0 51 | numpy==1.24.2 52 | packaging==23.0 53 | pandas==1.5.3 54 | pathspec==0.11.0 55 | patsy==0.5.3 56 | Pillow==9.4.0 57 | pip==23.0 58 | platformdirs==3.0.0 59 | plotly==5.13.0 60 | pluggy==1.0.0 61 | pooch==1.6.0 62 | pre-commit==3.0.4 63 | psycopg2-binary==2.9.5 64 | pycparser==2.21 65 | pydantic==1.10.4 66 | Pygments==2.14.0 67 | pymdown-extensions==9.9.2 68 | pyOpenSSL==23.0.0 69 | pyparsing==3.0.9 70 | PySocks==1.7.1 71 | pytest==7.2.1 72 | pytest-order==1.0.1 73 | pytest-watch==4.2.0 74 | python-dateutil==2.8.2 75 | python-dotenv==0.21.1 76 | pytz==2022.7.1 77 | PyWavelets==1.4.1 78 | PyYAML==5.4.1 79 | pyyaml_env_tag==0.1 80 | regex==2022.10.31 81 | requests==2.28.2 82 | requests-mock==1.10.0 83 | rfc3986==1.5.0 84 | scikit-image==0.19.3 85 | scikit-learn==1.2.1 86 | scipy==1.10.0 87 | setuptools==67.1.0 88 | six==1.16.0 89 | sniffio==1.3.0 90 | SQLAlchemy==1.4.46 91 | stack-data==0.5.1 92 | starlette==0.23.1 93 | statsmodels==0.13.5 94 | streamlit==1.18.1 95 | tenacity==8.2.1 96 | threadpoolctl==3.1.0 97 | tifffile==2023.2.3 98 | tomli==2.0.1 99 | tqdm==4.64.1 100 | traitlets==5.4.0 101 | typed-ast==1.5.4 102 | typer==0.6.1 103 | types-python-dateutil==2.8.19 104 | typing_extensions==4.4.0 105 | unicodedata2==14.0.0 106 | urllib3==1.26.14 107 | uvicorn==0.20.0 108 | virtualenv==20.19.0 109 | watchdog==2.2.1 110 | wcwidth==0.2.5 111 | wheel==0.38.4 112 | wrapt==1.14.1 113 | xgboost==1.6.2 114 | yarl==1.8.1 115 | zipp==3.10.0 116 | 117 | -------------------------------------------------------------------------------- /scripts/decrypt_api_key.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | from whitebox.utils.passwords import decrypt_api_key 7 | from whitebox.core.settings import get_settings 8 | 9 | value = sys.argv[1] 10 | settings = get_settings() 11 | 12 | if __name__ == "__main__": 13 | api_key = decrypt_api_key(value, settings.SECRET_KEY.encode()) 14 | print(api_key) 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | 4 | 5 | VERSION = "0.0.20" 6 | 7 | DESCRIPTION = "Whitebox is an open source E2E ML monitoring platform with edge capabilities that plays nicely with kubernetes" 8 | LONG_DESCRIPTION = (Path(__file__).parent / "README.md").read_text() 9 | 10 | # Setting up 11 | setup( 12 | # the name must match the folder name 'verysimplemodule' 13 | name="whitebox-sdk", 14 | version=VERSION, 15 | author="Squaredev", 16 | author_email="hello@squaredev.io", 17 | description=DESCRIPTION, 18 | long_description=LONG_DESCRIPTION, # add README.md 19 | long_description_content_type="text/markdown", 20 | packages=find_packages(), 21 | install_requires=[], # add any additional packages that 22 | # needs to be installed along with your package. Eg: 'caer' 23 | keywords=["python", "model monitoring", "whitebox", "mlops"], 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /whitebox/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="dark" 3 | primaryColor="#21babe" 4 | backgroundColor="#1e2025" 5 | secondaryBackgroundColor="#252a33" 6 | 7 | -------------------------------------------------------------------------------- /whitebox/__init__.py: -------------------------------------------------------------------------------- 1 | from whitebox.sdk import * 2 | -------------------------------------------------------------------------------- /whitebox/analytics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/analytics/__init__.py -------------------------------------------------------------------------------- /whitebox/analytics/data/testing/classification_test_data.csv: -------------------------------------------------------------------------------- 1 | y_testing_binary,y_prediction_binary,y_testing_multi,y_prediction_multi 2 | 0,0,0,1 3 | 0,0,2,2 4 | 0,1,0,0 5 | 1,1,1,1 6 | 0,0,2,0 7 | 1,0,1,1 8 | 1,0,1,0 9 | 1,1,2,2 10 | 0,1,0,1 11 | 1,1,1,1 12 | -------------------------------------------------------------------------------- /whitebox/analytics/data/testing/metrics_test_data.csv: -------------------------------------------------------------------------------- 1 | num1,num2,num3,cat1,cat2 2 | 15.0,,0,Cat,True 3 | 40.0,0.25,2,Dog,False 4 | 200.0,1.456,0,Cat,True 5 | ,45.896,1,Dog,False 6 | 60.0,2.67,2,, 7 | 48.0,9.748,1,Dog,False 8 | 1000.0,,1,Cat,True 9 | 43.0,1.67,2,Dog,False 10 | 1.0,0.00054,0,Cat,True 11 | 0.0,12.1,1,Dog, 12 | -------------------------------------------------------------------------------- /whitebox/analytics/data/testing/regression_test_data.csv: -------------------------------------------------------------------------------- 1 | y_test,y_prediction 2 | 0.11,0.12 3 | 0.56,0.55 4 | 0.43,0.43 5 | 0.77,0.51 6 | 0.23,0.2 7 | 0.54,0.54 8 | 0.48,0.47 9 | 0.13,0.13 10 | 0.01,0.0 11 | 0.86,0.9 12 | -------------------------------------------------------------------------------- /whitebox/analytics/drift/pipelines.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | from evidently.report import Report 4 | from evidently.metric_preset import DataDriftPreset, TargetDriftPreset 5 | from whitebox.schemas.driftingMetric import DataDriftTable, ConceptDriftTable 6 | 7 | 8 | def run_data_drift_pipeline( 9 | reference_dataset: pd.DataFrame, current_dataset: pd.DataFrame 10 | ) -> DataDriftTable: 11 | """ 12 | Two datasets are needed 13 | The reference dataset serves as a benchmark. 14 | An analysis happens comparing the current production data to the reference data. 15 | 16 | The dataset should include the needed features to evaluate for drift. 17 | The schema of both datasets should be identical. 18 | - In the case of pandas DataFrame, all column names should be string 19 | - All feature columns analyzed for drift should have the numerical type (np.number) 20 | - Categorical data can be encoded as numerical labels and specified in the column_mapping. 21 | - DateTime column is the only exception. If available, it can be used as the x-axis in the plots. 22 | 23 | Potentially, any two datasets can be used for comparison. Only the reference dataset 24 | will be used as a basis for comparison. 25 | 26 | How it works 27 | 28 | To estimate the data drift Evidently compares the distributions of each feature in the two datasets. 29 | Evidently applies statistical tests to detect if the distribution has changed significantly. There is a default 30 | logic to choosing the appropriate statistical test based on: 31 | - feature type: categorical or numerical 32 | - the number of observations in the reference dataset 33 | - the number of unique values in the feature (n_unique) 34 | 35 | For small data with <= 1000 observations in the reference dataset: 36 | - For numerical features (n_unique > 5): two-sample Kolmogorov-Smirnov test. 37 | - For categorical features or numerical features with n_unique <= 5: chi-squared test. 38 | - For binary categorical features (n_unique <= 2), we use the proportion difference test for independent samples based on Z-score. 39 | 40 | All tests use a 0.95 confidence level by default. 41 | 42 | For larger data with > 1000 observations in the reference dataset: 43 | - For numerical features (n_unique > 5): Wasserstein Distance. 44 | - For categorical features or numerical with n_unique <= 5): Jensen–Shannon divergence. 45 | 46 | All tests use a threshold = 0.1 by default. 47 | 48 | """ 49 | drift_report = Report(metrics=[DataDriftPreset()]) 50 | drift_report.run(reference_data=reference_dataset, current_data=current_dataset) 51 | 52 | initial_report = drift_report.json() 53 | initial_report = json.loads(initial_report) 54 | 55 | data_drift_report = {} 56 | data_drift_report["drift_summary"] = initial_report["metrics"][1]["result"] 57 | 58 | return DataDriftTable(**data_drift_report["drift_summary"]) 59 | 60 | 61 | def run_concept_drift_pipeline( 62 | reference_dataset: pd.DataFrame, current_dataset: pd.DataFrame, target_feature: str 63 | ) -> ConceptDriftTable: 64 | """ 65 | To estimate the categorical target drift, we compare the distribution of the target in the two datasets. 66 | This solution works for both binary and multi-class classification. 67 | As this function works with keywords we have to explicitly define the target column. At the end of the function 68 | we return the initial name of the feature. 69 | 70 | There is a default logic to choosing the appropriate statistical test, based on: 71 | - the number of observations in the reference dataset 72 | - the number of unique values in the target (n_unique) 73 | 74 | For small data with <= 1000 observations in the reference dataset: 75 | - For categorical target with n_unique > 2: chi-squared test. 76 | - For binary categorical target (n_unique <= 2), we use the proportion difference test for independent samples based on Z-score. 77 | 78 | All tests use a 0.95 confidence level by default. 79 | 80 | For larger data with > 1000 observations in the reference dataset we use Jensen–Shannon divergence with a threshold = 0.1. 81 | 82 | """ 83 | reference_dataset.rename(columns={target_feature: "target"}, inplace=True) 84 | current_dataset.rename(columns={target_feature: "target"}, inplace=True) 85 | drift_report = Report(metrics=[TargetDriftPreset()]) 86 | drift_report.run(reference_data=reference_dataset, current_data=current_dataset) 87 | initial_report = drift_report.json() 88 | initial_report = json.loads(initial_report) 89 | concept_drift_report = {} 90 | concept_drift_report["concept_drift_summary"] = initial_report["metrics"][0][ 91 | "result" 92 | ] 93 | concept_drift_report["column_correlation"] = initial_report["metrics"][1]["result"] 94 | 95 | return ConceptDriftTable( 96 | concept_drift_summary=concept_drift_report["concept_drift_summary"], 97 | column_correlation=concept_drift_report["column_correlation"], 98 | ) 99 | -------------------------------------------------------------------------------- /whitebox/analytics/metrics/functions.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import multilabel_confusion_matrix 2 | import pandas as pd 3 | from typing import Dict, Union, List 4 | 5 | 6 | def format_feature_metrics( 7 | missing_count: Dict[str, int], 8 | non_missing_count: Dict[str, int], 9 | mean: Dict[str, float], 10 | minimum: Dict[str, float], 11 | maximum: Dict[str, float], 12 | sum: Dict[str, float], 13 | standard_deviation: Dict[str, float], 14 | variance: Dict[str, float], 15 | ) -> Dict[str, Union[int, float]]: 16 | formated_metrics = { 17 | "missing_count": missing_count, 18 | "non_missing_count": non_missing_count, 19 | "mean": mean, 20 | "minimum": minimum, 21 | "maximum": maximum, 22 | "sum": sum, 23 | "standard_deviation": standard_deviation, 24 | "variance": variance, 25 | } 26 | 27 | return formated_metrics 28 | 29 | 30 | def format_evaluation_metrics_binary( 31 | accuracy: float, 32 | precision: float, 33 | recall: float, 34 | f1: float, 35 | tn: int, 36 | fp: int, 37 | fn: int, 38 | tp: int, 39 | ) -> Dict[str, Union[int, float]]: 40 | formated_metrics_for_binary = { 41 | "accuracy": accuracy, 42 | "precision": precision, 43 | "recall": recall, 44 | "f1": f1, 45 | "true_negative": tn, 46 | "false_positive": fp, 47 | "false_negative": fn, 48 | "true_positive": tp, 49 | } 50 | 51 | return formated_metrics_for_binary 52 | 53 | 54 | def format_evaluation_metrics_multiple( 55 | accuracy: float, 56 | precision_statistics: Dict[str, float], 57 | recall_statistics: Dict[str, float], 58 | f1_statistics: Dict[str, float], 59 | conf_matrix: Dict[str, Dict[str, int]], 60 | ) -> Dict[str, Union[float, Dict[str, Union[int, float]]]]: 61 | formated_metrics_for_multiple = { 62 | "accuracy": accuracy, 63 | "precision": precision_statistics, 64 | "recall": recall_statistics, 65 | "f1": f1_statistics, 66 | "confusion_matrix": conf_matrix, 67 | } 68 | 69 | return formated_metrics_for_multiple 70 | 71 | 72 | def format_evaluation_metrics_regression( 73 | r_square: float, 74 | mean_squared_error: float, 75 | mean_absolute_error: float, 76 | ) -> Dict[str, Union[int, float]]: 77 | formated_metrics_for_regression = { 78 | "r_square": r_square, 79 | "mean_squared_error": mean_squared_error, 80 | "mean_absolute_error": mean_absolute_error, 81 | } 82 | 83 | return formated_metrics_for_regression 84 | 85 | 86 | def confusion_for_multiclass( 87 | test_set: pd.DataFrame, prediction_set: pd.DataFrame, labels: List[int] 88 | ) -> Dict[str, Dict[str, int]]: 89 | """ 90 | Gets 2 datasets based on multiclass classification and calculates 91 | the corresponding confusion matrix outputs tn, fp, fn, tp 92 | 93 | Parameters 94 | ---------- 95 | test_set : pd.DataFrame 96 | Multiclass ground truth labels. 97 | 98 | y_score : pd.DataFrame 99 | Multiclass predicted labels. 100 | 101 | Returns 102 | ------- 103 | mult_dict : Dict 104 | 105 | """ 106 | cm = multilabel_confusion_matrix(test_set, prediction_set, labels=labels) 107 | mult_dict = {} 108 | class_key = 0 109 | for i in cm: 110 | tn, fp, fn, tp = i.ravel() 111 | eval_dict = { 112 | "true_negative": tn, 113 | "false_positive": fp, 114 | "false_negative": fn, 115 | "true_positive": tp, 116 | } 117 | mult_dict["class{}".format(class_key)] = eval_dict 118 | class_key = class_key + 1 119 | return mult_dict 120 | -------------------------------------------------------------------------------- /whitebox/analytics/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/analytics/tests/__init__.py -------------------------------------------------------------------------------- /whitebox/analytics/xai_models/pipelines.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import Dict 3 | import joblib 4 | import lime 5 | import lime.lime_tabular 6 | from whitebox.analytics.models.pipelines import * 7 | from whitebox.core.settings import get_settings 8 | from whitebox.schemas.model import ModelType 9 | 10 | 11 | settings = get_settings() 12 | 13 | 14 | def create_xai_pipeline_per_inference_row( 15 | training_set: pd.DataFrame, 16 | target: str, 17 | inference_row: pd.Series, 18 | type_of_task: str, 19 | model_id: str, 20 | ) -> Dict[str, float]: 21 | model_base_path = settings.MODEL_PATH 22 | model_path = f"{model_base_path}/{model_id}" 23 | 24 | xai_dataset = training_set.drop(columns=[target]) 25 | explainability_report = {} 26 | 27 | # Make a mapping dict which will be used later to map the explainer index 28 | # with the features names 29 | 30 | mapping_dict = {} 31 | for feature in range(0, len(xai_dataset.columns.tolist())): 32 | mapping_dict[feature] = xai_dataset.columns.tolist()[feature] 33 | 34 | # Expainability for both classifications tasks and regression 35 | # We have again to revisit here in the future as in case we upload the model 36 | # from the file system we don't care if it is binary or multiclass 37 | 38 | if type_of_task == ModelType.multi_class: 39 | # Giving the option of retrieving the local model 40 | 41 | model = joblib.load(f"{model_path}/lgb_multi.pkl") 42 | explainer = lime.lime_tabular.LimeTabularExplainer( 43 | xai_dataset.values, 44 | feature_names=xai_dataset.columns.values.tolist(), 45 | mode="classification", 46 | random_state=1, 47 | ) 48 | 49 | exp = explainer.explain_instance(inference_row, model.predict) 50 | med_report = exp.as_map() 51 | temp_dict = dict(list(med_report.values())[0]) 52 | explainability_report = { 53 | mapping_dict[name]: val for name, val in temp_dict.items() 54 | } 55 | 56 | elif type_of_task == ModelType.binary: 57 | # Giving the option of retrieving the local model 58 | 59 | model = joblib.load(f"{model_path}/lgb_binary.pkl") 60 | explainer = lime.lime_tabular.LimeTabularExplainer( 61 | xai_dataset.values, 62 | feature_names=xai_dataset.columns.values.tolist(), 63 | mode="classification", 64 | random_state=1, 65 | ) 66 | 67 | exp = explainer.explain_instance(inference_row, model.predict_proba) 68 | med_report = exp.as_map() 69 | temp_dict = dict(list(med_report.values())[0]) 70 | explainability_report = { 71 | mapping_dict[name]: val for name, val in temp_dict.items() 72 | } 73 | 74 | elif type_of_task == ModelType.regression: 75 | # Giving the option of retrieving the local model 76 | 77 | model = joblib.load(f"{model_path}/lgb_reg.pkl") 78 | explainer = lime.lime_tabular.LimeTabularExplainer( 79 | xai_dataset.values, 80 | feature_names=xai_dataset.columns.values.tolist(), 81 | mode="regression", 82 | random_state=1, 83 | ) 84 | 85 | exp = explainer.explain_instance(inference_row, model.predict) 86 | med_report = exp.as_map() 87 | temp_dict = dict(list(med_report.values())[0]) 88 | explainability_report = { 89 | mapping_dict[name]: val for name, val in temp_dict.items() 90 | } 91 | 92 | return explainability_report 93 | -------------------------------------------------------------------------------- /whitebox/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/api/__init__.py -------------------------------------------------------------------------------- /whitebox/api/v1/__init__.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from .health import health_router 3 | 4 | from .models import models_router 5 | from .dataset_rows import dataset_rows_router 6 | from .inference_rows import inference_rows_router 7 | from .performance_metrics import performance_metrics_router 8 | from .drifting_metrics import drifting_metrics_router 9 | from .model_integrity_metrics import model_integrity_metrics_router 10 | from .model_monitors import model_monitors_router 11 | from .alerts import alerts_router 12 | from .cron_tasks import cron_tasks_router 13 | 14 | 15 | v1_router = APIRouter() 16 | v1 = "/v1" 17 | 18 | v1_router.include_router(health_router, prefix=v1) 19 | v1_router.include_router(models_router, prefix=v1) 20 | v1_router.include_router(dataset_rows_router, prefix=v1) 21 | v1_router.include_router(inference_rows_router, prefix=v1) 22 | v1_router.include_router(performance_metrics_router, prefix=v1) 23 | v1_router.include_router(drifting_metrics_router, prefix=v1) 24 | v1_router.include_router(model_integrity_metrics_router, prefix=v1) 25 | v1_router.include_router(model_monitors_router, prefix=v1) 26 | v1_router.include_router(alerts_router, prefix=v1) 27 | v1_router.include_router(cron_tasks_router, prefix=v1) 28 | -------------------------------------------------------------------------------- /whitebox/api/v1/alerts.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from fastapi import APIRouter, Depends, status 3 | from whitebox import crud 4 | from sqlalchemy.orm import Session 5 | from whitebox.core.db import get_db 6 | from whitebox.middleware.auth import authenticate_user 7 | from whitebox.schemas.alert import Alert 8 | from whitebox.schemas.user import User 9 | from whitebox.utils.errors import add_error_responses, errors 10 | 11 | 12 | alerts_router = APIRouter() 13 | 14 | 15 | @alerts_router.get( 16 | "/alerts", 17 | tags=["Alerts"], 18 | response_model=List[Alert], 19 | summary="Get all model's alerts", 20 | status_code=status.HTTP_200_OK, 21 | responses=add_error_responses([401, 404]), 22 | ) 23 | async def get_alerts( 24 | model_id: Union[str, None] = None, 25 | db: Session = Depends(get_db), 26 | authenticated_user: User = Depends(authenticate_user), 27 | ): 28 | """ 29 | Fetches alerts from the databse. 30 | \n If a model id is provided, only the alerts for the specific model will be fetched. 31 | \n If a model id is not provided then all alerts from the database will be fetched. 32 | """ 33 | 34 | if model_id: 35 | model = crud.models.get(db, model_id) 36 | if model: 37 | return crud.alerts.get_model_alerts_by_model(db=db, model_id=model_id) 38 | else: 39 | return errors.not_found("Model not found") 40 | else: 41 | return crud.alerts.get_all(db=db) 42 | -------------------------------------------------------------------------------- /whitebox/api/v1/cron_tasks.py: -------------------------------------------------------------------------------- 1 | from whitebox.schemas.utils import HealthCheck 2 | from fastapi import APIRouter, status 3 | from whitebox.cron_tasks.monitoring_metrics import run_calculate_metrics_pipeline 4 | from whitebox.cron_tasks.monitoring_alerts import run_create_alerts_pipeline 5 | 6 | 7 | cron_tasks_router = APIRouter() 8 | 9 | 10 | @cron_tasks_router.post( 11 | "/cron-tasks/run", 12 | tags=["Cron Tasks"], 13 | summary="Helper endpoint", 14 | status_code=status.HTTP_200_OK, 15 | response_description="Result of cron tasks", 16 | ) 17 | async def run_cron(): 18 | """A helper endpoint that triggers the metrics and alerts pipelines while testing.""" 19 | 20 | await run_calculate_metrics_pipeline() 21 | await run_create_alerts_pipeline() 22 | return HealthCheck(status="OK") 23 | -------------------------------------------------------------------------------- /whitebox/api/v1/dataset_rows.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import List 3 | from whitebox.analytics.models.pipelines import ( 4 | create_binary_classification_training_model_pipeline, 5 | create_multiclass_classification_training_model_pipeline, 6 | create_regression_training_model_pipeline, 7 | ) 8 | from whitebox.middleware.auth import authenticate_user 9 | from whitebox.schemas.datasetRow import DatasetRow, DatasetRowCreate 10 | from whitebox.schemas.model import ModelType 11 | from whitebox.schemas.user import User 12 | from fastapi import APIRouter, BackgroundTasks, Depends, status 13 | from fastapi.encoders import jsonable_encoder 14 | from whitebox import crud 15 | from sqlalchemy.orm import Session 16 | from whitebox.core.db import get_db 17 | from whitebox.utils.errors import add_error_responses, errors 18 | 19 | 20 | dataset_rows_router = APIRouter() 21 | 22 | 23 | @dataset_rows_router.post( 24 | "/dataset-rows", 25 | tags=["Dataset Rows"], 26 | response_model=List[DatasetRow], 27 | summary="Create dataset rows", 28 | status_code=status.HTTP_201_CREATED, 29 | responses=add_error_responses([400, 401, 404, 409]), 30 | ) 31 | async def create_dataset_rows( 32 | body: List[DatasetRowCreate], 33 | background_tasks: BackgroundTasks, 34 | db: Session = Depends(get_db), 35 | authenticated_user: User = Depends(authenticate_user), 36 | ) -> DatasetRow: 37 | """ 38 | Inserts a set of dataset rows into the database. 39 | \nWhen the dataset rows are successfully saved, the pipeline for training the model is triggered. 40 | """ 41 | 42 | if len(body) <= 1: 43 | return errors.bad_request("Training dataset should be longer that 1 row!") 44 | 45 | model = crud.models.get(db=db, _id=dict(body[0])["model_id"]) 46 | if model: 47 | for row in body: 48 | if not model.target_column in row.processed: 49 | return errors.bad_request( 50 | f'Column "{model.target_column}" was not found in some or any of the rows in provided training dataset. Please try again!' 51 | ) 52 | 53 | predictions = list(set(vars(x)["processed"][model.target_column] for x in body)) 54 | if len(predictions) <= 1: 55 | return errors.bad_request( 56 | f'Training dataset\'s "{model.target_column}" columns must have at least 2 different values!' 57 | ) 58 | 59 | new_dataset_rows = crud.dataset_rows.create_many(db=db, obj_list=body) 60 | processed_dataset_rows = [ 61 | x["processed"] for x in jsonable_encoder(new_dataset_rows) 62 | ] 63 | processed_dataset_rows_pd = pd.DataFrame(processed_dataset_rows) 64 | 65 | if model.type == ModelType.binary: 66 | background_tasks.add_task( 67 | create_binary_classification_training_model_pipeline, 68 | processed_dataset_rows_pd, 69 | model.target_column, 70 | model.id, 71 | ) 72 | elif model.type == ModelType.multi_class: 73 | background_tasks.add_task( 74 | create_multiclass_classification_training_model_pipeline, 75 | processed_dataset_rows_pd, 76 | model.target_column, 77 | model.id, 78 | ) 79 | elif model.type == ModelType.regression: 80 | background_tasks.add_task( 81 | create_regression_training_model_pipeline, 82 | processed_dataset_rows_pd, 83 | model.target_column, 84 | model.id, 85 | ) 86 | return new_dataset_rows 87 | else: 88 | return errors.not_found(f"Model with id: {dict(body[0])['model_id']} not found") 89 | 90 | 91 | @dataset_rows_router.get( 92 | "/dataset-rows", 93 | tags=["Dataset Rows"], 94 | response_model=List[DatasetRow], 95 | summary="Get all model's dataset rows", 96 | status_code=status.HTTP_200_OK, 97 | responses=add_error_responses([401, 404]), 98 | ) 99 | async def get_all_dataset_rows( 100 | model_id: str, 101 | db: Session = Depends(get_db), 102 | authenticated_user: User = Depends(authenticate_user), 103 | ): 104 | """Fetches the dataset rows of a specific model. A model id is required.""" 105 | 106 | model = crud.models.get(db, model_id) 107 | if model: 108 | return crud.dataset_rows.get_dataset_rows_by_model(db=db, model_id=model_id) 109 | else: 110 | return errors.not_found("Model not found") 111 | -------------------------------------------------------------------------------- /whitebox/api/v1/docs.py: -------------------------------------------------------------------------------- 1 | from whitebox.schemas.utils import ErrorResponse 2 | 3 | 4 | tags_metadata = [ 5 | { 6 | "name": "Health", 7 | "description": "Health endpoints are used for checking the status of the service", 8 | }, 9 | { 10 | "name": "Models", 11 | "description": "This set of endpoints handles the models that a user creates.", 12 | }, 13 | { 14 | "name": "Dataset Rows", 15 | "description": "This set of endpoints handles the dataset rows.", 16 | }, 17 | { 18 | "name": "Inference Rows", 19 | "description": "This set of endpoints handles a model's inference rows.", 20 | }, 21 | { 22 | "name": "Performance Metrics", 23 | "description": "This set of endpoints handles a model's performance metrics.", 24 | }, 25 | { 26 | "name": "Drifting Metrics", 27 | "description": "This set of endpoints handles a model's drifting metrics.", 28 | }, 29 | { 30 | "name": "Model Integrity Metrics", 31 | "description": "This set of endpoints handles a model's integrity metrics.", 32 | }, 33 | { 34 | "name": "Model Monitors", 35 | "description": "This set of endpoints handles a model's model monitors.", 36 | }, 37 | { 38 | "name": "Alerts", 39 | "description": "This set of endpoints handles a model's alerts.", 40 | }, 41 | { 42 | "name": "Cron Tasks", 43 | "description": "This is a helper endpoint to trigger cron tasks for tests.", 44 | }, 45 | ] 46 | 47 | 48 | bad_request: ErrorResponse = { 49 | "title": "BadRequest", 50 | "type": "object", 51 | "properties": { 52 | "error": {"title": "Error Message", "type": "string"}, 53 | "status_code": {"title": "Status code", "type": "integer"}, 54 | }, 55 | } 56 | 57 | validation_error: ErrorResponse = { 58 | "title": "HTTPValidationError", 59 | "type": "object", 60 | "properties": { 61 | "error": {"title": "Error Message", "type": "string"}, 62 | "status_code": {"title": "Status code", "type": "integer"}, 63 | }, 64 | } 65 | 66 | authorization_error: ErrorResponse = { 67 | "title": "AuthorizationError", 68 | "type": "object", 69 | "properties": { 70 | "error": {"title": "Error Message", "type": "string"}, 71 | "status_code": {"title": "Status code", "type": "integer"}, 72 | }, 73 | } 74 | 75 | not_found_error: ErrorResponse = { 76 | "title": "NotFoundError", 77 | "type": "object", 78 | "properties": { 79 | "error": {"title": "Error Message", "type": "string"}, 80 | "status_code": {"title": "Status code", "type": "integer"}, 81 | }, 82 | } 83 | 84 | conflict_error: ErrorResponse = { 85 | "title": "ConflictError", 86 | "type": "object", 87 | "properties": { 88 | "error": {"title": "Error Message", "type": "string"}, 89 | "status_code": {"title": "Status code", "type": "integer"}, 90 | }, 91 | } 92 | 93 | conflict_error: ErrorResponse = { 94 | "title": "ConflictError", 95 | "type": "object", 96 | "properties": { 97 | "error": {"title": "Error Message", "type": "string"}, 98 | "status_code": {"title": "Status code", "type": "integer"}, 99 | }, 100 | } 101 | 102 | content_gone: ErrorResponse = { 103 | "title": "ContentGone", 104 | "type": "object", 105 | "properties": { 106 | "error": {"title": "Error Message", "type": "string"}, 107 | "status_code": {"title": "Status code", "type": "integer"}, 108 | }, 109 | } 110 | -------------------------------------------------------------------------------- /whitebox/api/v1/drifting_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from fastapi import APIRouter, Depends, status 3 | from whitebox import crud 4 | from sqlalchemy.orm import Session 5 | from whitebox.core.db import get_db 6 | from whitebox.middleware.auth import authenticate_user 7 | from whitebox.schemas.driftingMetric import DriftingMetricBase 8 | from whitebox.schemas.user import User 9 | from whitebox.utils.errors import add_error_responses, errors 10 | 11 | 12 | drifting_metrics_router = APIRouter() 13 | 14 | 15 | @drifting_metrics_router.get( 16 | "/drifting-metrics", 17 | tags=["Drifting Metrics"], 18 | response_model=List[DriftingMetricBase], 19 | summary="Get all model's drifting metrics", 20 | status_code=status.HTTP_200_OK, 21 | responses=add_error_responses([401, 404]), 22 | ) 23 | async def get_all_models_drifting_metrics( 24 | model_id: str, 25 | db: Session = Depends(get_db), 26 | authenticated_user: User = Depends(authenticate_user), 27 | ): 28 | """Fetches the drifting metrics of a specific model. A model id is required.""" 29 | 30 | model = crud.models.get(db, model_id) 31 | if model: 32 | return crud.drifting_metrics.get_drifting_metrics_by_model( 33 | db=db, model_id=model_id 34 | ) 35 | else: 36 | return errors.not_found("Model not found") 37 | -------------------------------------------------------------------------------- /whitebox/api/v1/health.py: -------------------------------------------------------------------------------- 1 | from whitebox.schemas.utils import HealthCheck 2 | from fastapi import APIRouter, status 3 | 4 | health_router = APIRouter() 5 | 6 | 7 | @health_router.get( 8 | "/health", 9 | tags=["Health"], 10 | response_model=HealthCheck, 11 | summary="Health check the service", 12 | status_code=status.HTTP_200_OK, 13 | response_description="Status of the service", 14 | ) 15 | def health_check(): 16 | """Responds with the status of the service.""" 17 | return HealthCheck(status="OK") 18 | -------------------------------------------------------------------------------- /whitebox/api/v1/model_integrity_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from fastapi import APIRouter, Depends, status 3 | from whitebox import crud 4 | from sqlalchemy.orm import Session 5 | from whitebox.core.db import get_db 6 | from whitebox.middleware.auth import authenticate_user 7 | from whitebox.schemas.modelIntegrityMetric import ModelIntegrityMetric 8 | from whitebox.schemas.user import User 9 | from whitebox.utils.errors import add_error_responses, errors 10 | 11 | 12 | model_integrity_metrics_router = APIRouter() 13 | 14 | 15 | @model_integrity_metrics_router.get( 16 | "/model-integrity-metrics", 17 | tags=["Model Integrity Metrics"], 18 | response_model=List[ModelIntegrityMetric], 19 | summary="Get all model's model integrity metrics", 20 | status_code=status.HTTP_200_OK, 21 | responses=add_error_responses([401, 404]), 22 | ) 23 | async def get_all_models_model_integrity_metrics( 24 | model_id: str, 25 | db: Session = Depends(get_db), 26 | authenticated_user: User = Depends(authenticate_user), 27 | ): 28 | """Fetches the model integrity metrics of a specific model. A model id is required.""" 29 | 30 | model = crud.models.get(db, model_id) 31 | if model: 32 | return crud.model_integrity_metrics.get_model_integrity_metrics_by_model( 33 | db=db, model_id=model_id 34 | ) 35 | else: 36 | return errors.not_found("Model not found") 37 | -------------------------------------------------------------------------------- /whitebox/api/v1/models.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from whitebox.middleware.auth import authenticate_user 3 | from whitebox.schemas.model import Model, ModelCreateDto, ModelUpdateDto 4 | from fastapi import APIRouter, Depends, status 5 | from whitebox import crud 6 | from sqlalchemy.orm import Session 7 | from whitebox.core.db import get_db 8 | from whitebox.schemas.utils import StatusCode 9 | from whitebox.schemas.user import User 10 | from whitebox.utils.errors import add_error_responses, errors 11 | 12 | 13 | models_router = APIRouter() 14 | 15 | 16 | @models_router.post( 17 | "/models", 18 | tags=["Models"], 19 | response_model=Model, 20 | summary="Create model", 21 | status_code=status.HTTP_201_CREATED, 22 | responses=add_error_responses([400, 401]), 23 | ) 24 | async def create_model( 25 | body: ModelCreateDto, 26 | db: Session = Depends(get_db), 27 | authenticated_user: User = Depends(authenticate_user), 28 | ) -> Model: 29 | """Inserts a model into the database""" 30 | 31 | granularity = body.granularity 32 | 33 | try: 34 | granularity_amount = float(granularity[:-1]) 35 | except ValueError: 36 | return errors.bad_request("Granularity amount that was given is not a number!") 37 | 38 | if not granularity_amount.is_integer(): 39 | return errors.bad_request( 40 | "Granularity amount should be an integer and not a float (e.g. 1D)!" 41 | ) 42 | 43 | granularity_type = granularity[-1] 44 | if granularity_type not in ["T", "H", "D", "W"]: 45 | return errors.bad_request( 46 | "Wrong granularity type. Accepted values: T (minutes), H (hours), D (days), W (weeks)" 47 | ) 48 | 49 | new_model = crud.models.create(db=db, obj_in=body) 50 | return new_model 51 | 52 | 53 | @models_router.get( 54 | "/models", 55 | tags=["Models"], 56 | response_model=List[Model], 57 | summary="Get all models", 58 | status_code=status.HTTP_200_OK, 59 | responses=add_error_responses([401]), 60 | ) 61 | async def get_all_models( 62 | db: Session = Depends(get_db), authenticated_user: User = Depends(authenticate_user) 63 | ): 64 | """Fetches all models from the database""" 65 | 66 | models_in_db = crud.models.get_all(db=db) 67 | return models_in_db 68 | 69 | 70 | @models_router.get( 71 | "/models/{model_id}", 72 | tags=["Models"], 73 | response_model=Model, 74 | summary="Get model by id", 75 | status_code=status.HTTP_200_OK, 76 | responses=add_error_responses([401, 404]), 77 | ) 78 | async def get_model( 79 | model_id: str, 80 | db: Session = Depends(get_db), 81 | authenticated_user: User = Depends(authenticate_user), 82 | ): 83 | """Fetches the model with the specified id from the database""" 84 | 85 | model = crud.models.get(db=db, _id=model_id) 86 | 87 | if not model: 88 | return errors.not_found("Model not found") 89 | 90 | return model 91 | 92 | 93 | @models_router.put( 94 | "/models/{model_id}", 95 | tags=["Models"], 96 | response_model=Model, 97 | summary="Update model", 98 | status_code=status.HTTP_200_OK, 99 | responses=add_error_responses([400, 401, 404]), 100 | ) 101 | async def update_model( 102 | model_id: str, 103 | body: ModelUpdateDto, 104 | db: Session = Depends(get_db), 105 | authenticated_user: User = Depends(authenticate_user), 106 | ) -> Model: 107 | """Updates record of the model with the specified id""" 108 | 109 | # Remove all unset properties (with None values) from the update object 110 | filtered_body = {k: v for k, v in dict(body).items() if v is not None} 111 | 112 | model = crud.models.get(db=db, _id=model_id) 113 | 114 | if not model: 115 | return errors.not_found("Model not found") 116 | 117 | return crud.models.update(db=db, db_obj=model, obj_in=filtered_body) 118 | 119 | 120 | @models_router.delete( 121 | "/models/{model_id}", 122 | tags=["Models"], 123 | response_model=StatusCode, 124 | summary="Delete model", 125 | status_code=status.HTTP_200_OK, 126 | responses=add_error_responses([401, 404]), 127 | ) 128 | async def delete_model( 129 | model_id: str, 130 | db: Session = Depends(get_db), 131 | authenticated_user: User = Depends(authenticate_user), 132 | ) -> StatusCode: 133 | """Deletes the model with the specified id from the database""" 134 | 135 | model = crud.models.get(db=db, _id=model_id) 136 | if not model: 137 | return errors.not_found("Model not found") 138 | 139 | crud.models.remove(db=db, _id=model_id) 140 | return {"status_code": status.HTTP_200_OK} 141 | -------------------------------------------------------------------------------- /whitebox/api/v1/performance_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from fastapi import APIRouter, Depends, status 3 | from whitebox import crud 4 | from sqlalchemy.orm import Session 5 | from whitebox.core.db import get_db 6 | from whitebox.middleware.auth import authenticate_user 7 | from whitebox.schemas.performanceMetric import ( 8 | BinaryClassificationMetrics, 9 | MultiClassificationMetrics, 10 | RegressionMetrics, 11 | ) 12 | from whitebox.schemas.user import User 13 | from whitebox.schemas.model import ModelType 14 | from whitebox.utils.errors import add_error_responses, errors 15 | 16 | 17 | performance_metrics_router = APIRouter() 18 | 19 | 20 | @performance_metrics_router.get( 21 | "/performance-metrics", 22 | tags=["Performance Metrics"], 23 | response_model=Union[ 24 | List[BinaryClassificationMetrics], 25 | List[MultiClassificationMetrics], 26 | List[RegressionMetrics], 27 | ], 28 | summary="Get all model's performance metrics", 29 | status_code=status.HTTP_200_OK, 30 | responses=add_error_responses([401, 404]), 31 | ) 32 | async def get_all_models_performance_metrics( 33 | model_id: str, 34 | db: Session = Depends(get_db), 35 | authenticated_user: User = Depends(authenticate_user), 36 | ): 37 | """Fetches the performance metrics of a specific model. A model id is required.""" 38 | 39 | model = crud.models.get(db, model_id) 40 | if model: 41 | if model.type == ModelType.binary: 42 | return crud.binary_classification_metrics.get_performance_metrics_by_model( 43 | db=db, model_id=model_id 44 | ) 45 | elif model.type == ModelType.multi_class: 46 | return crud.multi_classification_metrics.get_performance_metrics_by_model( 47 | db=db, model_id=model_id 48 | ) 49 | elif model.type == ModelType.regression: 50 | return crud.regression_metrics.get_performance_metrics_by_model( 51 | db=db, model_id=model_id 52 | ) 53 | else: 54 | return errors.not_found("Model not found") 55 | -------------------------------------------------------------------------------- /whitebox/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/core/__init__.py -------------------------------------------------------------------------------- /whitebox/core/db.py: -------------------------------------------------------------------------------- 1 | from whitebox.core.settings import get_settings 2 | import databases 3 | import sqlalchemy 4 | from sqlalchemy.orm import sessionmaker 5 | from whitebox.entities.Base import Base 6 | from whitebox.schemas.user import UserCreateDto 7 | 8 | from whitebox import crud 9 | from whitebox.utils.passwords import encrypt_api_key 10 | from whitebox.utils.logger import cronLogger as logger 11 | 12 | from secrets import token_hex 13 | 14 | settings = get_settings() 15 | database = databases.Database(settings.DATABASE_URL) 16 | engine = sqlalchemy.create_engine(settings.DATABASE_URL) 17 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 18 | 19 | 20 | def get_db(): 21 | db = SessionLocal() 22 | try: 23 | yield db 24 | finally: 25 | db.close() 26 | 27 | 28 | async def connect(): 29 | """ 30 | Connect to DB 31 | """ 32 | Base.metadata.create_all(engine) 33 | db = SessionLocal() 34 | 35 | admin_exists = crud.users.get_first_by_filter(db=db, username="admin") 36 | if not admin_exists: 37 | plain_api_key = token_hex(32) 38 | secret_key = settings.SECRET_KEY 39 | api_key = ( 40 | encrypt_api_key(plain_api_key, secret_key.encode()) 41 | if secret_key 42 | else plain_api_key 43 | ) 44 | 45 | obj_in = UserCreateDto(username="admin", api_key=api_key) 46 | crud.users.create(db=db, obj_in=obj_in) 47 | logger.info(f"Created username: admin, API key: {plain_api_key}") 48 | await database.connect() 49 | 50 | 51 | async def close(): 52 | """ 53 | Close DB Connection 54 | """ 55 | await database.disconnect() 56 | # logging.info("Closed connection with DB") 57 | -------------------------------------------------------------------------------- /whitebox/core/settings.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from pydantic import BaseSettings 3 | import os 4 | 5 | 6 | class Settings(BaseSettings): 7 | APP_NAME: str = "" 8 | ENV: str = "" 9 | DATABASE_URL: str = "" 10 | VERSION: str = "" 11 | MODEL_PATH: str = "" 12 | SECRET_KEY: str = "" 13 | GRANULARITY: str = "" 14 | 15 | class Config: 16 | env_file = f".env.{os.getenv('ENV')}" or ".env.dev" 17 | 18 | 19 | @lru_cache() 20 | def get_settings(): 21 | return Settings() 22 | 23 | 24 | class CronSettings(Settings): 25 | APP_NAME_CRON: str 26 | METRICS_CRON: str 27 | 28 | class Config: 29 | env_file = f".env.{os.getenv('ENV')}" or ".env.dev" 30 | 31 | 32 | @lru_cache() 33 | def get_cron_settings(): 34 | return CronSettings() 35 | -------------------------------------------------------------------------------- /whitebox/cron.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Depends 2 | import asyncio 3 | import json 4 | from whitebox.utils.logger import cronLogger as logger 5 | 6 | from whitebox.core.settings import get_cron_settings 7 | from whitebox.cron_tasks.tasks import task_manager 8 | from fastapi.openapi.utils import get_openapi 9 | 10 | 11 | settings = get_cron_settings() 12 | cron_app = FastAPI(title=settings.APP_NAME_CRON, redoc_url="/") 13 | 14 | 15 | @cron_app.on_event("startup") 16 | async def init(): 17 | # Start task amanager 18 | asyncio.get_event_loop().create_task(task_manager.run()) 19 | 20 | 21 | @cron_app.on_event("shutdown") 22 | async def shutdown(): 23 | logger.info("App is shutting down...") 24 | logger.info("Task Manager is shutting down...") 25 | await task_manager.shutdown() 26 | 27 | 28 | def app_openapi(): 29 | if cron_app.openapi_schema: 30 | return cron_app.openapi_schema 31 | openapi_schema = get_openapi( 32 | title="Cron API", version=settings.VERSION, routes=cron_app.routes 33 | ) 34 | 35 | cron_app.openapi_schema = openapi_schema 36 | return cron_app.openapi_schema 37 | 38 | 39 | cron_app.openapi = app_openapi 40 | -------------------------------------------------------------------------------- /whitebox/cron_tasks/monitoring_alerts.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import time 3 | from sqlalchemy import create_engine 4 | from sqlalchemy.orm import sessionmaker, Session 5 | 6 | from whitebox import crud, entities 7 | from whitebox.core.settings import get_settings 8 | from whitebox.cron_tasks.shared import ( 9 | get_all_models, 10 | get_latest_drift_metrics_report, 11 | get_latest_performance_metrics_report, 12 | get_active_model_monitors, 13 | ) 14 | from whitebox.schemas.model import Model, ModelType 15 | from whitebox.schemas.modelMonitor import ModelMonitor, MonitorMetrics 16 | from whitebox.utils.logger import cronLogger as logger 17 | 18 | settings = get_settings() 19 | 20 | engine = create_engine(settings.DATABASE_URL) 21 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 22 | db: Session = SessionLocal() 23 | 24 | 25 | async def run_create_performance_metric_alert_pipeline( 26 | model: Model, monitor: ModelMonitor 27 | ): 28 | """ 29 | Run the pipeline to find any alerts for a metric in performance metrics 30 | If one is found it is saved in the database 31 | """ 32 | 33 | last_performance_metrics_report = await get_latest_performance_metrics_report( 34 | db, model 35 | ) 36 | 37 | if not last_performance_metrics_report: 38 | logger.info( 39 | f"No alert created for monitor: {monitor.id} because no performance report was found!" 40 | ) 41 | return 42 | 43 | # Performance metrics reports for not multi_class models have the same format: metric: float. 44 | # Same if the metric is accuracy. 45 | if ( 46 | model.type is not ModelType.multi_class 47 | or monitor.metric == MonitorMetrics.accuracy 48 | ): 49 | metric_value = vars(last_performance_metrics_report)[monitor.metric] 50 | else: 51 | metric_value = vars(last_performance_metrics_report)[monitor.metric]["weighted"] 52 | 53 | if metric_value < monitor.lower_threshold: 54 | new_alert = entities.Alert( 55 | model_id=model.id, 56 | model_monitor_id=monitor.id, 57 | timestamp=str(datetime.utcnow()), 58 | description=f"{monitor.metric} fell below the threshold of {monitor.lower_threshold} at value {metric_value}.", 59 | ) 60 | crud.alerts.create(db, obj_in=new_alert) 61 | logger.info(f"Created alert for monitor {monitor.id}!") 62 | 63 | 64 | async def run_create_drift_alert_pipeline(model: Model, monitor: ModelMonitor): 65 | """ 66 | Run the pipeline to find any alerts for a metric in drift metrics 67 | If one is found it is saved in the database 68 | """ 69 | 70 | last_drift_report = await get_latest_drift_metrics_report(db, model) 71 | 72 | if not last_drift_report: 73 | logger.info( 74 | f"No alert created for monitor: {monitor.id} because no drift report was found!" 75 | ) 76 | return 77 | 78 | if monitor.metric == MonitorMetrics.data_drift: 79 | drift_detected: bool = last_drift_report.data_drift_summary["drift_by_columns"][ 80 | monitor.feature 81 | ]["drift_detected"] 82 | else: 83 | drift_detected: bool = last_drift_report.concept_drift_summary[ 84 | "concept_drift_summary" 85 | ]["drift_detected"] 86 | 87 | if drift_detected: 88 | new_alert = entities.Alert( 89 | model_id=model.id, 90 | model_monitor_id=monitor.id, 91 | timestamp=str(datetime.utcnow()), 92 | description=f'{monitor.metric.capitalize().replace("_", " ")} found in "{monitor.feature}" feature.', 93 | ) 94 | crud.alerts.create(db, obj_in=new_alert) 95 | logger.info(f"Created alert for monitor {monitor.id}!") 96 | 97 | 98 | async def run_create_alerts_pipeline(): 99 | logger.info("Beginning Alerts pipeline for all models!") 100 | start = time.time() 101 | engine.connect() 102 | 103 | models = await get_all_models(db) 104 | if not models: 105 | logger.info("No models found! Skipping pipeline") 106 | else: 107 | for model in models: 108 | model_monitors = await get_active_model_monitors(db, model_id=model.id) 109 | for monitor in model_monitors: 110 | if monitor.metric in [ 111 | MonitorMetrics.accuracy, 112 | MonitorMetrics.precision, 113 | MonitorMetrics.recall, 114 | MonitorMetrics.f1, 115 | MonitorMetrics.r_square, 116 | MonitorMetrics.mean_squared_error, 117 | MonitorMetrics.mean_absolute_error, 118 | ]: 119 | await run_create_performance_metric_alert_pipeline(model, monitor) 120 | elif ( 121 | monitor.metric == MonitorMetrics.data_drift 122 | or monitor.metric == MonitorMetrics.concept_drift 123 | ): 124 | await run_create_drift_alert_pipeline(model, monitor) 125 | 126 | db.close() 127 | end = time.time() 128 | logger.info("Alerts pipeline ended for all models!") 129 | logger.info("Runtime of Alerts pipeline took {}".format(end - start)) 130 | -------------------------------------------------------------------------------- /whitebox/cron_tasks/tasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from whitebox.core.manager import get_task_manager 3 | from whitebox.cron_tasks.monitoring_metrics import run_calculate_metrics_pipeline 4 | from whitebox.cron_tasks.monitoring_alerts import run_create_alerts_pipeline 5 | 6 | task_manager = get_task_manager() 7 | 8 | metrics_cron = os.getenv("METRICS_CRON") or "0 12 * * *" 9 | 10 | task_manager.register( 11 | name="metrics_cron", 12 | async_callable=run_calculate_metrics_pipeline, 13 | crontab=metrics_cron, 14 | ) 15 | 16 | task_manager.register( 17 | name="alerts_cron", 18 | async_callable=run_create_alerts_pipeline, 19 | crontab=metrics_cron, 20 | ) 21 | -------------------------------------------------------------------------------- /whitebox/crud/__init__.py: -------------------------------------------------------------------------------- 1 | from .drifting_metrics import * 2 | from .dataset_rows import * 3 | from .inference_rows import * 4 | from .model_integrity_metrics import * 5 | from .model_monitors import * 6 | from .performance_metrics import * 7 | from .alerts import * 8 | from .users import * 9 | from .models import * 10 | from .base import * 11 | -------------------------------------------------------------------------------- /whitebox/crud/alerts.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from sqlalchemy.orm import Session 3 | from whitebox.crud.base import CRUDBase 4 | from whitebox.entities.Alert import Alert as AlertEntity 5 | from whitebox.schemas.alert import Alert 6 | 7 | 8 | class CRUD(CRUDBase[Alert, Any, Any]): 9 | def get_model_alerts_by_model(self, db: Session, *, model_id: str) -> List[Alert]: 10 | return db.query(self.model).filter(AlertEntity.model_id == model_id).all() 11 | 12 | 13 | alerts = CRUD(AlertEntity) 14 | -------------------------------------------------------------------------------- /whitebox/crud/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union 2 | from fastapi.encoders import jsonable_encoder 3 | from pydantic import BaseModel 4 | from sqlalchemy.orm import Session 5 | from whitebox.core.db import Base 6 | import datetime 7 | 8 | ModelType = TypeVar("ModelType", bound=Base) 9 | CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) 10 | UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) 11 | 12 | 13 | class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): 14 | def __init__(self, model: Type[ModelType]): 15 | """ 16 | CRUD object with default methods to Create, Read, Update, Delete (CRUD). 17 | 18 | **Parameters** 19 | 20 | * `model`: A SQLAlchemy model class 21 | * `schema`: A Pydantic model (schema) class 22 | """ 23 | self.model = model 24 | 25 | def get(self, db: Session, _id: str) -> Optional[ModelType]: 26 | return db.query(self.model).filter(self.model.id == _id).first() 27 | 28 | def get_all( 29 | self, db: Session, *, skip: int = 0, limit: int = 100 30 | ) -> List[ModelType]: 31 | return db.query(self.model).offset(skip).limit(limit).all() 32 | 33 | def get_first_by_filter(self, db: Session, **kwargs: Any) -> Optional[ModelType]: 34 | return db.query(self.model).filter_by(**kwargs).first() 35 | 36 | def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: 37 | date_now = datetime.datetime.utcnow() 38 | obj_in_data = jsonable_encoder(obj_in) 39 | db_obj = self.model(**obj_in_data, created_at=date_now, updated_at=date_now) 40 | db.add(db_obj) 41 | db.commit() 42 | db.refresh(db_obj) 43 | return db_obj 44 | 45 | def create_many( 46 | self, db: Session, *, obj_list: List[CreateSchemaType] 47 | ) -> List[ModelType]: 48 | date_now = datetime.datetime.utcnow() 49 | obj_list_in_data = jsonable_encoder(obj_list) 50 | db_obj_list = list( 51 | map( 52 | lambda x: self.model(**x, created_at=date_now, updated_at=date_now), 53 | obj_list_in_data, 54 | ) 55 | ) 56 | db.add_all(db_obj_list) 57 | db.commit() 58 | for obj in db_obj_list: 59 | db.refresh(obj) 60 | return db_obj_list 61 | 62 | def update( 63 | self, 64 | db: Session, 65 | *, 66 | db_obj: ModelType, 67 | obj_in: Union[UpdateSchemaType, Dict[str, Any]] 68 | ) -> ModelType: 69 | date_now = datetime.datetime.utcnow() 70 | obj_data = jsonable_encoder(db_obj) 71 | if isinstance(obj_in, dict): 72 | update_data = obj_in 73 | else: 74 | update_data = obj_in.dict(exclude_unset=True) 75 | for field in obj_data: 76 | if field in update_data: 77 | setattr(db_obj, field, update_data[field]) 78 | setattr(db_obj, "updated_at", date_now) 79 | db.commit() 80 | db.refresh(db_obj) 81 | return db_obj 82 | 83 | def remove(self, db: Session, *, _id: str): 84 | db.query(self.model).filter(self.model.id == _id).delete() 85 | db.commit() 86 | return 87 | -------------------------------------------------------------------------------- /whitebox/crud/dataset_rows.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from sqlalchemy.orm import Session 3 | from whitebox.crud.base import CRUDBase 4 | from whitebox.schemas.datasetRow import DatasetRow, DatasetRowCreate 5 | from whitebox.entities.DatasetRow import DatasetRow as DatasetRowEntity 6 | 7 | 8 | class CRUD(CRUDBase[DatasetRow, DatasetRowCreate, Any]): 9 | def get_dataset_rows_by_model( 10 | self, db: Session, *, model_id: str 11 | ) -> List[DatasetRow]: 12 | return db.query(self.model).filter(DatasetRowEntity.model_id == model_id).all() 13 | 14 | 15 | dataset_rows = CRUD(DatasetRowEntity) 16 | -------------------------------------------------------------------------------- /whitebox/crud/drifting_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from sqlalchemy.orm import Session 3 | from sqlalchemy import asc, desc 4 | from whitebox.crud.base import CRUDBase 5 | from whitebox.entities.DriftingMetric import DriftingMetric as DriftingMetricEntity 6 | from whitebox.schemas.driftingMetric import DriftingMetric 7 | 8 | 9 | class CRUD(CRUDBase[DriftingMetric, Any, Any]): 10 | def get_drifting_metrics_by_model( 11 | self, db: Session, *, model_id: str 12 | ) -> List[DriftingMetric]: 13 | return ( 14 | db.query(self.model) 15 | .filter(DriftingMetricEntity.model_id == model_id) 16 | .order_by(asc("timestamp")) 17 | .all() 18 | ) 19 | 20 | def get_latest_report_by_model( 21 | self, db: Session, *, model_id: int 22 | ) -> DriftingMetric: 23 | return ( 24 | db.query(self.model) 25 | .filter(self.model.model_id == model_id) 26 | .order_by(desc("created_at")) 27 | .first() 28 | ) 29 | 30 | 31 | drifting_metrics = CRUD(DriftingMetricEntity) 32 | -------------------------------------------------------------------------------- /whitebox/crud/inference_rows.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from datetime import datetime 3 | from sqlalchemy import and_ 4 | from sqlalchemy.orm import Session 5 | from whitebox.crud.base import CRUDBase 6 | from whitebox.schemas.inferenceRow import InferenceRow, InferenceRowPreDb 7 | from whitebox.entities.Inference import InferenceRow as InferenceRowEntity 8 | 9 | 10 | class CRUD(CRUDBase[InferenceRow, InferenceRowPreDb, Any]): 11 | def get_inference_rows_by_model( 12 | self, db: Session, *, model_id: str 13 | ) -> List[InferenceRow]: 14 | return ( 15 | db.query(self.model).filter(InferenceRowEntity.model_id == model_id).all() 16 | ) 17 | 18 | def get_unused_inference_rows( 19 | self, db: Session, *, model_id: str 20 | ) -> List[InferenceRow]: 21 | return ( 22 | db.query(self.model) 23 | .filter( 24 | InferenceRowEntity.model_id == model_id, 25 | InferenceRowEntity.is_used == False, 26 | ) 27 | .all() 28 | ) 29 | 30 | def get_inference_rows_betweet_dates( 31 | self, db: Session, *, model_id: str, min_date: datetime, max_date: datetime 32 | ) -> List[InferenceRow]: 33 | return ( 34 | db.query(self.model) 35 | .filter( 36 | and_( 37 | InferenceRowEntity.model_id == model_id, 38 | InferenceRowEntity.is_used == True, 39 | InferenceRowEntity.timestamp >= min_date, 40 | InferenceRowEntity.timestamp < max_date, 41 | ) 42 | ) 43 | .all() 44 | ) 45 | 46 | 47 | inference_rows = CRUD(InferenceRowEntity) 48 | -------------------------------------------------------------------------------- /whitebox/crud/model_integrity_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from sqlalchemy.orm import Session 3 | from sqlalchemy import asc 4 | from whitebox.crud.base import CRUDBase 5 | from whitebox.entities.ModelIntegrityMetric import ( 6 | ModelIntegrityMetric as ModelIntegrityMetricEntity, 7 | ) 8 | from whitebox.schemas.modelIntegrityMetric import ( 9 | ModelIntegrityMetricCreate, 10 | ModelIntegrityMetric, 11 | ) 12 | 13 | 14 | class CRUD(CRUDBase[ModelIntegrityMetric, ModelIntegrityMetricCreate, Any]): 15 | def get_model_integrity_metrics_by_model( 16 | self, db: Session, *, model_id: str 17 | ) -> List[ModelIntegrityMetric]: 18 | return ( 19 | db.query(self.model) 20 | .filter(ModelIntegrityMetricEntity.model_id == model_id) 21 | .order_by(asc("timestamp")) 22 | .all() 23 | ) 24 | 25 | 26 | model_integrity_metrics = CRUD(ModelIntegrityMetricEntity) 27 | -------------------------------------------------------------------------------- /whitebox/crud/model_monitors.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from whitebox.crud.base import CRUDBase 3 | from sqlalchemy.orm import Session 4 | from whitebox.schemas.modelMonitor import ( 5 | ModelMonitor, 6 | ModelMonitorCreateDto, 7 | MonitorStatus, 8 | ) 9 | from whitebox.entities.ModelMonitor import ModelMonitor as ModelMonitorEntity 10 | 11 | 12 | class CRUD(CRUDBase[ModelMonitor, ModelMonitorCreateDto, Any]): 13 | def get_model_monitors_by_model( 14 | self, db: Session, *, model_id: str 15 | ) -> List[ModelMonitor]: 16 | return ( 17 | db.query(self.model).filter(ModelMonitorEntity.model_id == model_id).all() 18 | ) 19 | 20 | def get_active_model_monitors_by_model( 21 | self, db: Session, *, model_id: str 22 | ) -> List[ModelMonitor]: 23 | return ( 24 | db.query(self.model) 25 | .filter( 26 | ModelMonitorEntity.model_id == model_id, 27 | ModelMonitorEntity.status == MonitorStatus.active, 28 | ) 29 | .all() 30 | ) 31 | 32 | 33 | model_monitors = CRUD(ModelMonitorEntity) 34 | -------------------------------------------------------------------------------- /whitebox/crud/models.py: -------------------------------------------------------------------------------- 1 | from whitebox.crud.base import CRUDBase 2 | from whitebox.schemas.model import Model, ModelCreateDto, ModelUpdateDto 3 | from whitebox.entities.Model import Model as ModelEntity 4 | 5 | 6 | class CRUD(CRUDBase[Model, ModelCreateDto, ModelUpdateDto]): 7 | pass 8 | 9 | 10 | models = CRUD(ModelEntity) 11 | -------------------------------------------------------------------------------- /whitebox/crud/performance_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | from sqlalchemy.orm import Session 3 | from sqlalchemy import asc, desc 4 | from whitebox.crud.base import CRUDBase 5 | from whitebox.entities.PerformanceMetric import ( 6 | BinaryClassificationMetrics as BinaryClassificationMetricsEntity, 7 | MultiClassificationMetrics as MultiClassificationMetricsEntity, 8 | RegressionMetrics as RegressionMetricsEntity, 9 | ) 10 | from whitebox.schemas.performanceMetric import ( 11 | BinaryClassificationMetrics, 12 | MultiClassificationMetrics, 13 | RegressionMetrics, 14 | ) 15 | 16 | 17 | class CRUD( 18 | CRUDBase[Union[BinaryClassificationMetrics, MultiClassificationMetrics], Any, Any] 19 | ): 20 | def get_performance_metrics_by_model( 21 | self, db: Session, *, model_id: int 22 | ) -> Union[ 23 | List[BinaryClassificationMetrics], 24 | List[MultiClassificationMetrics], 25 | List[RegressionMetrics], 26 | ]: 27 | return ( 28 | db.query(self.model) 29 | .filter(self.model.model_id == model_id) 30 | .order_by(asc("timestamp")) 31 | .all() 32 | ) 33 | 34 | def get_latest_report_by_model( 35 | self, db: Session, *, model_id: int 36 | ) -> Union[ 37 | BinaryClassificationMetrics, MultiClassificationMetrics, RegressionMetrics 38 | ]: 39 | return ( 40 | db.query(self.model) 41 | .filter(self.model.model_id == model_id) 42 | .order_by(desc("created_at")) 43 | .first() 44 | ) 45 | 46 | 47 | binary_classification_metrics = CRUD(BinaryClassificationMetricsEntity) 48 | multi_classification_metrics = CRUD(MultiClassificationMetricsEntity) 49 | regression_metrics = CRUD(RegressionMetricsEntity) 50 | -------------------------------------------------------------------------------- /whitebox/crud/users.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from whitebox.crud.base import CRUDBase 3 | from whitebox.schemas.user import User, UserCreateDto 4 | from whitebox.entities.User import User as UserEntity 5 | 6 | 7 | class CRUD(CRUDBase[User, UserCreateDto, Any]): 8 | pass 9 | 10 | 11 | users = CRUD(UserEntity) 12 | -------------------------------------------------------------------------------- /whitebox/entities/Alert.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, ForeignKey, DateTime 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | 5 | 6 | class Alert(Base): 7 | __tablename__ = "alerts" 8 | 9 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 10 | model_id = model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 11 | model_monitor_id = Column( 12 | String, ForeignKey("model_monitors.id", ondelete="CASCADE") 13 | ) 14 | timestamp = Column(DateTime) 15 | description = Column(String) 16 | created_at = Column(DateTime) 17 | updated_at = Column(DateTime) 18 | -------------------------------------------------------------------------------- /whitebox/entities/Base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from sqlalchemy.ext.declarative import as_declarative, declared_attr 3 | 4 | 5 | class_registry: Dict = {} 6 | 7 | 8 | @as_declarative(class_registry=class_registry) 9 | class Base: 10 | id: Any 11 | __name__: str 12 | 13 | # Generate __tablename__ automatically 14 | @declared_attr 15 | def __tablename__(cls) -> str: 16 | return cls.__name__.lower() 17 | -------------------------------------------------------------------------------- /whitebox/entities/DatasetRow.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, ForeignKey, DateTime, JSON 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | 5 | 6 | class DatasetRow(Base): 7 | __tablename__ = "dataset_rows" 8 | 9 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 10 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 11 | nonprocessed = Column(JSON) 12 | processed = Column(JSON) 13 | created_at = Column(DateTime) 14 | updated_at = Column(DateTime) 15 | -------------------------------------------------------------------------------- /whitebox/entities/DriftingMetric.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, ForeignKey, String, DateTime, JSON 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | 5 | 6 | class DriftingMetric(Base): 7 | __tablename__ = "drifting_metrics" 8 | 9 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 10 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 11 | timestamp = Column(DateTime) 12 | concept_drift_summary = Column(JSON) 13 | data_drift_summary = Column(JSON) 14 | created_at = Column(DateTime) 15 | updated_at = Column(DateTime) 16 | -------------------------------------------------------------------------------- /whitebox/entities/Inference.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Boolean, Column, String, ForeignKey, DateTime, JSON, Float 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | 5 | 6 | class InferenceRow(Base): 7 | __tablename__ = "inference_rows" 8 | 9 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 10 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 11 | timestamp = Column(DateTime) 12 | nonprocessed = Column(JSON) 13 | processed = Column(JSON) 14 | is_used = Column(Boolean) 15 | actual = Column(Float, nullable=True) 16 | 17 | created_at = Column(DateTime) 18 | updated_at = Column(DateTime) 19 | -------------------------------------------------------------------------------- /whitebox/entities/Model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, DateTime, JSON, Enum 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | from sqlalchemy.orm import relationship 5 | from whitebox.schemas.model import ModelType 6 | 7 | 8 | class Model(Base): 9 | __tablename__ = "models" 10 | 11 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 12 | name = Column(String) 13 | description = Column(String) 14 | type = Column("type", Enum(ModelType)) 15 | target_column = Column(String) 16 | granularity = Column(String) 17 | labels = Column(JSON, nullable=True) 18 | created_at = Column(DateTime) 19 | updated_at = Column(DateTime) 20 | 21 | dataset_rows = relationship("DatasetRow") 22 | inference_rows = relationship("InferenceRow") 23 | binary_classification_metrics = relationship("BinaryClassificationMetrics") 24 | multi_classification_metrics = relationship("MultiClassificationMetrics") 25 | regression_metrics = relationship("RegressionMetrics") 26 | drifting_metrics = relationship("DriftingMetric") 27 | model_integrity_metrics = relationship("ModelIntegrityMetric") 28 | model_monitors = relationship("ModelMonitor") 29 | -------------------------------------------------------------------------------- /whitebox/entities/ModelIntegrityMetric.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, JSON, Float, ForeignKey, String, DateTime 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | 5 | 6 | class ModelIntegrityMetric(Base): 7 | __tablename__ = "model_integrity_metrics" 8 | 9 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 10 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 11 | timestamp = Column(DateTime) 12 | feature_metrics = Column(JSON) 13 | created_at = Column(DateTime) 14 | updated_at = Column(DateTime) 15 | -------------------------------------------------------------------------------- /whitebox/entities/ModelMonitor.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Numeric, Enum, String, ForeignKey, DateTime 2 | from whitebox.entities.Base import Base 3 | from sqlalchemy.orm import relationship 4 | from whitebox.schemas.modelMonitor import AlertSeverity, MonitorMetrics, MonitorStatus 5 | from whitebox.utils.id_gen import generate_uuid 6 | 7 | 8 | class ModelMonitor(Base): 9 | __tablename__ = "model_monitors" 10 | 11 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 12 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 13 | name = Column(String) 14 | status = Column("status", Enum(MonitorStatus)) 15 | metric = Column("metric", Enum(MonitorMetrics)) 16 | feature = Column(String, nullable=True) 17 | lower_threshold = Column(Numeric, nullable=True) 18 | severity = Column("severity", Enum(AlertSeverity)) 19 | email = Column(String) 20 | created_at = Column(DateTime) 21 | updated_at = Column(DateTime) 22 | 23 | alerts = relationship("Alert") 24 | -------------------------------------------------------------------------------- /whitebox/entities/PerformanceMetric.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, Float, ForeignKey, String, DateTime, JSON 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | 5 | 6 | class BinaryClassificationMetrics(Base): 7 | __tablename__ = "binary_classification_metrics" 8 | 9 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 10 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 11 | timestamp = Column(DateTime) 12 | accuracy = Column(Float) 13 | precision = Column(Float) 14 | recall = Column(Float) 15 | f1 = Column(Float) 16 | true_negative = Column(Integer) 17 | false_positive = Column(Integer) 18 | false_negative = Column(Integer) 19 | true_positive = Column(Integer) 20 | created_at = Column(DateTime) 21 | updated_at = Column(DateTime) 22 | 23 | 24 | class MultiClassificationMetrics(Base): 25 | __tablename__ = "multi_classification_metrics" 26 | 27 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 28 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 29 | timestamp = Column(DateTime) 30 | accuracy = Column(Float) 31 | precision = Column(JSON) 32 | recall = Column(JSON) 33 | f1 = Column(JSON) 34 | confusion_matrix = Column(JSON) 35 | created_at = Column(DateTime) 36 | updated_at = Column(DateTime) 37 | 38 | 39 | class RegressionMetrics(Base): 40 | __tablename__ = "regression_metrics" 41 | 42 | id = Column(String, primary_key=True, unique=True, default=generate_uuid) 43 | model_id = Column(String, ForeignKey("models.id", ondelete="CASCADE")) 44 | timestamp = Column(DateTime) 45 | r_square = Column(Float) 46 | mean_squared_error = Column(Float) 47 | mean_absolute_error = Column(Float) 48 | created_at = Column(DateTime) 49 | updated_at = Column(DateTime) 50 | -------------------------------------------------------------------------------- /whitebox/entities/User.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, DateTime 2 | from whitebox.entities.Base import Base 3 | from whitebox.utils.id_gen import generate_uuid 4 | from sqlalchemy.orm import deferred, relationship 5 | 6 | 7 | class User(Base): 8 | __tablename__ = "users" 9 | 10 | id = Column(String, unique=True, primary_key=True, default=generate_uuid) 11 | username = Column(String) 12 | api_key = Column(String) 13 | created_at = Column(DateTime) 14 | updated_at = Column(DateTime) 15 | -------------------------------------------------------------------------------- /whitebox/entities/__init__.py: -------------------------------------------------------------------------------- 1 | from .Alert import Alert 2 | from .DatasetRow import DatasetRow 3 | from .DriftingMetric import DriftingMetric 4 | from .Inference import InferenceRow 5 | from .Model import Model 6 | from .ModelIntegrityMetric import ModelIntegrityMetric 7 | from .ModelMonitor import ModelMonitor 8 | from .PerformanceMetric import ( 9 | BinaryClassificationMetrics, 10 | MultiClassificationMetrics, 11 | RegressionMetrics, 12 | ) 13 | from .User import User 14 | -------------------------------------------------------------------------------- /whitebox/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.middleware.cors import CORSMiddleware 3 | from whitebox.api.v1 import v1_router 4 | from fastapi.openapi.utils import get_openapi 5 | from whitebox.api.v1.docs import ( 6 | tags_metadata, 7 | validation_error, 8 | authorization_error, 9 | not_found_error, 10 | conflict_error, 11 | content_gone, 12 | bad_request, 13 | ) 14 | from whitebox.core.settings import get_settings 15 | from whitebox.core.db import connect, close 16 | from starlette.exceptions import HTTPException as StarletteHTTPException 17 | from fastapi.exceptions import RequestValidationError 18 | 19 | from whitebox.utils.errors import errors 20 | 21 | settings = get_settings() 22 | 23 | app = FastAPI(title=settings.APP_NAME, redoc_url="/") 24 | 25 | app.add_middleware( 26 | CORSMiddleware, 27 | allow_origins=["*"], 28 | allow_credentials=True, 29 | allow_methods=["*"], 30 | allow_headers=["*"], 31 | ) 32 | 33 | app.include_router(v1_router) 34 | 35 | 36 | app.add_exception_handler(StarletteHTTPException, errors.http_exception_handler) 37 | app.add_exception_handler(RequestValidationError, errors.validation_exception_handler) 38 | 39 | 40 | @app.on_event("startup") 41 | async def on_app_start(): 42 | """Anything that needs to be done while app starts""" 43 | await connect() 44 | 45 | 46 | @app.on_event("shutdown") 47 | async def on_app_shutdown(): 48 | """Anything that needs to be done while app shutdown""" 49 | await close() 50 | 51 | 52 | def app_openapi(): 53 | if app.openapi_schema: 54 | return app.openapi_schema 55 | openapi_schema = get_openapi( 56 | title="Whitebox", 57 | version=settings.VERSION, 58 | routes=app.routes, 59 | tags=tags_metadata, 60 | ) 61 | 62 | openapi_schema["components"]["schemas"]["HTTPValidationError"] = validation_error 63 | openapi_schema["components"]["schemas"]["AuthorizationError"] = authorization_error 64 | openapi_schema["components"]["schemas"]["NotFoundError"] = not_found_error 65 | openapi_schema["components"]["schemas"]["ConflictError"] = conflict_error 66 | openapi_schema["components"]["schemas"]["BadRequest"] = bad_request 67 | openapi_schema["components"]["schemas"]["ContentGone"] = content_gone 68 | 69 | app.openapi_schema = openapi_schema 70 | return app.openapi_schema 71 | 72 | 73 | app.openapi = app_openapi 74 | -------------------------------------------------------------------------------- /whitebox/middleware/__initi__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/middleware/__initi__.py -------------------------------------------------------------------------------- /whitebox/middleware/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, HTTPException, status, Header 2 | from sqlalchemy.orm import Session 3 | from whitebox import crud 4 | from whitebox.schemas.user import User 5 | from whitebox.core.db import get_db 6 | from whitebox.utils.passwords import passwords_match 7 | 8 | 9 | async def authenticate_user( 10 | api_key: str = Header(), 11 | db: Session = Depends(get_db), 12 | ) -> User: 13 | user = crud.users.get_first_by_filter(db, username="admin") 14 | if not passwords_match(user.api_key, api_key): 15 | raise HTTPException( 16 | detail="Invalid API key", status_code=status.HTTP_401_UNAUTHORIZED 17 | ) 18 | return user 19 | -------------------------------------------------------------------------------- /whitebox/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from .inferenceRow import * 2 | from .model import * 3 | from .alert import * 4 | from .base import * 5 | from .datasetRow import * 6 | from .driftingMetric import * 7 | from .modelIntegrityMetric import * 8 | from .modelMonitor import * 9 | from .performanceMetric import * 10 | from .user import * 11 | from .utils import * 12 | -------------------------------------------------------------------------------- /whitebox/schemas/alert.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Union 3 | from pydantic import BaseModel 4 | from whitebox.schemas.base import ItemBase 5 | 6 | 7 | class AlertBase(BaseModel): 8 | model_id: str 9 | model_monitor_id: str 10 | timestamp: Union[str, datetime] 11 | description: str 12 | 13 | 14 | class Alert(AlertBase, ItemBase): 15 | pass 16 | 17 | 18 | class AlertCreateDto(AlertBase): 19 | pass 20 | -------------------------------------------------------------------------------- /whitebox/schemas/base.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pydantic import BaseModel 3 | 4 | 5 | class ItemBase(BaseModel): 6 | id: str 7 | created_at: datetime.datetime 8 | updated_at: datetime.datetime 9 | 10 | class Config: 11 | orm_mode = True 12 | -------------------------------------------------------------------------------- /whitebox/schemas/datasetRow.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from whitebox.schemas.base import ItemBase 3 | from typing import Dict, Any 4 | 5 | 6 | class DatasetRowBase(BaseModel): 7 | model_id: str 8 | # Data before any processing 9 | nonprocessed: Dict[str, Any] 10 | # Before model entry 11 | processed: Dict[str, float] 12 | 13 | 14 | class DatasetRow(DatasetRowBase, ItemBase): 15 | pass 16 | 17 | 18 | class DatasetRowCreate(DatasetRowBase): 19 | pass 20 | -------------------------------------------------------------------------------- /whitebox/schemas/driftingMetric.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Dict, List, Union 3 | from pydantic import BaseModel 4 | from whitebox.schemas.base import ItemBase 5 | 6 | 7 | class ColumnDataDriftMetrics(BaseModel): 8 | """One column drift metrics""" 9 | 10 | column_name: str 11 | column_type: str 12 | stattest_name: str 13 | drift_score: float 14 | drift_detected: bool 15 | threshold: float 16 | 17 | 18 | class DataDriftTable(BaseModel): 19 | number_of_columns: int 20 | number_of_drifted_columns: int 21 | share_of_drifted_columns: float 22 | dataset_drift: bool 23 | drift_by_columns: Dict[str, ColumnDataDriftMetrics] 24 | 25 | 26 | class CramerV(BaseModel): 27 | """CramerV statistics""" 28 | 29 | column_name: str 30 | kind: str 31 | values: Dict[str, List[str]] 32 | 33 | 34 | class ColumnConceptDriftCorrelationMetrics(BaseModel): 35 | """One column concept drift correlation metrics""" 36 | 37 | column_name: str 38 | current: Dict[str, CramerV] 39 | reference: Dict[str, CramerV] 40 | 41 | 42 | class ColumnConceptDriftMetrics(BaseModel): 43 | """One column concept drift metrics""" 44 | 45 | column_name: str 46 | column_type: str 47 | stattest_name: str 48 | drift_score: float 49 | drift_detected: bool 50 | stattest_threshold: float 51 | 52 | 53 | class ConceptDriftTable(BaseModel): 54 | """Concept drift Table metrics""" 55 | 56 | concept_drift_summary: ColumnConceptDriftMetrics 57 | column_correlation: ColumnConceptDriftCorrelationMetrics 58 | 59 | 60 | class DriftingMetricBase(ItemBase): 61 | model_id: str 62 | timestamp: Union[str, datetime] 63 | concept_drift_summary: ConceptDriftTable 64 | data_drift_summary: DataDriftTable 65 | 66 | 67 | class DriftingMetric(DriftingMetricBase, ItemBase): 68 | pass 69 | -------------------------------------------------------------------------------- /whitebox/schemas/inferenceRow.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Any, Dict, Union, Optional 3 | from pydantic import BaseModel 4 | from whitebox.schemas.base import ItemBase 5 | 6 | 7 | class InferenceRowBase(BaseModel): 8 | model_id: str 9 | timestamp: Union[str, datetime] 10 | # Prediction is included into nonprocessed & processed 11 | nonprocessed: Dict[str, Any] 12 | processed: Dict[str, float] 13 | actual: Optional[float] 14 | 15 | 16 | class InferenceRowCreateDto(InferenceRowBase): 17 | pass 18 | 19 | 20 | class InferenceRowPreDb(InferenceRowBase): 21 | is_used: bool 22 | 23 | 24 | class InferenceRow(InferenceRowPreDb, ItemBase): 25 | pass 26 | -------------------------------------------------------------------------------- /whitebox/schemas/model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from pydantic import BaseModel 3 | from whitebox.schemas.base import ItemBase 4 | import enum 5 | 6 | 7 | class ModelType(str, enum.Enum): 8 | binary = "binary" 9 | multi_class = "multi_class" 10 | regression = "regression" 11 | 12 | 13 | class ModelBase(BaseModel): 14 | name: str 15 | description: str 16 | type: ModelType 17 | target_column: str 18 | granularity: str 19 | labels: Optional[Dict[str, int]] 20 | 21 | 22 | class Model(ModelBase, ItemBase): 23 | pass 24 | 25 | 26 | class ModelCreateDto(ModelBase): 27 | pass 28 | 29 | 30 | class ModelUpdateDto(BaseModel): 31 | name: Optional[str] 32 | description: Optional[str] 33 | -------------------------------------------------------------------------------- /whitebox/schemas/modelIntegrityMetric.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from pydantic import BaseModel 3 | from whitebox.schemas.base import ItemBase 4 | from datetime import datetime 5 | 6 | 7 | class FeatureMetrics(BaseModel): 8 | missing_count: Dict[str, int] 9 | non_missing_count: Dict[str, int] 10 | mean: Dict[str, float] 11 | minimum: Dict[str, float] 12 | maximum: Dict[str, float] 13 | sum: Dict[str, float] 14 | standard_deviation: Dict[str, float] 15 | variance: Dict[str, float] 16 | 17 | 18 | class ModelIntegrityMetricBase(BaseModel): 19 | model_id: str 20 | timestamp: Union[str, datetime] 21 | feature_metrics: FeatureMetrics 22 | 23 | 24 | class ModelIntegrityMetric(ModelIntegrityMetricBase, ItemBase): 25 | pass 26 | 27 | 28 | class ModelIntegrityMetricCreate(ModelIntegrityMetricBase): 29 | pass 30 | -------------------------------------------------------------------------------- /whitebox/schemas/modelMonitor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel 3 | import enum 4 | from whitebox.schemas.base import ItemBase 5 | 6 | 7 | class MonitorStatus(str, enum.Enum): 8 | active = "active" 9 | inactive = "inactive" 10 | 11 | 12 | class MonitorMetrics(str, enum.Enum): 13 | # Performance metrics 14 | accuracy = "accuracy" 15 | precision = "precision" 16 | recall = "recall" 17 | f1 = "f1" 18 | r_square = "r_square" 19 | mean_squared_error = "mean_squared_error" 20 | mean_absolute_error = "mean_absolute_error" 21 | 22 | # Drifting metrics 23 | data_drift = "data_drift" 24 | concept_drift = "concept_drift" 25 | 26 | 27 | class AlertSeverity(str, enum.Enum): 28 | low = "low" 29 | mid = "mid" 30 | high = "high" 31 | 32 | 33 | class ModelMonitorBase(BaseModel): 34 | model_id: str 35 | name: str 36 | status: MonitorStatus 37 | metric: MonitorMetrics 38 | severity: AlertSeverity 39 | email: str 40 | feature: Optional[str] 41 | lower_threshold: Optional[float] 42 | 43 | 44 | class ModelMonitor(ModelMonitorBase, ItemBase): 45 | pass 46 | 47 | 48 | class ModelMonitorCreateDto(ModelMonitorBase): 49 | pass 50 | 51 | 52 | class ModelMonitorUpdateDto(BaseModel): 53 | name: Optional[str] 54 | status: Optional[MonitorStatus] 55 | severity: Optional[AlertSeverity] 56 | email: Optional[str] 57 | lower_threshold: Optional[float] 58 | -------------------------------------------------------------------------------- /whitebox/schemas/performanceMetric.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pydantic import BaseModel 3 | from typing import Dict, Union 4 | from whitebox.schemas.base import ItemBase 5 | 6 | # TODO: Include comments of what each class represents 7 | 8 | 9 | class BinaryClassificationMetricsPipelineResult(BaseModel): 10 | """This class is used to store the results of the pipeline that calculates the binary classification metrics""" 11 | 12 | accuracy: float 13 | precision: float 14 | recall: float 15 | f1: float 16 | true_negative: int 17 | false_positive: int 18 | false_negative: int 19 | true_positive: int 20 | 21 | 22 | class BinaryClassificationMetricsBase(BinaryClassificationMetricsPipelineResult): 23 | model_id: str 24 | timestamp: Union[str, datetime] 25 | 26 | 27 | class BinaryClassificationMetrics(BinaryClassificationMetricsBase, ItemBase): 28 | pass 29 | 30 | 31 | class DifferentStatistics(BaseModel): 32 | micro: float 33 | macro: float 34 | weighted: float 35 | 36 | 37 | class ConfusionMatrix(BaseModel): 38 | true_negative: int 39 | false_positive: int 40 | false_negative: int 41 | true_positive: int 42 | 43 | 44 | class MultiClassificationMetricsPipelineResult(BaseModel): 45 | """This class is used to store the results of the pipeline that calculates the multi classification metrics""" 46 | 47 | accuracy: float 48 | precision: DifferentStatistics 49 | recall: DifferentStatistics 50 | f1: DifferentStatistics 51 | confusion_matrix: Dict[str, ConfusionMatrix] 52 | 53 | 54 | class MultiClassificationMetricsBase(MultiClassificationMetricsPipelineResult): 55 | model_id: str 56 | timestamp: Union[str, datetime] 57 | 58 | 59 | class MultiClassificationMetrics(MultiClassificationMetricsBase, ItemBase): 60 | pass 61 | 62 | 63 | class RegressionMetricsPipelineResult(BaseModel): 64 | """This class is used to store the results of the pipeline that calculates the regression metrics""" 65 | 66 | r_square: float 67 | mean_squared_error: float 68 | mean_absolute_error: float 69 | 70 | 71 | class RegressionMetricsBase(RegressionMetricsPipelineResult): 72 | model_id: str 73 | timestamp: Union[str, datetime] 74 | 75 | 76 | class RegressionMetrics(RegressionMetricsBase, ItemBase): 77 | pass 78 | -------------------------------------------------------------------------------- /whitebox/schemas/task.py: -------------------------------------------------------------------------------- 1 | from asyncio.tasks import Task 2 | from dataclasses import dataclass 3 | 4 | from pydantic import BaseModel, Field 5 | from typing import Callable, Coroutine, Optional, List, Dict, Deque, Literal, Union 6 | from uuid import UUID, uuid4 7 | 8 | import datetime 9 | import pytz 10 | 11 | 12 | def now(): 13 | return datetime.datetime.utcnow().replace(tzinfo=pytz.utc) 14 | 15 | 16 | TaskStatus = Literal[ 17 | "registered", "running", "finished", "pending", "cancelled", "failed" 18 | ] 19 | 20 | EventType = Literal[ 21 | "task_registered", 22 | "task_started", 23 | "task_failed", 24 | "task_finished", 25 | "task_cancelled", 26 | "task_disabled", 27 | ] 28 | 29 | 30 | class TaskInfo(BaseModel): 31 | uid: UUID = uuid4() 32 | name: str 33 | status: str 34 | previous_status: str 35 | enabled: bool 36 | crontab: Union[str, None] = None 37 | created_at: Union[datetime.datetime, float] = now().timestamp() 38 | started_at: Union[datetime.datetime, float, None] = None 39 | stopped_at: Union[datetime.datetime, float, None] = None 40 | next_run_in: Union[int, None] = None 41 | 42 | 43 | class TaskDefinition(BaseModel): 44 | name: str 45 | async_callable: Callable[[], Coroutine] 46 | enabled: bool = True 47 | crontab: Optional[str] = None 48 | 49 | 50 | @dataclass 51 | class RunningTask: 52 | task_definition: TaskDefinition 53 | asyncio_task: Task 54 | since: Union[datetime.datetime, float] 55 | 56 | 57 | class TaskLog(BaseModel): 58 | event_type: EventType 59 | task_name: str 60 | crontab: Union[str, None] = None 61 | enabled: bool 62 | error: Union[str, None] = None 63 | timestamp: int = Field(default_factory=lambda: datetime.datetime.now().timestamp()) 64 | 65 | 66 | class State(BaseModel): 67 | created_at: datetime.datetime 68 | tasks_info: List[TaskInfo] 69 | 70 | 71 | class TaskRealTimeInfo(BaseModel): 72 | name: str 73 | status: TaskStatus 74 | previous_status: Optional[TaskStatus] 75 | next_run_ts: Optional[int] 76 | started_at: Optional[Union[datetime.datetime, float]] 77 | stopped_at: Optional[Union[datetime.datetime, float]] 78 | -------------------------------------------------------------------------------- /whitebox/schemas/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from pydantic import BaseModel 3 | from whitebox.schemas.base import ItemBase 4 | 5 | 6 | class UserBase(BaseModel): 7 | username: str 8 | 9 | 10 | class User(UserBase, ItemBase): 11 | pass 12 | 13 | 14 | class UserCreateDto(UserBase): 15 | api_key: str 16 | -------------------------------------------------------------------------------- /whitebox/schemas/utils.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class HealthCheck(BaseModel): 5 | status: str 6 | 7 | 8 | class StatusCode(BaseModel): 9 | status_code: str 10 | 11 | 12 | class ErrorProps(BaseModel): 13 | error: str 14 | status_code: int 15 | 16 | 17 | class ErrorResponse(BaseModel): 18 | title: str 19 | type: str 20 | properties: ErrorProps 21 | -------------------------------------------------------------------------------- /whitebox/sdk/__init__.py: -------------------------------------------------------------------------------- 1 | from .whitebox import * 2 | -------------------------------------------------------------------------------- /whitebox/streamlit/app.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import streamlit as st 4 | from typing import Dict, Union 5 | 6 | from tabs.drifting import * 7 | from tabs.sidebar import * 8 | from tabs.overview import * 9 | from tabs.performance import * 10 | from tabs.inferences import * 11 | from tabs.monitors import * 12 | from tabs.alerts import * 13 | from cards import * 14 | from utils.transformation import get_model_from_name 15 | 16 | st.set_option("deprecation.showPyplotGlobalUse", False) 17 | 18 | 19 | # ---------------------------------------------- 20 | def format_evaluation_metrics_binary( 21 | accuracy: float, 22 | precision: float, 23 | recall: float, 24 | f1: float, 25 | tn: int, 26 | fp: int, 27 | fn: int, 28 | tp: int, 29 | ) -> Dict[str, Union[int, float]]: 30 | formated_metrics_for_binary = { 31 | "accuracy": accuracy, 32 | "precision": precision, 33 | "recall": recall, 34 | "f1": f1, 35 | "true_negative": tn, 36 | "false_positive": fp, 37 | "false_negative": fn, 38 | "true_positive": tp, 39 | } 40 | 41 | return formated_metrics_for_binary 42 | 43 | 44 | evaluation_metrics_binary = format_evaluation_metrics_binary( 45 | 0.64, 0.5, 0.11, 0.72, 1200, 600, 840, 260 46 | ) 47 | evaluation_metrics_binary_df = pd.DataFrame(evaluation_metrics_binary, index=[0]) 48 | base_evaluation_metrics_binary_df = evaluation_metrics_binary_df[ 49 | ["accuracy", "precision", "recall", "f1"] 50 | ] 51 | # Conf matrix 52 | first_part = [ 53 | evaluation_metrics_binary["true_positive"], 54 | evaluation_metrics_binary["false_positive"], 55 | ] 56 | second_part = [ 57 | evaluation_metrics_binary["false_negative"], 58 | evaluation_metrics_binary["true_negative"], 59 | ] 60 | cm = np.array([first_part, second_part]) 61 | 62 | # ----------------------------------- 63 | overview, performance, drifting, inferences, monitors, alerts = st.tabs( 64 | ["Overview", "Performance", "Drifting", "Inferences", "Monitors", "Alerts"] 65 | ) 66 | model_option, models_list, checkbox, wb = create_sidebar() 67 | 68 | if checkbox: 69 | model = get_model_from_name(models_list, model_option) 70 | 71 | if model: 72 | pred_column = model["target_column"] 73 | model_id = model["id"] 74 | model_type = model["type"] 75 | # TODO: Need to connect this one with the db. 76 | with overview: 77 | create_overview_tab(model, cm, base_evaluation_metrics_binary_df) 78 | 79 | with performance: 80 | create_performance_tab(wb, model_id, model_type) 81 | 82 | with drifting: 83 | create_drift_tab(wb, model_id) 84 | 85 | with inferences: 86 | create_inferences_tab(wb, model_id, pred_column) 87 | 88 | with monitors: 89 | create_monitors_tab(wb, model_id, model_type) 90 | 91 | with alerts: 92 | create_alerts_tab(wb, model_id) 93 | -------------------------------------------------------------------------------- /whitebox/streamlit/cards.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | def card1(text1, text2): 5 | st.markdown( 6 | f""" 7 |
8 |
9 |

{text1}

10 |

{text2}

11 |
12 |
13 | """, 14 | unsafe_allow_html=True, 15 | ) 16 | 17 | 18 | def card(text1, text2, text3): 19 | st.markdown( 20 | f""" 21 |
22 |
23 |
{text1}
24 |
type: {text2}
25 |

{text3}

26 |
27 |
28 | """, 29 | unsafe_allow_html=True, 30 | ) 31 | -------------------------------------------------------------------------------- /whitebox/streamlit/classification_test_data copy.csv: -------------------------------------------------------------------------------- 1 | y_testing_binary,y_prediction_binary,y_testing_multi,y_prediction_multi 2 | 0,0,0,1 3 | 0,0,2,2 4 | 0,1,0,0 5 | 1,1,1,1 6 | 0,0,2,0 7 | 1,0,1,1 8 | 1,0,1,0 9 | 1,1,2,2 10 | 0,1,0,1 11 | 1,1,1,1 12 | -------------------------------------------------------------------------------- /whitebox/streamlit/classification_test_data.csv: -------------------------------------------------------------------------------- 1 | y_testing_binary,y_prediction_binary,y_testing_multi,y_prediction_multi 2 | 0,0,0,1 3 | 0,0,2,2 4 | 0,1,0,0 5 | 1,1,1,1 6 | 0,0,2,0 7 | 1,0,1,1 8 | 1,0,1,0 9 | 1,1,2,2 10 | 0,1,0,1 11 | 1,1,1,1 12 | -------------------------------------------------------------------------------- /whitebox/streamlit/config/config_readme.toml: -------------------------------------------------------------------------------- 1 | [app] 2 | app_intro = """ 3 | This streamlit app allows you to visualise and monitor the results and all related analytics produced by Whitebox. 4 | 5 | All you have to do is to select your desired model which is stored in your database and then navigate through the different tabs. 6 | """ 7 | [tooltips] 8 | model_option = """ 9 | These are the models that are saved in your connected database. 10 | """ 11 | overview_performance = """ 12 | These are the model's performance metrics after it was trained with the training data.  13 | """ 14 | monitor_name = """ 15 | Add your preference name for the new monitor.  16 | """ 17 | monitor_use_case = """ 18 | Add your preference type of monitor.  19 | """ 20 | stat_thresh_monitor = """ 21 | Enter a value between 0 and 1 with max two decimal points, e.g. 0.25.  22 | """ 23 | alert_trig_monitor = """ 24 | Alert will trigger if below the lower threshold or above the upper threshold.  25 | """ 26 | alert_severity_monitor = """ 27 | What alert severity should be associated with the notifications being sent? 28 | """ 29 | notifications_monitor = """ 30 | Notifications will be sent via email. Please provide your email below: 31 | """ 32 | host = """ 33 | Your host and port combined in a url. 34 | """ 35 | api_key = """ 36 | Your api key as was created from the initialisation of Whitebox. 37 | """ 38 | model_name = """ 39 | Your desired model name. 40 | """ 41 | model_description = """ 42 | Your desired model description. 43 | """ 44 | model_type = """ 45 | Your desired model type. 46 | """ 47 | target_column = """ 48 | Your desired target column as depicted in your data. 49 | """ 50 | granularity_amount = """ 51 | Your data's granularity amount (must be an integer). 52 | """ 53 | granularity_type = """ 54 | Your data's granularity type. 55 | """ 56 | [links] 57 | repo = "https://github.com/squaredev-io/whitebox" -------------------------------------------------------------------------------- /whitebox/streamlit/mock/alerts.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": 44, 4 | "model_monitor_id": 1, 5 | "timestamp": "2023-03-14", 6 | "description": "fell below the threshold" 7 | }, 8 | { 9 | "id": 45, 10 | "model_monitor_id": 1, 11 | "timestamp": "2023-03-16", 12 | "description": "fell below the threshold" 13 | }, 14 | { 15 | "id": 46, 16 | "model_monitor_id": 3, 17 | "timestamp": "2023-03-16", 18 | "description": "fell increasinlgy below the threshold" 19 | } 20 | ] -------------------------------------------------------------------------------- /whitebox/streamlit/mock/inferences.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": 0, 4 | "model_id": "test", 5 | "nonprocessed": { 6 | "first_col": 5.1, 7 | "sec_col": 3.5, 8 | "third_col": 1.4, 9 | "four_col": 0.2, 10 | "target": 0 11 | }, 12 | "processed": null, 13 | "timestamp": "2023-2-14", 14 | "actual": 0 15 | }, 16 | { 17 | "id": 1, 18 | "model_id": "test", 19 | "nonprocessed": { 20 | "first_col": 4.9, 21 | "sec_col": 3.0, 22 | "third_col": 1.4, 23 | "four_col": 0.2, 24 | "target": 0 25 | }, 26 | "processed": null, 27 | "timestamp": "2023-2-14", 28 | "actual": 0 29 | }, 30 | { 31 | "id": 2, 32 | "model_id": "test", 33 | "nonprocessed": { 34 | "first_col": 4.7, 35 | "sec_col": 3.2, 36 | "third_col": 1.3, 37 | "four_col": 0.2, 38 | "target": 0 39 | }, 40 | "processed": null, 41 | "timestamp": "2023-2-14", 42 | "actual": 0 43 | }, 44 | { 45 | "id": 3, 46 | "model_id": "test", 47 | "nonprocessed": { 48 | "first_col": 4.6, 49 | "sec_col": 3.1, 50 | "third_col": 1.5, 51 | "four_col": 0.2, 52 | "target": 1 53 | }, 54 | "processed": null, 55 | "timestamp": "2023-2-14", 56 | "actual": 0 57 | }, 58 | { 59 | "id": 4, 60 | "model_id": "test", 61 | "nonprocessed": { 62 | "first_col": 5.0, 63 | "sec_col": 3.6, 64 | "third_col": 1.4, 65 | "four_col": 0.2, 66 | "target": 0 67 | }, 68 | "processed": null, 69 | "timestamp": "2023-2-14", 70 | "actual": 0 71 | }, 72 | { 73 | "id": 5, 74 | "model_id": "test", 75 | "nonprocessed": { 76 | "first_col": 5.4, 77 | "sec_col": 3.9, 78 | "third_col": 1.7, 79 | "four_col": 0.4, 80 | "target": 1 81 | }, 82 | "processed": null, 83 | "timestamp": "2023-2-14", 84 | "actual": 0 85 | }, 86 | { 87 | "id": 6, 88 | "model_id": "test", 89 | "nonprocessed": { 90 | "first_col": 4.6, 91 | "sec_col": 3.4, 92 | "third_col": 1.4, 93 | "four_col": 0.3, 94 | "target": 0 95 | }, 96 | "processed": null, 97 | "timestamp": "2023-2-14", 98 | "actual": 0 99 | }, 100 | { 101 | "id": 7, 102 | "model_id": "test", 103 | "nonprocessed": { 104 | "first_col": 5.0, 105 | "sec_col": 3.4, 106 | "third_col": 1.5, 107 | "four_col": 0.2, 108 | "target": 0 109 | }, 110 | "processed": null, 111 | "timestamp": "2023-2-14", 112 | "actual": 0 113 | }, 114 | { 115 | "id": 8, 116 | "model_id": "test", 117 | "nonprocessed": { 118 | "first_col": 4.4, 119 | "sec_col": 2.9, 120 | "third_col": 1.4, 121 | "four_col": 0.2, 122 | "target": 1 123 | }, 124 | "processed": null, 125 | "timestamp": "2023-2-14", 126 | "actual": 0 127 | }, 128 | { 129 | "id": 9, 130 | "model_id": "test", 131 | "nonprocessed": { 132 | "first_col": 4.9, 133 | "sec_col": 3.1, 134 | "third_col": 1.5, 135 | "four_col": 0.1, 136 | "target": 0 137 | }, 138 | "processed": null, 139 | "timestamp": "2023-2-14", 140 | "actual": 0 141 | } 142 | ] -------------------------------------------------------------------------------- /whitebox/streamlit/mock/monitors.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": 1, 4 | "model_id": "model_test", 5 | "name": "my_custom_metric", 6 | "status": "active", 7 | "metric": "accuracy", 8 | "severity": "low", 9 | "email": "test@squaredev.io", 10 | "lower_threshold": 0.7, 11 | "updated_at": "2023-02-16" 12 | }, 13 | { 14 | "id": 2, 15 | "model_id": "model_test", 16 | "name": "another_metric", 17 | "status": "active", 18 | "metric": "data_drift", 19 | "severity": "low", 20 | "email": "test@squaredev.io", 21 | "lower_threshold": 0.5, 22 | "updated_at": "2023-02-14" 23 | }, 24 | { 25 | "id": 3, 26 | "model_id": "model_test", 27 | "name": "check this metric", 28 | "status": "active", 29 | "metric": "accuracy", 30 | "severity": "low", 31 | "email": "test@squaredev.io", 32 | "lower_threshold": 0.5, 33 | "updated_at": "2023-02-14" 34 | } 35 | ] -------------------------------------------------------------------------------- /whitebox/streamlit/mock/performance.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "b4bff1cd-76bd-47e6-8446-4ce690096dac", 4 | "created_at": "2023-02-09T14:33:00.473312", 5 | "updated_at": "2023-02-09T14:33:00.473312", 6 | "accuracy": 0.6, 7 | "precision": { 8 | "micro": 0.6, 9 | "macro": 0.7333333333333334, 10 | "weighted": 0.9199999999999999 11 | }, 12 | "recall": { 13 | "micro": 0.6, 14 | "macro": 0.7000000000000001, 15 | "weighted": 0.6 16 | }, 17 | "f1": { 18 | "micro": 0.6, 19 | "macro": 0.5833333333333334, 20 | "weighted": 0.675 21 | }, 22 | "confusion_matrix": { 23 | "class0": { 24 | "true_negative": 5, 25 | "false_positive": 4, 26 | "false_negative": 0, 27 | "true_positive": 1 28 | }, 29 | "class1": { 30 | "true_negative": 5, 31 | "false_positive": 0, 32 | "false_negative": 2, 33 | "true_positive": 3 34 | }, 35 | "class2": { 36 | "true_negative": 6, 37 | "false_positive": 0, 38 | "false_negative": 2, 39 | "true_positive": 2 40 | } 41 | }, 42 | "model_id": "83539c2b-579f-4a2c-b7ba-02d31c9408d8", 43 | "timestamp": "2022-12-22T00:00:00" 44 | }, 45 | { 46 | "id": "0ffa0e1d-8e8a-4658-8443-7405a9cd6fd5", 47 | "created_at": "2023-02-09T14:36:00.522083", 48 | "updated_at": "2023-02-09T14:36:00.522083", 49 | "accuracy": 0.45, 50 | "precision": { 51 | "micro": 0.45, 52 | "macro": 0.48611111111111116, 53 | "weighted": 0.6125 54 | }, 55 | "recall": { 56 | "micro": 0.45, 57 | "macro": 0.46296296296296297, 58 | "weighted": 0.45 59 | }, 60 | "f1": { 61 | "micro": 0.45, 62 | "macro": 0.4222222222222222, 63 | "weighted": 0.5000000000000001 64 | }, 65 | "confusion_matrix": { 66 | "class0": { 67 | "true_negative": 11, 68 | "false_positive": 7, 69 | "false_negative": 1, 70 | "true_positive": 1 71 | }, 72 | "class1": { 73 | "true_negative": 10, 74 | "false_positive": 1, 75 | "false_negative": 4, 76 | "true_positive": 5 77 | }, 78 | "class2": { 79 | "true_negative": 8, 80 | "false_positive": 3, 81 | "false_negative": 6, 82 | "true_positive": 3 83 | } 84 | }, 85 | "model_id": "83539c2b-579f-4a2c-b7ba-02d31c9408d8", 86 | "timestamp": "2022-12-23T00:00:00" 87 | }, 88 | { 89 | "id": "ca19193b-5c9b-478b-93a7-48c7b08af306", 90 | "created_at": "2023-02-09T14:39:00.618561", 91 | "updated_at": "2023-02-09T14:39:00.618561", 92 | "accuracy": 0.43333333333333335, 93 | "precision": { 94 | "micro": 0.43333333333333335, 95 | "macro": 0.48860398860398857, 96 | "weighted": 0.6921652421652421 97 | }, 98 | "recall": { 99 | "micro": 0.43333333333333335, 100 | "macro": 0.4447415329768271, 101 | "weighted": 0.43333333333333335 102 | }, 103 | "f1": { 104 | "micro": 0.43333333333333335, 105 | "macro": 0.389923526765632, 106 | "weighted": 0.5119928025191183 107 | }, 108 | "confusion_matrix": { 109 | "class0": { 110 | "true_negative": 16, 111 | "false_positive": 12, 112 | "false_negative": 1, 113 | "true_positive": 1 114 | }, 115 | "class1": { 116 | "true_negative": 12, 117 | "false_positive": 1, 118 | "false_negative": 9, 119 | "true_positive": 8 120 | }, 121 | "class2": { 122 | "true_negative": 15, 123 | "false_positive": 4, 124 | "false_negative": 7, 125 | "true_positive": 4 126 | } 127 | }, 128 | "model_id": "83539c2b-579f-4a2c-b7ba-02d31c9408d8", 129 | "timestamp": "2022-12-24T00:00:00" 130 | } 131 | ] -------------------------------------------------------------------------------- /whitebox/streamlit/mock_app.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import streamlit as st 4 | from typing import Dict, Union, List 5 | from matplotlib import pyplot as plt 6 | import json 7 | 8 | 9 | from tabs.drifting import * 10 | from tabs.sidebar import * 11 | from tabs.overview import * 12 | from tabs.performance import * 13 | from tabs.inferences import * 14 | from tabs.monitors import * 15 | from tabs.alerts import * 16 | from cards import * 17 | 18 | st.set_option("deprecation.showPyplotGlobalUse", False) 19 | 20 | 21 | # The below lines are temp until we have the performance metricc functionality 22 | # ---------------------------------------- 23 | def format_evaluation_metrics_binary( 24 | accuracy: float, 25 | precision: float, 26 | recall: float, 27 | f1: float, 28 | tn: int, 29 | fp: int, 30 | fn: int, 31 | tp: int, 32 | ) -> Dict[str, Union[int, float]]: 33 | formated_metrics_for_binary = { 34 | "accuracy": accuracy, 35 | "precision": precision, 36 | "recall": recall, 37 | "f1": f1, 38 | "true_negative": tn, 39 | "false_positive": fp, 40 | "false_negative": fn, 41 | "true_positive": tp, 42 | } 43 | 44 | return formated_metrics_for_binary 45 | 46 | 47 | model_names = ["model test", "decision_tree", "random_forest", "custom_tf_model"] 48 | 49 | model = { 50 | "id": "001", 51 | "name": "model test", 52 | "description": "a model for testing visualisations", 53 | "type": "binary", 54 | "target_column": "target", 55 | "labels": {"default": 0, "no_default": 1}, 56 | "created_at": "2022-05-05", 57 | "updated_at": "2022-05-05", 58 | } 59 | evaluation_metrics_binary = format_evaluation_metrics_binary( 60 | 0.64, 0.5, 0.11, 0.72, 1200, 600, 840, 260 61 | ) 62 | evaluation_metrics_binary_df = pd.DataFrame(evaluation_metrics_binary, index=[0]) 63 | base_evaluation_metrics_binary_df = evaluation_metrics_binary_df[ 64 | ["accuracy", "precision", "recall", "f1"] 65 | ] 66 | # Conf matrix 67 | first_part = [ 68 | evaluation_metrics_binary["true_positive"], 69 | evaluation_metrics_binary["false_positive"], 70 | ] 71 | second_part = [ 72 | evaluation_metrics_binary["false_negative"], 73 | evaluation_metrics_binary["true_negative"], 74 | ] 75 | cm = np.array([first_part, second_part]) 76 | 77 | f = open("whitebox/streamlit/mock/drift.json") 78 | drift = json.load(f) 79 | f.close() 80 | 81 | f = open("whitebox/streamlit/mock/performance.json") 82 | perf = json.load(f) 83 | f.close() 84 | 85 | f = open("whitebox/streamlit/mock/inferences.json") 86 | inf = json.load(f) 87 | f.close() 88 | 89 | f = open("whitebox/streamlit/mock/monitors.json") 90 | mon = json.load(f) 91 | f.close() 92 | 93 | f = open("whitebox/streamlit/mock/alerts.json") 94 | al = json.load(f) 95 | f.close() 96 | 97 | pred_column = model["prediction"] 98 | 99 | # ----------------------------------- 100 | overview, performance, drifting, inferences, monitors, alerts = st.tabs( 101 | ["Overview", "Performance", "Drifting", "Inferences", "Monitors", "Alerts"] 102 | ) 103 | 104 | model_option, button = create_sidebar(model_names) 105 | 106 | if button: 107 | with overview: 108 | create_overview_tab(model, cm, base_evaluation_metrics_binary_df) 109 | 110 | with performance: 111 | create_performance_tab(perf, model) 112 | 113 | with drifting: 114 | create_drift_tab(drift) 115 | 116 | with inferences: 117 | create_inferences_tab(inf, pred_column) 118 | 119 | with monitors: 120 | create_monitors_tab(mon, al) 121 | 122 | with alerts: 123 | create_alerts_tab(al, mon) 124 | -------------------------------------------------------------------------------- /whitebox/streamlit/references/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/streamlit/references/logo.png -------------------------------------------------------------------------------- /whitebox/streamlit/references/whitebox2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/streamlit/references/whitebox2.png -------------------------------------------------------------------------------- /whitebox/streamlit/tabs/alerts.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | from utils.export import structure 4 | from utils.transformation import combine_monitor_with_alert_for_alerts 5 | 6 | from pandas.api.types import ( 7 | is_categorical_dtype, 8 | is_datetime64_any_dtype, 9 | is_numeric_dtype, 10 | is_object_dtype, 11 | ) 12 | 13 | import os, sys 14 | 15 | sys.path.insert(0, os.path.abspath("./")) 16 | from whitebox import Whitebox 17 | 18 | 19 | def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: 20 | """ 21 | Creates filtering on top of a dataframe 22 | 23 | Args: 24 | df (pd.DataFrame): Input dataframe 25 | 26 | Returns: 27 | pd.DataFrame: Filtered dataframe 28 | """ 29 | filter_checkbox = st.checkbox("Add filters") 30 | 31 | if not filter_checkbox: 32 | return df 33 | 34 | df = df.copy() 35 | 36 | # Convert datetimes into a standard format (datetime with no timezone) 37 | for col in df.columns: 38 | if is_object_dtype(df[col]): 39 | try: 40 | df[col] = pd.to_datetime(df[col]) 41 | except Exception: 42 | pass 43 | 44 | if is_datetime64_any_dtype(df[col]): 45 | df[col] = df[col].dt.tz_localize(None) 46 | 47 | filter_container = st.container() 48 | 49 | with filter_container: 50 | filtering_columns = st.multiselect("Filter dataframe on", df.columns) 51 | 52 | for column in filtering_columns: 53 | left, right = st.columns((1, 20)) 54 | 55 | # Treat columns with < 10 unique values as categorical 56 | if is_categorical_dtype(df[column]) or df[column].nunique() < 10: 57 | user_cat_input = right.multiselect( 58 | f"Values for {column}", 59 | df[column].unique(), 60 | default=list(df[column].unique()), 61 | ) 62 | df = df[df[column].isin(user_cat_input)] 63 | 64 | elif is_numeric_dtype(df[column]): 65 | _min = float(df[column].min()) 66 | _max = float(df[column].max()) 67 | step = (_max - _min) / 100 68 | user_num_input = right.slider( 69 | f"Values for {column}", 70 | min_value=_min, 71 | max_value=_max, 72 | value=(_min, _max), 73 | step=step, 74 | ) 75 | df = df[df[column].between(*user_num_input)] 76 | 77 | elif is_datetime64_any_dtype(df[column]): 78 | user_date_input = right.date_input( 79 | f"Values for {column}", 80 | value=( 81 | df[column].min(), 82 | df[column].max(), 83 | ), 84 | ) 85 | 86 | if len(user_date_input) == 2: 87 | user_date_input = tuple(map(pd.to_datetime, user_date_input)) 88 | start_date, end_date = user_date_input 89 | df = df.loc[df[column].between(start_date, end_date)] 90 | 91 | else: 92 | user_text_input = right.text_input( 93 | f"Substring or regex in {column}", 94 | ) 95 | 96 | if user_text_input: 97 | df = df[df[column].astype(str).str.contains(user_text_input)] 98 | 99 | return df 100 | 101 | 102 | def create_alerts_tab(wb: Whitebox, model_id: str) -> None: 103 | """ 104 | Creates the alerts tab in Streamlit. 105 | A table with all the alerts is visualised. 106 | """ 107 | 108 | with st.spinner("Loading alerts..."): 109 | structure() 110 | monitors = wb.get_model_monitors(model_id) 111 | alerts = wb.get_alerts(model_id) 112 | total_alerts = len(alerts) 113 | st.title("Alerts (" + str(total_alerts) + ")") 114 | 115 | alerts_df = pd.DataFrame(alerts) 116 | monitors_df = pd.DataFrame(monitors) 117 | 118 | if (len(alerts_df) > 0) & (len(monitors_df) > 0): 119 | merged_df = combine_monitor_with_alert_for_alerts(monitors_df, alerts_df) 120 | show_df = merged_df[["timestamp", "metric", "description", "name"]] 121 | show_df.columns = [ 122 | "Anomaly timestamp", 123 | "Metric", 124 | "Anomaly details", 125 | "Monitor Name", 126 | ] 127 | filtered_df = filter_dataframe(show_df) 128 | st.dataframe(filtered_df, width=1200, height=300) 129 | -------------------------------------------------------------------------------- /whitebox/streamlit/tabs/drifting.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import plotly.express as px 3 | 4 | from utils.transformation import export_drift_timeseries 5 | from utils.export import structure 6 | 7 | import os, sys 8 | 9 | sys.path.insert(0, os.path.abspath("./")) 10 | 11 | from whitebox import Whitebox 12 | 13 | 14 | def create_drift_tab(wb: Whitebox, model_id: str) -> None: 15 | """ 16 | Creates the dift tab in Streamlit. 17 | 18 | Gets the drift object and plots via streamlit the drifting graphs. 19 | It creates 2 tabs of graphs, one with the combined drifts of variables 20 | and one tab with graphs for each one variable. 21 | """ 22 | with st.spinner("Loading model drift..."): 23 | structure() 24 | st.title("Drifting") 25 | drift = wb.get_drifting_metrics(model_id) 26 | # Isolate timeseties parts from the drift object 27 | value_df, drift_df = export_drift_timeseries(drift) 28 | # Keep the columns except the time/index 29 | df_columns = value_df.drop("index", axis=1).columns 30 | # We have 2 different representations 31 | common_tab, sep_tab = st.tabs( 32 | ["Common representation", "Separeted representation"] 33 | ) 34 | 35 | with sep_tab: 36 | # Create a graph for each column/variable 37 | for column in df_columns: 38 | # Check if we have drift in order to mention it (as a subtitle) 39 | drift_detected = (drift_df[[column]] == True).any()[0] 40 | viz_df = value_df[[column, "index"]] 41 | viz_df.columns = ["drift_score", "time"] 42 | subtitle = "" 43 | 44 | if drift_detected: 45 | subtitle = "Drift detected" 46 | 47 | fig = px.line( 48 | viz_df, 49 | x="time", 50 | y="drift_score", 51 | title=f"{column}
{subtitle}", 52 | ) 53 | st.plotly_chart(fig) 54 | 55 | with common_tab: 56 | # If dift is detected in at least one column/variable it 57 | # mentions drift as a whole 58 | drift_detected = True in drift_df.values 59 | subtitle = "" 60 | 61 | if drift_detected: 62 | subtitle = "Drift detected" 63 | 64 | fig = px.line( 65 | value_df, 66 | x="index", 67 | y=df_columns, 68 | title=f"All variables
{subtitle}", 69 | markers=True, 70 | ) 71 | st.plotly_chart(fig) 72 | -------------------------------------------------------------------------------- /whitebox/streamlit/tabs/inferences.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | from utils.transformation import convert_inference_to_df 4 | from utils.export import structure 5 | import time 6 | 7 | import os, sys 8 | 9 | sys.path.insert(0, os.path.abspath("./")) 10 | 11 | from whitebox import Whitebox 12 | 13 | 14 | @st.cache_data(ttl=300) 15 | def highlight_rows(row: pd.DataFrame, pred_column: str): 16 | """ 17 | Part of styling function of dataframe. 18 | It highlights the rows where actual is inequal to to prediction 19 | """ 20 | actual = row.loc["actual"] 21 | pred = row.loc[pred_column] 22 | if actual != pred: 23 | color = "#8FC4C7" 24 | else: 25 | color = "" 26 | 27 | return ["background-color: {}".format(color) for r in row] 28 | 29 | 30 | def viz_inference_df(inf_df: pd.DataFrame, pred_column: str): 31 | """ 32 | Visualises the highlighted dataframe in Streamlit 33 | """ 34 | # Style and mark the columns when actual is not equal to prediction 35 | inf_df = inf_df.style.apply(lambda x: highlight_rows(x, pred_column), axis=1) 36 | st.dataframe(inf_df, width=1200, height=390) 37 | 38 | 39 | def create_inferences_tab(wb: Whitebox, model_id: str, pred_column: str) -> None: 40 | """ 41 | Creates the Inferences tab in Streamlit. 42 | It visualises the dataframe of the inferences and also spawns 43 | the explanation part based on explainability. 44 | """ 45 | with st.spinner("Loading inferences..."): 46 | structure() 47 | st.title("Inferences") 48 | inf = wb.get_inferences(model_id) 49 | 50 | if inf: 51 | inf_df = convert_inference_to_df(inf, pred_column) 52 | 53 | explain = st.checkbox("Explain inferences") 54 | if explain: 55 | col1, col2 = st.columns(2) 56 | 57 | with col1: 58 | # TODO: Add filter for dates (eg. Show data from 'last month') 59 | viz_inference_df(inf_df, pred_column) 60 | 61 | with col2: 62 | text_input = st.text_input( 63 | "Explain an inference based on id:", 64 | placeholder="an inference id", 65 | ) 66 | 67 | if text_input: 68 | with st.spinner("Loading explanations for id: " + text_input): 69 | time.sleep(8) 70 | 71 | else: 72 | viz_inference_df(inf_df, pred_column) 73 | -------------------------------------------------------------------------------- /whitebox/streamlit/tabs/overview.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from matplotlib import pyplot as plt 3 | import pandas as pd 4 | from numpy import ndarray 5 | from sklearn.metrics import ConfusionMatrixDisplay 6 | from cards import * 7 | from utils.export import structure 8 | from utils.load import load_config 9 | 10 | import os, sys 11 | 12 | sys.path.insert(0, os.path.abspath("./")) 13 | from whitebox.schemas.model import Model 14 | 15 | 16 | # TODO: Need to connect this one with the db. 17 | # Currently one shot running for training data is not supported! 18 | def create_classification_performance_metrics( 19 | base_evaluation_metrics: pd.DataFrame, 20 | ) -> None: 21 | """ 22 | Create performance metrics visualisation for classification model in Streamlit 23 | """ 24 | col1, col2, col3, col4 = st.columns(4) 25 | col1.metric( 26 | label="Accuracy", 27 | value=base_evaluation_metrics["accuracy"].iloc[0], 28 | ) 29 | 30 | col2.metric( 31 | label="Precision", 32 | value=base_evaluation_metrics["precision"].iloc[0], 33 | ) 34 | 35 | col3.metric( 36 | label="Recall", 37 | value=base_evaluation_metrics["recall"].iloc[0], 38 | ) 39 | 40 | col4.metric( 41 | label="F1", 42 | value=base_evaluation_metrics["f1"].iloc[0], 43 | ) 44 | 45 | 46 | def create_regression_performance_metrics( 47 | base_evaluation_metrics: pd.DataFrame, 48 | ) -> None: 49 | """ 50 | Create performance metrics visualisation for regression model in Streamlit 51 | """ 52 | col1, col2, col3 = st.columns(3) 53 | col1.metric( 54 | label="R2", 55 | value=base_evaluation_metrics["r_square"].iloc[0], 56 | ) 57 | 58 | col2.metric( 59 | label="MSE", 60 | value=base_evaluation_metrics["mean_squared_error"].iloc[0], 61 | ) 62 | 63 | col3.metric( 64 | label="MAE", 65 | value=base_evaluation_metrics["mean_absolute_error"].iloc[0], 66 | ) 67 | 68 | 69 | def plot_confusion_matrix(confusion_matrix: ndarray, model: Model): 70 | st.header("Confusion matrix") 71 | 72 | if model["labels"]: 73 | display_labels = list(model["labels"].values()) 74 | else: 75 | display_labels = None 76 | 77 | disp = ConfusionMatrixDisplay( 78 | confusion_matrix=confusion_matrix, 79 | display_labels=display_labels, 80 | ) 81 | fig, ax = plt.subplots(figsize=(5, 5)) 82 | disp.plot(ax=ax) 83 | st.pyplot() 84 | 85 | 86 | def create_overview_tab( 87 | model: Model, confusion_matrix: ndarray, base_evaluation_metrics: pd.DataFrame 88 | ) -> None: 89 | """ 90 | Creates the overview tab in Streamlit 91 | """ 92 | with st.spinner("Loading overview of the model..."): 93 | readme = load_config("config_readme.toml") 94 | structure() 95 | st.title("Overview") 96 | st.write(card(model["name"], model["type"], model["description"])) 97 | st.header("Performance") 98 | 99 | with st.expander("See explanation"): 100 | st.write(readme["tooltips"]["overview_performance"]) 101 | 102 | if model["type"] == "binary": 103 | create_classification_performance_metrics(base_evaluation_metrics) 104 | plot_confusion_matrix(confusion_matrix, model) 105 | 106 | elif model["type"] == "multi_class": 107 | create_classification_performance_metrics(base_evaluation_metrics) 108 | 109 | elif model["type"] == "regression": 110 | create_regression_performance_metrics(base_evaluation_metrics) 111 | -------------------------------------------------------------------------------- /whitebox/streamlit/tabs/performance.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import numpy as np 4 | from utils.export import structure 5 | from utils.transformation import ( 6 | get_dataframe_from_classification_performance_metrics, 7 | get_dataframe_from_regression_performance_metrics, 8 | ) 9 | from utils.graphs import create_line_graph 10 | 11 | import os, sys 12 | 13 | sys.path.insert(0, os.path.abspath("./")) 14 | from whitebox import Whitebox 15 | 16 | 17 | def create_performance_graphs(performance_df: pd.DataFrame, perf_column: str) -> None: 18 | """ 19 | Creates a graph based on a performance metric 20 | """ 21 | viz_df = performance_df[[perf_column, "timestamp"]] 22 | viz_df.columns = ["score (%)", "time"] 23 | mean_score = round(np.mean(viz_df["score (%)"]) * 100, 2) 24 | subtitle = str(mean_score) + " %" 25 | create_line_graph(viz_df, "time", "score (%)", perf_column, subtitle, 400, 380) 26 | 27 | 28 | def create_performance_tab(wb: Whitebox, model_id: str, model_type: str) -> None: 29 | """ 30 | Creates the performance tab in Streamlit 31 | """ 32 | 33 | with st.spinner("Loading performance of the model..."): 34 | structure() 35 | st.title("Performance") 36 | performance = wb.get_performance_metrics(model_id) 37 | 38 | if performance: 39 | # Set the graphs in two columns (side by side) 40 | col1, col2 = st.columns(2) 41 | 42 | if (model_type == "binary") | (model_type == "multi_class"): 43 | performance_df = get_dataframe_from_classification_performance_metrics( 44 | performance 45 | ) 46 | else: 47 | # For now the only case is regression 48 | performance_df = get_dataframe_from_regression_performance_metrics( 49 | performance 50 | ) 51 | # Need to keep only the metrics columns to be visualised as separeted graphs 52 | perf_columns = performance_df.drop("timestamp", axis=1).columns 53 | 54 | for i in range(len(perf_columns)): 55 | if (i % 2) == 0: 56 | with col1: 57 | create_performance_graphs(performance_df, perf_columns[i]) 58 | else: 59 | with col2: 60 | create_performance_graphs(performance_df, perf_columns[i]) 61 | -------------------------------------------------------------------------------- /whitebox/streamlit/utils/export.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | import streamlit as st 3 | 4 | 5 | def display_links(repo_link: str) -> None: 6 | """Displays a repository""" 7 | st.sidebar.markdown( 8 | f"Source code", 9 | unsafe_allow_html=True, 10 | ) 11 | 12 | 13 | def structure() -> None: 14 | """Structures the tabs in Streamlit""" 15 | st.markdown( 16 | """ 17 | 18 | """, 19 | unsafe_allow_html=True, 20 | ) 21 | 22 | 23 | def center_image() -> None: 24 | """Markdown for seting logo in the center""" 25 | st.markdown( 26 | """ 27 | 36 | """, 37 | unsafe_allow_html=True, 38 | ) 39 | 40 | 41 | def text_markdown(text: str, color: str, font_size: str) -> None: 42 | """Sets a text in specific color and font size""" 43 | st.markdown( 44 | f'

{text}

', 45 | unsafe_allow_html=True, 46 | ) 47 | -------------------------------------------------------------------------------- /whitebox/streamlit/utils/graphs.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import plotly.express as px 4 | 5 | 6 | def create_line_graph( 7 | df: pd.DataFrame, 8 | x: str, 9 | y: str, 10 | title: str, 11 | subtitle: str, 12 | height: float, 13 | width: float, 14 | markers: bool = False, 15 | ) -> None: 16 | """Plots a plotly chart in Streamlit""" 17 | fig = px.line( 18 | df, 19 | x=x, 20 | y=y, 21 | title=f"{title}
{subtitle}", 22 | height=height, 23 | width=width, 24 | ) 25 | st.plotly_chart(fig) 26 | -------------------------------------------------------------------------------- /whitebox/streamlit/utils/load.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from typing import Tuple, Dict, Any 3 | import toml 4 | from pathlib import Path 5 | from PIL import Image 6 | 7 | 8 | @st.cache_data(ttl=300) 9 | def load_config(config_readme_filename: str) -> Dict[str, Any]: 10 | """Loads configuration files. 11 | 12 | Parameters 13 | ---------- 14 | config_readme_filename : str 15 | Filename of readme configuration file. 16 | 17 | Returns 18 | ------- 19 | dict 20 | Lib configuration file. 21 | dict 22 | Readme configuration file. 23 | """ 24 | config_readme = toml.load( 25 | Path(f"whitebox/streamlit/config/{config_readme_filename}") 26 | ) 27 | return dict(config_readme) 28 | 29 | 30 | @st.cache_data(ttl=300) 31 | def load_image(image_name: str): 32 | """Displays an image. 33 | 34 | Parameters 35 | ---------- 36 | image_name : str 37 | Local path of the image. 38 | 39 | Returns 40 | ------- 41 | Image 42 | Image to be displayed. 43 | """ 44 | return Image.open(f"whitebox/streamlit/references/{image_name}") 45 | 46 | 47 | def local_css(file_name: str): 48 | with open(file_name) as f: 49 | st.markdown(f"", unsafe_allow_html=True) 50 | -------------------------------------------------------------------------------- /whitebox/streamlit/utils/style.css: -------------------------------------------------------------------------------- 1 | div[data-testid="stExpander"] div[role="button"] p { 2 | font-size: 1.5rem; 3 | } -------------------------------------------------------------------------------- /whitebox/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/tests/__init__.py -------------------------------------------------------------------------------- /whitebox/tests/unit_tests/test_unit.py: -------------------------------------------------------------------------------- 1 | from whitebox.cron_tasks.shared import change_timestamp 2 | from datetime import datetime 3 | 4 | 5 | class TestNodes: 6 | def test_round_timestamp(self): 7 | timestamp = datetime(2023, 3, 7, 15, 34, 23) 8 | start_time = datetime(2023, 3, 6) 9 | 10 | assert change_timestamp(timestamp, start_time, 15, "T") == datetime( 11 | 2023, 3, 7, 15, 45 12 | ) 13 | assert change_timestamp(timestamp, start_time, 5, "H") == datetime( 14 | 2023, 3, 7, 16, 0 15 | ) 16 | assert change_timestamp(timestamp, start_time, 2, "D") == datetime(2023, 3, 8) 17 | assert change_timestamp(timestamp, start_time, 1, "W") == datetime(2023, 3, 13) 18 | -------------------------------------------------------------------------------- /whitebox/tests/utils/maps.py: -------------------------------------------------------------------------------- 1 | v1_test_order_map = [ 2 | "health", 3 | "models_no_api_key", 4 | "models_wrong_api_key", 5 | "cron_tasks_run_no_models", 6 | "models_create", 7 | "models_get_all", 8 | "models_get", 9 | "models_update", 10 | "cron_tasks_run_no_inference", 11 | "inference_rows_create", 12 | "inference_rows_create_many", 13 | "inference_rows_get_model's_all", 14 | "inference_rows_get", 15 | "dataset_rows_wrong_training_dataset", 16 | "dataset_rows_create", 17 | "dataset_rows_get_model's_all", 18 | "inference_rows_xai", 19 | "model_monitor_create", 20 | "model_monitors_get_model_all", 21 | "model_monitor_update", 22 | "cron_tasks_run_ok", 23 | "performance_metrics_get_model_all", 24 | "drifting_metrics_get_model_all", 25 | "model_integrity_metrics_get_model_all", 26 | "alerts_get", 27 | "inference_rows_create_many_after_x_time", 28 | "cron_tasks_run_after_x_time", 29 | "drifting_metrics_get_binary_model_after_x_time", 30 | "model_monitor_delete", 31 | "models_delete", 32 | # SDK tests 33 | "sdk_init", 34 | "sdk_create_model", 35 | "sdk_get_model", 36 | "sdk_delete_model", 37 | "sdk_log_training_dataset", 38 | "sdk_log_inferences", 39 | "sdk_create_model_monitor", 40 | "sdk_update_model_monitor", 41 | "sdk_delete_model_monitor", 42 | "sdk_get_alerts", 43 | "sdk_get_drifting_metrics", 44 | "sdk_get_descriptive_statistics", 45 | "sdk_get_performance_metrics", 46 | "sdk_get_xai_row", 47 | ] 48 | -------------------------------------------------------------------------------- /whitebox/tests/v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/tests/v1/__init__.py -------------------------------------------------------------------------------- /whitebox/tests/v1/conftest.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from fastapi.testclient import TestClient 3 | from pytest import fixture 4 | from sqlalchemy.orm import close_all_sessions 5 | from whitebox import crud 6 | from whitebox.core.settings import get_settings 7 | from whitebox.entities.Base import Base 8 | from whitebox.main import app 9 | from whitebox.sdk.whitebox import Whitebox 10 | from whitebox.tests.utils.maps import v1_test_order_map 11 | from whitebox.utils.passwords import decrypt_api_key 12 | from whitebox.core.db import SessionLocal, engine 13 | 14 | settings = get_settings() 15 | 16 | 17 | def get_order_number(task): 18 | return v1_test_order_map.index(task) 19 | 20 | 21 | @fixture(scope="session") 22 | def client(): 23 | with TestClient(app) as client: 24 | yield client 25 | 26 | 27 | @fixture(scope="session") 28 | def api_key(): 29 | db = SessionLocal() 30 | 31 | user = crud.users.get_first_by_filter(db=db, username="admin") 32 | api_key = ( 33 | decrypt_api_key(user.api_key, settings.SECRET_KEY.encode()) 34 | if settings.SECRET_KEY 35 | else user.api_key 36 | ) 37 | 38 | yield api_key 39 | 40 | 41 | @fixture(scope="session", autouse=True) 42 | def drop_db(): 43 | yield 44 | # Removes the folder "test_model" with the models trained during testing 45 | test_model_path = settings.MODEL_PATH 46 | shutil.rmtree(test_model_path) 47 | 48 | close_all_sessions() 49 | # Drops all test database tables 50 | Base.metadata.drop_all(engine) 51 | 52 | 53 | class TestsState: 54 | user: dict = {} 55 | model_binary: dict = {} 56 | model_multi: dict = {} 57 | model_multi_2: dict = {} 58 | model_multi_3: dict = {} 59 | model_regression: dict = {} 60 | inference_row_multi: dict = {} 61 | inference_row_binary: dict = {} 62 | concept_drift_monitor: dict = {} 63 | 64 | 65 | state = TestsState() 66 | 67 | 68 | class TestsSDKState: 69 | wb: Whitebox 70 | 71 | 72 | state_sdk = TestsSDKState() 73 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_alerts.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox import schemas 3 | from whitebox.tests.v1.conftest import get_order_number, state 4 | from fastapi import status 5 | 6 | 7 | @pytest.mark.order(get_order_number("alerts_get")) 8 | def test_alerts_get_model_all(client, api_key): 9 | response_all = client.get( 10 | f"/v1/alerts", 11 | headers={"api-key": api_key}, 12 | ) 13 | 14 | response_model_all = client.get( 15 | f"/v1/alerts?model_id={state.model_multi['id']}", 16 | headers={"api-key": api_key}, 17 | ) 18 | 19 | response_wrong_model = client.get( 20 | f"/v1/alerts?model_id=wrong_model_id", 21 | headers={"api-key": api_key}, 22 | ) 23 | 24 | assert len(response_all.json()) == 6 25 | 26 | assert response_all.status_code == status.HTTP_200_OK 27 | assert response_model_all.status_code == status.HTTP_200_OK 28 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 29 | 30 | validated = [schemas.Alert(**m) for m in response_all.json()] 31 | validated = [schemas.Alert(**m) for m in response_model_all.json()] 32 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_cron_tasks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox.tests.v1.conftest import get_order_number 3 | from fastapi import status 4 | 5 | 6 | @pytest.mark.order(get_order_number("cron_tasks_run_no_models")) 7 | def test_cron_tasks_no_models(client): 8 | response = client.post("/v1/cron-tasks/run") 9 | assert response.status_code == status.HTTP_200_OK 10 | 11 | 12 | @pytest.mark.order(get_order_number("cron_tasks_run_no_inference")) 13 | def test_cron_tasks_no_inference(client): 14 | response = client.post("/v1/cron-tasks/run") 15 | assert response.status_code == status.HTTP_200_OK 16 | 17 | 18 | @pytest.mark.order(get_order_number("cron_tasks_run_ok")) 19 | def test_cron_tasks_ok(client): 20 | response = client.post("/v1/cron-tasks/run") 21 | assert response.status_code == status.HTTP_200_OK 22 | 23 | 24 | @pytest.mark.order(get_order_number("cron_tasks_run_after_x_time")) 25 | def test_cron_tasks_run_after_x_time(client): 26 | response = client.post("/v1/cron-tasks/run") 27 | assert response.status_code == status.HTTP_200_OK 28 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_dataset_rows.py: -------------------------------------------------------------------------------- 1 | from whitebox.tests.v1.mock_data import ( 2 | dataset_rows_single_row_column_payload, 3 | dataset_rows_no_prediction_column_payload, 4 | dataset_rows_one_prediction_value_payload, 5 | dataset_rows_create_multi_class_payload, 6 | dataset_rows_create_binary_payload, 7 | dataset_rows_create_wrong_model_payload, 8 | dataset_rows_create_reg_payload, 9 | ) 10 | import pytest 11 | from whitebox import schemas 12 | from whitebox.tests.v1.conftest import get_order_number, state 13 | from fastapi import status 14 | 15 | 16 | @pytest.mark.order(get_order_number("dataset_rows_wrong_training_dataset")) 17 | def test_dataset_rows_wrong_training_data(client, api_key): 18 | response_single_row = client.post( 19 | "/v1/dataset-rows", 20 | json=list( 21 | map( 22 | lambda x: {**x, "model_id": state.model_multi["id"]}, 23 | dataset_rows_single_row_column_payload, 24 | ) 25 | ), 26 | headers={"api-key": api_key}, 27 | ) 28 | 29 | response_no_prediction = client.post( 30 | "/v1/dataset-rows", 31 | json=list( 32 | map( 33 | lambda x: {**x, "model_id": state.model_multi["id"]}, 34 | dataset_rows_no_prediction_column_payload, 35 | ) 36 | ), 37 | headers={"api-key": api_key}, 38 | ) 39 | 40 | response_one_prediction_value = client.post( 41 | "/v1/dataset-rows", 42 | json=list( 43 | map( 44 | lambda x: {**x, "model_id": state.model_multi["id"]}, 45 | dataset_rows_one_prediction_value_payload, 46 | ) 47 | ), 48 | headers={"api-key": api_key}, 49 | ) 50 | 51 | assert response_single_row.status_code == status.HTTP_400_BAD_REQUEST 52 | assert response_no_prediction.status_code == status.HTTP_400_BAD_REQUEST 53 | assert response_one_prediction_value.status_code == status.HTTP_400_BAD_REQUEST 54 | 55 | 56 | @pytest.mark.order(get_order_number("dataset_rows_create")) 57 | def test_dataset_row_create_many(client, api_key): 58 | response_model_multi = client.post( 59 | "/v1/dataset-rows", 60 | json=list( 61 | map( 62 | lambda x: {**x, "model_id": state.model_multi["id"]}, 63 | dataset_rows_create_multi_class_payload, 64 | ) 65 | ), 66 | headers={"api-key": api_key}, 67 | ) 68 | 69 | response_model_binary = client.post( 70 | "/v1/dataset-rows", 71 | json=list( 72 | map( 73 | lambda x: {**x, "model_id": state.model_binary["id"]}, 74 | dataset_rows_create_binary_payload, 75 | ) 76 | ), 77 | headers={"api-key": api_key}, 78 | ) 79 | 80 | response_model_reg = client.post( 81 | "/v1/dataset-rows", 82 | json=list( 83 | map( 84 | lambda x: {**x, "model_id": state.model_regression["id"]}, 85 | dataset_rows_create_reg_payload, 86 | ) 87 | ), 88 | headers={"api-key": api_key}, 89 | ) 90 | 91 | response_wrong_model = client.post( 92 | "/v1/dataset-rows", 93 | json=(dataset_rows_create_wrong_model_payload), 94 | headers={"api-key": api_key}, 95 | ) 96 | 97 | assert response_model_multi.status_code == status.HTTP_201_CREATED 98 | assert response_model_binary.status_code == status.HTTP_201_CREATED 99 | assert response_model_reg.status_code == status.HTTP_201_CREATED 100 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 101 | validated = [schemas.DatasetRow(**m) for m in response_model_multi.json()] 102 | validated = [schemas.DatasetRow(**m) for m in response_model_binary.json()] 103 | validated = [schemas.DatasetRow(**m) for m in response_model_reg.json()] 104 | 105 | 106 | @pytest.mark.order(get_order_number("dataset_rows_get_model's_all")) 107 | def test_dataset_row_get_models_all(client, api_key): 108 | response_model_multi = client.get( 109 | f"/v1/dataset-rows?model_id={state.model_multi['id']}", 110 | headers={"api-key": api_key}, 111 | ) 112 | response_model_binary = client.get( 113 | f"/v1/dataset-rows?model_id={state.model_binary['id']}", 114 | headers={"api-key": api_key}, 115 | ) 116 | response_model_not_found = client.get( 117 | f"/v1/dataset-rows?model_id=wrong_model_id", 118 | headers={"api-key": api_key}, 119 | ) 120 | 121 | assert response_model_multi.status_code == status.HTTP_200_OK 122 | assert response_model_binary.status_code == status.HTTP_200_OK 123 | assert response_model_not_found.status_code == status.HTTP_404_NOT_FOUND 124 | validated = [schemas.DatasetRow(**m) for m in response_model_multi.json()] 125 | validated = [schemas.DatasetRow(**m) for m in response_model_binary.json()] 126 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_drifting_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox import schemas 3 | from whitebox.tests.v1.conftest import get_order_number, state 4 | from fastapi import status 5 | 6 | 7 | @pytest.mark.order(get_order_number("drifting_metrics_get_model_all")) 8 | def test_drifting_metric_get_model_all(client, api_key): 9 | response_multi = client.get( 10 | f"/v1/drifting-metrics?model_id={state.model_multi['id']}", 11 | headers={"api-key": api_key}, 12 | ) 13 | 14 | response_binary = client.get( 15 | f"/v1/drifting-metrics?model_id={state.model_binary['id']}", 16 | headers={"api-key": api_key}, 17 | ) 18 | response_wrong_model = client.get( 19 | f"/v1/drifting-metrics?model_id=wrong_model_id", 20 | headers={"api-key": api_key}, 21 | ) 22 | 23 | response_binary_json = response_binary.json() 24 | response_multi_json = response_multi.json() 25 | 26 | assert len(response_multi_json) == 1 27 | assert len(response_binary_json) == 1 28 | 29 | assert response_multi_json[0]["timestamp"] == "2023-03-06T12:15:00" 30 | assert response_binary_json[0]["timestamp"] == "2023-03-07T00:00:00" 31 | 32 | assert response_multi.status_code == status.HTTP_200_OK 33 | assert response_binary.status_code == status.HTTP_200_OK 34 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 35 | 36 | validated = [schemas.DriftingMetric(**m) for m in response_multi_json] 37 | validated = [schemas.DriftingMetric(**m) for m in response_binary_json] 38 | 39 | 40 | @pytest.mark.order(get_order_number("drifting_metrics_get_binary_model_after_x_time")) 41 | def test_drifting_metrics_get_binary_model_after_x_time(client, api_key): 42 | response_binary = client.get( 43 | f"/v1/drifting-metrics?model_id={state.model_binary['id']}", 44 | headers={"api-key": api_key}, 45 | ) 46 | 47 | response_binary_json = response_binary.json() 48 | 49 | assert len(response_binary_json) == 1 50 | 51 | assert response_binary_json[0]["timestamp"] == "2023-03-07T00:00:00" 52 | 53 | assert response_binary.status_code == status.HTTP_200_OK 54 | 55 | validated = [schemas.DriftingMetric(**m) for m in response_binary_json] 56 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_errors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox.tests.v1.conftest import get_order_number 3 | from fastapi import status 4 | 5 | 6 | @pytest.mark.order(get_order_number("models_no_api_key")) 7 | def test_model_no_api_key(client): 8 | response = client.get("/v1/models") 9 | 10 | assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY 11 | 12 | 13 | @pytest.mark.order(get_order_number("models_wrong_api_key")) 14 | def test_model_wrong_api_key(client): 15 | response = client.get( 16 | "/v1/models", 17 | headers={"api-key": "1234567890"}, 18 | ) 19 | 20 | assert response.status_code == status.HTTP_401_UNAUTHORIZED 21 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_health.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox import schemas 3 | from whitebox.tests.v1.conftest import get_order_number 4 | from fastapi import status 5 | 6 | 7 | @pytest.mark.order(get_order_number("health")) 8 | def test_health(client): 9 | response = client.get("/v1/health") 10 | assert response.status_code == status.HTTP_200_OK 11 | validated = schemas.HealthCheck(**response.json()) 12 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_model_integrity_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox import schemas 3 | from whitebox.tests.v1.conftest import get_order_number, state 4 | from fastapi import status 5 | 6 | 7 | @pytest.mark.order(get_order_number("model_integrity_metrics_get_model_all")) 8 | def test_model_integrity_metric_get_model_all(client, api_key): 9 | response_multi = client.get( 10 | f"/v1/model-integrity-metrics?model_id={state.model_multi['id']}", 11 | headers={"api-key": api_key}, 12 | ) 13 | response_binary = client.get( 14 | f"/v1/model-integrity-metrics?model_id={state.model_binary['id']}", 15 | headers={"api-key": api_key}, 16 | ) 17 | response_wrong_model = client.get( 18 | f"/v1/model-integrity-metrics?model_id=wrong_model_id", 19 | headers={"api-key": api_key}, 20 | ) 21 | 22 | assert len(response_multi.json()) == 1 23 | assert len(response_binary.json()) == 1 24 | 25 | assert response_multi.status_code == status.HTTP_200_OK 26 | assert response_binary.status_code == status.HTTP_200_OK 27 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 28 | 29 | validated = [schemas.ModelIntegrityMetric(**m) for m in response_multi.json()] 30 | validated = [schemas.ModelIntegrityMetric(**m) for m in response_binary.json()] 31 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_models.py: -------------------------------------------------------------------------------- 1 | from whitebox.tests.v1.mock_data import ( 2 | model_binary_create_payload, 3 | model_multi_create_payload, 4 | model_multi_2_create_payload, 5 | model_multi_3_create_payload, 6 | model_regression_create_payload, 7 | model_update_payload, 8 | ) 9 | import pytest 10 | from whitebox import schemas 11 | from whitebox.tests.v1.conftest import get_order_number, state 12 | from fastapi import status 13 | 14 | 15 | @pytest.mark.order(get_order_number("models_create")) 16 | def test_model_create(client, api_key): 17 | response_binary = client.post( 18 | "/v1/models", 19 | json={**model_binary_create_payload}, 20 | headers={"api-key": api_key}, 21 | ) 22 | 23 | response_multi = client.post( 24 | "/v1/models", 25 | json={**model_multi_create_payload}, 26 | headers={"api-key": api_key}, 27 | ) 28 | 29 | response_multi_2 = client.post( 30 | "/v1/models", 31 | json={**model_multi_2_create_payload}, 32 | headers={"api-key": api_key}, 33 | ) 34 | 35 | response_multi_3 = client.post( 36 | "/v1/models", 37 | json={**model_multi_3_create_payload}, 38 | headers={"api-key": api_key}, 39 | ) 40 | 41 | response_regression = client.post( 42 | "/v1/models", 43 | json={**model_regression_create_payload}, 44 | headers={"api-key": api_key}, 45 | ) 46 | 47 | state.model_binary = response_binary.json() 48 | state.model_multi = response_multi.json() 49 | state.model_multi_2 = response_multi_2.json() 50 | state.model_multi_3 = response_multi_3.json() 51 | state.model_regression = response_regression.json() 52 | 53 | assert response_binary.status_code == status.HTTP_201_CREATED 54 | assert response_multi.status_code == status.HTTP_201_CREATED 55 | assert response_multi_2.status_code == status.HTTP_201_CREATED 56 | assert response_multi_3.status_code == status.HTTP_201_CREATED 57 | assert response_regression.status_code == status.HTTP_201_CREATED 58 | 59 | validated = schemas.Model(**response_binary.json()) 60 | validated = schemas.Model(**response_multi.json()) 61 | validated = schemas.Model(**response_multi_2.json()) 62 | validated = schemas.Model(**response_multi_3.json()) 63 | validated = schemas.Model(**response_regression.json()) 64 | 65 | 66 | @pytest.mark.order(get_order_number("models_get_all")) 67 | def test_model_get_all(client, api_key): 68 | response = client.get(f"/v1/models", headers={"api-key": api_key}) 69 | assert response.status_code == status.HTTP_200_OK 70 | validated = [schemas.Model(**m) for m in response.json()] 71 | 72 | 73 | @pytest.mark.order(get_order_number("models_get")) 74 | def test_model_get(client, api_key): 75 | response = client.get( 76 | f"/v1/models/{state.model_multi['id']}", headers={"api-key": api_key} 77 | ) 78 | response_wrong_model = client.get( 79 | f"/v1/models/wrong_model_id", headers={"api-key": api_key} 80 | ) 81 | 82 | assert response.status_code == status.HTTP_200_OK 83 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 84 | 85 | validated = schemas.Model(**response.json()) 86 | 87 | 88 | @pytest.mark.order(get_order_number("models_update")) 89 | def test_model_update(client, api_key): 90 | response = client.put( 91 | f"/v1/models/{state.model_multi['id']}", 92 | json=model_update_payload, 93 | headers={"api-key": api_key}, 94 | ) 95 | response_wrong_model = client.put( 96 | f"/v1/models/wrong_model_id", 97 | json=model_update_payload, 98 | headers={"api-key": api_key}, 99 | ) 100 | 101 | assert response.status_code == status.HTTP_200_OK 102 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 103 | 104 | validated = schemas.Model(**response.json()) 105 | 106 | 107 | @pytest.mark.order(get_order_number("models_delete")) 108 | def test_model_delete(client, api_key): 109 | response_binary = client.delete( 110 | f"/v1/models/{state.model_binary['id']}", headers={"api-key": api_key} 111 | ) 112 | response_multi = client.delete( 113 | f"/v1/models/{state.model_multi['id']}", headers={"api-key": api_key} 114 | ) 115 | response_multi_2 = client.delete( 116 | f"/v1/models/{state.model_multi_2['id']}", headers={"api-key": api_key} 117 | ) 118 | response_multi_3 = client.delete( 119 | f"/v1/models/{state.model_multi_3['id']}", headers={"api-key": api_key} 120 | ) 121 | response_regression = client.delete( 122 | f"/v1/models/{state.model_regression['id']}", headers={"api-key": api_key} 123 | ) 124 | response_no_model = client.delete( 125 | f"/v1/models/{state.model_binary['id']}", headers={"api-key": api_key} 126 | ) 127 | 128 | assert response_binary.status_code == status.HTTP_200_OK 129 | assert response_multi.status_code == status.HTTP_200_OK 130 | assert response_multi_2.status_code == status.HTTP_200_OK 131 | assert response_multi_3.status_code == status.HTTP_200_OK 132 | assert response_regression.status_code == status.HTTP_200_OK 133 | assert response_no_model.status_code == status.HTTP_404_NOT_FOUND 134 | -------------------------------------------------------------------------------- /whitebox/tests/v1/test_performance_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from whitebox import schemas 3 | from whitebox.tests.v1.conftest import get_order_number, state 4 | from fastapi import status 5 | 6 | 7 | @pytest.mark.order(get_order_number("performance_metrics_get_model_all")) 8 | def test_performance_metric_get_model_all(client, api_key): 9 | response_multi = client.get( 10 | f"/v1/performance-metrics?model_id={state.model_multi['id']}", 11 | headers={"api-key": api_key}, 12 | ) 13 | response_binary = client.get( 14 | f"/v1/performance-metrics?model_id={state.model_binary['id']}", 15 | headers={"api-key": api_key}, 16 | ) 17 | response_reg = client.get( 18 | f"/v1/performance-metrics?model_id={state.model_regression['id']}", 19 | headers={"api-key": api_key}, 20 | ) 21 | response_wrong_model = client.get( 22 | f"/v1/performance-metrics?model_id=wrong_model_id", 23 | headers={"api-key": api_key}, 24 | ) 25 | 26 | assert response_multi.status_code == status.HTTP_200_OK 27 | assert response_binary.status_code == status.HTTP_200_OK 28 | assert response_reg.status_code == status.HTTP_200_OK 29 | assert response_wrong_model.status_code == status.HTTP_404_NOT_FOUND 30 | 31 | validated = [schemas.MultiClassificationMetrics(**m) for m in response_multi.json()] 32 | validated = [ 33 | schemas.BinaryClassificationMetrics(**m) for m in response_binary.json() 34 | ] 35 | validated = [schemas.RegressionMetrics(**m) for m in response_reg.json()] 36 | -------------------------------------------------------------------------------- /whitebox/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/squaredev-io/whitebox/9524d86fa07a135536811a8bc70dcc2f5eabc468/whitebox/utils/__init__.py -------------------------------------------------------------------------------- /whitebox/utils/errors.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from fastapi import status 3 | from fastapi.encoders import jsonable_encoder 4 | from fastapi.responses import JSONResponse 5 | from starlette.exceptions import HTTPException as StarletteHTTPException 6 | from fastapi.exceptions import HTTPException 7 | from starlette.requests import Request 8 | from whitebox.schemas.utils import ErrorProps 9 | from whitebox.utils.logger import log 10 | 11 | 12 | class CustomError(BaseException): 13 | async def http_exception_handler( 14 | self, _: Request, exc: StarletteHTTPException 15 | ) -> ErrorProps: 16 | log.error(f"{exc.status_code}: {exc.detail}") 17 | return JSONResponse( 18 | status_code=exc.status_code, 19 | content=jsonable_encoder( 20 | {"error": f"{str(exc.detail)}", "status_code": exc.status_code} 21 | ), 22 | ) 23 | 24 | async def validation_exception_handler( 25 | self, 26 | _: Request, 27 | exc: HTTPException, 28 | ) -> ErrorProps: 29 | responsible_value = exc.errors()[0]["loc"][-1] 30 | reason = exc.errors()[0]["msg"] 31 | log.error( 32 | f"{status.HTTP_422_UNPROCESSABLE_ENTITY}: ({str(responsible_value)}) {str(reason)}" 33 | ) 34 | return JSONResponse( 35 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, 36 | content=jsonable_encoder( 37 | { 38 | "error": f"({str(responsible_value)}) {str(reason)}", 39 | "status_code": status.HTTP_422_UNPROCESSABLE_ENTITY, 40 | } 41 | ), 42 | ) 43 | 44 | def bad_request(self, msg: str = "Bad request") -> ErrorProps: 45 | log.error(f"{status.HTTP_400_BAD_REQUEST}: {str(msg)}") 46 | return JSONResponse( 47 | status_code=status.HTTP_400_BAD_REQUEST, 48 | content=jsonable_encoder( 49 | {"error": str(msg), "status_code": status.HTTP_400_BAD_REQUEST} 50 | ), 51 | ) 52 | 53 | def not_found(self, msg: str = "Content not found") -> ErrorProps: 54 | log.error(f"{status.HTTP_404_NOT_FOUND}: {str(msg)}") 55 | return JSONResponse( 56 | status_code=status.HTTP_404_NOT_FOUND, 57 | content=jsonable_encoder( 58 | {"error": str(msg), "status_code": status.HTTP_404_NOT_FOUND} 59 | ), 60 | ) 61 | 62 | 63 | errors = CustomError() 64 | 65 | 66 | def add_error_responses(status_codes) -> List[ErrorProps]: 67 | """ 68 | For the schema to work the part after schemas/ should correspond to a title error schema in whitebox/api/app/v1/docs.py 69 | """ 70 | 71 | error_responses = { 72 | 400: { 73 | "description": "Bad Request", 74 | "content": { 75 | "application/json": { 76 | "schema": {"$ref": "#/components/schemas/BadRequest"} 77 | } 78 | }, 79 | }, 80 | 401: { 81 | "description": "Authorization Error", 82 | "content": { 83 | "application/json": { 84 | "schema": {"$ref": "#/components/schemas/AuthorizationError"} 85 | } 86 | }, 87 | }, 88 | 404: { 89 | "description": "Not Found Error", 90 | "content": { 91 | "application/json": { 92 | "schema": {"$ref": "#/components/schemas/NotFoundError"} 93 | } 94 | }, 95 | }, 96 | 409: { 97 | "description": "Conflict Error", 98 | "content": { 99 | "application/json": { 100 | "schema": {"$ref": "#/components/schemas/ConflictError"} 101 | } 102 | }, 103 | }, 104 | 410: { 105 | "description": "Content Gone", 106 | "content": { 107 | "application/json": { 108 | "schema": {"$ref": "#/components/schemas/ContenGone"} 109 | } 110 | }, 111 | }, 112 | 422: { 113 | "description": "Validation Error", 114 | "content": { 115 | "application/json": { 116 | "schema": {"$ref": "#/components/schemas/HTTPValidationError"} 117 | } 118 | }, 119 | }, 120 | } 121 | responses = {} 122 | for code in status_codes: 123 | responses[code] = error_responses[code] 124 | return responses 125 | -------------------------------------------------------------------------------- /whitebox/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class TaskAlreadyRunningException(Exception): 2 | pass 3 | 4 | 5 | class TaskNotFoundException(Exception): 6 | pass 7 | 8 | 9 | class TaskNotRunningException(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /whitebox/utils/id_gen.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | 4 | def generate_uuid(): 5 | return str(uuid4()) 6 | -------------------------------------------------------------------------------- /whitebox/utils/logger.py: -------------------------------------------------------------------------------- 1 | from colorama import Fore, Back, Style 2 | import logging 3 | 4 | 5 | class Logger: 6 | def info(self, info): 7 | print(Fore.LIGHTBLUE_EX + "INFO" + Fore.BLACK + ":" + 4 * " ", info) 8 | 9 | def error(self, error): 10 | print(Fore.RED + "ERROR" + Fore.BLACK + ":" + 3 * " ", error) 11 | 12 | def success(self, msg): 13 | print(Fore.GREEN + "SUCCESS" + Fore.BLACK + ":" + 1 * " ", msg) 14 | 15 | 16 | log = Logger() 17 | 18 | 19 | logging.basicConfig(level="INFO") 20 | 21 | cronLogger = logging.getLogger("cron") 22 | -------------------------------------------------------------------------------- /whitebox/utils/passwords.py: -------------------------------------------------------------------------------- 1 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 2 | from cryptography.hazmat.primitives import padding 3 | from cryptography.hazmat.backends import default_backend 4 | import secrets 5 | from whitebox.core.settings import get_settings 6 | 7 | 8 | settings = get_settings() 9 | 10 | 11 | def to_utf8(ps): 12 | return ps.encode("utf-8") 13 | 14 | 15 | def encrypt_api_key(password: str, key: bytes) -> str: 16 | iv = secrets.token_bytes(16) 17 | cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) 18 | encryptor = cipher.encryptor() 19 | padder = padding.PKCS7(128).padder() 20 | padded_data = padder.update(to_utf8(password)) + padder.finalize() 21 | ct = encryptor.update(padded_data) + encryptor.finalize() 22 | return (iv + ct).hex() 23 | 24 | 25 | def decrypt_api_key(password: str, key: bytes) -> str: 26 | ct = bytes.fromhex(password) 27 | iv = ct[:16] 28 | ct = ct[16:] 29 | cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) 30 | decryptor = cipher.decryptor() 31 | pt = decryptor.update(ct) + decryptor.finalize() 32 | unpadder = padding.PKCS7(128).unpadder() 33 | data = unpadder.update(pt) + unpadder.finalize() 34 | return data.decode("utf-8") 35 | 36 | 37 | def passwords_match(api_key_in_db: str, api_key_in_headers: str): 38 | if not settings.SECRET_KEY: 39 | return api_key_in_db == api_key_in_headers 40 | 41 | api_key = decrypt_api_key(api_key_in_db, settings.SECRET_KEY.encode()) 42 | return api_key == api_key_in_headers 43 | --------------------------------------------------------------------------------