├── tests ├── unit │ ├── fixtures │ │ └── workspace │ │ │ ├── all │ │ │ ├── SyntheticData │ │ │ │ ├── part.000001.parquet │ │ │ │ └── part.000002.parquet │ │ │ ├── ModelStore │ │ │ │ ├── model-data │ │ │ │ │ ├── model-weights.pt │ │ │ │ │ └── model-configs.json │ │ │ │ ├── ctx-meta │ │ │ │ │ ├── keys.json │ │ │ │ │ └── encoding-types.json │ │ │ │ ├── tgt-meta │ │ │ │ │ ├── encoding-types.json │ │ │ │ │ └── keys.json │ │ │ │ ├── ctx-stats │ │ │ │ │ ├── part.000000-trn.json │ │ │ │ │ ├── part.000000-val.json │ │ │ │ │ └── stats.json │ │ │ │ └── tgt-stats │ │ │ │ │ ├── part.000000-trn.json │ │ │ │ │ ├── part.000000-val.json │ │ │ │ │ └── stats.json │ │ │ └── OriginalData │ │ │ │ ├── ctx-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ │ │ ├── encoded-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ │ │ └── tgt-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ │ └── some │ │ │ ├── OriginalData │ │ │ └── tgt-data │ │ │ │ ├── part.000000-trn.parquet │ │ │ │ └── part.000000-val.parquet │ │ │ └── ModelStore │ │ │ └── tgt-meta │ │ │ ├── encoding-types.json │ │ │ └── keys.json │ ├── __init__.py │ ├── encoding_types │ │ ├── __init__.py │ │ ├── language │ │ │ ├── __init__.py │ │ │ ├── test_categorical.py │ │ │ ├── test_numeric.py │ │ │ └── test_datetime.py │ │ └── tabular │ │ │ ├── __init__.py │ │ │ ├── test_character.py │ │ │ ├── test_datetime.py │ │ │ └── test_itt.py │ ├── test_memory.py │ ├── test_domain.py │ ├── test_fairness.py │ ├── test_workspace.py │ └── test_tabular_common.py ├── __init__.py └── end_to_end │ ├── __init__.py │ ├── conftest.py │ ├── test_language_interface.py │ └── test_numeric.py ├── docs ├── index.md ├── logo.png ├── favicon.png └── api.md ├── .gitignore ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── documentation.md │ ├── feature_request.md │ └── bug_report.md ├── workflows │ ├── pre-commit-check.yml │ ├── workflow.yaml │ ├── run-tests-gpu.yaml │ └── run-tests-cpu.yaml └── changelog_config.json ├── mostlyai └── engine │ ├── _tabular │ └── __init__.py │ ├── _encoding_types │ ├── __init__.py │ ├── language │ │ ├── __init__.py │ │ ├── text.py │ │ ├── categorical.py │ │ ├── numeric.py │ │ └── datetime.py │ └── tabular │ │ ├── __init__.py │ │ ├── categorical.py │ │ └── character.py │ ├── _language │ ├── engine │ │ ├── __init__.py │ │ ├── base.py │ │ ├── hf_engine.py │ │ └── vllm_engine.py │ ├── __init__.py │ ├── common.py │ ├── lstm.py │ └── tokenizer_utils.py │ ├── __init__.py │ ├── logging.py │ ├── random_state.py │ ├── encoding.py │ ├── _dtypes.py │ ├── _memory.py │ ├── generation.py │ ├── training.py │ └── _training_utils.py ├── .pre-commit-config.yaml ├── mkdocs.yml ├── CONTRIBUTING.md ├── pyproject.toml ├── examples ├── language.ipynb ├── flat.ipynb └── sequential.ipynb └── Makefile /tests/unit/fixtures/workspace/all/SyntheticData/part.000001.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/SyntheticData/part.000002.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/model-data/model-weights.pt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/ctx-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/ctx-data/part.000000-val.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/encoded-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/encoded-data/part.000000-val.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/tgt-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/OriginalData/tgt-data/part.000000-val.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/OriginalData/tgt-data/part.000000-trn.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/OriginalData/tgt-data/part.000000-val.parquet: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | --8<-- "README.md" 7 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/HEAD/docs/logo.png -------------------------------------------------------------------------------- /docs/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai-engine/HEAD/docs/favicon.png -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-meta/keys.json: -------------------------------------------------------------------------------- 1 | { 2 | "primary_key": "__primary_key" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-meta/encoding-types.json: -------------------------------------------------------------------------------- 1 | { 2 | "desc": "LANGUAGE_TEXT" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-meta/keys.json: -------------------------------------------------------------------------------- 1 | { 2 | "context_key": "__primary_key" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/ModelStore/tgt-meta/encoding-types.json: -------------------------------------------------------------------------------- 1 | { 2 | "desc": "LANGUAGE_TEXT" 3 | } 4 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/some/ModelStore/tgt-meta/keys.json: -------------------------------------------------------------------------------- 1 | { 2 | "context_key": "__primary_key" 3 | } 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .vscode/ 4 | .ipynb_checkpoints/ 5 | .DS_Store 6 | dist/ 7 | examples/ws-*/ 8 | LICENSE_HEADER 9 | /site/ 10 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-meta/encoding-types.json: -------------------------------------------------------------------------------- 1 | { 2 | "deathDate": "TABULAR_DATETIME", 3 | "bats": "TABULAR_CATEGORICAL" 4 | } 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Questions 4 | url: https://github.com/mostly-ai/mostlyai-engine/discussions 5 | about: Ask questions and discuss with other community members 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Report an issue related to the documentation. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the issue** 11 | Please provide a clear and concise description of what the issue is. 12 | 13 | **Expected behavior** 14 | A clear and concise description of what you expected to see. 15 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | ## Engine Reference 7 | 8 | ::: mostlyai.engine 9 | options: 10 | members: 11 | - split 12 | - analyze 13 | - encode 14 | - train 15 | - generate 16 | 17 | ## Schema Reference 18 | 19 | ::: mostlyai.engine.domain 20 | options: 21 | filters: 22 | - "!^CustomBaseModel" 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature Request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/model-data/model-configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "consistency_correction_tf": [], 3 | "model_size": { 4 | "embed.tgt.__ridx": 5, 5 | "embed.tgt.__stop": 2, 6 | "embed.tgt.c0__tokens": 9, 7 | "embed.ctx.c0__nan": 2, 8 | "embed.ctx.c0__year": 12, 9 | "embed.ctx.c0__month": 6, 10 | "embed.ctx.c0__day": 9, 11 | "embed.ctx.c1__cat": 5, 12 | "context_0": 256, 13 | "history_0": 256, 14 | "reg.tgt.__ridx_0": 16, 15 | "reg.tgt.__stop_0": 16, 16 | "reg.tgt.c0__tokens_0": 32 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/end_to_end/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_tabular/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/tabular/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-check.yml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit Check 2 | 3 | on: [workflow_call] 4 | 5 | jobs: 6 | pre-commit-check: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 10 | - name: Set up Python 11 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 12 | with: 13 | python-version: '3.10' 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install pre-commit 18 | pre-commit install 19 | - name: Run pre-commit 20 | run: pre-commit run --all-files 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Create a report to help us reproduce and fix the bug 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | Please provide a clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Code to reproduce the behavior: 15 | ``` 16 | # All necessary imports at the beginning 17 | import pandas as pd 18 | from mostlyai import engine 19 | # A succinct reproducing example trimmed down to the essential parts: 20 | df = pd.DataFrame({'x': [1, 2, 3]}) 21 | engine.split(...) 22 | engine.analyze(...) 23 | engine.encode(...) 24 | engine.train(...) 25 | engine.generate(...) 26 | ``` 27 | 28 | **Expected behavior** 29 | A clear and concise description of what you expected to happen. 30 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from mostlyai.engine._language.lstm import register_mostly_lstm_model 18 | 19 | # suppress xgrammar max_rollback_tokens deprecation warnings 20 | warnings.filterwarnings("ignore", message=".*max_rollback_tokens.*", category=DeprecationWarning) 21 | 22 | register_mostly_lstm_model() 23 | -------------------------------------------------------------------------------- /.github/changelog_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "# What's Changed\n\n#{{CHANGELOG}}\n\n**Full Changelog**: [#{{FROM_TAG}}...#{{TO_TAG}}](#{{RELEASE_DIFF}})", 3 | "pr_template": "- #{{TITLE}} [##{{NUMBER}}](#{{URL}})", 4 | "empty_template": "No Changes", 5 | "categories": [ 6 | { 7 | "title": "## 🚀 Features", 8 | "labels": ["feat"] 9 | }, 10 | { 11 | "title": "## 🐛 Fixes", 12 | "labels": ["fix"] 13 | }, 14 | { 15 | "title": "## 📦 Uncategorized", 16 | "labels": ["chore", "build", "docs", "refactor", "style"] 17 | } 18 | ], 19 | "ignore_labels": ["bump", "ci"], 20 | "label_extractor": [ 21 | { 22 | "pattern": "^([\\w-]+)(?:\\(([^)]+)\\))?: (.+)$", 23 | "target": "$1", 24 | "on_property": "title" 25 | } 26 | ], 27 | "transformers": [ 28 | { 29 | "pattern": "^(?:[^:]+:\\s*)?(.*)$", 30 | "method": "replace", 31 | "target": "$1", 32 | "on_property": "title" 33 | } 34 | ] 35 | } 36 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yaml: -------------------------------------------------------------------------------- 1 | name: Complete Workflow 2 | 3 | on: [push, pull_request] 4 | 5 | env: 6 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 7 | FORCE_COLOR: '1' 8 | 9 | jobs: 10 | pre-commit-check: 11 | if: | 12 | github.event_name == 'push' || 13 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 14 | uses: ./.github/workflows/pre-commit-check.yml 15 | secrets: inherit 16 | run-tests-cpu: 17 | if: | 18 | github.event_name == 'push' || 19 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 20 | uses: ./.github/workflows/run-tests-cpu.yaml 21 | secrets: inherit 22 | run-tests-gpu: 23 | if: | 24 | github.ref == 'refs/heads/main' || 25 | startsWith(github.ref, 'refs/tags/') || 26 | contains(github.event.head_commit.message, '[gpu]') 27 | uses: ./.github/workflows/run-tests-gpu.yaml 28 | secrets: inherit 29 | -------------------------------------------------------------------------------- /tests/unit/test_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mostlyai.engine._memory import extract_memory_from_string 16 | 17 | 18 | def test_extract_memory_from_string(): 19 | assert extract_memory_from_string("3.2GB") == int(3.2 * 1024**3) 20 | assert extract_memory_from_string("3.2Gi") == int(3.2 * 1024**3) 21 | assert extract_memory_from_string(" 3 g ") == 3 * 1024**3 22 | assert extract_memory_from_string("0.23GB") == int(0.23 * 1024**3) 23 | assert extract_memory_from_string("32804 gb") == 32804 * 1024**3 24 | assert extract_memory_from_string("4B") == 4 25 | assert extract_memory_from_string("4") == 4 26 | assert extract_memory_from_string("") is None 27 | assert extract_memory_from_string() is None 28 | -------------------------------------------------------------------------------- /tests/unit/test_domain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | from pydantic import ValidationError 17 | 18 | from mostlyai.engine.domain import RebalancingConfig 19 | 20 | 21 | def test_rebalancing_config_valid(): 22 | config = RebalancingConfig(column="test_column", probabilities={"A": 0.3, "B": 0.5}) 23 | assert config.column == "test_column" 24 | assert config.probabilities == {"A": 0.3, "B": 0.5} 25 | 26 | 27 | def test_rebalancing_config_invalid_probabilities_values_out_of_range(): 28 | with pytest.raises(ValidationError): 29 | RebalancingConfig(column="test_column", probabilities={"A": -0.5, "B": 1.5}) 30 | 31 | 32 | def test_rebalancing_config_invalid_probabilities_values_sum(): 33 | with pytest.raises(ValidationError): 34 | RebalancingConfig(column="test_column", probabilities={"A": 0.3, "B": 0.8}) 35 | -------------------------------------------------------------------------------- /mostlyai/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | 16 | from mostlyai.engine._language.interface import LanguageModel 17 | from mostlyai.engine._tabular.interface import TabularARGN 18 | from mostlyai.engine.analysis import analyze 19 | from mostlyai.engine.encoding import encode 20 | from mostlyai.engine.generation import generate 21 | from mostlyai.engine.logging import init_logging 22 | from mostlyai.engine.random_state import set_random_state 23 | from mostlyai.engine.splitting import split 24 | from mostlyai.engine.training import train 25 | 26 | __all__ = [ 27 | "split", 28 | "analyze", 29 | "encode", 30 | "train", 31 | "generate", 32 | "init_logging", 33 | "set_random_state", 34 | "TabularARGN", 35 | "LanguageModel", 36 | ] 37 | __version__ = "2.3.3" 38 | 39 | # suppress specific warning related to os.fork() in multi-threaded processes 40 | warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*multi-threaded.*fork.*") 41 | -------------------------------------------------------------------------------- /mostlyai/engine/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import sys 17 | 18 | _LOG = logging.getLogger(__name__.rsplit(".", 1)[0]) # get the logger with the root module name (mostlyai.engine) 19 | 20 | 21 | def init_logging() -> None: 22 | """ 23 | Initialize the logging configuration to stdout. 24 | """ 25 | 26 | _LOG.propagate = False 27 | if not _LOG.hasHandlers(): 28 | handler = logging.StreamHandler(stream=sys.stdout) 29 | handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)-7s: %(message)s")) 30 | handler.setLevel(logging.INFO) 31 | _LOG.addHandler(handler) 32 | _LOG.setLevel(logging.INFO) 33 | 34 | 35 | def disable_logging() -> None: 36 | """ 37 | Disable the logging by removing all handlers and resetting to default state. 38 | """ 39 | # Remove all handlers 40 | for handler in _LOG.handlers[:]: 41 | handler.close() 42 | _LOG.removeHandler(handler) 43 | 44 | # Reset to default state 45 | _LOG.setLevel(logging.WARNING) # Default Python logging level 46 | _LOG.propagate = True # Restore default propagation 47 | -------------------------------------------------------------------------------- /mostlyai/engine/random_state.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import random 18 | import struct 19 | 20 | import numpy as np 21 | import torch 22 | 23 | _LOG = logging.getLogger(__name__) 24 | 25 | 26 | def set_random_state(random_state: int | None = None, worker: bool = False): 27 | def get_random_int_from_os() -> int: 28 | # 32-bit, cryptographically secure random int from os 29 | return int(struct.unpack("I", os.urandom(4))[0]) 30 | 31 | if worker: # worker process 32 | if "MOSTLYAI_ENGINE_SEED" in os.environ: 33 | random_state = int(os.environ["MOSTLYAI_ENGINE_SEED"]) 34 | else: 35 | # don't set seed for worker process if not set in main process 36 | return 37 | else: # main process 38 | if random_state is not None: 39 | _LOG.info(f"Global random_state set to `{random_state}`") 40 | 41 | if random_state is None: 42 | random_state = get_random_int_from_os() 43 | 44 | os.environ["MOSTLYAI_ENGINE_SEED"] = str(random_state) 45 | 46 | random.seed(random_state) 47 | np.random.seed(random_state) 48 | torch.manual_seed(random_state) 49 | torch.cuda.manual_seed_all(random_state) 50 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-gpu.yaml: -------------------------------------------------------------------------------- 1 | name: '[GPU] mostlyai-engine Tests' 2 | 3 | on: 4 | workflow_call: 5 | 6 | env: 7 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 8 | FORCE_COLOR: '1' 9 | 10 | jobs: 11 | run-tests-gpu: 12 | runs-on: gha-gpu-public-internal 13 | container: 14 | image: nvidia/cuda:13.0.2-cudnn-runtime-ubuntu24.04 15 | options: --gpus all 16 | permissions: 17 | contents: read 18 | packages: write 19 | steps: 20 | - name: Setup | Install Git 21 | run: | 22 | apt-get update -qq 23 | apt-get install -y --no-install-recommends git build-essential 24 | 25 | - name: Setup | Checkout 26 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 27 | with: 28 | fetch-depth: 0 29 | submodules: 'recursive' 30 | 31 | - name: Setup | uv 32 | uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v7.1.5 33 | with: 34 | enable-cache: false 35 | python-version: '3.10' 36 | 37 | - name: Setup | Dependencies 38 | run: | 39 | uv sync --frozen --only-group dev 40 | uv pip install ".[gpu]" 41 | 42 | - name: Setup | Check for available GPU-s 43 | run: nvidia-smi 44 | 45 | - name: Run tests -> end_to_end -> sequential 46 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py 47 | 48 | - name: Run tests -> end_to_end -> sequential context 49 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py 50 | 51 | - name: Run tests -> end_to_end all except sequential 52 | run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ 53 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import STRING, safe_convert_string 18 | 19 | 20 | def analyze_text(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 21 | # ideally, we should ensure that values are converted to string in a consistent way across analyze/encode/qa steps 22 | values = safe_convert_string(values) 23 | nchars = values.map(str).str.len() 24 | stats = {"nchar_max": int(nchars.max()), "nchar_sum": int(nchars.sum()), "count": len(values)} 25 | return stats 26 | 27 | 28 | def analyze_reduce_text( 29 | stats_list: list[dict], 30 | value_protection: bool = True, 31 | value_protection_epsilon: float | None = None, 32 | ) -> dict: 33 | nchar_max = 0 34 | nchar_sum = 0 35 | count = 0 36 | for stats in stats_list: 37 | nchar_max = max(stats["nchar_max"], nchar_max) 38 | nchar_sum += stats["nchar_sum"] 39 | count += stats["count"] 40 | 41 | stats = { 42 | "nchar_avg": round(nchar_sum / count, 1), 43 | "nchar_max": nchar_max, 44 | } 45 | return stats 46 | 47 | 48 | def decode_text(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: 49 | return x.astype(STRING) 50 | -------------------------------------------------------------------------------- /mostlyai/engine/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | from mostlyai.engine._common import ProgressCallback 18 | from mostlyai.engine._workspace import resolve_model_type 19 | from mostlyai.engine.domain import ModelType 20 | 21 | 22 | def encode( 23 | *, 24 | workspace_dir: str | Path = "engine-ws", 25 | update_progress: ProgressCallback | None = None, 26 | ) -> None: 27 | """ 28 | Encodes data in the workspace that has already been split and analyzed. 29 | 30 | Creates the following folder structure within the `workspace_dir`: 31 | 32 | - `OriginalData/encoded-data`: Encoded data for training, stored as parquet files. 33 | 34 | Args: 35 | workspace_dir: Directory path for workspace. 36 | update_progress: Callback for progress updates. 37 | """ 38 | model_type = resolve_model_type(workspace_dir) 39 | if model_type == ModelType.tabular: 40 | from mostlyai.engine._tabular.encoding import encode as encode_tabular 41 | 42 | return encode_tabular(workspace_dir=workspace_dir, update_progress=update_progress) 43 | else: 44 | from mostlyai.engine._language.encoding import encode as encode_language 45 | 46 | return encode_language(workspace_dir=workspace_dir, update_progress=update_progress) 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | exclude: '^(examples)/' 3 | repos: 4 | - repo: local 5 | hooks: 6 | - id: generate-license-header 7 | name: Generate temporary license header file 8 | entry: | 9 | bash -c ' 10 | HEADER_CONTENT="Copyright 2025 MOSTLY AI\n\ 11 | \n\ 12 | Licensed under the Apache License, Version 2.0 (the \"License\");\n\ 13 | you may not use this file except in compliance with the License.\n\ 14 | You may obtain a copy of the License at\n\ 15 | \n\ 16 | http://www.apache.org/licenses/LICENSE-2.0\n\ 17 | \n\ 18 | Unless required by applicable law or agreed to in writing, software\n\ 19 | distributed under the License is distributed on an \"AS IS\" BASIS,\n\ 20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n\ 21 | See the License for the specific language governing permissions and\n\ 22 | limitations under the License." 23 | 24 | echo -e "$HEADER_CONTENT" > LICENSE_HEADER 25 | ' 26 | language: system 27 | - repo: https://github.com/Lucas-C/pre-commit-hooks 28 | rev: v1.5.5 29 | hooks: 30 | - id: insert-license 31 | files: \.py$ 32 | args: 33 | # - --remove-header 34 | - --license-filepath 35 | - LICENSE_HEADER 36 | - --use-current-year 37 | - repo: https://github.com/pre-commit/pre-commit-hooks 38 | rev: v5.0.0 39 | hooks: 40 | - id: end-of-file-fixer 41 | - id: trailing-whitespace 42 | - id: end-of-file-fixer 43 | - id: check-json 44 | - id: mixed-line-ending 45 | args: [--fix=lf] 46 | - repo: https://github.com/asottile/pyupgrade 47 | rev: v3.19.1 48 | hooks: 49 | - id: pyupgrade 50 | args: [--py310-plus] 51 | - repo: https://github.com/astral-sh/ruff-pre-commit 52 | rev: v0.11.6 53 | hooks: 54 | - id: ruff 55 | args: [ --fix ] 56 | - id: ruff-format 57 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-stats/part.000000-trn.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "deathDate": { 4 | "has_nan": true, 5 | "min_values": { 6 | "year": 1873, 7 | "month": 1, 8 | "day": 1, 9 | "hour": 0, 10 | "minute": 0, 11 | "second": 0, 12 | "ms_E2": 0, 13 | "ms_E1": 0, 14 | "ms_E0": 0 15 | }, 16 | "max_values": { 17 | "year": 2019, 18 | "month": 12, 19 | "day": 31, 20 | "hour": 0, 21 | "minute": 0, 22 | "second": 0, 23 | "ms_E2": 0, 24 | "ms_E1": 0, 25 | "ms_E0": 0 26 | }, 27 | "min10": [ 28 | "1873-02-26", 29 | "1876-10-18", 30 | "1879-06-18", 31 | "1881-03-01", 32 | "1881-05-10", 33 | "1884-04-29", 34 | "1884-09-26", 35 | "1886-02-13", 36 | "1886-05-21", 37 | "1886-08-09" 38 | ], 39 | "max10": [ 40 | "2019-12-29", 41 | "2019-12-16", 42 | "2019-12-15", 43 | "2019-12-08", 44 | "2019-11-28", 45 | "2019-11-23", 46 | "2019-09-07", 47 | "2019-09-06", 48 | "2019-09-06", 49 | "2019-08-26" 50 | ], 51 | "encoding_type": "TABULAR_DATETIME" 52 | }, 53 | "bats": { 54 | "has_nan": false, 55 | "cnt_values": { 56 | "": 361, 57 | "B": 363, 58 | "L": 1586, 59 | "R": 3690 60 | }, 61 | "encoding_type": "TABULAR_CATEGORICAL" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-stats/part.000000-val.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "deathDate": { 4 | "has_nan": true, 5 | "min_values": { 6 | "year": 1873, 7 | "month": 1, 8 | "day": 1, 9 | "hour": 0, 10 | "minute": 0, 11 | "second": 0, 12 | "ms_E2": 0, 13 | "ms_E1": 0, 14 | "ms_E0": 0 15 | }, 16 | "max_values": { 17 | "year": 2019, 18 | "month": 12, 19 | "day": 31, 20 | "hour": 0, 21 | "minute": 0, 22 | "second": 0, 23 | "ms_E2": 0, 24 | "ms_E1": 0, 25 | "ms_E0": 0 26 | }, 27 | "min10": [ 28 | "1873-02-26", 29 | "1876-10-18", 30 | "1879-06-18", 31 | "1881-03-01", 32 | "1881-05-10", 33 | "1884-04-29", 34 | "1884-09-26", 35 | "1886-02-13", 36 | "1886-05-21", 37 | "1886-08-09" 38 | ], 39 | "max10": [ 40 | "2019-12-29", 41 | "2019-12-16", 42 | "2019-12-15", 43 | "2019-12-08", 44 | "2019-11-28", 45 | "2019-11-23", 46 | "2019-09-07", 47 | "2019-09-06", 48 | "2019-09-06", 49 | "2019-08-26" 50 | ], 51 | "encoding_type": "TABULAR_DATETIME" 52 | }, 53 | "bats": { 54 | "has_nan": false, 55 | "cnt_values": { 56 | "": 361, 57 | "B": 363, 58 | "L": 1586, 59 | "R": 3690 60 | }, 61 | "encoding_type": "TABULAR_CATEGORICAL" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "mostlyai-engine" 2 | site_url: "https://mostly-ai.github.io/mostlyai-engine/" 3 | repo_url: "https://github.com/mostly-ai/mostlyai-engine" 4 | repo_name: "mostly-ai/mostlyai-engine" 5 | 6 | theme: 7 | name: material 8 | logo: logo.png 9 | favicon: favicon.png 10 | font: 11 | text: Lato 12 | features: 13 | - navigation.top 14 | - navigation.tracking 15 | - navigation.tabs 16 | - navigation.tabs.sticky 17 | - content.code.select 18 | - content.code.copy 19 | - navigation.footer 20 | 21 | palette: 22 | - scheme: default 23 | toggle: 24 | icon: material/brightness-7 25 | name: Switch to dark mode 26 | - scheme: slate 27 | toggle: 28 | icon: material/brightness-2 29 | name: Switch to light mode 30 | 31 | nav: 32 | - Getting started: index.md 33 | - API Reference: api.md 34 | 35 | plugins: 36 | - search 37 | - mkdocstrings: 38 | handlers: 39 | python: 40 | options: 41 | heading_level: 3 42 | show_root_toc_entry: false 43 | show_root_heading: false 44 | show_object_full_path: true 45 | show_bases: false 46 | show_docstring: true 47 | show_source: false 48 | show_signature: true 49 | separate_signature: true 50 | show_docstring_examples: true 51 | docstring_section_style: table 52 | extensions: 53 | - griffe_fieldz 54 | docstring_style: google 55 | - llmstxt: 56 | full_output: llms-full.txt 57 | sections: 58 | Getting started: 59 | - index.md 60 | API Reference: 61 | - api.md 62 | 63 | markdown_extensions: 64 | - pymdownx.highlight: 65 | anchor_linenums: true 66 | line_spans: __span 67 | pygments_lang_class: true 68 | - pymdownx.inlinehilite 69 | - pymdownx.snippets 70 | - pymdownx.superfences 71 | - toc: 72 | permalink: true 73 | -------------------------------------------------------------------------------- /mostlyai/engine/_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pyarrow as pa 17 | 18 | 19 | def is_string_dtype(x: pd.Series) -> bool: 20 | if isinstance(x.dtype, pd.ArrowDtype): 21 | return pa.types.is_string(x.dtype.pyarrow_dtype) 22 | else: 23 | return pd.api.types.is_string_dtype(x) 24 | 25 | 26 | def is_integer_dtype(x: pd.Series) -> bool: 27 | if isinstance(x.dtype, pd.ArrowDtype): 28 | return pa.types.is_integer(x.dtype.pyarrow_dtype) 29 | else: 30 | return pd.api.types.is_integer_dtype(x) 31 | 32 | 33 | def is_float_dtype(x: pd.Series) -> bool: 34 | if isinstance(x.dtype, pd.ArrowDtype): 35 | return pa.types.is_floating(x.dtype.pyarrow_dtype) 36 | else: 37 | return pd.api.types.is_float_dtype(x) 38 | 39 | 40 | def is_date_dtype(x: pd.Series) -> bool: 41 | if isinstance(x.dtype, pd.ArrowDtype): 42 | return pa.types.is_date(x.dtype.pyarrow_dtype) 43 | else: 44 | return False 45 | 46 | 47 | def is_timestamp_dtype(x: pd.Series) -> bool: 48 | if isinstance(x.dtype, pd.ArrowDtype): 49 | return pa.types.is_timestamp(x.dtype.pyarrow_dtype) 50 | else: 51 | return pd.api.types.is_datetime64_any_dtype(x) 52 | 53 | 54 | def is_boolean_dtype(x: pd.Series) -> bool: 55 | if isinstance(x.dtype, pd.ArrowDtype): 56 | return pa.types.is_boolean(x.dtype.pyarrow_dtype) 57 | else: 58 | return pd.api.types.is_bool_dtype(x) 59 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/ctx-stats/stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "deathDate": { 4 | "cardinalities": { 5 | "nan": 2, 6 | "year": 136, 7 | "month": 12, 8 | "day": 31 9 | }, 10 | "has_nan": true, 11 | "has_time": false, 12 | "has_ms": false, 13 | "min_values": { 14 | "year": 1884, 15 | "month": 1, 16 | "day": 1, 17 | "hour": 0, 18 | "minute": 0, 19 | "second": 0, 20 | "ms_E2": 0, 21 | "ms_E1": 0, 22 | "ms_E0": 0 23 | }, 24 | "max_values": { 25 | "year": 2019, 26 | "month": 12, 27 | "day": 31, 28 | "hour": 0, 29 | "minute": 0, 30 | "second": 0, 31 | "ms_E2": 0, 32 | "ms_E1": 0, 33 | "ms_E0": 0 34 | }, 35 | "min5": [ 36 | "1884-04-29", 37 | "1884-09-26", 38 | "1886-02-13", 39 | "1886-05-21", 40 | "1886-08-09" 41 | ], 42 | "max5": [ 43 | "2019-11-23", 44 | "2019-09-07", 45 | "2019-09-06", 46 | "2019-09-06", 47 | "2019-08-26" 48 | ], 49 | "encoding_type": "TABULAR_DATETIME", 50 | "tf_name": "c0" 51 | }, 52 | "bats": { 53 | "no_of_rare_categories": 0, 54 | "codes": { 55 | "_RARE_": 0, 56 | "": 1, 57 | "B": 2, 58 | "L": 3, 59 | "R": 4 60 | }, 61 | "cardinalities": { 62 | "cat": 5 63 | }, 64 | "encoding_type": "TABULAR_CATEGORICAL", 65 | "tf_name": "c1" 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /mostlyai/engine/_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import re 18 | 19 | import psutil 20 | import torch 21 | 22 | _LOG = logging.getLogger(__name__) 23 | 24 | 25 | def get_available_vram_for_heuristics() -> int: 26 | if not torch.cuda.is_available(): 27 | return 0 28 | free, total = torch.cuda.mem_get_info() 29 | return total 30 | 31 | 32 | def get_available_ram_for_heuristics() -> int: 33 | mem_limit = extract_memory_from_string(os.getenv("MOSTLY_ENGINE_AVAILABLE_RAM_FOR_HEURISTICS", default=None)) 34 | if mem_limit is None: 35 | mem_limit = psutil.virtual_memory().available 36 | return mem_limit 37 | 38 | 39 | def extract_memory_from_string(memory_str: str | None = None) -> int | None: 40 | """ 41 | Extract the memory in bytes from a string. 42 | 43 | :param memory_str: The memory string to extract the memory from. 44 | :return: The memory in bytes. 45 | """ 46 | if not memory_str: 47 | return None 48 | 49 | # Conversion factors, considering metric (decimal) vs. binary (IEC) units 50 | units = { 51 | "": 1, 52 | "b": 1, 53 | "k": 1024, 54 | "m": 1024**2, 55 | "g": 1024**3, 56 | "t": 1024**4, 57 | } 58 | match = re.match(r"(\d+(?:\.\d+)?)[ ]?([a-z]?)", memory_str.strip().lower()) 59 | if not match: 60 | return None 61 | 62 | value, unit = match.groups() 63 | value = float(value) 64 | 65 | # Convert to bytes 66 | if unit in units: 67 | return int(value * units[unit]) 68 | else: 69 | return None 70 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Synthetic Data Engine 2 | 3 | Thanks for your interest in contributing to Synthetic Data Engine! Follow these guidelines to set up your environment and streamline your contributions. 4 | 5 | ## Setup 6 | 7 | 1. **Clone the repository**: 8 | ```bash 9 | git clone https://github.com/mostly-ai/mostlyai-engine.git 10 | cd mostlyai-engine 11 | ``` 12 | If you don’t have direct write access to `mostlyai-engine`, fork the repository first and clone your fork: 13 | ```bash 14 | git clone https://github.com//mostlyai-engine.git 15 | cd mostlyai-engine 16 | ``` 17 | 18 | 2. **Install `uv` (if not installed already)**: 19 | ```bash 20 | curl -LsSf https://astral.sh/uv/install.sh | sh 21 | ``` 22 | For alternative installation methods, visit the [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/). 23 | 24 | 3. **Create a virtual environment and install dependencies**: 25 | ```bash 26 | uv sync --frozen --extra cpu --python=3.10 # For CPU-only 27 | source .venv/bin/activate 28 | ``` 29 | If using GPU, run: 30 | ```bash 31 | uv sync --frozen --extra gpu --python=3.10 # For GPU support 32 | source .venv/bin/activate 33 | ``` 34 | 35 | 4. **Install pre-commit hooks**: 36 | ```bash 37 | pre-commit install 38 | ``` 39 | 40 | ## Development Workflow 41 | 42 | 1. **Ensure your local `main` branch is up to date**: 43 | ```bash 44 | git checkout main 45 | git reset --hard origin/main 46 | git pull origin main 47 | ``` 48 | 49 | 2. **Create a new feature or bugfix branch**: 50 | ```bash 51 | git checkout -b my-feature-branch 52 | ``` 53 | 54 | 3. **Implement your changes.** 55 | 56 | 4. **Run tests and pre-commit hooks**: 57 | ```bash 58 | pytest 59 | pre-commit run 60 | ``` 61 | 62 | 5. **Commit your changes with a descriptive message**: 63 | ```bash 64 | git add . 65 | git commit -m "feat: add a clear description of your feature" 66 | ``` 67 | Follow the [Conventional Commits](https://gist.github.com/qoomon/5dfcdf8eec66a051ecd85625518cfd13) format. 68 | 69 | 6. **Push your changes**: 70 | ```bash 71 | git push origin my-feature-branch 72 | ``` 73 | 74 | 7. **Open a pull request on GitHub.** 75 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from dataclasses import dataclass 17 | 18 | from pydantic import BaseModel 19 | 20 | 21 | @dataclass 22 | class EngineMetrics: 23 | tokenize_time: float 24 | generate_time: float 25 | 26 | 27 | class LanguageEngine(ABC): 28 | @abstractmethod 29 | def generate( 30 | self, text: list[str], sampling_temperature: float, sampling_top_p: float 31 | ) -> tuple[list[int], EngineMetrics]: 32 | pass 33 | 34 | @abstractmethod 35 | def get_default_batch_size(self) -> int: 36 | pass 37 | 38 | @abstractmethod 39 | def supports_json_enforcing(self) -> bool: 40 | pass 41 | 42 | @abstractmethod 43 | def cleanup(self): 44 | pass 45 | 46 | def update_json_constraints(self, schemas: list[BaseModel] | None) -> None: 47 | """Update JSON schema constraints for the next generation call. 48 | 49 | Args: 50 | schemas: Schema constraints to apply to the next generate() call. 51 | None to clear any existing constraints. 52 | 53 | Default implementation does nothing. 54 | Engines that support JSON constraints should override this method. 55 | """ 56 | pass 57 | 58 | def can_reuse_schemas(self) -> bool: 59 | """Whether the engine can reuse JSON schema constraints across batches with different sizes. 60 | 61 | Returns: 62 | True if the engine can handle variable batch sizes with reused schema constraints, 63 | False if schema constraints need to be recreated for each batch. 64 | """ 65 | return True 66 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Categorical encoding for language models. 17 | """ 18 | 19 | import pandas as pd 20 | 21 | from mostlyai.engine._common import STRING, safe_convert_string 22 | from mostlyai.engine._encoding_types.tabular.categorical import analyze_categorical, analyze_reduce_categorical 23 | 24 | CATEGORICAL_UNKNOWN_TOKEN = "_RARE_" 25 | 26 | 27 | def analyze_language_categorical(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 28 | return analyze_categorical(values, root_keys, _, safe_escape=False) 29 | 30 | 31 | def analyze_reduce_language_categorical( 32 | stats_list: list[dict], 33 | value_protection: bool = True, 34 | value_protection_epsilon: float | None = None, 35 | ) -> dict: 36 | stats = analyze_reduce_categorical(stats_list, value_protection, value_protection_epsilon) 37 | stats["categories"] = list(stats["codes"].keys()) 38 | if any([j["has_nan"] for j in stats_list]): 39 | # when has_nan, tabular stats are like [CATEGORICAL_UNKNOWN_TOKEN, CATEGORICAL_NULL_TOKEN, ...] 40 | # and we need to replace CATEGORICAL_NULL_TOKEN with None for language 41 | stats["categories"][1] = None 42 | # drop tabular stats 43 | stats.pop("codes") 44 | stats.pop("cardinalities") 45 | return stats 46 | 47 | 48 | def encode_language_categorical(values: pd.Series, stats: dict) -> pd.Series: 49 | values = safe_convert_string(values) 50 | values = values.copy() 51 | known_categories = stats["categories"] 52 | mask = ~values.isin(known_categories) 53 | if None in known_categories: 54 | mask &= ~pd.isna(values) 55 | values[mask] = CATEGORICAL_UNKNOWN_TOKEN 56 | return values 57 | 58 | 59 | def decode_language_categorical(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: 60 | x = x.astype(STRING) 61 | allowed_categories = col_stats.get("categories", []) 62 | return x.where(x.isin(allowed_categories), other=None) 63 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-stats/part.000000-trn.json: -------------------------------------------------------------------------------- 1 | { 2 | "no_of_training_records": 5400, 3 | "no_of_validation_records": 0, 4 | "seq_len": { 5 | "cnt_lengths": { 6 | "1": 6000 7 | } 8 | }, 9 | "columns": { 10 | "desc": { 11 | "has_na": false, 12 | "cnt_values": { 13 | "a": 0, 14 | "d": 0, 15 | "e": 0, 16 | "f": 0, 17 | "g": 0, 18 | "h": 0, 19 | "i": 0, 20 | "l": 0, 21 | "n": 0, 22 | "p": 0, 23 | "r": 0, 24 | "s": 0, 25 | "t": 0, 26 | "w": 0, 27 | "y": 0, 28 | "▁": 0, 29 | "an": 0, 30 | "ay": 0, 31 | "de": 0, 32 | "er": 0, 33 | "han": 0, 34 | "lay": 0, 35 | "play": 0, 36 | "▁han": 0, 37 | "▁play": 0, 38 | "ded": 0, 39 | "▁handed": 6000, 40 | "▁player": 6000, 41 | "gh": 0, 42 | "igh": 0, 43 | "righ": 0, 44 | "▁righ": 0, 45 | "▁right": 3690, 46 | "is": 0, 47 | "▁is": 3062, 48 | "as": 0, 49 | "was": 0, 50 | "▁was": 2938, 51 | "ef": 0, 52 | "lef": 0, 53 | "▁lef": 0, 54 | "▁left": 2310 55 | }, 56 | "cnt_lengths": { 57 | "4": 6000 58 | }, 59 | "merges": [ 60 | "a n", 61 | "a y", 62 | "d e", 63 | "e r", 64 | "h an", 65 | "l ay", 66 | "p lay", 67 | "▁ han", 68 | "▁ play", 69 | "de d", 70 | "▁han ded", 71 | "▁play er", 72 | "g h", 73 | "i gh", 74 | "r igh", 75 | "▁ righ", 76 | "▁righ t", 77 | "i s", 78 | "▁ is", 79 | "a s", 80 | "w as", 81 | "▁ was", 82 | "e f", 83 | "l ef", 84 | "▁ lef", 85 | "▁lef t" 86 | ], 87 | "encoding_type": "LANGUAGE_TEXT" 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-stats/part.000000-val.json: -------------------------------------------------------------------------------- 1 | { 2 | "no_of_training_records": 0, 3 | "no_of_validation_records": 600, 4 | "seq_len": { 5 | "cnt_lengths": { 6 | "1": 6000 7 | } 8 | }, 9 | "columns": { 10 | "desc": { 11 | "has_na": false, 12 | "cnt_values": { 13 | "a": 0, 14 | "d": 0, 15 | "e": 0, 16 | "f": 0, 17 | "g": 0, 18 | "h": 0, 19 | "i": 0, 20 | "l": 0, 21 | "n": 0, 22 | "p": 0, 23 | "r": 0, 24 | "s": 0, 25 | "t": 0, 26 | "w": 0, 27 | "y": 0, 28 | "▁": 0, 29 | "an": 0, 30 | "ay": 0, 31 | "de": 0, 32 | "er": 0, 33 | "han": 0, 34 | "lay": 0, 35 | "play": 0, 36 | "▁han": 0, 37 | "▁play": 0, 38 | "ded": 0, 39 | "▁handed": 6000, 40 | "▁player": 6000, 41 | "gh": 0, 42 | "igh": 0, 43 | "righ": 0, 44 | "▁righ": 0, 45 | "▁right": 3690, 46 | "is": 0, 47 | "▁is": 3062, 48 | "as": 0, 49 | "was": 0, 50 | "▁was": 2938, 51 | "ef": 0, 52 | "lef": 0, 53 | "▁lef": 0, 54 | "▁left": 2310 55 | }, 56 | "cnt_lengths": { 57 | "4": 6000 58 | }, 59 | "merges": [ 60 | "a n", 61 | "a y", 62 | "d e", 63 | "e r", 64 | "h an", 65 | "l ay", 66 | "p lay", 67 | "▁ han", 68 | "▁ play", 69 | "de d", 70 | "▁han ded", 71 | "▁play er", 72 | "g h", 73 | "i gh", 74 | "r igh", 75 | "▁ righ", 76 | "▁righ t", 77 | "i s", 78 | "▁ is", 79 | "a s", 80 | "w as", 81 | "▁ was", 82 | "e f", 83 | "l ef", 84 | "▁ lef", 85 | "▁lef t" 86 | ], 87 | "encoding_type": "LANGUAGE_TEXT" 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-cpu.yaml: -------------------------------------------------------------------------------- 1 | name: '[CPU] mostlyai-engine Tests' 2 | 3 | on: 4 | workflow_call: 5 | 6 | env: 7 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 8 | FORCE_COLOR: '1' 9 | 10 | jobs: 11 | run-tests-cpu-unit-sequential: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | contents: read 15 | packages: write 16 | steps: 17 | - name: Setup | Checkout 18 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 19 | with: 20 | fetch-depth: 0 21 | submodules: 'recursive' 22 | 23 | - name: Setup | uv 24 | uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v7.1.5 25 | with: 26 | enable-cache: false 27 | python-version: '3.10' 28 | 29 | - name: Setup | Dependencies 30 | run: | 31 | uv sync --frozen --only-group dev --only-group docs 32 | uv pip install --index-strategy unsafe-first-match torch==2.9.1+cpu torchvision==0.24.1+cpu . --extra-index-url https://download.pytorch.org/whl/cpu 33 | 34 | - name: Run | Tests -> unit 35 | run: uv run --no-sync pytest tests/unit 36 | 37 | - name: Build mkdocs 38 | run: uv run --no-sync mkdocs build --strict 39 | 40 | - name: Run tests -> end_to_end -> sequential 41 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py 42 | 43 | - name: Run tests -> end_to_end -> sequential context 44 | run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py 45 | 46 | run-tests-cpu-end-to-end-nonsequential: 47 | runs-on: ubuntu-latest 48 | permissions: 49 | contents: read 50 | packages: write 51 | steps: 52 | - name: Setup | Checkout 53 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 54 | with: 55 | fetch-depth: 0 56 | submodules: 'recursive' 57 | 58 | - name: Setup | uv 59 | uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v7.1.5 60 | with: 61 | enable-cache: false 62 | python-version: '3.10' 63 | 64 | - name: Setup | Dependencies 65 | run: | 66 | uv sync --frozen --only-group dev 67 | uv pip install --index-strategy unsafe-first-match torch==2.9.1+cpu torchvision==0.24.1+cpu . --extra-index-url https://download.pytorch.org/whl/cpu 68 | 69 | - name: Run tests -> end_to_end all except sequential 70 | run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ 71 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mostlyai-engine" 3 | version = "2.3.3" 4 | description = "Synthetic Data Engine" 5 | authors = [{ name = "MOSTLY AI", email = "dev@mostly.ai" }] 6 | requires-python = ">=3.10" 7 | readme = "README.md" 8 | license = "Apache-2.0" 9 | classifiers = [ 10 | "Development Status :: 5 - Production/Stable", 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Science/Research", 13 | "Intended Audience :: Information Technology", 14 | "Intended Audience :: Financial and Insurance Industry", 15 | "Intended Audience :: Healthcare Industry", 16 | "Intended Audience :: Telecommunications Industry", 17 | "Programming Language :: Python :: 3.10", 18 | "Programming Language :: Python :: 3.11", 19 | "Programming Language :: Python :: 3.12", 20 | "Programming Language :: Python :: 3.13", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Operating System :: OS Independent", 23 | "Topic :: Software Development :: Libraries", 24 | "Typing :: Typed", 25 | ] 26 | 27 | dependencies = [ 28 | "setuptools>=77.0.3", 29 | "numpy>=2.0.0", 30 | "pandas>=2.2.0", 31 | "pyarrow>=16.0.0", 32 | "joblib>=1.4.2", 33 | "scikit-learn>=1.4.0", 34 | "psutil>=5.9.5,<6", # upgrade when colab psutil is updated 35 | "tokenizers>=0.21.0", 36 | "transformers>=4.55.0", 37 | "datasets>=3.0.0", 38 | "accelerate>=1.5.0", 39 | "peft>=0.12.0", 40 | "huggingface-hub[hf-xet]>=0.30.2", 41 | "opacus>=1.5.4", 42 | "xgrammar>=0.1.21", 43 | "json-repair>=0.47.0", 44 | "torch>=2.9.0,<2.10.0", 45 | "torchaudio>=2.9.0,<2.10.0", 46 | "torchvision>=0.24.0,<0.25.0" 47 | ] 48 | 49 | [project.optional-dependencies] 50 | gpu = [ 51 | "bitsandbytes==0.42.0; sys_platform == 'darwin'", 52 | "bitsandbytes>=0.45.5; sys_platform == 'linux'", 53 | "vllm==0.12.0; sys_platform == 'linux' or sys_platform == 'darwin'", 54 | ] 55 | 56 | [dependency-groups] 57 | dev = [ 58 | "pytest>=8.0", 59 | "pytest-rerunfailures>=15.0", 60 | "ruff>=0.11", # sync'ed with .pre-commit-config 61 | "pre-commit>=4.0", 62 | "twine>=6.1", 63 | "ipykernel>=6.25", 64 | ] 65 | docs = [ 66 | "mkdocs>=1.6", 67 | "mkdocstrings[crystal, python]>=0.29", 68 | "mkdocs-material>=9.0", 69 | "mkdocs-llmstxt>=0.2", 70 | "griffe>=1.0", 71 | "pymdown-extensions>=10.0", 72 | "griffe-fieldz>=0.2", 73 | "black>=25.0", 74 | ] 75 | 76 | [project.urls] 77 | homepage = "https://github.com/mostly-ai/mostlyai-engine" 78 | repository = "https://github.com/mostly-ai/mostlyai-engine" 79 | documentation = "https://mostly-ai.github.io/mostlyai-engine/" 80 | 81 | [tool.uv] 82 | default-groups = ["dev", "docs"] 83 | 84 | [tool.hatch.build.targets.sdist] 85 | include = ["mostlyai/engine"] 86 | 87 | [tool.hatch.build.targets.wheel] 88 | include = ["mostlyai/engine"] 89 | 90 | [tool.hatch.metadata] 91 | allow-direct-references = true 92 | 93 | [build-system] 94 | requires = ["hatchling", "hatch-vcs"] 95 | build-backend = "hatchling.build" 96 | 97 | [tool.ruff] 98 | target-version = "py310" 99 | line-length = 120 100 | [tool.ruff.format] 101 | exclude = ["examples/*.ipynb"] 102 | [tool.ruff.lint] 103 | extend-select = ["I"] 104 | -------------------------------------------------------------------------------- /examples/language.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Language Model: flat data, without context\n", 10 | "\n", 11 | "**Note**: The default model is `MOSTLY_AI/LSTMFromScratch-3m`, a lightweight LSTM model trained from scratch (**GPU strongly recommended**). You can also use pre-trained HuggingFace models by setting e.g. `model=\"microsoft/phi-1.5\"` (**GPU required**).\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/language.ipynb)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import pandas as pd\n", 28 | "from mostlyai.engine import LanguageModel\n", 29 | "\n", 30 | "# load original data\n", 31 | "url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/arxiv\"\n", 32 | "trn_df = pd.read_parquet(f\"{url}/synthetic-data-papers.parquet\")[['category', 'title']]\n", 33 | "\n", 34 | "# create and fit the model\n", 35 | "lm = LanguageModel(\n", 36 | " model=\"MOSTLY_AI/LSTMFromScratch-3m\", # use a light-weight LSTM model, trained from scratch (GPU recommended)\n", 37 | " # model=\"microsoft/phi-1.5\", # or alternatively use a HF-hosted LLM model (GPU required)\n", 38 | " max_training_time=10, # limit training to 10 minutes for demo purposes\n", 39 | " tgt_encoding_types={\n", 40 | " 'category': 'LANGUAGE_CATEGORICAL',\n", 41 | " 'title': 'LANGUAGE_TEXT',\n", 42 | " },\n", 43 | " verbose=1,\n", 44 | ")\n", 45 | "lm.fit(trn_df)\n", 46 | "\n", 47 | "# generate synthetic samples\n", 48 | "syn_tgt_df = lm.sample(n_samples=100)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "syn_tgt_df.head(5)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": ".venv", 71 | "language": "python", 72 | "name": "python3" 73 | }, 74 | "language_info": { 75 | "codemirror_mode": { 76 | "name": "ipython", 77 | "version": 3 78 | }, 79 | "file_extension": ".py", 80 | "mimetype": "text/x-python", 81 | "name": "python", 82 | "nbconvert_exporter": "python", 83 | "pygments_lexer": "ipython3", 84 | "version": "3.12.8" 85 | }, 86 | "toc": { 87 | "base_numbering": 1, 88 | "nav_menu": {}, 89 | "number_sections": false, 90 | "sideBar": true, 91 | "skip_h1_title": false, 92 | "title_cell": "Table of Contents", 93 | "title_sidebar": "Contents", 94 | "toc_cell": false, 95 | "toc_position": {}, 96 | "toc_section_display": true, 97 | "toc_window_display": false 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 4 102 | } 103 | -------------------------------------------------------------------------------- /tests/unit/fixtures/workspace/all/ModelStore/tgt-stats/stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": { 3 | "desc": { 4 | "cardinalities": { 5 | "tokens": 43 6 | }, 7 | "has_na": false, 8 | "tokens": [ 9 | "▁handed", 10 | "▁player", 11 | "▁right", 12 | "▁is", 13 | "▁was", 14 | "▁left", 15 | "a", 16 | "d", 17 | "e", 18 | "f", 19 | "g", 20 | "h", 21 | "i", 22 | "l", 23 | "n", 24 | "p", 25 | "r", 26 | "s", 27 | "t", 28 | "w", 29 | "y", 30 | "▁", 31 | "an", 32 | "ay", 33 | "de", 34 | "er", 35 | "han", 36 | "lay", 37 | "play", 38 | "▁han", 39 | "▁play", 40 | "ded", 41 | "gh", 42 | "igh", 43 | "righ", 44 | "▁righ", 45 | "is", 46 | "as", 47 | "was", 48 | "ef", 49 | "lef", 50 | "▁lef" 51 | ], 52 | "merges": [ 53 | "a n", 54 | "a y", 55 | "d e", 56 | "e r", 57 | "h an", 58 | "l ay", 59 | "p lay", 60 | "▁ han", 61 | "▁ play", 62 | "de d", 63 | "▁han ded", 64 | "▁play er", 65 | "g h", 66 | "i gh", 67 | "r igh", 68 | "▁ righ", 69 | "▁righ t", 70 | "i s", 71 | "▁ is", 72 | "a s", 73 | "w as", 74 | "▁ was", 75 | "e f", 76 | "l ef", 77 | "▁ lef", 78 | "▁lef t" 79 | ], 80 | "seq_len": { 81 | "min": 4, 82 | "max": 4, 83 | "median": 4, 84 | "deciles": [ 85 | 4, 86 | 4, 87 | 4, 88 | 4, 89 | 4, 90 | 4, 91 | 4, 92 | 4, 93 | 4, 94 | 4, 95 | 4 96 | ] 97 | }, 98 | "encoding_type": "LANGUAGE_TEXT" 99 | } 100 | }, 101 | "no_of_training_records": 5400, 102 | "no_of_validation_records": 600, 103 | "seq_len": { 104 | "min": 1, 105 | "max": 1, 106 | "median": 1, 107 | "deciles": [ 108 | 1, 109 | 1, 110 | 1, 111 | 1, 112 | 1, 113 | 1, 114 | 1, 115 | 1, 116 | 1, 117 | 1, 118 | 1 119 | ] 120 | }, 121 | "is_sequential": false, 122 | "has_sequential_columns": true 123 | } 124 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/test_categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from mostlyai.engine._encoding_types.language.categorical import ( 20 | CATEGORICAL_UNKNOWN_TOKEN, 21 | analyze_language_categorical, 22 | analyze_reduce_language_categorical, 23 | decode_language_categorical, 24 | encode_language_categorical, 25 | ) 26 | 27 | 28 | class TestLanguageCategoricalAnalyze: 29 | def test_3_frequent_and_1_rare_values(self): 30 | values = pd.Series(np.repeat(["secret", "male", "female", pd.NA], 100), name="gender") 31 | ids = pd.Series( 32 | np.concatenate([np.repeat(0, 100), range(100), range(100, 200), range(200, 300)]), 33 | name="subject_id", 34 | ) 35 | stats = analyze_language_categorical(values, ids) 36 | assert stats == { 37 | "cnt_values": {"female": 100, "male": 100, "secret": 1}, 38 | "has_nan": True, 39 | } 40 | 41 | 42 | class TestLanguageCategoricalAnalyzeReduce: 43 | @pytest.fixture 44 | def stats_list(self): 45 | stats1 = { 46 | "cnt_values": {"secret1": 1, "male": 100}, 47 | "has_nan": True, 48 | } 49 | stats2 = { 50 | "cnt_values": {"secret2": 1, "male": 100, "female": 100}, 51 | "has_nan": False, 52 | } 53 | return stats1, stats2 54 | 55 | def test_with_value_protection(self, stats_list): 56 | stats1, stats2 = stats_list 57 | stats = analyze_reduce_language_categorical([stats1, stats2], value_protection=True) 58 | assert stats == { 59 | "categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "female", "male"], 60 | "no_of_rare_categories": 2, 61 | } 62 | 63 | 64 | class TestLanguageCategoricalEncode: 65 | def test_2_frequent_and_1_rare_and_1_null_values(self): 66 | values = pd.Series(np.repeat(["secret", "male", "female", pd.NA], 100), name="gender") 67 | stats = { 68 | "categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "female", "male"], 69 | "no_of_rare_categories": 1, 70 | } 71 | expected = pd.Series( 72 | np.repeat([CATEGORICAL_UNKNOWN_TOKEN, "male", "female", pd.NA], 100), name="gender", dtype="string" 73 | ) 74 | encoded = encode_language_categorical(values, stats) 75 | pd.testing.assert_series_equal(encoded, expected) 76 | 77 | 78 | class TestLanguageCategoricalDecode: 79 | @pytest.fixture 80 | def col_stats(self): 81 | return {"categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "apple", "banana", "cherry"]} 82 | 83 | @pytest.fixture 84 | def sample_values(self): 85 | return pd.Series(["apple", "durian", "banana", "elderberry", "cherry", "fig", None]) 86 | 87 | def test_language_categorical_decode(self, sample_values, col_stats): 88 | decoded = decode_language_categorical(sample_values, col_stats) 89 | expected = pd.Series(["apple", None, "banana", None, "cherry", None, None], dtype=decoded.dtype) 90 | pd.testing.assert_series_equal(decoded, expected) 91 | -------------------------------------------------------------------------------- /tests/end_to_end/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import pytest 19 | from joblib.externals.loky import get_reusable_executor 20 | 21 | from mostlyai.engine._common import STRING 22 | 23 | 24 | @pytest.fixture() 25 | def cleanup_joblib_pool(): 26 | # make sure the test is using a fresh joblib pool 27 | get_reusable_executor().shutdown(wait=True) 28 | yield 29 | get_reusable_executor().shutdown(wait=True) 30 | 31 | 32 | class MockData: 33 | def __init__(self, n_samples: int): 34 | self.n_samples = n_samples 35 | self.df = pd.DataFrame(index=range(self.n_samples)) 36 | 37 | def add_index_column(self, name: str): 38 | values = pd.DataFrame({name: range(len(self.df))}).astype(STRING) 39 | self.df = pd.concat([self.df, values], axis=1) 40 | 41 | def add_categorical_column( 42 | self, name: str, probabilities: dict[str, float], rare_categories: list[str] | None = None 43 | ): 44 | values = np.random.choice( 45 | list(probabilities.keys()), 46 | size=len(self.df), 47 | p=list(probabilities.values()), 48 | ) 49 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 50 | if rare_categories: 51 | self.df.loc[np.random.choice(self.df.index, len(rare_categories), replace=False), name] = rare_categories 52 | 53 | def add_numeric_column(self, name: str, quantiles: dict[float, float], dtype: str = "float32"): 54 | uniform_samples = np.random.rand(len(self.df)) 55 | values = np.interp(uniform_samples, list(quantiles.keys()), list(quantiles.values())).astype(dtype) 56 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 57 | 58 | def add_datetime_column(self, name: str, start_date: str, end_date: str, freq: str = "s"): 59 | date_range = pd.date_range(start=start_date, end=end_date, freq=freq) 60 | values = np.random.choice(date_range, len(self.df), replace=True) 61 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 62 | 63 | def add_date_column(self, name: str, start_date: str, end_date: str): 64 | self.add_datetime_column(name, start_date, end_date, freq="D") 65 | 66 | def add_lat_long_column(self, name: str, lat_limit: tuple[float, float], long_limit: tuple[float, float]): 67 | latitude = np.random.uniform(lat_limit[0], lat_limit[1], len(self.df)) 68 | longitude = np.random.uniform(long_limit[0], long_limit[1], len(self.df)) 69 | values = [f"{lat:.4f}, {long:.4f}" for lat, long in zip(latitude, longitude)] 70 | self.df = pd.concat([self.df, pd.DataFrame({name: values})], axis=1) 71 | 72 | def add_sequential_column(self, name: str, seq_len_quantiles: dict[float, float]): 73 | self.add_numeric_column("seq_len", seq_len_quantiles, dtype="int32") 74 | # if seq_len is 3, it will populate a sequence ["0", "1", "2"] and then explode the list to 3 rows 75 | self.df[name] = self.df["seq_len"].apply(lambda x: [str(i) for i in range(x)]) 76 | self.df = self.df.explode(name).drop(columns="seq_len").reset_index(drop=True) 77 | -------------------------------------------------------------------------------- /tests/unit/test_fairness.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | from mostlyai.engine._common import ARGN_COLUMN, ARGN_PROCESSOR, ARGN_TABLE, get_argn_name 18 | from mostlyai.engine._encoding_types.tabular.categorical import CATEGORICAL_SUB_COL_SUFFIX 19 | from mostlyai.engine._tabular.fairness import _get_sensitive_groups 20 | from mostlyai.engine.domain import ModelEncodingType 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def tgt_stats(): 25 | return { 26 | "columns": { 27 | "c0_cat": { 28 | "encoding_type": ModelEncodingType.tabular_categorical.value, 29 | "argn_processor": "tgt", 30 | "argn_table": "t0", 31 | "argn_column": "c0", 32 | "no_of_rare_categories": 0, 33 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(2)}}, 34 | }, 35 | "c1_cat": { 36 | "encoding_type": ModelEncodingType.tabular_categorical.value, 37 | "argn_processor": "tgt", 38 | "argn_table": "t0", 39 | "argn_column": "c1", 40 | "no_of_rare_categories": 0, 41 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(3)}}, 42 | }, 43 | "c2_cat": { 44 | "encoding_type": ModelEncodingType.tabular_categorical.value, 45 | "argn_processor": "tgt", 46 | "argn_table": "t0", 47 | "argn_column": "c2", 48 | "no_of_rare_categories": 1, 49 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(5)}}, 50 | }, 51 | "c3_num": { 52 | "encoding_type": ModelEncodingType.tabular_numeric_auto.value, 53 | "argn_processor": "tgt", 54 | "argn_table": "t0", 55 | "argn_column": "c3", 56 | }, 57 | "c4_cat": { 58 | "encoding_type": ModelEncodingType.tabular_categorical.value, 59 | "argn_processor": "tgt", 60 | "argn_table": "t0", 61 | "argn_column": "c4", 62 | "no_of_rare_categories": 0, 63 | "codes": {"_RARE_": 0, **{f"cat_{i}": i + 1 for i in range(7)}}, 64 | }, 65 | } 66 | } 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "target_column, sensitive_columns, expected_n_rows", 71 | [ 72 | ("c0_cat", ["c1_cat", "c2_cat"], 18), # 3 * (5+1) 73 | ("c0_cat", ["c1_cat", "c4_cat"], 21), # 3 * 7 74 | ], 75 | ) 76 | def test_get_sensitive_category_groups(tgt_stats, target_column, sensitive_columns, expected_n_rows): 77 | column_stats = tgt_stats["columns"] 78 | sensitive_sub_cols = [ 79 | get_argn_name( 80 | argn_processor=tgt_stats["columns"][col][ARGN_PROCESSOR], 81 | argn_table=tgt_stats["columns"][col][ARGN_TABLE], 82 | argn_column=tgt_stats["columns"][col][ARGN_COLUMN], 83 | argn_sub_column=CATEGORICAL_SUB_COL_SUFFIX, 84 | ) 85 | for col in sensitive_columns 86 | ] 87 | groups_df = _get_sensitive_groups(column_stats, sensitive_columns, sensitive_sub_cols) 88 | assert groups_df.shape[0] == expected_n_rows 89 | -------------------------------------------------------------------------------- /tests/unit/test_workspace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | from unittest import mock 18 | 19 | from mostlyai.engine._common import read_json, write_json 20 | from mostlyai.engine._workspace import Workspace 21 | 22 | FIXTURES_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")) 23 | 24 | 25 | class TestWorkspace: 26 | def test_workspace_all_objects(self): 27 | ws_path = Path(FIXTURES_PATH) / "workspace" / "all" 28 | ws = Workspace(ws_path) 29 | 30 | # Split-related 31 | assert ws.tgt_data_path == ws_path / "OriginalData" / "tgt-data" 32 | tgt_data_file_names = [i.name for i in ws.tgt_data.fetch_all()] 33 | assert tgt_data_file_names == ["part.000000-trn.parquet", "part.000000-val.parquet"] 34 | assert isinstance(ws.tgt_encoding_types.read(), dict) 35 | assert isinstance(ws.tgt_keys.read(), dict) 36 | 37 | assert ws.ctx_data_path == ws_path / "OriginalData" / "ctx-data" 38 | ctx_data_file_names = [i.name for i in ws.ctx_data.fetch_all()] 39 | assert ctx_data_file_names == ["part.000000-trn.parquet", "part.000000-val.parquet"] 40 | assert isinstance(ws.ctx_encoding_types.read(), dict) 41 | assert isinstance(ws.ctx_keys.read(), dict) 42 | 43 | # Analyze-related 44 | assert ws.tgt_stats_path == Path(ws_path) / "ModelStore" / "tgt-stats" 45 | tgt_all_stats_file_names = [i.name for i in ws.tgt_all_stats.fetch_all()] 46 | assert tgt_all_stats_file_names == ["part.000000-trn.json", "part.000000-val.json"] 47 | assert isinstance(ws.tgt_stats.read(), dict) 48 | assert ws.ctx_stats_path == Path(ws_path) / "ModelStore" / "ctx-stats" 49 | ctx_all_stats_file_names = [i.name for i in ws.ctx_all_stats.fetch_all()] 50 | assert ctx_all_stats_file_names == ["part.000000-trn.json", "part.000000-val.json"] 51 | assert isinstance(ws.tgt_stats.read(), dict) 52 | 53 | # Encode-related 54 | assert ws.encoded_data_path == Path(ws_path) / "OriginalData" / "encoded-data" 55 | assert len(ws.encoded_data_val.fetch_all()) == 1 56 | assert len(ws.encoded_data_trn.fetch_all()) == 1 57 | 58 | # Train-related 59 | assert ws.model_path == Path(ws_path) / "ModelStore" / "model-data" 60 | assert ws.model_tabular_weights_path.exists() 61 | assert isinstance(ws.model_configs.read(), dict) 62 | 63 | # Generate-related 64 | assert ws.generated_data_path == Path(ws_path) / "SyntheticData" 65 | generated_data_file_names = [i.name for i in ws.generated_data.fetch_all()] 66 | assert generated_data_file_names == ["part.000001.parquet", "part.000002.parquet"] 67 | 68 | def test_read_write_json(self): 69 | ws_path = Path(FIXTURES_PATH) / "workspace" / "some" 70 | ws = Workspace(ws_path) 71 | 72 | assert ws.tgt_keys.read_handler == read_json 73 | ws.tgt_keys.read() == {"context_key": "__primary_key"} 74 | assert ws.tgt_keys.write_handler == write_json 75 | with mock.patch.object(ws.tgt_keys, "write_handler") as write_mock: 76 | new_key_data = {"new_key": "test_key"} 77 | ws.tgt_keys.write(new_key_data) 78 | assert write_mock.call_args[0] == ( 79 | new_key_data, 80 | ws_path / "OriginalData" / "tgt-meta" / "keys.json", 81 | ) 82 | -------------------------------------------------------------------------------- /examples/flat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Tabular Model: flat data, without context" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/flat.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "from mostlyai.engine import TabularARGN\n", 27 | "\n", 28 | "# load original data\n", 29 | "url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/census\"\n", 30 | "trn_df = pd.read_csv(f\"{url}/census.csv.gz\")\n", 31 | "\n", 32 | "# create and fit the model\n", 33 | "argn = TabularARGN(verbose=1)\n", 34 | "argn.fit(trn_df)\n", 35 | "\n", 36 | "# generate synthetic samples\n", 37 | "syn_df = argn.sample(n_samples=len(trn_df))" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "syn_df.head(5)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "### QUALITY ASSURANCE" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "tags": [] 60 | }, 61 | "source": [ 62 | "#### univariate `age`" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "print(\"Original Age: \" + \", \".join([f'q{q*100:.0f}: {trn_df[\"age\"].quantile(q):.0f}' for q in [.1, .25, .5, .75, .9]]))\n", 72 | "print(\"Synthetic Age: \" + \", \".join([f'q{q*100:.0f}: {syn_df[\"age\"].quantile(q):.0f}' for q in [.1, .25, .5, .75, .9]]))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "#### bivariate `sex` ~ `income`: income gap" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "trn_gap = (trn_df[trn_df[\"sex\"] == \"Male\"][\"income\"] == \">50K\").mean() - (trn_df[trn_df[\"sex\"] == \"Female\"][\"income\"] == \">50K\").mean()\n", 89 | "syn_gap = (syn_df[syn_df[\"sex\"] == \"Male\"][\"income\"] == \">50K\").mean() - (syn_df[syn_df[\"sex\"] == \"Female\"][\"income\"] == \">50K\").mean()\n", 90 | "print(f\"Income Gap {trn_gap:.1%} vs. {syn_gap:.1%}\")" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "#### check consistency between `education` and `education.num`" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "pd.crosstab(syn_df[\"education\"], syn_df[\"education_num\"])" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": ".venv", 113 | "language": "python", 114 | "name": "python3" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 3 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython3", 126 | "version": "3.12.8" 127 | }, 128 | "toc": { 129 | "base_numbering": 1, 130 | "nav_menu": {}, 131 | "number_sections": false, 132 | "sideBar": true, 133 | "skip_h1_title": false, 134 | "title_cell": "Table of Contents", 135 | "title_sidebar": "Contents", 136 | "toc_cell": false, 137 | "toc_position": {}, 138 | "toc_section_display": true, 139 | "toc_window_display": false 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 4 144 | } 145 | -------------------------------------------------------------------------------- /examples/sequential.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "tags": [] 7 | }, 8 | "source": [ 9 | "# Tabular Model: sequential data, with context" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mostly-ai/mostlyai-engine/blob/main/examples/sequential.ipynb)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "tags": [] 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "import pandas as pd\n", 28 | "import numpy as np\n", 29 | "from mostlyai.engine import TabularARGN\n", 30 | "\n", 31 | "# load original data\n", 32 | "url = \"https://github.com/mostly-ai/public-demo-data/raw/refs/heads/dev/baseball\"\n", 33 | "trn_ctx_df = pd.read_csv(f\"{url}/players.csv.gz\") # context data\n", 34 | "trn_tgt_df = pd.read_csv(f\"{url}/batting.csv.gz\") # target data\n", 35 | "\n", 36 | "# create and fit the model with context data\n", 37 | "argn = TabularARGN(\n", 38 | " tgt_context_key=\"players_id\",\n", 39 | " ctx_primary_key=\"id\",\n", 40 | " ctx_data=trn_ctx_df,\n", 41 | " max_training_time=2, # limit training to 2 minutes for demo purposes\n", 42 | " verbose=1,\n", 43 | ")\n", 44 | "argn.fit(trn_tgt_df)\n", 45 | "\n", 46 | "# generate synthetic samples\n", 47 | "syn_tgt_df = argn.sample(n_samples=len(trn_tgt_df))" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "syn_tgt_df.head(5)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "### QUALITY ASSURANCE" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "#### sequence lengths" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "trn_seq_lens = trn_tgt_df.groupby(\"players_id\").size()\n", 82 | "syn_seq_lens = syn_tgt_df.groupby(\"players_id\").size()\n", 83 | "print(\"tgt: \", np.quantile(trn_seq_lens, np.arange(0, 1.1, 0.1), method=\"inverted_cdf\"))\n", 84 | "print(\"syn: \", np.quantile(syn_seq_lens, np.arange(0, 1.1, 0.1), method=\"inverted_cdf\"))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": { 90 | "tags": [] 91 | }, 92 | "source": [ 93 | "#### coherence" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "syn_avg_teams_per_player = syn_tgt_df.groupby(\"players_id\")[\"team\"].nunique().mean().round(1)\n", 103 | "trn_avg_teams_per_player = trn_tgt_df.groupby(\"players_id\")[\"team\"].nunique().mean().round(1)\n", 104 | "syn_avg_teams_per_player, trn_avg_teams_per_player" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": ".venv", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.12.8" 132 | }, 133 | "toc": { 134 | "base_numbering": 1, 135 | "nav_menu": {}, 136 | "number_sections": false, 137 | "sideBar": true, 138 | "skip_h1_title": false, 139 | "title_cell": "Table of Contents", 140 | "title_sidebar": "Contents", 141 | "toc_cell": false, 142 | "toc_position": {}, 143 | "toc_section_display": true, 144 | "toc_window_display": false 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 4 149 | } 150 | -------------------------------------------------------------------------------- /tests/unit/test_tabular_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | from mostlyai.engine._common import ( 18 | ARGN_COLUMN, 19 | ARGN_PROCESSOR, 20 | ARGN_TABLE, 21 | ) 22 | from mostlyai.engine._encoding_types.tabular.categorical import ( 23 | CATEGORICAL_SUB_COL_SUFFIX, 24 | CATEGORICAL_UNKNOWN_TOKEN, 25 | ) 26 | from mostlyai.engine._encoding_types.tabular.numeric import ( 27 | NUMERIC_BINNED_SUB_COL_SUFFIX, 28 | NUMERIC_BINNED_UNKNOWN_TOKEN, 29 | NUMERIC_DISCRETE_SUB_COL_SUFFIX, 30 | NUMERIC_DISCRETE_UNKNOWN_TOKEN, 31 | ) 32 | from mostlyai.engine._tabular.common import ( 33 | fix_rare_token_probs, 34 | translate_fixed_probs, 35 | ) 36 | from mostlyai.engine.domain import ModelEncodingType, RareCategoryReplacementMethod 37 | 38 | 39 | class TestFixRareTokenProbs: 40 | @pytest.mark.parametrize( 41 | "encoding_type", 42 | [ 43 | ModelEncodingType.tabular_numeric_binned, 44 | ModelEncodingType.tabular_numeric_discrete, 45 | ModelEncodingType.tabular_numeric_digit, 46 | ], 47 | ) 48 | def test_numerics(self, encoding_type): 49 | subcol, code = { 50 | ModelEncodingType.tabular_numeric_binned: (NUMERIC_BINNED_SUB_COL_SUFFIX, 1), 51 | ModelEncodingType.tabular_numeric_discrete: (NUMERIC_DISCRETE_SUB_COL_SUFFIX, 0), 52 | ModelEncodingType.tabular_numeric_digit: (None, None), 53 | }[encoding_type] 54 | 55 | def get_stats() -> dict: 56 | return { 57 | "columns": { 58 | "column": { 59 | "encoding_type": encoding_type, 60 | "codes": { 61 | NUMERIC_DISCRETE_UNKNOWN_TOKEN: 0, 62 | NUMERIC_BINNED_UNKNOWN_TOKEN: 1, 63 | }, 64 | }, 65 | } 66 | } 67 | 68 | stats = get_stats() 69 | fixed_probs = fix_rare_token_probs(stats) 70 | expected = {"column": {subcol: {code: 0.0}}} if subcol else {} 71 | assert fixed_probs == expected 72 | 73 | @pytest.mark.parametrize( 74 | "no_of_rare_categories,rare_category_replacement_method,do_fix", 75 | [ 76 | (0, None, True), 77 | (1, None, False), 78 | (1, RareCategoryReplacementMethod.sample, True), 79 | ], 80 | ) 81 | def test_categoricals(self, no_of_rare_categories, rare_category_replacement_method, do_fix): 82 | def get_stats() -> dict: 83 | return { 84 | "columns": { 85 | "column": { 86 | "encoding_type": ModelEncodingType.tabular_categorical.value, 87 | "no_of_rare_categories": no_of_rare_categories, 88 | "codes": {CATEGORICAL_UNKNOWN_TOKEN: 0}, 89 | }, 90 | } 91 | } 92 | 93 | fixed_probs = fix_rare_token_probs( 94 | stats=get_stats(), 95 | rare_category_replacement_method=rare_category_replacement_method, 96 | ) 97 | expected = {"column": {CATEGORICAL_SUB_COL_SUFFIX: {0: 0.0}}} if do_fix else {} 98 | assert fixed_probs == expected 99 | 100 | 101 | class TestTranslateFixedProbs: 102 | def test(self): 103 | fixed_probs = {"column": {"cat": {0: 0.0}}} 104 | stats = { 105 | "columns": { 106 | "column": { 107 | ARGN_PROCESSOR: "tgt", 108 | ARGN_TABLE: "t0", 109 | ARGN_COLUMN: "c0", 110 | } 111 | } 112 | } 113 | fixed_probs_model = translate_fixed_probs(fixed_probs, stats) 114 | assert fixed_probs_model == {"tgt:t0/c0__cat": {0: 0.0}} 115 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/test_numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | from mostlyai.engine._common import ANALYZE_MIN_MAX_TOP_N 20 | from mostlyai.engine._encoding_types.language.numeric import ( 21 | analyze_language_numeric, 22 | analyze_reduce_language_numeric, 23 | decode_language_numeric, 24 | encode_language_numeric, 25 | ) 26 | from mostlyai.engine.domain import ModelEncodingType 27 | 28 | 29 | class TestLanguageNumericAnalyze: 30 | def test_analyze_language_numeric(self): 31 | values = pd.Series([0, 1, 2, 3, 4, 5] * ANALYZE_MIN_MAX_TOP_N, name="value") 32 | ids = pd.Series(range(len(values)), name="id") 33 | stats = analyze_language_numeric(values, ids) 34 | assert stats["has_nan"] is False 35 | assert stats["max_n"] == [5] * ANALYZE_MIN_MAX_TOP_N 36 | assert stats["min_n"] == [0] * ANALYZE_MIN_MAX_TOP_N 37 | 38 | 39 | class TestLanguageNumericAnalyzeReduce: 40 | def test_analyze_reduce_language_numeric(self): 41 | stats1 = { 42 | "has_nan": False, 43 | "max_n": [5] * ANALYZE_MIN_MAX_TOP_N, 44 | "min_n": [0] * ANALYZE_MIN_MAX_TOP_N, 45 | "max_scale": 0, 46 | } 47 | stats2 = { 48 | "has_nan": True, 49 | "max_n": [10] * ANALYZE_MIN_MAX_TOP_N, 50 | "min_n": [6] * ANALYZE_MIN_MAX_TOP_N, 51 | "max_scale": 1, 52 | } 53 | reduced = analyze_reduce_language_numeric([stats1, stats2]) 54 | assert reduced["has_nan"] is True 55 | assert reduced["max"] == 10 56 | assert reduced["min"] == 0 57 | assert reduced["max_scale"] == 1 58 | 59 | 60 | class TestLanguageNumericEncode: 61 | def test_encode_language_numeric(self): 62 | values = pd.Series([-1, 0, 1, 2, 3, 4, 5, 6], name="value") 63 | stats = { 64 | "has_nan": False, 65 | "max": 5, 66 | "min": 0, 67 | "max_scale": 0, 68 | } 69 | encoded = encode_language_numeric(values, stats) 70 | assert encoded.dtype == "Int64" 71 | assert encoded.isna().sum() == 0 72 | assert encoded.iloc[0] == 0 73 | assert encoded.iloc[1] == 0 74 | assert encoded.iloc[2] == 1 75 | assert encoded.iloc[3] == 2 76 | assert encoded.iloc[4] == 3 77 | assert encoded.iloc[5] == 4 78 | assert encoded.iloc[6] == 5 79 | assert encoded.iloc[7] == 5 80 | 81 | 82 | class TestLanguageNumericDecode: 83 | @pytest.fixture 84 | def int_stats(self): 85 | return { 86 | "encoding_type": ModelEncodingType.language_numeric, 87 | "has_nan": False, 88 | "max": 91, 89 | "max_scale": 0, 90 | "min": 17, 91 | } 92 | 93 | @pytest.fixture 94 | def float_stats(self): 95 | return { 96 | "encoding_type": ModelEncodingType.language_numeric, 97 | "has_nan": False, 98 | "max": 91.12, 99 | "max_scale": 2, 100 | "min": 17.0, 101 | } 102 | 103 | @pytest.fixture 104 | def sample_values(self): 105 | return pd.Series(["25.3541", "99.99", "-312.0", "61", None, "35.10091", "-1.223"]) 106 | 107 | @pytest.mark.parametrize( 108 | "stats_name, expected_dtype", 109 | [ 110 | ("int_stats", "Int64"), 111 | ("float_stats", float), 112 | ], 113 | ) 114 | def test_decode_language_numeric(self, sample_values, request, stats_name, expected_dtype): 115 | stats = request.getfixturevalue(stats_name) 116 | decoded = decode_language_numeric(sample_values, stats) 117 | assert decoded.dtype == expected_dtype 118 | non_null = decoded.dropna() # we don't enforce compatability with "has_nan" 119 | round_digits = stats["max_scale"] 120 | for v in non_null: 121 | assert np.isclose(v, round(v, round_digits), atol=1e-8) 122 | assert all(non_null <= stats["max"]) 123 | assert all(non_null >= stats["min"]) 124 | -------------------------------------------------------------------------------- /tests/end_to_end/test_language_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Functional tests for LanguageModel interface. 17 | """ 18 | 19 | import pandas as pd 20 | import pytest 21 | 22 | from mostlyai.engine import LanguageModel 23 | from mostlyai.engine.domain import ModelEncodingType 24 | 25 | 26 | @pytest.fixture 27 | def simple_language_data(): 28 | """Create minimal language data for testing.""" 29 | data = pd.DataFrame( 30 | { 31 | "category": ["business", "tech", "sports", "business", "tech"] * 10, 32 | "headline": [ 33 | "Company announces new product", 34 | "Tech innovation changes industry", 35 | "Team wins championship", 36 | "Market analysis shows growth", 37 | "AI breakthrough announced", 38 | ] 39 | * 10, 40 | "date": pd.date_range("2024-01-01", periods=50, freq="D"), 41 | } 42 | ) 43 | return data 44 | 45 | 46 | class TestLanguageModelBasic: 47 | """Test basic LanguageModel functionality: fit and unconditional sampling.""" 48 | 49 | def test_fit_and_unconditional_sample(self, simple_language_data, tmp_path_factory): 50 | """Test fit() and unconditional sample().""" 51 | data = simple_language_data 52 | 53 | lm = LanguageModel( 54 | model="MOSTLY_AI/LSTMFromScratch-3m", 55 | tgt_encoding_types={ 56 | "category": ModelEncodingType.language_categorical.value, 57 | "headline": ModelEncodingType.language_text.value, 58 | "date": ModelEncodingType.language_datetime.value, 59 | }, 60 | max_epochs=1, 61 | verbose=0, 62 | workspace_dir=tmp_path_factory.mktemp("workspace"), 63 | ) 64 | 65 | # Fit the model 66 | lm.fit(X=data) 67 | assert lm._fitted is True 68 | 69 | # Generate unconditional samples 70 | syn_data = lm.sample( 71 | n_samples=10, 72 | sampling_temperature=0.5, 73 | ) 74 | 75 | # Verify output shape and columns 76 | assert syn_data.shape[0] == 10 77 | assert set(syn_data.columns) == set(data.columns) 78 | assert all(col in syn_data.columns for col in data.columns) 79 | # Verify text columns are strings 80 | assert syn_data["headline"].dtype == "string" or str(syn_data["headline"].dtype).startswith("string") 81 | 82 | 83 | class TestLanguageModelConditional: 84 | """Test conditional sampling with seed data.""" 85 | 86 | @pytest.fixture 87 | def fitted_model(self, simple_language_data, tmp_path_factory): 88 | """Create a fitted model for reuse in tests.""" 89 | data = simple_language_data 90 | lm = LanguageModel( 91 | model="MOSTLY_AI/LSTMFromScratch-3m", 92 | tgt_encoding_types={ 93 | "category": ModelEncodingType.language_categorical.value, 94 | "headline": ModelEncodingType.language_text.value, 95 | "date": ModelEncodingType.language_datetime.value, 96 | }, 97 | max_epochs=1, 98 | verbose=0, 99 | workspace_dir=tmp_path_factory.mktemp("workspace"), 100 | ) 101 | lm.fit(X=data) 102 | return lm 103 | 104 | def test_conditional_sample(self, fitted_model, simple_language_data): 105 | """Test conditional sampling with seed_data.""" 106 | lm = fitted_model 107 | 108 | # Prepare seed data 109 | seed_data = pd.DataFrame( 110 | { 111 | "category": ["business", "tech"], 112 | } 113 | ) 114 | 115 | # Generate conditional samples 116 | syn_data = lm.sample( 117 | seed_data=seed_data, 118 | sampling_temperature=0.5, 119 | ) 120 | 121 | # Verify seeded columns are preserved 122 | assert len(syn_data) == 2 123 | assert all(syn_data["category"] == seed_data["category"]) 124 | # Verify all columns are present 125 | assert set(syn_data.columns) == set(simple_language_data.columns) 126 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | help: ## Show definition of each function 3 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z1-9_-]+:.*?## / {printf "\033[36m%-25s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) 4 | 5 | .PHONY: clean 6 | clean: ## Remove .gitignore files 7 | git clean -fdX 8 | 9 | .PHONY: install 10 | install: # Install dependencies 11 | uv sync --frozen 12 | 13 | .PHONY: lint 14 | lint: ## Run lints 15 | uv run --no-sync pre-commit run --all-files 16 | 17 | .PHONY: test 18 | test: ## Run tests 19 | uv run --no-sync pytest 20 | 21 | .PHONY: all 22 | all: clean install lint test ## Run the commands: clean install lint test 23 | 24 | # Default files to update 25 | PYPROJECT_TOML = pyproject.toml 26 | INIT_FILE = mostlyai/engine/__init__.py 27 | 28 | # Internal Variables for Release Workflow 29 | BUMP_TYPE ?= patch 30 | CURRENT_VERSION := $(shell grep -m 1 'version = ' $(PYPROJECT_TOML) | sed -e 's/version = "\(.*\)"/\1/') 31 | # Assuming current_version is already set from pyproject.toml 32 | NEW_VERSION := $(shell echo $(CURRENT_VERSION) | awk -F. -v bump=$(BUMP_TYPE) '{ \ 33 | if (bump == "patch") { \ 34 | printf("%d.%d.%d", $$1, $$2, $$3 + 1); \ 35 | } else if (bump == "minor") { \ 36 | printf("%d.%d.0", $$1, $$2 + 1); \ 37 | } else if (bump == "major") { \ 38 | printf("%d.0.0", $$1 + 1); \ 39 | } else { \ 40 | print "Error: Invalid BUMP_TYPE. Expected patch, minor or major. Input was BUMP_TYPE=" bump; \ 41 | exit 1; \ 42 | } \ 43 | }') 44 | 45 | # Targets for Release Workflow/Automation 46 | .PHONY: update-version-gh release-pypi docs 47 | 48 | update-version-gh: pull-main bump-version update-vars-version create-branch ## Update version in GitHub: pull main, bump version, create and push the new branch 49 | 50 | release-pypi: clean-dist pull-main build upload-pypi docs ## Release to PyPI: pull main, build and upload to PyPI 51 | 52 | pull-main: # Pull main branch 53 | # stash changes 54 | @git stash 55 | # switch to main branch 56 | @git checkout main 57 | # fetch latest changes 58 | @git fetch origin main 59 | # get a clean copy of main branch 60 | @git reset --hard origin/main 61 | # clean 62 | @git clean -fdX 63 | 64 | bump-version: # Bump version (default: patch, options: patch, minor, major) 65 | @echo "Bumping $(BUMP_TYPE) version from $(CURRENT_VERSION) to $(NEW_VERSION)" 66 | @echo "Replaces $(CURRENT_VERSION) to $(NEW_VERSION) in $(PYPROJECT_TOML)" 67 | @echo "Replaces $(CURRENT_VERSION) to $(NEW_VERSION) in $(INIT_FILE)" 68 | @echo "Current directory: $(shell pwd)" 69 | # Check if current version was found 70 | @if [ -z "$(CURRENT_VERSION)" ]; then \ 71 | echo "Error: Could not find current version in $(PYPROJECT_TOML)"; \ 72 | exit 1; \ 73 | fi 74 | # Replace the version in pyproject.toml 75 | @if [[ "$(shell uname -s)" == "Darwin" ]]; then \ 76 | sed -i '' 's/version = "$(CURRENT_VERSION)"/version = "$(NEW_VERSION)"/g' $(PYPROJECT_TOML); \ 77 | sed -i '' 's/__version__ = "$(CURRENT_VERSION)"/__version__ = "$(NEW_VERSION)"/g' $(INIT_FILE); \ 78 | else \ 79 | sed -i 's/version = "$(CURRENT_VERSION)"/version = "$(NEW_VERSION)"/g' $(PYPROJECT_TOML); \ 80 | sed -i 's/__version__ = "$(CURRENT_VERSION)"/__version__ = "$(NEW_VERSION)"/g' $(INIT_FILE); \ 81 | fi 82 | 83 | update-vars-version: # Update the required variables after bump 84 | $(eval VERSION := $(shell python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])")) 85 | $(eval BRANCH := verbump_$(shell echo $(VERSION) | tr '.' '_')) 86 | $(eval TAG := $(VERSION)) 87 | @echo "Updated VERSION to $(VERSION), BRANCH to $(BRANCH), TAG to $(TAG)" 88 | 89 | create-branch: # Create verbump_{new_ver} branch 90 | @git checkout -b $(BRANCH) 91 | @echo "Created branch $(BRANCH)" 92 | # commit the version bump 93 | @git add $(INIT_FILE) 94 | @git add $(PYPROJECT_TOML) 95 | @git commit -m "Version Bump to $(VERSION)" 96 | @echo "Committed version bump to $(VERSION)" 97 | @git push --set-upstream origin $(BRANCH) 98 | @echo "Pushed branch $(BRANCH) to origin" 99 | 100 | clean-dist: # Remove "volatile" directory dist 101 | @rm -rf dist 102 | @echo "Cleaned up dist directory" 103 | 104 | build: # Build the project and create the dist directory if it doesn't exist 105 | @mkdir -p dist 106 | @uv build 107 | @echo "Built the project" 108 | @twine check --strict dist/* 109 | @echo "Project is checked" 110 | 111 | confirm-upload: # Confirm before the irreversible zone 112 | @echo "Are you sure you want to upload to PyPI? (yes/no)" 113 | @read ans && [ $${ans:-no} = yes ] 114 | 115 | upload-pypi: confirm-upload # Upload to PyPI (ensure the token is present in .pypirc file before running upload) 116 | @twine upload dist/*$(VERSION)* --verbose 117 | @echo "Uploaded version $(VERSION) to PyPI" 118 | 119 | docs: ## Update docs site 120 | @mkdocs gh-deploy 121 | @echo "Deployed docs" 122 | -------------------------------------------------------------------------------- /mostlyai/engine/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | import pandas as pd 18 | 19 | from mostlyai.engine._common import ProgressCallback 20 | from mostlyai.engine._workspace import resolve_model_type 21 | from mostlyai.engine.domain import ( 22 | FairnessConfig, 23 | ImputationConfig, 24 | ModelType, 25 | RareCategoryReplacementMethod, 26 | RebalancingConfig, 27 | ) 28 | 29 | 30 | def generate( 31 | *, 32 | ctx_data: pd.DataFrame | None = None, 33 | seed_data: pd.DataFrame | None = None, 34 | sample_size: int | None = None, 35 | batch_size: int | None = None, 36 | sampling_temperature: float = 1.0, 37 | sampling_top_p: float = 1.0, 38 | device: str | None = None, 39 | rare_category_replacement_method: RareCategoryReplacementMethod | str = RareCategoryReplacementMethod.constant, 40 | rebalancing: RebalancingConfig | dict | None = None, 41 | imputation: ImputationConfig | dict | None = None, 42 | fairness: FairnessConfig | dict | None = None, 43 | workspace_dir: str | Path = "engine-ws", 44 | update_progress: ProgressCallback | None = None, 45 | ) -> None: 46 | """ 47 | Generates synthetic data from a trained model. 48 | 49 | Creates the following folder structure within the `workspace_dir`: 50 | 51 | - `SyntheticData`: Generated synthetic data, stored as parquet files. 52 | 53 | Args: 54 | ctx_data: Context data to be used for generation. 55 | seed_data: Seed data to condition generation on fixed target columns. 56 | sample_size: Number of samples to generate. Defaults to number of original samples. 57 | batch_size: Batch size for generation. If None, determined automatically. 58 | sampling_temperature: Sampling temperature. Higher values increase randomness. 59 | sampling_top_p: Nucleus sampling probability threshold. 60 | device: Device to run generation on ('cuda' or 'cpu'). Defaults to 'cuda' if available, else 'cpu'. 61 | rare_category_replacement_method: Method for handling rare categories. Only applicable for tabular models. 62 | rebalancing: Configuration for rebalancing column distributions. Only applicable for tabular models. 63 | imputation: List of columns to impute missing values. Only applicable for tabular models. 64 | fairness: Configuration for fairness constraints. Only applicable for tabular models. 65 | workspace_dir: Directory path for workspace. 66 | update_progress: Callback for progress updates. 67 | """ 68 | model_type = resolve_model_type(workspace_dir) 69 | if model_type == ModelType.tabular: 70 | from mostlyai.engine._tabular.generation import generate as generate_tabular 71 | 72 | return generate_tabular( 73 | ctx_data=ctx_data, 74 | seed_data=seed_data, 75 | sample_size=sample_size, 76 | batch_size=batch_size, 77 | sampling_temperature=sampling_temperature, 78 | sampling_top_p=sampling_top_p, 79 | rare_category_replacement_method=rare_category_replacement_method, 80 | rebalancing=rebalancing, 81 | imputation=imputation, 82 | fairness=fairness, 83 | device=device, 84 | workspace_dir=workspace_dir, 85 | update_progress=update_progress, 86 | ) 87 | else: 88 | from mostlyai.engine._language.generation import generate as generate_language 89 | 90 | if imputation is not None: 91 | raise ValueError("imputation is not supported for language models") 92 | if fairness is not None: 93 | raise ValueError("fairness is not supported for language models") 94 | if rebalancing is not None: 95 | raise ValueError("rebalancing is not supported for language models") 96 | return generate_language( 97 | ctx_data=ctx_data, 98 | seed_data=seed_data, 99 | sample_size=sample_size, 100 | batch_size=batch_size, 101 | sampling_temperature=sampling_temperature, 102 | sampling_top_p=sampling_top_p, 103 | rare_category_replacement_method=rare_category_replacement_method, 104 | device=device, 105 | workspace_dir=workspace_dir, 106 | update_progress=update_progress, 107 | ) 108 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/test_character.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from mostlyai.engine._common import read_json, write_json 19 | from mostlyai.engine._encoding_types.tabular.character import ( 20 | MAX_LENGTH_CHARS, 21 | analyze_character, 22 | analyze_reduce_character, 23 | decode_character, 24 | encode_character, 25 | ) 26 | 27 | 28 | def test_character(tmp_path): 29 | # create sequence of common strings, with some of those being overly long 30 | vals = np.repeat(["word", "sentence", "_".join("too_long") * 10], 100) 31 | # inject canaries, to then check whether those tokens are suppressed 32 | canary = "§§§" 33 | no_of_canaries = 3 34 | values1 = pd.Series([canary] * no_of_canaries + list(vals), name="chars") 35 | ids1 = pd.Series(np.arange(len(values1)), name="subject_id") 36 | # create sequence of common strings, with some of those missing 37 | values2 = pd.Series([pd.NA, "random_word", pd.NA] * 100, name="chars") 38 | ids2 = pd.Series(np.arange(len(values2)), name="subject_id") 39 | unseen_values = pd.Series(["a_sentence", "new_word"], name="chars") 40 | 41 | stats1 = analyze_character(values1, ids1) 42 | stats2 = analyze_character(values2, ids2) 43 | assert stats1["max_string_length"] == MAX_LENGTH_CHARS 44 | assert len(stats1["characters"]) == MAX_LENGTH_CHARS 45 | assert stats2["max_string_length"] == values2.str.len().max() 46 | assert len(stats2["characters"]) == values2.str.len().max() 47 | write_json(stats1, tmp_path / "stats1.json") 48 | write_json(stats2, tmp_path / "stats2.json") 49 | 50 | stats1 = read_json(tmp_path / "stats1.json") 51 | stats2 = read_json(tmp_path / "stats2.json") 52 | stats = analyze_reduce_character([stats1, stats2]) 53 | assert len(stats["codes"]) == MAX_LENGTH_CHARS 54 | # check that those rare characters don't occur in any vocabulary set 55 | for p in stats["codes"]: 56 | assert "§" not in stats["codes"][p] 57 | write_json(stats, tmp_path / "stats.json") 58 | 59 | stats = read_json(tmp_path / "stats.json") 60 | encoded1 = encode_character(values1, stats) 61 | decoded1 = decode_character(encoded1, stats) 62 | assert decoded1[no_of_canaries:].equals(values1[no_of_canaries:].str.slice(stop=MAX_LENGTH_CHARS)) 63 | encoded2 = encode_character(values2, stats) 64 | decoded2 = decode_character(encoded2, stats) 65 | assert decoded2.equals(values2.str.slice(stop=MAX_LENGTH_CHARS)) 66 | 67 | unseen_encoded = encode_character(unseen_values, stats) 68 | assert all(unseen_encoded.drop("nan", axis=1).values.flatten() >= 0) 69 | 70 | 71 | def test_character_empty(): 72 | values = pd.Series([None, None, None], name="value") 73 | ids = pd.Series(np.arange(len(values)), name="subject_id") 74 | stats = analyze_reduce_character([analyze_character(values, ids)]) 75 | df_encoded = encode_character(values, stats) 76 | df_decoded = decode_character(df_encoded, stats) 77 | assert all(df_decoded.isna()) 78 | assert df_decoded.index.is_unique 79 | 80 | values = pd.Series(["hello", None, None], name="value") 81 | df_encoded = encode_character(values, stats) 82 | df_decoded = decode_character(df_encoded, stats) 83 | assert all(df_decoded.isna()) 84 | 85 | # no values at all 86 | values = pd.Series([], name="value") 87 | ids = pd.Series(np.arange(len(values)), name="subject_id") 88 | partition_stats = analyze_character(values, ids) 89 | stats = analyze_reduce_character([partition_stats]) 90 | df_encoded = encode_character(values, stats) 91 | df_decoded = decode_character(df_encoded, stats) 92 | assert partition_stats == { 93 | "characters": {}, 94 | "has_nan": False, 95 | "max_string_length": 0, 96 | } 97 | assert stats == { 98 | "cardinalities": {}, 99 | "codes": {}, 100 | "has_nan": False, 101 | "max_string_length": 0, 102 | } 103 | assert df_encoded.empty, df_encoded.columns.tolist() == (True, []) 104 | assert df_decoded.empty, df_encoded.columns.tolist() == (True, []) 105 | 106 | 107 | def test_character_noempties(): 108 | values = pd.Series(["hello", "world", "!"], name="value") 109 | ids = pd.Series(np.arange(len(values)), name="subject_id") 110 | stats = analyze_reduce_character([analyze_character(values, ids)]) 111 | values = pd.Series([None, None, None], name="value") 112 | df_encoded = encode_character(values, stats) 113 | df_decoded = decode_character(df_encoded, stats) 114 | assert df_decoded.size == values.size 115 | -------------------------------------------------------------------------------- /tests/end_to_end/test_numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import shutil 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import pytest 21 | 22 | from mostlyai.engine import analyze, encode 23 | from mostlyai.engine._common import write_json 24 | from mostlyai.engine._tabular.generation import generate 25 | from mostlyai.engine._tabular.training import train 26 | from mostlyai.engine.domain import ModelEncodingType 27 | 28 | 29 | @pytest.fixture 30 | def sum_df(): 31 | df = pd.DataFrame( 32 | { 33 | "n": np.random.randint(0, 98, size=1000) * 100, 34 | "m": np.random.randint(0, 98, size=1000), 35 | } 36 | ) 37 | 38 | # Create column 'o' as sum of 'n' and 'm' 39 | df["o"] = df["n"] + df["m"] 40 | 41 | yield df 42 | 43 | 44 | @pytest.fixture 45 | def product_df(): 46 | # Generate 1000 uniformly distributed random prices between 0 and 300 47 | prices = np.random.uniform(0, 300, 1000).astype(int) 48 | 49 | # Replace the decimal part with one of the following: 00, 05, 50, 95, 99. 50 | decimals = np.random.choice([0, 0.05, 0.5, 0.95, 0.99], 1000) 51 | prices = prices.astype(int) + decimals 52 | 53 | # Create a DataFrame 54 | df = pd.DataFrame( 55 | { 56 | "price": prices, 57 | } 58 | ) 59 | 60 | yield df 61 | 62 | 63 | def prepare_ws(tmp_path: Path, df: pd.DataFrame, keys: dict, encoding_types: dict) -> Path: 64 | workspace_dir = tmp_path / "ws" 65 | shutil.rmtree(workspace_dir, ignore_errors=True) # cleanup 66 | tgt_meta_path = workspace_dir / "OriginalData" / "tgt-meta" 67 | tgt_data_path = workspace_dir / "OriginalData" / "tgt-data" 68 | for path in [ 69 | workspace_dir, 70 | tgt_meta_path, 71 | tgt_data_path, 72 | ]: 73 | path.mkdir(exist_ok=True, parents=True) 74 | 75 | df.to_parquet(tgt_data_path / "part.000000-trn.parquet") 76 | write_json(keys, tgt_meta_path / "keys.json") 77 | write_json(encoding_types, tgt_meta_path / "encoding-types.json") 78 | 79 | return workspace_dir 80 | 81 | 82 | def synthetize(ws_dir: Path) -> pd.DataFrame: 83 | analyze(workspace_dir=ws_dir) 84 | encode(workspace_dir=ws_dir) 85 | train(max_epochs=5, workspace_dir=ws_dir) 86 | generate(workspace_dir=ws_dir) 87 | syn_data_path = ws_dir / "SyntheticData" 88 | syn = pd.read_parquet(syn_data_path) 89 | 90 | return syn 91 | 92 | 93 | def compare_numeric_encodings( 94 | tmp_path, 95 | df, 96 | numeric_cols, 97 | first=ModelEncodingType.tabular_numeric_auto, 98 | second=ModelEncodingType.tabular_numeric_digit, 99 | ): 100 | syn = [] 101 | for numeric_encoding in [first, second]: 102 | ws = prepare_ws( 103 | tmp_path=tmp_path, 104 | df=df, 105 | keys={}, 106 | encoding_types={k: numeric_encoding.value for k in numeric_cols}, 107 | ) 108 | syn.append(synthetize(ws)) 109 | 110 | return syn[0], syn[1] 111 | 112 | 113 | def test_numeric_sum_quality(tmp_path, sum_df): 114 | sum_syn_auto, sum_syn_digit = compare_numeric_encodings(tmp_path=tmp_path, df=sum_df, numeric_cols=["n", "m", "o"]) 115 | 116 | assert sum_syn_auto.shape == sum_syn_digit.shape 117 | 118 | def calculate_sum_square_errors(df: pd.DataFrame, expected: str, actual: str): 119 | # Calculate the squares of the % errors 120 | squared_error = np.square((df[actual] - df[expected]) / df[actual]) 121 | return np.sum(squared_error) 122 | 123 | sum_syn_auto["expected"] = sum_syn_auto["n"] + sum_syn_auto["m"] 124 | sum_syn_auto_errors = calculate_sum_square_errors(df=sum_syn_auto, expected="expected", actual="o") 125 | sum_syn_digit["expected"] = sum_syn_digit["n"] + sum_syn_digit["m"] 126 | sum_syn_digit_errors = calculate_sum_square_errors(df=sum_syn_digit, expected="expected", actual="o") 127 | 128 | # ensure the quality is reasonable 129 | assert sum_syn_auto_errors / sum_syn_digit_errors < 10 130 | 131 | 132 | def test_numeric_price_quality(tmp_path, product_df): 133 | prod_syn_auto, prod_syn_digit = compare_numeric_encodings(tmp_path=tmp_path, df=product_df, numeric_cols=["price"]) 134 | 135 | assert prod_syn_auto.shape == prod_syn_digit.shape 136 | 137 | def similar_quantiles(ser_first, ser_second, threshold=0.05) -> bool: 138 | quantiles = [0.25, 0.5, 0.75, 1] 139 | q_first = ser_first.quantile(quantiles) 140 | q_second = ser_second.quantile(quantiles) 141 | return bool(all(np.abs((q_first - q_second) / ((q_first + q_second) / 2)) <= threshold)) 142 | 143 | assert similar_quantiles(product_df, prod_syn_auto) 144 | assert similar_quantiles(product_df, prod_syn_digit) 145 | -------------------------------------------------------------------------------- /mostlyai/engine/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from collections.abc import Callable 17 | from pathlib import Path 18 | 19 | import torch 20 | 21 | from mostlyai.engine._common import ProgressCallback 22 | from mostlyai.engine._workspace import resolve_model_type 23 | from mostlyai.engine.domain import DifferentialPrivacyConfig, ModelStateStrategy, ModelType 24 | 25 | 26 | def train( 27 | *, 28 | model: str | None = None, 29 | max_training_time: float | None = 14400.0, # 10 days 30 | max_epochs: float | None = 100.0, # 100 epochs 31 | batch_size: int | None = None, 32 | gradient_accumulation_steps: int | None = None, 33 | enable_flexible_generation: bool = True, 34 | max_sequence_window: int | None = None, 35 | differential_privacy: DifferentialPrivacyConfig | dict | None = None, 36 | model_state_strategy: ModelStateStrategy = ModelStateStrategy.reset, 37 | device: torch.device | str | None = None, 38 | workspace_dir: str | Path = "engine-ws", 39 | update_progress: ProgressCallback | None = None, 40 | upload_model_data_callback: Callable | None = None, 41 | ) -> None: 42 | """ 43 | Trains a model with optional early stopping and differential privacy. 44 | 45 | Creates the following folder structure within the `workspace_dir`: 46 | 47 | - `ModelStore`: Trained model checkpoints and logs. 48 | 49 | Args: 50 | model: The identifier of the model to train. If tabular, defaults to MOSTLY_AI/Medium. If language, defaults to MOSTLY_AI/LSTMFromScratch-3m. 51 | max_training_time: Maximum training time in minutes. If None, defaults to 10 days. 52 | max_epochs: Maximum number of training epochs. If None, defaults to 100 epochs. 53 | batch_size: Per-device batch size for training and validation. If None, determined automatically. 54 | gradient_accumulation_steps: Number of steps to accumulate gradients. If None, determined automatically. 55 | enable_flexible_generation: Whether to enable flexible order generation. Defaults to True. 56 | max_sequence_window: Maximum sequence window for tabular sequential models. Only applicable for tabular models. 57 | differential_privacy: Configuration for differential privacy training. If None, DP is disabled. 58 | model_state_strategy: Strategy for handling existing model state (reset/resume/reuse). 59 | device: Device to run training on ('cuda' or 'cpu'). Defaults to 'cuda' if available, else 'cpu'. 60 | workspace_dir: Directory path for workspace. Training outputs are stored in ModelStore subdirectory. 61 | update_progress: Callback function to report training progress. 62 | upload_model_data_callback: Callback function to upload model data during training. 63 | """ 64 | model_type = resolve_model_type(workspace_dir) 65 | if model_type == ModelType.tabular: 66 | from mostlyai.engine._tabular.training import train as train_tabular 67 | 68 | args = inspect.signature(train_tabular).parameters 69 | train_tabular( 70 | model=model if model else args["model"].default, 71 | workspace_dir=workspace_dir, 72 | max_training_time=max_training_time if max_training_time else args["max_training_time"].default, 73 | max_epochs=max_epochs if max_epochs else args["max_epochs"].default, 74 | batch_size=batch_size, 75 | gradient_accumulation_steps=gradient_accumulation_steps, 76 | enable_flexible_generation=enable_flexible_generation, 77 | differential_privacy=differential_privacy, 78 | update_progress=update_progress, 79 | upload_model_data_callback=upload_model_data_callback, 80 | model_state_strategy=model_state_strategy, 81 | device=device, 82 | max_sequence_window=max_sequence_window if max_sequence_window else args["max_sequence_window"].default, 83 | ) 84 | else: 85 | from mostlyai.engine._language.training import train as train_language 86 | 87 | if max_sequence_window is not None: 88 | raise ValueError("max_sequence_window is not supported for language models") 89 | 90 | args = inspect.signature(train_language).parameters 91 | train_language( 92 | model=model if model else args["model"].default, 93 | workspace_dir=workspace_dir, 94 | max_training_time=max_training_time if max_training_time else args["max_training_time"].default, 95 | max_epochs=max_epochs if max_epochs else args["max_epochs"].default, 96 | batch_size=batch_size, 97 | gradient_accumulation_steps=gradient_accumulation_steps, 98 | enable_flexible_generation=enable_flexible_generation, 99 | differential_privacy=differential_privacy, 100 | update_progress=update_progress, 101 | upload_model_data_callback=upload_model_data_callback, 102 | model_state_strategy=model_state_strategy, 103 | device=device, 104 | ) 105 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/tabular/categorical.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Categorical encoding maps each categorical value to its own integer code. 17 | """ 18 | 19 | import pandas as pd 20 | 21 | from mostlyai.engine._common import dp_non_rare, get_stochastic_rare_threshold, safe_convert_string 22 | 23 | CATEGORICAL_UNKNOWN_TOKEN = "_RARE_" 24 | CATEGORICAL_NULL_TOKEN = "<>" 25 | CATEGORICAL_SUB_COL_SUFFIX = "cat" 26 | CATEGORICAL_ESCAPE_CHAR = "\x01" 27 | 28 | 29 | def safe_categorical_escape(values: pd.Series) -> pd.Series: 30 | """Inplace escaping of categorical values""" 31 | reserved_tokens = (CATEGORICAL_UNKNOWN_TOKEN, CATEGORICAL_NULL_TOKEN) 32 | reserved_tokens_replacement_map = {t: CATEGORICAL_ESCAPE_CHAR + t for t in reserved_tokens} 33 | # first, prefix values starting with escape char with another escape char 34 | mask = values.str.startswith(CATEGORICAL_ESCAPE_CHAR, na=False) 35 | values.loc[mask] = values.loc[mask].str.slice_replace(stop=1, repl=CATEGORICAL_ESCAPE_CHAR * 2) 36 | # second, add escape char to all reserved tokens 37 | values = values.replace(reserved_tokens_replacement_map) 38 | return values 39 | 40 | 41 | def safe_categorical_unescape(values: pd.Series) -> pd.Series: 42 | """Inplace un-escaping of categorical values""" 43 | # de-prefix all values starting with escape char by removing just the first one 44 | mask = values.str.startswith(CATEGORICAL_ESCAPE_CHAR, na=False) 45 | values.loc[mask] = values.loc[mask].str[1:] 46 | return values 47 | 48 | 49 | def analyze_categorical( 50 | values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None, *, safe_escape: bool = True 51 | ) -> dict: 52 | # ensure a safe representation of values: 1. string dtype; 2. escape reserved tokens 53 | values = safe_convert_string(values) 54 | if safe_escape: 55 | values = safe_categorical_escape(values) 56 | # count distinct root_keys per categorical value for rare-category protection 57 | df = pd.concat([root_keys, values], axis=1) 58 | cnt_values = df.groupby(values.name)[root_keys.name].nunique().to_dict() 59 | stats = {"has_nan": sum(values.isna()) > 0, "cnt_values": cnt_values} 60 | return stats 61 | 62 | 63 | def analyze_reduce_categorical( 64 | stats_list: list[dict], 65 | value_protection: bool = True, 66 | value_protection_epsilon: float | None = None, 67 | ) -> dict: 68 | # sum up all counts for each categorical value 69 | cnt_values: dict[str, int] = {} 70 | for item in stats_list: 71 | for value, count in item["cnt_values"].items(): 72 | cnt_values[value] = cnt_values.get(value, 0) + count 73 | cnt_values = dict(sorted(cnt_values.items())) 74 | known_categories = list(cnt_values.keys()) 75 | if value_protection: 76 | if value_protection_epsilon is not None: 77 | categories, _ = dp_non_rare(cnt_values, value_protection_epsilon, threshold=5) 78 | else: 79 | rare_min = get_stochastic_rare_threshold(min_threshold=5) 80 | categories = [k for k in known_categories if cnt_values[k] >= rare_min] 81 | else: 82 | categories = known_categories 83 | no_of_rare_categories = len(known_categories) - len(categories) 84 | # add special token for MISSING categories, if any are present 85 | if any([j["has_nan"] for j in stats_list]): 86 | categories = [CATEGORICAL_NULL_TOKEN] + categories 87 | # add special token for UNKNOWN categories at first position 88 | categories = [CATEGORICAL_UNKNOWN_TOKEN] + categories 89 | stats = { 90 | "no_of_rare_categories": no_of_rare_categories, 91 | "codes": {categories[i]: i for i in range(len(categories))}, 92 | "cardinalities": {CATEGORICAL_SUB_COL_SUFFIX: len(categories)}, 93 | } 94 | return stats 95 | 96 | 97 | def encode_categorical(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: 98 | # ensure a safe representation of values: 1. string dtype; 2. escape reserved tokens 99 | values = safe_categorical_escape(safe_convert_string(values)) 100 | known_categories = [str(k) for k in stats["codes"].keys()] 101 | values = values.copy() 102 | if CATEGORICAL_NULL_TOKEN in known_categories: 103 | values[values.isna()] = CATEGORICAL_NULL_TOKEN 104 | values[~values.isin(known_categories)] = CATEGORICAL_UNKNOWN_TOKEN 105 | 106 | # map categories to their corresponding codes 107 | codes = pd.Series( 108 | pd.Categorical(values, categories=known_categories).codes, 109 | name=CATEGORICAL_SUB_COL_SUFFIX, 110 | index=values.index, 111 | ) 112 | return codes.to_frame() 113 | 114 | 115 | def decode_categorical(df_encoded: pd.DataFrame, stats: dict) -> pd.Series: 116 | categories = stats["codes"].keys() 117 | values = pd.Series( 118 | pd.Categorical.from_codes(df_encoded[CATEGORICAL_SUB_COL_SUFFIX], categories=categories), 119 | dtype="string", 120 | ) 121 | values[values == CATEGORICAL_NULL_TOKEN] = pd.NA 122 | # convert escaped values to their original representation 123 | values = safe_categorical_unescape(values) 124 | return values 125 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | import logging 17 | from pathlib import Path 18 | 19 | import torch 20 | from peft import PeftConfig, prepare_model_for_kbit_training 21 | from transformers import ( 22 | AutoConfig, 23 | AutoModel, 24 | AutoModelForCausalLM, 25 | AutoModelForImageTextToText, 26 | BitsAndBytesConfig, 27 | PretrainedConfig, 28 | PreTrainedModel, 29 | ) 30 | from transformers.quantizers import AutoQuantizationConfig 31 | 32 | from mostlyai.engine._language.lstm import LSTMFromScratchConfig 33 | 34 | _LOG = logging.getLogger(__name__) 35 | 36 | MAX_LENGTH = 10_000 37 | 38 | 39 | def is_bf16_supported(device: torch.device) -> bool: 40 | if device.type != "cuda": 41 | return False 42 | compute_capability = torch.cuda.get_device_capability(device) 43 | return compute_capability[0] >= 8 44 | 45 | 46 | def get_attention_implementation(config: PretrainedConfig) -> str | None: 47 | model_cls = AutoModel._model_mapping[type(config)] 48 | attn_implementation = None 49 | if getattr(model_cls, "_supports_sdpa", False): 50 | attn_implementation = "sdpa" 51 | return attn_implementation 52 | 53 | 54 | def load_base_model_and_config( 55 | model_id_or_path: str | Path, device: torch.device, is_peft_adapter: bool, is_training: bool 56 | ) -> tuple[PreTrainedModel, PretrainedConfig]: 57 | # opacus DP does not support parallel/sharded training 58 | model_id_or_path = str(model_id_or_path) 59 | if is_peft_adapter: 60 | # get the base model name from adapter_config.json 61 | peft_config = PeftConfig.from_pretrained(model_id_or_path) 62 | model_id_or_path = peft_config.base_model_name_or_path 63 | config = AutoConfig.from_pretrained(model_id_or_path) 64 | else: 65 | config = AutoConfig.from_pretrained(model_id_or_path) 66 | if config.model_type == LSTMFromScratchConfig.model_id: 67 | # make sure that we use standard LSTM layers during inference for the model trained with DP 68 | # (see https://opacus.ai/api/dp_rnn.html#opacus.layers.dp_rnn.DPLSTM for more details) 69 | if not is_training: 70 | config.with_dp = False 71 | return AutoModelForCausalLM.from_pretrained(model_id_or_path, config=config, device_map=device), config 72 | 73 | # Load pretrained base model 74 | use_cache = not is_training # KV cache is not needed during training 75 | is_gpu_training = is_training and device.type == "cuda" 76 | is_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None 77 | if is_gpu_training and not is_bitsandbytes_available: 78 | _LOG.warning( 79 | "CUDA device was found but bitsandbytes is not available. Please use extra [gpu] to install bitsandbytes for quantization." 80 | ) 81 | bf16_supported = is_bf16_supported(device) 82 | if bf16_supported: 83 | attn_implementation = get_attention_implementation(config) 84 | torch_dtype = torch.bfloat16 85 | else: 86 | attn_implementation = None 87 | torch_dtype = torch.float32 88 | if hasattr(config, "quantization_config"): 89 | quantization_config = AutoQuantizationConfig.from_dict(config.quantization_config) 90 | elif is_gpu_training and is_bitsandbytes_available: 91 | quantization_config = BitsAndBytesConfig( 92 | load_in_4bit=True, 93 | bnb_4bit_quant_type="nf4", 94 | bnb_4bit_use_double_quant=False, 95 | bnb_4bit_compute_dtype=torch_dtype, 96 | ) 97 | else: 98 | quantization_config = None 99 | 100 | if device.type == "cuda" and device.index is None: 101 | device_map = "auto" 102 | else: # device is `cpu` or `cuda:0` (when using single GPU on a multi-GPU instance) 103 | device_map = str(device) 104 | 105 | if hasattr(config, "text_config") and hasattr(config, "vision_config"): 106 | config.text_config.use_cache = use_cache 107 | config.text_config.attn_implementation = attn_implementation 108 | auto_model_cls = AutoModelForImageTextToText 109 | elif hasattr(config, "use_cache"): 110 | config.use_cache = use_cache 111 | config.attn_implementation = attn_implementation 112 | auto_model_cls = AutoModelForCausalLM 113 | else: 114 | raise ValueError("Unsupported model") 115 | 116 | model = auto_model_cls.from_pretrained( 117 | model_id_or_path, 118 | config=config, 119 | device_map=device_map, 120 | quantization_config=quantization_config, 121 | torch_dtype=torch_dtype, 122 | ) 123 | if isinstance(quantization_config, BitsAndBytesConfig): 124 | # convert all non-kbit layers to float32 125 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) 126 | if is_gpu_training and model.supports_gradient_checkpointing: 127 | # pay 50% time penalty for _large_ memory savings 128 | _LOG.info("enable gradient checkpointing") 129 | model.gradient_checkpointing_enable() 130 | model.enable_input_require_grads() 131 | return model, config 132 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/hf_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import time 18 | from os import PathLike 19 | from pathlib import Path 20 | 21 | import torch 22 | from peft import PeftModel 23 | from pydantic import BaseModel 24 | from transformers import AutoTokenizer 25 | from xgrammar.contrib.hf import LogitsProcessor 26 | 27 | from mostlyai.engine._language.common import load_base_model_and_config 28 | from mostlyai.engine._language.engine.base import EngineMetrics, LanguageEngine 29 | from mostlyai.engine._language.tokenizer_utils import tokenize_fn 30 | from mostlyai.engine._language.xgrammar_utils import create_compiled_grammars 31 | 32 | 33 | class HuggingFaceEngine(LanguageEngine): 34 | def __init__( 35 | self, model_path: PathLike | str, device: torch.device, max_new_tokens: int, tokenizer_max_length: int 36 | ): 37 | self.device = device 38 | self.max_new_tokens = max_new_tokens 39 | self.tokenizer_max_length = tokenizer_max_length 40 | self.is_peft_adapter = (Path(model_path) / "adapter_config.json").exists() 41 | 42 | model_path = str(model_path) 43 | self._model, self._model_config = load_base_model_and_config( 44 | model_path, device=device, is_peft_adapter=self.is_peft_adapter, is_training=False 45 | ) 46 | if self.is_peft_adapter: 47 | self._model = PeftModel.from_pretrained(self._model, model_path, is_trainable=False) 48 | self._model = self._model.merge_and_unload() 49 | self._default_batch_size = 64 50 | else: 51 | # only the LSTM model does not have an adapter 52 | self._default_batch_size = 128 53 | 54 | self.tokenizer = AutoTokenizer.from_pretrained( 55 | model_path, 56 | padding_side="left", 57 | truncation_side="left", 58 | legacy=True, 59 | # these must be False at initialization, as we manually add them later in tokenize_fn 60 | add_bos_token=False, 61 | add_eos_token=False, 62 | ) 63 | 64 | # we can't enforce JSON output if LSTM tokenizer training was skipped 65 | is_trained_lstm_tokenizer = not self.is_peft_adapter and self.tokenizer.vocab_size > len( 66 | self.tokenizer.special_tokens_map 67 | ) 68 | self._json_enforcing_possible = self.is_peft_adapter or is_trained_lstm_tokenizer 69 | self._logits_processors = None 70 | 71 | def get_default_batch_size(self) -> int: 72 | return self._default_batch_size 73 | 74 | def supports_json_enforcing(self) -> bool: 75 | return self._json_enforcing_possible 76 | 77 | def generate( 78 | self, text: list[str], sampling_temperature: float, sampling_top_p: float 79 | ) -> tuple[list[int], EngineMetrics]: 80 | do_sample = sampling_temperature > 0.0 81 | 82 | tokenize_kwargs = dict( 83 | tokenizer=self.tokenizer, 84 | return_tensors="pt", 85 | add_bos_token=True, 86 | add_eos_token=False, 87 | padding=True, 88 | truncation=True, 89 | max_length=self.tokenizer_max_length, # truncates input 90 | ) 91 | t_tokenize = time.time() 92 | inputs = tokenize_fn(text=text, **tokenize_kwargs).to(self.device) 93 | tokenize_time = time.time() - t_tokenize 94 | 95 | generate_kwargs = dict( 96 | do_sample=do_sample, 97 | max_new_tokens=self.max_new_tokens, 98 | temperature=sampling_temperature if do_sample else None, 99 | top_p=sampling_top_p if do_sample else None, 100 | bos_token_id=self.tokenizer.bos_token_id, 101 | pad_token_id=self.tokenizer.pad_token_id, 102 | eos_token_id=self.tokenizer.eos_token_id, 103 | ) 104 | 105 | t_generate = time.time() 106 | outputs = self._model.generate(**inputs, **generate_kwargs, logits_processor=self._logits_processors) 107 | generate_time = time.time() - t_generate 108 | 109 | _, input_length = inputs["input_ids"].shape 110 | # truncate the prompt from the outputs 111 | outputs = outputs[:, input_length:] 112 | metrics = EngineMetrics(tokenize_time=tokenize_time, generate_time=generate_time) 113 | return outputs.detach().cpu().tolist(), metrics 114 | 115 | def cleanup(self): 116 | pass 117 | 118 | def update_json_constraints(self, schemas: list[BaseModel] | None) -> None: 119 | """Update JSON schema constraints for the next generation call.""" 120 | if schemas: 121 | compiled_grammars = create_compiled_grammars( 122 | schemas=schemas, 123 | tokenizer=self.tokenizer, 124 | vocab_size=self._model_config.vocab_size, 125 | is_peft_adapter=self.is_peft_adapter, 126 | ) 127 | self._logits_processors = [LogitsProcessor(list(compiled_grammars))] 128 | else: 129 | self._logits_processors = None 130 | 131 | def can_reuse_schemas(self) -> bool: 132 | """HuggingFaceEngine cannot reuse LogitsProcessor across different batch sizes.""" 133 | return False 134 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/test_datetime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import read_json, write_json 18 | from mostlyai.engine._encoding_types.tabular.datetime import ( 19 | analyze_datetime, 20 | analyze_reduce_datetime, 21 | decode_datetime, 22 | encode_datetime, 23 | split_sub_columns_datetime, 24 | ) 25 | 26 | 27 | def test_datetime(tmp_path): 28 | s1 = pd.Series( 29 | [ 30 | "1910-01-01", 31 | "", 32 | "1930-01-31", 33 | "1940-02-12", 34 | "", 35 | "1971-09-01", 36 | "1983-05-19", 37 | "1998-05-24", 38 | ], 39 | name="birth_date", 40 | ) 41 | i1 = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], name="id") 42 | s2 = pd.Series( 43 | [ 44 | "1912-01-01", 45 | "", 46 | "1932-01-31", 47 | "1942-02-12", 48 | "", 49 | "1972-09-01", 50 | "1984-05-19", 51 | "1994-05-24", 52 | ], 53 | name="birth_date", 54 | ) 55 | i2 = pd.Series([11, 12, 13, 14, 15, 16, 17], name="id") 56 | 57 | stats1 = analyze_datetime(s1, i1) 58 | stats2 = analyze_datetime(s2, i2) 59 | write_json(stats1, tmp_path / "stats1.json") 60 | write_json(stats2, tmp_path / "stats2.json") 61 | 62 | stats = analyze_reduce_datetime([stats1, stats2], value_protection=False) 63 | write_json(stats, tmp_path / "stats.json") 64 | 65 | stats = read_json(tmp_path / "stats.json") 66 | df_encoded = encode_datetime(s1, stats) 67 | df_decoded = decode_datetime(df_encoded, stats) 68 | assert pd.to_datetime(s1).astype("datetime64[ns]").equals(df_decoded) 69 | 70 | 71 | def test_datetime_empty(tmp_path): 72 | values = pd.to_datetime(pd.Series([pd.NaT, pd.NaT, pd.NaT], name="value")).astype("datetime64[ns]") 73 | root_keys = pd.Series(range(len(values)), name="id") 74 | stats = analyze_reduce_datetime([analyze_datetime(values, root_keys)], value_protection=False) 75 | df_encoded = encode_datetime(values, stats) 76 | df_decoded = decode_datetime(df_encoded, stats) 77 | assert values.equals(df_decoded) 78 | assert all(df_decoded.isna()) 79 | 80 | values = pd.to_datetime(pd.Series(["2020-05-24", pd.NaT, pd.NaT], name="value")) 81 | df_encoded = encode_datetime(values, stats) 82 | df_decoded = decode_datetime(df_encoded, stats) 83 | assert all(df_decoded.isna()) 84 | 85 | # no values at all 86 | values = pd.to_datetime(pd.Series([], name="value")) 87 | root_keys = pd.Series(range(len(values)), name="id") 88 | partition_stats = analyze_datetime(values, root_keys) 89 | stats = analyze_reduce_datetime([partition_stats]) 90 | df_encoded = encode_datetime(values, stats) 91 | df_decoded = decode_datetime(df_encoded, stats) 92 | min_max_values = { 93 | "day": 1, 94 | "hour": 0, 95 | "minute": 0, 96 | "month": 1, 97 | "ms_E0": 0, 98 | "ms_E1": 0, 99 | "ms_E2": 0, 100 | "second": 0, 101 | "year": 2022, 102 | } 103 | assert partition_stats == { 104 | "has_nan": False, 105 | "max_n": [], 106 | "max_values": min_max_values, 107 | "min_n": [], 108 | "min_values": min_max_values, 109 | "log_hist": [0.0] * 128, 110 | } 111 | assert stats == { 112 | "cardinalities": {"day": 1, "month": 1, "year": 1}, 113 | "has_ms": False, 114 | "has_nan": False, 115 | "has_time": False, 116 | "max": None, 117 | "max_values": min_max_values, 118 | "min": None, 119 | "min_values": min_max_values, 120 | } 121 | assert df_encoded.empty, df_encoded.columns.tolist() == (True, []) 122 | assert df_decoded.empty, df_encoded.columns.tolist() == (True, []) 123 | 124 | 125 | def test_datetime_noempties(tmp_path): 126 | values = pd.to_datetime(pd.Series(["2020-05-24", "2021-05-24", "2022-05-24"], name="value")) 127 | root_keys = pd.Series(range(len(values)), name="id") 128 | stats = analyze_reduce_datetime([analyze_datetime(values, root_keys)], value_protection=False) 129 | values = pd.to_datetime(pd.Series([pd.NaT, pd.NaT, pd.NaT], name="value")) 130 | df_encoded = encode_datetime(values, stats) 131 | df_decoded = decode_datetime(df_encoded, stats) 132 | assert all(df_decoded.notna()) 133 | 134 | 135 | def test_datetime_min_max_overlapping(): 136 | root_keys = pd.Series(list(range(100)), name="id") 137 | values = pd.Series([pd.to_datetime(f"01-01-{2000 + y}") for y in range(100)], name="value") 138 | stats = analyze_reduce_datetime([analyze_datetime(values, root_keys)]) 139 | for pos, card in stats["cardinalities"].items(): 140 | assert card > 0 141 | 142 | 143 | def test_split_sub_columns_datetime(): 144 | values = pd.Series([pd.to_datetime("2020-01-01"), pd.NaT], name="dt", index=[1, 1]) 145 | df = split_sub_columns_datetime(values) 146 | cols = [ 147 | "nan", 148 | "year", 149 | "month", 150 | "day", 151 | "hour", 152 | "minute", 153 | "second", 154 | "ms_E2", 155 | "ms_E1", 156 | "ms_E0", 157 | ] 158 | vals = [ 159 | [0, 2020, 1, 1, 0, 0, 0, 0, 0, 0], 160 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 161 | ] 162 | pd.testing.assert_frame_equal(df, pd.DataFrame(vals, columns=cols)) 163 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/language/test_datetime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pytest 17 | 18 | from mostlyai.engine._common import ANALYZE_MIN_MAX_TOP_N 19 | from mostlyai.engine._encoding_types.language.datetime import ( 20 | analyze_language_datetime, 21 | analyze_reduce_language_datetime, 22 | decode_language_datetime, 23 | encode_language_datetime, 24 | ) 25 | from mostlyai.engine.domain import ModelEncodingType 26 | 27 | 28 | class TestLanguageDatetimeAnalyze: 29 | def test_analyze_language_datetime(self): 30 | birth_dates = pd.Series( 31 | [ 32 | "1910-01-01", 33 | "", 34 | "1930-01-31", 35 | "1940-02-12", 36 | "", 37 | "1971-09-01", 38 | "1983-05-19", 39 | "1998-05-24", 40 | ] 41 | * ANALYZE_MIN_MAX_TOP_N, 42 | name="birth_date", 43 | ) 44 | keys = pd.Series(range(len(birth_dates)), name="id") 45 | stats = analyze_language_datetime(birth_dates, keys) 46 | assert stats["has_nan"] is True 47 | assert stats["min_n"] == ["1910-01-01"] * ANALYZE_MIN_MAX_TOP_N 48 | assert stats["max_n"] == ["1998-05-24"] * ANALYZE_MIN_MAX_TOP_N 49 | 50 | 51 | class TestLanguageDatetimeAnalyzeReduce: 52 | def test_analyze_reduce_language_datetime(self): 53 | stats1 = { 54 | "has_nan": True, 55 | "min_n": ["1910-01-01"] * ANALYZE_MIN_MAX_TOP_N, 56 | "max_n": ["1998-05-24"] * ANALYZE_MIN_MAX_TOP_N, 57 | } 58 | stats2 = { 59 | "has_nan": False, 60 | "min_n": ["2000-01-01"] * ANALYZE_MIN_MAX_TOP_N, 61 | "max_n": ["2024-12-31"] * ANALYZE_MIN_MAX_TOP_N, 62 | } 63 | reduced = analyze_reduce_language_datetime([stats1, stats2]) 64 | assert reduced["has_nan"] is True 65 | assert reduced["min"] == "1910-01-01" 66 | assert reduced["max"] == "2024-12-31" 67 | 68 | 69 | class TestLanguageDatetimeEncode: 70 | def test_encode_language_datetime(self): 71 | values = pd.Series( 72 | [ 73 | "1910-01-01", 74 | "", 75 | "1930-01-31", 76 | "1940-02-12", 77 | "", 78 | "1971-09-01", 79 | "1983-05-19", 80 | "1998-05-24", 81 | ], 82 | name="birth_date", 83 | ) 84 | stats = { 85 | "has_nan": True, 86 | "min": "1930-01-31", 87 | "max": "2024-12-31", 88 | } 89 | encoded = encode_language_datetime(values, stats) 90 | assert encoded.dtype == "datetime64[us]" 91 | assert encoded.isna().sum() == 2 92 | assert encoded.iloc[0] == pd.Timestamp("1930-01-31") 93 | assert encoded.iloc[1] is pd.NaT 94 | assert encoded.iloc[2] == pd.Timestamp("1930-01-31") 95 | assert encoded.iloc[3] == pd.Timestamp("1940-02-12") 96 | assert encoded.iloc[4] is pd.NaT 97 | assert encoded.iloc[5] == pd.Timestamp("1971-09-01") 98 | assert encoded.iloc[6] == pd.Timestamp("1983-05-19") 99 | 100 | 101 | class TestLanguageDatetimeDecode: 102 | @pytest.fixture 103 | def datetime_stats(self): 104 | return { 105 | "encoding_type": ModelEncodingType.language_datetime, 106 | "has_nan": True, 107 | "min": "2000-01-01", 108 | "max": "2024-12-31", 109 | } 110 | 111 | @pytest.fixture 112 | def no_clip_stats(self): 113 | return { 114 | "encoding_type": ModelEncodingType.language_datetime, 115 | "has_nan": True, 116 | "min": "1900-01-01", 117 | "max": "2100-01-01", 118 | } 119 | 120 | @pytest.fixture 121 | def sample_dates(self): 122 | return pd.Series( 123 | [ 124 | "2021-05-20 14:30:00", # valid datetime with time 125 | "2020-02-30", # Feb 30 is invalid; should be clamped to Feb 29, 2020 126 | "1999-12-31", # below the min bound -> will be clipped upward 127 | "2025-01-01", # above the max bound -> will be clipped downward 128 | "abcd", # invalid date string -> becomes NaT 129 | "", # empty string -> becomes NaT 130 | "_INVALID_", # marked as invalid -> becomes NaT 131 | "2010-10-10", # valid date without explicit time (defaults to 00:00:00) 132 | ] 133 | ) 134 | 135 | def test_datetime_dtype_bounds_and_invalids(self, sample_dates, datetime_stats): 136 | decoded = decode_language_datetime(sample_dates, datetime_stats) 137 | assert decoded.dtype == "datetime64[ns]" 138 | non_null = decoded.dropna() 139 | min_bound = pd.to_datetime(datetime_stats["min"]) 140 | max_bound = pd.to_datetime(datetime_stats["max"]) 141 | for dt in non_null: 142 | assert dt >= min_bound 143 | assert dt <= max_bound 144 | assert all(pd.isna(decoded.iloc[4:7])) 145 | 146 | def test_date_day_clamping(self, no_clip_stats): 147 | s = pd.Series(["2021-04-31"]) 148 | decoded = decode_language_datetime(s, no_clip_stats) 149 | expected = pd.Timestamp("2021-04-30 00:00:00") 150 | assert decoded.iloc[0] == expected 151 | 152 | def test_time_extraction(self, no_clip_stats): 153 | s = pd.Series(["2021-07-15T23:59:59.123"]) 154 | decoded = decode_language_datetime(s, no_clip_stats) 155 | expected = pd.Timestamp("2021-07-15 23:59:59.123") 156 | assert decoded.iloc[0] == expected 157 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import torch 18 | import torch.nn as nn 19 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, GenerationMixin, PretrainedConfig, PreTrainedModel 20 | from transformers.modeling_outputs import CausalLMOutput 21 | 22 | _LOG = logging.getLogger(__name__) 23 | 24 | 25 | class LSTMFromScratchConfig(PretrainedConfig): 26 | model_type = model_id = "MOSTLY_AI/LSTMFromScratch-3m" 27 | 28 | # Map standard transformer attributes to our custom LSTM attributes 29 | attribute_map = { 30 | "num_hidden_layers": "num_layers", 31 | } 32 | 33 | def __init__( 34 | self, 35 | vocab_size: int | None = None, 36 | embedding_size: int = 256, 37 | hidden_size: int = 256, 38 | num_layers: int = 1, 39 | dropout: float = 0.25, 40 | with_dp: bool = False, 41 | **kwargs, 42 | ): 43 | self.vocab_size = vocab_size 44 | self.embedding_size = embedding_size 45 | self.hidden_size = hidden_size 46 | self.num_layers = num_layers 47 | self.dropout = dropout 48 | self.with_dp = with_dp 49 | super().__init__(**kwargs) 50 | 51 | 52 | class LSTMFromScratchLMHeadModel(PreTrainedModel, GenerationMixin): 53 | config_class = LSTMFromScratchConfig 54 | 55 | def __init__(self, config: LSTMFromScratchConfig): 56 | super().__init__(config) 57 | self.config = config 58 | 59 | self.embedding = nn.Embedding(self.config.vocab_size, self.config.embedding_size) 60 | self.dropout = nn.Dropout(self.config.dropout) 61 | if self.config.with_dp: 62 | from opacus.layers import DPLSTM 63 | 64 | lstm_cls = DPLSTM 65 | else: 66 | lstm_cls = nn.LSTM 67 | self.lstm = lstm_cls( 68 | input_size=self.config.embedding_size, 69 | hidden_size=self.config.hidden_size, 70 | num_layers=self.config.num_layers, 71 | dropout=self.config.dropout if self.config.num_layers > 1 else 0.0, 72 | batch_first=True, 73 | ) 74 | self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size) 75 | self.loss_fn = nn.CrossEntropyLoss() 76 | 77 | # this will be filled by left_to_right_padding() during the generation 78 | self.pad_token_id = None 79 | 80 | def forward( 81 | self, 82 | input_ids: torch.Tensor, 83 | attention_mask: torch.Tensor, 84 | labels: torch.Tensor | None = None, 85 | **kwargs, 86 | ) -> CausalLMOutput: 87 | lengths = attention_mask.sum(dim=1) 88 | embeddings = self.embedding(input_ids) 89 | embeddings = self.dropout(embeddings) 90 | 91 | # (DP)LSTM layers without pack_padded_sequence/pad_packed_sequence 92 | lstm_outputs, _ = self.lstm(embeddings) 93 | 94 | logits = self.lm_head(lstm_outputs) 95 | 96 | loss = None 97 | if labels is not None: 98 | labels = labels[:, 1:].contiguous() 99 | shifted_prediction_scores = logits[:, :-1, :].contiguous() 100 | loss = self.loss_fn(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 101 | else: 102 | # overwrite the logit of the last time step with the logit of the actual last token 103 | # so that Hugging Face Transformers' generate() will sample on the right probabilities 104 | logits[:, -1, :] = torch.stack([logits[i, length - 1, :] for i, length in enumerate(lengths)]) 105 | return CausalLMOutput( 106 | loss=loss, 107 | logits=logits, 108 | ) 109 | 110 | def prepare_inputs_for_generation( 111 | self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs 112 | ) -> dict[str, torch.Tensor]: 113 | """ 114 | This function is mandatory so that the model is able to use the Hugging Face `.generate()` method. 115 | Since `.generate()` works with left-padded sequences but the model is trained with right-padded sequences, 116 | we need to convert the padding side here to make it work properly. 117 | """ 118 | lengths = attention_mask.sum(dim=1) 119 | return { 120 | "input_ids": self.left_to_right_padding(input_ids, lengths), 121 | "attention_mask": attention_mask, 122 | } 123 | 124 | def left_to_right_padding(self, left_padded_tensors: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: 125 | batch_size, max_length = left_padded_tensors.size() 126 | indices = torch.nonzero(lengths < max_length) 127 | if len(indices) == 0: 128 | # none of the samples are padded, so we can just return them as they are 129 | return left_padded_tensors 130 | else: 131 | if self.pad_token_id is None: 132 | # get the pad token id from the first padded sample 133 | self.pad_token_id = left_padded_tensors[indices[0], -1].item() 134 | right_padded_tensors = torch.full_like(left_padded_tensors, self.pad_token_id) 135 | for i in range(batch_size): 136 | right_padded_tensors[i, : lengths[i]] = left_padded_tensors[i, max_length - lengths[i] :] 137 | return right_padded_tensors 138 | 139 | 140 | def register_mostly_lstm_model(): 141 | # register the model so that we can load it with `AutoModelForCausalLM.from_pretrained()` later 142 | AutoConfig.register(LSTMFromScratchConfig.model_id, LSTMFromScratchConfig) 143 | AutoModel.register(LSTMFromScratchConfig, LSTMFromScratchLMHeadModel) 144 | AutoModelForCausalLM.register(LSTMFromScratchConfig, LSTMFromScratchLMHeadModel) 145 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import ( 18 | ANALYZE_MIN_MAX_TOP_N, 19 | ANALYZE_REDUCE_MIN_MAX_N, 20 | compute_log_histogram, 21 | dp_approx_bounds, 22 | get_stochastic_rare_threshold, 23 | safe_convert_numeric, 24 | ) 25 | from mostlyai.engine._encoding_types.tabular.numeric import _type_safe_numeric_series 26 | from mostlyai.engine.domain import ModelEncodingType 27 | 28 | 29 | def analyze_language_numeric(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 30 | values = safe_convert_numeric(values) 31 | # compute log histogram for DP bounds 32 | log_hist = compute_log_histogram(values.dropna()) 33 | 34 | # determine lowest/highest values by root ID, and return top ANALYZE_MIN_MAX_TOP_N 35 | df = pd.concat([root_keys, values], axis=1) 36 | min_values = df.groupby(root_keys.name)[values.name].min().dropna() 37 | min_n = min_values.sort_values(ascending=True).head(ANALYZE_MIN_MAX_TOP_N).tolist() 38 | max_values = df.groupby(root_keys.name)[values.name].max().dropna() 39 | max_n = max_values.sort_values(ascending=False).head(ANALYZE_MIN_MAX_TOP_N).tolist() 40 | 41 | # determine if there are any NaN values 42 | has_nan = bool(values.isna().any()) 43 | 44 | # determine max scale 45 | def count_scale(num: float) -> int: 46 | # represent number as fixed point string, remove trailing zeros and decimal point 47 | num = format(num, "f").rstrip("0").rstrip(".") 48 | if "." in num: 49 | # in case of decimal, return number of digits after decimal point 50 | return len(num.split(".")[1]) 51 | # in case of integer, return 0 52 | return 0 53 | 54 | max_scale = int(values.apply(count_scale).max()) 55 | 56 | stats = { 57 | "has_nan": has_nan, 58 | "max_scale": max_scale, 59 | "min_n": min_n, 60 | "max_n": max_n, 61 | "log_hist": log_hist, 62 | } 63 | return stats 64 | 65 | 66 | def analyze_reduce_language_numeric( 67 | stats_list: list[dict], 68 | value_protection: bool = True, 69 | value_protection_epsilon: float | None = None, 70 | ) -> dict: 71 | # check for occurrence of NaN values 72 | has_nan = any([j["has_nan"] for j in stats_list]) 73 | 74 | # determine max scale 75 | max_scale = max([j["max_scale"] for j in stats_list]) 76 | 77 | reduced_min_n = sorted([v for min_n in [j["min_n"] for j in stats_list] for v in min_n], reverse=False) 78 | reduced_max_n = sorted([v for max_n in [j["max_n"] for j in stats_list] for v in max_n], reverse=True) 79 | if value_protection: 80 | if len(reduced_min_n) < ANALYZE_REDUCE_MIN_MAX_N or len(reduced_max_n) < ANALYZE_REDUCE_MIN_MAX_N: 81 | # protect all values if there are less than ANALYZE_REDUCE_MIN_MAX_N values 82 | reduced_min = None 83 | reduced_max = None 84 | else: 85 | if value_protection_epsilon is not None: 86 | # Sum up log histograms bin-wise from all partitions 87 | log_hist = [sum(bin) for bin in zip(*[j["log_hist"] for j in stats_list])] 88 | reduced_min, reduced_max = dp_approx_bounds(log_hist, value_protection_epsilon) 89 | if reduced_min is not None and reduced_max is not None and max_scale == 0: 90 | reduced_min = int(reduced_min) 91 | reduced_max = int(reduced_max) 92 | else: 93 | reduced_min = reduced_min_n[get_stochastic_rare_threshold(min_threshold=5)] 94 | reduced_max = reduced_max_n[get_stochastic_rare_threshold(min_threshold=5)] 95 | else: 96 | reduced_min = reduced_min_n[0] if len(reduced_min_n) > 0 else None 97 | reduced_max = reduced_max_n[0] if len(reduced_max_n) > 0 else None 98 | 99 | stats = { 100 | "encoding_type": ModelEncodingType.language_numeric.value, 101 | "has_nan": has_nan, 102 | "max_scale": max_scale, 103 | "min": reduced_min, 104 | "max": reduced_max, 105 | } 106 | 107 | return stats 108 | 109 | 110 | def encode_language_numeric(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: 111 | values = safe_convert_numeric(values) 112 | # try to convert to int, if possible 113 | dtype = "Int64" if stats["max_scale"] == 0 else "Float64" 114 | if dtype == "Int64": 115 | values = values.round() 116 | try: 117 | values = values.astype(dtype) 118 | except TypeError: 119 | if dtype == "Int64": # if couldn't safely convert to int, stick to float 120 | dtype = "Float64" 121 | values = values.astype(dtype) 122 | # reset index, as `values.mask` can throw errors for misaligned indices 123 | values.reset_index(drop=True, inplace=True) 124 | if stats["min"] is not None: 125 | reduced_min = _type_safe_numeric_series([stats["min"]], dtype).iloc[0] 126 | values.loc[values < reduced_min] = reduced_min 127 | if stats["max"] is not None: 128 | reduced_max = _type_safe_numeric_series([stats["max"]], dtype).iloc[0] 129 | values.loc[values > reduced_max] = reduced_max 130 | return values 131 | 132 | 133 | def decode_language_numeric(x: pd.Series, stats: dict[str, str]) -> pd.Series: 134 | x = pd.to_numeric(x, errors="coerce") 135 | x = x.round(stats["max_scale"]) 136 | if stats["min"] is not None: 137 | reduced_min = np.dtype(x.dtype).type(stats["min"]) 138 | x.loc[x < reduced_min] = reduced_min 139 | if stats["max"] is not None: 140 | reduced_max = np.dtype(x.dtype).type(stats["max"]) 141 | x.loc[x > reduced_max] = reduced_max 142 | dtype = "Int64" if stats["max_scale"] == 0 else float 143 | return x.astype(dtype) 144 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/language/datetime.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import calendar 15 | 16 | import numpy as np 17 | import pandas as pd 18 | 19 | from mostlyai.engine._common import ( 20 | ANALYZE_MIN_MAX_TOP_N, 21 | ANALYZE_REDUCE_MIN_MAX_N, 22 | compute_log_histogram, 23 | dp_approx_bounds, 24 | get_stochastic_rare_threshold, 25 | safe_convert_datetime, 26 | ) 27 | 28 | 29 | def analyze_language_datetime(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 30 | values = safe_convert_datetime(values) 31 | # compute log histogram for DP bounds 32 | log_hist = compute_log_histogram(values.dropna().astype("int64")) 33 | 34 | df = pd.concat([root_keys, values], axis=1) 35 | # determine lowest/highest values by root ID, and return Top 10 36 | min_dates = df.groupby(root_keys.name)[values.name].min().dropna() 37 | min_n = min_dates.sort_values(ascending=True).head(ANALYZE_MIN_MAX_TOP_N).astype(str).tolist() 38 | max_dates = df.groupby(root_keys.name)[values.name].max().dropna() 39 | max_n = max_dates.sort_values(ascending=False).head(ANALYZE_MIN_MAX_TOP_N).astype(str).tolist() 40 | # determine if there are any NaN values 41 | has_nan = bool(values.isna().any()) 42 | # return stats 43 | stats = { 44 | "has_nan": has_nan, 45 | "min_n": min_n, 46 | "max_n": max_n, 47 | "log_hist": log_hist, 48 | } 49 | return stats 50 | 51 | 52 | def analyze_reduce_language_datetime( 53 | stats_list: list[dict], 54 | value_protection: bool = True, 55 | value_protection_epsilon: float | None = None, 56 | ) -> dict: 57 | # check if there are missing values 58 | has_nan = any([j["has_nan"] for j in stats_list]) 59 | reduced_min_n = sorted([v for min_n in [j["min_n"] for j in stats_list] for v in min_n], reverse=False) 60 | reduced_max_n = sorted([v for max_n in [j["max_n"] for j in stats_list] for v in max_n], reverse=True) 61 | if value_protection: 62 | if len(reduced_min_n) < ANALYZE_REDUCE_MIN_MAX_N or len(reduced_max_n) < ANALYZE_REDUCE_MIN_MAX_N: 63 | # protect all values if there are less than ANALYZE_REDUCE_MIN_MAX_N values 64 | reduced_min = None 65 | reduced_max = None 66 | else: 67 | if value_protection_epsilon is not None: 68 | if any(len(v) > 10 for v in reduced_min_n + reduced_max_n): 69 | dt_format = "%Y-%m-%d %H:%M:%S" 70 | else: 71 | dt_format = "%Y-%m-%d" 72 | # Sum up log histograms bin-wise from all partitions 73 | log_hist = [sum(bin) for bin in zip(*[j["log_hist"] for j in stats_list])] 74 | reduced_min, reduced_max = dp_approx_bounds(log_hist, value_protection_epsilon) 75 | if reduced_min is not None and reduced_max is not None: 76 | # convert back to the original string format 77 | reduced_min = pd.to_datetime(int(reduced_min), unit="us").strftime(dt_format) 78 | reduced_max = pd.to_datetime(int(reduced_max), unit="us").strftime(dt_format) 79 | else: 80 | reduced_min = str(reduced_min_n[get_stochastic_rare_threshold(min_threshold=5)]) 81 | reduced_max = str(reduced_max_n[get_stochastic_rare_threshold(min_threshold=5)]) 82 | else: 83 | reduced_min = str(reduced_min_n[0]) if len(reduced_min_n) > 0 else None 84 | reduced_max = str(reduced_max_n[0]) if len(reduced_max_n) > 0 else None 85 | stats = { 86 | "has_nan": has_nan, 87 | "min": reduced_min, 88 | "max": reduced_max, 89 | } 90 | return stats 91 | 92 | 93 | def _clip_datetime(values: pd.Series, stats: dict) -> pd.Series: 94 | if stats["min"] is not None: 95 | reduced_min = np.datetime64(stats["min"], "ns") 96 | values.loc[values < reduced_min] = reduced_min 97 | if stats["max"] is not None: 98 | reduced_max = np.datetime64(stats["max"], "ns") 99 | values.loc[values > reduced_max] = reduced_max 100 | return values 101 | 102 | 103 | def encode_language_datetime(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.Series: 104 | # convert 105 | values = safe_convert_datetime(values) 106 | values = values.copy() 107 | # reset index, as `values.mask` can throw errors for misaligned indices 108 | values.reset_index(drop=True, inplace=True) 109 | # replace extreme values with min/max 110 | values = _clip_datetime(values, stats) 111 | return values 112 | 113 | 114 | def decode_language_datetime(x: pd.Series, stats: dict[str, str]) -> pd.Series: 115 | x = x.where(~x.isin(["", "_INVALID_"]), np.nan) 116 | 117 | valid_mask = ( 118 | x.str.len().ge(10) 119 | & x.str.slice(0, 4).str.isdigit() 120 | & x.str.slice(5, 7).str.isdigit() 121 | & x.str.slice(8, 10).str.isdigit() 122 | ) 123 | if valid_mask.sum() > 0: # expected "YYYY-MM-DD" prefix 124 | # handle the date portion, ensuring validity 125 | years = x[valid_mask].str.slice(0, 4).astype(int) 126 | months = x[valid_mask].str.slice(5, 7).astype(int) 127 | days = x[valid_mask].str.slice(8, 10).astype(int) 128 | 129 | # clamp days according to maximum possible day of the month of a given year 130 | last_days = np.array([calendar.monthrange(y, m)[1] for y, m in zip(years, months)]) 131 | clamped_days = np.minimum(days, last_days) 132 | 133 | # rebuild the date portion 134 | new_date = ( 135 | years.astype(str).str.zfill(4) 136 | + "-" 137 | + months.astype(str).str.zfill(2) 138 | + "-" 139 | + pd.Series(clamped_days, index=years.index).astype(str).str.zfill(2) 140 | ) 141 | 142 | # handle the time portion, ensuring validity 143 | remainder = x[valid_mask].str.slice(10) 144 | 145 | time_regex = r"^[ T]?(\d{2}:\d{2}:\d{2}(?:\.\d+)?)" 146 | valid_time = remainder.str.extract(time_regex, expand=False) 147 | valid_time = valid_time.fillna("00:00:00") 148 | valid_time = " " + valid_time 149 | 150 | new_date = new_date + valid_time 151 | x.loc[valid_mask] = new_date 152 | 153 | x = pd.to_datetime(x, errors="coerce") 154 | x = _clip_datetime(x, stats) 155 | return x.astype("datetime64[ns]") 156 | -------------------------------------------------------------------------------- /mostlyai/engine/_training_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import functools 17 | import gc 18 | import logging 19 | import time 20 | 21 | import pandas as pd 22 | import torch 23 | from opacus.accountants import IAccountant 24 | from pydantic import BaseModel, Field, field_validator 25 | 26 | from mostlyai.engine._workspace import Workspace 27 | 28 | _LOG = logging.getLogger(__name__) 29 | 30 | 31 | class ProgressMessage(BaseModel, extra="allow"): 32 | epoch: float | None = Field(None, description="Current epoch number") 33 | is_checkpoint: bool | int | None = Field(0, description="Whether this progress is a checkpoint") 34 | steps: int | None = Field(None, description="Number of processed steps") 35 | samples: int | None = Field(None, description="Number of processed samples") 36 | trn_loss: float | None = Field(None, description="Training loss") 37 | val_loss: float | None = Field(None, description="Validation loss") 38 | total_time: float | None = Field(None, description="Elapsed total time (s)") 39 | learn_rate: float | None = Field(None, description="Learning rate") 40 | dp_eps: float | None = Field(None, description="Differential privacy epsilon") 41 | dp_delta: float | None = Field(None, description="Differential privacy delta") 42 | 43 | @field_validator("epoch", "trn_loss", "val_loss", "learn_rate", "total_time", "dp_eps", "dp_delta") 44 | @classmethod 45 | def round_float(cls, v, info) -> float: 46 | field_decimal_places = { 47 | "epoch": 2, 48 | "trn_loss": 4, 49 | "val_loss": 4, 50 | "learn_rate": 6, 51 | "total_time": 1, 52 | "dp_eps": 2, 53 | "dp_delta": 8, 54 | } 55 | if isinstance(v, float) and info.field_name in field_decimal_places: 56 | return round(v, field_decimal_places[info.field_name]) 57 | return v 58 | 59 | @field_validator("is_checkpoint") 60 | @classmethod 61 | def cast_to_int(cls, v) -> int: 62 | return int(v) 63 | 64 | 65 | class EarlyStopper: 66 | """ 67 | Stop training when val_loss stopped improving for a while 68 | """ 69 | 70 | def __init__(self, val_loss_patience: int) -> None: 71 | self.val_loss_patience = val_loss_patience 72 | self.best_loss = float("inf") 73 | self.val_loss_cnt = 0 74 | 75 | def __call__(self, val_loss: float) -> bool: 76 | do_stop = False 77 | # check val_loss 78 | if not pd.isna(val_loss) and val_loss < self.best_loss: 79 | # remember best val_loss 80 | self.best_loss = val_loss 81 | # reset counter 82 | self.val_loss_cnt = 0 83 | else: 84 | self.val_loss_cnt += 1 85 | if self.val_loss_cnt > self.val_loss_patience: 86 | _LOG.info("early stopping: val_loss stopped improving") 87 | do_stop = True 88 | return do_stop 89 | 90 | 91 | class ModelCheckpoint(abc.ABC): 92 | """ 93 | Save model weights for best model. 94 | """ 95 | 96 | def __init__(self, workspace: Workspace, initial_best_val_loss: float = float("inf")) -> None: 97 | self.workspace = workspace 98 | self.best_val_loss = initial_best_val_loss 99 | self.last_save_time = time.time() 100 | self.save_count = 0 101 | 102 | def optimizer_and_lr_scheduler_paths_exist(self) -> bool: 103 | return self.workspace.model_optimizer_path.exists() and self.workspace.model_lr_scheduler_path.exists() 104 | 105 | @abc.abstractmethod 106 | def model_weights_path_exists(self) -> None: 107 | pass 108 | 109 | def clear_checkpoint(self): 110 | self.workspace.model_optimizer_path.unlink(missing_ok=True) 111 | self.workspace.model_lr_scheduler_path.unlink(missing_ok=True) 112 | self._clear_model_weights() 113 | 114 | def save_checkpoint_if_best( 115 | self, 116 | val_loss: float, 117 | model: torch.nn.Module, 118 | optimizer: torch.optim.Optimizer | None = None, 119 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, 120 | dp_accountant: IAccountant | None = None, 121 | ) -> bool: 122 | # save model weights if validation loss has improved 123 | if val_loss < self.best_val_loss: 124 | self.best_val_loss = val_loss 125 | self.save_checkpoint(model, optimizer, lr_scheduler, dp_accountant) 126 | return True 127 | else: 128 | return False 129 | 130 | def save_checkpoint( 131 | self, 132 | model: torch.nn.Module, 133 | optimizer: torch.optim.Optimizer | None = None, 134 | lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, 135 | dp_accountant: IAccountant | None = None, 136 | ) -> None: 137 | if optimizer is not None and lr_scheduler is not None: 138 | torch.save(optimizer.state_dict(), self.workspace.model_optimizer_path) 139 | torch.save(lr_scheduler.state_dict(), self.workspace.model_lr_scheduler_path) 140 | if dp_accountant is not None: 141 | torch.save(dp_accountant.state_dict(), self.workspace.model_dp_accountant_path) 142 | self._save_model_weights(model) 143 | self.last_save_time = time.time() 144 | self.save_count += 1 145 | 146 | def has_saved_once(self) -> bool: 147 | return self.save_count > 0 148 | 149 | @abc.abstractmethod 150 | def _save_model_weights(self, model: torch.nn.Module) -> None: 151 | pass 152 | 153 | @abc.abstractmethod 154 | def _clear_model_weights(self) -> None: 155 | pass 156 | 157 | 158 | def check_early_training_exit(workspace: Workspace, trn_cnt: int, val_cnt: int) -> bool: 159 | trn_files = workspace.encoded_data_trn.fetch_all() 160 | val_files = workspace.encoded_data_val.fetch_all() 161 | return any((len(trn_files) == 0, len(val_files) == 0, trn_cnt == 0, val_cnt == 0)) 162 | 163 | 164 | def gpu_memory_cleanup(func): 165 | """Decorator to clean up GPU memory after function execution.""" 166 | 167 | @functools.wraps(func) 168 | def wrapper(*args, **kwargs): 169 | try: 170 | return func(*args, **kwargs) 171 | finally: 172 | for _ in range(5): 173 | gc.collect() 174 | torch.cuda.empty_cache() 175 | 176 | return wrapper 177 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Iterator, Mapping 16 | from dataclasses import dataclass 17 | from typing import Any 18 | 19 | from transformers import BatchEncoding, DataCollatorForLanguageModeling, LlamaTokenizerFast, PreTrainedTokenizerFast 20 | from transformers.data.data_collator import _torch_collate_batch, pad_without_fast_tokenizer_warning 21 | 22 | from mostlyai.engine.domain import ModelEncodingType 23 | 24 | ################# 25 | ### TOKENIZER ### 26 | ################# 27 | 28 | 29 | def train_tokenizer( 30 | training_iterator: Iterator | list | None = None, 31 | tokenizer_kwargs: dict[str, Any] | None = None, 32 | tgt_stats: dict[str, Any] | None = None, 33 | ): 34 | if tokenizer_kwargs is None: 35 | tokenizer_kwargs = {} 36 | from tokenizers import Tokenizer, decoders 37 | from tokenizers.models import BPE 38 | from tokenizers.normalizers import Replace 39 | from tokenizers.pre_tokenizers import Metaspace, Punctuation, Sequence, Split 40 | from tokenizers.trainers import BpeTrainer 41 | 42 | special_tokens = { 43 | "unk_token": "", 44 | "pad_token": "", 45 | "bos_token": "", 46 | "eos_token": "", 47 | } 48 | SPECIAL_TOKENS = list(special_tokens.values()) 49 | NEW_LINE_VALUE = "\n" 50 | NEW_LINE_SYMBOL = "\u240a" # https://www.fileformat.info/info/unicode/char/240a/index.htm 51 | MIN_FREQ_MERGE = 20 52 | VOCAB_SIZE = 5000 53 | 54 | # add initial alphabet for numeric and datetime columns if needed 55 | has_numeric_columns = any( 56 | col_stats["encoding_type"] == ModelEncodingType.language_numeric for col_stats in tgt_stats["columns"].values() 57 | ) 58 | has_datetime_columns = any( 59 | col_stats["encoding_type"] == ModelEncodingType.language_datetime for col_stats in tgt_stats["columns"].values() 60 | ) 61 | initial_alphabet = set() 62 | if has_numeric_columns: 63 | # FIXME: maybe the set can be more fine-grained based on max_scale in stats 64 | initial_alphabet |= {str(i) for i in range(10)} | {".", "-", "+", "e", "E"} 65 | if has_datetime_columns: 66 | initial_alphabet |= {str(i) for i in range(10)} | {".", "-", ":", "T", "Z"} 67 | initial_alphabet = list(initial_alphabet) 68 | 69 | # Builds a BPE raw_tokenizer, and optionally trains it based on provided text 70 | training_iterator = training_iterator or [] # allow easy training skip 71 | raw_tokenizer = Tokenizer(BPE(unk_token=special_tokens["unk_token"])) 72 | trainer = BpeTrainer( 73 | initial_alphabet=initial_alphabet, 74 | special_tokens=SPECIAL_TOKENS, 75 | min_frequency=MIN_FREQ_MERGE, 76 | vocab_size=VOCAB_SIZE, 77 | show_progress=False, 78 | ) 79 | raw_tokenizer.normalizer = Replace(NEW_LINE_VALUE, NEW_LINE_SYMBOL) 80 | raw_tokenizer.pre_tokenizer = Sequence( 81 | [ 82 | Metaspace(), 83 | Split(pattern=NEW_LINE_SYMBOL, behavior="isolated"), 84 | Punctuation(), 85 | ] 86 | ) 87 | raw_tokenizer.decoder = decoders.Sequence( 88 | [ 89 | decoders.Metaspace(), 90 | decoders.Replace(NEW_LINE_SYMBOL, NEW_LINE_VALUE), 91 | ] 92 | ) 93 | raw_tokenizer.train_from_iterator(iterator=training_iterator, trainer=trainer) 94 | tokenizer = LlamaTokenizerFast(tokenizer_object=raw_tokenizer, **special_tokens, **tokenizer_kwargs) 95 | return tokenizer 96 | 97 | 98 | def tokenize_fn( 99 | text: dict[str, str] | dict[str, list[str]] | list[str], 100 | tokenizer: PreTrainedTokenizerFast, 101 | text_key: str | None = None, 102 | return_tensors: str | None = None, 103 | padding: bool | str = True, 104 | truncation: bool = True, 105 | add_bos_token: bool = True, 106 | add_eos_token: bool = True, 107 | max_length: int = 1024, 108 | ) -> BatchEncoding: 109 | if text_key: 110 | text = text[text_key] 111 | # make sure the tokenizer is configured as expected 112 | if getattr(tokenizer, "add_bos_token", False) or getattr(tokenizer, "add_eos_token", False): 113 | raise RuntimeError("Tokenizer must be configured as add_bos_token=False and add_eos_token=False") 114 | if tokenizer.bos_token is None or tokenizer.eos_token is None: 115 | raise RuntimeError("Tokenizer must have bos_token and eos_token set") 116 | prefix = tokenizer.bos_token if add_bos_token else "" 117 | suffix = tokenizer.eos_token if add_eos_token else "" 118 | # NOTE: here we add bos/eos tokens before truncation and padding, 119 | # which means that they may be truncated for long sequences 120 | if isinstance(text, str): 121 | text = f"{prefix}{text}{suffix}" 122 | else: 123 | for i, t in enumerate(text): 124 | text[i] = f"{prefix}{t}{suffix}" 125 | tokenized_content = tokenizer( 126 | text, 127 | padding=padding, 128 | truncation=truncation, 129 | max_length=max_length, 130 | return_tensors=return_tensors, 131 | ) 132 | return tokenized_content 133 | 134 | 135 | ##################### 136 | ### DATA COLLATOR ### 137 | ##################### 138 | 139 | 140 | @dataclass 141 | class MostlyDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): 142 | def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]: 143 | """ 144 | A variation of the original `DataCollatorForLanguageModeling.torch_call` method. 145 | 146 | This method can mask tokens based on the attention mask, so that bos and eos tokens will not be masked 147 | even if they are identical to pad token. 148 | If attention mask is not provided, it will fall back to masking pad tokens. 149 | """ 150 | if isinstance(examples[0], Mapping): 151 | batch = pad_without_fast_tokenizer_warning( 152 | self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=None 153 | ) 154 | else: 155 | batch = {"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=None)} 156 | 157 | labels = batch["input_ids"].clone() 158 | attention_mask = batch.get("attention_mask", None) 159 | if attention_mask is not None: 160 | labels[(attention_mask == 0)] = -100 161 | else: 162 | if self.tokenizer.pad_token_id is not None: 163 | labels[labels == self.tokenizer.pad_token_id] = -100 164 | batch["labels"] = labels 165 | return batch 166 | -------------------------------------------------------------------------------- /mostlyai/engine/_language/engine/vllm_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import os 18 | 19 | os.environ["VLLM_USE_V1"] = "1" 20 | 21 | 22 | import time 23 | from os import PathLike 24 | 25 | import torch 26 | from peft import PeftConfig 27 | from pydantic import BaseModel 28 | from transformers import AutoConfig, AutoTokenizer 29 | from vllm import LLM, SamplingParams 30 | from vllm.distributed import cleanup_dist_env_and_memory 31 | from vllm.inputs.data import TokensPrompt 32 | from vllm.lora.request import LoRARequest 33 | from vllm.sampling_params import StructuredOutputsParams 34 | 35 | from mostlyai.engine._language.common import is_bf16_supported 36 | from mostlyai.engine._language.engine.base import EngineMetrics, LanguageEngine 37 | from mostlyai.engine._language.tokenizer_utils import tokenize_fn 38 | 39 | 40 | def get_dynamic_gpu_memory_utilization(utilization_ratio: float = 0.9) -> float: 41 | """ 42 | Calculate dynamic GPU memory utilization based on available memory. 43 | 44 | Args: 45 | utilization_ratio: Fraction of available GPU memory to use (default: 0.9) 46 | 47 | Returns: 48 | GPU memory utilization as a fraction. 49 | """ 50 | if not torch.cuda.is_available(): 51 | return utilization_ratio # fallback for non-GPU environments 52 | 53 | try: 54 | # Get free and total memory from CUDA 55 | free_memory, total_memory = torch.cuda.mem_get_info() 56 | 57 | # Use specified ratio of free memory 58 | target_memory = free_memory * utilization_ratio 59 | utilization = target_memory / total_memory 60 | 61 | # Ensure utilization is within reasonable bounds (0.1 to 0.95) 62 | return max(0.1, min(0.95, utilization)) 63 | 64 | except Exception: 65 | # Fallback to provided ratio if anything goes wrong 66 | return utilization_ratio 67 | 68 | 69 | class VLLMEngine(LanguageEngine): 70 | def __init__( 71 | self, model_path: PathLike | str, device: torch.device, max_new_tokens: int, tokenizer_max_length: int 72 | ): 73 | self.device = device 74 | self.tokenizer_max_length = tokenizer_max_length 75 | self.max_new_tokens = max_new_tokens 76 | 77 | peft_config = PeftConfig.from_pretrained(model_path) 78 | base_config = AutoConfig.from_pretrained(peft_config.base_model_name_or_path) 79 | 80 | model_path = str(model_path) 81 | self._lora_request = LoRARequest("adapter", 1, model_path) 82 | # Get max model length from config (different models use different attribute names) 83 | config_max_model_len = getattr( 84 | base_config, 85 | "max_position_embeddings", 86 | getattr(base_config, "n_positions", getattr(base_config, "max_sequence_length", 2048)), 87 | ) 88 | 89 | self.llm = LLM( 90 | model=peft_config.base_model_name_or_path, 91 | tokenizer=model_path, 92 | max_model_len=min(config_max_model_len, self.tokenizer_max_length + max_new_tokens), 93 | enable_lora=True, 94 | dtype=torch.bfloat16 if is_bf16_supported(device) else torch.float16, 95 | # enforce_eager=True, # results in big slowdown, but is needed when running pytest locally 96 | swap_space=0, 97 | disable_log_stats=True, 98 | tensor_parallel_size=torch.cuda.device_count(), 99 | gpu_memory_utilization=get_dynamic_gpu_memory_utilization(), 100 | ) 101 | self.tokenizer = AutoTokenizer.from_pretrained( 102 | model_path, 103 | padding_side="left", 104 | truncation_side="left", 105 | legacy=True, 106 | # these must be False at initialization, as we manually add them later in tokenize_fn 107 | add_bos_token=False, 108 | add_eos_token=False, 109 | ) 110 | self._prepared_schemas = None 111 | 112 | def get_default_batch_size(self) -> int: 113 | return 192 114 | 115 | def supports_json_enforcing(self) -> bool: 116 | return True 117 | 118 | def generate( 119 | self, text: list[str], sampling_temperature: float, sampling_top_p: float 120 | ) -> tuple[list[int], EngineMetrics]: 121 | tokenize_kwargs = dict( 122 | tokenizer=self.tokenizer, 123 | return_tensors=None, 124 | add_bos_token=True, 125 | add_eos_token=False, 126 | padding=False, 127 | truncation=True, 128 | max_length=self.tokenizer_max_length, # truncates input 129 | ) 130 | t_tokenize = time.time() 131 | inputs = tokenize_fn(text=text, **tokenize_kwargs) 132 | tokenize_time = time.time() - t_tokenize 133 | 134 | actual_batch_size = len(inputs["input_ids"]) 135 | 136 | # Create sampling params with guided decoding if schemas are prepared 137 | effective_schemas = self._prepared_schemas 138 | 139 | sampling_params = [] 140 | for i in range(actual_batch_size): 141 | structured_outputs = None 142 | if effective_schemas and i < len(effective_schemas): 143 | # Convert Pydantic model to JSON schema for structured output 144 | schema_dict = effective_schemas[i].model_json_schema() 145 | structured_outputs = StructuredOutputsParams(json=schema_dict) 146 | 147 | sampling_params.append( 148 | SamplingParams( 149 | max_tokens=self.max_new_tokens, 150 | temperature=sampling_temperature, 151 | top_p=sampling_top_p, 152 | structured_outputs=structured_outputs, 153 | ) 154 | ) 155 | t_generate = time.time() 156 | outputs = self.llm.generate( 157 | prompts=[TokensPrompt(prompt_token_ids=token_ids) for token_ids in inputs["input_ids"]], 158 | sampling_params=sampling_params, 159 | use_tqdm=False, 160 | lora_request=self._lora_request, 161 | ) 162 | generate_time = time.time() - t_generate 163 | metrics = EngineMetrics(tokenize_time=tokenize_time, generate_time=generate_time) 164 | return [r.outputs[0].token_ids for r in outputs], metrics 165 | 166 | def cleanup(self): 167 | del self.llm 168 | cleanup_dist_env_and_memory() 169 | 170 | def update_json_constraints(self, schemas: list[BaseModel] | None) -> None: 171 | """Update JSON schema constraints for the next generation call.""" 172 | self._prepared_schemas = list(schemas) if schemas else None 173 | 174 | def can_reuse_schemas(self) -> bool: 175 | """VLLMEngine can handle variable batch sizes since it creates sampling params per sample.""" 176 | return True 177 | -------------------------------------------------------------------------------- /mostlyai/engine/_encoding_types/tabular/character.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Character encoding splits any value into its characters, and encodes each position then separately as a categorical. 17 | """ 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | from mostlyai.engine._common import ( 23 | dp_non_rare, 24 | get_stochastic_rare_threshold, 25 | impute_from_non_nan_distribution, 26 | safe_convert_string, 27 | ) 28 | 29 | UNKNOWN_TOKEN = "\0" 30 | MAX_LENGTH_CHARS = 50 31 | 32 | 33 | def analyze_character(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: 34 | values = safe_convert_string(values) 35 | df_split = split_sub_columns_character(values) 36 | has_nan = sum(df_split["nan"]) > 0 37 | # count distinct root_keys per token for each character position 38 | df = pd.concat([root_keys, df_split], axis=1) 39 | characters = { 40 | sub_col: df.groupby(sub_col)[root_keys.name].nunique().to_dict() 41 | for sub_col in df_split.columns 42 | if sub_col.startswith("P") 43 | } 44 | stats = { 45 | "max_string_length": len(characters), 46 | "has_nan": has_nan, 47 | "characters": characters, 48 | } 49 | return stats 50 | 51 | 52 | def analyze_reduce_character( 53 | stats_list: list[dict], 54 | value_protection: bool = True, 55 | value_protection_epsilon: float | None = None, 56 | ) -> dict: 57 | # gather maximum string length across partitions 58 | max_string_length = max(stats["max_string_length"] for stats in stats_list) 59 | positions = [f"P{idx}" for idx in range(max_string_length)] 60 | # gather codes for each position 61 | codes: dict[str, dict[str, int]] = {pos: {} for pos in positions} 62 | for pos in positions: 63 | cnt_values: dict[str, int] = {} 64 | # sum up all counts for each token 65 | for item in stats_list: 66 | for value, count in item["characters"].get(pos, {}).items(): 67 | cnt_values[value] = cnt_values.get(value, 0) + count 68 | cnt_values = dict(sorted(cnt_values.items())) 69 | known_categories = list(cnt_values.keys()) 70 | if value_protection: 71 | if value_protection_epsilon is not None: 72 | categories, _ = dp_non_rare(cnt_values, value_protection_epsilon, threshold=5) 73 | else: 74 | rare_min = get_stochastic_rare_threshold(min_threshold=5) 75 | categories = [k for k in known_categories if cnt_values[k] >= rare_min] 76 | else: 77 | categories = known_categories 78 | # add special token for UNKNOWN at first position 79 | categories = [UNKNOWN_TOKEN] + [c for c in categories if c != UNKNOWN_TOKEN] 80 | # assign codes for each token 81 | codes[pos] = {categories[i]: i for i in range(len(categories))} 82 | # determine cardinalities 83 | cardinalities = {} 84 | has_nan = any([s["has_nan"] for s in stats_list]) 85 | if has_nan: 86 | cardinalities["nan"] = 2 # binary 87 | for sub_col, sub_col_codes in codes.items(): 88 | cardinalities[sub_col] = len(sub_col_codes) 89 | stats = { 90 | "has_nan": has_nan, 91 | "max_string_length": max_string_length, 92 | "codes": codes, 93 | "cardinalities": cardinalities, 94 | } 95 | return stats 96 | 97 | 98 | def encode_character(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: 99 | values = safe_convert_string(values) 100 | values, nan_mask = impute_from_non_nan_distribution(values, stats) 101 | max_string_length = stats["max_string_length"] 102 | df_split = split_sub_columns_character(values, max_string_length) 103 | for idx in range(max_string_length): 104 | sub_col = f"P{idx}" 105 | np_codes = np.array(pd.Categorical(df_split[sub_col], categories=stats["codes"][sub_col]).codes) 106 | np.place(np_codes, np_codes == -1, 0) 107 | df_split[sub_col] = np_codes 108 | if stats["has_nan"]: 109 | df_split["nan"] = nan_mask 110 | else: 111 | df_split.drop(["nan"], axis=1, inplace=True) 112 | return df_split 113 | 114 | 115 | def split_sub_columns_character( 116 | values: pd.Series, 117 | max_string_length: int | None = None, 118 | ) -> pd.DataFrame: 119 | if not pd.api.types.is_string_dtype(values): 120 | raise ValueError("expected to be string") 121 | is_na = pd.Series(values.isna().astype("int"), name="nan").to_frame() 122 | values = values.fillna("") 123 | # trim strings to a maximum length 124 | values = values.str.slice(stop=MAX_LENGTH_CHARS) 125 | # pad strings to string_length 126 | if max_string_length is None: 127 | max_string_length = values.str.len().max() 128 | max_string_length = ( 129 | int(max_string_length) # type: ignore 130 | if np.isscalar(max_string_length) and not np.isnan(max_string_length) 131 | else 0 132 | ) 133 | else: 134 | values = values.str.slice(stop=max_string_length) 135 | # explode to wide dataframe 136 | padded_values = values.str.ljust(max_string_length, UNKNOWN_TOKEN) 137 | chars_df = padded_values.str.split("", expand=True) 138 | if not chars_df.empty: 139 | chars_df = chars_df.drop([0, max_string_length + 1], axis=1) 140 | chars_df.columns = [f"P{idx}" for idx in range(max_string_length)] 141 | else: # chars_df.empty is True 142 | # even though the input is empty, we still need to return a dataframe with the correct columns 143 | chars_df = pd.DataFrame(columns=[f"P{idx}" for idx in range(max_string_length)]) 144 | df = pd.concat([is_na, chars_df], axis=1) 145 | return df 146 | 147 | 148 | def decode_character(df_encoded: pd.DataFrame, stats: dict) -> pd.Series: 149 | if len(stats["codes"].keys()) > 0: 150 | df_decoded = pd.DataFrame( 151 | { 152 | sub_col: pd.Series( 153 | pd.Categorical.from_codes(df_encoded[sub_col], categories=stats["codes"][sub_col]), 154 | dtype="string", 155 | ) 156 | for sub_col in stats["codes"].keys() 157 | }, 158 | ) 159 | values = df_decoded.apply(lambda item: "".join(item), axis=1, result_type="reduce").astype( 160 | str 161 | ) # necessary to keep string dtype for empty df_decoded 162 | # remove unknown tokens and strip trailing whitespaces 163 | values = values.apply(lambda item: item.replace(UNKNOWN_TOKEN, "")).str.rstrip() 164 | else: 165 | # handle de-generate case, where no tokens were stored 166 | values = pd.Series(pd.NA, index=range(df_encoded.shape[0])) 167 | if stats["has_nan"]: 168 | values[df_encoded["nan"] == 1] = pd.NA 169 | return values 170 | -------------------------------------------------------------------------------- /tests/unit/encoding_types/tabular/test_itt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | 17 | from mostlyai.engine._common import read_json, write_json 18 | from mostlyai.engine._encoding_types.tabular.itt import ( 19 | analyze_itt, 20 | analyze_reduce_itt, 21 | decode_itt, 22 | encode_itt, 23 | ) 24 | 25 | 26 | def test_itt_date(tmp_path): 27 | values = pd.to_datetime( 28 | pd.Series( 29 | [None, "1978-05-24", "1976-06-22", "1992-12-24", None], 30 | name="date", 31 | dtype="datetime64[us]", 32 | ) 33 | ) 34 | context_keys = pd.Series(["a", "a", "a", "b", "c"], name="__context_key") 35 | root_keys = context_keys.copy() 36 | root_keys.name = "__root_key" 37 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 38 | write_json(stats1, tmp_path / "stats1.json") 39 | stats1 = read_json(tmp_path / "stats1.json") 40 | stats = analyze_reduce_itt([stats1], value_protection=False) 41 | write_json(stats, tmp_path / "stats.json") 42 | stats = read_json(tmp_path / "stats.json") 43 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 44 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys) 45 | assert values.equals(df_decoded) 46 | 47 | 48 | def test_itt_datetime(tmp_path): 49 | values = pd.to_datetime( 50 | pd.Series( 51 | [ 52 | None, 53 | "1978-05-24 12:23:43", 54 | "1976-06-22 17:32:00", 55 | "1992-12-24 01:32:59", 56 | None, 57 | ], 58 | name="date", 59 | dtype="datetime64[us]", 60 | ) 61 | ) 62 | context_keys = pd.Series(["a", "a", "a", "b", "c"], name="__context_key") 63 | root_keys = context_keys.copy() 64 | root_keys.name = "__root_key" 65 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 66 | stats = analyze_reduce_itt([stats1], value_protection=False) 67 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 68 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys) 69 | assert values.equals(df_decoded) 70 | 71 | 72 | def test_itt_nones_only(tmp_path): 73 | values = pd.to_datetime(pd.Series([None, None, None], name="value", dtype="datetime64[us]")) 74 | context_keys = pd.Series(["a", "a", "b"], name="id") 75 | root_keys = pd.Series(["a", "a", "b"], name="rid") 76 | stats = analyze_reduce_itt([analyze_itt(values, root_keys, context_keys)], value_protection=False) 77 | df_encoded = encode_itt(values, stats, context_keys) 78 | df_decoded = decode_itt(df_encoded, stats, context_keys) 79 | assert all(df_decoded.isna()) 80 | 81 | 82 | def test_itt_empty(tmp_path): 83 | values = pd.Series([], name="value") 84 | root_keys = pd.Series([], name="rid") 85 | context_keys = pd.Series([], name="id") 86 | partition_stats = analyze_itt(values, root_keys, context_keys) 87 | stats = analyze_reduce_itt([partition_stats]) 88 | df_encoded = encode_itt(values, stats, context_keys) 89 | df_decoded = decode_itt(df_encoded, stats, context_keys) 90 | min_max_values = { 91 | "itt_day": 0, 92 | "itt_hour": 0, 93 | "itt_minute": 0, 94 | "itt_second": 0, 95 | "itt_week": 0, 96 | "start_day": 1, 97 | "start_hour": 0, 98 | "start_minute": 0, 99 | "start_month": 1, 100 | "start_second": 0, 101 | "start_year": 2022, 102 | } 103 | assert partition_stats == { 104 | "has_nan": False, 105 | "has_neg": False, 106 | "max_n": [], 107 | "max_values": min_max_values, 108 | "min_n": [], 109 | "min_values": min_max_values, 110 | "log_hist": [0.0] * 128, 111 | } 112 | assert stats == { 113 | "cardinalities": { 114 | "itt_day": 1, 115 | "itt_week": 1, 116 | "start_day": 1, 117 | "start_month": 1, 118 | "start_year": 1, 119 | }, 120 | "has_nan": False, 121 | "has_neg": False, 122 | "has_time": False, 123 | "max": None, 124 | "max_values": min_max_values, 125 | "min": None, 126 | "min_values": min_max_values, 127 | } 128 | assert df_encoded.empty, df_encoded.columns.tolist() == (True, []) 129 | assert df_decoded.empty, df_encoded.columns.tolist() == (True, []) 130 | 131 | 132 | def test_itt_1to1(tmp_path): 133 | values = pd.to_datetime( 134 | pd.Series( 135 | [None, "1978-05-24", "1976-06-22", "1992-12-24", None], 136 | name="date", 137 | dtype="datetime64[us]", 138 | ) 139 | ) 140 | context_keys = pd.Series(["a", "b", "c", "d", "e"], name="__context_key") 141 | root_keys = context_keys.copy() 142 | root_keys.name = "__root_key" 143 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 144 | write_json(stats1, tmp_path / "stats1.json") 145 | stats1 = read_json(tmp_path / "stats1.json") 146 | stats = analyze_reduce_itt([stats1], value_protection=False) 147 | write_json(stats, tmp_path / "stats.json") 148 | stats = read_json(tmp_path / "stats.json") 149 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 150 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys) 151 | assert values.equals(df_decoded) 152 | 153 | 154 | def test_itt_with_prev_steps(tmp_path): 155 | values = pd.to_datetime( 156 | pd.Series( 157 | ["1978-05-24", "1976-06-22", "1976-06-23", "1976-06-24"], 158 | name="date", 159 | dtype="datetime64[us]", 160 | ) 161 | ) 162 | context_keys = pd.Series(["a", "b", "b", "b"], name="__context_key") 163 | root_keys = context_keys.copy() 164 | root_keys.name = "__root_key" 165 | stats1 = analyze_itt(values=values, root_keys=root_keys, context_keys=context_keys) 166 | write_json(stats1, tmp_path / "stats1.json") 167 | stats1 = read_json(tmp_path / "stats1.json") 168 | stats = analyze_reduce_itt([stats1], value_protection=False) 169 | write_json(stats, tmp_path / "stats.json") 170 | stats = read_json(tmp_path / "stats.json") 171 | df_encoded = encode_itt(values=values, stats=stats, context_keys=context_keys) 172 | prev_steps = { 173 | "prev_dts": pd.DataFrame( 174 | { 175 | "__CONTEXT_KEYS": ["a", "b"], 176 | "__STARTS": pd.to_datetime(pd.Series(["1978-05-23", "1976-06-21"], dtype="datetime64[us]")), 177 | } 178 | ) 179 | } 180 | df_decoded = decode_itt(df_encoded=df_encoded, stats=stats, context_keys=context_keys, prev_steps=prev_steps) 181 | assert values.equals(df_decoded) 182 | --------------------------------------------------------------------------------