├── .gitattributes ├── .github └── workflows │ ├── ci.yaml │ └── release.yaml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pyproject.toml ├── replicate ├── __about__.py ├── __init__.py ├── account.py ├── client.py ├── collection.py ├── deployment.py ├── exceptions.py ├── files.py ├── hardware.py ├── identifier.py ├── json.py ├── model.py ├── pagination.py ├── prediction.py ├── resource.py ├── run.py ├── schema.py ├── stream.py ├── training.py └── version.py ├── requirements-dev.txt ├── requirements.txt ├── script ├── format ├── lint ├── setup └── test └── tests ├── __init__.py ├── cassettes ├── collections-get.yaml ├── collections-list.yaml ├── hardware-list.yaml ├── models-create.yaml ├── models-get.yaml ├── models-list.yaml ├── models-list__pagination.yaml ├── models-predictions-create.yaml ├── predictions-cancel.yaml ├── predictions-create.yaml ├── predictions-get.yaml ├── predictions-stream.yaml ├── run.yaml ├── run__concurrently.yaml ├── run__invalid-token.yaml ├── test_predictions_cancel[False].yaml ├── test_predictions_cancel[True].yaml ├── test_predictions_create_by_model[False].yaml ├── test_predictions_create_by_model[True].yaml ├── trainings-cancel.yaml ├── trainings-create.yaml ├── trainings-create__invalid-destination.yaml └── trainings-get.yaml ├── conftest.py ├── test_account.py ├── test_client.py ├── test_collection.py ├── test_deployment.py ├── test_hardware.py ├── test_identifier.py ├── test_model.py ├── test_pagination.py ├── test_prediction.py ├── test_run.py ├── test_stream.py ├── test_training.py └── test_version.py /.gitattributes: -------------------------------------------------------------------------------- 1 | tests/cassettes/** binary 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | 7 | pull_request: 8 | branches: ["main"] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | 14 | name: "Test Python ${{ matrix.python-version }}" 15 | 16 | env: 17 | REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} 18 | 19 | timeout-minutes: 10 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 25 | 26 | defaults: 27 | run: 28 | shell: bash 29 | 30 | steps: 31 | - uses: actions/checkout@v3 32 | - uses: actions/setup-python@v3 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | cache: "pip" 36 | 37 | - name: Setup 38 | run: ./script/setup 39 | 40 | - name: Test 41 | run: ./script/test 42 | 43 | - name: Lint 44 | run: ./script/lint 45 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: ["*"] 6 | 7 | jobs: 8 | release: 9 | runs-on: ubuntu-latest 10 | 11 | name: "Publish to PyPI" 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - uses: actions/setup-python@v3 16 | with: 17 | python-version: "3.10" 18 | - name: Install pypa/build 19 | run: python -m pip install build --user 20 | - name: Build a package 21 | run: python -m build 22 | - name: Publish distribution 📦 to PyPI 23 | uses: pypa/gh-action-pypi-publish@release/v1 24 | with: 25 | password: ${{ secrets.PYPI_API_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff", 4 | "ms-python.python", 5 | "ms-python.vscode-pylance", 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "editor.formatOnType": true, 4 | "editor.formatOnPaste": true, 5 | "editor.renderControlCharacters": true, 6 | "editor.suggest.localityBonus": true, 7 | "files.insertFinalNewline": true, 8 | "files.trimFinalNewlines": true, 9 | "[python]": { 10 | "editor.defaultFormatter": "charliermarsh.ruff", 11 | "editor.formatOnSave": true, 12 | "editor.codeActionsOnSave": { 13 | "source.fixAll": "explicit", 14 | "source.organizeImports": "explicit" 15 | } 16 | }, 17 | "python.languageServer": "Pylance", 18 | "python.analysis.typeCheckingMode": "basic", 19 | "python.testing.pytestArgs": [ 20 | "-vvv", 21 | "python" 22 | ], 23 | "python.testing.unittestEnabled": false, 24 | "python.testing.pytestEnabled": true, 25 | "ruff.lint.args": [ 26 | "--config=pyproject.toml" 27 | ], 28 | "ruff.format.args": [ 29 | "--config=pyproject.toml" 30 | ], 31 | } 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | - [Making a contribution](#making-a-contribution) 4 | - [Signing your work](#signing-your-work) 5 | - [How to sign off your commits](#how-to-sign-off-your-commits) 6 | - [Development](#development) 7 | - [Environment variables](#environment-variables) 8 | - [Publishing a release](#publishing-a-release) 9 | 10 | ## Making a contribution 11 | 12 | ### Signing your work 13 | 14 | Each commit you contribute to this repo must be signed off (not to be confused with **[signing](https://git-scm.com/book/en/v2/Git-Tools-Signing-Your-Work)**). It certifies that you wrote the patch, or have the right to contribute it. It is called the [Developer Certificate of Origin](https://developercertificate.org/) and was originally developed for the Linux kernel. 15 | 16 | If you can certify the following: 17 | 18 | ``` 19 | By making a contribution to this project, I certify that: 20 | 21 | (a) The contribution was created in whole or in part by me and I 22 | have the right to submit it under the open source license 23 | indicated in the file; or 24 | 25 | (b) The contribution is based upon previous work that, to the best 26 | of my knowledge, is covered under an appropriate open source 27 | license and I have the right under that license to submit that 28 | work with modifications, whether created in whole or in part 29 | by me, under the same open source license (unless I am 30 | permitted to submit under a different license), as indicated 31 | in the file; or 32 | 33 | (c) The contribution was provided directly to me by some other 34 | person who certified (a), (b) or (c) and I have not modified 35 | it. 36 | 37 | (d) I understand and agree that this project and the contribution 38 | are public and that a record of the contribution (including all 39 | personal information I submit with it, including my sign-off) is 40 | maintained indefinitely and may be redistributed consistent with 41 | this project or the open source license(s) involved. 42 | ``` 43 | 44 | Then add this line to each of your Git commit messages, with your name and email: 45 | 46 | ``` 47 | Signed-off-by: Sam Smith 48 | ``` 49 | 50 | ### How to sign off your commits 51 | 52 | If you're using the `git` CLI, you can sign a commit by passing the `-s` option: `git commit -s -m "Reticulate splines"` 53 | 54 | You can also create a git hook which will sign off all your commits automatically. Using hooks also allows you to sign off commits when using non-command-line tools like GitHub Desktop or VS Code. 55 | 56 | First, create the hook file and make it executable: 57 | 58 | ```sh 59 | cd your/checkout/of/replicate-python 60 | touch .git/hooks/prepare-commit-msg 61 | chmod +x .git/hooks/prepare-commit-msg 62 | ``` 63 | 64 | Then paste the following into the file: 65 | 66 | ``` 67 | #!/bin/sh 68 | 69 | NAME=$(git config user.name) 70 | EMAIL=$(git config user.email) 71 | 72 | if [ -z "$NAME" ]; then 73 | echo "empty git config user.name" 74 | exit 1 75 | fi 76 | 77 | if [ -z "$EMAIL" ]; then 78 | echo "empty git config user.email" 79 | exit 1 80 | fi 81 | 82 | git interpret-trailers --if-exists doNothing --trailer \ 83 | "Signed-off-by: $NAME <$EMAIL>" \ 84 | --in-place "$1" 85 | ``` 86 | 87 | ## Development 88 | 89 | To run the tests: 90 | 91 | ```sh 92 | pip install -r requirements-dev.txt 93 | pytest 94 | ``` 95 | 96 | To install the package in development: 97 | 98 | ```sh 99 | pip install -e . 100 | ``` 101 | 102 | ### Environment variables 103 | 104 | - `REPLICATE_API_BASE_URL`: Defaults to `https://api.replicate.com` but can be overriden to point the client at a development host. 105 | - `REPLICATE_API_TOKEN`: Required. Find your token at https://replicate.com/#token 106 | 107 | ## Publishing a release 108 | 109 | This project has a [GitHub Actions workflow](/.github/workflows/ci.yaml) that publishes the `replicate` package to PyPI. The release process is triggered by manually creating and pushing a new git tag. 110 | 111 | First, set the version number in [pyproject.toml](pyproject.toml) and commit it to the `main` branch: 112 | 113 | ``` 114 | version = "0.7.0" 115 | ``` 116 | 117 | Then run the following in your local checkout: 118 | 119 | ```sh 120 | git checkout main 121 | git fetch --all --tags 122 | git tag 0.7.0 123 | git push --tags 124 | ``` 125 | 126 | Then visit [github.com/replicate/replicate-python/actions](https://github.com/replicate/replicate-python/actions) to monitor the release process. 127 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2022, Replicate, Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Replicate Python client 2 | 3 | This is a Python client for [Replicate](https://replicate.com). It lets you run models from your Python code or Jupyter notebook, and do various other things on Replicate. 4 | 5 | > **👋** Check out an interactive version of this tutorial on [Google Colab](https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c). 6 | > 7 | > [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c) 8 | 9 | 10 | ## Install 11 | 12 | ```sh 13 | pip install replicate 14 | ``` 15 | 16 | ## Authenticate 17 | 18 | Before running any Python scripts that use the API, you need to set your Replicate API token in your environment. 19 | 20 | Grab your token from [replicate.com/account](https://replicate.com/account) and set it as an environment variable: 21 | 22 | ``` 23 | export REPLICATE_API_TOKEN= 24 | ``` 25 | 26 | We recommend not adding the token directly to your source code, because you don't want to put your credentials in source control. If anyone used your API key, their usage would be charged to your account. 27 | 28 | ## Run a model 29 | 30 | Create a new Python file and add the following code, replacing the model identifier and input with your own: 31 | 32 | ```python 33 | >>> import replicate 34 | >>> replicate.run( 35 | "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", 36 | input={"prompt": "a 19th century portrait of a wombat gentleman"} 37 | ) 38 | 39 | ['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png'] 40 | ``` 41 | 42 | Some models, particularly language models, may not require the version string. Refer to the API documentation for the model for more on the specifics: 43 | 44 | ```python 45 | replicate.run( 46 | "meta/meta-llama-3-70b-instruct", 47 | input={ 48 | "prompt": "Can you write a poem about open source machine learning?", 49 | "system_prompt": "You are a helpful, respectful and honest assistant.", 50 | }, 51 | ) 52 | ``` 53 | 54 | Some models, like [andreasjansson/blip-2](https://replicate.com/andreasjansson/blip-2), have files as inputs. 55 | To run a model that takes a file input, 56 | pass a URL to a publicly accessible file. 57 | Or, for smaller files (<10MB), you can pass a file handle directly. 58 | 59 | ```python 60 | >>> output = replicate.run( 61 | "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9", 62 | input={ "image": open("path/to/mystery.jpg") } 63 | ) 64 | 65 | "an astronaut riding a horse" 66 | ``` 67 | 68 | > [!NOTE] 69 | > You can also use the Replicate client asynchronously by prepending `async_` to the method name. 70 | > 71 | > Here's an example of how to run several predictions concurrently and wait for them all to complete: 72 | > 73 | > ```python 74 | > import asyncio 75 | > import replicate 76 | > 77 | > # https://replicate.com/stability-ai/sdxl 78 | > model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" 79 | > prompts = [ 80 | > f"A chariot pulled by a team of {count} rainbow unicorns" 81 | > for count in ["two", "four", "six", "eight"] 82 | > ] 83 | > 84 | > async with asyncio.TaskGroup() as tg: 85 | > tasks = [ 86 | > tg.create_task(replicate.async_run(model_version, input={"prompt": prompt})) 87 | > for prompt in prompts 88 | > ] 89 | > 90 | > results = await asyncio.gather(*tasks) 91 | > print(results) 92 | > ``` 93 | 94 | ## Run a model and stream its output 95 | 96 | Replicate’s API supports server-sent event streams (SSEs) for language models. 97 | Use the `stream` method to consume tokens as they're produced by the model. 98 | 99 | ```python 100 | import replicate 101 | 102 | for event in replicate.stream( 103 | "meta/meta-llama-3-70b-instruct", 104 | input={ 105 | "prompt": "Please write a haiku about llamas.", 106 | }, 107 | ): 108 | print(str(event), end="") 109 | ``` 110 | 111 | You can also stream the output of a prediction you create. 112 | This is helpful when you want the ID of the prediction separate from its output. 113 | 114 | ```python 115 | version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" 116 | prediction = replicate.predictions.create( 117 | version=version, 118 | input={"prompt": "Please write a haiku about llamas."}, 119 | stream=True, 120 | ) 121 | 122 | for event in prediction.stream(): 123 | print(str(event), end="") 124 | ``` 125 | 126 | For more information, see 127 | ["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs. 128 | 129 | 130 | ## Run a model in the background 131 | 132 | You can start a model and run it in the background: 133 | 134 | ```python 135 | >>> model = replicate.models.get("kvfrans/clipdraw") 136 | >>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") 137 | >>> prediction = replicate.predictions.create( 138 | version=version, 139 | input={"prompt":"Watercolor painting of an underwater submarine"}) 140 | 141 | >>> prediction 142 | Prediction(...) 143 | 144 | >>> prediction.status 145 | 'starting' 146 | 147 | >>> dict(prediction) 148 | {"id": "...", "status": "starting", ...} 149 | 150 | >>> prediction.reload() 151 | >>> prediction.status 152 | 'processing' 153 | 154 | >>> print(prediction.logs) 155 | iteration: 0, render:loss: -0.6171875 156 | iteration: 10, render:loss: -0.92236328125 157 | iteration: 20, render:loss: -1.197265625 158 | iteration: 30, render:loss: -1.3994140625 159 | 160 | >>> prediction.wait() 161 | 162 | >>> prediction.status 163 | 'succeeded' 164 | 165 | >>> prediction.output 166 | 'https://.../output.png' 167 | ``` 168 | 169 | ## Run a model in the background and get a webhook 170 | 171 | You can run a model and get a webhook when it completes, instead of waiting for it to finish: 172 | 173 | ```python 174 | model = replicate.models.get("ai-forever/kandinsky-2.2") 175 | version = model.versions.get("ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463") 176 | prediction = replicate.predictions.create( 177 | version=version, 178 | input={"prompt":"Watercolor painting of an underwater submarine"}, 179 | webhook="https://example.com/your-webhook", 180 | webhook_events_filter=["completed"] 181 | ) 182 | ``` 183 | 184 | For details on receiving webhooks, see [replicate.com/docs/webhooks](https://replicate.com/docs/webhooks). 185 | 186 | ## Compose models into a pipeline 187 | 188 | You can run a model and feed the output into another model: 189 | 190 | ```python 191 | laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05") 192 | swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a") 193 | image = laionide.predict(prompt="avocado armchair") 194 | upscaled_image = swinir.predict(image=image) 195 | ``` 196 | 197 | ## Get output from a running model 198 | 199 | Run a model and get its output while it's running: 200 | 201 | ```python 202 | iterator = replicate.run( 203 | "pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf", 204 | input={"prompts": "san francisco sunset"} 205 | ) 206 | 207 | for image in iterator: 208 | display(image) 209 | ``` 210 | 211 | ## Cancel a prediction 212 | 213 | You can cancel a running prediction: 214 | 215 | ```python 216 | >>> model = replicate.models.get("kvfrans/clipdraw") 217 | >>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") 218 | >>> prediction = replicate.predictions.create( 219 | version=version, 220 | input={"prompt":"Watercolor painting of an underwater submarine"} 221 | ) 222 | 223 | >>> prediction.status 224 | 'starting' 225 | 226 | >>> prediction.cancel() 227 | 228 | >>> prediction.reload() 229 | >>> prediction.status 230 | 'canceled' 231 | ``` 232 | 233 | ## List predictions 234 | 235 | You can list all the predictions you've run: 236 | 237 | ```python 238 | replicate.predictions.list() 239 | # [, ] 240 | ``` 241 | 242 | Lists of predictions are paginated. You can get the next page of predictions by passing the `next` property as an argument to the `list` method: 243 | 244 | ```python 245 | page1 = replicate.predictions.list() 246 | 247 | if page1.next: 248 | page2 = replicate.predictions.list(page1.next) 249 | ``` 250 | 251 | ## Load output files 252 | 253 | Output files are returned as HTTPS URLs. You can load an output file as a buffer: 254 | 255 | ```python 256 | import replicate 257 | from PIL import Image 258 | from urllib.request import urlretrieve 259 | 260 | out = replicate.run( 261 | "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", 262 | input={"prompt": "wavy colorful abstract patterns, oceans"} 263 | ) 264 | 265 | urlretrieve(out[0], "/tmp/out.png") 266 | background = Image.open("/tmp/out.png") 267 | ``` 268 | 269 | ## List models 270 | 271 | You can the models you've created: 272 | 273 | ```python 274 | replicate.models.list() 275 | ``` 276 | 277 | Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method, or you can use the `paginate` method to fetch pages automatically. 278 | 279 | ```python 280 | # Automatic pagination using `replicate.paginate` (recommended) 281 | models = [] 282 | for page in replicate.paginate(replicate.models.list): 283 | models.extend(page.results) 284 | if len(models) > 100: 285 | break 286 | 287 | # Manual pagination using `next` cursors 288 | page = replicate.models.list() 289 | while page: 290 | models.extend(page.results) 291 | if len(models) > 100: 292 | break 293 | page = replicate.models.list(page.next) if page.next else None 294 | ``` 295 | 296 | You can also find collections of featured models on Replicate: 297 | 298 | ```python 299 | >>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page] 300 | >>> collections[0].slug 301 | "vision-models" 302 | >>> collections[0].description 303 | "Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)" 304 | 305 | >>> replicate.collections.get("text-to-image").models 306 | [, ...] 307 | ``` 308 | 309 | ## Create a model 310 | 311 | You can create a model for a user or organization 312 | with a given name, visibility, and hardware SKU: 313 | 314 | ```python 315 | import replicate 316 | 317 | model = replicate.models.create( 318 | owner="your-username", 319 | name="my-model", 320 | visibility="public", 321 | hardware="gpu-a40-large" 322 | ) 323 | ``` 324 | 325 | Here's how to list of all the available hardware for running models on Replicate: 326 | 327 | ```python 328 | >>> [hw.sku for hw in replicate.hardware.list()] 329 | ['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large'] 330 | ``` 331 | 332 | ## Fine-tune a model 333 | 334 | Use the [training API](https://replicate.com/docs/fine-tuning) 335 | to fine-tune models to make them better at a particular task. 336 | To see what **language models** currently support fine-tuning, 337 | check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). 338 | 339 | If you're looking to fine-tune **image models**, 340 | check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model). 341 | 342 | Here's how to fine-tune a model on Replicate: 343 | 344 | ```python 345 | training = replicate.trainings.create( 346 | model="stability-ai/sdxl", 347 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 348 | input={ 349 | "input_images": "https://my-domain/training-images.zip", 350 | "token_string": "TOK", 351 | "caption_prefix": "a photo of TOK", 352 | "max_train_steps": 1000, 353 | "use_face_detection_instead": False 354 | }, 355 | # You need to create a model on Replicate that will be the destination for the trained version. 356 | destination="your-username/model-name" 357 | ) 358 | ``` 359 | 360 | ## Development 361 | 362 | See [CONTRIBUTING.md](CONTRIBUTING.md) 363 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "replicate" 7 | version = "0.25.2" 8 | description = "Python client for Replicate" 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | authors = [{ name = "Replicate, Inc." }] 12 | requires-python = ">=3.8" 13 | dependencies = [ 14 | "httpx>=0.21.0,<1", 15 | "packaging", 16 | "pydantic>1.10.7", 17 | "typing_extensions>=4.5.0", 18 | ] 19 | optional-dependencies = { dev = [ 20 | "pylint", 21 | "pyright", 22 | "pytest", 23 | "pytest-asyncio", 24 | "pytest-recording", 25 | "respx", 26 | "ruff>=0.3.3", 27 | ] } 28 | 29 | [project.urls] 30 | homepage = "https://replicate.com" 31 | repository = "https://github.com/replicate/replicate-python" 32 | 33 | [tool.pytest.ini_options] 34 | testpaths = "tests/" 35 | 36 | [tool.setuptools] 37 | packages = ["replicate"] 38 | 39 | [tool.setuptools.package-data] 40 | "replicate" = ["py.typed"] 41 | 42 | [tool.pylint.main] 43 | disable = [ 44 | "C0301", # Line too long 45 | "C0413", # Import should be placed at the top of the module 46 | "C0114", # Missing module docstring 47 | "R0801", # Similar lines in N files 48 | "W0212", # Access to a protected member 49 | "W0622", # Redefining built-in 50 | "R0903", # Too few public methods 51 | ] 52 | good-names = ["id"] 53 | 54 | [tool.ruff.lint] 55 | select = [ 56 | "E", # pycodestyle error 57 | "F", # Pyflakes 58 | "I", # isort 59 | "W", # pycodestyle warning 60 | "UP", # pyupgrade 61 | "S", # flake8-bandit 62 | "BLE", # flake8-blind-except 63 | "FBT", # flake8-boolean-trap 64 | "B", # flake8-bugbear 65 | "ANN", # flake8-annotations 66 | ] 67 | ignore = [ 68 | "E501", # Line too long 69 | "S113", # Probable use of requests call without timeout 70 | "ANN001", # Missing type annotation for function argument 71 | "ANN002", # Missing type annotation for `*args` 72 | "ANN003", # Missing type annotation for `**kwargs` 73 | "ANN101", # Missing type annotation for self in method 74 | "ANN102", # Missing type annotation for cls in classmethod 75 | "W191", # Indentation contains tabs 76 | ] 77 | 78 | [tool.ruff.lint.per-file-ignores] 79 | "tests/*" = [ 80 | "S101", # Use of assert 81 | "S106", # Possible use of hard-coded password function arguments 82 | "ANN201", # Missing return type annotation for public function 83 | ] 84 | -------------------------------------------------------------------------------- /replicate/__about__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | __version__ = version(__package__) 4 | -------------------------------------------------------------------------------- /replicate/__init__.py: -------------------------------------------------------------------------------- 1 | from replicate.client import Client 2 | from replicate.pagination import async_paginate as _async_paginate 3 | from replicate.pagination import paginate as _paginate 4 | 5 | default_client = Client() 6 | 7 | run = default_client.run 8 | async_run = default_client.async_run 9 | 10 | stream = default_client.stream 11 | async_stream = default_client.async_stream 12 | 13 | paginate = _paginate 14 | async_paginate = _async_paginate 15 | 16 | collections = default_client.collections 17 | hardware = default_client.hardware 18 | deployments = default_client.deployments 19 | models = default_client.models 20 | predictions = default_client.predictions 21 | trainings = default_client.trainings 22 | -------------------------------------------------------------------------------- /replicate/account.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Literal, Optional 2 | 3 | from replicate.resource import Namespace, Resource 4 | 5 | 6 | class Account(Resource): 7 | """ 8 | A user or organization account on Replicate. 9 | """ 10 | 11 | type: Literal["user", "organization"] 12 | """The type of account.""" 13 | 14 | username: str 15 | """The username of the account.""" 16 | 17 | name: str 18 | """The name of the account.""" 19 | 20 | github_url: Optional[str] 21 | """The GitHub URL of the account.""" 22 | 23 | 24 | class Accounts(Namespace): 25 | """ 26 | Namespace for operations related to accounts. 27 | """ 28 | 29 | def current(self) -> Account: 30 | """ 31 | Get the current account. 32 | 33 | Returns: 34 | Account: The current account. 35 | """ 36 | 37 | resp = self._client._request("GET", "/v1/account") 38 | obj = resp.json() 39 | 40 | return _json_to_account(obj) 41 | 42 | async def async_current(self) -> Account: 43 | """ 44 | Get the current account. 45 | 46 | Returns: 47 | Account: The current account. 48 | """ 49 | 50 | resp = await self._client._async_request("GET", "/v1/account") 51 | obj = resp.json() 52 | 53 | return _json_to_account(obj) 54 | 55 | 56 | def _json_to_account(json: Dict[str, Any]) -> Account: 57 | return Account(**json) 58 | -------------------------------------------------------------------------------- /replicate/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import random 4 | import time 5 | from datetime import datetime 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Any, 9 | AsyncIterator, 10 | Dict, 11 | Iterable, 12 | Iterator, 13 | Mapping, 14 | Optional, 15 | Type, 16 | Union, 17 | ) 18 | 19 | import httpx 20 | from typing_extensions import Unpack 21 | 22 | from replicate.__about__ import __version__ 23 | from replicate.account import Accounts 24 | from replicate.collection import Collections 25 | from replicate.deployment import Deployments 26 | from replicate.exceptions import ReplicateError 27 | from replicate.hardware import HardwareNamespace as Hardware 28 | from replicate.model import Models 29 | from replicate.prediction import Predictions 30 | from replicate.run import async_run, run 31 | from replicate.stream import async_stream, stream 32 | from replicate.training import Trainings 33 | 34 | if TYPE_CHECKING: 35 | from replicate.stream import ServerSentEvent 36 | 37 | 38 | class Client: 39 | """A Replicate API client library""" 40 | 41 | __client: Optional[httpx.Client] = None 42 | __async_client: Optional[httpx.AsyncClient] = None 43 | 44 | def __init__( 45 | self, 46 | api_token: Optional[str] = None, 47 | *, 48 | base_url: Optional[str] = None, 49 | timeout: Optional[httpx.Timeout] = None, 50 | **kwargs, 51 | ) -> None: 52 | super().__init__() 53 | 54 | self._api_token = api_token 55 | self._base_url = base_url 56 | self._timeout = timeout 57 | self._client_kwargs = kwargs 58 | 59 | self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) 60 | 61 | @property 62 | def _client(self) -> httpx.Client: 63 | if not self.__client: 64 | self.__client = _build_httpx_client( 65 | httpx.Client, 66 | self._api_token, 67 | self._base_url, 68 | self._timeout, 69 | **self._client_kwargs, 70 | ) # type: ignore[assignment] 71 | return self.__client # type: ignore[return-value] 72 | 73 | @property 74 | def _async_client(self) -> httpx.AsyncClient: 75 | if not self.__async_client: 76 | self.__async_client = _build_httpx_client( 77 | httpx.AsyncClient, 78 | self._api_token, 79 | self._base_url, 80 | self._timeout, 81 | **self._client_kwargs, 82 | ) # type: ignore[assignment] 83 | return self.__async_client # type: ignore[return-value] 84 | 85 | def _request(self, method: str, path: str, **kwargs) -> httpx.Response: 86 | resp = self._client.request(method, path, **kwargs) 87 | _raise_for_status(resp) 88 | 89 | return resp 90 | 91 | async def _async_request(self, method: str, path: str, **kwargs) -> httpx.Response: 92 | resp = await self._async_client.request(method, path, **kwargs) 93 | _raise_for_status(resp) 94 | 95 | return resp 96 | 97 | @property 98 | def accounts(self) -> Accounts: 99 | """ 100 | Namespace for operations related to accounts. 101 | """ 102 | 103 | return Accounts(client=self) 104 | 105 | @property 106 | def collections(self) -> Collections: 107 | """ 108 | Namespace for operations related to collections of models. 109 | """ 110 | return Collections(client=self) 111 | 112 | @property 113 | def deployments(self) -> Deployments: 114 | """ 115 | Namespace for operations related to deployments. 116 | """ 117 | return Deployments(client=self) 118 | 119 | @property 120 | def hardware(self) -> Hardware: 121 | """ 122 | Namespace for operations related to hardware. 123 | """ 124 | return Hardware(client=self) 125 | 126 | @property 127 | def models(self) -> Models: 128 | """ 129 | Namespace for operations related to models. 130 | """ 131 | return Models(client=self) 132 | 133 | @property 134 | def predictions(self) -> Predictions: 135 | """ 136 | Namespace for operations related to predictions. 137 | """ 138 | return Predictions(client=self) 139 | 140 | @property 141 | def trainings(self) -> Trainings: 142 | """ 143 | Namespace for operations related to trainings. 144 | """ 145 | return Trainings(client=self) 146 | 147 | def run( 148 | self, 149 | ref: str, 150 | input: Optional[Dict[str, Any]] = None, 151 | **params: Unpack["Predictions.CreatePredictionParams"], 152 | ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 153 | """ 154 | Run a model and wait for its output. 155 | """ 156 | 157 | return run(self, ref, input, **params) 158 | 159 | async def async_run( 160 | self, 161 | ref: str, 162 | input: Optional[Dict[str, Any]] = None, 163 | **params: Unpack["Predictions.CreatePredictionParams"], 164 | ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 165 | """ 166 | Run a model and wait for its output asynchronously. 167 | """ 168 | 169 | return await async_run(self, ref, input, **params) 170 | 171 | def stream( 172 | self, 173 | ref: str, 174 | input: Optional[Dict[str, Any]] = None, 175 | **params: Unpack["Predictions.CreatePredictionParams"], 176 | ) -> Iterator["ServerSentEvent"]: 177 | """ 178 | Stream a model's output. 179 | """ 180 | 181 | return stream(self, ref, input, **params) 182 | 183 | async def async_stream( 184 | self, 185 | ref: str, 186 | input: Optional[Dict[str, Any]] = None, 187 | **params: Unpack["Predictions.CreatePredictionParams"], 188 | ) -> AsyncIterator["ServerSentEvent"]: 189 | """ 190 | Stream a model's output asynchronously. 191 | """ 192 | 193 | return async_stream(self, ref, input, **params) 194 | 195 | 196 | # Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 197 | class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport): 198 | """A custom HTTP transport that automatically retries requests using an exponential backoff strategy 199 | for specific HTTP status codes and request methods. 200 | """ 201 | 202 | RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]) 203 | RETRYABLE_STATUS_CODES = frozenset( 204 | [ 205 | 429, # Too Many Requests 206 | 503, # Service Unavailable 207 | 504, # Gateway Timeout 208 | ] 209 | ) 210 | MAX_BACKOFF_WAIT = 60 211 | 212 | def __init__( # pylint: disable=too-many-arguments 213 | self, 214 | wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], 215 | *, 216 | max_attempts: int = 10, 217 | max_backoff_wait: float = MAX_BACKOFF_WAIT, 218 | backoff_factor: float = 0.1, 219 | jitter_ratio: float = 0.1, 220 | retryable_methods: Optional[Iterable[str]] = None, 221 | retry_status_codes: Optional[Iterable[int]] = None, 222 | ) -> None: 223 | self._wrapped_transport = wrapped_transport 224 | 225 | if jitter_ratio < 0 or jitter_ratio > 0.5: 226 | raise ValueError( 227 | f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}" 228 | ) 229 | 230 | self.max_attempts = max_attempts 231 | self.backoff_factor = backoff_factor 232 | self.retryable_methods = ( 233 | frozenset(retryable_methods) 234 | if retryable_methods 235 | else self.RETRYABLE_METHODS 236 | ) 237 | self.retry_status_codes = ( 238 | frozenset(retry_status_codes) 239 | if retry_status_codes 240 | else self.RETRYABLE_STATUS_CODES 241 | ) 242 | self.jitter_ratio = jitter_ratio 243 | self.max_backoff_wait = max_backoff_wait 244 | 245 | def _calculate_sleep( 246 | self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]] 247 | ) -> float: 248 | retry_after_header = (headers.get("Retry-After") or "").strip() 249 | if retry_after_header: 250 | if retry_after_header.isdigit(): 251 | return float(retry_after_header) 252 | 253 | try: 254 | parsed_date = datetime.fromisoformat(retry_after_header).astimezone() 255 | diff = (parsed_date - datetime.now().astimezone()).total_seconds() 256 | if diff > 0: 257 | return min(diff, self.max_backoff_wait) 258 | except ValueError: 259 | pass 260 | 261 | backoff = self.backoff_factor * (2 ** (attempts_made - 1)) 262 | jitter = (backoff * self.jitter_ratio) * random.choice([1, -1]) # noqa: S311 263 | total_backoff = backoff + jitter 264 | return min(total_backoff, self.max_backoff_wait) 265 | 266 | def handle_request(self, request: httpx.Request) -> httpx.Response: 267 | response = self._wrapped_transport.handle_request(request) # type: ignore 268 | 269 | if request.method not in self.retryable_methods: 270 | return response 271 | 272 | remaining_attempts = self.max_attempts - 1 273 | attempts_made = 1 274 | 275 | while True: 276 | if ( 277 | remaining_attempts < 1 278 | or response.status_code not in self.retry_status_codes 279 | ): 280 | return response 281 | 282 | sleep_for = self._calculate_sleep(attempts_made, response.headers) 283 | time.sleep(sleep_for) 284 | 285 | response = self._wrapped_transport.handle_request(request) # type: ignore 286 | 287 | attempts_made += 1 288 | remaining_attempts -= 1 289 | 290 | async def handle_async_request(self, request: httpx.Request) -> httpx.Response: 291 | response = await self._wrapped_transport.handle_async_request(request) # type: ignore 292 | 293 | if request.method not in self.retryable_methods: 294 | return response 295 | 296 | remaining_attempts = self.max_attempts - 1 297 | attempts_made = 1 298 | 299 | while True: 300 | if ( 301 | remaining_attempts < 1 302 | or response.status_code not in self.retry_status_codes 303 | ): 304 | return response 305 | 306 | sleep_for = self._calculate_sleep(attempts_made, response.headers) 307 | await asyncio.sleep(sleep_for) 308 | 309 | response = await self._wrapped_transport.handle_async_request(request) # type: ignore 310 | 311 | attempts_made += 1 312 | remaining_attempts -= 1 313 | 314 | def close(self) -> None: 315 | self._wrapped_transport.close() # type: ignore 316 | 317 | async def aclose(self) -> None: 318 | await self._wrapped_transport.aclose() # type: ignore 319 | 320 | 321 | def _build_httpx_client( 322 | client_type: Type[Union[httpx.Client, httpx.AsyncClient]], 323 | api_token: Optional[str] = None, 324 | base_url: Optional[str] = None, 325 | timeout: Optional[httpx.Timeout] = None, 326 | **kwargs, 327 | ) -> Union[httpx.Client, httpx.AsyncClient]: 328 | headers = kwargs.pop("headers", {}) 329 | headers["User-Agent"] = f"replicate-python/{__version__}" 330 | 331 | if ( 332 | api_token := api_token or os.environ.get("REPLICATE_API_TOKEN") 333 | ) and api_token != "": 334 | headers["Authorization"] = f"Bearer {api_token}" 335 | 336 | base_url = ( 337 | base_url or os.environ.get("REPLICATE_BASE_URL") or "https://api.replicate.com" 338 | ) 339 | if base_url == "": 340 | base_url = "https://api.replicate.com" 341 | 342 | timeout = timeout or httpx.Timeout( 343 | 5.0, read=30.0, write=30.0, connect=5.0, pool=10.0 344 | ) 345 | 346 | transport = kwargs.pop("transport", None) or ( 347 | httpx.HTTPTransport() 348 | if client_type is httpx.Client 349 | else httpx.AsyncHTTPTransport() 350 | ) 351 | 352 | return client_type( 353 | base_url=base_url, 354 | headers=headers, 355 | timeout=timeout, 356 | transport=RetryTransport(wrapped_transport=transport), # type: ignore[arg-type] 357 | **kwargs, 358 | ) 359 | 360 | 361 | def _raise_for_status(resp: httpx.Response) -> None: 362 | if 400 <= resp.status_code < 600: 363 | raise ReplicateError.from_response(resp) 364 | -------------------------------------------------------------------------------- /replicate/collection.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterator, List, Optional, Union, overload 2 | 3 | from typing_extensions import deprecated 4 | 5 | from replicate.model import Model 6 | from replicate.pagination import Page 7 | from replicate.resource import Namespace, Resource 8 | 9 | 10 | class Collection(Resource): 11 | """ 12 | A collection of models on Replicate. 13 | """ 14 | 15 | slug: str 16 | """The slug used to identify the collection.""" 17 | 18 | name: str 19 | """The name of the collection.""" 20 | 21 | description: str 22 | """A description of the collection.""" 23 | 24 | models: Optional[List[Model]] = None 25 | """The models in the collection.""" 26 | 27 | @property 28 | @deprecated("Use `slug` instead of `id`") 29 | def id(self) -> str: 30 | """ 31 | DEPRECATED: Use `slug` instead. 32 | """ 33 | return self.slug 34 | 35 | def __iter__(self) -> Iterator[Model]: 36 | if self.models is not None: 37 | return iter(self.models) 38 | return iter([]) 39 | 40 | @overload 41 | def __getitem__(self, index: int) -> Optional[Model]: ... 42 | 43 | @overload 44 | def __getitem__(self, index: slice) -> Optional[List[Model]]: ... 45 | 46 | def __getitem__( 47 | self, index: Union[int, slice] 48 | ) -> Union[Optional[Model], Optional[List[Model]]]: 49 | if self.models is not None: 50 | return self.models[index] 51 | return None 52 | 53 | def __len__(self) -> int: 54 | if self.models is not None: 55 | return len(self.models) 56 | 57 | return 0 58 | 59 | 60 | class Collections(Namespace): 61 | """ 62 | A namespace for operations related to collections of models. 63 | """ 64 | 65 | def list( 66 | self, 67 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 68 | ) -> Page[Collection]: 69 | """ 70 | List collections of models. 71 | 72 | Parameters: 73 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 74 | Returns: 75 | Page[Collection]: A page of of model collections. 76 | Raises: 77 | ValueError: If `cursor` is `None`. 78 | """ 79 | 80 | if cursor is None: 81 | raise ValueError("cursor cannot be None") 82 | 83 | resp = self._client._request( 84 | "GET", "/v1/collections" if cursor is ... else cursor 85 | ) 86 | 87 | obj = resp.json() 88 | obj["results"] = [_json_to_collection(result) for result in obj["results"]] 89 | 90 | return Page[Collection](**obj) 91 | 92 | async def async_list( 93 | self, 94 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 95 | ) -> Page[Collection]: 96 | """ 97 | List collections of models. 98 | 99 | Parameters: 100 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 101 | Returns: 102 | Page[Collection]: A page of of model collections. 103 | Raises: 104 | ValueError: If `cursor` is `None`. 105 | """ 106 | 107 | if cursor is None: 108 | raise ValueError("cursor cannot be None") 109 | 110 | resp = await self._client._async_request( 111 | "GET", "/v1/collections" if cursor is ... else cursor 112 | ) 113 | 114 | obj = resp.json() 115 | obj["results"] = [_json_to_collection(result) for result in obj["results"]] 116 | 117 | return Page[Collection](**obj) 118 | 119 | def get(self, slug: str) -> Collection: 120 | """Get a model by name. 121 | 122 | Args: 123 | name: The name of the model, in the format `owner/model-name`. 124 | Returns: 125 | The model. 126 | """ 127 | 128 | resp = self._client._request("GET", f"/v1/collections/{slug}") 129 | 130 | return _json_to_collection(resp.json()) 131 | 132 | async def async_get(self, slug: str) -> Collection: 133 | """Get a model by name. 134 | 135 | Args: 136 | name: The name of the model, in the format `owner/model-name`. 137 | Returns: 138 | The model. 139 | """ 140 | 141 | resp = await self._client._async_request("GET", f"/v1/collections/{slug}") 142 | 143 | return _json_to_collection(resp.json()) 144 | 145 | 146 | def _json_to_collection(json: Dict[str, Any]) -> Collection: 147 | return Collection(**json) 148 | -------------------------------------------------------------------------------- /replicate/deployment.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union 2 | 3 | from typing_extensions import Unpack, deprecated 4 | 5 | from replicate.account import Account 6 | from replicate.pagination import Page 7 | from replicate.prediction import ( 8 | Prediction, 9 | _create_prediction_body, 10 | _json_to_prediction, 11 | ) 12 | from replicate.resource import Namespace, Resource 13 | 14 | try: 15 | from pydantic import v1 as pydantic # type: ignore 16 | except ImportError: 17 | import pydantic # type: ignore 18 | 19 | 20 | if TYPE_CHECKING: 21 | from replicate.client import Client 22 | from replicate.prediction import Predictions 23 | 24 | 25 | class Deployment(Resource): 26 | """ 27 | A deployment of a model hosted on Replicate. 28 | """ 29 | 30 | _client: "Client" = pydantic.PrivateAttr() 31 | 32 | owner: str 33 | """ 34 | The name of the user or organization that owns the deployment. 35 | """ 36 | 37 | name: str 38 | """ 39 | The name of the deployment. 40 | """ 41 | 42 | class Release(Resource): 43 | """ 44 | A release of a deployment. 45 | """ 46 | 47 | number: int 48 | """ 49 | The release number. 50 | """ 51 | 52 | model: str 53 | """ 54 | The model identifier string in the format of `{model_owner}/{model_name}`. 55 | """ 56 | 57 | version: str 58 | """ 59 | The ID of the model version used in the release. 60 | """ 61 | 62 | created_at: str 63 | """ 64 | The time the release was created. 65 | """ 66 | 67 | created_by: Optional[Account] 68 | """ 69 | The account that created the release. 70 | """ 71 | 72 | class Configuration(Resource): 73 | """ 74 | A configuration for a deployment. 75 | """ 76 | 77 | hardware: str 78 | """ 79 | The SKU for the hardware used to run the model. 80 | """ 81 | 82 | min_instances: int 83 | """ 84 | The minimum number of instances for scaling. 85 | """ 86 | 87 | max_instances: int 88 | """ 89 | The maximum number of instances for scaling. 90 | """ 91 | 92 | configuration: Configuration 93 | """ 94 | The deployment configuration. 95 | """ 96 | 97 | current_release: Optional[Release] 98 | """ 99 | The current release of the deployment. 100 | """ 101 | 102 | @property 103 | @deprecated("Use `deployment.owner` instead.") 104 | def username(self) -> str: 105 | """ 106 | The name of the user or organization that owns the deployment. 107 | This attribute is deprecated and will be removed in future versions. 108 | """ 109 | return self.owner 110 | 111 | @property 112 | def id(self) -> str: 113 | """ 114 | Return the qualified deployment name, in the format `owner/name`. 115 | """ 116 | return f"{self.owner}/{self.name}" 117 | 118 | @property 119 | def predictions(self) -> "DeploymentPredictions": 120 | """ 121 | Get the predictions for this deployment. 122 | """ 123 | 124 | return DeploymentPredictions(client=self._client, deployment=self) 125 | 126 | 127 | class Deployments(Namespace): 128 | """ 129 | Namespace for operations related to deployments. 130 | """ 131 | 132 | _client: "Client" 133 | 134 | def list( 135 | self, 136 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 137 | ) -> Page[Deployment]: 138 | """ 139 | List all deployments. 140 | 141 | Returns: 142 | A page of Deployments. 143 | """ 144 | 145 | if cursor is None: 146 | raise ValueError("cursor cannot be None") 147 | 148 | resp = self._client._request( 149 | "GET", "/v1/deployments" if cursor is ... else cursor 150 | ) 151 | 152 | obj = resp.json() 153 | obj["results"] = [ 154 | _json_to_deployment(self._client, result) for result in obj["results"] 155 | ] 156 | 157 | return Page[Deployment](**obj) 158 | 159 | async def async_list( 160 | self, 161 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 162 | ) -> Page[Deployment]: 163 | """ 164 | List all deployments. 165 | 166 | Returns: 167 | A page of Deployments. 168 | """ 169 | if cursor is None: 170 | raise ValueError("cursor cannot be None") 171 | 172 | resp = await self._client._async_request( 173 | "GET", "/v1/deployments" if cursor is ... else cursor 174 | ) 175 | 176 | obj = resp.json() 177 | obj["results"] = [ 178 | _json_to_deployment(self._client, result) for result in obj["results"] 179 | ] 180 | 181 | return Page[Deployment](**obj) 182 | 183 | def get(self, name: str) -> Deployment: 184 | """ 185 | Get a deployment by name. 186 | 187 | Args: 188 | name: The name of the deployment, in the format `owner/model-name`. 189 | Returns: 190 | The model. 191 | """ 192 | 193 | owner, name = name.split("/", 1) 194 | 195 | resp = self._client._request( 196 | "GET", 197 | f"/v1/deployments/{owner}/{name}", 198 | ) 199 | 200 | return _json_to_deployment(self._client, resp.json()) 201 | 202 | async def async_get(self, name: str) -> Deployment: 203 | """ 204 | Get a deployment by name. 205 | 206 | Args: 207 | name: The name of the deployment, in the format `owner/model-name`. 208 | Returns: 209 | The model. 210 | """ 211 | 212 | owner, name = name.split("/", 1) 213 | 214 | resp = await self._client._async_request( 215 | "GET", 216 | f"/v1/deployments/{owner}/{name}", 217 | ) 218 | 219 | return _json_to_deployment(self._client, resp.json()) 220 | 221 | class CreateDeploymentParams(TypedDict): 222 | """ 223 | Parameters for creating a new deployment. 224 | """ 225 | 226 | name: str 227 | """The name of the deployment.""" 228 | 229 | model: str 230 | """The model identifier string in the format of `{model_owner}/{model_name}`.""" 231 | 232 | version: str 233 | """The version of the model to deploy.""" 234 | 235 | hardware: str 236 | """The SKU for the hardware used to run the model.""" 237 | 238 | min_instances: int 239 | """The minimum number of instances for scaling.""" 240 | 241 | max_instances: int 242 | """The maximum number of instances for scaling.""" 243 | 244 | def create(self, **params: Unpack[CreateDeploymentParams]) -> Deployment: 245 | """ 246 | Create a new deployment. 247 | 248 | Args: 249 | params: Configuration for the new deployment. 250 | Returns: 251 | The newly created Deployment. 252 | """ 253 | 254 | if name := params.get("name", None): 255 | if "/" in name: 256 | _, name = name.split("/", 1) 257 | params["name"] = name 258 | 259 | resp = self._client._request( 260 | "POST", 261 | "/v1/deployments", 262 | json=params, 263 | ) 264 | 265 | return _json_to_deployment(self._client, resp.json()) 266 | 267 | async def async_create( 268 | self, **params: Unpack[CreateDeploymentParams] 269 | ) -> Deployment: 270 | """ 271 | Create a new deployment. 272 | 273 | Args: 274 | params: Configuration for the new deployment. 275 | Returns: 276 | The newly created Deployment. 277 | """ 278 | 279 | if name := params.get("name", None): 280 | if "/" in name: 281 | _, name = name.split("/", 1) 282 | params["name"] = name 283 | 284 | resp = await self._client._async_request( 285 | "POST", 286 | "/v1/deployments", 287 | json=params, 288 | ) 289 | 290 | return _json_to_deployment(self._client, resp.json()) 291 | 292 | class UpdateDeploymentParams(TypedDict, total=False): 293 | """ 294 | Parameters for updating an existing deployment. 295 | """ 296 | 297 | version: str 298 | """The version of the model to deploy.""" 299 | 300 | hardware: str 301 | """The SKU for the hardware used to run the model.""" 302 | 303 | min_instances: int 304 | """The minimum number of instances for scaling.""" 305 | 306 | max_instances: int 307 | """The maximum number of instances for scaling.""" 308 | 309 | def update( 310 | self, 311 | deployment_owner: str, 312 | deployment_name: str, 313 | **params: Unpack[UpdateDeploymentParams], 314 | ) -> Deployment: 315 | """ 316 | Update an existing deployment. 317 | 318 | Args: 319 | deployment_owner: The owner of the deployment. 320 | deployment_name: The name of the deployment. 321 | params: Configuration updates for the deployment. 322 | Returns: 323 | The updated Deployment. 324 | """ 325 | 326 | resp = self._client._request( 327 | "PATCH", 328 | f"/v1/deployments/{deployment_owner}/{deployment_name}", 329 | json=params, 330 | ) 331 | 332 | return _json_to_deployment(self._client, resp.json()) 333 | 334 | async def async_update( 335 | self, 336 | deployment_owner: str, 337 | deployment_name: str, 338 | **params: Unpack[UpdateDeploymentParams], 339 | ) -> Deployment: 340 | """ 341 | Update an existing deployment. 342 | 343 | Args: 344 | deployment_owner: The owner of the deployment. 345 | deployment_name: The name of the deployment. 346 | params: Configuration updates for the deployment. 347 | Returns: 348 | The updated Deployment. 349 | """ 350 | 351 | resp = await self._client._async_request( 352 | "PATCH", 353 | f"/v1/deployments/{deployment_owner}/{deployment_name}", 354 | json=params, 355 | ) 356 | 357 | return _json_to_deployment(self._client, resp.json()) 358 | 359 | @property 360 | def predictions(self) -> "DeploymentsPredictions": 361 | """ 362 | Get predictions for deployments. 363 | """ 364 | 365 | return DeploymentsPredictions(client=self._client) 366 | 367 | 368 | def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment: 369 | deployment = Deployment(**json) 370 | deployment._client = client 371 | return deployment 372 | 373 | 374 | class DeploymentPredictions(Namespace): 375 | """ 376 | Namespace for operations related to predictions in a deployment. 377 | """ 378 | 379 | _deployment: Deployment 380 | 381 | def __init__(self, client: "Client", deployment: Deployment) -> None: 382 | super().__init__(client=client) 383 | self._deployment = deployment 384 | 385 | def create( 386 | self, 387 | input: Dict[str, Any], 388 | **params: Unpack["Predictions.CreatePredictionParams"], 389 | ) -> Prediction: 390 | """ 391 | Create a new prediction with the deployment. 392 | """ 393 | 394 | body = _create_prediction_body(version=None, input=input, **params) 395 | 396 | resp = self._client._request( 397 | "POST", 398 | f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", 399 | json=body, 400 | ) 401 | 402 | return _json_to_prediction(self._client, resp.json()) 403 | 404 | async def async_create( 405 | self, 406 | input: Dict[str, Any], 407 | **params: Unpack["Predictions.CreatePredictionParams"], 408 | ) -> Prediction: 409 | """ 410 | Create a new prediction with the deployment. 411 | """ 412 | 413 | body = _create_prediction_body(version=None, input=input, **params) 414 | 415 | resp = await self._client._async_request( 416 | "POST", 417 | f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", 418 | json=body, 419 | ) 420 | 421 | return _json_to_prediction(self._client, resp.json()) 422 | 423 | 424 | class DeploymentsPredictions(Namespace): 425 | """ 426 | Namespace for operations related to predictions in deployments. 427 | """ 428 | 429 | def create( 430 | self, 431 | deployment: Union[str, Tuple[str, str], Deployment], 432 | input: Dict[str, Any], 433 | **params: Unpack["Predictions.CreatePredictionParams"], 434 | ) -> Prediction: 435 | """ 436 | Create a new prediction with the deployment. 437 | """ 438 | 439 | url = _create_prediction_url_from_deployment(deployment) 440 | body = _create_prediction_body(version=None, input=input, **params) 441 | 442 | resp = self._client._request( 443 | "POST", 444 | url, 445 | json=body, 446 | ) 447 | 448 | return _json_to_prediction(self._client, resp.json()) 449 | 450 | async def async_create( 451 | self, 452 | deployment: Union[str, Tuple[str, str], Deployment], 453 | input: Dict[str, Any], 454 | **params: Unpack["Predictions.CreatePredictionParams"], 455 | ) -> Prediction: 456 | """ 457 | Create a new prediction with the deployment. 458 | """ 459 | 460 | url = _create_prediction_url_from_deployment(deployment) 461 | body = _create_prediction_body(version=None, input=input, **params) 462 | 463 | resp = await self._client._async_request( 464 | "POST", 465 | url, 466 | json=body, 467 | ) 468 | 469 | return _json_to_prediction(self._client, resp.json()) 470 | 471 | 472 | def _create_prediction_url_from_deployment( 473 | deployment: Union[str, Tuple[str, str], Deployment], 474 | ) -> str: 475 | owner, name = None, None 476 | if isinstance(deployment, Deployment): 477 | owner, name = deployment.owner, deployment.name 478 | elif isinstance(deployment, tuple): 479 | owner, name = deployment[0], deployment[1] 480 | elif isinstance(deployment, str): 481 | owner, name = deployment.split("/", 1) 482 | 483 | if owner is None or name is None: 484 | raise ValueError( 485 | "deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'" 486 | ) 487 | 488 | return f"/v1/deployments/{owner}/{name}/predictions" 489 | -------------------------------------------------------------------------------- /replicate/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import httpx 4 | 5 | 6 | class ReplicateException(Exception): 7 | """A base class for all Replicate exceptions.""" 8 | 9 | 10 | class ModelError(ReplicateException): 11 | """An error from user's code in a model.""" 12 | 13 | 14 | class ReplicateError(ReplicateException): 15 | """ 16 | An error from Replicate's API. 17 | 18 | This class represents a problem details response as defined in RFC 7807. 19 | """ 20 | 21 | type: Optional[str] 22 | """A URI that identifies the error type.""" 23 | 24 | title: Optional[str] 25 | """A short, human-readable summary of the error.""" 26 | 27 | status: Optional[int] 28 | """The HTTP status code.""" 29 | 30 | detail: Optional[str] 31 | """A human-readable explanation specific to this occurrence of the error.""" 32 | 33 | instance: Optional[str] 34 | """A URI that identifies the specific occurrence of the error.""" 35 | 36 | def __init__( # pylint: disable=too-many-arguments 37 | self, 38 | type: Optional[str] = None, 39 | title: Optional[str] = None, 40 | status: Optional[int] = None, 41 | detail: Optional[str] = None, 42 | instance: Optional[str] = None, 43 | ) -> None: 44 | self.type = type 45 | self.title = title 46 | self.status = status 47 | self.detail = detail 48 | self.instance = instance 49 | 50 | @classmethod 51 | def from_response(cls, response: httpx.Response) -> "ReplicateError": 52 | """Create a ReplicateError from an HTTP response.""" 53 | 54 | try: 55 | data = response.json() 56 | except ValueError: 57 | data = {} 58 | 59 | return cls( 60 | type=data.get("type"), 61 | title=data.get("title"), 62 | detail=data.get("detail"), 63 | status=response.status_code, 64 | instance=data.get("instance"), 65 | ) 66 | 67 | def to_dict(self) -> dict: 68 | """Get a dictionary representation of the error.""" 69 | 70 | return { 71 | key: value 72 | for key, value in { 73 | "type": self.type, 74 | "title": self.title, 75 | "status": self.status, 76 | "detail": self.detail, 77 | "instance": self.instance, 78 | }.items() 79 | if value is not None 80 | } 81 | 82 | def __str__(self) -> str: 83 | return "ReplicateError Details:\n" + "\n".join( 84 | [f"{key}: {value}" for key, value in self.to_dict().items()] 85 | ) 86 | 87 | def __repr__(self) -> str: 88 | class_name = self.__class__.__name__ 89 | params = ", ".join( 90 | [ 91 | f"type={repr(self.type)}", 92 | f"title={repr(self.title)}", 93 | f"status={repr(self.status)}", 94 | f"detail={repr(self.detail)}", 95 | f"instance={repr(self.instance)}", 96 | ] 97 | ) 98 | return f"{class_name}({params})" 99 | -------------------------------------------------------------------------------- /replicate/files.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import mimetypes 4 | import os 5 | from typing import Optional 6 | 7 | import httpx 8 | 9 | 10 | def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str: 11 | """ 12 | Upload a file to the server. 13 | 14 | Args: 15 | file: A file handle to upload. 16 | output_file_prefix: A string to prepend to the output file name. 17 | Returns: 18 | str: A URL to the uploaded file. 19 | """ 20 | # Lifted straight from cog.files 21 | 22 | file.seek(0) 23 | 24 | if output_file_prefix is not None: 25 | name = getattr(file, "name", "output") 26 | url = output_file_prefix + os.path.basename(name) 27 | resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore 28 | resp.raise_for_status() 29 | 30 | return url 31 | 32 | body = file.read() 33 | # Ensure the file handle is in bytes 34 | body = body.encode("utf-8") if isinstance(body, str) else body 35 | encoded_body = base64.b64encode(body).decode("utf-8") 36 | # Use getattr to avoid mypy complaints about io.IOBase having no attribute name 37 | mime_type = ( 38 | mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" 39 | ) 40 | return f"data:{mime_type};base64,{encoded_body}" 41 | -------------------------------------------------------------------------------- /replicate/hardware.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List 2 | 3 | from typing_extensions import deprecated 4 | 5 | from replicate.resource import Namespace, Resource 6 | 7 | if TYPE_CHECKING: 8 | pass 9 | 10 | 11 | class Hardware(Resource): 12 | """ 13 | Hardware for running a model on Replicate. 14 | """ 15 | 16 | sku: str 17 | """ 18 | The SKU of the hardware. 19 | """ 20 | 21 | name: str 22 | """ 23 | The name of the hardware. 24 | """ 25 | 26 | @property 27 | @deprecated("Use `sku` instead of `id`") 28 | def id(self) -> str: 29 | """ 30 | DEPRECATED: Use `sku` instead. 31 | """ 32 | return self.sku 33 | 34 | 35 | class HardwareNamespace(Namespace): 36 | """ 37 | Namespace for operations related to hardware. 38 | """ 39 | 40 | def list(self) -> List[Hardware]: 41 | """ 42 | List all hardware available for you to run models on Replicate. 43 | 44 | Returns: 45 | List[Hardware]: A list of hardware. 46 | """ 47 | 48 | resp = self._client._request("GET", "/v1/hardware") 49 | obj = resp.json() 50 | 51 | return [_json_to_hardware(entry) for entry in obj] 52 | 53 | async def async_list(self) -> List[Hardware]: 54 | """ 55 | List all hardware available for you to run models on Replicate. 56 | 57 | Returns: 58 | List[Hardware]: A list of hardware. 59 | """ 60 | 61 | resp = await self._client._async_request("GET", "/v1/hardware") 62 | obj = resp.json() 63 | 64 | return [_json_to_hardware(entry) for entry in obj] 65 | 66 | 67 | def _json_to_hardware(json: Dict[str, Any]) -> Hardware: 68 | return Hardware(**json) 69 | -------------------------------------------------------------------------------- /replicate/identifier.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union 3 | 4 | if TYPE_CHECKING: 5 | from replicate.model import Model 6 | from replicate.version import Version 7 | 8 | 9 | class ModelVersionIdentifier(NamedTuple): 10 | """ 11 | A reference to a model version in the format owner/name or owner/name:version. 12 | """ 13 | 14 | owner: str 15 | name: str 16 | version: Optional[str] = None 17 | 18 | @classmethod 19 | def parse(cls, ref: str) -> "ModelVersionIdentifier": 20 | """ 21 | Split a reference in the format owner/name:version into its components. 22 | """ 23 | 24 | match = re.match(r"^(?P[^/]+)/(?P[^/:]+)(:(?P.+))?$", ref) 25 | if not match: 26 | raise ValueError( 27 | f"Invalid reference to model version: {ref}. Expected format: owner/name:version" 28 | ) 29 | 30 | return cls(match.group("owner"), match.group("name"), match.group("version")) 31 | 32 | 33 | def _resolve( 34 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 35 | ) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]: 36 | from replicate.model import Model # pylint: disable=import-outside-toplevel 37 | from replicate.version import Version # pylint: disable=import-outside-toplevel 38 | 39 | version = None 40 | owner, name, version_id = None, None, None 41 | if isinstance(ref, Model): 42 | owner, name = ref.owner, ref.name 43 | elif isinstance(ref, Version): 44 | version = ref 45 | version_id = ref.id 46 | elif isinstance(ref, ModelVersionIdentifier): 47 | owner, name, version_id = ref 48 | elif isinstance(ref, str): 49 | owner, name, version_id = ModelVersionIdentifier.parse(ref) 50 | return version, owner, name, version_id 51 | -------------------------------------------------------------------------------- /replicate/json.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from types import GeneratorType 4 | from typing import Any, Callable 5 | 6 | try: 7 | import numpy as np # type: ignore 8 | 9 | HAS_NUMPY = True 10 | except ImportError: 11 | HAS_NUMPY = False 12 | 13 | 14 | # pylint: disable=too-many-return-statements 15 | def encode_json( 16 | obj: Any, # noqa: ANN401 17 | upload_file: Callable[[io.IOBase], str], 18 | ) -> Any: # noqa: ANN401 19 | """ 20 | Return a JSON-compatible version of the object. 21 | """ 22 | # Effectively the same thing as cog.json.encode_json. 23 | 24 | if isinstance(obj, dict): 25 | return {key: encode_json(value, upload_file) for key, value in obj.items()} 26 | if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): 27 | return [encode_json(value, upload_file) for value in obj] 28 | if isinstance(obj, Path): 29 | with obj.open("rb") as file: 30 | return upload_file(file) 31 | if isinstance(obj, io.IOBase): 32 | return upload_file(obj) 33 | if HAS_NUMPY: 34 | if isinstance(obj, np.integer): # type: ignore 35 | return int(obj) 36 | if isinstance(obj, np.floating): # type: ignore 37 | return float(obj) 38 | if isinstance(obj, np.ndarray): # type: ignore 39 | return obj.tolist() 40 | return obj 41 | -------------------------------------------------------------------------------- /replicate/model.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union 2 | 3 | from typing_extensions import NotRequired, TypedDict, Unpack, deprecated 4 | 5 | from replicate.exceptions import ReplicateException 6 | from replicate.identifier import ModelVersionIdentifier 7 | from replicate.pagination import Page 8 | from replicate.prediction import ( 9 | Prediction, 10 | _create_prediction_body, 11 | _json_to_prediction, 12 | ) 13 | from replicate.resource import Namespace, Resource 14 | from replicate.version import Version, Versions 15 | 16 | try: 17 | from pydantic import v1 as pydantic # type: ignore 18 | except ImportError: 19 | import pydantic # type: ignore 20 | 21 | 22 | if TYPE_CHECKING: 23 | from replicate.client import Client 24 | from replicate.prediction import Predictions 25 | 26 | 27 | class Model(Resource): 28 | """ 29 | A machine learning model hosted on Replicate. 30 | """ 31 | 32 | _client: "Client" = pydantic.PrivateAttr() 33 | 34 | url: str 35 | """ 36 | The URL of the model. 37 | """ 38 | 39 | owner: str 40 | """ 41 | The owner of the model. 42 | """ 43 | 44 | name: str 45 | """ 46 | The name of the model. 47 | """ 48 | 49 | description: Optional[str] 50 | """ 51 | The description of the model. 52 | """ 53 | 54 | visibility: Literal["public", "private"] 55 | """ 56 | The visibility of the model. Can be 'public' or 'private'. 57 | """ 58 | 59 | github_url: Optional[str] 60 | """ 61 | The GitHub URL of the model. 62 | """ 63 | 64 | paper_url: Optional[str] 65 | """ 66 | The URL of the paper related to the model. 67 | """ 68 | 69 | license_url: Optional[str] 70 | """ 71 | The URL of the license for the model. 72 | """ 73 | 74 | run_count: int 75 | """ 76 | The number of runs of the model. 77 | """ 78 | 79 | cover_image_url: Optional[str] 80 | """ 81 | The URL of the cover image for the model. 82 | """ 83 | 84 | default_example: Optional[Prediction] 85 | """ 86 | The default example of the model. 87 | """ 88 | 89 | latest_version: Optional[Version] 90 | """ 91 | The latest version of the model. 92 | """ 93 | 94 | @property 95 | def id(self) -> str: 96 | """ 97 | Return the qualified model name, in the format `owner/name`. 98 | """ 99 | return f"{self.owner}/{self.name}" 100 | 101 | @property 102 | @deprecated("Use `model.owner` instead.") 103 | def username(self) -> str: 104 | """ 105 | The name of the user or organization that owns the model. 106 | This attribute is deprecated and will be removed in future versions. 107 | """ 108 | return self.owner 109 | 110 | @username.setter 111 | @deprecated("Use `model.owner` instead.") 112 | def username(self, value: str) -> None: 113 | self.owner = value 114 | 115 | def predict(self, *args, **kwargs) -> None: 116 | """ 117 | DEPRECATED: Use `replicate.run()` instead. 118 | """ 119 | 120 | raise ReplicateException( 121 | "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `replicate.run()` instead. See https://github.com/replicate/replicate-python#readme" 122 | ) 123 | 124 | @property 125 | def versions(self) -> Versions: 126 | """ 127 | Get the versions of this model. 128 | """ 129 | 130 | return Versions(client=self._client, model=self) 131 | 132 | def reload(self) -> None: 133 | """ 134 | Load this object from the server. 135 | """ 136 | 137 | obj = self._client.models.get(f"{self.owner}/{self.name}") 138 | for name, value in obj.dict().items(): 139 | setattr(self, name, value) 140 | 141 | 142 | class Models(Namespace): 143 | """ 144 | Namespace for operations related to models. 145 | """ 146 | 147 | model = Model 148 | 149 | @property 150 | def predictions(self) -> "ModelsPredictions": 151 | """ 152 | Get a namespace for operations related to predictions on a model. 153 | """ 154 | 155 | return ModelsPredictions(client=self._client) 156 | 157 | def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821 158 | """ 159 | List all public models. 160 | 161 | Parameters: 162 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 163 | Returns: 164 | Page[Model]: A page of of models. 165 | Raises: 166 | ValueError: If `cursor` is `None`. 167 | """ 168 | 169 | if cursor is None: 170 | raise ValueError("cursor cannot be None") 171 | 172 | resp = self._client._request("GET", "/v1/models" if cursor is ... else cursor) 173 | 174 | obj = resp.json() 175 | obj["results"] = [ 176 | _json_to_model(self._client, result) for result in obj["results"] 177 | ] 178 | 179 | return Page[Model](**obj) 180 | 181 | async def async_list( 182 | self, 183 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 184 | ) -> Page[Model]: 185 | """ 186 | List all public models. 187 | 188 | Parameters: 189 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 190 | Returns: 191 | Page[Model]: A page of of models. 192 | Raises: 193 | ValueError: If `cursor` is `None`. 194 | """ 195 | 196 | if cursor is None: 197 | raise ValueError("cursor cannot be None") 198 | 199 | resp = await self._client._async_request( 200 | "GET", "/v1/models" if cursor is ... else cursor 201 | ) 202 | 203 | obj = resp.json() 204 | obj["results"] = [ 205 | _json_to_model(self._client, result) for result in obj["results"] 206 | ] 207 | 208 | return Page[Model](**obj) 209 | 210 | def get(self, key: str) -> Model: 211 | """ 212 | Get a model by name. 213 | 214 | Args: 215 | key: The qualified name of the model, in the format `owner/model-name`. 216 | Returns: 217 | The model. 218 | """ 219 | 220 | resp = self._client._request("GET", f"/v1/models/{key}") 221 | 222 | return _json_to_model(self._client, resp.json()) 223 | 224 | async def async_get(self, key: str) -> Model: 225 | """ 226 | Get a model by name. 227 | 228 | Args: 229 | key: The qualified name of the model, in the format `owner/model-name`. 230 | Returns: 231 | The model. 232 | """ 233 | 234 | resp = await self._client._async_request("GET", f"/v1/models/{key}") 235 | 236 | return _json_to_model(self._client, resp.json()) 237 | 238 | class CreateModelParams(TypedDict): 239 | """Parameters for creating a model.""" 240 | 241 | hardware: str 242 | """The SKU for the hardware used to run the model. 243 | 244 | Possible values can be found by calling `replicate.hardware.list()`.""" 245 | 246 | visibility: Literal["public", "private"] 247 | """Whether the model should be public or private.""" 248 | 249 | description: NotRequired[str] 250 | """The description of the model.""" 251 | 252 | github_url: NotRequired[str] 253 | """A URL for the model's source code on GitHub.""" 254 | 255 | paper_url: NotRequired[str] 256 | """A URL for the model's paper.""" 257 | 258 | license_url: NotRequired[str] 259 | """A URL for the model's license.""" 260 | 261 | cover_image_url: NotRequired[str] 262 | """A URL for the model's cover image.""" 263 | 264 | def create( 265 | self, 266 | owner: str, 267 | name: str, 268 | **params: Unpack["Models.CreateModelParams"], 269 | ) -> Model: 270 | """ 271 | Create a model. 272 | """ 273 | 274 | body = _create_model_body(owner, name, **params) 275 | resp = self._client._request("POST", "/v1/models", json=body) 276 | 277 | return _json_to_model(self._client, resp.json()) 278 | 279 | async def async_create( 280 | self, owner: str, name: str, **params: Unpack["Models.CreateModelParams"] 281 | ) -> Model: 282 | """ 283 | Create a model. 284 | """ 285 | 286 | body = body = _create_model_body(owner, name, **params) 287 | resp = await self._client._async_request("POST", "/v1/models", json=body) 288 | 289 | return _json_to_model(self._client, resp.json()) 290 | 291 | 292 | class ModelsPredictions(Namespace): 293 | """ 294 | Namespace for operations related to predictions in a deployment. 295 | """ 296 | 297 | def create( 298 | self, 299 | model: Union[str, Tuple[str, str], "Model"], 300 | input: Dict[str, Any], 301 | **params: Unpack["Predictions.CreatePredictionParams"], 302 | ) -> Prediction: 303 | """ 304 | Create a new prediction with the deployment. 305 | """ 306 | 307 | url = _create_prediction_url_from_model(model) 308 | body = _create_prediction_body(version=None, input=input, **params) 309 | 310 | resp = self._client._request( 311 | "POST", 312 | url, 313 | json=body, 314 | ) 315 | 316 | return _json_to_prediction(self._client, resp.json()) 317 | 318 | async def async_create( 319 | self, 320 | model: Union[str, Tuple[str, str], "Model"], 321 | input: Dict[str, Any], 322 | **params: Unpack["Predictions.CreatePredictionParams"], 323 | ) -> Prediction: 324 | """ 325 | Create a new prediction with the deployment. 326 | """ 327 | 328 | url = _create_prediction_url_from_model(model) 329 | body = _create_prediction_body(version=None, input=input, **params) 330 | 331 | resp = await self._client._async_request( 332 | "POST", 333 | url, 334 | json=body, 335 | ) 336 | 337 | return _json_to_prediction(self._client, resp.json()) 338 | 339 | 340 | def _create_model_body( # pylint: disable=too-many-arguments 341 | owner: str, 342 | name: str, 343 | *, 344 | visibility: str, 345 | hardware: str, 346 | description: Optional[str] = None, 347 | github_url: Optional[str] = None, 348 | paper_url: Optional[str] = None, 349 | license_url: Optional[str] = None, 350 | cover_image_url: Optional[str] = None, 351 | ) -> Dict[str, Any]: 352 | body = { 353 | "owner": owner, 354 | "name": name, 355 | "visibility": visibility, 356 | "hardware": hardware, 357 | } 358 | 359 | if description is not None: 360 | body["description"] = description 361 | 362 | if github_url is not None: 363 | body["github_url"] = github_url 364 | 365 | if paper_url is not None: 366 | body["paper_url"] = paper_url 367 | 368 | if license_url is not None: 369 | body["license_url"] = license_url 370 | 371 | if cover_image_url is not None: 372 | body["cover_image_url"] = cover_image_url 373 | 374 | return body 375 | 376 | 377 | def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: 378 | model = Model(**json) 379 | model._client = client 380 | if model.default_example is not None: 381 | model.default_example._client = client 382 | return model 383 | 384 | 385 | def _create_prediction_url_from_model( 386 | model: Union[str, Tuple[str, str], "Model"], 387 | ) -> str: 388 | owner, name = None, None 389 | if isinstance(model, Model): 390 | owner, name = model.owner, model.name 391 | elif isinstance(model, tuple): 392 | owner, name = model[0], model[1] 393 | elif isinstance(model, str): 394 | owner, name, version_id = ModelVersionIdentifier.parse(model) 395 | if version_id is not None: 396 | raise ValueError( 397 | f"Invalid reference to model version: {model}. Expected model or reference in the format owner/name" 398 | ) 399 | 400 | if owner is None or name is None: 401 | raise ValueError( 402 | "model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'" 403 | ) 404 | 405 | return f"/v1/models/{owner}/{name}/predictions" 406 | -------------------------------------------------------------------------------- /replicate/pagination.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | AsyncGenerator, 4 | Awaitable, 5 | Callable, 6 | Generator, 7 | Generic, 8 | List, 9 | Optional, 10 | TypeVar, 11 | Union, 12 | ) 13 | 14 | try: 15 | from pydantic import v1 as pydantic # type: ignore 16 | except ImportError: 17 | import pydantic # type: ignore 18 | 19 | from replicate.resource import Resource 20 | 21 | T = TypeVar("T", bound=Resource) 22 | 23 | if TYPE_CHECKING: 24 | pass 25 | 26 | 27 | class Page(pydantic.BaseModel, Generic[T]): # type: ignore 28 | """ 29 | A page of results from the API. 30 | """ 31 | 32 | previous: Optional[str] = None 33 | """A pointer to the previous page of results""" 34 | 35 | next: Optional[str] = None 36 | """A pointer to the next page of results""" 37 | 38 | results: List[T] 39 | """The results on this page""" 40 | 41 | def __iter__(self): # noqa: ANN204 42 | return iter(self.results) 43 | 44 | def __getitem__(self, index: int) -> T: 45 | return self.results[index] 46 | 47 | def __len__(self) -> int: 48 | return len(self.results) 49 | 50 | 51 | def paginate( 52 | list_method: Callable[[Union[str, "ellipsis", None]], Page[T]], # noqa: F821 53 | ) -> Generator[Page[T], None, None]: 54 | """ 55 | Iterate over all items using the provided list method. 56 | 57 | Args: 58 | list_method: A method that takes a cursor argument and returns a Page of items. 59 | """ 60 | cursor: Union[str, "ellipsis", None] = ... # noqa: F821 61 | while cursor is not None: 62 | page = list_method(cursor) 63 | yield page 64 | cursor = page.next 65 | 66 | 67 | async def async_paginate( 68 | list_method: Callable[[Union[str, "ellipsis", None]], Awaitable[Page[T]]], # noqa: F821 69 | ) -> AsyncGenerator[Page[T], None]: 70 | """ 71 | Asynchronously iterate over all items using the provided list method. 72 | 73 | Args: 74 | list_method: An async method that takes a cursor argument and returns a Page of items. 75 | """ 76 | cursor: Union[str, "ellipsis", None] = ... # noqa: F821 77 | while cursor is not None: 78 | page = await list_method(cursor) 79 | yield page 80 | cursor = page.next 81 | -------------------------------------------------------------------------------- /replicate/resource.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import TYPE_CHECKING 3 | 4 | try: 5 | from pydantic import v1 as pydantic # type: ignore 6 | except ImportError: 7 | import pydantic # type: ignore 8 | 9 | if TYPE_CHECKING: 10 | from replicate.client import Client 11 | 12 | 13 | class Resource(pydantic.BaseModel): # type: ignore 14 | """ 15 | A base class for representing a single object on the server. 16 | """ 17 | 18 | 19 | class Namespace(abc.ABC): 20 | """ 21 | A base class for representing objects of a particular type on the server. 22 | """ 23 | 24 | _client: "Client" 25 | 26 | def __init__(self, client: "Client") -> None: 27 | self._client = client 28 | -------------------------------------------------------------------------------- /replicate/run.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | AsyncIterator, 5 | Dict, 6 | Iterator, 7 | List, 8 | Optional, 9 | Union, 10 | ) 11 | 12 | from typing_extensions import Unpack 13 | 14 | from replicate import identifier 15 | from replicate.exceptions import ModelError 16 | from replicate.model import Model 17 | from replicate.prediction import Prediction 18 | from replicate.schema import make_schema_backwards_compatible 19 | from replicate.version import Version, Versions 20 | 21 | if TYPE_CHECKING: 22 | from replicate.client import Client 23 | from replicate.identifier import ModelVersionIdentifier 24 | from replicate.prediction import Predictions 25 | 26 | 27 | def run( 28 | client: "Client", 29 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 30 | input: Optional[Dict[str, Any]] = None, 31 | **params: Unpack["Predictions.CreatePredictionParams"], 32 | ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 33 | """ 34 | Run a model and wait for its output. 35 | """ 36 | 37 | version, owner, name, version_id = identifier._resolve(ref) 38 | 39 | if version_id is not None: 40 | prediction = client.predictions.create( 41 | version=version_id, input=input or {}, **params 42 | ) 43 | elif owner and name: 44 | prediction = client.models.predictions.create( 45 | model=(owner, name), input=input or {}, **params 46 | ) 47 | else: 48 | raise ValueError( 49 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 50 | ) 51 | 52 | if not version and (owner and name and version_id): 53 | version = Versions(client, model=(owner, name)).get(version_id) 54 | 55 | if version and (iterator := _make_output_iterator(version, prediction)): 56 | return iterator 57 | 58 | prediction.wait() 59 | 60 | if prediction.status == "failed": 61 | raise ModelError(prediction.error) 62 | 63 | return prediction.output 64 | 65 | 66 | async def async_run( 67 | client: "Client", 68 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 69 | input: Optional[Dict[str, Any]] = None, 70 | **params: Unpack["Predictions.CreatePredictionParams"], 71 | ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 72 | """ 73 | Run a model and wait for its output asynchronously. 74 | """ 75 | 76 | version, owner, name, version_id = identifier._resolve(ref) 77 | 78 | if version or version_id: 79 | prediction = await client.predictions.async_create( 80 | version=(version or version_id), input=input or {}, **params 81 | ) 82 | elif owner and name: 83 | prediction = await client.models.predictions.async_create( 84 | model=(owner, name), input=input or {}, **params 85 | ) 86 | else: 87 | raise ValueError( 88 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 89 | ) 90 | 91 | if not version and (owner and name and version_id): 92 | version = await Versions(client, model=(owner, name)).async_get(version_id) 93 | 94 | if version and (iterator := _make_async_output_iterator(version, prediction)): 95 | return iterator 96 | 97 | await prediction.async_wait() 98 | 99 | if prediction.status == "failed": 100 | raise ModelError(prediction.error) 101 | 102 | return prediction.output 103 | 104 | 105 | def _has_output_iterator_array_type(version: Version) -> bool: 106 | schema = make_schema_backwards_compatible( 107 | version.openapi_schema, version.cog_version 108 | ) 109 | output = schema.get("components", {}).get("schemas", {}).get("Output", {}) 110 | return ( 111 | output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" 112 | ) 113 | 114 | 115 | def _make_output_iterator( 116 | version: Version, prediction: Prediction 117 | ) -> Optional[Iterator[Any]]: 118 | if _has_output_iterator_array_type(version): 119 | return prediction.output_iterator() 120 | 121 | return None 122 | 123 | 124 | def _make_async_output_iterator( 125 | version: Version, prediction: Prediction 126 | ) -> Optional[AsyncIterator[Any]]: 127 | if _has_output_iterator_array_type(version): 128 | return prediction.async_output_iterator() 129 | 130 | return None 131 | 132 | 133 | __all__: List = [] 134 | -------------------------------------------------------------------------------- /replicate/schema.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from packaging import version 4 | 5 | # TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth? 6 | 7 | 8 | def version_has_no_array_type(cog_version: str) -> Optional[bool]: 9 | """Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward""" 10 | try: 11 | return version.parse(cog_version) < version.parse("0.3.9") 12 | except version.InvalidVersion: 13 | return None 14 | 15 | 16 | def make_schema_backwards_compatible( 17 | schema: dict, 18 | cog_version: str, 19 | ) -> dict: 20 | """A place to add backwards compatibility logic for our openapi schema""" 21 | 22 | # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type 23 | if version_has_no_array_type(cog_version): 24 | output = schema["components"]["schemas"]["Output"] 25 | if output.get("type") == "array": 26 | output["x-cog-array-type"] = "iterator" 27 | return schema 28 | -------------------------------------------------------------------------------- /replicate/stream.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import ( 3 | TYPE_CHECKING, 4 | Any, 5 | AsyncIterator, 6 | Dict, 7 | Iterator, 8 | List, 9 | Optional, 10 | Union, 11 | ) 12 | 13 | from typing_extensions import Unpack 14 | 15 | from replicate import identifier 16 | from replicate.exceptions import ReplicateError 17 | 18 | try: 19 | from pydantic import v1 as pydantic # type: ignore 20 | except ImportError: 21 | import pydantic # type: ignore 22 | 23 | 24 | if TYPE_CHECKING: 25 | import httpx 26 | 27 | from replicate.client import Client 28 | from replicate.identifier import ModelVersionIdentifier 29 | from replicate.model import Model 30 | from replicate.prediction import Predictions 31 | from replicate.version import Version 32 | 33 | 34 | class ServerSentEvent(pydantic.BaseModel): # type: ignore 35 | """ 36 | A server-sent event. 37 | """ 38 | 39 | class EventType(Enum): 40 | """ 41 | A server-sent event type. 42 | """ 43 | 44 | OUTPUT = "output" 45 | LOGS = "logs" 46 | ERROR = "error" 47 | DONE = "done" 48 | 49 | event: EventType 50 | data: str 51 | id: str 52 | retry: Optional[int] 53 | 54 | def __str__(self) -> str: 55 | if self.event == ServerSentEvent.EventType.OUTPUT: 56 | return self.data 57 | 58 | return "" 59 | 60 | 61 | class EventSource: 62 | """ 63 | A server-sent event source. 64 | """ 65 | 66 | response: "httpx.Response" 67 | 68 | def __init__(self, response: "httpx.Response") -> None: 69 | self.response = response 70 | content_type, _, _ = response.headers["content-type"].partition(";") 71 | if content_type != "text/event-stream": 72 | raise ValueError( 73 | "Expected response Content-Type to be 'text/event-stream', " 74 | f"got {content_type!r}" 75 | ) 76 | 77 | class Decoder: 78 | """ 79 | A decoder for server-sent events. 80 | """ 81 | 82 | event: Optional["ServerSentEvent.EventType"] 83 | data: List[str] 84 | last_event_id: Optional[str] 85 | retry: Optional[int] 86 | 87 | def __init__(self) -> None: 88 | self.event = None 89 | self.data = [] 90 | self.last_event_id = None 91 | self.retry = None 92 | 93 | def decode(self, line: str) -> Optional[ServerSentEvent]: 94 | """ 95 | Decode a line and return a server-sent event if applicable. 96 | """ 97 | 98 | if not line: 99 | if ( 100 | not any([self.event, self.data, self.last_event_id, self.retry]) 101 | or self.event is None 102 | or self.last_event_id is None 103 | ): 104 | return None 105 | 106 | sse = ServerSentEvent( 107 | event=self.event, 108 | data="\n".join(self.data), 109 | id=self.last_event_id, 110 | retry=self.retry, 111 | ) 112 | 113 | self.event = None 114 | self.data = [] 115 | self.retry = None 116 | 117 | return sse 118 | 119 | if line.startswith(":"): 120 | return None 121 | 122 | fieldname, _, value = line.partition(":") 123 | value = value[1:] if value.startswith(" ") else value 124 | 125 | if fieldname == "event": 126 | if event := ServerSentEvent.EventType(value): 127 | self.event = event 128 | elif fieldname == "data": 129 | self.data.append(value) 130 | elif fieldname == "id": 131 | if "\0" not in value: 132 | self.last_event_id = value 133 | elif fieldname == "retry": 134 | try: 135 | self.retry = int(value) 136 | except (TypeError, ValueError): 137 | pass 138 | 139 | return None 140 | 141 | def __iter__(self) -> Iterator[ServerSentEvent]: 142 | decoder = EventSource.Decoder() 143 | 144 | for line in self.response.iter_lines(): 145 | line = line.rstrip("\n") 146 | sse = decoder.decode(line) 147 | if sse is not None: 148 | if sse.event == ServerSentEvent.EventType.ERROR: 149 | raise RuntimeError(sse.data) 150 | 151 | yield sse 152 | 153 | if sse.event == ServerSentEvent.EventType.DONE: 154 | return 155 | 156 | async def __aiter__(self) -> AsyncIterator[ServerSentEvent]: 157 | decoder = EventSource.Decoder() 158 | async for line in self.response.aiter_lines(): 159 | line = line.rstrip("\n") 160 | sse = decoder.decode(line) 161 | if sse is not None: 162 | if sse.event == ServerSentEvent.EventType.ERROR: 163 | raise RuntimeError(sse.data) 164 | 165 | yield sse 166 | 167 | if sse.event == ServerSentEvent.EventType.DONE: 168 | return 169 | 170 | 171 | def stream( 172 | client: "Client", 173 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 174 | input: Optional[Dict[str, Any]] = None, 175 | **params: Unpack["Predictions.CreatePredictionParams"], 176 | ) -> Iterator[ServerSentEvent]: 177 | """ 178 | Run a model and stream its output. 179 | """ 180 | 181 | params = params or {} 182 | params["stream"] = True 183 | 184 | version, owner, name, version_id = identifier._resolve(ref) 185 | 186 | if version or version_id: 187 | prediction = client.predictions.create( 188 | version=(version or version_id), input=input or {}, **params 189 | ) 190 | elif owner and name: 191 | prediction = client.models.predictions.create( 192 | model=(owner, name), input=input or {}, **params 193 | ) 194 | else: 195 | raise ValueError( 196 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 197 | ) 198 | 199 | url = prediction.urls and prediction.urls.get("stream", None) 200 | if not url or not isinstance(url, str): 201 | raise ReplicateError("Model does not support streaming") 202 | 203 | headers = {} 204 | headers["Accept"] = "text/event-stream" 205 | headers["Cache-Control"] = "no-store" 206 | 207 | with client._client.stream("GET", url, headers=headers) as response: 208 | yield from EventSource(response) 209 | 210 | 211 | async def async_stream( 212 | client: "Client", 213 | ref: Union["Model", "Version", "ModelVersionIdentifier", str], 214 | input: Optional[Dict[str, Any]] = None, 215 | **params: Unpack["Predictions.CreatePredictionParams"], 216 | ) -> AsyncIterator[ServerSentEvent]: 217 | """ 218 | Run a model and stream its output asynchronously. 219 | """ 220 | 221 | params = params or {} 222 | params["stream"] = True 223 | 224 | version, owner, name, version_id = identifier._resolve(ref) 225 | 226 | if version or version_id: 227 | prediction = await client.predictions.async_create( 228 | version=(version or version_id), input=input or {}, **params 229 | ) 230 | elif owner and name: 231 | prediction = await client.models.predictions.async_create( 232 | model=(owner, name), input=input or {}, **params 233 | ) 234 | else: 235 | raise ValueError( 236 | f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" 237 | ) 238 | 239 | url = prediction.urls and prediction.urls.get("stream", None) 240 | if not url or not isinstance(url, str): 241 | raise ReplicateError("Model does not support streaming") 242 | 243 | headers = {} 244 | headers["Accept"] = "text/event-stream" 245 | headers["Cache-Control"] = "no-store" 246 | 247 | async with client._async_client.stream("GET", url, headers=headers) as response: 248 | async for event in EventSource(response): 249 | yield event 250 | 251 | 252 | __all__ = ["ServerSentEvent"] 253 | -------------------------------------------------------------------------------- /replicate/training.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Dict, 5 | List, 6 | Literal, 7 | Optional, 8 | Tuple, 9 | TypedDict, 10 | Union, 11 | overload, 12 | ) 13 | 14 | from typing_extensions import NotRequired, Unpack 15 | 16 | from replicate.files import upload_file 17 | from replicate.identifier import ModelVersionIdentifier 18 | from replicate.json import encode_json 19 | from replicate.model import Model 20 | from replicate.pagination import Page 21 | from replicate.resource import Namespace, Resource 22 | from replicate.version import Version 23 | 24 | try: 25 | from pydantic import v1 as pydantic # type: ignore 26 | except ImportError: 27 | import pydantic # type: ignore 28 | 29 | if TYPE_CHECKING: 30 | from replicate.client import Client 31 | 32 | 33 | class Training(Resource): 34 | """ 35 | A training made for a model hosted on Replicate. 36 | """ 37 | 38 | _client: "Client" = pydantic.PrivateAttr() 39 | 40 | id: str 41 | """The unique ID of the training.""" 42 | 43 | model: str 44 | """An identifier for the model used to create the prediction, in the form `owner/name`.""" 45 | 46 | version: Union[str, Version] 47 | """The version of the model used to create the training.""" 48 | 49 | destination: Optional[str] 50 | """The model destination of the training.""" 51 | 52 | status: Literal["starting", "processing", "succeeded", "failed", "canceled"] 53 | """The status of the training.""" 54 | 55 | input: Optional[Dict[str, Any]] 56 | """The input to the training.""" 57 | 58 | output: Optional[Any] 59 | """The output of the training.""" 60 | 61 | logs: Optional[str] 62 | """The logs of the training.""" 63 | 64 | error: Optional[str] 65 | """The error encountered during the training, if any.""" 66 | 67 | created_at: Optional[str] 68 | """When the training was created.""" 69 | 70 | started_at: Optional[str] 71 | """When the training was started.""" 72 | 73 | completed_at: Optional[str] 74 | """When the training was completed, if finished.""" 75 | 76 | urls: Optional[Dict[str, str]] 77 | """ 78 | URLs associated with the training. 79 | 80 | The following keys are available: 81 | - `get`: A URL to fetch the training. 82 | - `cancel`: A URL to cancel the training. 83 | """ 84 | 85 | def cancel(self) -> None: 86 | """ 87 | Cancel a running training. 88 | """ 89 | 90 | canceled = self._client.trainings.cancel(self.id) 91 | for name, value in canceled.dict().items(): 92 | setattr(self, name, value) 93 | 94 | async def async_cancel(self) -> None: 95 | """ 96 | Cancel a running training asynchronously. 97 | """ 98 | 99 | canceled = await self._client.trainings.async_cancel(self.id) 100 | for name, value in canceled.dict().items(): 101 | setattr(self, name, value) 102 | 103 | def reload(self) -> None: 104 | """ 105 | Load the training from the server. 106 | """ 107 | 108 | updated = self._client.trainings.get(self.id) 109 | for name, value in updated.dict().items(): 110 | setattr(self, name, value) 111 | 112 | async def async_reload(self) -> None: 113 | """ 114 | Load the training from the server asynchronously. 115 | """ 116 | 117 | updated = await self._client.trainings.async_get(self.id) 118 | for name, value in updated.dict().items(): 119 | setattr(self, name, value) 120 | 121 | 122 | class Trainings(Namespace): 123 | """ 124 | Namespace for operations related to trainings. 125 | """ 126 | 127 | def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Training]: # noqa: F821 128 | """ 129 | List your trainings. 130 | 131 | Parameters: 132 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 133 | Returns: 134 | Page[Training]: A page of trainings. 135 | Raises: 136 | ValueError: If `cursor` is `None`. 137 | """ 138 | 139 | if cursor is None: 140 | raise ValueError("cursor cannot be None") 141 | 142 | resp = self._client._request( 143 | "GET", "/v1/trainings" if cursor is ... else cursor 144 | ) 145 | 146 | obj = resp.json() 147 | obj["results"] = [ 148 | _json_to_training(self._client, result) for result in obj["results"] 149 | ] 150 | 151 | return Page[Training](**obj) 152 | 153 | async def async_list( 154 | self, 155 | cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 156 | ) -> Page[Training]: 157 | """ 158 | List your trainings. 159 | 160 | Parameters: 161 | cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`. 162 | Returns: 163 | Page[Training]: A page of trainings. 164 | Raises: 165 | ValueError: If `cursor` is `None`. 166 | """ 167 | 168 | if cursor is None: 169 | raise ValueError("cursor cannot be None") 170 | 171 | resp = await self._client._async_request( 172 | "GET", "/v1/trainings" if cursor is ... else cursor 173 | ) 174 | 175 | obj = resp.json() 176 | obj["results"] = [ 177 | _json_to_training(self._client, result) for result in obj["results"] 178 | ] 179 | 180 | return Page[Training](**obj) 181 | 182 | def get(self, id: str) -> Training: 183 | """ 184 | Get a training by ID. 185 | 186 | Args: 187 | id: The ID of the training. 188 | Returns: 189 | Training: The training object. 190 | """ 191 | 192 | resp = self._client._request( 193 | "GET", 194 | f"/v1/trainings/{id}", 195 | ) 196 | 197 | return _json_to_training(self._client, resp.json()) 198 | 199 | async def async_get(self, id: str) -> Training: 200 | """ 201 | Get a training by ID. 202 | 203 | Args: 204 | id: The ID of the training. 205 | Returns: 206 | Training: The training object. 207 | """ 208 | 209 | resp = await self._client._async_request( 210 | "GET", 211 | f"/v1/trainings/{id}", 212 | ) 213 | 214 | return _json_to_training(self._client, resp.json()) 215 | 216 | class CreateTrainingParams(TypedDict): 217 | """Parameters for creating a training.""" 218 | 219 | destination: Union[str, Tuple[str, str], "Model"] 220 | webhook: NotRequired[str] 221 | webhook_completed: NotRequired[str] 222 | webhook_events_filter: NotRequired[List[str]] 223 | 224 | @overload 225 | def create( # pylint: disable=too-many-arguments 226 | self, 227 | version: str, 228 | input: Dict[str, Any], 229 | destination: str, 230 | webhook: Optional[str] = None, 231 | webhook_events_filter: Optional[List[str]] = None, 232 | **kwargs, 233 | ) -> Training: ... 234 | 235 | @overload 236 | def create( 237 | self, 238 | model: Union[str, Tuple[str, str], "Model"], 239 | version: Union[str, Version], 240 | input: Optional[Dict[str, Any]] = None, 241 | **params: Unpack["Trainings.CreateTrainingParams"], 242 | ) -> Training: ... 243 | 244 | def create( # type: ignore 245 | self, 246 | *args, 247 | model: Optional[Union[str, Tuple[str, str], "Model"]] = None, 248 | version: Optional[Union[str, Version]] = None, 249 | input: Optional[Dict[str, Any]] = None, 250 | **params: Unpack["Trainings.CreateTrainingParams"], 251 | ) -> Training: 252 | """ 253 | Create a new training using the specified model version as a base. 254 | """ 255 | 256 | url = None 257 | 258 | # Support positional arguments for backwards compatibility 259 | if args: 260 | if shorthand := args[0] if len(args) > 0 else None: 261 | url = _create_training_url_from_shorthand(shorthand) 262 | 263 | input = args[1] if len(args) > 1 else input 264 | if len(args) > 2: 265 | params["destination"] = args[2] 266 | if len(args) > 3: 267 | params["webhook"] = args[3] 268 | if len(args) > 4: 269 | params["webhook_completed"] = args[4] 270 | if len(args) > 5: 271 | params["webhook_events_filter"] = args[5] 272 | elif model and version: 273 | url = _create_training_url_from_model_and_version(model, version) 274 | elif model is None and isinstance(version, str): 275 | url = _create_training_url_from_shorthand(version) 276 | 277 | if not url: 278 | raise ValueError("model and version or shorthand version must be specified") 279 | 280 | body = _create_training_body(input, **params) 281 | resp = self._client._request( 282 | "POST", 283 | url, 284 | json=body, 285 | ) 286 | 287 | return _json_to_training(self._client, resp.json()) 288 | 289 | async def async_create( 290 | self, 291 | model: Union[str, Tuple[str, str], "Model"], 292 | version: Union[str, Version], 293 | input: Dict[str, Any], 294 | **params: Unpack["Trainings.CreateTrainingParams"], 295 | ) -> Training: 296 | """ 297 | Create a new training using the specified model version as a base. 298 | 299 | Args: 300 | version: The ID of the base model version that you're using to train a new model version. 301 | input: The input to the training. 302 | destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. 303 | webhook: The URL to send a POST request to when the training is completed. Defaults to None. 304 | webhook_completed: The URL to receive a POST request when the prediction is completed. 305 | webhook_events_filter: The events to send to the webhook. Defaults to None. 306 | Returns: 307 | The training object. 308 | """ 309 | 310 | url = _create_training_url_from_model_and_version(model, version) 311 | body = _create_training_body(input, **params) 312 | resp = await self._client._async_request( 313 | "POST", 314 | url, 315 | json=body, 316 | ) 317 | 318 | return _json_to_training(self._client, resp.json()) 319 | 320 | def cancel(self, id: str) -> Training: 321 | """ 322 | Cancel a training. 323 | 324 | Args: 325 | id: The ID of the training to cancel. 326 | Returns: 327 | Training: The canceled training object. 328 | """ 329 | 330 | resp = self._client._request( 331 | "POST", 332 | f"/v1/trainings/{id}/cancel", 333 | ) 334 | 335 | return _json_to_training(self._client, resp.json()) 336 | 337 | async def async_cancel(self, id: str) -> Training: 338 | """ 339 | Cancel a training. 340 | 341 | Args: 342 | id: The ID of the training to cancel. 343 | Returns: 344 | Training: The canceled training object. 345 | """ 346 | 347 | resp = await self._client._async_request( 348 | "POST", 349 | f"/v1/trainings/{id}/cancel", 350 | ) 351 | 352 | return _json_to_training(self._client, resp.json()) 353 | 354 | 355 | def _create_training_body( 356 | input: Optional[Dict[str, Any]] = None, 357 | *, 358 | destination: Optional[Union[str, Tuple[str, str], "Model"]] = None, 359 | webhook: Optional[str] = None, 360 | webhook_completed: Optional[str] = None, 361 | webhook_events_filter: Optional[List[str]] = None, 362 | ) -> Dict[str, Any]: 363 | body = {} 364 | 365 | if input is not None: 366 | body["input"] = encode_json(input, upload_file=upload_file) 367 | 368 | if destination is None: 369 | raise ValueError( 370 | "A destination must be provided as a positional or keyword argument." 371 | ) 372 | if isinstance(destination, Model): 373 | destination = f"{destination.owner}/{destination.name}" 374 | elif isinstance(destination, tuple): 375 | destination = f"{destination[0]}/{destination[1]}" 376 | body["destination"] = destination 377 | 378 | if webhook is not None: 379 | body["webhook"] = webhook 380 | 381 | if webhook_completed is not None: 382 | body["webhook_completed"] = webhook_completed 383 | 384 | if webhook_events_filter is not None: 385 | body["webhook_events_filter"] = webhook_events_filter 386 | 387 | return body 388 | 389 | 390 | def _create_training_url_from_shorthand(ref: str) -> str: 391 | owner, name, version_id = ModelVersionIdentifier.parse(ref) 392 | return f"/v1/models/{owner}/{name}/versions/{version_id}/trainings" 393 | 394 | 395 | def _create_training_url_from_model_and_version( 396 | model: Union[str, Tuple[str, str], "Model"], 397 | version: Union[str, "Version"], 398 | ) -> str: 399 | if isinstance(model, Model): 400 | owner, name = model.owner, model.name 401 | elif isinstance(model, tuple): 402 | owner, name = model[0], model[1] 403 | elif isinstance(model, str): 404 | owner, name, _ = ModelVersionIdentifier.parse(model) 405 | else: 406 | raise ValueError( 407 | "model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'" 408 | ) 409 | 410 | if isinstance(version, Version): 411 | version_id = version.id 412 | else: 413 | version_id = version 414 | 415 | return f"/v1/models/{owner}/{name}/versions/{version_id}/trainings" 416 | 417 | 418 | def _json_to_training(client: "Client", json: Dict[str, Any]) -> Training: 419 | training = Training(**json) 420 | training._client = client 421 | return training 422 | -------------------------------------------------------------------------------- /replicate/version.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import TYPE_CHECKING, Any, Dict, Tuple, Union 3 | 4 | if TYPE_CHECKING: 5 | from replicate.client import Client 6 | from replicate.model import Model 7 | 8 | from replicate.pagination import Page 9 | from replicate.resource import Namespace, Resource 10 | 11 | 12 | class Version(Resource): 13 | """ 14 | A version of a model. 15 | """ 16 | 17 | id: str 18 | """The unique ID of the version.""" 19 | 20 | created_at: datetime.datetime 21 | """When the version was created.""" 22 | 23 | cog_version: str 24 | """The version of the Cog used to create the version.""" 25 | 26 | openapi_schema: dict 27 | """An OpenAPI description of the model inputs and outputs.""" 28 | 29 | 30 | class Versions(Namespace): 31 | """ 32 | Namespace for operations related to model versions. 33 | """ 34 | 35 | model: Tuple[str, str] 36 | 37 | def __init__( 38 | self, client: "Client", model: Union[str, Tuple[str, str], "Model"] 39 | ) -> None: 40 | super().__init__(client=client) 41 | 42 | from replicate.model import Model # pylint: disable=import-outside-toplevel 43 | 44 | if isinstance(model, Model): 45 | self.model = (model.owner, model.name) 46 | elif isinstance(model, str): 47 | owner, name = model.split("/", 1) 48 | self.model = (owner, name) 49 | else: 50 | self.model = model 51 | 52 | def get(self, id: str) -> Version: 53 | """ 54 | Get a specific model version. 55 | 56 | Args: 57 | id: The version ID. 58 | Returns: 59 | The model version. 60 | """ 61 | 62 | resp = self._client._request( 63 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 64 | ) 65 | 66 | return _json_to_version(resp.json()) 67 | 68 | async def async_get(self, id: str) -> Version: 69 | """ 70 | Get a specific model version. 71 | 72 | Args: 73 | id: The version ID. 74 | Returns: 75 | The model version. 76 | """ 77 | 78 | resp = await self._client._async_request( 79 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 80 | ) 81 | 82 | return _json_to_version(resp.json()) 83 | 84 | def list(self) -> Page[Version]: 85 | """ 86 | Return a list of all versions for a model. 87 | 88 | Returns: 89 | List[Version]: A list of version objects. 90 | """ 91 | 92 | resp = self._client._request( 93 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions" 94 | ) 95 | obj = resp.json() 96 | obj["results"] = [_json_to_version(result) for result in obj["results"]] 97 | 98 | return Page[Version](**obj) 99 | 100 | async def async_list(self) -> Page[Version]: 101 | """ 102 | Return a list of all versions for a model. 103 | 104 | Returns: 105 | List[Version]: A list of version objects. 106 | """ 107 | 108 | resp = await self._client._async_request( 109 | "GET", f"/v1/models/{self.model[0]}/{self.model[1]}/versions" 110 | ) 111 | obj = resp.json() 112 | obj["results"] = [_json_to_version(result) for result in obj["results"]] 113 | 114 | return Page[Version](**obj) 115 | 116 | def delete(self, id: str) -> bool: 117 | """ 118 | Delete a model version and all associated predictions, including all output files. 119 | 120 | Model version deletion has some restrictions: 121 | 122 | * You can only delete versions from models you own. 123 | * You can only delete versions from private models. 124 | * You cannot delete a version if someone other than you 125 | has run predictions with it. 126 | 127 | Args: 128 | id: The version ID. 129 | """ 130 | 131 | resp = self._client._request( 132 | "DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 133 | ) 134 | return resp.status_code == 204 135 | 136 | async def async_delete(self, id: str) -> bool: 137 | """ 138 | Delete a model version and all associated predictions, including all output files. 139 | 140 | Model version deletion has some restrictions: 141 | 142 | * You can only delete versions from models you own. 143 | * You can only delete versions from private models. 144 | * You cannot delete a version if someone other than you 145 | has run predictions with it. 146 | 147 | Args: 148 | id: The version ID. 149 | """ 150 | 151 | resp = await self._client._async_request( 152 | "DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" 153 | ) 154 | return resp.status_code == 204 155 | 156 | 157 | def _json_to_version(json: Dict[str, Any]) -> Version: 158 | return Version(**json) 159 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile --extra=dev --output-file=requirements-dev.txt pyproject.toml 6 | # 7 | annotated-types==0.5.0 8 | # via pydantic 9 | anyio==3.7.1 10 | # via httpcore 11 | astroid==3.0.1 12 | # via pylint 13 | certifi==2023.7.22 14 | # via 15 | # httpcore 16 | # httpx 17 | dill==0.3.7 18 | # via pylint 19 | h11==0.14.0 20 | # via httpcore 21 | httpcore==0.17.3 22 | # via httpx 23 | httpx==0.24.1 24 | # via 25 | # replicate (pyproject.toml) 26 | # respx 27 | idna==3.4 28 | # via 29 | # anyio 30 | # httpx 31 | # yarl 32 | iniconfig==2.0.0 33 | # via pytest 34 | isort==5.12.0 35 | # via pylint 36 | mccabe==0.7.0 37 | # via pylint 38 | multidict==6.0.4 39 | # via yarl 40 | nodeenv==1.8.0 41 | # via pyright 42 | packaging==23.1 43 | # via 44 | # pytest 45 | # replicate (pyproject.toml) 46 | platformdirs==3.11.0 47 | # via pylint 48 | pluggy==1.2.0 49 | # via pytest 50 | pydantic==2.0.3 51 | # via replicate (pyproject.toml) 52 | pydantic-core==2.3.0 53 | # via pydantic 54 | pylint==3.0.2 55 | # via replicate (pyproject.toml) 56 | pyright==1.1.337 57 | # via replicate (pyproject.toml) 58 | pytest==7.4.0 59 | # via 60 | # pytest-asyncio 61 | # pytest-recording 62 | # replicate (pyproject.toml) 63 | pytest-asyncio==0.21.1 64 | # via replicate (pyproject.toml) 65 | pytest-recording==0.13.0 66 | # via replicate (pyproject.toml) 67 | pyyaml==6.0.1 68 | # via vcrpy 69 | respx==0.20.2 70 | # via replicate (pyproject.toml) 71 | ruff==0.3.3 72 | # via replicate (pyproject.toml) 73 | sniffio==1.3.0 74 | # via 75 | # anyio 76 | # httpcore 77 | # httpx 78 | tomlkit==0.12.1 79 | # via pylint 80 | typing-extensions==4.7.1 81 | # via 82 | # pydantic 83 | # pydantic-core 84 | # replicate (pyproject.toml) 85 | vcrpy==5.1.0 86 | # via pytest-recording 87 | wrapt==1.15.0 88 | # via vcrpy 89 | yarl==1.9.2 90 | # via vcrpy 91 | 92 | # The following packages are considered to be unsafe in a requirements file: 93 | # setuptools 94 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.11 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements.txt --resolver=backtracking pyproject.toml 6 | # 7 | annotated-types==0.5.0 8 | # via pydantic 9 | anyio==3.7.1 10 | # via httpcore 11 | certifi==2023.7.22 12 | # via 13 | # httpcore 14 | # httpx 15 | h11==0.14.0 16 | # via httpcore 17 | httpcore==0.17.3 18 | # via httpx 19 | httpx==0.24.1 20 | # via replicate (pyproject.toml) 21 | idna==3.4 22 | # via 23 | # anyio 24 | # httpx 25 | packaging==23.1 26 | # via replicate (pyproject.toml) 27 | pydantic==2.0.3 28 | # via replicate (pyproject.toml) 29 | pydantic-core==2.3.0 30 | # via pydantic 31 | sniffio==1.3.0 32 | # via 33 | # anyio 34 | # httpcore 35 | # httpx 36 | typing-extensions==4.7.1 37 | # via 38 | # pydantic 39 | # pydantic-core 40 | # replicate (pyproject.toml) 41 | -------------------------------------------------------------------------------- /script/format: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python -m ruff format . 6 | -------------------------------------------------------------------------------- /script/lint: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | STATUS=0 6 | 7 | echo "Running pyright" 8 | python -m pyright replicate || STATUS=$? 9 | echo "" 10 | 11 | echo "Running pylint" 12 | python -m pylint --exit-zero replicate || STATUS=$? 13 | echo "" 14 | 15 | echo "Running ruff check" 16 | python -m ruff check . || STATUS=$? 17 | echo "" 18 | 19 | echo "Running ruff format check" 20 | python -m ruff format --check . || STATUS=$? 21 | echo "" 22 | 23 | exit $STATUS 24 | -------------------------------------------------------------------------------- /script/setup: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python -m pip install -r requirements.txt -r requirements-dev.txt . 6 | -------------------------------------------------------------------------------- /script/test: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python -m pytest -v 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/super-dev03/python-test/891390d1b2e484a121097fde896bf378145e2af3/tests/__init__.py -------------------------------------------------------------------------------- /tests/cassettes/collections-list.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.15.6 15 | method: GET 16 | uri: https://api.replicate.com/v1/collections 17 | response: 18 | content: '{"next":null,"previous":null,"results":[{"name":"Vision models","slug":"vision-models","description":"Multimodal 19 | large language models with vision capabilities like object detection and optical 20 | character recognition (OCR)"},{"name":"T2I-Adapter","slug":"t2i-adapter","description":"T2I-Adapter 21 | models to modify images"},{"name":"Language models with support for grammars 22 | and jsonschema","slug":"language-models-with-grammar","description":"Language 23 | models that support grammar-based decoding as well as jsonschema constraints."},{"name":"SDXL 24 | fine-tunes","slug":"sdxl-fine-tunes","description":"Some of our favorite SDXL 25 | fine-tunes."},{"name":"Streaming language models","slug":"streaming-language-models","description":"Language 26 | models that support streaming responses. See https://replicate.com/docs/streaming"},{"name":"Image 27 | editing","slug":"image-editing","description":"Tools for manipulating images."},{"name":"Embedding 28 | models","slug":"embedding-models","description":"Models that generate embeddings 29 | from inputs"},{"name":"Trainable language models","slug":"trainable-language-models","description":"Language 30 | models that you can fine-tune using Replicate''s training API."},{"name":"Language 31 | models","slug":"language-models","description":"Models that can understand and 32 | generate text"},{"name":"ControlNet","slug":"control-net","description":"Control 33 | diffusion models"},{"name":"Audio generation","slug":"audio-generation","description":"Models 34 | to generate and modify audio"},{"name":"Diffusion models","slug":"diffusion-models","description":"Image 35 | and video generation models trained with diffusion processes"},{"name":"Videos","slug":"text-to-video","description":"Models 36 | that create and edit videos"},{"name":"Image to text","slug":"image-to-text","description":"Models 37 | that generate text prompts and captions from images"},{"name":"Super resolution","slug":"super-resolution","description":"Upscaling 38 | models that create high-quality images from low-quality images"},{"name":"Style 39 | transfer","slug":"style-transfer","description":"Models that take a content 40 | image and a style reference to produce a new image"},{"name":"ML makeovers","slug":"ml-makeovers","description":"Models 41 | that let you change facial features"},{"name":"Image restoration","slug":"image-restoration","description":"Models 42 | that improve or restore images by deblurring, colorization, and removing noise"},{"name":"Text 43 | to image","slug":"text-to-image","description":"Models that generate images 44 | from text prompts"}]}' 45 | headers: 46 | CF-Cache-Status: 47 | - DYNAMIC 48 | CF-RAY: 49 | - 827025392eae200a-IAD 50 | Connection: 51 | - keep-alive 52 | Content-Encoding: 53 | - gzip 54 | Content-Type: 55 | - application/json 56 | Date: 57 | - Thu, 16 Nov 2023 13:40:22 GMT 58 | Server: 59 | - cloudflare 60 | Strict-Transport-Security: 61 | - max-age=15552000 62 | Transfer-Encoding: 63 | - chunked 64 | allow: 65 | - GET, HEAD, OPTIONS 66 | content-security-policy-report-only: 67 | - 'media-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery 68 | https://*.mux.com https://*.sentry.io; default-src ''self''; script-src ''report-sample'' 69 | ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; img-src 70 | ''report-sample'' ''self'' data: https://replicate.delivery https://*.replicate.delivery 71 | https://*.githubusercontent.com https://github.com; worker-src ''none''; style-src 72 | ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample'' 73 | ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com 74 | https://*.rudderstack.com https://*.mux.com https://*.sentry.io; font-src 75 | ''report-sample'' ''self'' data:; report-uri' 76 | cross-origin-opener-policy: 77 | - same-origin 78 | nel: 79 | - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' 80 | ratelimit-remaining: 81 | - '2999' 82 | ratelimit-reset: 83 | - '1' 84 | referrer-policy: 85 | - same-origin 86 | report-to: 87 | - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1700142022&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=AUQIFO072WbZjKq785Xqd67vUUwGAhLFqu5%2BlLug%2BWE%3D"}]}' 88 | reporting-endpoints: 89 | - heroku-nel=https://nel.heroku.com/reports?ts=1700142022&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=AUQIFO072WbZjKq785Xqd67vUUwGAhLFqu5%2BlLug%2BWE%3D 90 | vary: 91 | - Cookie, origin 92 | via: 93 | - 1.1 vegur, 1.1 google 94 | x-content-type-options: 95 | - nosniff 96 | x-frame-options: 97 | - DENY 98 | http_version: HTTP/1.1 99 | status_code: 200 100 | version: 1 101 | -------------------------------------------------------------------------------- /tests/cassettes/hardware-list.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.15.5 15 | method: GET 16 | uri: https://api.replicate.com/v1/hardware 17 | response: 18 | content: '[{"sku":"cpu","name":"CPU"},{"sku":"gpu-t4","name":"Nvidia T4 GPU"},{"sku":"gpu-a40-small","name":"Nvidia 19 | A40 GPU"},{"sku":"gpu-a40-large","name":"Nvidia A40 (Large) GPU"}]' 20 | headers: 21 | CF-Cache-Status: 22 | - DYNAMIC 23 | CF-RAY: 24 | - 81fbfed29fe1c58a-SEA 25 | Connection: 26 | - keep-alive 27 | Content-Encoding: 28 | - gzip 29 | Content-Type: 30 | - application/json 31 | Date: 32 | - Thu, 02 Nov 2023 11:21:41 GMT 33 | Server: 34 | - cloudflare 35 | Strict-Transport-Security: 36 | - max-age=15552000 37 | Transfer-Encoding: 38 | - chunked 39 | allow: 40 | - OPTIONS, GET 41 | content-security-policy-report-only: 42 | - 'connect-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery 43 | https://*.rudderlabs.com https://*.rudderstack.com https://*.mux.com https://*.sentry.io; 44 | worker-src ''none''; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; 45 | style-src ''report-sample'' ''self'' ''unsafe-inline''; font-src ''report-sample'' 46 | ''self'' data:; img-src ''report-sample'' ''self'' data: https://replicate.delivery 47 | https://*.replicate.delivery https://*.githubusercontent.com https://github.com; 48 | default-src ''self''; media-src ''report-sample'' ''self'' https://replicate.delivery 49 | https://*.replicate.delivery https://*.mux.com https://*.sentry.io; report-uri' 50 | cross-origin-opener-policy: 51 | - same-origin 52 | nel: 53 | - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' 54 | ratelimit-remaining: 55 | - '2999' 56 | ratelimit-reset: 57 | - '1' 58 | referrer-policy: 59 | - same-origin 60 | report-to: 61 | - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D"}]}' 62 | reporting-endpoints: 63 | - heroku-nel=https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D 64 | vary: 65 | - Cookie, origin 66 | via: 67 | - 1.1 vegur, 1.1 google 68 | x-content-type-options: 69 | - nosniff 70 | x-frame-options: 71 | - DENY 72 | http_version: HTTP/1.1 73 | status_code: 200 74 | version: 1 75 | -------------------------------------------------------------------------------- /tests/cassettes/models-create.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware": 4 | "cpu", "description": "An example model"}' 5 | headers: 6 | accept: 7 | - '*/*' 8 | accept-encoding: 9 | - gzip, deflate 10 | connection: 11 | - keep-alive 12 | content-length: 13 | - '123' 14 | content-type: 15 | - application/json 16 | host: 17 | - api.replicate.com 18 | user-agent: 19 | - replicate-python/0.15.6 20 | method: POST 21 | uri: https://api.replicate.com/v1/models 22 | response: 23 | content: '{"url": "https://replicate.com/test/python-example", "owner": "test", 24 | "name": "python-example", "description": "An example model", "visibility": "private", 25 | "github_url": null, "paper_url": null, "license_url": null, "run_count": 0, 26 | "cover_image_url": null, "default_example": null, "latest_version": null}' 27 | headers: 28 | CF-Cache-Status: 29 | - DYNAMIC 30 | CF-RAY: 31 | - 81ff2e098ec0eb5b-SEA 32 | Connection: 33 | - keep-alive 34 | Content-Length: 35 | - '307' 36 | Content-Type: 37 | - application/json 38 | Date: 39 | - Thu, 02 Nov 2023 20:38:12 GMT 40 | Server: 41 | - cloudflare 42 | Strict-Transport-Security: 43 | - max-age=15552000 44 | allow: 45 | - GET, POST, HEAD, OPTIONS 46 | content-security-policy-report-only: 47 | - 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self'' 48 | data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com 49 | https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; 50 | style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample'' 51 | ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com 52 | https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src 53 | ''none''; media-src ''report-sample'' ''self'' https://replicate.delivery 54 | https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src 55 | ''self''; report-uri' 56 | cross-origin-opener-policy: 57 | - same-origin 58 | nel: 59 | - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' 60 | ratelimit-remaining: 61 | - '2999' 62 | ratelimit-reset: 63 | - '1' 64 | referrer-policy: 65 | - same-origin 66 | report-to: 67 | - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}' 68 | reporting-endpoints: 69 | - heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D 70 | vary: 71 | - Cookie, origin 72 | via: 73 | - 1.1 vegur, 1.1 google 74 | x-content-type-options: 75 | - nosniff 76 | x-frame-options: 77 | - DENY 78 | http_version: HTTP/1.1 79 | status_code: 201 80 | version: 1 81 | -------------------------------------------------------------------------------- /tests/cassettes/models-predictions-create.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: '{"input": {"prompt": "Please write a haiku about llamas."}}' 4 | headers: 5 | accept: 6 | - '*/*' 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | content-length: 12 | - '59' 13 | content-type: 14 | - application/json 15 | host: 16 | - api.replicate.com 17 | user-agent: 18 | - replicate-python/0.21.0 19 | method: POST 20 | uri: https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions 21 | response: 22 | content: '{"id":"heat2o3bzn3ahtr6bjfftvbaci","model":"replicate/lifeboat-70b","version":"d-c6559c5791b50af57b69f4a73f8e021c","input":{"prompt":"Please 23 | write a haiku about llamas."},"logs":"","error":null,"status":"starting","created_at":"2023-11-27T13:35:45.99397566Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel","get":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"}} 24 | 25 | ' 26 | headers: 27 | CF-Cache-Status: 28 | - DYNAMIC 29 | CF-RAY: 30 | - 82cac197efaec53d-SEA 31 | Connection: 32 | - keep-alive 33 | Content-Length: 34 | - '431' 35 | Content-Type: 36 | - application/json 37 | Date: 38 | - Mon, 27 Nov 2023 13:35:46 GMT 39 | NEL: 40 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 41 | Report-To: 42 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=7R5RONMF6xaGRc39n0wnSe3jU1FbpX64Xz4U%2B%2F2nasvFaz0pKARxPhnzDgYkLaWgdK9zWrD2jxU04aKOy5HMPHAXboJ993L4zfsOyto56lBtdqSjNgkptzzxYEsKD%2FxIhe2F"}],"group":"cf-nel","max_age":604800}' 43 | Server: 44 | - cloudflare 45 | Strict-Transport-Security: 46 | - max-age=15552000 47 | ratelimit-remaining: 48 | - '599' 49 | ratelimit-reset: 50 | - '1' 51 | via: 52 | - 1.1 google 53 | http_version: HTTP/1.1 54 | status_code: 201 55 | version: 1 56 | -------------------------------------------------------------------------------- /tests/cassettes/predictions-get.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: "" 4 | headers: 5 | accept: 6 | - "*/*" 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.11.0 15 | method: GET 16 | uri: https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu 17 | response: 18 | content: 19 | "{\"id\":\"vgcm4plb7tgzlyznry5d5jkgvu\",\"model\":\"stability-ai/sdxl\",\"version\":\"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b\",\"input\":{\"height\":512,\"prompt\":\"a 20 | studio photo of a rainbow colored corgi\",\"seed\":42069,\"width\":512},\"logs\":\"Using 21 | seed: 42069\\nPrompt: a studio photo of a rainbow colored corgi\\ntxt2img mode\\n 22 | \ 0%| | 0/50 [00:00\\u003c?, ?it/s]\\n 4%|\u258D | 2/50 [00:00\\u003c00:02, 23 | 16.47it/s]\\n 8%|\u258A | 4/50 [00:00\\u003c00:02, 16.39it/s]\\n 12%|\u2588\u258F 24 | \ | 6/50 [00:00\\u003c00:02, 16.60it/s]\\n 16%|\u2588\u258C | 8/50 25 | [00:00\\u003c00:02, 16.53it/s]\\n 20%|\u2588\u2588 | 10/50 [00:00\\u003c00:02, 26 | 16.76it/s]\\n 24%|\u2588\u2588\u258D | 12/50 [00:00\\u003c00:02, 16.93it/s]\\n 27 | 28%|\u2588\u2588\u258A | 14/50 [00:00\\u003c00:02, 17.04it/s]\\n 32%|\u2588\u2588\u2588\u258F 28 | \ | 16/50 [00:00\\u003c00:01, 17.10it/s]\\n 36%|\u2588\u2588\u2588\u258C 29 | \ | 18/50 [00:01\\u003c00:01, 17.12it/s]\\n 40%|\u2588\u2588\u2588\u2588 30 | \ | 20/50 [00:01\\u003c00:01, 17.15it/s]\\n 44%|\u2588\u2588\u2588\u2588\u258D 31 | \ | 22/50 [00:01\\u003c00:01, 17.16it/s]\\n 48%|\u2588\u2588\u2588\u2588\u258A 32 | \ | 24/50 [00:01\\u003c00:01, 17.17it/s]\\n 52%|\u2588\u2588\u2588\u2588\u2588\u258F 33 | \ | 26/50 [00:01\\u003c00:01, 17.20it/s]\\n 56%|\u2588\u2588\u2588\u2588\u2588\u258C 34 | \ | 28/50 [00:01\\u003c00:01, 17.21it/s]\\n 60%|\u2588\u2588\u2588\u2588\u2588\u2588 35 | \ | 30/50 [00:01\\u003c00:01, 17.19it/s]\\n 64%|\u2588\u2588\u2588\u2588\u2588\u2588\u258D 36 | \ | 32/50 [00:01\\u003c00:01, 17.18it/s]\\n 68%|\u2588\u2588\u2588\u2588\u2588\u2588\u258A 37 | \ | 34/50 [00:01\\u003c00:00, 17.18it/s]\\n 72%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258F 38 | \ | 36/50 [00:02\\u003c00:00, 17.20it/s]\\n 76%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258C 39 | \ | 38/50 [00:02\\u003c00:00, 17.21it/s]\\n 80%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 40 | \ | 40/50 [00:02\\u003c00:00, 17.19it/s]\\n 84%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258D 41 | | 42/50 [00:02\\u003c00:00, 17.19it/s]\\n 88%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258A 42 | | 44/50 [00:02\\u003c00:00, 17.19it/s]\\n 92%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258F| 43 | 46/50 [00:02\\u003c00:00, 17.20it/s]\\n 96%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258C| 44 | 48/50 [00:02\\u003c00:00, 17.22it/s]\\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 45 | 50/50 [00:02\\u003c00:00, 17.19it/s]\\n100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 46 | 50/50 [00:02\\u003c00:00, 17.09it/s]\\n\",\"output\":[\"https://replicate.delivery/pbxt/9inf36wjsEWuQ6XTf84iezPftv9QZdePfGySnU5tUai3BOrWE/out-0.png\"],\"error\":null,\"status\":\"succeeded\",\"created_at\":\"2023-08-16T18:57:08.360785Z\",\"started_at\":\"2023-08-16T18:57:08.366092Z\",\"completed_at\":\"2023-08-16T18:57:12.17042Z\",\"metrics\":{\"predict_time\":3.804328},\"urls\":{\"cancel\":\"https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu/cancel\",\"get\":\"https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu\"}}\n" 47 | headers: 48 | CF-Cache-Status: 49 | - DYNAMIC 50 | CF-RAY: 51 | - 7f7be8b17b47f8d1-SEA 52 | Connection: 53 | - keep-alive 54 | Content-Encoding: 55 | - gzip 56 | Content-Type: 57 | - application/json 58 | Date: 59 | - Wed, 16 Aug 2023 18:58:28 GMT 60 | NEL: 61 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 62 | Report-To: 63 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=UuflCrg7N4NAE6re4SQXd4aIWksjHl5BGKMFC3j9Rh9twIDDFOZzdAiU%2F%2FcWNM%2FofzDCdfSG628ZUoRTySoOZY04dhbVtbL4FCJ6YCsfEfkB%2B282Tfjs0VSoavvJmvBcSN%2B0"}],"group":"cf-nel","max_age":604800}' 64 | Server: 65 | - cloudflare 66 | Strict-Transport-Security: 67 | - max-age=15552000 68 | Transfer-Encoding: 69 | - chunked 70 | ratelimit-remaining: 71 | - "59999" 72 | ratelimit-reset: 73 | - "1" 74 | via: 75 | - 1.1 google 76 | http_version: HTTP/1.1 77 | status_code: 200 78 | - request: 79 | body: "" 80 | headers: 81 | accept: 82 | - "*/*" 83 | accept-encoding: 84 | - gzip, deflate 85 | connection: 86 | - keep-alive 87 | host: 88 | - api.replicate.com 89 | user-agent: 90 | - replicate-python/0.11.0 91 | method: GET 92 | uri: https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu 93 | response: 94 | content: 95 | "{\"completed_at\":\"2023-08-16T18:57:12.170420Z\",\"created_at\":\"2023-08-16T18:57:08.394251Z\",\"error\":null,\"id\":\"vgcm4plb7tgzlyznry5d5jkgvu\",\"input\":{\"seed\":42069,\"width\":512,\"height\":512,\"prompt\":\"a 96 | studio photo of a rainbow colored corgi\"},\"logs\":\"Using seed: 42069\\nPrompt: 97 | a studio photo of a rainbow colored corgi\\ntxt2img mode\\n 0%| | 98 | 0/50 [00:00"}' 6 | headers: 7 | accept: 8 | - "*/*" 9 | accept-encoding: 10 | - gzip, deflate 11 | connection: 12 | - keep-alive 13 | content-length: 14 | - "148" 15 | content-type: 16 | - application/json 17 | host: 18 | - api.replicate.com 19 | user-agent: 20 | - replicate-python/0.11.0 21 | method: POST 22 | uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b/trainings 23 | response: 24 | content: 25 | '{"detail":"The specified training destination does not exist","status":404} 26 | 27 | ' 28 | headers: 29 | CF-Cache-Status: 30 | - DYNAMIC 31 | CF-RAY: 32 | - 7f7c2190ed8c281a-SEA 33 | Connection: 34 | - keep-alive 35 | Content-Length: 36 | - "76" 37 | Content-Type: 38 | - application/problem+json 39 | Date: 40 | - Wed, 16 Aug 2023 19:37:18 GMT 41 | NEL: 42 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 43 | Report-To: 44 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' 45 | Server: 46 | - cloudflare 47 | Strict-Transport-Security: 48 | - max-age=15552000 49 | ratelimit-remaining: 50 | - "2999" 51 | ratelimit-reset: 52 | - "1" 53 | via: 54 | - 1.1 google 55 | http_version: HTTP/1.1 56 | status_code: 404 57 | version: 1 58 | -------------------------------------------------------------------------------- /tests/cassettes/trainings-get.yaml: -------------------------------------------------------------------------------- 1 | interactions: 2 | - request: 3 | body: "" 4 | headers: 5 | accept: 6 | - "*/*" 7 | accept-encoding: 8 | - gzip, deflate 9 | connection: 10 | - keep-alive 11 | host: 12 | - api.replicate.com 13 | user-agent: 14 | - replicate-python/0.11.0 15 | method: GET 16 | uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte 17 | response: 18 | content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","webhook_completed":null}' 19 | headers: 20 | CF-Cache-Status: 21 | - DYNAMIC 22 | CF-RAY: 23 | - 7f7c1beaedff279c-SEA 24 | Connection: 25 | - keep-alive 26 | Content-Encoding: 27 | - gzip 28 | Content-Type: 29 | - application/json 30 | Date: 31 | - Wed, 16 Aug 2023 19:33:26 GMT 32 | NEL: 33 | - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' 34 | Report-To: 35 | - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' 36 | Server: 37 | - cloudflare 38 | Strict-Transport-Security: 39 | - max-age=15552000 40 | Transfer-Encoding: 41 | - chunked 42 | allow: 43 | - OPTIONS, GET 44 | content-security-policy-report-only: 45 | - "style-src 'report-sample' 'self' 'unsafe-inline' https://fonts.googleapis.com; 46 | img-src 'report-sample' 'self' data: https://replicate.delivery https://*.replicate.delivery 47 | https://*.githubusercontent.com https://github.com; worker-src 'none'; media-src 48 | 'report-sample' 'self' https://replicate.delivery https://*.replicate.delivery 49 | https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src 'report-sample' 50 | 'self' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com 51 | https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src 52 | 'report-sample' 'self' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; 53 | font-src 'report-sample' 'self' data: https://fonts.replicate.ai https://fonts.gstatic.com; 54 | default-src 'self'; report-uri" 55 | cross-origin-opener-policy: 56 | - same-origin 57 | ratelimit-remaining: 58 | - "2999" 59 | ratelimit-reset: 60 | - "1" 61 | referrer-policy: 62 | - same-origin 63 | vary: 64 | - Cookie, origin 65 | via: 66 | - 1.1 vegur, 1.1 google 67 | x-content-type-options: 68 | - nosniff 69 | x-frame-options: 70 | - DENY 71 | http_version: HTTP/1.1 72 | status_code: 200 73 | version: 1 74 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from unittest import mock 4 | 5 | import pytest 6 | import pytest_asyncio 7 | 8 | 9 | @pytest_asyncio.fixture(scope="session", autouse=True) 10 | def event_loop(): 11 | event_loop_policy = asyncio.get_event_loop_policy() 12 | loop = event_loop_policy.new_event_loop() 13 | yield loop 14 | loop.close() 15 | 16 | 17 | @pytest.fixture(scope="session") 18 | def mock_replicate_api_token(scope="class"): 19 | if os.environ.get("REPLICATE_API_TOKEN", "") != "": 20 | yield 21 | else: 22 | with mock.patch.dict( 23 | os.environ, 24 | {"REPLICATE_API_TOKEN": "test-token", "REPLICATE_POLL_INTERVAL": "0.0"}, 25 | ): 26 | yield 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def vcr_config(): 31 | return {"allowed_hosts": ["api.replicate.com"], "filter_headers": ["authorization"]} 32 | 33 | 34 | @pytest.fixture(scope="module") 35 | def vcr_cassette_dir(request): 36 | module = request.node.fspath 37 | return os.path.join(module.dirname, "cassettes") 38 | -------------------------------------------------------------------------------- /tests/test_account.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | import respx 4 | 5 | from replicate.account import Account 6 | from replicate.client import Client 7 | 8 | router = respx.Router(base_url="https://api.replicate.com/v1") 9 | router.route( 10 | method="GET", 11 | path="/account", 12 | name="accounts.current", 13 | ).mock( 14 | return_value=httpx.Response( 15 | 200, 16 | json={ 17 | "type": "organization", 18 | "username": "replicate", 19 | "name": "Replicate", 20 | "github_url": "https://github.com/replicate", 21 | }, 22 | ) 23 | ) 24 | router.route(host="api.replicate.com").pass_through() 25 | 26 | 27 | @pytest.mark.asyncio 28 | @pytest.mark.parametrize("async_flag", [True, False]) 29 | async def test_account_current(async_flag): 30 | client = Client( 31 | api_token="test-token", transport=httpx.MockTransport(router.handler) 32 | ) 33 | 34 | if async_flag: 35 | account = await client.accounts.async_current() 36 | else: 37 | account = client.accounts.current() 38 | 39 | assert router["accounts.current"].called 40 | assert isinstance(account, Account) 41 | assert account.type == "organization" 42 | assert account.username == "replicate" 43 | assert account.name == "Replicate" 44 | assert account.github_url == "https://github.com/replicate" 45 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | import httpx 5 | import pytest 6 | import respx 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_authorization_when_setting_environ_after_import(): 11 | import replicate 12 | 13 | router = respx.Router() 14 | router.route( 15 | method="GET", 16 | url="https://api.replicate.com/", 17 | headers={"Authorization": "Bearer test-set-after-import"}, 18 | ).mock( 19 | return_value=httpx.Response( 20 | 200, 21 | json={}, 22 | ) 23 | ) 24 | 25 | token = "test-set-after-import" # noqa: S105 26 | 27 | with mock.patch.dict( 28 | os.environ, 29 | {"REPLICATE_API_TOKEN": token}, 30 | ): 31 | client = replicate.Client(transport=httpx.MockTransport(router.handler)) 32 | resp = client._request("GET", "/") 33 | assert resp.status_code == 200 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_client_error_handling(): 38 | import replicate 39 | from replicate.exceptions import ReplicateError 40 | 41 | router = respx.Router() 42 | router.route( 43 | method="GET", 44 | url="https://api.replicate.com/", 45 | headers={"Authorization": "Bearer test-client-error"}, 46 | ).mock( 47 | return_value=httpx.Response( 48 | 400, 49 | json={"detail": "Client error occurred"}, 50 | ) 51 | ) 52 | 53 | token = "test-client-error" # noqa: S105 54 | 55 | with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): 56 | client = replicate.Client(transport=httpx.MockTransport(router.handler)) 57 | with pytest.raises(ReplicateError) as exc_info: 58 | client._request("GET", "/") 59 | assert "status: 400" in str(exc_info.value) 60 | assert "detail: Client error occurred" in str(exc_info.value) 61 | 62 | 63 | @pytest.mark.asyncio 64 | async def test_server_error_handling(): 65 | import replicate 66 | from replicate.exceptions import ReplicateError 67 | 68 | router = respx.Router() 69 | router.route( 70 | method="GET", 71 | url="https://api.replicate.com/", 72 | headers={"Authorization": "Bearer test-server-error"}, 73 | ).mock( 74 | return_value=httpx.Response( 75 | 500, 76 | json={"detail": "Server error occurred"}, 77 | ) 78 | ) 79 | 80 | token = "test-server-error" # noqa: S105 81 | 82 | with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): 83 | client = replicate.Client(transport=httpx.MockTransport(router.handler)) 84 | with pytest.raises(ReplicateError) as exc_info: 85 | client._request("GET", "/") 86 | assert "status: 500" in str(exc_info.value) 87 | assert "detail: Server error occurred" in str(exc_info.value) 88 | 89 | 90 | def test_custom_headers_are_applied(): 91 | import replicate 92 | from replicate.exceptions import ReplicateError 93 | 94 | custom_headers = {"Custom-Header": "CustomValue"} 95 | 96 | def mock_send(request: httpx.Request, **kwargs) -> httpx.Response: 97 | assert "Custom-Header" in request.headers 98 | assert request.headers["Custom-Header"] == "CustomValue" 99 | 100 | return httpx.Response(401, json={}) 101 | 102 | client = replicate.Client( 103 | api_token="dummy_token", 104 | headers=custom_headers, 105 | transport=httpx.MockTransport(mock_send), 106 | ) 107 | 108 | try: 109 | client.accounts.current() 110 | except ReplicateError: 111 | pass 112 | -------------------------------------------------------------------------------- /tests/test_collection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.vcr("collections-list.yaml") 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_collections_list(async_flag): 10 | if async_flag: 11 | page = await replicate.collections.async_list() 12 | else: 13 | page = replicate.collections.list() 14 | 15 | assert page.next is None 16 | assert page.previous is None 17 | 18 | found = False 19 | for collection in page.results: 20 | if collection.slug == "text-to-image": 21 | found = True 22 | break 23 | 24 | assert found 25 | 26 | 27 | @pytest.mark.vcr("collections-get.yaml") 28 | @pytest.mark.asyncio 29 | @pytest.mark.parametrize("async_flag", [True, False]) 30 | async def test_collections_get(async_flag): 31 | if async_flag: 32 | collection = await replicate.collections.async_get("text-to-image") 33 | else: 34 | collection = replicate.collections.get("text-to-image") 35 | 36 | assert collection.slug == "text-to-image" 37 | assert collection.name == "Text to image" 38 | assert collection.models is not None 39 | assert len(collection.models) > 0 40 | 41 | found = False 42 | for model in collection.models: 43 | if model.name == "stable-diffusion": 44 | found = True 45 | break 46 | 47 | assert found 48 | -------------------------------------------------------------------------------- /tests/test_deployment.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import httpx 4 | import pytest 5 | import respx 6 | 7 | from replicate.client import Client 8 | 9 | router = respx.Router(base_url="https://api.replicate.com/v1") 10 | 11 | router.route( 12 | method="GET", 13 | path="/deployments/replicate/my-app-image-generator", 14 | name="deployments.get", 15 | ).mock( 16 | return_value=httpx.Response( 17 | 201, 18 | json={ 19 | "owner": "replicate", 20 | "name": "my-app-image-generator", 21 | "current_release": { 22 | "number": 1, 23 | "model": "stability-ai/sdxl", 24 | "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", 25 | "created_at": "2024-02-15T16:32:57.018467Z", 26 | "created_by": { 27 | "type": "organization", 28 | "username": "acme", 29 | "name": "Acme Corp, Inc.", 30 | "github_url": "https://github.com/acme", 31 | }, 32 | "configuration": { 33 | "hardware": "gpu-t4", 34 | "min_instances": 1, 35 | "max_instances": 5, 36 | }, 37 | }, 38 | }, 39 | ) 40 | ) 41 | router.route( 42 | method="POST", 43 | path="/deployments/replicate/my-app-image-generator/predictions", 44 | name="deployments.predictions.create", 45 | ).mock( 46 | return_value=httpx.Response( 47 | 201, 48 | json={ 49 | "id": "p1", 50 | "model": "replicate/my-app-image-generator", 51 | "version": "v1", 52 | "urls": { 53 | "get": "https://api.replicate.com/v1/predictions/p1", 54 | "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", 55 | }, 56 | "created_at": "2022-04-26T20:00:40.658234Z", 57 | "source": "api", 58 | "status": "processing", 59 | "input": {"text": "world"}, 60 | "output": None, 61 | "error": None, 62 | "logs": "", 63 | }, 64 | ) 65 | ) 66 | router.route( 67 | method="GET", 68 | path="/deployments", 69 | name="deployments.list", 70 | ).mock( 71 | return_value=httpx.Response( 72 | 200, 73 | json={ 74 | "results": [ 75 | { 76 | "owner": "acme", 77 | "name": "image-upscaler", 78 | "current_release": { 79 | "number": 1, 80 | "model": "acme/esrgan", 81 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 82 | "created_at": "2022-01-01T00:00:00Z", 83 | "created_by": { 84 | "type": "organization", 85 | "username": "acme", 86 | "name": "Acme, Inc.", 87 | }, 88 | "configuration": { 89 | "hardware": "gpu-t4", 90 | "min_instances": 1, 91 | "max_instances": 5, 92 | }, 93 | }, 94 | }, 95 | { 96 | "owner": "acme", 97 | "name": "text-generator", 98 | "current_release": { 99 | "number": 2, 100 | "model": "acme/acme-llama", 101 | "version": "4b7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccbb", 102 | "created_at": "2022-02-02T00:00:00Z", 103 | "created_by": { 104 | "type": "organization", 105 | "username": "acme", 106 | "name": "Acme, Inc.", 107 | }, 108 | "configuration": { 109 | "hardware": "cpu", 110 | "min_instances": 2, 111 | "max_instances": 10, 112 | }, 113 | }, 114 | }, 115 | ] 116 | }, 117 | ) 118 | ) 119 | 120 | router.route( 121 | method="POST", 122 | path="/deployments", 123 | name="deployments.create", 124 | ).mock( 125 | return_value=httpx.Response( 126 | 201, 127 | json={ 128 | "owner": "acme", 129 | "name": "new-deployment", 130 | "current_release": { 131 | "number": 1, 132 | "model": "acme/new-model", 133 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 134 | "created_at": "2022-01-01T00:00:00Z", 135 | "created_by": { 136 | "type": "organization", 137 | "username": "acme", 138 | "name": "Acme, Inc.", 139 | }, 140 | "configuration": { 141 | "hardware": "gpu-t4", 142 | "min_instances": 1, 143 | "max_instances": 5, 144 | }, 145 | }, 146 | }, 147 | ) 148 | ) 149 | 150 | 151 | router.route( 152 | method="PATCH", 153 | path="/deployments/acme/image-upscaler", 154 | name="deployments.update", 155 | ).mock( 156 | return_value=httpx.Response( 157 | 200, 158 | json={ 159 | "owner": "acme", 160 | "name": "image-upscaler", 161 | "current_release": { 162 | "number": 2, 163 | "model": "acme/esrgan-updated", 164 | "version": "new-version-id", 165 | "created_at": "2022-02-02T00:00:00Z", 166 | "created_by": { 167 | "type": "organization", 168 | "username": "acme", 169 | "name": "Acme, Inc.", 170 | }, 171 | "configuration": { 172 | "hardware": "gpu-v100", 173 | "min_instances": 2, 174 | "max_instances": 10, 175 | }, 176 | }, 177 | }, 178 | ) 179 | ) 180 | 181 | 182 | router.route(host="api.replicate.com").pass_through() 183 | 184 | 185 | @pytest.mark.asyncio 186 | @pytest.mark.parametrize("async_flag", [True, False]) 187 | async def test_deployment_get(async_flag): 188 | client = Client( 189 | api_token="test-token", transport=httpx.MockTransport(router.handler) 190 | ) 191 | 192 | if async_flag: 193 | deployment = await client.deployments.async_get( 194 | "replicate/my-app-image-generator" 195 | ) 196 | else: 197 | deployment = client.deployments.get("replicate/my-app-image-generator") 198 | 199 | assert router["deployments.get"].called 200 | 201 | assert deployment.owner == "replicate" 202 | assert deployment.name == "my-app-image-generator" 203 | assert deployment.current_release is not None 204 | assert deployment.current_release.number == 1 205 | assert deployment.current_release.model == "stability-ai/sdxl" 206 | assert ( 207 | deployment.current_release.version 208 | == "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" 209 | ) 210 | assert deployment.current_release is not None 211 | assert deployment.current_release.created_by is not None 212 | assert deployment.current_release.created_by.type == "organization" 213 | assert deployment.current_release.created_by.username == "acme" 214 | assert deployment.current_release.created_by.name == "Acme Corp, Inc." 215 | assert deployment.current_release.created_by.github_url == "https://github.com/acme" 216 | 217 | 218 | @pytest.mark.asyncio 219 | @pytest.mark.parametrize("async_flag", [True, False]) 220 | async def test_deployment_predictions_create(async_flag): 221 | client = Client( 222 | api_token="test-token", transport=httpx.MockTransport(router.handler) 223 | ) 224 | 225 | if async_flag: 226 | deployment = await client.deployments.async_get( 227 | "replicate/my-app-image-generator" 228 | ) 229 | 230 | prediction = await deployment.predictions.async_create( 231 | input={"text": "world"}, 232 | webhook="https://example.com/webhook", 233 | webhook_events_filter=["completed"], 234 | stream=True, 235 | ) 236 | else: 237 | deployment = client.deployments.get("replicate/my-app-image-generator") 238 | 239 | prediction = deployment.predictions.create( 240 | input={"text": "world"}, 241 | webhook="https://example.com/webhook", 242 | webhook_events_filter=["completed"], 243 | stream=True, 244 | ) 245 | 246 | assert router["deployments.predictions.create"].called 247 | request = router["deployments.predictions.create"].calls[0].request 248 | request_body = json.loads(request.content) 249 | assert request_body["input"] == {"text": "world"} 250 | assert request_body["webhook"] == "https://example.com/webhook" 251 | assert request_body["webhook_events_filter"] == ["completed"] 252 | assert request_body["stream"] is True 253 | 254 | assert prediction.id == "p1" 255 | assert prediction.input == {"text": "world"} 256 | 257 | 258 | @pytest.mark.asyncio 259 | @pytest.mark.parametrize("async_flag", [True, False]) 260 | async def test_deploymentspredictions_create(async_flag): 261 | client = Client( 262 | api_token="test-token", transport=httpx.MockTransport(router.handler) 263 | ) 264 | 265 | if async_flag: 266 | prediction = await client.deployments.predictions.async_create( 267 | deployment="replicate/my-app-image-generator", 268 | input={"text": "world"}, 269 | webhook="https://example.com/webhook", 270 | webhook_events_filter=["completed"], 271 | stream=True, 272 | ) 273 | else: 274 | prediction = await client.deployments.predictions.async_create( 275 | deployment="replicate/my-app-image-generator", 276 | input={"text": "world"}, 277 | webhook="https://example.com/webhook", 278 | webhook_events_filter=["completed"], 279 | stream=True, 280 | ) 281 | 282 | assert router["deployments.predictions.create"].called 283 | request = router["deployments.predictions.create"].calls[0].request 284 | request_body = json.loads(request.content) 285 | assert request_body["input"] == {"text": "world"} 286 | assert request_body["webhook"] == "https://example.com/webhook" 287 | assert request_body["webhook_events_filter"] == ["completed"] 288 | assert request_body["stream"] is True 289 | 290 | assert prediction.id == "p1" 291 | assert prediction.input == {"text": "world"} 292 | 293 | 294 | @respx.mock 295 | @pytest.mark.asyncio 296 | @pytest.mark.parametrize("async_flag", [True, False]) 297 | async def test_deployments_list(async_flag): 298 | client = Client( 299 | api_token="test-token", transport=httpx.MockTransport(router.handler) 300 | ) 301 | 302 | if async_flag: 303 | deployments = await client.deployments.async_list() 304 | else: 305 | deployments = client.deployments.list() 306 | 307 | assert router["deployments.list"].called 308 | 309 | assert len(deployments.results) == 2 310 | assert deployments.results[0].owner == "acme" 311 | assert deployments.results[0].name == "image-upscaler" 312 | assert deployments.results[0].current_release is not None 313 | assert deployments.results[0].current_release.number == 1 314 | assert deployments.results[0].current_release.model == "acme/esrgan" 315 | assert deployments.results[1].owner == "acme" 316 | assert deployments.results[1].name == "text-generator" 317 | assert deployments.results[1].current_release is not None 318 | assert deployments.results[1].current_release.number == 2 319 | assert deployments.results[1].current_release.model == "acme/acme-llama" 320 | 321 | 322 | @respx.mock 323 | @pytest.mark.asyncio 324 | @pytest.mark.parametrize("async_flag", [True, False]) 325 | async def test_create_deployment(async_flag): 326 | client = Client( 327 | api_token="test-token", transport=httpx.MockTransport(router.handler) 328 | ) 329 | 330 | config = { 331 | "name": "new-deployment", 332 | "model": "acme/new-model", 333 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 334 | "hardware": "gpu-t4", 335 | "min_instances": 1, 336 | "max_instances": 5, 337 | } 338 | 339 | if async_flag: 340 | deployment = await client.deployments.async_create(**config) 341 | else: 342 | deployment = client.deployments.create(**config) 343 | 344 | assert router["deployments.create"].called 345 | 346 | assert deployment.owner == "acme" 347 | assert deployment.name == "new-deployment" 348 | assert deployment.current_release is not None 349 | assert deployment.current_release.number == 1 350 | assert deployment.current_release.model == "acme/new-model" 351 | assert ( 352 | deployment.current_release.version 353 | == "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 354 | ) 355 | assert deployment.current_release.created_by is not None 356 | assert deployment.current_release.created_by.type == "organization" 357 | assert deployment.current_release.created_by.username == "acme" 358 | assert deployment.current_release.created_by.name == "Acme, Inc." 359 | assert deployment.current_release.configuration.hardware == "gpu-t4" 360 | assert deployment.current_release.configuration.min_instances == 1 361 | assert deployment.current_release.configuration.max_instances == 5 362 | 363 | 364 | @respx.mock 365 | @pytest.mark.asyncio 366 | @pytest.mark.parametrize("async_flag", [True, False]) 367 | async def test_update_deployment(async_flag): 368 | config = { 369 | "version": "new-version-id", 370 | "hardware": "gpu-v100", 371 | "min_instances": 2, 372 | "max_instances": 10, 373 | } 374 | 375 | client = Client( 376 | api_token="test-token", transport=httpx.MockTransport(router.handler) 377 | ) 378 | 379 | if async_flag: 380 | updated_deployment = await client.deployments.async_update( 381 | deployment_owner="acme", deployment_name="image-upscaler", **config 382 | ) 383 | else: 384 | updated_deployment = client.deployments.update( 385 | deployment_owner="acme", deployment_name="image-upscaler", **config 386 | ) 387 | 388 | assert router["deployments.update"].called 389 | request = router["deployments.update"].calls[0].request 390 | request_body = json.loads(request.content) 391 | assert request_body == config 392 | 393 | assert updated_deployment.owner == "acme" 394 | assert updated_deployment.name == "image-upscaler" 395 | assert updated_deployment.current_release is not None 396 | assert updated_deployment.current_release.number == 2 397 | assert updated_deployment.current_release.model == "acme/esrgan-updated" 398 | assert updated_deployment.current_release.version == "new-version-id" 399 | assert updated_deployment.current_release.configuration.hardware == "gpu-v100" 400 | assert updated_deployment.current_release.configuration.min_instances == 2 401 | assert updated_deployment.current_release.configuration.max_instances == 10 402 | -------------------------------------------------------------------------------- /tests/test_hardware.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.vcr("hardware-list.yaml") 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_hardware_list(async_flag): 10 | if async_flag: 11 | hardware = await replicate.hardware.async_list() 12 | else: 13 | hardware = replicate.hardware.list() 14 | 15 | assert hardware is not None 16 | assert isinstance(hardware, list) 17 | assert len(hardware) > 0 18 | -------------------------------------------------------------------------------- /tests/test_identifier.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from replicate.identifier import ModelVersionIdentifier 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "id, expected", 8 | [ 9 | ( 10 | "meta/llama-2-70b-chat", 11 | { 12 | "owner": "meta", 13 | "name": "llama-2-70b-chat", 14 | "version": None, 15 | "error": False, 16 | }, 17 | ), 18 | ( 19 | "mistralai/mistral-7b-instruct-v1.4", 20 | { 21 | "owner": "mistralai", 22 | "name": "mistral-7b-instruct-v1.4", 23 | "version": None, 24 | "error": False, 25 | }, 26 | ), 27 | ( 28 | "nateraw/video-llava:a494250c04691c458f57f2f8ef5785f25bc851e0c91fd349995081d4362322dd", 29 | { 30 | "owner": "nateraw", 31 | "name": "video-llava", 32 | "version": "a494250c04691c458f57f2f8ef5785f25bc851e0c91fd349995081d4362322dd", 33 | "error": False, 34 | }, 35 | ), 36 | ( 37 | "", 38 | {"error": True}, 39 | ), 40 | ( 41 | "invalid", 42 | {"error": True}, 43 | ), 44 | ( 45 | "invalid/id/format", 46 | {"error": True}, 47 | ), 48 | ], 49 | ) 50 | def test_parse_model_id(id, expected): 51 | try: 52 | result = ModelVersionIdentifier.parse(id) 53 | assert result.owner == expected["owner"] 54 | assert result.name == expected["name"] 55 | assert result.version == expected["version"] 56 | except ValueError: 57 | assert expected["error"] 58 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.vcr("models-get.yaml") 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_models_get(async_flag): 10 | if async_flag: 11 | sdxl = await replicate.models.async_get("stability-ai/sdxl") 12 | else: 13 | sdxl = replicate.models.get("stability-ai/sdxl") 14 | 15 | assert sdxl is not None 16 | assert sdxl.owner == "stability-ai" 17 | assert sdxl.name == "sdxl" 18 | assert sdxl.visibility == "public" 19 | 20 | if async_flag: 21 | empty = await replicate.models.async_get("mattt/empty") 22 | else: 23 | empty = replicate.models.get("mattt/empty") 24 | 25 | assert empty.default_example is None 26 | 27 | 28 | @pytest.mark.vcr("models-list.yaml") 29 | @pytest.mark.asyncio 30 | @pytest.mark.parametrize("async_flag", [True, False]) 31 | async def test_models_list(async_flag): 32 | if async_flag: 33 | models = await replicate.models.async_list() 34 | else: 35 | models = replicate.models.list() 36 | 37 | assert len(models) > 0 38 | assert models[0].owner is not None 39 | assert models[0].name is not None 40 | assert models[0].visibility == "public" 41 | 42 | 43 | @pytest.mark.vcr("models-list__pagination.yaml") 44 | @pytest.mark.asyncio 45 | @pytest.mark.parametrize("async_flag", [True, False]) 46 | async def test_models_list_pagination(async_flag): 47 | if async_flag: 48 | page1 = await replicate.models.async_list() 49 | else: 50 | page1 = replicate.models.list() 51 | assert len(page1) > 0 52 | assert page1.next is not None 53 | 54 | if async_flag: 55 | page2 = await replicate.models.async_list(cursor=page1.next) 56 | else: 57 | page2 = replicate.models.list(cursor=page1.next) 58 | assert len(page2) > 0 59 | assert page2.previous is not None 60 | 61 | 62 | @pytest.mark.vcr("models-create.yaml") 63 | @pytest.mark.asyncio 64 | @pytest.mark.parametrize("async_flag", [True, False]) 65 | async def test_models_create(async_flag): 66 | if async_flag: 67 | model = await replicate.models.async_create( 68 | owner="test", 69 | name="python-example", 70 | visibility="private", 71 | hardware="cpu", 72 | description="An example model", 73 | ) 74 | else: 75 | model = replicate.models.create( 76 | owner="test", 77 | name="python-example", 78 | visibility="private", 79 | hardware="cpu", 80 | description="An example model", 81 | ) 82 | 83 | assert model.owner == "test" 84 | assert model.name == "python-example" 85 | assert model.visibility == "private" 86 | 87 | 88 | @pytest.mark.vcr("models-create.yaml") 89 | @pytest.mark.asyncio 90 | @pytest.mark.parametrize("async_flag", [True, False]) 91 | async def test_models_create_with_positional_arguments(async_flag): 92 | if async_flag: 93 | model = await replicate.models.async_create( 94 | "test", 95 | "python-example", 96 | visibility="private", 97 | hardware="cpu", 98 | ) 99 | else: 100 | model = replicate.models.create( 101 | "test", 102 | "python-example", 103 | visibility="private", 104 | hardware="cpu", 105 | ) 106 | 107 | assert model.owner == "test" 108 | assert model.name == "python-example" 109 | assert model.visibility == "private" 110 | 111 | 112 | @pytest.mark.vcr("models-predictions-create.yaml") 113 | @pytest.mark.asyncio 114 | @pytest.mark.parametrize("async_flag", [True, False]) 115 | async def test_models_predictions_create(async_flag): 116 | input = { 117 | "prompt": "Please write a haiku about llamas.", 118 | } 119 | 120 | if async_flag: 121 | prediction = await replicate.models.predictions.async_create( 122 | "meta/llama-2-70b-chat", input=input 123 | ) 124 | else: 125 | prediction = replicate.models.predictions.create( 126 | "meta/llama-2-70b-chat", input=input 127 | ) 128 | 129 | assert prediction.id is not None 130 | # assert prediction.model == "meta/llama-2-70b-chat" 131 | assert prediction.model == "replicate/lifeboat-70b" # FIXME: this is temporary 132 | assert prediction.status == "starting" 133 | -------------------------------------------------------------------------------- /tests/test_pagination.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | 5 | 6 | @pytest.mark.asyncio 7 | async def test_paginate_with_none_cursor(mock_replicate_api_token): 8 | with pytest.raises(ValueError): 9 | replicate.models.list(None) 10 | 11 | 12 | @pytest.mark.vcr("collections-list.yaml") 13 | @pytest.mark.asyncio 14 | @pytest.mark.parametrize("async_flag", [True, False]) 15 | async def test_paginate(async_flag): 16 | found = False 17 | 18 | if async_flag: 19 | async for page in replicate.async_paginate(replicate.collections.async_list): 20 | assert page.next is None 21 | assert page.previous is None 22 | 23 | for collection in page: 24 | if collection.slug == "text-to-image": 25 | found = True 26 | break 27 | 28 | else: 29 | for page in replicate.paginate(replicate.collections.list): 30 | assert page.next is None 31 | assert page.previous is None 32 | 33 | for collection in page: 34 | if collection.slug == "text-to-image": 35 | found = True 36 | break 37 | 38 | assert found 39 | -------------------------------------------------------------------------------- /tests/test_run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | 4 | import httpx 5 | import pytest 6 | import respx 7 | 8 | import replicate 9 | from replicate.client import Client 10 | from replicate.exceptions import ReplicateError 11 | 12 | 13 | @pytest.mark.vcr("run.yaml") 14 | @pytest.mark.asyncio 15 | @pytest.mark.parametrize("async_flag", [True, False]) 16 | async def test_run(async_flag, record_mode): 17 | if record_mode == "none": 18 | replicate.default_client.poll_interval = 0.001 19 | 20 | version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" 21 | 22 | input = { 23 | "prompt": "a studio photo of a rainbow colored corgi", 24 | "width": 512, 25 | "height": 512, 26 | "seed": 42069, 27 | } 28 | 29 | if async_flag: 30 | output = await replicate.async_run( 31 | f"stability-ai/sdxl:{version}", 32 | input=input, 33 | ) 34 | else: 35 | output = replicate.run( 36 | f"stability-ai/sdxl:{version}", 37 | input=input, 38 | ) 39 | 40 | assert output is not None 41 | assert isinstance(output, list) 42 | assert len(output) > 0 43 | assert output[0].startswith("https://") 44 | 45 | 46 | @pytest.mark.vcr("run__concurrently.yaml") 47 | @pytest.mark.asyncio 48 | @pytest.mark.skipif( 49 | sys.version_info < (3, 11), reason="asyncio.TaskGroup requires Python 3.11" 50 | ) 51 | async def test_run_concurrently(mock_replicate_api_token, record_mode): 52 | client = replicate.Client() 53 | if record_mode == "none": 54 | client.poll_interval = 0.001 55 | 56 | version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" 57 | 58 | prompts = [ 59 | f"A chariot pulled by a team of {count} rainbow unicorns" 60 | for count in ["two", "four", "six", "eight"] 61 | ] 62 | 63 | async with asyncio.TaskGroup() as tg: 64 | tasks = [ 65 | tg.create_task( 66 | client.async_run( 67 | f"stability-ai/sdxl:{version}", input={"prompt": prompt} 68 | ) 69 | ) 70 | for prompt in prompts 71 | ] 72 | 73 | results = await asyncio.gather(*tasks) 74 | assert len(results) == len(prompts) 75 | assert all(isinstance(result, list) for result in results) 76 | assert all(len(result) > 0 for result in results) 77 | 78 | 79 | @pytest.mark.vcr("run.yaml") 80 | @pytest.mark.asyncio 81 | async def test_run_with_invalid_identifier(mock_replicate_api_token): 82 | with pytest.raises(ValueError): 83 | replicate.run("invalid") 84 | 85 | 86 | @pytest.mark.vcr("run__invalid-token.yaml") 87 | @pytest.mark.asyncio 88 | async def test_run_with_invalid_token(): 89 | with pytest.raises(ReplicateError) as excinfo: 90 | client = replicate.Client(api_token="invalid") 91 | 92 | version = "73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9" 93 | client.run( 94 | f"meta/llama-2-7b:{version}", 95 | ) 96 | 97 | assert "You did not pass a valid authentication token" in str(excinfo.value) 98 | 99 | 100 | @pytest.mark.asyncio 101 | async def test_run_version_with_invalid_cog_version(mock_replicate_api_token): 102 | def prediction_with_status(status: str) -> dict: 103 | return { 104 | "id": "p1", 105 | "model": "test/example", 106 | "version": "v1", 107 | "urls": { 108 | "get": "https://api.replicate.com/v1/predictions/p1", 109 | "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", 110 | }, 111 | "created_at": "2023-10-05T12:00:00.000000Z", 112 | "source": "api", 113 | "status": status, 114 | "input": {"text": "world"}, 115 | "output": "Hello, world!" if status == "succeeded" else None, 116 | "error": None, 117 | "logs": "", 118 | } 119 | 120 | router = respx.Router(base_url="https://api.replicate.com/v1") 121 | router.route(method="POST", path="/predictions").mock( 122 | return_value=httpx.Response( 123 | 201, 124 | json=prediction_with_status("processing"), 125 | ) 126 | ) 127 | router.route(method="GET", path="/predictions/p1").mock( 128 | return_value=httpx.Response( 129 | 200, 130 | json=prediction_with_status("succeeded"), 131 | ) 132 | ) 133 | router.route( 134 | method="GET", 135 | path="/models/test/example/versions/invalid", 136 | ).mock( 137 | return_value=httpx.Response( 138 | 201, 139 | json={ 140 | "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", 141 | "created_at": "2022-03-16T00:35:56.210272Z", 142 | "cog_version": "dev", 143 | "openapi_schema": { 144 | "openapi": "3.0.2", 145 | "info": {"title": "Cog", "version": "0.1.0"}, 146 | "paths": {}, 147 | "components": { 148 | "schemas": { 149 | "Input": { 150 | "type": "object", 151 | "title": "Input", 152 | "required": ["text"], 153 | "properties": { 154 | "text": { 155 | "type": "string", 156 | "title": "Text", 157 | "x-order": 0, 158 | "description": "The text input", 159 | }, 160 | }, 161 | }, 162 | "Output": { 163 | "type": "string", 164 | "title": "Output", 165 | }, 166 | } 167 | }, 168 | }, 169 | }, 170 | ) 171 | ) 172 | router.route(host="api.replicate.com").pass_through() 173 | 174 | client = Client( 175 | api_token="test-token", transport=httpx.MockTransport(router.handler) 176 | ) 177 | client.poll_interval = 0.001 178 | 179 | output = client.run( 180 | "test/example:invalid", 181 | input={ 182 | "text": "Hello, world!", 183 | }, 184 | ) 185 | 186 | assert output == "Hello, world!" 187 | -------------------------------------------------------------------------------- /tests/test_stream.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | from replicate.stream import ServerSentEvent 5 | 6 | 7 | @pytest.mark.asyncio 8 | @pytest.mark.parametrize("async_flag", [True, False]) 9 | async def test_stream(async_flag, record_mode): 10 | model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" 11 | input = { 12 | "text": "Hello", 13 | } 14 | 15 | events = [] 16 | 17 | if async_flag: 18 | async for event in await replicate.async_stream( 19 | model, 20 | input=input, 21 | ): 22 | events.append(event) 23 | else: 24 | for event in replicate.stream( 25 | model, 26 | input=input, 27 | ): 28 | events.append(event) 29 | 30 | assert len(events) > 0 31 | assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events) 32 | assert any(event.event == ServerSentEvent.EventType.DONE for event in events) 33 | 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.parametrize("async_flag", [True, False]) 37 | async def test_stream_prediction(async_flag, record_mode): 38 | version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" 39 | input = { 40 | "text": "Hello", 41 | } 42 | 43 | events = [] 44 | 45 | if async_flag: 46 | async for event in replicate.predictions.create( 47 | version=version, input=input, stream=True 48 | ).async_stream(): 49 | events.append(event) 50 | else: 51 | for event in replicate.predictions.create( 52 | version=version, input=input, stream=True 53 | ).stream(): 54 | events.append(event) 55 | 56 | assert len(events) > 0 57 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import replicate 4 | from replicate.exceptions import ReplicateException 5 | 6 | input_images_url = "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip" 7 | 8 | 9 | @pytest.mark.vcr("trainings-create.yaml") 10 | @pytest.mark.asyncio 11 | @pytest.mark.parametrize("async_flag", [True, False]) 12 | async def test_trainings_create(async_flag, mock_replicate_api_token): 13 | if async_flag: 14 | training = await replicate.trainings.async_create( 15 | model="stability-ai/sdxl", 16 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 17 | input={ 18 | "input_images": input_images_url, 19 | "use_face_detection_instead": True, 20 | }, 21 | destination="replicate/dreambooth-sdxl", 22 | ) 23 | else: 24 | training = replicate.trainings.create( 25 | model="stability-ai/sdxl", 26 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 27 | input={ 28 | "input_images": input_images_url, 29 | "use_face_detection_instead": True, 30 | }, 31 | destination="replicate/dreambooth-sdxl", 32 | ) 33 | 34 | assert training.id is not None 35 | assert training.status == "starting" 36 | 37 | 38 | @pytest.mark.vcr("trainings-create.yaml") 39 | @pytest.mark.asyncio 40 | @pytest.mark.parametrize("async_flag", [True, False]) 41 | async def test_trainings_create_with_named_version_argument( 42 | async_flag, mock_replicate_api_token 43 | ): 44 | if async_flag: 45 | # The overload with a model version identifier is soft-deprecated 46 | # and not supported in the async version. 47 | return 48 | else: 49 | training = replicate.trainings.create( 50 | version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 51 | input={ 52 | "input_images": input_images_url, 53 | "use_face_detection_instead": True, 54 | }, 55 | destination="replicate/dreambooth-sdxl", 56 | ) 57 | 58 | assert training.id is not None 59 | assert training.status == "starting" 60 | 61 | 62 | @pytest.mark.vcr("trainings-create.yaml") 63 | @pytest.mark.asyncio 64 | @pytest.mark.parametrize("async_flag", [True, False]) 65 | async def test_trainings_create_with_positional_argument( 66 | async_flag, mock_replicate_api_token 67 | ): 68 | if async_flag: 69 | # The overload with positional arguments is soft-deprecated 70 | # and not supported in the async version. 71 | return 72 | else: 73 | training = replicate.trainings.create( 74 | "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 75 | { 76 | "input_images": input_images_url, 77 | "use_face_detection_instead": True, 78 | }, 79 | "replicate/dreambooth-sdxl", 80 | ) 81 | 82 | assert training.id is not None 83 | assert training.status == "starting" 84 | 85 | 86 | @pytest.mark.vcr("trainings-create__invalid-destination.yaml") 87 | @pytest.mark.asyncio 88 | @pytest.mark.parametrize("async_flag", [True, False]) 89 | async def test_trainings_create_with_invalid_destination( 90 | async_flag, mock_replicate_api_token 91 | ): 92 | with pytest.raises(ReplicateException): 93 | if async_flag: 94 | await replicate.trainings.async_create( 95 | model="stability-ai/sdxl", 96 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 97 | input={ 98 | "input_images": input_images_url, 99 | "use_face_detection_instead": True, 100 | }, 101 | destination="", 102 | ) 103 | else: 104 | replicate.trainings.create( 105 | model="stability-ai/sdxl", 106 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 107 | input={ 108 | "input_images": input_images_url, 109 | }, 110 | destination="", 111 | ) 112 | 113 | 114 | @pytest.mark.vcr("trainings-get.yaml") 115 | @pytest.mark.asyncio 116 | @pytest.mark.parametrize("async_flag", [True, False]) 117 | async def test_trainings_get(async_flag, mock_replicate_api_token): 118 | id = "medrnz3bm5dd6ultvad2tejrte" 119 | 120 | if async_flag: 121 | training = await replicate.trainings.async_get(id) 122 | else: 123 | training = replicate.trainings.get(id) 124 | 125 | assert training.id == id 126 | assert training.status == "processing" 127 | 128 | 129 | @pytest.mark.vcr("trainings-cancel.yaml") 130 | @pytest.mark.asyncio 131 | @pytest.mark.parametrize("async_flag", [True, False]) 132 | async def test_trainings_cancel(async_flag, mock_replicate_api_token): 133 | input = { 134 | "input_images": input_images_url, 135 | "use_face_detection_instead": True, 136 | } 137 | 138 | destination = "replicate/dreambooth-sdxl" 139 | 140 | if async_flag: 141 | training = await replicate.trainings.async_create( 142 | model="stability-ai/sdxl", 143 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 144 | input=input, 145 | destination=destination, 146 | ) 147 | 148 | assert training.status == "starting" 149 | 150 | training = replicate.trainings.cancel(training.id) 151 | assert training.status == "canceled" 152 | else: 153 | training = replicate.trainings.create( 154 | version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 155 | destination=destination, 156 | input=input, 157 | ) 158 | 159 | assert training.status == "starting" 160 | 161 | training = replicate.trainings.cancel(training.id) 162 | assert training.status == "canceled" 163 | 164 | 165 | @pytest.mark.vcr("trainings-cancel.yaml") 166 | @pytest.mark.asyncio 167 | @pytest.mark.parametrize("async_flag", [True, False]) 168 | async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_token): 169 | input = { 170 | "input_images": input_images_url, 171 | "use_face_detection_instead": True, 172 | } 173 | 174 | destination = "replicate/dreambooth-sdxl" 175 | 176 | if async_flag: 177 | # The cancel instance method is soft-deprecated, 178 | # and not supported in the async version. 179 | return 180 | else: 181 | training = replicate.trainings.create( 182 | version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 183 | destination=destination, 184 | input=input, 185 | ) 186 | 187 | assert training.status == "starting" 188 | 189 | training.cancel() 190 | assert training.status == "canceled" 191 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | import respx 4 | 5 | from replicate.client import Client 6 | 7 | router = respx.Router(base_url="https://api.replicate.com/v1") 8 | 9 | router.route( 10 | method="GET", 11 | path="/models/replicate/hello-world", 12 | name="models.get", 13 | ).mock( 14 | return_value=httpx.Response( 15 | 200, 16 | json={ 17 | "owner": "replicate", 18 | "name": "hello-world", 19 | "description": "A tiny model that says hello", 20 | "visibility": "public", 21 | "run_count": 1e10, 22 | "url": "https://replicate.com/replicate/hello-world", 23 | "created_at": "2022-04-26T19:13:45.911328Z", 24 | "latest_version": { 25 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 26 | "cog_version": "0.3.0", 27 | "openapi_schema": { 28 | "openapi": "3.0.2", 29 | "info": {"title": "Cog", "version": "0.1.0"}, 30 | "components": { 31 | "schemas": { 32 | "Input": { 33 | "type": "object", 34 | "title": "Input", 35 | "required": ["text"], 36 | "properties": { 37 | "text": { 38 | "type": "string", 39 | "title": "Text", 40 | "x-order": 0, 41 | "description": "Text to prefix with 'hello '", 42 | } 43 | }, 44 | }, 45 | "Output": {"type": "string", "title": "Output"}, 46 | } 47 | }, 48 | }, 49 | "created_at": "2022-04-26T19:29:04.418669Z", 50 | }, 51 | }, 52 | ) 53 | ) 54 | 55 | router.route( 56 | method="DELETE", 57 | path__regex=r"^/models/replicate/hello-world/versions/(?P\w+)/?", 58 | name="models.versions.delete", 59 | ).mock( 60 | return_value=httpx.Response( 61 | 202, 62 | ) 63 | ) 64 | 65 | 66 | @pytest.mark.asyncio 67 | @pytest.mark.parametrize("async_flag", [True, False]) 68 | async def test_version_delete(async_flag): 69 | client = Client( 70 | api_token="test-token", transport=httpx.MockTransport(router.handler) 71 | ) 72 | 73 | if async_flag: 74 | model = await client.models.async_get("replicate/hello-world") 75 | assert model is not None 76 | assert model.latest_version is not None 77 | 78 | await model.versions.async_delete(model.latest_version.id) 79 | else: 80 | model = client.models.get("replicate/hello-world") 81 | assert model is not None 82 | assert model.latest_version is not None 83 | 84 | model.versions.delete(model.latest_version.id) 85 | 86 | assert router["models.get"].called 87 | assert router["models.versions.delete"].called 88 | --------------------------------------------------------------------------------