├── .env.example ├── .github └── workflows │ └── main.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── aixplain ├── __init__.py ├── base │ └── parameters.py ├── cli_groups.py ├── decorators │ ├── __init__.py │ └── api_key_checker.py ├── enums │ ├── __init__.py │ ├── asset_status.py │ ├── data_split.py │ ├── data_subtype.py │ ├── data_type.py │ ├── database_source.py │ ├── embedding_model.py │ ├── error_handler.py │ ├── file_type.py │ ├── function.py │ ├── index_stores.py │ ├── language.py │ ├── license.py │ ├── onboard_status.py │ ├── ownership_type.py │ ├── privacy.py │ ├── response_status.py │ ├── sort_by.py │ ├── sort_order.py │ ├── status.py │ ├── storage_type.py │ └── supplier.py ├── exceptions │ ├── __init__.py │ └── types.py ├── factories │ ├── __init__.py │ ├── agent_factory │ │ ├── __init__.py │ │ └── utils.py │ ├── api_key_factory.py │ ├── asset_factory.py │ ├── benchmark_factory.py │ ├── cli │ │ └── model_factory_cli.py │ ├── corpus_factory.py │ ├── data_factory.py │ ├── dataset_factory.py │ ├── file_factory.py │ ├── finetune_factory │ │ ├── __init__.py │ │ └── prompt_validator.py │ ├── index_factory │ │ ├── __init__.py │ │ └── utils.py │ ├── metric_factory.py │ ├── model_factory │ │ ├── __init__.py │ │ └── utils.py │ ├── pipeline_factory │ │ ├── __init__.py │ │ └── utils.py │ ├── script_factory.py │ ├── team_agent_factory │ │ ├── __init__.py │ │ └── utils.py │ └── wallet_factory.py ├── modules │ ├── __init__.py │ ├── agent │ │ ├── __init__.py │ │ ├── agent_response.py │ │ ├── agent_response_data.py │ │ ├── agent_task.py │ │ ├── output_format.py │ │ ├── tool │ │ │ ├── __init__.py │ │ │ ├── custom_python_code_tool.py │ │ │ ├── model_tool.py │ │ │ ├── pipeline_tool.py │ │ │ ├── python_interpreter_tool.py │ │ │ └── sql_tool.py │ │ └── utils.py │ ├── api_key.py │ ├── asset.py │ ├── benchmark.py │ ├── benchmark_job.py │ ├── content_interval.py │ ├── corpus.py │ ├── data.py │ ├── dataset.py │ ├── file.py │ ├── finetune │ │ ├── __init__.py │ │ ├── cost.py │ │ ├── hyperparameters.py │ │ └── status.py │ ├── metadata.py │ ├── metric.py │ ├── mixins.py │ ├── model │ │ ├── __init__.py │ │ ├── index_model.py │ │ ├── llm_model.py │ │ ├── model_parameters.py │ │ ├── model_response_streamer.py │ │ ├── record.py │ │ ├── response.py │ │ ├── utility_model.py │ │ └── utils.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── asset.py │ │ ├── default.py │ │ ├── designer │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── enums.py │ │ │ ├── mixins.py │ │ │ ├── nodes.py │ │ │ ├── pipeline.py │ │ │ └── utils.py │ │ ├── pipeline.py │ │ └── response.py │ ├── team_agent │ │ └── __init__.py │ └── wallet.py ├── processes │ ├── __init__.py │ └── data_onboarding │ │ ├── __init__.py │ │ ├── onboard_functions.py │ │ ├── process_media_files.py │ │ └── process_text_files.py ├── utils │ ├── __init__.py │ ├── cache_utils.py │ ├── config.py │ ├── convert_datatype_utils.py │ ├── file_utils.py │ ├── request_utils.py │ └── validation_utils.py └── v2 │ ├── __init__.py │ ├── agent.py │ ├── api_key.py │ ├── benchmark.py │ ├── client.py │ ├── core.py │ ├── corpus.py │ ├── data.py │ ├── dataset.py │ ├── enums.py │ ├── enums_include.py │ ├── file.py │ ├── finetune.py │ ├── metric.py │ ├── model.py │ ├── pipeline.py │ ├── resource.py │ ├── script.py │ ├── team_agent.py │ └── wallet.py ├── docs ├── assets │ ├── aixplain-brandmark-line.png │ ├── architecture.png │ ├── data-onboard.png │ ├── designer-subtitling-sample.png │ ├── model-id-on-platform.png │ ├── navigate-api-key.png │ └── subtitle-generator-output.json ├── development │ └── developer_guide.md ├── samples │ ├── corpus_onboarding │ │ ├── corpus_onboarding.ipynb │ │ └── data.csv │ ├── dataset_onboarding │ │ ├── data.csv │ │ └── dataset_onboarding.ipynb │ ├── label_dataset_onboarding │ │ ├── corpus │ │ │ ├── images │ │ │ │ ├── 1.jpg │ │ │ │ └── 2.png │ │ │ ├── index.csv │ │ │ └── labels │ │ │ │ ├── 1.json │ │ │ │ └── 2.json │ │ └── label_dataset_onboarding.ipynb │ └── subtitle_generator │ │ ├── README.md │ │ └── subtitle_generator.py ├── streaming │ ├── README.md │ ├── aixplain_diarization_streaming_client.py │ ├── aixplain_speech_transcription_streaming_client.py │ ├── make_audio_compatible.py │ ├── proto │ │ ├── aixplain_diarization_streaming.proto │ │ └── aixplain_speech_transcription_streaming.proto │ ├── requirements.txt │ └── test_dia.wav └── user │ ├── api_setup.md │ └── user_doc.md ├── generate.py ├── pipeline_test2.ipynb ├── pyproject.toml ├── pytest.ini ├── setup.cfg └── tests ├── __init__.py ├── conftest.py ├── functional ├── agent │ ├── agent_functional_test.py │ └── data │ │ └── agent_test_end2end.json ├── apikey │ ├── README.md │ ├── apikey.json │ └── test_api.py ├── benchmark │ ├── benchmark_functional_test.py │ └── data │ │ ├── benchmark_module_test_data.json │ │ └── benchmark_test_run_data.json ├── data_asset │ ├── __init__.py │ ├── corpus_onboarding_test.py │ ├── dataset_onboarding_test.py │ └── input │ │ ├── audio-en_url.csv │ │ ├── audio-en_with_invalid_split_url.csv │ │ └── audio-en_with_split_url.csv ├── file_asset │ ├── __init__.py │ ├── file_create_test.py │ └── input │ │ └── test.csv ├── finetune │ ├── __init__.py │ ├── data │ │ ├── finetune_test_cost_estimation.json │ │ ├── finetune_test_end2end.json │ │ ├── finetune_test_list_data.json │ │ └── finetune_test_prompt_validator.json │ └── finetune_functional_test.py ├── general_assets │ ├── asset_functional_test.py │ └── data │ │ └── asset_run_test_data.json ├── model │ ├── data │ │ └── test_input.txt │ ├── hf_onboarding_test.py │ ├── image_upload_e2e_test.py │ ├── image_upload_functional_test.py │ ├── run_model_test.py │ └── run_utility_model_test.py ├── pipelines │ ├── create_test.py │ ├── data │ │ ├── pipeline.json │ │ └── script.py │ ├── designer_test.py │ └── run_test.py └── team_agent │ ├── data │ └── team_agent_test_end2end.json │ └── team_agent_functional_test.py ├── mock_responses ├── create_asset_repo_response.json ├── list_functions_response.json ├── list_host_machines_response.json ├── list_image_repo_tags_response.json └── login_response.json ├── test_requests └── create_asset_request.json ├── test_utils.py └── unit ├── agent ├── agent_factory_utils_test.py ├── agent_test.py ├── model_tool_test.py └── sql_tool_test.py ├── api_key_test.py ├── benchmark_test.py ├── corpus_test.py ├── data └── create_finetune_percentage_exception.json ├── dataset_test.py ├── designer_unit_test.py ├── finetune_test.py ├── hyperparameters_test.py ├── image_upload_test.py ├── index_model_test.py ├── llm_test.py ├── mock_responses ├── cost_estimation_response.json ├── finetune_response.json ├── finetune_status_response.json ├── finetune_status_response_2.json ├── list_models_response.json └── model_response.json ├── model_test.py ├── pipeline_test.py ├── team_agent_test.py ├── utility_test.py ├── utility_tool_decorator_test.py ├── v2 ├── test_core.py └── test_resource.py └── wallet_test.py /.env.example: -------------------------------------------------------------------------------- 1 | BACKEND_URL=https://platform-api.aixplain.com 2 | MODELS_RUN_URL=https://models.aixplain.com/api/v1/execute 3 | PIPELINE_API_KEY= 4 | MODEL_API_KEY= 5 | LOG_LEVEL=DEBUG 6 | TEAM_API_KEY= -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .aixplain_cache/ 6 | 7 | setup_env_ahmet.sh 8 | # C extensions 9 | *.so 10 | *.ipynb 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Vscode 135 | .vscode 136 | .DS_Store 137 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: pytest-check 5 | name: pytest-check 6 | entry: coverage run --source=. -m pytest tests/unit 7 | language: python 8 | pass_filenames: false 9 | types: [python] 10 | always_run: true 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 22.10.0 14 | hooks: 15 | - id: black 16 | language_version: python3 17 | args: # arguments to configure black 18 | - --line-length=128 19 | 20 | - repo: https://github.com/pre-commit/pre-commit-hooks 21 | rev: v2.0.0 # Use the latest version 22 | hooks: 23 | - id: flake8 24 | args: # arguments to configure flake8 25 | - --ignore=E402,E501,E203,W503 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | aiXplain logo 2 | 3 | # aiXplain 4 | 5 | aixplain is a software development kit (SDK) for the [aiXplain](https://aixplain.com/) platform. With aixplain, developers can quickly and easily: 6 | 7 | - [Discover](https://aixplain.com/platform/discovery/) aiXplain’s ever-expanding catalog of 35,000+ ready-to-use AI models and utilize them. 8 | - [Benchmark](https://aixplain.com/platform/benchmark/) AI systems by choosing models, datasets and metrics. 9 | - [Design](https://aixplain.com/platform/studio/) their own custom pipelines and run them. 10 | - [FineTune](https://aixplain.com/platform/finetune/) pre-trained models by tuning them using your data, enhancing their performance. 11 | 12 | 🔎 **Find [models](https://platform.aixplain.com/discovery/models), [datasets](https://platform.aixplain.com/discovery/datasets), [metrics](https://platform.aixplain.com/discovery/metrics) on the platform.** 13 | 14 | 💛 Our repository is constantly evolving. With the help of the scientific community, we plan to add even more datasets, models, and metrics across domains and tasks. 15 | 16 | ## Getting Started 17 | 18 | ### Installation 19 | To install the base package, simply, 20 | ```bash 21 | pip install aixplain 22 | ``` 23 | 24 | To install aiXplain with additional model building support: 25 | ```bash 26 | pip install aixplain[model-builder] 27 | ``` 28 | 29 | ### API Key Setup 30 | Before you can use the aixplain SDK, you'll need to obtain an API key from our platform. For details refer this [Team API Key Guide](docs/user/api_setup.md). 31 | 32 | Once you get the API key, you'll need to add this API key as an environment variable on your system. 33 | 34 | #### Linux or macOS 35 | ```bash 36 | export TEAM_API_KEY=YOUR_API_KEY 37 | ``` 38 | #### Windows 39 | ```bash 40 | set TEAM_API_KEY=YOUR_API_KEY 41 | ``` 42 | #### Jupyter Notebook 43 | ``` 44 | %env TEAM_API_KEY=YOUR_API_KEY 45 | ``` 46 | 47 | ### Usage 48 | 49 | Let’s see how we can use aixplain to run a machine translation model. The following example shows an [English to French translation model](https://platform.aixplain.com/discovery/model/61dc52976eb5634cf06e97cc). 50 | 51 | ```python 52 | from aixplain.factories import ModelFactory 53 | model = ModelFactory.get("61dc52976eb5634cf06e97cc") # Get the ID of a model from our platform. 54 | translation = model.run("This is a sample text") # Alternatively, you can input a public URL or provide a file path on your local machine. 55 | ``` 56 | *Check out the [explore section](docs/user/user_doc.md#explore) of our guide on Models to get the ID of your desired model* 57 | 58 | ## Quick Links 59 | 60 | * [Team API Key Guide](docs/user/api_setup.md) 61 | * [User Documentation](docs/user/user_doc.md) 62 | * [Developer Guide](docs/development/developer_guide.md) 63 | * [API Reference](https://docs.aixplain.com) 64 | * [Release notes](https://github.com/aixplain/aiXplain/releases) 65 | 66 | ## Support 67 | Raise issues for support in this repository. 68 | Pull requests are welcome! 69 | 70 | ## Note 71 | The **aiXtend** python package was renamed to **aiXplain** from the release v0.1.1. -------------------------------------------------------------------------------- /aixplain/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | aiXplain SDK Library. 3 | --- 4 | 5 | aiXplain SDK enables python programmers to add AI functions 6 | to their software. 7 | 8 | Copyright 2022 The aiXplain SDK authors 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import os 24 | import logging 25 | from dotenv import load_dotenv 26 | 27 | load_dotenv() 28 | 29 | from .v2.core import Aixplain # noqa 30 | 31 | LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() 32 | logging.basicConfig(level=LOG_LEVEL) 33 | 34 | 35 | aixplain_v2 = None 36 | try: 37 | aixplain_v2 = Aixplain() 38 | except Exception: 39 | pass 40 | 41 | 42 | __all__ = ["Aixplain", "aixplain_v2"] 43 | -------------------------------------------------------------------------------- /aixplain/base/parameters.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Optional, List 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class Parameter: 7 | name: str 8 | required: bool 9 | value: Optional[Any] = None 10 | 11 | 12 | class BaseParameters: 13 | def __init__(self) -> None: 14 | """Initialize base parameters class""" 15 | self.parameters: Dict[str, Parameter] = {} 16 | 17 | def get_parameter(self, name: str) -> Optional[Parameter]: 18 | """Get a parameter by name. 19 | 20 | Args: 21 | name (str): Name of the parameter 22 | 23 | Returns: 24 | Optional[Parameter]: Parameter object if found, None otherwise 25 | """ 26 | return self.parameters.get(name) 27 | 28 | def to_dict(self) -> Dict[str, Dict[str, Any]]: 29 | """Convert parameters back to dictionary format. 30 | 31 | Returns: 32 | Dict[str, Dict[str, Any]]: Dictionary representation of parameters 33 | """ 34 | return {param.name: {"required": param.required, "value": param.value} for param in self.parameters.values()} 35 | 36 | def to_list(self) -> List[str]: 37 | """Convert parameters back to list format. 38 | 39 | Returns: 40 | List[str]: List representation of parameters 41 | """ 42 | return [{"name": param.name, "value": param.value} for param in self.parameters.values() if param.value is not None] 43 | 44 | def __str__(self) -> str: 45 | """Create a pretty string representation of the parameters. 46 | 47 | Returns: 48 | str: Formatted string showing all parameters 49 | """ 50 | if not self.parameters: 51 | return "No parameters defined" 52 | 53 | lines = ["Parameters:"] 54 | for param in self.parameters.values(): 55 | value_str = str(param.value) if param.value is not None else "Not set" 56 | required_str = "(Required)" if param.required else "(Optional)" 57 | lines.append(f" - {param.name}: {value_str} {required_str}") 58 | 59 | return "\n".join(lines) 60 | 61 | def __setattr__(self, name: str, value: Any) -> None: 62 | """Allow setting parameters using attribute syntax (e.g., params.text = "Hello"). 63 | 64 | Args: 65 | name (str): Name of the parameter 66 | value (Any): Value to set for the parameter 67 | """ 68 | if name == "parameters": # Allow setting the parameters dict normally 69 | super().__setattr__(name, value) 70 | return 71 | 72 | if name in self.parameters: 73 | self.parameters[name].value = value 74 | else: 75 | raise AttributeError(f"Parameter '{name}' is not defined") 76 | 77 | def __getattr__(self, name: str) -> Any: 78 | """Allow getting parameter values using attribute syntax (e.g., params.text). 79 | 80 | Args: 81 | name (str): Name of the parameter 82 | 83 | Returns: 84 | Any: Value of the parameter 85 | 86 | Raises: 87 | AttributeError: If parameter is not defined 88 | """ 89 | if name in self.parameters: 90 | return self.parameters[name].value 91 | raise AttributeError(f"Parameter '{name}' is not defined") 92 | -------------------------------------------------------------------------------- /aixplain/cli_groups.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Michael Lam 19 | Date: September 18th 2023 20 | Description: 21 | CLI Runner 22 | """ 23 | import click 24 | from aixplain.factories.cli.model_factory_cli import ( 25 | list_host_machines, 26 | list_functions, 27 | create_asset_repo, 28 | asset_repo_login, 29 | onboard_model, 30 | deploy_huggingface_model, 31 | get_huggingface_model_status, 32 | list_gpus, 33 | ) 34 | 35 | 36 | @click.group("cli") 37 | def cli(): 38 | pass 39 | 40 | 41 | @click.group("list") 42 | def list(): 43 | pass 44 | 45 | 46 | @click.group("get") 47 | def get(): 48 | pass 49 | 50 | 51 | @click.group("create") 52 | def create(): 53 | pass 54 | 55 | 56 | @click.group("onboard") 57 | def onboard(): 58 | pass 59 | 60 | 61 | cli.add_command(list) 62 | cli.add_command(get) 63 | cli.add_command(create) 64 | cli.add_command(onboard) 65 | 66 | create.add_command(create_asset_repo) 67 | list.add_command(list_host_machines) 68 | list.add_command(list_functions) 69 | list.add_command(list_gpus) 70 | get.add_command(asset_repo_login) 71 | get.add_command(get_huggingface_model_status) 72 | onboard.add_command(onboard_model) 73 | onboard.add_command(deploy_huggingface_model) 74 | 75 | 76 | def run_cli(): 77 | cli() 78 | -------------------------------------------------------------------------------- /aixplain/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | from .api_key_checker import check_api_key 2 | -------------------------------------------------------------------------------- /aixplain/decorators/api_key_checker.py: -------------------------------------------------------------------------------- 1 | from aixplain.utils import config 2 | 3 | 4 | def check_api_key(method): 5 | def wrapper(*args, **kwargs): 6 | if config.TEAM_API_KEY == "" and config.AIXPLAIN_API_KEY == "": 7 | raise Exception( 8 | "A 'TEAM_API_KEY' is required to run an asset. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)" 9 | ) 10 | return method(*args, **kwargs) 11 | 12 | return wrapper 13 | -------------------------------------------------------------------------------- /aixplain/enums/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401 // to ignore the F401 (unused import) 2 | from .data_split import DataSplit 3 | from .data_subtype import DataSubtype 4 | from .data_type import DataType 5 | from .error_handler import ErrorHandler 6 | from .file_type import FileType 7 | from .function import Function, FunctionInputOutput 8 | from .language import Language 9 | from .license import License 10 | from .onboard_status import OnboardStatus 11 | from .ownership_type import OwnershipType 12 | from .privacy import Privacy 13 | from .storage_type import StorageType 14 | from .supplier import Supplier 15 | from .sort_by import SortBy 16 | from .sort_order import SortOrder 17 | from .response_status import ResponseStatus 18 | from .database_source import DatabaseSourceType 19 | from .embedding_model import EmbeddingModel 20 | from .asset_status import AssetStatus 21 | from .index_stores import IndexStores 22 | -------------------------------------------------------------------------------- /aixplain/enums/asset_status.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: February 21st 2024 20 | Description: 21 | Asset Enum 22 | """ 23 | 24 | from enum import Enum 25 | from typing import Text 26 | 27 | 28 | class AssetStatus(Text, Enum): 29 | DRAFT = "draft" 30 | HIDDEN = "hidden" 31 | SCHEDULED = "scheduled" 32 | ONBOARDING = "onboarding" 33 | ONBOARDED = "onboarded" 34 | PENDING = "pending" 35 | FAILED = "failed" 36 | TRAINING = "training" 37 | REJECTED = "rejected" 38 | ENABLING = "enabling" 39 | DELETING = "deleting" 40 | DISABLED = "disabled" 41 | DELETED = "deleted" 42 | IN_PROGRESS = "in_progress" 43 | COMPLETED = "completed" 44 | CANCELING = "canceling" 45 | CANCELED = "canceled" 46 | DEPRECATED_DRAFT = "deprecated_draft" 47 | -------------------------------------------------------------------------------- /aixplain/enums/data_split.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | Data Split Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class DataSplit(Enum): 28 | TRAIN = "train" 29 | VALIDATION = "validation" 30 | TEST = "test" 31 | -------------------------------------------------------------------------------- /aixplain/enums/data_subtype.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: May 3rd 2023 20 | Description: 21 | Data Subtype Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class DataSubtype(Enum): 28 | AGE = "age" 29 | GENDER = "gender" 30 | INTERVAL = "interval" 31 | OTHER = "other" 32 | RACE = "race" 33 | SPLIT = "split" 34 | TOPIC = "topic" 35 | 36 | def __str__(self): 37 | return self._value_ 38 | -------------------------------------------------------------------------------- /aixplain/enums/data_type.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | Data Type Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class DataType(str, Enum): 28 | AUDIO = "audio" 29 | FLOAT = "float" 30 | IMAGE = "image" 31 | INTEGER = "integer" 32 | LABEL = "label" 33 | TENSOR = "tensor" 34 | TEXT = "text" 35 | VIDEO = "video" 36 | EMBEDDING = "embedding" 37 | NUMBER = "number" 38 | BOOLEAN = "boolean" 39 | 40 | def __str__(self): 41 | return self._value_ 42 | -------------------------------------------------------------------------------- /aixplain/enums/database_source.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Lucas Pavanelli and Thiago Castro Ferreira and Ahmet Gunduz 19 | Date: March 7th 2025 20 | Description: 21 | Database Source Type Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class DatabaseSourceType(Enum): 28 | """Enum for database source types""" 29 | 30 | POSTGRESQL = "postgresql" 31 | SQLITE = "sqlite" 32 | CSV = "csv" 33 | 34 | @classmethod 35 | def from_string(cls, source_type: str) -> "DatabaseSourceType": 36 | """Convert string to DatabaseSourceType enum 37 | 38 | Args: 39 | source_type (str): Source type string 40 | 41 | Returns: 42 | DatabaseSourceType: Corresponding enum value 43 | """ 44 | try: 45 | return cls[source_type.upper()] 46 | except KeyError: 47 | raise ValueError(f"Invalid source type: {source_type}") 48 | -------------------------------------------------------------------------------- /aixplain/enums/embedding_model.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | Author: aiXplain team 15 | Date: February 17th 2025 16 | Description: 17 | Embedding Model Enum 18 | """ 19 | 20 | from enum import Enum 21 | 22 | 23 | class EmbeddingModel(str, Enum): 24 | SNOWFLAKE_ARCTIC_EMBED_M_LONG = "6658d40729985c2cf72f42ec" 25 | OPENAI_ADA002 = "6734c55df127847059324d9e" 26 | SNOWFLAKE_ARCTIC_EMBED_L_V2_0 = "678a4f8547f687504744960a" 27 | JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" 28 | MULTILINGUAL_E5_LARGE = "67efd0772a0a850afa045af3" 29 | BGE_M3 = "67efd4f92a0a850afa045af7" 30 | AIXPLAIN_LEGAL_EMBEDDINGS = "681254b668e47e7844c1f15a" 31 | 32 | 33 | def __str__(self): 34 | return self._value_ 35 | -------------------------------------------------------------------------------- /aixplain/enums/error_handler.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: May 26th 2023 20 | Description: 21 | Error Handler Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class ErrorHandler(Enum): 28 | """ 29 | Enumeration class defining different error handler strategies. 30 | 31 | Attributes: 32 | SKIP (str): skip failed rows. 33 | FAIL (str): raise an exception. 34 | """ 35 | 36 | SKIP = "skip" 37 | FAIL = "fail" 38 | -------------------------------------------------------------------------------- /aixplain/enums/file_type.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | File Type Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class FileType(Enum): 28 | CSV = ".csv" 29 | JSON = ".json" 30 | TXT = ".txt" 31 | XML = ".xml" 32 | FLAC = ".flac" 33 | MP3 = ".mp3" 34 | WAV = ".wav" 35 | JPEG = ".jpeg" 36 | PNG = ".png" 37 | JPG = ".jpg" 38 | GIF = ".gif" 39 | WEBP = ".webp" 40 | AVI = ".avi" 41 | MP4 = ".mp4" 42 | MOV = ".mov" 43 | MPEG4 = ".mpeg4" 44 | -------------------------------------------------------------------------------- /aixplain/enums/index_stores.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class IndexStores(Enum): 5 | AIR = {"name": "air", "id": "66eae6656eb56311f2595011"} 6 | VECTARA = {"name": "vectara", "id": "655e20f46eb563062a1aa301"} 7 | GRAPHRAG = {"name": "graphrag", "id": "67dd6d487cbf0a57cf4b72f3"} 8 | ZERO_ENTROPY = {"name": "zeroentropy", "id": "6807949168e47e7844c1f0c5"} 9 | 10 | def __str__(self): 11 | return self.value["name"] 12 | 13 | def get_model_id(self): 14 | return self.value["id"] 15 | -------------------------------------------------------------------------------- /aixplain/enums/language.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 22th 2023 20 | Description: 21 | Language Enum 22 | """ 23 | 24 | from enum import Enum 25 | from urllib.parse import urljoin 26 | from aixplain.utils import config 27 | from aixplain.utils.request_utils import _request_with_retry 28 | from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER 29 | 30 | CACHE_FILE = f"{CACHE_FOLDER}/languages.json" 31 | 32 | 33 | def load_languages(): 34 | resp = load_from_cache(CACHE_FILE) 35 | if resp is None: 36 | api_key = config.TEAM_API_KEY 37 | backend_url = config.BACKEND_URL 38 | 39 | url = urljoin(backend_url, "sdk/languages") 40 | 41 | headers = {"x-api-key": api_key, "Content-Type": "application/json"} 42 | r = _request_with_retry("get", url, headers=headers) 43 | if not 200 <= r.status_code < 300: 44 | raise Exception( 45 | f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' 46 | ) 47 | resp = r.json() 48 | save_to_cache(CACHE_FILE, resp) 49 | 50 | languages = {} 51 | for w in resp: 52 | language = w["value"] 53 | language_label = "_".join(w["label"].split()) 54 | languages[language_label] = {"language": language, "dialect": ""} 55 | for dialect in w["dialects"]: 56 | dialect_label = "_".join(dialect["label"].split()).upper() 57 | dialect_value = dialect["value"] 58 | 59 | languages[language_label + "_" + dialect_label] = {"language": language, "dialect": dialect_value} 60 | return Enum("Language", languages, type=dict) 61 | 62 | 63 | Language = load_languages() 64 | -------------------------------------------------------------------------------- /aixplain/enums/license.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | License Enum 22 | """ 23 | 24 | import logging 25 | from enum import Enum 26 | from urllib.parse import urljoin 27 | from aixplain.utils import config 28 | from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER 29 | from aixplain.utils.request_utils import _request_with_retry 30 | 31 | CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" 32 | 33 | 34 | def load_licenses(): 35 | resp = load_from_cache(CACHE_FILE) 36 | 37 | try: 38 | if resp is None: 39 | api_key = config.TEAM_API_KEY 40 | backend_url = config.BACKEND_URL 41 | 42 | url = urljoin(backend_url, "sdk/licenses") 43 | 44 | headers = {"x-api-key": api_key, "Content-Type": "application/json"} 45 | r = _request_with_retry("get", url, headers=headers) 46 | if not 200 <= r.status_code < 300: 47 | raise Exception( 48 | f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' 49 | ) 50 | resp = r.json() 51 | save_to_cache(CACHE_FILE, resp) 52 | licenses = {"_".join(w["name"].split()): w["id"] for w in resp} 53 | return Enum("License", licenses, type=str) 54 | except Exception: 55 | logging.exception("License Loading Error") 56 | raise Exception("License Loading Error") 57 | 58 | 59 | License = load_licenses() 60 | -------------------------------------------------------------------------------- /aixplain/enums/onboard_status.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 22th 2023 20 | Description: 21 | Onboard Status Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class OnboardStatus(Enum): 28 | ONBOARDING = "onboarding" 29 | ONBOARDED = "onboarded" 30 | FAILED = "failed" 31 | DELETED = "deleted" 32 | -------------------------------------------------------------------------------- /aixplain/enums/ownership_type.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: November 22nd 2023 20 | Description: 21 | Asset Ownership Type 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class OwnershipType(Enum): 28 | SUBSCRIBED = "SUBSCRIBED" 29 | OWNED = "OWNED" 30 | 31 | def __str__(self): 32 | return self._value_ 33 | -------------------------------------------------------------------------------- /aixplain/enums/privacy.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | Privacy Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class Privacy(Enum): 28 | PUBLIC = "Public" 29 | PRIVATE = "Private" 30 | RESTRICTED = "Restricted" 31 | -------------------------------------------------------------------------------- /aixplain/enums/response_status.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: February 21st 2024 20 | Description: 21 | Asset Enum 22 | """ 23 | 24 | from enum import Enum 25 | from typing import Text 26 | 27 | 28 | class ResponseStatus(Text, Enum): 29 | IN_PROGRESS = "IN_PROGRESS" 30 | SUCCESS = "SUCCESS" 31 | FAILED = "FAILED" 32 | 33 | def __str__(self): 34 | return self.value 35 | -------------------------------------------------------------------------------- /aixplain/enums/sort_by.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | Sort By Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class SortBy(Enum): 28 | CREATION_DATE = "createdAt" 29 | PRICE = "normalizedPrice" 30 | POPULARITY = "totalSubscribed" 31 | -------------------------------------------------------------------------------- /aixplain/enums/sort_order.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | Sort By Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class SortOrder(Enum): 28 | ASCENDING = 1 29 | DESCENDING = -1 30 | -------------------------------------------------------------------------------- /aixplain/enums/status.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Text 3 | 4 | 5 | class Status(Text, Enum): 6 | FAILED = "failed" 7 | IN_PROGRESS = "in_progress" 8 | SUCCESS = "success" 9 | -------------------------------------------------------------------------------- /aixplain/enums/storage_type.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | Storage Type Enum 22 | """ 23 | 24 | from enum import Enum 25 | 26 | 27 | class StorageType(Enum): 28 | TEXT = "text" 29 | URL = "url" 30 | FILE = "file" 31 | 32 | def __str__(self): 33 | return self._value_ 34 | -------------------------------------------------------------------------------- /aixplain/enums/supplier.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: September 25th 2023 20 | Description: 21 | Supplier Enum 22 | """ 23 | 24 | import logging 25 | 26 | from aixplain.utils import config 27 | from aixplain.utils.request_utils import _request_with_retry 28 | from enum import Enum 29 | from urllib.parse import urljoin 30 | import re 31 | 32 | 33 | def clean_name(name): 34 | cleaned_name = re.sub(r"[ -]+", "_", name) 35 | cleaned_name = re.sub(r"[^a-zA-Z0-9_]", "", cleaned_name) 36 | cleaned_name = re.sub(r"^\d+", "", cleaned_name) 37 | return cleaned_name.upper() 38 | 39 | 40 | def load_suppliers(): 41 | api_key = config.TEAM_API_KEY 42 | backend_url = config.BACKEND_URL 43 | 44 | url = urljoin(backend_url, "sdk/suppliers") 45 | 46 | headers = {"x-api-key": api_key, "Content-Type": "application/json"} 47 | logging.debug(f"Start service for GET API Creation - {url} - {headers}") 48 | r = _request_with_retry("get", url, headers=headers) 49 | if not 200 <= r.status_code < 300: 50 | raise Exception( 51 | f'Suppliers could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' 52 | ) 53 | resp = r.json() 54 | suppliers = Enum( 55 | "Supplier", {clean_name(w["name"]): {"id": w["id"], "name": w["name"], "code": w["code"]} for w in resp}, type=dict 56 | ) 57 | suppliers.__str__ = lambda self: self.value["name"] 58 | 59 | return suppliers 60 | 61 | 62 | Supplier = load_suppliers() 63 | -------------------------------------------------------------------------------- /aixplain/factories/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | aiXplain SDK Library. 3 | --- 4 | 5 | aiXplain SDK enables python programmers to add AI functions 6 | to their software. 7 | 8 | Copyright 2022 The aiXplain SDK authors 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | from .asset_factory import AssetFactory 23 | from .agent_factory import AgentFactory 24 | from .team_agent_factory import TeamAgentFactory 25 | from .benchmark_factory import BenchmarkFactory 26 | from .corpus_factory import CorpusFactory 27 | from .data_factory import DataFactory 28 | from .dataset_factory import DatasetFactory 29 | from .file_factory import FileFactory 30 | from .metric_factory import MetricFactory 31 | from .model_factory import ModelFactory 32 | from .pipeline_factory import PipelineFactory 33 | from .finetune_factory import FinetuneFactory 34 | from .wallet_factory import WalletFactory 35 | from .api_key_factory import APIKeyFactory 36 | from .index_factory import IndexFactory 37 | -------------------------------------------------------------------------------- /aixplain/factories/asset_factory.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: December 27th 2022 20 | Description: 21 | Asset Factory Class 22 | """ 23 | 24 | from abc import abstractmethod 25 | from typing import Text 26 | from aixplain.modules.asset import Asset 27 | from aixplain.utils import config 28 | 29 | 30 | class AssetFactory: 31 | 32 | backend_url = config.BACKEND_URL 33 | 34 | @abstractmethod 35 | def get(self, asset_id: Text) -> Asset: 36 | """Create a 'Asset' object from id 37 | 38 | Args: 39 | asset_id (str): ID of required asset. 40 | 41 | Returns: 42 | Asset: Created 'Asset' object 43 | """ 44 | pass 45 | -------------------------------------------------------------------------------- /aixplain/factories/data_factory.py: -------------------------------------------------------------------------------- 1 | __author__ = "shreyassharma" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: May 15th 2023 20 | Description: 21 | Data Factory Class 22 | """ 23 | 24 | import aixplain.utils.config as config 25 | import logging 26 | 27 | from aixplain.factories.asset_factory import AssetFactory 28 | from aixplain.modules.data import Data 29 | from aixplain.enums.data_subtype import DataSubtype 30 | from aixplain.enums.data_type import DataType 31 | from aixplain.enums.language import Language 32 | from aixplain.enums.privacy import Privacy 33 | from aixplain.utils.request_utils import _request_with_retry 34 | from typing import Dict, Text 35 | from urllib.parse import urljoin 36 | 37 | 38 | class DataFactory(AssetFactory): 39 | """A static class for creating and exploring Dataset Objects. 40 | 41 | Attributes: 42 | backend_url (str): The URL for the backend. 43 | """ 44 | 45 | backend_url = config.BACKEND_URL 46 | 47 | @classmethod 48 | def __from_response(cls, response: Dict) -> Data: 49 | """Converts Json response to 'Data' object 50 | 51 | Args: 52 | response (dict): Json from API 53 | 54 | Returns: 55 | Data: Converted 'Data' object 56 | """ 57 | languages = [] 58 | if "languages" in response["metadata"]: 59 | languages = [] 60 | for lng in response["metadata"]["languages"]: 61 | if "dialect" not in lng: 62 | lng["dialect"] = "" 63 | languages.append(Language(lng)) 64 | 65 | data = Data( 66 | id=response["id"], 67 | name=response["name"], 68 | dtype=DataType(response["dataType"]), 69 | dsubtype=DataSubtype(response["dataSubtype"]), 70 | privacy=Privacy.PRIVATE, 71 | languages=languages, 72 | onboard_status=response["status"], 73 | length=int(response["segmentsCount"]) 74 | if "segmentsCount" in response and response["segmentsCount"] is not None 75 | else None, 76 | ) 77 | return data 78 | 79 | @classmethod 80 | def get(cls, data_id: Text) -> Data: 81 | """Create a 'Data' object from dataset id 82 | 83 | Args: 84 | data_id (Text): Data ID of required dataset. 85 | 86 | Returns: 87 | Data: Created 'Data' object 88 | """ 89 | url = urljoin(cls.backend_url, f"sdk/data/{data_id}/overview") 90 | 91 | headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} 92 | logging.info(f"Start service for GET Data - {url} - {headers}") 93 | r = _request_with_retry("get", url, headers=headers) 94 | resp = r.json() 95 | if "statusCode" in resp and resp["statusCode"] == 404: 96 | raise Exception(f"Data GET Error: Data {data_id} not found.") 97 | return cls.__from_response(resp) 98 | -------------------------------------------------------------------------------- /aixplain/factories/finetune_factory/prompt_validator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Text 2 | from aixplain.modules.dataset import Dataset 3 | import re 4 | 5 | 6 | def _get_data_list(dataset: Dataset): 7 | flatten_target_values = [item for sublist in list(dataset.target_data.values()) for item in sublist] 8 | data_list = list(dataset.source_data.values()) + flatten_target_values 9 | return data_list 10 | 11 | 12 | def validate_prompt(prompt: Text, dataset_list: List[Dataset]) -> Text: 13 | result_prompt = prompt 14 | referenced_data = set(re.findall("<<(.+?)>>", prompt)) 15 | for dataset in dataset_list: 16 | data_list = _get_data_list(dataset) 17 | for data in data_list: 18 | if data.id in referenced_data: 19 | result_prompt = result_prompt.replace(f"<<{data.id}>>", f"<<{data.name}>>") 20 | referenced_data.remove(data.id) 21 | referenced_data.add(data.name) 22 | 23 | # check if dataset list has same data name and it is referenced 24 | name_set = set() 25 | for dataset in dataset_list: 26 | data_list = _get_data_list(dataset) 27 | for data in data_list: 28 | assert not ( 29 | data.name in name_set and data.name in referenced_data 30 | ), "Datasets must not have more than one referenced data with same name" 31 | name_set.add(data.name) 32 | 33 | # check if all referenced data have a respective data in dataset list 34 | for dataset in dataset_list: 35 | data_list = _get_data_list(dataset) 36 | for data in data_list: 37 | if data.name in referenced_data: 38 | result_prompt = result_prompt.replace(f"<<{data.name}>>", f"{{{data.name}}}") 39 | referenced_data.remove(data.name) 40 | assert len(referenced_data) == 0, "Referenced data are not present in dataset list" 41 | return result_prompt 42 | -------------------------------------------------------------------------------- /aixplain/factories/index_factory/utils.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | from typing import Text, Optional, ClassVar, Dict 3 | from aixplain.enums import IndexStores, EmbeddingModel 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class BaseIndexParams(BaseModel, ABC): 8 | model_config = ConfigDict(use_enum_values=True) 9 | name: Text 10 | description: Optional[Text] = "" 11 | 12 | def to_dict(self): 13 | data = self.model_dump(exclude_none=True) 14 | data["data"] = data.pop("name") 15 | return data 16 | 17 | @property 18 | @abstractmethod 19 | def id(self) -> str: 20 | """Abstract property that must be implemented in subclasses.""" 21 | pass 22 | 23 | 24 | class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): 25 | embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 26 | embedding_size: Optional[int] = None 27 | 28 | def to_dict(self): 29 | data = super().to_dict() 30 | data["model"] = data.pop("embedding_model") 31 | if data.get("embedding_size"): 32 | data["additional_params"] = {"embedding_size": data.pop("embedding_size")} 33 | return data 34 | 35 | 36 | class VectaraParams(BaseIndexParams): 37 | _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() 38 | 39 | @property 40 | def id(self) -> str: 41 | return self._id 42 | 43 | 44 | class ZeroEntropyParams(BaseIndexParams): 45 | _id: ClassVar[str] = IndexStores.ZERO_ENTROPY.get_model_id() 46 | 47 | @property 48 | def id(self) -> str: 49 | return self._id 50 | 51 | 52 | class AirParams(BaseIndexParamsWithEmbeddingModel): 53 | _id: ClassVar[str] = IndexStores.AIR.get_model_id() 54 | 55 | @property 56 | def id(self) -> str: 57 | return self._id 58 | 59 | 60 | class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): 61 | _id: ClassVar[str] = IndexStores.GRAPHRAG.get_model_id() 62 | llm: Optional[Text] = None 63 | 64 | @property 65 | def id(self) -> str: 66 | return self._id 67 | -------------------------------------------------------------------------------- /aixplain/factories/script_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Tuple 4 | 5 | import requests 6 | 7 | from aixplain.utils import config 8 | 9 | 10 | class ScriptFactory: 11 | @classmethod 12 | def upload_script(cls, script_path: str) -> Tuple[str, str]: 13 | try: 14 | url = f"{config.BACKEND_URL}/sdk/pipelines/script" 15 | headers = {"Authorization": f"Token {config.TEAM_API_KEY}"} 16 | r = requests.post(url, headers=headers, files={"file": open(script_path, "rb")}) 17 | if 200 <= r.status_code < 300: 18 | response = r.json() 19 | else: 20 | raise Exception() 21 | except Exception: 22 | response = {"fileId": ""} 23 | 24 | # get metadata info 25 | fname = os.path.splitext(os.path.basename(script_path))[0] 26 | file_size = int(os.path.getsize(script_path)) 27 | metadata = json.dumps({"name": fname, "size": file_size}) 28 | return response["fileId"], metadata 29 | -------------------------------------------------------------------------------- /aixplain/factories/team_agent_factory/utils.py: -------------------------------------------------------------------------------- 1 | __author__ = "lucaspavanelli" 2 | 3 | import logging 4 | from typing import Dict, Text, List 5 | from urllib.parse import urljoin 6 | 7 | import aixplain.utils.config as config 8 | from aixplain.enums.asset_status import AssetStatus 9 | from aixplain.modules.agent import Agent 10 | from aixplain.modules.team_agent import TeamAgent, InspectorTarget 11 | 12 | 13 | GPT_4o_ID = "6646261c6eb563165658bbb1" 14 | 15 | 16 | def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = config.TEAM_API_KEY) -> TeamAgent: 17 | """Instantiate a new team agent in the platform.""" 18 | from aixplain.factories.agent_factory import AgentFactory 19 | 20 | agents_dict = payload["agents"] 21 | payload_agents = agents 22 | if payload_agents is None: 23 | payload_agents = [] 24 | for i, agent in enumerate(agents_dict): 25 | try: 26 | payload_agents.append(AgentFactory.get(agent["assetId"])) 27 | except Exception: 28 | logging.warning( 29 | f"Agent {agent['assetId']} not found. Make sure it exists or you have access to it. " 30 | "If you think this is an error, please contact the administrators." 31 | ) 32 | continue 33 | 34 | inspector_targets = [InspectorTarget(target.lower()) for target in payload.get("inspectorTargets", [])] 35 | 36 | team_agent = TeamAgent( 37 | id=payload.get("id", ""), 38 | name=payload.get("name", ""), 39 | agents=payload_agents, 40 | description=payload.get("description", ""), 41 | instructions=payload.get("role", None), 42 | supplier=payload.get("teamId", None), 43 | version=payload.get("version", None), 44 | cost=payload.get("cost", None), 45 | llm_id=payload.get("llmId", GPT_4o_ID), 46 | use_mentalist=True if payload.get("plannerId", None) is not None else False, 47 | use_inspector=True if payload.get("inspectorId", None) is not None else False, 48 | max_inspectors=payload.get("maxInspectors", 1), 49 | inspector_targets=inspector_targets, 50 | api_key=api_key, 51 | status=AssetStatus(payload["status"]), 52 | ) 53 | team_agent.url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{team_agent.id}/run") 54 | 55 | # fill up dependencies 56 | all_tasks = {} 57 | for agent in team_agent.agents: 58 | for task in agent.tasks: 59 | all_tasks[task.name] = task 60 | 61 | for idx, agent in enumerate(team_agent.agents): 62 | for i, task in enumerate(agent.tasks): 63 | for j, dependency in enumerate(task.dependencies or []): 64 | if isinstance(dependency, Text): 65 | task_dependency = all_tasks.get(dependency, None) 66 | if task_dependency: 67 | team_agent.agents[idx].tasks[i].dependencies[j] = task_dependency 68 | else: 69 | raise Exception(f"Team Agent Creation Error: Task dependency not found - {dependency}") 70 | return team_agent 71 | -------------------------------------------------------------------------------- /aixplain/factories/wallet_factory.py: -------------------------------------------------------------------------------- 1 | import aixplain.utils.config as config 2 | from aixplain.modules.wallet import Wallet 3 | from aixplain.utils.request_utils import _request_with_retry 4 | import logging 5 | from typing import Text 6 | 7 | 8 | class WalletFactory: 9 | backend_url = config.BACKEND_URL 10 | 11 | @classmethod 12 | def get(cls, api_key: Text = config.TEAM_API_KEY) -> Wallet: 13 | """Get wallet information""" 14 | try: 15 | resp = None 16 | url = f"{cls.backend_url}/sdk/billing/wallet" 17 | headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} 18 | logging.info(f"Start fetching billing information from - {url} - {headers}") 19 | headers = {"Content-Type": "application/json", "x-api-key": api_key} 20 | r = _request_with_retry("get", url, headers=headers) 21 | resp = r.json() 22 | total_balance = float(resp.get("totalBalance", 0.0)) 23 | reserved_balance = float(resp.get("reservedBalance", 0.0)) 24 | 25 | return Wallet(total_balance=total_balance, reserved_balance=reserved_balance) 26 | except Exception as e: 27 | raise Exception(f"Failed to get the wallet credit information. Error: {str(e)}") 28 | -------------------------------------------------------------------------------- /aixplain/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | aiXplain SDK Library. 3 | --- 4 | 5 | aiXplain SDK enables python programmers to add AI functions 6 | to their software. 7 | 8 | Copyright 2022 The aiXplain SDK authors 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | from .asset import Asset 23 | from .corpus import Corpus 24 | from .data import Data 25 | from .dataset import Dataset 26 | from .file import File 27 | from .metadata import MetaData 28 | from .metric import Metric 29 | from .model import Model 30 | from .model.llm_model import LLM 31 | from .pipeline import Pipeline 32 | from .finetune import Finetune, FinetuneCost 33 | from .finetune.status import FinetuneStatus 34 | from .benchmark import Benchmark 35 | from .benchmark_job import BenchmarkJob 36 | from .agent import Agent 37 | from .agent.tool import Tool 38 | from .team_agent import TeamAgent 39 | from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit 40 | from .model.index_model import IndexModel 41 | -------------------------------------------------------------------------------- /aixplain/modules/agent/agent_response.py: -------------------------------------------------------------------------------- 1 | from aixplain.enums import ResponseStatus 2 | from typing import Any, Dict, Optional, Text, Union, List 3 | from aixplain.modules.agent.agent_response_data import AgentResponseData 4 | from aixplain.modules.model.response import ModelResponse 5 | 6 | 7 | class AgentResponse(ModelResponse): 8 | def __init__( 9 | self, 10 | status: ResponseStatus = ResponseStatus.FAILED, 11 | data: Optional[AgentResponseData] = None, 12 | details: Optional[Union[Dict, List]] = {}, 13 | completed: bool = False, 14 | error_message: Text = "", 15 | used_credits: float = 0.0, 16 | run_time: float = 0.0, 17 | usage: Optional[Dict] = None, 18 | url: Optional[Text] = None, 19 | **kwargs, 20 | ): 21 | 22 | super().__init__( 23 | status=status, 24 | data="", 25 | details=details, 26 | completed=completed, 27 | error_message=error_message, 28 | used_credits=used_credits, 29 | run_time=run_time, 30 | usage=usage, 31 | url=url, 32 | **kwargs, 33 | ) 34 | self.data = data or AgentResponseData() 35 | 36 | def __getitem__(self, key: Text) -> Any: 37 | if key == "data": 38 | return self.data.to_dict() 39 | return super().__getitem__(key) 40 | 41 | def __setitem__(self, key: Text, value: Any) -> None: 42 | if key == "data" and isinstance(value, Dict): 43 | self.data = AgentResponseData.from_dict(value) 44 | elif key == "data" and isinstance(value, AgentResponseData): 45 | self.data = value 46 | else: 47 | super().__setitem__(key, value) 48 | 49 | def to_dict(self) -> Dict[Text, Any]: 50 | base_dict = super().to_dict() 51 | base_dict["data"] = self.data.to_dict() 52 | return base_dict 53 | 54 | def __repr__(self) -> str: 55 | fields = super().__repr__()[len("ModelResponse(") : -1] 56 | return f"AgentResponse({fields})" 57 | -------------------------------------------------------------------------------- /aixplain/modules/agent/agent_response_data.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional, Text 2 | 3 | 4 | class AgentResponseData: 5 | def __init__( 6 | self, 7 | input: Optional[Any] = None, 8 | output: Optional[Any] = None, 9 | session_id: str = "", 10 | intermediate_steps: Optional[List[Any]] = None, 11 | execution_stats: Optional[Dict[str, Any]] = None, 12 | ): 13 | self.input = input 14 | self.output = output 15 | self.session_id = session_id 16 | self.intermediate_steps = intermediate_steps or [] 17 | self.execution_stats = execution_stats 18 | 19 | @classmethod 20 | def from_dict(cls, data: Dict[str, Any]) -> "AgentResponseData": 21 | return cls( 22 | input=data.get("input"), 23 | output=data.get("output"), 24 | session_id=data.get("session_id", ""), 25 | intermediate_steps=data.get("intermediate_steps", []), 26 | execution_stats=data.get("executionStats"), 27 | ) 28 | 29 | def to_dict(self) -> Dict[str, Any]: 30 | return { 31 | "input": self.input, 32 | "output": self.output, 33 | "session_id": self.session_id, 34 | "intermediate_steps": self.intermediate_steps, 35 | "executionStats": self.execution_stats, 36 | "execution_stats": self.execution_stats, 37 | } 38 | 39 | def __getitem__(self, key): 40 | return getattr(self, key, None) 41 | 42 | def __setitem__(self, key, value): 43 | if hasattr(self, key): 44 | setattr(self, key, value) 45 | else: 46 | raise KeyError(f"{key} is not a valid attribute of {self.__class__.__name__}") 47 | 48 | def __repr__(self) -> str: 49 | return ( 50 | f"{self.__class__.__name__}(" 51 | f"input={self.input}, " 52 | f"output={self.output}, " 53 | f"session_id='{self.session_id}', " 54 | f"intermediate_steps={self.intermediate_steps}, " 55 | f"execution_stats={self.execution_stats})" 56 | ) 57 | 58 | def __contains__(self, key: Text) -> bool: 59 | try: 60 | self[key] 61 | return True 62 | except KeyError: 63 | return False 64 | -------------------------------------------------------------------------------- /aixplain/modules/agent/agent_task.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Text, Union 2 | 3 | 4 | class AgentTask: 5 | def __init__( 6 | self, 7 | name: Text, 8 | description: Text, 9 | expected_output: Text, 10 | dependencies: Optional[List[Union[Text, "AgentTask"]]] = None, 11 | ): 12 | self.name = name 13 | self.description = description 14 | self.expected_output = expected_output 15 | self.dependencies = dependencies 16 | 17 | def to_dict(self): 18 | agent_task_dict = { 19 | "name": self.name, 20 | "description": self.description, 21 | "expectedOutput": self.expected_output, 22 | "dependencies": self.dependencies, 23 | } 24 | 25 | if self.dependencies: 26 | for i, dependency in enumerate(agent_task_dict["dependencies"]): 27 | if isinstance(dependency, AgentTask): 28 | agent_task_dict["dependencies"][i] = dependency.name 29 | return agent_task_dict 30 | -------------------------------------------------------------------------------- /aixplain/modules/agent/output_format.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: February 21st 2024 20 | Description: 21 | Asset Enum 22 | """ 23 | 24 | from enum import Enum 25 | from typing import Text 26 | 27 | 28 | class OutputFormat(Text, Enum): 29 | MARKDOWN = "markdown" 30 | TEXT = "text" 31 | -------------------------------------------------------------------------------- /aixplain/modules/agent/tool/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Lucas Pavanelli and Thiago Castro Ferreira 19 | Date: May 16th 2024 20 | Description: 21 | Agentification Class 22 | """ 23 | from abc import ABC 24 | from typing import Optional, Text 25 | from aixplain.utils import config 26 | from aixplain.enums import AssetStatus 27 | 28 | 29 | class Tool(ABC): 30 | """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. 31 | 32 | Attributes: 33 | name (Text): name of the tool 34 | description (Text): description of the tool 35 | version (Text): version of the tool 36 | """ 37 | 38 | def __init__( 39 | self, 40 | name: Text, 41 | description: Text, 42 | version: Optional[Text] = None, 43 | api_key: Optional[Text] = config.TEAM_API_KEY, 44 | status: Optional[AssetStatus] = AssetStatus.DRAFT, 45 | **additional_info, 46 | ) -> None: 47 | """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. 48 | 49 | Args: 50 | name (Text): name of the tool 51 | description (Text): descriptiion of the tool 52 | version (Text): version of the tool 53 | api_key (Text): api key of the tool. Defaults to config.TEAM_API_KEY. 54 | """ 55 | self.name = name 56 | self.description = description 57 | self.version = version 58 | self.api_key = api_key 59 | self.additional_info = additional_info 60 | self.status = status 61 | 62 | def to_dict(self): 63 | """Converts the tool to a dictionary.""" 64 | raise NotImplementedError 65 | 66 | def validate(self): 67 | raise NotImplementedError 68 | -------------------------------------------------------------------------------- /aixplain/modules/agent/tool/custom_python_code_tool.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Lucas Pavanelli and Thiago Castro Ferreira 19 | Date: May 16th 2024 20 | Description: 21 | Agentification Class 22 | """ 23 | 24 | from typing import Text, Union, Callable, Optional 25 | from aixplain.modules.agent.tool import Tool 26 | import logging 27 | from aixplain.enums import AssetStatus 28 | 29 | 30 | class CustomPythonCodeTool(Tool): 31 | """Custom Python Code Tool""" 32 | 33 | def __init__( 34 | self, code: Union[Text, Callable], description: Text = "", name: Optional[Text] = None, **additional_info 35 | ) -> None: 36 | """Custom Python Code Tool""" 37 | super().__init__(name=name or "", description=description, **additional_info) 38 | self.code = code 39 | self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool 40 | 41 | self.validate() 42 | 43 | def to_dict(self): 44 | return { 45 | "name": self.name, 46 | "description": self.description, 47 | "type": "utility", 48 | "utility": "custom_python_code", 49 | "utilityCode": self.code, 50 | } 51 | 52 | def validate(self): 53 | from aixplain.modules.model.utils import parse_code_decorated 54 | 55 | if not str(self.code).startswith("s3://"): 56 | self.code, _, description, name = parse_code_decorated(self.code) 57 | else: 58 | logging.info("Utility Model Already Exists, skipping code validation") 59 | return 60 | 61 | # Set description from parsed code if not already set 62 | if not self.description or self.description.strip() == "": 63 | self.description = description 64 | # Set name from parsed code if could find it 65 | if name and name.strip() != "": 66 | self.name = name 67 | 68 | assert ( 69 | self.description and self.description.strip() != "" 70 | ), "Custom Python Code Tool Error: Tool description is required" 71 | assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" 72 | assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" 73 | assert self.status in [ 74 | AssetStatus.DRAFT, 75 | AssetStatus.ONBOARDED, 76 | ], "Custom Python Code Tool Error: Status must be DRAFT or ONBOARDED" 77 | 78 | 79 | 80 | def __repr__(self) -> Text: 81 | return f"CustomPythonCodeTool(name={self.name})" 82 | -------------------------------------------------------------------------------- /aixplain/modules/agent/tool/pipeline_tool.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Lucas Pavanelli and Thiago Castro Ferreira 19 | Date: May 16th 2024 20 | Description: 21 | Agentification Class 22 | """ 23 | from typing import Text, Union, Optional 24 | 25 | from aixplain.modules.agent.tool import Tool 26 | from aixplain.modules.pipeline import Pipeline 27 | from aixplain.enums import AssetStatus 28 | 29 | 30 | class PipelineTool(Tool): 31 | """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. 32 | 33 | Attributes: 34 | description (Text): description of the tool 35 | pipeline (Union[Text, Pipeline]): pipeline 36 | """ 37 | 38 | def __init__( 39 | self, 40 | description: Text, 41 | pipeline: Union[Text, Pipeline], 42 | name: Optional[Text] = None, 43 | **additional_info, 44 | ) -> None: 45 | """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. 46 | 47 | Args: 48 | description (Text): description of the tool 49 | pipeline (Union[Text, Pipeline]): pipeline 50 | """ 51 | name = name or "" 52 | super().__init__(name=name, description=description, **additional_info) 53 | 54 | self.status = AssetStatus.DRAFT 55 | 56 | self.pipeline = pipeline 57 | self.validate() 58 | 59 | def to_dict(self): 60 | return { 61 | "assetId": self.pipeline, 62 | "name": self.name, 63 | "description": self.description, 64 | "type": "pipeline", 65 | "status": self.status, 66 | } 67 | 68 | def __repr__(self) -> Text: 69 | return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" 70 | 71 | def validate(self): 72 | from aixplain.factories.pipeline_factory import PipelineFactory 73 | 74 | if isinstance(self.pipeline, Pipeline): 75 | pipeline_obj = self.pipeline 76 | else: 77 | try: 78 | pipeline_obj = PipelineFactory.get(self.pipeline, api_key=self.api_key) 79 | except Exception: 80 | raise Exception( 81 | f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it." 82 | ) 83 | 84 | if self.name.strip() == "": 85 | self.name = pipeline_obj.name 86 | self.status = pipeline_obj.status 87 | 88 | -------------------------------------------------------------------------------- /aixplain/modules/agent/tool/python_interpreter_tool.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Lucas Pavanelli and Thiago Castro Ferreira 19 | Date: May 16th 2024 20 | Description: 21 | Agentification Class 22 | """ 23 | 24 | from aixplain.modules.agent.tool import Tool 25 | from aixplain.enums import AssetStatus 26 | 27 | from typing import Text 28 | 29 | 30 | class PythonInterpreterTool(Tool): 31 | """Python Interpreter Tool""" 32 | 33 | def __init__(self, **additional_info) -> None: 34 | """Python Interpreter Tool""" 35 | description = "A Python shell. Use this to execute python commands. Input should be a valid python command." 36 | super().__init__(name="Python Interpreter", description=description, **additional_info) 37 | self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool 38 | 39 | 40 | def to_dict(self): 41 | return { 42 | "description": self.description, 43 | "type": "utility", 44 | "utility": "custom_python_code", 45 | } 46 | 47 | def validate(self): 48 | pass 49 | 50 | def __repr__(self) -> Text: 51 | return "PythonInterpreterTool()" 52 | -------------------------------------------------------------------------------- /aixplain/modules/agent/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Text, Union 2 | import re 3 | 4 | 5 | def process_variables(query: Union[Text, Dict], data: Union[Dict, Text], parameters: Dict, agent_description: Text) -> Text: 6 | from aixplain.factories.file_factory import FileFactory 7 | 8 | if isinstance(query, dict): 9 | for key, value in query.items(): 10 | assert isinstance(value, str), "When providing a dictionary, all values must be strings." 11 | query[key] = FileFactory.to_link(value) 12 | input_data = query 13 | else: 14 | input_data = {"input": FileFactory.to_link(query)} 15 | 16 | variables = re.findall(r"(? None: 41 | """Create an Asset with the necessary information 42 | 43 | Args: 44 | id (Text): ID of the Asset 45 | name (Text): Name of the Asset 46 | description (Text): Description of the Asset 47 | supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". 48 | version (Optional[Text], optional): asset version. Defaults to "1.0". 49 | cost (Optional[Union[Dict, float]], optional): asset price. Defaults to None. 50 | """ 51 | self.id = id 52 | self.name = name 53 | self.description = description 54 | try: 55 | if isinstance(supplier, Supplier) is True: 56 | self.supplier = supplier 57 | elif isinstance(supplier, Dict) is True: 58 | self.supplier = Supplier(supplier) 59 | else: 60 | self.supplier = None 61 | for supplier_ in Supplier: 62 | if supplier.lower() in [supplier_.value["code"].lower(), supplier_.value["name"].lower()]: 63 | self.supplier = supplier_ 64 | break 65 | if self.supplier is None: 66 | self.supplier = supplier 67 | except Exception: 68 | self.supplier = Supplier.AIXPLAIN 69 | self.version = version 70 | self.license = license 71 | self.privacy = privacy 72 | self.cost = cost 73 | 74 | def to_dict(self) -> dict: 75 | """Get the asset info as a Dictionary 76 | 77 | Returns: 78 | dict: Asset Information 79 | """ 80 | return self.__dict__ 81 | -------------------------------------------------------------------------------- /aixplain/modules/content_interval.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2023 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: June 6th 2023 20 | Description: 21 | Content Interval 22 | """ 23 | 24 | from dataclasses import dataclass 25 | from typing import List, Optional, Text, Tuple, Union 26 | 27 | 28 | @dataclass 29 | class ContentInterval: 30 | content: Text 31 | content_id: int 32 | 33 | 34 | @dataclass 35 | class TextContentInterval(ContentInterval): 36 | start: Union[int, Tuple[int, int]] 37 | end: Union[int, Tuple[int, int]] 38 | 39 | 40 | @dataclass 41 | class AudioContentInterval(ContentInterval): 42 | start_time: float 43 | end_time: float 44 | 45 | 46 | @dataclass 47 | class ImageContentInterval(ContentInterval): 48 | x: Union[float, List[float]] 49 | y: Union[float, List[float]] 50 | width: Optional[float] = None 51 | height: Optional[float] = None 52 | rotation: Optional[float] = None 53 | 54 | 55 | @dataclass 56 | class VideoContentInterval(ContentInterval): 57 | start_time: float 58 | end_time: float 59 | x: Optional[Union[float, List[float]]] = None 60 | y: Optional[Union[float, List[float]]] = None 61 | width: Optional[float] = None 62 | height: Optional[float] = None 63 | rotation: Optional[float] = None 64 | -------------------------------------------------------------------------------- /aixplain/modules/file.py: -------------------------------------------------------------------------------- 1 | __author__ = "aiXplain" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain team 19 | Date: March 20th 2023 20 | Description: 21 | File Class 22 | """ 23 | 24 | import pathlib 25 | 26 | from aixplain.enums.data_split import DataSplit 27 | from aixplain.enums.file_type import FileType 28 | from typing import Optional, Text, Union 29 | 30 | 31 | class File: 32 | def __init__( 33 | self, 34 | path: Union[Text, pathlib.Path], 35 | extension: Union[Text, FileType], 36 | data_split: Optional[DataSplit] = None, 37 | compression: Optional[Text] = None, 38 | ) -> None: 39 | """File Class 40 | 41 | Description: 42 | File where samples of a data is stored in 43 | 44 | Args: 45 | path (Union[Text, pathlib.Path]): File path 46 | extension (Union[Text, FileType]): File extension (e.g. CSV, TXT, etc.) 47 | data_split (Optional[DataSplit], optional): Data split of the file. Defaults to None. 48 | compression (Optional[Text], optional): Compression extension (e.g., .gz). Defaults to None. 49 | """ 50 | self.path = path 51 | 52 | if isinstance(extension, FileType): 53 | self.extension = extension 54 | else: 55 | try: 56 | self.extension = FileType(extension) 57 | except Exception as e: 58 | raise Exception("File Error: This file extension is not supported.") 59 | 60 | self.compression = compression 61 | self.data_split = data_split 62 | -------------------------------------------------------------------------------- /aixplain/modules/finetune/cost.py: -------------------------------------------------------------------------------- 1 | __author__ = "lucaspavanelli" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: June 14th 2023 20 | Description: 21 | FinetuneCost Class 22 | """ 23 | 24 | from typing import Dict 25 | 26 | 27 | class FinetuneCost: 28 | def __init__( 29 | self, 30 | training: Dict, 31 | inference: Dict, 32 | hosting: Dict, 33 | ) -> None: 34 | """Create a FinetuneCost object with training, inference, and hosting cost information. 35 | 36 | Args: 37 | training (Dict): Dictionary containing training cost information. 38 | inference (Dict): Dictionary containing inference cost information. 39 | hosting (Dict): Dictionary containing hosting cost information. 40 | """ 41 | self.training = training 42 | self.inference = inference 43 | self.hosting = hosting 44 | 45 | def to_dict(self) -> Dict: 46 | """Convert the FinetuneCost object to a dictionary. 47 | 48 | Returns: 49 | Dict: A dictionary representation of the FinetuneCost object. 50 | """ 51 | return {"trainingCost": self.training, "inferenceCost": self.inference, "hostingCost": self.hosting} 52 | -------------------------------------------------------------------------------- /aixplain/modules/finetune/hyperparameters.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from dataclasses_json import dataclass_json 3 | from enum import Enum 4 | from typing import Text 5 | 6 | 7 | class SchedulerType(Text, Enum): 8 | LINEAR = "linear" 9 | COSINE = "cosine" 10 | COSINE_WITH_RESTARTS = "cosine_with_restarts" 11 | POLYNOMIAL = "polynomial" 12 | CONSTANT = "constant" 13 | CONSTANT_WITH_WARMUP = "constant_with_warmup" 14 | INVERSE_SQRT = "inverse_sqrt" 15 | REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" 16 | 17 | 18 | EPOCHS_MAX_VALUE = 4 19 | BATCH_SIZE_VALUES = [1, 2, 4, 8, 16, 32, 64] 20 | MAX_SEQ_LENGTH_MAX_VALUE = 4096 21 | 22 | 23 | @dataclass_json 24 | @dataclass 25 | class Hyperparameters(object): 26 | epochs: int = 1 27 | train_batch_size: int = 4 28 | eval_batch_size: int = 4 29 | learning_rate: float = 1e-5 30 | max_seq_length: int = 4096 31 | warmup_ratio: float = 0.0 32 | warmup_steps: int = 0 33 | lr_scheduler_type: SchedulerType = SchedulerType.LINEAR 34 | 35 | def __post_init__(self): 36 | if not isinstance(self.epochs, int): 37 | raise TypeError("epochs should be of type int") 38 | 39 | if not isinstance(self.train_batch_size, int): 40 | raise TypeError("train_batch_size should be of type int") 41 | 42 | if not isinstance(self.eval_batch_size, int): 43 | raise TypeError("eval_batch_size should be of type int") 44 | 45 | if not isinstance(self.learning_rate, float): 46 | raise TypeError("learning_rate should be of type float") 47 | 48 | if not isinstance(self.max_seq_length, int): 49 | raise TypeError("max_seq_length should be of type int") 50 | 51 | if not isinstance(self.warmup_ratio, float): 52 | raise TypeError("warmup_ratio should be of type float") 53 | 54 | if not isinstance(self.warmup_steps, int): 55 | raise TypeError("warmup_steps should be of type int") 56 | 57 | if not isinstance(self.lr_scheduler_type, SchedulerType): 58 | raise TypeError("lr_scheduler_type should be of type SchedulerType") 59 | 60 | if self.epochs > EPOCHS_MAX_VALUE: 61 | raise ValueError(f"epochs must be less or equal to {EPOCHS_MAX_VALUE}") 62 | 63 | if self.train_batch_size not in BATCH_SIZE_VALUES: 64 | raise ValueError(f"train_batch_size must be one of the following values: {BATCH_SIZE_VALUES}") 65 | 66 | if self.eval_batch_size not in BATCH_SIZE_VALUES: 67 | raise ValueError(f"eval_batch_size must be one of the following values: {BATCH_SIZE_VALUES}") 68 | 69 | if self.max_seq_length > MAX_SEQ_LENGTH_MAX_VALUE: 70 | raise ValueError(f"max_seq_length must be less or equal to {MAX_SEQ_LENGTH_MAX_VALUE}") 71 | -------------------------------------------------------------------------------- /aixplain/modules/finetune/status.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 19 | Date: February 21st 2024 20 | Description: 21 | FinetuneCost Class 22 | """ 23 | 24 | from aixplain.enums.asset_status import AssetStatus 25 | from dataclasses import dataclass 26 | from dataclasses_json import dataclass_json 27 | from typing import Optional, Text 28 | 29 | 30 | @dataclass_json 31 | @dataclass 32 | class FinetuneStatus(object): 33 | status: "AssetStatus" 34 | model_status: "AssetStatus" 35 | epoch: Optional[float] = None 36 | training_loss: Optional[float] = None 37 | validation_loss: Optional[float] = None 38 | -------------------------------------------------------------------------------- /aixplain/modules/mixins.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2024 The aiXplain SDK authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli 17 | Date: November 25th 2024 18 | Description: 19 | Mixins for common functionality across different asset types 20 | """ 21 | from abc import ABC 22 | from typing import TypeVar, Generic 23 | from aixplain.enums import AssetStatus 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | class DeployableMixin(ABC, Generic[T]): 29 | """A mixin that provides common deployment-related functionality for assets. 30 | 31 | This mixin provides methods for: 32 | 1. Filtering items that are not onboarded 33 | 2. Validating if an asset is ready to be deployed 34 | 3. Deploying an asset 35 | 36 | Classes that inherit from this mixin should: 37 | 1. Implement _validate_deployment_readiness to call the parent implementation with their specific asset type 38 | 2. Optionally override deploy() if they need special deployment handling 39 | """ 40 | 41 | def _validate_deployment_readiness(self) -> None: 42 | """Validate if the asset is ready to be deployed. 43 | 44 | Args: 45 | asset_type (str): Type of asset being validated (e.g. "Agent", "Team Agent", "Pipeline") 46 | items (Optional[List[T]], optional): List of items to validate (e.g. tools for Agent, agents for TeamAgent) 47 | 48 | Raises: 49 | ValueError: If the asset is not ready to be deployed 50 | """ 51 | asset_type = self.__class__.__name__ 52 | if self.status == AssetStatus.ONBOARDED: 53 | raise ValueError(f"{asset_type} is already deployed.") 54 | 55 | if self.status != AssetStatus.DRAFT: 56 | raise ValueError(f"{asset_type} must be in DRAFT status to be deployed.") 57 | 58 | def deploy(self) -> None: 59 | """Deploy the asset. 60 | 61 | This method validates that the asset is ready to be deployed and updates its status to ONBOARDED. 62 | Classes that need special deployment handling should override this method. 63 | 64 | Raises: 65 | ValueError: If the asset is not ready to be deployed 66 | """ 67 | self._validate_deployment_readiness() 68 | previous_status = self.status 69 | try: 70 | self.status = AssetStatus.ONBOARDED 71 | self.update() 72 | except Exception as e: 73 | self.status = previous_status 74 | raise Exception(f"Error deploying because of backend error: {e}") from e 75 | -------------------------------------------------------------------------------- /aixplain/modules/model/model_parameters.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from aixplain.base.parameters import BaseParameters, Parameter 3 | 4 | 5 | class ModelParameters(BaseParameters): 6 | def __init__(self, input_params: Dict[str, Dict[str, Any]]) -> None: 7 | """Initialize ModelParameters with input parameters dictionary. 8 | 9 | Args: 10 | input_params (Dict[str, Dict[str, Any]]): Dictionary containing parameter configurations 11 | """ 12 | super().__init__() 13 | for param_name, param_config in input_params.items(): 14 | self.parameters[param_name] = Parameter(name=param_name, required=param_config["required"]) 15 | -------------------------------------------------------------------------------- /aixplain/modules/model/model_response_streamer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterator 3 | 4 | from aixplain.modules.model.response import ModelResponse, ResponseStatus 5 | 6 | 7 | class ModelResponseStreamer: 8 | def __init__(self, iterator: Iterator): 9 | self.iterator = iterator 10 | self.status = ResponseStatus.IN_PROGRESS 11 | 12 | def __next__(self): 13 | """ 14 | Returns the next chunk of the response. 15 | """ 16 | line = next(self.iterator).replace("data: ", "") 17 | try: 18 | data = json.loads(line) 19 | except json.JSONDecodeError: 20 | data = {"data": line} 21 | content = data.get("data", "") 22 | if content == "[DONE]": 23 | self.status = ResponseStatus.SUCCESS 24 | content = "" 25 | return ModelResponse(status=self.status, data=content) 26 | 27 | def __iter__(self): 28 | return self 29 | -------------------------------------------------------------------------------- /aixplain/modules/model/record.py: -------------------------------------------------------------------------------- 1 | from aixplain.enums import DataType, StorageType 2 | from typing import Optional 3 | from uuid import uuid4 4 | 5 | 6 | class Record: 7 | def __init__( 8 | self, 9 | value: str = "", 10 | value_type: DataType = DataType.TEXT, 11 | id: Optional[str] = None, 12 | uri: str = "", 13 | attributes: dict = {}, 14 | ): 15 | self.value = value 16 | self.value_type = value_type 17 | self.id = id if id is not None else str(uuid4()) 18 | self.uri = uri 19 | self.attributes = attributes 20 | 21 | def to_dict(self): 22 | return { 23 | "data": self.value, 24 | "dataType": str(self.value_type), 25 | "document_id": self.id, 26 | "uri": self.uri, 27 | "attributes": self.attributes, 28 | } 29 | 30 | def validate(self): 31 | """Validate the record""" 32 | from aixplain.factories import FileFactory 33 | from aixplain.modules.model.utils import is_supported_image_type 34 | 35 | assert self.value_type in [DataType.TEXT, DataType.IMAGE], "Index Upsert Error: Invalid value type" 36 | if self.value_type == DataType.IMAGE: 37 | assert self.uri is not None and self.uri != "", "Index Upsert Error: URI is required for image records" 38 | else: 39 | assert self.value is not None and self.value != "", "Index Upsert Error: Value is required for text records" 40 | 41 | storage_type = FileFactory.check_storage_type(self.uri) 42 | 43 | # Check if value is an image file or URL 44 | if storage_type in [StorageType.FILE, StorageType.URL]: 45 | if is_supported_image_type(self.uri): 46 | self.value_type = DataType.IMAGE 47 | self.uri = FileFactory.to_link(self.uri) if storage_type == StorageType.FILE else self.uri 48 | else: 49 | raise Exception(f"Index Upsert Error: Unsupported file type ({self.uri})") 50 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import Pipeline 2 | 3 | __all__ = ["Pipeline"] 4 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/default.py: -------------------------------------------------------------------------------- 1 | from .asset import Pipeline as PipelineAsset 2 | from .designer import DesignerPipeline 3 | 4 | 5 | class DefaultPipeline(PipelineAsset, DesignerPipeline): 6 | def __init__(self, *args, **kwargs): 7 | PipelineAsset.__init__(self, *args, **kwargs) 8 | DesignerPipeline.__init__(self) 9 | 10 | def save(self, *args, **kwargs): 11 | self.auto_infer() 12 | self.validate() 13 | super().save(*args, **kwargs) 14 | 15 | def to_dict(self) -> dict: 16 | return self.serialize() 17 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/designer/__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import ( 2 | AssetNode, 3 | Decision, 4 | Script, 5 | Input, 6 | Output, 7 | Route, 8 | Router, 9 | BaseReconstructor, 10 | BaseSegmentor, 11 | BaseMetric, 12 | BareAsset, 13 | BareMetric, 14 | BareSegmentor, 15 | BareReconstructor, 16 | ) 17 | from .pipeline import DesignerPipeline 18 | from .base import ( 19 | Node, 20 | Link, 21 | Param, 22 | ParamProxy, 23 | InputParam, 24 | OutputParam, 25 | Inputs, 26 | Outputs, 27 | TI, 28 | TO, 29 | ) 30 | from .enums import ( 31 | ParamType, 32 | RouteType, 33 | Operation, 34 | NodeType, 35 | AssetType, 36 | FunctionType, 37 | ) 38 | from .mixins import LinkableMixin, OutputableMixin, RoutableMixin 39 | 40 | 41 | __all__ = [ 42 | "DesignerPipeline", 43 | "AssetNode", 44 | "BareAsset", 45 | "Decision", 46 | "Script", 47 | "Input", 48 | "Output", 49 | "Route", 50 | "Router", 51 | "BaseReconstructor", 52 | "BaseSegmentor", 53 | "Node", 54 | "Link", 55 | "Param", 56 | "ParamType", 57 | "InputParam", 58 | "OutputParam", 59 | "RouteType", 60 | "Operation", 61 | "NodeType", 62 | "AssetType", 63 | "FunctionType", 64 | "LinkableMixin", 65 | "OutputableMixin", 66 | "RoutableMixin", 67 | "Inputs", 68 | "Outputs", 69 | "ParamProxy", 70 | "TI", 71 | "TO", 72 | "BaseMetric", 73 | "BareMetric", 74 | ] 75 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/designer/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class RouteType(str, Enum): 5 | CHECK_TYPE = "checkType" 6 | CHECK_VALUE = "checkValue" 7 | 8 | 9 | class Operation(str, Enum): 10 | GREATER_THAN = "greaterThan" 11 | GREATER_THAN_OR_EQUAL = "greaterThanOrEqual" 12 | LESS_THAN = "lessThan" 13 | LESS_THAN_OR_EQUAL = "lessThanOrEqual" 14 | EQUAL = "equal" 15 | DIFFERENT = "different" 16 | CONTAIN = "contain" 17 | NOT_CONTAIN = "notContain" 18 | 19 | 20 | class NodeType(str, Enum): 21 | ASSET = "ASSET" 22 | INPUT = "INPUT" 23 | OUTPUT = "OUTPUT" 24 | SCRIPT = "SCRIPT" 25 | ROUTER = "ROUTER" 26 | DECISION = "DECISION" 27 | 28 | 29 | class AssetType(str, Enum): 30 | MODEL = "MODEL" 31 | 32 | 33 | class FunctionType(str, Enum): 34 | AI = "ai" 35 | SEGMENTOR = "segmentor" 36 | RECONSTRUCTOR = "reconstructor" 37 | UTILITY = "utility" 38 | METRIC = "metric" 39 | 40 | 41 | class ParamType: 42 | INPUT = "INPUT" 43 | OUTPUT = "OUTPUT" 44 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/designer/mixins.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from .base import Node, Link, Param 3 | 4 | 5 | class LinkableMixin: 6 | """ 7 | Linkable mixin class, this class will be used to link the output of the 8 | node to the input of another node. 9 | 10 | This class will be used to link the output of the node to the input of 11 | another node. 12 | """ 13 | 14 | def link( 15 | self, 16 | to_node: Node, 17 | from_param: Union[str, Param], 18 | to_param: Union[str, Param], 19 | ) -> Link: 20 | """ 21 | Link the output of the node to the input of another node. This method 22 | will link the output of the node to the input of another node. 23 | 24 | :param to_node: the node to link to the output 25 | :param from_param: the output parameter or the code of the output 26 | parameter 27 | :param to_param: the input parameter or the code of the input parameter 28 | :return: the link 29 | """ 30 | return Link( 31 | pipeline=self.pipeline, 32 | from_node=self, 33 | to_node=to_node, 34 | from_param=from_param, 35 | to_param=to_param, 36 | ) 37 | 38 | 39 | class RoutableMixin: 40 | """ 41 | Routable mixin class, this class will be used to route the input data to 42 | different nodes based on the input data type. 43 | """ 44 | 45 | def route(self, *params: Param) -> Node: 46 | """ 47 | Route the input data to different nodes based on the input data type. 48 | This method will automatically link the input data to the output data 49 | of the node. 50 | 51 | :param params: the output parameters 52 | :return: the router node 53 | """ 54 | assert self.pipeline, "Node not attached to a pipeline" 55 | 56 | router = self.pipeline.router([(param.data_type, param.node) for param in params]) 57 | self.outputs.input.link(router.inputs.input) 58 | for param in params: 59 | router.outputs.input.link(param) 60 | return router 61 | 62 | 63 | class OutputableMixin: 64 | """ 65 | Outputable mixin class, this class will be used to link the output of the 66 | node to the output node of the pipeline. 67 | """ 68 | 69 | def use_output(self, param: Union[str, Param]) -> Node: 70 | """ 71 | Use the output of the node as the output of the pipeline. 72 | This method will automatically link the output of the node to the 73 | output node of the pipeline. 74 | 75 | :param param: the output parameter or the code of the output parameter 76 | :return: the output node 77 | """ 78 | assert self.pipeline, "Node not attached to a pipeline" 79 | output = self.pipeline.output() 80 | if isinstance(param, str): 81 | param = self.outputs[param] 82 | param.link(output.inputs.output) 83 | return output 84 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/designer/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | 5 | def find_prompt_params(prompt: str) -> List[str]: 6 | """ 7 | This method will find the prompt parameters in the prompt string. 8 | 9 | :param prompt: the prompt string 10 | :return: list of prompt parameters 11 | """ 12 | param_regex = re.compile(r"\{\{([^\}]+)\}\}") 13 | return param_regex.findall(prompt) 14 | -------------------------------------------------------------------------------- /aixplain/modules/pipeline/response.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Optional, Dict, Text 3 | from aixplain.enums import ResponseStatus 4 | 5 | 6 | @dataclass 7 | class PipelineResponse: 8 | def __init__( 9 | self, 10 | status: ResponseStatus, 11 | error: Optional[Dict[str, Any]] = None, 12 | elapsed_time: Optional[float] = 0.0, 13 | data: Optional[Text] = None, 14 | url: Optional[Text] = "", 15 | **kwargs, 16 | ): 17 | self.status = status 18 | self.error = error 19 | self.elapsed_time = elapsed_time 20 | self.data = data 21 | self.additional_fields = kwargs 22 | self.url = url 23 | 24 | def __getattr__(self, key: str) -> Any: 25 | if self.additional_fields and key in self.additional_fields: 26 | return self.additional_fields[key] 27 | 28 | raise AttributeError() 29 | 30 | def get(self, key: str, default: Any = None) -> Any: 31 | return getattr(self, key, default) 32 | 33 | def __getitem__(self, key: str) -> Any: 34 | return getattr(self, key) 35 | 36 | def __repr__(self) -> str: 37 | fields = [] 38 | if self.status: 39 | fields.append(f"status={self.status}") 40 | if self.error: 41 | fields.append(f"error={self.error}") 42 | if self.elapsed_time is not None: 43 | fields.append(f"elapsed_time={self.elapsed_time}") 44 | if self.data: 45 | fields.append(f"data={self.data}") 46 | if self.additional_fields: 47 | fields.extend([f"{k}={repr(v)}" for k, v in self.additional_fields.items()]) 48 | return f"PipelineResponse({', '.join(fields)})" 49 | 50 | def __contains__(self, key: str) -> bool: 51 | return hasattr(self, key) 52 | -------------------------------------------------------------------------------- /aixplain/modules/wallet.py: -------------------------------------------------------------------------------- 1 | __author__ = "aixplain" 2 | 3 | """ 4 | Copyright 2024 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | 18 | Author: aiXplain Team 19 | Date: August 20th 2024 20 | Description: 21 | Wallet Class 22 | """ 23 | 24 | 25 | class Wallet: 26 | def __init__(self, total_balance: float, reserved_balance: float): 27 | """ 28 | Args: 29 | total_balance (float): total credit balance 30 | reserved_balance (float): reserved credit balance 31 | available_balance (float): available balance (total - credit) 32 | """ 33 | self.total_balance = total_balance 34 | self.reserved_balance = reserved_balance 35 | self.available_balance = total_balance - reserved_balance 36 | -------------------------------------------------------------------------------- /aixplain/processes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/aixplain/processes/__init__.py -------------------------------------------------------------------------------- /aixplain/processes/data_onboarding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/aixplain/processes/data_onboarding/__init__.py -------------------------------------------------------------------------------- /aixplain/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | aiXplain SDK Library. 3 | --- 4 | 5 | aiXplain SDK enables python programmers to add AI functions 6 | to their software. 7 | 8 | Copyright 2022 The aiXplain SDK authors 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | -------------------------------------------------------------------------------- /aixplain/utils/cache_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import logging 5 | 6 | CACHE_DURATION = 24 * 60 * 60 7 | CACHE_FOLDER = ".aixplain_cache" 8 | 9 | 10 | def save_to_cache(cache_file, data): 11 | try: 12 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 13 | with open(cache_file, "w") as f: 14 | json.dump({"timestamp": time.time(), "data": data}, f) 15 | except Exception as e: 16 | logging.error(f"Failed to save cache to {cache_file}: {e}") 17 | 18 | 19 | def load_from_cache(cache_file): 20 | if os.path.exists(cache_file) is True: 21 | with open(cache_file, "r") as f: 22 | cache_data = json.load(f) 23 | if time.time() - cache_data["timestamp"] < CACHE_DURATION: 24 | return cache_data["data"] 25 | else: 26 | return None 27 | return None 28 | -------------------------------------------------------------------------------- /aixplain/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2022 The aiXplain SDK authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import os 18 | import logging 19 | import sentry_sdk 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com") 24 | MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute") 25 | # GET THE API KEY FROM CMD 26 | TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") 27 | AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") 28 | 29 | ENV = "dev" if "dev" in BACKEND_URL else "test" if "test" in BACKEND_URL else "prod" 30 | 31 | if not TEAM_API_KEY and not AIXPLAIN_API_KEY: 32 | raise Exception( 33 | "'TEAM_API_KEY' has not been set properly and is empty. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)" 34 | ) 35 | 36 | 37 | if AIXPLAIN_API_KEY and TEAM_API_KEY and AIXPLAIN_API_KEY != TEAM_API_KEY: 38 | raise Exception( 39 | "Conflicting API keys: 'AIXPLAIN_API_KEY' and 'TEAM_API_KEY' are both provided but do not match. Please provide only one API key." 40 | ) 41 | 42 | 43 | if AIXPLAIN_API_KEY and not TEAM_API_KEY: 44 | TEAM_API_KEY = AIXPLAIN_API_KEY 45 | 46 | PIPELINE_API_KEY = os.getenv("PIPELINE_API_KEY", "") 47 | MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") 48 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") 49 | HF_TOKEN = os.getenv("HF_TOKEN", "") 50 | SENTRY_DSN = os.getenv("SENTRY_DSN") 51 | 52 | if SENTRY_DSN: 53 | sentry_sdk.init( 54 | dsn=SENTRY_DSN, 55 | environment=ENV, 56 | send_default_pii=True, 57 | ) 58 | -------------------------------------------------------------------------------- /aixplain/utils/convert_datatype_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2022 The aiXplain SDK authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from typing import Union, Dict, List 18 | from aixplain.modules.metadata import MetaData 19 | 20 | 21 | def dict_to_metadata(metadatas: List[Union[Dict, MetaData]]) -> None: 22 | 23 | """Convert all the Dicts to MetaData 24 | 25 | Args: 26 | metadatas (List[Union[Dict, MetaData]], optional): metadata of metadata information of the dataset. 27 | 28 | """ 29 | try: 30 | for i in range(len(metadatas)): 31 | if isinstance(metadatas[i], dict): 32 | metadatas[i] = MetaData(**metadatas[i]) 33 | except TypeError: 34 | raise TypeError(f"Data Asset Onboarding Error: One or more elements in the metadata_schema are not well-structured") 35 | -------------------------------------------------------------------------------- /aixplain/utils/request_utils.py: -------------------------------------------------------------------------------- 1 | from requests.adapters import HTTPAdapter, Retry 2 | import requests 3 | from typing import Text 4 | 5 | 6 | def _request_with_retry(method: Text, url: Text, **params) -> requests.Response: 7 | """Wrapper around requests with Session to retry in case it fails 8 | 9 | Args: 10 | method (Text): HTTP method, such as 'GET' or 'HEAD'. 11 | url (Text): The URL of the resource to fetch. 12 | **params: Params to pass to request function 13 | 14 | Returns: 15 | requests.Response: Response object of the request 16 | """ 17 | session = requests.Session() 18 | retries = Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]) 19 | session.mount("https://", HTTPAdapter(max_retries=retries)) 20 | response = session.request(method=method.upper(), url=url, **params) 21 | return response 22 | -------------------------------------------------------------------------------- /aixplain/v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/aixplain/v2/__init__.py -------------------------------------------------------------------------------- /aixplain/v2/api_key.py: -------------------------------------------------------------------------------- 1 | from .resource import ( 2 | BaseResource, 3 | BareListParams, 4 | BareGetParams, 5 | GetResourceMixin, 6 | ListResourceMixin, 7 | CreateResourceMixin, 8 | BaseCreateParams, 9 | Page, 10 | ) 11 | from datetime import datetime 12 | from typing_extensions import Unpack, NotRequired, TYPE_CHECKING 13 | from typing import Dict, List, Optional, Text, Union 14 | 15 | if TYPE_CHECKING: 16 | from aixplain.modules import APIKeyLimits, APIKeyUsageLimit 17 | 18 | 19 | class APIKeyCreateParams(BaseCreateParams): 20 | name: Text 21 | budget: int 22 | global_limits: Union[Dict, "APIKeyLimits"] 23 | asset_limits: List[Union[Dict, "APIKeyLimits"]] 24 | expires_at: datetime 25 | 26 | 27 | class APIKeyGetParams(BareGetParams): 28 | api_key: NotRequired[str] 29 | 30 | 31 | class APIKey( 32 | BaseResource, 33 | GetResourceMixin[APIKeyGetParams, "APIKey"], 34 | ListResourceMixin[BareListParams, "APIKey"], 35 | CreateResourceMixin[APIKeyCreateParams, "APIKey"], 36 | ): 37 | @classmethod 38 | def get(cls, **kwargs: Unpack[APIKeyGetParams]) -> "APIKey": 39 | from aixplain.factories import APIKeyFactory 40 | import aixplain.utils.config as config 41 | 42 | api_key = kwargs.get("api_key", config.TEAM_API_KEY) 43 | return APIKeyFactory.get(api_key=api_key) 44 | 45 | @classmethod 46 | def list(cls, **kwargs: Unpack[BareListParams]) -> Page["APIKey"]: 47 | from aixplain.factories import APIKeyFactory 48 | 49 | return APIKeyFactory.list(**kwargs) 50 | 51 | @classmethod 52 | def create(cls, *args, **kwargs: Unpack[APIKeyCreateParams]) -> "APIKey": 53 | from aixplain.factories import APIKeyFactory 54 | 55 | return APIKeyFactory.create(*args, **kwargs) 56 | 57 | @classmethod 58 | def update(cls, api_key: "APIKey") -> "APIKey": 59 | from aixplain.factories import APIKeyFactory 60 | 61 | return APIKeyFactory.update(api_key) 62 | 63 | @classmethod 64 | def get_usage_limits( 65 | cls, api_key: Text = None, asset_id: Optional[Text] = None 66 | ) -> List["APIKeyUsageLimit"]: 67 | from aixplain.factories import APIKeyFactory 68 | from aixplain.utils import config 69 | 70 | api_key = api_key or config.TEAM_API_KEY 71 | 72 | return APIKeyFactory.get_usage_limits(api_key, asset_id) 73 | -------------------------------------------------------------------------------- /aixplain/v2/benchmark.py: -------------------------------------------------------------------------------- 1 | from typing import List, TYPE_CHECKING 2 | from typing_extensions import Unpack 3 | 4 | from aixplain.v2.resource import ( 5 | BaseResource, 6 | GetResourceMixin, 7 | BareGetParams, 8 | CreateResourceMixin, 9 | BareCreateParams, 10 | ) 11 | 12 | if TYPE_CHECKING: 13 | from aixplain.modules.metric import Metric 14 | from aixplain.modules.model import Model 15 | from aixplain.modules.dataset import Dataset 16 | 17 | 18 | class BenchmarkCreateParams(BareCreateParams): 19 | """Parameters for creating a benchmark. 20 | 21 | Attributes: 22 | name: str: The name of the benchmark. 23 | dataset_list: List["Dataset"]: The list of datasets. 24 | model_list: List["Model"]: The list of models. 25 | metric_list: List["Metric"]: The list of metrics. 26 | """ 27 | 28 | name: str 29 | dataset_list: List["Dataset"] 30 | model_list: List["Model"] 31 | metric_list: List["Metric"] 32 | 33 | 34 | class Benchmark( 35 | BaseResource, 36 | GetResourceMixin[BareGetParams, "Benchmark"], 37 | CreateResourceMixin[BenchmarkCreateParams, "Benchmark"], 38 | ): 39 | """Resource for benchmarks.""" 40 | 41 | RESOURCE_PATH = "sdk/benchmarks" 42 | 43 | @classmethod 44 | def get(cls, id: str, **kwargs: Unpack[BareGetParams]) -> "Benchmark": 45 | from aixplain.factories import BenchmarkFactory 46 | 47 | return BenchmarkFactory.get(benchmark_id=id) 48 | 49 | @classmethod 50 | def create(cls, *args, **kwargs: Unpack[BenchmarkCreateParams]) -> "Benchmark": 51 | from aixplain.factories import BenchmarkFactory 52 | 53 | return BenchmarkFactory.create(*args, **kwargs) 54 | 55 | @classmethod 56 | def list_normalization_options(cls, metric: "Metric", model: "Model") -> List[str]: 57 | """ 58 | List the normalization options for a metric and a model. 59 | 60 | Args: 61 | metric: "Metric": The metric. 62 | model: "Model": The model. 63 | 64 | Returns: 65 | List[str]: The list of normalization options. 66 | """ 67 | from aixplain.factories import BenchmarkFactory 68 | 69 | return BenchmarkFactory.list_normalization_options(metric, model) 70 | 71 | 72 | class BenchmarkJob( 73 | BaseResource, 74 | GetResourceMixin[BareGetParams, "BenchmarkJob"], 75 | ): 76 | """Resource for benchmark jobs.""" 77 | 78 | RESOURCE_PATH = "sdk/benchmarks/jobs" 79 | 80 | @classmethod 81 | def get(cls, **kwargs: Unpack[BareGetParams]) -> "BenchmarkJob": 82 | from aixplain.factories import BenchmarkFactory 83 | 84 | return BenchmarkFactory.get_job(job_id=kwargs["id"]) 85 | 86 | def get_scores(self) -> dict: 87 | """ 88 | Get the scores for a benchmark job. 89 | 90 | Returns: 91 | dict: The scores. 92 | """ 93 | from aixplain.factories import BenchmarkFactory 94 | 95 | return BenchmarkFactory.get_benchmark_job_scores(self.id) 96 | -------------------------------------------------------------------------------- /aixplain/v2/corpus.py: -------------------------------------------------------------------------------- 1 | from .resource import ( 2 | BaseResource, 3 | BaseCreateParams, 4 | BaseListParams, 5 | ListResourceMixin, 6 | GetResourceMixin, 7 | BareGetParams, 8 | Page, 9 | ) 10 | 11 | from .enums import DataType, Function, Language, License, Privacy, ErrorHandler 12 | from pathlib import Path 13 | from typing_extensions import Unpack, NotRequired, TYPE_CHECKING 14 | from typing import Any, Dict, List, Text, Union 15 | 16 | if TYPE_CHECKING: 17 | from aixplain.modules.metadata import MetaData 18 | 19 | 20 | class CorpusCreateParams(BaseCreateParams): 21 | name: Text 22 | description: Text 23 | license: License 24 | content_path: Union[Union[Text, Path], List[Union[Text, Path]]] 25 | schema: List[Union[Dict, "MetaData"]] 26 | ref_data: List[Any] 27 | tags: List[Text] 28 | functions: List[Function] 29 | privacy: Privacy 30 | error_handler: ErrorHandler 31 | api_key: NotRequired[Text] 32 | 33 | 34 | class CorpusListParams(BaseListParams): 35 | """Parameters for listing corpora. 36 | 37 | Attributes: 38 | query: Optional[Text]: A search query. 39 | function: Optional[Function]: The function of the model. 40 | suppliers: Union[Supplier, List[Supplier]: The suppliers of the model. 41 | source_languages: Union[Language, List[Language]: The source languages of the model. 42 | target_languages: Union[Language, List[Language]: The target languages of the model. 43 | is_finetunable: bool: Whether the model is finetunable. 44 | """ 45 | 46 | query: NotRequired[Text] 47 | function: NotRequired[Function] 48 | language: NotRequired[Union[Language, List[Language]]] 49 | data_type: NotRequired[DataType] 50 | license: NotRequired[License] 51 | page_number: int 52 | page_size: int 53 | 54 | 55 | class Corpus( 56 | BaseResource, 57 | ListResourceMixin[CorpusListParams, "Corpus"], 58 | GetResourceMixin[BareGetParams, "Corpus"], 59 | ): 60 | @classmethod 61 | def get(cls, id: str, **kwargs: Unpack[BareGetParams]) -> "Corpus": 62 | from aixplain.factories import CorpusFactory 63 | 64 | return CorpusFactory.get(corpus_id=id) 65 | 66 | @classmethod 67 | def list(cls, **kwargs: Unpack[CorpusListParams]) -> Page["Corpus"]: 68 | from aixplain.factories import CorpusFactory 69 | 70 | kwargs.setdefault("page_number", cls.PAGINATE_DEFAULT_PAGE_NUMBER) 71 | kwargs.setdefault("page_size", cls.PAGINATE_DEFAULT_PAGE_SIZE) 72 | 73 | return CorpusFactory.list(**kwargs) 74 | 75 | @classmethod 76 | def create(cls, *args, **kwargs: Unpack[CorpusCreateParams]) -> Dict: 77 | from aixplain.factories import CorpusFactory 78 | 79 | kwargs.setdefault("ref_data", []) 80 | kwargs.setdefault("tags", []) 81 | kwargs.setdefault("functions", []) 82 | kwargs.setdefault("privacy", Privacy.PRIVATE) 83 | kwargs.setdefault("error_handler", ErrorHandler.SKIP) 84 | return CorpusFactory.create(*args, **kwargs) 85 | -------------------------------------------------------------------------------- /aixplain/v2/data.py: -------------------------------------------------------------------------------- 1 | from .resource import ( 2 | BaseResource, 3 | GetResourceMixin, 4 | BareGetParams, 5 | ) 6 | 7 | from typing_extensions import Unpack 8 | 9 | 10 | class Data( 11 | BaseResource, 12 | GetResourceMixin[BareGetParams, "Data"], 13 | ): 14 | @classmethod 15 | def get(cls, id: str, **kwargs: Unpack[BareGetParams]) -> "Data": 16 | from aixplain.factories import DataFactory 17 | 18 | return DataFactory.get(data_id=id) 19 | -------------------------------------------------------------------------------- /aixplain/v2/enums_include.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AssetStatus(str, Enum): 5 | DRAFT = "draft" 6 | HIDDEN = "hidden" 7 | SCHEDULED = "scheduled" 8 | ONBOARDING = "onboarding" 9 | ONBOARDED = "onboarded" 10 | PENDING = "pending" 11 | FAILED = "failed" 12 | TRAINING = "training" 13 | REJECTED = "rejected" 14 | ENABLING = "enabling" 15 | DELETING = "deleting" 16 | DISABLED = "disabled" 17 | DELETED = "deleted" 18 | IN_PROGRESS = "in_progress" 19 | COMPLETED = "completed" 20 | CANCELING = "canceling" 21 | CANCELED = "canceled" 22 | 23 | 24 | class DataSplit(str, Enum): 25 | TRAIN = "train" 26 | VALIDATION = "validation" 27 | TEST = "test" 28 | 29 | 30 | class DataSubtype(str, Enum): 31 | AGE = "age" 32 | GENDER = "gender" 33 | INTERVAL = "interval" 34 | OTHER = "other" 35 | RACE = "race" 36 | SPLIT = "split" 37 | TOPIC = "topic" 38 | 39 | 40 | class DataType(str, Enum): 41 | AUDIO = "audio" 42 | FLOAT = "float" 43 | IMAGE = "image" 44 | INTEGER = "integer" 45 | LABEL = "label" 46 | TENSOR = "tensor" 47 | TEXT = "text" 48 | VIDEO = "video" 49 | EMBEDDING = "embedding" 50 | NUMBER = "number" 51 | BOOLEAN = "boolean" 52 | 53 | 54 | class ErrorHandler(str, Enum): 55 | """ 56 | Enumeration class defining different error handler strategies. 57 | 58 | Attributes: 59 | SKIP (str): skip failed rows. 60 | FAIL (str): raise an exception. 61 | """ 62 | 63 | SKIP = "skip" 64 | FAIL = "fail" 65 | 66 | 67 | class FileType(str, Enum): 68 | CSV = ".csv" 69 | JSON = ".json" 70 | TXT = ".txt" 71 | XML = ".xml" 72 | FLAC = ".flac" 73 | MP3 = ".mp3" 74 | WAV = ".wav" 75 | JPEG = ".jpeg" 76 | PNG = ".png" 77 | JPG = ".jpg" 78 | GIF = ".gif" 79 | WEBP = ".webp" 80 | AVI = ".avi" 81 | MP4 = ".mp4" 82 | MOV = ".mov" 83 | MPEG4 = ".mpeg4" 84 | 85 | 86 | class OnboardStatus(str, Enum): 87 | ONBOARDING = "onboarding" 88 | ONBOARDED = "onboarded" 89 | FAILED = "failed" 90 | DELETED = "deleted" 91 | 92 | 93 | class OwnershipType(str, Enum): 94 | SUBSCRIBED = "SUBSCRIBED" 95 | OWNED = "OWNED" 96 | 97 | 98 | class Privacy(str, Enum): 99 | PUBLIC = "Public" 100 | PRIVATE = "Private" 101 | RESTRICTED = "Restricted" 102 | 103 | 104 | class ResponseStatus(str, Enum): 105 | IN_PROGRESS = "IN_PROGRESS" 106 | SUCCESS = "SUCCESS" 107 | FAILED = "FAILED" 108 | 109 | 110 | class SortBy(str, Enum): 111 | CREATION_DATE = "createdAt" 112 | PRICE = "normalizedPrice" 113 | POPULARITY = "totalSubscribed" 114 | 115 | 116 | class SortOrder(Enum): 117 | ASCENDING = 1 118 | DESCENDING = -1 119 | 120 | 121 | class StorageType(str, Enum): 122 | TEXT = "text" 123 | URL = "url" 124 | FILE = "file" 125 | -------------------------------------------------------------------------------- /aixplain/v2/file.py: -------------------------------------------------------------------------------- 1 | from typing import List, TYPE_CHECKING 2 | from typing_extensions import Unpack, NotRequired 3 | 4 | from aixplain.v2.resource import BaseResource, CreateResourceMixin, BaseCreateParams 5 | 6 | if TYPE_CHECKING: 7 | from aixplain.v2.enums import License, StorageType 8 | 9 | 10 | class FileCreateParams(BaseCreateParams): 11 | """Parameters for creating a file.""" 12 | 13 | local_path: str 14 | tags: NotRequired[List[str]] 15 | license: NotRequired["License"] 16 | is_temp: NotRequired[bool] 17 | 18 | 19 | class File(BaseResource, CreateResourceMixin[FileCreateParams, "File"]): 20 | """Resource for files.""" 21 | 22 | RESOURCE_PATH = "sdk/files" 23 | 24 | @classmethod 25 | def create(cls, *args, **kwargs: Unpack[FileCreateParams]) -> "File": 26 | """Create a file.""" 27 | from aixplain.factories import FileFactory 28 | 29 | kwargs.setdefault("is_temp", True) 30 | kwargs.setdefault("license", None) 31 | kwargs.setdefault("tags", None) 32 | 33 | return FileFactory.create(*args, **kwargs) 34 | 35 | @classmethod 36 | def to_link(cls, local_path: str) -> str: 37 | """Convert a local path to a link. 38 | 39 | Args: 40 | local_path: str: The local path to the file. 41 | 42 | Returns: 43 | str: The link to the file. 44 | """ 45 | from aixplain.factories import FileFactory 46 | 47 | return FileFactory.to_link(local_path) 48 | 49 | @classmethod 50 | def upload( 51 | cls, 52 | local_path: str, 53 | tags: List[str] = None, 54 | license: "License" = None, 55 | is_temp: bool = True, 56 | ) -> str: 57 | """Upload a file. 58 | 59 | Args: 60 | local_path: str: The local path to the file. 61 | 62 | Returns: 63 | str: The upload URL. 64 | """ 65 | from aixplain.factories import FileFactory 66 | 67 | return FileFactory.upload(local_path, tags, license, is_temp) 68 | 69 | @classmethod 70 | def check_storage_type(cls, upload_url: str) -> "StorageType": 71 | """Check the storage type of a file. 72 | 73 | Args: 74 | upload_url: str: The upload URL. 75 | 76 | Returns: 77 | StorageType: The storage type of the file. 78 | """ 79 | from aixplain.factories import FileFactory 80 | 81 | return FileFactory.check_storage_type(upload_url) 82 | -------------------------------------------------------------------------------- /aixplain/v2/finetune.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, TYPE_CHECKING 2 | from typing_extensions import Unpack, NotRequired 3 | 4 | from aixplain.v2.resource import ( 5 | BaseResource, 6 | CreateResourceMixin, 7 | BareCreateParams, 8 | ) 9 | 10 | if TYPE_CHECKING: 11 | from aixplain.modules.finetune import Hyperparameters 12 | from aixplain.modules.dataset import Dataset 13 | from aixplain.modules.model import Model 14 | 15 | 16 | class FinetuneCreateParams(BareCreateParams): 17 | """Parameters for creating a finetune. 18 | 19 | Attributes: 20 | name: str: The name of the finetune. 21 | dataset_list: List[Dataset]: The list of datasets. 22 | model: Union[Model, str]: The model. 23 | prompt_template: str: The prompt template. 24 | hyperparameters: Hyperparameters: The hyperparameters. 25 | train_percentage: float: The train percentage. 26 | dev_percentage: float: The dev percentage. 27 | """ 28 | 29 | name: str 30 | dataset_list: List["Dataset"] 31 | model: Union["Model", str] 32 | prompt_template: NotRequired[str] 33 | hyperparameters: NotRequired["Hyperparameters"] 34 | train_percentage: NotRequired[float] 35 | dev_percentage: NotRequired[float] 36 | 37 | 38 | class Finetune( 39 | BaseResource, 40 | CreateResourceMixin[FinetuneCreateParams, "Finetune"], 41 | ): 42 | """Resource for finetunes.""" 43 | 44 | RESOURCE_PATH = "sdk/finetunes" 45 | 46 | @classmethod 47 | def create(cls, *args, **kwargs: Unpack[FinetuneCreateParams]) -> "Finetune": 48 | from aixplain.factories import FinetuneFactory 49 | 50 | return FinetuneFactory.create(*args, **kwargs) 51 | -------------------------------------------------------------------------------- /aixplain/v2/metric.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Unpack, NotRequired 2 | 3 | from aixplain.v2.resource import ( 4 | BaseResource, 5 | ListResourceMixin, 6 | GetResourceMixin, 7 | BareGetParams, 8 | BaseListParams, 9 | Page, 10 | ) 11 | 12 | 13 | class MetricListParams(BaseListParams): 14 | """Parameters for listing metrics. 15 | 16 | Attributes: 17 | model_id: str: The model ID. 18 | is_source_required: bool: Whether the source is required. 19 | is_reference_required: bool: Whether the reference is required. 20 | """ 21 | 22 | model_id: str 23 | is_source_required: NotRequired[bool] 24 | is_reference_required: NotRequired[bool] 25 | 26 | 27 | class Metric( 28 | BaseResource, 29 | ListResourceMixin[MetricListParams, "Metric"], 30 | GetResourceMixin[BareGetParams, "Metric"], 31 | ): 32 | """Resource for metrics.""" 33 | 34 | RESOURCE_PATH = "sdk/metrics" 35 | 36 | @classmethod 37 | def get(cls, id: str, **kwargs: Unpack[BareGetParams]) -> "Metric": 38 | from aixplain.factories.metric_factory import MetricFactory 39 | 40 | return MetricFactory.get(metric_id=id) 41 | 42 | @classmethod 43 | def list(cls, **kwargs: Unpack[MetricListParams]) -> Page["Metric"]: 44 | from aixplain.factories.metric_factory import MetricFactory 45 | 46 | kwargs.setdefault("is_source_required", None) 47 | kwargs.setdefault("is_reference_required", None) 48 | kwargs.setdefault("page_number", cls.PAGINATE_DEFAULT_PAGE_NUMBER) 49 | kwargs.setdefault("page_size", cls.PAGINATE_DEFAULT_PAGE_SIZE) 50 | 51 | return MetricFactory.list(**kwargs) 52 | -------------------------------------------------------------------------------- /aixplain/v2/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from typing_extensions import Unpack, NotRequired 3 | 4 | from .resource import ( 5 | BaseResource, 6 | ListResourceMixin, 7 | GetResourceMixin, 8 | BareListParams, 9 | BareGetParams, 10 | CreateResourceMixin, 11 | BaseCreateParams, 12 | Page, 13 | ) 14 | from .enums import Function, Supplier, DataType 15 | from .model import Model 16 | 17 | 18 | class PipelineListParams(BareListParams): 19 | """Parameters for listing pipelines. 20 | 21 | Attributes: 22 | functions: Union[Function, List[Function]]: The functions of the pipeline. 23 | suppliers: Union[Supplier, List[Supplier]]: The suppliers of the pipeline. 24 | models: Union[Model, List[Model]]: The models of the pipeline. 25 | input_data_types: Union[DataType, List[DataType]]: The input data types of the pipeline. 26 | output_data_types: Union[DataType, List[DataType]]: The output data types of the pipeline. 27 | drafts_only: bool: Whether to list only drafts. 28 | """ 29 | 30 | functions: NotRequired[Union[Function, List[Function]]] 31 | suppliers: NotRequired[Union[Supplier, List[Supplier]]] 32 | models: NotRequired[Union[Model, List[Model]]] 33 | input_data_types: NotRequired[Union[DataType, List[DataType]]] 34 | output_data_types: NotRequired[Union[DataType, List[DataType]]] 35 | drafts_only: NotRequired[bool] 36 | 37 | 38 | class PipelineCreateParams(BaseCreateParams): 39 | name: str 40 | pipeline: Union[str, dict] 41 | api_key: NotRequired[str] = None 42 | 43 | 44 | class Pipeline( 45 | BaseResource, 46 | ListResourceMixin[PipelineListParams, "Pipeline"], 47 | GetResourceMixin[BareGetParams, "Pipeline"], 48 | CreateResourceMixin[PipelineCreateParams, "Pipeline"], 49 | ): 50 | """Resource for pipelines. 51 | 52 | Attributes: 53 | RESOURCE_PATH: str: The resource path. 54 | """ 55 | 56 | RESOURCE_PATH = "sdk/pipelines" 57 | 58 | @classmethod 59 | def list(cls, **kwargs: Unpack[PipelineListParams]) -> Page["Pipeline"]: 60 | from aixplain.factories import PipelineFactory 61 | 62 | kwargs.setdefault("page_number", cls.PAGINATE_DEFAULT_PAGE_NUMBER) 63 | kwargs.setdefault("page_size", cls.PAGINATE_DEFAULT_PAGE_SIZE) 64 | 65 | return PipelineFactory.list(**kwargs) 66 | 67 | @classmethod 68 | def get(cls, id: str, **kwargs: Unpack[BareGetParams]) -> "Pipeline": 69 | from aixplain.factories import PipelineFactory 70 | 71 | return PipelineFactory.get(pipeline_id=id) 72 | 73 | @classmethod 74 | def create(cls, *args, **kwargs: Unpack[PipelineCreateParams]) -> "Pipeline": 75 | from aixplain.factories import PipelineFactory 76 | from aixplain.utils import config 77 | 78 | kwargs.setdefault("api_key", config.TEAM_API_KEY) 79 | return PipelineFactory.create(*args, **kwargs) 80 | 81 | @classmethod 82 | def init(cls, name: str, api_key: str = None) -> "Pipeline": 83 | from aixplain.factories import PipelineFactory 84 | 85 | return PipelineFactory.init(name, api_key=api_key) 86 | -------------------------------------------------------------------------------- /aixplain/v2/script.py: -------------------------------------------------------------------------------- 1 | from aixplain.v2.resource import BaseResource 2 | 3 | 4 | class Script(BaseResource): 5 | 6 | @classmethod 7 | def upload(cls, script_path: str) -> "Script": 8 | """Upload a script to the server. 9 | 10 | Args: 11 | script_path: str: The path to the script. 12 | 13 | Returns: 14 | Script: The script. 15 | """ 16 | from aixplain.factories.script_factory import ScriptFactory 17 | 18 | file_id, metadata = ScriptFactory.upload_script(script_path) 19 | 20 | return ScriptFactory.upload_script({"fileId": file_id, "metadata": metadata}) 21 | -------------------------------------------------------------------------------- /aixplain/v2/team_agent.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import ( 2 | Dict, 3 | Unpack, 4 | List, 5 | Union, 6 | TYPE_CHECKING, 7 | NotRequired, 8 | Text, 9 | ) 10 | 11 | from .resource import ( 12 | BaseResource, 13 | ListResourceMixin, 14 | GetResourceMixin, 15 | BareListParams, 16 | BareGetParams, 17 | BaseCreateParams, 18 | Page, 19 | ) 20 | 21 | if TYPE_CHECKING: 22 | from aixplain.modules.agent import Agent 23 | from aixplain.enums import Supplier 24 | 25 | 26 | class TeamAgentCreateParams(BaseCreateParams): 27 | name: Text 28 | agents: List[Union[Text, "Agent"]] 29 | llm_id: Text 30 | description: Text 31 | api_key: Text 32 | supplier: Union[Dict, Text, "Supplier", int] 33 | version: NotRequired[Text] 34 | use_mentalist_and_inspector: bool 35 | 36 | 37 | class TeamAgentGetParams(BareGetParams): 38 | api_key: NotRequired[str] 39 | 40 | 41 | class TeamAgent( 42 | BaseResource, 43 | ListResourceMixin[BareListParams, "TeamAgent"], 44 | GetResourceMixin[BareGetParams, "TeamAgent"], 45 | ): 46 | """Resource for agents. 47 | 48 | Attributes: 49 | RESOURCE_PATH: str: The resource path. 50 | PAGINATE_PATH: None: The path for pagination. 51 | PAGINATE_METHOD: str: The method for pagination. 52 | PAGINATE_ITEMS_KEY: None: The key for the response. 53 | """ 54 | 55 | RESOURCE_PATH = "sdk/agent-communities" 56 | PAGINATE_PATH = None 57 | PAGINATE_METHOD = "get" 58 | PAGINATE_ITEMS_KEY = None 59 | 60 | LLM_ID = "669a63646eb56306647e1091" 61 | SUPPLIER = "aiXplain" 62 | 63 | @classmethod 64 | def list(cls, **kwargs: Unpack[BareListParams]) -> Page["TeamAgent"]: 65 | from aixplain.factories import TeamAgentFactory 66 | 67 | return TeamAgentFactory.list(**kwargs) 68 | 69 | @classmethod 70 | def get(cls, id: str, **kwargs: Unpack[TeamAgentGetParams]) -> "TeamAgent": 71 | from aixplain.factories import TeamAgentFactory 72 | 73 | return TeamAgentFactory.get(id, **kwargs) 74 | 75 | @classmethod 76 | def create(cls, *args, **kwargs: Unpack[TeamAgentCreateParams]) -> "TeamAgent": 77 | from aixplain.factories import TeamAgentFactory 78 | from aixplain.utils import config 79 | 80 | kwargs.setdefault("llm_id", cls.LLM_ID) 81 | kwargs.setdefault("api_key", config.TEAM_API_KEY) 82 | kwargs.setdefault("supplier", cls.SUPPLIER) 83 | kwargs.setdefault("description", "") 84 | kwargs.setdefault("use_mentalist_and_inspector", True) 85 | 86 | return TeamAgentFactory.create(*args, **kwargs) 87 | -------------------------------------------------------------------------------- /aixplain/v2/wallet.py: -------------------------------------------------------------------------------- 1 | from .resource import ( 2 | BaseResource, 3 | GetResourceMixin, 4 | BareGetParams, 5 | ) 6 | from typing_extensions import Unpack, NotRequired 7 | 8 | 9 | class WalletGetParams(BareGetParams): 10 | api_key: NotRequired[str] 11 | 12 | 13 | class Wallet( 14 | BaseResource, 15 | GetResourceMixin[WalletGetParams, "Wallet"], 16 | ): 17 | @classmethod 18 | def get(cls, **kwargs: Unpack[WalletGetParams]) -> "Wallet": 19 | from aixplain.factories import WalletFactory 20 | import aixplain.utils.config as config 21 | 22 | api_key = kwargs.get("api_key", config.TEAM_API_KEY) 23 | return WalletFactory.get(api_key=api_key) 24 | -------------------------------------------------------------------------------- /docs/assets/aixplain-brandmark-line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/assets/aixplain-brandmark-line.png -------------------------------------------------------------------------------- /docs/assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/assets/architecture.png -------------------------------------------------------------------------------- /docs/assets/data-onboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/assets/data-onboard.png -------------------------------------------------------------------------------- /docs/assets/designer-subtitling-sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/assets/designer-subtitling-sample.png -------------------------------------------------------------------------------- /docs/assets/model-id-on-platform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/assets/model-id-on-platform.png -------------------------------------------------------------------------------- /docs/assets/navigate-api-key.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/assets/navigate-api-key.png -------------------------------------------------------------------------------- /docs/assets/subtitle-generator-output.json: -------------------------------------------------------------------------------- 1 | { 2 | "success":true, 3 | "response":{ 4 | "status":"SUCCESS", 5 | "data":"https://aixplain-output.s3.amazonaws.com/0279dc6e-8e55-45a4-88b1-dc7242e908a1.out", 6 | "completed":true, 7 | "numofsegments":4, 8 | "numoffailedsegments":0, 9 | "progress":100, 10 | "segments":[ 11 | { 12 | "index":0, 13 | "separator":" ", 14 | "size":1499, 15 | "inputs":{ 16 | "1":{ 17 | "value":"https://aixplain-output.s3.amazonaws.com/6026307d-70c2-4070-9f09-a1ad1d4dc540_0.txt", 18 | "input_type":"text" 19 | } 20 | }, 21 | "success":true, 22 | "response":"https://aixplain-output.s3.amazonaws.com/868bc682-b0e8-478c-9158-0f60a43e64a3.tmp", 23 | "output_type":"text", 24 | "output_name":"3_ASSET", 25 | "details":"", 26 | "credits_used":0.01499, 27 | "timeSpent":0.969 28 | }, 29 | { 30 | "index":1, 31 | "separator":"\n", 32 | "size":101, 33 | "inputs":{ 34 | "1":{ 35 | "value":"https://aixplain-output.s3.amazonaws.com/6026307d-70c2-4070-9f09-a1ad1d4dc540_1.txt", 36 | "input_type":"text" 37 | } 38 | }, 39 | "success":true, 40 | "response":"https://aixplain-output.s3.amazonaws.com/ca20396d-fb04-41eb-9b48-b253e1a0692d.tmp", 41 | "output_type":"text", 42 | "output_name":"3_ASSET", 43 | "details":"", 44 | "credits_used":0.00101, 45 | "timeSpent":0.925 46 | }, 47 | { 48 | "index":2, 49 | "separator":" ", 50 | "size":1494, 51 | "inputs":{ 52 | "1":{ 53 | "value":"https://aixplain-output.s3.amazonaws.com/6026307d-70c2-4070-9f09-a1ad1d4dc540_2.txt", 54 | "input_type":"text" 55 | } 56 | }, 57 | "success":true, 58 | "response":"https://aixplain-output.s3.amazonaws.com/941e6a08-f814-4ef7-85a0-fbfb5abcc4e6.tmp", 59 | "output_type":"text", 60 | "output_name":"3_ASSET", 61 | "details":"", 62 | "credits_used":0.01494, 63 | "timeSpent":0.938 64 | }, 65 | { 66 | "index":3, 67 | "separator":"\n", 68 | "size":56, 69 | "inputs":{ 70 | "1":{ 71 | "value":"https://aixplain-output.s3.amazonaws.com/6026307d-70c2-4070-9f09-a1ad1d4dc540_3.txt", 72 | "input_type":"text" 73 | } 74 | }, 75 | "success":true, 76 | "response":"https://aixplain-output.s3.amazonaws.com/b26a9055-0fa0-4612-be3b-c62511b616ab.tmp", 77 | "output_type":"text", 78 | "output_name":"3_ASSET", 79 | "details":"", 80 | "credits_used":0.00056, 81 | "timeSpent":0.782 82 | } 83 | ] 84 | }, 85 | "error":"None", 86 | "elapsed_time":0.609717845916748 87 | } -------------------------------------------------------------------------------- /docs/development/developer_guide.md: -------------------------------------------------------------------------------- 1 | # Developer Guide 2 | 3 | ## Requirements 4 | 5 | - Install [Python](https://www.python.org/) 3.5+ 6 | 7 | ## Installation 8 | ``` 9 | pip install -e . 10 | ``` 11 | 12 | ## Running Tests 13 | 14 | ### Setup Environment 15 | 16 | ``` 17 | cp .env.example .env 18 | ``` 19 | 20 | Populate values in ```.env``` for pytest consumption. 21 | 22 | ### Run ```pytest``` 23 | 24 | ``` 25 | pytest 26 | ``` 27 | 28 | ## Changing logging level 29 | 30 | #### Linux or macOS 31 | ```bash 32 | export LOG_LEVEL=DEBUG 33 | ``` 34 | #### Windows 35 | ```bash 36 | set LOG_LEVEL=DEBUG 37 | ``` 38 | #### Jupyter Notebook 39 | ```bash 40 | %env LOG_LEVEL=DEBUG 41 | ``` 42 | 43 | ## Architecture 44 | 45 | ### Diagram 46 | 47 | Data Asset Onboard Process 48 | 49 | ### Data Asset Onboard 50 | 51 | The image below depicts the onboard process of a data asset (e.g. corpora and datasets): 52 | 53 | Data Asset Onboard Process -------------------------------------------------------------------------------- /docs/samples/label_dataset_onboarding/corpus/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/samples/label_dataset_onboarding/corpus/images/1.jpg -------------------------------------------------------------------------------- /docs/samples/label_dataset_onboarding/corpus/images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/samples/label_dataset_onboarding/corpus/images/2.png -------------------------------------------------------------------------------- /docs/samples/label_dataset_onboarding/corpus/index.csv: -------------------------------------------------------------------------------- 1 | ,images,labels 2 | 0,corpus/images/1.jpg,corpus/labels/1.json 3 | 1,corpus/images/2.png,corpus/labels/2.json 4 | -------------------------------------------------------------------------------- /docs/samples/label_dataset_onboarding/corpus/labels/1.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "arcade", 3 | "boundingBox": { 4 | "top": 0, 5 | "bottom": 0, 6 | "left": 0, 7 | "right": 0 8 | } 9 | } -------------------------------------------------------------------------------- /docs/samples/label_dataset_onboarding/corpus/labels/2.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "building", 3 | "boundingBox": { 4 | "top": 0, 5 | "bottom": 0, 6 | "left": 0, 7 | "right": 0 8 | } 9 | } -------------------------------------------------------------------------------- /docs/samples/subtitle_generator/README.md: -------------------------------------------------------------------------------- 1 | # Video Subtitling Example 2 | 3 | The code in `subtitle_generator.py` subtitles a video spoken one source language to a target language and generates an srt file as an output. 4 | 5 | In this example pipeline, the pipeline receives an English video as input and first extracts its audio, which is then fed to a VAD node, responsible to find the voice activity segments of the audio input. These segments are then transcribed by a speech recognition node and then translated into Spanish by a translation English-Spanish node. Finally, a speech synthesis node is added to dub the Spanish translation. 6 | 7 | 8 | 9 | ## Run the Example 10 | 11 | ### Build a Pipeline using aiXplain 12 | 13 | In order to build a subtitle generation pipeline, you need to log in to [aiXplain platform](https://platform.aixplain.com/) and use the web UI for designing pipelines. 14 | 15 | ### Run Code 16 | 17 | Generate a http(s) link to your video file to subtitle. 18 | Using the pipeline id generated for subtitling in the step above, run the code: 19 | 20 | ``` 21 | python3 subtitle_generator.py --video-pt-path \ 22 | --srt-path pt.srt 23 | -k 24 | ``` 25 | 26 | ### Sample Output 27 | 28 | You can refer to a sample output in this [file](../../assets/subtitle-generator-output.json). 29 | -------------------------------------------------------------------------------- /docs/streaming/README.md: -------------------------------------------------------------------------------- 1 | # aiXplain Client Streaming Sample 2 | 3 | This guide will walk you through the process of connecting to and using the aiXplain streaming services. 4 | 5 | ## Prerequisites 6 | 7 | Ensure you have Python and pip installed on your system. 8 | 9 | ## Installation 10 | 11 | To install necessary requirements, navigate to the project directory and run the following command: 12 | 13 | 14 | ```sh 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Generating Stubs 19 | 20 | We use Protocol Buffers (protobuf) to define our service interface. To generate the stubs from .proto files, use the following commands: 21 | 22 | ```bash 23 | # For diarization 24 | python -m grpc_tools.protoc -I./proto --python_out=. --pyi_out=. --grpc_python_out=. proto/aixplain_diarization_streaming.proto 25 | 26 | # For speech transcription 27 | python -m grpc_tools.protoc -I./proto --python_out=. --pyi_out=. --grpc_python_out=. proto/aixplain_speech_transcription_streaming.proto 28 | ``` 29 | 30 | Note: You can generate the stubs in any language of your choice as long as you have the protobuf compiler (protoc) installed and configured. 31 | 32 | ## Running the Client 33 | 34 | aiXplain provides the necessary certificates for mTLS. You can pass them to the client when running it. Here's an example of how to use the diarization client: 35 | 36 | ```bash 37 | python3.8 aixplain_speech_transcription_streaming_client.py --file-path=test_dia.wav --cacert=./client-crt/ca.crt --cert=./client-crt/tls.crt --key=./client-crt/tls.key --addr : 38 | ``` 39 | For transcription, you can also enable subtitle like print style with argument `--print-style=subtitle`. 40 | 41 | ```bash 42 | python3.8 aixplain_diarization_streaming_client.py --file-path=test_dia.wav --cacert=./client-crt/ca.crt --cert=./client-crt/tls.crt --key=./client-crt/tls.key --addr : 43 | ``` 44 | 45 | The arguments --cacert, --cert, and --key are used to provide the paths to the necessary certificate files for mTLS. 46 | 47 | The --file-path argument is used to specify the path to the input file. 48 | 49 | You can configure the model's latency by setting the `--latency` argument. Values between 0.5 and 5.0 seconds are supported. 50 | 51 | ## Audio requirements 52 | 53 | Our service is configured to process audio streamed as a single channel with a sampling rate of 16000Hz. 54 | 55 | If the audio does not meet these specifications, the service may yield unexpected results. 56 | 57 | To ensure compatibility, we provide a helper script that adjusts your audio files to meet these requirements. 58 | 59 | ### Installing Dependencies to run the helper script 60 | 61 | The helper script relies on the pydub package. If you need to use the script to adjust your audio files, install pydub using pip: 62 | 63 | ```sh 64 | pip install pydub==0.25.1 65 | ``` 66 | 67 | ### Using the Helper Script 68 | 69 | If your audio files need to be converted to meet our service's specifications, you can do this with our helper script as follows: 70 | 71 | `python make_audio_compatible.py --source_path=input.wav --dest_path=test_dia.wav` 72 | 73 | If your audio files already meet the specifications, you don't need to use this script or install its dependencies. 74 | -------------------------------------------------------------------------------- /docs/streaming/make_audio_compatible.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pydub import AudioSegment 4 | 5 | FRAME_RATE = 16000 6 | 7 | 8 | def create_compatible_audio(source_path, dest_path): 9 | """ 10 | Function to resample an audio file and change the number of channels if there are more than 1. 11 | """ 12 | # Load the audio file 13 | sound_file = AudioSegment.from_file(source_path) 14 | updated = False 15 | if sound_file.frame_rate != FRAME_RATE: 16 | # Resample the audio file 17 | logging.info(f"Resampling {sound_file.frame_rate} -> {FRAME_RATE}") 18 | sound_file = sound_file.set_frame_rate(FRAME_RATE) 19 | updated = True 20 | # If the audio file has more than one channel, convert it to mono 21 | if sound_file.channels > 1: 22 | logging.info(f"Changing no. channels {sound_file.channels} -> 1") 23 | sound_file = sound_file.set_channels(1) 24 | updated = True 25 | if updated: 26 | # Export the processed audio file 27 | sound_file.export(dest_path, format="wav") 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser(description="Process some audio files.") 32 | parser.add_argument("--source_path", required=True, help="Source path for the audio file") 33 | parser.add_argument("--dest_path", required=True, help="Destination path for the processed audio file") 34 | 35 | args = parser.parse_args() 36 | 37 | create_compatible_audio(args.source_path, args.dest_path) 38 | -------------------------------------------------------------------------------- /docs/streaming/requirements.txt: -------------------------------------------------------------------------------- 1 | grpcio==1.54.3 2 | grpcio-tools==1.54.0 3 | -------------------------------------------------------------------------------- /docs/streaming/test_dia.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/docs/streaming/test_dia.wav -------------------------------------------------------------------------------- /docs/user/api_setup.md: -------------------------------------------------------------------------------- 1 | # Team API Key Guide 2 | [Sign up](https://platform.aixplain.com/register) or [login](https://platform.aixplain.com/login) for an account on aiXplain. Then from the Dashboard, navigate to the [Integrations](https://platform.aixplain.com/account/integrations). 3 | 4 | Please refer to the image below. 5 | 6 | 7 | 8 | ### Creating a New API Key 9 | On the **Integrations** page, you can find the **Create a team access key** button on the top right corner. You can create a new key by clicking that button, then specifiying a label and an (optional) expiry date. 10 | 11 | ### Manage API Keys 12 | On the **Integrations** page, you can view all the existing Team API keys. You can also delete keys on this page. 13 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # aixplain/pyproject.toml 2 | [build-system] 3 | requires = ["setuptools", "setuptools-scm"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools.packages.find] 7 | where = ["."] 8 | include = ["aixplain", "tests"] 9 | namespaces = true 10 | 11 | [project] 12 | name = "aiXplain" 13 | version = "0.2.26" 14 | description = "aiXplain SDK adds AI functions to software." 15 | readme = "README.md" 16 | requires-python = ">=3.5, <4" 17 | license = { text = "Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0" } 18 | authors = [ 19 | {email = "thiago.ferreira@aixplain.com"}, 20 | {email = "krishna.durai@aixplain.com"}, 21 | {email = "lucas.pavanelli@aixplain.com"} 22 | ] 23 | classifiers = [ 24 | "Development Status :: 2 - Pre-Alpha", 25 | "Environment :: Web Environment", 26 | "Intended Audience :: Developers", 27 | "License :: OSI Approved :: Apache Software License", 28 | "Natural Language :: English", 29 | "Operating System :: OS Independent", 30 | "Programming Language :: Python", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.5", 33 | "Programming Language :: Python :: 3.6", 34 | "Programming Language :: Python :: 3.7", 35 | "Programming Language :: Python :: 3.8", 36 | "Programming Language :: Python :: 3.9", 37 | "Programming Language :: Python :: 3.10", 38 | "Programming Language :: Python :: 3.11", 39 | "Programming Language :: Python :: 3 :: Only", 40 | "Programming Language :: Python :: Implementation :: CPython", 41 | "Programming Language :: Python :: Implementation :: PyPy", 42 | "Topic :: Internet :: WWW/HTTP", 43 | "Topic :: Software Development :: Libraries", 44 | ] 45 | dependencies = [ 46 | "requests>=2.1.0", 47 | "tqdm>=4.1.0", 48 | "pandas>=1.2.1", 49 | "python-dotenv>=1.0.0", 50 | "validators>=0.20.0", 51 | "filetype>=1.2.0", 52 | "click>=7.1.2", 53 | "PyYAML>=6.0.1", 54 | "dataclasses-json>=0.5.2", 55 | "Jinja2==3.1.6", 56 | "sentry-sdk>=1.0.0", 57 | "pydantic>=2.10.6" 58 | ] 59 | 60 | [project.urls] 61 | Homepage = "https://github.com/aixplain/aiXplain" 62 | Documentation = "https://github.com/aixplain/pipelines/tree/main/docs" 63 | 64 | [project.scripts] 65 | aixplain = "aixplain.cli_groups:run_cli" 66 | 67 | [project.optional-dependencies] 68 | model-builder = [ 69 | "model-interfaces~=0.0.1" 70 | ] 71 | test = [ 72 | "pytest>=6.1.0", 73 | "docker>=6.1.3", 74 | "requests-mock>=1.11.0", 75 | "pytest-mock>=3.10.0" 76 | ] 77 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = 3 | tests -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description_file=README.md 3 | license_files=LICENSE.rst 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | aiXplain SDK Library. 3 | --- 4 | 5 | aiXplain SDK enables python programmers to add AI functions 6 | to their software. 7 | 8 | Copyright 2022 The aiXplain SDK authors 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | -------------------------------------------------------------------------------- /tests/functional/agent/data/agent_test_end2end.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "agent_name": "TEST Translation agent", 4 | "llm_id": "6626a3a8c8f1d089790cf5a2", 5 | "llm_name": "Groq Llama 3 70B", 6 | "query": "Who is the president of Brazil right now? Translate to pt", 7 | "model_tools": [ 8 | { 9 | "function": "translation", 10 | "supplier": "AWS" 11 | }, 12 | { 13 | "model": "60ddefca8d38c51c58860108", 14 | "function": null, 15 | "supplier": null 16 | } 17 | ] 18 | } 19 | ] 20 | -------------------------------------------------------------------------------- /tests/functional/apikey/README.md: -------------------------------------------------------------------------------- 1 | # API Key Tests 2 | 3 | This directory contains tests for the API Key functionality in the aiXplain SDK. 4 | 5 | ## Prerequisites 6 | 7 | To run these tests, you need: 8 | 9 | 1. An admin API key with permissions to: 10 | - Create new API keys 11 | - Update existing API keys 12 | - Delete API keys 13 | - List API keys 14 | - View API key usage 15 | 16 | 2. Available API key slots: 17 | - The tests create and delete API keys during execution 18 | - Make sure you have at least one available slot for API key creation 19 | - The tests will fail if you've reached the maximum number of allowed API keys 20 | 21 | -------------------------------------------------------------------------------- /tests/functional/apikey/apikey.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Test API Key", 3 | "asset_limits": [ 4 | { 5 | "model": "640b517694bf816d35a59125", 6 | "token_per_minute": 100, 7 | "token_per_day": 1000, 8 | "request_per_day": 1000, 9 | "request_per_minute": 100 10 | } 11 | ], 12 | "global_limits": { 13 | "token_per_minute": 100, 14 | "token_per_day": 1000, 15 | "request_per_day": 1000, 16 | "request_per_minute": 100 17 | }, 18 | "budget": 1000, 19 | "expires_at": "2024-12-12T00:00:00Z" 20 | } 21 | -------------------------------------------------------------------------------- /tests/functional/benchmark/benchmark_functional_test.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import pandas as pd 3 | import json 4 | from dotenv import load_dotenv 5 | import time 6 | 7 | load_dotenv() 8 | from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, BenchmarkFactory 9 | from aixplain.modules.benchmark import Benchmark 10 | from aixplain.modules.benchmark_job import BenchmarkJob 11 | from pathlib import Path 12 | 13 | import pytest 14 | 15 | import logging 16 | 17 | from aixplain import aixplain_v2 as v2 18 | 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.DEBUG) 21 | 22 | TIMEOUT = 60 * 30 23 | RUN_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_run_data.json")) 24 | MODULE_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_module_test_data.json")) 25 | 26 | 27 | def read_data(data_path): 28 | return json.load(open(data_path, "r")) 29 | 30 | 31 | @pytest.fixture(scope="module", params=read_data(RUN_FILE)) 32 | def run_input_map(request): 33 | return request.param 34 | 35 | 36 | @pytest.fixture(scope="module", params=read_data(MODULE_FILE)) 37 | def module_input_map(request): 38 | return request.param 39 | 40 | 41 | def is_job_finshed(benchmark_job): 42 | time_taken = 0 43 | sleep_time = 15 44 | timeout = 15 * 60 45 | while True: 46 | if time_taken > timeout: 47 | break 48 | job_status = benchmark_job.check_status() 49 | if job_status == "in_progress": 50 | time.sleep(sleep_time) 51 | time_taken += sleep_time 52 | elif job_status == "completed": 53 | return True 54 | else: 55 | break 56 | return False 57 | 58 | 59 | def assert_correct_results(benchmark_job): 60 | df = benchmark_job.download_results_as_csv(return_dataframe=True) 61 | assert type(df) is pd.DataFrame, "Couldn't download CSV" 62 | model_success_rate = (sum(df["Model_success"]) * 100) / len(df.index) 63 | assert model_success_rate > 80, f"Low model success rate ({model_success_rate})" 64 | metric_name = "BLEU by sacrebleu" 65 | mean_score = df[metric_name].mean() 66 | assert mean_score != 0, f"Zero Mean Score - Please check metric ({metric_name})" 67 | 68 | 69 | @pytest.mark.parametrize("BenchmarkFactory", [BenchmarkFactory, v2.Benchmark]) 70 | def test_create_and_run(run_input_map, BenchmarkFactory): 71 | model_list = [ModelFactory.get(model_id) for model_id in run_input_map["model_ids"]] 72 | dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in run_input_map["dataset_names"]] 73 | metric_list = [MetricFactory.get(metric_id) for metric_id in run_input_map["metric_ids"]] 74 | benchmark = BenchmarkFactory.create(f"SDK Benchmark Test {uuid.uuid4()}", dataset_list, model_list, metric_list) 75 | assert type(benchmark) is Benchmark, "Couldn't create benchmark" 76 | benchmark_job = benchmark.start() 77 | assert type(benchmark_job) is BenchmarkJob, "Couldn't start job" 78 | assert is_job_finshed(benchmark_job), "Job did not finish in time" 79 | assert_correct_results(benchmark_job) 80 | 81 | 82 | # def test_module(module_input_map): 83 | # benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"]) 84 | # assert benchmark.id == module_input_map["benchmark_id"] 85 | # benchmark_job = benchmark.job_list[0] 86 | # assert benchmark_job.benchmark_id == module_input_map["benchmark_id"] 87 | # job_status = benchmark_job.check_status() 88 | # assert job_status in ["in_progress", "completed"] 89 | # df = benchmark_job.download_results_as_csv(return_dataframe=True) 90 | # assert type(df) is pd.DataFrame 91 | -------------------------------------------------------------------------------- /tests/functional/benchmark/data/benchmark_module_test_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "benchmark_id" : "64da356e13d879bec2323aa8" 4 | } 5 | ] -------------------------------------------------------------------------------- /tests/functional/benchmark/data/benchmark_test_run_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "model_ids": ["61b097551efecf30109d32da", "60ddefbe8d38c51c5885f98a"], 4 | "dataset_ids": ["64da34a813d879bec2323aa3"], 5 | "dataset_names": ["EnHi SDK Test - Benchmark Dataset"], 6 | "metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"] 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /tests/functional/data_asset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/tests/functional/data_asset/__init__.py -------------------------------------------------------------------------------- /tests/functional/data_asset/corpus_onboarding_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import pytest 20 | import time 21 | from aixplain.enums import Language, License, OnboardStatus 22 | from aixplain.factories.corpus_factory import CorpusFactory 23 | from uuid import uuid4 24 | from aixplain import aixplain_v2 as v2 25 | 26 | 27 | @pytest.mark.parametrize("CorpusFactory", [CorpusFactory, v2.Corpus]) 28 | def test_corpus_onboard_get_delete(CorpusFactory): 29 | upload_file = "tests/functional/data_asset/input/audio-en_url.csv" 30 | schema = [ 31 | { 32 | "name": "audio", 33 | "dtype": "audio", 34 | "storage_type": "url", 35 | "start_column": "audio_start_time", 36 | "end_column": "audio_end_time", 37 | "languages": [Language.English_UNITED_STATES], 38 | }, 39 | {"name": "text", "dtype": "text", "storage_type": "text", "languages": [Language.English_UNITED_STATES]}, 40 | ] 41 | 42 | response = CorpusFactory.create( 43 | name=str(uuid4()), 44 | description="This corpus contain 20 English audios with their corresponding transcriptions.", 45 | license=License.MIT, 46 | content_path=upload_file, 47 | schema=schema, 48 | ) 49 | asset_id = response["asset_id"] 50 | onboard_status = OnboardStatus(response["status"]) 51 | while onboard_status == OnboardStatus.ONBOARDING: 52 | corpus = CorpusFactory.get(asset_id) 53 | onboard_status = corpus.onboard_status 54 | time.sleep(1) 55 | # assert the asset was onboarded 56 | assert onboard_status == OnboardStatus.ONBOARDED 57 | # assert the asset was deleted 58 | corpus.delete() 59 | with pytest.raises(Exception): 60 | corpus = CorpusFactory.get(asset_id) 61 | 62 | 63 | @pytest.mark.parametrize("CorpusFactory", [CorpusFactory, v2.Corpus]) 64 | def test_corpus_listing(CorpusFactory): 65 | response = CorpusFactory.list() 66 | assert "results" in response 67 | 68 | 69 | @pytest.mark.parametrize("CorpusFactory", [CorpusFactory, v2.Corpus]) 70 | def test_corpus_get_error(CorpusFactory): 71 | with pytest.raises(Exception): 72 | response = CorpusFactory.get("131312") -------------------------------------------------------------------------------- /tests/functional/data_asset/input/audio-en_with_invalid_split_url.csv: -------------------------------------------------------------------------------- 1 | ,audio,text,audio_start_time,audio_end_time,split,split-2 2 | 0,https://aixplain-platform-assets.s3.amazonaws.com/samples/en/discovery_demo.wav,Welcome to another episode of Explain using discover to find and benchmark AI models.,0.9,6.56,TRAIN,TRAIN -------------------------------------------------------------------------------- /tests/functional/file_asset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/tests/functional/file_asset/__init__.py -------------------------------------------------------------------------------- /tests/functional/file_asset/file_create_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "mohammedalyafeai" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import pytest 20 | from aixplain.enums import License 21 | from aixplain.factories import FileFactory 22 | from aixplain import aixplain_v2 as v2 23 | 24 | 25 | @pytest.mark.parametrize("FileFactory", [FileFactory, v2.File]) 26 | def test_file_create(FileFactory): 27 | upload_file = "tests/functional/file_asset/input/test.csv" 28 | s3_link = FileFactory.create(local_path=upload_file, tags=["test1", "test2"], license=License.MIT, is_temp=False) 29 | assert s3_link.startswith("s3") 30 | 31 | 32 | @pytest.mark.parametrize("FileFactory", [FileFactory, v2.File]) 33 | def test_file_create_temp(FileFactory): 34 | upload_file = "tests/functional/file_asset/input/test.csv" 35 | s3_link = FileFactory.create(local_path=upload_file, tags=["test1", "test2"], license=License.MIT, is_temp=True) 36 | assert s3_link.startswith("s3") 37 | -------------------------------------------------------------------------------- /tests/functional/file_asset/input/test.csv: -------------------------------------------------------------------------------- 1 | A,B 2 | 1,2 3 | 3,4 -------------------------------------------------------------------------------- /tests/functional/finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/aiXplain/92aad6a317c157cc37e91c532c4551ba369191e7/tests/functional/finetune/__init__.py -------------------------------------------------------------------------------- /tests/functional/finetune/data/finetune_test_cost_estimation.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"model_name": "Llama 2 7b", "model_id": "6543cb991f695e72028e9428", "dataset_name": "Test text generation dataset"}, 3 | {"model_name": "Llama 2 7B Chat", "model_id": "65519ee7bf42e6037ab109d8", "dataset_name": "Test text generation dataset"}, 4 | {"model_name": "Mistral 7b", "model_id": "6551a9e7bf42e6037ab109de", "dataset_name": "Test text generation dataset"}, 5 | {"model_name": "Mistral 7B Instruct v0.3", "model_id": "6551a9e7bf42e6037ab109de", "dataset_name": "Test text generation dataset"}, 6 | {"model_name": "Falcon 7b", "model_id": "6551bff9bf42e6037ab109e1", "dataset_name": "Test text generation dataset"}, 7 | {"model_name": "Falcon 7b Instruct", "model_id": "65519d57bf42e6037ab109d5", "dataset_name": "Test text generation dataset"}, 8 | {"model_name": "MPT 7b", "model_id": "6551a72bbf42e6037ab109d9", "dataset_name": "Test text generation dataset"}, 9 | {"model_name": "MPT 7b storywriter", "model_id": "6551a870bf42e6037ab109db", "dataset_name": "Test text generation dataset"}, 10 | {"model_name": "BloomZ 7b", "model_id": "6551ab17bf42e6037ab109e0", "dataset_name": "Test text generation dataset"}, 11 | {"model_name": "BloomZ 7b MT", "model_id": "656e80147ca71e334752d5a3", "dataset_name": "Test text generation dataset"} 12 | ] 13 | -------------------------------------------------------------------------------- /tests/functional/finetune/data/finetune_test_end2end.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "model_name": "llama2 7b", 4 | "model_id": "6543cb991f695e72028e9428", 5 | "dataset_name": "Test text generation dataset", 6 | "inference_data": "Hello!", 7 | "required_dev": true, 8 | "search_metadata": false 9 | } 10 | ] 11 | -------------------------------------------------------------------------------- /tests/functional/finetune/data/finetune_test_list_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "function": "text-generation" 4 | } 5 | ] -------------------------------------------------------------------------------- /tests/functional/finetune/data/finetune_test_prompt_validator.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "model_name": "llama2 7b", 4 | "model_id": "6543cb991f695e72028e9428", 5 | "dataset_name": "Test text generation dataset", 6 | "prompt_template": "Source: <>\nReference: <>", 7 | "is_valid": true 8 | }, 9 | { 10 | "model_name": "llama2 7b", 11 | "model_id": "6543cb991f695e72028e9428", 12 | "dataset_name": "Test text generation dataset", 13 | "prompt_template": "Source: <>\nReference: <>", 14 | "is_valid": false 15 | } 16 | ] -------------------------------------------------------------------------------- /tests/functional/general_assets/data/asset_run_test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "model" : { 3 | "id" : "61b097551efecf30109d32da", 4 | "data": "This is a test sentence." 5 | }, 6 | "model2" : { 7 | "id" : "60ddefab8d38c51c5885ee38", 8 | "data": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/myname.mp3" 9 | }, 10 | "model3" : { 11 | "id" : "6736411cf127849667606689", 12 | "data": "How to cook a shrimp risotto?" 13 | }, 14 | "pipeline": { 15 | "name": "SingleNodePipeline", 16 | "data": "This is a test sentence." 17 | }, 18 | "metric": { 19 | "id" : "639874ab506c987b1ae1acc6", 20 | "data": { 21 | "hypothesis": "hello world", 22 | "reference": "hello world" 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /tests/functional/model/data/test_input.txt: -------------------------------------------------------------------------------- 1 | Hello! Here is a robot emoji: 🤖 Response should contain this emoji. -------------------------------------------------------------------------------- /tests/functional/model/hf_onboarding_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "michaellam" 2 | 3 | import pytest 4 | import time 5 | 6 | from aixplain.factories.model_factory import ModelFactory 7 | from tests.test_utils import delete_asset 8 | from aixplain.utils import config 9 | 10 | 11 | @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") 12 | def test_deploy_model(): 13 | # Start the deployment 14 | model_name = "Test Model" 15 | repo_id = "tiiuae/falcon-7b" 16 | response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token=config.HF_TOKEN) 17 | assert "id" in response.keys() 18 | 19 | # Check for status 20 | model_id = response["id"] 21 | num_retries = 120 22 | counter = 0 23 | while ModelFactory.get_huggingface_model_status(model_id)["status"].lower() != "onboarded": 24 | time.sleep(10) 25 | counter += 1 26 | if counter == num_retries: 27 | assert ModelFactory.get_huggingface_model_status(model_id)["status"].lower() == "onboarded" 28 | 29 | # Clean up 30 | delete_asset(model_id, config.TEAM_API_KEY) 31 | 32 | 33 | # @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") 34 | def test_nonexistent_model(): 35 | # Start the deployment 36 | model_name = "Test Model" 37 | repo_id = "nonexistent-supplier/nonexistent-model" 38 | response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token=config.HF_TOKEN) 39 | assert response["statusCode"] == 400 40 | assert response["message"] == "err.unable_to_onboard_model" 41 | 42 | 43 | # @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") 44 | def test_size_limit(): 45 | # Start the deployment 46 | model_name = "Test Model" 47 | repo_id = "tiiuae/falcon-40b" 48 | response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token=config.HF_TOKEN) 49 | assert response["statusCode"] == 400 50 | assert response["message"] == "err.unable_to_onboard_model" 51 | 52 | 53 | # @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") 54 | def test_gated_model(): 55 | # Start the deployment 56 | model_name = "Test Model" 57 | repo_id = "meta-llama/Llama-2-7b-hf" 58 | response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token="mock_key") 59 | assert response["statusCode"] == 400 60 | assert response["message"] == "err.unable_to_onboard_model" 61 | -------------------------------------------------------------------------------- /tests/functional/model/image_upload_e2e_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "michaellam" 2 | 3 | from pathlib import Path 4 | import json 5 | from aixplain.factories.model_factory import ModelFactory 6 | from tests.test_utils import delete_asset, delete_service_account 7 | from aixplain.utils import config 8 | import docker 9 | import pytest 10 | 11 | 12 | def test_create_and_upload_model(): 13 | # List the host machines 14 | host_response = ModelFactory.list_host_machines() 15 | for hosting_machine_dict in host_response: 16 | assert "code" in hosting_machine_dict.keys() 17 | assert "type" in hosting_machine_dict.keys() 18 | assert "cores" in hosting_machine_dict.keys() 19 | assert "memory" in hosting_machine_dict.keys() 20 | assert "hourlyCost" in hosting_machine_dict.keys() 21 | 22 | # List the functions 23 | response = ModelFactory.list_functions() 24 | items = response["items"] 25 | for item in items: 26 | assert "output" not in item.keys() 27 | assert "params" not in item.keys() 28 | assert "id" not in item.keys() 29 | assert "name" in item.keys() 30 | 31 | # Register the model, and create an image repository for it. 32 | with open(Path("tests/test_requests/create_asset_request.json")) as f: 33 | mock_register_payload = json.load(f) 34 | name = mock_register_payload["name"] 35 | description = mock_register_payload["description"] 36 | function = mock_register_payload["function"] 37 | source_language = mock_register_payload["sourceLanguage"] 38 | input_modality = mock_register_payload["input_modality"] 39 | output_modality = mock_register_payload["output_modality"] 40 | documentation_url = mock_register_payload["documentation_url"] 41 | register_response = ModelFactory.create_asset_repo( 42 | name, description, function, source_language, input_modality, output_modality, documentation_url, config.TEAM_API_KEY 43 | ) 44 | assert "id" in register_response.keys() 45 | assert "repositoryName" in register_response.keys() 46 | model_id = register_response["id"] 47 | repo_name = register_response["repositoryName"] 48 | 49 | # Log into the image repository. 50 | login_response = ModelFactory.asset_repo_login() 51 | 52 | assert login_response["username"] == "AWS" 53 | assert login_response["registry"] == "535945872701.dkr.ecr.us-east-1.amazonaws.com" 54 | assert "password" in login_response.keys() 55 | 56 | username = login_response["username"] 57 | password = login_response["password"] 58 | registry = login_response["registry"] 59 | 60 | # Push an image to ECR 61 | low_level_client = docker.APIClient(base_url="unix://var/run/docker.sock") 62 | low_level_client.pull("bash") 63 | low_level_client.tag("bash", f"{registry}/{repo_name}") 64 | low_level_client.push(f"{registry}/{repo_name}", auth_config={"username": username, "password": password}) 65 | 66 | # Send an email to finalize onboarding process 67 | ModelFactory.onboard_model(model_id, "latest", "fake_hash") 68 | 69 | # Clean up 70 | delete_service_account(config.TEAM_API_KEY) 71 | delete_asset(model_id, config.TEAM_API_KEY) 72 | -------------------------------------------------------------------------------- /tests/functional/model/image_upload_functional_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "michaellam" 2 | from pathlib import Path 3 | import json 4 | from aixplain.factories.model_factory import ModelFactory 5 | from tests.test_utils import delete_asset, delete_service_account 6 | from aixplain.utils import config 7 | import docker 8 | import pytest 9 | 10 | 11 | def test_login(): 12 | response = ModelFactory.asset_repo_login() 13 | assert response["username"] == "AWS" 14 | assert response["registry"] == "535945872701.dkr.ecr.us-east-1.amazonaws.com" 15 | assert "password" in response.keys() 16 | 17 | # Test cleanup 18 | delete_service_account(config.TEAM_API_KEY) 19 | 20 | 21 | def test_create_asset_repo(): 22 | with open(Path("tests/test_requests/create_asset_request.json")) as f: 23 | mock_register_payload = json.load(f) 24 | name = mock_register_payload["name"] 25 | description = mock_register_payload["description"] 26 | function = mock_register_payload["function"] 27 | source_language = mock_register_payload["sourceLanguage"] 28 | input_modality = mock_register_payload["input_modality"] 29 | output_modality = mock_register_payload["output_modality"] 30 | documentation_url = mock_register_payload["documentation_url"] 31 | response = ModelFactory.create_asset_repo( 32 | name, description, function, source_language, input_modality, output_modality, documentation_url, config.TEAM_API_KEY 33 | ) 34 | response_dict = dict(response) 35 | assert "id" in response_dict.keys() 36 | assert "repositoryName" in response_dict.keys() 37 | 38 | # Test cleanup 39 | delete_asset(response["id"], config.TEAM_API_KEY) 40 | 41 | 42 | def test_list_host_machines(): 43 | response = ModelFactory.list_host_machines() 44 | for hosting_machine_dict in response: 45 | assert "code" in hosting_machine_dict.keys() 46 | assert "type" in hosting_machine_dict.keys() 47 | assert "cores" in hosting_machine_dict.keys() 48 | assert "memory" in hosting_machine_dict.keys() 49 | assert "hourlyCost" in hosting_machine_dict.keys() 50 | 51 | 52 | def test_get_functions(): 53 | # Verbose 54 | response = ModelFactory.list_functions(True) 55 | items = response["items"] 56 | for item in items: 57 | assert "output" in item.keys() 58 | assert "params" in item.keys() 59 | assert "id" in item.keys() 60 | assert "name" in item.keys() 61 | 62 | # Non-verbose 63 | response = ModelFactory.list_functions() # Not verbose by default 64 | items = response["items"] 65 | for item in items: 66 | assert "output" not in item.keys() 67 | assert "params" not in item.keys() 68 | assert "id" not in item.keys() 69 | assert "name" in item.keys() 70 | 71 | 72 | @pytest.mark.skip(reason="Not included in first release") 73 | def list_image_repo_tags(): 74 | response = ModelFactory.list_image_repo_tags() 75 | assert "Image tags" in response.keys() 76 | assert "nextToken" in response.keys() 77 | -------------------------------------------------------------------------------- /tests/functional/pipelines/create_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "thiagocastroferreira" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import json 20 | import pytest 21 | from aixplain.factories import PipelineFactory 22 | from aixplain.modules import Pipeline 23 | from uuid import uuid4 24 | from aixplain import aixplain_v2 as v2 25 | 26 | @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) 27 | def test_create_pipeline_from_json(PipelineFactory): 28 | pipeline_json = "tests/functional/pipelines/data/pipeline.json" 29 | pipeline_name = str(uuid4()) 30 | pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json) 31 | 32 | assert isinstance(pipeline, Pipeline) 33 | assert pipeline.id != "" 34 | pipeline.delete() 35 | 36 | 37 | @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) 38 | def test_create_pipeline_from_string(PipelineFactory): 39 | pipeline_json = "tests/functional/pipelines/data/pipeline.json" 40 | with open(pipeline_json) as f: 41 | pipeline_dict = json.load(f) 42 | 43 | pipeline_name = str(uuid4()) 44 | pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_dict) 45 | 46 | assert isinstance(pipeline, Pipeline) 47 | assert pipeline.id != "" 48 | assert pipeline.status.value == "draft" 49 | 50 | pipeline.deploy() 51 | pipeline = PipelineFactory.get(pipeline.id) 52 | assert pipeline.status.value == "onboarded" 53 | pipeline.delete() 54 | 55 | 56 | @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) 57 | def test_update_pipeline(PipelineFactory): 58 | pipeline_json = "tests/functional/pipelines/data/pipeline.json" 59 | with open(pipeline_json) as f: 60 | pipeline_dict = json.load(f) 61 | 62 | pipeline_name = str(uuid4()) 63 | pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_dict) 64 | 65 | pipeline.update(pipeline=pipeline_json, save_as_asset=True, name="NEW NAME") 66 | assert pipeline.name == "NEW NAME" 67 | assert isinstance(pipeline, Pipeline) 68 | assert pipeline.id != "" 69 | pipeline.delete() 70 | 71 | 72 | @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) 73 | def test_create_pipeline_wrong_path(PipelineFactory): 74 | pipeline_name = str(uuid4()) 75 | 76 | with pytest.raises(Exception): 77 | PipelineFactory.create(name=pipeline_name, pipeline="/") 78 | -------------------------------------------------------------------------------- /tests/functional/pipelines/data/pipeline.json: -------------------------------------------------------------------------------- 1 | { 2 | "links": [ 3 | { 4 | "from": 0, 5 | "to": 1, 6 | "paramMapping": [ 7 | { 8 | "from": "input", 9 | "to": "text" 10 | } 11 | ] 12 | }, 13 | { 14 | "from": 1, 15 | "to": 2, 16 | "paramMapping": [ 17 | { 18 | "from": "data", 19 | "to": "text" 20 | } 21 | ] 22 | }, 23 | { 24 | "from": 2, 25 | "to": 3, 26 | "paramMapping": [ 27 | { 28 | "from": "data", 29 | "to": "output" 30 | } 31 | ] 32 | } 33 | ], 34 | "nodes": [ 35 | { 36 | "number": 0, 37 | "type": "INPUT" 38 | }, 39 | { 40 | "number": 1, 41 | "type": "ASSET", 42 | "function": "sentiment-analysis", 43 | "inputValues": [ 44 | { 45 | "code": "language", 46 | "value": "en" 47 | }, 48 | { 49 | "code": "text", 50 | "dataType": "text" 51 | } 52 | ], 53 | "assetId": "6172874f720b09325cbcdc33", 54 | "assetType": "MODEL", 55 | "autoSelectOptions": [], 56 | "functionType": "AI", 57 | "status": "Exists", 58 | "outputValues": [ 59 | { 60 | "code": "data", 61 | "dataType": "label" 62 | } 63 | ] 64 | }, 65 | { 66 | "number": 2, 67 | "type": "ASSET", 68 | "function": "translation", 69 | "inputValues": [ 70 | { 71 | "code": "sourcelanguage", 72 | "value": "en" 73 | }, 74 | { 75 | "code": "targetlanguage", 76 | "value": "es" 77 | }, 78 | { 79 | "code": "text", 80 | "dataType": "text" 81 | } 82 | ], 83 | "assetId": "61b097551efecf30109d3316", 84 | "assetType": "MODEL", 85 | "autoSelectOptions": [], 86 | "functionType": "AI", 87 | "status": "Exists", 88 | "outputValues": [ 89 | { 90 | "code": "data", 91 | "dataType": "text" 92 | } 93 | ] 94 | }, 95 | { 96 | "number": 3, 97 | "type": "OUTPUT" 98 | } 99 | ] 100 | } -------------------------------------------------------------------------------- /tests/functional/pipelines/data/script.py: -------------------------------------------------------------------------------- 1 | def main(speakers): 2 | # build the response 3 | response = [] 4 | for i, speaker in enumerate(speakers): 5 | print(f"Processing speaker at index={i}") 6 | data = speaker["data"] 7 | data_modified = f"SCRIPT MODIFIED: {data}" 8 | response.append( 9 | { 10 | "index": i, 11 | "success": True, 12 | "input_type": "text", 13 | "is_url": False, 14 | "details": {}, 15 | "data": data_modified, 16 | "input": data_modified, 17 | } 18 | ) 19 | return response 20 | -------------------------------------------------------------------------------- /tests/functional/team_agent/data/team_agent_test_end2end.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "team_agent_name": "TEST Multi agent", 4 | "llm_id": "6626a3a8c8f1d089790cf5a2", 5 | "llm_name": "Groq Llama 3 70B", 6 | "query": "Who is the president of Brazil right now? Translate to pt and synthesize in audio", 7 | "agents": [ 8 | { 9 | "agent_name": "TEST Translation agent", 10 | "llm_id": "6626a3a8c8f1d089790cf5a2", 11 | "llm_name": "Groq Llama 3 70B", 12 | "model_tools": [ 13 | { 14 | "function": "translation", 15 | "supplier": "AWS" 16 | } 17 | ] 18 | }, 19 | { 20 | "agent_name": "TEST Speech Synthesis agent", 21 | "llm_id": "6626a3a8c8f1d089790cf5a2", 22 | "llm_name": "Groq Llama 3 70B", 23 | "model_tools": [ 24 | { 25 | "function": "speech-synthesis", 26 | "supplier": "Google" 27 | } 28 | ] 29 | } 30 | ] 31 | } 32 | ] 33 | -------------------------------------------------------------------------------- /tests/mock_responses/create_asset_repo_response.json: -------------------------------------------------------------------------------- 1 | { 2 | "modelId": "mockId" 3 | } -------------------------------------------------------------------------------- /tests/mock_responses/list_host_machines_response.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "64dce914adc92335dc35beb5", 4 | "code": "aix-2c-8g-od", 5 | "type": "on-demand", 6 | "cores": 2, 7 | "memory": 8, 8 | "hourlyCost": 0.12 9 | }, 10 | { 11 | "id": "64dceafdadc92335dc35beb6", 12 | "code": "aix-2c-8g", 13 | "type": "always-on", 14 | "cores": 2, 15 | "memory": 8, 16 | "hourlyCost": 0.096 17 | } 18 | ] -------------------------------------------------------------------------------- /tests/mock_responses/list_image_repo_tags_response.json: -------------------------------------------------------------------------------- 1 | ["tag1", "tag2", "tag3", "tag4"] -------------------------------------------------------------------------------- /tests/mock_responses/login_response.json: -------------------------------------------------------------------------------- 1 | { 2 | "username": "AWS", 3 | "password": "eyJwYXlsb2FkIjoiNlNkQmp0WkRWbDRtQmxYMk5ZWHh0dFJQRTJFeWpScDFVczI1RTl2WUJMRmN5SVU1TE1wd3hiK2FaZWtnbjZidy9Ea3UxQ1FpcnIwRURwNXZNMTJuZXBJVzhITjNkVEtmeFNCa1RwcTNESSt2ZnVtSm5MVXM3KzlPNEo3cmRySDE5NjdnazYyb0NIRVV1WmZvOUFuUm5CeHUyU2ZmZWFndFlYVyt0dDVXeWtMQjRCRlNFaGJUelNtSnllSW9pNTlkNFNYdGtXY3pDT1RZQ281MUVlVEI0L1c4NGZMVVZQRVF6VThmdmtYRVl1TDNEUWFzc3F3dUxxcHp2bWtrSCtNOHNrdFp6bHZubXlxMnFGYkR0aElhamNXNW1Ud1BkVjJMN2w0ZFJVSTlTQ3Y1SlExbnlZZ01obUxHeDRDRG5KYmh0NndzeEtWcVpxbmMzMDR6WXZnQlZTcWFEY2VvWXV0SFEwSTVSQU1DaUtNd09SZHF0Skt1Y3FxRVBwTkxPaDhlcUFScmd4bkVCYnhQZm4zZ0M5L0x2bHBiZ1I5UFRIWGlqZlFWczNnUW5vTzFmd0R2d1dudTRsMjJDWjdSUTN4WlRNL29NdFNtZ2RScmplclpqNWo0RVMycTdQTEFXOU9UcUtieDRpZklMRUVucTIxbDBXaFNtc0xlR2g4Rm9GZkpOSGJ5L2wzUklTY2hjUzBYUUdYMXJ0cFhFOTc3bUVtdzY0WDdYT3h5UGlnZytzNWowMjhFY0VqSzV6R01sNzdDYUprcVVyZjVUUWZraTU4VURCMTNXWDlvVDVGQVUvcU9DY3F0SlQ5TlBZTnFXQ0xhamdFdk93TXFsQndkVzhKTEhwMTkwZ3psNE1nN0YwRDIvTFpScWRDVVh2SXRBSFJJUmROa1U3RDI1Y3VoL0xjSjlhZUQ2MnJiVDA1R2FIWkV5Z0d5MmxnRWlmekUvbWhPSGNUclBOSnlPTGhHaFc2L0F5dCt1MDRxNEdqMzVFQk1GSHZ0a0lLUEQ5MU04NTVKZnVMV3F3d09QR1NlZnNGRXlRNExxRGZtMkNueVpqd3NuNWRFSlR5VUZhTUMyODMwbCtBV3lZMFBQQ3l6eTFJK0FoMHV0VkJvMlBabkFPZVk1c3hOL05uOFhlbmRMbTA0Mm1wTENWOCtHd3lzYnVFM1BHRDdNV3pDaVVicm0rbXdBLzk0c3hTODlTNkJpVWhnUHp5RC84TWhyVUNNL1FTRGNFY3ZUTjVFc0N0UDM1cUdUT28yOWdxc3VzdWRLZHdEQkhWMlpkaGNNR0xQMElWNEZKN01CQVZSMnd4OTRiZXpDMm4xU3V5TGRGVVBQYVFKa2wwWmw2M3E4MU5FRjdMSzQ0M0FJbzlpV3FuazltbFBYRVo1OHdVUERnMUpZbWw4b3BCYVprazJtM2dvYk5HdEFWUHY1dDlXZitXY2Q2MDN3WnJ1TlhwUTNPSlk2WWI4ZXBMNlZpN1ErTkpaa2Z0NWl5M1FQRFpUUFZjSCs0c1VjZ0E2dmFMSUY2aEZCUncwWitRS0pvK0VZUWtFK0RTQXhMaldFYkt5ZzBSN1V3UHg0VThENjQ4My9mMlV2cU5jSFRORHNkbXRKcjlXcUwxNHRoc1BqQTNqZ3Bqc0pydDJJWTA1bEdNOWJJbGpmbUtGWFdsemppQ2ptSUNsSm14SUxIdzgvSTlKb3JYb2NmNXpoSHVzbCswUkdKc1NMTHAyOWc9PSIsImRhdGFrZXkiOiJBUUVCQUhod20wWWFJU0plUnRKbTVuMUc2dXFlZWtYdW9YWFBlNVVGY2U5UnE4LzE0d0FBQUg0d2ZBWUpLb1pJaHZjTkFRY0dvRzh3YlFJQkFEQm9CZ2txaGtpRzl3MEJCd0V3SGdZSllJWklBV1VEQkFFdU1CRUVERGFyODZkalUxNVFHNCtZaEFJQkVJQTdvY0xIeWFpUHViY2VTQ0g5djB6THd2UFZGbHU0WmJqZ09JSGkrdmxiNEpCVTBlNyt5VmpnT3BpcWVmQlkxbFBGWktKalgvMEIwMkJDcU1nPSIsInZlcnNpb24iOiIyIiwidHlwZSI6IkRBVEFfS0VZIiwiZXhwaXJhdGlvbiI6MTY5MjYxNDYwMX0=", 4 | "registry": "https://535945872701.dkr.ecr.us-east-1.amazonaws.com" 5 | } -------------------------------------------------------------------------------- /tests/test_requests/create_asset_request.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mock_name", 3 | "description": "mock_description", 4 | "function": "Text Generation", 5 | "sourceLanguage": "en", 6 | "input_modality": "text", 7 | "output_modality": "text", 8 | "documentation_url": "" 9 | } -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from aixplain.utils.request_utils import _request_with_retry 2 | from urllib.parse import urljoin 3 | import logging 4 | from aixplain.utils import config 5 | 6 | 7 | def delete_asset(model_id, api_key): 8 | delete_url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") 9 | logging.debug(f"URL: {delete_url}") 10 | headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} 11 | _ = _request_with_retry("delete", delete_url, headers=headers) 12 | 13 | 14 | def delete_service_account(api_key): 15 | delete_url = urljoin(config.BACKEND_URL, "sdk/ecr/logout") 16 | logging.debug(f"URL: {delete_url}") 17 | headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} 18 | _ = _request_with_retry("post", delete_url, headers=headers) 19 | -------------------------------------------------------------------------------- /tests/unit/benchmark_test.py: -------------------------------------------------------------------------------- 1 | import requests_mock 2 | import pytest 3 | from urllib.parse import urljoin 4 | from aixplain.utils import config 5 | from aixplain.factories import MetricFactory, BenchmarkFactory 6 | from aixplain.modules.model import Model 7 | from aixplain.modules.dataset import Dataset 8 | 9 | 10 | def test_create_benchmark_error_response(): 11 | metric_list = [MetricFactory.get("66df3e2d6eb56336b6628171")] 12 | with requests_mock.Mocker() as mock: 13 | name = "test-benchmark" 14 | dataset_list = [ 15 | Dataset( 16 | id="dataset1", 17 | name="Dataset 1", 18 | description="Test dataset", 19 | function="test_func", 20 | source_data="src", 21 | target_data="tgt", 22 | onboard_status="onboarded", 23 | ) 24 | ] 25 | model_list = [ 26 | Model(id="model1", name="Model 1", description="Test model", supplier="Test supplier", cost=10, version="v1") 27 | ] 28 | 29 | url = urljoin(config.BACKEND_URL, "sdk/benchmarks") 30 | headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} 31 | 32 | error_response = {"statusCode": 400, "message": "Invalid request"} 33 | mock.post(url, headers=headers, json=error_response, status_code=400) 34 | 35 | with pytest.raises(Exception) as excinfo: 36 | BenchmarkFactory.create(name=name, dataset_list=dataset_list, model_list=model_list, metric_list=metric_list) 37 | 38 | assert "Benchmark Creation Error: Status 400 - {'statusCode': 400, 'message': 'Invalid request'}" in str(excinfo.value) 39 | 40 | 41 | def test_get_benchmark_error(): 42 | with requests_mock.Mocker() as mock: 43 | benchmark_id = "test-benchmark-id" 44 | url = urljoin(config.BACKEND_URL, f"sdk/benchmarks/{benchmark_id}") 45 | headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} 46 | 47 | error_response = {"statusCode": 404, "message": "Benchmark not found"} 48 | mock.get(url, headers=headers, json=error_response, status_code=404) 49 | 50 | with pytest.raises(Exception) as excinfo: 51 | BenchmarkFactory.get(benchmark_id) 52 | 53 | assert "Benchmark GET Error: Status 404 - {'statusCode': 404, 'message': 'Benchmark not found'}" in str(excinfo.value) 54 | 55 | 56 | def test_list_normalization_options_error(): 57 | metric = MetricFactory.get("66df3e2d6eb56336b6628171") 58 | with requests_mock.Mocker() as mock: 59 | model = Model(id="model1", name="Test Model", description="Test model", supplier="Test supplier", cost=10, version="v1") 60 | 61 | url = urljoin(config.BACKEND_URL, "sdk/benchmarks/normalization-options") 62 | headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} 63 | 64 | error_response = {"message": "Internal Server Error"} 65 | mock.post(url, headers=headers, json=error_response, status_code=500) 66 | 67 | with pytest.raises(Exception) as excinfo: 68 | BenchmarkFactory.list_normalization_options(metric, model) 69 | 70 | assert "Error listing normalization options: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) 71 | -------------------------------------------------------------------------------- /tests/unit/corpus_test.py: -------------------------------------------------------------------------------- 1 | from aixplain.factories import CorpusFactory 2 | import pytest 3 | import requests_mock 4 | from urllib.parse import urljoin 5 | from aixplain.utils import config 6 | 7 | 8 | def test_get_corpus_error_response(): 9 | with requests_mock.Mocker() as mock: 10 | corpus_id = "invalid_corpus_id" 11 | url = urljoin(config.BACKEND_URL, f"sdk/corpora/{corpus_id}/overview") 12 | headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} 13 | 14 | error_response = {"message": "Not Found"} 15 | mock.get(url, headers=headers, json=error_response, status_code=404) 16 | 17 | with pytest.raises(Exception) as excinfo: 18 | CorpusFactory.get(corpus_id=corpus_id) 19 | 20 | assert "Corpus GET Error: Status 404 - {'message': 'Not Found'}" in str(excinfo.value) 21 | 22 | 23 | def test_list_corpus_error_response(): 24 | with requests_mock.Mocker() as mock: 25 | url = urljoin(config.BACKEND_URL, "sdk/corpora/paginate") 26 | headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} 27 | 28 | error_response = {"message": "Internal Server Error"} 29 | mock.post(url, headers=headers, json=error_response, status_code=500) 30 | 31 | with pytest.raises(Exception) as excinfo: 32 | CorpusFactory.list(query="test_query", page_number=0, page_size=20) 33 | 34 | assert "Corpus List Error: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) 35 | -------------------------------------------------------------------------------- /tests/unit/data/create_finetune_percentage_exception.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "train_percentage": 0, 4 | "dev_percentage": 100 5 | }, 6 | { 7 | "train_percentage": 0, 8 | "dev_percentage": 0 9 | }, 10 | { 11 | "train_percentage": 80, 12 | "dev_percentage": 30 13 | } 14 | ] -------------------------------------------------------------------------------- /tests/unit/dataset_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests_mock 3 | from aixplain.factories import DatasetFactory 4 | from urllib.parse import urljoin 5 | from aixplain.utils import config 6 | 7 | 8 | def test_list_dataset_error_response(): 9 | with requests_mock.Mocker() as mock: 10 | url = urljoin(config.BACKEND_URL, "sdk/datasets/paginate") 11 | headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} 12 | 13 | error_response = {"message": "Internal Server Error"} 14 | mock.post(url, headers=headers, json=error_response, status_code=500) 15 | 16 | with pytest.raises(Exception) as excinfo: 17 | DatasetFactory.list(query="test_query", page_number=0, page_size=20) 18 | 19 | assert "Dataset List Error: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) 20 | 21 | 22 | def test_get_dataset_error_response(): 23 | with requests_mock.Mocker() as mock: 24 | dataset_id = "invalid_dataset_id" 25 | url = urljoin(config.BACKEND_URL, f"sdk/datasets/{dataset_id}/overview") 26 | headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} 27 | 28 | error_response = {"message": "Not Found"} 29 | mock.get(url, headers=headers, json=error_response, status_code=404) 30 | 31 | with pytest.raises(Exception) as excinfo: 32 | DatasetFactory.get(dataset_id=dataset_id) 33 | 34 | assert "Dataset GET Error: Status 404 - {'message': 'Not Found'}" in str(excinfo.value) 35 | -------------------------------------------------------------------------------- /tests/unit/hyperparameters_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "lucaspavanelli" 2 | 3 | """ 4 | Copyright 2022 The aiXplain SDK authors 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | from dotenv import load_dotenv 20 | 21 | load_dotenv() 22 | 23 | from aixplain.modules.finetune import Hyperparameters 24 | from aixplain.modules.finetune.hyperparameters import ( 25 | EPOCHS_MAX_VALUE, 26 | BATCH_SIZE_VALUES, 27 | MAX_SEQ_LENGTH_MAX_VALUE, 28 | ) 29 | 30 | 31 | import pytest 32 | 33 | 34 | def test_create(): 35 | hyp = Hyperparameters() 36 | assert hyp is not None 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "params", 41 | [ 42 | {"epochs": "string"}, 43 | {"train_batch_size": "string"}, 44 | {"eval_batch_size": "string"}, 45 | {"learning_rate": "string"}, 46 | {"max_seq_length": "string"}, 47 | {"warmup_ratio": "string"}, 48 | {"warmup_steps": "string"}, 49 | {"lr_scheduler_type": "string"}, 50 | ], 51 | ) 52 | def test_create_invalid_type(params): 53 | with pytest.raises(Exception) as exc_info: 54 | Hyperparameters(**params) 55 | assert exc_info.type is TypeError 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "params", 60 | [ 61 | {"epochs": EPOCHS_MAX_VALUE + 1}, 62 | {"train_batch_size": max(BATCH_SIZE_VALUES) + 1}, 63 | {"eval_batch_size": max(BATCH_SIZE_VALUES) + 1}, 64 | {"max_seq_length": MAX_SEQ_LENGTH_MAX_VALUE + 1}, 65 | ], 66 | ) 67 | def test_create_invalid_values(params): 68 | with pytest.raises(Exception) as exc_info: 69 | Hyperparameters(**params) 70 | assert exc_info.type is ValueError 71 | -------------------------------------------------------------------------------- /tests/unit/image_upload_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "michaellam" 2 | 3 | import json 4 | import requests_mock 5 | from pathlib import Path 6 | from aixplain.utils import config 7 | from urllib.parse import urljoin 8 | import pytest 9 | 10 | from aixplain.factories.model_factory import ModelFactory 11 | 12 | AUTH_FIXED_HEADER = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} 13 | API_FIXED_HEADER = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} 14 | 15 | 16 | def test_login(): 17 | url = urljoin(config.BACKEND_URL, f"sdk/ecr/login") 18 | with requests_mock.Mocker() as mock: 19 | with open(Path("tests/mock_responses/login_response.json")) as f: 20 | mock_json = json.load(f) 21 | mock.post(url, headers=AUTH_FIXED_HEADER, json=mock_json) 22 | creds = ModelFactory.asset_repo_login(config.TEAM_API_KEY) 23 | assert creds == mock_json 24 | 25 | 26 | def test_create_asset_repo(): 27 | url_register = urljoin(config.BACKEND_URL, f"sdk/models/onboard") 28 | url_function = urljoin(config.BACKEND_URL, f"sdk/functions") 29 | print(f"URL_Register {url_register}") 30 | with requests_mock.Mocker() as mock: 31 | with open(Path("tests/mock_responses/create_asset_repo_response.json")) as f: 32 | mock_json_register = json.load(f) 33 | mock.post(url_register, headers=API_FIXED_HEADER, json=mock_json_register, status_code=201) 34 | 35 | with open(Path("tests/mock_responses/list_functions_response.json")) as f: 36 | mock_json_functions = json.load(f) 37 | mock.get(url_function, headers=AUTH_FIXED_HEADER, json=mock_json_functions) 38 | 39 | model_id = ModelFactory.create_asset_repo( 40 | "mock_name", "mock_description", "Text Generation", "en", "text", "text", api_key=config.TEAM_API_KEY 41 | ) 42 | # print(f"Model ID {model_id}") 43 | assert model_id == mock_json_register 44 | 45 | 46 | def test_list_host_machines(): 47 | url = urljoin(config.BACKEND_URL, f"sdk/hosting-machines") 48 | with requests_mock.Mocker() as mock: 49 | with open(Path("tests/mock_responses/list_host_machines_response.json")) as f: 50 | mock_json = json.load(f) 51 | mock.get(url, headers=API_FIXED_HEADER, json=mock_json) 52 | machines = ModelFactory.list_host_machines(config.TEAM_API_KEY) 53 | for i in range(len(machines)): 54 | machine_dict = machines[i] 55 | mock_json_dict = mock_json[i] 56 | for key in machine_dict.keys(): 57 | assert machine_dict[key] == mock_json_dict[key] 58 | 59 | 60 | def test_get_functions(): 61 | url = urljoin(config.BACKEND_URL, f"sdk/functions") 62 | with requests_mock.Mocker() as mock: 63 | with open(Path("tests/mock_responses/list_functions_response.json")) as f: 64 | mock_json = json.load(f) 65 | mock.get(url, headers=AUTH_FIXED_HEADER, json=mock_json) 66 | functions = ModelFactory.list_functions(config.TEAM_API_KEY) 67 | assert functions == mock_json 68 | 69 | 70 | @pytest.mark.skip(reason="Not currently supported.") 71 | def test_list_image_repo_tags(): 72 | model_id = "mock_id" 73 | url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}/images") 74 | with requests_mock.Mocker() as mock: 75 | with open(Path("tests/mock_responses/list_image_repo_tags_response.json")) as f: 76 | mock_json = json.load(f) 77 | mock.get(url, headers=AUTH_FIXED_HEADER, json=mock_json) 78 | tags = ModelFactory.list_image_repo_tags(model_id, config.TEAM_API_KEY) 79 | assert tags == mock_json 80 | -------------------------------------------------------------------------------- /tests/unit/mock_responses/cost_estimation_response.json: -------------------------------------------------------------------------------- 1 | { 2 | "trainingCost": { 3 | "total": 1, 4 | "supplierCost": 0, 5 | "overheadCost": 1, 6 | "isDependingOnTrainingTime": false, 7 | "willRefundIfLowerThanMax": false, 8 | "totalVolume": 98.72, 9 | "unitPrice": 0, 10 | "timeScale": null 11 | }, 12 | "inferenceCost": [ 13 | { 14 | "unitPrice": 0.023333333333333334, 15 | "unitType": "TIME", 16 | "unitTypeScale": "MINUTE", 17 | "volume": 0 18 | } 19 | ], 20 | "hostingCost": { 21 | "currentMonthPrice": 28.3526, 22 | "monthlyPrice": 38.736, 23 | "pricePerCycle": 0.0538, 24 | "supplierBillingCycle": "HOUR", 25 | "willRefundIfLowerThanMax": true 26 | } 27 | } -------------------------------------------------------------------------------- /tests/unit/mock_responses/finetune_response.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "MODEL_ID", 3 | "status": "MODEL_STATUS" 4 | } -------------------------------------------------------------------------------- /tests/unit/mock_responses/finetune_status_response.json: -------------------------------------------------------------------------------- 1 | { 2 | "finetuneStatus": "onboarding", 3 | "modelStatus": "onboarded", 4 | "logs": [ 5 | { 6 | "epoch": 1, 7 | "learningRate": 9.938725490196079e-05, 8 | "trainLoss": 0.1, 9 | "evalLoss": 0.1106, 10 | "step": 10 11 | }, 12 | { 13 | "epoch": 2, 14 | "learningRate": 9.877450980392157e-05, 15 | "trainLoss": 0.2, 16 | "evalLoss": 0.0482, 17 | "step": 20 18 | }, 19 | { 20 | "epoch": 3, 21 | "learningRate": 9.816176470588235e-05, 22 | "trainLoss": 0.3, 23 | "evalLoss": 0.0251, 24 | "step": 30 25 | }, 26 | { 27 | "epoch": 4, 28 | "learningRate": 9.754901960784314e-05, 29 | "trainLoss": 0.9, 30 | "evalLoss": 0.0228, 31 | "step": 40 32 | }, 33 | { 34 | "epoch": 5, 35 | "learningRate": 9.693627450980392e-05, 36 | "trainLoss": 0.4, 37 | "evalLoss": 0.0217, 38 | "step": 50 39 | } 40 | ] 41 | } -------------------------------------------------------------------------------- /tests/unit/mock_responses/finetune_status_response_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "65fb26268fe9153a6c9c29c4", 3 | "finetuneStatus": "in_progress", 4 | "modelStatus": "training", 5 | "logs": [ 6 | { 7 | "epoch": 1, 8 | "learningRate": null, 9 | "trainLoss": null, 10 | "validationLoss": null, 11 | "step": null, 12 | "evalLoss": 2.684150457382, 13 | "totalFlos": null, 14 | "evalRuntime": 12.4129, 15 | "trainRuntime": null, 16 | "evalStepsPerSecond": 0.322, 17 | "trainStepsPerSecond": null, 18 | "evalSamplesPerSecond": 16.112 19 | }, 20 | { 21 | "epoch": 2, 22 | "learningRate": null, 23 | "trainLoss": null, 24 | "validationLoss": null, 25 | "step": null, 26 | "evalLoss": 2.596168756485, 27 | "totalFlos": null, 28 | "evalRuntime": 11.8249, 29 | "trainRuntime": null, 30 | "evalStepsPerSecond": 0.338, 31 | "trainStepsPerSecond": null, 32 | "evalSamplesPerSecond": 16.913 33 | }, 34 | { 35 | "epoch": 2, 36 | "learningRate": null, 37 | "trainLoss": 2.657801408034, 38 | "validationLoss": null, 39 | "step": null, 40 | "evalLoss": null, 41 | "totalFlos": 11893948284928, 42 | "evalRuntime": null, 43 | "trainRuntime": 221.7946, 44 | "evalStepsPerSecond": null, 45 | "trainStepsPerSecond": 0.117, 46 | "evalSamplesPerSecond": null 47 | } 48 | ] 49 | } -------------------------------------------------------------------------------- /tests/unit/utility_tool_decorator_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from aixplain.enums import DataType 3 | from aixplain.enums.asset_status import AssetStatus 4 | from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput 5 | 6 | def test_utility_tool_basic_decoration(): 7 | """Test basic decoration with minimal parameters""" 8 | @utility_tool( 9 | name="test_function", 10 | description="Test function description" 11 | ) 12 | def test_func(input_text: str) -> str: 13 | return input_text 14 | 15 | assert hasattr(test_func, '_is_utility_tool') 16 | assert test_func._is_utility_tool is True 17 | assert test_func._tool_name == "test_function" 18 | assert test_func._tool_description == "Test function description" 19 | assert test_func._tool_inputs == [] 20 | assert test_func._tool_output_examples == "" 21 | assert test_func._tool_status == AssetStatus.DRAFT 22 | 23 | def test_utility_tool_with_all_parameters(): 24 | """Test decoration with all optional parameters""" 25 | inputs = [ 26 | UtilityModelInput(name="text_input", type=DataType.TEXT, description="A text input"), 27 | UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input") 28 | ] 29 | 30 | @utility_tool( 31 | name="full_test_function", 32 | description="Full test function description", 33 | inputs=inputs, 34 | output_examples="Example output: Hello World", 35 | status=AssetStatus.ONBOARDED 36 | ) 37 | def test_func(text_input: str, num_input: int) -> str: 38 | return f"{text_input} {num_input}" 39 | 40 | assert test_func._is_utility_tool is True 41 | assert test_func._tool_name == "full_test_function" 42 | assert test_func._tool_description == "Full test function description" 43 | assert len(test_func._tool_inputs) == 2 44 | assert test_func._tool_inputs == inputs 45 | assert test_func._tool_output_examples == "Example output: Hello World" 46 | assert test_func._tool_status == AssetStatus.ONBOARDED 47 | 48 | def test_utility_tool_function_still_callable(): 49 | """Test that decorated function remains callable""" 50 | @utility_tool( 51 | name="callable_test", 52 | description="Test function callable" 53 | ) 54 | def test_func(x: int, y: int) -> int: 55 | return x + y 56 | 57 | assert test_func(2, 3) == 5 58 | assert test_func._is_utility_tool is True 59 | 60 | def test_utility_tool_invalid_inputs(): 61 | """Test validation of invalid inputs""" 62 | with pytest.raises(ValueError): 63 | @utility_tool( 64 | name="", # Empty name should raise error 65 | description="Test description" 66 | ) 67 | def test_func(): 68 | pass 69 | 70 | with pytest.raises(ValueError): 71 | @utility_tool( 72 | name="test_name", 73 | description="" # Empty description should raise error 74 | ) 75 | def test_func(): 76 | pass -------------------------------------------------------------------------------- /tests/unit/v2/test_core.py: -------------------------------------------------------------------------------- 1 | import os 2 | from aixplain.v2.core import Aixplain 3 | from aixplain.v2.model import Model 4 | from aixplain.v2.pipeline import Pipeline 5 | from aixplain.v2.agent import Agent 6 | from unittest.mock import patch 7 | 8 | 9 | def test_aixplain_instance(): 10 | # mock init_env, init_client, init_resources 11 | with patch.object(Aixplain, "init_env"): 12 | with patch.object(Aixplain, "init_client"): 13 | with patch.object(Aixplain, "init_resources"): 14 | aixplain = Aixplain(api_key="test") 15 | assert aixplain is not None 16 | assert aixplain.api_key == "test" 17 | assert ( 18 | aixplain.base_url == os.getenv("BACKEND_URL") 19 | or "https://platform-api.aixplain.com" 20 | ) 21 | assert ( 22 | aixplain.pipeline_url == os.getenv("PIPELINES_RUN_URL") 23 | or "https://platform-api.aixplain.com/assets/pipeline/execution/run" 24 | ) 25 | assert ( 26 | aixplain.model_url == os.getenv("MODELS_RUN_URL") 27 | or "https://models.aixplain.com/api/v1/execute" 28 | ) 29 | aixplain.init_env.assert_called_once() 30 | aixplain.init_client.assert_called_once() 31 | aixplain.init_resources.assert_called_once() 32 | 33 | 34 | def test_aixplain_init_env(): 35 | aixplain = Aixplain( 36 | api_key="test", 37 | backend_url="https://platform-api.aixplain.com", 38 | pipeline_url="https://platform-api.aixplain.com/assets/pipeline/execution/run", 39 | model_url="https://models.aixplain.com/api/v1/execute", 40 | ) 41 | with patch.object(os, "environ", new=dict()) as mock_environ: 42 | aixplain.init_env() 43 | assert mock_environ["TEAM_API_KEY"] == "test" 44 | assert mock_environ["BACKEND_URL"] == "https://platform-api.aixplain.com" 45 | assert ( 46 | mock_environ["PIPELINE_URL"] 47 | == "https://platform-api.aixplain.com/assets/pipeline/execution/run" 48 | ) 49 | assert mock_environ["MODEL_URL"] == "https://models.aixplain.com/api/v1/execute" 50 | 51 | 52 | def test_aixplain_init_client(): 53 | aixplain = Aixplain(api_key="test") 54 | with patch("aixplain.v2.core.AixplainClient") as mock_client: 55 | aixplain.init_client() 56 | mock_client.assert_called_once_with( 57 | base_url="https://platform-api.aixplain.com", 58 | team_api_key="test", 59 | ) 60 | assert aixplain.client is not None 61 | 62 | 63 | def test_aixplain_init_resources(): 64 | aixplain = Aixplain(api_key="test") 65 | with patch.object(Aixplain, "init_resources"): 66 | aixplain.init_resources() 67 | assert aixplain.Model is not None 68 | assert aixplain.Pipeline is not None 69 | assert aixplain.Agent is not None 70 | assert aixplain.Model.context == aixplain 71 | assert aixplain.Pipeline.context == aixplain 72 | assert aixplain.Agent.context == aixplain 73 | 74 | assert issubclass(aixplain.Model, Model) 75 | assert issubclass(aixplain.Pipeline, Pipeline) 76 | assert issubclass(aixplain.Agent, Agent) 77 | 78 | # check if the resources are NOT the same class type 79 | assert aixplain.Pipeline != Pipeline 80 | assert aixplain.Model != Model 81 | assert aixplain.Agent != Agent 82 | -------------------------------------------------------------------------------- /tests/unit/wallet_test.py: -------------------------------------------------------------------------------- 1 | __author__ = "aixplain" 2 | 3 | from aixplain.factories import WalletFactory 4 | import aixplain.utils.config as config 5 | import requests_mock 6 | 7 | 8 | def test_wallet_service(): 9 | with requests_mock.Mocker() as mock: 10 | url = f"{config.BACKEND_URL}/sdk/billing/wallet" 11 | headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} 12 | ref_response = {"totalBalance": 5, "reservedBalance": "0"} 13 | mock.get(url, headers=headers, json=ref_response) 14 | wallet = WalletFactory.get() 15 | assert wallet.total_balance == float(ref_response["totalBalance"]) 16 | assert wallet.reserved_balance == float(ref_response["reservedBalance"]) 17 | assert wallet.available_balance == float(ref_response["totalBalance"]) - float(ref_response["reservedBalance"]) 18 | --------------------------------------------------------------------------------