├── mostlyai └── sdk │ ├── py.typed │ ├── _data │ ├── db │ │ ├── __init__.py │ │ ├── sqlite.py │ │ ├── types_coercion.py │ │ ├── mysql.py │ │ ├── snowflake.py │ │ ├── postgresql.py │ │ └── mssql.py │ ├── file │ │ ├── __init__.py │ │ ├── table │ │ │ ├── __init__.py │ │ │ ├── feather.py │ │ │ ├── parquet.py │ │ │ ├── csv.py │ │ │ └── json.py │ │ ├── container │ │ │ ├── __init__.py │ │ │ ├── minio.py │ │ │ ├── gcs.py │ │ │ └── bucket_based.py │ │ └── utils.py │ ├── util │ │ ├── __init__.py │ │ └── kerberos.py │ ├── exceptions.py │ ├── __init__.py │ ├── pull_context.py │ ├── progress_callback.py │ ├── language_model.py │ ├── pull.py │ ├── auto_detect.py │ └── conversions.py │ ├── _local │ ├── __init__.py │ ├── execution │ │ ├── __init__.py │ │ ├── step_encode_training_data.py │ │ ├── migration.py │ │ ├── step_create_data_report.py │ │ ├── step_train_model.py │ │ ├── step_deliver_data.py │ │ ├── step_pull_training_data.py │ │ ├── step_generate_model_report_data.py │ │ └── step_analyze_training_data.py │ ├── cli.py │ └── server.py │ ├── client │ ├── __init__.py │ ├── exceptions.py │ ├── artifacts.py │ ├── _naming_conventions.py │ ├── _base_utils.py │ └── integrations.py │ └── __init__.py ├── docs ├── logo.png ├── favicon.png ├── index.md ├── TabularARGN-benchmark.png ├── tutorials │ ├── multi-table │ │ ├── berka-sqlite.db │ │ ├── berka-ui-1.png │ │ ├── berka-ui-2.png │ │ ├── berka-original.png │ │ ├── berka-configuration.png │ │ └── migrate-sqlite.py │ ├── rebalancing │ │ └── rebalancing.png │ ├── fake-or-real │ │ └── fake-or-real.png │ ├── train-synthetic-test-real │ │ └── TSTR.png │ ├── size-vs-accuracy │ │ └── size-vs-accuracy.png │ ├── star-schema-correlations │ │ └── baseball_table_relationships.png │ └── quality-assurance │ │ └── quality-assurance.ipynb ├── api_client.md └── api_domain.md ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml ├── pull_request_template.md ├── workflows │ ├── pre-commit-check.yaml │ ├── run-tests-gpu.yaml │ ├── workflow.yaml │ ├── build-docker-image.yaml │ └── run-tests-cpu.yaml └── changelog_config.json ├── .gitignore ├── tests ├── __init__.py ├── _data │ ├── __init__.py │ └── unit │ │ ├── __init__.py │ │ ├── db │ │ ├── __init__.py │ │ ├── test_db.py │ │ └── test_types_coercion.py │ │ ├── file │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_feather.py │ │ ├── test_json.py │ │ └── test_utils.py │ │ └── util │ │ ├── __init__.py │ │ └── test_kerberos.py ├── _local │ ├── __init__.py │ ├── unit │ │ ├── __init__.py │ │ ├── test_server.py │ │ └── test_migration.py │ └── end_to_end │ │ └── __init__.py └── client │ ├── __init__.py │ └── unit │ ├── __init__.py │ ├── test_naming_conventions.py │ └── test_base.py ├── .devcontainer ├── setup.sh ├── devcontainer.json └── local │ └── devcontainer.json ├── tools ├── README.md ├── docker_entrypoint.py └── extend_model.py ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── mkdocs.yml └── Dockerfile /mostlyai/sdk/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/logo.png -------------------------------------------------------------------------------- /docs/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/favicon.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | --8<-- "README.md" 7 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @mostly-ai/mostly-developers 2 | /.github/workflows/* @mostly-ai/mostly-devops 3 | -------------------------------------------------------------------------------- /docs/TabularARGN-benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/TabularARGN-benchmark.png -------------------------------------------------------------------------------- /docs/tutorials/multi-table/berka-sqlite.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/multi-table/berka-sqlite.db -------------------------------------------------------------------------------- /docs/tutorials/multi-table/berka-ui-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/multi-table/berka-ui-1.png -------------------------------------------------------------------------------- /docs/tutorials/multi-table/berka-ui-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/multi-table/berka-ui-2.png -------------------------------------------------------------------------------- /docs/tutorials/rebalancing/rebalancing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/rebalancing/rebalancing.png -------------------------------------------------------------------------------- /docs/tutorials/fake-or-real/fake-or-real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/fake-or-real/fake-or-real.png -------------------------------------------------------------------------------- /docs/tutorials/multi-table/berka-original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/multi-table/berka-original.png -------------------------------------------------------------------------------- /docs/tutorials/train-synthetic-test-real/TSTR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/train-synthetic-test-real/TSTR.png -------------------------------------------------------------------------------- /docs/tutorials/multi-table/berka-configuration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/multi-table/berka-configuration.png -------------------------------------------------------------------------------- /docs/tutorials/size-vs-accuracy/size-vs-accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/size-vs-accuracy/size-vs-accuracy.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | dist/ 3 | LICENSE_HEADER 4 | site/ 5 | __pycache__/ 6 | .vscode/ 7 | .pytest_cache/ 8 | .ipynb_checkpoints 9 | .DS_Store 10 | .env 11 | -------------------------------------------------------------------------------- /docs/tutorials/star-schema-correlations/baseball_table_relationships.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostly-ai/mostlyai/HEAD/docs/tutorials/star-schema-correlations/baseball_table_relationships.png -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false # Disables the option to create blank issues 2 | 3 | contact_links: 4 | - name: Contact Support 5 | url: mailto:support@mostly.ai 6 | about: For any support-related queries, please email us directly at support@mostly.ai. 7 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Pull Request 2 | 3 | ## Changes 4 | 5 | Briefly describe your changes here. 6 | 7 | ## Why this change? 8 | 9 | Explain the reason for the change. 10 | 11 | ## Testing 12 | 13 | How was the change tested? 14 | 15 | ## Additional Notes 16 | 17 | Any additional information or context you want to provide? 18 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_local/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_data/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_local/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/client/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_data/unit/db/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_data/unit/file/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_data/unit/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/client/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/client/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/_local/end_to_end/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/table/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/container/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/api_client.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | # API Reference 7 | 8 | ## MOSTLY AI Client 9 | 10 | ::: mostlyai.sdk.client.api.MostlyAI 11 | 12 | ## Generators 13 | 14 | ::: mostlyai.sdk.client.generators._MostlyGeneratorsClient 15 | 16 | ## Generator 17 | 18 | ::: mostlyai.sdk.domain.Generator 19 | 20 | ## Synthetic Datasets 21 | 22 | ::: mostlyai.sdk.client.synthetic_datasets._MostlySyntheticDatasetsClient 23 | 24 | ## Synthetic Dataset 25 | 26 | ::: mostlyai.sdk.domain.SyntheticDataset 27 | 28 | ## Connectors 29 | 30 | ::: mostlyai.sdk.client.connectors._MostlyConnectorsClient 31 | 32 | ## Connector 33 | 34 | ::: mostlyai.sdk.domain.Connector 35 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class MostlyDataException(Exception): 17 | pass 18 | -------------------------------------------------------------------------------- /docs/api_domain.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | # Schema References for `mostlyai.sdk.domain` 7 | 8 | This module is auto-generated to represent `pydantic`-based classes of the defined schema in the [Public API](https://github.com/mostly-ai/mostly-openapi/blob/main/public-api.yaml). 9 | 10 | ::: mostlyai.sdk.domain 11 | options: 12 | show_root_heading: true 13 | show_root_full_path: true 14 | show_object_full_path: false 15 | show_root_toc_entry: false 16 | filters: 17 | - "!^Assistant.*" 18 | - "!^Share.*" 19 | - "!^ResourceShares" 20 | - "!^LiteLlm.*" 21 | - "!^DataLlm.*" 22 | - "!.*PatchConfig.*" 23 | - "!.*CloneConfig.*" 24 | - "!^UsageReport.*" 25 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-check.yaml: -------------------------------------------------------------------------------- 1 | name: 'mostlyai Pre-Commit Check' 2 | 3 | on: [workflow_call] 4 | 5 | env: 6 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 7 | FORCE_COLOR: '1' 8 | 9 | jobs: 10 | pre-commit-check: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 14 | - name: Set up Python 15 | uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 16 | with: 17 | python-version: '3.10' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install pre-commit 22 | pre-commit install 23 | - name: Run pre-commit 24 | run: pre-commit run --all-files 25 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mostlyai.sdk._data.pull import pull 16 | from mostlyai.sdk._data.pull_context import pull_context 17 | 18 | __all__ = ["pull", "pull_context"] 19 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/container/minio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mostlyai.sdk._data.file.container.aws import AwsS3FileContainer 16 | 17 | 18 | class MinIOContainer(AwsS3FileContainer): 19 | pass 20 | -------------------------------------------------------------------------------- /.devcontainer/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Install uv package manager 3 | pip install uv 4 | 5 | # Set up virtual environment using uv 6 | uv venv 7 | 8 | # Conditionally add `--extra local` if SDK_MODE is "local" 9 | if [ "$SDK_MODE" == "local" ]; then 10 | echo "Running in LOCAL SDK mode" 11 | uv sync --extra local --extra dev --frozen 12 | else 13 | echo "Running in CLIENT SDK mode" 14 | uv sync --extra dev --frozen 15 | fi 16 | 17 | # Activate the virtual environment 18 | source .venv/bin/activate 19 | 20 | # Ensure pip and Jupyter (along with useful related packages) are installed and up-to-date 21 | uv pip install --upgrade --force-reinstall pip jupyter ipywidgets ipykernel jupyter_contrib_nbextensions 22 | 23 | # Register the Jupyter kernel explicitly 24 | python -m ipykernel install --user --name=python3 --display-name "Python 3 (Dev Container)" 25 | -------------------------------------------------------------------------------- /tests/_local/unit/test_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from mostlyai.sdk import MostlyAI 17 | from mostlyai.sdk.domain import AboutService 18 | 19 | 20 | def test_server(tmp_path): 21 | mostly = MostlyAI(local=True, local_dir=str(tmp_path), quiet=True) 22 | assert isinstance(mostly.about(), AboutService) 23 | -------------------------------------------------------------------------------- /.github/changelog_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "# What's Changed\n\n#{{CHANGELOG}}\n\n**Full Changelog**: [#{{FROM_TAG}}...#{{TO_TAG}}](#{{RELEASE_DIFF}})", 3 | "pr_template": "- #{{TITLE}} [##{{NUMBER}}](#{{URL}})", 4 | "empty_template": "# No Changes", 5 | "categories": [ 6 | { 7 | "title": "## 🚀 Features", 8 | "labels": ["feat"] 9 | }, 10 | { 11 | "title": "## 🐛 Fixes", 12 | "labels": ["fix"] 13 | }, 14 | { 15 | "title": "## 📦 Uncategorized", 16 | "labels": ["chore", "build", "docs", "refactor", "style"] 17 | } 18 | ], 19 | "ignore_labels": ["bump", "ci"], 20 | "label_extractor": [ 21 | { 22 | "pattern": "^([\\w-]+)(?:\\(([^)]+)\\))?: (.+)$", 23 | "target": "$1", 24 | "on_property": "title" 25 | } 26 | ], 27 | "transformers": [ 28 | { 29 | "pattern": "^(?:.*?:\\s+)?(.*)$", 30 | "target": "$1", 31 | "on_property": "title" 32 | } 33 | ] 34 | } 35 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Client Mode (light)", 3 | "image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye", 4 | "postCreateCommand": "/bin/bash -c 'source .devcontainer/setup.sh'", 5 | "containerEnv": { 6 | "UV_LINK_MODE": "copy", 7 | "SDK_MODE": "client" 8 | }, 9 | "extensions": [ 10 | "ms-python.python", 11 | "ms-python.vscode-pylance", 12 | "ms-toolsai.jupyter", 13 | "charliermarsh.ruff", 14 | "kevinrose.vsc-python-indent" 15 | ], 16 | "forwardPorts": [8000], 17 | "settings": { 18 | "default": true, 19 | "python.defaultInterpreterPath": "/workspaces/mostlyai/.venv/bin/python", 20 | "jupyter.defaultKernel": "Python 3 (Dev Container)", 21 | "jupyter.jupyterServerType": "local", 22 | "python.terminal.activateEnvironment": true 23 | }, 24 | "customizations": { 25 | "vscode": { 26 | "settings": { 27 | "workbench.editorAssociations": { 28 | "*.md": "vscode.markdown.preview.editor" 29 | } 30 | } 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a new feature or improvement 3 | title: "[FEATURE]: " 4 | labels: [enhancement] 5 | body: 6 | - type: markdown 7 | id: intro 8 | attributes: 9 | value: "Thanks for suggesting a feature! Please share the details below." 10 | - type: input 11 | id: feature_summary 12 | attributes: 13 | label: Feature Summary 14 | description: A clear and concise summary of the feature. 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: problem_solution 19 | attributes: 20 | label: Problem and Solution 21 | description: Explain the problem your feature aims to solve and how it does so. 22 | - type: textarea 23 | id: potential_alternatives 24 | attributes: 25 | label: Potential Alternatives 26 | description: Any alternative solutions or features you've considered. 27 | - type: textarea 28 | id: additional_context 29 | attributes: 30 | label: Additional Context 31 | description: Add any other context about the feature request. 32 | -------------------------------------------------------------------------------- /.devcontainer/local/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Local Mode (full bundle)", 3 | "image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye", 4 | "hostRequirements": { 5 | "cpu": 8, 6 | "memory": "32gb" 7 | }, 8 | "postCreateCommand": "/bin/bash -c 'source .devcontainer/setup.sh'", 9 | "containerEnv": { 10 | "UV_LINK_MODE": "copy", 11 | "SDK_MODE": "local" 12 | }, 13 | "extensions": [ 14 | "ms-python.python", 15 | "ms-python.vscode-pylance", 16 | "ms-toolsai.jupyter", 17 | "charliermarsh.ruff", 18 | "kevinrose.vsc-python-indent" 19 | ], 20 | "forwardPorts": [8000], 21 | "settings": { 22 | "default": false, 23 | "python.defaultInterpreterPath": "/workspaces/mostlyai/.venv/bin/python", 24 | "jupyter.defaultKernel": "Python 3 (Dev Container)", 25 | "jupyter.jupyterServerType": "local", 26 | "python.terminal.activateEnvironment": true 27 | }, 28 | "customizations": { 29 | "vscode": { 30 | "settings": { 31 | "workbench.editorAssociations": { 32 | "*.md": "vscode.markdown.preview.editor" 33 | } 34 | } 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Tools 2 | 3 | ## model extension 4 | 5 | While `datamodel-codegen` is a very handy tool for converting OpenAPI definitions into 6 | pydantic objects (see `mostlyai/sdk/client/domain.py`), it lacks the direct ability of adding extra functionality (e.g. methods) 7 | to the classes it creates. For example, having `Generator.add_table(...)` is out of its scope, but 8 | there's a trick being used here: 9 | - `datamodel-codegen` works with `Jinja2` templates, and custom templates can be specified 10 | - That specific template is located in `custom_template/pydantic_v2/BaseModel.jinja2` 11 | - That template is created based on `tools/model.py`, which contains the functionality to add to existing classes 12 | - Running `extend_model.py` updates the template based on `tools/model.py` 13 | 14 | tl;dr the content of `tools/model.py` is being stitched to what `datamodel-codegen` generates natively. 15 | Moreover, all of that happens by simply running `make gen-public-model` 16 | 17 | ## Updating public model (based on `public-api.yaml`) 18 | 19 | `make gen-public-model` does all the required actions to rewrite and format `mostlyai/sdk/client/domain.py`. The relevant code sections from `tools/model.py` are stitched into the resulted `mostlyai/sdk/client/domain.py` 20 | -------------------------------------------------------------------------------- /mostlyai/sdk/client/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from rich.console import Console 16 | from rich.panel import Panel 17 | from rich.text import Text 18 | 19 | 20 | class APIError(Exception): 21 | def __init__(self, message: str = None, do_rich_print: bool = True): 22 | super().__init__(message) 23 | self.message = message 24 | if do_rich_print: 25 | console = Console() 26 | error_message = Text(self.message, style="bold red") 27 | error_panel = Panel(error_message, expand=False) 28 | console.print(error_panel) 29 | 30 | def __str__(self): 31 | return self.message 32 | 33 | 34 | class APIStatusError(APIError): 35 | pass 36 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_encode_training_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections.abc import Callable 17 | from pathlib import Path 18 | 19 | from mostlyai.sdk._local.execution.migration import migrate_workspace 20 | 21 | 22 | def execute_step_encode_training_data( 23 | *, 24 | workspace_dir: Path, 25 | update_progress: Callable, 26 | ): 27 | # import ENGINE here to avoid pre-mature loading of large ENGINE dependencies 28 | from mostlyai import engine 29 | 30 | # ensure backward compatibility 31 | migrate_workspace(workspace_dir) 32 | 33 | # call ENCODE 34 | engine.encode( 35 | workspace_dir=workspace_dir, 36 | update_progress=update_progress, 37 | ) 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report 3 | title: "[BUG]: " 4 | labels: [bug] 5 | body: 6 | - type: markdown 7 | id: intro 8 | attributes: 9 | value: "Thanks for reporting a bug! Please fill out the information below." 10 | - type: input 11 | id: bug_description 12 | attributes: 13 | label: Bug Description 14 | description: A clear and concise description of what the bug is. 15 | validations: 16 | required: true 17 | - type: input 18 | id: python_version 19 | attributes: 20 | label: Python Version 21 | description: The version of Python you're using. 22 | placeholder: "e.g., 3.8, 3.9, 3.10" 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: steps_to_reproduce 27 | attributes: 28 | label: Steps to Reproduce 29 | description: Steps to reproduce the behavior. 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: expected_behavior 34 | attributes: 35 | label: Expected Behavior 36 | description: What you expected to happen. 37 | validations: 38 | required: true 39 | - type: textarea 40 | id: additional_context 41 | attributes: 42 | label: Additional Context 43 | description: Add any other context like OS, environment setup, etc. 44 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/table/feather.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pyarrow.dataset as ds 17 | 18 | from mostlyai.sdk._data.file.base import FileContainer, FileDataTable, LocalFileContainer 19 | 20 | 21 | class FeatherDataTable(FileDataTable): 22 | DATA_TABLE_TYPE = "feather" 23 | IS_WRITE_APPEND_ALLOWED = False 24 | 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | 28 | @classmethod 29 | def container_class(cls) -> type["FileContainer"]: 30 | return LocalFileContainer 31 | 32 | def _get_dataset_format(self): 33 | return ds.FeatherFileFormat() 34 | 35 | def write_data(self, df: pd.DataFrame, if_exists: str = "replace", **kwargs): 36 | self.handle_if_exists(if_exists) # will gracefully handle append as replace 37 | df.to_feather( 38 | self.container.path_str, 39 | storage_options=self.container.storage_options, 40 | ) 41 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/migration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | from mostlyai.engine._workspace import Workspace 18 | 19 | 20 | def migrate_workspace(workspace_dir: Path) -> None: 21 | workspace = Workspace(workspace_dir) 22 | # migrate min5/max5 in column stats to min/max (<= 4.5.6) 23 | for stats_pathdesc in [workspace.ctx_stats, workspace.tgt_stats]: 24 | stats = stats_pathdesc.read() 25 | if stats: 26 | for col, col_stats in stats.get("columns", {}).items(): 27 | if col_stats.get("min5") is not None: 28 | col_stats["min"] = min(col_stats["min5"]) if len(col_stats["min5"]) > 0 else None 29 | col_stats.pop("min5") 30 | if col_stats.get("max5") is not None: 31 | col_stats["max"] = max(col_stats["max5"]) if len(col_stats["max5"]) > 0 else None 32 | col_stats.pop("max5") 33 | stats_pathdesc.write(stats) 34 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | from pathlib import Path 16 | 17 | import typer 18 | 19 | from mostlyai.sdk._local.execution.jobs import execute_generation_job, execute_training_job 20 | 21 | cli = typer.Typer(pretty_exceptions_enable=False) 22 | 23 | 24 | @cli.command() 25 | def run_training(generator_id: str, home_dir: Path): 26 | # suppress any deprecation warnings 27 | with warnings.catch_warnings(): 28 | warnings.simplefilter("ignore", DeprecationWarning) 29 | execute_training_job(generator_id, home_dir) 30 | 31 | 32 | @cli.command() 33 | def run_generation(synthetic_dataset_id: str, home_dir: Path): 34 | # suppress any deprecation warnings 35 | with warnings.catch_warnings(): 36 | warnings.simplefilter("ignore", DeprecationWarning) 37 | execute_generation_job(synthetic_dataset_id, home_dir) 38 | 39 | 40 | def run_cli(): 41 | cli(standalone_mode=False) 42 | 43 | 44 | if __name__ == "__main__": 45 | run_cli() 46 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_create_data_report.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable 16 | from pathlib import Path 17 | 18 | from mostlyai.sdk._local.execution.step_create_model_report import create_report 19 | from mostlyai.sdk.domain import Generator, ModelType, StepCode 20 | 21 | 22 | def execute_step_create_data_report( 23 | *, 24 | generator: Generator, 25 | target_table_name: str, 26 | model_type: ModelType, 27 | workspace_dir: Path, 28 | report_credits: str = "", 29 | update_progress: Callable, 30 | ): 31 | # create model report and return metrics 32 | create_report( 33 | step_code=StepCode.create_data_report_tabular 34 | if model_type == ModelType.tabular 35 | else StepCode.create_data_report_language, 36 | generator=generator, 37 | workspace_dir=workspace_dir, 38 | model_type=model_type, 39 | target_table_name=target_table_name, 40 | report_credits=report_credits, 41 | update_progress=update_progress, 42 | ) 43 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-gpu.yaml: -------------------------------------------------------------------------------- 1 | name: '[GPU] mostlyai Tests' 2 | 3 | on: [workflow_call] 4 | 5 | env: 6 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 7 | FORCE_COLOR: '1' 8 | 9 | jobs: 10 | run-test-gpu: 11 | runs-on: gha-gpu-public 12 | container: 13 | image: nvidia/cuda:13.0.2-cudnn-runtime-ubuntu24.04 14 | options: --gpus all 15 | permissions: 16 | contents: read 17 | packages: write 18 | steps: 19 | - name: Setup | Install Git 20 | run: | 21 | export DEBIAN_FRONTEND=noninteractive 22 | ln -fs /usr/share/zoneinfo/Etc/UTC /etc/localtime 23 | echo "Etc/UTC" > /etc/timezone 24 | apt-get update -qq 25 | apt-get install -y --no-install-recommends git tzdata build-essential 26 | 27 | - name: Setup | Checkout 28 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 29 | with: 30 | fetch-depth: 1 31 | submodules: 'true' 32 | 33 | - name: Setup | uv 34 | uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v7.1.5 35 | with: 36 | enable-cache: false 37 | python-version: '3.10' 38 | 39 | - name: Setup | Dependencies 40 | run: | 41 | uv sync --frozen --only-group dev 42 | uv pip install ".[local-gpu]" 43 | 44 | - name: Setup | Check for available GPU-s 45 | run: nvidia-smi 46 | 47 | - name: Test | End-to-End Tests 48 | # client mode e2e test will be skipped when no extra environment variable is provided 49 | run: | 50 | uv run --no-sync pytest -vv tests/_local/end_to_end 51 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yaml: -------------------------------------------------------------------------------- 1 | name: 'mostlyai CI' 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened, synchronize, edited] 7 | 8 | jobs: 9 | pre-commit-check: 10 | if: | 11 | github.event_name == 'push' || 12 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 13 | uses: ./.github/workflows/pre-commit-check.yaml 14 | secrets: inherit 15 | run-tests-cpu: 16 | if: | 17 | github.event_name == 'push' || 18 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 19 | uses: ./.github/workflows/run-tests-cpu.yaml 20 | secrets: inherit 21 | run-tests-gpu: 22 | if: | 23 | ( 24 | github.event_name == 'push' || 25 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 26 | ) && 27 | ( 28 | github.ref == 'refs/heads/main' || 29 | startsWith(github.ref, 'refs/tags/') || 30 | contains(github.event.head_commit.message, '[gpu]') || 31 | contains(github.event.pull_request.title, '[gpu]') 32 | ) 33 | uses: ./.github/workflows/run-tests-gpu.yaml 34 | secrets: inherit 35 | build-docker-image: 36 | if: | 37 | ( 38 | github.event_name == 'push' || 39 | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) 40 | ) && 41 | ( 42 | github.ref == 'refs/heads/main' || 43 | startsWith(github.ref, 'refs/tags/') 44 | ) 45 | needs: [pre-commit-check, run-tests-cpu, run-tests-gpu] 46 | secrets: inherit 47 | uses: ./.github/workflows/build-docker-image.yaml 48 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/sqlite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import sqlalchemy as sa 18 | from sqlalchemy.dialects.sqlite.base import SQLiteDialect 19 | 20 | from mostlyai.sdk._data.db.base import DBDType, SqlAlchemyContainer, SqlAlchemyTable 21 | 22 | _LOG = logging.getLogger(__name__) 23 | 24 | 25 | class SqliteDType(DBDType): 26 | @classmethod 27 | def sa_dialect_class(cls): 28 | return SQLiteDialect 29 | 30 | 31 | class SqliteContainer(SqlAlchemyContainer): 32 | SCHEMES = ["sqlite"] 33 | SA_CONNECT_ARGS_ACCESS_ENGINE = {"timeout": 3} 34 | 35 | @property 36 | def sa_uri(self): 37 | return f"sqlite+pysqlite:///{self.dbname}" 38 | 39 | @classmethod 40 | def table_class(cls): 41 | return SqliteTable 42 | 43 | def _is_schema_exist(self): 44 | return True 45 | 46 | def does_database_exist(self) -> bool: 47 | return True 48 | 49 | 50 | class SqliteTable(SqlAlchemyTable): 51 | DATA_TABLE_TYPE = "sqlite" 52 | SA_RANDOM = sa.func.random() 53 | 54 | @classmethod 55 | def dtype_class(cls): 56 | return SqliteDType 57 | 58 | @classmethod 59 | def container_class(cls): 60 | return SqliteContainer 61 | -------------------------------------------------------------------------------- /tools/docker_entrypoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def main() -> None: 17 | """ 18 | Entrypoint for the Synthetic Data SDK Docker image. 19 | Can be called without any arguments which would start in a Local mode, running on port 8080. 20 | Alternatively, any arguments can be passed as key-value pairs and they will be used when initiating the MostlyAI class. 21 | """ 22 | from argparse import ArgumentParser 23 | from time import sleep 24 | 25 | parser = ArgumentParser(description="Synthetic Data SDK Docker Entrypoint") 26 | _, args = parser.parse_known_args() 27 | kwargs = {} 28 | for arg in args: 29 | if arg.startswith("--"): 30 | key, value = arg.lstrip("--").split("=", 1) 31 | kwargs[key] = value 32 | if len(kwargs) == 0: 33 | kwargs = {"local": True, "local_port": 8080} 34 | 35 | print("Startup may take a few seconds while libraries are being loaded...") 36 | 37 | from mostlyai.sdk import MostlyAI 38 | 39 | MostlyAI(**kwargs) 40 | 41 | try: 42 | while True: 43 | sleep(1) 44 | except KeyboardInterrupt: 45 | print("Shutting down...") 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: local 4 | hooks: 5 | - id: generate-license-header 6 | name: Generate temporary license header file 7 | entry: | 8 | bash -c ' 9 | HEADER_CONTENT="Copyright 2024 MOSTLY AI\n\ 10 | \n\ 11 | Licensed under the Apache License, Version 2.0 (the \"License\");\n\ 12 | you may not use this file except in compliance with the License.\n\ 13 | You may obtain a copy of the License at\n\ 14 | \n\ 15 | http://www.apache.org/licenses/LICENSE-2.0\n\ 16 | \n\ 17 | Unless required by applicable law or agreed to in writing, software\n\ 18 | distributed under the License is distributed on an \"AS IS\" BASIS,\n\ 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n\ 20 | See the License for the specific language governing permissions and\n\ 21 | limitations under the License." 22 | 23 | echo -e "$HEADER_CONTENT" > LICENSE_HEADER 24 | ' 25 | language: system 26 | - repo: https://github.com/Lucas-C/pre-commit-hooks 27 | rev: v1.5.5 28 | hooks: 29 | - id: insert-license 30 | files: \.py$ 31 | args: 32 | # - --remove-header 33 | - --license-filepath 34 | - LICENSE_HEADER 35 | - --use-current-year 36 | - repo: https://github.com/pre-commit/pre-commit-hooks 37 | rev: v5.0.0 38 | hooks: 39 | - id: end-of-file-fixer 40 | - id: trailing-whitespace 41 | - id: end-of-file-fixer 42 | - id: check-json 43 | - id: mixed-line-ending 44 | args: [--fix=lf] 45 | - repo: https://github.com/asottile/pyupgrade 46 | rev: v3.19.1 47 | hooks: 48 | - id: pyupgrade 49 | args: [--py310-plus] 50 | - repo: https://github.com/astral-sh/ruff-pre-commit 51 | rev: v0.11.6 52 | hooks: 53 | - id: ruff 54 | args: [--fix] 55 | - id: ruff-format 56 | -------------------------------------------------------------------------------- /.github/workflows/build-docker-image.yaml: -------------------------------------------------------------------------------- 1 | name: Build mostlyai Docker Image 2 | 3 | on: 4 | workflow_dispatch: 5 | workflow_call: 6 | 7 | env: 8 | PLATFORMS: linux/amd64 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | contents: read 15 | packages: write 16 | id-token: write 17 | steps: 18 | - name: Setup | Checkout 19 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 20 | with: {fetch-depth: 1} 21 | 22 | - name: Setup | Docker Buildx 23 | uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 24 | with: 25 | cache-binary: false 26 | cleanup: false 27 | platforms: ${{ env.PLATFORMS }} 28 | 29 | - name: Setup | Docker Build Metadata 30 | uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v.5.10.0 31 | id: meta 32 | with: 33 | images: ghcr.io/mostly-ai/sdk 34 | 35 | - name: Setup | Docker Login 36 | uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 37 | with: 38 | registry: ghcr.io 39 | username: ${{ github.actor }} 40 | password: ${{ secrets.GITHUB_TOKEN }} 41 | 42 | - name: Build | Build Docker Image 43 | uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0 44 | env: 45 | BUILDX_NO_DEFAULT_ATTESTATIONS: '1' 46 | DOCKER_BUILD_SUMMARY: 'false' 47 | DOCKER_BUILD_CHECKS_ANNOTATIONS: 'false' 48 | DOCKER_BUILD_RECORD_UPLOAD: 'false' 49 | with: 50 | platforms: ${{ env.PLATFORMS }} 51 | outputs: type=registry,compression=zstd,force-compression=true,oci-mediatypes=true,compression-level=9 52 | tags: ${{ steps.meta.outputs.tags }} 53 | labels: ${{ steps.meta.outputs.labels }} 54 | push: true 55 | -------------------------------------------------------------------------------- /tests/_data/unit/file/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import pytest 18 | 19 | 20 | @pytest.fixture 21 | def sample_csv_file(tmp_path): 22 | df = pd.DataFrame( 23 | { 24 | "id": [1, 2, 3, 4], 25 | "bool": [True, False, pd.NA, np.nan], 26 | "int": [1212, 2512, pd.NA, np.nan], 27 | "float": [0.1, -1.0, pd.NA, np.nan], 28 | "date": ["2020-04-04", "1877-05-05", pd.NA, np.nan], 29 | "ts_s": ["2020-04-04 14:14:14", "1877-05-05 01:01:01", pd.NA, np.nan], 30 | "ts_ns": ["2020-04-04 14:14:14.44", "1877-05-05 01:01:01.11", pd.NA, np.nan], 31 | "ts_tz": ["1999-09-15T09:37:50.871127Z", "1998-02-19T08:12:02.573302Z", pd.NA, np.nan], 32 | "text": ['This is a "quoted" text', "Row 日本", pd.NA, np.nan], 33 | "bigint": [1372636854620000520, 1372637091620000337, pd.NA, np.nan], 34 | } 35 | ) 36 | fn = tmp_path / "sample.csv" 37 | df.to_csv(fn, index=False) 38 | return fn 39 | 40 | 41 | @pytest.fixture 42 | def sample_parquet_file(tmp_path, sample_csv_file, request): 43 | df = pd.read_csv(sample_csv_file, engine="pyarrow", dtype_backend=request.param) 44 | fn = tmp_path / f"sample_{request.param}.parquet" 45 | df.to_parquet(fn, engine="pyarrow") 46 | return fn 47 | -------------------------------------------------------------------------------- /tests/_data/unit/file/test_feather.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import pandas as pd 17 | import pytest 18 | 19 | from mostlyai.sdk._data.dtype import ( 20 | is_boolean_dtype, 21 | is_date_dtype, 22 | is_float_dtype, 23 | is_integer_dtype, 24 | is_string_dtype, 25 | is_timestamp_dtype, 26 | ) 27 | from mostlyai.sdk._data.file.table.feather import FeatherDataTable 28 | 29 | 30 | @pytest.mark.parametrize("sample_parquet_file", ["pyarrow"], indirect=True) 31 | def test_read_write_data(tmp_path, sample_parquet_file): 32 | # write data 33 | sample_data = pd.read_parquet(sample_parquet_file, dtype_backend="pyarrow") 34 | table1 = FeatherDataTable(path=tmp_path / "sample.feather", is_output=True) 35 | table1.write_data(sample_data) 36 | # read data 37 | table2 = FeatherDataTable(path=tmp_path / "sample.feather") 38 | data = table2.read_data() 39 | # compare data 40 | assert data.shape == sample_data.shape 41 | assert is_integer_dtype(data["id"]) 42 | assert is_boolean_dtype(data["bool"]) 43 | assert is_integer_dtype(data["int"]) 44 | assert is_float_dtype(data["float"]) 45 | assert is_date_dtype(data["date"]) 46 | assert is_timestamp_dtype(data["ts_s"]) 47 | assert is_timestamp_dtype(data["ts_ns"]) 48 | assert is_timestamp_dtype(data["ts_tz"]) 49 | assert is_string_dtype(data["text"]) 50 | -------------------------------------------------------------------------------- /mostlyai/sdk/client/artifacts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from mostlyai.sdk.client.base import ( 17 | GET, 18 | PATCH, 19 | _MostlyBaseClient, 20 | ) 21 | from mostlyai.sdk.domain import ( 22 | Artifact, 23 | ArtifactPatchConfig, 24 | ) 25 | 26 | 27 | class _MostlyArtifactsClient(_MostlyBaseClient): 28 | SECTION = ["artifacts"] 29 | 30 | def get(self, artifact_id: str) -> Artifact: 31 | """ 32 | Retrieve artifact metadata including the shareable URL where the artifact can be viewed. 33 | Unauthenticated access is allowed to enable public sharing of artifacts. 34 | 35 | Args: 36 | artifact_id: The unique identifier of the artifact. 37 | 38 | Returns: 39 | The retrieved Artifact object. 40 | 41 | Example for retrieving an artifact: 42 | ```python 43 | from mostlyai.sdk import MostlyAI 44 | mostly = MostlyAI() 45 | art = mostly.artifacts.get("INSERT_YOUR_ARTIFACT_ID") 46 | art 47 | ``` 48 | """ 49 | response = self.request(verb=GET, path=[artifact_id], response_type=Artifact) 50 | return response 51 | 52 | def _update( 53 | self, 54 | artifact_id: str, 55 | config: ArtifactPatchConfig, 56 | ) -> Artifact: 57 | response = self.request( 58 | verb=PATCH, 59 | path=[artifact_id], 60 | json=config, 61 | exclude_none_in_json=True, 62 | response_type=Artifact, 63 | ) 64 | return response 65 | -------------------------------------------------------------------------------- /mostlyai/sdk/client/_naming_conventions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from collections.abc import Callable 17 | 18 | 19 | def _snake_to_camel(snake_str: str) -> str: 20 | components = snake_str.split("_") 21 | return components[0] + "".join(x.title() for x in components[1:]) 22 | 23 | 24 | def _camel_to_snake(camel_str: str) -> str: 25 | # handle acronyms by treating consecutive uppercase letters as a group 26 | s1 = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", camel_str) # e.g., "GPUTime" -> "GPU_Time" 27 | s2 = re.sub(r"([a-z\d])([A-Z])", r"\1_\2", s1) # e.g., "virtualGPU" -> "virtual_GPU" 28 | return s2.lower() 29 | 30 | 31 | def _convert_case(input_data: dict, conv_func: Callable[[str], str]) -> dict: 32 | if not isinstance(input_data, dict): 33 | return input_data 34 | 35 | new_dict = {} 36 | for key, value in input_data.items(): 37 | new_key = conv_func(key) 38 | # recursively convert nested dictionaries or lists 39 | if isinstance(value, dict): 40 | new_dict[new_key] = _convert_case(value, conv_func) 41 | elif isinstance(value, list): 42 | new_dict[new_key] = [_convert_case(item, conv_func) if isinstance(item, dict) else item for item in value] 43 | else: 44 | new_dict[new_key] = value 45 | return new_dict 46 | 47 | 48 | def map_snake_to_camel_case(input_dict: dict) -> dict: 49 | return _convert_case(input_dict, _snake_to_camel) 50 | 51 | 52 | def map_camel_to_snake_case(input_dict: dict) -> dict: 53 | return _convert_case(input_dict, _camel_to_snake) 54 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Synthetic Data SDK 2 | 3 | Thanks for your interest in contributing to Synthetic Data SDK! Follow these guidelines to set up your environment and streamline your contributions. 4 | 5 | ## Setup 6 | 7 | 1. **Clone the repository**: 8 | ```bash 9 | git clone https://github.com/mostly-ai/mostlyai.git 10 | cd mostlyai 11 | ``` 12 | If you don’t have direct write access to `mostlyai`, fork the repository first and clone your fork: 13 | ```bash 14 | git clone https://github.com//mostlyai.git 15 | cd mostlyai 16 | ``` 17 | 18 | 2. **Install `uv` (if not installed already)**: 19 | ```bash 20 | curl -LsSf https://astral.sh/uv/install.sh | sh 21 | ``` 22 | For alternative installation methods, visit the [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/). 23 | 24 | 3. **Create a virtual environment and install dependencies**: 25 | ```bash 26 | uv sync --frozen --extra local --python=3.10 # For CPU 27 | source .venv/bin/activate 28 | ``` 29 | If using GPU, run: 30 | ```bash 31 | uv sync --frozen --extra local-gpu --python=3.10 # For GPU support 32 | source .venv/bin/activate 33 | ``` 34 | 35 | 4. **Install pre-commit hooks**: 36 | ```bash 37 | pre-commit install 38 | ``` 39 | 40 | ## Development Workflow 41 | 42 | 1. **Ensure your local `main` branch is up to date**: 43 | ```bash 44 | git checkout main 45 | git reset --hard origin/main 46 | git pull origin main 47 | ``` 48 | 49 | 2. **Create a new feature or bugfix branch**: 50 | ```bash 51 | git checkout -b my-feature-branch 52 | ``` 53 | 54 | 3. **Implement your changes.** 55 | 56 | 4. **Run tests and pre-commit hooks**: 57 | ```bash 58 | pytest 59 | pre-commit run 60 | ``` 61 | 62 | 5. **Commit your changes with a descriptive message**: 63 | ```bash 64 | git add . 65 | git commit -m "feat: add a clear description of your feature" 66 | ``` 67 | Follow the [Conventional Commits](https://gist.github.com/qoomon/5dfcdf8eec66a051ecd85625518cfd13) format. 68 | 69 | 6. **Push your changes**: 70 | ```bash 71 | git push origin my-feature-branch 72 | ``` 73 | 74 | 7. **Open a pull request on GitHub.** 75 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "mostlyai" 2 | site_url: "https://mostly-ai.github.io/mostlyai/" 3 | repo_url: "https://github.com/mostly-ai/mostlyai" 4 | repo_name: "mostly-ai/mostlyai" 5 | 6 | theme: 7 | name: material 8 | logo: logo.png 9 | favicon: favicon.png 10 | font: 11 | text: Lato 12 | features: 13 | - navigation.top 14 | - navigation.tracking 15 | - navigation.tabs 16 | - navigation.tabs.sticky 17 | - content.code.select 18 | - content.code.copy 19 | - navigation.footer 20 | 21 | palette: 22 | - scheme: default 23 | toggle: 24 | icon: material/brightness-7 25 | name: Switch to dark mode 26 | - scheme: slate 27 | toggle: 28 | icon: material/brightness-2 29 | name: Switch to light mode 30 | 31 | nav: 32 | - Getting Started: index.md 33 | - Usage Examples: usage.md 34 | - API Reference: api_client.md 35 | - Schema Reference: api_domain.md 36 | - Cheat Sheet: syntax.md 37 | - Tutorials: tutorials.md 38 | 39 | plugins: 40 | - search 41 | - mkdocstrings: 42 | handlers: 43 | python: 44 | options: 45 | heading_level: 3 46 | show_root_toc_entry: false 47 | show_root_heading: false 48 | show_object_full_path: true 49 | show_bases: false 50 | show_docstring: true 51 | show_source: false 52 | show_signature: true 53 | separate_signature: true 54 | show_docstring_examples: true 55 | docstring_section_style: table 56 | extensions: 57 | - griffe_fieldz 58 | docstring_style: google 59 | - llmstxt: 60 | full_output: llms-full.txt 61 | sections: 62 | Getting started: 63 | - index.md 64 | Usage Examples: 65 | - usage.md 66 | API Reference: 67 | - api_client.md 68 | Schema Reference: 69 | - api_domain.md 70 | Cheat Sheet: 71 | - syntax.md 72 | 73 | markdown_extensions: 74 | - pymdownx.highlight: 75 | anchor_linenums: true 76 | line_spans: __span 77 | pygments_lang_class: true 78 | - pymdownx.inlinehilite 79 | - pymdownx.snippets 80 | - pymdownx.superfences 81 | - toc: 82 | permalink: true 83 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/table/parquet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pyarrow as pa 17 | import pyarrow.dataset as ds 18 | 19 | from mostlyai.sdk._data.file.base import FileContainer, FileDataTable, LocalFileContainer 20 | 21 | 22 | class ParquetDataTable(FileDataTable): 23 | DATA_TABLE_TYPE = "parquet" 24 | IS_WRITE_APPEND_ALLOWED = False 25 | 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | 29 | @classmethod 30 | def container_class(cls) -> type["FileContainer"]: 31 | return LocalFileContainer 32 | 33 | def _get_dataset_format(self): 34 | return ds.ParquetFileFormat() 35 | 36 | def get_columns(self, exclude_complex_types=True): 37 | idx_columns = ( 38 | self.dataset.schema.pandas_metadata.get("index_columns", []) 39 | if self.dataset.schema and isinstance(self.dataset.schema.pandas_metadata, dict) 40 | else [] 41 | ) 42 | excluded = (pa.ListType, pa.StructType, pa.MapType, pa.UnionType) if exclude_complex_types else () 43 | columns = [ 44 | c.name for c in self.dataset.schema if not isinstance(c.type, excluded) and (c.name not in idx_columns) 45 | ] 46 | return columns 47 | 48 | def _get_columns(self): 49 | return self.get_columns(exclude_complex_types=True) 50 | 51 | def write_data(self, df: pd.DataFrame, if_exists: str = "replace", **kwargs): 52 | self.handle_if_exists(if_exists) # will gracefully handle append as replace 53 | df.to_parquet( 54 | self.container.path_str, 55 | storage_options=self.container.storage_options, 56 | index=False, 57 | ) 58 | -------------------------------------------------------------------------------- /tests/_data/unit/db/test_db.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pandas as pd 16 | import pytest 17 | 18 | from mostlyai.sdk._data.db.sqlite import SqliteContainer, SqliteTable 19 | 20 | 21 | @pytest.fixture() 22 | def temp_table(tmp_path): 23 | container = SqliteContainer(dbname=str(tmp_path / "database.db")) 24 | return SqliteTable(name="data", container=container, is_output=True) 25 | 26 | 27 | def test_count_rows(temp_table): 28 | df = pd.DataFrame({"id": [1, 2, 3]}) 29 | temp_table.write_data(df, if_exists="replace") 30 | assert temp_table.row_count == df.shape[0] 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "write_chunk_size", 35 | [0, 100], 36 | ) 37 | def test_write_data_empty(temp_table, write_chunk_size): 38 | df = pd.DataFrame({"id": [], "col": []}) 39 | temp_table.WRITE_CHUNK_SIZE = write_chunk_size 40 | temp_table.write_data(df, if_exists="replace") 41 | df_read = temp_table.read_data() 42 | assert df_read.empty 43 | 44 | 45 | def test_name_in_db_differs_from_logical_name(temp_table): 46 | # create a table with actual name "users_v2" but logical name "users" 47 | temp_table.name = "users_v2" 48 | df = pd.DataFrame({"name": ["alice", "bob", "charlie"]}) 49 | temp_table.write_data(df, if_exists="replace") 50 | 51 | # now access with different logical name but same name_in_db 52 | table_with_alias = SqliteTable(name="users", container=temp_table.container, name_in_db="users_v2", is_output=False) 53 | df_read = table_with_alias.read_data() 54 | 55 | assert df_read.shape == df.shape 56 | pd.testing.assert_frame_equal( 57 | df_read.sort_values("name").reset_index(drop=True), 58 | df.sort_values("name").reset_index(drop=True), 59 | check_dtype=False, 60 | ) 61 | -------------------------------------------------------------------------------- /mostlyai/sdk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Synthetic Data SDK - Python toolkit for high-fidelity, privacy-safe synthetic data. 16 | 17 | The SDK supports two operating modes: 18 | - LOCAL mode: Train and generate synthetic data locally on your own compute resources 19 | - CLIENT mode: Connect to a remote MOSTLY AI platform for training & generation 20 | 21 | Key Resources 22 | ------------- 23 | The SDK manages four core resources: 24 | 25 | 1. **Generators** - Train synthetic data generators on tabular or language data 26 | 2. **Synthetic Datasets** - Create synthetic samples from trained generators 27 | 3. **Connectors** - Connect to data sources (databases, cloud storage) 28 | 4. **Datasets** - Create datasets with instructions (CLIENT mode only) 29 | 30 | Core Operations 31 | --------------- 32 | - `mostly.train(config)` - Train a generator on tabular or language data 33 | - `mostly.generate(g, config)` - Generate synthetic data records 34 | - `mostly.probe(g, config)` - Live probe the generator on demand 35 | - `mostly.connect(config)` - Connect to external data sources 36 | 37 | Key Features 38 | ------------ 39 | - Broad data support (mixed-type, single/multi-table, time-series) 40 | - Multiple model types (TabularARGN, Hugging Face LLMs, LSTM) 41 | - Advanced training options (GPU/CPU, Differential Privacy) 42 | - Automated quality assurance with metrics and HTML reports 43 | - Flexible sampling (up-sampling, conditional simulation, rebalancing) 44 | - Seamless integration with external data sources 45 | 46 | For more information, visit: https://mostly-ai.github.io/mostlyai/ 47 | """ 48 | 49 | from mostlyai.sdk.client.api import MostlyAI 50 | 51 | __all__ = ["MostlyAI"] 52 | __version__ = "5.9.2" # Do not set this manually. Use poetry version [params]. 53 | -------------------------------------------------------------------------------- /tests/_data/unit/db/test_types_coercion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import datetime 16 | 17 | import pandas as pd 18 | from sqlalchemy.dialects import postgresql 19 | 20 | from mostlyai.sdk._data.db.types_coercion import coerce_to_sql_dtype 21 | from mostlyai.sdk._data.dtype import BOOL, FLOAT64, INT64, STRING 22 | 23 | 24 | def test_coerce_to_sql_dtype(): 25 | # this test is primarily to check if appropriate function was selected for coercion 26 | 27 | def _assert_coerced(in_data, in_sql_dtype, out_data, out_pd_dtype): 28 | s = coerce_to_sql_dtype(pd.Series(in_data), in_sql_dtype) 29 | pd.testing.assert_series_equal(s, pd.Series(out_data, dtype=out_pd_dtype)) 30 | 31 | _assert_coerced( 32 | in_data=["a", True], 33 | in_sql_dtype=postgresql.BOOLEAN(), 34 | out_data=[pd.NA, True], 35 | out_pd_dtype=BOOL, 36 | ) 37 | _assert_coerced( 38 | in_data=["a", "2020-01-01"], 39 | in_sql_dtype=postgresql.DATE(), 40 | out_data=[pd.NA, "2020-01-01"], 41 | out_pd_dtype="datetime64[ns]", 42 | ) 43 | _assert_coerced( 44 | in_data=["a", 1.0], 45 | in_sql_dtype=postgresql.FLOAT(), 46 | out_data=[pd.NA, 1.0], 47 | out_pd_dtype=FLOAT64, 48 | ) 49 | _assert_coerced( 50 | in_data=["a", 1], 51 | in_sql_dtype=postgresql.INTEGER(), 52 | out_data=[pd.NA, 1], 53 | out_pd_dtype=INT64, 54 | ) 55 | _assert_coerced( 56 | in_data=["abcde", "abc"], 57 | in_sql_dtype=postgresql.VARCHAR(4), 58 | out_data=["abcd", "abc"], 59 | out_pd_dtype=STRING, 60 | ) 61 | _assert_coerced( 62 | in_data=["abcde", "10:20"], 63 | in_sql_dtype=postgresql.TIME(), 64 | out_data=[pd.NA, datetime.time(10, 20)], 65 | out_pd_dtype="object", 66 | ) 67 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/pull_context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import time 17 | from pathlib import Path 18 | from typing import Any 19 | 20 | from mostlyai.sdk._data.base import Schema 21 | from mostlyai.sdk._data.progress_callback import ProgressCallbackWrapper 22 | from mostlyai.sdk._data.pull_utils import handle_workspace_dir, pull_split 23 | from mostlyai.sdk.domain import ModelType 24 | 25 | _LOG = logging.getLogger(__name__) 26 | 27 | 28 | def pull_context( 29 | *, 30 | tgt: str, 31 | schema: Schema, 32 | max_sample_size: int | None = None, 33 | model_type: str | ModelType = ModelType.tabular, 34 | workspace_dir: str | Path = "engine-ws", 35 | ): 36 | t0 = time.time() 37 | workspace_dir = Path(workspace_dir) 38 | model_type = ModelType(model_type) 39 | if tgt not in schema.tables: 40 | raise ValueError(f"table '{tgt}' not defined in schema") 41 | schema.preprocess_schema_before_pull() 42 | # gather context_tables 43 | context_tables = schema.get_context_tables(tgt) 44 | _LOG.info(f"context_tables (size: {len(context_tables)}): {context_tables}") 45 | # handle workspace_dir 46 | workspace_dir = handle_workspace_dir(workspace_dir=workspace_dir) 47 | # ensure that max_sample_size is a positive integer, if given 48 | if max_sample_size is not None: 49 | max_sample_size = max(1, max_sample_size) 50 | # log arguments 51 | _LOG.info(f"tgt: {tgt}") 52 | _LOG.info(f"model_type: {model_type}") 53 | _LOG.info(f"max_sample_size: {max_sample_size}") 54 | 55 | def update_progress(*args: Any, **kwargs: Any) -> None: ... 56 | 57 | pull_split( 58 | tgt=tgt, 59 | schema=schema, 60 | trn_val_split=None, 61 | model_type=model_type, 62 | do_ctx_only=True, 63 | workspace_dir=workspace_dir, 64 | progress=ProgressCallbackWrapper(update_progress=update_progress), 65 | ) 66 | 67 | _LOG.info(f"pull_context total time: {time.time() - t0:.2f}s") 68 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/progress_callback.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable 16 | from functools import partial 17 | from typing import Protocol 18 | 19 | from rich.progress import Progress 20 | 21 | 22 | class ProgressCallback(Protocol): 23 | def __call__( 24 | self, 25 | total: int | None = None, 26 | completed: int | None = None, 27 | advance: int | None = None, 28 | **kwargs, 29 | ) -> None: ... 30 | 31 | 32 | class ProgressCallbackWrapper: 33 | @staticmethod 34 | def _wrap_progress_callback( 35 | update_progress: ProgressCallback | None = None, **kwargs 36 | ) -> tuple[ProgressCallback, Callable]: 37 | if not update_progress: 38 | rich_progress = Progress() 39 | rich_progress.start() 40 | task_id = rich_progress.add_task(**kwargs) 41 | update_progress = partial(rich_progress.update, task_id=task_id) 42 | else: 43 | rich_progress = None 44 | 45 | def teardown_progress(): 46 | if rich_progress: 47 | rich_progress.refresh() 48 | rich_progress.stop() 49 | 50 | return update_progress, teardown_progress 51 | 52 | def update( 53 | self, 54 | total: int | None = None, 55 | completed: int | None = None, 56 | advance: int | None = None, 57 | **kwargs, 58 | ) -> None: 59 | self._update_progress(total=total, completed=completed, advance=advance, **kwargs) 60 | 61 | def __init__(self, update_progress: ProgressCallback | None = None, **kwargs): 62 | self._update_progress, self._teardown_progress = self._wrap_progress_callback(update_progress, **kwargs) 63 | 64 | def __enter__(self): 65 | self._update_progress(completed=0, total=1) 66 | return self 67 | 68 | def __exit__(self, exc_type, exc_value, traceback): 69 | if exc_type is None: 70 | self._update_progress(completed=1, total=1) 71 | self._teardown_progress() 72 | -------------------------------------------------------------------------------- /docs/tutorials/multi-table/migrate-sqlite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # helper script to convert CSV files into a single SQLite database file 16 | 17 | from pathlib import Path 18 | 19 | import pandas as pd 20 | import sqlalchemy as sa 21 | 22 | 23 | def create_table_sql(df, tbl_name): 24 | cols = [] 25 | pks = [] 26 | fks = [] 27 | for c in df.columns: 28 | if c == f"{tbl_name}_id": 29 | pks.append(f'PRIMARY KEY("{c}")') 30 | elif c.endswith("_id"): 31 | fks.append(f'FOREIGN KEY("{c}") REFERENCES "{c[:-3]}"("{c}")') 32 | dtype = df[c].dtype 33 | if pd.api.types.is_integer_dtype(dtype): 34 | ctype = "BIGINT" 35 | elif pd.api.types.is_float_dtype(dtype): 36 | ctype = "FLOAT" 37 | elif pd.api.types.is_datetime64_any_dtype(dtype): 38 | ctype = "DATETIME" 39 | else: 40 | ctype = "TEXT" 41 | cols.append(f'"{c}" {ctype}') 42 | stmt = f'CREATE TABLE "{tbl_name}" ({", ".join(cols + pks + fks)})' 43 | return stmt 44 | 45 | 46 | engine = sa.create_engine("sqlite+pysqlite:///berka-sqlite.db", echo=False) 47 | 48 | data = {} 49 | for fn in Path(".").glob("*.csv"): 50 | df = pd.read_csv(fn) 51 | # convert dtypes 52 | for col in df.columns: 53 | if col in ["date", "issued"]: 54 | df[col] = pd.to_datetime(df[col]) 55 | if col.endswith("_id"): 56 | df[col] = df[col].astype(str) 57 | # get filename w/o extension 58 | tbl_name = fn.stem 59 | data[tbl_name] = df 60 | 61 | with engine.connect() as conn: 62 | for tbl_name, df in data.items(): 63 | # create table 64 | stmt = create_table_sql(df, tbl_name) 65 | conn.execute(sa.text(stmt)) 66 | print(f"created table {tbl_name}") 67 | conn.commit() 68 | conn.close() 69 | 70 | with engine.connect() as conn: 71 | for tbl_name, df in data.items(): 72 | # insert records 73 | df.to_sql(tbl_name, conn, index=False, if_exists="append") 74 | print(f"loaded data to {tbl_name}") 75 | conn.commit() 76 | conn.close() 77 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/util/kerberos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from datetime import datetime 17 | from io import StringIO 18 | 19 | import pandas as pd 20 | from dateutil.parser import parse 21 | 22 | 23 | def is_kerberos_ticket_alive(klist_result: str, service_principal: str): 24 | # Extract the relevant part of the data that contains the tickets 25 | tickets_section_match = re.search( 26 | r"((Issued\s+Expires\s+Principal)|(Valid starting\s+Expires\s+Service principal))\n(.*)", 27 | klist_result, 28 | re.DOTALL, 29 | ) 30 | if not tickets_section_match: 31 | return False 32 | 33 | tickets_data = tickets_section_match.group(4) 34 | 35 | # Read the tickets data using pandas 36 | try: 37 | df = pd.read_csv( 38 | StringIO(tickets_data), 39 | sep=r"\s+", 40 | header=None, 41 | dtype="str", 42 | ) 43 | df.columns = [str(c) for c in df.columns] 44 | except pd.errors.EmptyDataError: 45 | return False 46 | 47 | # Convert dates to datetime objects using dateutil.parser.parse 48 | def safe_parse(x): 49 | try: 50 | return parse(x, fuzzy=True) 51 | except (ValueError, TypeError): 52 | return None 53 | 54 | # equally split the columns into issued and expires except for last column which is principal 55 | df.columns.values[-1] = "principal" 56 | n_date_columns = (df.shape[1] - 1) // 2 57 | issued = df.iloc[:, :n_date_columns].apply(lambda x: " ".join(x.values), axis=1) 58 | expires = df.iloc[:, n_date_columns:-1].apply(lambda x: " ".join(x.values), axis=1) 59 | 60 | df["issued"] = issued.apply(safe_parse) 61 | df["expires"] = expires.apply(safe_parse) 62 | 63 | # Drop rows where date parsing failed 64 | df.dropna(subset=["issued", "expires"], inplace=True) 65 | 66 | # Check if the principal's ticket has not expired 67 | now = datetime.now() 68 | 69 | # Check if the ticket for a specific service principal is alive 70 | is_alive = any((df["principal"] == service_principal) & (df["expires"] > now + pd.Timedelta(seconds=60))) 71 | 72 | return is_alive 73 | -------------------------------------------------------------------------------- /tools/extend_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ast 16 | import re 17 | 18 | 19 | def extract_class_contents(filename): 20 | with open(filename) as file: 21 | source = file.read() 22 | 23 | parsed_source = ast.parse(source) 24 | classes = [node for node in parsed_source.body if isinstance(node, ast.ClassDef)] 25 | 26 | class_contents = {} 27 | 28 | for class_node in classes: 29 | start_line = class_node.lineno 30 | end_line = class_node.end_lineno 31 | 32 | # Extract lines and handle skipping 33 | class_lines = [] 34 | skip = False 35 | for line in source.splitlines()[start_line:end_line]: 36 | if line.strip() == "# skip": 37 | skip = True 38 | continue 39 | if line.strip() == "# /skip": 40 | skip = False 41 | continue 42 | if not skip: 43 | class_lines.append(line) 44 | 45 | class_contents[class_node.name] = "\n".join(class_lines) 46 | 47 | return class_contents 48 | 49 | 50 | def append_or_replace_in_jinja_template(template_filename, classes_content): 51 | with open(template_filename) as file: 52 | template_content = file.read() 53 | 54 | for class_name, content in classes_content.items(): 55 | class_block_pattern = r"{%- if class_name == \"" + re.escape(class_name) + r"\" %}.*?{%- endif %}" 56 | new_block = '{%- if class_name == "' + class_name + '" %}\n' + content + "\n{%- endif %}" 57 | 58 | # Check if class block exists 59 | if re.search(class_block_pattern, template_content, re.DOTALL): 60 | # Replace existing block 61 | template_content = re.sub(class_block_pattern, new_block, template_content, flags=re.DOTALL) 62 | else: 63 | # Append new block 64 | template_content += new_block 65 | 66 | # Write back to the template file 67 | with open(template_filename, "w") as file: 68 | file.write(template_content) 69 | 70 | 71 | # Example Usage 72 | source_filename = "tools/model.py" 73 | template_filename = "tools/custom_template/pydantic_v2/BaseModel.jinja2" 74 | classes_content = extract_class_contents(source_filename) 75 | append_or_replace_in_jinja_template(template_filename, classes_content) 76 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-cpu.yaml: -------------------------------------------------------------------------------- 1 | name: '[CPU] mostlyai Tests' 2 | 3 | on: [workflow_call] 4 | 5 | env: 6 | PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring 7 | FORCE_COLOR: '1' 8 | 9 | jobs: 10 | run-test-cpu-local: 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: read 14 | packages: write 15 | steps: 16 | - name: Setup | Checkout 17 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 18 | with: 19 | fetch-depth: 1 20 | submodules: 'true' 21 | 22 | - name: Setup | uv 23 | uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v7.1.5 24 | with: 25 | enable-cache: false 26 | python-version: '3.10' 27 | 28 | - name: Setup | Dependencies 29 | run: | 30 | uv sync --frozen --only-group dev 31 | uv pip install --index-strategy unsafe-first-match torch==2.8.0+cpu torchvision==0.23.0+cpu ".[local]" --extra-index-url https://download.pytorch.org/whl/cpu 32 | 33 | - name: Test | End-to-End (Local mode only) 34 | run: | 35 | uv run --no-sync pytest -vv tests/_local/end_to_end -k 'not (client and mode)' 36 | 37 | - name: Test | Unit Tests 38 | run: | 39 | uv run --no-sync pytest -vv tests/client/unit 40 | uv run --no-sync pytest -vv tests/_data/unit 41 | uv run --no-sync pytest -vv tests/_local/unit 42 | uv run --no-sync pytest -vv tests/test_domain.py 43 | 44 | run-test-cpu-client: 45 | if: false 46 | runs-on: ubuntu-latest 47 | permissions: 48 | contents: read 49 | packages: write 50 | steps: 51 | - name: Setup | Checkout 52 | uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 53 | with: 54 | fetch-depth: 1 55 | submodules: 'true' 56 | 57 | - name: Setup | uv 58 | uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # v7.1.5 59 | with: 60 | enable-cache: false 61 | python-version: '3.10' 62 | 63 | - name: Setup | Dependencies 64 | run: | 65 | uv sync --frozen --only-group dev 66 | uv pip install --index-strategy unsafe-first-match torch==2.8.0+cpu torchvision==0.23.0+cpu ".[local]" --extra-index-url https://download.pytorch.org/whl/cpu 67 | 68 | - name: Test | End-to-End (Client mode only) 69 | env: 70 | MOSTLY_API_KEY: ${{ secrets.E2E_CLIENT_MOSTLY_API_KEY }} 71 | MOSTLY_BASE_URL: ${{ secrets.E2E_CLIENT_MOSTLY_BASE_URL }} 72 | E2E_CLIENT_S3_ACCESS_KEY: ${{ secrets.E2E_CLIENT_S3_ACCESS_KEY }} 73 | E2E_CLIENT_S3_SECRET_KEY: ${{ secrets.E2E_CLIENT_S3_SECRET_KEY }} 74 | E2E_CLIENT_S3_BUCKET: ${{ secrets.E2E_CLIENT_S3_BUCKET }} 75 | run: | 76 | uv run --no-sync pytest -vv tests/_local/end_to_end -k 'client and mode' 77 | -------------------------------------------------------------------------------- /tests/_data/unit/file/test_json.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import pandas as pd 17 | import pyarrow as pa 18 | import pytest 19 | 20 | from mostlyai.sdk._data.dtype import ( 21 | is_boolean_dtype, 22 | is_float_dtype, 23 | is_integer_dtype, 24 | is_string_dtype, 25 | is_timestamp_dtype, 26 | ) 27 | from mostlyai.sdk._data.file.table.json import JsonDataTable 28 | 29 | 30 | @pytest.mark.parametrize("file_name", ["sample.json", "sample.json.gz"]) 31 | @pytest.mark.parametrize("sample_parquet_file", ["pyarrow"], indirect=True) 32 | def test_read_write_data(tmp_path, sample_parquet_file, file_name): 33 | sample_data = pd.read_parquet(sample_parquet_file, dtype_backend="pyarrow") 34 | sample_data["date"] = sample_data["date"].astype(pd.ArrowDtype(pa.date32())) 35 | # write 36 | table1 = JsonDataTable(path=tmp_path / file_name, is_output=True) 37 | table1.write_data(sample_data) 38 | # read 39 | table2 = JsonDataTable(path=tmp_path / file_name, is_output=False) 40 | data = table2.read_data() 41 | chunk_df = next(table2.read_chunks()) # assuming a single chunk, due to a small size 42 | # compare data 43 | assert data.shape == sample_data.shape 44 | assert chunk_df.shape == sample_data.shape 45 | # check dtypes from data 46 | assert is_integer_dtype(data["id"]) 47 | assert is_boolean_dtype(data["bool"]) 48 | assert is_integer_dtype(data["int"]) 49 | assert is_float_dtype(data["float"]) 50 | assert is_timestamp_dtype(data["date"]) 51 | assert is_timestamp_dtype(data["ts_s"]) 52 | assert is_timestamp_dtype(data["ts_ns"]) 53 | assert is_timestamp_dtype(data["ts_tz"]) 54 | assert is_string_dtype(data["text"]) 55 | # check dtypes from meta-data 56 | dtypes = table2.dtypes 57 | assert pd.api.types.is_bool_dtype(dtypes["bool"].wrapped) 58 | assert pd.api.types.is_integer_dtype(dtypes["int"].wrapped) 59 | assert pd.api.types.is_float_dtype(dtypes["float"].wrapped) 60 | assert pd.api.types.is_datetime64_any_dtype(dtypes["date"].wrapped) 61 | assert pd.api.types.is_datetime64_any_dtype(dtypes["ts_s"].wrapped) 62 | assert pd.api.types.is_datetime64_any_dtype(dtypes["ts_ns"].wrapped) 63 | assert pd.api.types.is_datetime64_any_dtype(dtypes["ts_tz"].wrapped) 64 | assert pd.api.types.is_string_dtype(dtypes["text"].wrapped) 65 | -------------------------------------------------------------------------------- /tests/_local/unit/test_migration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | from mostlyai.engine._workspace import Workspace 18 | from mostlyai.sdk._local.execution.migration import migrate_workspace 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "ctx_stats, tgt_stats", 23 | [ 24 | ( 25 | None, 26 | { 27 | "columns": { 28 | "datetime": { 29 | "min5": [f"2024-01-0{i}" for i in range(1, 6)], 30 | "max5": [f"2025-01-0{i}" for i in range(1, 6)], 31 | } 32 | } 33 | }, 34 | ), 35 | ( 36 | { 37 | "columns": { 38 | "int": { 39 | "min5": [i for i in range(1, 6)], 40 | "max5": [100 + i for i in range(1, 6)], 41 | }, 42 | "float": { 43 | "min5": [], 44 | "max5": [], 45 | }, 46 | } 47 | }, 48 | { 49 | "columns": { 50 | "datetime": { 51 | "min5": [f"2024-01-0{i}" for i in range(1, 6)], 52 | "max5": [f"2025-01-0{i}" for i in range(1, 6)], 53 | } 54 | } 55 | }, 56 | ), 57 | ], 58 | ) 59 | def test_migrate_workspace(tmp_path, ctx_stats, tgt_stats): 60 | workspace_dir = tmp_path / "ModelStore" 61 | workspace = Workspace(workspace_dir) 62 | if ctx_stats: 63 | workspace.ctx_stats.write(ctx_stats) 64 | workspace.tgt_stats.write(tgt_stats) 65 | migrate_workspace(workspace_dir) 66 | if ctx_stats: 67 | migrated_ctx_stats = workspace.ctx_stats.read() 68 | assert migrated_ctx_stats["columns"]["int"]["min"] == 1 69 | assert migrated_ctx_stats["columns"]["int"]["max"] == 105 70 | assert migrated_ctx_stats["columns"]["float"]["min"] is None 71 | assert migrated_ctx_stats["columns"]["float"]["max"] is None 72 | else: 73 | assert not workspace.ctx_stats.path.exists() 74 | migrated_tgt_stats = workspace.tgt_stats.read() 75 | assert migrated_tgt_stats["columns"]["datetime"]["min"] == "2024-01-01" 76 | assert migrated_tgt_stats["columns"]["datetime"]["max"] == "2025-01-05" 77 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_train_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | from collections.abc import Callable 16 | from pathlib import Path 17 | 18 | from mostlyai.sdk.domain import Generator, ModelType 19 | 20 | _LOG = logging.getLogger(__name__) 21 | 22 | 23 | def execute_step_train_model( 24 | *, 25 | generator: Generator, 26 | model_type: ModelType, 27 | target_table_name: str, 28 | restarts: int, 29 | workspace_dir: Path, 30 | update_progress: Callable, 31 | upload_model_data_callback: Callable | None, 32 | ): 33 | # import ENGINE here to avoid pre-mature loading of large ENGINE dependencies 34 | from mostlyai import engine 35 | from mostlyai.engine.domain import DifferentialPrivacyConfig, ModelStateStrategy 36 | 37 | _LOG.info(f"mostlyai-engine: {engine.__version__}") 38 | 39 | # fetch model_config 40 | tgt_table = next(t for t in generator.tables if t.name == target_table_name) 41 | if model_type == ModelType.language: 42 | model_config = tgt_table.language_model_configuration 43 | else: 44 | model_config = tgt_table.tabular_model_configuration 45 | 46 | # convert from SDK domain to ENGINE domain 47 | if model_config.differential_privacy: 48 | differential_privacy = DifferentialPrivacyConfig(**model_config.differential_privacy.model_dump()) 49 | else: 50 | differential_privacy = None 51 | 52 | # ensure disallowed arguments are set to None 53 | if model_type == ModelType.language: 54 | max_sequence_window = None 55 | else: # model_type == ModelType.tabular 56 | max_sequence_window = model_config.max_sequence_window 57 | 58 | # call TRAIN 59 | engine.train( 60 | model=model_config.model, 61 | max_training_time=model_config.max_training_time, 62 | max_epochs=model_config.max_epochs, 63 | batch_size=model_config.batch_size, 64 | gradient_accumulation_steps=model_config.gradient_accumulation_steps, 65 | enable_flexible_generation=model_config.enable_flexible_generation, 66 | max_sequence_window=max_sequence_window, 67 | differential_privacy=differential_privacy, 68 | model_state_strategy=ModelStateStrategy.resume if restarts > 0 else ModelStateStrategy.reuse, 69 | workspace_dir=workspace_dir, 70 | upload_model_data_callback=upload_model_data_callback, 71 | update_progress=update_progress, 72 | ) 73 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/container/gcs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from typing import Any 17 | 18 | import duckdb 19 | import gcsfs 20 | from cloudpathlib.gs import GSClient, GSPath 21 | from google.cloud import storage 22 | 23 | from mostlyai.sdk._data.exceptions import MostlyDataException 24 | from mostlyai.sdk._data.file.container.bucket_based import BucketBasedContainer 25 | from mostlyai.sdk._data.util.common import validate_gcs_key_file 26 | 27 | _LOG = logging.getLogger(__name__) 28 | 29 | 30 | class GcsContainer(BucketBasedContainer): 31 | SCHEMES = ["http", "https", "gs"] 32 | DEFAULT_SCHEME = "gs" 33 | DELIMITER_SCHEMA = "gs" 34 | SECRET_ATTR_NAME = "key_file" 35 | 36 | def __init__(self, *args, key_file, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | self.key_file = key_file 39 | self.decrypt_secret() 40 | 41 | if not self.key_file: 42 | raise MostlyDataException("Provide the key file.") 43 | try: 44 | self.client = storage.Client.from_service_account_info(self.key_file) 45 | self.fs = gcsfs.GCSFileSystem(project=self.key_file["project_id"], token=self.key_file) 46 | self._client = GSClient(project=self.key_file["project_id"], storage_client=self.client) 47 | except Exception as e: 48 | error_message = str(e).lower() 49 | if "unsupported algorithm" in error_message: 50 | raise MostlyDataException("Key file is incorrect.") 51 | 52 | def decrypt_secret(self, secret_attr_name: str | None = None) -> None: 53 | super().decrypt_secret() 54 | self.key_file = validate_gcs_key_file(self.key_file) 55 | if not self.key_file: 56 | raise MostlyDataException("Key file is incorrect.") 57 | 58 | @classmethod 59 | def cloud_path_cls(cls): 60 | return GSPath 61 | 62 | @property 63 | def storage_options(self) -> dict: 64 | return self.fs.storage_options 65 | 66 | @property 67 | def transport_params(self) -> dict | None: 68 | return dict(client=self.client) 69 | 70 | @property 71 | def file_system(self) -> Any: 72 | return self.fs 73 | 74 | def _check_authenticity(self) -> bool: 75 | return gcsfs.GCSFileSystem(project=self.key_file["project_id"], token=self.key_file) is not None 76 | 77 | def _init_duckdb(self, con: duckdb.DuckDBPyConnection) -> None: 78 | # register the GCS filesystem with DuckDB using service account credentials 79 | # this is an alternative to HMAC keys, which we don't use 80 | con.register_filesystem(self.fs) 81 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/types_coercion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Provides function for data coercion based on SQL dtype. 17 | """ 18 | 19 | import datetime 20 | import decimal 21 | import logging 22 | from collections import defaultdict 23 | from typing import Any 24 | 25 | import pandas as pd 26 | from sqlalchemy.dialects import postgresql 27 | 28 | from mostlyai.sdk._data.dtype import ( 29 | bool_coerce, 30 | datetime_coerce, 31 | float_coerce, 32 | int_coerce, 33 | str_coerce, 34 | time_coerce, 35 | ) 36 | 37 | _LOG = logging.getLogger(__name__) 38 | 39 | # primary mapping from SQL dtype class to coercion function 40 | SQL_DTYPE__COERCE = { 41 | # postgresql 42 | postgresql.VARCHAR: str_coerce 43 | # NOTE: we may add more of specific type coercions 44 | } 45 | 46 | 47 | # fallback mapping from python type to coercion function 48 | PYTHON_DTYPE__COERCE = { 49 | bool: bool_coerce, 50 | datetime.time: time_coerce, 51 | datetime.date: datetime_coerce, 52 | datetime.datetime: datetime_coerce, 53 | decimal.Decimal: float_coerce, 54 | float: float_coerce, 55 | int: int_coerce, 56 | str: str_coerce, 57 | } 58 | 59 | COERCE__SQL_DTYPE__KWARGS = defaultdict( 60 | lambda: lambda _: {}, 61 | {str_coerce: lambda sql_dtype: {"max_length": getattr(sql_dtype, "length", None)}}, 62 | ) 63 | 64 | 65 | def coerce_to_sql_dtype(s: pd.Series, sql_dtype: Any) -> pd.Series | None: 66 | """ 67 | Coerces data to given SQL dtype. 68 | 69 | :param s: data 70 | :param sql_dtype: target SQL dtype for data to fit 71 | :return: coerced data if coercion function found for SQL dtype, otherwise None 72 | """ 73 | 74 | # check if coercion function is available for SQL dtype class 75 | sql_dtype_class = type(sql_dtype) 76 | if sql_dtype_class in SQL_DTYPE__COERCE: 77 | coerce = SQL_DTYPE__COERCE[sql_dtype_class] 78 | coerce_kwargs = COERCE__SQL_DTYPE__KWARGS[coerce](sql_dtype) 79 | return coerce(s, **coerce_kwargs) 80 | 81 | # check if coercion function is available for SQL dtype's python_type 82 | # note that, not every SQL dtype has python_type implemented! 83 | try: 84 | python_dtype = sql_dtype.python_type 85 | except NotImplementedError: 86 | _LOG.warning(f"python_type not implemented for SQL dtype {sql_dtype}") 87 | else: 88 | if python_dtype in PYTHON_DTYPE__COERCE: 89 | coerce = PYTHON_DTYPE__COERCE[python_dtype] 90 | coerce_kwargs = COERCE__SQL_DTYPE__KWARGS[coerce](sql_dtype) 91 | return coerce(s, **coerce_kwargs) 92 | 93 | _LOG.warning(f"coercion function not found for column {s.name} and SQL dtype {sql_dtype}") 94 | # return None if coercion function not found 95 | return None 96 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine:latest AS alpine 2 | # ! We're using alpine as a hack to fetch the musl libraries. 3 | 4 | FROM cgr.dev/chainguard/wolfi-base:latest AS base 5 | ENV LANG="C.UTF-8" 6 | ENV UV_FROZEN=true 7 | ENV UV_NO_CACHE=true 8 | 9 | WORKDIR /app 10 | RUN chmod 777 /app 11 | RUN apk add --no-cache wget bash tzdata 12 | RUN apk add --no-cache python-3.12-dev 13 | RUN apk add --no-cache uv 14 | COPY --from=alpine --chmod=755 /lib/*musl-*.so.1 /lib/ 15 | USER nonroot 16 | 17 | FROM base AS final 18 | USER root 19 | RUN apk add --no-cache krb5-dev libpq unixodbc-dev libaio krb5 20 | RUN CURRENT_ARCH=$(uname -m | sed 's|x86_64|amd64|g') \ 21 | && wget https://download.microsoft.com/download/fae28b9a-d880-42fd-9b98-d779f0fdd77f/msodbcsql18_18.5.1.1-1_$CURRENT_ARCH.apk -qO /tmp/msodbcsql.apk \ 22 | && apk add --no-cache --allow-untrusted /tmp/msodbcsql.apk \ 23 | && rm -rf /tmp/msodbcsql.apk \ 24 | && apk add --no-cache glibc-iconv 25 | 26 | RUN CURRENT_ARCH=$(uname -m | sed 's|x86_64|x64|g') \ 27 | && if [ "$CURRENT_ARCH" != "x64" ]; then exit 0; fi \ 28 | && wget https://download.oracle.com/otn_software/linux/instantclient/211000/instantclient-basic-linux.$CURRENT_ARCH-21.1.0.0.0.zip -qO /tmp/oracle-instantclient.zip \ 29 | && wget https://download.oracle.com/otn_software/linux/instantclient/211000/instantclient-sqlplus-linux.$CURRENT_ARCH-21.1.0.0.0.zip -qO /tmp/oracle-sqlplus.zip \ 30 | && unzip /tmp/oracle-instantclient.zip -d /opt/oracle \ 31 | && unzip /tmp/oracle-sqlplus.zip -d /opt/oracle \ 32 | && rm -rf /tmp/oracle-sqlplus.zip /tmp/oracle-instantclient.zip \ 33 | && sh -c "echo '/opt/oracle/instantclient_21_1' > /etc/ld.so.conf.d/oracle-instantclient.conf" 34 | ENV PATH="$PATH:/opt/oracle/instantclient_21_1" 35 | ENV ORACLE_HOME="/opt/oracle/instantclient_21_1" 36 | 37 | COPY ./uv.lock ./pyproject.toml ./ 38 | RUN apk add --no-cache --virtual temp-build-deps gcc~12 postgresql-dev \ 39 | && uv sync --no-editable --all-extras --no-extra local-gpu \ 40 | --no-install-package torch \ 41 | --no-install-package torchvision \ 42 | --no-install-package torchaudio \ 43 | --no-install-package vllm \ 44 | --no-install-package bitsandbytes \ 45 | --no-install-package nvidia-cudnn-cu12 \ 46 | --no-install-package nvidia-cublas-cu12 \ 47 | --no-install-package nvidia-cusparse-cu12 \ 48 | --no-install-package nvidia-cufft-cu12 \ 49 | --no-install-package nvidia-cuda-cupti-cu12 \ 50 | --no-install-package nvidia-nvjitlink-cu12 \ 51 | --no-install-package nvidia-cuda-nvrtc-cu12 \ 52 | --no-install-package nvidia-curand-cu12 \ 53 | --no-install-package nvidia-cusolver-cu12 \ 54 | --no-install-package nvidia-cusparselt-cu12 \ 55 | --no-install-package nvidia-nccl-cu12 \ 56 | --no-install-package nvidia-cuda-runtime-cu12 \ 57 | --no-install-package nvidia-nvtx-cu12 \ 58 | --no-install-package ray \ 59 | --no-install-package cupy-cuda12x \ 60 | --no-install-package triton \ 61 | --no-install-package mostlyai \ 62 | --no-install-project \ 63 | && apk del temp-build-deps 64 | 65 | RUN uv pip install torch==2.8.0 torchvision==0.23.0 --torch-backend=cpu 66 | COPY mostlyai ./mostlyai 67 | COPY README.md ./ 68 | RUN uv pip install -e . 69 | COPY ./tools/docker_entrypoint.py /app/entrypoint.py 70 | 71 | USER nonroot 72 | 73 | EXPOSE 8080 74 | ENTRYPOINT [ "uv", "run", "--no-sync", "--project", "/app", "--"] 75 | CMD ["/app/entrypoint.py"] 76 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_deliver_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | from mostlyai.sdk._data.base import DataContainer, Schema 18 | from mostlyai.sdk._data.conversions import create_container_from_connector 19 | from mostlyai.sdk._data.db.base import SqlAlchemyContainer 20 | from mostlyai.sdk._data.file.container.bucket_based import BucketBasedContainer 21 | from mostlyai.sdk._data.file.table.parquet import ParquetDataTable 22 | from mostlyai.sdk._data.file.utils import make_data_table_from_container 23 | from mostlyai.sdk._data.push import push_data, push_data_by_copying 24 | from mostlyai.sdk.domain import Connector, Generator, SyntheticDatasetDelivery 25 | 26 | 27 | def execute_step_deliver_data( 28 | *, 29 | generator: Generator, 30 | delivery: SyntheticDatasetDelivery, 31 | connector: Connector | None, 32 | schema: Schema, 33 | job_workspace_dir: Path, 34 | ): 35 | # skip DELIVER_DATA step if no destination connector is provided 36 | if connector is None: 37 | return 38 | 39 | # create destination container 40 | container = create_container_from_connector(connector) 41 | container.set_location(delivery.location) 42 | 43 | overwrite_tables = delivery.overwrite_tables 44 | for table_name in schema.tables: 45 | local_path = job_workspace_dir / "FinalizedSyntheticData" / table_name / "parquet" 46 | if isinstance(container, BucketBasedContainer): 47 | bucket_path = container.path / table_name 48 | push_data_by_copying( 49 | source=local_path, 50 | destination=bucket_path, 51 | overwrite_tables=overwrite_tables, 52 | ) 53 | elif isinstance(container, SqlAlchemyContainer): 54 | src_table = ParquetDataTable(path=local_path) 55 | table = _create_destination_table(table_name, generator, container) 56 | push_data( 57 | source=src_table, 58 | destination=table, 59 | schema=schema, 60 | overwrite_tables=overwrite_tables, 61 | ) 62 | else: 63 | raise ValueError(f"Unsupported destination container type: {container}") 64 | 65 | 66 | def _create_destination_table( 67 | table_name: str, 68 | generator: Generator, 69 | data_container: DataContainer, 70 | ): 71 | # create destination table 72 | source_table = next(t for t in generator.tables if t.name == table_name) 73 | data_table = make_data_table_from_container(data_container) 74 | data_table.name = source_table.name 75 | data_table.primary_key = source_table.primary_key 76 | data_table.columns = [c.name for c in source_table.columns if c.included] 77 | data_table.encoding_types = {c.name: c.model_encoding_type for c in source_table.columns if c.included} 78 | data_table.is_output = True 79 | return data_table 80 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_pull_training_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable 16 | from pathlib import Path 17 | 18 | from mostlyai.sdk import _data as data 19 | from mostlyai.sdk._data.base import ForeignKey, Schema 20 | from mostlyai.sdk._data.conversions import create_container_from_connector 21 | from mostlyai.sdk._data.db.base import SqlAlchemyTable 22 | from mostlyai.sdk._data.file.utils import make_data_table_from_container 23 | from mostlyai.sdk.domain import Connector, Generator, ModelType 24 | 25 | 26 | def execute_step_pull_training_data( 27 | *, 28 | generator: Generator, 29 | connectors: list[Connector], 30 | model_type: ModelType, 31 | target_table_name: str, 32 | workspace_dir: Path, 33 | update_progress: Callable, 34 | ) -> tuple[list[str], int]: 35 | schema = create_training_schema(generator=generator, connectors=connectors) 36 | 37 | # fetch total rows 38 | tgt_table_total_rows = schema.tables[target_table_name].row_count 39 | # fetch columns 40 | tgt_table_columns = schema.tables[target_table_name].columns 41 | 42 | # fetch model_config 43 | tgt_table = next(t for t in generator.tables if t.name == target_table_name) 44 | if model_type == ModelType.language: 45 | model_config = tgt_table.language_model_configuration 46 | else: 47 | model_config = tgt_table.tabular_model_configuration 48 | 49 | # call PULL 50 | data.pull( 51 | tgt=target_table_name, 52 | schema=schema, 53 | model_type=model_type, 54 | max_sample_size=model_config.max_sample_size, 55 | workspace_dir=workspace_dir, 56 | update_progress=update_progress, 57 | ) 58 | return tgt_table_columns, tgt_table_total_rows 59 | 60 | 61 | def create_training_schema(generator: Generator, connectors: list[Connector]) -> Schema: 62 | tables = {} 63 | for table in generator.tables: 64 | # create DataContainer 65 | connector_id = table.source_connector_id 66 | connector = next(c for c in connectors if c.id == connector_id) 67 | container = create_container_from_connector(connector) 68 | meta = container.set_location(table.location) 69 | # create DataTable 70 | data_table = make_data_table_from_container(container, lazy_fetch_primary_key=False) 71 | # preserve actual database table name for database containers 72 | if isinstance(data_table, SqlAlchemyTable) and meta and "table_name" in meta: 73 | data_table.name_in_db = meta["table_name"] 74 | data_table.name = table.name 75 | data_table.primary_key = table.primary_key 76 | if table.columns: 77 | data_table.columns = [c.name for c in table.columns if c.included] 78 | data_table.encoding_types = {c.name: c.model_encoding_type for c in table.columns if c.included} 79 | data_table.is_output = False 80 | data_table.foreign_keys = [ 81 | ForeignKey(column=fk.column, referenced_table=fk.referenced_table, is_context=fk.is_context) 82 | for fk in table.foreign_keys or [] 83 | ] 84 | tables[table.name] = data_table 85 | return Schema(tables=tables) 86 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/language_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import uuid 17 | 18 | import pandas as pd 19 | 20 | from mostlyai.sdk._data.base import Schema 21 | from mostlyai.sdk._data.util.common import TABLE_COLUMN_INFIX, TEMPORARY_PRIMARY_KEY 22 | from mostlyai.sdk.domain import ModelType 23 | 24 | _LOG = logging.getLogger(__name__) 25 | 26 | 27 | def split_language_model( 28 | schema: Schema, 29 | tgt: str, 30 | tgt_data: pd.DataFrame, 31 | ctx_data: pd.DataFrame | None = None, 32 | ) -> tuple[pd.DataFrame, pd.DataFrame]: 33 | """ 34 | Split LANGUAGE cols into `tgt_data`, and add all other columns to `ctx_data`. 35 | 36 | :return: ctx_data, tgt_data 37 | """ 38 | enctypes = schema.tables[tgt].encoding_types 39 | language_cols = [col for col in enctypes if enctypes[col].startswith(ModelType.language)] 40 | if len(language_cols) == 0: 41 | # if no LANGUAGE columns are present, then leave data as-is 42 | return ctx_data, tgt_data 43 | _LOG.info("split_language_model") 44 | # split into LANGUAGE and TABULAR columns 45 | other_data = tgt_data[[c for c in tgt_data if c not in language_cols]] 46 | # context data must be prefixed 47 | other_data = other_data.add_prefix(tgt + TABLE_COLUMN_INFIX) 48 | # No LANGUAGE cols on tgt_data means pull ctx only, can't get LANGUAGE cols from tgt_data, so we must create it 49 | if set(language_cols).issubset(set(tgt_data.columns)): 50 | tgt_data = tgt_data[language_cols] 51 | else: 52 | tgt_data = pd.DataFrame() 53 | 54 | if ctx_data is None: 55 | # handle single table case: split all TABULAR columns into `ctx_data`, 56 | # and only keep txt_data as `tgt_data`. 57 | ctx_data = other_data 58 | else: 59 | # handle two table case: right-join all TABULAR columns to ctx_data, 60 | # and only keep txt_data as tgt 61 | ctx_relation = schema.get_parent_context_relation(tgt) 62 | 63 | if ctx_relation: 64 | ctx_pk = ctx_relation.parent.ref_name() 65 | tgt_fk = ctx_relation.child.ref_name() 66 | ctx_data = pd.merge(ctx_data, other_data, how="inner", left_on=ctx_pk, right_on=tgt_fk) 67 | 68 | tmp_keys = [str(uuid.uuid4()) for _ in range(len(ctx_data))] 69 | tgt_data.insert(0, TEMPORARY_PRIMARY_KEY, tmp_keys) 70 | ctx_data.insert(0, f"{tgt}{TABLE_COLUMN_INFIX}{TEMPORARY_PRIMARY_KEY}", tmp_keys) 71 | return ctx_data, tgt_data 72 | 73 | 74 | def drop_language_columns_in_target( 75 | tgt: str, 76 | schema: Schema, 77 | tgt_data: pd.DataFrame, 78 | ) -> pd.DataFrame: 79 | """ 80 | Drop language columns when pulling data for a tabular model. 81 | 82 | :param tgt: target table 83 | :param schema: database schema 84 | :param tgt_data: the target DataFrame 85 | :return: the target DataFrame with unsupported encoding types dropped 86 | """ 87 | tgt_table = schema.tables[tgt] 88 | drop_columns = [] 89 | for col_name, encoding_type in tgt_table.encoding_types.items(): 90 | if encoding_type.startswith(ModelType.language): 91 | drop_columns.append(col_name) 92 | if drop_columns: 93 | _LOG.info(f"drop LANGUAGE columns from target: {drop_columns}") 94 | return tgt_data.drop(columns=drop_columns) 95 | -------------------------------------------------------------------------------- /mostlyai/sdk/client/_base_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import base64 16 | import csv 17 | import io 18 | import warnings 19 | from pathlib import Path 20 | from typing import Any, Literal 21 | 22 | import pandas as pd 23 | 24 | warnings.simplefilter("always", DeprecationWarning) 25 | 26 | 27 | def convert_to_base64( 28 | df: pd.DataFrame | list[dict[str, Any]], 29 | format: Literal["parquet", "jsonl"] = "parquet", 30 | ) -> str: 31 | """ 32 | Convert a DataFrame to a base64 encoded string, representing the content in Parquet or JSONL format. 33 | 34 | Args: 35 | df: The DataFrame to convert. 36 | format: The format to use for the conversion. Either "parquet" or "jsonl". 37 | 38 | Returns: 39 | The base64 encoded string. 40 | """ 41 | if df.__class__.__name__ == "DataFrame" and df.__class__.__module__.startswith("pyspark.sql"): 42 | # Convert PySpark DataFrame to Pandas DataFrame (safely) 43 | df = pd.DataFrame(df.collect(), columns=df.columns) 44 | elif not isinstance(df, pd.DataFrame): 45 | df = pd.DataFrame(df) 46 | # Save the DataFrame to a buffer in Parquet / JSONL format 47 | buffer = io.BytesIO() 48 | if format == "parquet": 49 | # clear any (potentially non-serializable) attributes that might stop us from saving to PQT 50 | if df.attrs: 51 | df.attrs.clear() 52 | # persist the DataFrame to Parquet format 53 | df.to_parquet(buffer, index=False) 54 | else: # format == "jsonl" 55 | # persist the DataFrame to JSONL format 56 | df.to_json(buffer, orient="records", date_format="iso", lines=True, index=False) 57 | # read in persisted file as base64 encoded string 58 | buffer.seek(0) 59 | binary_data = buffer.read() 60 | base64_encoded_str = base64.b64encode(binary_data).decode() 61 | return base64_encoded_str 62 | 63 | 64 | def convert_to_df(data: str, format: Literal["parquet", "jsonl"] = "parquet") -> pd.DataFrame: 65 | # Load the DataFrame from a base64 encoded string 66 | binary_data = base64.b64decode(data) 67 | buffer = io.BytesIO(binary_data) 68 | if format == "parquet": 69 | df = pd.read_parquet(buffer) 70 | else: # format == "jsonl" 71 | df = pd.read_json(buffer, orient="records", lines=True) 72 | return df 73 | 74 | 75 | def read_table_from_path(path: str | Path) -> tuple[str, pd.DataFrame]: 76 | # read data from file 77 | fn = str(path) 78 | if fn.lower().endswith((".pqt", ".parquet")): 79 | df = pd.read_parquet(fn) 80 | else: 81 | delimiter = "," 82 | if fn.lower().endswith((".csv", ".tsv")): 83 | try: 84 | with open(fn) as f: 85 | header = f.readline() 86 | sniffer = csv.Sniffer() 87 | delimiter = sniffer.sniff(header, ",;|\t' :").delimiter 88 | except (csv.Error, FileNotFoundError): 89 | # csv.Error: happens for example for single column CSV files 90 | # FileNotFoundError: happens for example for remote files 91 | pass 92 | df = pd.read_csv(fn, low_memory=False, delimiter=delimiter) 93 | if fn.lower().endswith((".gz", ".gzip", ".bz2")): 94 | fn = fn.rsplit(".", 1)[0] 95 | name = Path(fn).stem 96 | return name, df 97 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/mysql.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from urllib.parse import quote 17 | 18 | import sqlalchemy as sa 19 | from sqlalchemy import text 20 | from sqlalchemy.dialects.mysql.base import MySQLDialect 21 | 22 | from mostlyai.sdk._data.db.base import DBDType, SqlAlchemyContainer, SqlAlchemyTable 23 | 24 | _LOG = logging.getLogger(__name__) 25 | 26 | _MY_SQL_PRIVILEGE_DB_NAME = "mysql" 27 | _MY_SQL_PRIVILEGE_TABLE_NAME = "user" 28 | 29 | 30 | class MysqlDType(DBDType): 31 | UNBOUNDED_VARCHAR_ALLOWED = False 32 | 33 | @classmethod 34 | def sa_dialect_class(cls): 35 | return MySQLDialect 36 | 37 | 38 | class BaseMySqlContainer(SqlAlchemyContainer): 39 | SCHEMES = ["mysql", "mariadb"] 40 | SA_CONNECTION_KWARGS = {} 41 | SA_SSL_ATTR_KEY_MAP = { 42 | "root_certificate_path": "ssl_ca", 43 | "ssl_certificate_path": "ssl_cert", 44 | "ssl_certificate_key_path": "ssl_key", 45 | } 46 | DIALECT = "" 47 | SQL_FETCH_FOREIGN_KEYS = """ 48 | SELECT 49 | TABLE_NAME as TABLE_NAME, 50 | COLUMN_NAME as COLUMN_NAME, 51 | REFERENCED_TABLE_NAME as REFERENCED_TABLE_NAME, 52 | REFERENCED_COLUMN_NAME as REFERENCED_COLUMN_NAME 53 | FROM 54 | information_schema.KEY_COLUMN_USAGE 55 | WHERE 56 | CONSTRAINT_SCHEMA = :schema_name AND 57 | REFERENCED_TABLE_NAME IS NOT NULL; 58 | """ 59 | INIT_DEFAULT_VALUES = {"dbname": "", "port": "3306"} 60 | 61 | def _get_uri_without_dbname(self): 62 | # User and password are needed to avoid double-encoding of @ character 63 | username = quote(self.username) 64 | password = quote(self.password) 65 | return f"{self.DIALECT}+mysqlconnector://{username}:{password}@{self.host}:{self.port}" 66 | 67 | @property 68 | def sa_uri(self): 69 | sa_uri = f"{self._get_uri_without_dbname()}/{self.dbname}" 70 | return sa_uri 71 | 72 | @property 73 | def sa_uri_for_does_database_exist(self): 74 | return self._get_uri_without_dbname() 75 | 76 | @classmethod 77 | def table_class(cls): 78 | return MysqlLikeTable 79 | 80 | def does_database_exist(self) -> bool: 81 | with self.init_sa_connection("db_exist_check") as connection: 82 | result = connection.execute( 83 | text(f"SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME='{self.dbname}'") 84 | ) 85 | return bool(result.rowcount) 86 | 87 | def update_dbschema(self, dbschema: str | None) -> None: 88 | # schema (as a prefix) is equivalent to db name in mysql 89 | self.dbname = dbschema or self.INIT_DEFAULT_VALUES.get("dbname") 90 | # reset engine 91 | self._sa_engine_for_read = None 92 | self._sa_engine_for_write = None 93 | 94 | 95 | class MysqlContainer(BaseMySqlContainer): 96 | DIALECT = "mysql" 97 | 98 | 99 | class MariadbContainer(BaseMySqlContainer): 100 | DIALECT = "mariadb" 101 | 102 | 103 | class MysqlLikeTable(SqlAlchemyTable): 104 | DATA_TABLE_TYPE = "mysql" 105 | SA_RANDOM = sa.func.rand() 106 | 107 | @classmethod 108 | def dtype_class(cls): 109 | return MysqlDType 110 | 111 | @classmethod 112 | def container_class(cls): 113 | return MysqlContainer 114 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/pull.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import shutil 17 | import time 18 | from pathlib import Path 19 | 20 | from mostlyai.sdk._data.base import Schema 21 | from mostlyai.sdk._data.progress_callback import ProgressCallback, ProgressCallbackWrapper 22 | from mostlyai.sdk._data.pull_utils import ( 23 | handle_workspace_dir, 24 | pull_fetch, 25 | pull_keys, 26 | pull_split, 27 | remake_schema_after_pull_fetch, 28 | ) 29 | from mostlyai.sdk.domain import ModelType 30 | 31 | _LOG = logging.getLogger(__name__) 32 | 33 | 34 | def pull( 35 | *, 36 | tgt: str, 37 | schema: Schema, 38 | model_type: str | ModelType = ModelType.tabular, 39 | max_sample_size: int | None = None, 40 | trn_val_split: float | None = 0.8, 41 | workspace_dir: str | Path = "engine-ws", 42 | update_progress: ProgressCallback | None = None, 43 | ): 44 | t0 = time.time() 45 | with ProgressCallbackWrapper(update_progress, description="Pull training data") as progress: 46 | workspace_dir = Path(workspace_dir) 47 | model_type = ModelType(model_type) 48 | if tgt not in schema.tables: 49 | raise ValueError(f"table '{tgt}' not defined in schema") 50 | schema.preprocess_schema_before_pull() 51 | # gather context_tables 52 | context_tables = schema.get_context_tables(tgt) 53 | _LOG.info(f"context_tables (size: {len(context_tables)}): {context_tables}") 54 | # handle workspace_dir 55 | workspace_dir = handle_workspace_dir(workspace_dir=workspace_dir) 56 | # ensure that max_sample_size is a positive integer, if given 57 | if max_sample_size is not None: 58 | max_sample_size = max(1, max_sample_size) 59 | # log arguments 60 | _LOG.info(f"tgt: {tgt}") 61 | _LOG.info(f"model_type: {model_type}") 62 | _LOG.info(f"max_sample_size: {max_sample_size}") 63 | 64 | # initialize progress counter 65 | tbl_count_rows = 0 66 | tbl_count_rows += schema.tables[tgt].row_count 67 | for ctx_table in schema.get_context_tables(tgt): 68 | tbl_count_rows += schema.tables[ctx_table].row_count 69 | progress_plan = 1000 70 | progress_fetch = tbl_count_rows 71 | progress_split = tbl_count_rows 72 | progress.update(completed=0, total=progress_plan + progress_fetch + progress_split + 1) 73 | 74 | keys = pull_keys( 75 | tgt=tgt, 76 | schema=schema, 77 | max_sample_size=max_sample_size, 78 | model_type=model_type, 79 | ) 80 | progress.update(advance=progress_plan) 81 | 82 | pull_fetch( 83 | tgt=tgt, 84 | schema=schema, 85 | keys=keys, 86 | max_sample_size=max_sample_size, 87 | workspace_dir=workspace_dir, 88 | progress=progress, 89 | ) 90 | schema = remake_schema_after_pull_fetch(tgt=tgt, schema=schema, workspace_dir=workspace_dir) 91 | 92 | pull_split( 93 | tgt=tgt, 94 | schema=schema, 95 | trn_val_split=trn_val_split, 96 | model_type=model_type, 97 | do_ctx_only=False, 98 | workspace_dir=workspace_dir, 99 | progress=progress, 100 | ) 101 | 102 | _LOG.info("clean up temporary fetch directory") 103 | shutil.rmtree(workspace_dir / "__PULL_FETCH", ignore_errors=True) 104 | _LOG.info(f"pull total time: {time.time() - t0:.2f}s") 105 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/snowflake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from urllib.parse import quote 17 | 18 | import sqlalchemy as sa 19 | from snowflake.sqlalchemy import URL 20 | from snowflake.sqlalchemy.snowdialect import SnowflakeDialect 21 | 22 | from mostlyai.sdk._data.db.base import DBDType, SqlAlchemyContainer, SqlAlchemyTable 23 | from mostlyai.sdk._data.exceptions import MostlyDataException 24 | 25 | _LOG = logging.getLogger(__name__) 26 | 27 | ACCOUNT_SUFFIX_TO_REMOVE = ".snowflakecomputing.com" 28 | ACCOUNT_PREFIX_TO_REMOVE = "https://" 29 | DEFAULT_WAREHOUSE = "COMPUTE_WH" 30 | 31 | 32 | class SnowflakeDType(DBDType): 33 | @classmethod 34 | def sa_dialect_class(cls): 35 | return SnowflakeDialect 36 | 37 | 38 | class SnowflakeContainer(SqlAlchemyContainer): 39 | SCHEMES = ["snowflake"] 40 | INIT_DEFAULT_VALUES = {"dbname": ""} 41 | 42 | def __init__(self, *args, account, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.account = account 45 | 46 | @property 47 | def sa_uri(self): 48 | # User and password are needed to avoid double-encoding of @ character 49 | username = quote(self.username) 50 | password = quote(self.password) 51 | self.host = self.host or DEFAULT_WAREHOUSE 52 | return URL( 53 | user=username, 54 | password=password, 55 | account=self.account.replace(ACCOUNT_SUFFIX_TO_REMOVE, "").replace(ACCOUNT_PREFIX_TO_REMOVE, ""), 56 | warehouse=self.host, 57 | database=self.dbname, 58 | cache_column_metadata=True, 59 | ) 60 | 61 | @classmethod 62 | def table_class(cls): 63 | return SnowflakeTable 64 | 65 | def does_database_exist(self) -> bool: 66 | with self.init_sa_connection("db_exist_check") as connection: 67 | result = connection.execute(sa.text(f"SHOW DATABASES LIKE '{self.dbname}'")) 68 | if result.rowcount: 69 | db_exist = True 70 | else: 71 | db_exist = False 72 | return db_exist 73 | 74 | def _is_schema_exist(self) -> bool: 75 | if self.dbschema is None: 76 | return True 77 | with self.use_sa_engine() as sa_engine: 78 | schema_names = sa.inspect(sa_engine).get_schema_names() 79 | return self.dbschema.lower() in schema_names 80 | 81 | def is_accessible(self) -> bool: 82 | try: 83 | with self.init_sa_connection("access_check"): 84 | if self.dbname and not self.does_database_exist(): 85 | raise MostlyDataException(f"Database `{self.dbname}` does not exist.") 86 | elif self.schema and not self._is_schema_exist(): 87 | raise MostlyDataException(f"Schema `{self.dbschema}` does not exist.") 88 | else: 89 | return True 90 | except sa.exc.SQLAlchemyError as e: 91 | error_message = str(e).lower() 92 | _LOG.error(f"Database connection failed with error: {e}") 93 | if "password" in error_message or "user" or ("account name" and "snowflake") in error_message: 94 | raise MostlyDataException("Credentials are incorrect.") 95 | else: 96 | raise 97 | 98 | 99 | class SnowflakeTable(SqlAlchemyTable): 100 | DATA_TABLE_TYPE = "snowflake" 101 | SA_RANDOM = sa.func.random() 102 | 103 | @classmethod 104 | def dtype_class(cls): 105 | return SnowflakeDType 106 | 107 | @classmethod 108 | def container_class(cls): 109 | return SnowflakeContainer 110 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/postgresql.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from urllib.parse import quote 17 | 18 | import numpy as np 19 | import sqlalchemy as sa 20 | from psycopg2.extensions import AsIs, register_adapter 21 | from sqlalchemy.dialects.postgresql.base import PGDialect 22 | 23 | from mostlyai.sdk._data.db.base import DBDType, SqlAlchemyContainer, SqlAlchemyTable 24 | 25 | _LOG = logging.getLogger(__name__) 26 | 27 | register_adapter(np.int64, AsIs) 28 | register_adapter(np.float64, AsIs) 29 | 30 | 31 | class PostgresqlDType(DBDType): 32 | FROM_VIRTUAL_DATETIME = sa.TIMESTAMP 33 | 34 | @classmethod 35 | def sa_dialect_class(cls): 36 | return PGDialect 37 | 38 | 39 | class PostgresqlContainer(SqlAlchemyContainer): 40 | SCHEMES = ["postgresql"] 41 | SA_CONNECTION_KWARGS = {"sslmode": "require"} 42 | SA_SSL_ATTR_KEY_MAP = { 43 | "root_certificate_path": "sslrootcert", 44 | "ssl_certificate_path": "sslcert", 45 | "ssl_certificate_key_path": "sslkey", 46 | } 47 | SQL_FETCH_FOREIGN_KEYS = """ 48 | SELECT 49 | kcu.table_name AS TABLE_NAME, 50 | kcu.column_name AS COLUMN_NAME, 51 | ccu.table_name AS REFERENCED_TABLE_NAME, 52 | ccu.column_name AS REFERENCED_COLUMN_NAME 53 | FROM 54 | information_schema.table_constraints AS tc 55 | JOIN information_schema.key_column_usage AS kcu 56 | ON tc.constraint_name = kcu.constraint_name 57 | AND tc.table_schema = kcu.table_schema 58 | JOIN information_schema.constraint_column_usage AS ccu 59 | ON ccu.constraint_name = tc.constraint_name 60 | AND ccu.table_schema = tc.table_schema 61 | WHERE 62 | tc.constraint_type = 'FOREIGN KEY' 63 | AND tc.table_schema = :schema_name; 64 | """ 65 | INIT_DEFAULT_VALUES = {"dbname": "", "port": "5432"} 66 | 67 | @property 68 | def sa_uri(self): 69 | # User and password are needed to avoid double-encoding of @ character 70 | username = quote(self.username) 71 | password = quote(self.password) 72 | return f"postgresql+psycopg2://{username}:{password}@{self.host}:{self.port}/{self.dbname}" 73 | 74 | @property 75 | def sa_create_engine_kwargs(self) -> dict: 76 | return { 77 | # read more: https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#psycopg2-fast-execution-helpers 78 | "executemany_mode": "values_only", 79 | # the following knobs can be used in attempts to improve speed of writing 80 | # however, no better setting was found than the default 81 | # "executemany_values_page_size": 1_000, # default 82 | # "executemany_batch_page_size": 100 # default 83 | } 84 | 85 | @classmethod 86 | def table_class(cls): 87 | return PostgresqlTable 88 | 89 | def does_database_exist(self) -> bool: 90 | try: 91 | with self.init_sa_connection() as connection: 92 | result = connection.execute(sa.text(f"SELECT 1 FROM pg_database WHERE datname='{self.dbname}'")) 93 | return bool(result.rowcount) 94 | except Exception as e: 95 | _LOG.error(f"Error when checking if database exists: {e}") 96 | return False 97 | 98 | 99 | class PostgresqlTable(SqlAlchemyTable): 100 | DATA_TABLE_TYPE = "postgresql" 101 | SA_RANDOM = sa.func.random() 102 | 103 | @classmethod 104 | def dtype_class(cls): 105 | return PostgresqlDType 106 | 107 | @classmethod 108 | def container_class(cls): 109 | return PostgresqlContainer 110 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import re 17 | 18 | from mostlyai.sdk._data.base import DataContainer, DataTable 19 | from mostlyai.sdk._data.db.base import SqlAlchemyContainer 20 | from mostlyai.sdk._data.exceptions import MostlyDataException 21 | from mostlyai.sdk._data.file.base import ( 22 | FILE_TYPE_COMPRESSED, 23 | FileContainer, 24 | FileDataTable, 25 | FileType, 26 | get_file_name_and_type, 27 | ) 28 | from mostlyai.sdk._data.file.table.csv import CsvDataTable 29 | from mostlyai.sdk._data.file.table.feather import FeatherDataTable 30 | from mostlyai.sdk._data.file.table.json import JsonDataTable 31 | from mostlyai.sdk._data.file.table.parquet import ParquetDataTable 32 | 33 | _LOG = logging.getLogger(__name__) 34 | 35 | 36 | FILE_EXT_DATA_TABLE_CLASS_MAP = { 37 | FileType.csv: CsvDataTable, 38 | FileType.tsv: CsvDataTable, 39 | FileType.parquet: ParquetDataTable, 40 | FileType.feather: FeatherDataTable, 41 | FileType.json: JsonDataTable, 42 | } 43 | 44 | 45 | def read_data_table_from_path( 46 | container_object: FileContainer, return_class: bool = False 47 | ) -> FileDataTable | type[FileDataTable]: 48 | return _fetch_file_data_table(container_object, return_class, is_read=True) 49 | 50 | 51 | def write_data_table_to_path( 52 | container_object: FileContainer, return_class: bool = False 53 | ) -> FileDataTable | type[FileDataTable]: 54 | return _fetch_file_data_table(container_object, return_class, is_read=False) 55 | 56 | 57 | def _fetch_file_data_table( 58 | container_object: FileContainer, return_class: bool, is_read: bool 59 | ) -> FileDataTable | type[FileDataTable]: 60 | # determine table_name 61 | table_name = container_object.path.absolute().name 62 | for ext in FILE_TYPE_COMPRESSED: 63 | table_name = re.sub(f"\\.{ext}$", "", table_name, flags=re.IGNORECASE) 64 | if "." in table_name: 65 | table_name = ".".join(table_name.split(".")[:-1]) 66 | # determine file extension 67 | if is_read: 68 | _LOG.info(f"detect data files for `{container_object.path}`") 69 | file_list = container_object.list_valid_files() 70 | _LOG.info(container_object.path) 71 | _LOG.info(file_list) 72 | _LOG.info(f"detected {len(file_list)} files: {file_list}") 73 | if len(file_list) == 0: 74 | raise MostlyDataException("No data files found.") 75 | else: 76 | file_list = [container_object.path] 77 | _, file_type = get_file_name_and_type(file_list[0]) 78 | data_table_cls = FILE_EXT_DATA_TABLE_CLASS_MAP.get(file_type) 79 | if return_class: 80 | return data_table_cls 81 | _LOG.info("create FileDataTable") 82 | data_table = data_table_cls( 83 | container=container_object, 84 | path=file_list, 85 | name=table_name, 86 | is_output=not is_read, 87 | ) 88 | return data_table 89 | 90 | 91 | def make_data_table_from_container( 92 | container: DataContainer, is_output: bool = False, lazy_fetch_primary_key: bool = True 93 | ) -> DataTable: 94 | if isinstance(container, SqlAlchemyContainer): 95 | # handle DB containers 96 | data_table_class = container.table_class() 97 | elif isinstance(container, FileContainer): 98 | # handle local fs and bucket containers 99 | if is_output: 100 | data_table_class = write_data_table_to_path(container, return_class=True) 101 | else: 102 | data_table_class = read_data_table_from_path(container, return_class=True) 103 | else: 104 | raise RuntimeError(f"Unknown container type: {type(container)}") 105 | return data_table_class(container=container, is_output=is_output, lazy_fetch_primary_key=lazy_fetch_primary_key) 106 | -------------------------------------------------------------------------------- /tests/_data/unit/file/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import pandas as pd 17 | 18 | from mostlyai.sdk._data.file.base import LocalFileContainer 19 | from mostlyai.sdk._data.file.container.bucket_based import BucketBasedContainer 20 | from mostlyai.sdk._data.file.table.csv import CsvDataTable 21 | from mostlyai.sdk._data.file.table.parquet import ParquetDataTable 22 | from mostlyai.sdk._data.file.utils import read_data_table_from_path 23 | 24 | 25 | def test_read_data_table_from_local_path(tmp_path): 26 | def _create_csv(file): 27 | file.parent.mkdir(exist_ok=True) 28 | pd.DataFrame({"x": [1]}).to_csv(file) 29 | 30 | def _create_pqt(file): 31 | file.parent.mkdir(exist_ok=True) 32 | pd.DataFrame({"x": [1]}).to_parquet(file) 33 | 34 | # test single CSV file 35 | file = tmp_path / "folder1" / "my_data.2023.csv.gz" 36 | _create_csv(file) 37 | container = LocalFileContainer(file_path=file) 38 | table = read_data_table_from_path(container) 39 | assert table.name == "my_data.2023" 40 | assert isinstance(table, CsvDataTable) 41 | 42 | # test multiple CSV files, mixed with some parquet 43 | file1 = tmp_path / "folder2" / "my_data.2023.csv.gz" 44 | file2 = tmp_path / "folder2" / "my_data.2024.csv.gz" 45 | file3 = tmp_path / "folder2" / "my_data.2025.parquet" 46 | _create_csv(file1) 47 | _create_csv(file2) 48 | _create_pqt(file3) 49 | container = LocalFileContainer(file_path=tmp_path / "folder2") 50 | table = read_data_table_from_path(container) 51 | assert table.name == "folder2" 52 | assert isinstance(table, ParquetDataTable) 53 | 54 | # test multiple Parquet files 55 | file1 = tmp_path / "folder3" / "my_data.2023.parquet" 56 | file2 = tmp_path / "folder3" / "my_data.2024.parquet" 57 | _create_pqt(file1) 58 | _create_pqt(file2) 59 | container = LocalFileContainer(file_path=tmp_path / "folder3") 60 | table = read_data_table_from_path(container) 61 | assert table.name == "folder3" 62 | assert isinstance(table, ParquetDataTable) 63 | 64 | 65 | def test_normalize_bucket_location(): 66 | uris = [ 67 | "s3:///bucketname/bucketpath/filename.xlsx/", 68 | "s3:///bucketname/bucketpath/filename.xlsx", 69 | "s3:///bucketname/bucketpath/", 70 | "s3:///bucketname/bucketpath", 71 | "s3:///bucketname/", 72 | "s3:///bucketname", 73 | "s3://bucketname/bucketpath/filename.xlsx/", 74 | "s3://bucketname/bucketpath/filename.xlsx", 75 | "s3://bucketname/bucketpath/", 76 | "s3://bucketname/bucketpath", 77 | "s3://bucketname/", 78 | "s3://bucketname", 79 | "/bucketname/bucketpath/filename.xlsx/", 80 | "/bucketname/bucketpath/filename.xlsx", 81 | "/bucketname/bucketpath/", 82 | "/bucketname/bucketpath", 83 | "bucketname/bucketpath/filename.xlsx/", 84 | "bucketname/bucketpath/filename.xlsx", 85 | "bucketname/bucketpath/", 86 | "bucketname/bucketpath", 87 | "bucketname/", 88 | "bucketname", 89 | ] 90 | results = [BucketBasedContainer.normalize_bucket_location(uri) for uri in uris] 91 | assert results == [ 92 | "bucketname/bucketpath/filename.xlsx", 93 | "bucketname/bucketpath/filename.xlsx", 94 | "bucketname/bucketpath", 95 | "bucketname/bucketpath", 96 | "bucketname/", 97 | "bucketname/", 98 | "bucketname/bucketpath/filename.xlsx", 99 | "bucketname/bucketpath/filename.xlsx", 100 | "bucketname/bucketpath", 101 | "bucketname/bucketpath", 102 | "bucketname/", 103 | "bucketname/", 104 | "bucketname/bucketpath/filename.xlsx", 105 | "bucketname/bucketpath/filename.xlsx", 106 | "bucketname/bucketpath", 107 | "bucketname/bucketpath", 108 | "bucketname/bucketpath/filename.xlsx", 109 | "bucketname/bucketpath/filename.xlsx", 110 | "bucketname/bucketpath", 111 | "bucketname/bucketpath", 112 | "bucketname/", 113 | "bucketname/", 114 | ] 115 | -------------------------------------------------------------------------------- /docs/tutorials/quality-assurance/quality-assurance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3fd5150a-d33e-435b-ac47-3fdf15fb05c2", 6 | "metadata": {}, 7 | "source": [ 8 | "# Quality Assurance \"Run\n", 9 | "\n", 10 | "In this tutorial we will leverage `mostlyai-qa`, the open-source Python toolkit to assess Synthetic Data quality. See also https://mostly-ai.github.io/mostlyai-qa/ for more info on that toolkit." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "c311dfb2", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Install SDK in CLIENT mode\n", 21 | "!uv pip install -U mostlyai\n", 22 | "# Or install in LOCAL mode\n", 23 | "!uv pip install -U 'mostlyai[local]' \n", 24 | "# Note: Restart kernel session after installation!" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "baa3bd75-3d32-44e2-87d3-b9895753b27f", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import webbrowser\n", 35 | "\n", 36 | "import pandas as pd\n", 37 | "\n", 38 | "from mostlyai import qa\n", 39 | "\n", 40 | "# initialize logging to stdout\n", 41 | "qa.init_logging()\n", 42 | "\n", 43 | "# print version\n", 44 | "print(f\"loaded mostlyai-qa {qa.__version__}\")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "08145f83-a985-4f4b-a02f-6f3ed3dbe6f3", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "repo = \"https://github.com/mostly-ai/paper-fidelity-accuracy/raw/refs/heads/main/2024-12/data\"\n", 55 | "trn = pd.read_csv(f\"{repo}/adult_trn.csv.gz\")\n", 56 | "hol = pd.read_csv(f\"{repo}/adult_hol.csv.gz\")\n", 57 | "syn = pd.read_csv(f\"{repo}/adult_mostlyai.csv.gz\")\n", 58 | "print(f\"fetched training data with {trn.shape[0]:,} records and {trn.shape[1]} attributes\")\n", 59 | "print(f\"fetched holdout data with {hol.shape[0]:,} records and {hol.shape[1]} attributes\")\n", 60 | "print(f\"fetched synthetic data with {syn.shape[0]:,} records and {syn.shape[1]} attributes\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "c9ce2c48-89e8-4f87-bcfe-97be95e25212", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "trn.sample(n=3)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "2d955e83-d15d-4a74-a85a-55c60c248d01", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "syn.sample(n=3)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "2e6e7026-494d-484d-b935-e366a4d695f4", 86 | "metadata": {}, 87 | "source": [ 88 | "## Generate HTML Report with Metrics" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "0c056183-1d3c-40a2-b528-90f058ce1d44", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# takes about 1-2 minutes\n", 99 | "report_path, metrics = qa.report(\n", 100 | " syn_tgt_data=syn,\n", 101 | " trn_tgt_data=trn,\n", 102 | " hol_tgt_data=hol,\n", 103 | " max_sample_size_embeddings=1_000, # set limit to speed up demo; remove limit for best measures\n", 104 | ")\n", 105 | "\n", 106 | "# pretty print metrics\n", 107 | "print(metrics.model_dump_json(indent=4))\n", 108 | "\n", 109 | "# open up HTML report in new browser window\n", 110 | "webbrowser.open(f\"file://{report_path.absolute()}\")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "23776f87-1bf3-4145-9942-38386c77a923", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "Python 3 (ipykernel)", 125 | "language": "python", 126 | "name": "python3" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 3 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython3", 138 | "version": "3.12.8" 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 5 143 | } 144 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/table/csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import csv 16 | import logging 17 | 18 | import pandas as pd 19 | import pyarrow.csv as pa_csv 20 | import pyarrow.dataset as ds 21 | import smart_open 22 | 23 | from mostlyai.sdk._data.file.base import ( 24 | FILE_DATA_TABLE_LAZY_INIT_FIELDS, 25 | FileContainer, 26 | FileDataTable, 27 | LocalFileContainer, 28 | ) 29 | 30 | CSV_DATA_TABLE_LAZY_INIT_FIELDS = FILE_DATA_TABLE_LAZY_INIT_FIELDS + [ 31 | "delimiter", 32 | ] 33 | 34 | _LOG = logging.getLogger(__name__) 35 | 36 | 37 | class CsvDataTable(FileDataTable): 38 | DATA_TABLE_TYPE = "csv" 39 | LAZY_INIT_FIELDS = frozenset(CSV_DATA_TABLE_LAZY_INIT_FIELDS) 40 | IS_WRITE_APPEND_ALLOWED = True 41 | 42 | def __init__(self, *args, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.delimiter: str | None = None 45 | 46 | @classmethod 47 | def container_class(cls) -> type["FileContainer"]: 48 | return LocalFileContainer 49 | 50 | def _get_delimiter(self) -> str: 51 | try: 52 | # only use first file to determine CSV delimiter for all files 53 | file = self.container.list_valid_files()[0] # AnyPath 54 | # substitute scheme prefix to work with smart_open (e.g. Azure in particular) 55 | file = f"{self.container.delimiter_prefix}{str(file).split('//')[-1]}" 56 | header = smart_open.open( 57 | file, 58 | "r", 59 | errors="backslashreplace", 60 | transport_params=self.container.transport_params, 61 | ).readline() 62 | sniffer = csv.Sniffer() 63 | try: 64 | delimiter = sniffer.sniff(header, ",;|\t' '").delimiter 65 | except csv.Error: 66 | # happens for example for single column CSV files 67 | delimiter = "," 68 | return delimiter 69 | except Exception as err: 70 | _LOG.warning(f"{err=} of {type(err)=}, defaulting to ',' delimiter") 71 | return "," 72 | 73 | def _get_dataset_format(self, **kwargs): 74 | if "delimiter" in kwargs: 75 | delimiter = kwargs.get("delimiter") 76 | else: 77 | delimiter = self.delimiter 78 | fmt = ds.CsvFileFormat( 79 | parse_options=pa_csv.ParseOptions( 80 | delimiter=delimiter, 81 | # silently drop any invalid rows 82 | invalid_row_handler=lambda x: "skip", 83 | ), 84 | # use 100MB to increase reliability of dtype detection; 85 | # if e.g. dtype is detected as int64 and then later on in the file a float occurs, 86 | # an error is raised; the same issue also occurs if multiple CSV files are provided 87 | # with the first file consisting of integers and the others of floats; so, there still 88 | # might be scenarios where errors occur due to dtype mismatch in CSV chunks; in these 89 | # cases we shall advise to convert the source data to Parquet 90 | read_options=pa_csv.ReadOptions(block_size=100 * 1024 * 1024), 91 | # add additional formats for datetime conversion 92 | convert_options=pa_csv.ConvertOptions( 93 | timestamp_parsers=[ 94 | pa_csv.ISO8601, 95 | "%m/%d/%Y %H:%M:%S", 96 | "%m/%d/%Y %H:%M", 97 | "%m/%d/%Y", 98 | ] 99 | ), 100 | ) 101 | return fmt 102 | 103 | def _lazy_fetch(self, item: str) -> None: 104 | if item == "delimiter": 105 | self.delimiter = self._get_delimiter() 106 | else: 107 | super()._lazy_fetch(item) 108 | return 109 | 110 | def write_data(self, df: pd.DataFrame, if_exists: str = "append", **kwargs): 111 | mode = self.handle_if_exists(if_exists) 112 | df.to_csv( 113 | self.container.path_str, 114 | mode=mode, 115 | # write the header only during an initial "write", not during "append" 116 | header=(mode == "w"), 117 | storage_options=self.container.storage_options, 118 | index=False, 119 | ) 120 | -------------------------------------------------------------------------------- /tests/client/unit/test_naming_conventions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | from datamodel_code_generator.reference import camel_to_snake 17 | 18 | from mostlyai.sdk.client._naming_conventions import ( 19 | _snake_to_camel, 20 | map_camel_to_snake_case, 21 | map_snake_to_camel_case, 22 | ) 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "snake_str, expected", 27 | [ 28 | ("snake_case", "snakeCase"), 29 | ("this_is_a_test", "thisIsATest"), 30 | ("alreadyCamelCase", "alreadyCamelCase"), 31 | ], 32 | ) 33 | def test__snake_to_camel_param(snake_str, expected): 34 | assert _snake_to_camel(snake_str) == expected 35 | 36 | 37 | def test_map_snake_to_camel_case(): 38 | input_data = { 39 | "snake_case_key": "value", 40 | "nested_dict": { 41 | "another_snake_case": 123, 42 | "deep_nested_dict": { 43 | "yet_another_key": "deepValue", 44 | "someCamelCaseKey": "someOtherValue", 45 | "normalNumber": 42, 46 | }, 47 | }, 48 | "list_of_dicts": [{"list_snake_case": "item1"}, {"another_list_key": "item2"}], 49 | "mixed_list": ["simpleString", {"nested_snake": "nestedValue"}, 3.14], 50 | "alreadyCamelCase": "unchangedValue", 51 | "no_underscore_key": "noChange", 52 | "empty_string_key": "", 53 | } 54 | 55 | expected_output = { 56 | "snakeCaseKey": "value", 57 | "nestedDict": { 58 | "anotherSnakeCase": 123, 59 | "deepNestedDict": { 60 | "yetAnotherKey": "deepValue", 61 | "someCamelCaseKey": "someOtherValue", 62 | "normalNumber": 42, 63 | }, 64 | }, 65 | "listOfDicts": [{"listSnakeCase": "item1"}, {"anotherListKey": "item2"}], 66 | "mixedList": ["simpleString", {"nestedSnake": "nestedValue"}, 3.14], 67 | "alreadyCamelCase": "unchangedValue", 68 | "noUnderscoreKey": "noChange", 69 | "emptyStringKey": "", 70 | } 71 | 72 | assert map_snake_to_camel_case(input_data) == expected_output 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "camel_str, expected", 77 | [ 78 | ("camelCase", "camel_case"), 79 | ("thisIsATest", "this_is_a_test"), 80 | ("already_snake_case", "already_snake_case"), 81 | ("HTTPRequest", "http_request"), 82 | ("parseJSONResponse", "parse_json_response"), 83 | ("XMLHttpRequest", "xml_http_request"), 84 | ("totalVirtualGPUTime", "total_virtual_gpu_time"), 85 | ("totalVirtualCPUTime", "total_virtual_cpu_time"), 86 | ("userIDToken", "user_id_token"), 87 | ("AWSConfig", "aws_config"), 88 | ("useHTTPSConnection", "use_https_connection"), 89 | ], 90 | ) 91 | def test_camel_to_snake(camel_str, expected): 92 | assert camel_to_snake(camel_str) == expected 93 | 94 | 95 | def test_map_camel_to_snake_case(): 96 | input_data = { 97 | "camelCaseKey": "value", 98 | "nestedDict": { 99 | "anotherCamelCase": 123, 100 | "deepNestedDict": { 101 | "yetAnotherKey": "deepValue", 102 | "some_snake_case_key": "someOtherValue", 103 | "NormalNumber": 42, 104 | }, 105 | }, 106 | "listOfDicts": [{"listCamelCase": "item1"}, {"anotherListKey": "item2"}], 107 | "mixedList": ["simpleString", {"nestedCamel": "nestedValue"}, 3.14], 108 | "AlreadySnakeCase": "unchangedValue", 109 | "NoUnderscoreKey": "noChange", 110 | "EmptyStringKey": "", 111 | } 112 | 113 | expected_output = { 114 | "camel_case_key": "value", 115 | "nested_dict": { 116 | "another_camel_case": 123, 117 | "deep_nested_dict": { 118 | "yet_another_key": "deepValue", 119 | "some_snake_case_key": "someOtherValue", 120 | "normal_number": 42, 121 | }, 122 | }, 123 | "list_of_dicts": [{"list_camel_case": "item1"}, {"another_list_key": "item2"}], 124 | "mixed_list": ["simpleString", {"nested_camel": "nestedValue"}, 3.14], 125 | "already_snake_case": "unchangedValue", 126 | "no_underscore_key": "noChange", 127 | "empty_string_key": "", 128 | } 129 | 130 | assert map_camel_to_snake_case(input_data) == expected_output 131 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/auto_detect.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import re 17 | import time 18 | 19 | import pandas as pd 20 | 21 | from mostlyai.engine._common import safe_convert_datetime 22 | from mostlyai.sdk._data.base import DataTable 23 | from mostlyai.sdk._data.dtype import VirtualDType, VirtualInteger, VirtualVarchar 24 | from mostlyai.sdk._data.util.common import absorb_errors, run_with_timeout_unsafe 25 | from mostlyai.sdk.domain import ModelEncodingType 26 | 27 | AUTODETECT_SAMPLE_SIZE = 10_000 28 | AUTODETECT_TIMEOUT = 15 29 | LAT_LONG_REGEX = re.compile(r"^\s*-?\d+(\.\d+)?,\s*-?\d+(\.\d+)?\s*$") 30 | PK_POSSIBLE_VIRTUAL_DTYPES = (VirtualVarchar, VirtualInteger) 31 | 32 | _LOG = logging.getLogger(__name__) 33 | 34 | 35 | def auto_detect_encoding_types_and_pk(table: DataTable) -> tuple[dict[str, ModelEncodingType], str | None]: 36 | # sub-select only the columns which got the default tabular categorical encoding type 37 | columns_to_auto_detect = [ 38 | c for c, enc in table.encoding_types.items() if enc == ModelEncodingType.tabular_categorical 39 | ] 40 | dtypes = VirtualDType.from_dtypes(table.dtypes) 41 | primary_key_candidates = [ 42 | c for c, t in dtypes.items() if type(t) in PK_POSSIBLE_VIRTUAL_DTYPES and c.lower().endswith("id") 43 | ] # sub-select primary key candidates before sampling the data 44 | columns_to_sample = primary_key_candidates + [c for c in columns_to_auto_detect if c not in primary_key_candidates] 45 | fallback = ( 46 | {c: ModelEncodingType.tabular_categorical for c in columns_to_auto_detect}, 47 | None, 48 | ) 49 | 50 | def auto_detection_logic(): 51 | return_vals = None 52 | t0 = time.time() 53 | with absorb_errors(): 54 | data_sample = next(table.read_chunks(columns=columns_to_sample, fetch_chunk_size=AUTODETECT_SAMPLE_SIZE)) 55 | primary_key = auto_detect_primary_key(data_sample[primary_key_candidates]) 56 | remaining_columns_to_auto_detect = [c for c in columns_to_auto_detect if c != primary_key] 57 | # auto-detect encoding types for the sampled data 58 | return_vals = ( 59 | {c: auto_detect_encoding_type(data_sample[c]) for c in remaining_columns_to_auto_detect}, 60 | primary_key, 61 | ) 62 | _LOG.info(f"auto_detect_encoding_types_and_pk logic for table={table.name} took {time.time() - t0:.2f} seconds") 63 | return return_vals if return_vals is not None else fallback 64 | 65 | # wrap the detection logic with timeout and fallback 66 | return run_with_timeout_unsafe( 67 | auto_detection_logic, 68 | timeout=AUTODETECT_TIMEOUT, 69 | fallback=fallback, 70 | ) 71 | 72 | 73 | def auto_detect_primary_key(sample: pd.DataFrame) -> str | None: 74 | # assuming sample columns are (1) ending with "id" and (2) of PK_POSSIBLE_VIRTUAL_DTYPES dtype 75 | for c in sample.columns: 76 | # check (3) all values unique and non-null and (4) of max len 36 (e.g. uuid len) 77 | if sample[c].is_unique and sample[c].notnull().all() and sample[c].astype(str).str.len().max() <= 36: 78 | return c 79 | return None 80 | 81 | 82 | def auto_detect_encoding_type(x: pd.Series) -> ModelEncodingType: 83 | x = x.dropna() 84 | x = x.astype(str) 85 | x = x[x.str.strip() != ""] # filter out empty and whitespace-only strings 86 | 87 | # if all values are null or empty, default to categorical 88 | if len(x) == 0: 89 | return ModelEncodingType.tabular_categorical 90 | 91 | # if all non-null values can be converted to datetime -> datetime encoding 92 | if safe_convert_datetime(x).notna().all(): 93 | return ModelEncodingType.tabular_datetime 94 | 95 | # if all values match lat/long pattern -> lat_long (geo) encoding 96 | if x.str.match(LAT_LONG_REGEX).all(): 97 | return ModelEncodingType.tabular_lat_long 98 | 99 | # if more than 5% of rows contain unique values 100 | if len(x) >= 100 and x.value_counts().eq(1).reindex(x).mean() > 0.05: 101 | if ( 102 | x.str.len().nunique() == 1 or x.str.len().max() <= 36 103 | ): # if all values are of the same length or shorter than 36 chars -> character encoding 104 | return ModelEncodingType.tabular_character 105 | else: # if values are of different lengths -> text encoding 106 | return ModelEncodingType.language_text 107 | 108 | return ModelEncodingType.tabular_categorical 109 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/db/mssql.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from urllib.parse import quote, urlencode 17 | 18 | import pyodbc 19 | import sqlalchemy as sa 20 | from sqlalchemy.dialects.mssql.base import MSDialect 21 | 22 | from mostlyai.sdk._data.db.base import DBDType, SqlAlchemyContainer, SqlAlchemyTable 23 | 24 | # disable pyodbc pooling, SQLAlchemy utilises its own mechanism 25 | # read more: https://docs.sqlalchemy.org/en/14/dialects/mssql.html#pyodbc-pooling-connection-close-behavior 26 | pyodbc.pooling = False 27 | 28 | _LOG = logging.getLogger(__name__) 29 | 30 | 31 | class MssqlDType(DBDType): 32 | FROM_VIRTUAL_TIMESTAMP = sa.DATETIME 33 | 34 | @classmethod 35 | def sa_dialect_class(cls): 36 | return MSDialect 37 | 38 | 39 | class MssqlContainer(SqlAlchemyContainer): 40 | SCHEMES = ["mssql"] 41 | SA_CONNECT_ARGS_ACCESS_ENGINE = {"timeout": 3} 42 | SA_CONNECTION_KWARGS = { 43 | "ssl": "True", 44 | } 45 | SA_SSL_ATTR_KEY_MAP = { 46 | "root_certificate_path": "Certificate", 47 | } 48 | SQL_FETCH_FOREIGN_KEYS = """ 49 | SELECT 50 | FK.TABLE_NAME as TABLE_NAME, 51 | CU.COLUMN_NAME as COLUMN_NAME, 52 | PK.TABLE_NAME as REFERENCED_TABLE_NAME, 53 | PT.COLUMN_NAME as REFERENCED_COLUMN_NAME 54 | FROM 55 | INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS C 56 | INNER JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS FK ON C.CONSTRAINT_NAME = FK.CONSTRAINT_NAME 57 | INNER JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS PK ON C.UNIQUE_CONSTRAINT_NAME = PK.CONSTRAINT_NAME 58 | INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE CU ON C.CONSTRAINT_NAME = CU.CONSTRAINT_NAME 59 | INNER JOIN ( 60 | SELECT 61 | i1.TABLE_NAME, 62 | i2.COLUMN_NAME 63 | FROM 64 | INFORMATION_SCHEMA.TABLE_CONSTRAINTS i1 65 | INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE i2 ON i1.CONSTRAINT_NAME = i2.CONSTRAINT_NAME 66 | WHERE 67 | i1.CONSTRAINT_TYPE = 'PRIMARY KEY' 68 | AND i1.TABLE_SCHEMA = :schema_name 69 | ) PT ON PT.TABLE_NAME = PK.TABLE_NAME 70 | WHERE 71 | FK.TABLE_SCHEMA = :schema_name 72 | AND PK.TABLE_SCHEMA = :schema_name 73 | AND CU.TABLE_SCHEMA = :schema_name; 74 | """ 75 | _MSSQL_PYODBC_DRIVER = "ODBC Driver 18 for SQL Server" 76 | INIT_DEFAULT_VALUES = {"dbname": "master", "port": "1433"} 77 | 78 | @property 79 | def sa_uri(self): 80 | # User and password are needed to avoid double-encoding of @ character 81 | username = quote(self.username) 82 | password = quote(self.password) 83 | props = urlencode( 84 | { 85 | "driver": self._MSSQL_PYODBC_DRIVER, 86 | "TrustServerCertificate": "yes", 87 | **self.sa_engine_connection_kwargs, 88 | } 89 | ) 90 | uri = f"mssql+pyodbc://{username}:{password}@{self.host}:{self.port}/{self.dbname}?{props}" 91 | return uri 92 | 93 | @property 94 | def sa_create_engine_kwargs(self) -> dict: 95 | return { 96 | # improves speed of write dramatically 97 | # read more: https://docs.sqlalchemy.org/en/14/dialects/mssql.html#fast-executemany-mode 98 | "fast_executemany": True 99 | } 100 | 101 | @classmethod 102 | def table_class(cls): 103 | return MssqlTable 104 | 105 | def does_database_exist(self) -> bool: 106 | try: 107 | with self.init_sa_connection("db_exist_check") as connection: 108 | result = connection.execute(sa.text(f"SELECT name FROM sys.databases WHERE name='{self.dbname}'")) 109 | if result.fetchone(): 110 | db_exist = True 111 | else: 112 | db_exist = False 113 | return db_exist 114 | except Exception: 115 | return False 116 | 117 | 118 | class MssqlTable(SqlAlchemyTable): 119 | DATA_TABLE_TYPE = "mssql" 120 | SA_RANDOM = sa.func.newid() 121 | # MSSQL has upper bound of 2100 on number of bound parameters in a query, 122 | # each batch contributes len(batch) bound parameters to the counter, 123 | # so max batch size must be significantly smaller than 2100 to 124 | # leave some space for other bound parameters in a query, e.g. column names 125 | SA_MAX_VALS_PER_BATCH = 1_900 126 | 127 | @classmethod 128 | def dtype_class(cls): 129 | return MssqlDType 130 | 131 | @classmethod 132 | def container_class(cls): 133 | return MssqlContainer 134 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_generate_model_report_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | from collections.abc import Callable 16 | from pathlib import Path 17 | 18 | import pandas as pd 19 | 20 | from mostlyai.sdk.domain import ModelType 21 | 22 | _LOG = logging.getLogger(__name__) 23 | 24 | 25 | def execute_step_generate_model_report_data( 26 | *, 27 | workspace_dir: Path, 28 | model_type: ModelType, 29 | update_progress: Callable, 30 | ): 31 | # import ENGINE here to avoid pre-mature loading of large ENGINE dependencies 32 | from mostlyai import engine 33 | from mostlyai.engine._workspace import Workspace 34 | 35 | # determine max sample size for generated report samples 36 | workspace = Workspace(workspace_dir) 37 | tgt_stats = workspace.tgt_stats.read() 38 | max_sample_size = qa_sample_size_heuristic(tgt_stats=tgt_stats, model_type=model_type) 39 | 40 | # pull context data for report generation (if applicable) 41 | has_context = workspace.ctx_stats.path.exists() 42 | if has_context: 43 | ctx_stats = workspace.ctx_stats.read() 44 | ctx_primary_key = ctx_stats.get("keys", {}).get("primary_key") 45 | ctx_input_path = workspace_dir / "report-ctx-data" 46 | _pull_context_for_report_generation( 47 | ctx_data_path=workspace.ctx_data_path, 48 | output_path=ctx_input_path, 49 | max_sample_size=max_sample_size, 50 | ctx_primary_key=ctx_primary_key, 51 | ) 52 | ctx_data = pd.read_parquet(ctx_input_path) 53 | else: 54 | ctx_data = None 55 | 56 | # call GENERATE 57 | engine.generate( 58 | ctx_data=ctx_data, 59 | sample_size=max_sample_size, 60 | workspace_dir=workspace_dir, 61 | update_progress=update_progress, 62 | ) 63 | 64 | 65 | def qa_sample_size_heuristic(tgt_stats: dict, model_type: ModelType) -> int: 66 | # import ENGINE here to avoid pre-mature loading of large ENGINE dependencies 67 | from mostlyai.engine._common import get_cardinalities, get_sequence_length_stats 68 | 69 | if model_type == ModelType.language: 70 | return 1_000 71 | trn_sample_size = tgt_stats["no_of_training_records"] + tgt_stats["no_of_validation_records"] 72 | no_tgt_sub_columns = len(get_cardinalities(tgt_stats)) 73 | tgt_q50_seqlen = get_sequence_length_stats(tgt_stats)["median"] 74 | data_points = no_tgt_sub_columns * tgt_q50_seqlen 75 | if data_points > 1_000: 76 | gen_sample_size = 10_000 77 | else: 78 | gen_sample_size = 100_000 79 | return min(gen_sample_size, trn_sample_size) 80 | 81 | 82 | def _pull_context_for_report_generation( 83 | *, ctx_data_path: Path, output_path: Path, max_sample_size: int, ctx_primary_key: str 84 | ): 85 | ctx_trn_files = sorted(ctx_data_path.glob("part.*-trn.parquet")) 86 | ctx_val_files = sorted(ctx_data_path.glob("part.*-val.parquet")) 87 | # fetch keys alone first 88 | ctx_trn_keys = pd.concat([pd.read_parquet(f, columns=[ctx_primary_key]) for f in ctx_trn_files], ignore_index=True) 89 | ctx_val_keys = pd.concat([pd.read_parquet(f, columns=[ctx_primary_key]) for f in ctx_val_files], ignore_index=True) 90 | ctx_trn_keys = ctx_trn_keys.sample(frac=1).reset_index(drop=True) 91 | ctx_val_keys = ctx_val_keys.sample(frac=1).reset_index(drop=True) 92 | # attempt to balance the training and validation sets 93 | trn_keys = int(max_sample_size * 0.50) 94 | val_keys = max_sample_size - trn_keys 95 | keys = pd.concat( 96 | [ 97 | ctx_trn_keys[:trn_keys].assign(set="trn"), 98 | ctx_val_keys[:val_keys].assign(set="val"), 99 | ctx_trn_keys[trn_keys:].assign(set="trn"), 100 | ctx_val_keys[val_keys:].assign(set="val"), 101 | ], 102 | ignore_index=True, 103 | ) 104 | keys = keys.head(max_sample_size) 105 | ctx_trn_keys = keys[keys["set"] == "trn"][[ctx_primary_key]] 106 | ctx_val_keys = keys[keys["set"] == "val"][[ctx_primary_key]] 107 | # fetch rest of the context data 108 | df_trn_ctx = pd.concat([pd.read_parquet(f).merge(ctx_trn_keys, on=ctx_primary_key) for f in ctx_trn_files]) 109 | df_val_ctx = pd.concat([pd.read_parquet(f).merge(ctx_val_keys, on=ctx_primary_key) for f in ctx_val_files]) 110 | df_trn_ctx = df_trn_ctx.reset_index(drop=True) 111 | df_val_ctx = df_val_ctx.reset_index(drop=True) 112 | output_path.mkdir(parents=True, exist_ok=True) 113 | df_trn_ctx.to_parquet(output_path / f"part.{0:06}.{0:06}-trn.parquet") 114 | df_val_ctx.to_parquet(output_path / f"part.{0:06}.{0:06}-val.parquet") 115 | _LOG.info(f"pulled context data for model report ({len(df_trn_ctx)=:,} {len(df_val_ctx)=:,})") 116 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/execution/step_analyze_training_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections import ChainMap, defaultdict 17 | from collections.abc import Callable 18 | from pathlib import Path 19 | 20 | from mostlyai.engine.domain import DifferentialPrivacyConfig 21 | from mostlyai.sdk.domain import Generator, ModelEncodingType, ModelType, SourceColumnValueRange 22 | 23 | 24 | def execute_step_analyze_training_data( 25 | *, 26 | generator: Generator, 27 | model_type: ModelType, 28 | target_table_name: str, 29 | workspace_dir: Path, 30 | update_progress: Callable, 31 | ) -> tuple[dict[str, ModelEncodingType], dict[str, SourceColumnValueRange]]: 32 | # import ENGINE here to avoid pre-mature loading of large ENGINE dependencies 33 | from mostlyai import engine 34 | from mostlyai.engine._workspace import Workspace 35 | 36 | # fetch model_config 37 | tgt_table = next(t for t in generator.tables if t.name == target_table_name) 38 | if model_type == ModelType.language: 39 | model_config = tgt_table.language_model_configuration 40 | else: 41 | model_config = tgt_table.tabular_model_configuration 42 | 43 | # convert from SDK domain to ENGINE domain 44 | if model_config.differential_privacy: 45 | differential_privacy = DifferentialPrivacyConfig(**model_config.differential_privacy.model_dump()) 46 | else: 47 | differential_privacy = None 48 | 49 | # call ANALYZE 50 | engine.analyze( 51 | workspace_dir=workspace_dir, 52 | value_protection=model_config.value_protection, 53 | differential_privacy=differential_privacy, 54 | update_progress=update_progress, 55 | ) 56 | 57 | # read stats 58 | workspace = Workspace(workspace_dir) 59 | tgt_stats = workspace.tgt_stats.read() 60 | encoding_types = _get_encoding_types(tgt_stats) 61 | value_ranges = _get_value_ranges(tgt_stats) 62 | return encoding_types, value_ranges 63 | 64 | 65 | def _get_encoding_types(stats: dict) -> dict[str, ModelEncodingType]: 66 | encoding_types = {} 67 | for col, col_stats in stats.get("columns", {}).items(): 68 | encoding_type = col_stats.get("encoding_type") 69 | if encoding_type is not None: 70 | encoding_types[col] = ModelEncodingType(encoding_type) 71 | return encoding_types 72 | 73 | 74 | def _get_value_ranges(stats: dict) -> dict[str, SourceColumnValueRange]: 75 | # import ENGINE here to avoid pre-mature loading of large ENGINE dependencies 76 | from mostlyai.engine._encoding_types.tabular.categorical import CATEGORICAL_NULL_TOKEN, CATEGORICAL_UNKNOWN_TOKEN 77 | 78 | def parse_values(col_stats: dict) -> dict: 79 | size_limit = 1_000 80 | values = [ 81 | code 82 | for code in col_stats.get("codes", {}).keys() 83 | if code not in [CATEGORICAL_UNKNOWN_TOKEN, CATEGORICAL_NULL_TOKEN] 84 | ][:size_limit] 85 | return {"values": values} 86 | 87 | def parse_min_max(col_stats: dict) -> dict: 88 | values = col_stats.get("bins", []) + col_stats.get("min5", []) + col_stats.get("max5", []) 89 | min_ = str(min(values)) if values else None 90 | max_ = str(max(values)) if values else None 91 | return {"min": min_, "max": max_} 92 | 93 | def parse_has_null(col_stats: dict) -> dict: 94 | has_null = any( 95 | [ 96 | CATEGORICAL_NULL_TOKEN in col_stats.get("codes", {}).keys(), 97 | col_stats.get("has_nan", False), 98 | col_stats.get("has_na", False), 99 | ] 100 | ) 101 | return {"has_null": has_null} 102 | 103 | def combine(*parsers): 104 | def pipe(col_stats: dict) -> SourceColumnValueRange: 105 | return SourceColumnValueRange(**ChainMap(*[parser(col_stats) for parser in parsers])) 106 | 107 | return pipe 108 | 109 | parsers = defaultdict( 110 | lambda: combine(parse_has_null), 111 | { 112 | ModelEncodingType.tabular_categorical: combine(parse_values, parse_has_null), 113 | ModelEncodingType.tabular_numeric_discrete: combine(parse_values, parse_has_null), 114 | ModelEncodingType.tabular_numeric_binned: combine(parse_min_max, parse_has_null), 115 | ModelEncodingType.tabular_numeric_digit: combine(parse_min_max, parse_has_null), 116 | ModelEncodingType.tabular_datetime: combine(parse_min_max, parse_has_null), 117 | }, 118 | ) 119 | 120 | value_ranges = {} 121 | for col, col_stats in stats.get("columns", {}).items(): 122 | encoding_type = col_stats.get("encoding_type") 123 | value_ranges[col] = parsers[encoding_type](col_stats) 124 | 125 | return value_ranges 126 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/conversions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from pydoc import locate 17 | from typing import Any 18 | 19 | from mostlyai.sdk._data.db.base import SqlAlchemyContainer 20 | from mostlyai.sdk._data.file.base import FileContainer 21 | from mostlyai.sdk._data.file.container.bucket_based import BucketBasedContainer 22 | from mostlyai.sdk.domain import Connector, ConnectorType 23 | 24 | _LOG = logging.getLogger(__name__) 25 | 26 | CONNECTOR_TYPE_CONTAINER_CLASS_MAP = { 27 | ConnectorType.mysql: "mostlyai.sdk._data.db.mysql.MysqlContainer", 28 | ConnectorType.postgres: "mostlyai.sdk._data.db.postgresql.PostgresqlContainer", 29 | ConnectorType.redshift: "mostlyai.sdk._data.db.redshift.RedshiftContainer", 30 | ConnectorType.mssql: "mostlyai.sdk._data.db.mssql.MssqlContainer", 31 | ConnectorType.oracle: "mostlyai.sdk._data.db.oracle.OracleContainer", 32 | ConnectorType.mariadb: "mostlyai.sdk._data.db.mysql.MariadbContainer", 33 | ConnectorType.snowflake: "mostlyai.sdk._data.db.snowflake.SnowflakeContainer", 34 | ConnectorType.bigquery: "mostlyai.sdk._data.db.bigquery.BigQueryContainer", 35 | ConnectorType.databricks: "mostlyai.sdk._data.db.databricks.DatabricksContainer", 36 | ConnectorType.hive: "mostlyai.sdk._data.db.hive.HiveContainer", 37 | ConnectorType.sqlite: "mostlyai.sdk._data.db.sqlite.SqliteContainer", 38 | ConnectorType.azure_storage: "mostlyai.sdk._data.file.container.azure.AzureBlobFileContainer", 39 | ConnectorType.google_cloud_storage: "mostlyai.sdk._data.file.container.gcs.GcsContainer", 40 | ConnectorType.s3_storage: "mostlyai.sdk._data.file.container.aws.AwsS3FileContainer", 41 | ConnectorType.file_upload: "mostlyai.sdk._data.file.base.LocalFileContainer", 42 | } 43 | 44 | CONNECTOR_TYPE_CONTAINER_PARAMS_CLASS_MAP = { 45 | ConnectorType.mysql: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 46 | ConnectorType.postgres: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 47 | ConnectorType.redshift: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 48 | ConnectorType.mssql: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 49 | ConnectorType.oracle: "mostlyai.sdk._data.metadata_objects.OracleContainerParameters", 50 | ConnectorType.mariadb: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 51 | ConnectorType.snowflake: "mostlyai.sdk._data.metadata_objects.SnowflakeContainerParameters", 52 | ConnectorType.bigquery: "mostlyai.sdk._data.metadata_objects.BigQueryContainerParameters", 53 | ConnectorType.databricks: "mostlyai.sdk._data.metadata_objects.DatabricksContainerParameters", 54 | ConnectorType.hive: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 55 | ConnectorType.sqlite: "mostlyai.sdk._data.metadata_objects.SqlAlchemyContainerParameters", 56 | ConnectorType.azure_storage: "mostlyai.sdk._data.metadata_objects.AzureBlobFileContainerParameters", 57 | ConnectorType.google_cloud_storage: "mostlyai.sdk._data.metadata_objects.GcsContainerParameters", 58 | ConnectorType.s3_storage: "mostlyai.sdk._data.metadata_objects.AwsS3FileContainerParameters", 59 | ConnectorType.file_upload: "mostlyai.sdk._data.metadata_objects.LocalFileContainerParameters", 60 | } 61 | 62 | 63 | def convert_connector_params_to_container_params(connector: Connector) -> dict[str, Any]: 64 | """ 65 | Merge `config`, `secrets` and `ssl` of the Connector into one dictionary 66 | and then validate it individually based on the connector type. 67 | """ 68 | connector.config = connector.config or {} 69 | connector.secrets = connector.secrets or {} 70 | connector.ssl = connector.ssl or {} 71 | 72 | container_params_cls_path = CONNECTOR_TYPE_CONTAINER_PARAMS_CLASS_MAP.get(connector.type) 73 | container_params_pydantic_cls = locate(container_params_cls_path) 74 | 75 | container_params = connector.config | connector.secrets | connector.ssl 76 | container_params = container_params_pydantic_cls.model_validate(container_params).model_dump() 77 | return container_params 78 | 79 | 80 | def create_container_from_connector( 81 | connector: Connector, 82 | ) -> SqlAlchemyContainer | BucketBasedContainer | FileContainer: 83 | container_cls_path = CONNECTOR_TYPE_CONTAINER_CLASS_MAP.get(connector.type) 84 | if not container_cls_path: 85 | raise ValueError("Unsupported connector type!") 86 | container_cls = locate(container_cls_path) 87 | container_params = convert_connector_params_to_container_params(connector) 88 | container = container_cls(**container_params) 89 | # Check if the container is accessible before __repr__ (workaround broken logic of sa and __repr__) 90 | is_accessible = container.is_accessible() 91 | _LOG.info(f"Container accessible: {is_accessible}") 92 | _LOG.info(f"Container created: {container}") 93 | return container 94 | -------------------------------------------------------------------------------- /tests/client/unit/test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from unittest import mock 17 | 18 | import pytest 19 | import respx 20 | from httpx import NetworkError, Response 21 | 22 | from mostlyai.sdk.client.base import DEFAULT_BASE_URL, CustomBaseModel, Paginator, _MostlyBaseClient 23 | from mostlyai.sdk.client.exceptions import APIError, APIStatusError 24 | 25 | 26 | @pytest.fixture 27 | def mostly_base_client(): 28 | """Fixture to provide a _MostlyBaseClient instance.""" 29 | return _MostlyBaseClient(api_key="test_api_key") 30 | 31 | 32 | @pytest.fixture 33 | def more_specific_client(): 34 | class MoreSpecificClient(_MostlyBaseClient): 35 | SECTION = ["more", "specific"] 36 | 37 | return MoreSpecificClient(api_key="test_api_key") 38 | 39 | 40 | class TestMostlyBaseClient: 41 | def test_initialization(self): 42 | # Test with all parameters provided 43 | client = _MostlyBaseClient(base_url="https://custom.url", api_key="12345") 44 | assert client.base_url == "https://custom.url" 45 | assert client.api_key == "12345" 46 | 47 | # Test with all required parameters provided 48 | client = _MostlyBaseClient(api_key="12345") 49 | assert client.base_url == DEFAULT_BASE_URL 50 | assert client.api_key == "12345" 51 | 52 | @respx.mock 53 | def test_request_success(self, mostly_base_client): 54 | mock_url = respx.get("https://app.mostly.ai/api/v2/test").mock( 55 | return_value=Response(200, json={"success": True}) 56 | ) 57 | response = mostly_base_client.request(path="test", verb="GET") 58 | 59 | assert mock_url.called 60 | assert response == {"success": True} 61 | 62 | @respx.mock 63 | def test_request_http_error(self, mostly_base_client): 64 | respx.get("https://app.mostly.ai/api/v2/test").mock(return_value=Response(404, json={"message": "Not found"})) 65 | 66 | with pytest.raises(APIStatusError) as excinfo: 67 | mostly_base_client.request(path="test", verb="GET") 68 | 69 | assert "HTTP 404: Not found" in str(excinfo.value) 70 | 71 | @respx.mock 72 | def test_client_request_network_error(self, mostly_base_client): 73 | respx.get("https://app.mostly.ai/api/v2/test").mock(side_effect=NetworkError("Network error")) 74 | 75 | with pytest.raises(APIError) as excinfo: 76 | mostly_base_client.request("test", "GET") 77 | 78 | assert "An error occurred while requesting" in str(excinfo.value) 79 | 80 | @respx.mock 81 | def test_client_post_request(self, mostly_base_client): 82 | test_data = CustomBaseModel(name="Test") 83 | respx.post("https://app.mostly.ai/api/v2/create").mock(return_value=Response(201, json={"success": True})) 84 | 85 | response = mostly_base_client.request("create", "POST", json=test_data) 86 | 87 | assert response == {"success": True} 88 | 89 | @respx.mock 90 | def test_more_specific_client_request(self, more_specific_client): 91 | mock_url = respx.get("https://app.mostly.ai/api/v2/more/specific/test").mock( 92 | return_value=Response(200, json={"success": True}) 93 | ) 94 | response = more_specific_client.request(path="test", verb="GET") 95 | 96 | assert mock_url.called 97 | assert response == {"success": True} 98 | 99 | 100 | class TestPaginator: 101 | @respx.mock 102 | def test_iteration(self, mostly_base_client): 103 | # Define mock responses 104 | offset_page_map = { 105 | 0: {"results": [{"id": 1}, {"id": 2}], "totalCount": 4}, 106 | 2: {"results": [{"id": 3}, {"id": 4}], "totalCount": 4}, 107 | 4: {"results": [], "totalCount": 4}, 108 | } 109 | 110 | # Using a callback to differentiate the response based on offset and page 111 | def request_callback(request): 112 | url_pattern = re.compile(r"offset=(\d+)") 113 | match = re.search(url_pattern, str(request.url.query)) 114 | offset = int(match.group(1)) 115 | 116 | return Response(200, json=offset_page_map.get(offset)) 117 | 118 | # Mock any GET request and use the callback to handle it 119 | respx.get(url=mock.ANY).mock(side_effect=request_callback) 120 | 121 | paginator = Paginator(mostly_base_client, dict, limit=2) 122 | 123 | items = list(paginator) 124 | assert [item["id"] for item in items] == [1, 2] 125 | 126 | @respx.mock 127 | def test_paginator_no_results(self, mostly_base_client): 128 | respx.get("https://app.mostly.ai/api/v2?offset=0&limit=50").mock( 129 | return_value=Response(200, json={"results": [], "totalCount": 0}) 130 | ) 131 | 132 | paginator = Paginator(mostly_base_client, dict) 133 | 134 | items = list(paginator) 135 | assert len(items) == 0 136 | -------------------------------------------------------------------------------- /tests/_data/unit/util/test_kerberos.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from datetime import datetime 16 | 17 | import pytest 18 | 19 | from mostlyai.sdk._data.util.kerberos import is_kerberos_ticket_alive 20 | 21 | 22 | @pytest.fixture 23 | def mock_datetime_now(monkeypatch): 24 | class MockDatetime: 25 | @classmethod 26 | def now(cls): 27 | return datetime(2023, 6, 11, 14, 0, 0) 28 | 29 | monkeypatch.setattr("mostlyai.sdk._data.util.kerberos.datetime", MockDatetime) 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "klist_result, expected", 34 | [ 35 | ( 36 | """ 37 | Ticket cache: FILE:/tmp/ba42aff58f0e6d3d6b55a531459a56cc666c4c58ddc53d3b9a34083da6b2739a 38 | Default principal: hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 39 | 40 | Valid starting Expires Service principal 41 | 06/11/2024 14:47:57 06/11/2024 14:48:51 krbtgt/INTERNAL@INTERNAL 42 | 06/11/2023 11:00:00 06/11/2023 15:00:00 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 43 | """, 44 | True, 45 | ), 46 | ( 47 | """ 48 | Ticket cache: FILE:/tmp/ba42aff58f0e6d3d6b55a531459a56cc666c4c58ddc53d3b9a34083da6b2739a 49 | Default principal: hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 50 | 51 | Valid starting Expires Service principal 52 | 06/11/2024 14:47:57 06/11/2024 14:48:51 krbtgt/INTERNAL@INTERNAL 53 | 06/11/2023 11:00:00 06/11/2023 14:01:10 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 54 | """, 55 | True, 56 | ), 57 | ( 58 | """ 59 | Ticket cache: FILE:/tmp/ba42aff58f0e6d3d6b55a531459a56cc666c4c58ddc53d3b9a34083da6b2739a 60 | Default principal: hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 61 | 62 | Valid starting Expires Service principal 63 | 06/11/2024 14:47:57 06/11/2024 14:48:51 krbtgt/INTERNAL@INTERNAL 64 | 06/11/2023 11:00:00 06/11/2023 14:00:50 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 65 | """, 66 | False, 67 | ), 68 | ( 69 | """ 70 | Ticket cache: FILE:/tmp/ba42aff58f0e6d3d6b55a531459a56cc666c4c58ddc53d3b9a34083da6b2739a 71 | Default principal: hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 72 | 73 | Valid starting Expires Service principal 74 | 06/11/2024 14:47:57 06/11/2024 14:48:51 krbtgt/INTERNAL@INTERNAL 75 | 06/11/2023 11:00:00 06/10/2023 14:01:10 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 76 | """, 77 | False, 78 | ), 79 | ( 80 | """ 81 | Ticket cache: FILE:/tmp/ba42aff58f0e6d3d6b55a531459a56cc666c4c58ddc53d3b9a34083da6b2739a 82 | Default principal: hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 83 | 84 | Valid starting Expires Service principal 85 | 06/11/2024 14:47:57 06/11/2024 14:48:51 krbtgt/INTERNAL@INTERNAL 86 | 06/11/2023 11:00:00 06/12/2023 14:00:50 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 87 | """, 88 | True, 89 | ), 90 | ( 91 | """ 92 | Ticket cache: FILE:/tmp/krb5cc_1000 93 | Default principal: user@EXAMPLE.COM 94 | """, 95 | False, 96 | ), 97 | ( 98 | """ 99 | Credentials cache: FILE:/var/folders/0w/qcxns49j7fn6k52f0q1n41740000gn/T/456688f9f3ee845a58b75ced8c634f9e58d86a86ed42a947855363a3ca7fbc58 100 | Principal: hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 101 | 102 | Issued Expires Principal 103 | Jun 14 17:52:18 2024 Jun 15 03:52:18 2024 krbtgt/INTERNAL@INTERNAL 104 | Jun 14 17:52:18 2024 Jun 15 03:52:18 2024 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 105 | """, 106 | True, 107 | ), 108 | ( 109 | """ 110 | Credentials cache: FILE:/var/folders/0w/qcxns49j7fn6k52f0q1n41740000gn/T/625797d527c72b4062bf2a8c721428e9ae2a5d17e94f923f78a2a78924565f43 111 | Principal: randomuser/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 112 | 113 | Issued Expires Principal 114 | Jun 14 17:52:18 2024 Jun 15 03:52:18 2024 krbtgt/INTERNAL@INTERNAL 115 | Jun 14 17:52:18 2024 Jun 15 03:52:18 2024 hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL 116 | """, 117 | True, 118 | ), 119 | ], 120 | ) 121 | def test_is_kerberos_ticket_alive(mock_datetime_now, klist_result, expected): 122 | service_principal = "hive/hive-kerberized-ssl.test.mostlylab.com@INTERNAL" 123 | assert is_kerberos_ticket_alive(klist_result, service_principal) == expected 124 | -------------------------------------------------------------------------------- /mostlyai/sdk/_local/server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import atexit 15 | import os 16 | import tempfile 17 | import time 18 | from pathlib import Path 19 | from threading import Thread 20 | 21 | import rich 22 | import uvicorn 23 | from fastapi import FastAPI 24 | from fastapi.responses import JSONResponse 25 | from pydantic import ValidationError 26 | 27 | from mostlyai.sdk._local.routes import Routes 28 | 29 | 30 | class LocalServer: 31 | """ 32 | Instantiate a local server for the Synthetic Data SDK. 33 | 34 | Args: 35 | home_dir: The directory where the SDK stores its data. Defaults to `~/mostlyai`. 36 | port: The port to bind the server to. If `None`, a Unix Domain Socket (UDS) will be used. Defaults to `None` on Unix and `8080` on Windows. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | home_dir: str | Path | None = None, 42 | port: int | None = None, 43 | ): 44 | self.home_dir = Path(home_dir or "~/mostlyai").expanduser() 45 | self.home_dir.mkdir(parents=True, exist_ok=True) 46 | # check read/write access to `home_dir` 47 | if not os.access(self.home_dir, os.R_OK) or not os.access(self.home_dir, os.W_OK): 48 | raise PermissionError(f"Cannot read/write to {self.home_dir}") 49 | if port is None and os.name == "nt": 50 | port = 8080 # use TCP by default on Windows 51 | self.port = port 52 | # binding to all interfaces (0.0.0.0) is required for docker use case 53 | self.host = "0.0.0.0" if port is not None else None 54 | self.uds = ( 55 | tempfile.NamedTemporaryFile(prefix=".mostlyai-", suffix=".sock", delete=False).name 56 | if port is None 57 | else None 58 | ) 59 | self.base_url = "http://127.0.0.1" + (f":{port}" if port else "") 60 | self._app = FastAPI( 61 | root_path="/api/v2", 62 | title="Synthetic Data SDK ✨", 63 | description="Welcome! This is your Local Server instance of the Synthetic Data SDK. " 64 | "Connect via the MOSTLY AI client to train models and generate synthetic data locally. " 65 | "Share the knowledge of your synthetic data generators with your team or the world by " 66 | "deploying these then to a MOSTLY AI platform. Enjoy!", 67 | version="1.0.0", 68 | ) 69 | routes = Routes(self.home_dir) 70 | self._app.include_router(routes.router) 71 | self.register_exception_handlers() 72 | self._server = None 73 | self._thread = None 74 | self.start() # Automatically start the server during initialization 75 | 76 | def _clear_socket_file(self): 77 | if self.uds and os.path.exists(self.uds): 78 | os.remove(self.uds) 79 | 80 | def _create_server(self): 81 | self._clear_socket_file() 82 | config = uvicorn.Config( 83 | self._app, host=self.host, port=self.port, uds=self.uds, log_level="error", reload=False 84 | ) 85 | self._server = uvicorn.Server(config) 86 | 87 | def _run_server(self): 88 | if self._server: 89 | self._server.run() 90 | 91 | def start(self): 92 | if not self._server: 93 | self._create_server() 94 | self._thread = Thread(target=self._run_server, daemon=True) 95 | self._thread.start() 96 | # make sure the socket file is cleaned up on exit 97 | atexit.register(self._clear_socket_file) 98 | # give the server a moment to start 99 | time.sleep(0.5) 100 | 101 | def stop(self): 102 | if self._server and self._server.started: 103 | rich.print("Stopping Synthetic Data SDK in local mode") 104 | self._server.should_exit = True # Signal the server to shut down 105 | self._thread.join() # Wait for the server thread to finish 106 | self._clear_socket_file() 107 | 108 | def __enter__(self): 109 | # Ensure the server is running 110 | self.start() 111 | return self 112 | 113 | def __exit__(self, exc_type, exc_value, traceback): 114 | # Stop the server when exiting the context 115 | self.stop() 116 | 117 | def __del__(self): 118 | # Backup cleanup in case `stop` was not called explicitly or via context 119 | if self._server and self._server.started: 120 | print("Automatically shutting down server") 121 | self.stop() 122 | 123 | def register_exception_handlers(self): 124 | @self._app.exception_handler(Exception) 125 | async def global_exception_handler(request, exc): 126 | if isinstance(exc, ValidationError): 127 | return JSONResponse(status_code=422, content={"detail": str(exc)}) 128 | # for fastapi.HTTPException: it will be raised as is 129 | # for other unhandled exceptions: client will receive a 500 Internal Server Error 130 | raise exc 131 | -------------------------------------------------------------------------------- /mostlyai/sdk/client/integrations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import rich 18 | 19 | from mostlyai.sdk.client.base import ( 20 | DELETE, 21 | GET, 22 | POST, 23 | _MostlyBaseClient, 24 | ) 25 | from mostlyai.sdk.domain import ( 26 | Integration, 27 | IntegrationAuthorizationRequest, 28 | ) 29 | 30 | 31 | class _MostlyIntegrationsClient(_MostlyBaseClient): 32 | SECTION = ["integrations"] 33 | 34 | # PUBLIC METHODS # 35 | 36 | def list(self) -> list[Integration]: 37 | """ 38 | List integrations. 39 | 40 | Returns all integrations accessible by the user. 41 | 42 | Example for listing all integrations: 43 | ```python 44 | from mostlyai.sdk import MostlyAI 45 | mostly = MostlyAI() 46 | integrations = mostly.integrations.list() 47 | for i in integrations: 48 | print(f"Integration `{i.provider_name}` ({i.status}, {i.provider_id})") 49 | ``` 50 | 51 | Returns: 52 | list[Integration]: A list of integration objects. 53 | """ 54 | response = self.request(verb=GET, path=[]) 55 | return [Integration(**item) for item in response] 56 | 57 | def get(self, provider_id: str) -> Integration: 58 | """ 59 | Retrieve an integration by its provider ID. 60 | 61 | Args: 62 | provider_id: The provider identifier (e.g., "google", "slack", "github"). 63 | 64 | Returns: 65 | Integration: The retrieved integration object. 66 | 67 | Example for retrieving an integration: 68 | ```python 69 | from mostlyai.sdk import MostlyAI 70 | mostly = MostlyAI() 71 | i = mostly.integrations.get('google') 72 | i 73 | ``` 74 | """ 75 | response = self.request(verb=GET, path=[provider_id], response_type=Integration) 76 | return response 77 | 78 | def authorize( 79 | self, 80 | provider: str, 81 | scope_ids: list[str], 82 | ) -> str: 83 | """ 84 | Generate an OAuth authorization URL for connecting an integration. 85 | 86 | Args: 87 | provider: The OAuth provider identifier (e.g., "google", "slack", "github"). 88 | scope_ids: List of scope identifiers for this integration. 89 | 90 | Returns: 91 | str: The OAuth authorization URL. 92 | 93 | Example for generating an authorization URL: 94 | ```python 95 | from mostlyai.sdk import MostlyAI 96 | mostly = MostlyAI() 97 | url = mostly.integrations.authorize( 98 | provider='google', 99 | scope_ids=['550e8400-e29b-41d4-a716-446655440000'] 100 | ) 101 | print(f"Visit this URL to authorize: {url}") 102 | ``` 103 | """ 104 | config = IntegrationAuthorizationRequest(scope_ids=scope_ids) 105 | response = self.request( 106 | verb=POST, 107 | path=[provider, "authorize"], 108 | json=config, 109 | response_type=dict, 110 | do_response_dict_snake_case=True, 111 | ) 112 | return response.get("authorization_url", "") 113 | 114 | def refresh(self, provider_id: str) -> None: 115 | """ 116 | Refresh integration OAuth token. 117 | 118 | Refresh the OAuth access token for an integration. 119 | If the integration has a refresh token, it will be used to obtain new access tokens. 120 | If no refresh token exists, a new authorization flow must be initiated. 121 | 122 | Args: 123 | provider_id: The provider identifier (e.g., "google", "slack", "github"). 124 | 125 | Example for refreshing an integration token: 126 | ```python 127 | from mostlyai.sdk import MostlyAI 128 | mostly = MostlyAI() 129 | mostly.integrations.refresh('google') 130 | ``` 131 | 132 | Raises: 133 | APIError: If the integration is not found (404) or no refresh token is available (400). 134 | """ 135 | self.request(verb=POST, path=[provider_id, "refresh"]) 136 | 137 | def disconnect(self, provider_id: str) -> None: 138 | """ 139 | Disconnect an integration. 140 | 141 | Args: 142 | provider_id: The provider identifier (e.g., "google", "slack", "github"). 143 | 144 | Example for disconnecting an integration: 145 | ```python 146 | from mostlyai.sdk import MostlyAI 147 | mostly = MostlyAI() 148 | mostly.integrations.disconnect('google') 149 | ``` 150 | """ 151 | self.request(verb=DELETE, path=[provider_id]) 152 | rich.print( 153 | f"Disconnected integration [link={self.base_url}/d/integrations/{provider_id} dodger_blue2 underline]{provider_id}[/]" 154 | ) 155 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/table/json.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import logging 17 | import time 18 | from collections.abc import Iterable 19 | from typing import Any 20 | 21 | import pandas as pd 22 | import pyarrow.dataset as ds 23 | import smart_open 24 | from pyarrow import json as pa_json 25 | 26 | from mostlyai.sdk._data.base import order_df_by 27 | from mostlyai.sdk._data.dtype import ( 28 | coerce_dtypes_by_encoding, 29 | is_date_dtype, 30 | is_timestamp_dtype, 31 | pyarrow_to_pandas_map, 32 | ) 33 | from mostlyai.sdk._data.file.base import FileContainer, FileDataTable, LocalFileContainer 34 | from mostlyai.sdk._data.util.common import OrderBy 35 | 36 | _LOG = logging.getLogger(__name__) 37 | 38 | 39 | class JsonDataTable(FileDataTable): 40 | DATA_TABLE_TYPE = "json" 41 | # append is only supported when to_json(..., orient="records", lines=True) 42 | IS_WRITE_APPEND_ALLOWED = True 43 | 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | 47 | @classmethod 48 | def container_class(cls) -> type["FileContainer"]: 49 | return LocalFileContainer 50 | 51 | def read_data( 52 | self, 53 | where: dict[str, Any] | None = None, 54 | limit: int | None = None, 55 | columns: list[str] | None = None, 56 | shuffle: bool | None = False, 57 | order_by: OrderBy | None = None, 58 | do_coerce_dtypes: bool | None = False, 59 | ) -> pd.DataFrame: 60 | t0 = time.time() 61 | if where: 62 | filters = [] 63 | for c, v in where.items(): 64 | # make sure values is a list of unique values 65 | values = list(set(v)) if (isinstance(v, Iterable) and not isinstance(v, str)) else [v] 66 | filters.append(ds.field(c).isin(values)) 67 | filter = functools.reduce(lambda x, y: x & y, filters) 68 | else: 69 | filter = ds.scalar(True) 70 | files = [f"{self.container.path_prefix}{file}" for file in self.container.valid_files_without_scheme] 71 | df = pd.concat( 72 | [ 73 | pa_json.read_json( 74 | smart_open.open( 75 | file, 76 | "rb", 77 | transport_params=self.container.transport_params, 78 | ) 79 | ) 80 | .filter(filter) 81 | .to_pandas( 82 | # convert to pyarrow DTypes 83 | types_mapper=pyarrow_to_pandas_map.get, 84 | # reduce memory of conversion 85 | # see https://arrow.apache.org/docs/python/pandas.html#reducing-memory-use-in-table-to-pandas 86 | split_blocks=True, 87 | ) 88 | for file in files 89 | ], 90 | axis=0, 91 | ) 92 | if columns: 93 | df = df[columns] 94 | if shuffle: 95 | df = df.sample(frac=1) 96 | if limit is not None: 97 | df = df.head(limit) 98 | if order_by: 99 | df = order_df_by(df, order_by) 100 | if do_coerce_dtypes: 101 | df = coerce_dtypes_by_encoding(df, self.encoding_types) 102 | df = df.reset_index(drop=True) 103 | _LOG.info(f"read {self.DATA_TABLE_TYPE} data `{self.name}` {df.shape} in {time.time() - t0:.2f}s") 104 | return df 105 | 106 | @functools.cached_property 107 | def row_count(self) -> int: 108 | # Note: this currently reads all data; optimize later 109 | return len(self.read_data()) 110 | 111 | def _get_columns(self): 112 | # Note: this currently reads all data; optimize later 113 | return list(self.read_data().columns) 114 | 115 | def _get_dataset_format(self) -> ds.FileFormat: 116 | return ds.JsonFileFormat() 117 | 118 | def fetch_dtypes(self) -> dict[str, Any]: 119 | # Note: this currently reads all data; optimize later 120 | return self.read_data().dtypes.to_dict() 121 | 122 | def write_data(self, df: pd.DataFrame, if_exists: str = "append", **kwargs): 123 | # Convert to ISO format so that pyarrow.json.read_json can auto-detect these 124 | for c in df: 125 | if is_date_dtype(df[c]): 126 | df[c] = df[c].dt.strftime("%Y-%m-%d") 127 | elif is_timestamp_dtype(df[c]): 128 | # we need to strip off any milliseconds 129 | df[c] = ( 130 | df[c] 131 | .dt.tz_localize(None) 132 | .astype("timestamp[us][pyarrow]") 133 | .dt.strftime("%Y-%m-%d %H:%M:%S") 134 | .str[:-7] 135 | ) 136 | mode = self.handle_if_exists(if_exists) 137 | df.to_json(self.container.path_str, orient="records", lines=True, mode=mode) 138 | # raise MostlyException("write to cloud buckets not yet supported") 139 | -------------------------------------------------------------------------------- /mostlyai/sdk/_data/file/container/bucket_based.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 MOSTLY AI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import functools 17 | import re 18 | from typing import Any 19 | from urllib.parse import urlparse 20 | 21 | from mostlyai.sdk._data.exceptions import MostlyDataException 22 | from mostlyai.sdk._data.file.base import FileContainer 23 | 24 | 25 | class BucketBasedContainer(FileContainer, abc.ABC): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self.bucket_name = None 29 | self.bucket_path = None 30 | self._client = None # assumed to be provided by the subclass 31 | 32 | @property 33 | def inner_path(self): 34 | return f"{self.bucket_name}/{self.bucket_path}" 35 | 36 | def __hash__(self): 37 | return hash(self.inner_path) 38 | 39 | @classmethod 40 | @abc.abstractmethod 41 | def cloud_path_cls(cls): ... 42 | 43 | @property 44 | @abc.abstractmethod 45 | def file_system(self) -> Any: ... 46 | 47 | @property 48 | @functools.cache 49 | def path(self): 50 | return self.cloud_path_cls()(cloud_path=f"{self.path_prefix}{self.inner_path}", client=self._client) 51 | 52 | def is_accessible(self) -> bool: 53 | return self._check_authenticity() and (self._is_bucket_accessible() if self.bucket_name else True) 54 | 55 | def _is_bucket_accessible(self): 56 | bucket_path = self.bucket_name if self.bucket_path is None else self.path_without_scheme 57 | return self.file_system.ls(bucket_path) 58 | 59 | @abc.abstractmethod 60 | def _check_authenticity(self) -> bool: 61 | pass 62 | 63 | def set_uri(self, uri: str): 64 | def escape_trail_slash(s: str): 65 | return re.sub(r"^/+", "", s) 66 | 67 | try: 68 | match = self._re_match_uri(uri) 69 | self.bucket_name = escape_trail_slash(match.group(1)) 70 | self.bucket_path = escape_trail_slash(match.group(2)) 71 | except Exception: 72 | raise MostlyDataException("The location must contain the full path which includes the bucket.") 73 | 74 | @staticmethod 75 | def normalize_bucket_location(location: str) -> str: 76 | uri = urlparse(location) 77 | parts = [uri.netloc.strip("/"), uri.path.strip("/")] 78 | joined = "/".join([part for part in parts if part]) 79 | if "/" not in joined: 80 | joined = joined + "/" 81 | return joined 82 | 83 | def set_location(self, location: str) -> dict: 84 | location = self.normalize_bucket_location(location) 85 | return super().set_location(location) 86 | 87 | def list_locations(self, prefix: str | None) -> list[str]: 88 | """ 89 | List the available locations of a given prefix. 90 | If the prefix is None or an empty string, it will list the buckets. 91 | If the prefix refers to a bucket or a directory, it will return itself and the objects under it. 92 | If the prefix refers to a file, it will return a list containing itself. 93 | If the prefix refers to a non-existent object, it will return an empty list. 94 | 95 | The workaround implemented here is meant to handle the inconsistent behaviors of 96 | the different file system and to unify the output format. 97 | 98 | TODO: [Known issue] GCSFileSystem can only list a folder properly when the prefix ends with a slash 99 | """ 100 | locations = [] 101 | try: 102 | protocol = self.path_prefix 103 | prefix = prefix or "" 104 | # strip ending slashes that are not part of the protocol 105 | prefix = protocol + prefix.removeprefix(protocol).rstrip("/") 106 | # NOTE: with gcsfs >= 2025.5.0, GCSFileSystem's exists() and isdir() do not work properly for root prefix 107 | # so we need to ensure that those functions are not called in this case 108 | is_root = prefix == protocol 109 | if is_root or self.file_system.exists(prefix): 110 | if not is_root: 111 | # only add ending slash for directories 112 | prefix = prefix + "/" if self.file_system.isdir(prefix) else prefix 113 | # Directories come first 114 | locations = sorted( 115 | self.file_system.ls(prefix, detail=True), 116 | key=lambda loc: loc["type"], 117 | ) 118 | # Append an ending slash for directories if it is not there yet 119 | locations = [ 120 | loc["name"] + "/" if loc["type"] == "directory" and not loc["name"].endswith("/") else loc["name"] 121 | for loc in locations 122 | ] 123 | # Add the protocol prefix to the locations if not there yet 124 | locations = [protocol + loc if not loc.lower().startswith(protocol) else loc for loc in locations] 125 | # Append the prefix itself in the list if it is not there yet 126 | if prefix not in locations and prefix != "/": 127 | locations = [prefix] + locations 128 | except FileNotFoundError: 129 | pass 130 | return locations 131 | --------------------------------------------------------------------------------