├── .coveragerc ├── .github ├── pull_request_template.md └── workflows │ ├── publish.yml │ ├── requirements.txt │ ├── run_tests.yml │ ├── run_tests_prod.yml │ └── run_tests_staging.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── .ruff.toml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── clarifai ├── __init__.py ├── cli.py ├── cli │ ├── README.md │ ├── __init__.py │ ├── __main__.py │ ├── base.py │ ├── compute_cluster.py │ ├── deployment.py │ ├── model.py │ ├── model_templates.py │ └── nodepool.py ├── client │ ├── __init__.py │ ├── app.py │ ├── auth │ │ ├── __init__.py │ │ ├── helper.py │ │ ├── register.py │ │ └── stub.py │ ├── base.py │ ├── compute_cluster.py │ ├── dataset.py │ ├── deployment.py │ ├── input.py │ ├── lister.py │ ├── model.py │ ├── model_client.py │ ├── module.py │ ├── nodepool.py │ ├── runner.py │ ├── search.py │ ├── user.py │ └── workflow.py ├── constants │ ├── base.py │ ├── dataset.py │ ├── input.py │ ├── model.py │ ├── rag.py │ ├── search.py │ └── workflow.py ├── datasets │ ├── __init__.py │ ├── export │ │ ├── __init__.py │ │ └── inputs_annotations.py │ └── upload │ │ ├── __init__.py │ │ ├── base.py │ │ ├── features.py │ │ ├── image.py │ │ ├── loaders │ │ ├── README.md │ │ ├── __init__.py │ │ ├── coco_captions.py │ │ ├── coco_detection.py │ │ ├── imagenet_classification.py │ │ └── xview_detection.py │ │ ├── multimodal.py │ │ ├── text.py │ │ └── utils.py ├── errors.py ├── models │ ├── __init__.py │ └── api.py ├── modules │ ├── README.md │ ├── __init__.py │ ├── css.py │ ├── pages.py │ └── style.css ├── rag │ ├── __init__.py │ ├── rag.py │ └── utils.py ├── runners │ ├── __init__.py │ ├── dockerfile_template │ │ └── Dockerfile.template │ ├── models │ │ ├── __init__.py │ │ ├── dummy_openai_model.py │ │ ├── mcp_class.py │ │ ├── model_builder.py │ │ ├── model_class.py │ │ ├── model_run_locally.py │ │ ├── model_runner.py │ │ ├── model_servicer.py │ │ ├── openai_class.py │ │ ├── visual_classifier_class.py │ │ └── visual_detector_class.py │ ├── server.py │ └── utils │ │ ├── __init__.py │ │ ├── code_script.py │ │ ├── const.py │ │ ├── data_types │ │ ├── __init__.py │ │ └── data_types.py │ │ ├── data_utils.py │ │ ├── loader.py │ │ ├── method_signatures.py │ │ ├── openai_convertor.py │ │ ├── serializers.py │ │ └── url_fetcher.py ├── schema │ └── search.py ├── urls │ └── helper.py ├── utils │ ├── __init__.py │ ├── cli.py │ ├── config.py │ ├── constants.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── helpers.py │ │ ├── main.py │ │ └── testset_annotation_parser.py │ ├── logging.py │ ├── misc.py │ ├── model_train.py │ └── protobuf.py ├── versions.py └── workflows │ ├── __init__.py │ ├── export.py │ ├── utils.py │ └── validate.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts └── key_for_tests.py ├── setup.py └── tests ├── __init__.py ├── assets ├── coco_detection │ ├── images │ │ ├── 3176048.jpg │ │ ├── architectural-design-architecture-asphalt-2445783.jpg │ │ └── architecture-buildings-commerce-2308592.jpg │ └── instances_default.json ├── imagenet_classification │ ├── LOC_synset_mapping.txt │ └── train │ │ ├── n01855672 │ │ ├── n01855672_0.JPEG │ │ ├── n01855672_1.JPEG │ │ ├── n01855672_2.JPEG │ │ ├── n01855672_3.JPEG │ │ └── n01855672_4.JPEG │ │ └── n02113799 │ │ ├── n02113799_0.JPEG │ │ ├── n02113799_1.JPEG │ │ ├── n02113799_2.JPEG │ │ ├── n02113799_3.JPEG │ │ └── n02113799_4.JPEG ├── red-truck.png ├── sample.csv ├── sample.mp3 ├── sample.mp4 ├── sample.txt ├── sample_texts │ ├── sample1.txt │ ├── sample2.txt │ └── sample3.txt ├── test │ ├── zorua.png │ ├── zubat.png │ └── zweilous.png └── voc │ ├── __init__.py │ ├── annotations │ ├── 2007_000464.xml │ ├── 2008_000853.xml │ ├── 2008_003182.xml │ ├── 2008_008526.xml │ ├── 2009_004315.xml │ ├── 2009_004382.xml │ ├── 2011_000430.xml │ ├── 2011_001610.xml │ ├── 2011_006412.xml │ └── 2012_000690.xml │ ├── dataset.py │ └── images │ ├── 2007_000464.jpg │ ├── 2008_000853.jpg │ ├── 2008_003182.jpg │ ├── 2008_008526.jpg │ ├── 2009_004315.jpg │ ├── 2009_004382.jpg │ ├── 2011_000430.jpg │ ├── 2011_001610.jpg │ ├── 2011_006412.jpg │ └── 2012_000690.jpg ├── cli └── test_compute_orchestration.py ├── client └── test_model_upload_predict.py ├── compute_orchestration └── configs │ ├── example_compute_cluster_config.yaml │ ├── example_deployment_config.yaml │ └── example_nodepool_config.yaml ├── conftest.py ├── openai_model_test.py ├── requirements.txt ├── runners ├── dummy_mcp_model │ ├── 1 │ │ └── model.py │ ├── config.yaml │ └── requirements.txt ├── dummy_runner_models │ ├── 1 │ │ └── model.py │ ├── config.yaml │ └── requirements.txt ├── hf_mbart_model │ ├── 1 │ │ └── model.py │ ├── config.yaml │ └── requirements.txt ├── test_data_handler.py ├── test_download_checkpoints.py ├── test_model_classes.py ├── test_model_run_locally-container.py ├── test_model_run_locally.py ├── test_model_signatures.py ├── test_model_upload.py ├── test_num_threads_config.py ├── test_openai_model.py ├── test_runners.py └── test_url_fetcher.py ├── test_app.py ├── test_auth.py ├── test_data_upload.py ├── test_eval.py ├── test_misc.py ├── test_model_predict.py ├── test_model_train.py ├── test_modules.py ├── test_rag.py ├── test_search.py ├── test_stub.py └── workflow ├── fixtures ├── general.yml ├── multi_branch.yml ├── single_branch_with_custom_cropper_model-version.yml ├── single_branch_with_custom_cropper_model.yml ├── single_branch_with_public_cropper_model.yml ├── single_branch_with_public_cropper_model_and_latest_version.yml └── single_node.yml ├── test_create_delete.py ├── test_export.py ├── test_nodes_display.py ├── test_predict.py └── test_validate.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = tests/*, setup.py, clarifai/models/api.py 3 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 14 | 15 | 16 | 17 | ### Why 18 | 19 | * 20 | 21 | ### How 22 | 23 | * 24 | 25 | ### Tests 26 | 27 | * 28 | 29 | ### Notes 30 | 31 | * 32 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Publish package 5 | 6 | on: 7 | push: 8 | tags: 9 | - '[0-9]+.[0-9]+.[0-9a-zA-Z]+' # Matches 1.2.3, 1.2.3alpha1 etc. 10 | 11 | jobs: 12 | publish-pypi: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.9' 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install -r .github/workflows/requirements.txt 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 28 | run: | 29 | python -m build 30 | twine upload --non-interactive dist/* 31 | publish-github-release: 32 | needs: publish-pypi 33 | name: Create Release 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@v4 37 | - name: Create Release 38 | uses: softprops/action-gh-release@v1 39 | with: 40 | name: Release ${{ github.ref_name }} 41 | draft: false 42 | prerelease: false 43 | -------------------------------------------------------------------------------- /.github/workflows/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools==70.0.0 2 | build==1.2.1 3 | twine==6.1.0 4 | -------------------------------------------------------------------------------- /.github/workflows/run_tests_prod.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests - Prod 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: '0 6,18 * * *' 7 | 8 | jobs: 9 | sdk-python-tests-prod: 10 | uses: Clarifai/clarifai-python/.github/workflows/run_tests.yml@master 11 | with: 12 | PERIODIC_CHECKS: "true" 13 | CLARIFAI_ENV: "prod" 14 | CLARIFAI_GRPC_BASE: "api.clarifai.com" 15 | secrets: inherit 16 | -------------------------------------------------------------------------------- /.github/workflows/run_tests_staging.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests - Staging 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: '0 6,18 * * *' 7 | 8 | jobs: 9 | sdk-python-tests-staging: 10 | uses: Clarifai/clarifai-python/.github/workflows/run_tests.yml@master 11 | with: 12 | PERIODIC_CHECKS: "true" 13 | CLARIFAI_ENV: "staging" 14 | CLARIFAI_GRPC_BASE: "api-staging.clarifai.com" 15 | secrets: inherit 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Editor Directories 105 | .idea 106 | *.swp 107 | .vscode 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # macOS specific files 137 | .DS_Store 138 | 139 | # temp files. 140 | *~ 141 | *# 142 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length=99 3 | 4 | # Ignore differences in whitespace. 5 | ignore_whitespace=True 6 | 7 | known_third_party= 8 | grpc, 9 | yaml, 10 | cached_property, 11 | lap, 12 | magic, 13 | PIL, 14 | pydub, 15 | google, 16 | mock, 17 | requests, 18 | googleapiclient, 19 | oauth2client, 20 | dns, 21 | noise, 22 | urllib3, 23 | simplejson, 24 | imageio, 25 | retrying, 26 | json_lines 27 | 28 | 29 | # isort being dumb 30 | known_future_library= 31 | future, 32 | six, 33 | past 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: 2 | - pre-commit 3 | - manual 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | - id: check-added-large-files 9 | - id: check-executables-have-shebangs 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: end-of-file-fixer 13 | - id: mixed-line-ending 14 | - id: trailing-whitespace 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | # Ruff version. 17 | rev: v0.11.4 18 | hooks: 19 | # Run the linter. 20 | - id: ruff 21 | args: [ --fix, --exit-non-zero-on-fix ] 22 | # Run the formatter. 23 | - id: ruff-format 24 | -------------------------------------------------------------------------------- /.ruff.toml: -------------------------------------------------------------------------------- 1 | indent-width = 4 2 | line-length = 99 3 | 4 | [lint] 5 | select = [ 6 | "F", 7 | "E", 8 | "W", 9 | "I", 10 | "PLE", # pylint errors 11 | "PLW", # pylint warnings 12 | "PLC", # pylint conventions 13 | "PLR", # pylint refactor 14 | 15 | # TODO to add more, just bigger diff: 16 | # "C", # comprehension cleanups 17 | # "D", # docstring formatting 18 | # "RUF", # additional ruff specific things 19 | ] 20 | 21 | ignore = [ 22 | # Things we shouldn't bother fixing: 23 | "E111", # indentation multiple of 4. 24 | "E402", # Module level import not at top of file 25 | "E501", # line-too-long 26 | "E701", # Multiple statements on one line (colon) 27 | "E722", # bare-except 28 | "E731", # Do not assign a `lambda` expression, use a `def 29 | "E741", # Ambiguous variable name 30 | "E743", # Ambiguous function name 31 | "W605", # invalid escape sequence 32 | "C416", # unnecessary-comprehension 33 | "C901", # too complex 34 | "PLC1802", # use len without comparison 35 | "PLC0206", # values from dict without items 36 | "PLR1714", # merging multiple comparisons 37 | "PLW1508", # invalid-envvar-default 38 | "PLR0911", # too-many-return-statements 39 | "PLR0912", # too-many-branches 40 | "PLR0913", # too-many-arguments 41 | "PLR0915", # too-many-statements 42 | 43 | # TODO: Should fix: 44 | "F841", # unused-variable 45 | "W291", # trailing-whitespace 46 | 47 | # TODO: pylint ones we haven't fixed yet but should: 48 | "PLW0603", # global-statement 49 | "PLW2901", # loop var overwritten by assignment target 50 | "PLR2004", # needs a constant 51 | "PLR1704", # redefined-argument-from-local 52 | ] 53 | 54 | [lint.per-file-ignores] 55 | # Ignore autogenerate proto quirks 56 | "proto/**.py" = ["F401","E712","F403","I001"] 57 | "*_pb2*.py" = ["F401","E712","F821","E501","I001"] 58 | "*.ipynb" = ["F401","F821","I001"] 59 | # we do lots of type checking in here 60 | "utils/argspec.py" = ["E721"] 61 | # wildcard imports 62 | "conf/segmentation/test.py" = ["F403","F405"] 63 | "conf/tf_striate/slim/__init__.py" = ["F403"] 64 | 65 | [format] 66 | # just keep what we have so it's less of a change. 67 | quote-style = "preserve" 68 | # always use spaces instead of tabs 69 | indent-style = "space" 70 | 71 | exclude = [ 72 | "*pb2*.py", # Exclude generated protos 73 | "proto/**.py", # Exclude generated protos 74 | "*.ipynb", # Skip notebooks. 75 | ] 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Clarifai, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-include *.css 2 | include clarifai/modules/style.css 3 | recursive-include clarifai * 4 | include requirements.txt 5 | -------------------------------------------------------------------------------- /clarifai/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "11.4.10" 2 | -------------------------------------------------------------------------------- /clarifai/cli.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/cli.py -------------------------------------------------------------------------------- /clarifai/cli/README.md: -------------------------------------------------------------------------------- 1 | # Clarifai CLI 2 | 3 | ## Overview 4 | 5 | Clarifai offers a user-friendly interface for deploying your local model into production with Clarifai, featuring: 6 | 7 | * A convenient command-line interface (CLI) 8 | * Easy implementation and testing in Python 9 | * No need for MLops expertise. 10 | 11 | ## Compute Orchestration 12 | 13 | Quick example for deploying a `visual-classifier` model 14 | 15 | ### Login 16 | 17 | First, login to cli using clarifai account details in a config file as shown below: 18 | 19 | ```bash 20 | $ clarifai login --config 21 | ``` 22 | 23 | ### Setup 24 | 25 | To prepare for deployment step, we have to setup a Compute Cluster with Nodepool of required server config to deploy the model. 26 | 27 | So, First, create a new Compute Cluster 28 | ```bash 29 | $ clarifai computecluster create --config 30 | ``` 31 | 32 | Then, create a new Nodepool in the created Compute Cluster 33 | ```bash 34 | $ clarifai nodepool create --config 35 | ``` 36 | 37 | ### Deployment 38 | 39 | After setup, we can deploy the `visual-classifier` model using a deployment config file as shown below: 40 | 41 | ```bash 42 | $ clarifai deployment create --config 43 | ``` 44 | 45 | ### List Resources 46 | 47 | List out existing Compute Clusters: 48 | 49 | ```bash 50 | $ clarifai computecluster list 51 | ``` 52 | 53 | List out existing Nodepools: 54 | 55 | ```bash 56 | $ clarifai nodepool list --compute_cluster_id 57 | ``` 58 | 59 | List out existing Deployments: 60 | 61 | ```bash 62 | $ clarifai deployment list --nodepool_id 63 | ``` 64 | 65 | ### Delete Resources 66 | 67 | Delete existing Deployment: 68 | 69 | ```bash 70 | $ clarifai deployment delete --nodepool_id --deployment_id 71 | ``` 72 | 73 | Delete existing Nodepool: 74 | 75 | ```bash 76 | $ clarifai nodepool delete --compute_cluster_id --nodepool_id 77 | ``` 78 | 79 | Delete existing Compute Clusters: 80 | 81 | ```bash 82 | $ clarifai computecluster delete --compute_cluster_id 83 | ``` 84 | 85 | ## Learn More 86 | 87 | * [Example Configs](https://github.com/Clarifai/examples/tree/main/ComputeOrchestration/configs) 88 | -------------------------------------------------------------------------------- /clarifai/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/cli/__init__.py -------------------------------------------------------------------------------- /clarifai/cli/__main__.py: -------------------------------------------------------------------------------- 1 | from clarifai.cli.base import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /clarifai/cli/compute_cluster.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import click 4 | 5 | from clarifai.cli.base import cli 6 | from clarifai.utils.cli import AliasedGroup, display_co_resources, validate_context 7 | 8 | 9 | @cli.group( 10 | ['computecluster', 'cc'], 11 | cls=AliasedGroup, 12 | context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, 13 | ) 14 | def computecluster(): 15 | """Manage Compute Clusters: create, delete, list""" 16 | 17 | 18 | @computecluster.command(['c']) 19 | @click.argument('compute_cluster_id') 20 | @click.option( 21 | '--config', 22 | type=click.Path(exists=True), 23 | required=True, 24 | help='Path to the compute cluster config file.', 25 | ) 26 | @click.pass_context 27 | def create(ctx, compute_cluster_id, config): 28 | """Create a new Compute Cluster with the given config file.""" 29 | from clarifai.client.user import User 30 | 31 | validate_context(ctx) 32 | user = User( 33 | user_id=ctx.obj.current.user_id, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base 34 | ) 35 | if compute_cluster_id: 36 | user.create_compute_cluster(config, compute_cluster_id=compute_cluster_id) 37 | else: 38 | user.create_compute_cluster(config) 39 | 40 | 41 | @computecluster.command(['ls']) 42 | @click.option('--page_no', required=False, help='Page number to list.', default=1) 43 | @click.option('--per_page', required=False, help='Number of items per page.', default=16) 44 | @click.pass_context 45 | def list(ctx, page_no, per_page): 46 | """List all compute clusters for the user.""" 47 | from clarifai.client.user import User 48 | 49 | validate_context(ctx) 50 | user = User( 51 | user_id=ctx.obj.current.user_id, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base 52 | ) 53 | response = user.list_compute_clusters(page_no, per_page) 54 | display_co_resources( 55 | response, 56 | custom_columns={ 57 | 'ID': lambda c: c.id, 58 | 'USER_ID': lambda c: c.user_id, 59 | 'DESCRIPTION': lambda c: c.description, 60 | }, 61 | ) 62 | 63 | 64 | @computecluster.command(['rm']) 65 | @click.argument('compute_cluster_id') 66 | @click.pass_context 67 | def delete(ctx, compute_cluster_id): 68 | """Deletes a compute cluster for the user.""" 69 | from clarifai.client.user import User 70 | 71 | validate_context(ctx) 72 | user = User( 73 | user_id=ctx.obj.current.user_id, pat=ctx.obj.current.pat, base_url=ctx.obj.current.api_base 74 | ) 75 | user.delete_compute_clusters([compute_cluster_id]) 76 | -------------------------------------------------------------------------------- /clarifai/cli/deployment.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import click 4 | 5 | from clarifai.cli.base import cli 6 | from clarifai.utils.cli import AliasedGroup, display_co_resources, from_yaml, validate_context 7 | 8 | 9 | @cli.group( 10 | ['deployment', 'dp'], 11 | cls=AliasedGroup, 12 | context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, 13 | ) 14 | def deployment(): 15 | """Manage Deployments: create, delete, list""" 16 | 17 | 18 | @deployment.command(['c']) 19 | @click.argument('nodepool_id') 20 | @click.argument('deployment_id') 21 | @click.option( 22 | '--config', 23 | type=click.Path(exists=True), 24 | required=True, 25 | help='Path to the deployment config file.', 26 | ) 27 | @click.pass_context 28 | def create(ctx, nodepool_id, deployment_id, config): 29 | """Create a new Deployment with the given config file.""" 30 | 31 | from clarifai.client.nodepool import Nodepool 32 | 33 | validate_context(ctx) 34 | if not nodepool_id: 35 | deployment_config = from_yaml(config) 36 | nodepool_id = deployment_config['deployment']['nodepools'][0]['id'] 37 | 38 | nodepool = Nodepool( 39 | nodepool_id=nodepool_id, 40 | user_id=ctx.obj.current.user_id, 41 | pat=ctx.obj.current.pat, 42 | base_url=ctx.obj.current.api_base, 43 | ) 44 | if deployment_id: 45 | nodepool.create_deployment(config, deployment_id=deployment_id) 46 | else: 47 | nodepool.create_deployment(config) 48 | 49 | 50 | @deployment.command(['ls']) 51 | @click.argument('nodepool_id', default="") 52 | @click.option('--page_no', required=False, help='Page number to list.', default=1) 53 | @click.option('--per_page', required=False, help='Number of items per page.', default=16) 54 | @click.pass_context 55 | def list(ctx, nodepool_id, page_no, per_page): 56 | """List all deployments for the nodepool.""" 57 | from clarifai.client.compute_cluster import ComputeCluster 58 | from clarifai.client.nodepool import Nodepool 59 | from clarifai.client.user import User 60 | 61 | validate_context(ctx) 62 | if nodepool_id: 63 | nodepool = Nodepool( 64 | nodepool_id=nodepool_id, 65 | user_id=ctx.obj.current.user_id, 66 | pat=ctx.obj.current.pat, 67 | base_url=ctx.obj.current.api_base, 68 | ) 69 | response = nodepool.list_deployments(page_no=page_no, per_page=per_page) 70 | else: 71 | user = User( 72 | user_id=ctx.obj.current.user_id, 73 | pat=ctx.obj.current.pat, 74 | base_url=ctx.obj.current.api_base, 75 | ) 76 | ccs = user.list_compute_clusters(page_no, per_page) 77 | nps = [] 78 | for cc in ccs: 79 | compute_cluster = ComputeCluster( 80 | compute_cluster_id=cc.id, 81 | user_id=ctx.obj.current.user_id, 82 | pat=ctx.obj.current.pat, 83 | base_url=ctx.obj.current.api_base, 84 | ) 85 | nps.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)]) 86 | response = [] 87 | for np in nps: 88 | nodepool = Nodepool( 89 | nodepool_id=np.id, 90 | user_id=ctx.obj.current.user_id, 91 | pat=ctx.obj.current.pat, 92 | base_url=ctx.obj.current.api_base, 93 | ) 94 | response.extend( 95 | [i for i in nodepool.list_deployments(page_no=page_no, per_page=per_page)] 96 | ) 97 | 98 | display_co_resources( 99 | response, 100 | custom_columns={ 101 | 'ID': lambda c: c.id, 102 | 'USER_ID': lambda c: c.user_id, 103 | 'COMPUTE_CLUSTER_ID': lambda c: c.nodepools[0].compute_cluster.id, 104 | 'NODEPOOL_ID': lambda c: c.nodepools[0].id, 105 | 'MODEL_USER_ID': lambda c: c.worker.model.user_id, 106 | 'MODEL_APP_ID': lambda c: c.worker.model.app_id, 107 | 'MODEL_ID': lambda c: c.worker.model.id, 108 | 'MODEL_VERSION_ID': lambda c: c.worker.model.model_version.id, 109 | 'DESCRIPTION': lambda c: c.description, 110 | }, 111 | ) 112 | 113 | 114 | @deployment.command(['rm']) 115 | @click.argument('nodepool_id') 116 | @click.argument('deployment_id') 117 | @click.pass_context 118 | def delete(ctx, nodepool_id, deployment_id): 119 | """Deletes a deployment for the nodepool.""" 120 | from clarifai.client.nodepool import Nodepool 121 | 122 | validate_context(ctx) 123 | nodepool = Nodepool( 124 | nodepool_id=nodepool_id, 125 | user_id=ctx.obj.current.user_id, 126 | pat=ctx.obj.current.pat, 127 | base_url=ctx.obj.current.api_base, 128 | ) 129 | nodepool.delete_deployments([deployment_id]) 130 | -------------------------------------------------------------------------------- /clarifai/cli/nodepool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import click 4 | 5 | from clarifai.cli.base import cli 6 | from clarifai.utils.cli import ( 7 | AliasedGroup, 8 | display_co_resources, 9 | dump_yaml, 10 | from_yaml, 11 | validate_context, 12 | ) 13 | 14 | 15 | @cli.group( 16 | ['nodepool', 'np'], 17 | cls=AliasedGroup, 18 | context_settings={'max_content_width': shutil.get_terminal_size().columns - 10}, 19 | ) 20 | def nodepool(): 21 | """Manage Nodepools: create, delete, list""" 22 | 23 | 24 | @nodepool.command(['c']) 25 | @click.argument('compute_cluster_id') 26 | @click.argument('nodepool_id') 27 | @click.option( 28 | '--config', 29 | type=click.Path(exists=True), 30 | required=True, 31 | help='Path to the nodepool config file.', 32 | ) 33 | @click.pass_context 34 | def create(ctx, compute_cluster_id, nodepool_id, config): 35 | """Create a new Nodepool with the given config file.""" 36 | from clarifai.client.compute_cluster import ComputeCluster 37 | 38 | validate_context(ctx) 39 | nodepool_config = from_yaml(config) 40 | if not compute_cluster_id: 41 | if 'compute_cluster' not in nodepool_config['nodepool']: 42 | click.echo( 43 | "Please provide a compute cluster ID either in the config file or using --compute_cluster_id flag", 44 | err=True, 45 | ) 46 | return 47 | compute_cluster_id = nodepool_config['nodepool']['compute_cluster']['id'] 48 | elif 'compute_cluster' not in nodepool_config['nodepool']: 49 | nodepool_config['nodepool']['compute_cluster']['id'] = compute_cluster_id 50 | dump_yaml(config, nodepool_config) 51 | 52 | compute_cluster = ComputeCluster( 53 | compute_cluster_id=compute_cluster_id, 54 | user_id=ctx.obj.current.user_id, 55 | pat=ctx.obj.current.pat, 56 | base_url=ctx.obj.current.api_base, 57 | ) 58 | if nodepool_id: 59 | compute_cluster.create_nodepool(config, nodepool_id=nodepool_id) 60 | else: 61 | compute_cluster.create_nodepool(config) 62 | 63 | 64 | @nodepool.command(['ls']) 65 | @click.argument('compute_cluster_id', default="") 66 | @click.option('--page_no', required=False, help='Page number to list.', default=1) 67 | @click.option('--per_page', required=False, help='Number of items per page.', default=128) 68 | @click.pass_context 69 | def list(ctx, compute_cluster_id, page_no, per_page): 70 | """List all nodepools for the user across all compute clusters. If compute_cluster_id is provided 71 | it will list only within that compute cluster.""" 72 | from clarifai.client.compute_cluster import ComputeCluster 73 | from clarifai.client.user import User 74 | 75 | validate_context(ctx) 76 | 77 | cc_id = compute_cluster_id 78 | 79 | if cc_id: 80 | compute_cluster = ComputeCluster( 81 | compute_cluster_id=cc_id, 82 | user_id=ctx.obj.current.user_id, 83 | pat=ctx.obj.current.pat, 84 | base_url=ctx.obj.current.api_base, 85 | ) 86 | response = compute_cluster.list_nodepools(page_no, per_page) 87 | else: 88 | user = User( 89 | user_id=ctx.obj.current.user_id, 90 | pat=ctx.obj.current.pat, 91 | base_url=ctx.obj.current.api_base, 92 | ) 93 | ccs = user.list_compute_clusters(page_no, per_page) 94 | response = [] 95 | for cc in ccs: 96 | compute_cluster = ComputeCluster( 97 | compute_cluster_id=cc.id, 98 | user_id=ctx.obj.current.user_id, 99 | pat=ctx.obj.current.pat, 100 | base_url=ctx.obj.current.api_base, 101 | ) 102 | response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)]) 103 | 104 | display_co_resources( 105 | response, 106 | custom_columns={ 107 | 'ID': lambda c: c.id, 108 | 'USER_ID': lambda c: c.compute_cluster.user_id, 109 | 'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id, 110 | 'DESCRIPTION': lambda c: c.description, 111 | }, 112 | ) 113 | 114 | 115 | @nodepool.command(['rm']) 116 | @click.argument('compute_cluster_id') 117 | @click.argument('nodepool_id') 118 | @click.pass_context 119 | def delete(ctx, compute_cluster_id, nodepool_id): 120 | """Deletes a nodepool for the user.""" 121 | from clarifai.client.compute_cluster import ComputeCluster 122 | 123 | validate_context(ctx) 124 | compute_cluster = ComputeCluster( 125 | compute_cluster_id=compute_cluster_id, 126 | user_id=ctx.obj.current.user_id, 127 | pat=ctx.obj.current.pat, 128 | base_url=ctx.obj.current.api_base, 129 | ) 130 | compute_cluster.delete_nodepools([nodepool_id]) 131 | -------------------------------------------------------------------------------- /clarifai/client/__init__.py: -------------------------------------------------------------------------------- 1 | from clarifai.client.app import App 2 | from clarifai.client.auth.register import V2Stub 3 | from clarifai.client.auth.stub import create_stub 4 | from clarifai.client.base import BaseClient 5 | from clarifai.client.dataset import Dataset 6 | from clarifai.client.input import Inputs 7 | from clarifai.client.lister import Lister 8 | from clarifai.client.model import Model 9 | from clarifai.client.module import Module 10 | from clarifai.client.search import Search 11 | from clarifai.client.user import User 12 | from clarifai.client.workflow import Workflow 13 | 14 | __all__ = [ 15 | 'V2Stub', 16 | 'create_stub', 17 | 'User', 18 | 'App', 19 | 'Model', 20 | 'Workflow', 21 | 'Module', 22 | 'Lister', 23 | 'Dataset', 24 | 'Inputs', 25 | 'BaseClient', 26 | 'Search', 27 | ] 28 | -------------------------------------------------------------------------------- /clarifai/client/auth/__init__.py: -------------------------------------------------------------------------------- 1 | from clarifai.client.auth.register import V2Stub 2 | from clarifai.client.auth.stub import create_stub 3 | 4 | __all__ = ('V2Stub', 'create_stub') 5 | -------------------------------------------------------------------------------- /clarifai/client/auth/register.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import grpc 4 | from clarifai_grpc.grpc.api import service_pb2_grpc 5 | 6 | 7 | class V2Stub(abc.ABC): 8 | """Abstract base class of clarifai api rpc client stubs""" 9 | 10 | 11 | class RpcCallable(abc.ABC): 12 | """Abstract base class of clarifai api rpc callables""" 13 | 14 | 15 | # add grpc classes as subclasses of the abcs, so they also succeed in isinstance calls 16 | def _register_classes(): 17 | V2Stub.register(service_pb2_grpc.V2Stub) 18 | for name in dir(grpc): 19 | if name.endswith('Callable'): 20 | RpcCallable.register(getattr(grpc, name)) 21 | 22 | 23 | _register_classes() 24 | -------------------------------------------------------------------------------- /clarifai/client/deployment.py: -------------------------------------------------------------------------------- 1 | from clarifai_grpc.grpc.api import resources_pb2 2 | 3 | from clarifai.client.base import BaseClient 4 | from clarifai.client.lister import Lister 5 | from clarifai.utils.constants import DEFAULT_BASE 6 | from clarifai.utils.logging import logger 7 | from clarifai.utils.protobuf import dict_to_protobuf 8 | 9 | 10 | class Deployment(Lister, BaseClient): 11 | """Deployment is a class that provides access to Clarifai API endpoints related to Deployment information.""" 12 | 13 | def __init__( 14 | self, 15 | deployment_id: str = None, 16 | user_id: str = None, 17 | base_url: str = DEFAULT_BASE, 18 | pat: str = None, 19 | token: str = None, 20 | root_certificates_path: str = None, 21 | **kwargs, 22 | ): 23 | """Initializes a Deployment object. 24 | 25 | Args: 26 | deployment_id (str): The Deployment ID for the Deployment to interact with. 27 | user_id (str): The user ID of the user. 28 | base_url (str): Base API url. Default "https://api.clarifai.com" 29 | pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT 30 | token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN 31 | root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections. 32 | **kwargs: Additional keyword arguments to be passed to the deployment. 33 | """ 34 | self.kwargs = {**kwargs, 'id': deployment_id, 'user_id': user_id} 35 | self.deployment_info = resources_pb2.Deployment() 36 | dict_to_protobuf(self.deployment_info, self.kwargs) 37 | self.logger = logger 38 | BaseClient.__init__( 39 | self, 40 | user_id=user_id, 41 | base=base_url, 42 | pat=pat, 43 | token=token, 44 | root_certificates_path=root_certificates_path, 45 | ) 46 | Lister.__init__(self) 47 | 48 | @staticmethod 49 | def get_runner_selector(user_id: str, deployment_id: str) -> resources_pb2.RunnerSelector: 50 | """Returns a RunnerSelector object for the given deployment_id. 51 | 52 | Args: 53 | deployment_id (str): The deployment ID for the deployment. 54 | 55 | Returns: 56 | resources_pb2.RunnerSelector: A RunnerSelector object for the given deployment_id. 57 | """ 58 | return resources_pb2.RunnerSelector( 59 | deployment=resources_pb2.Deployment(id=deployment_id, user_id=user_id) 60 | ) 61 | 62 | def __getattr__(self, name): 63 | return getattr(self.deployment_info, name) 64 | 65 | def __str__(self): 66 | init_params = [param for param in self.kwargs.keys()] 67 | attribute_strings = [ 68 | f"{param}={getattr(self.deployment_info, param)}" 69 | for param in init_params 70 | if hasattr(self.deployment_info, param) 71 | ] 72 | return f"Deployment Details: \n{', '.join(attribute_strings)}\n" 73 | -------------------------------------------------------------------------------- /clarifai/client/lister.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Generator 2 | 3 | from clarifai_grpc.grpc.api.status import status_code_pb2 4 | from google.protobuf.json_format import MessageToDict 5 | 6 | from clarifai.client.base import BaseClient 7 | 8 | 9 | class Lister(BaseClient): 10 | """Lister class for obtaining paginated results from the Clarifai API.""" 11 | 12 | def __init__(self, page_size: int = 16): 13 | self.default_page_size = page_size 14 | 15 | def list_pages_generator( 16 | self, 17 | endpoint: Callable, 18 | proto_message: Any, 19 | request_data: Dict[str, Any], 20 | page_no: int = None, 21 | per_page: int = None, 22 | ) -> Generator[Dict[str, Any], None, None]: 23 | """Lists pages of a resource. 24 | 25 | Args: 26 | endpoint (Callable): The endpoint to call. 27 | proto_message (Any): The proto message to use. 28 | request_data (dict): The request data to use. 29 | page_no (int): The page number to list. 30 | per_page (int): The number of items per page. 31 | 32 | Yields: 33 | response_dict: The next item in the listing. 34 | """ 35 | page = 1 if not page_no else page_no 36 | if page_no and not per_page: 37 | per_page = self.default_page_size 38 | while True: 39 | request_data['page'] = page 40 | request_data['per_page'] = per_page 41 | response = self._grpc_request(endpoint, proto_message(**request_data)) 42 | dict_response = MessageToDict(response, preserving_proto_field_name=True) 43 | if response.status.code != status_code_pb2.SUCCESS: 44 | raise Exception(f"Listing failed with response {response!r}") 45 | if len(list(dict_response.keys())) == 1: 46 | break 47 | else: 48 | listing_resource = list(dict_response.keys())[1] 49 | for item in dict_response[listing_resource]: 50 | if listing_resource == "dataset_inputs": 51 | yield self.process_response_keys(item["input"], listing_resource[:-1]) 52 | else: 53 | yield self.process_response_keys(item, listing_resource[:-1]) 54 | if page_no is not None or per_page is not None: 55 | break 56 | page += 1 57 | -------------------------------------------------------------------------------- /clarifai/client/module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Generator 2 | 3 | from clarifai_grpc.grpc.api import resources_pb2, service_pb2 4 | 5 | from clarifai.client.base import BaseClient 6 | from clarifai.client.lister import Lister 7 | from clarifai.errors import UserError 8 | from clarifai.urls.helper import ClarifaiUrlHelper 9 | from clarifai.utils.constants import DEFAULT_BASE 10 | from clarifai.utils.logging import logger 11 | 12 | 13 | class Module(Lister, BaseClient): 14 | """Module is a class that provides access to Clarifai API endpoints related to Module information.""" 15 | 16 | def __init__( 17 | self, 18 | url: str = None, 19 | module_id: str = None, 20 | module_version: Dict = {'id': ""}, 21 | base_url: str = DEFAULT_BASE, 22 | pat: str = None, 23 | token: str = None, 24 | root_certificates_path: str = None, 25 | **kwargs, 26 | ): 27 | """Initializes a Module object. 28 | 29 | Args: 30 | url (str): The URL to initialize the module object. 31 | module_id (str): The Module ID to interact with. 32 | module_version (dict): The Module Version to interact with. 33 | base_url (str): Base API url. Default "https://api.clarifai.com" 34 | pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT. 35 | token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN. 36 | root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections. 37 | **kwargs: Additional keyword arguments to be passed to the Module. 38 | """ 39 | if url and module_id: 40 | raise UserError("You can only specify one of url or module_id.") 41 | if not url and not module_id: 42 | raise UserError("You must specify one of url or module_id.") 43 | if url: 44 | user_id, app_id, module_id, module_version_id = ClarifaiUrlHelper.split_module_ui_url( 45 | url 46 | ) 47 | module_version = {'id': module_version_id} 48 | kwargs = {'user_id': user_id, 'app_id': app_id} 49 | 50 | self.kwargs = {**kwargs, 'id': module_id, 'module_version': module_version} 51 | self.module_info = resources_pb2.Module(**self.kwargs) 52 | self.logger = logger 53 | BaseClient.__init__( 54 | self, 55 | user_id=self.user_id, 56 | app_id=self.app_id, 57 | base=base_url, 58 | pat=pat, 59 | token=token, 60 | root_certificates_path=root_certificates_path, 61 | ) 62 | Lister.__init__(self) 63 | 64 | def list_versions( 65 | self, page_no: int = None, per_page: int = None 66 | ) -> Generator['Module', None, None]: 67 | """Lists all the module versions for the module. 68 | 69 | Args: 70 | page_no (int): The page number to list. 71 | per_page (int): The number of items per page. 72 | 73 | Yields: 74 | Moudle: Module objects for versions of the module. 75 | 76 | Example: 77 | >>> from clarifai.client.module import Module 78 | >>> module = Module(module_id='module_id', user_id='user_id', app_id='app_id') 79 | >>> all_Module_versions = list(module.list_versions()) 80 | 81 | Note: 82 | Defaults to 16 per page if page_no is specified and per_page is not specified. 83 | If both page_no and per_page are None, then lists all the resources. 84 | """ 85 | request_data = dict( 86 | user_app_id=self.user_app_id, 87 | module_id=self.id, 88 | ) 89 | all_module_versions_info = self.list_pages_generator( 90 | self.STUB.ListModuleVersions, 91 | service_pb2.ListModuleVersionsRequest, 92 | request_data, 93 | per_page=per_page, 94 | page_no=page_no, 95 | ) 96 | 97 | for module_version_info in all_module_versions_info: 98 | module_version_info['id'] = module_version_info['module_version_id'] 99 | del module_version_info['module_version_id'] 100 | yield Module.from_auth_helper( 101 | self.auth_helper, 102 | module_id=self.id, 103 | **dict(self.kwargs, module_version=module_version_info), 104 | ) 105 | 106 | def __getattr__(self, name): 107 | return getattr(self.module_info, name) 108 | 109 | def __str__(self): 110 | init_params = [param for param in self.kwargs.keys()] 111 | attribute_strings = [ 112 | f"{param}={getattr(self.module_info, param)}" 113 | for param in init_params 114 | if hasattr(self.module_info, param) 115 | ] 116 | return f"Module Details: \n{', '.join(attribute_strings)}\n" 117 | -------------------------------------------------------------------------------- /clarifai/client/runner.py: -------------------------------------------------------------------------------- 1 | from clarifai_grpc.grpc.api import resources_pb2 2 | 3 | from clarifai.client.base import BaseClient 4 | from clarifai.client.lister import Lister 5 | from clarifai.utils.constants import DEFAULT_BASE 6 | from clarifai.utils.logging import logger 7 | from clarifai.utils.protobuf import dict_to_protobuf 8 | 9 | 10 | class Runner(Lister, BaseClient): 11 | """Runner is a class that provides access to Clarifai API endpoints related to Runner information.""" 12 | 13 | def __init__( 14 | self, 15 | runner_id: str = None, 16 | user_id: str = None, 17 | base_url: str = DEFAULT_BASE, 18 | pat: str = None, 19 | token: str = None, 20 | root_certificates_path: str = None, 21 | **kwargs, 22 | ): 23 | """Initializes a Runner object. 24 | 25 | Args: 26 | runner_id (str): The Runner ID for the Runner to interact with. 27 | user_id (str): The user ID of the user. 28 | base_url (str): Base API url. Default "https://api.clarifai.com" 29 | pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT 30 | token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN 31 | root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections. 32 | **kwargs: Additional keyword arguments to be passed to the runner. 33 | """ 34 | self.kwargs = {**kwargs, 'id': runner_id} 35 | self.runner_info = resources_pb2.Runner() 36 | dict_to_protobuf(self.runner_info, self.kwargs) 37 | 38 | self.logger = logger 39 | BaseClient.__init__( 40 | self, 41 | user_id=user_id, 42 | base=base_url, 43 | pat=pat, 44 | token=token, 45 | root_certificates_path=root_certificates_path, 46 | ) 47 | Lister.__init__(self) 48 | 49 | def __getattr__(self, name): 50 | return getattr(self.runner_info, name) 51 | 52 | def __str__(self): 53 | init_params = [param for param in self.kwargs.keys()] 54 | attribute_strings = [ 55 | f"{param}={getattr(self.runner_info, param)}" 56 | for param in init_params 57 | if hasattr(self.runner_info, param) 58 | ] 59 | return f"Runner Details: \n{', '.join(attribute_strings)}\n" 60 | -------------------------------------------------------------------------------- /clarifai/constants/base.py: -------------------------------------------------------------------------------- 1 | COMPUTE_ORCHESTRATION_RESOURCES = ['Runner', 'ComputeCluster', 'Nodepool', 'Deployment'] 2 | -------------------------------------------------------------------------------- /clarifai/constants/dataset.py: -------------------------------------------------------------------------------- 1 | DATASET_UPLOAD_TASKS = [ 2 | "visual_classification", 3 | "text_classification", 4 | "visual_detection", 5 | "visual_segmentation", 6 | "visual_captioning", 7 | "multimodal_dataset", 8 | ] 9 | 10 | TASK_TO_ANNOTATION_TYPE = { 11 | "visual_classification": {"concepts": "labels"}, 12 | "text_classification": {"concepts": "labels"}, 13 | "visual_captioning": {"concepts": "labels"}, 14 | "visual_detection": {"bboxes": "bboxes"}, 15 | "visual_segmentation": {"polygons": "polygons"}, 16 | } 17 | 18 | MAX_RETRIES = 2 19 | 20 | CONTENT_TYPE = {"json": "application/json", "zip": "application/zip"} 21 | -------------------------------------------------------------------------------- /clarifai/constants/input.py: -------------------------------------------------------------------------------- 1 | MAX_UPLOAD_BATCH_SIZE = 128 2 | -------------------------------------------------------------------------------- /clarifai/constants/model.py: -------------------------------------------------------------------------------- 1 | TRAINABLE_MODEL_TYPES = [ 2 | 'visual-classifier', 3 | 'visual-detector', 4 | 'visual-segmenter', 5 | 'visual-embedder', 6 | 'clusterer', 7 | 'text-classifier', 8 | 'embedding-classifier', 9 | 'text-to-text', 10 | ] 11 | MAX_MODEL_PREDICT_INPUTS = 128 12 | MODEL_EXPORT_TIMEOUT = 1800 13 | MIN_RANGE_SIZE = 4194304 # 4MB 14 | MAX_RANGE_SIZE = 314572800 # 300MB 15 | MIN_CHUNK_SIZE = 131072 # 128KB 16 | MAX_CHUNK_SIZE = 10485760 # 10MB 17 | RANGE_SIZE = 31457280 # 30MB 18 | CHUNK_SIZE = 1048576 # 1MB 19 | -------------------------------------------------------------------------------- /clarifai/constants/rag.py: -------------------------------------------------------------------------------- 1 | MAX_UPLOAD_BATCH_SIZE = 128 2 | -------------------------------------------------------------------------------- /clarifai/constants/search.py: -------------------------------------------------------------------------------- 1 | DEFAULT_TOP_K = 10 2 | DEFAULT_SEARCH_METRIC = "euclidean" 3 | DEFAULT_SEARCH_ALGORITHM = "nearest_neighbor" 4 | -------------------------------------------------------------------------------- /clarifai/constants/workflow.py: -------------------------------------------------------------------------------- 1 | MAX_WORKFLOW_PREDICT_INPUTS = 32 2 | -------------------------------------------------------------------------------- /clarifai/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/datasets/__init__.py -------------------------------------------------------------------------------- /clarifai/datasets/export/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/datasets/export/__init__.py -------------------------------------------------------------------------------- /clarifai/datasets/upload/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/datasets/upload/__init__.py -------------------------------------------------------------------------------- /clarifai/datasets/upload/base.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Iterator, List, Tuple, TypeVar, Union 3 | 4 | from clarifai_grpc.grpc.api import resources_pb2 5 | 6 | from clarifai.constants.dataset import DATASET_UPLOAD_TASKS 7 | from clarifai.datasets.upload.features import ( 8 | MultiModalFeatures, 9 | TextFeatures, 10 | VisualClassificationFeatures, 11 | VisualDetectionFeatures, 12 | VisualSegmentationFeatures, 13 | ) 14 | 15 | OutputFeaturesType = TypeVar( 16 | 'OutputFeaturesType', 17 | bound=Union[ 18 | TextFeatures, 19 | VisualClassificationFeatures, 20 | VisualDetectionFeatures, 21 | VisualSegmentationFeatures, 22 | MultiModalFeatures, 23 | ], 24 | ) 25 | 26 | 27 | class ClarifaiDataset: 28 | """Clarifai datasets base class.""" 29 | 30 | def __init__( 31 | self, data_generator: 'ClarifaiDataLoader', dataset_id: str, max_workers: int = 4 32 | ) -> None: 33 | self.data_generator = data_generator 34 | self.dataset_id = dataset_id 35 | self.max_workers = max_workers 36 | self.all_input_ids = {} 37 | self._all_input_protos = {} 38 | self._all_annotation_protos = defaultdict(list) 39 | 40 | def __len__(self) -> int: 41 | """Get size of all input protos""" 42 | return len(self.data_generator) 43 | 44 | def _to_list(self, input_protos: Iterator) -> List: 45 | """Parse protos iterator to list.""" 46 | return list(input_protos) 47 | 48 | def _extract_protos(self) -> None: 49 | """Create input image protos for each data generator item.""" 50 | raise NotImplementedError() 51 | 52 | def get_protos( 53 | self, input_ids: List[int] 54 | ) -> Tuple[List[resources_pb2.Input], List[resources_pb2.Annotation]]: 55 | """Get input and annotation protos based on input_ids. 56 | Args: 57 | input_ids: List of input IDs to retrieve the protos for. 58 | Returns: 59 | Input and Annotation proto iterators for the specified input IDs. 60 | """ 61 | input_protos, annotation_protos = self._extract_protos(input_ids) 62 | 63 | return input_protos, annotation_protos 64 | 65 | 66 | class ClarifaiDataLoader: 67 | """Clarifai data loader base class.""" 68 | 69 | def __init__(self) -> None: 70 | pass 71 | 72 | @property 73 | def task(self): 74 | raise NotImplementedError("Task should be one of {}".format(DATASET_UPLOAD_TASKS)) 75 | 76 | def load_data(self) -> None: 77 | raise NotImplementedError() 78 | 79 | def __len__(self) -> int: 80 | raise NotImplementedError() 81 | 82 | def __getitem__(self, index: int) -> OutputFeaturesType: 83 | raise NotImplementedError() 84 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/features.py: -------------------------------------------------------------------------------- 1 | #! dataset output features (output from preprocessing & input to clarifai data proto builders) 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Union 4 | 5 | 6 | @dataclass 7 | class TextFeatures: 8 | """Text classification datasets preprocessing output features.""" 9 | 10 | text: str 11 | labels: List[Union[str, int]] = None # List[str or int] to cater for multi-class tasks 12 | id: Optional[int] = None # text_id 13 | metadata: Optional[dict] = None 14 | label_ids: Optional[List[str]] = None 15 | 16 | 17 | @dataclass 18 | class VisualClassificationFeatures: 19 | """Image classification datasets preprocessing output features.""" 20 | 21 | image_path: str 22 | labels: List[Union[str, int]] # List[str or int] to cater for multi-class tasks 23 | geo_info: Optional[List[float]] = None # [Longitude, Latitude] 24 | id: Optional[int] = None # image_id 25 | metadata: Optional[dict] = None 26 | image_bytes: Optional[bytes] = None 27 | label_ids: Optional[List[str]] = None 28 | 29 | 30 | @dataclass 31 | class VisualDetectionFeatures: 32 | """Image Detection datasets preprocessing output features.""" 33 | 34 | image_path: str 35 | labels: List[Union[str, int]] 36 | bboxes: List[List[float]] 37 | geo_info: Optional[List[float]] = None # [Longitude, Latitude] 38 | id: Optional[int] = None # image_id 39 | metadata: Optional[dict] = None 40 | image_bytes: Optional[bytes] = None 41 | label_ids: Optional[List[str]] = None 42 | 43 | 44 | @dataclass 45 | class VisualSegmentationFeatures: 46 | """Image Segmentation datasets preprocessing output features.""" 47 | 48 | image_path: str 49 | labels: List[Union[str, int]] 50 | polygons: List[List[List[float]]] 51 | geo_info: Optional[List[float]] = None # [Longitude, Latitude] 52 | id: Optional[int] = None # image_id 53 | metadata: Optional[dict] = None 54 | image_bytes: Optional[bytes] = None 55 | label_ids: Optional[List[str]] = None 56 | 57 | 58 | @dataclass 59 | class MultiModalFeatures: 60 | """Multi-modal datasets preprocessing output features.""" 61 | 62 | text: str 63 | image_bytes: str 64 | labels: List[Union[str, int]] = None # List[str or int] to cater for multi-class tasks 65 | id: Optional[int] = None # image_id 66 | metadata: Optional[dict] = None 67 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/loaders/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset Loaders 2 | 3 | A collection of data preprocessing modules for popular public datasets to allow for compatible upload into Clarifai user app datasets. 4 | 5 | ## Usage 6 | 7 | If a dataset module exists in the zoo, uploading the specific dataset can be easily done by simply creating a python script (or via commandline) and specifying the dataloader object in the `dataloader` parameter of the `Dataset` class, `upload_dataset` method .i.e. 8 | 9 | ```python 10 | from clarifai.client.app import App 11 | from clarifai.datasets.upload.loaders.coco_detection import COCODetectionDataLoader 12 | 13 | app = App(app_id="", user_id="") 14 | # Create a dataset in Clarifai App 15 | dataset = app.create_dataset(dataset_id="") 16 | # instantiate dataloader object 17 | coco_det_dataloader = COCODetectionDataLoader(images_dir="", label_filepath="") 18 | # execute data upload to Clarifai app dataset 19 | dataset.upload_dataset(dataloader=coco_det_dataloader) 20 | ``` 21 | 22 | ## Dataset Loaders 23 | 24 | | dataset name | task | module name (.py) 25 | | --- | --- | --- 26 | | [COCO 2017](https://cocodataset.org/#download) | Detection | `coco_detection` | 27 | | | Captions | `coco_captions` | 28 | |[xVIEW](http://xviewdataset.org/) | Detection | `xview_detection` | 29 | | [ImageNet](https://www.image-net.org/) | Classification | `imagenet_classification` | 30 | ## Contributing To Loaders 31 | 32 | A dataloader (preprocessing) module is a python script that contains a dataloader class which implements dataloader methods. 33 | 34 | The class naming convention is `DataLoader`. The dataset class must inherit from `ClarifaiDataLoader` and the `__getitem__` method must return either of `VisualClassificationFeatures()`, `VisualDetectionFeatures()`, `VisualSegmentationFeatures()` or `TextFeatures()` as defined in [clarifai/datasets/upload/features.py](../features.py). Other methods can be added as seen fit but must be inherited from parent `ClarifaiDataLoader` base class [clarifai/datasets/upload/base.py](../base.py). 35 | Reference can be taken from the existing dataset modules in the zoo for development. 36 | 37 | ## Notes 38 | 39 | * COCO Format: To reuse the coco modules above on your coco format data, ensure the criteria in the two points above is adhered to first. If so, pass the coco images_dir and labels_filepath from any of the above in the loaders to the `dataloader=` parameter in `upload_dataset()`. 40 | 41 | * xVIEW Dataset: To upload, you have to register and download images,label from [xviewdataset](http://xviewdataset.org/#dataset) follow the above mentioned steps to place extracted folder in `data` directory. Finally pass the xview data_dir to `dataloader=` parameter in `upload_dataset()`. 42 | 43 | / 44 | ├── train_images/ 45 | ├── xview_train.geojson 46 | 47 | * ImageNet Dataset: ImageNet Dataset should be downloaded and placed in the 'data' folder along with the [label mapping file](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data?select=LOC_synset_mapping.txt). 48 | 49 | / 50 | ├── train/ 51 | ├── LOC_synset_mapping.txt 52 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/datasets/upload/loaders/__init__.py -------------------------------------------------------------------------------- /clarifai/datasets/upload/loaders/coco_captions.py: -------------------------------------------------------------------------------- 1 | #! COCO image captioning dataset 2 | 3 | import os 4 | 5 | from clarifai.datasets.upload.base import ClarifaiDataLoader 6 | 7 | from ..features import VisualClassificationFeatures 8 | 9 | # pycocotools is a dependency for this loader 10 | try: 11 | from pycocotools.coco import COCO 12 | except ImportError: 13 | raise ImportError( 14 | "Could not import pycocotools package. " 15 | "Please do `pip install 'clarifai[all]'` to import pycocotools." 16 | ) 17 | 18 | 19 | class COCOCaptionsDataLoader(ClarifaiDataLoader): 20 | """COCO Image Captioning Dataset.""" 21 | 22 | def __init__(self, images_dir, label_filepath): 23 | """ 24 | Args: 25 | images_dir: Directory containing the images. 26 | label_filepath: Path to the COCO annotation file. 27 | """ 28 | self.images_dir = images_dir 29 | self.label_filepath = label_filepath 30 | 31 | self.map_ids = {} 32 | self.load_data() 33 | 34 | @property 35 | def task(self): 36 | return "visual_captioning" 37 | 38 | def load_data(self) -> None: 39 | self.coco = COCO(self.label_filepath) 40 | self.map_ids = {i: img_id for i, img_id in enumerate(list(self.coco.imgs.keys()))} 41 | 42 | def __len__(self): 43 | return len(self.coco.imgs) 44 | 45 | def __getitem__(self, index): 46 | value = self.coco.imgs[self.map_ids[index]] 47 | image_path = os.path.join(self.images_dir, value['file_name']) 48 | annots = [] 49 | 50 | input_ann_ids = self.coco.getAnnIds(imgIds=[value['id']]) 51 | input_anns = self.coco.loadAnns(input_ann_ids) 52 | 53 | for ann in input_anns: 54 | annots.append(ann['caption']) 55 | 56 | return VisualClassificationFeatures(image_path, labels=annots[0], id=str(value['id'])) 57 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/loaders/coco_detection.py: -------------------------------------------------------------------------------- 1 | #! COCO detection dataset 2 | 3 | import os 4 | 5 | from ..base import ClarifaiDataLoader 6 | from ..features import VisualDetectionFeatures 7 | 8 | # pycocotools is a dependency for this loader 9 | try: 10 | from pycocotools.coco import COCO 11 | except ImportError: 12 | raise ImportError( 13 | "Could not import pycocotools package. " 14 | "Please do `pip install 'clarifai[all]'` to import pycocotools." 15 | ) 16 | 17 | 18 | class COCODetectionDataLoader(ClarifaiDataLoader): 19 | def __init__(self, images_dir, label_filepath): 20 | """ 21 | Args: 22 | images_dir: Directory containing the images. 23 | label_filepath: Path to the COCO annotation file. 24 | """ 25 | self.images_dir = images_dir 26 | self.label_filepath = label_filepath 27 | 28 | self.map_ids = {} 29 | self.load_data() 30 | 31 | @property 32 | def task(self): 33 | return "visual_detection" 34 | 35 | def load_data(self) -> None: 36 | self.coco = COCO(self.label_filepath) 37 | self.map_ids = {i: img_id for i, img_id in enumerate(list(self.coco.imgs.keys()))} 38 | 39 | def __getitem__(self, index: int): 40 | value = self.coco.imgs[self.map_ids[index]] 41 | image_path = os.path.join(self.images_dir, value['file_name']) 42 | annots = [] # bboxes 43 | concept_ids = [] 44 | 45 | input_ann_ids = self.coco.getAnnIds(imgIds=[value['id']]) 46 | input_anns = self.coco.loadAnns(input_ann_ids) 47 | 48 | for ann in input_anns: 49 | # get concept info 50 | # note1: concept_name can be human readable 51 | # note2: concept_id can only be alphanumeric, up to 32 characters, with no special chars except `-` and `_` 52 | concept_name = self.coco.cats[ann['category_id']]['name'] 53 | concept_id = concept_name.lower().replace(' ', '-') 54 | 55 | # get bbox information 56 | # note1: coco bboxes are `[x_min, y_min, width, height]` in pixels 57 | # note2: clarifai bboxes are `[x_min, y_min, x_max, y_max]` normalized between 0-1.0 58 | coco_bbox = ann['bbox'] 59 | clarifai_bbox = { 60 | 'left_col': max(0, coco_bbox[0] / value['width']), 61 | 'top_row': max(0, coco_bbox[1] / value['height']), 62 | 'right_col': min(1, (coco_bbox[0] + coco_bbox[2]) / value['width']), 63 | 'bottom_row': min(1, (coco_bbox[1] + coco_bbox[3]) / value['height']), 64 | } 65 | if (clarifai_bbox['left_col'] >= clarifai_bbox['right_col']) or ( 66 | clarifai_bbox['top_row'] >= clarifai_bbox['bottom_row'] 67 | ): 68 | continue 69 | annots.append( 70 | [ 71 | clarifai_bbox['left_col'], 72 | clarifai_bbox['top_row'], 73 | clarifai_bbox['right_col'], 74 | clarifai_bbox['bottom_row'], 75 | ] 76 | ) 77 | concept_ids.append(concept_id) 78 | 79 | assert len(concept_ids) == len(annots), ( 80 | f"Num concepts must match num bbox annotations\ 81 | for a single image. Found {len(concept_ids)} concepts and {len(annots)} bboxes." 82 | ) 83 | 84 | return VisualDetectionFeatures(image_path, concept_ids, annots, id=str(value['id'])) 85 | 86 | def __len__(self): 87 | return len(self.coco.imgs) 88 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/loaders/imagenet_classification.py: -------------------------------------------------------------------------------- 1 | #! ImageNet Classification dataset 2 | 3 | import os 4 | 5 | from clarifai.datasets.upload.base import ClarifaiDataLoader 6 | 7 | from ..features import VisualClassificationFeatures 8 | 9 | 10 | class ImageNetDataLoader(ClarifaiDataLoader): 11 | """ImageNet Dataset.""" 12 | 13 | def __init__(self, data_dir, split: str = "train"): 14 | """ 15 | Initialize dataset params. 16 | Args: 17 | data_dir: the local dataset directory. 18 | split: "train" or "test" 19 | """ 20 | self.split = split 21 | self.data_dir = data_dir 22 | self.label_map = dict() 23 | self.concepts = [] 24 | self.image_paths = [] 25 | 26 | self.load_data() 27 | 28 | @property 29 | def task(self): 30 | return "visual_classification" 31 | 32 | def load_data(self): 33 | # Creating label map 34 | with open(os.path.join(self.data_dir, "LOC_synset_mapping.txt")) as _file: 35 | for _id in _file: 36 | # Removing the spaces,upper quotes and Converting to set to remove repetitions. Then converting to list for compatibility. 37 | self.label_map[_id.split(" ")[0]] = list( 38 | { 39 | "".join(("".join((label.rstrip().lstrip().split(" ")))).split("'")) 40 | for label in _id[_id.find(" ") + 1 :].split(",") 41 | } 42 | ) 43 | 44 | for _folder in os.listdir(os.path.join(self.data_dir, self.split)): 45 | try: 46 | concept = self.label_map[_folder] # concepts 47 | except Exception: 48 | continue 49 | folder_path = os.path.join(self.data_dir, self.split) + "/" + _folder 50 | for _img in os.listdir(folder_path): 51 | if _img.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): 52 | self.concepts.append(concept) 53 | self.image_paths.append(folder_path + "/" + _img) 54 | 55 | assert len(self.concepts) == len(self.image_paths) 56 | "Number of concepts and images are not equal" 57 | 58 | def __len__(self): 59 | return len(self.image_paths) 60 | 61 | def __getitem__(self, idx): 62 | return VisualClassificationFeatures( 63 | image_path=self.image_paths[idx], 64 | labels=self.concepts[idx], 65 | id=self.image_paths[idx].split('.')[0].split('/')[-1], 66 | ) 67 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/multimodal.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import List, Tuple, Type 3 | 4 | from clarifai_grpc.grpc.api import resources_pb2 5 | from google.protobuf.struct_pb2 import Struct 6 | 7 | from clarifai.client.input import Inputs 8 | from clarifai.datasets.upload.base import ClarifaiDataLoader, ClarifaiDataset 9 | 10 | 11 | class MultiModalDataset(ClarifaiDataset): 12 | def __init__( 13 | self, data_generator: Type[ClarifaiDataLoader], dataset_id: str, max_workers: int = 4 14 | ) -> None: 15 | super().__init__(data_generator, dataset_id, max_workers) 16 | 17 | def _extract_protos( 18 | self, 19 | batch_input_ids: List[str], 20 | ) -> Tuple[List[resources_pb2.Input]]: 21 | """Creats Multimodal (image and text) input protos for batch of input ids. 22 | Args: 23 | batch_input_ids: List of input IDs to retrieve the protos for. 24 | Returns: 25 | input_protos: List of input protos. 26 | 27 | """ 28 | input_protos, annotation_protos = [], [] 29 | 30 | def process_data_item(id): 31 | data_item = self.data_generator[id] 32 | metadata = Struct() 33 | image_bytes = data_item.image_bytes 34 | text = data_item.text 35 | labels = ( 36 | data_item.labels 37 | if ((data_item.labels is None) or isinstance(data_item.labels, list)) 38 | else [data_item.labels] 39 | ) 40 | input_id = ( 41 | f"{self.dataset_id}-{id}" 42 | if data_item.id is None 43 | else f"{self.dataset_id}-{str(data_item.id)}" 44 | ) 45 | if data_item.metadata is not None: 46 | metadata.update(data_item.metadata) 47 | else: 48 | metadata = None 49 | 50 | self.all_input_ids[id] = input_id 51 | if data_item.image_bytes is not None: 52 | input_protos.append( 53 | Inputs.get_input_from_bytes( 54 | input_id=input_id, 55 | image_bytes=image_bytes, 56 | dataset_id=self.dataset_id, 57 | labels=labels, 58 | metadata=metadata, 59 | ) 60 | ) 61 | else: 62 | input_protos.append( 63 | Inputs.get_text_input( 64 | input_id=input_id, 65 | raw_text=text, 66 | dataset_id=self.dataset_id, 67 | labels=labels, 68 | metadata=metadata, 69 | ) 70 | ) 71 | 72 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 73 | futures = [executor.submit(process_data_item, id) for id in batch_input_ids] 74 | 75 | for job in futures: 76 | job.result() 77 | 78 | return input_protos, annotation_protos 79 | -------------------------------------------------------------------------------- /clarifai/datasets/upload/text.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import List, Tuple, Type 3 | 4 | from clarifai_grpc.grpc.api import resources_pb2 5 | from google.protobuf.struct_pb2 import Struct 6 | 7 | from clarifai.client.input import Inputs 8 | from clarifai.utils.misc import get_uuid 9 | 10 | from .base import ClarifaiDataLoader, ClarifaiDataset 11 | 12 | 13 | class TextClassificationDataset(ClarifaiDataset): 14 | """Upload text classification datasets to clarifai datasets""" 15 | 16 | def __init__( 17 | self, data_generator: Type[ClarifaiDataLoader], dataset_id: str, max_workers: int = 4 18 | ) -> None: 19 | super().__init__(data_generator, dataset_id, max_workers) 20 | 21 | def _extract_protos( 22 | self, batch_input_ids: List[int] 23 | ) -> Tuple[List[resources_pb2.Input], List[resources_pb2.Annotation]]: 24 | """Create input image and annotation protos for batch of input ids. 25 | Args: 26 | batch_input_ids: List of input IDs to retrieve the protos for. 27 | Returns: 28 | input_protos: List of input protos. 29 | annotation_protos: List of annotation protos. 30 | """ 31 | input_protos, annotation_protos = [], [] 32 | 33 | def process_data_item(id): 34 | data_item = self.data_generator[id] 35 | metadata = Struct() 36 | text = data_item.text 37 | labels = ( 38 | data_item.labels 39 | if ((data_item.labels is None) or isinstance(data_item.labels, list)) 40 | else [data_item.labels] 41 | ) # clarifai concept 42 | label_ids = data_item.label_ids 43 | input_id = ( 44 | f"{self.dataset_id}-{get_uuid(8)}" 45 | if data_item.id is None 46 | else f"{self.dataset_id}-{str(data_item.id)}" 47 | ) 48 | if data_item.metadata is not None: 49 | metadata.update(data_item.metadata) 50 | 51 | self.all_input_ids[id] = input_id 52 | input_protos.append( 53 | Inputs.get_text_input( 54 | input_id=input_id, 55 | raw_text=text, 56 | dataset_id=self.dataset_id, 57 | labels=labels, 58 | label_ids=label_ids, 59 | metadata=metadata, 60 | ) 61 | ) 62 | 63 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 64 | futures = [executor.submit(process_data_item, id) for id in batch_input_ids] 65 | for job in futures: 66 | job.result() 67 | 68 | return input_protos, annotation_protos 69 | -------------------------------------------------------------------------------- /clarifai/errors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import time 4 | 5 | import requests # noqa 6 | from google.protobuf.json_format import MessageToDict 7 | 8 | from clarifai.versions import CLIENT_VERSION, OS_VER, PYTHON_VERSION 9 | 10 | 11 | class TokenError(Exception): 12 | pass 13 | 14 | 15 | class ApiError(Exception): 16 | """API Server error""" 17 | 18 | def __init__( 19 | self, resource: str, params: dict, method: str, response: requests.Response = None 20 | ) -> None: 21 | self.resource = resource 22 | self.params = params 23 | self.method = method 24 | self.response = response 25 | 26 | self.error_code = 'N/A' 27 | self.error_desc = 'N/A' 28 | self.error_details = 'N/A' 29 | response_json = 'N/A' 30 | 31 | if response is not None: 32 | response_json_dict = MessageToDict(response) 33 | 34 | self.error_code = response_json_dict.get('status', {}).get('code', None) 35 | self.error_desc = response_json_dict.get('status', {}).get('description', None) 36 | self.error_details = response_json_dict.get('status', {}).get('details', None) 37 | response_json = json.dumps(response_json_dict['status'], indent=2) 38 | 39 | current_ts_str = str(time.time()) 40 | 41 | msg = """%(method)s %(resource)s FAILED(%(time_ts)s). error_code: %(error_code)s, error_description: %(error_desc)s, error_details: %(error_details)s 42 | >> Python client %(client_version)s with Python %(python_version)s on %(os_version)s 43 | >> %(method)s %(resource)s 44 | >> REQUEST(%(time_ts)s) %(request)s 45 | >> RESPONSE(%(time_ts)s) %(response)s""" % { 46 | 'method': method, 47 | 'resource': resource, 48 | 'error_code': self.error_code, 49 | 'error_desc': self.error_desc, 50 | 'error_details': self.error_details, 51 | 'request': json.dumps(params, indent=2), 52 | 'response': response_json, 53 | 'time_ts': current_ts_str, 54 | 'client_version': CLIENT_VERSION, 55 | 'python_version': PYTHON_VERSION, 56 | 'os_version': OS_VER, 57 | } 58 | 59 | super(ApiError, self).__init__(msg) 60 | 61 | 62 | class ApiClientError(Exception): 63 | """API Client Error""" 64 | 65 | 66 | class UserError(Exception): 67 | """User Error""" 68 | 69 | 70 | class AuthError(Exception): 71 | """Raised when a client has missing or invalid authentication.""" 72 | 73 | 74 | def _base_url(url: str) -> str: 75 | """ 76 | Extracts the base URL from the url, which is everything before the 4th slash character. 77 | https://www.clarifai.com/v2/models/1/output -> https://www.clarifai.com/v2/ 78 | """ 79 | try: 80 | return url[: _find_nth(url, '/', 4) + 1] 81 | except Exception: 82 | return '' 83 | 84 | 85 | def _find_nth(haystack: str, needle: str, n: int) -> int: 86 | start = haystack.find(needle) 87 | while start >= 0 and n > 1: 88 | start = haystack.find(needle, start + len(needle)) 89 | n -= 1 90 | return start 91 | -------------------------------------------------------------------------------- /clarifai/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/models/__init__.py -------------------------------------------------------------------------------- /clarifai/modules/README.md: -------------------------------------------------------------------------------- 1 | # Module Utils 2 | 3 | Additional helper functions for creating Clarifai Modules should be placed here so that they can be reused across modules. 4 | 5 | This should still not import streamlit as we want to keep clarifai-python-utils lightweight. If you find we need utilities for streamlit itself we should start a new repo for that. Please contact support@clarifai.com to do so. 6 | -------------------------------------------------------------------------------- /clarifai/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/modules/__init__.py -------------------------------------------------------------------------------- /clarifai/modules/css.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class ClarifaiStreamlitCSS(object): 5 | """ClarifaiStreamlitCSS helps get a consistent style by default for Clarifai provided 6 | streamlit apps. 7 | """ 8 | 9 | @classmethod 10 | def insert_default_css(cls, st): 11 | """Inserts the default style provided in style.css in this folder into the streamlit page 12 | 13 | Example: 14 | ClarifaiStreamlitCSS.insert_default_css() 15 | 16 | Note: 17 | This must be placed in both the app.py AND all the pages/*.py files to get the custom styles. 18 | """ 19 | file_name = os.path.join(os.path.dirname(__file__), "style.css") 20 | cls.insert_css_file(file_name, st) 21 | 22 | @classmethod 23 | def insert_css_file(cls, css_file, st): 24 | """Open the full filename to the css file and insert it's contents the style of the page.""" 25 | with open(css_file) as f: 26 | st.markdown(f'', unsafe_allow_html=True) 27 | 28 | @classmethod 29 | def buttonlink(cls, st, label, link, target="_parent", style=""): 30 | """ 31 | This is a streamlit button that will link to another page (or _self if target is _self). 32 | It is styled to look like the other stButton>button buttons that are created with st.button(). 33 | 34 | You must insert_default_css(st) before using on a page. 35 | 36 | Example: 37 | ClarifaiStreamlitCSS.insert_default_css(st) 38 | cols = st.columns(4) 39 | ClarifaiStreamlitCSS.buttonlink(cols[3], "Button", "https://clarifai.com", "_blank") 40 | 41 | Args: 42 | st: the streamlit package. 43 | label: the text string to display in the button. 44 | link: the url to link the button to. 45 | target: to open the link in same page (_parent) or new tab (_blank). 46 | style: additional style to apply to the button link. 47 | Ex: "background-color: rgb(45, 164, 78); color:white;" makes the button green background with white text. 48 | """ 49 | astyle = "" 50 | if style: 51 | astyle = f'style="{style}"' 52 | 53 | st.markdown( 54 | f''' 55 |
56 | {label} 57 |
58 | ''', 59 | unsafe_allow_html=True, 60 | ) 61 | -------------------------------------------------------------------------------- /clarifai/modules/pages.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | 4 | 5 | class ClarifaiModulePageManager(object): 6 | def __init__(self): 7 | # List all the available pages. 8 | page_files = sorted(glob.glob("pages/*.py")) 9 | self.page_names = [f.replace("pages/", "").replace(".py", "") for f in page_files] 10 | 11 | def get_page_from_query_params(self, qp): 12 | """ 13 | Args: 14 | qp: the streamlit query params st.experimental_get_query_params() 15 | """ 16 | # Get the page from query params or default to 1 from the url. 17 | page = qp.get("page", [None])[0] 18 | if page is None: 19 | page = self.page_names[0] 20 | # Check that the page number coming in is within the range of pages in the folder. 21 | if page not in self.page_names: 22 | raise Exception( 23 | "Page '%s' is not valid, there is no pages/%s.py file for this page. Valid page names are: %s" 24 | % (page, page, str(self.page_names)) 25 | ) 26 | 27 | return page 28 | 29 | def get_page_names(self): 30 | return self.page_names 31 | 32 | def render_page(self, page): 33 | # Since the page re-renders every time the selectbox changes, we'll always have the latest page out 34 | # of the query params. 35 | module_str = "pages.%s" % page 36 | # check if the page exists 37 | importlib.util.find_spec(module_str) 38 | if page is None: 39 | raise Exception("Page %s is was not found." % page) 40 | 41 | current_page = importlib.import_module(module_str) 42 | current_page.display() 43 | -------------------------------------------------------------------------------- /clarifai/rag/__init__.py: -------------------------------------------------------------------------------- 1 | from .rag import RAG 2 | 3 | __all__ = ["RAG"] 4 | -------------------------------------------------------------------------------- /clarifai/rag/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import requests 6 | 7 | 8 | ## TODO: Make this token-aware. 9 | def convert_messages_to_str(messages: List[dict]) -> str: 10 | """convert messages in OpenAI API format into a single string. 11 | 12 | Args: 13 | messages List[dict]: A list of dictionary in the following format: 14 | ``` 15 | [ 16 | {"role": "user", "content": "Hello there."}, 17 | {"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"}, 18 | {"role": "user", "content": "Can you explain LLMs in plain English?"}, 19 | ] 20 | ``` 21 | """ 22 | final_str = "" 23 | for msg in messages: 24 | if "role" in msg and "content" in msg: 25 | role = msg.get("role", "") 26 | content = msg.get("content", "") 27 | final_str += f"\n\n{role}: {content}" 28 | return final_str 29 | 30 | 31 | def format_assistant_message(raw_text: str) -> dict: 32 | return {"role": "assistant", "content": raw_text} 33 | 34 | 35 | def load_documents(file_path: str = None, folder_path: str = None, url: str = None) -> List[any]: 36 | """Loads documents from a local directory or public url or local filename. 37 | 38 | Args: 39 | file_path (str): The path to the filename. 40 | folder_path (str): The path to the folder. 41 | url (str): The url to the file. 42 | """ 43 | # check import packages 44 | try: 45 | from llama_index.core import Document, SimpleDirectoryReader 46 | from llama_index.core.readers.download import download_loader 47 | except ImportError: 48 | raise ImportError( 49 | "Could not import llama index package. " 50 | "Please install it with `pip install llama-index-core==0.10.1`." 51 | ) 52 | # document loaders for filepath 53 | if file_path: 54 | if file_path.endswith(".pdf"): 55 | PDFReader = download_loader("PDFReader") 56 | loader = PDFReader() 57 | documents = loader.load_data(file=Path(file_path)) 58 | elif file_path.endswith(".docx"): 59 | docReader = download_loader("DocxReader") 60 | loader = docReader() 61 | documents = loader.load_data(file=Path(file_path)) 62 | elif file_path.endswith(".txt"): 63 | with open(file_path, 'r') as file: 64 | text_content = file.read() 65 | documents = [Document(text=text_content)] 66 | else: 67 | raise ValueError("Only .pdf, .docx, and .txt files are supported.") 68 | 69 | # document loaders for folderpath 70 | if folder_path: 71 | documents = SimpleDirectoryReader( 72 | input_dir=Path(folder_path), required_exts=[".pdf", ".docx", ".txt"] 73 | ).load_data() 74 | 75 | # document loaders for url 76 | if url: 77 | response = requests.get(url) 78 | if response.status_code != 200: 79 | raise ValueError(f"Invalid url {url}.") 80 | # for text files 81 | try: 82 | documents = [Document(text=response.content)] 83 | # for pdf files 84 | except Exception: 85 | # check import packages 86 | try: 87 | from pypdf import PdfReader 88 | except ImportError: 89 | raise ImportError( 90 | "Could not import pypdf package. " 91 | "Please install it with `pip install pypdf==3.17.4`." 92 | ) 93 | documents = [] 94 | pdf_file = PdfReader(io.BytesIO(response.content)) 95 | num_pages = len(pdf_file.pages) 96 | for page in range(num_pages): 97 | page_text = pdf_file.pages[page].extract_text() 98 | documents.append(Document(text=page_text)) 99 | else: 100 | raise ValueError(f"Invalid url {url}.") 101 | 102 | return documents 103 | 104 | 105 | def split_document(text: str, chunk_size: int, chunk_overlap: int, **kwargs) -> List[str]: 106 | """Splits a document into chunks of text. 107 | 108 | Args: 109 | text (str): The text to split. 110 | chunk_size (int): The size of each chunk. 111 | chunk_overlap (int): The amount of overlap between each chunk. 112 | **kwargs: Additional keyword arguments for the SentenceSplitter. 113 | """ 114 | # check import packages 115 | try: 116 | from llama_index.core.node_parser.text import SentenceSplitter 117 | except ImportError: 118 | raise ImportError( 119 | "Could not import llama index package. " 120 | "Please install it with `pip install llama-index-core==0.10.24`." 121 | ) 122 | # document 123 | text_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs) 124 | text_chunks = text_parser.split_text(text) 125 | return text_chunks 126 | -------------------------------------------------------------------------------- /clarifai/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.model_builder import ModelBuilder 2 | from .models.model_class import ModelClass 3 | from .models.model_runner import ModelRunner 4 | from .models.openai_class import OpenAIModelClass 5 | 6 | __all__ = [ 7 | "ModelRunner", 8 | "ModelBuilder", 9 | "ModelClass", 10 | "OpenAIModelClass", 11 | ] 12 | -------------------------------------------------------------------------------- /clarifai/runners/dockerfile_template/Dockerfile.template: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1.13-labs 2 | FROM --platform=$TARGETPLATFORM ${FINAL_IMAGE} as final 3 | 4 | COPY --link requirements.txt /home/nonroot/requirements.txt 5 | 6 | # Update clarifai package so we always have latest protocol to the API. Everything should land in /venv 7 | RUN ["pip", "install", "--no-cache-dir", "-r", "/home/nonroot/requirements.txt"] 8 | RUN ["pip", "show", "clarifai"] 9 | 10 | # Set the NUMBA cache dir to /tmp 11 | # Set the TORCHINDUCTOR cache dir to /tmp 12 | # The CLARIFAI* will be set by the templaing system. 13 | ENV NUMBA_CACHE_DIR=/tmp/numba_cache \ 14 | TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_cache \ 15 | HOME=/tmp \ 16 | DEBIAN_FRONTEND=noninteractive 17 | 18 | ##### 19 | # Copy the files needed to download 20 | ##### 21 | # This creates the directory that HF downloader will populate and with nonroot:nonroot permissions up. 22 | COPY --chown=nonroot:nonroot downloader/unused.yaml /home/nonroot/main/1/checkpoints/.cache/unused.yaml 23 | 24 | ##### 25 | # Download checkpoints if config.yaml has checkpoints.when = "build" 26 | COPY --link=true config.yaml /home/nonroot/main/ 27 | RUN ["python", "-m", "clarifai.cli", "model", "download-checkpoints", "/home/nonroot/main", "--out_path", "/home/nonroot/main/1/checkpoints", "--stage", "build"] 28 | ##### 29 | 30 | # Copy in the actual files like config.yaml, requirements.txt, and most importantly 1/model.py 31 | # for the actual model. 32 | # If checkpoints aren't downloaded since a checkpoints: block is not provided, then they will 33 | # be in the build context and copied here as well. 34 | COPY --link=true 1 /home/nonroot/main/1 35 | # At this point we only need these for validation in the SDK. 36 | COPY --link=true requirements.txt config.yaml /home/nonroot/main/ 37 | 38 | # Add the model directory to the python path. 39 | ENV PYTHONPATH=${PYTHONPATH}:/home/nonroot/main \ 40 | CLARIFAI_PAT=${CLARIFAI_PAT} \ 41 | CLARIFAI_USER_ID=${CLARIFAI_USER_ID} \ 42 | CLARIFAI_RUNNER_ID=${CLARIFAI_RUNNER_ID} \ 43 | CLARIFAI_NODEPOOL_ID=${CLARIFAI_NODEPOOL_ID} \ 44 | CLARIFAI_COMPUTE_CLUSTER_ID=${CLARIFAI_COMPUTE_CLUSTER_ID} \ 45 | CLARIFAI_API_BASE=${CLARIFAI_API_BASE:-https://api.clarifai.com} 46 | 47 | # Finally run the clarifai entrypoint to start the runner loop and local dev server. 48 | # Note(zeiler): we may want to make this a clarifai CLI call. 49 | ENTRYPOINT ["python", "-m", "clarifai.runners.server"] 50 | CMD ["--model_path", "/home/nonroot/main"] 51 | ############################# 52 | -------------------------------------------------------------------------------- /clarifai/runners/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/runners/models/__init__.py -------------------------------------------------------------------------------- /clarifai/runners/models/model_servicer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import tee 3 | from typing import Iterator 4 | 5 | from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc 6 | from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2 7 | 8 | from ..utils.url_fetcher import ensure_urls_downloaded 9 | 10 | _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "1") 11 | 12 | 13 | class ModelServicer(service_pb2_grpc.V2Servicer): 14 | """ 15 | This is the servicer that will handle the gRPC requests from either the dev server or runner loop. 16 | """ 17 | 18 | def __init__(self, model): 19 | """ 20 | Args: 21 | model: The class that will handle the model logic. Must implement predict(), 22 | generate(), stream(). 23 | """ 24 | self.model = model 25 | 26 | def PostModelOutputs( 27 | self, request: service_pb2.PostModelOutputsRequest, context=None 28 | ) -> service_pb2.MultiOutputResponse: 29 | """ 30 | This is the method that will be called when the servicer is run. It takes in an input and 31 | returns an output. 32 | """ 33 | 34 | # Download any urls that are not already bytes. 35 | ensure_urls_downloaded(request) 36 | 37 | try: 38 | return self.model.predict_wrapper(request) 39 | except Exception as e: 40 | if _RAISE_EXCEPTIONS: 41 | raise 42 | return service_pb2.MultiOutputResponse( 43 | status=status_pb2.Status( 44 | code=status_code_pb2.MODEL_PREDICTION_FAILED, 45 | description="Failed", 46 | details="", 47 | internal_details=str(e), 48 | ) 49 | ) 50 | 51 | def GenerateModelOutputs( 52 | self, request: service_pb2.PostModelOutputsRequest, context=None 53 | ) -> Iterator[service_pb2.MultiOutputResponse]: 54 | """ 55 | This is the method that will be called when the servicer is run. It takes in an input and 56 | returns an output. 57 | """ 58 | # Download any urls that are not already bytes. 59 | ensure_urls_downloaded(request) 60 | 61 | try: 62 | yield from self.model.generate_wrapper(request) 63 | except Exception as e: 64 | if _RAISE_EXCEPTIONS: 65 | raise 66 | yield service_pb2.MultiOutputResponse( 67 | status=status_pb2.Status( 68 | code=status_code_pb2.MODEL_PREDICTION_FAILED, 69 | description="Failed", 70 | details="", 71 | internal_details=str(e), 72 | ) 73 | ) 74 | 75 | def StreamModelOutputs( 76 | self, request: Iterator[service_pb2.PostModelOutputsRequest], context=None 77 | ) -> Iterator[service_pb2.MultiOutputResponse]: 78 | """ 79 | This is the method that will be called when the servicer is run. It takes in an input and 80 | returns an output. 81 | """ 82 | # Duplicate the iterator 83 | request, request_copy = tee(request) 84 | 85 | # Download any urls that are not already bytes. 86 | for req in request: 87 | ensure_urls_downloaded(req) 88 | 89 | try: 90 | yield from self.model.stream_wrapper(request_copy) 91 | except Exception as e: 92 | if _RAISE_EXCEPTIONS: 93 | raise 94 | yield service_pb2.MultiOutputResponse( 95 | status=status_pb2.Status( 96 | code=status_code_pb2.MODEL_PREDICTION_FAILED, 97 | description="Failed", 98 | details="", 99 | internal_details=str(e), 100 | ) 101 | ) 102 | -------------------------------------------------------------------------------- /clarifai/runners/models/openai_class.py: -------------------------------------------------------------------------------- 1 | """Base class for creating OpenAI-compatible API server.""" 2 | 3 | import json 4 | from typing import Any, Dict, Iterator 5 | 6 | from clarifai.runners.models.model_class import ModelClass 7 | 8 | 9 | class OpenAIModelClass(ModelClass): 10 | """Base class for wrapping OpenAI-compatible servers as a model running in Clarifai. 11 | This handles all the transport between the API and the OpenAI-compatible server. 12 | 13 | To use this class, create a subclass and set the following class attributes: 14 | - client: The OpenAI-compatible client instance 15 | - model: The name of the model to use with the client 16 | 17 | Example: 18 | class MyOpenAIModel(OpenAIModelClass): 19 | client = OpenAI(api_key="your-key") 20 | model = "gpt-4" 21 | """ 22 | 23 | # These should be overridden in subclasses 24 | client = None 25 | model = None 26 | 27 | def __init__(self) -> None: 28 | if self.client is None: 29 | raise NotImplementedError("Subclasses must set the 'client' class attribute") 30 | if self.model is None: 31 | try: 32 | self.model = self.client.models.list().data[0].id 33 | except Exception as e: 34 | raise NotImplementedError( 35 | "Subclasses must set the 'model' class attribute or ensure the client can list models" 36 | ) from e 37 | 38 | def _create_completion_args(self, params: Dict[str, Any]) -> Dict[str, Any]: 39 | """Create the completion arguments dictionary from parameters. 40 | 41 | Args: 42 | params: Dictionary of parameters extracted from request 43 | 44 | Returns: 45 | Dict containing the completion arguments 46 | """ 47 | completion_args = {**params} 48 | completion_args.update({"model": self.model}) 49 | stream = completion_args.pop("stream", False) 50 | if stream: 51 | # Force to use usage 52 | stream_options = params.pop("stream_options", {}) 53 | stream_options.update({"include_usage": True}) 54 | completion_args["stream_options"] = stream_options 55 | completion_args["stream"] = stream 56 | 57 | return completion_args 58 | 59 | def _set_usage(self, resp): 60 | if resp.usage and resp.usage.prompt_tokens and resp.usage.completion_tokens: 61 | self.set_output_context( 62 | prompt_tokens=resp.usage.prompt_tokens, 63 | completion_tokens=resp.usage.completion_tokens, 64 | ) 65 | 66 | @ModelClass.method 67 | def openai_transport(self, msg: str) -> str: 68 | """The single model method to get the OpenAI-compatible request and send it to the OpenAI server 69 | then return its response. 70 | 71 | Args: 72 | msg: JSON string containing the request parameters 73 | 74 | Returns: 75 | JSON string containing the response or error 76 | """ 77 | try: 78 | request_data = json.loads(msg) 79 | completion_args = self._create_completion_args(request_data) 80 | completion = self.client.chat.completions.create(**completion_args) 81 | self._set_usage(completion) 82 | return json.dumps(completion.model_dump()) 83 | 84 | except Exception as e: 85 | return f"Error: {e}" 86 | 87 | @ModelClass.method 88 | def openai_stream_transport(self, msg: str) -> Iterator[str]: 89 | """Process an OpenAI-compatible request and return a streaming response iterator. 90 | This method is used when stream=True and returns an iterator of strings directly, 91 | without converting to a list or JSON serializing. 92 | 93 | Args: 94 | msg: The request as a JSON string. 95 | 96 | Returns: 97 | Iterator[str]: An iterator yielding text chunks from the streaming response. 98 | """ 99 | try: 100 | request_data = json.loads(msg) 101 | completion_args = self._create_completion_args(request_data) 102 | completion_stream = self.client.chat.completions.create(**completion_args) 103 | for chunk in completion_stream: 104 | self._set_usage(chunk) 105 | yield json.dumps(chunk.model_dump()) 106 | 107 | except Exception as e: 108 | yield f"Error: {e}" 109 | -------------------------------------------------------------------------------- /clarifai/runners/models/visual_classifier_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from io import BytesIO 4 | from typing import Dict, Iterator, List 5 | 6 | import cv2 7 | import torch 8 | from PIL import Image as PILImage 9 | 10 | from clarifai.runners.models.model_class import ModelClass 11 | from clarifai.runners.utils.data_types import Concept, Frame, Image 12 | from clarifai.utils.logging import logger 13 | 14 | 15 | class VisualClassifierClass(ModelClass): 16 | """Base class for visual classification models supporting image and video processing.""" 17 | 18 | @staticmethod 19 | def preprocess_image(image_bytes: bytes) -> PILImage: 20 | """Convert image bytes to PIL Image.""" 21 | return PILImage.open(BytesIO(image_bytes)).convert("RGB") 22 | 23 | @staticmethod 24 | def video_to_frames(video_bytes: bytes) -> Iterator[Frame]: 25 | """Convert video bytes to frames. 26 | 27 | Args: 28 | video_bytes: Raw video data in bytes 29 | 30 | Yields: 31 | Frame with JPEG encoded frame data as bytes and timestamp in milliseconds 32 | """ 33 | with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file: 34 | temp_video_file.write(video_bytes) 35 | temp_video_path = temp_video_file.name 36 | logger.debug(f"temp_video_path: {temp_video_path}") 37 | 38 | video = cv2.VideoCapture(temp_video_path) 39 | logger.debug(f"video opened: {video.isOpened()}") 40 | 41 | while video.isOpened(): 42 | ret, frame = video.read() 43 | if not ret: 44 | break 45 | # Get frame timestamp in milliseconds 46 | timestamp_ms = video.get(cv2.CAP_PROP_POS_MSEC) 47 | frame_bytes = cv2.imencode('.jpg', frame)[1].tobytes() 48 | yield Frame(image=Image(bytes=frame_bytes), time=timestamp_ms) 49 | 50 | video.release() 51 | os.unlink(temp_video_path) 52 | 53 | @staticmethod 54 | def process_concepts( 55 | logits: torch.Tensor, threshold: float, model_labels: Dict[int, str] 56 | ) -> List[List[Concept]]: 57 | """Convert model logits into a structured format of concepts. 58 | 59 | Args: 60 | logits: Model output logits as a tensor (batch_size x num_classes) 61 | model_labels: Dictionary mapping label indices to label names 62 | 63 | Returns: 64 | List of lists containing Concept objects for each input in the batch 65 | """ 66 | outputs = [] 67 | for logit in logits: 68 | probs = torch.softmax(logit, dim=-1) 69 | sorted_indices = torch.argsort(probs, dim=-1, descending=True) 70 | output_concepts = [] 71 | for idx in sorted_indices: 72 | concept = Concept(name=model_labels[idx.item()], value=probs[idx].item()) 73 | output_concepts.append(concept) 74 | outputs.append(output_concepts) 75 | return outputs 76 | -------------------------------------------------------------------------------- /clarifai/runners/models/visual_detector_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from io import BytesIO 4 | from typing import Dict, Iterator, List 5 | 6 | import cv2 7 | import torch 8 | from PIL import Image as PILImage 9 | 10 | from clarifai.runners.models.model_class import ModelClass 11 | from clarifai.runners.utils.data_types import Concept, Frame, Image, Region 12 | from clarifai.utils.logging import logger 13 | 14 | 15 | class VisualDetectorClass(ModelClass): 16 | """Base class for visual detection models supporting image and video processing.""" 17 | 18 | @staticmethod 19 | def preprocess_image(image_bytes: bytes) -> PILImage: 20 | """Convert image bytes to PIL Image.""" 21 | return PILImage.open(BytesIO(image_bytes)).convert("RGB") 22 | 23 | @staticmethod 24 | def video_to_frames(video_bytes: bytes) -> Iterator[Frame]: 25 | """Convert video bytes to frames. 26 | 27 | Args: 28 | video_bytes: Raw video data in bytes 29 | 30 | Yields: 31 | Frame with JPEG encoded frame data as bytes and timestamp in milliseconds 32 | """ 33 | with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file: 34 | temp_video_file.write(video_bytes) 35 | temp_video_path = temp_video_file.name 36 | logger.debug(f"temp_video_path: {temp_video_path}") 37 | 38 | video = cv2.VideoCapture(temp_video_path) 39 | logger.debug(f"video opened: {video.isOpened()}") 40 | 41 | while video.isOpened(): 42 | ret, frame = video.read() 43 | if not ret: 44 | break 45 | # Get frame timestamp in milliseconds 46 | timestamp_ms = video.get(cv2.CAP_PROP_POS_MSEC) 47 | frame_bytes = cv2.imencode('.jpg', frame)[1].tobytes() 48 | yield Frame(image=Image(bytes=frame_bytes), time=timestamp_ms) 49 | 50 | video.release() 51 | os.unlink(temp_video_path) 52 | 53 | @staticmethod 54 | def process_detections( 55 | results: List[Dict[str, torch.Tensor]], threshold: float, model_labels: Dict[int, str] 56 | ) -> List[List[Region]]: 57 | """Convert model outputs into a structured format of detections. 58 | 59 | Args: 60 | results: Raw detection results from model 61 | threshold: Confidence threshold for detections 62 | model_labels: Dictionary mapping label indices to names 63 | 64 | Returns: 65 | List of lists containing Region objects for each detection 66 | """ 67 | outputs = [] 68 | for result in results: 69 | detections = [] 70 | for score, label_idx, box in zip(result["scores"], result["labels"], result["boxes"]): 71 | if score > threshold: 72 | label = model_labels[label_idx.item()] 73 | detections.append( 74 | Region( 75 | box=box.tolist(), concepts=[Concept(name=label, value=score.item())] 76 | ) 77 | ) 78 | outputs.append(detections) 79 | return outputs 80 | -------------------------------------------------------------------------------- /clarifai/runners/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/runners/utils/__init__.py -------------------------------------------------------------------------------- /clarifai/runners/utils/const.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models') 4 | 5 | GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb" 6 | 7 | AMD_GIT_SHA = "81e942130173f54927e7c9a65aabc7e32780616d" 8 | 9 | PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA 10 | TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA 11 | 12 | AMD_PYTHON_BASE_IMAGE = registry + '/amd-python-base:{python_version}-' + AMD_GIT_SHA 13 | AMD_TORCH_BASE_IMAGE = ( 14 | registry + '/amd-torch:{torch_version}-py{python_version}-{gpu_version}-' + AMD_GIT_SHA 15 | ) 16 | AMD_VLLM_BASE_IMAGE = ( 17 | registry + '/amd-vllm:{torch_version}-py{python_version}-{gpu_version}-' + AMD_GIT_SHA 18 | ) 19 | 20 | # List of available python base images 21 | AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12'] 22 | 23 | DEFAULT_PYTHON_VERSION = 3.12 24 | 25 | DEFAULT_AMD_TORCH_VERSION = '2.8.0.dev20250511+rocm6.4' 26 | 27 | DEFAULT_AMD_GPU_VERSION = 'rocm6.4' 28 | 29 | # By default we download at runtime. 30 | DEFAULT_DOWNLOAD_CHECKPOINT_WHEN = "runtime" 31 | 32 | # Folder for downloading checkpoints at runtime. 33 | DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache") 34 | 35 | # List of available torch images 36 | # Keep sorted by most recent cuda version. 37 | AVAILABLE_TORCH_IMAGES = [ 38 | '2.4.1-py3.11-cu124', 39 | '2.5.1-py3.11-cu124', 40 | '2.4.1-py3.12-cu124', 41 | '2.5.1-py3.12-cu124', 42 | '2.6.0-py3.12-cu126', 43 | '2.7.0-py3.12-cu128', 44 | '2.7.0-py3.12-rocm6.3', 45 | ] 46 | 47 | CONCEPTS_REQUIRED_MODEL_TYPE = [ 48 | 'visual-classifier', 49 | 'visual-detector', 50 | 'visual-segmenter', 51 | 'text-classifier', 52 | ] 53 | -------------------------------------------------------------------------------- /clarifai/runners/utils/data_types/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_types import JSON # noqa 2 | from .data_types import Audio # noqa 3 | from .data_types import Concept # noqa 4 | from .data_types import Frame # noqa 5 | from .data_types import Image # noqa 6 | from .data_types import MessageData # noqa 7 | from .data_types import NamedFields # noqa 8 | from .data_types import NamedFieldsMeta # noqa 9 | from .data_types import Region # noqa 10 | from .data_types import Text # noqa 11 | from .data_types import Video # noqa 12 | from .data_types import cast # noqa 13 | -------------------------------------------------------------------------------- /clarifai/runners/utils/url_fetcher.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | 3 | import fsspec 4 | 5 | from clarifai.utils.logging import logger 6 | 7 | 8 | def download_input(input): 9 | _download_input_data(input.data) 10 | if input.data.parts: 11 | for i in range(len(input.data.parts)): 12 | _download_input_data(input.data.parts[i].data) 13 | 14 | 15 | def _download_input_data(input_data): 16 | """ 17 | This function will download any urls that are not already bytes. 18 | """ 19 | if input_data.image.url and not input_data.image.base64: 20 | # Download the image 21 | with fsspec.open(input_data.image.url, 'rb') as f: 22 | input_data.image.base64 = f.read() 23 | if input_data.video.url and not input_data.video.base64: 24 | # Download the video 25 | with fsspec.open(input_data.video.url, 'rb') as f: 26 | input_data.video.base64 = f.read() 27 | if input_data.audio.url and not input_data.audio.base64: 28 | # Download the audio 29 | with fsspec.open(input_data.audio.url, 'rb') as f: 30 | input_data.audio.base64 = f.read() 31 | if input_data.text.url and not input_data.text.raw: 32 | # Download the text 33 | with fsspec.open(input_data.text.url, 'r') as f: 34 | input_data.text.raw = f.read() 35 | 36 | 37 | def ensure_urls_downloaded(request, max_threads=128): 38 | """ 39 | This function will download any urls that are not already bytes and parallelize with a thread pool. 40 | """ 41 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: 42 | futures = [] 43 | for input in request.inputs: 44 | futures.append(executor.submit(download_input, input)) 45 | for future in concurrent.futures.as_completed(futures): 46 | try: 47 | future.result() 48 | except Exception as e: 49 | logger.exception(f"Error downloading input: {e}") 50 | -------------------------------------------------------------------------------- /clarifai/schema/search.py: -------------------------------------------------------------------------------- 1 | from schema import And, Optional, Regex, Schema 2 | 3 | 4 | def get_schema() -> Schema: 5 | """Initialize the schema for rank and filter. 6 | 7 | This schema validates: 8 | 9 | - Rank and filter must be a list 10 | - Each item in the list must be a dict 11 | - The dict can contain these optional keys: 12 | - 'image_url': Valid URL string 13 | - 'text_raw': Non-empty string 14 | - 'metadata': Dict 15 | - 'image_bytes': Bytes 16 | - 'geo_point': Dict with 'longitude', 'latitude' and 'geo_limit' as float, float and int respectively 17 | - 'concepts': List where each item is a concept dict 18 | - Concept dict requires at least one of: 19 | - 'name': Non-empty string with dashes/underscores 20 | - 'id': Non-empty string 21 | - 'language': Non-empty string 22 | - 'value': 0 or 1 integer 23 | - 'input_types': List of 'image', 'video', 'text' or 'audio' 24 | - 'input_dataset_ids': List of strings 25 | - 'input_status_code': Integer 26 | 27 | Returns: 28 | Schema: The schema for rank and filter. 29 | """ 30 | # Schema for a single concept 31 | concept_schema = Schema( 32 | { 33 | Optional('value'): And(int, lambda x: x in [0, 1]), 34 | Optional('id'): And(str, len), 35 | Optional('language'): And(str, len), 36 | # Non-empty strings with internal dashes and underscores. 37 | Optional('name'): And(str, len, Regex(r'^[0-9A-Za-z]+([-_][0-9A-Za-z]+)*$')), 38 | } 39 | ) 40 | 41 | # Schema for a rank or filter item 42 | rank_filter_item_schema = Schema( 43 | { 44 | Optional('image_url'): And(str, Regex(r'^https?://')), 45 | Optional('text_raw'): And(str, len), 46 | Optional('metadata'): dict, 47 | Optional('image_bytes'): bytes, 48 | Optional('geo_point'): {'longitude': float, 'latitude': float, 'geo_limit': int}, 49 | Optional("concepts"): And( 50 | list, lambda x: all(concept_schema.is_valid(item) and len(item) > 0 for item in x) 51 | ), 52 | ## input filters 53 | Optional('input_types'): And( 54 | list, 55 | lambda input_types: all( 56 | input_type in ('image', 'video', 'text', 'audio') for input_type in input_types 57 | ), 58 | ), 59 | Optional('input_dataset_ids'): list, 60 | Optional('input_status_code'): int, 61 | } 62 | ) 63 | 64 | # Schema for rank and filter args 65 | return Schema([rank_filter_item_schema]) 66 | -------------------------------------------------------------------------------- /clarifai/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/utils/__init__.py -------------------------------------------------------------------------------- /clarifai/utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DEFAULT_UI = os.environ.get("CLARIFAI_UI", "https://clarifai.com") 4 | DEFAULT_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com") 5 | 6 | MCP_TRANSPORT_NAME = "mcp_transport" 7 | OPENAI_TRANSPORT_NAME = "openai_transport" 8 | 9 | CLARIFAI_PAT_ENV_VAR = "CLARIFAI_PAT" 10 | CLARIFAI_SESSION_TOKEN_ENV_VAR = "CLARIFAI_SESSION_TOKEN" 11 | CLARIFAI_USER_ID_ENV_VAR = "CLARIFAI_USER_ID" 12 | 13 | DEFAULT_CONFIG = f'{os.environ["HOME"]}/.config/clarifai/config' 14 | 15 | # Default clusters, etc. for local dev runner easy setup 16 | DEFAULT_LOCAL_DEV_COMPUTE_CLUSTER_ID = "local-dev-compute-cluster" 17 | DEFAULT_LOCAL_DEV_NODEPOOL_ID = "local-dev-nodepool" 18 | DEFAULT_LOCAL_DEV_DEPLOYMENT_ID = "local-dev-deployment" 19 | DEFAULT_LOCAL_DEV_MODEL_ID = "local-dev-model" 20 | DEFAULT_LOCAL_DEV_APP_ID = "local-dev-runner-app" 21 | 22 | # FIXME: should have any-to-any for these cases. 23 | DEFAULT_LOCAL_DEV_MODEL_TYPE = "text-to-text" 24 | 25 | DEFAULT_LOCAL_DEV_COMPUTE_CLUSTER_CONFIG = { 26 | "compute_cluster": { 27 | "id": DEFAULT_LOCAL_DEV_COMPUTE_CLUSTER_ID, 28 | "description": "Default Local Dev Compute Cluster", 29 | "cloud_provider": { 30 | "id": "local", 31 | }, 32 | "region": "na", 33 | "managed_by": "user", 34 | "cluster_type": "local-dev", 35 | } 36 | } 37 | 38 | DEFAULT_LOCAL_DEV_NODEPOOL_CONFIG = { 39 | "nodepool": { 40 | "id": DEFAULT_LOCAL_DEV_NODEPOOL_ID, 41 | "description": "Default Local Dev Nodepool", 42 | "compute_cluster": { 43 | "id": DEFAULT_LOCAL_DEV_COMPUTE_CLUSTER_ID, 44 | "user_id": None, # This will be set when creating the compute cluster 45 | }, 46 | "instance_types": [ 47 | { 48 | "id": "local-cpu", 49 | "compute_info": { 50 | "cpu_limit": str(os.cpu_count()), 51 | "cpu_memory": "16Gi", # made up as we don't schedule based on this for local dev. 52 | "num_accelerators": 0, # TODO if we need accelerator detection for local dev. 53 | }, 54 | } 55 | ], 56 | "node_capacity_type": { 57 | "capacity_types": [1], 58 | }, 59 | "min_instances": 1, 60 | "max_instances": 1, 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /clarifai/utils/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import EvalResultCompare 2 | 3 | __all__ = ["EvalResultCompare"] 4 | -------------------------------------------------------------------------------- /clarifai/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import uuid 4 | from typing import Any, Dict, List 5 | 6 | from clarifai_grpc.grpc.api.status import status_code_pb2 7 | 8 | from clarifai.errors import UserError 9 | 10 | RETRYABLE_CODES = [ 11 | status_code_pb2.MODEL_DEPLOYING, 12 | status_code_pb2.MODEL_LOADING, 13 | status_code_pb2.MODEL_BUSY_PLEASE_RETRY, 14 | ] 15 | 16 | DEFAULT_CONFIG = f'{os.environ["HOME"]}/.config/clarifai/config' 17 | 18 | 19 | def status_is_retryable(status_code: int) -> bool: 20 | """Check if a status code is retryable.""" 21 | return status_code in RETRYABLE_CODES 22 | 23 | 24 | class Chunker: 25 | """Split an input sequence into small chunks.""" 26 | 27 | def __init__(self, seq: List, size: int) -> None: 28 | self.seq = seq 29 | self.size = size 30 | 31 | def chunk(self) -> List[List]: 32 | """Chunk input sequence.""" 33 | return [self.seq[pos : pos + self.size] for pos in range(0, len(self.seq), self.size)] 34 | 35 | 36 | class BackoffIterator: 37 | """Iterator that returns a sequence of backoff values.""" 38 | 39 | def __init__(self, count=0): 40 | self.count = count 41 | 42 | def __iter__(self): 43 | return self 44 | 45 | def __next__(self): 46 | self.count += 1 47 | return 0.1 * (1.3**self.count) 48 | 49 | 50 | def get_from_dict_or_env(key: str, env_key: str, **data) -> str: 51 | """Get a value from a dictionary or an environment variable.""" 52 | if key in data and data[key]: 53 | return data[key] 54 | else: 55 | return get_from_env(key, env_key) 56 | 57 | 58 | def get_from_env(key: str, env_key: str) -> str: 59 | """Get a value from a dictionary or an environment variable.""" 60 | if env_key in os.environ and os.environ[env_key]: 61 | return os.environ[env_key] 62 | else: 63 | raise UserError( 64 | f"Did not find `{key}`, please add an environment variable" 65 | f" `{env_key}` which contains it, or pass" 66 | f" `{key}` as a named parameter." 67 | ) 68 | 69 | 70 | def concept_relations_accumulation( 71 | relations_dict: Dict[str, Any], subject_concept: str, object_concept: str, predicate: str 72 | ) -> Dict[str, Any]: 73 | """Append the concept relation to relations dict based on its predicate. 74 | 75 | Args: 76 | relations_dict (dict): A dict of concept relations info. 77 | """ 78 | if predicate == 'hyponym': 79 | if object_concept in relations_dict: 80 | relations_dict[object_concept].append(subject_concept) 81 | else: 82 | relations_dict[object_concept] = [subject_concept] 83 | elif predicate == 'hypernym': 84 | if subject_concept in relations_dict: 85 | relations_dict[subject_concept].append(object_concept) 86 | else: 87 | relations_dict[subject_concept] = [object_concept] 88 | else: 89 | relations_dict[object_concept] = [] 90 | relations_dict[subject_concept] = [] 91 | return relations_dict 92 | 93 | 94 | def get_uuid(val: int) -> str: 95 | """Generates a UUID.""" 96 | return uuid.uuid4().hex[:val] 97 | 98 | 99 | def clean_input_id(input_id: str) -> str: 100 | """Clean input_id string into a valid input ID""" 101 | input_id = re.sub('[., /]+', '_', input_id) 102 | input_id = re.sub('[_]+', '_', input_id) 103 | input_id = re.sub('[-]+', '-', input_id) 104 | input_id = input_id.lower().strip('_-') 105 | input_id = re.sub('[^a-z0-9-_]+', '', input_id) 106 | return input_id 107 | -------------------------------------------------------------------------------- /clarifai/versions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from clarifai import __version__ 4 | 5 | CLIENT_VERSION = __version__ 6 | OS_VER = os.sys.platform 7 | PYTHON_VERSION = '.'.join( 8 | map(str, [os.sys.version_info.major, os.sys.version_info.minor, os.sys.version_info.micro]) 9 | ) 10 | -------------------------------------------------------------------------------- /clarifai/workflows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/clarifai/workflows/__init__.py -------------------------------------------------------------------------------- /clarifai/workflows/export.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import yaml 4 | from google.protobuf.json_format import MessageToDict 5 | 6 | VALID_YAML_KEYS = ["workflow", "id", "nodes", "node_inputs", "node_id", "model"] 7 | 8 | 9 | def clean_up_unused_keys(wf: dict): 10 | """Removes unused keys from dict before exporting to yaml. Supports nested dicts.""" 11 | new_wf = dict() 12 | for key, val in wf.items(): 13 | if key not in VALID_YAML_KEYS: 14 | continue 15 | if key == "model": 16 | new_wf["model"] = { 17 | "model_id": wf["model"]["id"], 18 | "model_version_id": wf["model"]["model_version"]["id"], 19 | } 20 | # If the model is not from clarifai main, add the app_id and user_id to the model dict. 21 | if wf["model"]["user_id"] != "clarifai" and wf["model"]["app_id"] != "main": 22 | new_wf["model"].update( 23 | {"app_id": wf["model"]["app_id"], "user_id": wf["model"]["user_id"]} 24 | ) 25 | elif isinstance(val, dict): 26 | new_wf[key] = clean_up_unused_keys(val) 27 | elif isinstance(val, list): 28 | new_list = [] 29 | for i in val: 30 | new_list.append(clean_up_unused_keys(i)) 31 | new_wf[key] = new_list 32 | else: 33 | new_wf[key] = val 34 | return new_wf 35 | 36 | 37 | class Exporter: 38 | def __init__(self, workflow): 39 | self.wf = workflow 40 | 41 | def __enter__(self): 42 | return self 43 | 44 | def parse(self) -> Dict[str, Any]: 45 | """Reads a resources_pb2.Workflow object (e.g. from a GetWorkflow response) 46 | 47 | Returns: 48 | dict: A dict representation of the workflow. 49 | """ 50 | if isinstance(self.wf, list): 51 | self.wf = self.wf[0] 52 | wf = {"workflow": MessageToDict(self.wf, preserving_proto_field_name=True)} 53 | clean_wf = clean_up_unused_keys(wf) 54 | self.wf_dict = clean_wf 55 | return clean_wf 56 | 57 | def export(self, out_path): 58 | with open(out_path, 'w') as out_file: 59 | yaml.dump(self.wf_dict["workflow"], out_file, default_flow_style=False) 60 | 61 | def __exit__(self, *args): 62 | self.close() 63 | 64 | def close(self): 65 | del self.wf 66 | del self.wf_dict 67 | -------------------------------------------------------------------------------- /clarifai/workflows/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Set 2 | 3 | from clarifai_grpc.grpc.api import resources_pb2 4 | from google.protobuf import struct_pb2 5 | from google.protobuf.json_format import MessageToDict 6 | 7 | 8 | def get_yaml_output_info_proto(yaml_model_output_info: Dict) -> Optional[resources_pb2.OutputInfo]: 9 | """Converts a yaml model output info to an api model output info.""" 10 | if not yaml_model_output_info: 11 | return None 12 | 13 | return resources_pb2.OutputInfo( 14 | params=convert_yaml_params_to_api_params(yaml_model_output_info.get('params')) 15 | ) 16 | 17 | 18 | def convert_yaml_params_to_api_params(yaml_params: Dict) -> Optional[struct_pb2.Struct]: 19 | """Converts a yaml model output info params to an api model output info params.""" 20 | if not yaml_params: 21 | return None 22 | 23 | s = struct_pb2.Struct() 24 | s.update(yaml_params) 25 | 26 | return s 27 | 28 | 29 | def is_same_yaml_model(api_model: resources_pb2.Model, yaml_model: Dict) -> bool: 30 | """Compares a model from the API with a model from a yaml file.""" 31 | api_model = MessageToDict(api_model, preserving_proto_field_name=True) 32 | 33 | yaml_model_from_api = dict() 34 | for k, _ in yaml_model.items(): 35 | if k == "output_info" and api_model["model_version"].get("output_info", "") != "": 36 | yaml_model_from_api[k] = dict( 37 | params=api_model["model_version"]["output_info"].get("params") 38 | ) 39 | else: 40 | yaml_model_from_api[k] = api_model.get(k) 41 | yaml_model_from_api.update({"model_id": api_model.get("id")}) 42 | 43 | ignore_keys = {} 44 | 45 | return is_dict_in_dict(yaml_model, yaml_model_from_api, ignore_keys) 46 | 47 | 48 | def is_dict_in_dict(d1: Dict, d2: Dict, ignore_keys: Set = None) -> bool: 49 | """Compares two dicts recursively.""" 50 | for k, v in d1.items(): 51 | if ignore_keys and k in ignore_keys: 52 | continue 53 | if k not in d2: 54 | return False 55 | if isinstance(v, dict): 56 | if not isinstance(d2[k], dict): 57 | return False 58 | return is_dict_in_dict(d1[k], d2[k], None) 59 | elif v != d2[k]: 60 | return False 61 | 62 | return True 63 | -------------------------------------------------------------------------------- /clarifai/workflows/validate.py: -------------------------------------------------------------------------------- 1 | from schema import And, Optional, Regex, Schema, SchemaError, Use 2 | 3 | # Non-empty, up to 32-character ASCII strings with internal dashes and underscores. 4 | _id_validator = And(str, lambda s: 0 < len(s) <= 48, Regex(r'^[0-9A-Za-z]+([-_][0-9A-Za-z]+)*$')) 5 | 6 | # 32-character hex string, converted to lower-case. 7 | _hex_id_validator = And(str, Use(str.lower), Regex(r'^[0-9a-f]{32}')) 8 | 9 | 10 | def _model_does_not_have_model_version_id_and_other_fields(m): 11 | """Validate that model does not have model_version_id and other model fields.""" 12 | if ('model_version_id' in m) and _model_has_other_fields(m): 13 | raise SchemaError( 14 | f"model should not set model_version_id and other model fields: {m};" 15 | f" please remove model_version_id or other model fields." 16 | ) 17 | return True 18 | 19 | 20 | def _model_has_other_fields(m): 21 | return any(k not in ['model_id', 'model_version_id', 'user_id', 'app_id'] for k in m.keys()) 22 | 23 | 24 | def _workflow_nodes_have_valid_dependencies(nodes): 25 | """Validate that all inputs to a node are declared before it.""" 26 | node_ids = set() 27 | for node in nodes: 28 | for node_input in node.get("node_inputs", []): 29 | if node_input["node_id"] not in node_ids: 30 | raise SchemaError( 31 | f"missing input '{node_input['node_id']}' for node '{node['id']}'" 32 | ) 33 | node_ids.add(node["id"]) 34 | 35 | return True 36 | 37 | 38 | _data_schema = Schema( 39 | { 40 | "workflow": { 41 | "id": _id_validator, 42 | "nodes": And( 43 | len, 44 | [ 45 | { 46 | "id": And(str, len), # Node IDs are not validated as IDs by the API. 47 | "model": And( 48 | { 49 | "model_id": _id_validator, 50 | Optional("app_id"): _id_validator, 51 | Optional("user_id"): _id_validator, 52 | Optional("model_version_id"): _hex_id_validator, 53 | Optional("model_type_id"): _id_validator, 54 | Optional("description"): str, 55 | Optional("output_info"): { 56 | Optional("params"): dict, 57 | }, 58 | }, 59 | _model_does_not_have_model_version_id_and_other_fields, 60 | ), 61 | Optional("node_inputs"): And( 62 | len, 63 | [ 64 | { 65 | "node_id": And(str, len), 66 | } 67 | ], 68 | ), 69 | } 70 | ], 71 | _workflow_nodes_have_valid_dependencies, 72 | ), 73 | }, 74 | } 75 | ) 76 | 77 | 78 | def validate(data): 79 | return _data_schema.validate(data) 80 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "clarifai" 3 | requires-python = ">=3.8" 4 | dynamic = [ 5 | "version", 6 | "authors", 7 | "license", 8 | "classifiers", 9 | "scripts", 10 | "dependencies", 11 | "optional-dependencies", 12 | "readme" # For long_description 13 | ] 14 | 15 | [tool.pytest.ini_options] 16 | markers = ["requires_secrets: mark a test as requiring secrets to run", "coverage_only: mark a test as required to run for coverage purpose only"] 17 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit==2.20.0 2 | ruff==0.11.4 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clarifai-grpc>=11.3.4 2 | clarifai-protocol>=0.0.23 3 | numpy>=1.22.0 4 | tqdm>=4.65.0 5 | PyYAML>=6.0.1 6 | schema==0.7.5 7 | Pillow>=9.5.0 8 | tabulate>=0.9.0 9 | fsspec>=2024.6.1 10 | click>=8.1.7 11 | requests>=2.32.3 12 | aiohttp>=3.10.0 13 | -------------------------------------------------------------------------------- /scripts/key_for_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import json 5 | import os 6 | import sys 7 | 8 | try: 9 | from urllib.error import HTTPError 10 | from urllib.request import HTTPHandler, Request, build_opener 11 | except ImportError: 12 | from urllib2 import HTTPError, HTTPHandler, Request, build_opener 13 | 14 | EMAIL = os.environ["CLARIFAI_USER_EMAIL"] 15 | PASSWORD = os.environ["CLARIFAI_USER_PASSWORD"] 16 | 17 | 18 | def _assert_response_success(response): 19 | assert "status" in response, f"Invalid response {response}" 20 | assert "code" in response["status"], f"Invalid response {response}" 21 | assert response["status"]["code"] == 10000, f"Invalid response {response}" 22 | 23 | 24 | def _request(method, url, payload={}, headers={}): 25 | base_url = os.environ.get("CLARIFAI_GRPC_BASE", "api.clarifai.com") 26 | 27 | opener = build_opener(HTTPHandler) 28 | full_url = f"https://{base_url}/v2{url}" 29 | request = Request(full_url, data=json.dumps(payload).encode()) 30 | for k in headers.keys(): 31 | request.add_header(k, headers[k]) 32 | request.get_method = lambda: method 33 | try: 34 | response = opener.open(request).read().decode() 35 | except HTTPError as e: 36 | error_body = e.read().decode() 37 | try: 38 | error_body = json.dumps(json.loads(error_body), indent=4) 39 | except Exception: 40 | pass 41 | raise Exception( 42 | "ERROR after a HTTP request to: %s %s" % (method, full_url) 43 | + ". Response: %d %s:\n%s" % (e.code, e.reason, error_body) 44 | ) 45 | return json.loads(response) 46 | 47 | 48 | def login(): 49 | url = "/login" 50 | payload = {"email": EMAIL, "password": PASSWORD} 51 | data = _request(method="POST", url=url, payload=payload) 52 | _assert_response_success(data) 53 | 54 | assert "v2_user_id" in data, f"Invalid response {data}" 55 | user_id = data["v2_user_id"] 56 | assert user_id, f"Invalid response {data}" 57 | 58 | assert "session_token" in data, f"Invalid response {data}" 59 | session_token = data["session_token"] 60 | assert session_token, f"Invalid response {data}" 61 | 62 | return session_token, user_id 63 | 64 | 65 | def _auth_headers(session_token): 66 | headers = {"Content-Type": "application/json", "X-Clarifai-Session-Token": session_token} 67 | return headers 68 | 69 | 70 | def create_pat(): 71 | session_token, user_id = login() 72 | os.environ["CLARIFAI_USER_ID"] = user_id 73 | 74 | url = "/users/%s/keys" % user_id 75 | payload = { 76 | "keys": [ 77 | { 78 | "description": "Auto-created in a CI test run", 79 | "scopes": ["All"], 80 | "type": "personal_access_token", 81 | "apps": [], 82 | } 83 | ] 84 | } 85 | data = _request(method="POST", url=url, payload=payload, headers=_auth_headers(session_token)) 86 | _assert_response_success(data) 87 | 88 | assert "keys" in data, f"Invalid response {data}" 89 | assert len(data["keys"]) == 1, f"Invalid response {data}" 90 | assert "id" in data["keys"][0], f"Invalid response {data}" 91 | pat_id = data["keys"][0]["id"] 92 | assert pat_id, f"Invalid response {data}" 93 | 94 | # This print needs to be present so we can read the value in CI. 95 | print(pat_id) 96 | 97 | 98 | def run(arguments): 99 | if arguments.email: 100 | global EMAIL 101 | EMAIL = arguments.email # override the default testing email 102 | if arguments.password: 103 | global PASSWORD 104 | PASSWORD = arguments.password # override the default testing password 105 | # these options are mutually exclusive 106 | if arguments.create_pat: 107 | create_pat() 108 | elif arguments.get_userid: 109 | _, user_id = login() 110 | # This print needs to be present so we can read the value in CI. 111 | print(user_id) 112 | else: 113 | print( 114 | f"No relevant arguments specified. Run {sys.argv[0]} --help to see available options" 115 | ) 116 | sys.exit(1) 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser( 121 | description="Create Applications, Keys, and Workflows for testing." 122 | ) 123 | parser.add_argument( 124 | "--user-email", 125 | dest="email", 126 | help="The email of the account for which the command will run. (Defaults to ${CLARIFAI_USER_EMAIL})", 127 | ) 128 | parser.add_argument( 129 | "--user-password", 130 | dest="password", 131 | help="The password of the account for which the command will run. (Defaults to ${CLARIFAI_USER_PASSWORD})", 132 | ) 133 | group = parser.add_mutually_exclusive_group() 134 | group.add_argument("--create-pat", action="store_true", help=" Creates a new PAT key.") 135 | group.add_argument("--get-userid", action="store_true", help=" Gets the user id.") 136 | 137 | args = parser.parse_args() 138 | run(args) 139 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import setuptools 4 | 5 | with open("README.md", "r", encoding="utf-8") as fh: 6 | long_description = fh.read() 7 | 8 | with open("./clarifai/__init__.py", encoding="utf-8") as f: 9 | content = f.read() 10 | _search_version = re.search(r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', content) 11 | assert _search_version 12 | version = _search_version.group(1) 13 | 14 | with open("requirements.txt", "r", encoding="utf-8") as fh: 15 | install_requires = fh.read().split('\n') 16 | 17 | if install_requires and install_requires[-1] == '': 18 | # Remove the last empty line 19 | install_requires = install_requires[:-1] 20 | 21 | packages = setuptools.find_namespace_packages(include=["clarifai*"]) 22 | 23 | setuptools.setup( 24 | name="clarifai", 25 | version=f"{version}", 26 | author="Clarifai", 27 | author_email="support@clarifai.com", 28 | description="Clarifai Python SDK", 29 | long_description=long_description, 30 | long_description_content_type="text/markdown", 31 | url="https://github.com/Clarifai/clarifai-python", 32 | packages=packages, 33 | classifiers=[ 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | "Programming Language :: Python :: 3", 36 | "Programming Language :: Python :: 3 :: Only", 37 | "Programming Language :: Python :: 3.8", 38 | "Programming Language :: Python :: 3.9", 39 | "Programming Language :: Python :: 3.10", 40 | "Programming Language :: Python :: 3.11", 41 | "Programming Language :: Python :: 3.12", 42 | "Programming Language :: Python :: Implementation :: CPython", 43 | "License :: OSI Approved :: Apache Software License", 44 | "Operating System :: OS Independent", 45 | ], 46 | license="Apache 2.0", 47 | python_requires='>=3.8', 48 | install_requires=install_requires, 49 | extras_require={ 50 | 'all': ["pycocotools>=2.0.7"], 51 | }, 52 | entry_points={ 53 | "console_scripts": [ 54 | "clarifai = clarifai.cli.base:cli", 55 | ], 56 | }, 57 | include_package_data=True, 58 | ) 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/coco_detection/images/3176048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/coco_detection/images/3176048.jpg -------------------------------------------------------------------------------- /tests/assets/coco_detection/images/architectural-design-architecture-asphalt-2445783.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/coco_detection/images/architectural-design-architecture-asphalt-2445783.jpg -------------------------------------------------------------------------------- /tests/assets/coco_detection/images/architecture-buildings-commerce-2308592.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/coco_detection/images/architecture-buildings-commerce-2308592.jpg -------------------------------------------------------------------------------- /tests/assets/coco_detection/instances_default.json: -------------------------------------------------------------------------------- 1 | {"licenses":[{"name":"","id":0,"url":""}],"info":{"contributor":"","date_created":"","description":"","url":"","version":"","year":""},"categories":[{"id":1,"name":"person","supercategory":""}],"images":[{"id":5,"width":1152,"height":768,"file_name":"3176048.jpg","license":0,"flickr_url":"","coco_url":"","date_captured":0},{"id":17,"width":768,"height":1152,"file_name":"architectural-design-architecture-asphalt-2445783.jpg","license":0,"flickr_url":"","coco_url":"","date_captured":0},{"id":19,"width":1152,"height":768,"file_name":"architecture-buildings-commerce-2308592.jpg","license":0,"flickr_url":"","coco_url":"","date_captured":0}],"annotations":[{"id":8,"image_id":5,"category_id":1,"segmentation":[],"area":9711.203999999994,"bbox":[518.23,338.97,76.4,127.11],"iscrowd":0,"attributes":{"occluded":false,"rotation":0.0}},{"id":9,"image_id":5,"category_id":1,"segmentation":[],"area":-3624.5103999999947,"bbox":[683.42,356.48,-44.56,81.34],"iscrowd":0,"attributes":{"occluded":false,"rotation":0.0}},{"id":33,"image_id":17,"category_id":1,"segmentation":[],"area":22327.3641,"bbox":[316.45,644.41,83.19,268.39],"iscrowd":0,"attributes":{"occluded":false,"rotation":0.0}},{"id":35,"image_id":19,"category_id":1,"segmentation":[],"area":8162.334000000001,"bbox":[521.24,563.73,49.23,165.8],"iscrowd":0,"attributes":{"occluded":false,"rotation":0.0}},{"id":36,"image_id":19,"category_id":1,"segmentation":[],"area":8020.044800000009,"bbox":[469.43,561.98,48.76,164.48],"iscrowd":0,"attributes":{"occluded":false,"rotation":0.0}}]} 2 | -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/LOC_synset_mapping.txt: -------------------------------------------------------------------------------- 1 | n01855672 goose 2 | n02113799 standard poodle 3 | -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n01855672/n01855672_0.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n01855672/n01855672_0.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n01855672/n01855672_1.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n01855672/n01855672_1.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n01855672/n01855672_2.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n01855672/n01855672_2.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n01855672/n01855672_3.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n01855672/n01855672_3.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n01855672/n01855672_4.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n01855672/n01855672_4.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n02113799/n02113799_0.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n02113799/n02113799_0.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n02113799/n02113799_1.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n02113799/n02113799_1.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n02113799/n02113799_2.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n02113799/n02113799_2.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n02113799/n02113799_3.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n02113799/n02113799_3.JPEG -------------------------------------------------------------------------------- /tests/assets/imagenet_classification/train/n02113799/n02113799_4.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/imagenet_classification/train/n02113799/n02113799_4.JPEG -------------------------------------------------------------------------------- /tests/assets/red-truck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/red-truck.png -------------------------------------------------------------------------------- /tests/assets/sample.csv: -------------------------------------------------------------------------------- 1 | input,concepts 2 | "Now, I won't deny that when I purchased this off eBay, I had high expectations. This was an incredible out-of-print work from the master of comedy that I so enjoy. However, I was soon to be disappointed. Apologies to those who enjoyed it, but I just found the Compleat Al to be very difficult to watch. I got a few smiles, sure, but the majority of the funny came from the music videos (which I've got on DVD) and the rest was basically filler. You could tell that this was not Al's greatest video achievement (that honor goes to UHF). Honestly, I doubt if this will ever make the jump to DVD, so if you're an ultra-hardcore Al fan and just HAVE to own everything, buy the tape off eBay. Just don't pay too much for it.",neg 3 | "The saddest thing about this ""tribute"" is that almost all the singers (including the otherwise incredibly talented Nick Cave) seem to have missed the whole point where Cohen's intensity lies: by delivering his lines in an almost tuneless poise, Cohen transmits the full extent of his poetry, his irony, his all-round humanity, laughter and tears in one.

To see some of these singer upstarts make convoluted suffering faces, launch their pathetic squeals in the patent effort to scream ""I'm a singer!,"" is a true pain. It's the same feeling many of you probably had listening in to some horrendous operatic versions of simple songs such as Lennon's ""Imagine."" Nothing, simply nothing gets close to the simplicity and directness of the original. If there is a form of art that doesn't need embellishments, it's Cohen's art. Embellishments cast it in the street looking like the tasteless make-up of sex for sale.

In this Cohen's tribute I found myself suffering and suffering through pitiful tributes and awful reinterpretations, all of them entirely lacking the original irony of the master and, if truth be told, several of these singers sounded as if they had been recruited at some asylum talent show. It's Cohen doing a tribute to them by letting them sing his material, really, not the other way around: they may have been friends, or his daughter's, he could have become very tender-hearted and in the mood for a gift. Too bad it didn't stay in the family.

Fortunately, but only at the very end, Cohen himself performed his majestic ""Tower of Song,"" but even that flower was spoiled by the totally incongruous background of the U2, all of them carrying the expression that bored kids have when they visit their poor grandpa at the nursing home.

A sad show, really, and sadder if you truly love Cohen as I do.",neg 4 | "Last night I decided to watch the prequel or shall I say the so called prequel to Carlito's Way - ""Carlito's Way: Rise to Power (2005)"" which went straight to DVD...no wonder .....it completely ...and I mean completely S%&KS !!! waist of time watching it and I think it would be a pure waist of time writing about it.... I don't understand how De Palma agreed on producing this sh#t-fest of a movie....except for only one fact that I tip my hat to... Jay Hernandez who plays the young Brigante.... reminded me how De Niro got into the shoes of Brando to portray the young Don Corleone in Godfather II ...but the difference De Niro was amazing and even got an Oscar for it !!! Jay Hernandez well he has guts for trying to be a young Pacino.... too bad for him I don't think he will be playing in film anymore and by the way after I watched this sh#$%ty movie, I sat down and watched the original Carlitos way to get the bad taste out of my mouth.",neg 5 | "I have to admit that i liked the first half of Sleepers. It looked good, the acting was even better, the story of childhood, pain and revenge was interesting and moving. A superior hollywood film. But...No one mentioned this so far (at least in the latest 20 comments), when it came to the courtroom scenes and Brat Pitt´s character followed his plan to rescue his two friends, who are rightly accused of murder, i felt cheated. This movie insulted my intelligence.

Warning spoilers!!

Why did anyone accept their false alibi, witnessed by the priest? If these two guys had been with him, why shouldn´t they tell this during the investigation? Amnesia? If you were the judge or member of the jury, would you believe it? Is it wise to give the motif of the murderers away?

I am sorry, but in the end, the story is very weak, and this angers me. This movie had great potential. 4/10",neg 6 | "I was not impressed about this film especially for the fact that I went to the cinema with my family in good faith to see a film which was certificate rated 12A here in the UK. To my dismay, this film was full of embarrassing sexual jokes. (Which is not a problem to me as an adult, but not good for watching with children). This film at times was very crude at times with fart jokes, getting hit in the groin etc... and for the most part of the film not very funny.

The premise of the film is that Calvin Sims who is a 2inch midget, gets out of jail and steals a giant sized diamond but is then forced to put it in a womens handbag. So the rest of the movie sees him passing himself off as an abandoned baby, getting into this womens house so he can get this diamond back.

Up until now, I have enjoyed most of the output from the Wayans Brothers - but this film is certainly taking the biscuit.

A Bit of good advice - wait till it comes on TV or Cable",pos 7 | -------------------------------------------------------------------------------- /tests/assets/sample.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/sample.mp3 -------------------------------------------------------------------------------- /tests/assets/sample.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/sample.mp4 -------------------------------------------------------------------------------- /tests/assets/sample.txt: -------------------------------------------------------------------------------- 1 | Hello, I'm looking to buy a new smartphone. Could you recommend a good smartphone with an excellent camera? 2 | -------------------------------------------------------------------------------- /tests/assets/sample_texts/sample1.txt: -------------------------------------------------------------------------------- 1 | Text Sample 1 2 | -------------------------------------------------------------------------------- /tests/assets/sample_texts/sample2.txt: -------------------------------------------------------------------------------- 1 | Text Sample 2 2 | -------------------------------------------------------------------------------- /tests/assets/sample_texts/sample3.txt: -------------------------------------------------------------------------------- 1 | Text Sample 3 2 | -------------------------------------------------------------------------------- /tests/assets/test/zorua.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/test/zorua.png -------------------------------------------------------------------------------- /tests/assets/test/zubat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/test/zubat.png -------------------------------------------------------------------------------- /tests/assets/test/zweilous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/test/zweilous.png -------------------------------------------------------------------------------- /tests/assets/voc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/__init__.py -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2007_000464.xml: -------------------------------------------------------------------------------- 1 | 2 | VOC2012 3 | 2007_000464.jpg 4 | 5 | The VOC2007 Database 6 | PASCAL VOC2007 7 | flickr 8 | 9 | 10 | 375 11 | 500 12 | 3 13 | 14 | 1 15 | 16 | cow 17 | Left 18 | 0 19 | 0 20 | 21 | 71 22 | 252 23 | 216 24 | 314 25 | 26 | 27 | 28 | cow 29 | Left 30 | 0 31 | 0 32 | 33 | 58 34 | 202 35 | 241 36 | 295 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2008_000853.xml: -------------------------------------------------------------------------------- 1 | 2 | VOC2012 3 | 2008_000853.jpg 4 | 5 | The VOC2008 Database 6 | PASCAL VOC2008 7 | flickr 8 | 9 | 10 | 375 11 | 500 12 | 3 13 | 14 | 1 15 | 16 | cat 17 | Frontal 18 | 0 19 | 1 20 | 21 | 37 22 | 345 23 | 186 24 | 417 25 | 26 | 0 27 | 28 | 29 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2008_003182.xml: -------------------------------------------------------------------------------- 1 | 2 | VOC2012 3 | 2008_003182.jpg 4 | 5 | The VOC2008 Database 6 | PASCAL VOC2008 7 | flickr 8 | 9 | 10 | 333 11 | 500 12 | 3 13 | 14 | 0 15 | 16 | person 17 | Left 18 | 1 19 | 0 20 | 21 | 181 22 | 187 23 | 319 24 | 500 25 | 26 | 0 27 | 28 | 29 | horse 30 | Unspecified 31 | 1 32 | 1 33 | 34 | 71 35 | 181 36 | 333 37 | 500 38 | 39 | 0 40 | 41 | 42 | bottle 43 | Unspecified 44 | 0 45 | 1 46 | 47 | 211 48 | 312 49 | 235 50 | 358 51 | 52 | 0 53 | 54 | 55 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2008_008526.xml: -------------------------------------------------------------------------------- 1 | 2 | VOC2012 3 | 2008_008526.jpg 4 | 5 | The VOC2008 Database 6 | PASCAL VOC2008 7 | flickr 8 | 9 | 10 | 500 11 | 375 12 | 3 13 | 14 | 0 15 | 16 | sofa 17 | Frontal 18 | 1 19 | 1 20 | 21 | 1 22 | 152 23 | 500 24 | 375 25 | 26 | 0 27 | 28 | 29 | bottle 30 | Unspecified 31 | 0 32 | 1 33 | 34 | 162 35 | 232 36 | 197 37 | 338 38 | 39 | 0 40 | 41 | 42 | person 43 | Frontal 44 | 1 45 | 1 46 | 47 | 66 48 | 138 49 | 352 50 | 375 51 | 52 | 0 53 | 54 | 55 | person 56 | Frontal 57 | 1 58 | 1 59 | 60 | 234 61 | 141 62 | 500 63 | 375 64 | 65 | 0 66 | 67 | 68 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2009_004315.xml: -------------------------------------------------------------------------------- 1 | 2 | 2009_004315.jpg 3 | VOC2012 4 | 5 | bird 6 | 7 | 461 8 | 219 9 | 327 10 | 51 11 | 12 | 0 13 | 0 14 | Left 15 | 0 16 | 17 | 0 18 | 19 | 3 20 | 463 21 | 500 22 | 23 | 24 | PASCAL VOC2009 25 | The VOC2009 Database 26 | flickr 27 | 28 | 29 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2009_004382.xml: -------------------------------------------------------------------------------- 1 | 2 | 2009_004382.jpg 3 | VOC2012 4 | 5 | cat 6 | 7 | 264 8 | 50 9 | 500 10 | 1 11 | 12 | 0 13 | 0 14 | Frontal 15 | 0 16 | 17 | 0 18 | 19 | 3 20 | 500 21 | 313 22 | 23 | 24 | PASCAL VOC2009 25 | The VOC2009 Database 26 | flickr 27 | 28 | 29 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2011_000430.xml: -------------------------------------------------------------------------------- 1 | 2 | 2011_000430.jpg 3 | VOC2012 4 | 5 | dog 6 | 7 | 500 8 | 17 9 | 365 10 | 33 11 | 12 | 0 13 | 0 14 | Left 15 | 1 16 | 17 | 0 18 | 19 | 3 20 | 375 21 | 500 22 | 23 | 24 | PASCAL VOC2011 25 | The VOC2011 Database 26 | flickr 27 | 28 | 29 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2011_001610.xml: -------------------------------------------------------------------------------- 1 | 2 | 2011_001610.jpg 3 | VOC2012 4 | 5 | person 6 | 7 | 283 8 | 113 9 | 390 10 | 112 11 | 12 | 0 13 | 1 14 | Frontal 15 | 1 16 | 17 | hand 18 | 19 | 128 20 | 231 21 | 183 22 | 283 23 | 24 | 25 | 26 | head 27 | 28 | 123 29 | 115 30 | 209 31 | 226 32 | 33 | 34 | 35 | 0 36 | 37 | 3 38 | 500 39 | 375 40 | 41 | 42 | PASCAL VOC2011 43 | The VOC2011 Database 44 | flickr 45 | 46 | 47 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2011_006412.xml: -------------------------------------------------------------------------------- 1 | 2 | 2011_006412.jpg 3 | VOC2011 4 | 5 | person 6 | 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 0 16 | 0 17 | 1 18 | 19 | 20 | 58 21 | 1 22 | 375 23 | 173 24 | 25 | 0 26 | Unspecified 27 | 28 | 18 29 | 257 30 | 31 | 32 | 33 | person 34 | 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 1 42 | 0 43 | 0 44 | 0 45 | 0 46 | 47 | 48 | 223 49 | 123 50 | 330 51 | 106 52 | 53 | 0 54 | Unspecified 55 | 56 | 177 57 | 183 58 | 59 | 60 | 61 | person 62 | 63 | 0 64 | 0 65 | 0 66 | 0 67 | 0 68 | 0 69 | 1 70 | 0 71 | 0 72 | 0 73 | 0 74 | 75 | 76 | 397 77 | 293 78 | 303 79 | 111 80 | 81 | 0 82 | Unspecified 83 | 84 | 350 85 | 178 86 | 87 | 88 | 0 89 | 90 | 3 91 | 375 92 | 500 93 | 94 | 95 | PASCAL VOC2011 96 | The VOC2011 Database 97 | flickr 98 | 99 | 100 | -------------------------------------------------------------------------------- /tests/assets/voc/annotations/2012_000690.xml: -------------------------------------------------------------------------------- 1 | 2 | 2012_000690.jpg 3 | VOC2012 4 | 5 | person 6 | 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 1 16 | 0 17 | 0 18 | 19 | 20 | 142 21 | 63 22 | 375 23 | 137 24 | 25 | 0 26 | Unspecified 27 | 28 | 93 29 | 221 30 | 31 | 32 | 0 33 | 34 | 3 35 | 375 36 | 500 37 | 38 | 39 | PASCAL VOC2012 40 | The VOC2012 Database 41 | flickr 42 | 43 | 44 | -------------------------------------------------------------------------------- /tests/assets/voc/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from clarifai.datasets.upload.base import ClarifaiDataLoader 5 | from clarifai.datasets.upload.features import VisualDetectionFeatures 6 | 7 | 8 | class VOCDetectionDataLoader(ClarifaiDataLoader): 9 | """PASCAL VOC 2012 Image Detection Dataset.""" 10 | 11 | voc_concepts = [ 12 | 'aeroplane', 13 | 'bicycle', 14 | 'bird', 15 | 'boat', 16 | 'bottle', 17 | 'bus', 18 | 'car', 19 | 'cat', 20 | 'chair', 21 | 'cow', 22 | 'diningtable', 23 | 'dog', 24 | 'horse', 25 | 'motorbike', 26 | 'person', 27 | 'pottedplant', 28 | 'sheep', 29 | 'sofa', 30 | 'train', 31 | 'tvmonitor', 32 | ] 33 | 34 | def __init__(self, split: str = "train"): 35 | self.split = split 36 | self.image_dir = {"train": os.path.join(os.path.dirname(__file__), "images")} 37 | self.annotations_dir = {"train": os.path.join(os.path.dirname(__file__), "annotations")} 38 | self.annotations = [] 39 | 40 | self.load_data() 41 | 42 | @property 43 | def task(self): 44 | return "visual_detection" 45 | 46 | def load_data(self): 47 | all_imgs = os.listdir(self.image_dir[self.split]) 48 | img_ids = [img_filename.split('.')[0] for img_filename in all_imgs] 49 | 50 | for _id in img_ids: 51 | annot_path = os.path.join(self.annotations_dir[self.split], _id + ".xml") 52 | root = ET.parse(annot_path).getroot() 53 | 54 | annots = [] 55 | class_names = [] 56 | for obj in root.iter('object'): 57 | concept = obj.find('name').text.strip().lower() 58 | if concept not in self.voc_concepts: 59 | continue 60 | xml_box = obj.find('bndbox') 61 | width = float(root.find('size').find('width').text) 62 | height = float(root.find('size').find('height').text) 63 | 64 | # Making bounding box to be 0-1 65 | x_min = max(min((float(xml_box.find('xmin').text) - 1) / width, 1.0), 0.0) 66 | y_min = max(min((float(xml_box.find('ymin').text) - 1) / height, 1.0), 0.0) 67 | x_max = max(min((float(xml_box.find('xmax').text) - 1) / width, 1.0), 0.0) 68 | y_max = max(min((float(xml_box.find('ymax').text) - 1) / height, 1.0), 0.0) 69 | 70 | if (x_min >= x_max) or (y_min >= y_max): 71 | continue 72 | annots.append([x_min, y_min, x_max, y_max]) 73 | class_names.append(concept) 74 | 75 | assert len(class_names) == len(annots), ( 76 | f"Num classes must match num bbox annotations\ 77 | for a single image. Found {len(class_names)} classes and {len(annots)} bboxes." 78 | ) 79 | 80 | self.annotations.append( 81 | { 82 | "image_id": _id, 83 | "image_path": os.path.join(self.image_dir[self.split], _id + ".jpg"), 84 | "class_names": class_names, 85 | "annots": annots, 86 | } 87 | ) 88 | 89 | def __getitem__(self, idx): 90 | annot = self.annotations[idx] 91 | image_path = annot["image_path"] 92 | class_names = annot["class_names"] 93 | annots = annot["annots"] 94 | return VisualDetectionFeatures(image_path, class_names, annots, id=annot["image_id"]) 95 | 96 | def __len__(self): 97 | return len(self.annotations) 98 | -------------------------------------------------------------------------------- /tests/assets/voc/images/2007_000464.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2007_000464.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2008_000853.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2008_000853.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2008_003182.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2008_003182.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2008_008526.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2008_008526.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2009_004315.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2009_004315.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2009_004382.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2009_004382.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2011_000430.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2011_000430.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2011_001610.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2011_001610.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2011_006412.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2011_006412.jpg -------------------------------------------------------------------------------- /tests/assets/voc/images/2012_000690.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clarifai/clarifai-python/9377514c621dd62692a9abb97016442425062d18/tests/assets/voc/images/2012_000690.jpg -------------------------------------------------------------------------------- /tests/compute_orchestration/configs/example_compute_cluster_config.yaml: -------------------------------------------------------------------------------- 1 | compute_cluster: 2 | id: "test-aws-cluster" 3 | description: "My AWS cluster" 4 | cloud_provider: 5 | id: "aws" 6 | region: "us-east-1" 7 | managed_by: "clarifai" 8 | cluster_type: "dedicated" 9 | visibility: 10 | gettable: 10 11 | -------------------------------------------------------------------------------- /tests/compute_orchestration/configs/example_deployment_config.yaml: -------------------------------------------------------------------------------- 1 | deployment: 2 | id: "my_string_cat_8_thread_dep" 3 | description: "some random deployment" 4 | autoscale_config: 5 | min_replicas: 0 6 | max_replicas: 1 7 | traffic_history_seconds: 100 8 | scale_down_delay_seconds: 30 9 | scale_up_delay_seconds: 30 10 | scale_to_zero_delay_seconds: 50 11 | disable_packing: false 12 | worker: 13 | model: 14 | id: "python_string_cat" 15 | model_version: 16 | id: "b7038e059a0c4ddca29c22aec561824d" 17 | user_id: "clarifai" 18 | app_id: "Test-Model-Upload" 19 | scheduling_choice: 4 20 | nodepools: 21 | - id: "test-nodepool-6" 22 | compute_cluster: 23 | id: "test-aws-cluster" 24 | -------------------------------------------------------------------------------- /tests/compute_orchestration/configs/example_nodepool_config.yaml: -------------------------------------------------------------------------------- 1 | nodepool: 2 | id: "test-nodepool-6" 3 | compute_cluster: 4 | id: "test-aws-cluster" 5 | description: "First nodepool in AWS in a proper compute cluster" 6 | instance_types: 7 | - id: "g5.xlarge" 8 | compute_info: 9 | cpu_limit: "8" 10 | cpu_memory: "16Gi" 11 | accelerator_type: 12 | - "a10" 13 | num_accelerators: 1 14 | accelerator_memory: "40Gi" 15 | node_capacity_type: 16 | capacity_types: 17 | - 1 18 | max_instances: 1 19 | min_instances: 0 20 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Conftest for OpenAI tests.""" 2 | 3 | import sys 4 | from unittest import mock 5 | 6 | # Create mock modules 7 | mock_fastmcp = mock.MagicMock() 8 | mock_fastmcp.Client = mock.MagicMock() 9 | mock_fastmcp.FastMCP = mock.MagicMock() 10 | 11 | mock_mcp = mock.MagicMock() 12 | mock_mcp.types = mock.MagicMock() 13 | mock_mcp.shared = mock.MagicMock() 14 | mock_mcp.shared.exceptions = mock.MagicMock() 15 | mock_mcp.shared.exceptions.McpError = Exception 16 | 17 | # Mock the fastmcp and mcp modules 18 | sys.modules['fastmcp'] = mock_fastmcp 19 | sys.modules['mcp'] = mock_mcp 20 | sys.modules['mcp.shared'] = mock_mcp.shared 21 | sys.modules['mcp.shared.exceptions'] = mock_mcp.shared.exceptions 22 | -------------------------------------------------------------------------------- /tests/openai_model_test.py: -------------------------------------------------------------------------------- 1 | """Tests for OpenAI class. 2 | 3 | This test uses conftest.py in the same directory to set up mocks for dependencies. 4 | """ 5 | 6 | import json 7 | import os 8 | import sys 9 | 10 | # Import pytest after sys.modules updates in conftest.py 11 | import pytest 12 | 13 | # Add the base directory to the path to allow direct imports 14 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) 15 | 16 | # Now we can import our module 17 | from clarifai.runners.models.dummy_openai_model import DummyOpenAIModel 18 | from clarifai.runners.models.model_class import ModelClass 19 | from clarifai.runners.models.openai_class import OpenAIModelClass 20 | 21 | 22 | class TestOpenAIModelClass: 23 | def test_inheritance(self): 24 | """Test that OpenAIModelClass inherits from ModelClass.""" 25 | assert issubclass(OpenAIModelClass, ModelClass) 26 | 27 | def test_abstract_method(self): 28 | """Test that has `client` attribute.""" 29 | with pytest.raises(NotImplementedError): 30 | OpenAIModelClass().client 31 | 32 | def test_dummy_model(self): 33 | """Test that DummyOpenAIModel works.""" 34 | model = DummyOpenAIModel() 35 | assert isinstance(model, OpenAIModelClass) 36 | 37 | client = model.client 38 | assert client is not None 39 | assert hasattr(client, 'chat') 40 | assert hasattr(client, 'completions') 41 | 42 | def test_transport_method_non_streaming(self): 43 | """Test the openai_transport method with non-streaming.""" 44 | model = DummyOpenAIModel() 45 | model.load_model() 46 | 47 | request = { 48 | "model": "test-model", 49 | "messages": [{"role": "user", "content": "Hello world"}], 50 | "stream": False, 51 | } 52 | 53 | response = model.openai_transport(json.dumps(request)) 54 | data = json.loads(response) 55 | 56 | # Verify response structure 57 | assert "id" in data 58 | assert "created" in data 59 | assert "model" in data 60 | assert "choices" in data 61 | assert len(data["choices"]) > 0 62 | assert "message" in data["choices"][0] 63 | assert "content" in data["choices"][0]["message"] 64 | assert "Echo: Hello world" in data["choices"][0]["message"]["content"] 65 | assert "usage" in data 66 | 67 | def test_transport_method_streaming(self): 68 | """Test the openai_transport method with streaming.""" 69 | model = DummyOpenAIModel() 70 | model.load_model() 71 | 72 | request = { 73 | "model": "test-model", 74 | "messages": [{"role": "user", "content": "Hello world"}], 75 | "stream": True, 76 | } 77 | 78 | response = model.openai_stream_transport(json.dumps(request)) 79 | data = [json.loads(resp) for resp in response] 80 | 81 | assert isinstance(data, list) 82 | assert len(data) > 0 83 | 84 | # Check first chunk for content 85 | first_chunk = data[0] 86 | assert "id" in first_chunk 87 | assert "created" in first_chunk 88 | assert "model" in first_chunk 89 | assert "choices" in first_chunk 90 | assert len(first_chunk["choices"]) > 0 91 | assert "delta" in first_chunk["choices"][0] 92 | assert "content" in first_chunk["choices"][0]["delta"] 93 | assert "Echo: Hello world" in first_chunk["choices"][0]["delta"]["content"] 94 | 95 | # Check remaining chunks for structure 96 | for chunk in data[1:]: 97 | assert "id" in chunk 98 | assert "created" in chunk 99 | assert "model" in chunk 100 | assert "choices" in chunk 101 | assert len(chunk["choices"]) > 0 102 | assert "delta" in chunk["choices"][0] 103 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==7.1.2 2 | pytest-cov==5.0.0 3 | pytest-xdist==2.5.0 4 | llama-index-core==0.12.33 5 | huggingface_hub[hf_transfer]==0.27.1 6 | pypdf==3.17.4 7 | seaborn==0.13.2 8 | pycocotools>=2.0.7 9 | rich>=13.4.2 10 | -------------------------------------------------------------------------------- /tests/runners/dummy_mcp_model/1/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from fastmcp import FastMCP # use fastmcp v2 not the built in mcp 4 | from pydantic import Field 5 | 6 | from clarifai.runners.models.mcp_class import MCPModelClass 7 | 8 | server = FastMCP("my-first-mcp-server", instructions="", stateless_http=True) 9 | 10 | 11 | @server.tool("calculate_sum", description="Add two numbers together") 12 | def sum(a: Any = Field(description="first number"), b: Any = Field(description="second number")): 13 | return float(a) + float(b) 14 | 15 | 16 | # Static resource 17 | @server.resource("config://version") 18 | def get_version(): 19 | return "2.0.1" 20 | 21 | 22 | @server.prompt() 23 | def summarize_request(text: str) -> str: 24 | """Generate a prompt asking for a summary.""" 25 | return f"Please summarize the following text:\n\n{text}" 26 | 27 | 28 | class MyModelClass(MCPModelClass): 29 | def get_server(self) -> FastMCP: 30 | return server 31 | -------------------------------------------------------------------------------- /tests/runners/dummy_mcp_model/config.yaml: -------------------------------------------------------------------------------- 1 | # This is the sample config file for a MCP model. 2 | 3 | model: 4 | id: "dummy-runner-model" 5 | user_id: "user_id" 6 | app_id: "app_id" 7 | model_type_id: "mcp" 8 | 9 | build_info: 10 | python_version: "3.12" 11 | 12 | inference_compute_info: 13 | cpu_limit: "1" 14 | cpu_memory: "1Gi" 15 | num_accelerators: 0 16 | -------------------------------------------------------------------------------- /tests/runners/dummy_mcp_model/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | requests 3 | clarifai 4 | huggingface_hub 5 | fastmcp>=2.3.4 6 | -------------------------------------------------------------------------------- /tests/runners/dummy_runner_models/1/model.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | from clarifai.runners.models.model_class import ModelClass 4 | from clarifai.runners.utils.data_types import Text 5 | 6 | 7 | class MyModel(ModelClass): 8 | """A custom runner that adds "Hello World" to the end of the text.""" 9 | 10 | def load_model(self): 11 | """Load the model here.""" 12 | 13 | @ModelClass.method 14 | def predict(self, text1: Text = "") -> Text: 15 | """This is the method that will be called when the runner is run. It takes in an input and 16 | returns an output. 17 | """ 18 | 19 | output_text = text1.text + "Hello World" 20 | 21 | return Text(output_text) 22 | 23 | @ModelClass.method 24 | def generate(self, text1: Text = Text("")) -> Iterator[Text]: 25 | """Example yielding a whole batch of streamed stuff back.""" 26 | 27 | for i in range(10): # fake something iterating generating 10 times. 28 | output_text = text1.text + f"Generate Hello World {i}" 29 | yield Text(output_text) 30 | 31 | @ModelClass.method 32 | def stream(self, input_iterator: Iterator[Text]) -> Iterator[Text]: 33 | """Example yielding a whole batch of streamed stuff back.""" 34 | 35 | for i, input in enumerate(input_iterator): 36 | output_text = input.text + f"Stream Hello World {i}" 37 | yield Text(output_text) 38 | 39 | def test(self): 40 | res = self.predict(Text("test")) 41 | assert res.text == "testHello World" 42 | 43 | res = self.generate(Text("test")) 44 | for i, r in enumerate(res): 45 | assert r.text == f"testGenerate Hello World {i}" 46 | 47 | res = self.stream(iter([Text("test")] * 5)) 48 | for i, r in enumerate(res): 49 | assert r.text == f"testStream Hello World {i}" 50 | -------------------------------------------------------------------------------- /tests/runners/dummy_runner_models/config.yaml: -------------------------------------------------------------------------------- 1 | # This is the sample config file for the GOT OCR2.O model. 2 | 3 | model: 4 | id: "dummy-runner-model" 5 | user_id: "user_id" 6 | app_id: "app_id" 7 | model_type_id: "multimodal-to-text" 8 | 9 | build_info: 10 | python_version: "3.12" 11 | 12 | inference_compute_info: 13 | cpu_limit: "1" 14 | cpu_memory: "1Gi" 15 | num_accelerators: 0 16 | 17 | 18 | checkpoints: 19 | type: "huggingface" 20 | repo_id: "timm/mobilenetv3_small_100.lamb_in1k" 21 | -------------------------------------------------------------------------------- /tests/runners/dummy_runner_models/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | requests 3 | clarifai 4 | huggingface_hub 5 | -------------------------------------------------------------------------------- /tests/runners/hf_mbart_model/1/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Iterator 3 | 4 | import torch 5 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 6 | 7 | from clarifai.runners.models.model_class import ModelClass 8 | 9 | 10 | class MyModel(ModelClass): 11 | """A custom runner that loads the model and generates text using lmdeploy inference.""" 12 | 13 | def load_model(self): 14 | """Load the model here""" 15 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints") 17 | 18 | # if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time. 19 | self.tokenizer = AutoTokenizer.from_pretrained(checkpoints) 20 | self.model = AutoModelForSeq2SeqLM.from_pretrained( 21 | checkpoints, torch_dtype="auto", device_map=self.device 22 | ) 23 | 24 | @ModelClass.method 25 | def predict(self, prompt: str = "") -> str: 26 | """This is the method that will be called when the runner is run. It takes in an input and 27 | returns an output. 28 | """ 29 | inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) 30 | outputs = self.model.generate(inputs) 31 | output_text = self.tokenizer.decode(outputs[0]) 32 | return output_text 33 | 34 | def generate(self, prompt: str = "") -> Iterator[str]: 35 | """Example yielding a whole batch of streamed stuff back.""" 36 | raise NotImplementedError("This method is not implemented yet.") 37 | 38 | def stream(self, input_iterator: Iterator[str]) -> Iterator[str]: 39 | """Example yielding a whole batch of streamed stuff back.""" 40 | raise NotImplementedError("This method is not implemented yet.") 41 | -------------------------------------------------------------------------------- /tests/runners/hf_mbart_model/config.yaml: -------------------------------------------------------------------------------- 1 | # Config file for the VLLM runner 2 | 3 | model: 4 | id: "hf-mbart-model" 5 | user_id: "user_id" 6 | app_id: "app_id" 7 | model_type_id: "text-to-text" 8 | 9 | build_info: 10 | python_version: "3.12" 11 | 12 | inference_compute_info: 13 | cpu_limit: "500m" 14 | cpu_memory: "500Mi" 15 | num_accelerators: 0 16 | 17 | checkpoints: 18 | type: "huggingface" 19 | repo_id: "sshleifer/tiny-mbart" 20 | when: "build" 21 | -------------------------------------------------------------------------------- /tests/runners/hf_mbart_model/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.6.0 2 | blobfile 3 | clarifai 4 | requests 5 | sentencepiece==0.2.0 6 | tiktoken==0.9.0 7 | tokenizers==0.21.1 8 | torch==2.6.0 9 | transformers==4.51.3 10 | -------------------------------------------------------------------------------- /tests/runners/test_data_handler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from clarifai_grpc.grpc.api import resources_pb2 5 | from PIL import Image 6 | 7 | from clarifai.runners.utils.data_utils import image_to_bytes 8 | 9 | IMAGE = np.ones([50, 50, 3], dtype="uint8") 10 | AUDIO = b"000" 11 | TEXT = "ABC" 12 | CONCEPTS = dict(a=0.0, b=0.2, c=1.0) 13 | EMBEDDINGS = [0.1, 1.1, 2.0] 14 | 15 | INPUT_DATA_PROTO = resources_pb2.Input( 16 | data=resources_pb2.Data( 17 | image=resources_pb2.Image(base64=image_to_bytes(Image.fromarray(IMAGE))), 18 | text=resources_pb2.Text(raw=TEXT), 19 | audio=resources_pb2.Audio(base64=AUDIO), 20 | ) 21 | ) 22 | 23 | 24 | class TestDataHandler(unittest.TestCase): 25 | def test_input_proto_to_python(self): 26 | pass 27 | 28 | def test_output_python_to_proto(self): 29 | pass 30 | -------------------------------------------------------------------------------- /tests/runners/test_download_checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | import pytest 6 | 7 | from clarifai.runners.models.model_builder import ModelBuilder 8 | from clarifai.runners.utils.loader import HuggingFaceLoader 9 | 10 | MODEL_ID = "timm/mobilenetv3_small_100.lamb_in1k" 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def checkpoint_dir(): 15 | # Create a temporary directory for the test checkpoints 16 | temp_dir = os.path.join(tempfile.gettempdir(), MODEL_ID[5:]) 17 | if not os.path.exists(temp_dir): 18 | os.makedirs(temp_dir) 19 | yield temp_dir # Provide the directory to the tests 20 | # Cleanup: remove the directory after all tests are complete 21 | shutil.rmtree(temp_dir, ignore_errors=True) 22 | 23 | 24 | # Pytest fixture to delete the checkpoints in dummy runner models folder after tests complete 25 | @pytest.fixture(scope="function") 26 | def dummy_runner_models_dir(): 27 | model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models") 28 | checkpoints_path = os.path.join(model_folder_path, "1", "checkpoints") 29 | yield checkpoints_path 30 | # Cleanup the checkpoints folder after the test 31 | if os.path.exists(checkpoints_path): 32 | shutil.rmtree(checkpoints_path) 33 | 34 | 35 | @pytest.fixture(scope="function", autouse=True) 36 | def override_environment_variables(): 37 | # Backup the existing environment variable value 38 | original_clarifai_pat = os.environ.get("CLARIFAI_PAT") 39 | if "CLARIFAI_PAT" in os.environ: 40 | del os.environ["CLARIFAI_PAT"] # Temporarily unset the variable for the tests 41 | yield 42 | # Restore the original environment variable value after tests 43 | if original_clarifai_pat: 44 | os.environ["CLARIFAI_PAT"] = original_clarifai_pat 45 | 46 | 47 | def test_loader_download_checkpoints(checkpoint_dir): 48 | loader = HuggingFaceLoader(repo_id=MODEL_ID) 49 | loader.download_checkpoints(checkpoint_path=checkpoint_dir) 50 | assert len(os.listdir(checkpoint_dir)) == 4 51 | 52 | 53 | def test_validate_download(checkpoint_dir): 54 | loader = HuggingFaceLoader(repo_id=MODEL_ID) 55 | assert ( 56 | loader.validate_download( 57 | checkpoint_path=checkpoint_dir, allowed_file_patterns=None, ignore_file_patterns=None 58 | ) 59 | is True 60 | ) 61 | 62 | 63 | def test_download_checkpoints(dummy_runner_models_dir): 64 | # This doesn't have when in it's config.yaml so runtime. 65 | model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models") 66 | model_builder = ModelBuilder(model_folder_path, download_validation_only=True) 67 | # defaults to runtime stage which matches config.yaml not having a when field. 68 | # get whatever stage is in config.yaml to force download now 69 | # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses 70 | _, _, _, when, _, _ = model_builder._validate_config_checkpoints() 71 | checkpoint_dir = model_builder.download_checkpoints( 72 | stage=when, checkpoint_path_override=model_builder.checkpoint_path 73 | ) 74 | assert checkpoint_dir == model_builder.checkpoint_path 75 | 76 | # This doesn't have when in it's config.yaml so build. 77 | model_folder_path = os.path.join(os.path.dirname(__file__), "hf_mbart_model") 78 | model_builder = ModelBuilder(model_folder_path, download_validation_only=True) 79 | # defaults to runtime stage which matches config.yaml not having a when field. 80 | # get whatever stage is in config.yaml to force download now 81 | # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses 82 | _, _, _, when, _, _ = model_builder._validate_config_checkpoints() 83 | checkpoint_dir = model_builder.download_checkpoints( 84 | stage=when, checkpoint_path_override=model_builder.checkpoint_path 85 | ) 86 | assert checkpoint_dir == os.path.join( 87 | os.path.dirname(__file__), "hf_mbart_model", "1", "checkpoints" 88 | ) 89 | -------------------------------------------------------------------------------- /tests/runners/test_model_classes.py: -------------------------------------------------------------------------------- 1 | """Test cases for MCPModelClass and OpenAIModelClass.""" 2 | 3 | import json 4 | 5 | import pytest 6 | 7 | from clarifai.runners.models.dummy_openai_model import DummyOpenAIModel 8 | from clarifai.runners.models.mcp_class import MCPModelClass 9 | from clarifai.runners.models.openai_class import OpenAIModelClass 10 | 11 | 12 | class TestModelClasses: 13 | """Tests for model classes.""" 14 | 15 | def test_mcp_model_initialization(self): 16 | """Test that MCPModelClass requires subclass implementation.""" 17 | # Test that subclass must implement get_server() 18 | with pytest.raises(NotImplementedError): 19 | MCPModelClass().get_server() 20 | 21 | def test_openai_model_initialization(self): 22 | """Test that OpenAIModelClass can be initialized.""" 23 | model = DummyOpenAIModel() 24 | assert isinstance(model, OpenAIModelClass) 25 | 26 | # Test that subclass must have `client` attribute 27 | with pytest.raises(NotImplementedError): 28 | OpenAIModelClass().client 29 | 30 | # Test that client has required attributes 31 | client = model.client 32 | assert hasattr(client, 'chat') 33 | assert hasattr(client, 'completions') 34 | 35 | def test_openai_transport_non_streaming(self): 36 | """Test OpenAI transport method with non-streaming request.""" 37 | model = DummyOpenAIModel() 38 | model.load_model() 39 | 40 | # Create a simple chat request 41 | request = { 42 | "model": "gpt-3.5-turbo", 43 | "messages": [ 44 | {"role": "system", "content": "You are a helpful assistant."}, 45 | {"role": "user", "content": "Hello, world!"}, 46 | ], 47 | "stream": False, 48 | } 49 | 50 | # Call the transport method 51 | response_str = model.openai_transport(json.dumps(request)) 52 | response = json.loads(response_str) 53 | 54 | # Verify response structure 55 | assert "id" in response 56 | assert "created" in response 57 | assert "model" in response 58 | assert "choices" in response 59 | assert len(response["choices"]) > 0 60 | assert "message" in response["choices"][0] 61 | assert "content" in response["choices"][0]["message"] 62 | assert "Echo: Hello, world!" in response["choices"][0]["message"]["content"] 63 | assert "usage" in response 64 | 65 | def test_openai_transport_streaming(self): 66 | """Test OpenAI transport method with streaming request.""" 67 | model = DummyOpenAIModel() 68 | model.load_model() 69 | 70 | # Create a simple chat request with streaming 71 | request = { 72 | "model": "gpt-3.5-turbo", 73 | "messages": [ 74 | {"role": "system", "content": "You are a helpful assistant."}, 75 | {"role": "user", "content": "Hello, world!"}, 76 | ], 77 | "stream": True, 78 | } 79 | 80 | # Call the transport method 81 | response = model.openai_stream_transport(json.dumps(request)) 82 | response_chunks = [json.loads(resp) for resp in response] 83 | 84 | assert isinstance(response_chunks, list) 85 | assert len(response_chunks) > 0 86 | 87 | # Check first chunk for content 88 | first_chunk = response_chunks[0] 89 | assert "id" in first_chunk 90 | assert "created" in first_chunk 91 | assert "model" in first_chunk 92 | assert "choices" in first_chunk 93 | assert len(first_chunk["choices"]) > 0 94 | assert "delta" in first_chunk["choices"][0] 95 | assert "content" in first_chunk["choices"][0]["delta"] 96 | assert "Echo: Hello, world!" in first_chunk["choices"][0]["delta"]["content"] 97 | 98 | def test_custom_method(self): 99 | """Test custom method on the DummyOpenAIModel.""" 100 | model = DummyOpenAIModel() 101 | result = model.test_method("test input") 102 | assert result == "Test: test input" 103 | -------------------------------------------------------------------------------- /tests/runners/test_num_threads_config.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | import pytest 5 | import yaml 6 | 7 | from clarifai.runners.models.model_builder import ModelBuilder 8 | 9 | 10 | @pytest.fixture 11 | def my_tmp_path(tmp_path): 12 | return tmp_path 13 | 14 | 15 | @pytest.mark.parametrize("num_threads", [-1, 0, 3, 1.5, "a", None]) 16 | def test_num_threads(my_tmp_path, num_threads, monkeypatch): 17 | """ 18 | Clone dummy_runner_models with different num_threads settings for testing 19 | """ 20 | tests_dir = Path(__file__).parent.resolve() 21 | original_dummy_path = tests_dir / "dummy_runner_models" 22 | if not original_dummy_path.exists(): 23 | # Adjust or raise an error if you cannot locate the dummy_runner_models folder 24 | raise FileNotFoundError( 25 | f"Could not find dummy_runner_models at {original_dummy_path}. " 26 | "Adjust path or ensure it exists." 27 | ) 28 | 29 | # Copy the entire folder to tmp_path 30 | target_folder = my_tmp_path / "dummy_runner_models" 31 | shutil.copytree(original_dummy_path, target_folder) 32 | 33 | # Update the config.yaml to override the app_id with the ephemeral one 34 | config_yaml_path = target_folder / "config.yaml" 35 | with config_yaml_path.open("r") as f: 36 | config = yaml.safe_load(f) 37 | 38 | monkeypatch.delenv("CLARIFAI_NUM_THREADS", raising=False) 39 | if num_threads is not None: 40 | config["num_threads"] = num_threads 41 | 42 | # Rewrite config.yaml 43 | with config_yaml_path.open("w") as f: 44 | yaml.dump(config, f, sort_keys=False) 45 | 46 | # no num_threads 47 | if num_threads is None: 48 | # default is 16 49 | builder = ModelBuilder(target_folder, validate_api_ids=False) 50 | assert builder.config.get("num_threads") == 16 51 | # set by env var if unset in config.yaml 52 | monkeypatch.setenv("CLARIFAI_NUM_THREADS", "4") 53 | builder = ModelBuilder(target_folder, validate_api_ids=False) 54 | assert builder.config.get("num_threads") == 4 55 | 56 | elif num_threads == 3: 57 | builder = ModelBuilder(target_folder, validate_api_ids=False) 58 | assert builder.config.get("num_threads") == num_threads 59 | 60 | # set by env var if unset in config.yaml 61 | monkeypatch.setenv("CLARIFAI_NUM_THREADS", "14") 62 | builder = ModelBuilder(target_folder, validate_api_ids=False) 63 | assert builder.config.get("num_threads") == num_threads 64 | 65 | elif num_threads in [-1, 0, "a", 1.5]: 66 | with pytest.raises(AssertionError): 67 | builder = ModelBuilder(target_folder, validate_api_ids=False) 68 | -------------------------------------------------------------------------------- /tests/runners/test_url_fetcher.py: -------------------------------------------------------------------------------- 1 | from clarifai_grpc.grpc.api import resources_pb2, service_pb2 2 | 3 | from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded 4 | 5 | image_url = "https://samples.clarifai.com/metro-north.jpg" 6 | audio_url = "https://samples.clarifai.com/GoodMorning.wav" 7 | text_url = "https://samples.clarifai.com/negative_sentence_12.txt" 8 | 9 | 10 | def test_url_fetcher(): 11 | request = service_pb2.PostModelOutputsRequest( 12 | model_id="model_id", 13 | version_id="version_id", 14 | user_app_id=resources_pb2.UserAppIDSet(user_id="user_id", app_id="app_id"), 15 | inputs=[ 16 | resources_pb2.Input( 17 | data=resources_pb2.Data( 18 | image=resources_pb2.Image(url=image_url), 19 | text=resources_pb2.Text(url=text_url), 20 | audio=resources_pb2.Audio(url=audio_url), 21 | ), 22 | ), 23 | ], 24 | ) 25 | ensure_urls_downloaded(request) 26 | for input in request.inputs: 27 | assert input.data.image.base64 and len(input.data.image.base64) == 70911, ( 28 | f"Expected length of of image base64 to be 70911, but got {len(input.data.image.base64)}" 29 | ) 30 | assert input.data.audio.base64 and len(input.data.audio.base64) == 200406, ( 31 | f"Expected length of of audio base64 to be 200406, but got {len(input.data.audio.base64)}" 32 | ) 33 | assert input.data.text.raw and len(input.data.text.raw) == 35, ( 34 | f"Expected length of of text raw to be 35, but got {len(input.data.text.raw)}" 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test_auth.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | from unittest.mock import Mock 3 | 4 | import pytest 5 | 6 | from clarifai.client.auth.helper import ClarifaiAuthHelper, clear_cache 7 | 8 | 9 | @pytest.fixture(autouse=True) 10 | def clear_caches(): 11 | clear_cache() 12 | 13 | 14 | def test_ui_default_url(): 15 | default = ClarifaiAuthHelper("clarifai", "main", "fake_pat") 16 | assert default.ui == "https://clarifai.com" 17 | assert default.pat == "fake_pat" 18 | 19 | 20 | @pytest.mark.parametrize( 21 | ("input_url", "expected_url"), 22 | ( 23 | ("http://localhost:3002", "http://localhost:3002"), 24 | ("https://localhost:3002", "https://localhost:3002"), 25 | ("https://clarifai.com", "https://clarifai.com"), 26 | ), 27 | ) 28 | def test_ui_urls(input_url, expected_url): 29 | helper = ClarifaiAuthHelper("clarifai", "main", "fake_pat", ui=input_url) 30 | assert helper.ui == expected_url 31 | 32 | 33 | def test_passing_no_schema_url_use_https_when_server_is_running(): 34 | def raise_exception(): 35 | return Mock() 36 | 37 | with mock.patch('urllib.request.urlopen', new_callable=raise_exception): 38 | helper = ClarifaiAuthHelper("clarifai", "main", "fake_pat", ui="server") 39 | assert helper.ui == "https://server" 40 | 41 | 42 | def test_passing_no_schema_url_show_error_when_not_server_running(): 43 | def raise_exception(): 44 | return Mock(side_effect=Exception("http_exception")) 45 | 46 | with mock.patch('urllib.request.urlopen', new_callable=raise_exception): 47 | with pytest.raises( 48 | Exception, 49 | match="Could not get a valid response from url: localhost:3002, is the API running there?", 50 | ): 51 | ClarifaiAuthHelper("clarifai", "main", "fake_pat", ui="localhost:3002") 52 | 53 | 54 | def test_passing_no_schema_url_detect_http_when_SSL_in_error(): 55 | def raise_exception(): 56 | return Mock(side_effect=Exception("Has SSL in error")) 57 | 58 | with mock.patch('urllib.request.urlopen', new_callable=raise_exception): 59 | helper = ClarifaiAuthHelper("clarifai", "main", "fake_pat", ui="localhost:3002") 60 | assert helper.ui == "http://localhost:3002" 61 | 62 | 63 | def test_passing_no_schema_url_require_port(): 64 | def raise_exception(): 65 | return Mock(side_effect=Exception("Has SSL in error")) 66 | 67 | with mock.patch('urllib.request.urlopen', new_callable=raise_exception): 68 | with pytest.raises( 69 | Exception, match="When providing an insecure url it must have both host:port format" 70 | ): 71 | ClarifaiAuthHelper("clarifai", "main", "fake_pat", ui="localhost") 72 | 73 | 74 | def test_exception_empty_user(): 75 | ClarifaiAuthHelper("", "main", "fake_pat", validate=False) 76 | with pytest.raises( 77 | Exception, 78 | match="Need 'user_id' to not be empty in the query params or user CLARIFAI_USER_ID env var", 79 | ): 80 | ClarifaiAuthHelper("", "main", "fake_pat") 81 | 82 | 83 | def test_exception_empty_pat(): 84 | ClarifaiAuthHelper("clarifai", "main", "", validate=False) 85 | with pytest.raises( 86 | Exception, 87 | match="Need 'pat' or 'token' in the query params or use one of the CLARIFAI_PAT or CLARIFAI_SESSION_TOKEN env vars", 88 | ): 89 | ClarifaiAuthHelper("clarifai", "main", "") 90 | 91 | 92 | def test_exception_path_root_cert(): 93 | ClarifaiAuthHelper( 94 | "clarifai", "main", "fake_pat", root_certificates_path='fake_file.crt', validate=False 95 | ) 96 | with pytest.raises(Exception, match="Root certificates path fake_file.crt does not exist"): 97 | ClarifaiAuthHelper("clarifai", "main", "fake_pat", root_certificates_path='fake_file.crt') 98 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | 5 | from clarifai.utils.logging import ( 6 | JsonFormatter, 7 | TerminalFormatter, 8 | _get_library_name, 9 | get_logger, 10 | set_logger_context, 11 | ) 12 | 13 | 14 | def test_get_logger(): 15 | logger = get_logger("DEBUG", "test_logger") 16 | assert logger.level == logging.DEBUG 17 | assert logger.name == "test_logger" 18 | assert isinstance(logger.handlers[0], logging.StreamHandler) 19 | assert isinstance(logger.handlers[0].formatter, TerminalFormatter) 20 | 21 | 22 | def test_get_logger_defaults(): 23 | logger = get_logger() 24 | assert logger.level == logging.NOTSET 25 | assert logger.name == _get_library_name() 26 | assert isinstance(logger.handlers[0], logging.StreamHandler) 27 | assert isinstance(logger.handlers[0].formatter, TerminalFormatter) 28 | 29 | 30 | def test_get_json_logger_defaults(monkeypatch): 31 | # with env setting of ENABLE_JSON_LOGGER to true 32 | # we should get a json logger. 33 | monkeypatch.setenv("ENABLE_JSON_LOGGER", "true") 34 | logger = get_logger() 35 | assert logger.level == logging.NOTSET 36 | assert logger.name == _get_library_name() 37 | assert isinstance(logger.handlers[0], logging.StreamHandler) 38 | assert isinstance(logger.handlers[0].formatter, JsonFormatter) 39 | 40 | 41 | def test_json_logger(): 42 | filename = "testy.py" 43 | msg = "testy" 44 | lineno = 1 45 | sinfo = "testf2" 46 | r = logging.LogRecord( 47 | name="testy", 48 | level=logging.INFO, 49 | pathname=filename, 50 | lineno=lineno, 51 | msg=msg, 52 | args=(), 53 | exc_info=None, 54 | func="testf", 55 | sinfo=sinfo, 56 | ) 57 | jf = JsonFormatter() 58 | # format the record as a json line. 59 | json_line = jf.format(r) 60 | result = json.loads(json_line) 61 | # parse timestamp of format "2024-09-24T22:06:49.573038Z" in @timestamp field. 62 | assert result["@timestamp"] is not None 63 | ts = result["@timestamp"] 64 | # assert the ts was within 10 seconds of now (in UTC time) 65 | assert abs( 66 | datetime.datetime.utcnow() - datetime.datetime.strptime(ts, "%Y-%m-%dT%H:%M:%S.%fZ") 67 | ) < datetime.timedelta(seconds=10) 68 | assert result['filename'] == filename 69 | assert result['msg'] == msg 70 | assert result['lineno'] == lineno 71 | assert result['stack_info'] == sinfo 72 | assert result['level'] == "info" 73 | assert 'req_id' not in result 74 | 75 | req_id = "test_req_id" 76 | set_logger_context(req_id=req_id) 77 | json_line = jf.format(r) 78 | result = json.loads(json_line) 79 | assert abs( 80 | datetime.datetime.utcnow() - datetime.datetime.strptime(ts, "%Y-%m-%dT%H:%M:%S.%fZ") 81 | ) < datetime.timedelta(seconds=10) 82 | assert result['filename'] == filename 83 | assert result['msg'] == msg 84 | assert result['lineno'] == lineno 85 | assert result['stack_info'] == sinfo 86 | assert result['level'] == "info" 87 | assert result['req_id'] == req_id 88 | -------------------------------------------------------------------------------- /tests/test_rag.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import pytest 5 | 6 | from clarifai.client import User 7 | from clarifai.rag import RAG 8 | from clarifai.urls.helper import ClarifaiUrlHelper 9 | 10 | CREATE_APP_USER_ID = os.environ["CLARIFAI_USER_ID"] 11 | 12 | TEXT_FILE_PATH = os.path.dirname(__file__) + "/assets/sample.txt" 13 | PDF_URL = "https://samples.clarifai.com/test_doc.pdf" 14 | 15 | CLARIFAI_API_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com") 16 | 17 | 18 | def client(): 19 | return User(user_id=CREATE_APP_USER_ID, base_url=CLARIFAI_API_BASE) 20 | 21 | 22 | @pytest.mark.requires_secrets 23 | class TestRAG: 24 | @classmethod 25 | def setup_class(self): 26 | self.rag = RAG.setup(user_id=CREATE_APP_USER_ID, base_url=CLARIFAI_API_BASE) 27 | wf = self.rag._prompt_workflow 28 | self.workflow_url = ClarifaiUrlHelper().clarifai_url( 29 | wf.user_id, wf.app_id, "workflows", wf.id 30 | ) 31 | 32 | def test_setup_correct(self): 33 | assert len(self.rag._prompt_workflow.workflow_info.nodes) == 2 34 | 35 | def test_from_existing_workflow(self): 36 | agent = RAG(workflow_url=self.workflow_url) 37 | assert agent._app.id == self.rag._app.id 38 | 39 | def test_predict_client_manage_state(self): 40 | messages = [{"role": "human", "content": "What is 1 + 1?"}] 41 | new_messages = self.rag.chat(messages, client_manage_state=True) 42 | assert len(new_messages) == 2 43 | 44 | @pytest.mark.skip(reason="Not yet supported. Work in progress.") 45 | def test_predict_server_manage_state(self): 46 | messages = [{"role": "human", "content": "What is 1 + 1?"}] 47 | new_messages = self.rag.chat(messages) 48 | assert len(new_messages) == 1 49 | 50 | def test_upload_docs_filepath(self, caplog): 51 | with caplog.at_level(logging.INFO): 52 | self.rag.upload(file_path=TEXT_FILE_PATH) 53 | assert "SUCCESS" in caplog.text 54 | 55 | def test_upload_docs_from_url(self, caplog): 56 | with caplog.at_level(logging.INFO): 57 | self.rag.upload(url=PDF_URL) 58 | assert "SUCCESS" in caplog.text 59 | 60 | @classmethod 61 | def teardown_class(self): 62 | client().delete_app(self.rag._app.id) 63 | -------------------------------------------------------------------------------- /tests/test_stub.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import grpc 4 | import pytest 5 | from clarifai_grpc.grpc.api import service_pb2 6 | from clarifai_grpc.grpc.api.status import status_code_pb2 7 | 8 | from clarifai.client.auth.helper import ClarifaiAuthHelper, clear_cache 9 | from clarifai.client.auth.stub import AuthorizedStub, RetryStub 10 | 11 | 12 | class MockRpcError(grpc.RpcError): 13 | pass 14 | 15 | 16 | @pytest.fixture(autouse=True) 17 | def clear_caches(): 18 | clear_cache() 19 | 20 | 21 | def test_auth_unary_unary(): 22 | auth = ClarifaiAuthHelper("clarifai", "main", "fake_pat") 23 | stub = AuthorizedStub(auth) 24 | with mock.patch.object(stub.stub, 'ListInputs', spec=stub.stub.ListInputs) as mock_f: 25 | req = service_pb2.ListInputsRequest() 26 | req.user_app_id.app_id = 'test_auth_unary_unary' 27 | stub.ListInputs(req) 28 | mock_f.assert_called_with(req, metadata=auth.metadata) 29 | 30 | 31 | def test_auth_unary_unary_future(): 32 | auth = ClarifaiAuthHelper("clarifai", "main", "fake_pat") 33 | stub = AuthorizedStub(auth) 34 | with mock.patch.object(stub.stub, 'ListInputs', spec=stub.stub.ListInputs) as mock_f: 35 | req = service_pb2.ListInputsRequest() 36 | req.user_app_id.app_id = 'test_auth_unary_unary_future' 37 | stub.ListInputs.future(req) 38 | mock_f.future.assert_called_with(req, metadata=auth.metadata) 39 | 40 | 41 | def test_auth_unary_stream(): 42 | auth = ClarifaiAuthHelper("clarifai", "main", "fake_pat") 43 | stub = AuthorizedStub(auth) 44 | with mock.patch.object(stub.stub, 'StreamInputs', spec=stub.stub.StreamInputs) as mock_f: 45 | req = service_pb2.StreamInputsRequest() 46 | req.user_app_id.app_id = 'test_auth_unary_stream' 47 | stub.StreamInputs(req) 48 | mock_f.assert_called_with(req, metadata=auth.metadata) 49 | 50 | 51 | def test_retry_unary_unary(): 52 | max_attempts = 5 53 | auth = ClarifaiAuthHelper("clarifai", "main", "fake_pat") 54 | stub = RetryStub(AuthorizedStub(auth), max_attempts=max_attempts, backoff_time=0.0001) 55 | retry_response = service_pb2.MultiInputResponse() 56 | retry_response.status.code = status_code_pb2.CONN_THROTTLED 57 | success_response = service_pb2.MultiInputResponse() 58 | success_response.status.code = status_code_pb2.SUCCESS 59 | for nfailures in range(0, max_attempts + 1): 60 | mock_resps = [retry_response] * nfailures + [success_response] 61 | with mock.patch.object( 62 | stub.stub, 'ListInputs', spec=stub.stub.stub.ListInputs, side_effect=mock_resps 63 | ) as mock_f: 64 | req = service_pb2.ListInputsRequest() 65 | req.user_app_id.app_id = 'test_retry_unary_unary' 66 | res = stub.ListInputs(req) 67 | assert mock_f.call_count == min(max_attempts, len(mock_resps)) 68 | if nfailures < max_attempts: 69 | assert res is success_response 70 | else: 71 | assert res is retry_response 72 | 73 | 74 | def test_retry_grpcconn_unary_unary(): 75 | max_attempts = 5 76 | auth = ClarifaiAuthHelper("clarifai", "main", "fake_pat") 77 | stub = RetryStub(AuthorizedStub(auth), max_attempts=max_attempts, backoff_time=0.0001) 78 | retry_response = service_pb2.MultiInputResponse() 79 | retry_response.status.code = status_code_pb2.CONN_THROTTLED 80 | success_response = service_pb2.MultiInputResponse() 81 | success_response.status.code = status_code_pb2.SUCCESS 82 | error = MockRpcError() 83 | error.code = lambda: grpc.StatusCode.UNAVAILABLE 84 | for nfailures in range(0, max_attempts + 1): 85 | mock_resps = [error] * nfailures + [success_response] 86 | with mock.patch.object( 87 | stub.stub, 'ListInputs', spec=stub.stub.stub.ListInputs, side_effect=mock_resps 88 | ) as mock_f: 89 | req = service_pb2.ListInputsRequest() 90 | req.user_app_id.app_id = 'test_retry_unary_unary' 91 | try: 92 | res = stub.ListInputs(req) 93 | except Exception as e: 94 | res = e 95 | assert mock_f.call_count == min(max_attempts, len(mock_resps)) 96 | if nfailures < max_attempts: 97 | assert res is success_response 98 | else: 99 | assert res is error 100 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/general.yml: -------------------------------------------------------------------------------- 1 | workflow: 2 | id: General 3 | nodes: 4 | - id: general-v1.5-concept 5 | model: 6 | model_id: general-image-recognition 7 | model_version_id: aa7f35c01e0642fda5cf400f543e7c40 8 | - id: general-v1.5-embed 9 | model: 10 | model_id: general-image-embedding 11 | model_version_id: bb186755eda04f9cbb6fe32e816be104 12 | - id: general-v1.5-cluster 13 | model: 14 | model_id: general-clusterering 15 | model_version_id: cc2074cff6dc4c02b6f4e1b8606dcb54 16 | node_inputs: 17 | - node_id: general-v1.5-embed 18 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/multi_branch.yml: -------------------------------------------------------------------------------- 1 | workflow: 2 | id: test-mb 3 | nodes: 4 | - id: detector 5 | model: 6 | model_id: face-detection 7 | model_version_id: fe995da8cb73490f8556416ecf25cea3 8 | - id: moderation 9 | model: 10 | model_id: moderation-recognition 11 | model_version_id: aa8be956dbaa4b7a858826a84253cab9 12 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/single_branch_with_custom_cropper_model-version.yml: -------------------------------------------------------------------------------- 1 | workflow: 2 | id: test-sb 3 | nodes: 4 | - id: detector 5 | model: 6 | model_id: face-detection 7 | model_version_id: fe995da8cb73490f8556416ecf25cea3 8 | - id: cropper 9 | model: 10 | model_id: margin-100-image-crop-custom # Uses the same model ID as the other workflow with custom cropper model 11 | model_type_id: image-crop 12 | description: Custom crop model 13 | output_info: 14 | params: 15 | margin: 1.5 # Uses different margin than previous model to trigger the creation of a new model version. 16 | node_inputs: 17 | - node_id: detector 18 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/single_branch_with_custom_cropper_model.yml: -------------------------------------------------------------------------------- 1 | workflow: 2 | id: test-sb 3 | nodes: 4 | - id: detector 5 | model: 6 | model_id: face-detection 7 | model_version_id: fe995da8cb73490f8556416ecf25cea3 8 | - id: cropper 9 | model: 10 | model_id: margin-100-image-crop-custom # such a model ID does not exist, so it will be created using the below model fields 11 | model_type_id: image-crop 12 | description: Custom crop model 13 | output_info: 14 | params: 15 | margin: 1.33 16 | node_inputs: 17 | - node_id: detector 18 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/single_branch_with_public_cropper_model.yml: -------------------------------------------------------------------------------- 1 | workflow: 2 | id: test-sb 3 | nodes: 4 | - id: detector 5 | model: 6 | model_id: face-detection 7 | model_version_id: fe995da8cb73490f8556416ecf25cea3 8 | - id: cropper 9 | model: 10 | model_id: margin-110-image-crop 11 | model_version_id: b9987421b40a46649566826ef9325303 12 | node_inputs: 13 | - node_id: detector 14 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/single_branch_with_public_cropper_model_and_latest_version.yml: -------------------------------------------------------------------------------- 1 | workflow: 2 | id: test-sb 3 | nodes: 4 | - id: detector 5 | model: 6 | model_id: a403429f2ddf4b49b307e318f00e528b 7 | model_version_id: 34ce21a40cc24b6b96ffee54aabff139 8 | - id: cropper 9 | model: 10 | model_id: margin-110-image-crop 11 | 12 | node_inputs: 13 | - node_id: detector 14 | -------------------------------------------------------------------------------- /tests/workflow/fixtures/single_node.yml: -------------------------------------------------------------------------------- 1 | # A single node workflow 2 | workflow: 3 | id: test-sn 4 | nodes: 5 | - id: detector 6 | model: 7 | model_id: face-detection 8 | model_version_id: fe995da8cb73490f8556416ecf25cea3 9 | -------------------------------------------------------------------------------- /tests/workflow/test_create_delete.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import typing 5 | import uuid 6 | 7 | import pytest 8 | 9 | from clarifai.client.user import User 10 | 11 | NOW = uuid.uuid4().hex[:10] 12 | CREATE_APP_USER_ID = os.environ["CLARIFAI_USER_ID"] 13 | CREATE_APP_ID = f"test_workflow_create_delete_app_{NOW}" 14 | 15 | # assets 16 | IMAGE_URL = "https://samples.clarifai.com/metro-north.jpg" 17 | 18 | CLARIFAI_API_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com") 19 | 20 | 21 | def get_test_parse_workflow_creation_workflows() -> typing.List[str]: 22 | filenames = glob.glob("tests/workflow/fixtures/*.yml") 23 | return filenames 24 | 25 | 26 | @pytest.mark.requires_secrets 27 | class TestWorkflowCreate: 28 | @classmethod 29 | def setup_class(cls): 30 | cls.client = User(user_id=CREATE_APP_USER_ID, base_url=CLARIFAI_API_BASE) 31 | try: 32 | cls.app = cls.client.create_app(app_id=CREATE_APP_ID, base_workflow="Empty") 33 | except Exception as e: 34 | if "already exists" in str(e): 35 | cls.app = cls.client.app(app_id=CREATE_APP_ID) 36 | 37 | @pytest.mark.parametrize("filename", get_test_parse_workflow_creation_workflows()) 38 | def test_parse_workflow_creation(self, filename: str, caplog): 39 | with caplog.at_level(logging.INFO): 40 | if "general.yml" in filename: 41 | generate_new_id = False 42 | else: 43 | generate_new_id = True 44 | self.app.create_workflow(filename, generate_new_id=generate_new_id) 45 | assert "Workflow created" in caplog.text 46 | 47 | def test_patch_workflow(self, caplog): 48 | with caplog.at_level(logging.INFO): 49 | workflow_id = list(self.app.list_workflows())[0].id 50 | self.app.patch_workflow( 51 | workflow_id=workflow_id, 52 | config_filepath='tests/workflow/fixtures/general.yml', 53 | visibility=10, 54 | description='Workflow Patching Test', 55 | notes='Workflow Patching Test', 56 | image_url=IMAGE_URL, 57 | ) 58 | assert "Workflow patched" in caplog.text 59 | 60 | def test_delete_workflow(self, caplog): 61 | with caplog.at_level(logging.INFO): 62 | self.app.delete_workflow("General") 63 | assert "Workflow Deleted" in caplog.text 64 | 65 | @classmethod 66 | def teardown_class(cls): 67 | cls.client.delete_app(app_id=CREATE_APP_ID) 68 | -------------------------------------------------------------------------------- /tests/workflow/test_export.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import yaml 5 | 6 | from clarifai.client.workflow import Workflow 7 | 8 | CLARIFAI_API_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com") 9 | 10 | 11 | @pytest.mark.requires_secrets 12 | def test_export_workflow_general(): 13 | workflow = Workflow( 14 | workflow_id="General", user_id="clarifai", app_id="main", base_url=CLARIFAI_API_BASE 15 | ) 16 | 17 | workflow.export('tests/workflow/fixtures/export_general.yml') 18 | # assert this to the reader result 19 | with open('tests/workflow/fixtures/general.yml', 'r') as file: 20 | expected_data = yaml.safe_load(file) 21 | with open('tests/workflow/fixtures/export_general.yml', 'r') as file: 22 | actual_data = yaml.safe_load(file) 23 | assert actual_data == expected_data, f"dicts did not match: actual: {actual_data}" 24 | 25 | # cleanup 26 | os.remove('tests/workflow/fixtures/export_general.yml') 27 | -------------------------------------------------------------------------------- /tests/workflow/test_nodes_display.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | from rich.console import Console 5 | from rich.tree import Tree 6 | 7 | 8 | def get_workflow_tree_test_data() -> typing.List[typing.Dict]: 9 | test_data = [ 10 | { # Single Branch Single Node 11 | "adjacency_dict": {"Input": [1]}, 12 | "expected_pattern": r""" 13 | Input 14 | └── 1 15 | """, 16 | }, 17 | { # Multi Branch Multiple Nodes 18 | "adjacency_dict": {"Input": [1, 2], 2: [3, 4, 5], 4: [6, 7], 6: [8]}, 19 | "expected_pattern": r""" 20 | Input 21 | ├── 1 22 | └── 2 23 | ├── 3 24 | ├── 4 25 | │ ├── 6 26 | │ │ └── 8 27 | │ └── 7 28 | └── 5 29 | """, 30 | }, 31 | { # Single Branch Multiple Nodes 32 | "adjacency_dict": {"Input": [1], 1: [2], 2: [3]}, 33 | "expected_pattern": r""" 34 | Input 35 | └── 1 36 | └── 2 37 | └── 3 38 | """, 39 | }, 40 | ] 41 | 42 | return test_data 43 | 44 | 45 | class TestDisplayWorkflowTree: 46 | def setup_method(self): 47 | self.console = Console( 48 | force_terminal=True, width=80 49 | ) # Ensure consistent terminal behavior 50 | 51 | def build_node_tree(self, adj, node_id="Input"): 52 | """Recursively builds a rich tree of the workflow nodes. Simplified version of the function in clarifai/utils/logging.py""" 53 | tree = Tree(str(node_id)) 54 | for child in adj.get(node_id, []): 55 | tree.add(self.build_node_tree(adj, child)) 56 | return tree 57 | 58 | @pytest.mark.parametrize("test_data", get_workflow_tree_test_data()) 59 | def test_display_workflow_tree(self, test_data: typing.Dict): 60 | tree = self.build_node_tree(test_data["adjacency_dict"]) 61 | with self.console.capture() as capture: 62 | self.console.print(tree) 63 | 64 | actual_pattern = capture.get() 65 | assert actual_pattern.strip() == test_data["expected_pattern"].strip() 66 | -------------------------------------------------------------------------------- /tests/workflow/test_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from clarifai_grpc.grpc.api import resources_pb2 5 | 6 | from clarifai.client.workflow import Workflow 7 | 8 | DOG_IMAGE_URL = "https://samples.clarifai.com/dog2.jpeg" 9 | NON_EXISTING_IMAGE_URL = "http://example.com/non-existing.jpg" 10 | RED_TRUCK_IMAGE_FILE_PATH = "tests/assets/red-truck.png" 11 | BEER_VIDEO_URL = "https://samples.clarifai.com/beer.mp4" 12 | 13 | MAIN_APP_ID = "main" 14 | MAIN_APP_USER_ID = "clarifai" 15 | WORKFLOW_ID = "General" 16 | 17 | CLARIFAI_PAT = os.environ["CLARIFAI_PAT"] 18 | CLARIFAI_API_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com") 19 | 20 | 21 | @pytest.fixture 22 | def workflow(): 23 | return Workflow( 24 | user_id=MAIN_APP_USER_ID, 25 | app_id=MAIN_APP_ID, 26 | workflow_id=WORKFLOW_ID, 27 | output_config=resources_pb2.OutputConfig(max_concepts=3), 28 | pat=CLARIFAI_PAT, 29 | base_url=CLARIFAI_API_BASE, 30 | ) 31 | 32 | 33 | @pytest.mark.requires_secrets 34 | class TestWorkflowPredict: 35 | def test_workflow_predict_image_url(self, workflow): 36 | post_workflows_response = workflow.predict_by_url(DOG_IMAGE_URL, input_type="image") 37 | 38 | assert len(post_workflows_response.results[0].outputs[0].data.concepts) > 0 39 | 40 | def test_workflow_predict_image_bytes(self, workflow): 41 | with open(RED_TRUCK_IMAGE_FILE_PATH, "rb") as f: 42 | file_bytes = f.read() 43 | post_workflows_response = workflow.predict_by_bytes(file_bytes, input_type="image") 44 | 45 | assert len(post_workflows_response.results[0].outputs[0].data.concepts) > 0 46 | 47 | def test_workflow_predict_file_path(self, workflow): 48 | post_workflows_response = workflow.predict_by_filepath( 49 | RED_TRUCK_IMAGE_FILE_PATH, input_type="image" 50 | ) 51 | 52 | assert len(post_workflows_response.results[0].outputs[0].data.concepts) > 0 53 | 54 | def test_workflow_predict_max_concepts(self, workflow): 55 | post_workflows_response = workflow.predict_by_url(DOG_IMAGE_URL, input_type="image") 56 | 57 | assert len(post_workflows_response.results[0].outputs[0].data.concepts) == 3 58 | -------------------------------------------------------------------------------- /tests/workflow/test_validate.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import pytest 5 | import yaml 6 | from schema import SchemaError 7 | 8 | from clarifai.workflows.validate import validate 9 | 10 | 11 | @pytest.mark.parametrize("filename", glob.glob("tests/workflow/fixtures/*.yml")) 12 | def test_validate_fixtures(filename): 13 | with open(filename, "r") as file: 14 | validate(yaml.safe_load(file)) 15 | 16 | 17 | def test_validate_invalid_id(): 18 | with pytest.raises( 19 | SchemaError, match="Key 'id' error:\nRegex(.*) does not match 'id with spaces'" 20 | ): 21 | validate({"workflow": {"id": "id with spaces"}}) 22 | 23 | 24 | def test_validate_empty_nodes(): 25 | with pytest.raises( 26 | SchemaError, match=re.escape("Key 'nodes' error:\nlen([]) should evaluate to True") 27 | ): 28 | validate({"workflow": {"id": "workflow-id", "nodes": []}}) 29 | 30 | 31 | def test_validate_invalid_hex_id(): 32 | with pytest.raises( 33 | SchemaError, match="Key 'model_version_id' error:\nRegex(.*) does not match 'not-a-hex-id'" 34 | ): 35 | validate( 36 | { 37 | "workflow": { 38 | "id": "workflow-id", 39 | "nodes": [ 40 | { 41 | "id": "node-id", 42 | "model": { 43 | "model_id": "model-id", 44 | "model_version_id": "not-a-hex-id", 45 | }, 46 | } 47 | ], 48 | } 49 | } 50 | ) 51 | 52 | 53 | def test_validate_upper_hex_id(): 54 | data = validate( 55 | { 56 | "workflow": { 57 | "id": "workflow-id", 58 | "nodes": [ 59 | { 60 | "id": "node-id", 61 | "model": { 62 | "model_id": "model-id", 63 | "model_version_id": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", 64 | }, 65 | } 66 | ], 67 | } 68 | } 69 | ) 70 | assert ( 71 | data["workflow"]["nodes"][0]["model"]["model_version_id"] 72 | == "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" 73 | ) 74 | 75 | 76 | def test_validate_missing_input(): 77 | with pytest.raises(SchemaError, match="missing input 'previous-node-id' for node 'node-id'"): 78 | validate( 79 | { 80 | "workflow": { 81 | "id": "workflow-id", 82 | "nodes": [ 83 | { 84 | "id": "node-id", 85 | "model": { 86 | "model_id": "model-id", 87 | "model_version_id": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 88 | }, 89 | "node_inputs": [ 90 | { 91 | "node_id": "previous-node-id", 92 | } 93 | ], 94 | } 95 | ], 96 | } 97 | } 98 | ) 99 | 100 | 101 | def test_validate_model_has_model_version_id_and_other_model_fields(): 102 | with pytest.raises( 103 | SchemaError, match="model should not set model_version_id and other model fields" 104 | ): 105 | validate( 106 | { 107 | "workflow": { 108 | "id": "workflow-id", 109 | "nodes": [ 110 | { 111 | "id": "node-id", 112 | "model": { 113 | "model_id": "model-id", 114 | "model_version_id": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 115 | "description": "hello", 116 | }, 117 | } 118 | ], 119 | } 120 | } 121 | ) 122 | --------------------------------------------------------------------------------