├── tests ├── __init__.py ├── test_hardware.py ├── conftest.py ├── test_pagination.py ├── test_account.py ├── test_collection.py ├── cassettes │ ├── run__invalid-token.yaml │ ├── trainings-create__invalid-destination.yaml │ ├── models-predictions-create.yaml │ ├── test_predictions_create_by_model[False].yaml │ ├── trainings-create.yaml │ ├── hardware-list.yaml │ ├── models-create.yaml │ ├── trainings-get.yaml │ ├── test_predictions_create_by_model[True].yaml │ ├── trainings-cancel.yaml │ ├── collections-list.yaml │ └── predictions-get.yaml ├── test_identifier.py ├── test_stream.py ├── test_version.py ├── test_client.py ├── test_model.py ├── test_run.py ├── test_training.py └── test_deployment.py ├── .gitattributes ├── script ├── test ├── format ├── setup └── lint ├── replicate ├── __about__.py ├── __init__.py ├── resource.py ├── schema.py ├── files.py ├── json.py ├── account.py ├── hardware.py ├── identifier.py ├── pagination.py ├── exceptions.py ├── run.py ├── collection.py ├── version.py ├── stream.py ├── model.py ├── client.py ├── training.py └── deployment.py ├── .vscode ├── extensions.json └── settings.json ├── .gitignore ├── .github └── workflows │ ├── release.yaml │ └── ci.yaml ├── requirements.txt ├── requirements-dev.txt ├── pyproject.toml ├── CONTRIBUTING.md ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | tests/cassettes/** binary 2 | -------------------------------------------------------------------------------- /script/test: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python -m pytest -v 6 | -------------------------------------------------------------------------------- /script/format: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python -m ruff format . 6 | -------------------------------------------------------------------------------- /replicate/__about__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | __version__ = version(__package__) 4 | -------------------------------------------------------------------------------- /script/setup: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python -m pip install -r requirements.txt -r requirements-dev.txt . 6 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff", 4 | "ms-python.python", 5 | "ms-python.vscode-pylance", 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | -------------------------------------------------------------------------------- /script/lint: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | STATUS=0 6 | 7 | echo "Running pyright" 8 | python -m pyright replicate || STATUS=$? 9 | echo "" 10 | 11 | echo "Running pylint" 12 | python -m pylint --exit-zero replicate || STATUS=$? 13 | echo "" 14 | 15 | echo "Running ruff check" 16 | python -m ruff check . || STATUS=$? 17 | echo "" 18 | 19 | echo "Running ruff format check" 20 | python -m ruff format --check . || STATUS=$? 21 | echo "" 22 | 23 | exit $STATUS 24 | -------------------------------------------------------------------------------- /tests/test_hardware.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.vcr("hardware-list.yaml") 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_hardware_list(async_flag): 10 | if async_flag: 11 | hardware = await replicate.hardware.async_list() 12 | else: 13 | hardware = replicate.hardware.list() 14 | 15 | assert hardware is not None 16 | assert isinstance(hardware, list) 17 | assert len(hardware) > 0 18 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: ["*"] 6 | 7 | jobs: 8 | release: 9 | runs-on: ubuntu-latest 10 | 11 | name: "Publish to PyPI" 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - uses: actions/setup-python@v3 16 | with: 17 | python-version: "3.10" 18 | - name: Install pypa/build 19 | run: python -m pip install build --user 20 | - name: Build a package 21 | run: python -m build 22 | - name: Publish distribution 📦 to PyPI 23 | uses: pypa/gh-action-pypi-publish@release/v1 24 | with: 25 | password: ${{ secrets.PYPI_API_TOKEN }} 26 | -------------------------------------------------------------------------------- /replicate/__init__.py: -------------------------------------------------------------------------------- 1 | from replicate.client import Client 2 | from replicate.pagination import async_paginate as _async_paginate 3 | from replicate.pagination import paginate as _paginate 4 | 5 | default_client = Client() 6 | 7 | run = default_client.run 8 | async_run = default_client.async_run 9 | 10 | stream = default_client.stream 11 | async_stream = default_client.async_stream 12 | 13 | paginate = _paginate 14 | async_paginate = _async_paginate 15 | 16 | collections = default_client.collections 17 | hardware = default_client.hardware 18 | deployments = default_client.deployments 19 | models = default_client.models 20 | predictions = default_client.predictions 21 | trainings = default_client.trainings 22 | -------------------------------------------------------------------------------- /replicate/resource.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import TYPE_CHECKING 3 | 4 | try: 5 | from pydantic import v1 as pydantic # type: ignore 6 | except ImportError: 7 | import pydantic # type: ignore 8 | 9 | if TYPE_CHECKING: 10 | from replicate.client import Client 11 | 12 | 13 | class Resource(pydantic.BaseModel): # type: ignore 14 | """ 15 | A base class for representing a single object on the server. 16 | """ 17 | 18 | 19 | class Namespace(abc.ABC): 20 | """ 21 | A base class for representing objects of a particular type on the server. 22 | """ 23 | 24 | _client: "Client" 25 | 26 | def __init__(self, client: "Client") -> None: 27 | self._client = client 28 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "editor.formatOnType": true, 4 | "editor.formatOnPaste": true, 5 | "editor.renderControlCharacters": true, 6 | "editor.suggest.localityBonus": true, 7 | "files.insertFinalNewline": true, 8 | "files.trimFinalNewlines": true, 9 | "[python]": { 10 | "editor.defaultFormatter": "charliermarsh.ruff", 11 | "editor.formatOnSave": true, 12 | "editor.codeActionsOnSave": { 13 | "source.fixAll": "explicit", 14 | "source.organizeImports": "explicit" 15 | } 16 | }, 17 | "python.languageServer": "Pylance", 18 | "python.analysis.typeCheckingMode": "basic", 19 | "python.testing.pytestArgs": [ 20 | "-vvv", 21 | "python" 22 | ], 23 | "python.testing.unittestEnabled": false, 24 | "python.testing.pytestEnabled": true, 25 | "ruff.lint.args": [ 26 | "--config=pyproject.toml" 27 | ], 28 | "ruff.format.args": [ 29 | "--config=pyproject.toml" 30 | ], 31 | } 32 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | 7 | pull_request: 8 | branches: ["main"] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | 14 | name: "Test Python ${{ matrix.python-version }}" 15 | 16 | env: 17 | REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} 18 | 19 | timeout-minutes: 10 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 25 | 26 | defaults: 27 | run: 28 | shell: bash 29 | 30 | steps: 31 | - uses: actions/checkout@v3 32 | - uses: actions/setup-python@v3 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | cache: "pip" 36 | 37 | - name: Setup 38 | run: ./script/setup 39 | 40 | - name: Test 41 | run: ./script/test 42 | 43 | - name: Lint 44 | run: ./script/lint 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.11 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements.txt --resolver=backtracking pyproject.toml 6 | # 7 | annotated-types==0.5.0 8 | # via pydantic 9 | anyio==3.7.1 10 | # via httpcore 11 | certifi==2023.7.22 12 | # via 13 | # httpcore 14 | # httpx 15 | h11==0.14.0 16 | # via httpcore 17 | httpcore==0.17.3 18 | # via httpx 19 | httpx==0.24.1 20 | # via replicate (pyproject.toml) 21 | idna==3.4 22 | # via 23 | # anyio 24 | # httpx 25 | packaging==23.1 26 | # via replicate (pyproject.toml) 27 | pydantic==2.0.3 28 | # via replicate (pyproject.toml) 29 | pydantic-core==2.3.0 30 | # via pydantic 31 | sniffio==1.3.0 32 | # via 33 | # anyio 34 | # httpcore 35 | # httpx 36 | typing-extensions==4.7.1 37 | # via 38 | # pydantic 39 | # pydantic-core 40 | # replicate (pyproject.toml) 41 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from unittest import mock 4 | 5 | import pytest 6 | import pytest_asyncio 7 | 8 | 9 | @pytest_asyncio.fixture(scope="session", autouse=True) 10 | def event_loop(): 11 | event_loop_policy = asyncio.get_event_loop_policy() 12 | loop = event_loop_policy.new_event_loop() 13 | yield loop 14 | loop.close() 15 | 16 | 17 | @pytest.fixture(scope="session") 18 | def mock_replicate_api_token(scope="class"): 19 | if os.environ.get("REPLICATE_API_TOKEN", "") != "": 20 | yield 21 | else: 22 | with mock.patch.dict( 23 | os.environ, 24 | {"REPLICATE_API_TOKEN": "test-token", "REPLICATE_POLL_INTERVAL": "0.0"}, 25 | ): 26 | yield 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def vcr_config(): 31 | return {"allowed_hosts": ["api.replicate.com"], "filter_headers": ["authorization"]} 32 | 33 | 34 | @pytest.fixture(scope="module") 35 | def vcr_cassette_dir(request): 36 | module = request.node.fspath 37 | return os.path.join(module.dirname, "cassettes") 38 | -------------------------------------------------------------------------------- /replicate/schema.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from packaging import version 4 | 5 | # TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth? 6 | 7 | 8 | def version_has_no_array_type(cog_version: str) -> Optional[bool]: 9 | """Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward""" 10 | try: 11 | return version.parse(cog_version) < version.parse("0.3.9") 12 | except version.InvalidVersion: 13 | return None 14 | 15 | 16 | def make_schema_backwards_compatible( 17 | schema: dict, 18 | cog_version: str, 19 | ) -> dict: 20 | """A place to add backwards compatibility logic for our openapi schema""" 21 | 22 | # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type 23 | if version_has_no_array_type(cog_version): 24 | output = schema["components"]["schemas"]["Output"] 25 | if output.get("type") == "array": 26 | output["x-cog-array-type"] = "iterator" 27 | return schema 28 | -------------------------------------------------------------------------------- /tests/test_pagination.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.asyncio 7 | async def test_paginate_with_none_cursor(mock_replicate_api_token): 8 | with pytest.raises(ValueError): 9 | replicate.models.list(None) 10 | 11 | 12 | @pytest.mark.vcr("collections-list.yaml") 13 | @pytest.mark.asyncio 14 | @pytest.mark.parametrize("async_flag", [True, False]) 15 | async def test_paginate(async_flag): 16 | found = False 17 | 18 | if async_flag: 19 | async for page in replicate.async_paginate(replicate.collections.async_list): 20 | assert page.next is None 21 | assert page.previous is None 22 | 23 | for collection in page: 24 | if collection.slug == "text-to-image": 25 | found = True 26 | break 27 | 28 | else: 29 | for page in replicate.paginate(replicate.collections.list): 30 | assert page.next is None 31 | assert page.previous is None 32 | 33 | for collection in page: 34 | if collection.slug == "text-to-image": 35 | found = True 36 | break 37 | 38 | assert found 39 | -------------------------------------------------------------------------------- /replicate/files.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import mimetypes 4 | import os 5 | from typing import Optional 6 | 7 | import httpx 8 | 9 | 10 | def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str: 11 | """ 12 | Upload a file to the server. 13 | 14 | Args: 15 | file: A file handle to upload. 16 | output_file_prefix: A string to prepend to the output file name. 17 | Returns: 18 | str: A URL to the uploaded file. 19 | """ 20 | # Lifted straight from cog.files 21 | 22 | file.seek(0) 23 | 24 | if output_file_prefix is not None: 25 | name = getattr(file, "name", "output") 26 | url = output_file_prefix + os.path.basename(name) 27 | resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore 28 | resp.raise_for_status() 29 | 30 | return url 31 | 32 | body = file.read() 33 | # Ensure the file handle is in bytes 34 | body = body.encode("utf-8") if isinstance(body, str) else body 35 | encoded_body = base64.b64encode(body).decode("utf-8") 36 | # Use getattr to avoid mypy complaints about io.IOBase having no attribute name 37 | mime_type = ( 38 | mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" 39 | ) 40 | return f"data:{mime_type};base64,{encoded_body}" 41 | -------------------------------------------------------------------------------- /replicate/json.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from types import GeneratorType 4 | from typing import Any, Callable 5 | 6 | try: 7 | import numpy as np # type: ignore 8 | 9 | HAS_NUMPY = True 10 | except ImportError: 11 | HAS_NUMPY = False 12 | 13 | 14 | # pylint: disable=too-many-return-statements 15 | def encode_json( 16 | obj: Any, # noqa: ANN401 17 | upload_file: Callable[[io.IOBase], str], 18 | ) -> Any: # noqa: ANN401 19 | """ 20 | Return a JSON-compatible version of the object. 21 | """ 22 | # Effectively the same thing as cog.json.encode_json. 23 | 24 | if isinstance(obj, dict): 25 | return {key: encode_json(value, upload_file) for key, value in obj.items()} 26 | if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): 27 | return [encode_json(value, upload_file) for value in obj] 28 | if isinstance(obj, Path): 29 | with obj.open("rb") as file: 30 | return upload_file(file) 31 | if isinstance(obj, io.IOBase): 32 | return upload_file(obj) 33 | if HAS_NUMPY: 34 | if isinstance(obj, np.integer): # type: ignore 35 | return int(obj) 36 | if isinstance(obj, np.floating): # type: ignore 37 | return float(obj) 38 | if isinstance(obj, np.ndarray): # type: ignore 39 | return obj.tolist() 40 | return obj 41 | -------------------------------------------------------------------------------- /tests/test_account.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | import respx 4 | 5 | from replicate.account import Account 6 | from replicate.client import Client 7 | 8 | router = respx.Router(base_url="https://api.replicate.com/v1") 9 | router.route( 10 | method="GET", 11 | path="/account", 12 | name="accounts.current", 13 | ).mock( 14 | return_value=httpx.Response( 15 | 200, 16 | json={ 17 | "type": "organization", 18 | "username": "replicate", 19 | "name": "Replicate", 20 | "github_url": "https://github.com/replicate", 21 | }, 22 | ) 23 | ) 24 | router.route(host="api.replicate.com").pass_through() 25 | 26 | 27 | @pytest.mark.asyncio 28 | @pytest.mark.parametrize("async_flag", [True, False]) 29 | async def test_account_current(async_flag): 30 | client = Client( 31 | api_token="test-token", transport=httpx.MockTransport(router.handler) 32 | ) 33 | 34 | if async_flag: 35 | account = await client.accounts.async_current() 36 | else: 37 | account = client.accounts.current() 38 | 39 | assert router["accounts.current"].called 40 | assert isinstance(account, Account) 41 | assert account.type == "organization" 42 | assert account.username == "replicate" 43 | assert account.name == "Replicate" 44 | assert account.github_url == "https://github.com/replicate" 45 | -------------------------------------------------------------------------------- /tests/test_collection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.vcr("collections-list.yaml") 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_collections_list(async_flag): 10 | if async_flag: 11 | page = await replicate.collections.async_list() 12 | else: 13 | page = replicate.collections.list() 14 | 15 | assert page.next is None 16 | assert page.previous is None 17 | 18 | found = False 19 | for collection in page.results: 20 | if collection.slug == "text-to-image": 21 | found = True 22 | break 23 | 24 | assert found 25 | 26 | 27 | @pytest.mark.vcr("collections-get.yaml") 28 | @pytest.mark.asyncio 29 | @pytest.mark.parametrize("async_flag", [True, False]) 30 | async def test_collections_get(async_flag): 31 | if async_flag: 32 | collection = await replicate.collections.async_get("text-to-image") 33 | else: 34 | collection = replicate.collections.get("text-to-image") 35 | 36 | assert collection.slug == "text-to-image" 37 | assert collection.name == "Text to image" 38 | assert collection.models is not None 39 | assert len(collection.models) > 0 40 | 41 | found = False 42 | for model in collection.models: 43 | if model.name == "stable-diffusion": 44 | found = True 45 | break 46 | 47 | assert found 48 | -------------------------------------------------------------------------------- /replicate/account.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Literal, Optional 2 | 3 | from replicate.resource import Namespace, Resource 4 | 5 | 6 | class Account(Resource): 7 | """ 8 | A user or organization account on Replicate. 9 | """ 10 | 11 | type: Literal["user", "organization"] 12 | """The type of account.""" 13 | 14 | username: str 15 | """The username of the account.""" 16 | 17 | name: str 18 | """The name of the account.""" 19 | 20 | github_url: Optional[str] 21 | """The GitHub URL of the account.""" 22 | 23 | 24 | class Accounts(Namespace): 25 | """ 26 | Namespace for operations related to accounts. 27 | """ 28 | 29 | def current(self) -> Account: 30 | """ 31 | Get the current account. 32 | 33 | Returns: 34 | Account: The current account. 35 | """ 36 | 37 | resp = self._client._request("GET", "/v1/account") 38 | obj = resp.json() 39 | 40 | return _json_to_account(obj) 41 | 42 | async def async_current(self) -> Account: 43 | """ 44 | Get the current account. 45 | 46 | Returns: 47 | Account: The current account. 48 | """ 49 | 50 | resp = await self._client._async_request("GET", "/v1/account") 51 | obj = resp.json() 52 | 53 | return _json_to_account(obj) 54 | 55 | 56 | def _json_to_account(json: Dict[str, Any]) -> Account: 57 | return Account(**json) 58 | -------------------------------------------------------------------------------- /tests/cassettes/run__invalid-token.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"input": {}, "version": "73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9"}' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | content-length: 12 | - '92' 13 | content-type: 14 | - application/json 15 | host: 16 | - api.replicate.com 17 | user-agent: 18 | - replicate-python/0.15.6 19 | method: POST 20 | uri: https://api.replicate.com/v1/predictions 21 | response: 22 | content: '{"title":"Unauthenticated","detail":"You did not pass a valid authentication 23 | token","status":401} 24 | 25 | ' 26 | headers: 27 | CF-Cache-Status: 28 | - DYNAMIC 29 | CF-RAY: 30 | - 826f7f62fe026fbc-IAD 31 | Connection: 32 | - keep-alive 33 | Content-Length: 34 | - '98' 35 | Content-Type: 36 | - application/problem+json 37 | Date: 38 | - Thu, 16 Nov 2023 11:47:09 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=wrRb9OS6whWRSaIb4X88Ipf0nma4uddnvh%2FpB7n4RKKuy9G6j6uUcoEz920jRaH0qq0FHHxkrKqUrxv0HkFVWnLoMJ4Q9As%2FfwUSGNLKvA0vbUfOJa7%2FYiET6%2Bdg0BjJ3qtWg3dArNIIHNi7%2BBX0"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | via: 48 | - 1.1 google 49 | http_version: HTTP/1.1 50 | status_code: 401 51 | version: 1 52 | -------------------------------------------------------------------------------- /tests/test_identifier.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from replicate.identifier import ModelVersionIdentifier 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "id, expected", 8 | [ 9 | ( 10 | "meta/llama-2-70b-chat", 11 | { 12 | "owner": "meta", 13 | "name": "llama-2-70b-chat", 14 | "version": None, 15 | "error": False, 16 | }, 17 | ), 18 | ( 19 | "mistralai/mistral-7b-instruct-v1.4", 20 | { 21 | "owner": "mistralai", 22 | "name": "mistral-7b-instruct-v1.4", 23 | "version": None, 24 | "error": False, 25 | }, 26 | ), 27 | ( 28 | "nateraw/video-llava:a494250c04691c458f57f2f8ef5785f25bc851e0c91fd349995081d4362322dd", 29 | { 30 | "owner": "nateraw", 31 | "name": "video-llava", 32 | "version": "a494250c04691c458f57f2f8ef5785f25bc851e0c91fd349995081d4362322dd", 33 | "error": False, 34 | }, 35 | ), 36 | ( 37 | "", 38 | {"error": True}, 39 | ), 40 | ( 41 | "invalid", 42 | {"error": True}, 43 | ), 44 | ( 45 | "invalid/id/format", 46 | {"error": True}, 47 | ), 48 | ], 49 | ) 50 | def test_parse_model_id(id, expected): 51 | try: 52 | result = ModelVersionIdentifier.parse(id) 53 | assert result.owner == expected["owner"] 54 | assert result.name == expected["name"] 55 | assert result.version == expected["version"] 56 | except ValueError: 57 | assert expected["error"] 58 | -------------------------------------------------------------------------------- /tests/test_stream.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | from replicate.stream import ServerSentEvent 5 | 6 | 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_stream(async_flag, record_mode): 10 | model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" 11 | input = { 12 | "text": "Hello", 13 | } 14 | 15 | events = [] 16 | 17 | if async_flag: 18 | async for event in await replicate.async_stream( 19 | model, 20 | input=input, 21 | ): 22 | events.append(event) 23 | else: 24 | for event in replicate.stream( 25 | model, 26 | input=input, 27 | ): 28 | events.append(event) 29 | 30 | assert len(events) > 0 31 | assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events) 32 | assert any(event.event == ServerSentEvent.EventType.DONE for event in events) 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.parametrize("async_flag", [True, False]) 37 | async def test_stream_prediction(async_flag, record_mode): 38 | version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" 39 | input = { 40 | "text": "Hello", 41 | } 42 | 43 | events = [] 44 | 45 | if async_flag: 46 | async for event in replicate.predictions.create( 47 | version=version, input=input, stream=True 48 | ).async_stream(): 49 | events.append(event) 50 | else: 51 | for event in replicate.predictions.create( 52 | version=version, input=input, stream=True 53 | ).stream(): 54 | events.append(event) 55 | 56 | assert len(events) > 0 57 | -------------------------------------------------------------------------------- /replicate/hardware.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List 2 | 3 | from typing_extensions import deprecated 4 | 5 | from replicate.resource import Namespace, Resource 6 | 7 | if TYPE_CHECKING: 8 | pass 9 | 10 | 11 | class Hardware(Resource): 12 | """ 13 | Hardware for running a model on Replicate. 14 | """ 15 | 16 | sku: str 17 | """ 18 | The SKU of the hardware. 19 | """ 20 | 21 | name: str 22 | """ 23 | The name of the hardware. 24 | """ 25 | 26 | @property 27 | @deprecated("Use `sku` instead of `id`") 28 | def id(self) -> str: 29 | """ 30 | DEPRECATED: Use `sku` instead. 31 | """ 32 | return self.sku 33 | 34 | 35 | class HardwareNamespace(Namespace): 36 | """ 37 | Namespace for operations related to hardware. 38 | """ 39 | 40 | def list(self) -> List[Hardware]: 41 | """ 42 | List all hardware available for you to run models on Replicate. 43 | 44 | Returns: 45 | List[Hardware]: A list of hardware. 46 | """ 47 | 48 | resp = self._client._request("GET", "/v1/hardware") 49 | obj = resp.json() 50 | 51 | return [_json_to_hardware(entry) for entry in obj] 52 | 53 | async def async_list(self) -> List[Hardware]: 54 | """ 55 | List all hardware available for you to run models on Replicate. 56 | 57 | Returns: 58 | List[Hardware]: A list of hardware. 59 | """ 60 | 61 | resp = await self._client._async_request("GET", "/v1/hardware") 62 | obj = resp.json() 63 | 64 | return [_json_to_hardware(entry) for entry in obj] 65 | 66 | 67 | def _json_to_hardware(json: Dict[str, Any]) -> Hardware: 68 | return Hardware(**json) 69 | -------------------------------------------------------------------------------- /replicate/identifier.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union 3 | 4 | if TYPE_CHECKING: 5 | from replicate.model import Model 6 | from replicate.version import Version 7 | 8 | 9 | class ModelVersionIdentifier(NamedTuple): 10 | """ 11 | A reference to a model version in the format owner/name or owner/name:version. 12 | """ 13 | 14 | owner: str 15 | name: str 16 | version: Optional[str] = None 17 | 18 | @classmethod 19 | def parse(cls, ref: str) -> "ModelVersionIdentifier": 20 | """ 21 | Split a reference in the format owner/name:version into its components. 22 | """ 23 | 24 | match = re.match(r"^(?P[^/]+)/(?P[^/:]+)(:(?P.+))?$", ref) 25 | if not match: 26 | raise ValueError( 27 | f"Invalid reference to model version: {ref}. Expected format: owner/name:version" 28 | ) 29 | 30 | return cls(match.group("owner"), match.group("name"), match.group("version")) 31 | 32 | 33 | def _resolve( 34 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 35 | ) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]: 36 | from replicate.model import Model # pylint: disable=import-outside-toplevel 37 | from replicate.version import Version # pylint: disable=import-outside-toplevel 38 | 39 | version = None 40 | owner, name, version_id = None, None, None 41 | if isinstance(ref, Model): 42 | owner, name = ref.owner, ref.name 43 | elif isinstance(ref, Version): 44 | version = ref 45 | version_id = ref.id 46 | elif isinstance(ref, ModelVersionIdentifier): 47 | owner, name, version_id = ref 48 | elif isinstance(ref, str): 49 | owner, name, version_id = ModelVersionIdentifier.parse(ref) 50 | return version, owner, name, version_id 51 | -------------------------------------------------------------------------------- /tests/cassettes/trainings-create__invalid-destination.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: 4 | '{"input": {"input_images": "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip"}, 5 | "destination": ""}' 6 | headers: 7 | accept: 8 | - "*/*" 9 | accept-encoding: 10 | - gzip, deflate 11 | connection: 12 | - keep-alive 13 | content-length: 14 | - "148" 15 | content-type: 16 | - application/json 17 | host: 18 | - api.replicate.com 19 | user-agent: 20 | - replicate-python/0.11.0 21 | method: POST 22 | uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b/trainings 23 | response: 24 | content: 25 | '{"detail":"The specified training destination does not exist","status":404} 26 | 27 | ' 28 | headers: 29 | CF-Cache-Status: 30 | - DYNAMIC 31 | CF-RAY: 32 | - 7f7c2190ed8c281a-SEA 33 | Connection: 34 | - keep-alive 35 | Content-Length: 36 | - "76" 37 | Content-Type: 38 | - application/problem+json 39 | Date: 40 | - Wed, 16 Aug 2023 19:37:18 GMT 41 | NEL: 42 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 43 | Report-To: 44 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' 45 | Server: 46 | - cloudflare 47 | Strict-Transport-Security: 48 | - max-age=15552000 49 | ratelimit-remaining: 50 | - "2999" 51 | ratelimit-reset: 52 | - "1" 53 | via: 54 | - 1.1 google 55 | http_version: HTTP/1.1 56 | status_code: 404 57 | version: 1 58 | -------------------------------------------------------------------------------- /tests/cassettes/models-predictions-create.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"input": {"prompt": "Please write a haiku about llamas."}}' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | content-length: 12 | - '59' 13 | content-type: 14 | - application/json 15 | host: 16 | - api.replicate.com 17 | user-agent: 18 | - replicate-python/0.21.0 19 | method: POST 20 | uri: https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions 21 | response: 22 | content: '{"id":"heat2o3bzn3ahtr6bjfftvbaci","model":"replicate/lifeboat-70b","version":"d-c6559c5791b50af57b69f4a73f8e021c","input":{"prompt":"Please 23 | write a haiku about llamas."},"logs":"","error":null,"status":"starting","created_at":"2023-11-27T13:35:45.99397566Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel","get":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"}} 24 | 25 | ' 26 | headers: 27 | CF-Cache-Status: 28 | - DYNAMIC 29 | CF-RAY: 30 | - 82cac197efaec53d-SEA 31 | Connection: 32 | - keep-alive 33 | Content-Length: 34 | - '431' 35 | Content-Type: 36 | - application/json 37 | Date: 38 | - Mon, 27 Nov 2023 13:35:46 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=7R5RONMF6xaGRc39n0wnSe3jU1FbpX64Xz4U%2B%2F2nasvFaz0pKARxPhnzDgYkLaWgdK9zWrD2jxU04aKOy5HMPHAXboJ993L4zfsOyto56lBtdqSjNgkptzzxYEsKD%2FxIhe2F"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | ratelimit-remaining: 48 | - '599' 49 | ratelimit-reset: 50 | - '1' 51 | via: 52 | - 1.1 google 53 | http_version: HTTP/1.1 54 | status_code: 201 55 | version: 1 56 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile --extra=dev --output-file=requirements-dev.txt pyproject.toml 6 | # 7 | annotated-types==0.5.0 8 | # via pydantic 9 | anyio==3.7.1 10 | # via httpcore 11 | astroid==3.0.1 12 | # via pylint 13 | certifi==2023.7.22 14 | # via 15 | # httpcore 16 | # httpx 17 | dill==0.3.7 18 | # via pylint 19 | h11==0.14.0 20 | # via httpcore 21 | httpcore==0.17.3 22 | # via httpx 23 | httpx==0.24.1 24 | # via 25 | # replicate (pyproject.toml) 26 | # respx 27 | idna==3.4 28 | # via 29 | # anyio 30 | # httpx 31 | # yarl 32 | iniconfig==2.0.0 33 | # via pytest 34 | isort==5.12.0 35 | # via pylint 36 | mccabe==0.7.0 37 | # via pylint 38 | multidict==6.0.4 39 | # via yarl 40 | nodeenv==1.8.0 41 | # via pyright 42 | packaging==23.1 43 | # via 44 | # pytest 45 | # replicate (pyproject.toml) 46 | platformdirs==3.11.0 47 | # via pylint 48 | pluggy==1.2.0 49 | # via pytest 50 | pydantic==2.0.3 51 | # via replicate (pyproject.toml) 52 | pydantic-core==2.3.0 53 | # via pydantic 54 | pylint==3.0.2 55 | # via replicate (pyproject.toml) 56 | pyright==1.1.337 57 | # via replicate (pyproject.toml) 58 | pytest==7.4.0 59 | # via 60 | # pytest-asyncio 61 | # pytest-recording 62 | # replicate (pyproject.toml) 63 | pytest-asyncio==0.21.1 64 | # via replicate (pyproject.toml) 65 | pytest-recording==0.13.0 66 | # via replicate (pyproject.toml) 67 | pyyaml==6.0.1 68 | # via vcrpy 69 | respx==0.20.2 70 | # via replicate (pyproject.toml) 71 | ruff==0.3.3 72 | # via replicate (pyproject.toml) 73 | sniffio==1.3.0 74 | # via 75 | # anyio 76 | # httpcore 77 | # httpx 78 | tomlkit==0.12.1 79 | # via pylint 80 | typing-extensions==4.7.1 81 | # via 82 | # pydantic 83 | # pydantic-core 84 | # replicate (pyproject.toml) 85 | vcrpy==5.1.0 86 | # via pytest-recording 87 | wrapt==1.15.0 88 | # via vcrpy 89 | yarl==1.9.2 90 | # via vcrpy 91 | 92 | # The following packages are considered to be unsafe in a requirements file: 93 | # setuptools 94 | -------------------------------------------------------------------------------- /tests/cassettes/test_predictions_create_by_model[False].yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"input": {"prompt": "write a haiku about llamas"}}' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | content-length: 12 | - '51' 13 | content-type: 14 | - application/json 15 | host: 16 | - api.replicate.com 17 | user-agent: 18 | - replicate-python/0.25.1 19 | method: POST 20 | uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions 21 | response: 22 | content: '{"id":"cjt2hahk61rgp0cf0p2arcdp9r","model":"replicate-internal/llama-3-8b-instruct-int8-triton","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"write 23 | a haiku about llamas"},"logs":"","error":null,"status":"starting","created_at":"2024-04-22T11:15:01.04Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/cjt2hahk61rgp0cf0p2arcdp9r/cancel","get":"https://api.replicate.com/v1/predictions/cjt2hahk61rgp0cf0p2arcdp9r"}} 24 | 25 | ' 26 | headers: 27 | CF-Cache-Status: 28 | - DYNAMIC 29 | CF-RAY: 30 | - 8785318aea5c760a-SEA 31 | Connection: 32 | - keep-alive 33 | Content-Length: 34 | - '446' 35 | Content-Type: 36 | - application/json; charset=utf-8 37 | Date: 38 | - Mon, 22 Apr 2024 11:15:01 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=z9hxBCXF5XyK0SB20HzzACvU%2F0RHmG8ds%2BcCyJc%2FBRZ2Y8CJXU%2F4GYO9zcHQYxw6kShxUxGTnNj6%2FwPF%2BJyf3wJT2AIRPrN0uZ17jodNA2iLGns0DKNySww4AXiQdpj%2FnzHl"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | alt-svc: 48 | - h3=":443"; ma=86400 49 | ratelimit-remaining: 50 | - '599' 51 | ratelimit-reset: 52 | - '1' 53 | replicate-edge-cluster: 54 | - us-central1 55 | replicate-target-cluster: 56 | - coreweave-us 57 | via: 58 | - 1.1 google 59 | http_version: HTTP/1.1 60 | status_code: 201 61 | version: 1 62 | -------------------------------------------------------------------------------- /replicate/pagination.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | AsyncGenerator, 4 | Awaitable, 5 | Callable, 6 | Generator, 7 | Generic, 8 | List, 9 | Optional, 10 | TypeVar, 11 | Union, 12 | ) 13 | 14 | try: 15 | from pydantic import v1 as pydantic # type: ignore 16 | except ImportError: 17 | import pydantic # type: ignore 18 | 19 | from replicate.resource import Resource 20 | 21 | T = TypeVar("T", bound=Resource) 22 | 23 | if TYPE_CHECKING: 24 | pass 25 | 26 | 27 | class Page(pydantic.BaseModel, Generic[T]): # type: ignore 28 | """ 29 | A page of results from the API. 30 | """ 31 | 32 | previous: Optional[str] = None 33 | """A pointer to the previous page of results""" 34 | 35 | next: Optional[str] = None 36 | """A pointer to the next page of results""" 37 | 38 | results: List[T] 39 | """The results on this page""" 40 | 41 | def __iter__(self): # noqa: ANN204 42 | return iter(self.results) 43 | 44 | def __getitem__(self, index: int) -> T: 45 | return self.results[index] 46 | 47 | def __len__(self) -> int: 48 | return len(self.results) 49 | 50 | 51 | def paginate( 52 | list_method: Callable[[Union[str, "ellipsis", None]], Page[T]], # noqa: F821 53 | ) -> Generator[Page[T], None, None]: 54 | """ 55 | Iterate over all items using the provided list method. 56 | 57 | Args: 58 | list_method: A method that takes a cursor argument and returns a Page of items. 59 | """ 60 | cursor: Union[str, "ellipsis", None] = ... # noqa: F821 61 | while cursor is not None: 62 | page = list_method(cursor) 63 | yield page 64 | cursor = page.next 65 | 66 | 67 | async def async_paginate( 68 | list_method: Callable[[Union[str, "ellipsis", None]], Awaitable[Page[T]]], # noqa: F821 69 | ) -> AsyncGenerator[Page[T], None]: 70 | """ 71 | Asynchronously iterate over all items using the provided list method. 72 | 73 | Args: 74 | list_method: An async method that takes a cursor argument and returns a Page of items. 75 | """ 76 | cursor: Union[str, "ellipsis", None] = ... # noqa: F821 77 | while cursor is not None: 78 | page = await list_method(cursor) 79 | yield page 80 | cursor = page.next 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "replicate" 7 | version = "0.25.2" 8 | description = "Python client for Replicate" 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | authors = [{ name = "Replicate, Inc." }] 12 | requires-python = ">=3.8" 13 | dependencies = [ 14 | "httpx>=0.21.0,<1", 15 | "packaging", 16 | "pydantic>1.10.7", 17 | "typing_extensions>=4.5.0", 18 | ] 19 | optional-dependencies = { dev = [ 20 | "pylint", 21 | "pyright", 22 | "pytest", 23 | "pytest-asyncio", 24 | "pytest-recording", 25 | "respx", 26 | "ruff>=0.3.3", 27 | ] } 28 | 29 | [project.urls] 30 | homepage = "https://replicate.com" 31 | repository = "https://github.com/replicate/replicate-python" 32 | 33 | [tool.pytest.ini_options] 34 | testpaths = "tests/" 35 | 36 | [tool.setuptools] 37 | packages = ["replicate"] 38 | 39 | [tool.setuptools.package-data] 40 | "replicate" = ["py.typed"] 41 | 42 | [tool.pylint.main] 43 | disable = [ 44 | "C0301", # Line too long 45 | "C0413", # Import should be placed at the top of the module 46 | "C0114", # Missing module docstring 47 | "R0801", # Similar lines in N files 48 | "W0212", # Access to a protected member 49 | "W0622", # Redefining built-in 50 | "R0903", # Too few public methods 51 | ] 52 | good-names = ["id"] 53 | 54 | [tool.ruff.lint] 55 | select = [ 56 | "E", # pycodestyle error 57 | "F", # Pyflakes 58 | "I", # isort 59 | "W", # pycodestyle warning 60 | "UP", # pyupgrade 61 | "S", # flake8-bandit 62 | "BLE", # flake8-blind-except 63 | "FBT", # flake8-boolean-trap 64 | "B", # flake8-bugbear 65 | "ANN", # flake8-annotations 66 | ] 67 | ignore = [ 68 | "E501", # Line too long 69 | "S113", # Probable use of requests call without timeout 70 | "ANN001", # Missing type annotation for function argument 71 | "ANN002", # Missing type annotation for `*args` 72 | "ANN003", # Missing type annotation for `**kwargs` 73 | "ANN101", # Missing type annotation for self in method 74 | "ANN102", # Missing type annotation for cls in classmethod 75 | "W191", # Indentation contains tabs 76 | ] 77 | 78 | [tool.ruff.lint.per-file-ignores] 79 | "tests/*" = [ 80 | "S101", # Use of assert 81 | "S106", # Possible use of hard-coded password function arguments 82 | "ANN201", # Missing return type annotation for public function 83 | ] 84 | -------------------------------------------------------------------------------- /tests/cassettes/trainings-create.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: 4 | '{"input": {"input_images": "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip", 5 | "use_face_detection_instead": true}, "destination": "replicate/dreambooth-sdxl"}' 6 | headers: 7 | accept: 8 | - "*/*" 9 | accept-encoding: 10 | - gzip, deflate 11 | connection: 12 | - keep-alive 13 | content-length: 14 | - "196" 15 | content-type: 16 | - application/json 17 | host: 18 | - api.replicate.com 19 | user-agent: 20 | - replicate-python/0.11.0 21 | method: POST 22 | uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b/trainings 23 | response: 24 | content: 25 | '{"id":"wj4is6lbdqqkdepnwpr6v2kuva","model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":"","error":null,"status":"starting","created_at":"2023-08-16T19:33:26.678378653Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/wj4is6lbdqqkdepnwpr6v2kuva/cancel","get":"https://api.replicate.com/v1/predictions/wj4is6lbdqqkdepnwpr6v2kuva"}} 26 | 27 | ' 28 | headers: 29 | CF-Cache-Status: 30 | - DYNAMIC 31 | CF-RAY: 32 | - 7f7c1be93b9b279c-SEA 33 | Connection: 34 | - keep-alive 35 | Content-Type: 36 | - application/json 37 | Date: 38 | - Wed, 16 Aug 2023 19:33:26 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=EaGGxbTVzjLhQ1yuc9h2fjPpejc%2F7r%2BTp6OVNUv8otlmgnXW3NxEiamWXhkk%2BSFO58xMQNdwHIEpjT1R%2ByLQYt2GKSF8IMyGHqLizpYmi8%2Bi5qBev1M9JV4hj2UpAXMBfsti"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | Transfer-Encoding: 48 | - chunked 49 | ratelimit-remaining: 50 | - "2999" 51 | ratelimit-reset: 52 | - "1" 53 | via: 54 | - 1.1 google 55 | http_version: HTTP/1.1 56 | status_code: 201 57 | version: 1 58 | -------------------------------------------------------------------------------- /tests/cassettes/hardware-list.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.15.5 15 | method: GET 16 | uri: https://api.replicate.com/v1/hardware 17 | response: 18 | content: '[{"sku":"cpu","name":"CPU"},{"sku":"gpu-t4","name":"Nvidia T4 GPU"},{"sku":"gpu-a40-small","name":"Nvidia 19 | A40 GPU"},{"sku":"gpu-a40-large","name":"Nvidia A40 (Large) GPU"}]' 20 | headers: 21 | CF-Cache-Status: 22 | - DYNAMIC 23 | CF-RAY: 24 | - 81fbfed29fe1c58a-SEA 25 | Connection: 26 | - keep-alive 27 | Content-Encoding: 28 | - gzip 29 | Content-Type: 30 | - application/json 31 | Date: 32 | - Thu, 02 Nov 2023 11:21:41 GMT 33 | Server: 34 | - cloudflare 35 | Strict-Transport-Security: 36 | - max-age=15552000 37 | Transfer-Encoding: 38 | - chunked 39 | allow: 40 | - OPTIONS, GET 41 | content-security-policy-report-only: 42 | - 'connect-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery 43 | https://*.rudderlabs.com https://*.rudderstack.com https://*.mux.com https://*.sentry.io; 44 | worker-src ''none''; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; 45 | style-src ''report-sample'' ''self'' ''unsafe-inline''; font-src ''report-sample'' 46 | ''self'' data:; img-src ''report-sample'' ''self'' data: https://replicate.delivery 47 | https://*.replicate.delivery https://*.githubusercontent.com https://github.com; 48 | default-src ''self''; media-src ''report-sample'' ''self'' https://replicate.delivery 49 | https://*.replicate.delivery https://*.mux.com https://*.sentry.io; report-uri' 50 | cross-origin-opener-policy: 51 | - same-origin 52 | nel: 53 | - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' 54 | ratelimit-remaining: 55 | - '2999' 56 | ratelimit-reset: 57 | - '1' 58 | referrer-policy: 59 | - same-origin 60 | report-to: 61 | - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D"}]}' 62 | reporting-endpoints: 63 | - heroku-nel=https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D 64 | vary: 65 | - Cookie, origin 66 | via: 67 | - 1.1 vegur, 1.1 google 68 | x-content-type-options: 69 | - nosniff 70 | x-frame-options: 71 | - DENY 72 | http_version: HTTP/1.1 73 | status_code: 200 74 | version: 1 75 | -------------------------------------------------------------------------------- /replicate/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import httpx 4 | 5 | 6 | class ReplicateException(Exception): 7 | """A base class for all Replicate exceptions.""" 8 | 9 | 10 | class ModelError(ReplicateException): 11 | """An error from user's code in a model.""" 12 | 13 | 14 | class ReplicateError(ReplicateException): 15 | """ 16 | An error from Replicate's API. 17 | 18 | This class represents a problem details response as defined in RFC 7807. 19 | """ 20 | 21 | type: Optional[str] 22 | """A URI that identifies the error type.""" 23 | 24 | title: Optional[str] 25 | """A short, human-readable summary of the error.""" 26 | 27 | status: Optional[int] 28 | """The HTTP status code.""" 29 | 30 | detail: Optional[str] 31 | """A human-readable explanation specific to this occurrence of the error.""" 32 | 33 | instance: Optional[str] 34 | """A URI that identifies the specific occurrence of the error.""" 35 | 36 | def __init__( # pylint: disable=too-many-arguments 37 | self, 38 | type: Optional[str] = None, 39 | title: Optional[str] = None, 40 | status: Optional[int] = None, 41 | detail: Optional[str] = None, 42 | instance: Optional[str] = None, 43 | ) -> None: 44 | self.type = type 45 | self.title = title 46 | self.status = status 47 | self.detail = detail 48 | self.instance = instance 49 | 50 | @classmethod 51 | def from_response(cls, response: httpx.Response) -> "ReplicateError": 52 | """Create a ReplicateError from an HTTP response.""" 53 | 54 | try: 55 | data = response.json() 56 | except ValueError: 57 | data = {} 58 | 59 | return cls( 60 | type=data.get("type"), 61 | title=data.get("title"), 62 | detail=data.get("detail"), 63 | status=response.status_code, 64 | instance=data.get("instance"), 65 | ) 66 | 67 | def to_dict(self) -> dict: 68 | """Get a dictionary representation of the error.""" 69 | 70 | return { 71 | key: value 72 | for key, value in { 73 | "type": self.type, 74 | "title": self.title, 75 | "status": self.status, 76 | "detail": self.detail, 77 | "instance": self.instance, 78 | }.items() 79 | if value is not None 80 | } 81 | 82 | def __str__(self) -> str: 83 | return "ReplicateError Details:\n" + "\n".join( 84 | [f"{key}: {value}" for key, value in self.to_dict().items()] 85 | ) 86 | 87 | def __repr__(self) -> str: 88 | class_name = self.__class__.__name__ 89 | params = ", ".join( 90 | [ 91 | f"type={repr(self.type)}", 92 | f"title={repr(self.title)}", 93 | f"status={repr(self.status)}", 94 | f"detail={repr(self.detail)}", 95 | f"instance={repr(self.instance)}", 96 | ] 97 | ) 98 | return f"{class_name}({params})" 99 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | import respx 4 | 5 | from replicate.client import Client 6 | 7 | router = respx.Router(base_url="https://api.replicate.com/v1") 8 | 9 | router.route( 10 | method="GET", 11 | path="/models/replicate/hello-world", 12 | name="models.get", 13 | ).mock( 14 | return_value=httpx.Response( 15 | 200, 16 | json={ 17 | "owner": "replicate", 18 | "name": "hello-world", 19 | "description": "A tiny model that says hello", 20 | "visibility": "public", 21 | "run_count": 1e10, 22 | "url": "https://replicate.com/replicate/hello-world", 23 | "created_at": "2022-04-26T19:13:45.911328Z", 24 | "latest_version": { 25 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 26 | "cog_version": "0.3.0", 27 | "openapi_schema": { 28 | "openapi": "3.0.2", 29 | "info": {"title": "Cog", "version": "0.1.0"}, 30 | "components": { 31 | "schemas": { 32 | "Input": { 33 | "type": "object", 34 | "title": "Input", 35 | "required": ["text"], 36 | "properties": { 37 | "text": { 38 | "type": "string", 39 | "title": "Text", 40 | "x-order": 0, 41 | "description": "Text to prefix with 'hello '", 42 | } 43 | }, 44 | }, 45 | "Output": {"type": "string", "title": "Output"}, 46 | } 47 | }, 48 | }, 49 | "created_at": "2022-04-26T19:29:04.418669Z", 50 | }, 51 | }, 52 | ) 53 | ) 54 | 55 | router.route( 56 | method="DELETE", 57 | path__regex=r"^/models/replicate/hello-world/versions/(?P\w+)/?", 58 | name="models.versions.delete", 59 | ).mock( 60 | return_value=httpx.Response( 61 | 202, 62 | ) 63 | ) 64 | 65 | 66 | @pytest.mark.asyncio 67 | @pytest.mark.parametrize("async_flag", [True, False]) 68 | async def test_version_delete(async_flag): 69 | client = Client( 70 | api_token="test-token", transport=httpx.MockTransport(router.handler) 71 | ) 72 | 73 | if async_flag: 74 | model = await client.models.async_get("replicate/hello-world") 75 | assert model is not None 76 | assert model.latest_version is not None 77 | 78 | await model.versions.async_delete(model.latest_version.id) 79 | else: 80 | model = client.models.get("replicate/hello-world") 81 | assert model is not None 82 | assert model.latest_version is not None 83 | 84 | model.versions.delete(model.latest_version.id) 85 | 86 | assert router["models.get"].called 87 | assert router["models.versions.delete"].called 88 | -------------------------------------------------------------------------------- /tests/cassettes/models-create.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware": 4 | "cpu", "description": "An example model"}' 5 | headers: 6 | accept: 7 | - '*/*' 8 | accept-encoding: 9 | - gzip, deflate 10 | connection: 11 | - keep-alive 12 | content-length: 13 | - '123' 14 | content-type: 15 | - application/json 16 | host: 17 | - api.replicate.com 18 | user-agent: 19 | - replicate-python/0.15.6 20 | method: POST 21 | uri: https://api.replicate.com/v1/models 22 | response: 23 | content: '{"url": "https://replicate.com/test/python-example", "owner": "test", 24 | "name": "python-example", "description": "An example model", "visibility": "private", 25 | "github_url": null, "paper_url": null, "license_url": null, "run_count": 0, 26 | "cover_image_url": null, "default_example": null, "latest_version": null}' 27 | headers: 28 | CF-Cache-Status: 29 | - DYNAMIC 30 | CF-RAY: 31 | - 81ff2e098ec0eb5b-SEA 32 | Connection: 33 | - keep-alive 34 | Content-Length: 35 | - '307' 36 | Content-Type: 37 | - application/json 38 | Date: 39 | - Thu, 02 Nov 2023 20:38:12 GMT 40 | Server: 41 | - cloudflare 42 | Strict-Transport-Security: 43 | - max-age=15552000 44 | allow: 45 | - GET, POST, HEAD, OPTIONS 46 | content-security-policy-report-only: 47 | - 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self'' 48 | data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com 49 | https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; 50 | style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample'' 51 | ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com 52 | https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src 53 | ''none''; media-src ''report-sample'' ''self'' https://replicate.delivery 54 | https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src 55 | ''self''; report-uri' 56 | cross-origin-opener-policy: 57 | - same-origin 58 | nel: 59 | - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' 60 | ratelimit-remaining: 61 | - '2999' 62 | ratelimit-reset: 63 | - '1' 64 | referrer-policy: 65 | - same-origin 66 | report-to: 67 | - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}' 68 | reporting-endpoints: 69 | - heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D 70 | vary: 71 | - Cookie, origin 72 | via: 73 | - 1.1 vegur, 1.1 google 74 | x-content-type-options: 75 | - nosniff 76 | x-frame-options: 77 | - DENY 78 | http_version: HTTP/1.1 79 | status_code: 201 80 | version: 1 81 | -------------------------------------------------------------------------------- /tests/cassettes/trainings-get.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: "" 4 | headers: 5 | accept: 6 | - "*/*" 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.11.0 15 | method: GET 16 | uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte 17 | response: 18 | content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","webhook_completed":null}' 19 | headers: 20 | CF-Cache-Status: 21 | - DYNAMIC 22 | CF-RAY: 23 | - 7f7c1beaedff279c-SEA 24 | Connection: 25 | - keep-alive 26 | Content-Encoding: 27 | - gzip 28 | Content-Type: 29 | - application/json 30 | Date: 31 | - Wed, 16 Aug 2023 19:33:26 GMT 32 | NEL: 33 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 34 | Report-To: 35 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' 36 | Server: 37 | - cloudflare 38 | Strict-Transport-Security: 39 | - max-age=15552000 40 | Transfer-Encoding: 41 | - chunked 42 | allow: 43 | - OPTIONS, GET 44 | content-security-policy-report-only: 45 | - "style-src 'report-sample' 'self' 'unsafe-inline' https://fonts.googleapis.com; 46 | img-src 'report-sample' 'self' data: https://replicate.delivery https://*.replicate.delivery 47 | https://*.githubusercontent.com https://github.com; worker-src 'none'; media-src 48 | 'report-sample' 'self' https://replicate.delivery https://*.replicate.delivery 49 | https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src 'report-sample' 50 | 'self' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com 51 | https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src 52 | 'report-sample' 'self' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; 53 | font-src 'report-sample' 'self' data: https://fonts.replicate.ai https://fonts.gstatic.com; 54 | default-src 'self'; report-uri" 55 | cross-origin-opener-policy: 56 | - same-origin 57 | ratelimit-remaining: 58 | - "2999" 59 | ratelimit-reset: 60 | - "1" 61 | referrer-policy: 62 | - same-origin 63 | vary: 64 | - Cookie, origin 65 | via: 66 | - 1.1 vegur, 1.1 google 67 | x-content-type-options: 68 | - nosniff 69 | x-frame-options: 70 | - DENY 71 | http_version: HTTP/1.1 72 | status_code: 200 73 | version: 1 74 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | import httpx 5 | import pytest 6 | import respx 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_authorization_when_setting_environ_after_import(): 11 | import replicate 12 | 13 | router = respx.Router() 14 | router.route( 15 | method="GET", 16 | url="https://api.replicate.com/", 17 | headers={"Authorization": "Bearer test-set-after-import"}, 18 | ).mock( 19 | return_value=httpx.Response( 20 | 200, 21 | json={}, 22 | ) 23 | ) 24 | 25 | token = "test-set-after-import" # noqa: S105 26 | 27 | with mock.patch.dict( 28 | os.environ, 29 | {"REPLICATE_API_TOKEN": token}, 30 | ): 31 | client = replicate.Client(transport=httpx.MockTransport(router.handler)) 32 | resp = client._request("GET", "/") 33 | assert resp.status_code == 200 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_client_error_handling(): 38 | import replicate 39 | from replicate.exceptions import ReplicateError 40 | 41 | router = respx.Router() 42 | router.route( 43 | method="GET", 44 | url="https://api.replicate.com/", 45 | headers={"Authorization": "Bearer test-client-error"}, 46 | ).mock( 47 | return_value=httpx.Response( 48 | 400, 49 | json={"detail": "Client error occurred"}, 50 | ) 51 | ) 52 | 53 | token = "test-client-error" # noqa: S105 54 | 55 | with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): 56 | client = replicate.Client(transport=httpx.MockTransport(router.handler)) 57 | with pytest.raises(ReplicateError) as exc_info: 58 | client._request("GET", "/") 59 | assert "status: 400" in str(exc_info.value) 60 | assert "detail: Client error occurred" in str(exc_info.value) 61 | 62 | 63 | @pytest.mark.asyncio 64 | async def test_server_error_handling(): 65 | import replicate 66 | from replicate.exceptions import ReplicateError 67 | 68 | router = respx.Router() 69 | router.route( 70 | method="GET", 71 | url="https://api.replicate.com/", 72 | headers={"Authorization": "Bearer test-server-error"}, 73 | ).mock( 74 | return_value=httpx.Response( 75 | 500, 76 | json={"detail": "Server error occurred"}, 77 | ) 78 | ) 79 | 80 | token = "test-server-error" # noqa: S105 81 | 82 | with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): 83 | client = replicate.Client(transport=httpx.MockTransport(router.handler)) 84 | with pytest.raises(ReplicateError) as exc_info: 85 | client._request("GET", "/") 86 | assert "status: 500" in str(exc_info.value) 87 | assert "detail: Server error occurred" in str(exc_info.value) 88 | 89 | 90 | def test_custom_headers_are_applied(): 91 | import replicate 92 | from replicate.exceptions import ReplicateError 93 | 94 | custom_headers = {"Custom-Header": "CustomValue"} 95 | 96 | def mock_send(request: httpx.Request, **kwargs) -> httpx.Response: 97 | assert "Custom-Header" in request.headers 98 | assert request.headers["Custom-Header"] == "CustomValue" 99 | 100 | return httpx.Response(401, json={}) 101 | 102 | client = replicate.Client( 103 | api_token="dummy_token", 104 | headers=custom_headers, 105 | transport=httpx.MockTransport(mock_send), 106 | ) 107 | 108 | try: 109 | client.accounts.current() 110 | except ReplicateError: 111 | pass 112 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.vcr("models-get.yaml") 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_models_get(async_flag): 10 | if async_flag: 11 | sdxl = await replicate.models.async_get("stability-ai/sdxl") 12 | else: 13 | sdxl = replicate.models.get("stability-ai/sdxl") 14 | 15 | assert sdxl is not None 16 | assert sdxl.owner == "stability-ai" 17 | assert sdxl.name == "sdxl" 18 | assert sdxl.visibility == "public" 19 | 20 | if async_flag: 21 | empty = await replicate.models.async_get("mattt/empty") 22 | else: 23 | empty = replicate.models.get("mattt/empty") 24 | 25 | assert empty.default_example is None 26 | 27 | 28 | @pytest.mark.vcr("models-list.yaml") 29 | @pytest.mark.asyncio 30 | @pytest.mark.parametrize("async_flag", [True, False]) 31 | async def test_models_list(async_flag): 32 | if async_flag: 33 | models = await replicate.models.async_list() 34 | else: 35 | models = replicate.models.list() 36 | 37 | assert len(models) > 0 38 | assert models[0].owner is not None 39 | assert models[0].name is not None 40 | assert models[0].visibility == "public" 41 | 42 | 43 | @pytest.mark.vcr("models-list__pagination.yaml") 44 | @pytest.mark.asyncio 45 | @pytest.mark.parametrize("async_flag", [True, False]) 46 | async def test_models_list_pagination(async_flag): 47 | if async_flag: 48 | page1 = await replicate.models.async_list() 49 | else: 50 | page1 = replicate.models.list() 51 | assert len(page1) > 0 52 | assert page1.next is not None 53 | 54 | if async_flag: 55 | page2 = await replicate.models.async_list(cursor=page1.next) 56 | else: 57 | page2 = replicate.models.list(cursor=page1.next) 58 | assert len(page2) > 0 59 | assert page2.previous is not None 60 | 61 | 62 | @pytest.mark.vcr("models-create.yaml") 63 | @pytest.mark.asyncio 64 | @pytest.mark.parametrize("async_flag", [True, False]) 65 | async def test_models_create(async_flag): 66 | if async_flag: 67 | model = await replicate.models.async_create( 68 | owner="test", 69 | name="python-example", 70 | visibility="private", 71 | hardware="cpu", 72 | description="An example model", 73 | ) 74 | else: 75 | model = replicate.models.create( 76 | owner="test", 77 | name="python-example", 78 | visibility="private", 79 | hardware="cpu", 80 | description="An example model", 81 | ) 82 | 83 | assert model.owner == "test" 84 | assert model.name == "python-example" 85 | assert model.visibility == "private" 86 | 87 | 88 | @pytest.mark.vcr("models-create.yaml") 89 | @pytest.mark.asyncio 90 | @pytest.mark.parametrize("async_flag", [True, False]) 91 | async def test_models_create_with_positional_arguments(async_flag): 92 | if async_flag: 93 | model = await replicate.models.async_create( 94 | "test", 95 | "python-example", 96 | visibility="private", 97 | hardware="cpu", 98 | ) 99 | else: 100 | model = replicate.models.create( 101 | "test", 102 | "python-example", 103 | visibility="private", 104 | hardware="cpu", 105 | ) 106 | 107 | assert model.owner == "test" 108 | assert model.name == "python-example" 109 | assert model.visibility == "private" 110 | 111 | 112 | @pytest.mark.vcr("models-predictions-create.yaml") 113 | @pytest.mark.asyncio 114 | @pytest.mark.parametrize("async_flag", [True, False]) 115 | async def test_models_predictions_create(async_flag): 116 | input = { 117 | "prompt": "Please write a haiku about llamas.", 118 | } 119 | 120 | if async_flag: 121 | prediction = await replicate.models.predictions.async_create( 122 | "meta/llama-2-70b-chat", input=input 123 | ) 124 | else: 125 | prediction = replicate.models.predictions.create( 126 | "meta/llama-2-70b-chat", input=input 127 | ) 128 | 129 | assert prediction.id is not None 130 | # assert prediction.model == "meta/llama-2-70b-chat" 131 | assert prediction.model == "replicate/lifeboat-70b" # FIXME: this is temporary 132 | assert prediction.status == "starting" 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | - [Making a contribution](#making-a-contribution) 4 | - [Signing your work](#signing-your-work) 5 | - [How to sign off your commits](#how-to-sign-off-your-commits) 6 | - [Development](#development) 7 | - [Environment variables](#environment-variables) 8 | - [Publishing a release](#publishing-a-release) 9 | 10 | ## Making a contribution 11 | 12 | ### Signing your work 13 | 14 | Each commit you contribute to this repo must be signed off (not to be confused with **[signing](https://git-scm.com/book/en/v2/Git-Tools-Signing-Your-Work)**). It certifies that you wrote the patch, or have the right to contribute it. It is called the [Developer Certificate of Origin](https://developercertificate.org/) and was originally developed for the Linux kernel. 15 | 16 | If you can certify the following: 17 | 18 | ``` 19 | By making a contribution to this project, I certify that: 20 | 21 | (a) The contribution was created in whole or in part by me and I 22 | have the right to submit it under the open source license 23 | indicated in the file; or 24 | 25 | (b) The contribution is based upon previous work that, to the best 26 | of my knowledge, is covered under an appropriate open source 27 | license and I have the right under that license to submit that 28 | work with modifications, whether created in whole or in part 29 | by me, under the same open source license (unless I am 30 | permitted to submit under a different license), as indicated 31 | in the file; or 32 | 33 | (c) The contribution was provided directly to me by some other 34 | person who certified (a), (b) or (c) and I have not modified 35 | it. 36 | 37 | (d) I understand and agree that this project and the contribution 38 | are public and that a record of the contribution (including all 39 | personal information I submit with it, including my sign-off) is 40 | maintained indefinitely and may be redistributed consistent with 41 | this project or the open source license(s) involved. 42 | ``` 43 | 44 | Then add this line to each of your Git commit messages, with your name and email: 45 | 46 | ``` 47 | Signed-off-by: Sam Smith 48 | ``` 49 | 50 | ### How to sign off your commits 51 | 52 | If you're using the `git` CLI, you can sign a commit by passing the `-s` option: `git commit -s -m "Reticulate splines"` 53 | 54 | You can also create a git hook which will sign off all your commits automatically. Using hooks also allows you to sign off commits when using non-command-line tools like GitHub Desktop or VS Code. 55 | 56 | First, create the hook file and make it executable: 57 | 58 | ```sh 59 | cd your/checkout/of/replicate-python 60 | touch .git/hooks/prepare-commit-msg 61 | chmod +x .git/hooks/prepare-commit-msg 62 | ``` 63 | 64 | Then paste the following into the file: 65 | 66 | ``` 67 | #!/bin/sh 68 | 69 | NAME=$(git config user.name) 70 | EMAIL=$(git config user.email) 71 | 72 | if [ -z "$NAME" ]; then 73 | echo "empty git config user.name" 74 | exit 1 75 | fi 76 | 77 | if [ -z "$EMAIL" ]; then 78 | echo "empty git config user.email" 79 | exit 1 80 | fi 81 | 82 | git interpret-trailers --if-exists doNothing --trailer \ 83 | "Signed-off-by: $NAME <$EMAIL>" \ 84 | --in-place "$1" 85 | ``` 86 | 87 | ## Development 88 | 89 | To run the tests: 90 | 91 | ```sh 92 | pip install -r requirements-dev.txt 93 | pytest 94 | ``` 95 | 96 | To install the package in development: 97 | 98 | ```sh 99 | pip install -e . 100 | ``` 101 | 102 | ### Environment variables 103 | 104 | - `REPLICATE_API_BASE_URL`: Defaults to `https://api.replicate.com` but can be overriden to point the client at a development host. 105 | - `REPLICATE_API_TOKEN`: Required. Find your token at https://replicate.com/#token 106 | 107 | ## Publishing a release 108 | 109 | This project has a [GitHub Actions workflow](/.github/workflows/ci.yaml) that publishes the `replicate` package to PyPI. The release process is triggered by manually creating and pushing a new git tag. 110 | 111 | First, set the version number in [pyproject.toml](pyproject.toml) and commit it to the `main` branch: 112 | 113 | ``` 114 | version = "0.7.0" 115 | ``` 116 | 117 | Then run the following in your local checkout: 118 | 119 | ```sh 120 | git checkout main 121 | git fetch --all --tags 122 | git tag 0.7.0 123 | git push --tags 124 | ``` 125 | 126 | Then visit [github.com/replicate/replicate-python/actions](https://github.com/replicate/replicate-python/actions) to monitor the release process. 127 | -------------------------------------------------------------------------------- /replicate/run.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | AsyncIterator, 5 | Dict, 6 | Iterator, 7 | List, 8 | Optional, 9 | Union, 10 | ) 11 | 12 | from typing_extensions import Unpack 13 | 14 | from replicate import identifier 15 | from replicate.exceptions import ModelError 16 | from replicate.model import Model 17 | from replicate.prediction import Prediction 18 | from replicate.schema import make_schema_backwards_compatible 19 | from replicate.version import Version, Versions 20 | 21 | if TYPE_CHECKING: 22 | from replicate.client import Client 23 | from replicate.identifier import ModelVersionIdentifier 24 | from replicate.prediction import Predictions 25 | 26 | 27 | def run( 28 | client: "Client", 29 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 30 | input: Optional[Dict[str, Any]] = None, 31 | **params: Unpack["Predictions.CreatePredictionParams"], 32 | ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 33 | """ 34 | Run a model and wait for its output. 35 | """ 36 | 37 | version, owner, name, version_id = identifier._resolve(ref) 38 | 39 | if version_id is not None: 40 | prediction = client.predictions.create( 41 | version=version_id, input=input or {}, **params 42 | ) 43 | elif owner and name: 44 | prediction = client.models.predictions.create( 45 | model=(owner, name), input=input or {}, **params 46 | ) 47 | else: 48 | raise ValueError( 49 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 50 | ) 51 | 52 | if not version and (owner and name and version_id): 53 | version = Versions(client, model=(owner, name)).get(version_id) 54 | 55 | if version and (iterator := _make_output_iterator(version, prediction)): 56 | return iterator 57 | 58 | prediction.wait() 59 | 60 | if prediction.status == "failed": 61 | raise ModelError(prediction.error) 62 | 63 | return prediction.output 64 | 65 | 66 | async def async_run( 67 | client: "Client", 68 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 69 | input: Optional[Dict[str, Any]] = None, 70 | **params: Unpack["Predictions.CreatePredictionParams"], 71 | ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 72 | """ 73 | Run a model and wait for its output asynchronously. 74 | """ 75 | 76 | version, owner, name, version_id = identifier._resolve(ref) 77 | 78 | if version or version_id: 79 | prediction = await client.predictions.async_create( 80 | version=(version or version_id), input=input or {}, **params 81 | ) 82 | elif owner and name: 83 | prediction = await client.models.predictions.async_create( 84 | model=(owner, name), input=input or {}, **params 85 | ) 86 | else: 87 | raise ValueError( 88 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 89 | ) 90 | 91 | if not version and (owner and name and version_id): 92 | version = await Versions(client, model=(owner, name)).async_get(version_id) 93 | 94 | if version and (iterator := _make_async_output_iterator(version, prediction)): 95 | return iterator 96 | 97 | await prediction.async_wait() 98 | 99 | if prediction.status == "failed": 100 | raise ModelError(prediction.error) 101 | 102 | return prediction.output 103 | 104 | 105 | def _has_output_iterator_array_type(version: Version) -> bool: 106 | schema = make_schema_backwards_compatible( 107 | version.openapi_schema, version.cog_version 108 | ) 109 | output = schema.get("components", {}).get("schemas", {}).get("Output", {}) 110 | return ( 111 | output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" 112 | ) 113 | 114 | 115 | def _make_output_iterator( 116 | version: Version, prediction: Prediction 117 | ) -> Optional[Iterator[Any]]: 118 | if _has_output_iterator_array_type(version): 119 | return prediction.output_iterator() 120 | 121 | return None 122 | 123 | 124 | def _make_async_output_iterator( 125 | version: Version, prediction: Prediction 126 | ) -> Optional[AsyncIterator[Any]]: 127 | if _has_output_iterator_array_type(version): 128 | return prediction.async_output_iterator() 129 | 130 | return None 131 | 132 | 133 | __all__: List = [] 134 | -------------------------------------------------------------------------------- /tests/cassettes/test_predictions_create_by_model[True].yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"input": {"prompt": "write a haiku about llamas"}}' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | content-length: 12 | - '51' 13 | content-type: 14 | - application/json 15 | host: 16 | - api.replicate.com 17 | user-agent: 18 | - replicate-python/0.25.1 19 | method: POST 20 | uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions 21 | response: 22 | content: '{"id":"vpx8dks2pnrgg0cf0p2b7p13hc","model":"replicate-internal/llama-3-8b-instruct-int8-triton","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"write 23 | a haiku about llamas"},"logs":"","error":null,"status":"starting","created_at":"2024-04-22T11:14:56.821Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/vpx8dks2pnrgg0cf0p2b7p13hc/cancel","get":"https://api.replicate.com/v1/predictions/vpx8dks2pnrgg0cf0p2b7p13hc"}} 24 | 25 | ' 26 | headers: 27 | CF-Cache-Status: 28 | - DYNAMIC 29 | CF-RAY: 30 | - 878531709ab476d6-SEA 31 | Connection: 32 | - keep-alive 33 | Content-Length: 34 | - '447' 35 | Content-Type: 36 | - application/json; charset=utf-8 37 | Date: 38 | - Mon, 22 Apr 2024 11:14:56 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=PGxIXq%2Fp%2FXsMcmw0GyjVGenwcrvsBa%2FAasnwL8Mm8c0ODadYfL9gsDVQMV0Cx%2Fg1xNNyWYL51rraqeV7xQkD9yj0UsgKYy70W4COulpbw6srEkT4TzSDdDOZA7Qo49eu2c5l"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | alt-svc: 48 | - h3=":443"; ma=86400 49 | ratelimit-remaining: 50 | - '599' 51 | ratelimit-reset: 52 | - '1' 53 | replicate-edge-cluster: 54 | - us-central1 55 | replicate-target-cluster: 56 | - coreweave-us 57 | via: 58 | - 1.1 google 59 | http_version: HTTP/1.1 60 | status_code: 201 61 | - request: 62 | body: '{"input": {"prompt": "write a haiku about llamas"}}' 63 | headers: 64 | accept: 65 | - '*/*' 66 | accept-encoding: 67 | - gzip, deflate 68 | connection: 69 | - keep-alive 70 | content-length: 71 | - '51' 72 | content-type: 73 | - application/json 74 | host: 75 | - api.replicate.com 76 | user-agent: 77 | - replicate-python/0.25.1 78 | method: POST 79 | uri: https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions 80 | response: 81 | content: '{"id":"vcp9g11jexrgp0cf0p2bweyqyg","model":"replicate-internal/llama-3-8b-instruct-int8-triton","version":"dp-a557b7387b4940df25b23f779dc534c4","input":{"prompt":"write 82 | a haiku about llamas"},"logs":"","error":null,"status":"starting","created_at":"2024-04-22T11:15:00.855Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/vcp9g11jexrgp0cf0p2bweyqyg/cancel","get":"https://api.replicate.com/v1/predictions/vcp9g11jexrgp0cf0p2bweyqyg"}} 83 | 84 | ' 85 | headers: 86 | CF-Cache-Status: 87 | - DYNAMIC 88 | CF-RAY: 89 | - 87853189c84a9b79-SEA 90 | Connection: 91 | - keep-alive 92 | Content-Length: 93 | - '447' 94 | Content-Type: 95 | - application/json; charset=utf-8 96 | Date: 97 | - Mon, 22 Apr 2024 11:15:00 GMT 98 | NEL: 99 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 100 | Report-To: 101 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=3IETw%2FWj35BwKNZG7vaA8O981ehk76n1iTstHAvJCMknGcuVRg9jDcQ1HPM0yiWRrl4oongZAJV%2Bky%2B36ee8yyz6%2Bqudt99I5D3yz9tzxzYGEZkrq05TFIAsVsS%2BaA2Sexnz"}],"group":"cf-nel","max_age":604800}' 102 | Server: 103 | - cloudflare 104 | Strict-Transport-Security: 105 | - max-age=15552000 106 | alt-svc: 107 | - h3=":443"; ma=86400 108 | ratelimit-remaining: 109 | - '599' 110 | ratelimit-reset: 111 | - '1' 112 | replicate-edge-cluster: 113 | - us-central1 114 | replicate-target-cluster: 115 | - coreweave-us 116 | via: 117 | - 1.1 google 118 | http_version: HTTP/1.1 119 | status_code: 201 120 | version: 1 121 | -------------------------------------------------------------------------------- /replicate/collection.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterator, List, Optional, Union, overload 2 | 3 | from typing_extensions import deprecated 4 | 5 | from replicate.model import Model 6 | from replicate.pagination import Page 7 | from replicate.resource import Namespace, Resource 8 | 9 | 10 | class Collection(Resource): 11 | """ 12 | A collection of models on Replicate. 13 | """ 14 | 15 | slug: str 16 | """The slug used to identify the collection.""" 17 | 18 | name: str 19 | """The name of the collection.""" 20 | 21 | description: str 22 | """A description of the collection.""" 23 | 24 | models: Optional[List[Model]] = None 25 | """The models in the collection.""" 26 | 27 | @property 28 | @deprecated("Use `slug` instead of `id`") 29 | def id(self) -> str: 30 | """ 31 | DEPRECATED: Use `slug` instead. 32 | """ 33 | return self.slug 34 | 35 | def __iter__(self) -> Iterator[Model]: 36 | if self.models is not None: 37 | return iter(self.models) 38 | return iter([]) 39 | 40 | @overload 41 | def __getitem__(self, index: int) -> Optional[Model]: ... 42 | 43 | @overload 44 | def __getitem__(self, index: slice) -> Optional[List[Model]]: ... 45 | 46 | def __getitem__( 47 | self, index: Union[int, slice] 48 | ) -> Union[Optional[Model], Optional[List[Model]]]: 49 | if self.models is not None: 50 | return self.models[index] 51 | return None 52 | 53 | def __len__(self) -> int: 54 | if self.models is not None: 55 | return len(self.models) 56 | 57 | return 0 58 | 59 | 60 | class Collections(Namespace): 61 | """ 62 | A namespace for operations related to collections of models. 63 | """ 64 | 65 | def list( 66 | self, 67 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 68 | ) -> Page[Collection]: 69 | """ 70 | List collections of models. 71 | 72 | Parameters: 73 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 74 | Returns: 75 | Page[Collection]: A page of of model collections. 76 | Raises: 77 | ValueError: If `cursor` is `None`. 78 | """ 79 | 80 | if cursor is None: 81 | raise ValueError("cursor cannot be None") 82 | 83 | resp = self._client._request( 84 | "GET", "/v1/collections" if cursor is ... else cursor 85 | ) 86 | 87 | obj = resp.json() 88 | obj["results"] = [_json_to_collection(result) for result in obj["results"]] 89 | 90 | return Page[Collection](**obj) 91 | 92 | async def async_list( 93 | self, 94 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 95 | ) -> Page[Collection]: 96 | """ 97 | List collections of models. 98 | 99 | Parameters: 100 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 101 | Returns: 102 | Page[Collection]: A page of of model collections. 103 | Raises: 104 | ValueError: If `cursor` is `None`. 105 | """ 106 | 107 | if cursor is None: 108 | raise ValueError("cursor cannot be None") 109 | 110 | resp = await self._client._async_request( 111 | "GET", "/v1/collections" if cursor is ... else cursor 112 | ) 113 | 114 | obj = resp.json() 115 | obj["results"] = [_json_to_collection(result) for result in obj["results"]] 116 | 117 | return Page[Collection](**obj) 118 | 119 | def get(self, slug: str) -> Collection: 120 | """Get a model by name. 121 | 122 | Args: 123 | name: The name of the model, in the format `owner/model-name`. 124 | Returns: 125 | The model. 126 | """ 127 | 128 | resp = self._client._request("GET", f"/v1/collections/{slug}") 129 | 130 | return _json_to_collection(resp.json()) 131 | 132 | async def async_get(self, slug: str) -> Collection: 133 | """Get a model by name. 134 | 135 | Args: 136 | name: The name of the model, in the format `owner/model-name`. 137 | Returns: 138 | The model. 139 | """ 140 | 141 | resp = await self._client._async_request("GET", f"/v1/collections/{slug}") 142 | 143 | return _json_to_collection(resp.json()) 144 | 145 | 146 | def _json_to_collection(json: Dict[str, Any]) -> Collection: 147 | return Collection(**json) 148 | -------------------------------------------------------------------------------- /tests/cassettes/trainings-cancel.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: 4 | '{"input": {"input_images": "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip", 5 | "use_face_detection_instead": true}, "destination": "replicate/dreambooth-sdxl"}' 6 | headers: 7 | accept: 8 | - "*/*" 9 | accept-encoding: 10 | - gzip, deflate 11 | connection: 12 | - keep-alive 13 | content-length: 14 | - "196" 15 | content-type: 16 | - application/json 17 | host: 18 | - api.replicate.com 19 | user-agent: 20 | - replicate-python/0.11.0 21 | method: POST 22 | uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b/trainings 23 | response: 24 | content: 25 | '{"id":"xxi2kydbnla3ocjp64sclkewku","model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":"","error":null,"status":"starting","created_at":"2023-08-16T19:33:27.07526337Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/xxi2kydbnla3ocjp64sclkewku/cancel","get":"https://api.replicate.com/v1/predictions/xxi2kydbnla3ocjp64sclkewku"}} 26 | 27 | ' 28 | headers: 29 | CF-Cache-Status: 30 | - DYNAMIC 31 | CF-RAY: 32 | - 7f7c1bebbefb279c-SEA 33 | Connection: 34 | - keep-alive 35 | Content-Type: 36 | - application/json 37 | Date: 38 | - Wed, 16 Aug 2023 19:33:27 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=8nwXblNwwv%2B3MNruaLvo105IDtUra2KhqpdDIxWIzz0LXn8aAFXd3LaW1jyNse6YOfiuX5%2BFuBrA4%2FdFnF%2FmnqqTC58iRNJw8qlNjr8ZvnEXwJrJqxDEWa%2FtzJGC8wVt%2BQko"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | Transfer-Encoding: 48 | - chunked 49 | ratelimit-remaining: 50 | - "2999" 51 | ratelimit-reset: 52 | - "1" 53 | via: 54 | - 1.1 google 55 | http_version: HTTP/1.1 56 | status_code: 201 57 | - request: 58 | body: "" 59 | headers: 60 | accept: 61 | - "*/*" 62 | accept-encoding: 63 | - gzip, deflate 64 | connection: 65 | - keep-alive 66 | content-length: 67 | - "0" 68 | host: 69 | - api.replicate.com 70 | user-agent: 71 | - replicate-python/0.11.0 72 | method: POST 73 | uri: https://api.replicate.com/v1/trainings/xxi2kydbnla3ocjp64sclkewku/cancel 74 | response: 75 | content: '{"completed_at":"2023-08-16T19:33:27.215203Z","created_at":"2023-08-16T19:33:27.135361Z","error":null,"id":"xxi2kydbnla3ocjp64sclkewku","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{"predict_time":3e-06},"output":null,"started_at":"2023-08-16T19:33:27.215200Z","status":"canceled","urls":{"get":"https://api.replicate.com/v1/predictions/xxi2kydbnla3ocjp64sclkewku","cancel":"https://api.replicate.com/v1/predictions/xxi2kydbnla3ocjp64sclkewku/cancel"},"model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","webhook_completed":null}' 76 | headers: 77 | CF-Cache-Status: 78 | - DYNAMIC 79 | CF-RAY: 80 | - 7f7c1bec8813279c-SEA 81 | Connection: 82 | - keep-alive 83 | Content-Encoding: 84 | - gzip 85 | Content-Type: 86 | - application/json 87 | Date: 88 | - Wed, 16 Aug 2023 19:33:27 GMT 89 | NEL: 90 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 91 | Report-To: 92 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=Y93Ig0lcI3wRbWLTTyTwQuz%2FQPzwlzM%2FI79Q17bYeJIu382PgqxEwZxl3dduPP%2FYm91xCRzfjW41b44fxB0YD6XXF%2FSMKmO19I1LpI%2FMd3p0jwepg7i3B1cNsCc6Rfpo0qN%2B"}],"group":"cf-nel","max_age":604800}' 93 | Server: 94 | - cloudflare 95 | Strict-Transport-Security: 96 | - max-age=15552000 97 | Transfer-Encoding: 98 | - chunked 99 | ratelimit-remaining: 100 | - "2999" 101 | ratelimit-reset: 102 | - "1" 103 | via: 104 | - 1.1 google 105 | http_version: HTTP/1.1 106 | status_code: 200 107 | version: 1 108 | -------------------------------------------------------------------------------- /replicate/version.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import TYPE_CHECKING, Any, Dict, Tuple, Union 3 | 4 | if TYPE_CHECKING: 5 | from replicate.client import Client 6 | from replicate.model import Model 7 | 8 | from replicate.pagination import Page 9 | from replicate.resource import Namespace, Resource 10 | 11 | 12 | class Version(Resource): 13 | """ 14 | A version of a model. 15 | """ 16 | 17 | id: str 18 | """The unique ID of the version.""" 19 | 20 | created_at: datetime.datetime 21 | """When the version was created.""" 22 | 23 | cog_version: str 24 | """The version of the Cog used to create the version.""" 25 | 26 | openapi_schema: dict 27 | """An OpenAPI description of the model inputs and outputs.""" 28 | 29 | 30 | class Versions(Namespace): 31 | """ 32 | Namespace for operations related to model versions. 33 | """ 34 | 35 | model: Tuple[str, str] 36 | 37 | def __init__( 38 | self, client: "Client", model: Union[str, Tuple[str, str], "Model"] 39 | ) -> None: 40 | super().__init__(client=client) 41 | 42 | from replicate.model import Model # pylint: disable=import-outside-toplevel 43 | 44 | if isinstance(model, Model): 45 | self.model = (model.owner, model.name) 46 | elif isinstance(model, str): 47 | owner, name = model.split("/", 1) 48 | self.model = (owner, name) 49 | else: 50 | self.model = model 51 | 52 | def get(self, id: str) -> Version: 53 | """ 54 | Get a specific model version. 55 | 56 | Args: 57 | id: The version ID. 58 | Returns: 59 | The model version. 60 | """ 61 | 62 | resp = self._client._request( 63 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 64 | ) 65 | 66 | return _json_to_version(resp.json()) 67 | 68 | async def async_get(self, id: str) -> Version: 69 | """ 70 | Get a specific model version. 71 | 72 | Args: 73 | id: The version ID. 74 | Returns: 75 | The model version. 76 | """ 77 | 78 | resp = await self._client._async_request( 79 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 80 | ) 81 | 82 | return _json_to_version(resp.json()) 83 | 84 | def list(self) -> Page[Version]: 85 | """ 86 | Return a list of all versions for a model. 87 | 88 | Returns: 89 | List[Version]: A list of version objects. 90 | """ 91 | 92 | resp = self._client._request( 93 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions" 94 | ) 95 | obj = resp.json() 96 | obj["results"] = [_json_to_version(result) for result in obj["results"]] 97 | 98 | return Page[Version](**obj) 99 | 100 | async def async_list(self) -> Page[Version]: 101 | """ 102 | Return a list of all versions for a model. 103 | 104 | Returns: 105 | List[Version]: A list of version objects. 106 | """ 107 | 108 | resp = await self._client._async_request( 109 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions" 110 | ) 111 | obj = resp.json() 112 | obj["results"] = [_json_to_version(result) for result in obj["results"]] 113 | 114 | return Page[Version](**obj) 115 | 116 | def delete(self, id: str) -> bool: 117 | """ 118 | Delete a model version and all associated predictions, including all output files. 119 | 120 | Model version deletion has some restrictions: 121 | 122 | * You can only delete versions from models you own. 123 | * You can only delete versions from private models. 124 | * You cannot delete a version if someone other than you 125 | has run predictions with it. 126 | 127 | Args: 128 | id: The version ID. 129 | """ 130 | 131 | resp = self._client._request( 132 | "DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 133 | ) 134 | return resp.status_code == 204 135 | 136 | async def async_delete(self, id: str) -> bool: 137 | """ 138 | Delete a model version and all associated predictions, including all output files. 139 | 140 | Model version deletion has some restrictions: 141 | 142 | * You can only delete versions from models you own. 143 | * You can only delete versions from private models. 144 | * You cannot delete a version if someone other than you 145 | has run predictions with it. 146 | 147 | Args: 148 | id: The version ID. 149 | """ 150 | 151 | resp = await self._client._async_request( 152 | "DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 153 | ) 154 | return resp.status_code == 204 155 | 156 | 157 | def _json_to_version(json: Dict[str, Any]) -> Version: 158 | return Version(**json) 159 | -------------------------------------------------------------------------------- /tests/cassettes/collections-list.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.15.6 15 | method: GET 16 | uri: https://api.replicate.com/v1/collections 17 | response: 18 | content: '{"next":null,"previous":null,"results":[{"name":"Vision models","slug":"vision-models","description":"Multimodal 19 | large language models with vision capabilities like object detection and optical 20 | character recognition (OCR)"},{"name":"T2I-Adapter","slug":"t2i-adapter","description":"T2I-Adapter 21 | models to modify images"},{"name":"Language models with support for grammars 22 | and jsonschema","slug":"language-models-with-grammar","description":"Language 23 | models that support grammar-based decoding as well as jsonschema constraints."},{"name":"SDXL 24 | fine-tunes","slug":"sdxl-fine-tunes","description":"Some of our favorite SDXL 25 | fine-tunes."},{"name":"Streaming language models","slug":"streaming-language-models","description":"Language 26 | models that support streaming responses. See https://replicate.com/docs/streaming"},{"name":"Image 27 | editing","slug":"image-editing","description":"Tools for manipulating images."},{"name":"Embedding 28 | models","slug":"embedding-models","description":"Models that generate embeddings 29 | from inputs"},{"name":"Trainable language models","slug":"trainable-language-models","description":"Language 30 | models that you can fine-tune using Replicate''s training API."},{"name":"Language 31 | models","slug":"language-models","description":"Models that can understand and 32 | generate text"},{"name":"ControlNet","slug":"control-net","description":"Control 33 | diffusion models"},{"name":"Audio generation","slug":"audio-generation","description":"Models 34 | to generate and modify audio"},{"name":"Diffusion models","slug":"diffusion-models","description":"Image 35 | and video generation models trained with diffusion processes"},{"name":"Videos","slug":"text-to-video","description":"Models 36 | that create and edit videos"},{"name":"Image to text","slug":"image-to-text","description":"Models 37 | that generate text prompts and captions from images"},{"name":"Super resolution","slug":"super-resolution","description":"Upscaling 38 | models that create high-quality images from low-quality images"},{"name":"Style 39 | transfer","slug":"style-transfer","description":"Models that take a content 40 | image and a style reference to produce a new image"},{"name":"ML makeovers","slug":"ml-makeovers","description":"Models 41 | that let you change facial features"},{"name":"Image restoration","slug":"image-restoration","description":"Models 42 | that improve or restore images by deblurring, colorization, and removing noise"},{"name":"Text 43 | to image","slug":"text-to-image","description":"Models that generate images 44 | from text prompts"}]}' 45 | headers: 46 | CF-Cache-Status: 47 | - DYNAMIC 48 | CF-RAY: 49 | - 827025392eae200a-IAD 50 | Connection: 51 | - keep-alive 52 | Content-Encoding: 53 | - gzip 54 | Content-Type: 55 | - application/json 56 | Date: 57 | - Thu, 16 Nov 2023 13:40:22 GMT 58 | Server: 59 | - cloudflare 60 | Strict-Transport-Security: 61 | - max-age=15552000 62 | Transfer-Encoding: 63 | - chunked 64 | allow: 65 | - GET, HEAD, OPTIONS 66 | content-security-policy-report-only: 67 | - 'media-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery 68 | https://*.mux.com https://*.sentry.io; default-src ''self''; script-src ''report-sample'' 69 | ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; img-src 70 | ''report-sample'' ''self'' data: https://replicate.delivery https://*.replicate.delivery 71 | https://*.githubusercontent.com https://github.com; worker-src ''none''; style-src 72 | ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample'' 73 | ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com 74 | https://*.rudderstack.com https://*.mux.com https://*.sentry.io; font-src 75 | ''report-sample'' ''self'' data:; report-uri' 76 | cross-origin-opener-policy: 77 | - same-origin 78 | nel: 79 | - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' 80 | ratelimit-remaining: 81 | - '2999' 82 | ratelimit-reset: 83 | - '1' 84 | referrer-policy: 85 | - same-origin 86 | report-to: 87 | - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1700142022&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=AUQIFO072WbZjKq785Xqd67vUUwGAhLFqu5%2BlLug%2BWE%3D"}]}' 88 | reporting-endpoints: 89 | - heroku-nel=https://nel.heroku.com/reports?ts=1700142022&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=AUQIFO072WbZjKq785Xqd67vUUwGAhLFqu5%2BlLug%2BWE%3D 90 | vary: 91 | - Cookie, origin 92 | via: 93 | - 1.1 vegur, 1.1 google 94 | x-content-type-options: 95 | - nosniff 96 | x-frame-options: 97 | - DENY 98 | http_version: HTTP/1.1 99 | status_code: 200 100 | version: 1 101 | -------------------------------------------------------------------------------- /tests/test_run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | 4 | import httpx 5 | import pytest 6 | import respx 7 | 8 | import replicate 9 | from replicate.client import Client 10 | from replicate.exceptions import ReplicateError 11 | 12 | 13 | @pytest.mark.vcr("run.yaml") 14 | @pytest.mark.asyncio 15 | @pytest.mark.parametrize("async_flag", [True, False]) 16 | async def test_run(async_flag, record_mode): 17 | if record_mode == "none": 18 | replicate.default_client.poll_interval = 0.001 19 | 20 | version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" 21 | 22 | input = { 23 | "prompt": "a studio photo of a rainbow colored corgi", 24 | "width": 512, 25 | "height": 512, 26 | "seed": 42069, 27 | } 28 | 29 | if async_flag: 30 | output = await replicate.async_run( 31 | f"stability-ai/sdxl:{version}", 32 | input=input, 33 | ) 34 | else: 35 | output = replicate.run( 36 | f"stability-ai/sdxl:{version}", 37 | input=input, 38 | ) 39 | 40 | assert output is not None 41 | assert isinstance(output, list) 42 | assert len(output) > 0 43 | assert output[0].startswith("https://") 44 | 45 | 46 | @pytest.mark.vcr("run__concurrently.yaml") 47 | @pytest.mark.asyncio 48 | @pytest.mark.skipif( 49 | sys.version_info < (3, 11), reason="asyncio.TaskGroup requires Python 3.11" 50 | ) 51 | async def test_run_concurrently(mock_replicate_api_token, record_mode): 52 | client = replicate.Client() 53 | if record_mode == "none": 54 | client.poll_interval = 0.001 55 | 56 | version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" 57 | 58 | prompts = [ 59 | f"A chariot pulled by a team of {count} rainbow unicorns" 60 | for count in ["two", "four", "six", "eight"] 61 | ] 62 | 63 | async with asyncio.TaskGroup() as tg: 64 | tasks = [ 65 | tg.create_task( 66 | client.async_run( 67 | f"stability-ai/sdxl:{version}", input={"prompt": prompt} 68 | ) 69 | ) 70 | for prompt in prompts 71 | ] 72 | 73 | results = await asyncio.gather(*tasks) 74 | assert len(results) == len(prompts) 75 | assert all(isinstance(result, list) for result in results) 76 | assert all(len(result) > 0 for result in results) 77 | 78 | 79 | @pytest.mark.vcr("run.yaml") 80 | @pytest.mark.asyncio 81 | async def test_run_with_invalid_identifier(mock_replicate_api_token): 82 | with pytest.raises(ValueError): 83 | replicate.run("invalid") 84 | 85 | 86 | @pytest.mark.vcr("run__invalid-token.yaml") 87 | @pytest.mark.asyncio 88 | async def test_run_with_invalid_token(): 89 | with pytest.raises(ReplicateError) as excinfo: 90 | client = replicate.Client(api_token="invalid") 91 | 92 | version = "73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9" 93 | client.run( 94 | f"meta/llama-2-7b:{version}", 95 | ) 96 | 97 | assert "You did not pass a valid authentication token" in str(excinfo.value) 98 | 99 | 100 | @pytest.mark.asyncio 101 | async def test_run_version_with_invalid_cog_version(mock_replicate_api_token): 102 | def prediction_with_status(status: str) -> dict: 103 | return { 104 | "id": "p1", 105 | "model": "test/example", 106 | "version": "v1", 107 | "urls": { 108 | "get": "https://api.replicate.com/v1/predictions/p1", 109 | "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", 110 | }, 111 | "created_at": "2023-10-05T12:00:00.000000Z", 112 | "source": "api", 113 | "status": status, 114 | "input": {"text": "world"}, 115 | "output": "Hello, world!" if status == "succeeded" else None, 116 | "error": None, 117 | "logs": "", 118 | } 119 | 120 | router = respx.Router(base_url="https://api.replicate.com/v1") 121 | router.route(method="POST", path="/predictions").mock( 122 | return_value=httpx.Response( 123 | 201, 124 | json=prediction_with_status("processing"), 125 | ) 126 | ) 127 | router.route(method="GET", path="/predictions/p1").mock( 128 | return_value=httpx.Response( 129 | 200, 130 | json=prediction_with_status("succeeded"), 131 | ) 132 | ) 133 | router.route( 134 | method="GET", 135 | path="/models/test/example/versions/invalid", 136 | ).mock( 137 | return_value=httpx.Response( 138 | 201, 139 | json={ 140 | "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", 141 | "created_at": "2022-03-16T00:35:56.210272Z", 142 | "cog_version": "dev", 143 | "openapi_schema": { 144 | "openapi": "3.0.2", 145 | "info": {"title": "Cog", "version": "0.1.0"}, 146 | "paths": {}, 147 | "components": { 148 | "schemas": { 149 | "Input": { 150 | "type": "object", 151 | "title": "Input", 152 | "required": ["text"], 153 | "properties": { 154 | "text": { 155 | "type": "string", 156 | "title": "Text", 157 | "x-order": 0, 158 | "description": "The text input", 159 | }, 160 | }, 161 | }, 162 | "Output": { 163 | "type": "string", 164 | "title": "Output", 165 | }, 166 | } 167 | }, 168 | }, 169 | }, 170 | ) 171 | ) 172 | router.route(host="api.replicate.com").pass_through() 173 | 174 | client = Client( 175 | api_token="test-token", transport=httpx.MockTransport(router.handler) 176 | ) 177 | client.poll_interval = 0.001 178 | 179 | output = client.run( 180 | "test/example:invalid", 181 | input={ 182 | "text": "Hello, world!", 183 | }, 184 | ) 185 | 186 | assert output == "Hello, world!" 187 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | from replicate.exceptions import ReplicateException 5 | 6 | input_images_url = "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip" 7 | 8 | 9 | @pytest.mark.vcr("trainings-create.yaml") 10 | @pytest.mark.asyncio 11 | @pytest.mark.parametrize("async_flag", [True, False]) 12 | async def test_trainings_create(async_flag, mock_replicate_api_token): 13 | if async_flag: 14 | training = await replicate.trainings.async_create( 15 | model="stability-ai/sdxl", 16 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 17 | input={ 18 | "input_images": input_images_url, 19 | "use_face_detection_instead": True, 20 | }, 21 | destination="replicate/dreambooth-sdxl", 22 | ) 23 | else: 24 | training = replicate.trainings.create( 25 | model="stability-ai/sdxl", 26 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 27 | input={ 28 | "input_images": input_images_url, 29 | "use_face_detection_instead": True, 30 | }, 31 | destination="replicate/dreambooth-sdxl", 32 | ) 33 | 34 | assert training.id is not None 35 | assert training.status == "starting" 36 | 37 | 38 | @pytest.mark.vcr("trainings-create.yaml") 39 | @pytest.mark.asyncio 40 | @pytest.mark.parametrize("async_flag", [True, False]) 41 | async def test_trainings_create_with_named_version_argument( 42 | async_flag, mock_replicate_api_token 43 | ): 44 | if async_flag: 45 | # The overload with a model version identifier is soft-deprecated 46 | # and not supported in the async version. 47 | return 48 | else: 49 | training = replicate.trainings.create( 50 | version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 51 | input={ 52 | "input_images": input_images_url, 53 | "use_face_detection_instead": True, 54 | }, 55 | destination="replicate/dreambooth-sdxl", 56 | ) 57 | 58 | assert training.id is not None 59 | assert training.status == "starting" 60 | 61 | 62 | @pytest.mark.vcr("trainings-create.yaml") 63 | @pytest.mark.asyncio 64 | @pytest.mark.parametrize("async_flag", [True, False]) 65 | async def test_trainings_create_with_positional_argument( 66 | async_flag, mock_replicate_api_token 67 | ): 68 | if async_flag: 69 | # The overload with positional arguments is soft-deprecated 70 | # and not supported in the async version. 71 | return 72 | else: 73 | training = replicate.trainings.create( 74 | "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 75 | { 76 | "input_images": input_images_url, 77 | "use_face_detection_instead": True, 78 | }, 79 | "replicate/dreambooth-sdxl", 80 | ) 81 | 82 | assert training.id is not None 83 | assert training.status == "starting" 84 | 85 | 86 | @pytest.mark.vcr("trainings-create__invalid-destination.yaml") 87 | @pytest.mark.asyncio 88 | @pytest.mark.parametrize("async_flag", [True, False]) 89 | async def test_trainings_create_with_invalid_destination( 90 | async_flag, mock_replicate_api_token 91 | ): 92 | with pytest.raises(ReplicateException): 93 | if async_flag: 94 | await replicate.trainings.async_create( 95 | model="stability-ai/sdxl", 96 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 97 | input={ 98 | "input_images": input_images_url, 99 | "use_face_detection_instead": True, 100 | }, 101 | destination="", 102 | ) 103 | else: 104 | replicate.trainings.create( 105 | model="stability-ai/sdxl", 106 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 107 | input={ 108 | "input_images": input_images_url, 109 | }, 110 | destination="", 111 | ) 112 | 113 | 114 | @pytest.mark.vcr("trainings-get.yaml") 115 | @pytest.mark.asyncio 116 | @pytest.mark.parametrize("async_flag", [True, False]) 117 | async def test_trainings_get(async_flag, mock_replicate_api_token): 118 | id = "medrnz3bm5dd6ultvad2tejrte" 119 | 120 | if async_flag: 121 | training = await replicate.trainings.async_get(id) 122 | else: 123 | training = replicate.trainings.get(id) 124 | 125 | assert training.id == id 126 | assert training.status == "processing" 127 | 128 | 129 | @pytest.mark.vcr("trainings-cancel.yaml") 130 | @pytest.mark.asyncio 131 | @pytest.mark.parametrize("async_flag", [True, False]) 132 | async def test_trainings_cancel(async_flag, mock_replicate_api_token): 133 | input = { 134 | "input_images": input_images_url, 135 | "use_face_detection_instead": True, 136 | } 137 | 138 | destination = "replicate/dreambooth-sdxl" 139 | 140 | if async_flag: 141 | training = await replicate.trainings.async_create( 142 | model="stability-ai/sdxl", 143 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 144 | input=input, 145 | destination=destination, 146 | ) 147 | 148 | assert training.status == "starting" 149 | 150 | training = replicate.trainings.cancel(training.id) 151 | assert training.status == "canceled" 152 | else: 153 | training = replicate.trainings.create( 154 | version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 155 | destination=destination, 156 | input=input, 157 | ) 158 | 159 | assert training.status == "starting" 160 | 161 | training = replicate.trainings.cancel(training.id) 162 | assert training.status == "canceled" 163 | 164 | 165 | @pytest.mark.vcr("trainings-cancel.yaml") 166 | @pytest.mark.asyncio 167 | @pytest.mark.parametrize("async_flag", [True, False]) 168 | async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_token): 169 | input = { 170 | "input_images": input_images_url, 171 | "use_face_detection_instead": True, 172 | } 173 | 174 | destination = "replicate/dreambooth-sdxl" 175 | 176 | if async_flag: 177 | # The cancel instance method is soft-deprecated, 178 | # and not supported in the async version. 179 | return 180 | else: 181 | training = replicate.trainings.create( 182 | version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 183 | destination=destination, 184 | input=input, 185 | ) 186 | 187 | assert training.status == "starting" 188 | 189 | training.cancel() 190 | assert training.status == "canceled" 191 | -------------------------------------------------------------------------------- /replicate/stream.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import ( 3 | TYPE_CHECKING, 4 | Any, 5 | AsyncIterator, 6 | Dict, 7 | Iterator, 8 | List, 9 | Optional, 10 | Union, 11 | ) 12 | 13 | from typing_extensions import Unpack 14 | 15 | from replicate import identifier 16 | from replicate.exceptions import ReplicateError 17 | 18 | try: 19 | from pydantic import v1 as pydantic # type: ignore 20 | except ImportError: 21 | import pydantic # type: ignore 22 | 23 | 24 | if TYPE_CHECKING: 25 | import httpx 26 | 27 | from replicate.client import Client 28 | from replicate.identifier import ModelVersionIdentifier 29 | from replicate.model import Model 30 | from replicate.prediction import Predictions 31 | from replicate.version import Version 32 | 33 | 34 | class ServerSentEvent(pydantic.BaseModel): # type: ignore 35 | """ 36 | A server-sent event. 37 | """ 38 | 39 | class EventType(Enum): 40 | """ 41 | A server-sent event type. 42 | """ 43 | 44 | OUTPUT = "output" 45 | LOGS = "logs" 46 | ERROR = "error" 47 | DONE = "done" 48 | 49 | event: EventType 50 | data: str 51 | id: str 52 | retry: Optional[int] 53 | 54 | def __str__(self) -> str: 55 | if self.event == ServerSentEvent.EventType.OUTPUT: 56 | return self.data 57 | 58 | return "" 59 | 60 | 61 | class EventSource: 62 | """ 63 | A server-sent event source. 64 | """ 65 | 66 | response: "httpx.Response" 67 | 68 | def __init__(self, response: "httpx.Response") -> None: 69 | self.response = response 70 | content_type, _, _ = response.headers["content-type"].partition(";") 71 | if content_type != "text/event-stream": 72 | raise ValueError( 73 | "Expected response Content-Type to be 'text/event-stream', " 74 | f"got {content_type!r}" 75 | ) 76 | 77 | class Decoder: 78 | """ 79 | A decoder for server-sent events. 80 | """ 81 | 82 | event: Optional["ServerSentEvent.EventType"] 83 | data: List[str] 84 | last_event_id: Optional[str] 85 | retry: Optional[int] 86 | 87 | def __init__(self) -> None: 88 | self.event = None 89 | self.data = [] 90 | self.last_event_id = None 91 | self.retry = None 92 | 93 | def decode(self, line: str) -> Optional[ServerSentEvent]: 94 | """ 95 | Decode a line and return a server-sent event if applicable. 96 | """ 97 | 98 | if not line: 99 | if ( 100 | not any([self.event, self.data, self.last_event_id, self.retry]) 101 | or self.event is None 102 | or self.last_event_id is None 103 | ): 104 | return None 105 | 106 | sse = ServerSentEvent( 107 | event=self.event, 108 | data="\n".join(self.data), 109 | id=self.last_event_id, 110 | retry=self.retry, 111 | ) 112 | 113 | self.event = None 114 | self.data = [] 115 | self.retry = None 116 | 117 | return sse 118 | 119 | if line.startswith(":"): 120 | return None 121 | 122 | fieldname, _, value = line.partition(":") 123 | value = value[1:] if value.startswith(" ") else value 124 | 125 | if fieldname == "event": 126 | if event := ServerSentEvent.EventType(value): 127 | self.event = event 128 | elif fieldname == "data": 129 | self.data.append(value) 130 | elif fieldname == "id": 131 | if "\0" not in value: 132 | self.last_event_id = value 133 | elif fieldname == "retry": 134 | try: 135 | self.retry = int(value) 136 | except (TypeError, ValueError): 137 | pass 138 | 139 | return None 140 | 141 | def __iter__(self) -> Iterator[ServerSentEvent]: 142 | decoder = EventSource.Decoder() 143 | 144 | for line in self.response.iter_lines(): 145 | line = line.rstrip("\n") 146 | sse = decoder.decode(line) 147 | if sse is not None: 148 | if sse.event == ServerSentEvent.EventType.ERROR: 149 | raise RuntimeError(sse.data) 150 | 151 | yield sse 152 | 153 | if sse.event == ServerSentEvent.EventType.DONE: 154 | return 155 | 156 | async def __aiter__(self) -> AsyncIterator[ServerSentEvent]: 157 | decoder = EventSource.Decoder() 158 | async for line in self.response.aiter_lines(): 159 | line = line.rstrip("\n") 160 | sse = decoder.decode(line) 161 | if sse is not None: 162 | if sse.event == ServerSentEvent.EventType.ERROR: 163 | raise RuntimeError(sse.data) 164 | 165 | yield sse 166 | 167 | if sse.event == ServerSentEvent.EventType.DONE: 168 | return 169 | 170 | 171 | def stream( 172 | client: "Client", 173 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 174 | input: Optional[Dict[str, Any]] = None, 175 | **params: Unpack["Predictions.CreatePredictionParams"], 176 | ) -> Iterator[ServerSentEvent]: 177 | """ 178 | Run a model and stream its output. 179 | """ 180 | 181 | params = params or {} 182 | params["stream"] = True 183 | 184 | version, owner, name, version_id = identifier._resolve(ref) 185 | 186 | if version or version_id: 187 | prediction = client.predictions.create( 188 | version=(version or version_id), input=input or {}, **params 189 | ) 190 | elif owner and name: 191 | prediction = client.models.predictions.create( 192 | model=(owner, name), input=input or {}, **params 193 | ) 194 | else: 195 | raise ValueError( 196 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 197 | ) 198 | 199 | url = prediction.urls and prediction.urls.get("stream", None) 200 | if not url or not isinstance(url, str): 201 | raise ReplicateError("Model does not support streaming") 202 | 203 | headers = {} 204 | headers["Accept"] = "text/event-stream" 205 | headers["Cache-Control"] = "no-store" 206 | 207 | with client._client.stream("GET", url, headers=headers) as response: 208 | yield from EventSource(response) 209 | 210 | 211 | async def async_stream( 212 | client: "Client", 213 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 214 | input: Optional[Dict[str, Any]] = None, 215 | **params: Unpack["Predictions.CreatePredictionParams"], 216 | ) -> AsyncIterator[ServerSentEvent]: 217 | """ 218 | Run a model and stream its output asynchronously. 219 | """ 220 | 221 | params = params or {} 222 | params["stream"] = True 223 | 224 | version, owner, name, version_id = identifier._resolve(ref) 225 | 226 | if version or version_id: 227 | prediction = await client.predictions.async_create( 228 | version=(version or version_id), input=input or {}, **params 229 | ) 230 | elif owner and name: 231 | prediction = await client.models.predictions.async_create( 232 | model=(owner, name), input=input or {}, **params 233 | ) 234 | else: 235 | raise ValueError( 236 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 237 | ) 238 | 239 | url = prediction.urls and prediction.urls.get("stream", None) 240 | if not url or not isinstance(url, str): 241 | raise ReplicateError("Model does not support streaming") 242 | 243 | headers = {} 244 | headers["Accept"] = "text/event-stream" 245 | headers["Cache-Control"] = "no-store" 246 | 247 | async with client._async_client.stream("GET", url, headers=headers) as response: 248 | async for event in EventSource(response): 249 | yield event 250 | 251 | 252 | __all__ = ["ServerSentEvent"] 253 | -------------------------------------------------------------------------------- /tests/cassettes/predictions-get.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: "" 4 | headers: 5 | accept: 6 | - "*/*" 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.11.0 15 | method: GET 16 | uri: https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu 17 | response: 18 | content: 19 | "{\"id\":\"vgcm4plb7tgzlyznry5d5jkgvu\",\"model\":\"stability-ai/sdxl\",\"version\":\"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b\",\"input\":{\"height\":512,\"prompt\":\"a 20 | studio photo of a rainbow colored corgi\",\"seed\":42069,\"width\":512},\"logs\":\"Using 21 | seed: 42069\\nPrompt: a studio photo of a rainbow colored corgi\\ntxt2img mode\\n 22 | \ 0%| | 0/50 [00:00\\u003c?, ?it/s]\\n 4%|\u258D | 2/50 [00:00\\u003c00:02, 23 | 16.47it/s]\\n 8%|\u258A | 4/50 [00:00\\u003c00:02, 16.39it/s]\\n 12%|\u2588\u258F 24 | \ | 6/50 [00:00\\u003c00:02, 16.60it/s]\\n 16%|\u2588\u258C | 8/50 25 | [00:00\\u003c00:02, 16.53it/s]\\n 20%|\u2588\u2588 | 10/50 [00:00\\u003c00:02, 26 | 16.76it/s]\\n 24%|\u2588\u2588\u258D | 12/50 [00:00\\u003c00:02, 16.93it/s]\\n 27 | 28%|\u2588\u2588\u258A | 14/50 [00:00\\u003c00:02, 17.04it/s]\\n 32%|\u2588\u2588\u2588\u258F 28 | \ | 16/50 [00:00\\u003c00:01, 17.10it/s]\\n 36%|\u2588\u2588\u2588\u258C 29 | \ | 18/50 [00:01\\u003c00:01, 17.12it/s]\\n 40%|\u2588\u2588\u2588\u2588 30 | \ | 20/50 [00:01\\u003c00:01, 17.15it/s]\\n 44%|\u2588\u2588\u2588\u2588\u258D 31 | \ | 22/50 [00:01\\u003c00:01, 17.16it/s]\\n 48%|\u2588\u2588\u2588\u2588\u258A 32 | \ | 24/50 [00:01\\u003c00:01, 17.17it/s]\\n 52%|\u2588\u2588\u2588\u2588\u2588\u258F 33 | \ | 26/50 [00:01\\u003c00:01, 17.20it/s]\\n 56%|\u2588\u2588\u2588\u2588\u2588\u258C 34 | \ | 28/50 [00:01\\u003c00:01, 17.21it/s]\\n 60%|\u2588\u2588\u2588\u2588\u2588\u2588 35 | \ | 30/50 [00:01\\u003c00:01, 17.19it/s]\\n 64%|\u2588\u2588\u2588\u2588\u2588\u2588\u258D 36 | \ | 32/50 [00:01\\u003c00:01, 17.18it/s]\\n 68%|\u2588\u2588\u2588\u2588\u2588\u2588\u258A 37 | \ | 34/50 [00:01\\u003c00:00, 17.18it/s]\\n 72%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258F 38 | \ | 36/50 [00:02\\u003c00:00, 17.20it/s]\\n 76%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258C 39 | \ | 38/50 [00:02\\u003c00:00, 17.21it/s]\\n 80%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 40 | \ | 40/50 [00:02\\u003c00:00, 17.19it/s]\\n 84%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258D 41 | | 42/50 [00:02\\u003c00:00, 17.19it/s]\\n 88%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258A 42 | | 44/50 [00:02\\u003c00:00, 17.19it/s]\\n 92%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258F| 43 | 46/50 [00:02\\u003c00:00, 17.20it/s]\\n 96%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258C| 44 | 48/50 [00:02\\u003c00:00, 17.22it/s]\\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 45 | 50/50 [00:02\\u003c00:00, 17.19it/s]\\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 46 | 50/50 [00:02\\u003c00:00, 17.09it/s]\\n\",\"output\":[\"https://replicate.delivery/pbxt/9inf36wjsEWuQ6XTf84iezPftv9QZdePfGySnU5tUai3BOrWE/out-0.png\"],\"error\":null,\"status\":\"succeeded\",\"created_at\":\"2023-08-16T18:57:08.360785Z\",\"started_at\":\"2023-08-16T18:57:08.366092Z\",\"completed_at\":\"2023-08-16T18:57:12.17042Z\",\"metrics\":{\"predict_time\":3.804328},\"urls\":{\"cancel\":\"https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu/cancel\",\"get\":\"https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu\"}}\n" 47 | headers: 48 | CF-Cache-Status: 49 | - DYNAMIC 50 | CF-RAY: 51 | - 7f7be8b17b47f8d1-SEA 52 | Connection: 53 | - keep-alive 54 | Content-Encoding: 55 | - gzip 56 | Content-Type: 57 | - application/json 58 | Date: 59 | - Wed, 16 Aug 2023 18:58:28 GMT 60 | NEL: 61 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 62 | Report-To: 63 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=UuflCrg7N4NAE6re4SQXd4aIWksjHl5BGKMFC3j9Rh9twIDDFOZzdAiU%2F%2FcWNM%2FofzDCdfSG628ZUoRTySoOZY04dhbVtbL4FCJ6YCsfEfkB%2B282Tfjs0VSoavvJmvBcSN%2B0"}],"group":"cf-nel","max_age":604800}' 64 | Server: 65 | - cloudflare 66 | Strict-Transport-Security: 67 | - max-age=15552000 68 | Transfer-Encoding: 69 | - chunked 70 | ratelimit-remaining: 71 | - "59999" 72 | ratelimit-reset: 73 | - "1" 74 | via: 75 | - 1.1 google 76 | http_version: HTTP/1.1 77 | status_code: 200 78 | - request: 79 | body: "" 80 | headers: 81 | accept: 82 | - "*/*" 83 | accept-encoding: 84 | - gzip, deflate 85 | connection: 86 | - keep-alive 87 | host: 88 | - api.replicate.com 89 | user-agent: 90 | - replicate-python/0.11.0 91 | method: GET 92 | uri: https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu 93 | response: 94 | content: 95 | "{\"completed_at\":\"2023-08-16T18:57:12.170420Z\",\"created_at\":\"2023-08-16T18:57:08.394251Z\",\"error\":null,\"id\":\"vgcm4plb7tgzlyznry5d5jkgvu\",\"input\":{\"seed\":42069,\"width\":512,\"height\":512,\"prompt\":\"a 96 | studio photo of a rainbow colored corgi\"},\"logs\":\"Using seed: 42069\\nPrompt: 97 | a studio photo of a rainbow colored corgi\\ntxt2img mode\\n 0%| | 98 | 0/50 [00:00 **👋** Check out an interactive version of this tutorial on [Google Colab](https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c). 6 | > 7 | > [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c) 8 | 9 | 10 | ## Install 11 | 12 | ```sh 13 | pip install replicate 14 | ``` 15 | 16 | ## Authenticate 17 | 18 | Before running any Python scripts that use the API, you need to set your Replicate API token in your environment. 19 | 20 | Grab your token from [replicate.com/account](https://replicate.com/account) and set it as an environment variable: 21 | 22 | ``` 23 | export REPLICATE_API_TOKEN= 24 | ``` 25 | 26 | We recommend not adding the token directly to your source code, because you don't want to put your credentials in source control. If anyone used your API key, their usage would be charged to your account. 27 | 28 | ## Run a model 29 | 30 | Create a new Python file and add the following code, replacing the model identifier and input with your own: 31 | 32 | ```python 33 | >>> import replicate 34 | >>> replicate.run( 35 | "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", 36 | input={"prompt": "a 19th century portrait of a wombat gentleman"} 37 | ) 38 | 39 | ['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png'] 40 | ``` 41 | 42 | Some models, particularly language models, may not require the version string. Refer to the API documentation for the model for more on the specifics: 43 | 44 | ```python 45 | replicate.run( 46 | "meta/meta-llama-3-70b-instruct", 47 | input={ 48 | "prompt": "Can you write a poem about open source machine learning?", 49 | "system_prompt": "You are a helpful, respectful and honest assistant.", 50 | }, 51 | ) 52 | ``` 53 | 54 | Some models, like [andreasjansson/blip-2](https://replicate.com/andreasjansson/blip-2), have files as inputs. 55 | To run a model that takes a file input, 56 | pass a URL to a publicly accessible file. 57 | Or, for smaller files (<10MB), you can pass a file handle directly. 58 | 59 | ```python 60 | >>> output = replicate.run( 61 | "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9", 62 | input={ "image": open("path/to/mystery.jpg") } 63 | ) 64 | 65 | "an astronaut riding a horse" 66 | ``` 67 | 68 | > [!NOTE] 69 | > You can also use the Replicate client asynchronously by prepending `async_` to the method name. 70 | > 71 | > Here's an example of how to run several predictions concurrently and wait for them all to complete: 72 | > 73 | > ```python 74 | > import asyncio 75 | > import replicate 76 | > 77 | > # https://replicate.com/stability-ai/sdxl 78 | > model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" 79 | > prompts = [ 80 | > f"A chariot pulled by a team of {count} rainbow unicorns" 81 | > for count in ["two", "four", "six", "eight"] 82 | > ] 83 | > 84 | > async with asyncio.TaskGroup() as tg: 85 | > tasks = [ 86 | > tg.create_task(replicate.async_run(model_version, input={"prompt": prompt})) 87 | > for prompt in prompts 88 | > ] 89 | > 90 | > results = await asyncio.gather(*tasks) 91 | > print(results) 92 | > ``` 93 | 94 | ## Run a model and stream its output 95 | 96 | Replicate’s API supports server-sent event streams (SSEs) for language models. 97 | Use the `stream` method to consume tokens as they're produced by the model. 98 | 99 | ```python 100 | import replicate 101 | 102 | for event in replicate.stream( 103 | "meta/meta-llama-3-70b-instruct", 104 | input={ 105 | "prompt": "Please write a haiku about llamas.", 106 | }, 107 | ): 108 | print(str(event), end="") 109 | ``` 110 | 111 | You can also stream the output of a prediction you create. 112 | This is helpful when you want the ID of the prediction separate from its output. 113 | 114 | ```python 115 | version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" 116 | prediction = replicate.predictions.create( 117 | version=version, 118 | input={"prompt": "Please write a haiku about llamas."}, 119 | stream=True, 120 | ) 121 | 122 | for event in prediction.stream(): 123 | print(str(event), end="") 124 | ``` 125 | 126 | For more information, see 127 | ["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs. 128 | 129 | 130 | ## Run a model in the background 131 | 132 | You can start a model and run it in the background: 133 | 134 | ```python 135 | >>> model = replicate.models.get("kvfrans/clipdraw") 136 | >>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") 137 | >>> prediction = replicate.predictions.create( 138 | version=version, 139 | input={"prompt":"Watercolor painting of an underwater submarine"}) 140 | 141 | >>> prediction 142 | Prediction(...) 143 | 144 | >>> prediction.status 145 | 'starting' 146 | 147 | >>> dict(prediction) 148 | {"id": "...", "status": "starting", ...} 149 | 150 | >>> prediction.reload() 151 | >>> prediction.status 152 | 'processing' 153 | 154 | >>> print(prediction.logs) 155 | iteration: 0, render:loss: -0.6171875 156 | iteration: 10, render:loss: -0.92236328125 157 | iteration: 20, render:loss: -1.197265625 158 | iteration: 30, render:loss: -1.3994140625 159 | 160 | >>> prediction.wait() 161 | 162 | >>> prediction.status 163 | 'succeeded' 164 | 165 | >>> prediction.output 166 | 'https://.../output.png' 167 | ``` 168 | 169 | ## Run a model in the background and get a webhook 170 | 171 | You can run a model and get a webhook when it completes, instead of waiting for it to finish: 172 | 173 | ```python 174 | model = replicate.models.get("ai-forever/kandinsky-2.2") 175 | version = model.versions.get("ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463") 176 | prediction = replicate.predictions.create( 177 | version=version, 178 | input={"prompt":"Watercolor painting of an underwater submarine"}, 179 | webhook="https://example.com/your-webhook", 180 | webhook_events_filter=["completed"] 181 | ) 182 | ``` 183 | 184 | For details on receiving webhooks, see [replicate.com/docs/webhooks](https://replicate.com/docs/webhooks). 185 | 186 | ## Compose models into a pipeline 187 | 188 | You can run a model and feed the output into another model: 189 | 190 | ```python 191 | laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05") 192 | swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a") 193 | image = laionide.predict(prompt="avocado armchair") 194 | upscaled_image = swinir.predict(image=image) 195 | ``` 196 | 197 | ## Get output from a running model 198 | 199 | Run a model and get its output while it's running: 200 | 201 | ```python 202 | iterator = replicate.run( 203 | "pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf", 204 | input={"prompts": "san francisco sunset"} 205 | ) 206 | 207 | for image in iterator: 208 | display(image) 209 | ``` 210 | 211 | ## Cancel a prediction 212 | 213 | You can cancel a running prediction: 214 | 215 | ```python 216 | >>> model = replicate.models.get("kvfrans/clipdraw") 217 | >>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") 218 | >>> prediction = replicate.predictions.create( 219 | version=version, 220 | input={"prompt":"Watercolor painting of an underwater submarine"} 221 | ) 222 | 223 | >>> prediction.status 224 | 'starting' 225 | 226 | >>> prediction.cancel() 227 | 228 | >>> prediction.reload() 229 | >>> prediction.status 230 | 'canceled' 231 | ``` 232 | 233 | ## List predictions 234 | 235 | You can list all the predictions you've run: 236 | 237 | ```python 238 | replicate.predictions.list() 239 | # [, ] 240 | ``` 241 | 242 | Lists of predictions are paginated. You can get the next page of predictions by passing the `next` property as an argument to the `list` method: 243 | 244 | ```python 245 | page1 = replicate.predictions.list() 246 | 247 | if page1.next: 248 | page2 = replicate.predictions.list(page1.next) 249 | ``` 250 | 251 | ## Load output files 252 | 253 | Output files are returned as HTTPS URLs. You can load an output file as a buffer: 254 | 255 | ```python 256 | import replicate 257 | from PIL import Image 258 | from urllib.request import urlretrieve 259 | 260 | out = replicate.run( 261 | "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", 262 | input={"prompt": "wavy colorful abstract patterns, oceans"} 263 | ) 264 | 265 | urlretrieve(out[0], "/tmp/out.png") 266 | background = Image.open("/tmp/out.png") 267 | ``` 268 | 269 | ## List models 270 | 271 | You can the models you've created: 272 | 273 | ```python 274 | replicate.models.list() 275 | ``` 276 | 277 | Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method, or you can use the `paginate` method to fetch pages automatically. 278 | 279 | ```python 280 | # Automatic pagination using `replicate.paginate` (recommended) 281 | models = [] 282 | for page in replicate.paginate(replicate.models.list): 283 | models.extend(page.results) 284 | if len(models) > 100: 285 | break 286 | 287 | # Manual pagination using `next` cursors 288 | page = replicate.models.list() 289 | while page: 290 | models.extend(page.results) 291 | if len(models) > 100: 292 | break 293 | page = replicate.models.list(page.next) if page.next else None 294 | ``` 295 | 296 | You can also find collections of featured models on Replicate: 297 | 298 | ```python 299 | >>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page] 300 | >>> collections[0].slug 301 | "vision-models" 302 | >>> collections[0].description 303 | "Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)" 304 | 305 | >>> replicate.collections.get("text-to-image").models 306 | [, ...] 307 | ``` 308 | 309 | ## Create a model 310 | 311 | You can create a model for a user or organization 312 | with a given name, visibility, and hardware SKU: 313 | 314 | ```python 315 | import replicate 316 | 317 | model = replicate.models.create( 318 | owner="your-username", 319 | name="my-model", 320 | visibility="public", 321 | hardware="gpu-a40-large" 322 | ) 323 | ``` 324 | 325 | Here's how to list of all the available hardware for running models on Replicate: 326 | 327 | ```python 328 | >>> [hw.sku for hw in replicate.hardware.list()] 329 | ['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large'] 330 | ``` 331 | 332 | ## Fine-tune a model 333 | 334 | Use the [training API](https://replicate.com/docs/fine-tuning) 335 | to fine-tune models to make them better at a particular task. 336 | To see what **language models** currently support fine-tuning, 337 | check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). 338 | 339 | If you're looking to fine-tune **image models**, 340 | check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model). 341 | 342 | Here's how to fine-tune a model on Replicate: 343 | 344 | ```python 345 | training = replicate.trainings.create( 346 | model="stability-ai/sdxl", 347 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 348 | input={ 349 | "input_images": "https://my-domain/training-images.zip", 350 | "token_string": "TOK", 351 | "caption_prefix": "a photo of TOK", 352 | "max_train_steps": 1000, 353 | "use_face_detection_instead": False 354 | }, 355 | # You need to create a model on Replicate that will be the destination for the trained version. 356 | destination="your-username/model-name" 357 | ) 358 | ``` 359 | 360 | ## Development 361 | 362 | See [CONTRIBUTING.md](CONTRIBUTING.md) 363 | -------------------------------------------------------------------------------- /replicate/model.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union 2 | 3 | from typing_extensions import NotRequired, TypedDict, Unpack, deprecated 4 | 5 | from replicate.exceptions import ReplicateException 6 | from replicate.identifier import ModelVersionIdentifier 7 | from replicate.pagination import Page 8 | from replicate.prediction import ( 9 | Prediction, 10 | _create_prediction_body, 11 | _json_to_prediction, 12 | ) 13 | from replicate.resource import Namespace, Resource 14 | from replicate.version import Version, Versions 15 | 16 | try: 17 | from pydantic import v1 as pydantic # type: ignore 18 | except ImportError: 19 | import pydantic # type: ignore 20 | 21 | 22 | if TYPE_CHECKING: 23 | from replicate.client import Client 24 | from replicate.prediction import Predictions 25 | 26 | 27 | class Model(Resource): 28 | """ 29 | A machine learning model hosted on Replicate. 30 | """ 31 | 32 | _client: "Client" = pydantic.PrivateAttr() 33 | 34 | url: str 35 | """ 36 | The URL of the model. 37 | """ 38 | 39 | owner: str 40 | """ 41 | The owner of the model. 42 | """ 43 | 44 | name: str 45 | """ 46 | The name of the model. 47 | """ 48 | 49 | description: Optional[str] 50 | """ 51 | The description of the model. 52 | """ 53 | 54 | visibility: Literal["public", "private"] 55 | """ 56 | The visibility of the model. Can be 'public' or 'private'. 57 | """ 58 | 59 | github_url: Optional[str] 60 | """ 61 | The GitHub URL of the model. 62 | """ 63 | 64 | paper_url: Optional[str] 65 | """ 66 | The URL of the paper related to the model. 67 | """ 68 | 69 | license_url: Optional[str] 70 | """ 71 | The URL of the license for the model. 72 | """ 73 | 74 | run_count: int 75 | """ 76 | The number of runs of the model. 77 | """ 78 | 79 | cover_image_url: Optional[str] 80 | """ 81 | The URL of the cover image for the model. 82 | """ 83 | 84 | default_example: Optional[Prediction] 85 | """ 86 | The default example of the model. 87 | """ 88 | 89 | latest_version: Optional[Version] 90 | """ 91 | The latest version of the model. 92 | """ 93 | 94 | @property 95 | def id(self) -> str: 96 | """ 97 | Return the qualified model name, in the format `owner/name`. 98 | """ 99 | return f"{self.owner}/{self.name}" 100 | 101 | @property 102 | @deprecated("Use `model.owner` instead.") 103 | def username(self) -> str: 104 | """ 105 | The name of the user or organization that owns the model. 106 | This attribute is deprecated and will be removed in future versions. 107 | """ 108 | return self.owner 109 | 110 | @username.setter 111 | @deprecated("Use `model.owner` instead.") 112 | def username(self, value: str) -> None: 113 | self.owner = value 114 | 115 | def predict(self, *args, **kwargs) -> None: 116 | """ 117 | DEPRECATED: Use `replicate.run()` instead. 118 | """ 119 | 120 | raise ReplicateException( 121 | "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `replicate.run()` instead. See https://github.com/replicate/replicate-python#readme" 122 | ) 123 | 124 | @property 125 | def versions(self) -> Versions: 126 | """ 127 | Get the versions of this model. 128 | """ 129 | 130 | return Versions(client=self._client, model=self) 131 | 132 | def reload(self) -> None: 133 | """ 134 | Load this object from the server. 135 | """ 136 | 137 | obj = self._client.models.get(f"{self.owner}/{self.name}") 138 | for name, value in obj.dict().items(): 139 | setattr(self, name, value) 140 | 141 | 142 | class Models(Namespace): 143 | """ 144 | Namespace for operations related to models. 145 | """ 146 | 147 | model = Model 148 | 149 | @property 150 | def predictions(self) -> "ModelsPredictions": 151 | """ 152 | Get a namespace for operations related to predictions on a model. 153 | """ 154 | 155 | return ModelsPredictions(client=self._client) 156 | 157 | def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821 158 | """ 159 | List all public models. 160 | 161 | Parameters: 162 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 163 | Returns: 164 | Page[Model]: A page of of models. 165 | Raises: 166 | ValueError: If `cursor` is `None`. 167 | """ 168 | 169 | if cursor is None: 170 | raise ValueError("cursor cannot be None") 171 | 172 | resp = self._client._request("GET", "/v1/models" if cursor is ... else cursor) 173 | 174 | obj = resp.json() 175 | obj["results"] = [ 176 | _json_to_model(self._client, result) for result in obj["results"] 177 | ] 178 | 179 | return Page[Model](**obj) 180 | 181 | async def async_list( 182 | self, 183 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 184 | ) -> Page[Model]: 185 | """ 186 | List all public models. 187 | 188 | Parameters: 189 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 190 | Returns: 191 | Page[Model]: A page of of models. 192 | Raises: 193 | ValueError: If `cursor` is `None`. 194 | """ 195 | 196 | if cursor is None: 197 | raise ValueError("cursor cannot be None") 198 | 199 | resp = await self._client._async_request( 200 | "GET", "/v1/models" if cursor is ... else cursor 201 | ) 202 | 203 | obj = resp.json() 204 | obj["results"] = [ 205 | _json_to_model(self._client, result) for result in obj["results"] 206 | ] 207 | 208 | return Page[Model](**obj) 209 | 210 | def get(self, key: str) -> Model: 211 | """ 212 | Get a model by name. 213 | 214 | Args: 215 | key: The qualified name of the model, in the format `owner/model-name`. 216 | Returns: 217 | The model. 218 | """ 219 | 220 | resp = self._client._request("GET", f"/v1/models/{key}") 221 | 222 | return _json_to_model(self._client, resp.json()) 223 | 224 | async def async_get(self, key: str) -> Model: 225 | """ 226 | Get a model by name. 227 | 228 | Args: 229 | key: The qualified name of the model, in the format `owner/model-name`. 230 | Returns: 231 | The model. 232 | """ 233 | 234 | resp = await self._client._async_request("GET", f"/v1/models/{key}") 235 | 236 | return _json_to_model(self._client, resp.json()) 237 | 238 | class CreateModelParams(TypedDict): 239 | """Parameters for creating a model.""" 240 | 241 | hardware: str 242 | """The SKU for the hardware used to run the model. 243 | 244 | Possible values can be found by calling `replicate.hardware.list()`.""" 245 | 246 | visibility: Literal["public", "private"] 247 | """Whether the model should be public or private.""" 248 | 249 | description: NotRequired[str] 250 | """The description of the model.""" 251 | 252 | github_url: NotRequired[str] 253 | """A URL for the model's source code on GitHub.""" 254 | 255 | paper_url: NotRequired[str] 256 | """A URL for the model's paper.""" 257 | 258 | license_url: NotRequired[str] 259 | """A URL for the model's license.""" 260 | 261 | cover_image_url: NotRequired[str] 262 | """A URL for the model's cover image.""" 263 | 264 | def create( 265 | self, 266 | owner: str, 267 | name: str, 268 | **params: Unpack["Models.CreateModelParams"], 269 | ) -> Model: 270 | """ 271 | Create a model. 272 | """ 273 | 274 | body = _create_model_body(owner, name, **params) 275 | resp = self._client._request("POST", "/v1/models", json=body) 276 | 277 | return _json_to_model(self._client, resp.json()) 278 | 279 | async def async_create( 280 | self, owner: str, name: str, **params: Unpack["Models.CreateModelParams"] 281 | ) -> Model: 282 | """ 283 | Create a model. 284 | """ 285 | 286 | body = body = _create_model_body(owner, name, **params) 287 | resp = await self._client._async_request("POST", "/v1/models", json=body) 288 | 289 | return _json_to_model(self._client, resp.json()) 290 | 291 | 292 | class ModelsPredictions(Namespace): 293 | """ 294 | Namespace for operations related to predictions in a deployment. 295 | """ 296 | 297 | def create( 298 | self, 299 | model: Union[str, Tuple[str, str], "Model"], 300 | input: Dict[str, Any], 301 | **params: Unpack["Predictions.CreatePredictionParams"], 302 | ) -> Prediction: 303 | """ 304 | Create a new prediction with the deployment. 305 | """ 306 | 307 | url = _create_prediction_url_from_model(model) 308 | body = _create_prediction_body(version=None, input=input, **params) 309 | 310 | resp = self._client._request( 311 | "POST", 312 | url, 313 | json=body, 314 | ) 315 | 316 | return _json_to_prediction(self._client, resp.json()) 317 | 318 | async def async_create( 319 | self, 320 | model: Union[str, Tuple[str, str], "Model"], 321 | input: Dict[str, Any], 322 | **params: Unpack["Predictions.CreatePredictionParams"], 323 | ) -> Prediction: 324 | """ 325 | Create a new prediction with the deployment. 326 | """ 327 | 328 | url = _create_prediction_url_from_model(model) 329 | body = _create_prediction_body(version=None, input=input, **params) 330 | 331 | resp = await self._client._async_request( 332 | "POST", 333 | url, 334 | json=body, 335 | ) 336 | 337 | return _json_to_prediction(self._client, resp.json()) 338 | 339 | 340 | def _create_model_body( # pylint: disable=too-many-arguments 341 | owner: str, 342 | name: str, 343 | *, 344 | visibility: str, 345 | hardware: str, 346 | description: Optional[str] = None, 347 | github_url: Optional[str] = None, 348 | paper_url: Optional[str] = None, 349 | license_url: Optional[str] = None, 350 | cover_image_url: Optional[str] = None, 351 | ) -> Dict[str, Any]: 352 | body = { 353 | "owner": owner, 354 | "name": name, 355 | "visibility": visibility, 356 | "hardware": hardware, 357 | } 358 | 359 | if description is not None: 360 | body["description"] = description 361 | 362 | if github_url is not None: 363 | body["github_url"] = github_url 364 | 365 | if paper_url is not None: 366 | body["paper_url"] = paper_url 367 | 368 | if license_url is not None: 369 | body["license_url"] = license_url 370 | 371 | if cover_image_url is not None: 372 | body["cover_image_url"] = cover_image_url 373 | 374 | return body 375 | 376 | 377 | def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: 378 | model = Model(**json) 379 | model._client = client 380 | if model.default_example is not None: 381 | model.default_example._client = client 382 | return model 383 | 384 | 385 | def _create_prediction_url_from_model( 386 | model: Union[str, Tuple[str, str], "Model"], 387 | ) -> str: 388 | owner, name = None, None 389 | if isinstance(model, Model): 390 | owner, name = model.owner, model.name 391 | elif isinstance(model, tuple): 392 | owner, name = model[0], model[1] 393 | elif isinstance(model, str): 394 | owner, name, version_id = ModelVersionIdentifier.parse(model) 395 | if version_id is not None: 396 | raise ValueError( 397 | f"Invalid reference to model version: {model}. Expected model or reference in the format owner/name" 398 | ) 399 | 400 | if owner is None or name is None: 401 | raise ValueError( 402 | "model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'" 403 | ) 404 | 405 | return f"/v1/models/{owner}/{name}/predictions" 406 | -------------------------------------------------------------------------------- /replicate/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import random 4 | import time 5 | from datetime import datetime 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Any, 9 | AsyncIterator, 10 | Dict, 11 | Iterable, 12 | Iterator, 13 | Mapping, 14 | Optional, 15 | Type, 16 | Union, 17 | ) 18 | 19 | import httpx 20 | from typing_extensions import Unpack 21 | 22 | from replicate.__about__ import __version__ 23 | from replicate.account import Accounts 24 | from replicate.collection import Collections 25 | from replicate.deployment import Deployments 26 | from replicate.exceptions import ReplicateError 27 | from replicate.hardware import HardwareNamespace as Hardware 28 | from replicate.model import Models 29 | from replicate.prediction import Predictions 30 | from replicate.run import async_run, run 31 | from replicate.stream import async_stream, stream 32 | from replicate.training import Trainings 33 | 34 | if TYPE_CHECKING: 35 | from replicate.stream import ServerSentEvent 36 | 37 | 38 | class Client: 39 | """A Replicate API client library""" 40 | 41 | __client: Optional[httpx.Client] = None 42 | __async_client: Optional[httpx.AsyncClient] = None 43 | 44 | def __init__( 45 | self, 46 | api_token: Optional[str] = None, 47 | *, 48 | base_url: Optional[str] = None, 49 | timeout: Optional[httpx.Timeout] = None, 50 | **kwargs, 51 | ) -> None: 52 | super().__init__() 53 | 54 | self._api_token = api_token 55 | self._base_url = base_url 56 | self._timeout = timeout 57 | self._client_kwargs = kwargs 58 | 59 | self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) 60 | 61 | @property 62 | def _client(self) -> httpx.Client: 63 | if not self.__client: 64 | self.__client = _build_httpx_client( 65 | httpx.Client, 66 | self._api_token, 67 | self._base_url, 68 | self._timeout, 69 | **self._client_kwargs, 70 | ) # type: ignore[assignment] 71 | return self.__client # type: ignore[return-value] 72 | 73 | @property 74 | def _async_client(self) -> httpx.AsyncClient: 75 | if not self.__async_client: 76 | self.__async_client = _build_httpx_client( 77 | httpx.AsyncClient, 78 | self._api_token, 79 | self._base_url, 80 | self._timeout, 81 | **self._client_kwargs, 82 | ) # type: ignore[assignment] 83 | return self.__async_client # type: ignore[return-value] 84 | 85 | def _request(self, method: str, path: str, **kwargs) -> httpx.Response: 86 | resp = self._client.request(method, path, **kwargs) 87 | _raise_for_status(resp) 88 | 89 | return resp 90 | 91 | async def _async_request(self, method: str, path: str, **kwargs) -> httpx.Response: 92 | resp = await self._async_client.request(method, path, **kwargs) 93 | _raise_for_status(resp) 94 | 95 | return resp 96 | 97 | @property 98 | def accounts(self) -> Accounts: 99 | """ 100 | Namespace for operations related to accounts. 101 | """ 102 | 103 | return Accounts(client=self) 104 | 105 | @property 106 | def collections(self) -> Collections: 107 | """ 108 | Namespace for operations related to collections of models. 109 | """ 110 | return Collections(client=self) 111 | 112 | @property 113 | def deployments(self) -> Deployments: 114 | """ 115 | Namespace for operations related to deployments. 116 | """ 117 | return Deployments(client=self) 118 | 119 | @property 120 | def hardware(self) -> Hardware: 121 | """ 122 | Namespace for operations related to hardware. 123 | """ 124 | return Hardware(client=self) 125 | 126 | @property 127 | def models(self) -> Models: 128 | """ 129 | Namespace for operations related to models. 130 | """ 131 | return Models(client=self) 132 | 133 | @property 134 | def predictions(self) -> Predictions: 135 | """ 136 | Namespace for operations related to predictions. 137 | """ 138 | return Predictions(client=self) 139 | 140 | @property 141 | def trainings(self) -> Trainings: 142 | """ 143 | Namespace for operations related to trainings. 144 | """ 145 | return Trainings(client=self) 146 | 147 | def run( 148 | self, 149 | ref: str, 150 | input: Optional[Dict[str, Any]] = None, 151 | **params: Unpack["Predictions.CreatePredictionParams"], 152 | ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 153 | """ 154 | Run a model and wait for its output. 155 | """ 156 | 157 | return run(self, ref, input, **params) 158 | 159 | async def async_run( 160 | self, 161 | ref: str, 162 | input: Optional[Dict[str, Any]] = None, 163 | **params: Unpack["Predictions.CreatePredictionParams"], 164 | ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 165 | """ 166 | Run a model and wait for its output asynchronously. 167 | """ 168 | 169 | return await async_run(self, ref, input, **params) 170 | 171 | def stream( 172 | self, 173 | ref: str, 174 | input: Optional[Dict[str, Any]] = None, 175 | **params: Unpack["Predictions.CreatePredictionParams"], 176 | ) -> Iterator["ServerSentEvent"]: 177 | """ 178 | Stream a model's output. 179 | """ 180 | 181 | return stream(self, ref, input, **params) 182 | 183 | async def async_stream( 184 | self, 185 | ref: str, 186 | input: Optional[Dict[str, Any]] = None, 187 | **params: Unpack["Predictions.CreatePredictionParams"], 188 | ) -> AsyncIterator["ServerSentEvent"]: 189 | """ 190 | Stream a model's output asynchronously. 191 | """ 192 | 193 | return async_stream(self, ref, input, **params) 194 | 195 | 196 | # Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 197 | class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport): 198 | """A custom HTTP transport that automatically retries requests using an exponential backoff strategy 199 | for specific HTTP status codes and request methods. 200 | """ 201 | 202 | RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]) 203 | RETRYABLE_STATUS_CODES = frozenset( 204 | [ 205 | 429, # Too Many Requests 206 | 503, # Service Unavailable 207 | 504, # Gateway Timeout 208 | ] 209 | ) 210 | MAX_BACKOFF_WAIT = 60 211 | 212 | def __init__( # pylint: disable=too-many-arguments 213 | self, 214 | wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], 215 | *, 216 | max_attempts: int = 10, 217 | max_backoff_wait: float = MAX_BACKOFF_WAIT, 218 | backoff_factor: float = 0.1, 219 | jitter_ratio: float = 0.1, 220 | retryable_methods: Optional[Iterable[str]] = None, 221 | retry_status_codes: Optional[Iterable[int]] = None, 222 | ) -> None: 223 | self._wrapped_transport = wrapped_transport 224 | 225 | if jitter_ratio < 0 or jitter_ratio > 0.5: 226 | raise ValueError( 227 | f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}" 228 | ) 229 | 230 | self.max_attempts = max_attempts 231 | self.backoff_factor = backoff_factor 232 | self.retryable_methods = ( 233 | frozenset(retryable_methods) 234 | if retryable_methods 235 | else self.RETRYABLE_METHODS 236 | ) 237 | self.retry_status_codes = ( 238 | frozenset(retry_status_codes) 239 | if retry_status_codes 240 | else self.RETRYABLE_STATUS_CODES 241 | ) 242 | self.jitter_ratio = jitter_ratio 243 | self.max_backoff_wait = max_backoff_wait 244 | 245 | def _calculate_sleep( 246 | self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]] 247 | ) -> float: 248 | retry_after_header = (headers.get("Retry-After") or "").strip() 249 | if retry_after_header: 250 | if retry_after_header.isdigit(): 251 | return float(retry_after_header) 252 | 253 | try: 254 | parsed_date = datetime.fromisoformat(retry_after_header).astimezone() 255 | diff = (parsed_date - datetime.now().astimezone()).total_seconds() 256 | if diff > 0: 257 | return min(diff, self.max_backoff_wait) 258 | except ValueError: 259 | pass 260 | 261 | backoff = self.backoff_factor * (2 ** (attempts_made - 1)) 262 | jitter = (backoff * self.jitter_ratio) * random.choice([1, -1]) # noqa: S311 263 | total_backoff = backoff + jitter 264 | return min(total_backoff, self.max_backoff_wait) 265 | 266 | def handle_request(self, request: httpx.Request) -> httpx.Response: 267 | response = self._wrapped_transport.handle_request(request) # type: ignore 268 | 269 | if request.method not in self.retryable_methods: 270 | return response 271 | 272 | remaining_attempts = self.max_attempts - 1 273 | attempts_made = 1 274 | 275 | while True: 276 | if ( 277 | remaining_attempts < 1 278 | or response.status_code not in self.retry_status_codes 279 | ): 280 | return response 281 | 282 | sleep_for = self._calculate_sleep(attempts_made, response.headers) 283 | time.sleep(sleep_for) 284 | 285 | response = self._wrapped_transport.handle_request(request) # type: ignore 286 | 287 | attempts_made += 1 288 | remaining_attempts -= 1 289 | 290 | async def handle_async_request(self, request: httpx.Request) -> httpx.Response: 291 | response = await self._wrapped_transport.handle_async_request(request) # type: ignore 292 | 293 | if request.method not in self.retryable_methods: 294 | return response 295 | 296 | remaining_attempts = self.max_attempts - 1 297 | attempts_made = 1 298 | 299 | while True: 300 | if ( 301 | remaining_attempts < 1 302 | or response.status_code not in self.retry_status_codes 303 | ): 304 | return response 305 | 306 | sleep_for = self._calculate_sleep(attempts_made, response.headers) 307 | await asyncio.sleep(sleep_for) 308 | 309 | response = await self._wrapped_transport.handle_async_request(request) # type: ignore 310 | 311 | attempts_made += 1 312 | remaining_attempts -= 1 313 | 314 | def close(self) -> None: 315 | self._wrapped_transport.close() # type: ignore 316 | 317 | async def aclose(self) -> None: 318 | await self._wrapped_transport.aclose() # type: ignore 319 | 320 | 321 | def _build_httpx_client( 322 | client_type: Type[Union[httpx.Client, httpx.AsyncClient]], 323 | api_token: Optional[str] = None, 324 | base_url: Optional[str] = None, 325 | timeout: Optional[httpx.Timeout] = None, 326 | **kwargs, 327 | ) -> Union[httpx.Client, httpx.AsyncClient]: 328 | headers = kwargs.pop("headers", {}) 329 | headers["User-Agent"] = f"replicate-python/{__version__}" 330 | 331 | if ( 332 | api_token := api_token or os.environ.get("REPLICATE_API_TOKEN") 333 | ) and api_token != "": 334 | headers["Authorization"] = f"Bearer {api_token}" 335 | 336 | base_url = ( 337 | base_url or os.environ.get("REPLICATE_BASE_URL") or "https://api.replicate.com" 338 | ) 339 | if base_url == "": 340 | base_url = "https://api.replicate.com" 341 | 342 | timeout = timeout or httpx.Timeout( 343 | 5.0, read=30.0, write=30.0, connect=5.0, pool=10.0 344 | ) 345 | 346 | transport = kwargs.pop("transport", None) or ( 347 | httpx.HTTPTransport() 348 | if client_type is httpx.Client 349 | else httpx.AsyncHTTPTransport() 350 | ) 351 | 352 | return client_type( 353 | base_url=base_url, 354 | headers=headers, 355 | timeout=timeout, 356 | transport=RetryTransport(wrapped_transport=transport), # type: ignore[arg-type] 357 | **kwargs, 358 | ) 359 | 360 | 361 | def _raise_for_status(resp: httpx.Response) -> None: 362 | if 400 <= resp.status_code < 600: 363 | raise ReplicateError.from_response(resp) 364 | -------------------------------------------------------------------------------- /replicate/training.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Dict, 5 | List, 6 | Literal, 7 | Optional, 8 | Tuple, 9 | TypedDict, 10 | Union, 11 | overload, 12 | ) 13 | 14 | from typing_extensions import NotRequired, Unpack 15 | 16 | from replicate.files import upload_file 17 | from replicate.identifier import ModelVersionIdentifier 18 | from replicate.json import encode_json 19 | from replicate.model import Model 20 | from replicate.pagination import Page 21 | from replicate.resource import Namespace, Resource 22 | from replicate.version import Version 23 | 24 | try: 25 | from pydantic import v1 as pydantic # type: ignore 26 | except ImportError: 27 | import pydantic # type: ignore 28 | 29 | if TYPE_CHECKING: 30 | from replicate.client import Client 31 | 32 | 33 | class Training(Resource): 34 | """ 35 | A training made for a model hosted on Replicate. 36 | """ 37 | 38 | _client: "Client" = pydantic.PrivateAttr() 39 | 40 | id: str 41 | """The unique ID of the training.""" 42 | 43 | model: str 44 | """An identifier for the model used to create the prediction, in the form `owner/name`.""" 45 | 46 | version: Union[str, Version] 47 | """The version of the model used to create the training.""" 48 | 49 | destination: Optional[str] 50 | """The model destination of the training.""" 51 | 52 | status: Literal["starting", "processing", "succeeded", "failed", "canceled"] 53 | """The status of the training.""" 54 | 55 | input: Optional[Dict[str, Any]] 56 | """The input to the training.""" 57 | 58 | output: Optional[Any] 59 | """The output of the training.""" 60 | 61 | logs: Optional[str] 62 | """The logs of the training.""" 63 | 64 | error: Optional[str] 65 | """The error encountered during the training, if any.""" 66 | 67 | created_at: Optional[str] 68 | """When the training was created.""" 69 | 70 | started_at: Optional[str] 71 | """When the training was started.""" 72 | 73 | completed_at: Optional[str] 74 | """When the training was completed, if finished.""" 75 | 76 | urls: Optional[Dict[str, str]] 77 | """ 78 | URLs associated with the training. 79 | 80 | The following keys are available: 81 | - `get`: A URL to fetch the training. 82 | - `cancel`: A URL to cancel the training. 83 | """ 84 | 85 | def cancel(self) -> None: 86 | """ 87 | Cancel a running training. 88 | """ 89 | 90 | canceled = self._client.trainings.cancel(self.id) 91 | for name, value in canceled.dict().items(): 92 | setattr(self, name, value) 93 | 94 | async def async_cancel(self) -> None: 95 | """ 96 | Cancel a running training asynchronously. 97 | """ 98 | 99 | canceled = await self._client.trainings.async_cancel(self.id) 100 | for name, value in canceled.dict().items(): 101 | setattr(self, name, value) 102 | 103 | def reload(self) -> None: 104 | """ 105 | Load the training from the server. 106 | """ 107 | 108 | updated = self._client.trainings.get(self.id) 109 | for name, value in updated.dict().items(): 110 | setattr(self, name, value) 111 | 112 | async def async_reload(self) -> None: 113 | """ 114 | Load the training from the server asynchronously. 115 | """ 116 | 117 | updated = await self._client.trainings.async_get(self.id) 118 | for name, value in updated.dict().items(): 119 | setattr(self, name, value) 120 | 121 | 122 | class Trainings(Namespace): 123 | """ 124 | Namespace for operations related to trainings. 125 | """ 126 | 127 | def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Training]: # noqa: F821 128 | """ 129 | List your trainings. 130 | 131 | Parameters: 132 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 133 | Returns: 134 | Page[Training]: A page of trainings. 135 | Raises: 136 | ValueError: If `cursor` is `None`. 137 | """ 138 | 139 | if cursor is None: 140 | raise ValueError("cursor cannot be None") 141 | 142 | resp = self._client._request( 143 | "GET", "/v1/trainings" if cursor is ... else cursor 144 | ) 145 | 146 | obj = resp.json() 147 | obj["results"] = [ 148 | _json_to_training(self._client, result) for result in obj["results"] 149 | ] 150 | 151 | return Page[Training](**obj) 152 | 153 | async def async_list( 154 | self, 155 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 156 | ) -> Page[Training]: 157 | """ 158 | List your trainings. 159 | 160 | Parameters: 161 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 162 | Returns: 163 | Page[Training]: A page of trainings. 164 | Raises: 165 | ValueError: If `cursor` is `None`. 166 | """ 167 | 168 | if cursor is None: 169 | raise ValueError("cursor cannot be None") 170 | 171 | resp = await self._client._async_request( 172 | "GET", "/v1/trainings" if cursor is ... else cursor 173 | ) 174 | 175 | obj = resp.json() 176 | obj["results"] = [ 177 | _json_to_training(self._client, result) for result in obj["results"] 178 | ] 179 | 180 | return Page[Training](**obj) 181 | 182 | def get(self, id: str) -> Training: 183 | """ 184 | Get a training by ID. 185 | 186 | Args: 187 | id: The ID of the training. 188 | Returns: 189 | Training: The training object. 190 | """ 191 | 192 | resp = self._client._request( 193 | "GET", 194 | f"/v1/trainings/{id}", 195 | ) 196 | 197 | return _json_to_training(self._client, resp.json()) 198 | 199 | async def async_get(self, id: str) -> Training: 200 | """ 201 | Get a training by ID. 202 | 203 | Args: 204 | id: The ID of the training. 205 | Returns: 206 | Training: The training object. 207 | """ 208 | 209 | resp = await self._client._async_request( 210 | "GET", 211 | f"/v1/trainings/{id}", 212 | ) 213 | 214 | return _json_to_training(self._client, resp.json()) 215 | 216 | class CreateTrainingParams(TypedDict): 217 | """Parameters for creating a training.""" 218 | 219 | destination: Union[str, Tuple[str, str], "Model"] 220 | webhook: NotRequired[str] 221 | webhook_completed: NotRequired[str] 222 | webhook_events_filter: NotRequired[List[str]] 223 | 224 | @overload 225 | def create( # pylint: disable=too-many-arguments 226 | self, 227 | version: str, 228 | input: Dict[str, Any], 229 | destination: str, 230 | webhook: Optional[str] = None, 231 | webhook_events_filter: Optional[List[str]] = None, 232 | **kwargs, 233 | ) -> Training: ... 234 | 235 | @overload 236 | def create( 237 | self, 238 | model: Union[str, Tuple[str, str], "Model"], 239 | version: Union[str, Version], 240 | input: Optional[Dict[str, Any]] = None, 241 | **params: Unpack["Trainings.CreateTrainingParams"], 242 | ) -> Training: ... 243 | 244 | def create( # type: ignore 245 | self, 246 | *args, 247 | model: Optional[Union[str, Tuple[str, str], "Model"]] = None, 248 | version: Optional[Union[str, Version]] = None, 249 | input: Optional[Dict[str, Any]] = None, 250 | **params: Unpack["Trainings.CreateTrainingParams"], 251 | ) -> Training: 252 | """ 253 | Create a new training using the specified model version as a base. 254 | """ 255 | 256 | url = None 257 | 258 | # Support positional arguments for backwards compatibility 259 | if args: 260 | if shorthand := args[0] if len(args) > 0 else None: 261 | url = _create_training_url_from_shorthand(shorthand) 262 | 263 | input = args[1] if len(args) > 1 else input 264 | if len(args) > 2: 265 | params["destination"] = args[2] 266 | if len(args) > 3: 267 | params["webhook"] = args[3] 268 | if len(args) > 4: 269 | params["webhook_completed"] = args[4] 270 | if len(args) > 5: 271 | params["webhook_events_filter"] = args[5] 272 | elif model and version: 273 | url = _create_training_url_from_model_and_version(model, version) 274 | elif model is None and isinstance(version, str): 275 | url = _create_training_url_from_shorthand(version) 276 | 277 | if not url: 278 | raise ValueError("model and version or shorthand version must be specified") 279 | 280 | body = _create_training_body(input, **params) 281 | resp = self._client._request( 282 | "POST", 283 | url, 284 | json=body, 285 | ) 286 | 287 | return _json_to_training(self._client, resp.json()) 288 | 289 | async def async_create( 290 | self, 291 | model: Union[str, Tuple[str, str], "Model"], 292 | version: Union[str, Version], 293 | input: Dict[str, Any], 294 | **params: Unpack["Trainings.CreateTrainingParams"], 295 | ) -> Training: 296 | """ 297 | Create a new training using the specified model version as a base. 298 | 299 | Args: 300 | version: The ID of the base model version that you're using to train a new model version. 301 | input: The input to the training. 302 | destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. 303 | webhook: The URL to send a POST request to when the training is completed. Defaults to None. 304 | webhook_completed: The URL to receive a POST request when the prediction is completed. 305 | webhook_events_filter: The events to send to the webhook. Defaults to None. 306 | Returns: 307 | The training object. 308 | """ 309 | 310 | url = _create_training_url_from_model_and_version(model, version) 311 | body = _create_training_body(input, **params) 312 | resp = await self._client._async_request( 313 | "POST", 314 | url, 315 | json=body, 316 | ) 317 | 318 | return _json_to_training(self._client, resp.json()) 319 | 320 | def cancel(self, id: str) -> Training: 321 | """ 322 | Cancel a training. 323 | 324 | Args: 325 | id: The ID of the training to cancel. 326 | Returns: 327 | Training: The canceled training object. 328 | """ 329 | 330 | resp = self._client._request( 331 | "POST", 332 | f"/v1/trainings/{id}/cancel", 333 | ) 334 | 335 | return _json_to_training(self._client, resp.json()) 336 | 337 | async def async_cancel(self, id: str) -> Training: 338 | """ 339 | Cancel a training. 340 | 341 | Args: 342 | id: The ID of the training to cancel. 343 | Returns: 344 | Training: The canceled training object. 345 | """ 346 | 347 | resp = await self._client._async_request( 348 | "POST", 349 | f"/v1/trainings/{id}/cancel", 350 | ) 351 | 352 | return _json_to_training(self._client, resp.json()) 353 | 354 | 355 | def _create_training_body( 356 | input: Optional[Dict[str, Any]] = None, 357 | *, 358 | destination: Optional[Union[str, Tuple[str, str], "Model"]] = None, 359 | webhook: Optional[str] = None, 360 | webhook_completed: Optional[str] = None, 361 | webhook_events_filter: Optional[List[str]] = None, 362 | ) -> Dict[str, Any]: 363 | body = {} 364 | 365 | if input is not None: 366 | body["input"] = encode_json(input, upload_file=upload_file) 367 | 368 | if destination is None: 369 | raise ValueError( 370 | "A destination must be provided as a positional or keyword argument." 371 | ) 372 | if isinstance(destination, Model): 373 | destination = f"{destination.owner}/{destination.name}" 374 | elif isinstance(destination, tuple): 375 | destination = f"{destination[0]}/{destination[1]}" 376 | body["destination"] = destination 377 | 378 | if webhook is not None: 379 | body["webhook"] = webhook 380 | 381 | if webhook_completed is not None: 382 | body["webhook_completed"] = webhook_completed 383 | 384 | if webhook_events_filter is not None: 385 | body["webhook_events_filter"] = webhook_events_filter 386 | 387 | return body 388 | 389 | 390 | def _create_training_url_from_shorthand(ref: str) -> str: 391 | owner, name, version_id = ModelVersionIdentifier.parse(ref) 392 | return f"/v1/models/{owner}/{name}/versions/{version_id}/trainings" 393 | 394 | 395 | def _create_training_url_from_model_and_version( 396 | model: Union[str, Tuple[str, str], "Model"], 397 | version: Union[str, "Version"], 398 | ) -> str: 399 | if isinstance(model, Model): 400 | owner, name = model.owner, model.name 401 | elif isinstance(model, tuple): 402 | owner, name = model[0], model[1] 403 | elif isinstance(model, str): 404 | owner, name, _ = ModelVersionIdentifier.parse(model) 405 | else: 406 | raise ValueError( 407 | "model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'" 408 | ) 409 | 410 | if isinstance(version, Version): 411 | version_id = version.id 412 | else: 413 | version_id = version 414 | 415 | return f"/v1/models/{owner}/{name}/versions/{version_id}/trainings" 416 | 417 | 418 | def _json_to_training(client: "Client", json: Dict[str, Any]) -> Training: 419 | training = Training(**json) 420 | training._client = client 421 | return training 422 | -------------------------------------------------------------------------------- /replicate/deployment.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union 2 | 3 | from typing_extensions import Unpack, deprecated 4 | 5 | from replicate.account import Account 6 | from replicate.pagination import Page 7 | from replicate.prediction import ( 8 | Prediction, 9 | _create_prediction_body, 10 | _json_to_prediction, 11 | ) 12 | from replicate.resource import Namespace, Resource 13 | 14 | try: 15 | from pydantic import v1 as pydantic # type: ignore 16 | except ImportError: 17 | import pydantic # type: ignore 18 | 19 | 20 | if TYPE_CHECKING: 21 | from replicate.client import Client 22 | from replicate.prediction import Predictions 23 | 24 | 25 | class Deployment(Resource): 26 | """ 27 | A deployment of a model hosted on Replicate. 28 | """ 29 | 30 | _client: "Client" = pydantic.PrivateAttr() 31 | 32 | owner: str 33 | """ 34 | The name of the user or organization that owns the deployment. 35 | """ 36 | 37 | name: str 38 | """ 39 | The name of the deployment. 40 | """ 41 | 42 | class Release(Resource): 43 | """ 44 | A release of a deployment. 45 | """ 46 | 47 | number: int 48 | """ 49 | The release number. 50 | """ 51 | 52 | model: str 53 | """ 54 | The model identifier string in the format of `{model_owner}/{model_name}`. 55 | """ 56 | 57 | version: str 58 | """ 59 | The ID of the model version used in the release. 60 | """ 61 | 62 | created_at: str 63 | """ 64 | The time the release was created. 65 | """ 66 | 67 | created_by: Optional[Account] 68 | """ 69 | The account that created the release. 70 | """ 71 | 72 | class Configuration(Resource): 73 | """ 74 | A configuration for a deployment. 75 | """ 76 | 77 | hardware: str 78 | """ 79 | The SKU for the hardware used to run the model. 80 | """ 81 | 82 | min_instances: int 83 | """ 84 | The minimum number of instances for scaling. 85 | """ 86 | 87 | max_instances: int 88 | """ 89 | The maximum number of instances for scaling. 90 | """ 91 | 92 | configuration: Configuration 93 | """ 94 | The deployment configuration. 95 | """ 96 | 97 | current_release: Optional[Release] 98 | """ 99 | The current release of the deployment. 100 | """ 101 | 102 | @property 103 | @deprecated("Use `deployment.owner` instead.") 104 | def username(self) -> str: 105 | """ 106 | The name of the user or organization that owns the deployment. 107 | This attribute is deprecated and will be removed in future versions. 108 | """ 109 | return self.owner 110 | 111 | @property 112 | def id(self) -> str: 113 | """ 114 | Return the qualified deployment name, in the format `owner/name`. 115 | """ 116 | return f"{self.owner}/{self.name}" 117 | 118 | @property 119 | def predictions(self) -> "DeploymentPredictions": 120 | """ 121 | Get the predictions for this deployment. 122 | """ 123 | 124 | return DeploymentPredictions(client=self._client, deployment=self) 125 | 126 | 127 | class Deployments(Namespace): 128 | """ 129 | Namespace for operations related to deployments. 130 | """ 131 | 132 | _client: "Client" 133 | 134 | def list( 135 | self, 136 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 137 | ) -> Page[Deployment]: 138 | """ 139 | List all deployments. 140 | 141 | Returns: 142 | A page of Deployments. 143 | """ 144 | 145 | if cursor is None: 146 | raise ValueError("cursor cannot be None") 147 | 148 | resp = self._client._request( 149 | "GET", "/v1/deployments" if cursor is ... else cursor 150 | ) 151 | 152 | obj = resp.json() 153 | obj["results"] = [ 154 | _json_to_deployment(self._client, result) for result in obj["results"] 155 | ] 156 | 157 | return Page[Deployment](**obj) 158 | 159 | async def async_list( 160 | self, 161 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 162 | ) -> Page[Deployment]: 163 | """ 164 | List all deployments. 165 | 166 | Returns: 167 | A page of Deployments. 168 | """ 169 | if cursor is None: 170 | raise ValueError("cursor cannot be None") 171 | 172 | resp = await self._client._async_request( 173 | "GET", "/v1/deployments" if cursor is ... else cursor 174 | ) 175 | 176 | obj = resp.json() 177 | obj["results"] = [ 178 | _json_to_deployment(self._client, result) for result in obj["results"] 179 | ] 180 | 181 | return Page[Deployment](**obj) 182 | 183 | def get(self, name: str) -> Deployment: 184 | """ 185 | Get a deployment by name. 186 | 187 | Args: 188 | name: The name of the deployment, in the format `owner/model-name`. 189 | Returns: 190 | The model. 191 | """ 192 | 193 | owner, name = name.split("/", 1) 194 | 195 | resp = self._client._request( 196 | "GET", 197 | f"/v1/deployments/{owner}/{name}", 198 | ) 199 | 200 | return _json_to_deployment(self._client, resp.json()) 201 | 202 | async def async_get(self, name: str) -> Deployment: 203 | """ 204 | Get a deployment by name. 205 | 206 | Args: 207 | name: The name of the deployment, in the format `owner/model-name`. 208 | Returns: 209 | The model. 210 | """ 211 | 212 | owner, name = name.split("/", 1) 213 | 214 | resp = await self._client._async_request( 215 | "GET", 216 | f"/v1/deployments/{owner}/{name}", 217 | ) 218 | 219 | return _json_to_deployment(self._client, resp.json()) 220 | 221 | class CreateDeploymentParams(TypedDict): 222 | """ 223 | Parameters for creating a new deployment. 224 | """ 225 | 226 | name: str 227 | """The name of the deployment.""" 228 | 229 | model: str 230 | """The model identifier string in the format of `{model_owner}/{model_name}`.""" 231 | 232 | version: str 233 | """The version of the model to deploy.""" 234 | 235 | hardware: str 236 | """The SKU for the hardware used to run the model.""" 237 | 238 | min_instances: int 239 | """The minimum number of instances for scaling.""" 240 | 241 | max_instances: int 242 | """The maximum number of instances for scaling.""" 243 | 244 | def create(self, **params: Unpack[CreateDeploymentParams]) -> Deployment: 245 | """ 246 | Create a new deployment. 247 | 248 | Args: 249 | params: Configuration for the new deployment. 250 | Returns: 251 | The newly created Deployment. 252 | """ 253 | 254 | if name := params.get("name", None): 255 | if "/" in name: 256 | _, name = name.split("/", 1) 257 | params["name"] = name 258 | 259 | resp = self._client._request( 260 | "POST", 261 | "/v1/deployments", 262 | json=params, 263 | ) 264 | 265 | return _json_to_deployment(self._client, resp.json()) 266 | 267 | async def async_create( 268 | self, **params: Unpack[CreateDeploymentParams] 269 | ) -> Deployment: 270 | """ 271 | Create a new deployment. 272 | 273 | Args: 274 | params: Configuration for the new deployment. 275 | Returns: 276 | The newly created Deployment. 277 | """ 278 | 279 | if name := params.get("name", None): 280 | if "/" in name: 281 | _, name = name.split("/", 1) 282 | params["name"] = name 283 | 284 | resp = await self._client._async_request( 285 | "POST", 286 | "/v1/deployments", 287 | json=params, 288 | ) 289 | 290 | return _json_to_deployment(self._client, resp.json()) 291 | 292 | class UpdateDeploymentParams(TypedDict, total=False): 293 | """ 294 | Parameters for updating an existing deployment. 295 | """ 296 | 297 | version: str 298 | """The version of the model to deploy.""" 299 | 300 | hardware: str 301 | """The SKU for the hardware used to run the model.""" 302 | 303 | min_instances: int 304 | """The minimum number of instances for scaling.""" 305 | 306 | max_instances: int 307 | """The maximum number of instances for scaling.""" 308 | 309 | def update( 310 | self, 311 | deployment_owner: str, 312 | deployment_name: str, 313 | **params: Unpack[UpdateDeploymentParams], 314 | ) -> Deployment: 315 | """ 316 | Update an existing deployment. 317 | 318 | Args: 319 | deployment_owner: The owner of the deployment. 320 | deployment_name: The name of the deployment. 321 | params: Configuration updates for the deployment. 322 | Returns: 323 | The updated Deployment. 324 | """ 325 | 326 | resp = self._client._request( 327 | "PATCH", 328 | f"/v1/deployments/{deployment_owner}/{deployment_name}", 329 | json=params, 330 | ) 331 | 332 | return _json_to_deployment(self._client, resp.json()) 333 | 334 | async def async_update( 335 | self, 336 | deployment_owner: str, 337 | deployment_name: str, 338 | **params: Unpack[UpdateDeploymentParams], 339 | ) -> Deployment: 340 | """ 341 | Update an existing deployment. 342 | 343 | Args: 344 | deployment_owner: The owner of the deployment. 345 | deployment_name: The name of the deployment. 346 | params: Configuration updates for the deployment. 347 | Returns: 348 | The updated Deployment. 349 | """ 350 | 351 | resp = await self._client._async_request( 352 | "PATCH", 353 | f"/v1/deployments/{deployment_owner}/{deployment_name}", 354 | json=params, 355 | ) 356 | 357 | return _json_to_deployment(self._client, resp.json()) 358 | 359 | @property 360 | def predictions(self) -> "DeploymentsPredictions": 361 | """ 362 | Get predictions for deployments. 363 | """ 364 | 365 | return DeploymentsPredictions(client=self._client) 366 | 367 | 368 | def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment: 369 | deployment = Deployment(**json) 370 | deployment._client = client 371 | return deployment 372 | 373 | 374 | class DeploymentPredictions(Namespace): 375 | """ 376 | Namespace for operations related to predictions in a deployment. 377 | """ 378 | 379 | _deployment: Deployment 380 | 381 | def __init__(self, client: "Client", deployment: Deployment) -> None: 382 | super().__init__(client=client) 383 | self._deployment = deployment 384 | 385 | def create( 386 | self, 387 | input: Dict[str, Any], 388 | **params: Unpack["Predictions.CreatePredictionParams"], 389 | ) -> Prediction: 390 | """ 391 | Create a new prediction with the deployment. 392 | """ 393 | 394 | body = _create_prediction_body(version=None, input=input, **params) 395 | 396 | resp = self._client._request( 397 | "POST", 398 | f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", 399 | json=body, 400 | ) 401 | 402 | return _json_to_prediction(self._client, resp.json()) 403 | 404 | async def async_create( 405 | self, 406 | input: Dict[str, Any], 407 | **params: Unpack["Predictions.CreatePredictionParams"], 408 | ) -> Prediction: 409 | """ 410 | Create a new prediction with the deployment. 411 | """ 412 | 413 | body = _create_prediction_body(version=None, input=input, **params) 414 | 415 | resp = await self._client._async_request( 416 | "POST", 417 | f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", 418 | json=body, 419 | ) 420 | 421 | return _json_to_prediction(self._client, resp.json()) 422 | 423 | 424 | class DeploymentsPredictions(Namespace): 425 | """ 426 | Namespace for operations related to predictions in deployments. 427 | """ 428 | 429 | def create( 430 | self, 431 | deployment: Union[str, Tuple[str, str], Deployment], 432 | input: Dict[str, Any], 433 | **params: Unpack["Predictions.CreatePredictionParams"], 434 | ) -> Prediction: 435 | """ 436 | Create a new prediction with the deployment. 437 | """ 438 | 439 | url = _create_prediction_url_from_deployment(deployment) 440 | body = _create_prediction_body(version=None, input=input, **params) 441 | 442 | resp = self._client._request( 443 | "POST", 444 | url, 445 | json=body, 446 | ) 447 | 448 | return _json_to_prediction(self._client, resp.json()) 449 | 450 | async def async_create( 451 | self, 452 | deployment: Union[str, Tuple[str, str], Deployment], 453 | input: Dict[str, Any], 454 | **params: Unpack["Predictions.CreatePredictionParams"], 455 | ) -> Prediction: 456 | """ 457 | Create a new prediction with the deployment. 458 | """ 459 | 460 | url = _create_prediction_url_from_deployment(deployment) 461 | body = _create_prediction_body(version=None, input=input, **params) 462 | 463 | resp = await self._client._async_request( 464 | "POST", 465 | url, 466 | json=body, 467 | ) 468 | 469 | return _json_to_prediction(self._client, resp.json()) 470 | 471 | 472 | def _create_prediction_url_from_deployment( 473 | deployment: Union[str, Tuple[str, str], Deployment], 474 | ) -> str: 475 | owner, name = None, None 476 | if isinstance(deployment, Deployment): 477 | owner, name = deployment.owner, deployment.name 478 | elif isinstance(deployment, tuple): 479 | owner, name = deployment[0], deployment[1] 480 | elif isinstance(deployment, str): 481 | owner, name = deployment.split("/", 1) 482 | 483 | if owner is None or name is None: 484 | raise ValueError( 485 | "deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'" 486 | ) 487 | 488 | return f"/v1/deployments/{owner}/{name}/predictions" 489 | -------------------------------------------------------------------------------- /tests/test_deployment.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import httpx 4 | import pytest 5 | import respx 6 | 7 | from replicate.client import Client 8 | 9 | router = respx.Router(base_url="https://api.replicate.com/v1") 10 | 11 | router.route( 12 | method="GET", 13 | path="/deployments/replicate/my-app-image-generator", 14 | name="deployments.get", 15 | ).mock( 16 | return_value=httpx.Response( 17 | 201, 18 | json={ 19 | "owner": "replicate", 20 | "name": "my-app-image-generator", 21 | "current_release": { 22 | "number": 1, 23 | "model": "stability-ai/sdxl", 24 | "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", 25 | "created_at": "2024-02-15T16:32:57.018467Z", 26 | "created_by": { 27 | "type": "organization", 28 | "username": "acme", 29 | "name": "Acme Corp, Inc.", 30 | "github_url": "https://github.com/acme", 31 | }, 32 | "configuration": { 33 | "hardware": "gpu-t4", 34 | "min_instances": 1, 35 | "max_instances": 5, 36 | }, 37 | }, 38 | }, 39 | ) 40 | ) 41 | router.route( 42 | method="POST", 43 | path="/deployments/replicate/my-app-image-generator/predictions", 44 | name="deployments.predictions.create", 45 | ).mock( 46 | return_value=httpx.Response( 47 | 201, 48 | json={ 49 | "id": "p1", 50 | "model": "replicate/my-app-image-generator", 51 | "version": "v1", 52 | "urls": { 53 | "get": "https://api.replicate.com/v1/predictions/p1", 54 | "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", 55 | }, 56 | "created_at": "2022-04-26T20:00:40.658234Z", 57 | "source": "api", 58 | "status": "processing", 59 | "input": {"text": "world"}, 60 | "output": None, 61 | "error": None, 62 | "logs": "", 63 | }, 64 | ) 65 | ) 66 | router.route( 67 | method="GET", 68 | path="/deployments", 69 | name="deployments.list", 70 | ).mock( 71 | return_value=httpx.Response( 72 | 200, 73 | json={ 74 | "results": [ 75 | { 76 | "owner": "acme", 77 | "name": "image-upscaler", 78 | "current_release": { 79 | "number": 1, 80 | "model": "acme/esrgan", 81 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 82 | "created_at": "2022-01-01T00:00:00Z", 83 | "created_by": { 84 | "type": "organization", 85 | "username": "acme", 86 | "name": "Acme, Inc.", 87 | }, 88 | "configuration": { 89 | "hardware": "gpu-t4", 90 | "min_instances": 1, 91 | "max_instances": 5, 92 | }, 93 | }, 94 | }, 95 | { 96 | "owner": "acme", 97 | "name": "text-generator", 98 | "current_release": { 99 | "number": 2, 100 | "model": "acme/acme-llama", 101 | "version": "4b7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccbb", 102 | "created_at": "2022-02-02T00:00:00Z", 103 | "created_by": { 104 | "type": "organization", 105 | "username": "acme", 106 | "name": "Acme, Inc.", 107 | }, 108 | "configuration": { 109 | "hardware": "cpu", 110 | "min_instances": 2, 111 | "max_instances": 10, 112 | }, 113 | }, 114 | }, 115 | ] 116 | }, 117 | ) 118 | ) 119 | 120 | router.route( 121 | method="POST", 122 | path="/deployments", 123 | name="deployments.create", 124 | ).mock( 125 | return_value=httpx.Response( 126 | 201, 127 | json={ 128 | "owner": "acme", 129 | "name": "new-deployment", 130 | "current_release": { 131 | "number": 1, 132 | "model": "acme/new-model", 133 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 134 | "created_at": "2022-01-01T00:00:00Z", 135 | "created_by": { 136 | "type": "organization", 137 | "username": "acme", 138 | "name": "Acme, Inc.", 139 | }, 140 | "configuration": { 141 | "hardware": "gpu-t4", 142 | "min_instances": 1, 143 | "max_instances": 5, 144 | }, 145 | }, 146 | }, 147 | ) 148 | ) 149 | 150 | 151 | router.route( 152 | method="PATCH", 153 | path="/deployments/acme/image-upscaler", 154 | name="deployments.update", 155 | ).mock( 156 | return_value=httpx.Response( 157 | 200, 158 | json={ 159 | "owner": "acme", 160 | "name": "image-upscaler", 161 | "current_release": { 162 | "number": 2, 163 | "model": "acme/esrgan-updated", 164 | "version": "new-version-id", 165 | "created_at": "2022-02-02T00:00:00Z", 166 | "created_by": { 167 | "type": "organization", 168 | "username": "acme", 169 | "name": "Acme, Inc.", 170 | }, 171 | "configuration": { 172 | "hardware": "gpu-v100", 173 | "min_instances": 2, 174 | "max_instances": 10, 175 | }, 176 | }, 177 | }, 178 | ) 179 | ) 180 | 181 | 182 | router.route(host="api.replicate.com").pass_through() 183 | 184 | 185 | @pytest.mark.asyncio 186 | @pytest.mark.parametrize("async_flag", [True, False]) 187 | async def test_deployment_get(async_flag): 188 | client = Client( 189 | api_token="test-token", transport=httpx.MockTransport(router.handler) 190 | ) 191 | 192 | if async_flag: 193 | deployment = await client.deployments.async_get( 194 | "replicate/my-app-image-generator" 195 | ) 196 | else: 197 | deployment = client.deployments.get("replicate/my-app-image-generator") 198 | 199 | assert router["deployments.get"].called 200 | 201 | assert deployment.owner == "replicate" 202 | assert deployment.name == "my-app-image-generator" 203 | assert deployment.current_release is not None 204 | assert deployment.current_release.number == 1 205 | assert deployment.current_release.model == "stability-ai/sdxl" 206 | assert ( 207 | deployment.current_release.version 208 | == "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" 209 | ) 210 | assert deployment.current_release is not None 211 | assert deployment.current_release.created_by is not None 212 | assert deployment.current_release.created_by.type == "organization" 213 | assert deployment.current_release.created_by.username == "acme" 214 | assert deployment.current_release.created_by.name == "Acme Corp, Inc." 215 | assert deployment.current_release.created_by.github_url == "https://github.com/acme" 216 | 217 | 218 | @pytest.mark.asyncio 219 | @pytest.mark.parametrize("async_flag", [True, False]) 220 | async def test_deployment_predictions_create(async_flag): 221 | client = Client( 222 | api_token="test-token", transport=httpx.MockTransport(router.handler) 223 | ) 224 | 225 | if async_flag: 226 | deployment = await client.deployments.async_get( 227 | "replicate/my-app-image-generator" 228 | ) 229 | 230 | prediction = await deployment.predictions.async_create( 231 | input={"text": "world"}, 232 | webhook="https://example.com/webhook", 233 | webhook_events_filter=["completed"], 234 | stream=True, 235 | ) 236 | else: 237 | deployment = client.deployments.get("replicate/my-app-image-generator") 238 | 239 | prediction = deployment.predictions.create( 240 | input={"text": "world"}, 241 | webhook="https://example.com/webhook", 242 | webhook_events_filter=["completed"], 243 | stream=True, 244 | ) 245 | 246 | assert router["deployments.predictions.create"].called 247 | request = router["deployments.predictions.create"].calls[0].request 248 | request_body = json.loads(request.content) 249 | assert request_body["input"] == {"text": "world"} 250 | assert request_body["webhook"] == "https://example.com/webhook" 251 | assert request_body["webhook_events_filter"] == ["completed"] 252 | assert request_body["stream"] is True 253 | 254 | assert prediction.id == "p1" 255 | assert prediction.input == {"text": "world"} 256 | 257 | 258 | @pytest.mark.asyncio 259 | @pytest.mark.parametrize("async_flag", [True, False]) 260 | async def test_deploymentspredictions_create(async_flag): 261 | client = Client( 262 | api_token="test-token", transport=httpx.MockTransport(router.handler) 263 | ) 264 | 265 | if async_flag: 266 | prediction = await client.deployments.predictions.async_create( 267 | deployment="replicate/my-app-image-generator", 268 | input={"text": "world"}, 269 | webhook="https://example.com/webhook", 270 | webhook_events_filter=["completed"], 271 | stream=True, 272 | ) 273 | else: 274 | prediction = await client.deployments.predictions.async_create( 275 | deployment="replicate/my-app-image-generator", 276 | input={"text": "world"}, 277 | webhook="https://example.com/webhook", 278 | webhook_events_filter=["completed"], 279 | stream=True, 280 | ) 281 | 282 | assert router["deployments.predictions.create"].called 283 | request = router["deployments.predictions.create"].calls[0].request 284 | request_body = json.loads(request.content) 285 | assert request_body["input"] == {"text": "world"} 286 | assert request_body["webhook"] == "https://example.com/webhook" 287 | assert request_body["webhook_events_filter"] == ["completed"] 288 | assert request_body["stream"] is True 289 | 290 | assert prediction.id == "p1" 291 | assert prediction.input == {"text": "world"} 292 | 293 | 294 | @respx.mock 295 | @pytest.mark.asyncio 296 | @pytest.mark.parametrize("async_flag", [True, False]) 297 | async def test_deployments_list(async_flag): 298 | client = Client( 299 | api_token="test-token", transport=httpx.MockTransport(router.handler) 300 | ) 301 | 302 | if async_flag: 303 | deployments = await client.deployments.async_list() 304 | else: 305 | deployments = client.deployments.list() 306 | 307 | assert router["deployments.list"].called 308 | 309 | assert len(deployments.results) == 2 310 | assert deployments.results[0].owner == "acme" 311 | assert deployments.results[0].name == "image-upscaler" 312 | assert deployments.results[0].current_release is not None 313 | assert deployments.results[0].current_release.number == 1 314 | assert deployments.results[0].current_release.model == "acme/esrgan" 315 | assert deployments.results[1].owner == "acme" 316 | assert deployments.results[1].name == "text-generator" 317 | assert deployments.results[1].current_release is not None 318 | assert deployments.results[1].current_release.number == 2 319 | assert deployments.results[1].current_release.model == "acme/acme-llama" 320 | 321 | 322 | @respx.mock 323 | @pytest.mark.asyncio 324 | @pytest.mark.parametrize("async_flag", [True, False]) 325 | async def test_create_deployment(async_flag): 326 | client = Client( 327 | api_token="test-token", transport=httpx.MockTransport(router.handler) 328 | ) 329 | 330 | config = { 331 | "name": "new-deployment", 332 | "model": "acme/new-model", 333 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 334 | "hardware": "gpu-t4", 335 | "min_instances": 1, 336 | "max_instances": 5, 337 | } 338 | 339 | if async_flag: 340 | deployment = await client.deployments.async_create(**config) 341 | else: 342 | deployment = client.deployments.create(**config) 343 | 344 | assert router["deployments.create"].called 345 | 346 | assert deployment.owner == "acme" 347 | assert deployment.name == "new-deployment" 348 | assert deployment.current_release is not None 349 | assert deployment.current_release.number == 1 350 | assert deployment.current_release.model == "acme/new-model" 351 | assert ( 352 | deployment.current_release.version 353 | == "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 354 | ) 355 | assert deployment.current_release.created_by is not None 356 | assert deployment.current_release.created_by.type == "organization" 357 | assert deployment.current_release.created_by.username == "acme" 358 | assert deployment.current_release.created_by.name == "Acme, Inc." 359 | assert deployment.current_release.configuration.hardware == "gpu-t4" 360 | assert deployment.current_release.configuration.min_instances == 1 361 | assert deployment.current_release.configuration.max_instances == 5 362 | 363 | 364 | @respx.mock 365 | @pytest.mark.asyncio 366 | @pytest.mark.parametrize("async_flag", [True, False]) 367 | async def test_update_deployment(async_flag): 368 | config = { 369 | "version": "new-version-id", 370 | "hardware": "gpu-v100", 371 | "min_instances": 2, 372 | "max_instances": 10, 373 | } 374 | 375 | client = Client( 376 | api_token="test-token", transport=httpx.MockTransport(router.handler) 377 | ) 378 | 379 | if async_flag: 380 | updated_deployment = await client.deployments.async_update( 381 | deployment_owner="acme", deployment_name="image-upscaler", **config 382 | ) 383 | else: 384 | updated_deployment = client.deployments.update( 385 | deployment_owner="acme", deployment_name="image-upscaler", **config 386 | ) 387 | 388 | assert router["deployments.update"].called 389 | request = router["deployments.update"].calls[0].request 390 | request_body = json.loads(request.content) 391 | assert request_body == config 392 | 393 | assert updated_deployment.owner == "acme" 394 | assert updated_deployment.name == "image-upscaler" 395 | assert updated_deployment.current_release is not None 396 | assert updated_deployment.current_release.number == 2 397 | assert updated_deployment.current_release.model == "acme/esrgan-updated" 398 | assert updated_deployment.current_release.version == "new-version-id" 399 | assert updated_deployment.current_release.configuration.hardware == "gpu-v100" 400 | assert updated_deployment.current_release.configuration.min_instances == 2 401 | assert updated_deployment.current_release.configuration.max_instances == 10 402 | --------------------------------------------------------------------------------