├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── documentation.md │ └── feature_request.md ├── changelog_config.json └── workflows │ ├── pre-commit-check.yml │ ├── run-tests-cpu.yaml │ ├── run-tests-gpu.yaml │ └── workflow.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── docs ├── api.md ├── favicon.png ├── index.md └── logo.png ├── examples ├── flat.ipynb ├── language.ipynb └── sequential.ipynb ├── mkdocs.yml ├── mostlyai └── engine │ ├── __init__.py │ ├── _common.py │ ├── _dtypes.py │ ├── _encoding_types │ ├── __init__.py │ ├── language │ │ ├── __init__.py │ │ ├── categorical.py │ │ ├── datetime.py │ │ ├── numeric.py │ │ └── text.py │ └── tabular │ │ ├── __init__.py │ │ ├── categorical.py │ │ ├── character.py │ │ ├── datetime.py │ │ ├── itt.py │ │ ├── lat_long.py │ │ └── numeric.py │ ├── _language │ ├── __init__.py │ ├── common.py │ ├── encoding.py │ ├── engine │ │ ├── __init__.py │ │ ├── base.py │ │ ├── hf_engine.py │ │ └── vllm_engine.py │ ├── generation.py │ ├── lstm.py │ ├── tokenizer_utils.py │ ├── training.py │ └── xgrammar_utils.py │ ├── _memory.py │ ├── _tabular │ ├── __init__.py │ ├── argn.py │ ├── common.py │ ├── encoding.py │ ├── fairness.py │ ├── generation.py │ └── training.py │ ├── _training_utils.py │ ├── _workspace.py │ ├── analysis.py │ ├── domain.py │ ├── encoding.py │ ├── generation.py │ ├── logging.py │ ├── random_state.py │ ├── splitting.py │ └── training.py ├── pyproject.toml ├── tests ├── __init__.py ├── end_to_end │ ├── __init__.py │ ├── conftest.py │ ├── test_language.py │ ├── test_numeric.py │ ├── test_tabular_flat.py │ ├── test_tabular_sequential.py │ └── test_tabular_sequential_context.py └── unit │ ├── __init__.py │ ├── encoding_types │ ├── __init__.py │ ├── language │ │ ├── __init__.py │ │ ├── test_categorical.py │ │ ├── test_datetime.py │ │ └── test_numeric.py │ └── tabular │ │ ├── __init__.py │ │ ├── test_categorical.py │ │ ├── test_character.py │ │ ├── test_datetime.py │ │ ├── test_itt.py │ │ ├── test_lat_long.py │ │ └── test_numeric.py │ ├── fixtures │ └── workspace │ │ ├── all │ │ ├── ModelStore │ │ │ ├── ctx-meta │ │ │ │ ├── encoding-types.json │ │ │ │ └── keys.json │ │ │ ├── ctx-stats │ │ │ │ ├── part.000000-trn.json │ │ │ │ ├── part.000000-val.json │ │ │ │ └── stats.json │ │ │ ├── model-data │ │ │ │ ├── model-configs.json │ │ │ │ └── model-weights.pt │ │ │ ├── tgt-meta │ │ │ │ ├── encoding-types.json │ │ │ │ └── keys.json │ │ │ └── tgt-stats │ │ │ │ ├── part.000000-trn.json │ │ │ │ ├── part.000000-val.json │ │ │ │ └── stats.json │ │ ├── OriginalData │ │ │ ├── ctx-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ │ ├── encoded-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ │ └── tgt-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ └── SyntheticData │ │ │ ├── part.000001.parquet │ │ │ └── part.000002.parquet │ │ └── some │ │ ├── ModelStore │ │ └── tgt-meta │ │ │ ├── encoding-types.json │ │ │ └── keys.json │ │ └── OriginalData │ │ └── tgt-data │ │ ├── part.000000-trn.parquet │ │ └── part.000000-val.parquet │ ├── test_analysis.py │ ├── test_argn.py │ ├── test_common.py │ ├── test_domain.py │ ├── test_encoding.py │ ├── test_fairness.py │ ├── test_generation.py │ ├── test_memory.py │ ├── test_splitting.py │ └── test_workspace.py └── uv.lock /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Create a report to help us reproduce and fix the bug 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | Please provide a clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Code to reproduce the behavior: 15 | ``` 16 | # All necessary imports at the beginning 17 | import pandas as pd 18 | from mostlyai import engine 19 | # A succinct reproducing example trimmed down to the essential parts: 20 | df = pd.DataFrame({'x': [1, 2, 3]}) 21 | engine.split(...) 22 | engine.analyze(...) 23 | engine.encode(...) 24 | engine.train(...) 25 | engine.generate(...) 26 | ``` 27 | 28 | **Expected behavior** 29 | A clear and concise description of what you expected to happen. 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Questions 4 | url: https://github.com/mostly-ai/mostlyai-engine/discussions 5 | about: Ask questions and discuss with other community members 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Report an issue related to the documentation. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the issue** 11 | Please provide a clear and concise description of what the issue is. 12 | 13 | **Expected behavior** 14 | A clear and concise description of what you expected to see. 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature Request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | -------------------------------------------------------------------------------- /.github/changelog_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "# What's Changed\n\n#{{CHANGELOG}}\n\n**Full Changelog**: [#{{FROM_TAG}}...#{{TO_TAG}}](#{{RELEASE_DIFF}})", 3 | "pr_template": "- #{{TITLE}} [##{{NUMBER}}](#{{URL}})", 4 | "empty_template": "No Changes", 5 | "categories": [ 6 | { 7 | "title": "## 🚀 Features", 8 | "labels": ["feat"] 9 | }, 10 | { 11 | "title": "## 🐛 Fixes", 12 | "labels": ["fix"] 13 | }, 14 | { 15 | "title": "## 📦 Uncategorized", 16 | "labels": ["chore", "build", "docs", "refactor", "style"] 17 | } 18 | ], 19 | "ignore_labels": ["bump", "ci"], 20 | "label_extractor": [ 21 | { 22 | "pattern": "^([\\w-]+)(?:\\(([^)]+)\\))?: (.+)$", 23 | "target": "$1", 24 | "on_property": "title" 25 | } 26 | ], 27 | "transformers": [ 28 | { 29 | "pattern": "^(?:[^:]+:\\s*)?(.*)$", 30 | "method": "replace", 31 | "target": "$1", 32 | "on_property": "title" 33 | } 34 | ] 35 | } 36 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-check.yml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit Check 2 | 3 | on: [workflow_call] 4 | 5 | jobs: 6 | pre-commit-check: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: '3.10' 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install pre-commit 18 | pre-commit install 19 | - name: Run pre-commit 20 | run: pre-commit run --all-files 21 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-cpu.yaml: -------------------------------------------------------------------------------- 1 | name: "[CPU] mostlyai-engine Tests" 2 | 3 | on: 4 | workflow_call: 5 | 6 | 7 | env: 8 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 9 | FORCE_COLOR: "1" 10 | 11 | jobs: 12 | run-tests-cpu-unit-sequential: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | packages: write 17 | steps: 18 | - name: Setup | Checkout 19 | uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | submodules: 'recursive' 23 | 24 | 25 | - name: Setup | uv 26 | uses: astral-sh/setup-uv@v5 27 | with: 28 | enable-cache: false 29 | python-version: '3.10' 30 | 31 | - name: Setup | Dependencies 32 | run: | 33 | uv sync --frozen --only-group dev --only-group docs 34 | uv pip install --index-strategy unsafe-first-match torch==2.6.0+cpu torchvision==0.21.0+cpu . --extra-index-url https://download.pytorch.org/whl/cpu 35 | 36 | - name: Run | Tests -> unit 37 | run: uv run --no-sync pytest tests/unit 38 | 39 | - name: Build mkdocs 40 | run: uv run --no-sync mkdocs build --strict 41 | 42 | - name: Run tests -> end_to_end -> sequential 43 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py 44 | 45 | - name: Run tests -> end_to_end -> sequential context 46 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py 47 | 48 | run-tests-cpu-end-to-end-nonsequential: 49 | runs-on: ubuntu-latest 50 | permissions: 51 | contents: read 52 | packages: write 53 | steps: 54 | - name: Setup | Checkout 55 | uses: actions/checkout@v4 56 | with: 57 | fetch-depth: 0 58 | submodules: 'recursive' 59 | 60 | 61 | - name: Setup | uv 62 | uses: astral-sh/setup-uv@v5 63 | with: 64 | enable-cache: false 65 | python-version: '3.10' 66 | 67 | - name: Setup | Dependencies 68 | run: | 69 | uv sync --frozen --only-group dev 70 | uv pip install --index-strategy unsafe-first-match torch==2.6.0+cpu torchvision==0.21.0+cpu . --extra-index-url https://download.pytorch.org/whl/cpu 71 | 72 | - name: Run tests -> end_to_end all except sequential 73 | run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ 74 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-gpu.yaml: -------------------------------------------------------------------------------- 1 | name: "[GPU] mostlyai-engine Tests" 2 | 3 | on: 4 | workflow_call: 5 | 6 | 7 | env: 8 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 9 | FORCE_COLOR: "1" 10 | 11 | jobs: 12 | run-tests-gpu: 13 | runs-on: gha-gpu-public 14 | container: 15 | image: nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 16 | options: --gpus all 17 | permissions: 18 | contents: read 19 | packages: write 20 | steps: 21 | - name: Setup | Install Git 22 | run: | 23 | apt-get update -qq 24 | apt-get install -y --no-install-recommends git build-essential 25 | 26 | - name: Setup | Checkout 27 | uses: actions/checkout@v4 28 | with: 29 | fetch-depth: 0 30 | submodules: 'recursive' 31 | 32 | - name: Setup | uv 33 | uses: astral-sh/setup-uv@v5 34 | with: 35 | enable-cache: false 36 | python-version: '3.10' 37 | 38 | - name: Setup | Dependencies 39 | run: | 40 | uv sync --frozen --only-group dev 41 | uv pip install ".[gpu]" 42 | 43 | - name: Setup | Check for available GPU-s 44 | run: nvidia-smi 45 | 46 | - name: Run tests -> end_to_end -> sequential 47 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py 48 | 49 | - name: Run tests -> end_to_end -> sequential context 50 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py 51 | 52 | - name: Run tests -> end_to_end all except sequential 53 | run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ 54 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yaml: -------------------------------------------------------------------------------- 1 | name: Complete Workflow 2 | 3 | on: [push, pull_request] 4 | 5 | env: 6 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 7 | FORCE_COLOR: "1" 8 | 9 | jobs: 10 | pre-commit-check: 11 | if: | 12 | github.event_name == 'push' || 13 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 14 | uses: ./.github/workflows/pre-commit-check.yml 15 | secrets: inherit 16 | run-tests-cpu: 17 | if: | 18 | github.event_name == 'push' || 19 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 20 | uses: ./.github/workflows/run-tests-cpu.yaml 21 | secrets: inherit 22 | run-tests-gpu: 23 | if: | 24 | github.ref == 'refs/heads/main' || 25 | startsWith(github.ref, 'refs/tags/') || 26 | contains(github.event.head_commit.message, '[gpu]') 27 | uses: ./.github/workflows/run-tests-gpu.yaml 28 | secrets: inherit 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .vscode/ 4 | .ipynb_checkpoints/ 5 | .DS_Store 6 | dist/ 7 | examples/ws-*/ 8 | LICENSE_HEADER 9 | /site/ 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | exclude: '^(examples)/' 3 | repos: 4 | - repo: local 5 | hooks: 6 | - id: generate-license-header 7 | name: Generate temporary license header file 8 | entry: | 9 | bash -c ' 10 | HEADER_CONTENT="Copyright 2025 MOSTLY AI\n\ 11 | \n\ 12 | Licensed under the Apache License, Version 2.0 (the \"License\");\n\ 13 | you may not use this file except in compliance with the License.\n\ 14 | You may obtain a copy of the License at\n\ 15 | \n\ 16 | http://www.apache.org/licenses/LICENSE-2.0\n\ 17 | \n\ 18 | Unless required by applicable law or agreed to in writing, software\n\ 19 | distributed under the License is distributed on an \"AS IS\" BASIS,\n\ 20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n\ 21 | See the License for the specific language governing permissions and\n\ 22 | limitations under the License." 23 | 24 | echo -e "$HEADER_CONTENT" > LICENSE_HEADER 25 | ' 26 | language: system 27 | - repo: https://github.com/Lucas-C/pre-commit-hooks 28 | rev: v1.5.5 29 | hooks: 30 | - id: insert-license 31 | files: \.py$ 32 | args: 33 | # - --remove-header 34 | - --license-filepath 35 | - LICENSE_HEADER 36 | - --use-current-year 37 | - repo: https://github.com/pre-commit/pre-commit-hooks 38 | rev: v5.0.0 39 | hooks: 40 | - id: end-of-file-fixer 41 | - id: trailing-whitespace 42 | - id: end-of-file-fixer 43 | - id: check-json 44 | - id: mixed-line-ending 45 | args: [--fix=lf] 46 | - repo: https://github.com/asottile/pyupgrade 47 | rev: v3.19.1 48 | hooks: 49 | - id: pyupgrade 50 | args: [--py310-plus] 51 | - repo: https://github.com/astral-sh/ruff-pre-commit 52 | rev: v0.11.6 53 | hooks: 54 | - id: ruff 55 | args: [ --fix ] 56 | - id: ruff-format 57 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Synthetic Data Engine 2 | 3 | Thanks for your interest in contributing to Synthetic Data Engine! Follow these guidelines to set up your environment and streamline your contributions. 4 | 5 | ## Setup 6 | 7 | 1. **Clone the repository**: 8 | ```bash 9 | git clone https://github.com/mostly-ai/mostlyai-engine.git 10 | cd mostlyai-engine 11 | ``` 12 | If you don’t have direct write access to `mostlyai-engine`, fork the repository first and clone your fork: 13 | ```bash 14 | git clone https://github.com//mostlyai-engine.git 15 | cd mostlyai-engine 16 | ``` 17 | 18 | 2. **Install `uv` (if not installed already)**: 19 | ```bash 20 | curl -LsSf https://astral.sh/uv/install.sh | sh 21 | ``` 22 | For alternative installation methods, visit the [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/). 23 | 24 | 3. **Create a virtual environment and install dependencies**: 25 | ```bash 26 | uv sync --frozen --extra cpu --python=3.10 # For CPU-only 27 | source .venv/bin/activate 28 | ``` 29 | If using GPU, run: 30 | ```bash 31 | uv sync --frozen --extra gpu --python=3.10 # For GPU support 32 | source .venv/bin/activate 33 | ``` 34 | 35 | 4. **Install pre-commit hooks**: 36 | ```bash 37 | pre-commit install 38 | ``` 39 | 40 | ## Development Workflow 41 | 42 | 1. **Ensure your local `main` branch is up to date**: 43 | ```bash 44 | git checkout main 45 | git reset --hard origin/main 46 | git pull origin main 47 | ``` 48 | 49 | 2. **Create a new feature or bugfix branch**: 50 | ```bash 51 | git checkout -b my-feature-branch 52 | ``` 53 | 54 | 3. **Implement your changes.** 55 | 56 | 4. **Run tests and pre-commit hooks**: 57 | ```bash 58 | pytest 59 | pre-commit run 60 | ``` 61 | 62 | 5. **Commit your changes with a descriptive message**: 63 | ```bash 64 | git add . 65 | git commit -m "feat: add a clear description of your feature" 66 | ``` 67 | Follow the [Conventional Commits](https://gist.github.com/qoomon/5dfcdf8eec66a051ecd85625518cfd13) format. 68 | 69 | 6. **Push your changes**: 70 | ```bash 71 | git push origin my-feature-branch 72 | ``` 73 | 74 | 7. **Open a pull request on GitHub.** 75 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | help: ## Show definition of each function 3 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z1-9_-]+:.*?## / {printf "\033[36m%-25s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) 4 | 5 | .PHONY: clean 6 | clean: ## Remove .gitignore files 7 | git clean -fdX 8 | 9 | .PHONY: install 10 | install: # Install dependencies 11 | uv sync --frozen 12 | 13 | .PHONY: lint 14 | lint: ## Run lints 15 | uv run --no-sync pre-commit run --all-files 16 | 17 | .PHONY: test 18 | test: ## Run tests 19 | uv run --no-sync pytest 20 | 21 | .PHONY: all 22 | all: clean install lint test ## Run the commands: clean install lint test 23 | 24 | # Default files to update 25 | PYPROJECT_TOML = pyproject.toml 26 | INIT_FILE = mostlyai/engine/__init__.py 27 | 28 | # Internal Variables for Release Workflow 29 | BUMP_TYPE ?= patch 30 | CURRENT_VERSION := $(shell grep -m 1 'version = ' $(PYPROJECT_TOML) | sed -e 's/version = "\(.*\)"/\1/') 31 | # Assuming current_version is already set from pyproject.toml 32 | NEW_VERSION := $(shell echo $(CURRENT_VERSION) | awk -F. -v bump=$(BUMP_TYPE) '{ \ 33 | if (bump == "patch") { \ 34 | printf("%d.%d.%d", $$1, $$2, $$3 + 1); \ 35 | } else if (bump == "minor") { \ 36 | printf("%d.%d.0", $$1, $$2 + 1); \ 37 | } else if (bump == "major") { \ 38 | printf("%d.0.0", $$1 + 1); \ 39 | } else { \ 40 | print "Error: Invalid BUMP_TYPE. Expected patch, minor or major. Input was BUMP_TYPE=" bump; \ 41 | exit 1; \ 42 | } \ 43 | }') 44 | 45 | # Targets for Release Workflow/Automation 46 | .PHONY: update-version-gh release-pypi docs 47 | 48 | update-version-gh: pull-main bump-version update-vars-version create-branch ## Update version in GitHub: pull main, bump version, create and push the new branch 49 | 50 | release-pypi: clean-dist pull-main build upload-pypi docs ## Release to PyPI: pull main, build and upload to PyPI 51 | 52 | pull-main: # Pull main branch 53 | # stash changes 54 | @git stash 55 | # switch to main branch 56 | @git checkout main 57 | # fetch latest changes 58 | @git fetch origin main 59 | # get a clean copy of main branch 60 | @git reset --hard origin/main 61 | # clean 62 | @git clean -fdX 63 | 64 | bump-version: # Bump version (default: patch, options: patch, minor, major) 65 | @echo "Bumping $(BUMP_TYPE) version from $(CURRENT_VERSION) to $(NEW_VERSION)" 66 | @echo "Replaces $(CURRENT_VERSION) to $(NEW_VERSION) in $(PYPROJECT_TOML)" 67 | @echo "Replaces $(CURRENT_VERSION) to $(NEW_VERSION) in $(INIT_FILE)" 68 | @echo "Current directory: $(shell pwd)" 69 | # Check if current version was found 70 | @if [ -z "$(CURRENT_VERSION)" ]; then \ 71 | echo "Error: Could not find current version in $(PYPROJECT_TOML)"; \ 72 | exit 1; \ 73 | fi 74 | # Replace the version in pyproject.toml 75 | @if [[ "$(shell uname -s)" == "Darwin" ]]; then \ 76 | sed -i '' 's/version = "$(CURRENT_VERSION)"/version = "$(NEW_VERSION)"/g' $(PYPROJECT_TOML); \ 77 | sed -i '' 's/__version__ = "$(CURRENT_VERSION)"/__version__ = "$(NEW_VERSION)"/g' $(INIT_FILE); \ 78 | else \ 79 | sed -i 's/version = "$(CURRENT_VERSION)"/version = "$(NEW_VERSION)"/g' $(PYPROJECT_TOML); \ 80 | sed -i 's/__version__ = "$(CURRENT_VERSION)"/__version__ = "$(NEW_VERSION)"/g' $(INIT_FILE); \ 81 | fi 82 | 83 | update-vars-version: # Update the required variables after bump 84 | $(eval VERSION := $(shell python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])")) 85 | $(eval BRANCH := verbump_$(shell echo $(VERSION) | tr '.' '_')) 86 | $(eval TAG := $(VERSION)) 87 | @echo "Updated VERSION to $(VERSION), BRANCH to $(BRANCH), TAG to $(TAG)" 88 | 89 | create-branch: # Create verbump_{new_ver} branch 90 | @git checkout -b $(BRANCH) 91 | @echo "Created branch $(BRANCH)" 92 | # commit the version bump 93 | @git add $(INIT_FILE) 94 | @git add $(PYPROJECT_TOML) 95 | @git commit -m "Version Bump to $(VERSION)" 96 | @echo "Committed version bump to $(VERSION)" 97 | @git push --set-upstream origin $(BRANCH) 98 | @echo "Pushed branch $(BRANCH) to origin" 99 | 100 | clean-dist: # Remove "volatile" directory dist 101 | @rm -rf dist 102 | @echo "Cleaned up dist directory" 103 | 104 | build: # Build the project and create the dist directory if it doesn't exist 105 | @mkdir -p dist 106 | @uv build 107 | @echo "Built the project" 108 | @twine check --strict dist/* 109 | @echo "Project is checked" 110 | 111 | confirm-upload: # Confirm before the irreversible zone 112 | @echo "Are you sure you want to upload to PyPI? (yes/no)" 113 | @read ans && [ $${ans:-no} = yes ] 114 | 115 | upload-pypi: confirm-upload # Upload to PyPI (ensure the token is present in .pypirc file before running upload) 116 | @twine upload dist/*$(VERSION)* --verbose 117 | @echo "Uploaded version $(VERSION) to PyPI" 118 | 119 | docs: ## Update docs site 120 | @mkdocs gh-deploy 121 | @echo "Deployed docs" 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Synthetic Data Engine 💎 2 | 3 | ![GitHub Release](https://img.shields.io/github/v/release/mostly-ai/mostlyai-engine) 4 | [![Documentation](https://img.shields.io/badge/docs-latest-green)](https://mostly-ai.github.io/mostlyai-engine/) 5 | [![stats](https://pepy.tech/badge/mostlyai-engine)](https://pypi.org/project/mostlyai-engine/) 6 | ![license](https://img.shields.io/github/license/mostly-ai/mostlyai-engine) 7 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mostlyai-engine) 8 | 9 | [Documentation](https://mostly-ai.github.io/mostlyai-engine/) | [Technical Paper](https://arxiv.org/abs/2501.12012) | [Free Cloud Service](https://app.mostly.ai/) 10 | 11 | Create high-fidelity privacy-safe synthetic data: 12 | 13 | 1. prepare, analyze, and encode original data 14 | 2. train a generative model on the encoded data 15 | 3. generate synthetic data samples to your needs: 16 | * up-sample / down-sample 17 | * conditionally generate 18 | * rebalance categories 19 | * impute missings 20 | * incorporate fairness 21 | * adjust sampling temperature 22 | 23 | ...all within your safe compute environment, all with a few lines of Python code 💥. 24 | 25 | Note: This library is the underlying model engine of the [Synthetic Data SDK](https://github.com/mostly-ai/mostlyai). Please refer to the latter, for an easy-to-use, higher-level software toolkit. 26 | 27 | 28 | ## Installation 29 | 30 | The latest release of `mostlyai-engine` can be installed via pip: 31 | 32 | ```bash 33 | pip install -U mostlyai-engine 34 | ``` 35 | 36 | or alternatively for a GPU setup (needed for LLM finetuning and inference): 37 | ```bash 38 | pip install -U 'mostlyai-engine[gpu]' 39 | ``` 40 | 41 | On Linux, one can explicitly install the CPU-only variant of torch together with `mostlyai-engine`: 42 | 43 | ```bash 44 | pip install -U torch==2.6.0+cpu torchvision==0.21.0+cpu mostlyai-engine --extra-index-url https://download.pytorch.org/whl/cpu 45 | ``` 46 | 47 | ## Quick start 48 | 49 | ### Tabular Model: flat data, without context 50 | 51 | ```python 52 | from pathlib import Path 53 | import pandas as pd 54 | from mostlyai import engine 55 | 56 | # set up workspace and default logging 57 | ws = Path("ws-tabular-flat") 58 | engine.init_logging() 59 | 60 | # load original data 61 | url = "https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/census" 62 | trn_df = pd.read_csv(f"{url}/census.csv.gz") 63 | 64 | # execute the engine steps 65 | engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/tgt-data` 66 | workspace_dir=ws, 67 | tgt_data=trn_df, 68 | model_type="TABULAR", 69 | ) 70 | engine.analyze(workspace_dir=ws) # generate column-level statistics to `{ws}/ModelData/tgt-stats/stats.json` 71 | engine.encode(workspace_dir=ws) # encode training data to `{ws}/OriginalData/encoded-data` 72 | engine.train( # train model and store to `{ws}/ModelStore/model-data` 73 | workspace_dir=ws, 74 | max_training_time=1, # limit TRAIN to 1 minute for demo purposes 75 | ) 76 | engine.generate(workspace_dir=ws) # use model to generate synthetic samples to `{ws}/SyntheticData` 77 | pd.read_parquet(ws / "SyntheticData") # load synthetic data 78 | ``` 79 | 80 | ### Tabular Model: sequential data, with context 81 | 82 | ```python 83 | from pathlib import Path 84 | import pandas as pd 85 | from mostlyai import engine 86 | 87 | engine.init_logging() 88 | 89 | # set up workspace and default logging 90 | ws = Path("ws-tabular-sequential") 91 | engine.init_logging() 92 | 93 | # load original data 94 | url = "https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/baseball" 95 | trn_ctx_df = pd.read_csv(f"{url}/players.csv.gz") # context data 96 | trn_tgt_df = pd.read_csv(f"{url}/batting.csv.gz") # target data 97 | 98 | # execute the engine steps 99 | engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/(tgt|ctx)-data` 100 | workspace_dir=ws, 101 | tgt_data=trn_tgt_df, 102 | ctx_data=trn_ctx_df, 103 | tgt_context_key="players_id", 104 | ctx_primary_key="id", 105 | model_type="TABULAR", 106 | ) 107 | engine.analyze(workspace_dir=ws) # generate column-level statistics to `{ws}/ModelStore/(tgt|ctx)-data/stats.json` 108 | engine.encode(workspace_dir=ws) # encode training data to `{ws}/OriginalData/encoded-data` 109 | engine.train( # train model and store to `{ws}/ModelStore/model-data` 110 | workspace_dir=ws, 111 | max_training_time=1, # limit TRAIN to 1 minute for demo purposes 112 | ) 113 | engine.generate(workspace_dir=ws) # use model to generate synthetic samples to `{ws}/SyntheticData` 114 | pd.read_parquet(ws / "SyntheticData") # load synthetic data 115 | ``` 116 | 117 | ### Language Model: flat data, without context 118 | 119 | ```python 120 | from pathlib import Path 121 | import pandas as pd 122 | from mostlyai import engine 123 | 124 | # init workspace and logging 125 | ws = Path("ws-language-flat") 126 | engine.init_logging() 127 | 128 | # load original data 129 | trn_df = pd.read_parquet("https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/headlines/headlines.parquet") 130 | trn_df = trn_df.sample(n=10_000, random_state=42) 131 | 132 | # execute the engine steps 133 | engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/tgt-data` 134 | workspace_dir=ws, 135 | tgt_data=trn_df, 136 | tgt_encoding_types={ 137 | 'category': 'LANGUAGE_CATEGORICAL', 138 | 'date': 'LANGUAGE_DATETIME', 139 | 'headline': 'LANGUAGE_TEXT', 140 | } 141 | ) 142 | engine.analyze(workspace_dir=ws) # generate column-level statistics to `{ws}/ModelStore/tgt-stats/stats.json` 143 | engine.encode(workspace_dir=ws) # encode training data to `{ws}/OriginalData/encoded-data` 144 | engine.train( # train model and store to `{ws}/ModelStore/model-data` 145 | workspace_dir=ws, 146 | max_training_time=2, # limit TRAIN to 2 minute for demo purposes 147 | model="MOSTLY_AI/LSTMFromScratch-3m", # use a light-weight LSTM model, trained from scratch (GPU recommended) 148 | # model="microsoft/phi-1.5", # alternatively use a pre-trained HF-hosted LLM model (GPU required) 149 | ) 150 | engine.generate( # use model to generate synthetic samples to `{ws}/SyntheticData` 151 | workspace_dir=ws, 152 | sample_size=10, 153 | ) 154 | pd.read_parquet(ws / "SyntheticData") # load synthetic data 155 | ``` 156 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | ## Engine Reference 7 | 8 | ::: mostlyai.engine 9 | options: 10 | members: 11 | - split 12 | - analyze 13 | - encode 14 | - train 15 | - generate 16 | 17 | ## Schema Reference 18 | 19 | ::: mostlyai.engine.domain 20 | options: 21 | filters: 22 | - "!^CustomBaseModel" 23 | -------------------------------------------------------------------------------- /docs/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/docs/favicon.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | --8<-- "README.md" 7 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/docs/logo.png -------------------------------------------------------------------------------- /examples/flat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Tabular Model: flat data, without context" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/flat.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from pathlib import Path\n", 26 | "import pandas as pd\n", 27 | "from mostlyai import engine\n", 28 | "\n", 29 | "# init workspace and logging\n", 30 | "ws = Path(\"ws-tabular-flat\")\n", 31 | "engine.init_logging()\n", 32 | "\n", 33 | "# load original data\n", 34 | "url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/census\"\n", 35 | "trn_df = pd.read_csv(f\"{url}/census.csv.gz\")\n", 36 | "\n", 37 | "# execute the engine steps\n", 38 | "engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/tgt-data`\n", 39 | " workspace_dir=ws,\n", 40 | " tgt_data=trn_df,\n", 41 | " model_type=\"TABULAR\",\n", 42 | ")\n", 43 | "engine.analyze(workspace_dir=ws) # generate column-level statistics to `{ws}/ModelData/tgt-stats/stats.json`\n", 44 | "engine.encode(workspace_dir=ws) # encode training data to `{ws}/OriginalData/encoded-data`\n", 45 | "engine.train(workspace_dir=ws) # train model and store to `{ws}/ModelData/model-data`\n", 46 | "engine.generate(workspace_dir=ws) # use model to generate synthetic samples to `{ws}/SyntheticData`" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# load synthetic data\n", 56 | "syn_df = pd.read_parquet(ws / \"SyntheticData\")\n", 57 | "syn_df.head(5)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### QUALITY ASSURANCE" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": { 70 | "tags": [] 71 | }, 72 | "source": [ 73 | "#### univariate `age`" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "print(\"Original Age: \" + \", \".join([f'q{q*100:.0f}: {trn_df[\"age\"].quantile(q):.0f}' for q in [.1, .25, .5, .75, .9]]))\n", 83 | "print(\"Synthetic Age: \" + \", \".join([f'q{q*100:.0f}: {syn_df[\"age\"].quantile(q):.0f}' for q in [.1, .25, .5, .75, .9]]))\n", 84 | "#syn_df[\"age\"].quantile(np.linspace(0, 1, 11))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "#### bivariate `sex` ~ `income`: income gap" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "trn_gap = (trn_df[trn_df[\"sex\"] == \"Male\"][\"income\"] == \">50K\").mean() - (trn_df[trn_df[\"sex\"] == \"Female\"][\"income\"] == \">50K\").mean()\n", 101 | "syn_gap = (syn_df[syn_df[\"sex\"] == \"Male\"][\"income\"] == \">50K\").mean() - (syn_df[syn_df[\"sex\"] == \"Female\"][\"income\"] == \">50K\").mean()\n", 102 | "print(f\"Income Gap {trn_gap:.1%} vs. {syn_gap:.1%}\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "#### check consistency between `education` and `education.num`" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "pd.crosstab(syn_df[\"education\"], syn_df[\"education_num\"])" 119 | ] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "Python 3 (ipykernel)", 125 | "language": "python", 126 | "name": "python3" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 3 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython3", 138 | "version": "3.12.3" 139 | }, 140 | "toc": { 141 | "base_numbering": 1, 142 | "nav_menu": {}, 143 | "number_sections": false, 144 | "sideBar": true, 145 | "skip_h1_title": false, 146 | "title_cell": "Table of Contents", 147 | "title_sidebar": "Contents", 148 | "toc_cell": false, 149 | "toc_position": {}, 150 | "toc_section_display": true, 151 | "toc_window_display": false 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 4 156 | } 157 | -------------------------------------------------------------------------------- /examples/language.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Language Model: flat data, without context" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/language.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from pathlib import Path\n", 26 | "import pandas as pd\n", 27 | "from mostlyai import engine\n", 28 | "\n", 29 | "# init workspace and logging\n", 30 | "ws = Path(\"ws-language-flat\")\n", 31 | "engine.init_logging()\n", 32 | "\n", 33 | "# load original data\n", 34 | "url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/arxiv\"\n", 35 | "trn_df = pd.read_parquet(f\"{url}/synthetic-data-papers.parquet\")[['category', 'title']]\n", 36 | "\n", 37 | "# execute the engine steps\n", 38 | "engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/tgt-data`\n", 39 | " workspace_dir=ws,\n", 40 | " tgt_data=trn_df,\n", 41 | " model_type=\"LANGUAGE\",\n", 42 | ")\n", 43 | "engine.analyze(workspace_dir=ws) # generate column-level statistics to `{ws}/ModelStore/tgt-stats/stats.json`\n", 44 | "engine.encode(workspace_dir=ws) # encode training data to `{ws}/OriginalData/encoded-data`\n", 45 | "engine.train( # train model and store to `{ws}/ModelStore/model-data`\n", 46 | " workspace_dir=ws,\n", 47 | " model=\"MOSTLY_AI/LSTMFromScratch-3m\", # use a light-weight LSTM model, trained from scratch (GPU recommended)\n", 48 | " # model=\"microsoft/phi-1.5\", # or alternatively use a HF-hosted LLM model (GPU required)\n", 49 | " max_training_time=10, # limit TRAIN to 10 minute for demo purposes\n", 50 | ")\n", 51 | "engine.generate( # use model to generate synthetic samples to `{ws}/SyntheticData`\n", 52 | " workspace_dir=ws, \n", 53 | " sample_size=100,\n", 54 | ")" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "syn_tgt_df = pd.read_parquet(ws / \"SyntheticData\") # load synthetic data\n", 64 | "syn_tgt_df.head(5)" 65 | ] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": "Python 3 (ipykernel)", 71 | "language": "python", 72 | "name": "python3" 73 | }, 74 | "language_info": { 75 | "codemirror_mode": { 76 | "name": "ipython", 77 | "version": 3 78 | }, 79 | "file_extension": ".py", 80 | "mimetype": "text/x-python", 81 | "name": "python", 82 | "nbconvert_exporter": "python", 83 | "pygments_lexer": "ipython3", 84 | "version": "3.12.3" 85 | }, 86 | "toc": { 87 | "base_numbering": 1, 88 | "nav_menu": {}, 89 | "number_sections": false, 90 | "sideBar": true, 91 | "skip_h1_title": false, 92 | "title_cell": "Table of Contents", 93 | "title_sidebar": "Contents", 94 | "toc_cell": false, 95 | "toc_position": {}, 96 | "toc_section_display": true, 97 | "toc_window_display": false 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 4 102 | } 103 | -------------------------------------------------------------------------------- /examples/sequential.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Tabular Model: sequential data, with context" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/sequential.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "tags": [] 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "from pathlib import Path\n", 28 | "import pandas as pd\n", 29 | "import numpy as np\n", 30 | "from mostlyai import engine\n", 31 | "\n", 32 | "# init workspace and logging\n", 33 | "ws = Path(\"ws-tabular-sequential\")\n", 34 | "engine.init_logging()\n", 35 | "\n", 36 | "# load original data\n", 37 | "url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/baseball\"\n", 38 | "trn_ctx_df = pd.read_csv(f\"{url}/players.csv.gz\") # context data\n", 39 | "trn_tgt_df = pd.read_csv(f\"{url}/batting.csv.gz\") # target data\n", 40 | "\n", 41 | "# execute the engine steps\n", 42 | "engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/(tgt|ctx)-data`\n", 43 | " workspace_dir=ws,\n", 44 | " tgt_data=trn_tgt_df,\n", 45 | " ctx_data=trn_ctx_df,\n", 46 | " tgt_context_key=\"players_id\",\n", 47 | " ctx_primary_key=\"id\",\n", 48 | " model_type=\"TABULAR\",\n", 49 | ")\n", 50 | "engine.analyze(workspace_dir=ws) # generate column-level statistics to `{ws}/ModelStore/(tgt|ctx)-data/stats.json`\n", 51 | "engine.encode(workspace_dir=ws) # encode training data to `{ws}/OriginalData/encoded-data`\n", 52 | "engine.train( # train model and store to `{ws}/ModelStore/model-data`\n", 53 | " workspace_dir=ws,\n", 54 | " max_training_time=2, # limit TRAIN to 2 minute for demo purposes\n", 55 | ")\n", 56 | "engine.generate(workspace_dir=ws) # use model to generate synthetic samples to `{ws}/SyntheticData`" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# load synthetic data\n", 66 | "syn_tgt_df = pd.read_parquet(ws / \"SyntheticData\")\n", 67 | "syn_tgt_df.head(5)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "### QUALITY ASSURANCE" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "#### sequence lengths" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "tags": [] 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "trn_seq_lens = trn_tgt_df.groupby(\"players_id\").size()\n", 93 | "syn_seq_lens = syn_tgt_df.groupby(\"players_id\").size()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "print(\"tgt: \", np.quantile(trn_seq_lens, np.arange(0, 1.1, 0.1), method=\"inverted_cdf\"))\n", 103 | "print(\"syn: \", np.quantile(syn_seq_lens, np.arange(0, 1.1, 0.1), method=\"inverted_cdf\"))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": { 109 | "tags": [] 110 | }, 111 | "source": [ 112 | "#### coherence" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "syn_avg_teams_per_player = syn_tgt_df.groupby(\"players_id\")[\"team\"].nunique().mean().round(1)\n", 122 | "trn_avg_teams_per_player = trn_tgt_df.groupby(\"players_id\")[\"team\"].nunique().mean().round(1)\n", 123 | "syn_avg_teams_per_player, trn_avg_teams_per_player" 124 | ] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3 (ipykernel)", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.12.3" 144 | }, 145 | "toc": { 146 | "base_numbering": 1, 147 | "nav_menu": {}, 148 | "number_sections": false, 149 | "sideBar": true, 150 | "skip_h1_title": false, 151 | "title_cell": "Table of Contents", 152 | "title_sidebar": "Contents", 153 | "toc_cell": false, 154 | "toc_position": {}, 155 | "toc_section_display": true, 156 | "toc_window_display": false 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 4 161 | } 162 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "mostlyai-engine" 2 | site_url: "https://mostly-ai.github.io/mostlyai-engine/" 3 | repo_url: "https://github.com/mostly-ai/mostlyai-engine" 4 | repo_name: "mostly-ai/mostlyai-engine" 5 | 6 | theme: 7 | name: material 8 | logo: logo.png 9 | favicon: favicon.png 10 | font: 11 | text: Lato 12 | features: 13 | - navigation.top 14 | - navigation.tracking 15 | - navigation.tabs 16 | - navigation.tabs.sticky 17 | - content.code.select 18 | - content.code.copy 19 | - navigation.footer 20 | 21 | palette: 22 | - scheme: default 23 | toggle: 24 | icon: material/brightness-7 25 | name: Switch to dark mode 26 | - scheme: slate 27 | toggle: 28 | icon: material/brightness-2 29 | name: Switch to light mode 30 | 31 | nav: 32 | - Getting started: index.md 33 | - API Reference: api.md 34 | 35 | plugins: 36 | - search 37 | - mkdocstrings: 38 | handlers: 39 | python: 40 | options: 41 | heading_level: 3 42 | show_root_toc_entry: false 43 | show_root_heading: false 44 | show_object_full_path: true 45 | show_bases: false 46 | show_docstring: true 47 | show_source: false 48 | show_signature: true 49 | separate_signature: true 50 | show_docstring_examples: true 51 | docstring_section_style: table 52 | extensions: 53 | - griffe_fieldz 54 | docstring_style: google 55 | 56 | markdown_extensions: 57 | - pymdownx.highlight: 58 | anchor_linenums: true 59 | line_spans: __span 60 | pygments_lang_class: true 61 | - pymdownx.inlinehilite 62 | - pymdownx.snippets 63 | - pymdownx.superfences 64 | - toc: 65 | permalink: true 66 | -------------------------------------------------------------------------------- /mostlyai/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | 16 | from mostlyai.engine.analysis import analyze 17 | from mostlyai.engine.encoding import encode 18 | from mostlyai.engine.generation import generate 19 | from mostlyai.engine.logging import init_logging 20 | from mostlyai.engine.random_state import set_random_state 21 | from mostlyai.engine.splitting import split 22 | from mostlyai.engine.training import train 23 | 24 | __all__ = ["split", "analyze", "encode", "train", "generate", "init_logging", "set_random_state"] 25 | __version__ = "1.4.3" 26 | 27 | # suppress specific warning related to os.fork() in multi-threaded processes 28 | warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*multi-threaded.*fork.*") 29 | -------------------------------------------------------------------------------- /mostlyai/engine/_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pyarrow as pa 17 | 18 | 19 | def is_string_dtype(x: pd.Series) -> bool: 20 | if isinstance(x.dtype, pd.ArrowDtype): 21 | return pa.types.is_string(x.dtype.pyarrow_dtype) 22 | else: 23 | return pd.api.types.is_string_dtype(x) 24 | 25 | 26 | def is_integer_dtype(x: pd.Series) -> bool: 27 | if isinstance(x.dtype, pd.ArrowDtype): 28 | return pa.types.is_integer(x.dtype.pyarrow_dtype) 29 | else: 30 | return pd.api.types.is_integer_dtype(x) 31 | 32 | 33 | def is_float_dtype(x: pd.Series) -> bool: 34 | if isinstance(x.dtype, pd.ArrowDtype): 35 | return pa.types.is_floating(x.dtype.pyarrow_dtype) 36 | else: 37 | return pd.api.types.is_float_dtype(x) 38 | 39 | 40 | def is_date_dtype(x: pd.Series) -> bool: 41 | if isinstance(x.dtype, pd.ArrowDtype): 42 | return pa.types.is_date(x.dtype.pyarrow_dtype) 43 | else: 44 | return False 45 | 46 | 47 | def is_timestamp_dtype(x: pd.Series) -> bool: 48 | if isinstance(x.dtype, pd.ArrowDtype): 49 | return pa.types.is_timestamp(x.dtype.pyarrow_dtype) 50 | else: 51 | return pd.api.types.is_datetime64_any_dtype(x) 52 | 53 | 54 | def is_boolean_dtype(x: pd.Series) -> bool: 55 | if isinstance(x.dtype, pd.ArrowDtype): 56 | return pa.types.is_boolean(x.dtype.pyarrow_dtype) 57 | else: 58 | return pd.api.types.is_bool_dtype(x) 59 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Categorical encoding for language models. 17 | """ 18 | 19 | import pandas as pd 20 | 21 | from mostlyai.engine._common import STRING, safe_convert_string 22 | from mostlyai.engine._encoding_types.tabular.categorical import analyze_categorical, analyze_reduce_categorical 23 | 24 | CATEGORICAL_UNKNOWN_TOKEN = "_RARE_" 25 | 26 | 27 | def analyze_language_categorical(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 28 | return analyze_categorical(values, root_keys, _, safe_escape=False) 29 | 30 | 31 | def analyze_reduce_language_categorical( 32 | stats_list: list[dict], 33 | value_protection: bool = True, 34 | value_protection_epsilon: float | None = None, 35 | ) -> dict: 36 | stats = analyze_reduce_categorical(stats_list, value_protection, value_protection_epsilon) 37 | stats["categories"] = list(stats["codes"].keys()) 38 | if any([j["has_nan"] for j in stats_list]): 39 | # when has_nan, tabular stats are like [CATEGORICAL_UNKNOWN_TOKEN, CATEGORICAL_NULL_TOKEN, ...] 40 | # and we need to replace CATEGORICAL_NULL_TOKEN with None for language 41 | stats["categories"][1] = None 42 | # drop tabular stats 43 | stats.pop("codes") 44 | stats.pop("cardinalities") 45 | return stats 46 | 47 | 48 | def encode_language_categorical(values: pd.Series, stats: dict) -> pd.Series: 49 | values = safe_convert_string(values) 50 | values = values.copy() 51 | known_categories = stats["categories"] 52 | mask = ~values.isin(known_categories) 53 | if None in known_categories: 54 | mask &= ~pd.isna(values) 55 | values[mask] = CATEGORICAL_UNKNOWN_TOKEN 56 | return values 57 | 58 | 59 | def decode_language_categorical(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: 60 | x = x.astype(STRING) 61 | allowed_categories = col_stats.get("categories", []) 62 | return x.where(x.isin(allowed_categories), other=None) 63 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/datetime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import calendar 15 | 16 | import numpy as np 17 | import pandas as pd 18 | 19 | from mostlyai.engine._common import ( 20 | ANALYZE_MIN_MAX_TOP_N, 21 | ANALYZE_REDUCE_MIN_MAX_N, 22 | compute_log_histogram, 23 | dp_approx_bounds, 24 | get_stochastic_rare_threshold, 25 | safe_convert_datetime, 26 | ) 27 | 28 | 29 | def analyze_language_datetime(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 30 | values = safe_convert_datetime(values) 31 | # compute log histogram for DP bounds 32 | log_hist = compute_log_histogram(values.dropna().astype("int64")) 33 | 34 | df = pd.concat([root_keys, values], axis=1) 35 | # determine lowest/highest values by root ID, and return Top 10 36 | min_dates = df.groupby(root_keys.name)[values.name].min().dropna() 37 | min_n = min_dates.sort_values(ascending=True).head(ANALYZE_MIN_MAX_TOP_N).astype(str).tolist() 38 | max_dates = df.groupby(root_keys.name)[values.name].max().dropna() 39 | max_n = max_dates.sort_values(ascending=False).head(ANALYZE_MIN_MAX_TOP_N).astype(str).tolist() 40 | # determine if there are any NaN values 41 | has_nan = bool(values.isna().any()) 42 | # return stats 43 | stats = { 44 | "has_nan": has_nan, 45 | "min_n": min_n, 46 | "max_n": max_n, 47 | "log_hist": log_hist, 48 | } 49 | return stats 50 | 51 | 52 | def analyze_reduce_language_datetime( 53 | stats_list: list[dict], 54 | value_protection: bool = True, 55 | value_protection_epsilon: float | None = None, 56 | ) -> dict: 57 | # check if there are missing values 58 | has_nan = any([j["has_nan"] for j in stats_list]) 59 | reduced_min_n = sorted([v for min_n in [j["min_n"] for j in stats_list] for v in min_n], reverse=False) 60 | reduced_max_n = sorted([v for max_n in [j["max_n"] for j in stats_list] for v in max_n], reverse=True) 61 | if value_protection: 62 | if len(reduced_min_n) < ANALYZE_REDUCE_MIN_MAX_N or len(reduced_max_n) < ANALYZE_REDUCE_MIN_MAX_N: 63 | # protect all values if there are less than ANALYZE_REDUCE_MIN_MAX_N values 64 | reduced_min = None 65 | reduced_max = None 66 | else: 67 | if value_protection_epsilon is not None: 68 | if any(len(v) > 10 for v in reduced_min_n + reduced_max_n): 69 | dt_format = "%Y-%m-%d %H:%M:%S" 70 | else: 71 | dt_format = "%Y-%m-%d" 72 | # Sum up log histograms bin-wise from all partitions 73 | log_hist = [sum(bin) for bin in zip(*[j["log_hist"] for j in stats_list])] 74 | reduced_min, reduced_max = dp_approx_bounds(log_hist, value_protection_epsilon) 75 | if reduced_min is not None and reduced_max is not None: 76 | # convert back to the original string format 77 | reduced_min = pd.to_datetime(int(reduced_min), unit="us").strftime(dt_format) 78 | reduced_max = pd.to_datetime(int(reduced_max), unit="us").strftime(dt_format) 79 | else: 80 | reduced_min = str(reduced_min_n[get_stochastic_rare_threshold(min_threshold=5)]) 81 | reduced_max = str(reduced_max_n[get_stochastic_rare_threshold(min_threshold=5)]) 82 | else: 83 | reduced_min = str(reduced_min_n[0]) if len(reduced_min_n) > 0 else None 84 | reduced_max = str(reduced_max_n[0]) if len(reduced_max_n) > 0 else None 85 | stats = { 86 | "has_nan": has_nan, 87 | "min": reduced_min, 88 | "max": reduced_max, 89 | } 90 | return stats 91 | 92 | 93 | def _clip_datetime(values: pd.Series, stats: dict) -> pd.Series: 94 | if stats["min"] is not None: 95 | reduced_min = np.datetime64(stats["min"], "ns") 96 | values.loc[values < reduced_min] = reduced_min 97 | if stats["max"] is not None: 98 | reduced_max = np.datetime64(stats["max"], "ns") 99 | values.loc[values > reduced_max] = reduced_max 100 | return values 101 | 102 | 103 | def encode_language_datetime(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.Series: 104 | # convert 105 | values = safe_convert_datetime(values) 106 | values = values.copy() 107 | # reset index, as `values.mask` can throw errors for misaligned indices 108 | values.reset_index(drop=True, inplace=True) 109 | # replace extreme values with min/max 110 | values = _clip_datetime(values, stats) 111 | return values 112 | 113 | 114 | def decode_language_datetime(x: pd.Series, stats: dict[str, str]) -> pd.Series: 115 | x = x.where(~x.isin(["", "_INVALID_"]), np.nan) 116 | 117 | valid_mask = ( 118 | x.str.len().ge(10) 119 | & x.str.slice(0, 4).str.isdigit() 120 | & x.str.slice(5, 7).str.isdigit() 121 | & x.str.slice(8, 10).str.isdigit() 122 | ) 123 | if valid_mask.sum() > 0: # expected "YYYY-MM-DD" prefix 124 | # handle the date portion, ensuring validity 125 | years = x[valid_mask].str.slice(0, 4).astype(int) 126 | months = x[valid_mask].str.slice(5, 7).astype(int) 127 | days = x[valid_mask].str.slice(8, 10).astype(int) 128 | 129 | # clamp days according to maximum possible day of the month of a given year 130 | last_days = np.array([calendar.monthrange(y, m)[1] for y, m in zip(years, months)]) 131 | clamped_days = np.minimum(days, last_days) 132 | 133 | # rebuild the date portion 134 | new_date = ( 135 | years.astype(str).str.zfill(4) 136 | + "-" 137 | + months.astype(str).str.zfill(2) 138 | + "-" 139 | + pd.Series(clamped_days, index=years.index).astype(str).str.zfill(2) 140 | ) 141 | 142 | # handle the time portion, ensuring validity 143 | remainder = x[valid_mask].str.slice(10) 144 | 145 | time_regex = r"^[ T]?(\d{2}:\d{2}:\d{2}(?:\.\d+)?)" 146 | valid_time = remainder.str.extract(time_regex, expand=False) 147 | valid_time = valid_time.fillna("00:00:00") 148 | valid_time = " " + valid_time 149 | 150 | new_date = new_date + valid_time 151 | x.loc[valid_mask] = new_date 152 | 153 | x = pd.to_datetime(x, errors="coerce") 154 | x = _clip_datetime(x, stats) 155 | return x.astype("datetime64[ns]") 156 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import ( 18 | ANALYZE_MIN_MAX_TOP_N, 19 | ANALYZE_REDUCE_MIN_MAX_N, 20 | compute_log_histogram, 21 | dp_approx_bounds, 22 | get_stochastic_rare_threshold, 23 | safe_convert_numeric, 24 | ) 25 | from mostlyai.engine._encoding_types.tabular.numeric import _type_safe_numeric_series 26 | from mostlyai.engine.domain import ModelEncodingType 27 | 28 | 29 | def analyze_language_numeric(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 30 | values = safe_convert_numeric(values) 31 | # compute log histogram for DP bounds 32 | log_hist = compute_log_histogram(values.dropna()) 33 | 34 | # determine lowest/highest values by root ID, and return top ANALYZE_MIN_MAX_TOP_N 35 | df = pd.concat([root_keys, values], axis=1) 36 | min_values = df.groupby(root_keys.name)[values.name].min().dropna() 37 | min_n = min_values.sort_values(ascending=True).head(ANALYZE_MIN_MAX_TOP_N).tolist() 38 | max_values = df.groupby(root_keys.name)[values.name].max().dropna() 39 | max_n = max_values.sort_values(ascending=False).head(ANALYZE_MIN_MAX_TOP_N).tolist() 40 | 41 | # determine if there are any NaN values 42 | has_nan = bool(values.isna().any()) 43 | 44 | # determine max scale 45 | def count_scale(num: float) -> int: 46 | # represent number as fixed point string, remove trailing zeros and decimal point 47 | num = format(num, "f").rstrip("0").rstrip(".") 48 | if "." in num: 49 | # in case of decimal, return number of digits after decimal point 50 | return len(num.split(".")[1]) 51 | # in case of integer, return 0 52 | return 0 53 | 54 | max_scale = int(values.apply(count_scale).max()) 55 | 56 | stats = { 57 | "has_nan": has_nan, 58 | "max_scale": max_scale, 59 | "min_n": min_n, 60 | "max_n": max_n, 61 | "log_hist": log_hist, 62 | } 63 | return stats 64 | 65 | 66 | def analyze_reduce_language_numeric( 67 | stats_list: list[dict], 68 | value_protection: bool = True, 69 | value_protection_epsilon: float | None = None, 70 | ) -> dict: 71 | # check for occurrence of NaN values 72 | has_nan = any([j["has_nan"] for j in stats_list]) 73 | 74 | # determine max scale 75 | max_scale = max([j["max_scale"] for j in stats_list]) 76 | 77 | reduced_min_n = sorted([v for min_n in [j["min_n"] for j in stats_list] for v in min_n], reverse=False) 78 | reduced_max_n = sorted([v for max_n in [j["max_n"] for j in stats_list] for v in max_n], reverse=True) 79 | if value_protection: 80 | if len(reduced_min_n) < ANALYZE_REDUCE_MIN_MAX_N or len(reduced_max_n) < ANALYZE_REDUCE_MIN_MAX_N: 81 | # protect all values if there are less than ANALYZE_REDUCE_MIN_MAX_N values 82 | reduced_min = None 83 | reduced_max = None 84 | else: 85 | if value_protection_epsilon is not None: 86 | # Sum up log histograms bin-wise from all partitions 87 | log_hist = [sum(bin) for bin in zip(*[j["log_hist"] for j in stats_list])] 88 | reduced_min, reduced_max = dp_approx_bounds(log_hist, value_protection_epsilon) 89 | if reduced_min is not None and reduced_max is not None and max_scale == 0: 90 | reduced_min = int(reduced_min) 91 | reduced_max = int(reduced_max) 92 | else: 93 | reduced_min = reduced_min_n[get_stochastic_rare_threshold(min_threshold=5)] 94 | reduced_max = reduced_max_n[get_stochastic_rare_threshold(min_threshold=5)] 95 | else: 96 | reduced_min = reduced_min_n[0] if len(reduced_min_n) > 0 else None 97 | reduced_max = reduced_max_n[0] if len(reduced_max_n) > 0 else None 98 | 99 | stats = { 100 | "encoding_type": ModelEncodingType.language_numeric.value, 101 | "has_nan": has_nan, 102 | "max_scale": max_scale, 103 | "min": reduced_min, 104 | "max": reduced_max, 105 | } 106 | 107 | return stats 108 | 109 | 110 | def encode_language_numeric(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: 111 | values = safe_convert_numeric(values) 112 | # try to convert to int, if possible 113 | dtype = "Int64" if stats["max_scale"] == 0 else "Float64" 114 | if dtype == "Int64": 115 | values = values.round() 116 | try: 117 | values = values.astype(dtype) 118 | except TypeError: 119 | if dtype == "Int64": # if couldn't safely convert to int, stick to float 120 | dtype = "Float64" 121 | values = values.astype(dtype) 122 | # reset index, as `values.mask` can throw errors for misaligned indices 123 | values.reset_index(drop=True, inplace=True) 124 | if stats["min"] is not None: 125 | reduced_min = _type_safe_numeric_series([stats["min"]], dtype).iloc[0] 126 | values.loc[values < reduced_min] = reduced_min 127 | if stats["max"] is not None: 128 | reduced_max = _type_safe_numeric_series([stats["max"]], dtype).iloc[0] 129 | values.loc[values > reduced_max] = reduced_max 130 | return values 131 | 132 | 133 | def decode_language_numeric(x: pd.Series, stats: dict[str, str]) -> pd.Series: 134 | x = pd.to_numeric(x, errors="coerce") 135 | x = x.round(stats["max_scale"]) 136 | if stats["min"] is not None: 137 | reduced_min = np.dtype(x.dtype).type(stats["min"]) 138 | x.loc[x < reduced_min] = reduced_min 139 | if stats["max"] is not None: 140 | reduced_max = np.dtype(x.dtype).type(stats["max"]) 141 | x.loc[x > reduced_max] = reduced_max 142 | dtype = "Int64" if stats["max_scale"] == 0 else float 143 | return x.astype(dtype) 144 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import STRING, safe_convert_string 18 | 19 | 20 | def analyze_text(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 21 | # ideally, we should ensure that values are converted to string in a consistent way across analyze/encode/qa steps 22 | values = safe_convert_string(values) 23 | nchars = values.map(str).str.len() 24 | stats = {"nchar_max": int(nchars.max()), "nchar_sum": int(nchars.sum()), "count": len(values)} 25 | return stats 26 | 27 | 28 | def analyze_reduce_text( 29 | stats_list: list[dict], 30 | value_protection: bool = True, 31 | value_protection_epsilon: float | None = None, 32 | ) -> dict: 33 | nchar_max = 0 34 | nchar_sum = 0 35 | count = 0 36 | for stats in stats_list: 37 | nchar_max = max(stats["nchar_max"], nchar_max) 38 | nchar_sum += stats["nchar_sum"] 39 | count += stats["count"] 40 | 41 | stats = { 42 | "nchar_avg": round(nchar_sum / count, 1), 43 | "nchar_max": nchar_max, 44 | } 45 | return stats 46 | 47 | 48 | def decode_text(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: 49 | return x.astype(STRING) 50 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/tabular/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/tabular/categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Categorical encoding maps each categorical value to its own integer code. 17 | """ 18 | 19 | import pandas as pd 20 | 21 | from mostlyai.engine._common import dp_non_rare, get_stochastic_rare_threshold, safe_convert_string 22 | 23 | CATEGORICAL_UNKNOWN_TOKEN = "_RARE_" 24 | CATEGORICAL_NULL_TOKEN = "<>" 25 | CATEGORICAL_SUB_COL_SUFFIX = "cat" 26 | CATEGORICAL_ESCAPE_CHAR = "\x01" 27 | 28 | 29 | def safe_categorical_escape(values: pd.Series) -> pd.Series: 30 | """Inplace escaping of categorical values""" 31 | reserved_tokens = (CATEGORICAL_UNKNOWN_TOKEN, CATEGORICAL_NULL_TOKEN) 32 | reserved_tokens_replacement_map = {t: CATEGORICAL_ESCAPE_CHAR + t for t in reserved_tokens} 33 | # first, prefix values starting with escape char with another escape char 34 | mask = values.str.startswith(CATEGORICAL_ESCAPE_CHAR, na=False) 35 | values.loc[mask] = values.loc[mask].str.slice_replace(stop=1, repl=CATEGORICAL_ESCAPE_CHAR * 2) 36 | # second, add escape char to all reserved tokens 37 | values = values.replace(reserved_tokens_replacement_map) 38 | return values 39 | 40 | 41 | def safe_categorical_unescape(values: pd.Series) -> pd.Series: 42 | """Inplace un-escaping of categorical values""" 43 | # de-prefix all values starting with escape char by removing just the first one 44 | mask = values.str.startswith(CATEGORICAL_ESCAPE_CHAR, na=False) 45 | values.loc[mask] = values.loc[mask].str[1:] 46 | return values 47 | 48 | 49 | def analyze_categorical( 50 | values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None, *, safe_escape: bool = True 51 | ) -> dict: 52 | # ensure a safe representation of values: 1. string dtype; 2. escape reserved tokens 53 | values = safe_convert_string(values) 54 | if safe_escape: 55 | values = safe_categorical_escape(values) 56 | # count distinct root_keys per categorical value for rare-category protection 57 | df = pd.concat([root_keys, values], axis=1) 58 | cnt_values = df.groupby(values.name)[root_keys.name].nunique().to_dict() 59 | stats = {"has_nan": sum(values.isna()) > 0, "cnt_values": cnt_values} 60 | return stats 61 | 62 | 63 | def analyze_reduce_categorical( 64 | stats_list: list[dict], 65 | value_protection: bool = True, 66 | value_protection_epsilon: float | None = None, 67 | ) -> dict: 68 | # sum up all counts for each categorical value 69 | cnt_values: dict[str, int] = {} 70 | for item in stats_list: 71 | for value, count in item["cnt_values"].items(): 72 | cnt_values[value] = cnt_values.get(value, 0) + count 73 | cnt_values = dict(sorted(cnt_values.items())) 74 | known_categories = list(cnt_values.keys()) 75 | if value_protection: 76 | if value_protection_epsilon is not None: 77 | categories, _ = dp_non_rare(cnt_values, value_protection_epsilon, threshold=5) 78 | else: 79 | rare_min = get_stochastic_rare_threshold(min_threshold=5) 80 | categories = [k for k in known_categories if cnt_values[k] >= rare_min] 81 | else: 82 | categories = known_categories 83 | no_of_rare_categories = len(known_categories) - len(categories) 84 | # add special token for MISSING categories, if any are present 85 | if any([j["has_nan"] for j in stats_list]): 86 | categories = [CATEGORICAL_NULL_TOKEN] + categories 87 | # add special token for UNKNOWN categories at first position 88 | categories = [CATEGORICAL_UNKNOWN_TOKEN] + categories 89 | stats = { 90 | "no_of_rare_categories": no_of_rare_categories, 91 | "codes": {categories[i]: i for i in range(len(categories))}, 92 | "cardinalities": {CATEGORICAL_SUB_COL_SUFFIX: len(categories)}, 93 | } 94 | return stats 95 | 96 | 97 | def encode_categorical(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: 98 | # ensure a safe representation of values: 1. string dtype; 2. escape reserved tokens 99 | values = safe_categorical_escape(safe_convert_string(values)) 100 | known_categories = [str(k) for k in stats["codes"].keys()] 101 | values = values.copy() 102 | if CATEGORICAL_NULL_TOKEN in known_categories: 103 | values[values.isna()] = CATEGORICAL_NULL_TOKEN 104 | values[~values.isin(known_categories)] = CATEGORICAL_UNKNOWN_TOKEN 105 | 106 | # map categories to their corresponding codes 107 | codes = pd.Series( 108 | pd.Categorical(values, categories=known_categories).codes, 109 | name=CATEGORICAL_SUB_COL_SUFFIX, 110 | index=values.index, 111 | ) 112 | return codes.to_frame() 113 | 114 | 115 | def decode_categorical(df_encoded: pd.DataFrame, stats: dict) -> pd.Series: 116 | categories = stats["codes"].keys() 117 | values = pd.Series( 118 | pd.Categorical.from_codes(df_encoded[CATEGORICAL_SUB_COL_SUFFIX], categories=categories), 119 | dtype="string", 120 | ) 121 | values[values == CATEGORICAL_NULL_TOKEN] = pd.NA 122 | # convert escaped values to their original representation 123 | values = safe_categorical_unescape(values) 124 | return values 125 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/tabular/character.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Character encoding splits any value into its characters, and encodes each position then separately as a categorical. 17 | """ 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | from mostlyai.engine._common import dp_non_rare, get_stochastic_rare_threshold, safe_convert_string 23 | 24 | UNKNOWN_TOKEN = "\0" 25 | MAX_LENGTH_CHARS = 50 26 | 27 | 28 | def analyze_character(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 29 | values = safe_convert_string(values) 30 | df_split = split_sub_columns_character(values) 31 | has_nan = sum(df_split["nan"]) > 0 32 | # count distinct root_keys per token for each character position 33 | df = pd.concat([root_keys, df_split], axis=1) 34 | characters = { 35 | sub_col: df.groupby(sub_col)[root_keys.name].nunique().to_dict() 36 | for sub_col in df_split.columns 37 | if sub_col.startswith("P") 38 | } 39 | stats = { 40 | "max_string_length": len(characters), 41 | "has_nan": has_nan, 42 | "characters": characters, 43 | } 44 | return stats 45 | 46 | 47 | def analyze_reduce_character( 48 | stats_list: list[dict], 49 | value_protection: bool = True, 50 | value_protection_epsilon: float | None = None, 51 | ) -> dict: 52 | # gather maximum string length across partitions 53 | max_string_length = max(stats["max_string_length"] for stats in stats_list) 54 | positions = [f"P{idx}" for idx in range(max_string_length)] 55 | # gather codes for each position 56 | codes: dict[str, dict[str, int]] = {pos: {} for pos in positions} 57 | for pos in positions: 58 | cnt_values: dict[str, int] = {} 59 | # sum up all counts for each token 60 | for item in stats_list: 61 | for value, count in item["characters"].get(pos, {}).items(): 62 | cnt_values[value] = cnt_values.get(value, 0) + count 63 | cnt_values = dict(sorted(cnt_values.items())) 64 | known_categories = list(cnt_values.keys()) 65 | if value_protection: 66 | if value_protection_epsilon is not None: 67 | categories, _ = dp_non_rare(cnt_values, value_protection_epsilon, threshold=5) 68 | else: 69 | rare_min = get_stochastic_rare_threshold(min_threshold=5) 70 | categories = [k for k in known_categories if cnt_values[k] >= rare_min] 71 | else: 72 | categories = known_categories 73 | # add special token for UNKNOWN at first position 74 | categories = [UNKNOWN_TOKEN] + [c for c in categories if c != UNKNOWN_TOKEN] 75 | # assign codes for each token 76 | codes[pos] = {categories[i]: i for i in range(len(categories))} 77 | # determine cardinalities 78 | cardinalities = {} 79 | has_nan = any([s["has_nan"] for s in stats_list]) 80 | if has_nan: 81 | cardinalities["nan"] = 2 # binary 82 | for sub_col, sub_col_codes in codes.items(): 83 | cardinalities[sub_col] = len(sub_col_codes) 84 | stats = { 85 | "has_nan": has_nan, 86 | "max_string_length": max_string_length, 87 | "codes": codes, 88 | "cardinalities": cardinalities, 89 | } 90 | return stats 91 | 92 | 93 | def encode_character(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: 94 | values = safe_convert_string(values) 95 | max_string_length = stats["max_string_length"] 96 | df_split = split_sub_columns_character(values, max_string_length) 97 | if not stats["has_nan"]: 98 | df_split.drop(["nan"], axis=1, inplace=True) 99 | for idx in range(max_string_length): 100 | sub_col = f"P{idx}" 101 | np_codes = np.array(pd.Categorical(df_split[sub_col], categories=stats["codes"][sub_col]).codes) 102 | np.place(np_codes, np_codes == -1, 0) 103 | df_split[sub_col] = np_codes 104 | return df_split 105 | 106 | 107 | def split_sub_columns_character( 108 | values: pd.Series, 109 | max_string_length: int | None = None, 110 | ) -> pd.DataFrame: 111 | if not pd.api.types.is_string_dtype(values): 112 | raise ValueError("expected to be string") 113 | is_na = pd.Series(values.isna().astype("int"), name="nan").to_frame() 114 | values = values.fillna("") 115 | # trim strings to a maximum length 116 | values = values.str.slice(stop=MAX_LENGTH_CHARS) 117 | # pad strings to string_length 118 | if max_string_length is None: 119 | max_string_length = values.str.len().max() 120 | max_string_length = ( 121 | int(max_string_length) # type: ignore 122 | if np.isscalar(max_string_length) and not np.isnan(max_string_length) 123 | else 0 124 | ) 125 | else: 126 | values = values.str.slice(stop=max_string_length) 127 | # explode to wide dataframe 128 | padded_values = values.str.ljust(max_string_length, UNKNOWN_TOKEN) 129 | chars_df = padded_values.str.split("", expand=True) 130 | if not chars_df.empty: 131 | chars_df = chars_df.drop([0, max_string_length + 1], axis=1) 132 | chars_df.columns = [f"P{idx}" for idx in range(max_string_length)] 133 | else: # chars_df.empty is True 134 | # even though the input is empty, we still need to return a dataframe with the correct columns 135 | chars_df = pd.DataFrame(columns=[f"P{idx}" for idx in range(max_string_length)]) 136 | df = pd.concat([is_na, chars_df], axis=1) 137 | return df 138 | 139 | 140 | def decode_character(df_encoded: pd.DataFrame, stats: dict) -> pd.Series: 141 | if len(stats["codes"].keys()) > 0: 142 | df_decoded = pd.DataFrame( 143 | { 144 | sub_col: pd.Series( 145 | pd.Categorical.from_codes(df_encoded[sub_col], categories=stats["codes"][sub_col]), 146 | dtype="string", 147 | ) 148 | for sub_col in stats["codes"].keys() 149 | }, 150 | ) 151 | values = df_decoded.apply(lambda item: "".join(item), axis=1, result_type="reduce").astype( 152 | str 153 | ) # necessary to keep string dtype for empty df_decoded 154 | # remove unknown tokens and strip trailing whitespaces 155 | values = values.apply(lambda item: item.replace(UNKNOWN_TOKEN, "")).str.rstrip() 156 | else: 157 | # handle de-generate case, where no tokens were stored 158 | values = pd.Series(pd.NA).repeat(df_encoded.shape[0]) 159 | if stats["has_nan"]: 160 | values[df_encoded["nan"] == 1] = pd.NA 161 | return values 162 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mostlyai.engine._language.lstm import register_mostly_lstm_model 16 | 17 | register_mostly_lstm_model() 18 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | import logging 17 | from pathlib import Path 18 | 19 | import torch 20 | from peft import PeftConfig, prepare_model_for_kbit_training 21 | from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, PretrainedConfig, PreTrainedModel 22 | 23 | from mostlyai.engine._language.lstm import LSTMFromScratchConfig 24 | 25 | _LOG = logging.getLogger(__name__) 26 | 27 | MAX_LENGTH = 10_000 28 | 29 | 30 | def is_bf16_supported(device: torch.device) -> bool: 31 | if device.type != "cuda": 32 | return False 33 | compute_capability = torch.cuda.get_device_capability(device) 34 | return compute_capability[0] >= 8 35 | 36 | 37 | def get_attention_implementation(config: PretrainedConfig) -> str | None: 38 | model_cls = AutoModelForCausalLM._model_mapping[type(config)] 39 | attn_implementation = None 40 | if getattr(model_cls, "_supports_sdpa", False): 41 | attn_implementation = "sdpa" 42 | return attn_implementation 43 | 44 | 45 | def load_base_model_and_config( 46 | model_id_or_path: str | Path, device: torch.device, is_peft_adapter: bool, is_training: bool 47 | ) -> tuple[PreTrainedModel, PretrainedConfig]: 48 | model_id_or_path = str(model_id_or_path) 49 | if is_peft_adapter: 50 | # get the base model name from adapter_config.json 51 | peft_config = PeftConfig.from_pretrained(model_id_or_path) 52 | model_id_or_path = peft_config.base_model_name_or_path 53 | config = AutoConfig.from_pretrained(model_id_or_path) 54 | else: 55 | config = AutoConfig.from_pretrained(model_id_or_path) 56 | if config.model_type == LSTMFromScratchConfig.model_id: 57 | # make sure that we use standard LSTM layers during inference for the model trained with DP 58 | # (see https://opacus.ai/api/dp_rnn.html#opacus.layers.dp_rnn.DPLSTM for more details) 59 | if not is_training: 60 | config.with_dp = False 61 | return AutoModelForCausalLM.from_pretrained(model_id_or_path, config=config, device_map=device), config 62 | 63 | # Load pretrained base model 64 | use_cache = not is_training # KV cache is not needed during training 65 | is_gpu_training = is_training and device.type == "cuda" 66 | is_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None 67 | if is_gpu_training and not is_bitsandbytes_available: 68 | _LOG.warning( 69 | "CUDA device was found but bitsandbytes is not available. Please use extra [gpu] to install bitsandbytes for quantization." 70 | ) 71 | bf16_supported = is_bf16_supported(device) 72 | if bf16_supported: 73 | attn_implementation = get_attention_implementation(config) 74 | torch_dtype = torch.bfloat16 75 | else: 76 | attn_implementation = None 77 | torch_dtype = torch.float32 78 | if is_gpu_training and is_bitsandbytes_available: 79 | quantization_config = BitsAndBytesConfig( 80 | load_in_4bit=True, 81 | bnb_4bit_quant_type="nf4", 82 | bnb_4bit_use_double_quant=False, 83 | bnb_4bit_compute_dtype=torch_dtype, 84 | ) 85 | else: 86 | quantization_config = None 87 | model = AutoModelForCausalLM.from_pretrained( 88 | model_id_or_path, 89 | torch_dtype=torch_dtype, 90 | attn_implementation=attn_implementation, 91 | use_cache=use_cache, 92 | device_map=device, 93 | quantization_config=quantization_config, 94 | ) 95 | if quantization_config: 96 | # convert all non-kbit layers to float32 97 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) 98 | if is_gpu_training and model.supports_gradient_checkpointing: 99 | # pay 50% time penalty for _large_ memory savings 100 | _LOG.info("enable gradient checkpointing") 101 | model.gradient_checkpointing_enable() 102 | model.enable_input_require_grads() 103 | return model, config 104 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from collections.abc import Generator 17 | from dataclasses import dataclass 18 | 19 | from pydantic import BaseModel 20 | 21 | 22 | @dataclass 23 | class EngineMetrics: 24 | tokenize_time: float 25 | generate_time: float 26 | 27 | 28 | class LanguageEngine(ABC): 29 | @abstractmethod 30 | def initialize_logits_processors(self, schemas: Generator[BaseModel]): 31 | pass 32 | 33 | @abstractmethod 34 | def generate( 35 | self, text: list[str], sampling_temperature: float, sampling_top_p: float 36 | ) -> tuple[list[int], EngineMetrics]: 37 | pass 38 | 39 | @abstractmethod 40 | def get_default_batch_size(self) -> int: 41 | pass 42 | 43 | @abstractmethod 44 | def supports_json_enforcing(self) -> bool: 45 | pass 46 | 47 | @abstractmethod 48 | def cleanup(self): 49 | pass 50 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/hf_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import time 18 | from collections.abc import Generator 19 | from os import PathLike 20 | from pathlib import Path 21 | 22 | import torch 23 | from peft import PeftModel 24 | from pydantic import BaseModel 25 | from transformers import AutoTokenizer 26 | from xgrammar.contrib.hf import LogitsProcessor 27 | 28 | from mostlyai.engine._language.common import load_base_model_and_config 29 | from mostlyai.engine._language.engine.base import EngineMetrics, LanguageEngine 30 | from mostlyai.engine._language.tokenizer_utils import tokenize_fn 31 | from mostlyai.engine._language.xgrammar_utils import create_compiled_grammars 32 | 33 | 34 | class HuggingFaceEngine(LanguageEngine): 35 | def __init__( 36 | self, model_path: PathLike | str, device: torch.device, max_new_tokens: int, tokenizer_max_length: int 37 | ): 38 | self.device = device 39 | self.max_new_tokens = max_new_tokens 40 | self.tokenizer_max_length = tokenizer_max_length 41 | self.is_peft_adapter = (Path(model_path) / "adapter_config.json").exists() 42 | 43 | model_path = str(model_path) 44 | self._model, self._model_config = load_base_model_and_config( 45 | model_path, device=device, is_peft_adapter=self.is_peft_adapter, is_training=False 46 | ) 47 | if self.is_peft_adapter: 48 | self._model = PeftModel.from_pretrained(self._model, model_path, is_trainable=False) 49 | self._model = self._model.merge_and_unload() 50 | self._default_batch_size = 64 51 | else: 52 | # only the LSTM model does not have an adapter 53 | self._default_batch_size = 128 54 | 55 | self.tokenizer = AutoTokenizer.from_pretrained( 56 | model_path, 57 | padding_side="left", 58 | truncation_side="left", 59 | legacy=True, 60 | # these must be False at initialization, as we manually add them later in tokenize_fn 61 | add_bos_token=False, 62 | add_eos_token=False, 63 | ) 64 | 65 | # we can't enforce JSON output if LSTM tokenizer training was skipped 66 | is_trained_lstm_tokenizer = not self.is_peft_adapter and self.tokenizer.vocab_size > len( 67 | self.tokenizer.special_tokens_map 68 | ) 69 | self._json_enforcing_possible = self.is_peft_adapter or is_trained_lstm_tokenizer 70 | self._logits_processors = None 71 | 72 | def get_default_batch_size(self) -> int: 73 | return self._default_batch_size 74 | 75 | def supports_json_enforcing(self) -> bool: 76 | return self._json_enforcing_possible 77 | 78 | def initialize_logits_processors(self, schemas: Generator[BaseModel]): 79 | compiled_grammars = create_compiled_grammars( 80 | schemas=schemas, 81 | tokenizer=self.tokenizer, 82 | vocab_size=self._model_config.vocab_size, 83 | is_peft_adapter=self.is_peft_adapter, 84 | ) 85 | self._logits_processors = [LogitsProcessor(list(compiled_grammars))] 86 | 87 | def generate( 88 | self, text: list[str], sampling_temperature: float, sampling_top_p: float 89 | ) -> tuple[list[int], EngineMetrics]: 90 | do_sample = sampling_temperature > 0.0 91 | 92 | tokenize_kwargs = dict( 93 | tokenizer=self.tokenizer, 94 | return_tensors="pt", 95 | add_bos_token=True, 96 | add_eos_token=False, 97 | padding=True, 98 | truncation=True, 99 | max_length=self.tokenizer_max_length, # truncates input 100 | ) 101 | t_tokenize = time.time() 102 | inputs = tokenize_fn(text=text, **tokenize_kwargs).to(self.device) 103 | tokenize_time = time.time() - t_tokenize 104 | 105 | generate_kwargs = dict( 106 | do_sample=do_sample, 107 | max_new_tokens=self.max_new_tokens, 108 | temperature=sampling_temperature if do_sample else None, 109 | top_p=sampling_top_p if do_sample else None, 110 | bos_token_id=self.tokenizer.bos_token_id, 111 | pad_token_id=self.tokenizer.pad_token_id, 112 | eos_token_id=self.tokenizer.eos_token_id, 113 | ) 114 | 115 | t_generate = time.time() 116 | outputs = self._model.generate(**inputs, **generate_kwargs, logits_processor=self._logits_processors) 117 | generate_time = time.time() - t_generate 118 | 119 | _, input_length = inputs["input_ids"].shape 120 | # truncate the prompt from the outputs 121 | outputs = outputs[:, input_length:] 122 | metrics = EngineMetrics(tokenize_time=tokenize_time, generate_time=generate_time) 123 | return outputs.detach().cpu().tolist(), metrics 124 | 125 | def cleanup(self): 126 | pass 127 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import torch 18 | import torch.nn as nn 19 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, GenerationMixin, PretrainedConfig, PreTrainedModel 20 | from transformers.modeling_outputs import CausalLMOutput 21 | 22 | _LOG = logging.getLogger(__name__) 23 | 24 | 25 | class LSTMFromScratchConfig(PretrainedConfig): 26 | model_type = model_id = "MOSTLY_AI/LSTMFromScratch-3m" 27 | 28 | def __init__( 29 | self, 30 | vocab_size: int | None = None, 31 | embedding_size: int = 256, 32 | hidden_size: int = 256, 33 | num_layers: int = 1, 34 | dropout: float = 0.25, 35 | with_dp: bool = False, 36 | **kwargs, 37 | ): 38 | self.vocab_size = vocab_size 39 | self.embedding_size = embedding_size 40 | self.hidden_size = hidden_size 41 | self.num_layers = num_layers 42 | self.dropout = dropout 43 | self.with_dp = with_dp 44 | super().__init__(**kwargs) 45 | 46 | 47 | class LSTMFromScratchLMHeadModel(PreTrainedModel, GenerationMixin): 48 | config_class = LSTMFromScratchConfig 49 | 50 | def __init__(self, config: LSTMFromScratchConfig): 51 | super().__init__(config) 52 | self.config = config 53 | 54 | self.embedding = nn.Embedding(self.config.vocab_size, self.config.embedding_size) 55 | self.dropout = nn.Dropout(self.config.dropout) 56 | if self.config.with_dp: 57 | from opacus.layers import DPLSTM 58 | 59 | lstm_cls = DPLSTM 60 | else: 61 | lstm_cls = nn.LSTM 62 | self.lstm = lstm_cls( 63 | input_size=self.config.embedding_size, 64 | hidden_size=self.config.hidden_size, 65 | num_layers=self.config.num_layers, 66 | dropout=self.config.dropout if self.config.num_layers > 1 else 0.0, 67 | batch_first=True, 68 | ) 69 | self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size) 70 | self.loss_fn = nn.CrossEntropyLoss() 71 | 72 | # this will be filled by left_to_right_padding() during the generation 73 | self.pad_token_id = None 74 | 75 | def forward( 76 | self, 77 | input_ids: torch.Tensor, 78 | attention_mask: torch.Tensor, 79 | labels: torch.Tensor | None = None, 80 | **kwargs, 81 | ) -> CausalLMOutput: 82 | lengths = attention_mask.sum(dim=1) 83 | embeddings = self.embedding(input_ids) 84 | embeddings = self.dropout(embeddings) 85 | 86 | # (DP)LSTM layers without pack_padded_sequence/pad_packed_sequence 87 | lstm_outputs, _ = self.lstm(embeddings) 88 | 89 | logits = self.lm_head(lstm_outputs) 90 | 91 | loss = None 92 | if labels is not None: 93 | labels = labels[:, 1:].contiguous() 94 | shifted_prediction_scores = logits[:, :-1, :].contiguous() 95 | loss = self.loss_fn(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 96 | else: 97 | # overwrite the logit of the last time step with the logit of the actual last token 98 | # so that Hugging Face Transformers' generate() will sample on the right probabilities 99 | logits[:, -1, :] = torch.stack([logits[i, length - 1, :] for i, length in enumerate(lengths)]) 100 | return CausalLMOutput( 101 | loss=loss, 102 | logits=logits, 103 | ) 104 | 105 | def prepare_inputs_for_generation( 106 | self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs 107 | ) -> dict[str, torch.Tensor]: 108 | """ 109 | This function is mandatory so that the model is able to use the Hugging Face `.generate()` method. 110 | Since `.generate()` works with left-padded sequences but the model is trained with right-padded sequences, 111 | we need to convert the padding side here to make it work properly. 112 | """ 113 | lengths = attention_mask.sum(dim=1) 114 | return { 115 | "input_ids": self.left_to_right_padding(input_ids, lengths), 116 | "attention_mask": attention_mask, 117 | } 118 | 119 | def left_to_right_padding(self, left_padded_tensors: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: 120 | batch_size, max_length = left_padded_tensors.size() 121 | indices = torch.nonzero(lengths < max_length) 122 | if len(indices) == 0: 123 | # none of the samples are padded, so we can just return them as they are 124 | return left_padded_tensors 125 | else: 126 | if self.pad_token_id is None: 127 | # get the pad token id from the first padded sample 128 | self.pad_token_id = left_padded_tensors[indices[0], -1].item() 129 | right_padded_tensors = torch.full_like(left_padded_tensors, self.pad_token_id) 130 | for i in range(batch_size): 131 | right_padded_tensors[i, : lengths[i]] = left_padded_tensors[i, max_length - lengths[i] :] 132 | return right_padded_tensors 133 | 134 | 135 | def register_mostly_lstm_model(): 136 | # register the model so that we can load it with `AutoModelForCausalLM.from_pretrained()` later 137 | AutoConfig.register(LSTMFromScratchConfig.model_id, LSTMFromScratchConfig) 138 | AutoModel.register(LSTMFromScratchConfig, LSTMFromScratchLMHeadModel) 139 | AutoModelForCausalLM.register(LSTMFromScratchConfig, LSTMFromScratchLMHeadModel) 140 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Iterator, Mapping 16 | from dataclasses import dataclass 17 | from typing import Any 18 | 19 | from transformers import BatchEncoding, DataCollatorForLanguageModeling, LlamaTokenizerFast, PreTrainedTokenizerFast 20 | from transformers.data.data_collator import _torch_collate_batch, pad_without_fast_tokenizer_warning 21 | 22 | from mostlyai.engine.domain import ModelEncodingType 23 | 24 | ################# 25 | ### TOKENIZER ### 26 | ################# 27 | 28 | 29 | def train_tokenizer( 30 | training_iterator: Iterator | list | None = None, 31 | tokenizer_kwargs: dict[str, Any] | None = None, 32 | tgt_stats: dict[str, Any] | None = None, 33 | ): 34 | if tokenizer_kwargs is None: 35 | tokenizer_kwargs = {} 36 | from tokenizers import Tokenizer, decoders 37 | from tokenizers.models import BPE 38 | from tokenizers.normalizers import Replace 39 | from tokenizers.pre_tokenizers import Metaspace, Punctuation, Sequence, Split 40 | from tokenizers.trainers import BpeTrainer 41 | 42 | special_tokens = { 43 | "unk_token": "", 44 | "pad_token": "", 45 | "bos_token": "", 46 | "eos_token": "", 47 | } 48 | SPECIAL_TOKENS = list(special_tokens.values()) 49 | NEW_LINE_VALUE = "\n" 50 | NEW_LINE_SYMBOL = "\u240a" # https://www.fileformat.info/info/unicode/char/240a/index.htm 51 | MIN_FREQ_MERGE = 20 52 | VOCAB_SIZE = 5000 53 | 54 | # add initial alphabet for numeric and datetime columns if needed 55 | has_numeric_columns = any( 56 | col_stats["encoding_type"] == ModelEncodingType.language_numeric for col_stats in tgt_stats["columns"].values() 57 | ) 58 | has_datetime_columns = any( 59 | col_stats["encoding_type"] == ModelEncodingType.language_datetime for col_stats in tgt_stats["columns"].values() 60 | ) 61 | initial_alphabet = set() 62 | if has_numeric_columns: 63 | # FIXME: maybe the set can be more fine-grained based on max_scale in stats 64 | initial_alphabet |= {str(i) for i in range(10)} | {".", "-", "+", "e", "E"} 65 | if has_datetime_columns: 66 | initial_alphabet |= {str(i) for i in range(10)} | {".", "-", ":", "T", "Z"} 67 | initial_alphabet = list(initial_alphabet) 68 | 69 | # Builds a BPE raw_tokenizer, and optionally trains it based on provided text 70 | training_iterator = training_iterator or [] # allow easy training skip 71 | raw_tokenizer = Tokenizer(BPE(unk_token=special_tokens["unk_token"])) 72 | trainer = BpeTrainer( 73 | initial_alphabet=initial_alphabet, 74 | special_tokens=SPECIAL_TOKENS, 75 | min_frequency=MIN_FREQ_MERGE, 76 | vocab_size=VOCAB_SIZE, 77 | show_progress=False, 78 | ) 79 | raw_tokenizer.normalizer = Replace(NEW_LINE_VALUE, NEW_LINE_SYMBOL) 80 | raw_tokenizer.pre_tokenizer = Sequence( 81 | [ 82 | Metaspace(), 83 | Split(pattern=NEW_LINE_SYMBOL, behavior="isolated"), 84 | Punctuation(), 85 | ] 86 | ) 87 | raw_tokenizer.decoder = decoders.Sequence( 88 | [ 89 | decoders.Metaspace(), 90 | decoders.Replace(NEW_LINE_SYMBOL, NEW_LINE_VALUE), 91 | ] 92 | ) 93 | raw_tokenizer.train_from_iterator(iterator=training_iterator, trainer=trainer) 94 | tokenizer = LlamaTokenizerFast(tokenizer_object=raw_tokenizer, **special_tokens, **tokenizer_kwargs) 95 | return tokenizer 96 | 97 | 98 | def tokenize_fn( 99 | text: dict[str, str] | dict[str, list[str]] | list[str], 100 | tokenizer: PreTrainedTokenizerFast, 101 | text_key: str | None = None, 102 | return_tensors: str | None = None, 103 | padding: bool | str = True, 104 | truncation: bool = True, 105 | add_bos_token: bool = True, 106 | add_eos_token: bool = True, 107 | max_length: int = 1024, 108 | ) -> BatchEncoding: 109 | if text_key: 110 | text = text[text_key] 111 | # make sure the tokenizer is configured as expected 112 | if getattr(tokenizer, "add_bos_token", False) or getattr(tokenizer, "add_eos_token", False): 113 | raise RuntimeError("Tokenizer must be configured as add_bos_token=False and add_eos_token=False") 114 | if tokenizer.bos_token is None or tokenizer.eos_token is None: 115 | raise RuntimeError("Tokenizer must have bos_token and eos_token set") 116 | prefix = tokenizer.bos_token if add_bos_token else "" 117 | suffix = tokenizer.eos_token if add_eos_token else "" 118 | # NOTE: here we add bos/eos tokens before truncation and padding, 119 | # which means that they may be truncated for long sequences 120 | if isinstance(text, str): 121 | text = f"{prefix}{text}{suffix}" 122 | else: 123 | for i, t in enumerate(text): 124 | text[i] = f"{prefix}{t}{suffix}" 125 | tokenized_content = tokenizer( 126 | text, 127 | padding=padding, 128 | truncation=truncation, 129 | max_length=max_length, 130 | return_tensors=return_tensors, 131 | ) 132 | return tokenized_content 133 | 134 | 135 | ##################### 136 | ### DATA COLLATOR ### 137 | ##################### 138 | 139 | 140 | @dataclass 141 | class MostlyDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): 142 | def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]: 143 | """ 144 | A variation of the original `DataCollatorForLanguageModeling.torch_call` method. 145 | 146 | This method can mask tokens based on the attention mask, so that bos and eos tokens will not be masked 147 | even if they are identical to pad token. 148 | If attention mask is not provided, it will fall back to masking pad tokens. 149 | """ 150 | if isinstance(examples[0], Mapping): 151 | batch = pad_without_fast_tokenizer_warning( 152 | self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=None 153 | ) 154 | else: 155 | batch = {"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=None)} 156 | 157 | labels = batch["input_ids"].clone() 158 | attention_mask = batch.get("attention_mask", None) 159 | if attention_mask is not None: 160 | labels[(attention_mask == 0)] = -100 161 | else: 162 | if self.tokenizer.pad_token_id is not None: 163 | labels[labels == self.tokenizer.pad_token_id] = -100 164 | batch["labels"] = labels 165 | return batch 166 | -------------------------------------------------------------------------------- /mostlyai/engine/_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import re 18 | 19 | import psutil 20 | import torch 21 | 22 | _LOG = logging.getLogger(__name__) 23 | 24 | 25 | def get_available_vram_for_heuristics() -> int: 26 | if not torch.cuda.is_available(): 27 | return 0 28 | free, total = torch.cuda.mem_get_info() 29 | return total 30 | 31 | 32 | def get_available_ram_for_heuristics() -> int: 33 | mem_limit = extract_memory_from_string(os.getenv("MOSTLY_ENGINE_AVAILABLE_RAM_FOR_HEURISTICS", default=None)) 34 | if mem_limit is None: 35 | mem_limit = psutil.virtual_memory().available 36 | return mem_limit 37 | 38 | 39 | def extract_memory_from_string(memory_str: str | None = None) -> int | None: 40 | """ 41 | Extract the memory in bytes from a string. 42 | 43 | :param memory_str: The memory string to extract the memory from. 44 | :return: The memory in bytes. 45 | """ 46 | if not memory_str: 47 | return None 48 | 49 | # Conversion factors, considering metric (decimal) vs. binary (IEC) units 50 | units = { 51 | "": 1, 52 | "b": 1, 53 | "k": 1024, 54 | "m": 1024**2, 55 | "g": 1024**3, 56 | "t": 1024**4, 57 | } 58 | match = re.match(r"(\d+(?:\.\d+)?)[ ]?([a-z]?)", memory_str.strip().lower()) 59 | if not match: 60 | return None 61 | 62 | value, unit = match.groups() 63 | value = float(value) 64 | 65 | # Convert to bytes 66 | if unit in units: 67 | return int(value * units[unit]) 68 | else: 69 | return None 70 | -------------------------------------------------------------------------------- /mostlyai/engine/_tabular/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_tabular/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | 21 | _LOG = logging.getLogger(__name__) 22 | 23 | 24 | def load_model_weights(model: torch.nn.Module, path: Path, device: torch.device): 25 | try: 26 | t00 = time.time() 27 | model.load_state_dict(torch.load(f=path, map_location=device, weights_only=True)) 28 | _LOG.info(f"loaded model weights in {time.time() - t00:.2f}s") 29 | except Exception as e: 30 | _LOG.warning(f"failed to load model weights: {e}") 31 | -------------------------------------------------------------------------------- /mostlyai/engine/_training_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import logging 17 | import time 18 | 19 | import pandas as pd 20 | import torch 21 | from opacus.accountants import IAccountant 22 | from pydantic import BaseModel, Field, field_validator 23 | 24 | from mostlyai.engine._workspace import Workspace 25 | 26 | _LOG = logging.getLogger(__name__) 27 | 28 | 29 | class ProgressMessage(BaseModel, extra="allow"): 30 | epoch: float | None = Field(None, description="Current epoch number") 31 | is_checkpoint: bool | int | None = Field(0, description="Whether this progress is a checkpoint") 32 | steps: int | None = Field(None, description="Number of processed steps") 33 | samples: int | None = Field(None, description="Number of processed samples") 34 | trn_loss: float | None = Field(None, description="Training loss") 35 | val_loss: float | None = Field(None, description="Validation loss") 36 | total_time: float | None = Field(None, description="Elapsed total time (s)") 37 | learn_rate: float | None = Field(None, description="Learning rate") 38 | dp_eps: float | None = Field(None, description="Differential privacy epsilon") 39 | dp_delta: float | None = Field(None, description="Differential privacy delta") 40 | 41 | @field_validator("epoch", "trn_loss", "val_loss", "learn_rate", "total_time", "dp_eps", "dp_delta") 42 | @classmethod 43 | def round_float(cls, v, info) -> float: 44 | field_decimal_places = { 45 | "epoch": 2, 46 | "trn_loss": 4, 47 | "val_loss": 4, 48 | "learn_rate": 6, 49 | "total_time": 1, 50 | "dp_eps": 2, 51 | "dp_delta": 8, 52 | } 53 | if isinstance(v, float) and info.field_name in field_decimal_places: 54 | return round(v, field_decimal_places[info.field_name]) 55 | return v 56 | 57 | @field_validator("is_checkpoint") 58 | @classmethod 59 | def cast_to_int(cls, v) -> int: 60 | return int(v) 61 | 62 | 63 | class EarlyStopper: 64 | """ 65 | Stop training when val_loss stopped improving for a while 66 | """ 67 | 68 | def __init__(self, val_loss_patience: int) -> None: 69 | self.val_loss_patience = val_loss_patience 70 | self.best_loss = float("inf") 71 | self.val_loss_cnt = 0 72 | 73 | def __call__(self, val_loss: float) -> bool: 74 | do_stop = False 75 | # check val_loss 76 | if not pd.isna(val_loss) and val_loss < self.best_loss: 77 | # remember best val_loss 78 | self.best_loss = val_loss 79 | # reset counter 80 | self.val_loss_cnt = 0 81 | else: 82 | self.val_loss_cnt += 1 83 | if self.val_loss_cnt > self.val_loss_patience: 84 | _LOG.info("early stopping: val_loss stopped improving") 85 | do_stop = True 86 | return do_stop 87 | 88 | 89 | class ModelCheckpoint(abc.ABC): 90 | """ 91 | Save model weights for best model. 92 | """ 93 | 94 | def __init__(self, workspace: Workspace, initial_best_val_loss: float = float("inf")) -> None: 95 | self.workspace = workspace 96 | self.best_val_loss = initial_best_val_loss 97 | self.last_save_time = time.time() 98 | self.save_count = 0 99 | 100 | def optimizer_and_lr_scheduler_paths_exist(self) -> bool: 101 | return self.workspace.model_optimizer_path.exists() and self.workspace.model_lr_scheduler_path.exists() 102 | 103 | @abc.abstractmethod 104 | def model_weights_path_exists(self) -> None: 105 | pass 106 | 107 | def clear_checkpoint(self): 108 | self.workspace.model_optimizer_path.unlink(missing_ok=True) 109 | self.workspace.model_lr_scheduler_path.unlink(missing_ok=True) 110 | self._clear_model_weights() 111 | 112 | def save_checkpoint_if_best( 113 | self, 114 | val_loss: float, 115 | model: torch.nn.Module, 116 | optimizer: torch.optim.Optimizer | None = None, 117 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, 118 | dp_accountant: IAccountant | None = None, 119 | ) -> bool: 120 | # save model weights if validation loss has improved 121 | if val_loss < self.best_val_loss: 122 | self.best_val_loss = val_loss 123 | self.save_checkpoint(model, optimizer, lr_scheduler, dp_accountant) 124 | return True 125 | else: 126 | return False 127 | 128 | def save_checkpoint( 129 | self, 130 | model: torch.nn.Module, 131 | optimizer: torch.optim.Optimizer | None = None, 132 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, 133 | dp_accountant: IAccountant | None = None, 134 | ) -> None: 135 | if optimizer is not None and lr_scheduler is not None: 136 | torch.save(optimizer.state_dict(), self.workspace.model_optimizer_path) 137 | torch.save(lr_scheduler.state_dict(), self.workspace.model_lr_scheduler_path) 138 | if dp_accountant is not None: 139 | torch.save(dp_accountant.state_dict(), self.workspace.model_dp_accountant_path) 140 | self._save_model_weights(model) 141 | self.last_save_time = time.time() 142 | self.save_count += 1 143 | 144 | def has_saved_once(self) -> bool: 145 | return self.save_count > 0 146 | 147 | @abc.abstractmethod 148 | def _save_model_weights(self, model: torch.nn.Module) -> None: 149 | pass 150 | 151 | @abc.abstractmethod 152 | def _clear_model_weights(self) -> None: 153 | pass 154 | 155 | 156 | def check_early_training_exit(workspace: Workspace, trn_cnt: int, val_cnt: int) -> bool: 157 | trn_files = workspace.encoded_data_trn.fetch_all() 158 | val_files = workspace.encoded_data_val.fetch_all() 159 | return any((len(trn_files) == 0, len(val_files) == 0, trn_cnt == 0, val_cnt == 0)) 160 | -------------------------------------------------------------------------------- /mostlyai/engine/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | from mostlyai.engine._common import ProgressCallback 18 | from mostlyai.engine._workspace import resolve_model_type 19 | from mostlyai.engine.domain import ModelType 20 | 21 | 22 | def encode( 23 | *, 24 | workspace_dir: str | Path = "engine-ws", 25 | update_progress: ProgressCallback | None = None, 26 | ) -> None: 27 | """ 28 | Encodes data in the workspace that has already been split and analyzed. 29 | 30 | Creates the following folder structure within the `workspace_dir`: 31 | 32 | - `OriginalData/encoded-data`: Encoded data for training, stored as parquet files. 33 | 34 | Args: 35 | workspace_dir: Directory path for workspace. 36 | update_progress: Callback for progress updates. 37 | """ 38 | model_type = resolve_model_type(workspace_dir) 39 | if model_type == ModelType.tabular: 40 | from mostlyai.engine._tabular.encoding import encode as encode_tabular 41 | 42 | return encode_tabular(workspace_dir=workspace_dir, update_progress=update_progress) 43 | else: 44 | from mostlyai.engine._language.encoding import encode as encode_language 45 | 46 | return encode_language(workspace_dir=workspace_dir, update_progress=update_progress) 47 | -------------------------------------------------------------------------------- /mostlyai/engine/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | import pandas as pd 18 | 19 | from mostlyai.engine._common import ProgressCallback 20 | from mostlyai.engine._workspace import resolve_model_type 21 | from mostlyai.engine.domain import ( 22 | FairnessConfig, 23 | ImputationConfig, 24 | ModelType, 25 | RareCategoryReplacementMethod, 26 | RebalancingConfig, 27 | ) 28 | 29 | 30 | def generate( 31 | *, 32 | ctx_data: pd.DataFrame | None = None, 33 | seed_data: pd.DataFrame | None = None, 34 | sample_size: int | None = None, 35 | batch_size: int | None = None, 36 | sampling_temperature: float = 1.0, 37 | sampling_top_p: float = 1.0, 38 | device: str | None = None, 39 | rare_category_replacement_method: RareCategoryReplacementMethod | str = RareCategoryReplacementMethod.constant, 40 | rebalancing: RebalancingConfig | dict | None = None, 41 | imputation: ImputationConfig | dict | None = None, 42 | fairness: FairnessConfig | dict | None = None, 43 | workspace_dir: str | Path = "engine-ws", 44 | update_progress: ProgressCallback | None = None, 45 | ) -> None: 46 | """ 47 | Generates synthetic data from a trained model. 48 | 49 | Creates the following folder structure within the `workspace_dir`: 50 | 51 | - `SyntheticData`: Generated synthetic data, stored as parquet files. 52 | 53 | Args: 54 | ctx_data: Context data to be used for generation. 55 | seed_data: Seed data to condition generation on fixed target columns. 56 | sample_size: Number of samples to generate. Defaults to number of original samples. 57 | batch_size: Batch size for generation. If None, determined automatically. 58 | sampling_temperature: Sampling temperature. Higher values increase randomness. 59 | sampling_top_p: Nucleus sampling probability threshold. 60 | device: Device to run generation on ('cuda' or 'cpu'). Defaults to 'cuda' if available, else 'cpu'. 61 | rare_category_replacement_method: Method for handling rare categories. Only applicable for tabular models. 62 | rebalancing: Configuration for rebalancing column distributions. Only applicable for tabular models. 63 | imputation: List of columns to impute missing values. Only applicable for tabular models. 64 | fairness: Configuration for fairness constraints. Only applicable for tabular models. 65 | workspace_dir: Directory path for workspace. 66 | update_progress: Callback for progress updates. 67 | """ 68 | model_type = resolve_model_type(workspace_dir) 69 | if model_type == ModelType.tabular: 70 | from mostlyai.engine._tabular.generation import generate as generate_tabular 71 | 72 | return generate_tabular( 73 | ctx_data=ctx_data, 74 | seed_data=seed_data, 75 | sample_size=sample_size, 76 | batch_size=batch_size, 77 | sampling_temperature=sampling_temperature, 78 | sampling_top_p=sampling_top_p, 79 | rare_category_replacement_method=rare_category_replacement_method, 80 | rebalancing=rebalancing, 81 | imputation=imputation, 82 | fairness=fairness, 83 | device=device, 84 | workspace_dir=workspace_dir, 85 | update_progress=update_progress, 86 | ) 87 | else: 88 | from mostlyai.engine._language.generation import generate as generate_language 89 | 90 | if imputation is not None: 91 | raise ValueError("imputation is not supported for language models") 92 | if fairness is not None: 93 | raise ValueError("fairness is not supported for language models") 94 | if rebalancing is not None: 95 | raise ValueError("rebalancing is not supported for language models") 96 | return generate_language( 97 | ctx_data=ctx_data, 98 | seed_data=seed_data, 99 | sample_size=sample_size, 100 | batch_size=batch_size, 101 | sampling_temperature=sampling_temperature, 102 | sampling_top_p=sampling_top_p, 103 | rare_category_replacement_method=rare_category_replacement_method, 104 | device=device, 105 | workspace_dir=workspace_dir, 106 | update_progress=update_progress, 107 | ) 108 | -------------------------------------------------------------------------------- /mostlyai/engine/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import sys 17 | 18 | _LOG = logging.getLogger(__name__.rsplit(".", 1)[0]) # get the logger with the root module name (mostlyai.engine) 19 | 20 | 21 | def init_logging() -> None: 22 | """ 23 | Initialize the logging configuration to stdout. 24 | """ 25 | 26 | _LOG.propagate = False 27 | if not _LOG.hasHandlers(): 28 | handler = logging.StreamHandler(stream=sys.stdout) 29 | handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)-7s: %(message)s")) 30 | handler.setLevel(logging.INFO) 31 | _LOG.addHandler(handler) 32 | _LOG.setLevel(logging.INFO) 33 | -------------------------------------------------------------------------------- /mostlyai/engine/random_state.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import random 18 | import struct 19 | 20 | import numpy as np 21 | import torch 22 | 23 | _LOG = logging.getLogger(__name__) 24 | 25 | 26 | def set_random_state(random_state: int | None = None, worker: bool = False): 27 | def get_random_int_from_os() -> int: 28 | # 32-bit, cryptographically secure random int from os 29 | return int(struct.unpack("I", os.urandom(4))[0]) 30 | 31 | if worker: # worker process 32 | if "MOSTLYAI_ENGINE_SEED" in os.environ: 33 | random_state = int(os.environ["MOSTLYAI_ENGINE_SEED"]) 34 | else: 35 | # don't set seed for worker process if not set in main process 36 | return 37 | else: # main process 38 | if random_state is not None: 39 | _LOG.info(f"Global random_state set to `{random_state}`") 40 | 41 | if random_state is None: 42 | random_state = get_random_int_from_os() 43 | 44 | os.environ["MOSTLYAI_ENGINE_SEED"] = str(random_state) 45 | 46 | random.seed(random_state) 47 | np.random.seed(random_state) 48 | torch.manual_seed(random_state) 49 | torch.cuda.manual_seed_all(random_state) 50 | -------------------------------------------------------------------------------- /mostlyai/engine/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from collections.abc import Callable 17 | from pathlib import Path 18 | 19 | import torch 20 | 21 | from mostlyai.engine._common import ProgressCallback 22 | from mostlyai.engine._workspace import resolve_model_type 23 | from mostlyai.engine.domain import DifferentialPrivacyConfig, ModelStateStrategy, ModelType 24 | 25 | 26 | def train( 27 | *, 28 | model: str | None = None, 29 | max_training_time: float | None = 14400.0, # 10 days 30 | max_epochs: float | None = 100.0, # 100 epochs 31 | batch_size: int | None = None, 32 | gradient_accumulation_steps: int | None = None, 33 | enable_flexible_generation: bool = True, 34 | max_sequence_window: int | None = None, 35 | differential_privacy: DifferentialPrivacyConfig | dict | None = None, 36 | model_state_strategy: ModelStateStrategy = ModelStateStrategy.reset, 37 | device: torch.device | str | None = None, 38 | workspace_dir: str | Path = "engine-ws", 39 | update_progress: ProgressCallback | None = None, 40 | upload_model_data_callback: Callable | None = None, 41 | ) -> None: 42 | """ 43 | Trains a model with optional early stopping and differential privacy. 44 | 45 | Creates the following folder structure within the `workspace_dir`: 46 | 47 | - `ModelStore`: Trained model checkpoints and logs. 48 | 49 | Args: 50 | model: The identifier of the model to train. If tabular, defaults to MOSTLY_AI/Medium. If language, defaults to MOSTLY_AI/LSTMFromScratch-3m. 51 | max_training_time: Maximum training time in minutes. If None, defaults to 10 days. 52 | max_epochs: Maximum number of training epochs. If None, defaults to 100 epochs. 53 | batch_size: Per-device batch size for training and validation. If None, determined automatically. 54 | gradient_accumulation_steps: Number of steps to accumulate gradients. If None, determined automatically. 55 | enable_flexible_generation: Whether to enable flexible order generation. Defaults to True. 56 | max_sequence_window: Maximum sequence window for tabular sequential models. Only applicable for tabular models. 57 | differential_privacy: Configuration for differential privacy training. If None, DP is disabled. 58 | model_state_strategy: Strategy for handling existing model state (reset/resume/reuse). 59 | device: Device to run training on ('cuda' or 'cpu'). Defaults to 'cuda' if available, else 'cpu'. 60 | workspace_dir: Directory path for workspace. Training outputs are stored in ModelStore subdirectory. 61 | update_progress: Callback function to report training progress. 62 | upload_model_data_callback: Callback function to upload model data during training. 63 | """ 64 | model_type = resolve_model_type(workspace_dir) 65 | if model_type == ModelType.tabular: 66 | from mostlyai.engine._tabular.training import train as train_tabular 67 | 68 | args = inspect.signature(train_tabular).parameters 69 | train_tabular( 70 | model=model if model else args["model"].default, 71 | workspace_dir=workspace_dir, 72 | max_training_time=max_training_time if max_training_time else args["max_training_time"].default, 73 | max_epochs=max_epochs if max_epochs else args["max_epochs"].default, 74 | batch_size=batch_size, 75 | gradient_accumulation_steps=gradient_accumulation_steps, 76 | enable_flexible_generation=enable_flexible_generation, 77 | differential_privacy=differential_privacy, 78 | update_progress=update_progress, 79 | upload_model_data_callback=upload_model_data_callback, 80 | model_state_strategy=model_state_strategy, 81 | device=device, 82 | max_sequence_window=max_sequence_window if max_sequence_window else args["max_sequence_window"].default, 83 | ) 84 | else: 85 | from mostlyai.engine._language.training import train as train_language 86 | 87 | if max_sequence_window is not None: 88 | raise ValueError("max_sequence_window is not supported for language models") 89 | 90 | args = inspect.signature(train_language).parameters 91 | train_language( 92 | model=model if model else args["model"].default, 93 | workspace_dir=workspace_dir, 94 | max_training_time=max_training_time if max_training_time else args["max_training_time"].default, 95 | max_epochs=max_epochs if max_epochs else args["max_epochs"].default, 96 | batch_size=batch_size, 97 | gradient_accumulation_steps=gradient_accumulation_steps, 98 | enable_flexible_generation=enable_flexible_generation, 99 | differential_privacy=differential_privacy, 100 | update_progress=update_progress, 101 | upload_model_data_callback=upload_model_data_callback, 102 | model_state_strategy=model_state_strategy, 103 | device=device, 104 | ) 105 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mostlyai-engine" 3 | version = "1.4.3" 4 | description = "Synthetic Data Engine" 5 | authors = [{ name = "MOSTLY AI", email = "dev@mostly.ai" }] 6 | requires-python = ">=3.10" 7 | readme = "README.md" 8 | license = "Apache-2.0" 9 | classifiers = [ 10 | "Development Status :: 5 - Production/Stable", 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Science/Research", 13 | "Intended Audience :: Information Technology", 14 | "Intended Audience :: Financial and Insurance Industry", 15 | "Intended Audience :: Healthcare Industry", 16 | "Intended Audience :: Telecommunications Industry", 17 | "Programming Language :: Python :: 3.10", 18 | "Programming Language :: Python :: 3.11", 19 | "Programming Language :: Python :: 3.12", 20 | "Programming Language :: Python :: 3.13", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Operating System :: OS Independent", 23 | "Topic :: Software Development :: Libraries", 24 | "Typing :: Typed", 25 | ] 26 | 27 | dependencies = [ 28 | "setuptools>=77.0.3", # similar to vllm 0.8.5.post1 29 | "numpy>=1.26.3", 30 | "pandas~=2.2.0", 31 | "pyarrow>=16.0.0", 32 | "joblib>=1.4.2", 33 | "psutil>=5.9.5,<6", # upgrade when colab psutil is updated 34 | "tokenizers>=0.21.0", 35 | "transformers>=4.51.0", 36 | "datasets>=3.0.0", 37 | "accelerate>=1.5.0", 38 | "peft>=0.12.0", 39 | "huggingface-hub[hf-xet]>=0.30.2", 40 | "opacus>=1.5.2", # switch to 1.5.4 (once released) to allow numpy 2 41 | "xgrammar>=0.1.18", # for vllm 0.8.5.post1 compatibility 42 | "json-repair>=0.30.0", 43 | "torch>=2.6.0,<2.6.1", 44 | "torchaudio>=2.6.0,<2.6.1", # for vllm 0.8.5.post1 compatibility 45 | "torchvision>=0.21.0,<0.21.1" # for vllm 0.8.5.post1 compatibility 46 | ] 47 | 48 | [project.optional-dependencies] 49 | gpu = [ 50 | "bitsandbytes==0.42.0; sys_platform == 'darwin'", 51 | "bitsandbytes>=0.45.5; sys_platform == 'linux'", 52 | "vllm==0.8.5.post1; sys_platform == 'linux' or sys_platform == 'darwin'", 53 | ] 54 | 55 | [dependency-groups] 56 | dev = [ 57 | "pytest>=8.0", 58 | "ruff>=0.11", # sync'ed with .pre-commit-config 59 | "pre-commit>=4.0", 60 | "twine>=6.1", 61 | "ipykernel>=6.25", 62 | ] 63 | docs = [ 64 | "mkdocs>=1.6", 65 | "mkdocstrings[crystal, python]>=0.29", 66 | "mkdocs-material>=9.0", 67 | "griffe>=1.0", 68 | "pymdown-extensions>=10.0", 69 | "griffe-fieldz>=0.2", 70 | "black>=25.0", 71 | ] 72 | 73 | [project.urls] 74 | homepage = "https://github.com/mostly-ai/mostlyai-engine" 75 | repository = "https://github.com/mostly-ai/mostlyai-engine" 76 | documentation = "https://mostly-ai.github.io/mostlyai-engine/" 77 | 78 | [tool.uv] 79 | default-groups = ["dev", "docs"] 80 | 81 | [tool.hatch.build.targets.sdist] 82 | include = ["mostlyai/engine"] 83 | 84 | [tool.hatch.build.targets.wheel] 85 | include = ["mostlyai/engine"] 86 | 87 | [tool.hatch.metadata] 88 | allow-direct-references = true 89 | 90 | [build-system] 91 | requires = ["hatchling", "hatch-vcs"] 92 | build-backend = "hatchling.build" 93 | 94 | [tool.ruff] 95 | target-version = "py310" 96 | line-length = 120 97 | 98 | [tool.ruff.lint] 99 | extend-select = ["I"] 100 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/end_to_end/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/end_to_end/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import pytest 20 | 21 | from mostlyai.engine._common import STRING 22 | 23 | 24 | @pytest.fixture 25 | def set_random_seed(): 26 | random.seed(0) 27 | np.random.seed(0) 28 | 29 | 30 | class MockData: 31 | def __init__(self, n_samples: int): 32 | self.n_samples = n_samples 33 | self.df = pd.DataFrame(index=range(self.n_samples)) 34 | 35 | def add_index_column(self, name: str): 36 | values = pd.DataFrame({name: range(len(self.df))}).astype(STRING) 37 | self.df = pd.concat([self.df, values], axis=1) 38 | 39 | def add_categorical_column( 40 | self, name: str, probabilities: dict[str, float], rare_categories: list[str] | None = None 41 | ): 42 | values = np.random.choice( 43 | list(probabilities.keys()), 44 | size=len(self.df), 45 | p=list(probabilities.values()), 46 | ) 47 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 48 | if rare_categories: 49 | self.df.loc[np.random.choice(self.df.index, len(rare_categories), replace=False), name] = rare_categories 50 | 51 | def add_numeric_column(self, name: str, quantiles: dict[float, float], dtype: str = "float32"): 52 | uniform_samples = np.random.rand(len(self.df)) 53 | values = np.interp(uniform_samples, list(quantiles.keys()), list(quantiles.values())).astype(dtype) 54 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 55 | 56 | def add_datetime_column(self, name: str, start_date: str, end_date: str, freq: str = "s"): 57 | date_range = pd.date_range(start=start_date, end=end_date, freq=freq) 58 | values = np.random.choice(date_range, len(self.df), replace=True) 59 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 60 | 61 | def add_date_column(self, name: str, start_date: str, end_date: str): 62 | self.add_datetime_column(name, start_date, end_date, freq="D") 63 | 64 | def add_lat_long_column(self, name: str, lat_limit: tuple[float, float], long_limit: tuple[float, float]): 65 | latitude = np.random.uniform(lat_limit[0], lat_limit[1], len(self.df)) 66 | longitude = np.random.uniform(long_limit[0], long_limit[1], len(self.df)) 67 | values = [f"{lat:.4f}, {long:.4f}" for lat, long in zip(latitude, longitude)] 68 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 69 | 70 | def add_sequential_column(self, name: str, seq_len_quantiles: dict[float, float]): 71 | self.add_numeric_column("seq_len", seq_len_quantiles, dtype="int32") 72 | # if seq_len is 3, it will populate a sequence ["0", "1", "2"] and then explode the list to 3 rows 73 | self.df[name] = self.df["seq_len"].apply(lambda x: [str(i) for i in range(x)]) 74 | self.df = self.df.explode(name).drop(columns="seq_len").reset_index(drop=True) 75 | -------------------------------------------------------------------------------- /tests/end_to_end/test_numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import shutil 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import pytest 21 | 22 | from mostlyai.engine import analyze, encode 23 | from mostlyai.engine._common import write_json 24 | from mostlyai.engine._tabular.generation import generate 25 | from mostlyai.engine._tabular.training import train 26 | from mostlyai.engine.domain import ModelEncodingType 27 | 28 | 29 | @pytest.fixture 30 | def sum_df(): 31 | df = pd.DataFrame( 32 | { 33 | "n": np.random.randint(0, 98, size=1000) * 100, 34 | "m": np.random.randint(0, 98, size=1000), 35 | } 36 | ) 37 | 38 | # Create column 'o' as sum of 'n' and 'm' 39 | df["o"] = df["n"] + df["m"] 40 | 41 | yield df 42 | 43 | 44 | @pytest.fixture 45 | def product_df(): 46 | # Generate 1000 uniformly distributed random prices between 0 and 300 47 | prices = np.random.uniform(0, 300, 1000).astype(int) 48 | 49 | # Replace the decimal part with one of the following: 00, 05, 50, 95, 99. 50 | decimals = np.random.choice([0, 0.05, 0.5, 0.95, 0.99], 1000) 51 | prices = prices.astype(int) + decimals 52 | 53 | # Create a DataFrame 54 | df = pd.DataFrame( 55 | { 56 | "price": prices, 57 | } 58 | ) 59 | 60 | yield df 61 | 62 | 63 | def prepare_ws(tmp_path: Path, df: pd.DataFrame, keys: dict, encoding_types: dict) -> Path: 64 | workspace_dir = tmp_path / "ws" 65 | shutil.rmtree(workspace_dir, ignore_errors=True) # cleanup 66 | tgt_meta_path = workspace_dir / "OriginalData" / "tgt-meta" 67 | tgt_data_path = workspace_dir / "OriginalData" / "tgt-data" 68 | for path in [ 69 | workspace_dir, 70 | tgt_meta_path, 71 | tgt_data_path, 72 | ]: 73 | path.mkdir(exist_ok=True, parents=True) 74 | 75 | df.to_parquet(tgt_data_path / "part.000000-trn.parquet") 76 | write_json(keys, tgt_meta_path / "keys.json") 77 | write_json(encoding_types, tgt_meta_path / "encoding-types.json") 78 | 79 | return workspace_dir 80 | 81 | 82 | def synthetize(ws_dir: Path) -> pd.DataFrame: 83 | analyze(workspace_dir=ws_dir) 84 | encode(workspace_dir=ws_dir) 85 | train(max_epochs=5, workspace_dir=ws_dir) 86 | generate(workspace_dir=ws_dir) 87 | syn_data_path = ws_dir / "SyntheticData" 88 | syn = pd.read_parquet(syn_data_path) 89 | 90 | return syn 91 | 92 | 93 | def compare_numeric_encodings( 94 | tmp_path, 95 | df, 96 | numeric_cols, 97 | first=ModelEncodingType.tabular_numeric_auto, 98 | second=ModelEncodingType.tabular_numeric_digit, 99 | ): 100 | syn = [] 101 | for numeric_encoding in [first, second]: 102 | ws = prepare_ws( 103 | tmp_path=tmp_path, 104 | df=df, 105 | keys={}, 106 | encoding_types={k: numeric_encoding.value for k in numeric_cols}, 107 | ) 108 | syn.append(synthetize(ws)) 109 | 110 | return syn[0], syn[1] 111 | 112 | 113 | def test_numeric_sum_quality(tmp_path, sum_df): 114 | sum_syn_auto, sum_syn_digit = compare_numeric_encodings(tmp_path=tmp_path, df=sum_df, numeric_cols=["n", "m", "o"]) 115 | 116 | assert sum_syn_auto.shape == sum_syn_digit.shape 117 | 118 | def calculate_sum_square_errors(df: pd.DataFrame, expected: str, actual: str): 119 | # Calculate the squares of the % errors 120 | squared_error = np.square((df[actual] - df[expected]) / df[actual]) 121 | return np.sum(squared_error) 122 | 123 | sum_syn_auto["expected"] = sum_syn_auto["n"] + sum_syn_auto["m"] 124 | sum_syn_auto_errors = calculate_sum_square_errors(df=sum_syn_auto, expected="expected", actual="o") 125 | sum_syn_digit["expected"] = sum_syn_digit["n"] + sum_syn_digit["m"] 126 | sum_syn_digit_errors = calculate_sum_square_errors(df=sum_syn_digit, expected="expected", actual="o") 127 | 128 | # ensure the quality is reasonable 129 | assert sum_syn_auto_errors / sum_syn_digit_errors < 10 130 | 131 | 132 | def test_numeric_price_quality(tmp_path, product_df): 133 | prod_syn_auto, prod_syn_digit = compare_numeric_encodings(tmp_path=tmp_path, df=product_df, numeric_cols=["price"]) 134 | 135 | assert prod_syn_auto.shape == prod_syn_digit.shape 136 | 137 | def similar_quantiles(ser_first, ser_second, threshold=0.05) -> bool: 138 | quantiles = [0.25, 0.5, 0.75, 1] 139 | q_first = ser_first.quantile(quantiles) 140 | q_second = ser_second.quantile(quantiles) 141 | return bool(all(np.abs((q_first - q_second) / ((q_first + q_second) / 2)) <= threshold)) 142 | 143 | assert similar_quantiles(product_df, prod_syn_auto) 144 | assert similar_quantiles(product_df, prod_syn_digit) 145 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/test_categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from mostlyai.engine._encoding_types.language.categorical import ( 20 | CATEGORICAL_UNKNOWN_TOKEN, 21 | analyze_language_categorical, 22 | analyze_reduce_language_categorical, 23 | decode_language_categorical, 24 | encode_language_categorical, 25 | ) 26 | 27 | 28 | class TestLanguageCategoricalAnalyze: 29 | def test_3_frequent_and_1_rare_values(self): 30 | values = pd.Series(np.repeat(["secret", "male", "female", pd.NA], 100), name="gender") 31 | ids = pd.Series( 32 | np.concatenate([np.repeat(0, 100), range(100), range(100, 200), range(200, 300)]), 33 | name="subject_id", 34 | ) 35 | stats = analyze_language_categorical(values, ids) 36 | assert stats == { 37 | "cnt_values": {"female": 100, "male": 100, "secret": 1}, 38 | "has_nan": True, 39 | } 40 | 41 | 42 | class TestLanguageCategoricalAnalyzeReduce: 43 | @pytest.fixture 44 | def stats_list(self): 45 | stats1 = { 46 | "cnt_values": {"secret1": 1, "male": 100}, 47 | "has_nan": True, 48 | } 49 | stats2 = { 50 | "cnt_values": {"secret2": 1, "male": 100, "female": 100}, 51 | "has_nan": False, 52 | } 53 | return stats1, stats2 54 | 55 | def test_with_value_protection(self, stats_list): 56 | stats1, stats2 = stats_list 57 | stats = analyze_reduce_language_categorical([stats1, stats2], value_protection=True) 58 | assert stats == { 59 | "categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "female", "male"], 60 | "no_of_rare_categories": 2, 61 | } 62 | 63 | 64 | class TestLanguageCategoricalEncode: 65 | def test_2_frequent_and_1_rare_and_1_null_values(self): 66 | values = pd.Series(np.repeat(["secret", "male", "female", pd.NA], 100), name="gender") 67 | stats = { 68 | "categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "female", "male"], 69 | "no_of_rare_categories": 1, 70 | } 71 | expected = pd.Series( 72 | np.repeat([CATEGORICAL_UNKNOWN_TOKEN, "male", "female", pd.NA], 100), name="gender", dtype="string" 73 | ) 74 | encoded = encode_language_categorical(values, stats) 75 | pd.testing.assert_series_equal(encoded, expected) 76 | 77 | 78 | class TestLanguageCategoricalDecode: 79 | @pytest.fixture 80 | def col_stats(self): 81 | return {"categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "apple", "banana", "cherry"]} 82 | 83 | @pytest.fixture 84 | def sample_values(self): 85 | return pd.Series(["apple", "durian", "banana", "elderberry", "cherry", "fig", None]) 86 | 87 | def test_language_categorical_decode(self, sample_values, col_stats): 88 | decoded = decode_language_categorical(sample_values, col_stats) 89 | expected = pd.Series(["apple", None, "banana", None, "cherry", None, None], dtype=decoded.dtype) 90 | pd.testing.assert_series_equal(decoded, expected) 91 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/test_datetime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pytest 17 | 18 | from mostlyai.engine._common import ANALYZE_MIN_MAX_TOP_N 19 | from mostlyai.engine._encoding_types.language.datetime import ( 20 | analyze_language_datetime, 21 | analyze_reduce_language_datetime, 22 | decode_language_datetime, 23 | encode_language_datetime, 24 | ) 25 | from mostlyai.engine.domain import ModelEncodingType 26 | 27 | 28 | class TestLanguageDatetimeAnalyze: 29 | def test_analyze_language_datetime(self): 30 | birth_dates = pd.Series( 31 | [ 32 | "1910-01-01", 33 | "", 34 | "1930-01-31", 35 | "1940-02-12", 36 | "", 37 | "1971-09-01", 38 | "1983-05-19", 39 | "1998-05-24", 40 | ] 41 | * ANALYZE_MIN_MAX_TOP_N, 42 | name="birth_date", 43 | ) 44 | keys = pd.Series(range(len(birth_dates)), name="id") 45 | stats = analyze_language_datetime(birth_dates, keys) 46 | assert stats["has_nan"] is True 47 | assert stats["min_n"] == ["1910-01-01"] * ANALYZE_MIN_MAX_TOP_N 48 | assert stats["max_n"] == ["1998-05-24"] * ANALYZE_MIN_MAX_TOP_N 49 | 50 | 51 | class TestLanguageDatetimeAnalyzeReduce: 52 | def test_analyze_reduce_language_datetime(self): 53 | stats1 = { 54 | "has_nan": True, 55 | "min_n": ["1910-01-01"] * ANALYZE_MIN_MAX_TOP_N, 56 | "max_n": ["1998-05-24"] * ANALYZE_MIN_MAX_TOP_N, 57 | } 58 | stats2 = { 59 | "has_nan": False, 60 | "min_n": ["2000-01-01"] * ANALYZE_MIN_MAX_TOP_N, 61 | "max_n": ["2024-12-31"] * ANALYZE_MIN_MAX_TOP_N, 62 | } 63 | reduced = analyze_reduce_language_datetime([stats1, stats2]) 64 | assert reduced["has_nan"] is True 65 | assert reduced["min"] == "1910-01-01" 66 | assert reduced["max"] == "2024-12-31" 67 | 68 | 69 | class TestLanguageDatetimeEncode: 70 | def test_encode_language_datetime(self): 71 | values = pd.Series( 72 | [ 73 | "1910-01-01", 74 | "", 75 | "1930-01-31", 76 | "1940-02-12", 77 | "", 78 | "1971-09-01", 79 | "1983-05-19", 80 | "1998-05-24", 81 | ], 82 | name="birth_date", 83 | ) 84 | stats = { 85 | "has_nan": True, 86 | "min": "1930-01-31", 87 | "max": "2024-12-31", 88 | } 89 | encoded = encode_language_datetime(values, stats) 90 | assert encoded.dtype == "datetime64[us]" 91 | assert encoded.isna().sum() == 2 92 | assert encoded.iloc[0] == pd.Timestamp("1930-01-31") 93 | assert encoded.iloc[1] is pd.NaT 94 | assert encoded.iloc[2] == pd.Timestamp("1930-01-31") 95 | assert encoded.iloc[3] == pd.Timestamp("1940-02-12") 96 | assert encoded.iloc[4] is pd.NaT 97 | assert encoded.iloc[5] == pd.Timestamp("1971-09-01") 98 | assert encoded.iloc[6] == pd.Timestamp("1983-05-19") 99 | 100 | 101 | class TestLanguageDatetimeDecode: 102 | @pytest.fixture 103 | def datetime_stats(self): 104 | return { 105 | "encoding_type": ModelEncodingType.language_datetime, 106 | "has_nan": True, 107 | "min": "2000-01-01", 108 | "max": "2024-12-31", 109 | } 110 | 111 | @pytest.fixture 112 | def no_clip_stats(self): 113 | return { 114 | "encoding_type": ModelEncodingType.language_datetime, 115 | "has_nan": True, 116 | "min": "1900-01-01", 117 | "max": "2100-01-01", 118 | } 119 | 120 | @pytest.fixture 121 | def sample_dates(self): 122 | return pd.Series( 123 | [ 124 | "2021-05-20 14:30:00", # valid datetime with time 125 | "2020-02-30", # Feb 30 is invalid; should be clamped to Feb 29, 2020 126 | "1999-12-31", # below the min bound -> will be clipped upward 127 | "2025-01-01", # above the max bound -> will be clipped downward 128 | "abcd", # invalid date string -> becomes NaT 129 | "", # empty string -> becomes NaT 130 | "_INVALID_", # marked as invalid -> becomes NaT 131 | "2010-10-10", # valid date without explicit time (defaults to 00:00:00) 132 | ] 133 | ) 134 | 135 | def test_datetime_dtype_bounds_and_invalids(self, sample_dates, datetime_stats): 136 | decoded = decode_language_datetime(sample_dates, datetime_stats) 137 | assert decoded.dtype == "datetime64[ns]" 138 | non_null = decoded.dropna() 139 | min_bound = pd.to_datetime(datetime_stats["min"]) 140 | max_bound = pd.to_datetime(datetime_stats["max"]) 141 | for dt in non_null: 142 | assert dt >= min_bound 143 | assert dt <= max_bound 144 | assert all(pd.isna(decoded.iloc[4:7])) 145 | 146 | def test_date_day_clamping(self, no_clip_stats): 147 | s = pd.Series(["2021-04-31"]) 148 | decoded = decode_language_datetime(s, no_clip_stats) 149 | expected = pd.Timestamp("2021-04-30 00:00:00") 150 | assert decoded.iloc[0] == expected 151 | 152 | def test_time_extraction(self, no_clip_stats): 153 | s = pd.Series(["2021-07-15T23:59:59.123"]) 154 | decoded = decode_language_datetime(s, no_clip_stats) 155 | expected = pd.Timestamp("2021-07-15 23:59:59.123") 156 | assert decoded.iloc[0] == expected 157 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/test_numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from mostlyai.engine._common import ANALYZE_MIN_MAX_TOP_N 20 | from mostlyai.engine._encoding_types.language.numeric import ( 21 | analyze_language_numeric, 22 | analyze_reduce_language_numeric, 23 | decode_language_numeric, 24 | encode_language_numeric, 25 | ) 26 | from mostlyai.engine.domain import ModelEncodingType 27 | 28 | 29 | class TestLanguageNumericAnalyze: 30 | def test_analyze_language_numeric(self): 31 | values = pd.Series([0, 1, 2, 3, 4, 5] * ANALYZE_MIN_MAX_TOP_N, name="value") 32 | ids = pd.Series(range(len(values)), name="id") 33 | stats = analyze_language_numeric(values, ids) 34 | assert stats["has_nan"] is False 35 | assert stats["max_n"] == [5] * ANALYZE_MIN_MAX_TOP_N 36 | assert stats["min_n"] == [0] * ANALYZE_MIN_MAX_TOP_N 37 | 38 | 39 | class TestLanguageNumericAnalyzeReduce: 40 | def test_analyze_reduce_language_numeric(self): 41 | stats1 = { 42 | "has_nan": False, 43 | "max_n": [5] * ANALYZE_MIN_MAX_TOP_N, 44 | "min_n": [0] * ANALYZE_MIN_MAX_TOP_N, 45 | "max_scale": 0, 46 | } 47 | stats2 = { 48 | "has_nan": True, 49 | "max_n": [10] * ANALYZE_MIN_MAX_TOP_N, 50 | "min_n": [6] * ANALYZE_MIN_MAX_TOP_N, 51 | "max_scale": 1, 52 | } 53 | reduced = analyze_reduce_language_numeric([stats1, stats2]) 54 | assert reduced["has_nan"] is True 55 | assert reduced["max"] == 10 56 | assert reduced["min"] == 0 57 | assert reduced["max_scale"] == 1 58 | 59 | 60 | class TestLanguageNumericEncode: 61 | def test_encode_language_numeric(self): 62 | values = pd.Series([-1, 0, 1, 2, 3, 4, 5, 6], name="value") 63 | stats = { 64 | "has_nan": False, 65 | "max": 5, 66 | "min": 0, 67 | "max_scale": 0, 68 | } 69 | encoded = encode_language_numeric(values, stats) 70 | assert encoded.dtype == "Int64" 71 | assert encoded.isna().sum() == 0 72 | assert encoded.iloc[0] == 0 73 | assert encoded.iloc[1] == 0 74 | assert encoded.iloc[2] == 1 75 | assert encoded.iloc[3] == 2 76 | assert encoded.iloc[4] == 3 77 | assert encoded.iloc[5] == 4 78 | assert encoded.iloc[6] == 5 79 | assert encoded.iloc[7] == 5 80 | 81 | 82 | class TestLanguageNumericDecode: 83 | @pytest.fixture 84 | def int_stats(self): 85 | return { 86 | "encoding_type": ModelEncodingType.language_numeric, 87 | "has_nan": False, 88 | "max": 91, 89 | "max_scale": 0, 90 | "min": 17, 91 | } 92 | 93 | @pytest.fixture 94 | def float_stats(self): 95 | return { 96 | "encoding_type": ModelEncodingType.language_numeric, 97 | "has_nan": False, 98 | "max": 91.12, 99 | "max_scale": 2, 100 | "min": 17.0, 101 | } 102 | 103 | @pytest.fixture 104 | def sample_values(self): 105 | return pd.Series(["25.3541", "99.99", "-312.0", "61", None, "35.10091", "-1.223"]) 106 | 107 | @pytest.mark.parametrize( 108 | "stats_name, expected_dtype", 109 | [ 110 | ("int_stats", "Int64"), 111 | ("float_stats", float), 112 | ], 113 | ) 114 | def test_decode_language_numeric(self, sample_values, request, stats_name, expected_dtype): 115 | stats = request.getfixturevalue(stats_name) 116 | decoded = decode_language_numeric(sample_values, stats) 117 | assert decoded.dtype == expected_dtype 118 | non_null = decoded.dropna() # we don't enforce compatability with "has_nan" 119 | round_digits = stats["max_scale"] 120 | for v in non_null: 121 | assert np.isclose(v, round(v, round_digits), atol=1e-8) 122 | assert all(non_null <= stats["max"]) 123 | assert all(non_null >= stats["min"]) 124 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/test_character.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from mostlyai.engine._common import read_json, write_json 19 | from mostlyai.engine._encoding_types.tabular.character import ( 20 | MAX_LENGTH_CHARS, 21 | analyze_character, 22 | analyze_reduce_character, 23 | decode_character, 24 | encode_character, 25 | ) 26 | 27 | 28 | def test_character(tmp_path): 29 | # create sequence of common strings, with some of those being overly long 30 | vals = np.repeat(["word", "sentence", "_".join("too_long") * 10], 100) 31 | # inject canaries, to then check whether those tokens are suppressed 32 | canary = "§§§" 33 | no_of_canaries = 3 34 | values1 = pd.Series([canary] * no_of_canaries + list(vals), name="chars") 35 | ids1 = pd.Series(np.arange(len(values1)), name="subject_id") 36 | # create sequence of common strings, with some of those missing 37 | values2 = pd.Series([pd.NA, "random_word", pd.NA] * 100, name="chars") 38 | ids2 = pd.Series(np.arange(len(values2)), name="subject_id") 39 | unseen_values = pd.Series(["a_sentence", "new_word"], name="chars") 40 | 41 | stats1 = analyze_character(values1, ids1) 42 | stats2 = analyze_character(values2, ids2) 43 | assert stats1["max_string_length"] == MAX_LENGTH_CHARS 44 | assert len(stats1["characters"]) == MAX_LENGTH_CHARS 45 | assert stats2["max_string_length"] == values2.str.len().max() 46 | assert len(stats2["characters"]) == values2.str.len().max() 47 | write_json(stats1, tmp_path / "stats1.json") 48 | write_json(stats2, tmp_path / "stats2.json") 49 | 50 | stats1 = read_json(tmp_path / "stats1.json") 51 | stats2 = read_json(tmp_path / "stats2.json") 52 | stats = analyze_reduce_character([stats1, stats2]) 53 | assert len(stats["codes"]) == MAX_LENGTH_CHARS 54 | # check that those rare characters don't occur in any vocabulary set 55 | for p in stats["codes"]: 56 | assert "§" not in stats["codes"][p] 57 | write_json(stats, tmp_path / "stats.json") 58 | 59 | stats = read_json(tmp_path / "stats.json") 60 | encoded1 = encode_character(values1, stats) 61 | decoded1 = decode_character(encoded1, stats) 62 | assert decoded1[no_of_canaries:].equals(values1[no_of_canaries:].str.slice(stop=MAX_LENGTH_CHARS)) 63 | encoded2 = encode_character(values2, stats) 64 | decoded2 = decode_character(encoded2, stats) 65 | assert decoded2.equals(values2.str.slice(stop=MAX_LENGTH_CHARS)) 66 | 67 | unseen_encoded = encode_character(unseen_values, stats) 68 | assert all(unseen_encoded.drop("nan", axis=1).values.flatten() >= 0) 69 | 70 | 71 | def test_character_empty(): 72 | values = pd.Series([None, None, None], name="value") 73 | ids = pd.Series(np.arange(len(values)), name="subject_id") 74 | stats = analyze_reduce_character([analyze_character(values, ids)]) 75 | df_encoded = encode_character(values, stats) 76 | df_decoded = decode_character(df_encoded, stats) 77 | assert all(df_decoded.isna()) 78 | 79 | values = pd.Series(["hello", None, None], name="value") 80 | df_encoded = encode_character(values, stats) 81 | df_decoded = decode_character(df_encoded, stats) 82 | assert all(df_decoded.isna()) 83 | 84 | # no values at all 85 | values = pd.Series([], name="value") 86 | ids = pd.Series(np.arange(len(values)), name="subject_id") 87 | partition_stats = analyze_character(values, ids) 88 | stats = analyze_reduce_character([partition_stats]) 89 | df_encoded = encode_character(values, stats) 90 | df_decoded = decode_character(df_encoded, stats) 91 | assert partition_stats == { 92 | "characters": {}, 93 | "has_nan": False, 94 | "max_string_length": 0, 95 | } 96 | assert stats == { 97 | "cardinalities": {}, 98 | "codes": {}, 99 | "has_nan": False, 100 | "max_string_length": 0, 101 | } 102 | assert df_encoded.empty, df_encoded.columns.tolist() == (True, []) 103 | assert df_decoded.empty, df_encoded.columns.tolist() == (True, []) 104 | 105 | 106 | def test_character_noempties(): 107 | values = pd.Series(["hello", "world", "!"], name="value") 108 | ids = pd.Series(np.arange(len(values)), name="subject_id") 109 | stats = analyze_reduce_character([analyze_character(values, ids)]) 110 | values = pd.Series([None, None, None], name="value") 111 | df_encoded = encode_character(values, stats) 112 | df_decoded = decode_character(df_encoded, stats) 113 | assert df_decoded.size == values.size 114 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/test_datetime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import read_json, write_json 18 | from mostlyai.engine._encoding_types.tabular.datetime import ( 19 | analyze_datetime, 20 | analyze_reduce_datetime, 21 | decode_datetime, 22 | encode_datetime, 23 | split_sub_columns_datetime, 24 | ) 25 | 26 | 27 | def test_datetime(tmp_path): 28 | s1 = pd.Series( 29 | [ 30 | "1910-01-01", 31 | "", 32 | "1930-01-31", 33 | "1940-02-12", 34 | "", 35 | "1971-09-01", 36 | "1983-05-19", 37 | "1998-05-24", 38 | ], 39 | name="birth_date", 40 | ) 41 | i1 = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], name="id") 42 | s2 = pd.Series( 43 | [ 44 | "1912-01-01", 45 | "", 46 | "1932-01-31", 47 | "1942-02-12", 48 | "", 49 | "1972-09-01", 50 | "1984-05-19", 51 | "1994-05-24", 52 | ], 53 | name="birth_date", 54 | ) 55 | i2 = pd.Series([11, 12, 13, 14, 15, 16, 17], name="id") 56 | 57 | stats1 = analyze_datetime(s1, i1) 58 | stats2 = analyze_datetime(s2, i2) 59 | write_json(stats1, tmp_path / "stats1.json") 60 | write_json(stats2, tmp_path / "stats2.json") 61 | 62 | stats = analyze_reduce_datetime([stats1, stats2], value_protection=False) 63 | write_json(stats, tmp_path / "stats.json") 64 | 65 | stats = read_json(tmp_path / "stats.json") 66 | df_encoded = encode_datetime(s1, stats) 67 | df_decoded = decode_datetime(df_encoded, stats) 68 | assert pd.to_datetime(s1).astype("datetime64[ns]").equals(df_decoded) 69 | 70 | 71 | def test_datetime_empty(tmp_path): 72 | values = pd.to_datetime(pd.Series([pd.NaT, pd.NaT, pd.NaT], name="value")).astype("datetime64[ns]") 73 | root_keys = pd.Series(range(len(values)), name="id") 74 | stats = analyze_reduce_datetime([analyze_datetime(values, root_keys)], value_protection=False) 75 | df_encoded = encode_datetime(values, stats) 76 | df_decoded = decode_datetime(df_encoded, stats) 77 | assert values.equals(df_decoded) 78 | assert all(df_decoded.isna()) 79 | 80 | values = pd.to_datetime(pd.Series(["2020-05-24", pd.NaT, pd.NaT], name="value")) 81 | df_encoded = encode_datetime(values, stats) 82 | df_decoded = decode_datetime(df_encoded, stats) 83 | assert all(df_decoded.isna()) 84 | 85 | # no values at all 86 | values = pd.to_datetime(pd.Series([], name="value")) 87 | root_keys = pd.Series(range(len(values)), name="id") 88 | partition_stats = analyze_datetime(values, root_keys) 89 | stats = analyze_reduce_datetime([partition_stats]) 90 | df_encoded = encode_datetime(values, stats) 91 | df_decoded = decode_datetime(df_encoded, stats) 92 | min_max_values = { 93 | "day": 1, 94 | "hour": 0, 95 | "minute": 0, 96 | "month": 1, 97 | "ms_E0": 0, 98 | "ms_E1": 0, 99 | "ms_E2": 0, 100 | "second": 0, 101 | "year": 2022, 102 | } 103 | assert partition_stats == { 104 | "has_nan": False, 105 | "max_n": [], 106 | "max_values": min_max_values, 107 | "min_n": [], 108 | "min_values": min_max_values, 109 | "log_hist": [0.0] * 128, 110 | } 111 | assert stats == { 112 | "cardinalities": {"day": 1, "month": 1, "year": 1}, 113 | "has_ms": False, 114 | "has_nan": False, 115 | "has_time": False, 116 | "max": None, 117 | "max_values": min_max_values, 118 | "min": None, 119 | "min_values": min_max_values, 120 | } 121 | assert df_encoded.empty, df_encoded.columns.tolist() == (True, []) 122 | assert df_decoded.empty, df_encoded.columns.tolist() == (True, []) 123 | 124 | 125 | def test_datetime_noempties(tmp_path): 126 | values = pd.to_datetime(pd.Series(["2020-05-24", "2021-05-24", "2022-05-24"], name="value")) 127 | root_keys = pd.Series(range(len(values)), name="id") 128 | stats = analyze_reduce_datetime([analyze_datetime(values, root_keys)], value_protection=False) 129 | values = pd.to_datetime(pd.Series([pd.NaT, pd.NaT, pd.NaT], name="value")) 130 | df_encoded = encode_datetime(values, stats) 131 | df_decoded = decode_datetime(df_encoded, stats) 132 | assert all(df_decoded.notna()) 133 | 134 | 135 | def test_datetime_min_max_overlapping(): 136 | root_keys = pd.Series(list(range(100)), name="id") 137 | values = pd.Series([pd.to_datetime(f"01-01-{2000 + y}") for y in range(100)], name="value") 138 | stats = analyze_reduce_datetime([analyze_datetime(values, root_keys)]) 139 | for pos, card in stats["cardinalities"].items(): 140 | assert card > 0 141 | 142 | 143 | def test_split_sub_columns_datetime(): 144 | values = pd.Series([pd.to_datetime("2020-01-01"), pd.NaT], name="dt", index=[1, 1]) 145 | df = split_sub_columns_datetime(values) 146 | cols = [ 147 | "nan", 148 | "year", 149 | "month", 150 | "day", 151 | "hour", 152 | "minute", 153 | "second", 154 | "ms_E2", 155 | "ms_E1", 156 | "ms_E0", 157 | ] 158 | vals = [ 159 | [0, 2020, 1, 1, 0, 0, 0, 0, 0, 0], 160 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 161 | ] 162 | pd.testing.assert_frame_equal(df, pd.DataFrame(vals, columns=cols)) 163 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/test_itt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import read_json, write_json 18 | from mostlyai.engine._encoding_types.tabular.itt import ( 19 | analyze_itt, 20 | analyze_reduce_itt, 21 | decode_itt, 22 | encode_itt, 23 | ) 24 | 25 | 26 | def test_itt_date(tmp_path): 27 | values = pd.to_datetime( 28 | pd.Series( 29 | [None, "1978-05-24", "1976-06-22", "1992-12-24", None], 30 | name="date", 31 | dtype="datetime64[us]", 32 | ) 33 | ) 34 | context_keys = pd.Series(["a", "a", "a", "b", "c"], name="__context_key") 35 | root_keys = context_keys.copy() 36 | root_keys.name = "__root_key" 37 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 38 | write_json(stats1, tmp_path / "stats1.json") 39 | stats1 = read_json(tmp_path / "stats1.json") 40 | stats = analyze_reduce_itt([stats1], value_protection=False) 41 | write_json(stats, tmp_path / "stats.json") 42 | stats = read_json(tmp_path / "stats.json") 43 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 44 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys) 45 | assert values.equals(df_decoded) 46 | 47 | 48 | def test_itt_datetime(tmp_path): 49 | values = pd.to_datetime( 50 | pd.Series( 51 | [ 52 | None, 53 | "1978-05-24 12:23:43", 54 | "1976-06-22 17:32:00", 55 | "1992-12-24 01:32:59", 56 | None, 57 | ], 58 | name="date", 59 | dtype="datetime64[us]", 60 | ) 61 | ) 62 | context_keys = pd.Series(["a", "a", "a", "b", "c"], name="__context_key") 63 | root_keys = context_keys.copy() 64 | root_keys.name = "__root_key" 65 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 66 | stats = analyze_reduce_itt([stats1], value_protection=False) 67 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 68 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys) 69 | assert values.equals(df_decoded) 70 | 71 | 72 | def test_itt_nones_only(tmp_path): 73 | values = pd.to_datetime(pd.Series([None, None, None], name="value", dtype="datetime64[us]")) 74 | context_keys = pd.Series(["a", "a", "b"], name="id") 75 | root_keys = pd.Series(["a", "a", "b"], name="rid") 76 | stats = analyze_reduce_itt([analyze_itt(values, root_keys, context_keys)], value_protection=False) 77 | df_encoded = encode_itt(values, stats, context_keys) 78 | df_decoded = decode_itt(df_encoded, stats, context_keys) 79 | assert all(df_decoded.isna()) 80 | 81 | 82 | def test_itt_empty(tmp_path): 83 | values = pd.Series([], name="value") 84 | root_keys = pd.Series([], name="rid") 85 | context_keys = pd.Series([], name="id") 86 | partition_stats = analyze_itt(values, root_keys, context_keys) 87 | stats = analyze_reduce_itt([partition_stats]) 88 | df_encoded = encode_itt(values, stats, context_keys) 89 | df_decoded = decode_itt(df_encoded, stats, context_keys) 90 | min_max_values = { 91 | "itt_day": 0, 92 | "itt_hour": 0, 93 | "itt_minute": 0, 94 | "itt_second": 0, 95 | "itt_week": 0, 96 | "start_day": 1, 97 | "start_hour": 0, 98 | "start_minute": 0, 99 | "start_month": 1, 100 | "start_second": 0, 101 | "start_year": 2022, 102 | } 103 | assert partition_stats == { 104 | "has_nan": False, 105 | "has_neg": False, 106 | "max_n": [], 107 | "max_values": min_max_values, 108 | "min_n": [], 109 | "min_values": min_max_values, 110 | "log_hist": [0.0] * 128, 111 | } 112 | assert stats == { 113 | "cardinalities": { 114 | "itt_day": 1, 115 | "itt_week": 1, 116 | "start_day": 1, 117 | "start_month": 1, 118 | "start_year": 1, 119 | }, 120 | "has_nan": False, 121 | "has_neg": False, 122 | "has_time": False, 123 | "max": None, 124 | "max_values": min_max_values, 125 | "min": None, 126 | "min_values": min_max_values, 127 | } 128 | assert df_encoded.empty, df_encoded.columns.tolist() == (True, []) 129 | assert df_decoded.empty, df_encoded.columns.tolist() == (True, []) 130 | 131 | 132 | def test_itt_1to1(tmp_path): 133 | values = pd.to_datetime( 134 | pd.Series( 135 | [None, "1978-05-24", "1976-06-22", "1992-12-24", None], 136 | name="date", 137 | dtype="datetime64[us]", 138 | ) 139 | ) 140 | context_keys = pd.Series(["a", "b", "c", "d", "e"], name="__context_key") 141 | root_keys = context_keys.copy() 142 | root_keys.name = "__root_key" 143 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 144 | write_json(stats1, tmp_path / "stats1.json") 145 | stats1 = read_json(tmp_path / "stats1.json") 146 | stats = analyze_reduce_itt([stats1], value_protection=False) 147 | write_json(stats, tmp_path / "stats.json") 148 | stats = read_json(tmp_path / "stats.json") 149 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 150 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys) 151 | assert values.equals(df_decoded) 152 | 153 | 154 | def test_itt_with_prev_steps(tmp_path): 155 | values = pd.to_datetime( 156 | pd.Series( 157 | ["1978-05-24", "1976-06-22", "1976-06-23", "1976-06-24"], 158 | name="date", 159 | dtype="datetime64[us]", 160 | ) 161 | ) 162 | context_keys = pd.Series(["a", "b", "b", "b"], name="__context_key") 163 | root_keys = context_keys.copy() 164 | root_keys.name = "__root_key" 165 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 166 | write_json(stats1, tmp_path / "stats1.json") 167 | stats1 = read_json(tmp_path / "stats1.json") 168 | stats = analyze_reduce_itt([stats1], value_protection=False) 169 | write_json(stats, tmp_path / "stats.json") 170 | stats = read_json(tmp_path / "stats.json") 171 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 172 | prev_steps = { 173 | "prev_dts": pd.DataFrame( 174 | { 175 | "__CONTEXT_KEYS": ["a", "b"], 176 | "__STARTS": pd.to_datetime(pd.Series(["1978-05-23", "1976-06-21"], dtype="datetime64[us]")), 177 | } 178 | ) 179 | } 180 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys, prev_steps=prev_steps) 181 | assert values.equals(df_decoded) 182 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-meta/encoding-types.json: -------------------------------------------------------------------------------- 1 | { 2 | "deathDate": "TABULAR_DATETIME", 3 | "bats": "TABULAR_CATEGORICAL" 4 | } 5 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-meta/keys.json: -------------------------------------------------------------------------------- 1 | { 2 | "primary_key": "__primary_key" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-stats/part.000000-trn.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "deathDate": { 4 | "has_nan": true, 5 | "min_values": { 6 | "year": 1873, 7 | "month": 1, 8 | "day": 1, 9 | "hour": 0, 10 | "minute": 0, 11 | "second": 0, 12 | "ms_E2": 0, 13 | "ms_E1": 0, 14 | "ms_E0": 0 15 | }, 16 | "max_values": { 17 | "year": 2019, 18 | "month": 12, 19 | "day": 31, 20 | "hour": 0, 21 | "minute": 0, 22 | "second": 0, 23 | "ms_E2": 0, 24 | "ms_E1": 0, 25 | "ms_E0": 0 26 | }, 27 | "min10": [ 28 | "1873-02-26", 29 | "1876-10-18", 30 | "1879-06-18", 31 | "1881-03-01", 32 | "1881-05-10", 33 | "1884-04-29", 34 | "1884-09-26", 35 | "1886-02-13", 36 | "1886-05-21", 37 | "1886-08-09" 38 | ], 39 | "max10": [ 40 | "2019-12-29", 41 | "2019-12-16", 42 | "2019-12-15", 43 | "2019-12-08", 44 | "2019-11-28", 45 | "2019-11-23", 46 | "2019-09-07", 47 | "2019-09-06", 48 | "2019-09-06", 49 | "2019-08-26" 50 | ], 51 | "encoding_type": "TABULAR_DATETIME" 52 | }, 53 | "bats": { 54 | "has_nan": false, 55 | "cnt_values": { 56 | "": 361, 57 | "B": 363, 58 | "L": 1586, 59 | "R": 3690 60 | }, 61 | "encoding_type": "TABULAR_CATEGORICAL" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-stats/part.000000-val.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "deathDate": { 4 | "has_nan": true, 5 | "min_values": { 6 | "year": 1873, 7 | "month": 1, 8 | "day": 1, 9 | "hour": 0, 10 | "minute": 0, 11 | "second": 0, 12 | "ms_E2": 0, 13 | "ms_E1": 0, 14 | "ms_E0": 0 15 | }, 16 | "max_values": { 17 | "year": 2019, 18 | "month": 12, 19 | "day": 31, 20 | "hour": 0, 21 | "minute": 0, 22 | "second": 0, 23 | "ms_E2": 0, 24 | "ms_E1": 0, 25 | "ms_E0": 0 26 | }, 27 | "min10": [ 28 | "1873-02-26", 29 | "1876-10-18", 30 | "1879-06-18", 31 | "1881-03-01", 32 | "1881-05-10", 33 | "1884-04-29", 34 | "1884-09-26", 35 | "1886-02-13", 36 | "1886-05-21", 37 | "1886-08-09" 38 | ], 39 | "max10": [ 40 | "2019-12-29", 41 | "2019-12-16", 42 | "2019-12-15", 43 | "2019-12-08", 44 | "2019-11-28", 45 | "2019-11-23", 46 | "2019-09-07", 47 | "2019-09-06", 48 | "2019-09-06", 49 | "2019-08-26" 50 | ], 51 | "encoding_type": "TABULAR_DATETIME" 52 | }, 53 | "bats": { 54 | "has_nan": false, 55 | "cnt_values": { 56 | "": 361, 57 | "B": 363, 58 | "L": 1586, 59 | "R": 3690 60 | }, 61 | "encoding_type": "TABULAR_CATEGORICAL" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-stats/stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "deathDate": { 4 | "cardinalities": { 5 | "nan": 2, 6 | "year": 136, 7 | "month": 12, 8 | "day": 31 9 | }, 10 | "has_nan": true, 11 | "has_time": false, 12 | "has_ms": false, 13 | "min_values": { 14 | "year": 1884, 15 | "month": 1, 16 | "day": 1, 17 | "hour": 0, 18 | "minute": 0, 19 | "second": 0, 20 | "ms_E2": 0, 21 | "ms_E1": 0, 22 | "ms_E0": 0 23 | }, 24 | "max_values": { 25 | "year": 2019, 26 | "month": 12, 27 | "day": 31, 28 | "hour": 0, 29 | "minute": 0, 30 | "second": 0, 31 | "ms_E2": 0, 32 | "ms_E1": 0, 33 | "ms_E0": 0 34 | }, 35 | "min5": [ 36 | "1884-04-29", 37 | "1884-09-26", 38 | "1886-02-13", 39 | "1886-05-21", 40 | "1886-08-09" 41 | ], 42 | "max5": [ 43 | "2019-11-23", 44 | "2019-09-07", 45 | "2019-09-06", 46 | "2019-09-06", 47 | "2019-08-26" 48 | ], 49 | "encoding_type": "TABULAR_DATETIME", 50 | "tf_name": "c0" 51 | }, 52 | "bats": { 53 | "no_of_rare_categories": 0, 54 | "codes": { 55 | "_RARE_": 0, 56 | "": 1, 57 | "B": 2, 58 | "L": 3, 59 | "R": 4 60 | }, 61 | "cardinalities": { 62 | "cat": 5 63 | }, 64 | "encoding_type": "TABULAR_CATEGORICAL", 65 | "tf_name": "c1" 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/model-data/model-configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "consistency_correction_tf": [], 3 | "model_size": { 4 | "embed.tgt.__ridx": 5, 5 | "embed.tgt.__stop": 2, 6 | "embed.tgt.c0__tokens": 9, 7 | "embed.ctx.c0__nan": 2, 8 | "embed.ctx.c0__year": 12, 9 | "embed.ctx.c0__month": 6, 10 | "embed.ctx.c0__day": 9, 11 | "embed.ctx.c1__cat": 5, 12 | "context_0": 256, 13 | "history_0": 256, 14 | "reg.tgt.__ridx_0": 16, 15 | "reg.tgt.__stop_0": 16, 16 | "reg.tgt.c0__tokens_0": 32 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/model-data/model-weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/ModelStore/model-data/model-weights.pt -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-meta/encoding-types.json: -------------------------------------------------------------------------------- 1 | { 2 | "desc": "LANGUAGE_TEXT" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-meta/keys.json: -------------------------------------------------------------------------------- 1 | { 2 | "context_key": "__primary_key" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-stats/part.000000-trn.json: -------------------------------------------------------------------------------- 1 | { 2 | "no_of_training_records": 5400, 3 | "no_of_validation_records": 0, 4 | "seq_len": { 5 | "cnt_lengths": { 6 | "1": 6000 7 | } 8 | }, 9 | "columns": { 10 | "desc": { 11 | "has_na": false, 12 | "cnt_values": { 13 | "a": 0, 14 | "d": 0, 15 | "e": 0, 16 | "f": 0, 17 | "g": 0, 18 | "h": 0, 19 | "i": 0, 20 | "l": 0, 21 | "n": 0, 22 | "p": 0, 23 | "r": 0, 24 | "s": 0, 25 | "t": 0, 26 | "w": 0, 27 | "y": 0, 28 | "▁": 0, 29 | "an": 0, 30 | "ay": 0, 31 | "de": 0, 32 | "er": 0, 33 | "han": 0, 34 | "lay": 0, 35 | "play": 0, 36 | "▁han": 0, 37 | "▁play": 0, 38 | "ded": 0, 39 | "▁handed": 6000, 40 | "▁player": 6000, 41 | "gh": 0, 42 | "igh": 0, 43 | "righ": 0, 44 | "▁righ": 0, 45 | "▁right": 3690, 46 | "is": 0, 47 | "▁is": 3062, 48 | "as": 0, 49 | "was": 0, 50 | "▁was": 2938, 51 | "ef": 0, 52 | "lef": 0, 53 | "▁lef": 0, 54 | "▁left": 2310 55 | }, 56 | "cnt_lengths": { 57 | "4": 6000 58 | }, 59 | "merges": [ 60 | "a n", 61 | "a y", 62 | "d e", 63 | "e r", 64 | "h an", 65 | "l ay", 66 | "p lay", 67 | "▁ han", 68 | "▁ play", 69 | "de d", 70 | "▁han ded", 71 | "▁play er", 72 | "g h", 73 | "i gh", 74 | "r igh", 75 | "▁ righ", 76 | "▁righ t", 77 | "i s", 78 | "▁ is", 79 | "a s", 80 | "w as", 81 | "▁ was", 82 | "e f", 83 | "l ef", 84 | "▁ lef", 85 | "▁lef t" 86 | ], 87 | "encoding_type": "LANGUAGE_TEXT" 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-stats/part.000000-val.json: -------------------------------------------------------------------------------- 1 | { 2 | "no_of_training_records": 0, 3 | "no_of_validation_records": 600, 4 | "seq_len": { 5 | "cnt_lengths": { 6 | "1": 6000 7 | } 8 | }, 9 | "columns": { 10 | "desc": { 11 | "has_na": false, 12 | "cnt_values": { 13 | "a": 0, 14 | "d": 0, 15 | "e": 0, 16 | "f": 0, 17 | "g": 0, 18 | "h": 0, 19 | "i": 0, 20 | "l": 0, 21 | "n": 0, 22 | "p": 0, 23 | "r": 0, 24 | "s": 0, 25 | "t": 0, 26 | "w": 0, 27 | "y": 0, 28 | "▁": 0, 29 | "an": 0, 30 | "ay": 0, 31 | "de": 0, 32 | "er": 0, 33 | "han": 0, 34 | "lay": 0, 35 | "play": 0, 36 | "▁han": 0, 37 | "▁play": 0, 38 | "ded": 0, 39 | "▁handed": 6000, 40 | "▁player": 6000, 41 | "gh": 0, 42 | "igh": 0, 43 | "righ": 0, 44 | "▁righ": 0, 45 | "▁right": 3690, 46 | "is": 0, 47 | "▁is": 3062, 48 | "as": 0, 49 | "was": 0, 50 | "▁was": 2938, 51 | "ef": 0, 52 | "lef": 0, 53 | "▁lef": 0, 54 | "▁left": 2310 55 | }, 56 | "cnt_lengths": { 57 | "4": 6000 58 | }, 59 | "merges": [ 60 | "a n", 61 | "a y", 62 | "d e", 63 | "e r", 64 | "h an", 65 | "l ay", 66 | "p lay", 67 | "▁ han", 68 | "▁ play", 69 | "de d", 70 | "▁han ded", 71 | "▁play er", 72 | "g h", 73 | "i gh", 74 | "r igh", 75 | "▁ righ", 76 | "▁righ t", 77 | "i s", 78 | "▁ is", 79 | "a s", 80 | "w as", 81 | "▁ was", 82 | "e f", 83 | "l ef", 84 | "▁ lef", 85 | "▁lef t" 86 | ], 87 | "encoding_type": "LANGUAGE_TEXT" 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-stats/stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "desc": { 4 | "cardinalities": { 5 | "tokens": 43 6 | }, 7 | "has_na": false, 8 | "tokens": [ 9 | "▁handed", 10 | "▁player", 11 | "▁right", 12 | "▁is", 13 | "▁was", 14 | "▁left", 15 | "a", 16 | "d", 17 | "e", 18 | "f", 19 | "g", 20 | "h", 21 | "i", 22 | "l", 23 | "n", 24 | "p", 25 | "r", 26 | "s", 27 | "t", 28 | "w", 29 | "y", 30 | "▁", 31 | "an", 32 | "ay", 33 | "de", 34 | "er", 35 | "han", 36 | "lay", 37 | "play", 38 | "▁han", 39 | "▁play", 40 | "ded", 41 | "gh", 42 | "igh", 43 | "righ", 44 | "▁righ", 45 | "is", 46 | "as", 47 | "was", 48 | "ef", 49 | "lef", 50 | "▁lef" 51 | ], 52 | "merges": [ 53 | "a n", 54 | "a y", 55 | "d e", 56 | "e r", 57 | "h an", 58 | "l ay", 59 | "p lay", 60 | "▁ han", 61 | "▁ play", 62 | "de d", 63 | "▁han ded", 64 | "▁play er", 65 | "g h", 66 | "i gh", 67 | "r igh", 68 | "▁ righ", 69 | "▁righ t", 70 | "i s", 71 | "▁ is", 72 | "a s", 73 | "w as", 74 | "▁ was", 75 | "e f", 76 | "l ef", 77 | "▁ lef", 78 | "▁lef t" 79 | ], 80 | "seq_len": { 81 | "min": 4, 82 | "max": 4, 83 | "median": 4, 84 | "deciles": [ 85 | 4, 86 | 4, 87 | 4, 88 | 4, 89 | 4, 90 | 4, 91 | 4, 92 | 4, 93 | 4, 94 | 4, 95 | 4 96 | ] 97 | }, 98 | "encoding_type": "LANGUAGE_TEXT" 99 | } 100 | }, 101 | "no_of_training_records": 5400, 102 | "no_of_validation_records": 600, 103 | "seq_len": { 104 | "min": 1, 105 | "max": 1, 106 | "median": 1, 107 | "deciles": [ 108 | 1, 109 | 1, 110 | 1, 111 | 1, 112 | 1, 113 | 1, 114 | 1, 115 | 1, 116 | 1, 117 | 1, 118 | 1 119 | ] 120 | }, 121 | "is_sequential": false, 122 | "has_sequential_columns": true 123 | } 124 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/ctx-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/OriginalData/ctx-data/part.000000-trn.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/ctx-data/part.000000-val.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/OriginalData/ctx-data/part.000000-val.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/encoded-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/OriginalData/encoded-data/part.000000-trn.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/encoded-data/part.000000-val.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/OriginalData/encoded-data/part.000000-val.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/tgt-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/OriginalData/tgt-data/part.000000-trn.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/tgt-data/part.000000-val.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/OriginalData/tgt-data/part.000000-val.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/SyntheticData/part.000001.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/SyntheticData/part.000001.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/SyntheticData/part.000002.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/all/SyntheticData/part.000002.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/ModelStore/tgt-meta/encoding-types.json: -------------------------------------------------------------------------------- 1 | { 2 | "desc": "LANGUAGE_TEXT" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/ModelStore/tgt-meta/keys.json: -------------------------------------------------------------------------------- 1 | { 2 | "context_key": "__primary_key" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/OriginalData/tgt-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/some/OriginalData/tgt-data/part.000000-trn.parquet -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/OriginalData/tgt-data/part.000000-val.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/f82fcb43a1cee759ca9287af01ce02c85d6befdc/tests/unit/fixtures/workspace/some/OriginalData/tgt-data/part.000000-val.parquet -------------------------------------------------------------------------------- /tests/unit/test_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from mostlyai.engine._common import ANALYZE_MIN_MAX_TOP_N, read_json, write_json 19 | from mostlyai.engine.analysis import ( 20 | _analyze_col, 21 | _analyze_partition, 22 | _analyze_reduce_seq_len, 23 | _analyze_seq_len, 24 | ) 25 | from mostlyai.engine.domain import ModelEncodingType 26 | 27 | 28 | def test_analyze_cnt(tmp_path): 29 | events = pd.DataFrame({"context_key": [1, 1, 1, 2, 3, 4], "x": [1, 1, 1, 1, 1, 1]}) 30 | no_of_records = events["context_key"].nunique() 31 | 32 | events.to_parquet(tmp_path / "part.000000-trn.parquet") 33 | stats_path = tmp_path / "stats" 34 | stats_path.mkdir() 35 | _analyze_partition( 36 | tmp_path / "part.000000-trn.parquet", 37 | stats_path, 38 | tgt_context_key="context_key", 39 | tgt_encoding_types={"x": ModelEncodingType.tabular_numeric_digit}, 40 | ) 41 | stats = read_json(stats_path / "part.000000-trn.json") 42 | assert stats["no_of_training_records"] == no_of_records 43 | assert stats["no_of_validation_records"] == 0 44 | 45 | 46 | def test_analyze_seq_len(tmp_path): 47 | tgt_context_keys = pd.Series(np.repeat(range(21), range(21)), name="account_id") 48 | partition_stats = _analyze_seq_len(tgt_context_keys=tgt_context_keys, ctx_primary_keys=tgt_context_keys) 49 | write_json(partition_stats, tmp_path / "stats1.json") 50 | partition_stats = read_json(tmp_path / "stats1.json") 51 | global_stats = _analyze_reduce_seq_len([partition_stats]) 52 | write_json(global_stats, tmp_path / "stats.json") 53 | global_stats = read_json(tmp_path / "stats.json") 54 | assert global_stats["max"] >= 12 and global_stats["max"] <= 15 55 | assert isinstance(global_stats["max"], int) 56 | global_stats = _analyze_reduce_seq_len([partition_stats for i in range(9)]) 57 | assert global_stats["max"] == 20 58 | 59 | tgt_context_keys = pd.Series(np.repeat(range(21), range(21)), name="account_id") 60 | partition_stats = _analyze_seq_len( 61 | tgt_context_keys=tgt_context_keys, 62 | ctx_primary_keys=pd.concat([tgt_context_keys, pd.Series([100])]), 63 | ) 64 | global_stats = _analyze_reduce_seq_len([partition_stats], value_protection=False) 65 | assert global_stats["min"] == 0 66 | assert global_stats["max"] == 20 67 | 68 | 69 | def test_analyze_root_key(tmp_path): 70 | tgt_context_keys = pd.Series(np.repeat(range(40), 2), name="tgt_context_key") 71 | tgt_values = pd.Series(list(range(80)), name="tgt_values") 72 | tgt = pd.concat([tgt_context_keys, tgt_values], axis=1) 73 | 74 | ctx_root_keys = pd.Series(np.repeat(range(20), 2), name="ctx_root_keys") 75 | ctx_primary_keys = pd.Series(list(range(40)), name="ctx_primary_key") 76 | ctx_values = pd.Series(list(range(40)), name="ctx_values") 77 | ctx = pd.concat([ctx_root_keys, ctx_primary_keys, ctx_values], axis=1) 78 | 79 | tgt_partition_path, ctx_partition_path = ( 80 | tmp_path / "tgt.000000-trn.parquet", 81 | tmp_path / "ctx.000000-trn.parquet", 82 | ) 83 | tgt.to_parquet(tgt_partition_path), ctx.to_parquet(ctx_partition_path) 84 | 85 | tgt_stats_path, ctx_stats_path = tmp_path / "tgt_stats", tmp_path / "ctx_stats" 86 | tgt_stats_path.mkdir(), ctx_stats_path.mkdir() 87 | 88 | # root key column is in tgt table 89 | _analyze_partition( 90 | tgt_partition_file=tgt_partition_path, 91 | tgt_stats_path=tgt_stats_path, 92 | tgt_encoding_types={tgt_values.name: ModelEncodingType.tabular_numeric_digit}, 93 | tgt_context_key=tgt_context_keys.name, 94 | ctx_partition_file=ctx_partition_path, 95 | ctx_stats_path=ctx_stats_path, 96 | ctx_encoding_types={ctx_values.name: ModelEncodingType.tabular_numeric_digit}, 97 | ctx_primary_key=ctx_primary_keys.name, 98 | ctx_root_key=ctx_root_keys.name, 99 | ) 100 | ctx_stats = read_json(ctx_stats_path / "part.000000-trn.json") 101 | assert ctx_stats["columns"][ctx_values.name]["max_n"] == list(range(40))[::-2][:ANALYZE_MIN_MAX_TOP_N] 102 | assert ctx_stats["columns"][ctx_values.name]["min_n"] == list(range(40))[::2][:ANALYZE_MIN_MAX_TOP_N] 103 | 104 | # root key column is in ctx table 105 | _analyze_partition( 106 | tgt_partition_file=tgt_partition_path, 107 | tgt_stats_path=tgt_stats_path, 108 | tgt_encoding_types={tgt_values.name: ModelEncodingType.tabular_numeric_digit}, 109 | tgt_context_key=tgt_context_keys.name, 110 | ctx_partition_file=ctx_partition_path, 111 | ctx_stats_path=ctx_stats_path, 112 | ctx_encoding_types={ctx_values.name: ModelEncodingType.tabular_numeric_digit}, 113 | ctx_primary_key=ctx_primary_keys.name, 114 | ctx_root_key=ctx_root_keys.name, 115 | ) 116 | tgt_stats = read_json(tgt_stats_path / "part.000000-trn.json") 117 | assert tgt_stats["columns"][tgt_values.name]["max_n"] == list(range(80))[::-1][:ANALYZE_MIN_MAX_TOP_N] 118 | assert tgt_stats["columns"][tgt_values.name]["min_n"] == list(range(80))[::1][:ANALYZE_MIN_MAX_TOP_N] 119 | ctx_stats = read_json(ctx_stats_path / "part.000000-trn.json") 120 | assert ctx_stats["columns"][ctx_values.name]["max_n"] == list(range(40))[::-2][:ANALYZE_MIN_MAX_TOP_N] 121 | assert ctx_stats["columns"][ctx_values.name]["min_n"] == list(range(40))[::2][:ANALYZE_MIN_MAX_TOP_N] 122 | 123 | 124 | class TestAnalyzeCol: 125 | def test_empty_values(self): 126 | values = pd.Series([], name="values") 127 | root_keys = pd.Series([], name="root_keys") 128 | stats = _analyze_col(values=values, encoding_type=ModelEncodingType.tabular_categorical, root_keys=root_keys) 129 | assert stats == {"encoding_type": ModelEncodingType.tabular_categorical.value} 130 | 131 | def test_flat_values(self): 132 | values = pd.Series([1, 2, 3], name="values") 133 | root_keys = pd.Series([1, 2, 3], name="root_keys") 134 | stats = _analyze_col( 135 | values=values, encoding_type=ModelEncodingType.tabular_categorical.value, root_keys=root_keys 136 | ) 137 | assert stats == { 138 | "encoding_type": ModelEncodingType.tabular_categorical.value, 139 | "cnt_values": {"1": 1, "2": 1, "3": 1}, 140 | "has_nan": False, 141 | } 142 | 143 | def test_sequential_values(self): 144 | values = pd.Series([[1, 2, 3], [], [3], [], [2]], name="values") 145 | root_keys = pd.Series([1, 2, 3, 4, 5], name="root_keys") 146 | stats = _analyze_col(values=values, encoding_type=ModelEncodingType.tabular_categorical, root_keys=root_keys) 147 | assert stats == { 148 | "encoding_type": ModelEncodingType.tabular_categorical.value, 149 | "cnt_values": {"1": 1, "2": 2, "3": 2}, 150 | "has_nan": False, 151 | "seq_len": {"cnt_lengths": {0: 2, 1: 2, 3: 1}}, 152 | } 153 | -------------------------------------------------------------------------------- /tests/unit/test_domain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | from pydantic import ValidationError 17 | 18 | from mostlyai.engine.domain import RebalancingConfig 19 | 20 | 21 | def test_rebalancing_config_valid(): 22 | config = RebalancingConfig(column="test_column", probabilities={"A": 0.3, "B": 0.5}) 23 | assert config.column == "test_column" 24 | assert config.probabilities == {"A": 0.3, "B": 0.5} 25 | 26 | 27 | def test_rebalancing_config_invalid_probabilities_values_out_of_range(): 28 | with pytest.raises(ValidationError): 29 | RebalancingConfig(column="test_column", probabilities={"A": -0.5, "B": 1.5}) 30 | 31 | 32 | def test_rebalancing_config_invalid_probabilities_values_sum(): 33 | with pytest.raises(ValidationError): 34 | RebalancingConfig(column="test_column", probabilities={"A": 0.3, "B": 0.8}) 35 | -------------------------------------------------------------------------------- /tests/unit/test_fairness.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | from mostlyai.engine._common import ARGN_COLUMN, ARGN_PROCESSOR, ARGN_TABLE, get_argn_name 18 | from mostlyai.engine._encoding_types.tabular.categorical import CATEGORICAL_SUB_COL_SUFFIX 19 | from mostlyai.engine._tabular.fairness import _get_sensitive_groups 20 | from mostlyai.engine.domain import ModelEncodingType 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def tgt_stats(): 25 | return { 26 | "columns": { 27 | "c0_cat": { 28 | "encoding_type": ModelEncodingType.tabular_categorical.value, 29 | "argn_processor": "tgt", 30 | "argn_table": "t0", 31 | "argn_column": "c0", 32 | "no_of_rare_categories": 0, 33 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(2)}}, 34 | }, 35 | "c1_cat": { 36 | "encoding_type": ModelEncodingType.tabular_categorical.value, 37 | "argn_processor": "tgt", 38 | "argn_table": "t0", 39 | "argn_column": "c1", 40 | "no_of_rare_categories": 0, 41 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(3)}}, 42 | }, 43 | "c2_cat": { 44 | "encoding_type": ModelEncodingType.tabular_categorical.value, 45 | "argn_processor": "tgt", 46 | "argn_table": "t0", 47 | "argn_column": "c2", 48 | "no_of_rare_categories": 1, 49 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(5)}}, 50 | }, 51 | "c3_num": { 52 | "encoding_type": ModelEncodingType.tabular_numeric_auto.value, 53 | "argn_processor": "tgt", 54 | "argn_table": "t0", 55 | "argn_column": "c3", 56 | }, 57 | "c4_cat": { 58 | "encoding_type": ModelEncodingType.tabular_categorical.value, 59 | "argn_processor": "tgt", 60 | "argn_table": "t0", 61 | "argn_column": "c4", 62 | "no_of_rare_categories": 0, 63 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(7)}}, 64 | }, 65 | } 66 | } 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "target_column, sensitive_columns, expected_n_rows", 71 | [ 72 | ("c0_cat", ["c1_cat", "c2_cat"], 18), # 3 * (5+1) 73 | ("c0_cat", ["c1_cat", "c4_cat"], 21), # 3 * 7 74 | ], 75 | ) 76 | def test_get_sensitive_category_groups(tgt_stats, target_column, sensitive_columns, expected_n_rows): 77 | column_stats = tgt_stats["columns"] 78 | sensitive_sub_cols = [ 79 | get_argn_name( 80 | argn_processor=tgt_stats["columns"][col][ARGN_PROCESSOR], 81 | argn_table=tgt_stats["columns"][col][ARGN_TABLE], 82 | argn_column=tgt_stats["columns"][col][ARGN_COLUMN], 83 | argn_sub_column=CATEGORICAL_SUB_COL_SUFFIX, 84 | ) 85 | for col in sensitive_columns 86 | ] 87 | groups_df = _get_sensitive_groups(column_stats, sensitive_columns, sensitive_sub_cols) 88 | assert groups_df.shape[0] == expected_n_rows 89 | -------------------------------------------------------------------------------- /tests/unit/test_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mostlyai.engine._memory import extract_memory_from_string 16 | 17 | 18 | def test_extract_memory_from_string(): 19 | assert extract_memory_from_string("3.2GB") == int(3.2 * 1024**3) 20 | assert extract_memory_from_string("3.2Gi") == int(3.2 * 1024**3) 21 | assert extract_memory_from_string(" 3 g ") == 3 * 1024**3 22 | assert extract_memory_from_string("0.23GB") == int(0.23 * 1024**3) 23 | assert extract_memory_from_string("32804 gb") == 32804 * 1024**3 24 | assert extract_memory_from_string("4B") == 4 25 | assert extract_memory_from_string("4") == 4 26 | assert extract_memory_from_string("") is None 27 | assert extract_memory_from_string() is None 28 | -------------------------------------------------------------------------------- /tests/unit/test_workspace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | from unittest import mock 18 | 19 | from mostlyai.engine._common import read_json, write_json 20 | from mostlyai.engine._workspace import Workspace 21 | 22 | FIXTURES_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")) 23 | 24 | 25 | class TestWorkspace: 26 | def test_workspace_all_objects(self): 27 | ws_path = Path(FIXTURES_PATH) / "workspace" / "all" 28 | ws = Workspace(ws_path) 29 | 30 | # Split-related 31 | assert ws.tgt_data_path == ws_path / "OriginalData" / "tgt-data" 32 | tgt_data_file_names = [i.name for i in ws.tgt_data.fetch_all()] 33 | assert tgt_data_file_names == ["part.000000-trn.parquet", "part.000000-val.parquet"] 34 | assert isinstance(ws.tgt_encoding_types.read(), dict) 35 | assert isinstance(ws.tgt_keys.read(), dict) 36 | 37 | assert ws.ctx_data_path == ws_path / "OriginalData" / "ctx-data" 38 | ctx_data_file_names = [i.name for i in ws.ctx_data.fetch_all()] 39 | assert ctx_data_file_names == ["part.000000-trn.parquet", "part.000000-val.parquet"] 40 | assert isinstance(ws.ctx_encoding_types.read(), dict) 41 | assert isinstance(ws.ctx_keys.read(), dict) 42 | 43 | # Analyze-related 44 | assert ws.tgt_stats_path == Path(ws_path) / "ModelStore" / "tgt-stats" 45 | tgt_all_stats_file_names = [i.name for i in ws.tgt_all_stats.fetch_all()] 46 | assert tgt_all_stats_file_names == ["part.000000-trn.json", "part.000000-val.json"] 47 | assert isinstance(ws.tgt_stats.read(), dict) 48 | assert ws.ctx_stats_path == Path(ws_path) / "ModelStore" / "ctx-stats" 49 | ctx_all_stats_file_names = [i.name for i in ws.ctx_all_stats.fetch_all()] 50 | assert ctx_all_stats_file_names == ["part.000000-trn.json", "part.000000-val.json"] 51 | assert isinstance(ws.tgt_stats.read(), dict) 52 | 53 | # Encode-related 54 | assert ws.encoded_data_path == Path(ws_path) / "OriginalData" / "encoded-data" 55 | assert len(ws.encoded_data_val.fetch_all()) == 1 56 | assert len(ws.encoded_data_trn.fetch_all()) == 1 57 | 58 | # Train-related 59 | assert ws.model_path == Path(ws_path) / "ModelStore" / "model-data" 60 | assert ws.model_tabular_weights_path.exists() 61 | assert isinstance(ws.model_configs.read(), dict) 62 | 63 | # Generate-related 64 | assert ws.generated_data_path == Path(ws_path) / "SyntheticData" 65 | generated_data_file_names = [i.name for i in ws.generated_data.fetch_all()] 66 | assert generated_data_file_names == ["part.000001.parquet", "part.000002.parquet"] 67 | 68 | def test_read_write_json(self): 69 | ws_path = Path(FIXTURES_PATH) / "workspace" / "some" 70 | ws = Workspace(ws_path) 71 | 72 | assert ws.tgt_keys.read_handler == read_json 73 | ws.tgt_keys.read() == {"context_key": "__primary_key"} 74 | assert ws.tgt_keys.write_handler == write_json 75 | with mock.patch.object(ws.tgt_keys, "write_handler") as write_mock: 76 | new_key_data = {"new_key": "test_key"} 77 | ws.tgt_keys.write(new_key_data) 78 | assert write_mock.call_args[0] == ( 79 | new_key_data, 80 | ws_path / "OriginalData" / "tgt-meta" / "keys.json", 81 | ) 82 | --------------------------------------------------------------------------------