├── .github └── workflows │ ├── ci.yml │ └── codeql.yml ├── .gitignore ├── .spi.yml ├── LICENSE.md ├── Package.swift ├── README.md ├── Sources └── Replicate │ ├── Account.swift │ ├── Client.swift │ ├── Deployment.swift │ ├── Error.swift │ ├── Extensions │ └── Data+uriEncoded.swift │ ├── Hardware.swift │ ├── Identifier.swift │ ├── Model.swift │ ├── Predictable.swift │ ├── Prediction.swift │ ├── Status.swift │ ├── Training.swift │ ├── Value.swift │ └── Webhook.swift └── Tests └── ReplicateTests ├── ClientTests.swift ├── DateDecodingTests.swift ├── Helpers └── MockURLProtocol.swift ├── PredictionTests.swift ├── RetryPolicyTests.swift └── URIEncodingTests.swift /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | macos: 11 | runs-on: macos-13-xlarge 12 | 13 | strategy: 14 | matrix: 15 | xcode: 16 | - "14.3.1" # Swift 5.8.1 17 | 18 | name: "macOS (Xcode ${{ matrix.xcode }})" 19 | 20 | env: 21 | DEVELOPER_DIR: /Applications/Xcode_${{ matrix.xcode }}.app/Contents/Developer 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - uses: actions/cache@v3 26 | with: 27 | path: .build 28 | key: ${{ runner.os }}-spm-xcode-${{ matrix.xcode }}-${{ hashFiles('**/Package.resolved') }} 29 | restore-keys: | 30 | ${{ runner.os }}-spm-xcode-${{ matrix.xcode }}- 31 | - name: Build 32 | run: swift build -v 33 | - name: Run tests 34 | run: swift test -v 35 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | schedule: 9 | - cron: "42 4 * * 0" 10 | 11 | jobs: 12 | analyze: 13 | name: Analyze 14 | runs-on: "macos-latest" 15 | timeout-minutes: 120 16 | permissions: 17 | actions: read 18 | contents: read 19 | security-events: write 20 | 21 | steps: 22 | - name: Checkout repository 23 | uses: actions/checkout@v3 24 | 25 | # Initializes the CodeQL tools for scanning. 26 | - name: Initialize CodeQL 27 | uses: github/codeql-action/init@v2 28 | with: 29 | languages: swift 30 | # If you wish to specify custom queries, you can do so here or in a config file. 31 | # By default, queries listed here will override any specified in a config file. 32 | # Prefix the list here with "+" to use these queries and those in the config file. 33 | 34 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 35 | # queries: security-extended,security-and-quality 36 | 37 | - uses: actions/cache@v3 38 | with: 39 | path: .build 40 | key: ${{ runner.os }}-spm-${{ hashFiles('**/Package.resolved') }} 41 | restore-keys: | 42 | ${{ runner.os }}-spm- 43 | 44 | - name: Build 45 | run: swift build -v 46 | 47 | - name: Perform CodeQL Analysis 48 | uses: github/codeql-action/analyze@v2 49 | with: 50 | category: "/language:swift" 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.build 3 | /Packages 4 | /*.xcodeproj 5 | xcuserdata/ 6 | DerivedData/ 7 | .swiftpm/config/registries.json 8 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata 9 | .netrc 10 | /*.playground 11 | -------------------------------------------------------------------------------- /.spi.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | builder: 3 | configs: 4 | - documentation_targets: [Replicate] 5 | metadata: 6 | - authors: Replicate 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2023, Replicate, Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.7 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "Replicate", 8 | platforms: [ 9 | .macOS(.v12), 10 | .iOS(.v15) 11 | ], 12 | products: [ 13 | // Products define the executables and libraries a package produces, and make them visible to other packages. 14 | .library( 15 | name: "Replicate", 16 | targets: ["Replicate"]), 17 | ], 18 | dependencies: [], 19 | targets: [ 20 | // Targets are the basic building blocks of a package. A target can define a module or a test suite. 21 | // Targets can depend on other targets in this package, and on products in packages this package depends on. 22 | .target( 23 | name: "Replicate" 24 | ), 25 | .testTarget( 26 | name: "ReplicateTests", 27 | dependencies: [ 28 | "Replicate" 29 | ]), 30 | ] 31 | ) 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Replicate Swift client 2 | 3 | [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Freplicate%2Freplicate-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/replicate/replicate-swift) 4 | [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Freplicate%2Freplicate-swift%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/replicate/replicate-swift) 5 | 6 | This is a Swift client for [Replicate]. 7 | It lets you run models from your Swift code, 8 | and do various other things on Replicate. 9 | 10 | To learn how to use it, 11 | [take a look at our guide to building a SwiftUI app with Replicate](https://replicate.com/docs/get-started/swiftui). 12 | 13 | ## Usage 14 | 15 | Grab your API token from [replicate.com/account](https://replicate.com/account) 16 | and pass it to `Client(token:)`: 17 | 18 | ```swift 19 | import Foundation 20 | import Replicate 21 | 22 | let replicate = Replicate.Client(token: <#token#>) 23 | ``` 24 | 25 | > [!WARNING] 26 | > Don't store secrets in code or any other resources bundled with your app. 27 | > Instead, fetch them from CloudKit or another server and store them in the keychain. 28 | 29 | You can run a model and get its output: 30 | 31 | ```swift 32 | let output = try await replicate.run( 33 | "stability-ai/stable-diffusion-3", 34 | ["prompt": "a 19th century portrait of a gentleman otter"] 35 | ) { prediction in 36 | // Print the prediction status after each update 37 | print(prediction.status) 38 | } 39 | 40 | print(output) 41 | // ["https://replicate.delivery/yhqm/bh9SsjWXY3pGKJyQzYjQlsZPzcNZ4EYOeEsPjFytc5TjYeNTA/R8_SD3_00001_.webp"] 42 | ``` 43 | 44 | Or fetch a model by name and create a prediction against its latest version: 45 | 46 | ```swift 47 | let model = try await replicate.getModel("stability-ai/stable-diffusion-3") 48 | if let latestVersion = model.latestVersion { 49 | let prompt = "a 19th century portrait of a gentleman otter" 50 | let prediction = try await replicate.createPrediction(version: latestVersion.id, 51 | input: ["prompt": "\(prompt)"], 52 | wait: true) 53 | print(prediction.id) 54 | // "s654jhww3hrm60ch11v8t3zpkg" 55 | print(prediction.output) 56 | // ["https://replicate.delivery/yhqm/bh9SsjWXY3pGKJyQzYjQlsZPzcNZ4EYOeEsPjFytc5TjYeNTA/R8_SD3_00001_.webp"] 57 | } 58 | ``` 59 | 60 | Some models, 61 | like [tencentarc/gfpgan](https://replicate.com/tencentarc/gfpgan), 62 | receive images as inputs. 63 | To run a model that takes a file input you can pass either 64 | a URL to a publicly accessible file on the Internet 65 | or use the `uriEncoded(mimeType:) helper method to create 66 | a base64-encoded data URL from the contents of a local file. 67 | 68 | ```swift 69 | let model = try await replicate.getModel("tencentarc/gfpgan") 70 | if let latestVersion = model.latestVersion { 71 | let data = try! Data(contentsOf: URL(fileURLWithPath: "/path/to/image.jpg")) 72 | let mimeType = "image/jpeg" 73 | let prediction = try await replicate.createPrediction(version: latestVersion.id, 74 | input: ["img": "\(data.uriEncoded(mimeType: mimeType))"]) 75 | print(prediction.output) 76 | // ["https://replicate.delivery/mgxm/85f53415-0dc7-4703-891f-1e6f912119ad/output.png"] 77 | } 78 | ``` 79 | 80 | You can start a model and run it in the background: 81 | 82 | ```swift 83 | let model = replicate.getModel("kvfrans/clipdraw") 84 | 85 | let prompt = "watercolor painting of an underwater submarine" 86 | var prediction = replicate.createPrediction(version: model.latestVersion!.id, 87 | input: ["prompt": "\(prompt)"]) 88 | print(prediction.status) 89 | // "starting" 90 | 91 | try await prediction.wait(with: replicate) 92 | print(prediction.status) 93 | // "succeeded" 94 | ``` 95 | 96 | You can cancel a running prediction: 97 | 98 | ```swift 99 | let model = replicate.getModel("kvfrans/clipdraw") 100 | 101 | let prompt = "watercolor painting of an underwater submarine" 102 | var prediction = replicate.createPrediction(version: model.latestVersion!.id, 103 | input: ["prompt": "\(prompt)"]) 104 | print(prediction.status) 105 | // "starting" 106 | 107 | try await prediction.cancel(with: replicate) 108 | print(prediction.status) 109 | // "canceled" 110 | ``` 111 | 112 | You can list all the predictions you've run: 113 | 114 | ```swift 115 | var predictions: [Prediction] = [] 116 | var cursor: Replicate.Client.Pagination.Cursor? 117 | let limit = 100 118 | 119 | repeat { 120 | let page = try await replicate.getPredictions(cursor: cursor) 121 | predictions.append(contentsOf: page.results) 122 | cursor = page.next 123 | } while predictions.count < limit && cursor != nil 124 | ``` 125 | 126 | ## Adding `Replicate` as a Dependency 127 | 128 | To use the `Replicate` library in a Swift project, 129 | add it to the dependencies for your package and your target: 130 | 131 | ```swift 132 | let package = Package( 133 | // name, platforms, products, etc. 134 | dependencies: [ 135 | // other dependencies 136 | .package(url: "https://github.com/replicate/replicate-swift", from: "0.24.0"), 137 | ], 138 | targets: [ 139 | .target(name: "", dependencies: [ 140 | // other dependencies 141 | .product(name: "Replicate", package: "replicate-swift"), 142 | ]), 143 | // other targets 144 | ] 145 | ) 146 | ``` 147 | 148 | [Replicate]: https://replicate.com 149 | -------------------------------------------------------------------------------- /Sources/Replicate/Account.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.URL 2 | 3 | /// A Replicate account. 4 | public struct Account: Hashable { 5 | /// The acount type. 6 | public enum AccountType: String, CaseIterable, Hashable, Codable { 7 | /// A user. 8 | case user 9 | 10 | /// An organization. 11 | case organization 12 | } 13 | 14 | /// The type of account. 15 | var type: AccountType 16 | 17 | /// The username of the account. 18 | var username: String 19 | 20 | /// The name of the account. 21 | var name: String 22 | 23 | /// The GitHub URL of the account. 24 | var githubURL: URL? 25 | } 26 | 27 | // MARK: - Identifiable 28 | 29 | extension Account: Identifiable { 30 | public typealias ID = String 31 | 32 | public var id: ID { 33 | return self.username 34 | } 35 | } 36 | 37 | // MARK: - CustomStringConvertible 38 | 39 | extension Account: CustomStringConvertible { 40 | public var description: String { 41 | return self.username 42 | } 43 | } 44 | 45 | // MARK: - Codable 46 | 47 | extension Account: Codable { 48 | public enum CodingKeys: String, CodingKey { 49 | case type 50 | case username 51 | case name 52 | case githubURL = "github_url" 53 | } 54 | } 55 | 56 | extension Account.AccountType: CustomStringConvertible { 57 | public var description: String { 58 | return self.rawValue 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /Sources/Replicate/Client.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | #if canImport(FoundationNetworking) 4 | import FoundationNetworking 5 | #endif 6 | 7 | /// A Replicate HTTP API client. 8 | /// 9 | /// See https://replicate.com/docs/reference/http 10 | public class Client { 11 | /// The base URL for requests made by the client. 12 | public let baseURLString: String 13 | 14 | /// The value for the `User-Agent` header sent in requests, if any. 15 | public let userAgent: String? 16 | 17 | /// The API token used in the `Authorization` header sent in requests. 18 | private let token: String 19 | 20 | /// The underlying client session. 21 | internal(set) public var session: URLSession 22 | 23 | /// The retry policy for requests made by the client. 24 | public var retryPolicy: RetryPolicy = .default 25 | 26 | /// Creates a client with the specified API token. 27 | /// 28 | /// You can get an Replicate API token on your 29 | /// [account page](https://replicate.com/account). 30 | /// 31 | /// - Parameters: 32 | /// - session: The underlying client session. Defaults to `URLSession(configuration: .default)`. 33 | /// - baseURLString: The base URL for requests made by the client. Defaults to "https://api.replicate.com/v1/". 34 | /// - userAgent: The value for the `User-Agent` header sent in requests, if any. Defaults to `nil`. 35 | /// - token: The API token. 36 | public init( 37 | session: URLSession = URLSession(configuration: .default), 38 | baseURLString: String = "https://api.replicate.com/v1/", 39 | userAgent: String? = nil, 40 | token: String 41 | ) 42 | { 43 | var baseURLString = baseURLString 44 | if !baseURLString.hasSuffix("/") { 45 | baseURLString = baseURLString.appending("/") 46 | } 47 | 48 | self.baseURLString = baseURLString 49 | self.userAgent = userAgent 50 | self.token = token 51 | self.session = session 52 | } 53 | 54 | // MARK: - 55 | 56 | /// Runs a model and waits for its output. 57 | /// 58 | /// - Parameters: 59 | /// - identifier: 60 | /// The model version identifier in the format 61 | /// `{owner}/{name}` or `{owner}/{name}:{version}`. 62 | /// - input: 63 | /// The input depends on what model you are running. 64 | /// To see the available inputs, 65 | /// click the "Run with API" tab on the model you are running. 66 | /// For example, 67 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3) 68 | /// takes `prompt` as an input. 69 | /// - webhook: 70 | /// A webhook that is called when the prediction has completed. 71 | /// It will be a `POST` request where 72 | /// the request body is the same as 73 | /// the response body of the get prediction endpoint. 74 | /// If there are network problems, 75 | /// we will retry the webhook a few times, 76 | /// so make sure it can be safely called more than once. 77 | /// - type: 78 | /// The expected output type. Defaults to `Value.self`. 79 | /// - updateHandler: 80 | /// A closure that executes with the updated prediction 81 | /// after each polling request to the API. 82 | /// If the prediction is in a terminal state 83 | /// (e.g. `succeeded`, `failed`, or `canceled`), 84 | /// it's returned immediately and the closure is not executed. 85 | /// Use this to provide feedback to the user 86 | /// about the progress of the prediction, 87 | /// or throw `CancellationError` to stop waiting 88 | /// for the prediction to finish. 89 | /// - Returns: 90 | /// The output of the model, 91 | /// or `nil` if the output couldn't be decoded. 92 | /// - Throws: 93 | /// Any error thrown from the `updateHandler` closure 94 | /// other than ``CancellationError``. 95 | public func run( 96 | _ identifier: Identifier, 97 | input: Input, 98 | webhook: Webhook? = nil, 99 | _ type: Output.Type = Value.self, 100 | updateHandler: @escaping (Prediction) throws -> Void = { _ in () } 101 | ) async throws -> Output? { 102 | var prediction: Prediction 103 | if let version = identifier.version { 104 | prediction = try await createPrediction(Prediction.self, 105 | version: version, 106 | input: input, 107 | webhook: webhook) 108 | } else { 109 | prediction = try await createPrediction(Prediction.self, 110 | model: "\(identifier.owner)/\(identifier.name)", 111 | input: input, 112 | webhook: webhook) 113 | } 114 | 115 | try await prediction.wait(with: self, updateHandler: updateHandler) 116 | 117 | if prediction.status == .failed { 118 | throw prediction.error ?? Error(detail: "Prediction failed") 119 | } 120 | 121 | return prediction.output 122 | } 123 | 124 | // MARK: - 125 | 126 | /// Create a prediction 127 | /// 128 | /// - Parameters: 129 | /// - type: The type of the prediction. Defaults to `AnyPrediction.self`. 130 | /// - id: The ID of the model version that you want to run. 131 | /// You can get your model's versions using the API, 132 | /// or find them on the website by clicking 133 | /// the "Versions" tab on the Replicate model page, 134 | /// e.g. replicate.com/replicate/hello-world/versions, 135 | /// then copying the full SHA256 hash from the URL. 136 | /// 137 | /// The version ID is the same as the Docker image ID 138 | /// that's created when you build your model. 139 | /// - input: 140 | /// The input depends on what model you are running. 141 | /// To see the available inputs, 142 | /// click the "Run with API" tab on the model you are running. 143 | /// For example, 144 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3) 145 | /// takes `prompt` as an input. 146 | /// - webhook: 147 | /// A webhook that is called when the prediction has completed. 148 | /// It will be a `POST` request where 149 | /// the request body is the same as 150 | /// the response body of the get prediction endpoint. 151 | /// If there are network problems, 152 | /// we will retry the webhook a few times, 153 | /// so make sure it can be safely called more than once. 154 | /// - wait: 155 | /// If set to `true`, 156 | /// this method refreshes the prediction until it completes 157 | /// (``Prediction/status`` is `.succeeded` or `.failed`). 158 | /// By default, this is `false`, 159 | /// and this method returns the prediction object encoded 160 | /// in the original creation response 161 | /// (``Prediction/status`` is `.starting`). 162 | /// - Returns: The created prediction. 163 | @available(*, deprecated, message: "wait parameter is deprecated; use ``Prediction/wait(with:)`` or ``Client/run(_:input:webhook:_:)``") 164 | public func createPrediction( 165 | _ type: Prediction.Type = AnyPrediction.self, 166 | version id: Model.Version.ID, 167 | input: Input, 168 | webhook: Webhook? = nil, 169 | wait: Bool 170 | ) async throws -> Prediction { 171 | var params: [String: Value] = [ 172 | "version": "\(id)", 173 | "input": try Value(input) 174 | ] 175 | 176 | if let webhook { 177 | params["webhook"] = "\(webhook.url.absoluteString)" 178 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" }) 179 | } 180 | 181 | var prediction: Prediction = try await fetch(.post, "predictions", params: params) 182 | if wait { 183 | try await prediction.wait(with: self) 184 | return prediction 185 | } else { 186 | return prediction 187 | } 188 | } 189 | 190 | /// Create a prediction from a model version 191 | /// 192 | /// - Parameters: 193 | /// - type: 194 | /// The type of the prediction. Defaults to `AnyPrediction.self`. 195 | /// - version: 196 | /// The ID of the model version that you want to run. 197 | /// You can get your model's versions using the API, 198 | /// or find them on the website by clicking 199 | /// the "Versions" tab on the Replicate model page, 200 | /// e.g. replicate.com/replicate/hello-world/versions, 201 | /// then copying the full SHA256 hash from the URL. 202 | /// The version ID is the same as the Docker image ID 203 | /// that's created when you build your model. 204 | /// - input: The input depends on what model you are running. 205 | /// To see the available inputs, 206 | /// click the "Run with API" tab on the model you are running. 207 | /// For example, 208 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3) 209 | /// takes `prompt` as an input. 210 | /// - webhook: A webhook that is called when the prediction has completed. 211 | /// It will be a `POST` request where 212 | /// the request body is the same as 213 | /// the response body of the get prediction endpoint. 214 | /// If there are network problems, 215 | /// we will retry the webhook a few times, 216 | /// so make sure it can be safely called more than once. 217 | /// - stream: Whether to stream the prediction output. 218 | /// By default, this is `false`. 219 | /// - Returns: The created prediction. 220 | public func createPrediction( 221 | _ type: Prediction.Type = AnyPrediction.self, 222 | version id: Model.Version.ID, 223 | input: Input, 224 | webhook: Webhook? = nil, 225 | stream: Bool = false 226 | ) async throws -> Prediction { 227 | var params: [String: Value] = [ 228 | "version": "\(id)", 229 | "input": try Value(input) 230 | ] 231 | 232 | if let webhook { 233 | params["webhook"] = "\(webhook.url.absoluteString)" 234 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" }) 235 | } 236 | 237 | if stream { 238 | params["stream"] = true 239 | } 240 | 241 | return try await fetch(.post, "predictions", params: params) 242 | } 243 | 244 | /// Create a prediction from a model 245 | /// 246 | /// - Parameters: 247 | /// - type: 248 | /// The type of the prediction. Defaults to `AnyPrediction.self`. 249 | /// - model: 250 | /// The ID of the model that you want to run. 251 | /// - input: 252 | /// The input depends on what model you are running. 253 | /// To see the available inputs, 254 | /// click the "Run with API" tab on the model you are running. 255 | /// For example, 256 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3) 257 | /// takes `prompt` as an input. 258 | /// - webhook: 259 | /// A webhook that is called when the prediction has completed. 260 | /// It will be a `POST` request where 261 | /// the request body is the same as 262 | /// the response body of the get prediction endpoint. 263 | /// If there are network problems, 264 | /// we will retry the webhook a few times, 265 | /// so make sure it can be safely called more than once. 266 | /// - stream: Whether to stream the prediction output. 267 | /// By default, this is `false`. 268 | /// - Returns: The created prediction. 269 | public func createPrediction( 270 | _ type: Prediction.Type = AnyPrediction.self, 271 | model id: Model.ID, 272 | input: Input, 273 | webhook: Webhook? = nil, 274 | stream: Bool = false 275 | ) async throws -> Prediction { 276 | var params: [String: Value] = [ 277 | "input": try Value(input) 278 | ] 279 | 280 | if let webhook { 281 | params["webhook"] = "\(webhook.url.absoluteString)" 282 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" }) 283 | } 284 | 285 | if stream { 286 | params["stream"] = true 287 | } 288 | 289 | return try await fetch(.post, "models/\(id)/predictions", params: params) 290 | } 291 | 292 | /// Create a prediction using a deployment 293 | /// 294 | /// - Parameters: 295 | /// - type: 296 | /// The type of the prediction. Defaults to `AnyPrediction.self`. 297 | /// - deployment: 298 | /// The ID of the deployment. 299 | /// - input: 300 | /// The input depends on what model you are running. 301 | /// To see the available inputs, 302 | /// click the "Run with API" tab on the model you are running. 303 | /// For example, 304 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3) 305 | /// takes `prompt` as an input. 306 | /// - webhook: 307 | /// A webhook that is called when the prediction has completed. 308 | /// It will be a `POST` request where 309 | /// the request body is the same as 310 | /// the response body of the get prediction endpoint. 311 | /// If there are network problems, 312 | /// we will retry the webhook a few times, 313 | /// so make sure it can be safely called more than once. 314 | /// - stream: 315 | /// Whether to stream the prediction output. 316 | /// By default, this is `false`. 317 | /// - Returns: The created prediction. 318 | public func createPrediction( 319 | _ type: Prediction.Type = AnyPrediction.self, 320 | deployment id: Deployment.ID, 321 | input: Input, 322 | webhook: Webhook? = nil, 323 | stream: Bool = false 324 | ) async throws -> Prediction { 325 | var params: [String: Value] = [ 326 | "input": try Value(input) 327 | ] 328 | 329 | if let webhook { 330 | params["webhook"] = "\(webhook.url.absoluteString)" 331 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" }) 332 | } 333 | 334 | if stream { 335 | params["stream"] = true 336 | } 337 | 338 | return try await fetch(.post, "deployments/\(id)/predictions", params: params) 339 | } 340 | 341 | @available(*, deprecated, renamed: "listPredictions(_:cursor:)") 342 | public func getPredictions( 343 | _ type: Prediction.Type = AnyPrediction.self, 344 | cursor: Pagination.Cursor? = nil 345 | ) async throws -> Pagination.Page> 346 | { 347 | return try await listPredictions(type, cursor: cursor) 348 | } 349 | 350 | /// List predictions 351 | /// 352 | /// - Parameters: 353 | /// - type: 354 | /// The type of the predictions. Defaults to `AnyPrediction.self`. 355 | /// - cursor: 356 | /// A pointer to a page of results to fetch. 357 | /// - Returns: A page of predictions. 358 | public func listPredictions( 359 | _ type: Prediction.Type = AnyPrediction.self, 360 | cursor: Pagination.Cursor? = nil 361 | ) async throws -> Pagination.Page> 362 | { 363 | return try await fetch(.get, "predictions", cursor: cursor) 364 | } 365 | 366 | /// Get a prediction 367 | /// 368 | /// - Parameters: 369 | /// - type: The type of the prediction. Defaults to `AnyPrediction.self`. 370 | /// - id: The ID of the prediction you want to fetch. 371 | /// - Returns: The requested prediction. 372 | public func getPrediction( 373 | _ type: Prediction.Type = AnyPrediction.self, 374 | id: Prediction.ID 375 | ) async throws -> Prediction { 376 | return try await fetch(.get, "predictions/\(id)") 377 | } 378 | 379 | /// Cancel a prediction 380 | /// 381 | /// - Parameters: 382 | /// - type: 383 | /// The type of the prediction. Defaults to `AnyPrediction.self`. 384 | /// - id: 385 | /// The ID of the prediction you want to cancel. 386 | /// - Returns: The canceled prediction. 387 | public func cancelPrediction( 388 | _ type: Prediction.Type = AnyPrediction.self, 389 | id: Prediction.ID 390 | ) async throws -> Prediction { 391 | return try await fetch(.post, "predictions/\(id)/cancel") 392 | } 393 | 394 | // MARK: - 395 | 396 | /// List public models 397 | /// - Parameters: 398 | /// - cursor: 399 | /// A pointer to a page of results to fetch. 400 | /// - Returns: A page of models. 401 | public func listModels(cursor: Pagination.Cursor? = nil) 402 | async throws -> Pagination.Page 403 | { 404 | return try await fetch(.get, "models", cursor: cursor) 405 | } 406 | 407 | /// Search for public models on Replicate. 408 | /// 409 | /// - Parameters: 410 | /// - query: The search query string. 411 | /// - Returns: A page of models matching the search query. 412 | public func searchModels(query: String) async throws -> Pagination.Page { 413 | var request = try createRequest(method: .query, path: "models") 414 | request.addValue("text/plain", forHTTPHeaderField: "Content-Type") 415 | request.httpBody = query.data(using: .utf8) 416 | return try await sendRequest(request) 417 | } 418 | 419 | /// Get a model 420 | /// 421 | /// - Parameters: 422 | /// - id: The model identifier, comprising 423 | /// the name of the user or organization that owns the model and 424 | /// the name of the model. 425 | /// For example, "stability-ai/stable-diffusion-3". 426 | /// - Returns: The requested model. 427 | public func getModel(_ id: Model.ID) 428 | async throws -> Model 429 | { 430 | return try await fetch(.get, "models/\(id)") 431 | } 432 | 433 | /// Create a model 434 | /// 435 | /// - Parameters: 436 | /// - owner: 437 | /// The name of the user or organization that will own the model. 438 | /// This must be the same as the user or organization that is making the API request. 439 | /// In other words, the API token used in the request must belong to this user or organization. 440 | /// - name: 441 | /// The name of the model. 442 | /// This must be unique among all models owned by the user or organization. 443 | /// - visibility: 444 | /// Whether the model should be public or private. 445 | /// A public model can be viewed and run by anyone, 446 | /// whereas a private model can be viewed and run only by the user or organization members 447 | /// that own the model. 448 | /// - hardware: 449 | /// The SKU for the hardware used to run the model. 450 | /// Possible values can be found by calling ``listHardware()``. 451 | /// - description: 452 | /// A description of the model. 453 | /// - githubURL: 454 | /// A URL for the model's source code on GitHub. 455 | /// - paperURL: 456 | /// A URL for the model's paper. 457 | /// - licenseURL: 458 | /// A URL for the model's license. 459 | /// - coverImageURL: 460 | /// A URL for the model's cover image. 461 | /// This should be an image file. 462 | /// - Returns: The created model. 463 | public func createModel( 464 | owner: String, 465 | name: String, 466 | visibility: Model.Visibility, 467 | hardware: Hardware.ID, 468 | description: String? = nil, 469 | githubURL: URL? = nil, 470 | paperURL: URL? = nil, 471 | licenseURL: URL? = nil, 472 | coverImageURL: URL? = nil 473 | ) async throws -> Model 474 | { 475 | var params: [String: Value] = [ 476 | "owner": "\(owner)", 477 | "name": "\(name)", 478 | "visibility": "\(visibility.rawValue)", 479 | "hardware": "\(hardware)" 480 | ] 481 | 482 | if let description { 483 | params["description"] = "\(description)" 484 | } 485 | 486 | if let githubURL { 487 | params["github_url"] = "\(githubURL)" 488 | } 489 | 490 | if let paperURL { 491 | params["paper_url"] = "\(paperURL)" 492 | } 493 | 494 | if let licenseURL { 495 | params["license_url"] = "\(licenseURL)" 496 | } 497 | 498 | if let coverImageURL { 499 | params["cover_image_url"] = "\(coverImageURL)" 500 | } 501 | 502 | return try await fetch(.post, "models", params: params) 503 | } 504 | 505 | // MARK: - 506 | 507 | /// List hardware available for running a model on Replicate. 508 | /// 509 | /// - Returns: An array of hardware. 510 | public func listHardware() async throws -> [Hardware] { 511 | return try await fetch(.get, "hardware") 512 | } 513 | 514 | 515 | // MARK: - 516 | 517 | @available(*, deprecated, renamed: "listModelVersions(_:cursor:)") 518 | public func getModelVersions(_ id: Model.ID, 519 | cursor: Pagination.Cursor? = nil) 520 | async throws -> Pagination.Page 521 | { 522 | return try await listModelVersions(id, cursor: cursor) 523 | } 524 | 525 | /// List model versions 526 | /// 527 | /// - Parameters: 528 | /// - id: 529 | /// The model identifier, comprising 530 | /// the name of the user or organization that owns the model and 531 | /// the name of the model. 532 | /// For example, "stability-ai/stable-diffusion-3". 533 | /// - cursor: 534 | /// A pointer to a page of results to fetch. 535 | /// - Returns: A page of model versions. 536 | public func listModelVersions(_ id: Model.ID, 537 | cursor: Pagination.Cursor? = nil) 538 | async throws -> Pagination.Page 539 | { 540 | return try await fetch(.get, "models/\(id)/versions", cursor: cursor) 541 | } 542 | 543 | /// Get a model version 544 | /// 545 | /// - Parameters: 546 | /// - id: 547 | /// The model identifier, comprising 548 | /// the name of the user or organization that owns the model and 549 | /// the name of the model. 550 | /// For example, "stability-ai/stable-diffusion-3". 551 | /// - version: 552 | /// The ID of the version. 553 | public func getModelVersion(_ id: Model.ID, 554 | version: Model.Version.ID) 555 | async throws -> Model.Version 556 | { 557 | return try await fetch(.get, "models/\(id)/versions/\(version)") 558 | } 559 | 560 | // MARK: - 561 | 562 | /// List collections of models 563 | /// - Parameters: 564 | /// - Parameter cursor: A pointer to a page of results to fetch. 565 | public func listModelCollections(cursor: Pagination.Cursor? = nil) 566 | async throws -> Pagination.Page 567 | { 568 | return try await fetch(.get, "collections") 569 | } 570 | 571 | /// Get a collection of models 572 | /// 573 | /// - Parameters: 574 | /// - slug: 575 | /// The slug of the collection, 576 | /// like super-resolution or image-restoration. 577 | /// 578 | /// See 579 | public func getModelCollection(_ slug: String) 580 | async throws -> Model.Collection 581 | { 582 | return try await fetch(.get, "collections/\(slug)") 583 | } 584 | 585 | // MARK: - 586 | 587 | /// Train a model on Replicate. 588 | /// 589 | /// To find out which models can be trained, 590 | /// check out the [trainable language models collection](https://replicate.com/collections/trainable-language-models). 591 | /// 592 | /// - Parameters: 593 | /// - model: 594 | /// The base model used to train a new version. 595 | /// - id: 596 | /// The ID of the base model version 597 | /// that you're using to train a new model version. 598 | /// 599 | /// You can get your model's versions using the API, 600 | /// or find them on the website by clicking 601 | /// the "Versions" tab on the Replicate model page, 602 | /// e.g. replicate.com/replicate/hello-world/versions, 603 | /// then copying the full SHA256 hash from the URL. 604 | /// 605 | /// The version ID is the same as the Docker image ID 606 | /// that's created when you build your model. 607 | /// - destination: 608 | /// The desired model to push to in the format `{owner}/{model_name}`. 609 | /// This should be an existing model owned by 610 | /// the user or organization making the API request. 611 | /// - input: 612 | /// An object containing inputs to the 613 | /// Cog model's `train()` function. 614 | /// - webhook: 615 | /// A webhook that is called when the training has completed. 616 | /// 617 | /// It will be a `POST` request where 618 | /// the request body is the same as 619 | /// the response body of the get training endpoint. 620 | /// If there are network problems, 621 | /// we will retry the webhook a few times, 622 | /// so make sure it can be safely called more than once. 623 | public func createTraining( 624 | _ type: Training.Type = AnyTraining.self, 625 | model: Model.ID, 626 | version: Model.Version.ID, 627 | destination: Model.ID, 628 | input: Input, 629 | webhook: Webhook? = nil 630 | ) async throws -> Training 631 | { 632 | var params: [String: Value] = [ 633 | "destination": "\(destination)", 634 | "input": try Value(input) 635 | ] 636 | 637 | if let webhook { 638 | params["webhook"] = "\(webhook.url.absoluteString)" 639 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" }) 640 | } 641 | 642 | return try await fetch(.post, "models/\(model)/versions/\(version)/trainings", params: params) 643 | } 644 | 645 | @available(*, deprecated, renamed: "listTrainings(_:cursor:)") 646 | public func getTrainings( 647 | _ type: Training.Type = AnyTraining.self, 648 | cursor: Pagination.Cursor? = nil 649 | ) async throws -> Pagination.Page> 650 | { 651 | return try await listTrainings(type, cursor: cursor) 652 | } 653 | 654 | /// List trainings 655 | /// 656 | /// - Parameters: 657 | /// - type: The type of the training. Defaults to `AnyTraining.self`. 658 | /// - cursor: A pointer to a page of results to fetch. 659 | /// - Returns: A page of trainings. 660 | public func listTrainings( 661 | _ type: Training.Type = AnyTraining.self, 662 | cursor: Pagination.Cursor? = nil 663 | ) async throws -> Pagination.Page> 664 | { 665 | return try await fetch(.get, "trainings", cursor: cursor) 666 | } 667 | 668 | /// Get a training 669 | /// 670 | /// - Parameters: 671 | /// - type: The type of the training. Defaults to `AnyTraining.self`. 672 | /// - id: The ID of the training you want to fetch. 673 | /// - Returns: The requested training. 674 | public func getTraining( 675 | _ type: Training.Type = AnyTraining.self, 676 | id: Training.ID 677 | ) async throws -> Training 678 | { 679 | return try await fetch(.get, "trainings/\(id)") 680 | } 681 | 682 | /// Cancel a training 683 | /// 684 | /// - Parameters: 685 | /// - type: The type of the training. Defaults to `AnyTraining.self`. 686 | /// - id: The ID of the training you want to cancel. 687 | /// - Returns: The canceled training. 688 | public func cancelTraining( 689 | _ type: Training.Type = AnyTraining.self, 690 | id: Training.ID 691 | ) async throws -> Training 692 | { 693 | return try await fetch(.post, "trainings/\(id)/cancel") 694 | } 695 | 696 | // MARK: - 697 | 698 | /// Get the current account 699 | /// 700 | /// - Returns: The current account. 701 | public func getCurrentAccount() async throws -> Account { 702 | return try await fetch(.get, "account") 703 | } 704 | 705 | // MARK: - 706 | 707 | /// Get a deployment 708 | /// 709 | /// - Parameters: 710 | /// - id: The deployment identifier, comprising 711 | /// the name of the user or organization that owns the deployment and 712 | /// the name of the deployment. 713 | /// For example, "replicate/my-app-image-generator". 714 | /// - Returns: The requested deployment. 715 | public func getDeployment(_ id: Deployment.ID) 716 | async throws -> Deployment 717 | { 718 | return try await fetch(.get, "deployments/\(id)") 719 | } 720 | 721 | // MARK: - 722 | 723 | private enum Method: String, Hashable { 724 | case get = "GET" 725 | case post = "POST" 726 | case query = "QUERY" 727 | } 728 | 729 | private func fetch(_ method: Method, 730 | _ path: String, 731 | cursor: Pagination.Cursor?) 732 | async throws -> Pagination.Page { 733 | var params: [String: Value]? = nil 734 | if let cursor { 735 | params = ["cursor": "\(cursor)"] 736 | } 737 | 738 | let request = try createRequest(method: method, path: path, params: params) 739 | return try await sendRequest(request) 740 | } 741 | 742 | private func fetch(_ method: Method, 743 | _ path: String, 744 | params: [String: Value]? = nil) 745 | async throws -> T { 746 | let request = try createRequest(method: method, path: path, params: params) 747 | return try await sendRequest(request) 748 | } 749 | 750 | private func createRequest(method: Method, path: String, params: [String: Value]? = nil) throws -> URLRequest { 751 | var urlComponents = URLComponents(string: self.baseURLString.appending(path)) 752 | var httpBody: Data? = nil 753 | 754 | switch method { 755 | case .get: 756 | if let params { 757 | var queryItems: [URLQueryItem] = [] 758 | for (key, value) in params { 759 | queryItems.append(URLQueryItem(name: key, value: value.description)) 760 | } 761 | urlComponents?.queryItems = queryItems 762 | } 763 | case .post: 764 | if let params { 765 | let encoder = JSONEncoder() 766 | httpBody = try encoder.encode(params) 767 | } 768 | case .query: 769 | if let params, let queryString = params["query"] { 770 | httpBody = queryString.description.data(using: .utf8) 771 | } 772 | } 773 | 774 | guard let url = urlComponents?.url else { 775 | throw Error(detail: "invalid request \(method) \(path)") 776 | } 777 | 778 | var request = URLRequest(url: url) 779 | request.httpMethod = method.rawValue 780 | 781 | if !token.isEmpty { 782 | request.addValue("Bearer \(token)", forHTTPHeaderField: "Authorization") 783 | } 784 | request.addValue("application/json", forHTTPHeaderField: "Accept") 785 | 786 | if let httpBody { 787 | request.httpBody = httpBody 788 | request.addValue("application/json", forHTTPHeaderField: "Content-Type") 789 | } 790 | 791 | if let userAgent { 792 | request.addValue(userAgent, forHTTPHeaderField: "User-Agent") 793 | } 794 | 795 | return request 796 | } 797 | 798 | private func sendRequest(_ request: URLRequest) async throws -> T { 799 | let (data, response) = try await session.data(for: request) 800 | 801 | let decoder = JSONDecoder() 802 | decoder.dateDecodingStrategy = .iso8601WithFractionalSeconds 803 | 804 | switch (response as? HTTPURLResponse)?.statusCode { 805 | case (200..<300)?: 806 | return try decoder.decode(T.self, from: data) 807 | default: 808 | if let error = try? decoder.decode(Error.self, from: data) { 809 | throw error 810 | } 811 | 812 | if let string = String(data: data, encoding: .utf8) { 813 | throw Error(detail: "invalid response: \(response) \n \(string)") 814 | } 815 | 816 | throw Error(detail: "invalid response: \(response)") 817 | } 818 | } 819 | } 820 | 821 | // MARK: - Decodable 822 | 823 | extension Client.Pagination.Page: Decodable where Result: Decodable { 824 | private enum CodingKeys: String, CodingKey { 825 | case results 826 | case previous 827 | case next 828 | } 829 | 830 | public init(from decoder: Decoder) throws { 831 | let container = try decoder.container(keyedBy: CodingKeys.self) 832 | self.previous = try? container.decode(Client.Pagination.Cursor.self, forKey: .previous) 833 | self.next = try? container.decode(Client.Pagination.Cursor.self, forKey: .next) 834 | self.results = try container.decode([Result].self, forKey: .results) 835 | } 836 | } 837 | 838 | extension Client.Pagination.Cursor: Decodable { 839 | public init(from decoder: Decoder) throws { 840 | let container = try decoder.singleValueContainer() 841 | let string = try container.decode(String.self) 842 | guard let urlComponents = URLComponents(string: string), 843 | let queryItem = urlComponents.queryItems?.first(where: { $0.name == "cursor" }), 844 | let value = queryItem.value 845 | else { 846 | let context = DecodingError.Context(codingPath: container.codingPath, debugDescription: "invalid cursor") 847 | throw DecodingError.dataCorrupted(context) 848 | } 849 | 850 | self.rawValue = value 851 | } 852 | } 853 | 854 | // MARK: - 855 | 856 | extension Client { 857 | /// A namespace for pagination cursor and page types. 858 | public enum Pagination { 859 | /// A paginated collection of results. 860 | public struct Page { 861 | /// A pointer to the previous page of results 862 | public let previous: Cursor? 863 | 864 | /// A pointer to the next page of results. 865 | public let next: Cursor? 866 | 867 | /// The results for this page. 868 | public let results: [Result] 869 | } 870 | 871 | /// A pointer to a page of results. 872 | public struct Cursor: RawRepresentable, Hashable { 873 | public var rawValue: String 874 | 875 | public init(rawValue: String) { 876 | self.rawValue = rawValue 877 | } 878 | } 879 | } 880 | } 881 | 882 | extension Client.Pagination.Cursor: CustomStringConvertible { 883 | public var description: String { 884 | return self.rawValue 885 | } 886 | } 887 | 888 | extension Client.Pagination.Cursor: ExpressibleByStringLiteral { 889 | public init(stringLiteral value: String) { 890 | self.init(rawValue: value) 891 | } 892 | } 893 | 894 | // MARK: - 895 | 896 | extension Client { 897 | /// A policy for how often a client should retry a request. 898 | public struct RetryPolicy: Equatable, Sequence { 899 | /// A strategy used to determine how long to wait between retries. 900 | public enum Strategy: Hashable { 901 | /// Wait for a constant interval. 902 | /// 903 | /// This strategy implements constant backoff with jitter 904 | /// as described by the equation: 905 | /// 906 | /// $$ 907 | /// t = d + R([-j/2, j/2]) 908 | /// $$ 909 | /// 910 | /// - Parameters: 911 | /// - duration: The constant interval ($d$). 912 | /// - jitter: The amount of random jitter ($j$). 913 | case constant(duration: TimeInterval = 2.0, 914 | jitter: Double = 0.0) 915 | 916 | /// Wait for an exponentially increasing interval. 917 | /// 918 | /// This strategy implements exponential backoff with jitter 919 | /// as described by the equation: 920 | /// 921 | /// $$ 922 | /// t = b^c + R([-j/2, j/2]) 923 | /// $$ 924 | /// 925 | /// - Parameters: 926 | /// - base: The power base ($b$). 927 | /// - multiplier: The power exponent ($c$). 928 | /// - jitter: The amount of random jitter ($j$). 929 | case exponential(base: TimeInterval = 2.0, 930 | multiplier: Double = 2.0, 931 | jitter: Double = 0.5) 932 | } 933 | 934 | /// The strategy used to determine how long to wait between retries. 935 | public let strategy: Strategy 936 | 937 | /// The total maximum amount of time to retry requests. 938 | public let timeout: TimeInterval? 939 | 940 | /// The maximum amount of time between requests. 941 | public let maximumInterval: TimeInterval? 942 | 943 | /// The maximum number of requests to make. 944 | public let maximumRetries: Int? 945 | 946 | /// The default retry policy. 947 | static let `default` = RetryPolicy(strategy: .exponential(), 948 | timeout: 300.0, 949 | maximumInterval: 30.0, 950 | maximumRetries: 10) 951 | 952 | /// Creates a new retry policy. 953 | /// 954 | /// - Parameters: 955 | /// - strategy: The strategy used to determine how long to wait between retries. 956 | /// - timeout: The total maximum amount of time to retry requests. 957 | /// Must be greater than zero, if specified. 958 | /// - maximumInterval: The maximum amount of time between requests. 959 | /// Must be greater than zero, if specified. 960 | /// - maximumRetries: The maximum number of requests to make. 961 | /// Must be greater than zero, if specified. 962 | public init(strategy: Strategy, 963 | timeout: TimeInterval?, 964 | maximumInterval: TimeInterval?, 965 | maximumRetries: Int?) 966 | { 967 | precondition(timeout ?? .greatestFiniteMagnitude > 0) 968 | precondition(maximumInterval ?? .greatestFiniteMagnitude > 0) 969 | precondition(maximumRetries ?? .max > 0) 970 | 971 | self.strategy = strategy 972 | self.timeout = timeout 973 | self.maximumInterval = maximumInterval 974 | self.maximumRetries = maximumRetries 975 | } 976 | 977 | /// An instantiation of a retry policy. 978 | /// 979 | /// This type satisfies a requirement for `RetryPolicy` 980 | /// to conform to the `Sequence` protocol. 981 | public struct Retrier: IteratorProtocol { 982 | /// The number of retry attempts made. 983 | public private(set) var retries: Int = 0 984 | 985 | /// The retry policy. 986 | public let policy: RetryPolicy 987 | 988 | /// The random number generator used to create random values. 989 | private var randomNumberGenerator: any RandomNumberGenerator 990 | 991 | /// A time after which no delay values are produced, if any. 992 | public let deadline: DispatchTime? 993 | 994 | /// Creates a new instantiation of a retry policy. 995 | /// 996 | /// - Parameters: 997 | /// - policy: The retry policy. 998 | /// - randomNumberGenerator: The random number generator used to create random values. 999 | /// - deadline: A time after which no delay values are produced, if any. 1000 | init(policy: RetryPolicy, 1001 | randomNumberGenerator: any RandomNumberGenerator = SystemRandomNumberGenerator(), 1002 | deadline: DispatchTime?) 1003 | { 1004 | self.policy = policy 1005 | self.randomNumberGenerator = randomNumberGenerator 1006 | self.deadline = deadline 1007 | } 1008 | 1009 | /// Returns the next delay amount, or `nil`. 1010 | public mutating func next() -> TimeInterval? { 1011 | guard policy.maximumRetries.flatMap({ $0 > retries }) ?? true else { return nil } 1012 | guard deadline.flatMap({ $0 > .now() }) ?? true else { return nil } 1013 | 1014 | defer { retries += 1 } 1015 | 1016 | let delay: TimeInterval 1017 | switch policy.strategy { 1018 | case .constant(let base, let jitter): 1019 | delay = base + Double.random(jitter: jitter, using: &randomNumberGenerator) 1020 | case .exponential(let base, let multiplier, let jitter): 1021 | delay = base * (pow(multiplier, Double(retries))) + Double.random(jitter: jitter, using: &randomNumberGenerator) 1022 | } 1023 | 1024 | return delay.clamped(to: 0...(policy.maximumInterval ?? .greatestFiniteMagnitude)) 1025 | } 1026 | } 1027 | 1028 | // Returns a new instantiation of the retry policy. 1029 | public func makeIterator() -> Retrier { 1030 | return Retrier(policy: self, 1031 | deadline: timeout.flatMap { 1032 | .now().advanced(by: .nanoseconds(Int($0 * 1e+9))) 1033 | }) 1034 | } 1035 | } 1036 | } 1037 | 1038 | // MARK: - 1039 | 1040 | extension JSONDecoder.DateDecodingStrategy { 1041 | static let iso8601WithFractionalSeconds = custom { decoder in 1042 | let container = try decoder.singleValueContainer() 1043 | let string = try container.decode(String.self) 1044 | 1045 | let formatter = ISO8601DateFormatter() 1046 | formatter.formatOptions = [.withInternetDateTime, 1047 | .withFractionalSeconds] 1048 | 1049 | if let date = formatter.date(from: string) { 1050 | return date 1051 | } 1052 | 1053 | // Try again without fractional seconds 1054 | formatter.formatOptions = [.withInternetDateTime] 1055 | 1056 | guard let date = formatter.date(from: string) else { 1057 | throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid date: \(string)") 1058 | } 1059 | 1060 | return date 1061 | } 1062 | } 1063 | 1064 | private extension Double { 1065 | static func random(jitter amount: Double, 1066 | using generator: inout T) -> Double 1067 | where T : RandomNumberGenerator 1068 | { 1069 | guard !amount.isZero else { return 0.0 } 1070 | return Double.random(in: (-amount / 2)...(amount / 2)) 1071 | } 1072 | } 1073 | 1074 | private extension Comparable { 1075 | func clamped(to range: ClosedRange) -> Self { 1076 | return min(max(self, range.lowerBound), range.upperBound) 1077 | } 1078 | } 1079 | -------------------------------------------------------------------------------- /Sources/Replicate/Deployment.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.Date 2 | 3 | /// A deployment of a model on Replicate. 4 | public struct Deployment: Hashable { 5 | /// The owner of the deployment. 6 | public let owner: String 7 | 8 | /// The name of the deployment. 9 | public let name: String 10 | 11 | /// A release of a deployment. 12 | public struct Release: Hashable { 13 | /// The release number. 14 | let number: Int 15 | 16 | /// The model. 17 | let model: Model.ID 18 | 19 | /// The model version. 20 | let version: Model.Version.ID 21 | 22 | /// The time at which the release was created. 23 | let createdAt: Date 24 | 25 | /// The account that created the release 26 | let createdBy: Account 27 | 28 | /// The configuration of a deployment. 29 | public struct Configuration: Hashable { 30 | /// The configured hardware SKU. 31 | public let hardware: Hardware.ID 32 | 33 | /// A scaling configuration for a deployment. 34 | public struct Scaling: Hashable { 35 | /// The maximum number of instances. 36 | public let maxInstances: Int 37 | 38 | /// The minimum number of instances. 39 | public let minInstances: Int 40 | } 41 | 42 | /// The scaling configuration for the deployment. 43 | public let scaling: Scaling 44 | } 45 | 46 | /// The deployment configuration. 47 | public let configuration: Configuration 48 | } 49 | 50 | public let currentRelease: Release? 51 | } 52 | 53 | // MARK: - Identifiable 54 | 55 | extension Deployment: Identifiable { 56 | public typealias ID = String 57 | 58 | /// The ID of the model. 59 | public var id: ID { "\(owner)/\(name)" } 60 | } 61 | 62 | // MARK: - Codable 63 | 64 | extension Deployment: Codable { 65 | public enum CodingKeys: String, CodingKey { 66 | case owner 67 | case name 68 | case currentRelease = "current_release" 69 | } 70 | } 71 | 72 | extension Deployment.Release: Codable { 73 | public enum CodingKeys: String, CodingKey { 74 | case number 75 | case model 76 | case version 77 | case createdAt = "created_at" 78 | case createdBy = "created_by" 79 | case configuration 80 | } 81 | } 82 | 83 | extension Deployment.Release.Configuration: Codable { 84 | public enum CodingKeys: String, CodingKey { 85 | case hardware 86 | case scaling 87 | } 88 | } 89 | 90 | extension Deployment.Release.Configuration.Scaling: Codable { 91 | public enum CodingKeys: String, CodingKey { 92 | case minInstances = "min_instances" 93 | case maxInstances = "max_instances" 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /Sources/Replicate/Error.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// An error returned by the Replicate HTTP API 4 | public struct Error: Swift.Error, Hashable { 5 | /// A description of the error. 6 | public let detail: String 7 | } 8 | 9 | // MARK: - LocalizedError 10 | 11 | extension Error: LocalizedError { 12 | public var errorDescription: String? { 13 | return self.detail 14 | } 15 | } 16 | 17 | // MARK: - CustomStringConvertible 18 | 19 | extension Error: CustomStringConvertible { 20 | public var description: String { 21 | return self.detail 22 | } 23 | } 24 | 25 | // MARK: - Decodable 26 | 27 | extension Error: Decodable { 28 | private enum CodingKeys: String, CodingKey { 29 | case detail 30 | } 31 | 32 | public init(from decoder: Decoder) throws { 33 | if let container = try? decoder.container(keyedBy: CodingKeys.self) { 34 | self.detail = try container.decode(String.self, forKey: .detail) 35 | } else if let container = try? decoder.singleValueContainer() { 36 | self.detail = try container.decode(String.self) 37 | } else { 38 | let context = DecodingError.Context(codingPath: [], debugDescription: "unable to decode error") 39 | throw DecodingError.dataCorrupted(context) 40 | } 41 | } 42 | } 43 | 44 | // MARK: - Encodable 45 | 46 | extension Error: Encodable {} 47 | -------------------------------------------------------------------------------- /Sources/Replicate/Extensions/Data+uriEncoded.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | private let dataURIPrefix = "data:" 4 | 5 | public extension Data { 6 | static func isURIEncoded(string: String) -> Bool { 7 | return string.hasPrefix(dataURIPrefix) 8 | } 9 | 10 | static func decode(uriEncoded string: String) -> (mimeType: String, data: Data)? { 11 | guard isURIEncoded(string: string) else { 12 | return nil 13 | } 14 | 15 | let components = string.dropFirst(dataURIPrefix.count).components(separatedBy: ",") 16 | guard components.count == 2, 17 | let dataScheme = components.first, 18 | let dataBase64 = components.last 19 | else { 20 | return nil 21 | } 22 | 23 | let mimeType: String 24 | if dataScheme.contains(";") { 25 | mimeType = dataScheme.components(separatedBy: ";").first ?? "" 26 | } else { 27 | mimeType = dataScheme 28 | } 29 | 30 | guard let decodedData = Data(base64Encoded: dataBase64) else { 31 | return nil 32 | } 33 | 34 | return (mimeType: mimeType, data: decodedData) 35 | } 36 | 37 | func uriEncoded(mimeType: String?) -> String { 38 | return "data:\(mimeType ?? "");base64,\(base64EncodedString())" 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /Sources/Replicate/Hardware.swift: -------------------------------------------------------------------------------- 1 | // Hardware for running a model on Replicate. 2 | public struct Hardware: Hashable, Codable { 3 | public typealias ID = String 4 | 5 | /// The product identifier for the hardware. 6 | /// 7 | /// For example, "gpu-a40-large". 8 | public let sku: String 9 | 10 | /// The name of the hardware. 11 | /// 12 | /// For example, "Nvidia A40 (Large) GPU". 13 | public let name: String 14 | } 15 | 16 | // MARK: - Identifiable 17 | 18 | extension Hardware: Identifiable { 19 | public var id: String { 20 | return self.sku 21 | } 22 | } 23 | 24 | // MARK: - CustomStringConvertible 25 | 26 | extension Hardware: CustomStringConvertible { 27 | public var description: String { 28 | return self.name 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /Sources/Replicate/Identifier.swift: -------------------------------------------------------------------------------- 1 | /// An identifier in the form of "{owner}/{name}:{version}". 2 | public struct Identifier: Hashable { 3 | /// The name of the user or organization that owns the model. 4 | public let owner: String 5 | 6 | /// The name of the model. 7 | public let name: String 8 | 9 | /// The version. 10 | let version: Model.Version.ID? 11 | } 12 | 13 | // MARK: - Equatable & Comparable 14 | 15 | extension Identifier: Equatable, Comparable { 16 | public static func == (lhs: Identifier, rhs: Identifier) -> Bool { 17 | return lhs.rawValue.caseInsensitiveCompare(rhs.rawValue) == .orderedSame 18 | } 19 | 20 | public static func < (lhs: Identifier, rhs: Identifier) -> Bool { 21 | return lhs.rawValue.caseInsensitiveCompare(rhs.rawValue) == .orderedAscending 22 | } 23 | } 24 | 25 | // MARK: - RawRepresentable 26 | 27 | extension Identifier: RawRepresentable { 28 | public typealias RawValue = String 29 | 30 | public init?(rawValue: RawValue) { 31 | let components = rawValue.split(separator: "/") 32 | guard components.count == 2 else { return nil } 33 | 34 | if components[1].contains(":") { 35 | let nameAndVersion = components[1].split(separator: ":") 36 | guard nameAndVersion.count == 2 else { return nil } 37 | 38 | self.init(owner: String(components[0]), 39 | name: String(nameAndVersion[0]), 40 | version: Model.Version.ID(nameAndVersion[1])) 41 | } else { 42 | self.init(owner: String(components[0]), 43 | name: String(components[1]), 44 | version: nil) 45 | } 46 | } 47 | 48 | public var rawValue: String { 49 | if let version = version { 50 | return "\(owner)/\(name):\(version)" 51 | } else { 52 | return "\(owner)/\(name)" 53 | } 54 | } 55 | } 56 | 57 | // MARK: - ExpressibleByStringLiteral 58 | 59 | extension Identifier: ExpressibleByStringLiteral { 60 | public init!(stringLiteral value: StringLiteralType) { 61 | guard let identifier = Identifier(rawValue: value) else { 62 | fatalError("Invalid Identifier string literal: \(value)") 63 | } 64 | 65 | self = identifier 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Sources/Replicate/Model.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.Date 2 | import struct Foundation.URL 3 | 4 | /// A machine learning model hosted on Replicate. 5 | public struct Model: Hashable { 6 | /// The visibility of the model. 7 | public enum Visibility: String, Hashable, Decodable { 8 | /// Public visibility. 9 | case `public` 10 | 11 | /// Private visibility. 12 | case `private` 13 | } 14 | 15 | /// A version of a model. 16 | public struct Version: Hashable { 17 | /// The ID of the version. 18 | public let id: ID 19 | 20 | /// When the version was created. 21 | public let createdAt: Date 22 | 23 | /// An OpenAPI description of the model inputs and outputs. 24 | public let openAPISchema: [String: Value] 25 | } 26 | 27 | /// A collection of models. 28 | public struct Collection: Hashable, Decodable { 29 | /// The name of the collection. 30 | public let name: String 31 | 32 | /// The slug of the collection, 33 | /// like super-resolution or image-restoration. 34 | /// 35 | /// See 36 | public let slug: String 37 | 38 | /// A description for the collection. 39 | public let description: String 40 | 41 | /// A list of models in the collection. 42 | public let models: [Model]? 43 | } 44 | 45 | /// The name of the user or organization that owns the model. 46 | public let owner: String 47 | 48 | /// The name of the model. 49 | public let name: String 50 | 51 | /// A link to the model on Replicate. 52 | public let url: URL 53 | 54 | /// A link to the model source code on GitHub. 55 | public let githubURL: URL? 56 | 57 | /// A link to the model's paper. 58 | public let paperURL: URL? 59 | 60 | /// A link to the model's license. 61 | public let licenseURL: URL? 62 | 63 | /// A link to a cover image for the model. 64 | public let coverImageURL: URL? 65 | 66 | /// A description for the model. 67 | public let description: String? 68 | 69 | /// The visibility of the model. 70 | public let visibility: Visibility 71 | 72 | /// The latest version of the model, if any. 73 | public let latestVersion: Version? 74 | 75 | /// The number of times this model has been run. 76 | public let runCount: Int? 77 | 78 | public let defaultExample: AnyPrediction? 79 | } 80 | 81 | // MARK: - Identifiable 82 | 83 | extension Model: Identifiable { 84 | public typealias ID = String 85 | 86 | /// The ID of the model. 87 | public var id: ID { "\(owner)/\(name)" } 88 | } 89 | 90 | extension Model.Version: Identifiable { 91 | public typealias ID = String 92 | } 93 | 94 | extension Model.Collection: Identifiable { 95 | public typealias ID = String 96 | 97 | /// The ID of the model collection. 98 | public var id: String { slug } 99 | } 100 | 101 | // MARK: - Decodable 102 | 103 | extension Model: Decodable { 104 | private enum CodingKeys: String, CodingKey { 105 | case owner 106 | case name 107 | case url 108 | case githubURL = "github_url" 109 | case paperURL = "paper_url" 110 | case licenseURL = "license_url" 111 | case coverImageURL = "cover_image_url" 112 | case description 113 | case visibility 114 | case latestVersion = "latest_version" 115 | case runCount = "run_count" 116 | case defaultExample = "default_example" 117 | } 118 | } 119 | 120 | extension Model.Version: Decodable { 121 | private enum CodingKeys: String, CodingKey { 122 | case id 123 | case createdAt = "created_at" 124 | case openAPISchema = "openapi_schema" 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /Sources/Replicate/Predictable.swift: -------------------------------------------------------------------------------- 1 | /// A type that can make predictions with known inputs and outputs. 2 | public protocol Predictable { 3 | /// The type of the input to the model. 4 | associatedtype Input: Codable 5 | 6 | /// The type of the output from the model. 7 | associatedtype Output: Codable 8 | 9 | /// The ID of the model. 10 | static var modelID: Model.ID { get } 11 | 12 | /// The ID of the model version. 13 | static var versionID: Model.Version.ID { get } 14 | } 15 | 16 | // MARK: - Default Implementations 17 | 18 | extension Predictable { 19 | /// The type of prediction created by the model 20 | public typealias Prediction = Replicate.Prediction 21 | 22 | /// Creates a prediction. 23 | /// 24 | /// - Parameters: 25 | /// - client: The client used to make API requests. 26 | /// - input: The input passed to the model. 27 | /// - wait: 28 | /// If set to `true`, 29 | /// this method refreshes the prediction until it completes 30 | /// (``Prediction/status`` is `.succeeded` or `.failed`). 31 | /// By default, this is `false`, 32 | /// and this method returns the prediction object encoded 33 | /// in the original creation response 34 | /// (``Prediction/status`` is `.starting`). 35 | public static func predict( 36 | with client: Client, 37 | input: Input, 38 | webhook: Webhook? = nil, 39 | wait: Bool = false 40 | ) async throws -> Prediction { 41 | var prediction = try await client.createPrediction(Prediction.self, 42 | version: Self.versionID, 43 | input: input, 44 | webhook: webhook) 45 | 46 | if wait { 47 | try await prediction.wait(with: client) 48 | } 49 | 50 | return prediction 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /Sources/Replicate/Prediction.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.Date 2 | import struct Foundation.TimeInterval 3 | import struct Foundation.URL 4 | import class Foundation.Progress 5 | import struct Dispatch.DispatchTime 6 | 7 | import RegexBuilder 8 | 9 | /// A prediction with unspecified inputs and outputs. 10 | public typealias AnyPrediction = Prediction<[String: Value], Value> 11 | 12 | /// A prediction made by a model hosted on Replicate. 13 | public struct Prediction: Identifiable where Input: Codable, Output: Codable { 14 | public typealias ID = String 15 | 16 | /// Source for creating a prediction. 17 | public enum Source: String, Codable { 18 | /// The prediction was made on the web. 19 | case web 20 | 21 | /// The prediction was made using the API. 22 | case api 23 | } 24 | 25 | /// Metrics for the prediction. 26 | public struct Metrics: Hashable { 27 | /// How long it took to create the prediction, in seconds. 28 | public let predictTime: TimeInterval? 29 | } 30 | 31 | /// The unique ID of the prediction. 32 | /// Can be used to get a single prediction. 33 | /// 34 | /// - SeeAlso: ``Client/getPrediction(id:)`` 35 | public let id: ID 36 | 37 | /// The model used to create the prediction. 38 | public let modelID: Model.ID 39 | 40 | /// The version of the model used to create the prediction. 41 | public let versionID: Model.Version.ID 42 | 43 | /// Where the prediction was made. 44 | public let source: Source? 45 | 46 | /// The model's input as a JSON object. 47 | /// 48 | /// The input depends on what model you are running. 49 | /// To see the available inputs, 50 | /// click the "Run with API" tab on the model you are running. 51 | /// For example, 52 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3) 53 | /// takes `prompt` as an input. 54 | /// 55 | /// Files should be passed as data URLs or HTTP URLs. 56 | public let input: Input 57 | 58 | /// The output of the model for the prediction, if completed successfully. 59 | public let output: Output? 60 | 61 | /// The status of the prediction. 62 | public let status: Status 63 | 64 | /// The error encountered during the prediction, if any. 65 | public let error: Error? 66 | 67 | /// Logging output for the prediction. 68 | public let logs: String? 69 | 70 | /// Metrics for the prediction. 71 | public let metrics: Metrics? 72 | 73 | /// When the prediction was created. 74 | public let createdAt: Date 75 | 76 | /// When the prediction was started 77 | public let startedAt: Date? 78 | 79 | /// When the prediction was completed. 80 | public let completedAt: Date? 81 | 82 | /// A convenience object that can be used to construct new API requests against the given prediction. 83 | public let urls: [String: URL] 84 | 85 | // MARK: - 86 | 87 | @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) 88 | public var progress: Progress? { 89 | guard let logs = self.logs else { return nil } 90 | 91 | let regex: Regex = #/^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/# 92 | 93 | let lines = logs.split(separator: "\n") 94 | guard !lines.isEmpty else { return nil } 95 | 96 | for line in lines.reversed() { 97 | let lineString = String(line).trimmingCharacters(in: .whitespaces) 98 | if let match = try? regex.firstMatch(in: lineString), 99 | let current = Int64(match.output.2), 100 | let total = Int64(match.output.3) 101 | { 102 | let progress = Progress(totalUnitCount: total) 103 | progress.completedUnitCount = current 104 | return progress 105 | } 106 | } 107 | 108 | return nil 109 | } 110 | 111 | // MARK: - 112 | 113 | /// Wait for the prediction to complete. 114 | /// 115 | /// - Parameters: 116 | /// - client: 117 | /// The client used to make API requests. 118 | /// - priority: 119 | /// The task priority. 120 | /// - updateHandler: 121 | /// A closure that executes with the updated prediction 122 | /// after each polling request to the API. 123 | /// If the prediction is in a terminal state 124 | /// (e.g. `succeeded`, `failed`, or `canceled`), 125 | /// it's returned immediately and the closure is not executed. 126 | /// Use this to provide feedback to the user 127 | /// about the progress of the prediction, 128 | /// or throw `CancellationError` to stop waiting 129 | /// for the prediction to finish. 130 | /// - Returns: The completed prediction. 131 | /// - Important: 132 | /// Returning early from the `updateHandler` closure 133 | /// doesn't cancel the prediction. 134 | /// To cancel the prediction, 135 | /// call ``cancel(with:)``. 136 | /// - Throws: 137 | /// ``CancellationError`` if the prediction was canceled, 138 | /// or any error thrown from the `updateHandler` closure 139 | /// other than ``CancellationError``. 140 | public mutating func wait( 141 | with client: Client, 142 | priority: TaskPriority? = nil, 143 | updateHandler: @escaping (Self) throws -> Void = { _ in () } 144 | ) async throws { 145 | var retrier: Client.RetryPolicy.Retrier = client.retryPolicy.makeIterator() 146 | self = try await Self.wait(for: self, 147 | with: client, 148 | priority: priority, 149 | retrier: &retrier, 150 | updateHandler: updateHandler) 151 | } 152 | 153 | /// Waits for a prediction to complete and returns the updated prediction. 154 | /// 155 | /// - Parameters: 156 | /// - current: 157 | /// The prediction to wait for. 158 | /// - client: 159 | /// The client used to make API requests. 160 | /// - priority: 161 | /// The task priority. 162 | /// - retrier: 163 | /// An instance of the client retry policy. 164 | /// - updateHandler: 165 | /// A closure that executes with the updated prediction 166 | /// after each polling request to the API. 167 | /// If the prediction is in a terminal state 168 | /// (e.g. `succeeded`, `failed`, or `canceled`), 169 | /// it's returned immediately and the closure is not executed. 170 | /// Use this to provide feedback to the user 171 | /// about the progress of the prediction, 172 | /// or throw `CancellationError` to stop waiting 173 | /// for the prediction to finish. 174 | /// - Returns: The completed prediction. 175 | /// - Important: 176 | /// Returning early from the `updateHandler` closure 177 | /// doesn't cancel the prediction. 178 | /// To cancel the prediction, 179 | /// call ``cancel(with:)``. 180 | /// - Throws: 181 | /// ``CancellationError`` if the prediction was canceled, 182 | /// or any error thrown from the `updateHandler` closure 183 | /// other than ``CancellationError``. 184 | public static func wait( 185 | for current: Self, 186 | with client: Client, 187 | priority: TaskPriority? = nil, 188 | retrier: inout Client.RetryPolicy.Iterator, 189 | updateHandler: @escaping (Self) throws -> Void = { _ in () } 190 | ) async throws -> Self { 191 | guard !current.status.terminated else { return current } 192 | guard let delay = retrier.next() else { throw CancellationError() } 193 | 194 | let id = current.id 195 | let updated = try await withThrowingTaskGroup(of: Self.self) { group in 196 | group.addTask { 197 | try await Task.sleep(nanoseconds: UInt64(delay * 1e+9)) 198 | return try await client.getPrediction(Self.self, id: id) 199 | } 200 | 201 | if let deadline = retrier.deadline { 202 | group.addTask { 203 | try await Task.sleep(nanoseconds: deadline.uptimeNanoseconds - DispatchTime.now().uptimeNanoseconds) 204 | throw CancellationError() 205 | } 206 | } 207 | 208 | let value = try await group.next() 209 | group.cancelAll() 210 | 211 | return value ?? current 212 | } 213 | 214 | if updated.status.terminated { 215 | return updated 216 | } else { 217 | do { 218 | try updateHandler(updated) 219 | } catch is CancellationError { 220 | return current 221 | } catch { 222 | throw error 223 | } 224 | 225 | return try await wait(for: updated, 226 | with: client, 227 | priority: priority, 228 | retrier: &retrier, 229 | updateHandler: updateHandler) 230 | } 231 | } 232 | 233 | /// Cancel the prediction. 234 | /// 235 | /// - Parameters: 236 | /// - client: The client used to make API requests. 237 | public mutating func cancel(with client: Client) async throws { 238 | self = try await client.cancelPrediction(Self.self, id: id) 239 | } 240 | } 241 | 242 | // MARK: - Decodable 243 | 244 | extension Prediction: Codable { 245 | private enum CodingKeys: String, CodingKey { 246 | case id 247 | case modelID = "model" 248 | case versionID = "version" 249 | case source 250 | case input 251 | case output 252 | case status 253 | case error 254 | case logs 255 | case metrics 256 | case createdAt = "created_at" 257 | case startedAt = "started_at" 258 | case completedAt = "completed_at" 259 | case urls 260 | } 261 | } 262 | 263 | extension Prediction.Metrics: Codable { 264 | private enum CodingKeys: String, CodingKey { 265 | case predictTime = "predict_time" 266 | } 267 | 268 | public init(from decoder: Decoder) throws { 269 | let container = try decoder.container(keyedBy: CodingKeys.self) 270 | self.predictTime = try container.decodeIfPresent(TimeInterval.self, forKey: .predictTime) 271 | } 272 | } 273 | 274 | // MARK: - Hashable 275 | 276 | extension Prediction: Equatable where Input: Equatable, Output: Equatable {} 277 | extension Prediction: Hashable where Input: Hashable, Output: Hashable {} 278 | -------------------------------------------------------------------------------- /Sources/Replicate/Status.swift: -------------------------------------------------------------------------------- 1 | /// The status of the prediction or training. 2 | public enum Status: String, Hashable, Codable { 3 | /// The prediction or training is starting up. 4 | /// If this status lasts longer than a few seconds, 5 | /// then it's typically because a new worker is being started to run the prediction. 6 | case starting 7 | 8 | /// The `predict()` or `train()` method of the model is currently running. 9 | case processing 10 | 11 | /// The prediction or training completed successfully. 12 | case succeeded 13 | 14 | /// The prediction or training encountered an error during processing. 15 | case failed 16 | 17 | /// The prediction or training was canceled by the user. 18 | case canceled 19 | 20 | public var terminated: Bool { 21 | switch self { 22 | case .starting, .processing: 23 | return false 24 | default: 25 | return true 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /Sources/Replicate/Training.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.Date 2 | import struct Foundation.URL 3 | import struct Foundation.TimeInterval 4 | import struct Dispatch.DispatchTime 5 | 6 | /// A training with unspecified inputs and outputs. 7 | public typealias AnyTraining = Training<[String: Value]> 8 | 9 | /// A training made by a model hosted on Replicate. 10 | public struct Training: Identifiable where Input: Codable { 11 | public typealias ID = String 12 | 13 | public struct Output: Hashable, Codable { 14 | public var version: Model.Version.ID 15 | public var weights: URL? 16 | } 17 | 18 | /// Source for creating a training. 19 | public enum Source: String, Codable { 20 | /// The training was made on the web. 21 | case web 22 | 23 | /// The training was made using the API. 24 | case api 25 | } 26 | 27 | /// Metrics for the training. 28 | public struct Metrics: Hashable { 29 | /// How long it took to create the training, in seconds. 30 | public let predictTime: TimeInterval? 31 | } 32 | 33 | /// The unique ID of the training. 34 | /// Can be used to get a single training. 35 | /// 36 | /// - SeeAlso: ``Client/getTraining(id:)`` 37 | public let id: ID 38 | 39 | /// The version of the model used to create the training. 40 | public let versionID: Model.Version.ID 41 | 42 | /// Where the training was made. 43 | public let source: Source? 44 | 45 | /// The model's input as a JSON object. 46 | /// 47 | /// The input depends on what model you are running. 48 | /// To see the available inputs, 49 | /// click the "Run with API" tab on the model you are running. 50 | /// For example, 51 | /// [stability-ai/stable-diffusion](https://replicate.com/stability-ai/stable-diffusion) 52 | /// takes `prompt` as an input. 53 | /// 54 | /// Files should be passed as data URLs or HTTP URLs. 55 | public let input: Input 56 | 57 | /// The output of the model for the training, if completed successfully. 58 | public let output: Output? 59 | 60 | /// The status of the training. 61 | public let status: Status 62 | 63 | /// The error encountered during the training, if any. 64 | public let error: Error? 65 | 66 | /// Logging output for the training. 67 | public let logs: String? 68 | 69 | /// Metrics for the training. 70 | public let metrics: Metrics? 71 | 72 | /// When the training was created. 73 | public let createdAt: Date 74 | 75 | /// When the training was started 76 | public let startedAt: Date? 77 | 78 | /// When the training was completed. 79 | public let completedAt: Date? 80 | 81 | /// A convenience object that can be used to construct new API requests against the given training. 82 | public let urls: [String: URL] 83 | 84 | // MARK: - 85 | 86 | /// Cancel the training. 87 | /// 88 | /// - Parameters: 89 | /// - client: The client used to make API requests. 90 | public mutating func cancel(with client: Client) async throws { 91 | self = try await client.cancelTraining(Self.self, id: id) 92 | } 93 | } 94 | 95 | // MARK: - Decodable 96 | 97 | extension Training: Codable { 98 | private enum CodingKeys: String, CodingKey { 99 | case id 100 | case versionID = "version" 101 | case source 102 | case input 103 | case output 104 | case status 105 | case error 106 | case logs 107 | case metrics 108 | case createdAt = "created_at" 109 | case startedAt = "started_at" 110 | case completedAt = "completed_at" 111 | case urls 112 | } 113 | } 114 | 115 | extension Training.Metrics: Codable { 116 | private enum CodingKeys: String, CodingKey { 117 | case predictTime = "predict_time" 118 | } 119 | 120 | public init(from decoder: Decoder) throws { 121 | let container = try decoder.container(keyedBy: CodingKeys.self) 122 | self.predictTime = try container.decodeIfPresent(TimeInterval.self, forKey: .predictTime) 123 | } 124 | } 125 | 126 | // MARK: - Hashable 127 | 128 | extension Training: Equatable where Input: Equatable {} 129 | extension Training: Hashable where Input: Hashable {} 130 | -------------------------------------------------------------------------------- /Sources/Replicate/Value.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.Data 2 | import class Foundation.JSONDecoder 3 | import class Foundation.JSONEncoder 4 | 5 | /// A codable value. 6 | public enum Value: Hashable { 7 | case null 8 | case bool(Bool) 9 | case int(Int) 10 | case double(Double) 11 | case string(String) 12 | case data(mimeType: String? = nil, Data) 13 | case array([Value]) 14 | case object([String: Value]) 15 | 16 | /// Create a `Value` from a `Codable` value. 17 | /// - Parameter value: The codable value 18 | /// - Returns: A value 19 | public init(_ value: T) throws { 20 | if let valueAsValue = value as? Value { 21 | self = valueAsValue 22 | } else { 23 | let data = try JSONEncoder().encode(value) 24 | self = try JSONDecoder().decode(Value.self, from: data) 25 | } 26 | } 27 | 28 | /// Returns whether the value is `null`. 29 | public var isNull: Bool { 30 | return self == .null 31 | } 32 | 33 | /// Returns the `Bool` value if the value is a `bool`, 34 | /// otherwise returns `nil`. 35 | public var boolValue: Bool? { 36 | guard case let .bool(value) = self else { return nil } 37 | return value 38 | } 39 | 40 | /// Returns the `Int` value if the value is an `integer`, 41 | /// otherwise returns `nil`. 42 | public var intValue: Int? { 43 | guard case let .int(value) = self else { return nil } 44 | return value 45 | } 46 | 47 | /// Returns the `Double` value if the value is a `double`, 48 | /// otherwise returns `nil`. 49 | public var doubleValue: Double? { 50 | guard case let .double(value) = self else { return nil } 51 | return value 52 | } 53 | 54 | /// Returns the `String` value if the value is a `string`, 55 | /// otherwise returns `nil`. 56 | public var stringValue: String? { 57 | guard case let .string(value) = self else { return nil } 58 | return value 59 | } 60 | 61 | /// Returns the data value and optional MIME type if the value is `data`, 62 | /// otherwise returns `nil`. 63 | public var dataValue: (mimeType: String?, Data)? { 64 | guard case let .data(mimeType: mimeType, data) = self else { return nil } 65 | return (mimeType: mimeType, data) 66 | } 67 | 68 | /// Returns the `[Value]` value if the value is an `array`, 69 | /// otherwise returns `nil`. 70 | public var arrayValue: [Value]? { 71 | guard case let .array(value) = self else { return nil } 72 | return value 73 | } 74 | 75 | /// Returns the `[String: Value]` value if the value is an `object`, 76 | /// otherwise returns `nil`. 77 | public var objectValue: [String: Value]? { 78 | guard case let .object(value) = self else { return nil } 79 | return value 80 | } 81 | } 82 | 83 | // MARK: - Codable 84 | 85 | extension Value: Codable { 86 | public init(from decoder: Decoder) throws { 87 | let container = try decoder.singleValueContainer() 88 | 89 | if container.decodeNil() { 90 | self = .null 91 | } else if let value = try? container.decode(Bool.self) { 92 | self = .bool(value) 93 | } else if let value = try? container.decode(Int.self) { 94 | self = .int(value) 95 | } else if let value = try? container.decode(Double.self) { 96 | self = .double(value) 97 | } else if let value = try? container.decode(String.self) { 98 | if Data.isURIEncoded(string: value), 99 | case let (mimeType, data)? = Data.decode(uriEncoded: value) 100 | { 101 | self = .data(mimeType: mimeType, data) 102 | } else { 103 | self = .string(value) 104 | } 105 | } else if let value = try? container.decode([Value].self) { 106 | self = .array(value) 107 | } else if let value = try? container.decode([String: Value].self) { 108 | self = .object(value) 109 | } else { 110 | throw DecodingError.dataCorruptedError(in: container, debugDescription: "Value type not found") 111 | } 112 | } 113 | 114 | public func encode(to encoder: Encoder) throws { 115 | var container = encoder.singleValueContainer() 116 | 117 | switch self { 118 | case .null: 119 | try container.encodeNil() 120 | case .bool(let value): 121 | try container.encode(value) 122 | case .int(let value): 123 | try container.encode(value) 124 | case .double(let value): 125 | try container.encode(value) 126 | case .string(let value): 127 | try container.encode(value) 128 | case let .data(mimeType, value): 129 | try container.encode(value.uriEncoded(mimeType: mimeType)) 130 | case .array(let value): 131 | try container.encode(value) 132 | case .object(let value): 133 | try container.encode(value) 134 | } 135 | } 136 | } 137 | 138 | extension Value: CustomStringConvertible { 139 | public var description: String { 140 | switch self { 141 | case .null: 142 | return "" 143 | case .bool(let value): 144 | return value.description 145 | case .int(let value): 146 | return value.description 147 | case .double(let value): 148 | return value.description 149 | case .string(let value): 150 | return value.description 151 | case let .data(mimeType, value): 152 | return value.uriEncoded(mimeType: mimeType) 153 | case .array(let value): 154 | return value.description 155 | case .object(let value): 156 | return value.description 157 | } 158 | } 159 | } 160 | 161 | // MARK: - ExpressibleByNilLiteral 162 | 163 | extension Value: ExpressibleByNilLiteral { 164 | public init(nilLiteral: ()) { 165 | self = .null 166 | } 167 | } 168 | 169 | // MARK: - ExpressibleByBooleanLiteral 170 | 171 | extension Value: ExpressibleByBooleanLiteral { 172 | public init(booleanLiteral value: Bool) { 173 | self = .bool(value) 174 | } 175 | } 176 | 177 | // MARK: - ExpressibleByIntegerLiteral 178 | 179 | extension Value: ExpressibleByIntegerLiteral { 180 | public init(integerLiteral value: Int) { 181 | self = .int(value) 182 | } 183 | } 184 | 185 | // MARK: - ExpressibleByFloatLiteral 186 | 187 | extension Value: ExpressibleByFloatLiteral { 188 | public init(floatLiteral value: Double) { 189 | self = .double(value) 190 | } 191 | } 192 | 193 | // MARK: - ExpressibleByStringLiteral 194 | 195 | extension Value: ExpressibleByStringLiteral { 196 | public init(stringLiteral value: String) { 197 | self = .string(value) 198 | } 199 | } 200 | 201 | // MARK: - ExpressibleByArrayLiteral 202 | 203 | extension Value: ExpressibleByArrayLiteral { 204 | public init(arrayLiteral elements: Value...) { 205 | self = .array(elements) 206 | } 207 | } 208 | 209 | // MARK: - ExpressibleByDictionaryLiteral 210 | 211 | extension Value: ExpressibleByDictionaryLiteral { 212 | public init(dictionaryLiteral elements: (String, Value)...) { 213 | var dictionary: [String: Value] = [:] 214 | for (key, value) in elements { 215 | dictionary[key] = value 216 | } 217 | self = .object(dictionary) 218 | } 219 | } 220 | 221 | // MARK: - ExpressibleByStringInterpolation 222 | 223 | extension Value: ExpressibleByStringInterpolation { 224 | public struct StringInterpolation: StringInterpolationProtocol { 225 | var stringValue: String 226 | 227 | public init(literalCapacity: Int, interpolationCount: Int) { 228 | self.stringValue = "" 229 | self.stringValue.reserveCapacity(literalCapacity + interpolationCount) 230 | } 231 | 232 | public mutating func appendLiteral(_ literal: String) { 233 | self.stringValue.append(literal) 234 | } 235 | 236 | public mutating func appendInterpolation(_ value: T) { 237 | self.stringValue.append(value.description) 238 | } 239 | } 240 | 241 | public init(stringInterpolation: StringInterpolation) { 242 | self = .string(stringInterpolation.stringValue) 243 | } 244 | } 245 | -------------------------------------------------------------------------------- /Sources/Replicate/Webhook.swift: -------------------------------------------------------------------------------- 1 | import struct Foundation.URL 2 | 3 | /// A structure representing a webhook configuration. 4 | /// 5 | /// A webhook is an HTTPS URL that receives a `POST` request 6 | /// when a prediction or training generates new output, logs, 7 | /// or reaches a terminal state (succeeded / canceled / failed). 8 | /// 9 | /// Example usage: 10 | /// 11 | /// ```swift 12 | /// let webhook = Webhook(url: someURL, events: [.start, .completed]) 13 | /// ``` 14 | public struct Webhook { 15 | /// The events that can trigger a webhook request. 16 | public enum Event: String, Hashable, CaseIterable { 17 | /// Occurs immediately on prediction start. 18 | case start 19 | 20 | /// Occurs each time a prediction generates an output. 21 | case output 22 | 23 | /// Occurs each time log output is generated by a prediction. 24 | case logs 25 | 26 | /// Occurs when the prediction reaches a terminal state (succeeded / canceled / failed). 27 | /// - SeeAlso: ``Status/terminated`` 28 | case completed 29 | } 30 | 31 | /// The webhook URL. 32 | public let url: URL 33 | 34 | /// A set of events that trigger the webhook. 35 | public let events: Set 36 | 37 | /// Creates a new `Webhook` instance with the specified URL and events. 38 | /// 39 | /// By default, the webhook will be triggered by all event types. 40 | /// You can change which events trigger the webhook by specifying a 41 | /// sequence of events as `events` parameter. 42 | /// 43 | /// - Parameters: 44 | /// - url: The webhook URL. 45 | /// - events: A sequence of events that trigger the webhook. 46 | /// Defaults to all event types. 47 | public init( 48 | url: URL, 49 | events: S = Event.allCases 50 | ) where S.Element == Event 51 | { 52 | self.url = url 53 | self.events = Set(events) 54 | } 55 | } 56 | 57 | // MARK: - CustomStringConvertible 58 | 59 | extension Webhook.Event: CustomStringConvertible { 60 | public var description: String { 61 | return self.rawValue 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Tests/ReplicateTests/ClientTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | @testable import Replicate 3 | 4 | final class ClientTests: XCTestCase { 5 | var client = Client.valid 6 | 7 | static override func setUp() { 8 | URLProtocol.registerClass(MockURLProtocol.self) 9 | } 10 | 11 | func testRunWithVersion() async throws { 12 | let identifier: Identifier = "test/example:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 13 | let output = try await client.run(identifier, input: ["text": "Alice"]) 14 | XCTAssertEqual(output, ["Hello, Alice!"]) 15 | } 16 | 17 | func testRunWithModel() async throws { 18 | let identifier: Identifier = "meta/llama-2-70b-chat" 19 | let output = try await client.run(identifier, input: ["prompt": "Please write a haiku about llamas."]) 20 | XCTAssertEqual(output, ["I'm sorry, I'm afraid I can't do that"] ) 21 | } 22 | 23 | func testRunWithInvalidVersion() async throws { 24 | let identifier: Identifier = "test/example:invalid" 25 | do { 26 | _ = try await client.run(identifier, input: ["text": "Alice"]) 27 | XCTFail() 28 | } catch { 29 | XCTAssertEqual(error.localizedDescription, "Invalid version") 30 | } 31 | } 32 | 33 | func testCreatePredictionWithVersion() async throws { 34 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 35 | let prediction = try await client.createPrediction(version: version, input: ["text": "Alice"]) 36 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq") 37 | XCTAssertEqual(prediction.versionID, version) 38 | XCTAssertEqual(prediction.status, .starting) 39 | } 40 | 41 | func testCreatePredictionWithModel() async throws { 42 | let model: Model.ID = "meta/llama-2-70b-chat" 43 | let prediction = try await client.createPrediction(model: model, input: ["prompt": "Please write a poem about camelids"]) 44 | XCTAssertEqual(prediction.id, "heat2o3bzn3ahtr6bjfftvbaci") 45 | XCTAssertEqual(prediction.modelID, model) 46 | XCTAssertEqual(prediction.status, .starting) 47 | } 48 | 49 | func testCreatePredictionUsingDeployment() async throws { 50 | let deployment: Deployment.ID = "replicate/deployment" 51 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 52 | let prediction = try await client.createPrediction(deployment: deployment, input: ["text": "Alice"]) 53 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq") 54 | XCTAssertEqual(prediction.versionID, version) 55 | XCTAssertEqual(prediction.status, .starting) 56 | } 57 | 58 | func testCreatePredictionWithInvalidVersion() async throws { 59 | let version: Model.Version.ID = "invalid" 60 | do { 61 | _ = try await client.createPrediction(version: version, input: ["text": "Alice"]) 62 | XCTFail() 63 | } catch { 64 | XCTAssertEqual(error.localizedDescription, "Invalid version") 65 | } 66 | } 67 | 68 | func testCreatePredictionAndWait() async throws { 69 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 70 | var prediction = try await client.createPrediction(version: version, input: ["text": "Alice"]) 71 | try await prediction.wait(with: client) 72 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq") 73 | XCTAssertEqual(prediction.versionID, version) 74 | XCTAssertEqual(prediction.status, .succeeded) 75 | } 76 | 77 | func testPredictionWaitWithStop() async throws { 78 | var prediction = try await client.getPrediction(id: "r6bjfddngldkt2o3bzn3ahtaci") 79 | 80 | var didUpdate = false 81 | try await prediction.wait(with: client) { _ in 82 | didUpdate = true 83 | throw CancellationError() 84 | } 85 | 86 | XCTAssertEqual(prediction.id, "r6bjfddngldkt2o3bzn3ahtaci") 87 | XCTAssertEqual(prediction.status, .processing) 88 | XCTAssertTrue(didUpdate) 89 | } 90 | 91 | func testPredictionWaitWithNonCancellationError() async throws { 92 | var prediction = try await client.getPrediction(id: "r6bjfddngldkt2o3bzn3ahtaci") 93 | 94 | struct CustomError: Swift.Error {} 95 | 96 | do { 97 | try await prediction.wait(with: client) { _ in 98 | throw CustomError() 99 | } 100 | XCTFail("Expected CustomError to be thrown") 101 | } catch { 102 | XCTAssertTrue(error is CustomError, "Expected CustomError, but got \(type(of: error))") 103 | } 104 | 105 | XCTAssertEqual(prediction.id, "r6bjfddngldkt2o3bzn3ahtaci") 106 | XCTAssertEqual(prediction.status, .processing) 107 | } 108 | 109 | 110 | func testGetPrediction() async throws { 111 | let prediction = try await client.getPrediction(id: "ufawqhfynnddngldkgtslldrkq") 112 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq") 113 | XCTAssertEqual(prediction.versionID, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") 114 | XCTAssertEqual(prediction.source, .web) 115 | XCTAssertEqual(prediction.status, .succeeded) 116 | XCTAssertEqual(prediction.createdAt.timeIntervalSinceReferenceDate, 672703986.224, accuracy: 1) 117 | XCTAssertEqual(prediction.urls["cancel"]?.absoluteString, "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel") 118 | 119 | if #available(macOS 13.0, *) { 120 | XCTAssertEqual(prediction.progress?.completedUnitCount, 5) 121 | XCTAssertEqual(prediction.progress?.totalUnitCount, 5) 122 | } 123 | } 124 | 125 | func testCancelPrediction() async throws { 126 | let prediction = try await client.cancelPrediction(id: "ufawqhfynnddngldkgtslldrkq") 127 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq") 128 | } 129 | 130 | func testGetPredictions() async throws { 131 | let predictions = try await client.listPredictions() 132 | XCTAssertNil(predictions.previous) 133 | XCTAssertEqual(predictions.next, "cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") 134 | XCTAssertEqual(predictions.results.count, 1) 135 | } 136 | 137 | func testListModels() async throws { 138 | let models = try await client.listModels() 139 | XCTAssertEqual(models.results.count, 1) 140 | XCTAssertEqual(models.results.first?.owner, "replicate") 141 | XCTAssertEqual(models.results.first?.name, "hello-world") 142 | } 143 | 144 | func testGetModel() async throws { 145 | let model = try await client.getModel("replicate/hello-world") 146 | XCTAssertEqual(model.owner, "replicate") 147 | XCTAssertEqual(model.name, "hello-world") 148 | } 149 | 150 | func testGetModelVersions() async throws { 151 | let versions = try await client.listModelVersions("replicate/hello-world") 152 | XCTAssertNil(versions.previous) 153 | XCTAssertNil(versions.next) 154 | XCTAssertEqual(versions.results.count, 2) 155 | XCTAssertEqual(versions.results.first?.id, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") 156 | } 157 | 158 | func testGetModelVersion() async throws { 159 | let version = try await client.getModelVersion("replicate/hello-world", version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") 160 | XCTAssertEqual(version.id, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") 161 | } 162 | 163 | func testCreateModel() async throws { 164 | let model = try await client.createModel(owner: "replicate", name: "hello-world", visibility: .public, hardware: "cpu") 165 | XCTAssertEqual(model.owner, "replicate") 166 | XCTAssertEqual(model.name, "hello-world") 167 | } 168 | 169 | func testListModelCollections() async throws { 170 | let collections = try await client.listModelCollections() 171 | XCTAssertEqual(collections.results.count, 1) 172 | 173 | XCTAssertEqual(collections.results.first?.slug, "super-resolution") 174 | XCTAssertEqual(collections.results.first?.models, nil) 175 | 176 | let collection = try await client.getModelCollection(collections.results.first!.slug) 177 | XCTAssertEqual(collection.slug, "super-resolution") 178 | XCTAssertEqual(collection.models, []) 179 | } 180 | 181 | func testGetModelCollection() async throws { 182 | let collection = try await client.getModelCollection("super-resolution") 183 | XCTAssertEqual(collection.slug, "super-resolution") 184 | } 185 | 186 | func testCreateTraining() async throws { 187 | let base: Model.ID = "example/base" 188 | let version: Model.Version.ID = "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a" 189 | let destination: Model.ID = "my/fork" 190 | let training = try await client.createTraining(model: base, version: version, destination: destination, input: ["data": "..."]) 191 | XCTAssertEqual(training.id, "zz4ibbonubfz7carwiefibzgga") 192 | XCTAssertEqual(training.versionID, version) 193 | XCTAssertEqual(training.status, .starting) 194 | } 195 | 196 | func testGetTraining() async throws { 197 | let training = try await client.getTraining(id: "zz4ibbonubfz7carwiefibzgga") 198 | XCTAssertEqual(training.id, "zz4ibbonubfz7carwiefibzgga") 199 | XCTAssertEqual(training.versionID, "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a") 200 | XCTAssertEqual(training.source, .web) 201 | XCTAssertEqual(training.status, .succeeded) 202 | XCTAssertEqual(training.createdAt.timeIntervalSinceReferenceDate, 703980786.224, accuracy: 1) 203 | XCTAssertEqual(training.urls["cancel"]?.absoluteString, "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel") 204 | } 205 | 206 | func testCancelTraining() async throws { 207 | let training = try await client.cancelTraining(id: "zz4ibbonubfz7carwiefibzgga") 208 | XCTAssertEqual(training.id, "zz4ibbonubfz7carwiefibzgga") 209 | } 210 | 211 | func testGetTrainings() async throws { 212 | let trainings = try await client.listTrainings() 213 | XCTAssertNil(trainings.previous) 214 | XCTAssertEqual(trainings.next, "g5FWfcbO0EdVeR27rkXr0Z6tI0MjrW34ZejxnGzDeND3phpWWsyMGCQD") 215 | XCTAssertEqual(trainings.results.count, 1) 216 | } 217 | 218 | func testListHardware() async throws { 219 | let hardware = try await client.listHardware() 220 | XCTAssertGreaterThan(hardware.count, 1) 221 | XCTAssertEqual(hardware.first?.name, "CPU") 222 | XCTAssertEqual(hardware.first?.sku, "cpu") 223 | } 224 | 225 | func testCurrentAccount() async throws { 226 | let account = try await client.getCurrentAccount() 227 | XCTAssertEqual(account.type, .organization) 228 | XCTAssertEqual(account.username, "replicate") 229 | XCTAssertEqual(account.name, "Replicate") 230 | XCTAssertEqual(account.githubURL?.absoluteString, "https://github.com/replicate") 231 | } 232 | 233 | func testGetDeployment() async throws { 234 | let deployment = try await client.getDeployment("replicate/my-app-image-generator") 235 | XCTAssertEqual(deployment.owner, "replicate") 236 | XCTAssertEqual(deployment.name, "my-app-image-generator") 237 | XCTAssertEqual(deployment.currentRelease?.number, 1) 238 | XCTAssertEqual(deployment.currentRelease?.model, "stability-ai/sdxl") 239 | XCTAssertEqual(deployment.currentRelease?.version, "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf") 240 | XCTAssertEqual(deployment.currentRelease!.createdAt.timeIntervalSinceReferenceDate, 729707577.01, accuracy: 1) 241 | XCTAssertEqual(deployment.currentRelease?.createdBy.type, .organization) 242 | XCTAssertEqual(deployment.currentRelease?.createdBy.username, "replicate") 243 | XCTAssertEqual(deployment.currentRelease?.createdBy.name, "Replicate, Inc.") 244 | XCTAssertEqual(deployment.currentRelease?.createdBy.githubURL?.absoluteString, "https://github.com/replicate") 245 | XCTAssertEqual(deployment.currentRelease?.configuration.hardware, "gpu-t4") 246 | XCTAssertEqual(deployment.currentRelease?.configuration.scaling.minInstances, 1) 247 | XCTAssertEqual(deployment.currentRelease?.configuration.scaling.maxInstances, 5) 248 | } 249 | 250 | func testCustomBaseURL() async throws { 251 | let client = Client(baseURLString: "https://v1.replicate.proxy", token: MockURLProtocol.validToken).mocked 252 | let collection = try await client.getModelCollection("super-resolution") 253 | XCTAssertEqual(collection.slug, "super-resolution") 254 | } 255 | 256 | func testInvalidToken() async throws { 257 | do { 258 | let _ = try await Client.invalid.listPredictions() 259 | XCTFail("unauthenticated requests should fail") 260 | } catch { 261 | guard let error = error as? Replicate.Error else { 262 | return XCTFail("invalid error") 263 | } 264 | 265 | XCTAssertEqual(error.detail, "Invalid token.") 266 | } 267 | } 268 | 269 | func testUnauthenticated() async throws { 270 | do { 271 | let _ = try await Client.unauthenticated.listPredictions() 272 | XCTFail("unauthenticated requests should fail") 273 | } catch { 274 | guard let error = error as? Replicate.Error else { 275 | return XCTFail("invalid error") 276 | } 277 | 278 | XCTAssertEqual(error.detail, "Authentication credentials were not provided.") 279 | } 280 | } 281 | 282 | func testSearchModels() async throws { 283 | let models = try await client.searchModels(query: "greeter") 284 | XCTAssertEqual(models.results.count, 1) 285 | XCTAssertEqual(models.results[0].owner, "replicate") 286 | XCTAssertEqual(models.results[0].name, "hello-world") 287 | } 288 | } 289 | -------------------------------------------------------------------------------- /Tests/ReplicateTests/DateDecodingTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | @testable import Replicate 3 | 4 | struct Value: Decodable { 5 | let date: Date 6 | } 7 | 8 | final class DateDecodingTests: XCTestCase { 9 | func testISO8601Timestamp() throws { 10 | let decoder = JSONDecoder() 11 | decoder.dateDecodingStrategy = .iso8601WithFractionalSeconds 12 | 13 | let timestamp = "2023-10-29T01:23:45Z" 14 | let json = #"{"date": "\#(timestamp)"}"# 15 | let value = try decoder.decode(Value.self, from: json.data(using: .utf8)!) 16 | XCTAssertEqual(value.date.timeIntervalSince1970, 1698542625) 17 | } 18 | 19 | func testISO8601TimestampWithFractionalSeconds() throws { 20 | let decoder = JSONDecoder() 21 | decoder.dateDecodingStrategy = .iso8601WithFractionalSeconds 22 | 23 | let timestamp = "2023-10-29T01:23:45.678900Z" 24 | let json = #"{"date": "\#(timestamp)"}"# 25 | let value = try decoder.decode(Value.self, from: json.data(using: .utf8)!) 26 | XCTAssertEqual(value.date.timeIntervalSince1970, 1698542625.678, accuracy: 0.1) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /Tests/ReplicateTests/Helpers/MockURLProtocol.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | #if canImport(FoundationNetworking) 4 | import FoundationNetworking 5 | #endif 6 | 7 | @testable import Replicate 8 | 9 | class MockURLProtocol: URLProtocol { 10 | static let validToken = "" 11 | 12 | override class func canInit(with task: URLSessionTask) -> Bool { 13 | return true 14 | } 15 | 16 | override class func canInit(with request: URLRequest) -> Bool { 17 | return true 18 | } 19 | 20 | override class func canonicalRequest(for request: URLRequest) -> URLRequest { 21 | return request 22 | } 23 | 24 | override func startLoading() { 25 | let statusCode: Int 26 | let json: String 27 | 28 | switch request.value(forHTTPHeaderField: "Authorization") { 29 | case "Bearer \(Self.validToken)": 30 | switch (request.httpMethod, request.url?.absoluteString) { 31 | case ("GET", "https://api.replicate.com/v1/account"?): 32 | statusCode = 200 33 | json = #""" 34 | { 35 | "type": "organization", 36 | "username": "replicate", 37 | "name": "Replicate", 38 | "github_url": "https://github.com/replicate" 39 | } 40 | """# 41 | case ("GET", "https://api.replicate.com/v1/predictions"?): 42 | statusCode = 200 43 | json = #""" 44 | { 45 | "previous": null, 46 | "next": "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", 47 | "results": [ 48 | { 49 | "id": "ufawqhfynnddngldkgtslldrkq", 50 | "model": "replicate/hello-world", 51 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 52 | "urls": { 53 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", 54 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" 55 | }, 56 | "created_at": "2022-04-26T22:13:06.224088Z", 57 | "completed_at": "2022-04-26T22:13:06.580379Z", 58 | "source": "web", 59 | "status": "starting", 60 | "input": { 61 | "text": "Alice" 62 | }, 63 | "output": null, 64 | "error": null, 65 | "logs": null, 66 | "metrics": {} 67 | } 68 | ] 69 | } 70 | """# 71 | case ("POST", "https://api.replicate.com/v1/predictions"?), 72 | ("POST", "https://api.replicate.com/v1/deployments/replicate/deployment/predictions"?): 73 | 74 | if let body = request.json, 75 | body["version"] as? String == "invalid" 76 | { 77 | statusCode = 400 78 | json = #"{ "detail" : "Invalid version" }"# 79 | } else { 80 | statusCode = 201 81 | 82 | json = #""" 83 | { 84 | "id": "ufawqhfynnddngldkgtslldrkq", 85 | "model": "replicate/hello-world", 86 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 87 | "urls": { 88 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", 89 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" 90 | }, 91 | "created_at": "2022-04-26T22:13:06.224088Z", 92 | "completed_at": "2022-04-26T22:13:06.580379Z", 93 | "source": "web", 94 | "status": "starting", 95 | "input": { 96 | "text": "Alice" 97 | }, 98 | "output": null, 99 | "error": null, 100 | "logs": null, 101 | "metrics": {} 102 | } 103 | """# 104 | } 105 | case ("POST", "https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions"?): 106 | statusCode = 201 107 | json = #""" 108 | { 109 | "id": "heat2o3bzn3ahtr6bjfftvbaci", 110 | "model": "meta/llama-2-70b-chat", 111 | "version": "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", 112 | "input": { 113 | "prompt": "Please write a haiku about llamas." 114 | }, 115 | "logs": "", 116 | "error": null, 117 | "status": "starting", 118 | "created_at": "2023-11-27T13:35:45.99397566Z", 119 | "urls": { 120 | "cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", 121 | "get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci" 122 | } 123 | } 124 | """# 125 | case ("GET", "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq"?): 126 | statusCode = 200 127 | json = #""" 128 | { 129 | "id": "ufawqhfynnddngldkgtslldrkq", 130 | "model": "replicate/hello-world", 131 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", 132 | "urls": { 133 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", 134 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" 135 | }, 136 | "created_at": "2022-04-26T22:13:06.224088Z", 137 | "completed_at": "2022-04-26T22:15:06.224088Z", 138 | "source": "web", 139 | "status": "succeeded", 140 | "input": { 141 | "text": "Alice" 142 | }, 143 | "output": ["Hello, Alice!"], 144 | "error": null, 145 | "logs": "Using seed: 12345,\n0%| | 0/5 [00:00").mocked 861 | } 862 | 863 | static var unauthenticated: Client { 864 | return Client(token: "").mocked 865 | } 866 | 867 | var mocked: Self { 868 | let configuration = session.configuration 869 | configuration.protocolClasses = [MockURLProtocol.self] 870 | session = URLSession(configuration: configuration) 871 | return self 872 | } 873 | } 874 | 875 | private extension URLRequest { 876 | var json: [String: Any]? { 877 | var data = httpBody 878 | if let stream = httpBodyStream { 879 | let bufferSize = 1024 880 | data = Data() 881 | stream.open() 882 | 883 | while stream.hasBytesAvailable { 884 | var buffer = [UInt8](repeating: 0, count: bufferSize) 885 | let bytesRead = stream.read(&buffer, maxLength: bufferSize) 886 | if bytesRead > 0 { 887 | data?.append(buffer, count: bytesRead) 888 | } else { 889 | break 890 | } 891 | } 892 | } 893 | 894 | guard let data = data else { return nil } 895 | return try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] 896 | } 897 | } 898 | -------------------------------------------------------------------------------- /Tests/ReplicateTests/PredictionTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | @testable import Replicate 3 | 4 | final class PredictionTests: XCTestCase { 5 | var client = Client.valid 6 | 7 | static override func setUp() { 8 | URLProtocol.registerClass(MockURLProtocol.self) 9 | } 10 | 11 | func testWait() async throws { 12 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 13 | var prediction = try await client.createPrediction(version: version, input: ["text": "Alice"]) 14 | XCTAssertEqual(prediction.status, .starting) 15 | 16 | try await prediction.wait(with: client) 17 | XCTAssertEqual(prediction.status, .succeeded) 18 | } 19 | 20 | func testCancel() async throws { 21 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" 22 | var prediction = try await client.createPrediction(version: version, input: ["text": "Alice"]) 23 | XCTAssertEqual(prediction.status, .starting) 24 | 25 | try await prediction.cancel(with: client) 26 | XCTAssertEqual(prediction.status, .canceled) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /Tests/ReplicateTests/RetryPolicyTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | @testable import Replicate 3 | 4 | final class RetryPolicyTests: XCTestCase { 5 | func testConstantBackoffStrategy() throws { 6 | let policy = Client.RetryPolicy(strategy: .constant(duration: 1.0, 7 | jitter: 0.0), 8 | timeout: nil, 9 | maximumInterval: nil, 10 | maximumRetries: 5) 11 | 12 | XCTAssertEqual(Array(policy), [1.0, 1.0, 1.0, 1.0, 1.0]) 13 | } 14 | 15 | func testExponentialBackoffStrategy() throws { 16 | let policy = Client.RetryPolicy(strategy: .exponential(base: 1.0, 17 | multiplier: 2.0, 18 | jitter: 0.0), 19 | timeout: nil, 20 | maximumInterval: 30.0, 21 | maximumRetries: 7) 22 | 23 | XCTAssertEqual(Array(policy), [1.0, 2.0, 4.0, 8.0, 16.0, 30.0, 30.0]) 24 | } 25 | 26 | func testTimeoutWithDeadline() throws { 27 | let timeout: TimeInterval = 300.0 28 | let policy = Client.RetryPolicy(strategy: .constant(), 29 | timeout: timeout, 30 | maximumInterval: nil, 31 | maximumRetries: nil) 32 | let deadline = policy.makeIterator().deadline 33 | 34 | XCTAssertNotNil(deadline) 35 | XCTAssertLessThanOrEqual(deadline ?? .distantFuture, DispatchTime.now().advanced(by: .nanoseconds(Int(timeout * 1e+9)))) 36 | } 37 | 38 | func testTimeoutWithoutDeadline() throws { 39 | let policy = Client.RetryPolicy(strategy: .constant(), 40 | timeout: nil, 41 | maximumInterval: nil, 42 | maximumRetries: nil) 43 | let deadline = policy.makeIterator().deadline 44 | 45 | XCTAssertNil(deadline) 46 | } 47 | 48 | func testMaximumInterval() throws { 49 | let maximumInterval: TimeInterval = 30.0 50 | let policy = Client.RetryPolicy(strategy: .exponential(), 51 | timeout: nil, 52 | maximumInterval: maximumInterval, 53 | maximumRetries: 10) 54 | 55 | XCTAssertGreaterThanOrEqual(maximumInterval, policy.max() ?? .greatestFiniteMagnitude) 56 | } 57 | 58 | func testMaximumRetries() throws { 59 | let maximumRetries: Int = 5 60 | let policy = Client.RetryPolicy(strategy: .constant(), 61 | timeout: nil, 62 | maximumInterval: nil, 63 | maximumRetries: maximumRetries) 64 | 65 | XCTAssertEqual(Array(policy).count, maximumRetries) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Tests/ReplicateTests/URIEncodingTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | @testable import Replicate 3 | 4 | final class URIEncodingTests: XCTestCase { 5 | func testDataURIEncoding() throws { 6 | let string = "Hello, World!" 7 | let base64Encoded = "SGVsbG8sIFdvcmxkIQ==" 8 | XCTAssertEqual(string.data(using: .utf8)?.base64EncodedString(), base64Encoded) 9 | 10 | let mimeType = "text/plain" 11 | XCTAssertEqual(string.data(using: .utf8)?.uriEncoded(mimeType: mimeType), 12 | "data:\(mimeType);base64,\(base64Encoded)") 13 | } 14 | 15 | func testISURIEncoded() throws { 16 | XCTAssertTrue(Data.isURIEncoded(string: "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==")) 17 | XCTAssertFalse(Data.isURIEncoded(string: "Hello, World!")) 18 | XCTAssertFalse(Data.isURIEncoded(string: "")) 19 | } 20 | 21 | func testDecodeURIEncoded() throws { 22 | let encoded = "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==" 23 | guard case let (mimeType, data)? = Data.decode(uriEncoded: encoded) else { 24 | return XCTFail("failed to decode data URI") 25 | } 26 | 27 | XCTAssertEqual(mimeType, "text/plain") 28 | XCTAssertEqual(String(data: data, encoding: .utf8), "Hello, World!") 29 | } 30 | } 31 | --------------------------------------------------------------------------------