├── .spi.yml
├── .gitignore
├── Sources
└── Replicate
│ ├── Hardware.swift
│ ├── Status.swift
│ ├── Extensions
│ └── Data+uriEncoded.swift
│ ├── Error.swift
│ ├── Account.swift
│ ├── Predictable.swift
│ ├── Webhook.swift
│ ├── Identifier.swift
│ ├── Deployment.swift
│ ├── Model.swift
│ ├── Training.swift
│ ├── Value.swift
│ ├── Prediction.swift
│ └── Client.swift
├── .github
└── workflows
│ ├── ci.yml
│ └── codeql.yml
├── Package.swift
├── Tests
└── ReplicateTests
│ ├── DateDecodingTests.swift
│ ├── PredictionTests.swift
│ ├── URIEncodingTests.swift
│ ├── RetryPolicyTests.swift
│ ├── ClientTests.swift
│ └── Helpers
│ └── MockURLProtocol.swift
├── README.md
└── LICENSE.md
/.spi.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | builder:
3 | configs:
4 | - documentation_targets: [Replicate]
5 | metadata:
6 | - authors: Replicate
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | /.build
3 | /Packages
4 | /*.xcodeproj
5 | xcuserdata/
6 | DerivedData/
7 | .swiftpm/config/registries.json
8 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
9 | .netrc
10 | /*.playground
11 |
--------------------------------------------------------------------------------
/Sources/Replicate/Hardware.swift:
--------------------------------------------------------------------------------
1 | // Hardware for running a model on Replicate.
2 | public struct Hardware: Hashable, Codable {
3 | public typealias ID = String
4 |
5 | /// The product identifier for the hardware.
6 | ///
7 | /// For example, "gpu-a40-large".
8 | public let sku: String
9 |
10 | /// The name of the hardware.
11 | ///
12 | /// For example, "Nvidia A40 (Large) GPU".
13 | public let name: String
14 | }
15 |
16 | // MARK: - Identifiable
17 |
18 | extension Hardware: Identifiable {
19 | public var id: String {
20 | return self.sku
21 | }
22 | }
23 |
24 | // MARK: - CustomStringConvertible
25 |
26 | extension Hardware: CustomStringConvertible {
27 | public var description: String {
28 | return self.name
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | jobs:
10 | macos:
11 | runs-on: macos-13-xlarge
12 |
13 | strategy:
14 | matrix:
15 | xcode:
16 | - "14.3.1" # Swift 5.8.1
17 |
18 | name: "macOS (Xcode ${{ matrix.xcode }})"
19 |
20 | env:
21 | DEVELOPER_DIR: /Applications/Xcode_${{ matrix.xcode }}.app/Contents/Developer
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - uses: actions/cache@v3
26 | with:
27 | path: .build
28 | key: ${{ runner.os }}-spm-xcode-${{ matrix.xcode }}-${{ hashFiles('**/Package.resolved') }}
29 | restore-keys: |
30 | ${{ runner.os }}-spm-xcode-${{ matrix.xcode }}-
31 | - name: Build
32 | run: swift build -v
33 | - name: Run tests
34 | run: swift test -v
35 |
--------------------------------------------------------------------------------
/Sources/Replicate/Status.swift:
--------------------------------------------------------------------------------
1 | /// The status of the prediction or training.
2 | public enum Status: String, Hashable, Codable {
3 | /// The prediction or training is starting up.
4 | /// If this status lasts longer than a few seconds,
5 | /// then it's typically because a new worker is being started to run the prediction.
6 | case starting
7 |
8 | /// The `predict()` or `train()` method of the model is currently running.
9 | case processing
10 |
11 | /// The prediction or training completed successfully.
12 | case succeeded
13 |
14 | /// The prediction or training encountered an error during processing.
15 | case failed
16 |
17 | /// The prediction or training was canceled by the user.
18 | case canceled
19 |
20 | public var terminated: Bool {
21 | switch self {
22 | case .starting, .processing:
23 | return false
24 | default:
25 | return true
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 5.7
2 | // The swift-tools-version declares the minimum version of Swift required to build this package.
3 |
4 | import PackageDescription
5 |
6 | let package = Package(
7 | name: "Replicate",
8 | platforms: [
9 | .macOS(.v12),
10 | .iOS(.v15)
11 | ],
12 | products: [
13 | // Products define the executables and libraries a package produces, and make them visible to other packages.
14 | .library(
15 | name: "Replicate",
16 | targets: ["Replicate"]),
17 | ],
18 | dependencies: [],
19 | targets: [
20 | // Targets are the basic building blocks of a package. A target can define a module or a test suite.
21 | // Targets can depend on other targets in this package, and on products in packages this package depends on.
22 | .target(
23 | name: "Replicate"
24 | ),
25 | .testTarget(
26 | name: "ReplicateTests",
27 | dependencies: [
28 | "Replicate"
29 | ]),
30 | ]
31 | )
32 |
--------------------------------------------------------------------------------
/Tests/ReplicateTests/DateDecodingTests.swift:
--------------------------------------------------------------------------------
1 | import XCTest
2 | @testable import Replicate
3 |
4 | struct Value: Decodable {
5 | let date: Date
6 | }
7 |
8 | final class DateDecodingTests: XCTestCase {
9 | func testISO8601Timestamp() throws {
10 | let decoder = JSONDecoder()
11 | decoder.dateDecodingStrategy = .iso8601WithFractionalSeconds
12 |
13 | let timestamp = "2023-10-29T01:23:45Z"
14 | let json = #"{"date": "\#(timestamp)"}"#
15 | let value = try decoder.decode(Value.self, from: json.data(using: .utf8)!)
16 | XCTAssertEqual(value.date.timeIntervalSince1970, 1698542625)
17 | }
18 |
19 | func testISO8601TimestampWithFractionalSeconds() throws {
20 | let decoder = JSONDecoder()
21 | decoder.dateDecodingStrategy = .iso8601WithFractionalSeconds
22 |
23 | let timestamp = "2023-10-29T01:23:45.678900Z"
24 | let json = #"{"date": "\#(timestamp)"}"#
25 | let value = try decoder.decode(Value.self, from: json.data(using: .utf8)!)
26 | XCTAssertEqual(value.date.timeIntervalSince1970, 1698542625.678, accuracy: 0.1)
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/Tests/ReplicateTests/PredictionTests.swift:
--------------------------------------------------------------------------------
1 | import XCTest
2 | @testable import Replicate
3 |
4 | final class PredictionTests: XCTestCase {
5 | var client = Client.valid
6 |
7 | static override func setUp() {
8 | URLProtocol.registerClass(MockURLProtocol.self)
9 | }
10 |
11 | func testWait() async throws {
12 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
13 | var prediction = try await client.createPrediction(version: version, input: ["text": "Alice"])
14 | XCTAssertEqual(prediction.status, .starting)
15 |
16 | try await prediction.wait(with: client)
17 | XCTAssertEqual(prediction.status, .succeeded)
18 | }
19 |
20 | func testCancel() async throws {
21 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
22 | var prediction = try await client.createPrediction(version: version, input: ["text": "Alice"])
23 | XCTAssertEqual(prediction.status, .starting)
24 |
25 | try await prediction.cancel(with: client)
26 | XCTAssertEqual(prediction.status, .canceled)
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/Tests/ReplicateTests/URIEncodingTests.swift:
--------------------------------------------------------------------------------
1 | import XCTest
2 | @testable import Replicate
3 |
4 | final class URIEncodingTests: XCTestCase {
5 | func testDataURIEncoding() throws {
6 | let string = "Hello, World!"
7 | let base64Encoded = "SGVsbG8sIFdvcmxkIQ=="
8 | XCTAssertEqual(string.data(using: .utf8)?.base64EncodedString(), base64Encoded)
9 |
10 | let mimeType = "text/plain"
11 | XCTAssertEqual(string.data(using: .utf8)?.uriEncoded(mimeType: mimeType),
12 | "data:\(mimeType);base64,\(base64Encoded)")
13 | }
14 |
15 | func testISURIEncoded() throws {
16 | XCTAssertTrue(Data.isURIEncoded(string: "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="))
17 | XCTAssertFalse(Data.isURIEncoded(string: "Hello, World!"))
18 | XCTAssertFalse(Data.isURIEncoded(string: ""))
19 | }
20 |
21 | func testDecodeURIEncoded() throws {
22 | let encoded = "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="
23 | guard case let (mimeType, data)? = Data.decode(uriEncoded: encoded) else {
24 | return XCTFail("failed to decode data URI")
25 | }
26 |
27 | XCTAssertEqual(mimeType, "text/plain")
28 | XCTAssertEqual(String(data: data, encoding: .utf8), "Hello, World!")
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/Sources/Replicate/Extensions/Data+uriEncoded.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | private let dataURIPrefix = "data:"
4 |
5 | public extension Data {
6 | static func isURIEncoded(string: String) -> Bool {
7 | return string.hasPrefix(dataURIPrefix)
8 | }
9 |
10 | static func decode(uriEncoded string: String) -> (mimeType: String, data: Data)? {
11 | guard isURIEncoded(string: string) else {
12 | return nil
13 | }
14 |
15 | let components = string.dropFirst(dataURIPrefix.count).components(separatedBy: ",")
16 | guard components.count == 2,
17 | let dataScheme = components.first,
18 | let dataBase64 = components.last
19 | else {
20 | return nil
21 | }
22 |
23 | let mimeType: String
24 | if dataScheme.contains(";") {
25 | mimeType = dataScheme.components(separatedBy: ";").first ?? ""
26 | } else {
27 | mimeType = dataScheme
28 | }
29 |
30 | guard let decodedData = Data(base64Encoded: dataBase64) else {
31 | return nil
32 | }
33 |
34 | return (mimeType: mimeType, data: decodedData)
35 | }
36 |
37 | func uriEncoded(mimeType: String?) -> String {
38 | return "data:\(mimeType ?? "");base64,\(base64EncodedString())"
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/Sources/Replicate/Error.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | /// An error returned by the Replicate HTTP API
4 | public struct Error: Swift.Error, Hashable {
5 | /// A description of the error.
6 | public let detail: String
7 | }
8 |
9 | // MARK: - LocalizedError
10 |
11 | extension Error: LocalizedError {
12 | public var errorDescription: String? {
13 | return self.detail
14 | }
15 | }
16 |
17 | // MARK: - CustomStringConvertible
18 |
19 | extension Error: CustomStringConvertible {
20 | public var description: String {
21 | return self.detail
22 | }
23 | }
24 |
25 | // MARK: - Decodable
26 |
27 | extension Error: Decodable {
28 | private enum CodingKeys: String, CodingKey {
29 | case detail
30 | }
31 |
32 | public init(from decoder: Decoder) throws {
33 | if let container = try? decoder.container(keyedBy: CodingKeys.self) {
34 | self.detail = try container.decode(String.self, forKey: .detail)
35 | } else if let container = try? decoder.singleValueContainer() {
36 | self.detail = try container.decode(String.self)
37 | } else {
38 | let context = DecodingError.Context(codingPath: [], debugDescription: "unable to decode error")
39 | throw DecodingError.dataCorrupted(context)
40 | }
41 | }
42 | }
43 |
44 | // MARK: - Encodable
45 |
46 | extension Error: Encodable {}
47 |
--------------------------------------------------------------------------------
/Sources/Replicate/Account.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.URL
2 |
3 | /// A Replicate account.
4 | public struct Account: Hashable {
5 | /// The acount type.
6 | public enum AccountType: String, CaseIterable, Hashable, Codable {
7 | /// A user.
8 | case user
9 |
10 | /// An organization.
11 | case organization
12 | }
13 |
14 | /// The type of account.
15 | var type: AccountType
16 |
17 | /// The username of the account.
18 | var username: String
19 |
20 | /// The name of the account.
21 | var name: String
22 |
23 | /// The GitHub URL of the account.
24 | var githubURL: URL?
25 | }
26 |
27 | // MARK: - Identifiable
28 |
29 | extension Account: Identifiable {
30 | public typealias ID = String
31 |
32 | public var id: ID {
33 | return self.username
34 | }
35 | }
36 |
37 | // MARK: - CustomStringConvertible
38 |
39 | extension Account: CustomStringConvertible {
40 | public var description: String {
41 | return self.username
42 | }
43 | }
44 |
45 | // MARK: - Codable
46 |
47 | extension Account: Codable {
48 | public enum CodingKeys: String, CodingKey {
49 | case type
50 | case username
51 | case name
52 | case githubURL = "github_url"
53 | }
54 | }
55 |
56 | extension Account.AccountType: CustomStringConvertible {
57 | public var description: String {
58 | return self.rawValue
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | name: "CodeQL"
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 | schedule:
9 | - cron: "42 4 * * 0"
10 |
11 | jobs:
12 | analyze:
13 | name: Analyze
14 | runs-on: "macos-latest"
15 | timeout-minutes: 120
16 | permissions:
17 | actions: read
18 | contents: read
19 | security-events: write
20 |
21 | steps:
22 | - name: Checkout repository
23 | uses: actions/checkout@v3
24 |
25 | # Initializes the CodeQL tools for scanning.
26 | - name: Initialize CodeQL
27 | uses: github/codeql-action/init@v2
28 | with:
29 | languages: swift
30 | # If you wish to specify custom queries, you can do so here or in a config file.
31 | # By default, queries listed here will override any specified in a config file.
32 | # Prefix the list here with "+" to use these queries and those in the config file.
33 |
34 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
35 | # queries: security-extended,security-and-quality
36 |
37 | - uses: actions/cache@v3
38 | with:
39 | path: .build
40 | key: ${{ runner.os }}-spm-${{ hashFiles('**/Package.resolved') }}
41 | restore-keys: |
42 | ${{ runner.os }}-spm-
43 |
44 | - name: Build
45 | run: swift build -v
46 |
47 | - name: Perform CodeQL Analysis
48 | uses: github/codeql-action/analyze@v2
49 | with:
50 | category: "/language:swift"
51 |
--------------------------------------------------------------------------------
/Sources/Replicate/Predictable.swift:
--------------------------------------------------------------------------------
1 | /// A type that can make predictions with known inputs and outputs.
2 | public protocol Predictable {
3 | /// The type of the input to the model.
4 | associatedtype Input: Codable
5 |
6 | /// The type of the output from the model.
7 | associatedtype Output: Codable
8 |
9 | /// The ID of the model.
10 | static var modelID: Model.ID { get }
11 |
12 | /// The ID of the model version.
13 | static var versionID: Model.Version.ID { get }
14 | }
15 |
16 | // MARK: - Default Implementations
17 |
18 | extension Predictable {
19 | /// The type of prediction created by the model
20 | public typealias Prediction = Replicate.Prediction
21 |
22 | /// Creates a prediction.
23 | ///
24 | /// - Parameters:
25 | /// - client: The client used to make API requests.
26 | /// - input: The input passed to the model.
27 | /// - wait:
28 | /// If set to `true`,
29 | /// this method refreshes the prediction until it completes
30 | /// (``Prediction/status`` is `.succeeded` or `.failed`).
31 | /// By default, this is `false`,
32 | /// and this method returns the prediction object encoded
33 | /// in the original creation response
34 | /// (``Prediction/status`` is `.starting`).
35 | public static func predict(
36 | with client: Client,
37 | input: Input,
38 | webhook: Webhook? = nil,
39 | wait: Bool = false
40 | ) async throws -> Prediction {
41 | var prediction = try await client.createPrediction(Prediction.self,
42 | version: Self.versionID,
43 | input: input,
44 | webhook: webhook)
45 |
46 | if wait {
47 | try await prediction.wait(with: client)
48 | }
49 |
50 | return prediction
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/Sources/Replicate/Webhook.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.URL
2 |
3 | /// A structure representing a webhook configuration.
4 | ///
5 | /// A webhook is an HTTPS URL that receives a `POST` request
6 | /// when a prediction or training generates new output, logs,
7 | /// or reaches a terminal state (succeeded / canceled / failed).
8 | ///
9 | /// Example usage:
10 | ///
11 | /// ```swift
12 | /// let webhook = Webhook(url: someURL, events: [.start, .completed])
13 | /// ```
14 | public struct Webhook {
15 | /// The events that can trigger a webhook request.
16 | public enum Event: String, Hashable, CaseIterable {
17 | /// Occurs immediately on prediction start.
18 | case start
19 |
20 | /// Occurs each time a prediction generates an output.
21 | case output
22 |
23 | /// Occurs each time log output is generated by a prediction.
24 | case logs
25 |
26 | /// Occurs when the prediction reaches a terminal state (succeeded / canceled / failed).
27 | /// - SeeAlso: ``Status/terminated``
28 | case completed
29 | }
30 |
31 | /// The webhook URL.
32 | public let url: URL
33 |
34 | /// A set of events that trigger the webhook.
35 | public let events: Set
36 |
37 | /// Creates a new `Webhook` instance with the specified URL and events.
38 | ///
39 | /// By default, the webhook will be triggered by all event types.
40 | /// You can change which events trigger the webhook by specifying a
41 | /// sequence of events as `events` parameter.
42 | ///
43 | /// - Parameters:
44 | /// - url: The webhook URL.
45 | /// - events: A sequence of events that trigger the webhook.
46 | /// Defaults to all event types.
47 | public init(
48 | url: URL,
49 | events: S = Event.allCases
50 | ) where S.Element == Event
51 | {
52 | self.url = url
53 | self.events = Set(events)
54 | }
55 | }
56 |
57 | // MARK: - CustomStringConvertible
58 |
59 | extension Webhook.Event: CustomStringConvertible {
60 | public var description: String {
61 | return self.rawValue
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Sources/Replicate/Identifier.swift:
--------------------------------------------------------------------------------
1 | /// An identifier in the form of "{owner}/{name}:{version}".
2 | public struct Identifier: Hashable {
3 | /// The name of the user or organization that owns the model.
4 | public let owner: String
5 |
6 | /// The name of the model.
7 | public let name: String
8 |
9 | /// The version.
10 | let version: Model.Version.ID?
11 | }
12 |
13 | // MARK: - Equatable & Comparable
14 |
15 | extension Identifier: Equatable, Comparable {
16 | public static func == (lhs: Identifier, rhs: Identifier) -> Bool {
17 | return lhs.rawValue.caseInsensitiveCompare(rhs.rawValue) == .orderedSame
18 | }
19 |
20 | public static func < (lhs: Identifier, rhs: Identifier) -> Bool {
21 | return lhs.rawValue.caseInsensitiveCompare(rhs.rawValue) == .orderedAscending
22 | }
23 | }
24 |
25 | // MARK: - RawRepresentable
26 |
27 | extension Identifier: RawRepresentable {
28 | public typealias RawValue = String
29 |
30 | public init?(rawValue: RawValue) {
31 | let components = rawValue.split(separator: "/")
32 | guard components.count == 2 else { return nil }
33 |
34 | if components[1].contains(":") {
35 | let nameAndVersion = components[1].split(separator: ":")
36 | guard nameAndVersion.count == 2 else { return nil }
37 |
38 | self.init(owner: String(components[0]),
39 | name: String(nameAndVersion[0]),
40 | version: Model.Version.ID(nameAndVersion[1]))
41 | } else {
42 | self.init(owner: String(components[0]),
43 | name: String(components[1]),
44 | version: nil)
45 | }
46 | }
47 |
48 | public var rawValue: String {
49 | if let version = version {
50 | return "\(owner)/\(name):\(version)"
51 | } else {
52 | return "\(owner)/\(name)"
53 | }
54 | }
55 | }
56 |
57 | // MARK: - ExpressibleByStringLiteral
58 |
59 | extension Identifier: ExpressibleByStringLiteral {
60 | public init!(stringLiteral value: StringLiteralType) {
61 | guard let identifier = Identifier(rawValue: value) else {
62 | fatalError("Invalid Identifier string literal: \(value)")
63 | }
64 |
65 | self = identifier
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/Sources/Replicate/Deployment.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.Date
2 |
3 | /// A deployment of a model on Replicate.
4 | public struct Deployment: Hashable {
5 | /// The owner of the deployment.
6 | public let owner: String
7 |
8 | /// The name of the deployment.
9 | public let name: String
10 |
11 | /// A release of a deployment.
12 | public struct Release: Hashable {
13 | /// The release number.
14 | let number: Int
15 |
16 | /// The model.
17 | let model: Model.ID
18 |
19 | /// The model version.
20 | let version: Model.Version.ID
21 |
22 | /// The time at which the release was created.
23 | let createdAt: Date
24 |
25 | /// The account that created the release
26 | let createdBy: Account
27 |
28 | /// The configuration of a deployment.
29 | public struct Configuration: Hashable {
30 | /// The configured hardware SKU.
31 | public let hardware: Hardware.ID
32 |
33 | /// A scaling configuration for a deployment.
34 | public struct Scaling: Hashable {
35 | /// The maximum number of instances.
36 | public let maxInstances: Int
37 |
38 | /// The minimum number of instances.
39 | public let minInstances: Int
40 | }
41 |
42 | /// The scaling configuration for the deployment.
43 | public let scaling: Scaling
44 | }
45 |
46 | /// The deployment configuration.
47 | public let configuration: Configuration
48 | }
49 |
50 | public let currentRelease: Release?
51 | }
52 |
53 | // MARK: - Identifiable
54 |
55 | extension Deployment: Identifiable {
56 | public typealias ID = String
57 |
58 | /// The ID of the model.
59 | public var id: ID { "\(owner)/\(name)" }
60 | }
61 |
62 | // MARK: - Codable
63 |
64 | extension Deployment: Codable {
65 | public enum CodingKeys: String, CodingKey {
66 | case owner
67 | case name
68 | case currentRelease = "current_release"
69 | }
70 | }
71 |
72 | extension Deployment.Release: Codable {
73 | public enum CodingKeys: String, CodingKey {
74 | case number
75 | case model
76 | case version
77 | case createdAt = "created_at"
78 | case createdBy = "created_by"
79 | case configuration
80 | }
81 | }
82 |
83 | extension Deployment.Release.Configuration: Codable {
84 | public enum CodingKeys: String, CodingKey {
85 | case hardware
86 | case scaling
87 | }
88 | }
89 |
90 | extension Deployment.Release.Configuration.Scaling: Codable {
91 | public enum CodingKeys: String, CodingKey {
92 | case minInstances = "min_instances"
93 | case maxInstances = "max_instances"
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/Tests/ReplicateTests/RetryPolicyTests.swift:
--------------------------------------------------------------------------------
1 | import XCTest
2 | @testable import Replicate
3 |
4 | final class RetryPolicyTests: XCTestCase {
5 | func testConstantBackoffStrategy() throws {
6 | let policy = Client.RetryPolicy(strategy: .constant(duration: 1.0,
7 | jitter: 0.0),
8 | timeout: nil,
9 | maximumInterval: nil,
10 | maximumRetries: 5)
11 |
12 | XCTAssertEqual(Array(policy), [1.0, 1.0, 1.0, 1.0, 1.0])
13 | }
14 |
15 | func testExponentialBackoffStrategy() throws {
16 | let policy = Client.RetryPolicy(strategy: .exponential(base: 1.0,
17 | multiplier: 2.0,
18 | jitter: 0.0),
19 | timeout: nil,
20 | maximumInterval: 30.0,
21 | maximumRetries: 7)
22 |
23 | XCTAssertEqual(Array(policy), [1.0, 2.0, 4.0, 8.0, 16.0, 30.0, 30.0])
24 | }
25 |
26 | func testTimeoutWithDeadline() throws {
27 | let timeout: TimeInterval = 300.0
28 | let policy = Client.RetryPolicy(strategy: .constant(),
29 | timeout: timeout,
30 | maximumInterval: nil,
31 | maximumRetries: nil)
32 | let deadline = policy.makeIterator().deadline
33 |
34 | XCTAssertNotNil(deadline)
35 | XCTAssertLessThanOrEqual(deadline ?? .distantFuture, DispatchTime.now().advanced(by: .nanoseconds(Int(timeout * 1e+9))))
36 | }
37 |
38 | func testTimeoutWithoutDeadline() throws {
39 | let policy = Client.RetryPolicy(strategy: .constant(),
40 | timeout: nil,
41 | maximumInterval: nil,
42 | maximumRetries: nil)
43 | let deadline = policy.makeIterator().deadline
44 |
45 | XCTAssertNil(deadline)
46 | }
47 |
48 | func testMaximumInterval() throws {
49 | let maximumInterval: TimeInterval = 30.0
50 | let policy = Client.RetryPolicy(strategy: .exponential(),
51 | timeout: nil,
52 | maximumInterval: maximumInterval,
53 | maximumRetries: 10)
54 |
55 | XCTAssertGreaterThanOrEqual(maximumInterval, policy.max() ?? .greatestFiniteMagnitude)
56 | }
57 |
58 | func testMaximumRetries() throws {
59 | let maximumRetries: Int = 5
60 | let policy = Client.RetryPolicy(strategy: .constant(),
61 | timeout: nil,
62 | maximumInterval: nil,
63 | maximumRetries: maximumRetries)
64 |
65 | XCTAssertEqual(Array(policy).count, maximumRetries)
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/Sources/Replicate/Model.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.Date
2 | import struct Foundation.URL
3 |
4 | /// A machine learning model hosted on Replicate.
5 | public struct Model: Hashable {
6 | /// The visibility of the model.
7 | public enum Visibility: String, Hashable, Decodable {
8 | /// Public visibility.
9 | case `public`
10 |
11 | /// Private visibility.
12 | case `private`
13 | }
14 |
15 | /// A version of a model.
16 | public struct Version: Hashable {
17 | /// The ID of the version.
18 | public let id: ID
19 |
20 | /// When the version was created.
21 | public let createdAt: Date
22 |
23 | /// An OpenAPI description of the model inputs and outputs.
24 | public let openAPISchema: [String: Value]
25 | }
26 |
27 | /// A collection of models.
28 | public struct Collection: Hashable, Decodable {
29 | /// The name of the collection.
30 | public let name: String
31 |
32 | /// The slug of the collection,
33 | /// like super-resolution or image-restoration.
34 | ///
35 | /// See
36 | public let slug: String
37 |
38 | /// A description for the collection.
39 | public let description: String
40 |
41 | /// A list of models in the collection.
42 | public let models: [Model]?
43 | }
44 |
45 | /// The name of the user or organization that owns the model.
46 | public let owner: String
47 |
48 | /// The name of the model.
49 | public let name: String
50 |
51 | /// A link to the model on Replicate.
52 | public let url: URL
53 |
54 | /// A link to the model source code on GitHub.
55 | public let githubURL: URL?
56 |
57 | /// A link to the model's paper.
58 | public let paperURL: URL?
59 |
60 | /// A link to the model's license.
61 | public let licenseURL: URL?
62 |
63 | /// A link to a cover image for the model.
64 | public let coverImageURL: URL?
65 |
66 | /// A description for the model.
67 | public let description: String?
68 |
69 | /// The visibility of the model.
70 | public let visibility: Visibility
71 |
72 | /// The latest version of the model, if any.
73 | public let latestVersion: Version?
74 |
75 | /// The number of times this model has been run.
76 | public let runCount: Int?
77 |
78 | public let defaultExample: AnyPrediction?
79 | }
80 |
81 | // MARK: - Identifiable
82 |
83 | extension Model: Identifiable {
84 | public typealias ID = String
85 |
86 | /// The ID of the model.
87 | public var id: ID { "\(owner)/\(name)" }
88 | }
89 |
90 | extension Model.Version: Identifiable {
91 | public typealias ID = String
92 | }
93 |
94 | extension Model.Collection: Identifiable {
95 | public typealias ID = String
96 |
97 | /// The ID of the model collection.
98 | public var id: String { slug }
99 | }
100 |
101 | // MARK: - Decodable
102 |
103 | extension Model: Decodable {
104 | private enum CodingKeys: String, CodingKey {
105 | case owner
106 | case name
107 | case url
108 | case githubURL = "github_url"
109 | case paperURL = "paper_url"
110 | case licenseURL = "license_url"
111 | case coverImageURL = "cover_image_url"
112 | case description
113 | case visibility
114 | case latestVersion = "latest_version"
115 | case runCount = "run_count"
116 | case defaultExample = "default_example"
117 | }
118 | }
119 |
120 | extension Model.Version: Decodable {
121 | private enum CodingKeys: String, CodingKey {
122 | case id
123 | case createdAt = "created_at"
124 | case openAPISchema = "openapi_schema"
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/Sources/Replicate/Training.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.Date
2 | import struct Foundation.URL
3 | import struct Foundation.TimeInterval
4 | import struct Dispatch.DispatchTime
5 |
6 | /// A training with unspecified inputs and outputs.
7 | public typealias AnyTraining = Training<[String: Value]>
8 |
9 | /// A training made by a model hosted on Replicate.
10 | public struct Training: Identifiable where Input: Codable {
11 | public typealias ID = String
12 |
13 | public struct Output: Hashable, Codable {
14 | public var version: Model.Version.ID
15 | public var weights: URL?
16 | }
17 |
18 | /// Source for creating a training.
19 | public enum Source: String, Codable {
20 | /// The training was made on the web.
21 | case web
22 |
23 | /// The training was made using the API.
24 | case api
25 | }
26 |
27 | /// Metrics for the training.
28 | public struct Metrics: Hashable {
29 | /// How long it took to create the training, in seconds.
30 | public let predictTime: TimeInterval?
31 | }
32 |
33 | /// The unique ID of the training.
34 | /// Can be used to get a single training.
35 | ///
36 | /// - SeeAlso: ``Client/getTraining(id:)``
37 | public let id: ID
38 |
39 | /// The version of the model used to create the training.
40 | public let versionID: Model.Version.ID
41 |
42 | /// Where the training was made.
43 | public let source: Source?
44 |
45 | /// The model's input as a JSON object.
46 | ///
47 | /// The input depends on what model you are running.
48 | /// To see the available inputs,
49 | /// click the "Run with API" tab on the model you are running.
50 | /// For example,
51 | /// [stability-ai/stable-diffusion](https://replicate.com/stability-ai/stable-diffusion)
52 | /// takes `prompt` as an input.
53 | ///
54 | /// Files should be passed as data URLs or HTTP URLs.
55 | public let input: Input
56 |
57 | /// The output of the model for the training, if completed successfully.
58 | public let output: Output?
59 |
60 | /// The status of the training.
61 | public let status: Status
62 |
63 | /// The error encountered during the training, if any.
64 | public let error: Error?
65 |
66 | /// Logging output for the training.
67 | public let logs: String?
68 |
69 | /// Metrics for the training.
70 | public let metrics: Metrics?
71 |
72 | /// When the training was created.
73 | public let createdAt: Date
74 |
75 | /// When the training was started
76 | public let startedAt: Date?
77 |
78 | /// When the training was completed.
79 | public let completedAt: Date?
80 |
81 | /// A convenience object that can be used to construct new API requests against the given training.
82 | public let urls: [String: URL]
83 |
84 | // MARK: -
85 |
86 | /// Cancel the training.
87 | ///
88 | /// - Parameters:
89 | /// - client: The client used to make API requests.
90 | public mutating func cancel(with client: Client) async throws {
91 | self = try await client.cancelTraining(Self.self, id: id)
92 | }
93 | }
94 |
95 | // MARK: - Decodable
96 |
97 | extension Training: Codable {
98 | private enum CodingKeys: String, CodingKey {
99 | case id
100 | case versionID = "version"
101 | case source
102 | case input
103 | case output
104 | case status
105 | case error
106 | case logs
107 | case metrics
108 | case createdAt = "created_at"
109 | case startedAt = "started_at"
110 | case completedAt = "completed_at"
111 | case urls
112 | }
113 | }
114 |
115 | extension Training.Metrics: Codable {
116 | private enum CodingKeys: String, CodingKey {
117 | case predictTime = "predict_time"
118 | }
119 |
120 | public init(from decoder: Decoder) throws {
121 | let container = try decoder.container(keyedBy: CodingKeys.self)
122 | self.predictTime = try container.decodeIfPresent(TimeInterval.self, forKey: .predictTime)
123 | }
124 | }
125 |
126 | // MARK: - Hashable
127 |
128 | extension Training: Equatable where Input: Equatable {}
129 | extension Training: Hashable where Input: Hashable {}
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Replicate Swift client
2 |
3 | [](https://swiftpackageindex.com/replicate/replicate-swift)
4 | [](https://swiftpackageindex.com/replicate/replicate-swift)
5 |
6 | This is a Swift client for [Replicate].
7 | It lets you run models from your Swift code,
8 | and do various other things on Replicate.
9 |
10 | To learn how to use it,
11 | [take a look at our guide to building a SwiftUI app with Replicate](https://replicate.com/docs/get-started/swiftui).
12 |
13 | ## Usage
14 |
15 | Grab your API token from [replicate.com/account](https://replicate.com/account)
16 | and pass it to `Client(token:)`:
17 |
18 | ```swift
19 | import Foundation
20 | import Replicate
21 |
22 | let replicate = Replicate.Client(token: <#token#>)
23 | ```
24 |
25 | > [!WARNING]
26 | > Don't store secrets in code or any other resources bundled with your app.
27 | > Instead, fetch them from CloudKit or another server and store them in the keychain.
28 |
29 | You can run a model and get its output:
30 |
31 | ```swift
32 | let output = try await replicate.run(
33 | "stability-ai/stable-diffusion-3",
34 | ["prompt": "a 19th century portrait of a gentleman otter"]
35 | ) { prediction in
36 | // Print the prediction status after each update
37 | print(prediction.status)
38 | }
39 |
40 | print(output)
41 | // ["https://replicate.delivery/yhqm/bh9SsjWXY3pGKJyQzYjQlsZPzcNZ4EYOeEsPjFytc5TjYeNTA/R8_SD3_00001_.webp"]
42 | ```
43 |
44 | Or fetch a model by name and create a prediction against its latest version:
45 |
46 | ```swift
47 | let model = try await replicate.getModel("stability-ai/stable-diffusion-3")
48 | if let latestVersion = model.latestVersion {
49 | let prompt = "a 19th century portrait of a gentleman otter"
50 | let prediction = try await replicate.createPrediction(version: latestVersion.id,
51 | input: ["prompt": "\(prompt)"],
52 | wait: true)
53 | print(prediction.id)
54 | // "s654jhww3hrm60ch11v8t3zpkg"
55 | print(prediction.output)
56 | // ["https://replicate.delivery/yhqm/bh9SsjWXY3pGKJyQzYjQlsZPzcNZ4EYOeEsPjFytc5TjYeNTA/R8_SD3_00001_.webp"]
57 | }
58 | ```
59 |
60 | Some models,
61 | like [tencentarc/gfpgan](https://replicate.com/tencentarc/gfpgan),
62 | receive images as inputs.
63 | To run a model that takes a file input you can pass either
64 | a URL to a publicly accessible file on the Internet
65 | or use the `uriEncoded(mimeType:) helper method to create
66 | a base64-encoded data URL from the contents of a local file.
67 |
68 | ```swift
69 | let model = try await replicate.getModel("tencentarc/gfpgan")
70 | if let latestVersion = model.latestVersion {
71 | let data = try! Data(contentsOf: URL(fileURLWithPath: "/path/to/image.jpg"))
72 | let mimeType = "image/jpeg"
73 | let prediction = try await replicate.createPrediction(version: latestVersion.id,
74 | input: ["img": "\(data.uriEncoded(mimeType: mimeType))"])
75 | print(prediction.output)
76 | // ["https://replicate.delivery/mgxm/85f53415-0dc7-4703-891f-1e6f912119ad/output.png"]
77 | }
78 | ```
79 |
80 | You can start a model and run it in the background:
81 |
82 | ```swift
83 | let model = replicate.getModel("kvfrans/clipdraw")
84 |
85 | let prompt = "watercolor painting of an underwater submarine"
86 | var prediction = replicate.createPrediction(version: model.latestVersion!.id,
87 | input: ["prompt": "\(prompt)"])
88 | print(prediction.status)
89 | // "starting"
90 |
91 | try await prediction.wait(with: replicate)
92 | print(prediction.status)
93 | // "succeeded"
94 | ```
95 |
96 | You can cancel a running prediction:
97 |
98 | ```swift
99 | let model = replicate.getModel("kvfrans/clipdraw")
100 |
101 | let prompt = "watercolor painting of an underwater submarine"
102 | var prediction = replicate.createPrediction(version: model.latestVersion!.id,
103 | input: ["prompt": "\(prompt)"])
104 | print(prediction.status)
105 | // "starting"
106 |
107 | try await prediction.cancel(with: replicate)
108 | print(prediction.status)
109 | // "canceled"
110 | ```
111 |
112 | You can list all the predictions you've run:
113 |
114 | ```swift
115 | var predictions: [Prediction] = []
116 | var cursor: Replicate.Client.Pagination.Cursor?
117 | let limit = 100
118 |
119 | repeat {
120 | let page = try await replicate.getPredictions(cursor: cursor)
121 | predictions.append(contentsOf: page.results)
122 | cursor = page.next
123 | } while predictions.count < limit && cursor != nil
124 | ```
125 |
126 | ## Adding `Replicate` as a Dependency
127 |
128 | To use the `Replicate` library in a Swift project,
129 | add it to the dependencies for your package and your target:
130 |
131 | ```swift
132 | let package = Package(
133 | // name, platforms, products, etc.
134 | dependencies: [
135 | // other dependencies
136 | .package(url: "https://github.com/replicate/replicate-swift", from: "0.24.0"),
137 | ],
138 | targets: [
139 | .target(name: "", dependencies: [
140 | // other dependencies
141 | .product(name: "Replicate", package: "replicate-swift"),
142 | ]),
143 | // other targets
144 | ]
145 | )
146 | ```
147 |
148 | [Replicate]: https://replicate.com
149 |
--------------------------------------------------------------------------------
/Sources/Replicate/Value.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.Data
2 | import class Foundation.JSONDecoder
3 | import class Foundation.JSONEncoder
4 |
5 | /// A codable value.
6 | public enum Value: Hashable {
7 | case null
8 | case bool(Bool)
9 | case int(Int)
10 | case double(Double)
11 | case string(String)
12 | case data(mimeType: String? = nil, Data)
13 | case array([Value])
14 | case object([String: Value])
15 |
16 | /// Create a `Value` from a `Codable` value.
17 | /// - Parameter value: The codable value
18 | /// - Returns: A value
19 | public init(_ value: T) throws {
20 | if let valueAsValue = value as? Value {
21 | self = valueAsValue
22 | } else {
23 | let data = try JSONEncoder().encode(value)
24 | self = try JSONDecoder().decode(Value.self, from: data)
25 | }
26 | }
27 |
28 | /// Returns whether the value is `null`.
29 | public var isNull: Bool {
30 | return self == .null
31 | }
32 |
33 | /// Returns the `Bool` value if the value is a `bool`,
34 | /// otherwise returns `nil`.
35 | public var boolValue: Bool? {
36 | guard case let .bool(value) = self else { return nil }
37 | return value
38 | }
39 |
40 | /// Returns the `Int` value if the value is an `integer`,
41 | /// otherwise returns `nil`.
42 | public var intValue: Int? {
43 | guard case let .int(value) = self else { return nil }
44 | return value
45 | }
46 |
47 | /// Returns the `Double` value if the value is a `double`,
48 | /// otherwise returns `nil`.
49 | public var doubleValue: Double? {
50 | guard case let .double(value) = self else { return nil }
51 | return value
52 | }
53 |
54 | /// Returns the `String` value if the value is a `string`,
55 | /// otherwise returns `nil`.
56 | public var stringValue: String? {
57 | guard case let .string(value) = self else { return nil }
58 | return value
59 | }
60 |
61 | /// Returns the data value and optional MIME type if the value is `data`,
62 | /// otherwise returns `nil`.
63 | public var dataValue: (mimeType: String?, Data)? {
64 | guard case let .data(mimeType: mimeType, data) = self else { return nil }
65 | return (mimeType: mimeType, data)
66 | }
67 |
68 | /// Returns the `[Value]` value if the value is an `array`,
69 | /// otherwise returns `nil`.
70 | public var arrayValue: [Value]? {
71 | guard case let .array(value) = self else { return nil }
72 | return value
73 | }
74 |
75 | /// Returns the `[String: Value]` value if the value is an `object`,
76 | /// otherwise returns `nil`.
77 | public var objectValue: [String: Value]? {
78 | guard case let .object(value) = self else { return nil }
79 | return value
80 | }
81 | }
82 |
83 | // MARK: - Codable
84 |
85 | extension Value: Codable {
86 | public init(from decoder: Decoder) throws {
87 | let container = try decoder.singleValueContainer()
88 |
89 | if container.decodeNil() {
90 | self = .null
91 | } else if let value = try? container.decode(Bool.self) {
92 | self = .bool(value)
93 | } else if let value = try? container.decode(Int.self) {
94 | self = .int(value)
95 | } else if let value = try? container.decode(Double.self) {
96 | self = .double(value)
97 | } else if let value = try? container.decode(String.self) {
98 | if Data.isURIEncoded(string: value),
99 | case let (mimeType, data)? = Data.decode(uriEncoded: value)
100 | {
101 | self = .data(mimeType: mimeType, data)
102 | } else {
103 | self = .string(value)
104 | }
105 | } else if let value = try? container.decode([Value].self) {
106 | self = .array(value)
107 | } else if let value = try? container.decode([String: Value].self) {
108 | self = .object(value)
109 | } else {
110 | throw DecodingError.dataCorruptedError(in: container, debugDescription: "Value type not found")
111 | }
112 | }
113 |
114 | public func encode(to encoder: Encoder) throws {
115 | var container = encoder.singleValueContainer()
116 |
117 | switch self {
118 | case .null:
119 | try container.encodeNil()
120 | case .bool(let value):
121 | try container.encode(value)
122 | case .int(let value):
123 | try container.encode(value)
124 | case .double(let value):
125 | try container.encode(value)
126 | case .string(let value):
127 | try container.encode(value)
128 | case let .data(mimeType, value):
129 | try container.encode(value.uriEncoded(mimeType: mimeType))
130 | case .array(let value):
131 | try container.encode(value)
132 | case .object(let value):
133 | try container.encode(value)
134 | }
135 | }
136 | }
137 |
138 | extension Value: CustomStringConvertible {
139 | public var description: String {
140 | switch self {
141 | case .null:
142 | return ""
143 | case .bool(let value):
144 | return value.description
145 | case .int(let value):
146 | return value.description
147 | case .double(let value):
148 | return value.description
149 | case .string(let value):
150 | return value.description
151 | case let .data(mimeType, value):
152 | return value.uriEncoded(mimeType: mimeType)
153 | case .array(let value):
154 | return value.description
155 | case .object(let value):
156 | return value.description
157 | }
158 | }
159 | }
160 |
161 | // MARK: - ExpressibleByNilLiteral
162 |
163 | extension Value: ExpressibleByNilLiteral {
164 | public init(nilLiteral: ()) {
165 | self = .null
166 | }
167 | }
168 |
169 | // MARK: - ExpressibleByBooleanLiteral
170 |
171 | extension Value: ExpressibleByBooleanLiteral {
172 | public init(booleanLiteral value: Bool) {
173 | self = .bool(value)
174 | }
175 | }
176 |
177 | // MARK: - ExpressibleByIntegerLiteral
178 |
179 | extension Value: ExpressibleByIntegerLiteral {
180 | public init(integerLiteral value: Int) {
181 | self = .int(value)
182 | }
183 | }
184 |
185 | // MARK: - ExpressibleByFloatLiteral
186 |
187 | extension Value: ExpressibleByFloatLiteral {
188 | public init(floatLiteral value: Double) {
189 | self = .double(value)
190 | }
191 | }
192 |
193 | // MARK: - ExpressibleByStringLiteral
194 |
195 | extension Value: ExpressibleByStringLiteral {
196 | public init(stringLiteral value: String) {
197 | self = .string(value)
198 | }
199 | }
200 |
201 | // MARK: - ExpressibleByArrayLiteral
202 |
203 | extension Value: ExpressibleByArrayLiteral {
204 | public init(arrayLiteral elements: Value...) {
205 | self = .array(elements)
206 | }
207 | }
208 |
209 | // MARK: - ExpressibleByDictionaryLiteral
210 |
211 | extension Value: ExpressibleByDictionaryLiteral {
212 | public init(dictionaryLiteral elements: (String, Value)...) {
213 | var dictionary: [String: Value] = [:]
214 | for (key, value) in elements {
215 | dictionary[key] = value
216 | }
217 | self = .object(dictionary)
218 | }
219 | }
220 |
221 | // MARK: - ExpressibleByStringInterpolation
222 |
223 | extension Value: ExpressibleByStringInterpolation {
224 | public struct StringInterpolation: StringInterpolationProtocol {
225 | var stringValue: String
226 |
227 | public init(literalCapacity: Int, interpolationCount: Int) {
228 | self.stringValue = ""
229 | self.stringValue.reserveCapacity(literalCapacity + interpolationCount)
230 | }
231 |
232 | public mutating func appendLiteral(_ literal: String) {
233 | self.stringValue.append(literal)
234 | }
235 |
236 | public mutating func appendInterpolation(_ value: T) {
237 | self.stringValue.append(value.description)
238 | }
239 | }
240 |
241 | public init(stringInterpolation: StringInterpolation) {
242 | self = .string(stringInterpolation.stringValue)
243 | }
244 | }
245 |
--------------------------------------------------------------------------------
/Sources/Replicate/Prediction.swift:
--------------------------------------------------------------------------------
1 | import struct Foundation.Date
2 | import struct Foundation.TimeInterval
3 | import struct Foundation.URL
4 | import class Foundation.Progress
5 | import struct Dispatch.DispatchTime
6 |
7 | import RegexBuilder
8 |
9 | /// A prediction with unspecified inputs and outputs.
10 | public typealias AnyPrediction = Prediction<[String: Value], Value>
11 |
12 | /// A prediction made by a model hosted on Replicate.
13 | public struct Prediction: Identifiable where Input: Codable, Output: Codable {
14 | public typealias ID = String
15 |
16 | /// Source for creating a prediction.
17 | public enum Source: String, Codable {
18 | /// The prediction was made on the web.
19 | case web
20 |
21 | /// The prediction was made using the API.
22 | case api
23 | }
24 |
25 | /// Metrics for the prediction.
26 | public struct Metrics: Hashable {
27 | /// How long it took to create the prediction, in seconds.
28 | public let predictTime: TimeInterval?
29 | }
30 |
31 | /// The unique ID of the prediction.
32 | /// Can be used to get a single prediction.
33 | ///
34 | /// - SeeAlso: ``Client/getPrediction(id:)``
35 | public let id: ID
36 |
37 | /// The model used to create the prediction.
38 | public let modelID: Model.ID
39 |
40 | /// The version of the model used to create the prediction.
41 | public let versionID: Model.Version.ID
42 |
43 | /// Where the prediction was made.
44 | public let source: Source?
45 |
46 | /// The model's input as a JSON object.
47 | ///
48 | /// The input depends on what model you are running.
49 | /// To see the available inputs,
50 | /// click the "Run with API" tab on the model you are running.
51 | /// For example,
52 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3)
53 | /// takes `prompt` as an input.
54 | ///
55 | /// Files should be passed as data URLs or HTTP URLs.
56 | public let input: Input
57 |
58 | /// The output of the model for the prediction, if completed successfully.
59 | public let output: Output?
60 |
61 | /// The status of the prediction.
62 | public let status: Status
63 |
64 | /// The error encountered during the prediction, if any.
65 | public let error: Error?
66 |
67 | /// Logging output for the prediction.
68 | public let logs: String?
69 |
70 | /// Metrics for the prediction.
71 | public let metrics: Metrics?
72 |
73 | /// When the prediction was created.
74 | public let createdAt: Date
75 |
76 | /// When the prediction was started
77 | public let startedAt: Date?
78 |
79 | /// When the prediction was completed.
80 | public let completedAt: Date?
81 |
82 | /// A convenience object that can be used to construct new API requests against the given prediction.
83 | public let urls: [String: URL]
84 |
85 | // MARK: -
86 |
87 | @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
88 | public var progress: Progress? {
89 | guard let logs = self.logs else { return nil }
90 |
91 | let regex: Regex = #/^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/#
92 |
93 | let lines = logs.split(separator: "\n")
94 | guard !lines.isEmpty else { return nil }
95 |
96 | for line in lines.reversed() {
97 | let lineString = String(line).trimmingCharacters(in: .whitespaces)
98 | if let match = try? regex.firstMatch(in: lineString),
99 | let current = Int64(match.output.2),
100 | let total = Int64(match.output.3)
101 | {
102 | let progress = Progress(totalUnitCount: total)
103 | progress.completedUnitCount = current
104 | return progress
105 | }
106 | }
107 |
108 | return nil
109 | }
110 |
111 | // MARK: -
112 |
113 | /// Wait for the prediction to complete.
114 | ///
115 | /// - Parameters:
116 | /// - client:
117 | /// The client used to make API requests.
118 | /// - priority:
119 | /// The task priority.
120 | /// - updateHandler:
121 | /// A closure that executes with the updated prediction
122 | /// after each polling request to the API.
123 | /// If the prediction is in a terminal state
124 | /// (e.g. `succeeded`, `failed`, or `canceled`),
125 | /// it's returned immediately and the closure is not executed.
126 | /// Use this to provide feedback to the user
127 | /// about the progress of the prediction,
128 | /// or throw `CancellationError` to stop waiting
129 | /// for the prediction to finish.
130 | /// - Returns: The completed prediction.
131 | /// - Important:
132 | /// Returning early from the `updateHandler` closure
133 | /// doesn't cancel the prediction.
134 | /// To cancel the prediction,
135 | /// call ``cancel(with:)``.
136 | /// - Throws:
137 | /// ``CancellationError`` if the prediction was canceled,
138 | /// or any error thrown from the `updateHandler` closure
139 | /// other than ``CancellationError``.
140 | public mutating func wait(
141 | with client: Client,
142 | priority: TaskPriority? = nil,
143 | updateHandler: @escaping (Self) throws -> Void = { _ in () }
144 | ) async throws {
145 | var retrier: Client.RetryPolicy.Retrier = client.retryPolicy.makeIterator()
146 | self = try await Self.wait(for: self,
147 | with: client,
148 | priority: priority,
149 | retrier: &retrier,
150 | updateHandler: updateHandler)
151 | }
152 |
153 | /// Waits for a prediction to complete and returns the updated prediction.
154 | ///
155 | /// - Parameters:
156 | /// - current:
157 | /// The prediction to wait for.
158 | /// - client:
159 | /// The client used to make API requests.
160 | /// - priority:
161 | /// The task priority.
162 | /// - retrier:
163 | /// An instance of the client retry policy.
164 | /// - updateHandler:
165 | /// A closure that executes with the updated prediction
166 | /// after each polling request to the API.
167 | /// If the prediction is in a terminal state
168 | /// (e.g. `succeeded`, `failed`, or `canceled`),
169 | /// it's returned immediately and the closure is not executed.
170 | /// Use this to provide feedback to the user
171 | /// about the progress of the prediction,
172 | /// or throw `CancellationError` to stop waiting
173 | /// for the prediction to finish.
174 | /// - Returns: The completed prediction.
175 | /// - Important:
176 | /// Returning early from the `updateHandler` closure
177 | /// doesn't cancel the prediction.
178 | /// To cancel the prediction,
179 | /// call ``cancel(with:)``.
180 | /// - Throws:
181 | /// ``CancellationError`` if the prediction was canceled,
182 | /// or any error thrown from the `updateHandler` closure
183 | /// other than ``CancellationError``.
184 | public static func wait(
185 | for current: Self,
186 | with client: Client,
187 | priority: TaskPriority? = nil,
188 | retrier: inout Client.RetryPolicy.Iterator,
189 | updateHandler: @escaping (Self) throws -> Void = { _ in () }
190 | ) async throws -> Self {
191 | guard !current.status.terminated else { return current }
192 | guard let delay = retrier.next() else { throw CancellationError() }
193 |
194 | let id = current.id
195 | let updated = try await withThrowingTaskGroup(of: Self.self) { group in
196 | group.addTask {
197 | try await Task.sleep(nanoseconds: UInt64(delay * 1e+9))
198 | return try await client.getPrediction(Self.self, id: id)
199 | }
200 |
201 | if let deadline = retrier.deadline {
202 | group.addTask {
203 | try await Task.sleep(nanoseconds: deadline.uptimeNanoseconds - DispatchTime.now().uptimeNanoseconds)
204 | throw CancellationError()
205 | }
206 | }
207 |
208 | let value = try await group.next()
209 | group.cancelAll()
210 |
211 | return value ?? current
212 | }
213 |
214 | if updated.status.terminated {
215 | return updated
216 | } else {
217 | do {
218 | try updateHandler(updated)
219 | } catch is CancellationError {
220 | return current
221 | } catch {
222 | throw error
223 | }
224 |
225 | return try await wait(for: updated,
226 | with: client,
227 | priority: priority,
228 | retrier: &retrier,
229 | updateHandler: updateHandler)
230 | }
231 | }
232 |
233 | /// Cancel the prediction.
234 | ///
235 | /// - Parameters:
236 | /// - client: The client used to make API requests.
237 | public mutating func cancel(with client: Client) async throws {
238 | self = try await client.cancelPrediction(Self.self, id: id)
239 | }
240 | }
241 |
242 | // MARK: - Decodable
243 |
244 | extension Prediction: Codable {
245 | private enum CodingKeys: String, CodingKey {
246 | case id
247 | case modelID = "model"
248 | case versionID = "version"
249 | case source
250 | case input
251 | case output
252 | case status
253 | case error
254 | case logs
255 | case metrics
256 | case createdAt = "created_at"
257 | case startedAt = "started_at"
258 | case completedAt = "completed_at"
259 | case urls
260 | }
261 | }
262 |
263 | extension Prediction.Metrics: Codable {
264 | private enum CodingKeys: String, CodingKey {
265 | case predictTime = "predict_time"
266 | }
267 |
268 | public init(from decoder: Decoder) throws {
269 | let container = try decoder.container(keyedBy: CodingKeys.self)
270 | self.predictTime = try container.decodeIfPresent(TimeInterval.self, forKey: .predictTime)
271 | }
272 | }
273 |
274 | // MARK: - Hashable
275 |
276 | extension Prediction: Equatable where Input: Equatable, Output: Equatable {}
277 | extension Prediction: Hashable where Input: Hashable, Output: Hashable {}
278 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2023, Replicate, Inc.
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/Tests/ReplicateTests/ClientTests.swift:
--------------------------------------------------------------------------------
1 | import XCTest
2 | @testable import Replicate
3 |
4 | final class ClientTests: XCTestCase {
5 | var client = Client.valid
6 |
7 | static override func setUp() {
8 | URLProtocol.registerClass(MockURLProtocol.self)
9 | }
10 |
11 | func testRunWithVersion() async throws {
12 | let identifier: Identifier = "test/example:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
13 | let output = try await client.run(identifier, input: ["text": "Alice"])
14 | XCTAssertEqual(output, ["Hello, Alice!"])
15 | }
16 |
17 | func testRunWithModel() async throws {
18 | let identifier: Identifier = "meta/llama-2-70b-chat"
19 | let output = try await client.run(identifier, input: ["prompt": "Please write a haiku about llamas."])
20 | XCTAssertEqual(output, ["I'm sorry, I'm afraid I can't do that"] )
21 | }
22 |
23 | func testRunWithInvalidVersion() async throws {
24 | let identifier: Identifier = "test/example:invalid"
25 | do {
26 | _ = try await client.run(identifier, input: ["text": "Alice"])
27 | XCTFail()
28 | } catch {
29 | XCTAssertEqual(error.localizedDescription, "Invalid version")
30 | }
31 | }
32 |
33 | func testCreatePredictionWithVersion() async throws {
34 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
35 | let prediction = try await client.createPrediction(version: version, input: ["text": "Alice"])
36 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq")
37 | XCTAssertEqual(prediction.versionID, version)
38 | XCTAssertEqual(prediction.status, .starting)
39 | }
40 |
41 | func testCreatePredictionWithModel() async throws {
42 | let model: Model.ID = "meta/llama-2-70b-chat"
43 | let prediction = try await client.createPrediction(model: model, input: ["prompt": "Please write a poem about camelids"])
44 | XCTAssertEqual(prediction.id, "heat2o3bzn3ahtr6bjfftvbaci")
45 | XCTAssertEqual(prediction.modelID, model)
46 | XCTAssertEqual(prediction.status, .starting)
47 | }
48 |
49 | func testCreatePredictionUsingDeployment() async throws {
50 | let deployment: Deployment.ID = "replicate/deployment"
51 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
52 | let prediction = try await client.createPrediction(deployment: deployment, input: ["text": "Alice"])
53 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq")
54 | XCTAssertEqual(prediction.versionID, version)
55 | XCTAssertEqual(prediction.status, .starting)
56 | }
57 |
58 | func testCreatePredictionWithInvalidVersion() async throws {
59 | let version: Model.Version.ID = "invalid"
60 | do {
61 | _ = try await client.createPrediction(version: version, input: ["text": "Alice"])
62 | XCTFail()
63 | } catch {
64 | XCTAssertEqual(error.localizedDescription, "Invalid version")
65 | }
66 | }
67 |
68 | func testCreatePredictionAndWait() async throws {
69 | let version: Model.Version.ID = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
70 | var prediction = try await client.createPrediction(version: version, input: ["text": "Alice"])
71 | try await prediction.wait(with: client)
72 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq")
73 | XCTAssertEqual(prediction.versionID, version)
74 | XCTAssertEqual(prediction.status, .succeeded)
75 | }
76 |
77 | func testPredictionWaitWithStop() async throws {
78 | var prediction = try await client.getPrediction(id: "r6bjfddngldkt2o3bzn3ahtaci")
79 |
80 | var didUpdate = false
81 | try await prediction.wait(with: client) { _ in
82 | didUpdate = true
83 | throw CancellationError()
84 | }
85 |
86 | XCTAssertEqual(prediction.id, "r6bjfddngldkt2o3bzn3ahtaci")
87 | XCTAssertEqual(prediction.status, .processing)
88 | XCTAssertTrue(didUpdate)
89 | }
90 |
91 | func testPredictionWaitWithNonCancellationError() async throws {
92 | var prediction = try await client.getPrediction(id: "r6bjfddngldkt2o3bzn3ahtaci")
93 |
94 | struct CustomError: Swift.Error {}
95 |
96 | do {
97 | try await prediction.wait(with: client) { _ in
98 | throw CustomError()
99 | }
100 | XCTFail("Expected CustomError to be thrown")
101 | } catch {
102 | XCTAssertTrue(error is CustomError, "Expected CustomError, but got \(type(of: error))")
103 | }
104 |
105 | XCTAssertEqual(prediction.id, "r6bjfddngldkt2o3bzn3ahtaci")
106 | XCTAssertEqual(prediction.status, .processing)
107 | }
108 |
109 |
110 | func testGetPrediction() async throws {
111 | let prediction = try await client.getPrediction(id: "ufawqhfynnddngldkgtslldrkq")
112 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq")
113 | XCTAssertEqual(prediction.versionID, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
114 | XCTAssertEqual(prediction.source, .web)
115 | XCTAssertEqual(prediction.status, .succeeded)
116 | XCTAssertEqual(prediction.createdAt.timeIntervalSinceReferenceDate, 672703986.224, accuracy: 1)
117 | XCTAssertEqual(prediction.urls["cancel"]?.absoluteString, "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel")
118 |
119 | if #available(macOS 13.0, *) {
120 | XCTAssertEqual(prediction.progress?.completedUnitCount, 5)
121 | XCTAssertEqual(prediction.progress?.totalUnitCount, 5)
122 | }
123 | }
124 |
125 | func testCancelPrediction() async throws {
126 | let prediction = try await client.cancelPrediction(id: "ufawqhfynnddngldkgtslldrkq")
127 | XCTAssertEqual(prediction.id, "ufawqhfynnddngldkgtslldrkq")
128 | }
129 |
130 | func testGetPredictions() async throws {
131 | let predictions = try await client.listPredictions()
132 | XCTAssertNil(predictions.previous)
133 | XCTAssertEqual(predictions.next, "cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw")
134 | XCTAssertEqual(predictions.results.count, 1)
135 | }
136 |
137 | func testListModels() async throws {
138 | let models = try await client.listModels()
139 | XCTAssertEqual(models.results.count, 1)
140 | XCTAssertEqual(models.results.first?.owner, "replicate")
141 | XCTAssertEqual(models.results.first?.name, "hello-world")
142 | }
143 |
144 | func testGetModel() async throws {
145 | let model = try await client.getModel("replicate/hello-world")
146 | XCTAssertEqual(model.owner, "replicate")
147 | XCTAssertEqual(model.name, "hello-world")
148 | }
149 |
150 | func testGetModelVersions() async throws {
151 | let versions = try await client.listModelVersions("replicate/hello-world")
152 | XCTAssertNil(versions.previous)
153 | XCTAssertNil(versions.next)
154 | XCTAssertEqual(versions.results.count, 2)
155 | XCTAssertEqual(versions.results.first?.id, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
156 | }
157 |
158 | func testGetModelVersion() async throws {
159 | let version = try await client.getModelVersion("replicate/hello-world", version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
160 | XCTAssertEqual(version.id, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
161 | }
162 |
163 | func testCreateModel() async throws {
164 | let model = try await client.createModel(owner: "replicate", name: "hello-world", visibility: .public, hardware: "cpu")
165 | XCTAssertEqual(model.owner, "replicate")
166 | XCTAssertEqual(model.name, "hello-world")
167 | }
168 |
169 | func testListModelCollections() async throws {
170 | let collections = try await client.listModelCollections()
171 | XCTAssertEqual(collections.results.count, 1)
172 |
173 | XCTAssertEqual(collections.results.first?.slug, "super-resolution")
174 | XCTAssertEqual(collections.results.first?.models, nil)
175 |
176 | let collection = try await client.getModelCollection(collections.results.first!.slug)
177 | XCTAssertEqual(collection.slug, "super-resolution")
178 | XCTAssertEqual(collection.models, [])
179 | }
180 |
181 | func testGetModelCollection() async throws {
182 | let collection = try await client.getModelCollection("super-resolution")
183 | XCTAssertEqual(collection.slug, "super-resolution")
184 | }
185 |
186 | func testCreateTraining() async throws {
187 | let base: Model.ID = "example/base"
188 | let version: Model.Version.ID = "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a"
189 | let destination: Model.ID = "my/fork"
190 | let training = try await client.createTraining(model: base, version: version, destination: destination, input: ["data": "..."])
191 | XCTAssertEqual(training.id, "zz4ibbonubfz7carwiefibzgga")
192 | XCTAssertEqual(training.versionID, version)
193 | XCTAssertEqual(training.status, .starting)
194 | }
195 |
196 | func testGetTraining() async throws {
197 | let training = try await client.getTraining(id: "zz4ibbonubfz7carwiefibzgga")
198 | XCTAssertEqual(training.id, "zz4ibbonubfz7carwiefibzgga")
199 | XCTAssertEqual(training.versionID, "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a")
200 | XCTAssertEqual(training.source, .web)
201 | XCTAssertEqual(training.status, .succeeded)
202 | XCTAssertEqual(training.createdAt.timeIntervalSinceReferenceDate, 703980786.224, accuracy: 1)
203 | XCTAssertEqual(training.urls["cancel"]?.absoluteString, "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel")
204 | }
205 |
206 | func testCancelTraining() async throws {
207 | let training = try await client.cancelTraining(id: "zz4ibbonubfz7carwiefibzgga")
208 | XCTAssertEqual(training.id, "zz4ibbonubfz7carwiefibzgga")
209 | }
210 |
211 | func testGetTrainings() async throws {
212 | let trainings = try await client.listTrainings()
213 | XCTAssertNil(trainings.previous)
214 | XCTAssertEqual(trainings.next, "g5FWfcbO0EdVeR27rkXr0Z6tI0MjrW34ZejxnGzDeND3phpWWsyMGCQD")
215 | XCTAssertEqual(trainings.results.count, 1)
216 | }
217 |
218 | func testListHardware() async throws {
219 | let hardware = try await client.listHardware()
220 | XCTAssertGreaterThan(hardware.count, 1)
221 | XCTAssertEqual(hardware.first?.name, "CPU")
222 | XCTAssertEqual(hardware.first?.sku, "cpu")
223 | }
224 |
225 | func testCurrentAccount() async throws {
226 | let account = try await client.getCurrentAccount()
227 | XCTAssertEqual(account.type, .organization)
228 | XCTAssertEqual(account.username, "replicate")
229 | XCTAssertEqual(account.name, "Replicate")
230 | XCTAssertEqual(account.githubURL?.absoluteString, "https://github.com/replicate")
231 | }
232 |
233 | func testGetDeployment() async throws {
234 | let deployment = try await client.getDeployment("replicate/my-app-image-generator")
235 | XCTAssertEqual(deployment.owner, "replicate")
236 | XCTAssertEqual(deployment.name, "my-app-image-generator")
237 | XCTAssertEqual(deployment.currentRelease?.number, 1)
238 | XCTAssertEqual(deployment.currentRelease?.model, "stability-ai/sdxl")
239 | XCTAssertEqual(deployment.currentRelease?.version, "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf")
240 | XCTAssertEqual(deployment.currentRelease!.createdAt.timeIntervalSinceReferenceDate, 729707577.01, accuracy: 1)
241 | XCTAssertEqual(deployment.currentRelease?.createdBy.type, .organization)
242 | XCTAssertEqual(deployment.currentRelease?.createdBy.username, "replicate")
243 | XCTAssertEqual(deployment.currentRelease?.createdBy.name, "Replicate, Inc.")
244 | XCTAssertEqual(deployment.currentRelease?.createdBy.githubURL?.absoluteString, "https://github.com/replicate")
245 | XCTAssertEqual(deployment.currentRelease?.configuration.hardware, "gpu-t4")
246 | XCTAssertEqual(deployment.currentRelease?.configuration.scaling.minInstances, 1)
247 | XCTAssertEqual(deployment.currentRelease?.configuration.scaling.maxInstances, 5)
248 | }
249 |
250 | func testCustomBaseURL() async throws {
251 | let client = Client(baseURLString: "https://v1.replicate.proxy", token: MockURLProtocol.validToken).mocked
252 | let collection = try await client.getModelCollection("super-resolution")
253 | XCTAssertEqual(collection.slug, "super-resolution")
254 | }
255 |
256 | func testInvalidToken() async throws {
257 | do {
258 | let _ = try await Client.invalid.listPredictions()
259 | XCTFail("unauthenticated requests should fail")
260 | } catch {
261 | guard let error = error as? Replicate.Error else {
262 | return XCTFail("invalid error")
263 | }
264 |
265 | XCTAssertEqual(error.detail, "Invalid token.")
266 | }
267 | }
268 |
269 | func testUnauthenticated() async throws {
270 | do {
271 | let _ = try await Client.unauthenticated.listPredictions()
272 | XCTFail("unauthenticated requests should fail")
273 | } catch {
274 | guard let error = error as? Replicate.Error else {
275 | return XCTFail("invalid error")
276 | }
277 |
278 | XCTAssertEqual(error.detail, "Authentication credentials were not provided.")
279 | }
280 | }
281 |
282 | func testSearchModels() async throws {
283 | let models = try await client.searchModels(query: "greeter")
284 | XCTAssertEqual(models.results.count, 1)
285 | XCTAssertEqual(models.results[0].owner, "replicate")
286 | XCTAssertEqual(models.results[0].name, "hello-world")
287 | }
288 | }
289 |
--------------------------------------------------------------------------------
/Tests/ReplicateTests/Helpers/MockURLProtocol.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | #if canImport(FoundationNetworking)
4 | import FoundationNetworking
5 | #endif
6 |
7 | @testable import Replicate
8 |
9 | class MockURLProtocol: URLProtocol {
10 | static let validToken = ""
11 |
12 | override class func canInit(with task: URLSessionTask) -> Bool {
13 | return true
14 | }
15 |
16 | override class func canInit(with request: URLRequest) -> Bool {
17 | return true
18 | }
19 |
20 | override class func canonicalRequest(for request: URLRequest) -> URLRequest {
21 | return request
22 | }
23 |
24 | override func startLoading() {
25 | let statusCode: Int
26 | let json: String
27 |
28 | switch request.value(forHTTPHeaderField: "Authorization") {
29 | case "Bearer \(Self.validToken)":
30 | switch (request.httpMethod, request.url?.absoluteString) {
31 | case ("GET", "https://api.replicate.com/v1/account"?):
32 | statusCode = 200
33 | json = #"""
34 | {
35 | "type": "organization",
36 | "username": "replicate",
37 | "name": "Replicate",
38 | "github_url": "https://github.com/replicate"
39 | }
40 | """#
41 | case ("GET", "https://api.replicate.com/v1/predictions"?):
42 | statusCode = 200
43 | json = #"""
44 | {
45 | "previous": null,
46 | "next": "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
47 | "results": [
48 | {
49 | "id": "ufawqhfynnddngldkgtslldrkq",
50 | "model": "replicate/hello-world",
51 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
52 | "urls": {
53 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
54 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"
55 | },
56 | "created_at": "2022-04-26T22:13:06.224088Z",
57 | "completed_at": "2022-04-26T22:13:06.580379Z",
58 | "source": "web",
59 | "status": "starting",
60 | "input": {
61 | "text": "Alice"
62 | },
63 | "output": null,
64 | "error": null,
65 | "logs": null,
66 | "metrics": {}
67 | }
68 | ]
69 | }
70 | """#
71 | case ("POST", "https://api.replicate.com/v1/predictions"?),
72 | ("POST", "https://api.replicate.com/v1/deployments/replicate/deployment/predictions"?):
73 |
74 | if let body = request.json,
75 | body["version"] as? String == "invalid"
76 | {
77 | statusCode = 400
78 | json = #"{ "detail" : "Invalid version" }"#
79 | } else {
80 | statusCode = 201
81 |
82 | json = #"""
83 | {
84 | "id": "ufawqhfynnddngldkgtslldrkq",
85 | "model": "replicate/hello-world",
86 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
87 | "urls": {
88 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
89 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"
90 | },
91 | "created_at": "2022-04-26T22:13:06.224088Z",
92 | "completed_at": "2022-04-26T22:13:06.580379Z",
93 | "source": "web",
94 | "status": "starting",
95 | "input": {
96 | "text": "Alice"
97 | },
98 | "output": null,
99 | "error": null,
100 | "logs": null,
101 | "metrics": {}
102 | }
103 | """#
104 | }
105 | case ("POST", "https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions"?):
106 | statusCode = 201
107 | json = #"""
108 | {
109 | "id": "heat2o3bzn3ahtr6bjfftvbaci",
110 | "model": "meta/llama-2-70b-chat",
111 | "version": "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
112 | "input": {
113 | "prompt": "Please write a haiku about llamas."
114 | },
115 | "logs": "",
116 | "error": null,
117 | "status": "starting",
118 | "created_at": "2023-11-27T13:35:45.99397566Z",
119 | "urls": {
120 | "cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel",
121 | "get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"
122 | }
123 | }
124 | """#
125 | case ("GET", "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq"?):
126 | statusCode = 200
127 | json = #"""
128 | {
129 | "id": "ufawqhfynnddngldkgtslldrkq",
130 | "model": "replicate/hello-world",
131 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
132 | "urls": {
133 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
134 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"
135 | },
136 | "created_at": "2022-04-26T22:13:06.224088Z",
137 | "completed_at": "2022-04-26T22:15:06.224088Z",
138 | "source": "web",
139 | "status": "succeeded",
140 | "input": {
141 | "text": "Alice"
142 | },
143 | "output": ["Hello, Alice!"],
144 | "error": null,
145 | "logs": "Using seed: 12345,\n0%| | 0/5 [00:00, ?it/s]\n20%|██ | 1/5 [00:00<00:01, 21.38it/s]\n40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]\n60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]\n 80%|████████ | 4/5 [00:01<00:00, 22.86it/s]\n100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
146 | "metrics": {
147 | "predict_time": 10.0
148 | }
149 | }
150 | """#
151 | case ("GET", "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"?):
152 | statusCode = 200
153 | json = #"""
154 | {
155 | "id": "heat2o3bzn3ahtr6bjfftvbaci",
156 | "model": "meta/llama-2-70b-chat",
157 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
158 | "urls": {
159 | "get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci",
160 | "cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel"
161 | },
162 | "created_at": "2022-04-26T22:13:06.224088Z",
163 | "completed_at": "2022-04-26T22:15:06.224088Z",
164 | "source": "web",
165 | "status": "succeeded",
166 | "input": {
167 | "prompt": "Please write a haiku about llamas."
168 | },
169 | "output": ["I'm sorry, I'm afraid I can't do that"],
170 | "error": null,
171 | "logs": "",
172 | "metrics": {
173 | "predict_time": 1.0
174 | }
175 | }
176 | """#
177 | case ("GET", "https://api.replicate.com/v1/predictions/r6bjfddngldkt2o3bzn3ahtaci"?):
178 | statusCode = 200
179 | json = #"""
180 | {
181 | "id": "r6bjfddngldkt2o3bzn3ahtaci",
182 | "model": "meta/llama-2-70b-chat",
183 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
184 | "urls": {
185 | "get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci",
186 | "cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel"
187 | },
188 | "created_at": "2022-04-26T22:13:06.224088Z",
189 | "completed_at": "2022-04-26T22:15:06.224088Z",
190 | "source": "web",
191 | "status": "processing",
192 | "input": {
193 | "prompt": "Please write a haiku about llamas."
194 | },
195 | "output": null,
196 | "error": null,
197 | "logs": "",
198 | "metrics": {
199 | "predict_time": 1.0
200 | }
201 | }
202 | """#
203 | case ("POST", "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"?):
204 | statusCode = 200
205 | json = #"""
206 | {
207 | "id": "ufawqhfynnddngldkgtslldrkq",
208 | "model": "replicate/hello-world",
209 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
210 | "urls": {
211 | "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
212 | "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"
213 | },
214 | "created_at": "2022-04-26T22:13:06.224088Z",
215 | "completed_at": "2022-04-26T22:15:06.224088Z",
216 | "source": "web",
217 | "status": "canceled",
218 | "input": {
219 | "text": "Alice"
220 | },
221 | "output": null,
222 | "error": null,
223 | "logs": "",
224 | "metrics": {}
225 | }
226 | """#
227 | case ("GET", "https://api.replicate.com/v1/hardware"?):
228 | statusCode = 200
229 | json = #"""
230 | [
231 | { "name": "CPU", "sku": "cpu" },
232 | { "name": "Nvidia T4 GPU", "sku": "gpu-t4" },
233 | { "name": "Nvidia A40 GPU", "sku": "gpu-a40-small" },
234 | { "name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" }
235 | ]
236 | """#
237 | case ("GET", "https://api.replicate.com/v1/models"?):
238 | statusCode = 200
239 | json = #"""
240 | {
241 | "next": null,
242 | "previous": null,
243 | "results": [
244 | {
245 | "url": "https://replicate.com/replicate/hello-world",
246 | "owner": "replicate",
247 | "name": "hello-world",
248 | "description": "A tiny model that says hello",
249 | "visibility": "public",
250 | "github_url": "https://github.com/replicate/cog-examples",
251 | "paper_url": null,
252 | "license_url": null,
253 | "run_count": 930512,
254 | "cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/9c1f748e-a9fc-4cfd-a497-68262ee6151a/replicate-prediction-caujujsgrng7.png",
255 | "default_example": {
256 | "completed_at": "2022-04-26T19:30:10.926419Z",
257 | "created_at": "2022-04-26T19:30:10.761396Z",
258 | "error": null,
259 | "id": "3s2vyrb3pfblrnyp2smdsxxjvu",
260 | "input": {
261 | "text": "Alice"
262 | },
263 | "logs": null,
264 | "metrics": {
265 | "predict_time": 2e-06
266 | },
267 | "output": "hello Alice",
268 | "started_at": "2022-04-26T19:30:10.926417Z",
269 | "status": "succeeded",
270 | "urls": {
271 | "get": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu",
272 | "cancel": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu/cancel"
273 | },
274 | "model": "replicate/hello-world",
275 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
276 | "webhook_completed": null
277 | },
278 | "latest_version": {
279 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
280 | "created_at": "2022-04-26T19:29:04.418669Z",
281 | "cog_version": "0.3.0",
282 | "openapi_schema": {
283 | "openapi": "3.1.0",
284 | "components": {
285 | "schemas": {
286 | "Input": {
287 | "type": "object",
288 | "title": "Input",
289 | "required": [
290 | "text"
291 | ],
292 | "properties": {
293 | "text": {
294 | "type": "string",
295 | "title": "Text",
296 | "x-order": 0,
297 | "description": "Text to prefix with 'hello '"
298 | }
299 | }
300 | },
301 | "Output": {
302 | "type": "string",
303 | "title": "Output"
304 | }
305 | }
306 | }
307 | }
308 | }
309 | }
310 | ]
311 | }
312 | """#
313 | case ("QUERY", "https://api.replicate.com/v1/models"?):
314 | statusCode = 200
315 | json = #"""
316 | {
317 | "next": null,
318 | "previous": null,
319 | "results": [
320 | {
321 | "url": "https://replicate.com/replicate/hello-world",
322 | "owner": "replicate",
323 | "name": "hello-world",
324 | "description": "A tiny model that says hello",
325 | "visibility": "public",
326 | "github_url": "https://github.com/replicate/cog-examples",
327 | "paper_url": null,
328 | "license_url": null,
329 | "run_count": 930512,
330 | "cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/9c1f748e-a9fc-4cfd-a497-68262ee6151a/replicate-prediction-caujujsgrng7.png",
331 | "default_example": {
332 | "completed_at": "2022-04-26T19:30:10.926419Z",
333 | "created_at": "2022-04-26T19:30:10.761396Z",
334 | "error": null,
335 | "id": "3s2vyrb3pfblrnyp2smdsxxjvu",
336 | "input": {
337 | "text": "Alice"
338 | },
339 | "logs": null,
340 | "metrics": {
341 | "predict_time": 2e-06
342 | },
343 | "output": "hello Alice",
344 | "started_at": "2022-04-26T19:30:10.926417Z",
345 | "status": "succeeded",
346 | "urls": {
347 | "get": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu",
348 | "cancel": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu/cancel"
349 | },
350 | "model": "replicate/hello-world",
351 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
352 | "webhook_completed": null
353 | },
354 | "latest_version": {
355 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
356 | "created_at": "2022-04-26T19:29:04.418669Z",
357 | "cog_version": "0.3.0",
358 | "openapi_schema": {
359 | "openapi": "3.1.0",
360 | "components": {
361 | "schemas": {
362 | "Input": {
363 | "type": "object",
364 | "title": "Input",
365 | "required": [
366 | "text"
367 | ],
368 | "properties": {
369 | "text": {
370 | "type": "string",
371 | "title": "Text",
372 | "x-order": 0,
373 | "description": "Text to prefix with 'hello '"
374 | }
375 | }
376 | },
377 | "Output": {
378 | "type": "string",
379 | "title": "Output"
380 | }
381 | }
382 | }
383 | }
384 | }
385 | }
386 | ]
387 | }
388 | """#
389 | case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world"?):
390 | statusCode = 200
391 | json = #"""
392 | {
393 | "url": "https://replicate.com/replicate/hello-world",
394 | "owner": "replicate",
395 | "name": "hello-world",
396 | "description": "A tiny model that says hello",
397 | "visibility": "public",
398 | "github_url": "https://github.com/replicate/cog-examples",
399 | "paper_url": null,
400 | "license_url": null,
401 | "run_count": 930512,
402 | "cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/9c1f748e-a9fc-4cfd-a497-68262ee6151a/replicate-prediction-caujujsgrng7.png",
403 | "default_example": {
404 | "completed_at": "2022-04-26T19:30:10.926419Z",
405 | "created_at": "2022-04-26T19:30:10.761396Z",
406 | "error": null,
407 | "id": "3s2vyrb3pfblrnyp2smdsxxjvu",
408 | "input": {
409 | "text": "Alice"
410 | },
411 | "logs": null,
412 | "metrics": {
413 | "predict_time": 2e-06
414 | },
415 | "output": "hello Alice",
416 | "started_at": "2022-04-26T19:30:10.926417Z",
417 | "status": "succeeded",
418 | "urls": {
419 | "get": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu",
420 | "cancel": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu/cancel"
421 | },
422 | "model": "replicate/hello-world",
423 | "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
424 | "webhook_completed": null
425 | },
426 | "latest_version": {
427 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
428 | "created_at": "2022-04-26T19:29:04.418669Z",
429 | "cog_version": "0.3.0",
430 | "openapi_schema": {
431 | "info": {
432 | "title": "Cog",
433 | "version": "0.1.0"
434 | },
435 | "paths": {
436 | "/": {
437 | "get": {
438 | "summary": "Root",
439 | "responses": {
440 | "200": {
441 | "content": {
442 | "application/json": {
443 | "schema": {}
444 | }
445 | },
446 | "description": "Successful Response"
447 | }
448 | },
449 | "operationId": "root__get"
450 | }
451 | },
452 | "/predictions": {
453 | "post": {
454 | "summary": "Predict",
455 | "responses": {
456 | "200": {
457 | "content": {
458 | "application/json": {
459 | "schema": {
460 | "$ref": "#/components/schemas/Response"
461 | }
462 | }
463 | },
464 | "description": "Successful Response"
465 | },
466 | "422": {
467 | "content": {
468 | "application/json": {
469 | "schema": {
470 | "$ref": "#/components/schemas/HTTPValidationError"
471 | }
472 | }
473 | },
474 | "description": "Validation Error"
475 | }
476 | },
477 | "description": "Run a single prediction on the model",
478 | "operationId": "predict_predictions_post",
479 | "requestBody": {
480 | "content": {
481 | "application/json": {
482 | "schema": {
483 | "$ref": "#/components/schemas/Request"
484 | }
485 | }
486 | }
487 | }
488 | }
489 | }
490 | },
491 | "openapi": "3.0.2",
492 | "components": {
493 | "schemas": {
494 | "Input": {
495 | "type": "object",
496 | "title": "Input",
497 | "required": [
498 | "text"
499 | ],
500 | "properties": {
501 | "text": {
502 | "type": "string",
503 | "title": "Text",
504 | "x-order": 0,
505 | "description": "Text to prefix with 'hello '"
506 | }
507 | }
508 | },
509 | "Output": {
510 | "type": "string",
511 | "title": "Output"
512 | },
513 | "Status": {
514 | "enum": [
515 | "processing",
516 | "succeeded",
517 | "failed"
518 | ],
519 | "type": "string",
520 | "title": "Status",
521 | "description": "An enumeration."
522 | },
523 | "Request": {
524 | "type": "object",
525 | "title": "Request",
526 | "properties": {
527 | "input": {
528 | "$ref": "#/components/schemas/Input"
529 | },
530 | "output_file_prefix": {
531 | "type": "string",
532 | "title": "Output File Prefix"
533 | }
534 | },
535 | "description": "The request body for a prediction"
536 | },
537 | "Response": {
538 | "type": "object",
539 | "title": "Response",
540 | "required": [
541 | "status"
542 | ],
543 | "properties": {
544 | "error": {
545 | "type": "string",
546 | "title": "Error"
547 | },
548 | "output": {
549 | "$ref": "#/components/schemas/Output"
550 | },
551 | "status": {
552 | "$ref": "#/components/schemas/Status"
553 | }
554 | },
555 | "description": "The response body for a prediction"
556 | },
557 | "ValidationError": {
558 | "type": "object",
559 | "title": "ValidationError",
560 | "required": [
561 | "loc",
562 | "msg",
563 | "type"
564 | ],
565 | "properties": {
566 | "loc": {
567 | "type": "array",
568 | "items": {
569 | "anyOf": [
570 | {
571 | "type": "string"
572 | },
573 | {
574 | "type": "integer"
575 | }
576 | ]
577 | },
578 | "title": "Location"
579 | },
580 | "msg": {
581 | "type": "string",
582 | "title": "Message"
583 | },
584 | "type": {
585 | "type": "string",
586 | "title": "Error Type"
587 | }
588 | }
589 | },
590 | "HTTPValidationError": {
591 | "type": "object",
592 | "title": "HTTPValidationError",
593 | "properties": {
594 | "detail": {
595 | "type": "array",
596 | "items": {
597 | "$ref": "#/components/schemas/ValidationError"
598 | },
599 | "title": "Detail"
600 | }
601 | }
602 | }
603 | }
604 | }
605 | }
606 | }
607 | }
608 | """#
609 | case ("POST", "https://api.replicate.com/v1/models"?):
610 | statusCode = 200
611 | json = #"""
612 | {
613 | "url": "https://replicate.com/replicate/hello-world",
614 | "owner": "replicate",
615 | "name": "hello-world",
616 | "description": "A tiny model that says hello",
617 | "visibility": "public",
618 | "github_url": null,
619 | "paper_url": null,
620 | "license_url": null,
621 | "run_count": 0,
622 | "cover_image_url": null,
623 | "default_example": null
624 | }
625 | """#
626 | case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world/versions"?):
627 | statusCode = 200
628 | json = #"""
629 | {
630 | "previous": null,
631 | "next": null,
632 | "results": [
633 | {
634 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
635 | "created_at": "2022-04-26T19:29:04.418669Z",
636 | "cog_version": "0.3.0",
637 | "openapi_schema": {}
638 | },
639 | {
640 | "id": "e2e8c39e0f77177381177ba8c4025421ec2d7e7d3c389a9b3d364f8de560024f",
641 | "created_at": "2022-03-21T13:01:04.418669Z",
642 | "cog_version": "0.3.0",
643 | "openapi_schema": {}
644 | }
645 | ]
646 | }
647 | """#
648 | case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"?):
649 | statusCode = 200
650 | json = #"""
651 | {
652 | "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
653 | "created_at": "2022-04-26T19:29:04.418669Z",
654 | "cog_version": "0.3.0",
655 | "openapi_schema": {}
656 | }
657 | """#
658 | case ("GET", "https://api.replicate.com/v1/collections"?):
659 | statusCode = 200
660 | json = #"""
661 | {
662 | "results": [
663 | {
664 | "name": "Super resolution",
665 | "slug": "super-resolution",
666 | "description": "Upscaling models that create high-quality images from low-quality images.",
667 | }
668 | ],
669 | "next": null,
670 | "previous": null
671 | }
672 | """#
673 | case ("GET", "https://api.replicate.com/v1/collections/super-resolution"?),
674 | ("GET", "https://v1.replicate.proxy/collections/super-resolution"?):
675 | statusCode = 200
676 | json = #"""
677 | {
678 | "name": "Super resolution",
679 | "slug": "super-resolution",
680 | "description": "Upscaling models that create high-quality images from low-quality images.",
681 | "models": []
682 | }
683 | """#
684 | case ("GET", "https://api.replicate.com/v1/trainings"?):
685 | statusCode = 200
686 | json = #"""
687 | {
688 | "previous": null,
689 | "next": "https://api.replicate.com/v1/trainings?cursor=g5FWfcbO0EdVeR27rkXr0Z6tI0MjrW34ZejxnGzDeND3phpWWsyMGCQD",
690 | "results": [
691 | {
692 | "id": "zz4ibbonubfz7carwiefibzgga",
693 | "model": "replicate/hello-world",
694 | "version": "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a",
695 | "urls": {
696 | "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
697 | "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
698 | },
699 | "created_at": "2022-04-26T22:13:06.224088Z",
700 | "completed_at": "2022-04-26T22:13:06.580379Z",
701 | "source": "web",
702 | "status": "starting",
703 | "input": {
704 | "data": "..."
705 | },
706 | "output": null,
707 | "error": null,
708 | "logs": null,
709 | "metrics": {}
710 | }
711 | ]
712 | }
713 | """#
714 | case ("POST", "https://api.replicate.com/v1/models/example/base/versions/4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a/trainings"?):
715 | statusCode = 201
716 |
717 | json = #"""
718 | {
719 | "id": "zz4ibbonubfz7carwiefibzgga",
720 | "version": "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a",
721 | "urls": {
722 | "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
723 | "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
724 | },
725 | "created_at": "2022-04-26T22:13:06.224088Z",
726 | "completed_at": "2022-04-26T22:13:06.580379Z",
727 | "source": "web",
728 | "status": "starting",
729 | "input": {
730 | "data": "..."
731 | },
732 | "output": null,
733 | "error": null,
734 | "logs": null,
735 | "metrics": {}
736 | }
737 | """#
738 | case ("GET", "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga"?):
739 | statusCode = 200
740 | json = #"""
741 | {
742 | "id": "zz4ibbonubfz7carwiefibzgga",
743 | "model": "replicate/hello-world",
744 | "version": "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a",
745 | "urls": {
746 | "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
747 | "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
748 | },
749 | "created_at": "2023-04-23T22:13:06.224088Z",
750 | "completed_at": "2023-04-23T22:15:06.224088Z",
751 | "source": "web",
752 | "status": "succeeded",
753 | "input": {
754 | "data": "..."
755 | },
756 | "output": {
757 | "version": "b024d792ace1084d2504b2fc3012f013cef3b99842add1e7d82d2136ea1b78ac",
758 | "weights": "https://relicate.delivery/example-weights.tar.gz"
759 | },
760 | "error": null,
761 | "logs": "",
762 | "metrics": {}
763 | }
764 | """#
765 | case ("POST", "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"?):
766 | statusCode = 200
767 | json = #"""
768 | {
769 | "id": "zz4ibbonubfz7carwiefibzgga",
770 | "model": "replicate/hello-world",
771 | "version": "4a056052b8b98f6db8d011a450abbcd09a408ec9280c29f22d3538af1099646a",
772 | "urls": {
773 | "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
774 | "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
775 | },
776 | "created_at": "2023-04-23T22:13:06.224088Z",
777 | "completed_at": "2023-04-23T22:15:06.224088Z",
778 | "source": "web",
779 | "status": "canceled",
780 | "input": {
781 | "data": "..."
782 | },
783 | "output": null,
784 | "error": null,
785 | "logs": "",
786 | "metrics": {}
787 | }
788 | """#
789 | case ("GET", "https://api.replicate.com/v1/deployments/replicate/my-app-image-generator"?):
790 | statusCode = 200
791 | json = #"""
792 | {
793 | "owner": "replicate",
794 | "name": "my-app-image-generator",
795 | "current_release": {
796 | "number": 1,
797 | "model": "stability-ai/sdxl",
798 | "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
799 | "created_at": "2024-02-15T16:32:57.018467Z",
800 | "created_by": {
801 | "type": "organization",
802 | "username": "replicate",
803 | "name": "Replicate, Inc.",
804 | "github_url": "https://github.com/replicate",
805 | },
806 | "configuration": {
807 | "hardware": "gpu-t4",
808 | "scaling": {
809 | "min_instances": 1,
810 | "max_instances": 5
811 | }
812 | }
813 | }
814 | }
815 | """#
816 | default:
817 | client?.urlProtocol(self, didFailWithError: URLError(.badURL))
818 | return
819 | }
820 | case nil:
821 | statusCode = 401
822 | json = #"""
823 | { "detail" : "Authentication credentials were not provided." }
824 | """#
825 | default:
826 | statusCode = 401
827 | json = #"""
828 | { "detail" : "Invalid token." }
829 | """#
830 | }
831 |
832 | guard let data = json.data(using: .utf8),
833 | let response = HTTPURLResponse(url: request.url!,
834 | statusCode: statusCode,
835 | httpVersion: "1.1",
836 | headerFields: [
837 | "Content-Type": "application/json"
838 | ])
839 | else {
840 | client?.urlProtocol(self, didFailWithError: URLError(.badServerResponse))
841 | return
842 | }
843 |
844 | client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed)
845 | client?.urlProtocol(self, didLoad: data)
846 | client?.urlProtocolDidFinishLoading(self)
847 | }
848 |
849 | override func stopLoading() {}
850 | }
851 |
852 | // MARK: -
853 |
854 | extension Client {
855 | static var valid: Client {
856 | return Client(token: MockURLProtocol.validToken).mocked
857 | }
858 |
859 | static var invalid: Client {
860 | return Client(token: "").mocked
861 | }
862 |
863 | static var unauthenticated: Client {
864 | return Client(token: "").mocked
865 | }
866 |
867 | var mocked: Self {
868 | let configuration = session.configuration
869 | configuration.protocolClasses = [MockURLProtocol.self]
870 | session = URLSession(configuration: configuration)
871 | return self
872 | }
873 | }
874 |
875 | private extension URLRequest {
876 | var json: [String: Any]? {
877 | var data = httpBody
878 | if let stream = httpBodyStream {
879 | let bufferSize = 1024
880 | data = Data()
881 | stream.open()
882 |
883 | while stream.hasBytesAvailable {
884 | var buffer = [UInt8](repeating: 0, count: bufferSize)
885 | let bytesRead = stream.read(&buffer, maxLength: bufferSize)
886 | if bytesRead > 0 {
887 | data?.append(buffer, count: bytesRead)
888 | } else {
889 | break
890 | }
891 | }
892 | }
893 |
894 | guard let data = data else { return nil }
895 | return try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any]
896 | }
897 | }
898 |
--------------------------------------------------------------------------------
/Sources/Replicate/Client.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | #if canImport(FoundationNetworking)
4 | import FoundationNetworking
5 | #endif
6 |
7 | /// A Replicate HTTP API client.
8 | ///
9 | /// See https://replicate.com/docs/reference/http
10 | public class Client {
11 | /// The base URL for requests made by the client.
12 | public let baseURLString: String
13 |
14 | /// The value for the `User-Agent` header sent in requests, if any.
15 | public let userAgent: String?
16 |
17 | /// The API token used in the `Authorization` header sent in requests.
18 | private let token: String
19 |
20 | /// The underlying client session.
21 | internal(set) public var session: URLSession
22 |
23 | /// The retry policy for requests made by the client.
24 | public var retryPolicy: RetryPolicy = .default
25 |
26 | /// Creates a client with the specified API token.
27 | ///
28 | /// You can get an Replicate API token on your
29 | /// [account page](https://replicate.com/account).
30 | ///
31 | /// - Parameters:
32 | /// - session: The underlying client session. Defaults to `URLSession(configuration: .default)`.
33 | /// - baseURLString: The base URL for requests made by the client. Defaults to "https://api.replicate.com/v1/".
34 | /// - userAgent: The value for the `User-Agent` header sent in requests, if any. Defaults to `nil`.
35 | /// - token: The API token.
36 | public init(
37 | session: URLSession = URLSession(configuration: .default),
38 | baseURLString: String = "https://api.replicate.com/v1/",
39 | userAgent: String? = nil,
40 | token: String
41 | )
42 | {
43 | var baseURLString = baseURLString
44 | if !baseURLString.hasSuffix("/") {
45 | baseURLString = baseURLString.appending("/")
46 | }
47 |
48 | self.baseURLString = baseURLString
49 | self.userAgent = userAgent
50 | self.token = token
51 | self.session = session
52 | }
53 |
54 | // MARK: -
55 |
56 | /// Runs a model and waits for its output.
57 | ///
58 | /// - Parameters:
59 | /// - identifier:
60 | /// The model version identifier in the format
61 | /// `{owner}/{name}` or `{owner}/{name}:{version}`.
62 | /// - input:
63 | /// The input depends on what model you are running.
64 | /// To see the available inputs,
65 | /// click the "Run with API" tab on the model you are running.
66 | /// For example,
67 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3)
68 | /// takes `prompt` as an input.
69 | /// - webhook:
70 | /// A webhook that is called when the prediction has completed.
71 | /// It will be a `POST` request where
72 | /// the request body is the same as
73 | /// the response body of the get prediction endpoint.
74 | /// If there are network problems,
75 | /// we will retry the webhook a few times,
76 | /// so make sure it can be safely called more than once.
77 | /// - type:
78 | /// The expected output type. Defaults to `Value.self`.
79 | /// - updateHandler:
80 | /// A closure that executes with the updated prediction
81 | /// after each polling request to the API.
82 | /// If the prediction is in a terminal state
83 | /// (e.g. `succeeded`, `failed`, or `canceled`),
84 | /// it's returned immediately and the closure is not executed.
85 | /// Use this to provide feedback to the user
86 | /// about the progress of the prediction,
87 | /// or throw `CancellationError` to stop waiting
88 | /// for the prediction to finish.
89 | /// - Returns:
90 | /// The output of the model,
91 | /// or `nil` if the output couldn't be decoded.
92 | /// - Throws:
93 | /// Any error thrown from the `updateHandler` closure
94 | /// other than ``CancellationError``.
95 | public func run(
96 | _ identifier: Identifier,
97 | input: Input,
98 | webhook: Webhook? = nil,
99 | _ type: Output.Type = Value.self,
100 | updateHandler: @escaping (Prediction) throws -> Void = { _ in () }
101 | ) async throws -> Output? {
102 | var prediction: Prediction
103 | if let version = identifier.version {
104 | prediction = try await createPrediction(Prediction.self,
105 | version: version,
106 | input: input,
107 | webhook: webhook)
108 | } else {
109 | prediction = try await createPrediction(Prediction.self,
110 | model: "\(identifier.owner)/\(identifier.name)",
111 | input: input,
112 | webhook: webhook)
113 | }
114 |
115 | try await prediction.wait(with: self, updateHandler: updateHandler)
116 |
117 | if prediction.status == .failed {
118 | throw prediction.error ?? Error(detail: "Prediction failed")
119 | }
120 |
121 | return prediction.output
122 | }
123 |
124 | // MARK: -
125 |
126 | /// Create a prediction
127 | ///
128 | /// - Parameters:
129 | /// - type: The type of the prediction. Defaults to `AnyPrediction.self`.
130 | /// - id: The ID of the model version that you want to run.
131 | /// You can get your model's versions using the API,
132 | /// or find them on the website by clicking
133 | /// the "Versions" tab on the Replicate model page,
134 | /// e.g. replicate.com/replicate/hello-world/versions,
135 | /// then copying the full SHA256 hash from the URL.
136 | ///
137 | /// The version ID is the same as the Docker image ID
138 | /// that's created when you build your model.
139 | /// - input:
140 | /// The input depends on what model you are running.
141 | /// To see the available inputs,
142 | /// click the "Run with API" tab on the model you are running.
143 | /// For example,
144 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3)
145 | /// takes `prompt` as an input.
146 | /// - webhook:
147 | /// A webhook that is called when the prediction has completed.
148 | /// It will be a `POST` request where
149 | /// the request body is the same as
150 | /// the response body of the get prediction endpoint.
151 | /// If there are network problems,
152 | /// we will retry the webhook a few times,
153 | /// so make sure it can be safely called more than once.
154 | /// - wait:
155 | /// If set to `true`,
156 | /// this method refreshes the prediction until it completes
157 | /// (``Prediction/status`` is `.succeeded` or `.failed`).
158 | /// By default, this is `false`,
159 | /// and this method returns the prediction object encoded
160 | /// in the original creation response
161 | /// (``Prediction/status`` is `.starting`).
162 | /// - Returns: The created prediction.
163 | @available(*, deprecated, message: "wait parameter is deprecated; use ``Prediction/wait(with:)`` or ``Client/run(_:input:webhook:_:)``")
164 | public func createPrediction(
165 | _ type: Prediction.Type = AnyPrediction.self,
166 | version id: Model.Version.ID,
167 | input: Input,
168 | webhook: Webhook? = nil,
169 | wait: Bool
170 | ) async throws -> Prediction {
171 | var params: [String: Value] = [
172 | "version": "\(id)",
173 | "input": try Value(input)
174 | ]
175 |
176 | if let webhook {
177 | params["webhook"] = "\(webhook.url.absoluteString)"
178 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" })
179 | }
180 |
181 | var prediction: Prediction = try await fetch(.post, "predictions", params: params)
182 | if wait {
183 | try await prediction.wait(with: self)
184 | return prediction
185 | } else {
186 | return prediction
187 | }
188 | }
189 |
190 | /// Create a prediction from a model version
191 | ///
192 | /// - Parameters:
193 | /// - type:
194 | /// The type of the prediction. Defaults to `AnyPrediction.self`.
195 | /// - version:
196 | /// The ID of the model version that you want to run.
197 | /// You can get your model's versions using the API,
198 | /// or find them on the website by clicking
199 | /// the "Versions" tab on the Replicate model page,
200 | /// e.g. replicate.com/replicate/hello-world/versions,
201 | /// then copying the full SHA256 hash from the URL.
202 | /// The version ID is the same as the Docker image ID
203 | /// that's created when you build your model.
204 | /// - input: The input depends on what model you are running.
205 | /// To see the available inputs,
206 | /// click the "Run with API" tab on the model you are running.
207 | /// For example,
208 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3)
209 | /// takes `prompt` as an input.
210 | /// - webhook: A webhook that is called when the prediction has completed.
211 | /// It will be a `POST` request where
212 | /// the request body is the same as
213 | /// the response body of the get prediction endpoint.
214 | /// If there are network problems,
215 | /// we will retry the webhook a few times,
216 | /// so make sure it can be safely called more than once.
217 | /// - stream: Whether to stream the prediction output.
218 | /// By default, this is `false`.
219 | /// - Returns: The created prediction.
220 | public func createPrediction(
221 | _ type: Prediction.Type = AnyPrediction.self,
222 | version id: Model.Version.ID,
223 | input: Input,
224 | webhook: Webhook? = nil,
225 | stream: Bool = false
226 | ) async throws -> Prediction {
227 | var params: [String: Value] = [
228 | "version": "\(id)",
229 | "input": try Value(input)
230 | ]
231 |
232 | if let webhook {
233 | params["webhook"] = "\(webhook.url.absoluteString)"
234 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" })
235 | }
236 |
237 | if stream {
238 | params["stream"] = true
239 | }
240 |
241 | return try await fetch(.post, "predictions", params: params)
242 | }
243 |
244 | /// Create a prediction from a model
245 | ///
246 | /// - Parameters:
247 | /// - type:
248 | /// The type of the prediction. Defaults to `AnyPrediction.self`.
249 | /// - model:
250 | /// The ID of the model that you want to run.
251 | /// - input:
252 | /// The input depends on what model you are running.
253 | /// To see the available inputs,
254 | /// click the "Run with API" tab on the model you are running.
255 | /// For example,
256 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3)
257 | /// takes `prompt` as an input.
258 | /// - webhook:
259 | /// A webhook that is called when the prediction has completed.
260 | /// It will be a `POST` request where
261 | /// the request body is the same as
262 | /// the response body of the get prediction endpoint.
263 | /// If there are network problems,
264 | /// we will retry the webhook a few times,
265 | /// so make sure it can be safely called more than once.
266 | /// - stream: Whether to stream the prediction output.
267 | /// By default, this is `false`.
268 | /// - Returns: The created prediction.
269 | public func createPrediction(
270 | _ type: Prediction.Type = AnyPrediction.self,
271 | model id: Model.ID,
272 | input: Input,
273 | webhook: Webhook? = nil,
274 | stream: Bool = false
275 | ) async throws -> Prediction {
276 | var params: [String: Value] = [
277 | "input": try Value(input)
278 | ]
279 |
280 | if let webhook {
281 | params["webhook"] = "\(webhook.url.absoluteString)"
282 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" })
283 | }
284 |
285 | if stream {
286 | params["stream"] = true
287 | }
288 |
289 | return try await fetch(.post, "models/\(id)/predictions", params: params)
290 | }
291 |
292 | /// Create a prediction using a deployment
293 | ///
294 | /// - Parameters:
295 | /// - type:
296 | /// The type of the prediction. Defaults to `AnyPrediction.self`.
297 | /// - deployment:
298 | /// The ID of the deployment.
299 | /// - input:
300 | /// The input depends on what model you are running.
301 | /// To see the available inputs,
302 | /// click the "Run with API" tab on the model you are running.
303 | /// For example,
304 | /// [stability-ai/stable-diffusion-3](https://replicate.com/stability-ai/stable-diffusion-3)
305 | /// takes `prompt` as an input.
306 | /// - webhook:
307 | /// A webhook that is called when the prediction has completed.
308 | /// It will be a `POST` request where
309 | /// the request body is the same as
310 | /// the response body of the get prediction endpoint.
311 | /// If there are network problems,
312 | /// we will retry the webhook a few times,
313 | /// so make sure it can be safely called more than once.
314 | /// - stream:
315 | /// Whether to stream the prediction output.
316 | /// By default, this is `false`.
317 | /// - Returns: The created prediction.
318 | public func createPrediction(
319 | _ type: Prediction.Type = AnyPrediction.self,
320 | deployment id: Deployment.ID,
321 | input: Input,
322 | webhook: Webhook? = nil,
323 | stream: Bool = false
324 | ) async throws -> Prediction {
325 | var params: [String: Value] = [
326 | "input": try Value(input)
327 | ]
328 |
329 | if let webhook {
330 | params["webhook"] = "\(webhook.url.absoluteString)"
331 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" })
332 | }
333 |
334 | if stream {
335 | params["stream"] = true
336 | }
337 |
338 | return try await fetch(.post, "deployments/\(id)/predictions", params: params)
339 | }
340 |
341 | @available(*, deprecated, renamed: "listPredictions(_:cursor:)")
342 | public func getPredictions(
343 | _ type: Prediction.Type = AnyPrediction.self,
344 | cursor: Pagination.Cursor? = nil
345 | ) async throws -> Pagination.Page>
346 | {
347 | return try await listPredictions(type, cursor: cursor)
348 | }
349 |
350 | /// List predictions
351 | ///
352 | /// - Parameters:
353 | /// - type:
354 | /// The type of the predictions. Defaults to `AnyPrediction.self`.
355 | /// - cursor:
356 | /// A pointer to a page of results to fetch.
357 | /// - Returns: A page of predictions.
358 | public func listPredictions(
359 | _ type: Prediction.Type = AnyPrediction.self,
360 | cursor: Pagination.Cursor? = nil
361 | ) async throws -> Pagination.Page>
362 | {
363 | return try await fetch(.get, "predictions", cursor: cursor)
364 | }
365 |
366 | /// Get a prediction
367 | ///
368 | /// - Parameters:
369 | /// - type: The type of the prediction. Defaults to `AnyPrediction.self`.
370 | /// - id: The ID of the prediction you want to fetch.
371 | /// - Returns: The requested prediction.
372 | public func getPrediction(
373 | _ type: Prediction.Type = AnyPrediction.self,
374 | id: Prediction.ID
375 | ) async throws -> Prediction {
376 | return try await fetch(.get, "predictions/\(id)")
377 | }
378 |
379 | /// Cancel a prediction
380 | ///
381 | /// - Parameters:
382 | /// - type:
383 | /// The type of the prediction. Defaults to `AnyPrediction.self`.
384 | /// - id:
385 | /// The ID of the prediction you want to cancel.
386 | /// - Returns: The canceled prediction.
387 | public func cancelPrediction(
388 | _ type: Prediction.Type = AnyPrediction.self,
389 | id: Prediction.ID
390 | ) async throws -> Prediction {
391 | return try await fetch(.post, "predictions/\(id)/cancel")
392 | }
393 |
394 | // MARK: -
395 |
396 | /// List public models
397 | /// - Parameters:
398 | /// - cursor:
399 | /// A pointer to a page of results to fetch.
400 | /// - Returns: A page of models.
401 | public func listModels(cursor: Pagination.Cursor? = nil)
402 | async throws -> Pagination.Page
403 | {
404 | return try await fetch(.get, "models", cursor: cursor)
405 | }
406 |
407 | /// Search for public models on Replicate.
408 | ///
409 | /// - Parameters:
410 | /// - query: The search query string.
411 | /// - Returns: A page of models matching the search query.
412 | public func searchModels(query: String) async throws -> Pagination.Page {
413 | var request = try createRequest(method: .query, path: "models")
414 | request.addValue("text/plain", forHTTPHeaderField: "Content-Type")
415 | request.httpBody = query.data(using: .utf8)
416 | return try await sendRequest(request)
417 | }
418 |
419 | /// Get a model
420 | ///
421 | /// - Parameters:
422 | /// - id: The model identifier, comprising
423 | /// the name of the user or organization that owns the model and
424 | /// the name of the model.
425 | /// For example, "stability-ai/stable-diffusion-3".
426 | /// - Returns: The requested model.
427 | public func getModel(_ id: Model.ID)
428 | async throws -> Model
429 | {
430 | return try await fetch(.get, "models/\(id)")
431 | }
432 |
433 | /// Create a model
434 | ///
435 | /// - Parameters:
436 | /// - owner:
437 | /// The name of the user or organization that will own the model.
438 | /// This must be the same as the user or organization that is making the API request.
439 | /// In other words, the API token used in the request must belong to this user or organization.
440 | /// - name:
441 | /// The name of the model.
442 | /// This must be unique among all models owned by the user or organization.
443 | /// - visibility:
444 | /// Whether the model should be public or private.
445 | /// A public model can be viewed and run by anyone,
446 | /// whereas a private model can be viewed and run only by the user or organization members
447 | /// that own the model.
448 | /// - hardware:
449 | /// The SKU for the hardware used to run the model.
450 | /// Possible values can be found by calling ``listHardware()``.
451 | /// - description:
452 | /// A description of the model.
453 | /// - githubURL:
454 | /// A URL for the model's source code on GitHub.
455 | /// - paperURL:
456 | /// A URL for the model's paper.
457 | /// - licenseURL:
458 | /// A URL for the model's license.
459 | /// - coverImageURL:
460 | /// A URL for the model's cover image.
461 | /// This should be an image file.
462 | /// - Returns: The created model.
463 | public func createModel(
464 | owner: String,
465 | name: String,
466 | visibility: Model.Visibility,
467 | hardware: Hardware.ID,
468 | description: String? = nil,
469 | githubURL: URL? = nil,
470 | paperURL: URL? = nil,
471 | licenseURL: URL? = nil,
472 | coverImageURL: URL? = nil
473 | ) async throws -> Model
474 | {
475 | var params: [String: Value] = [
476 | "owner": "\(owner)",
477 | "name": "\(name)",
478 | "visibility": "\(visibility.rawValue)",
479 | "hardware": "\(hardware)"
480 | ]
481 |
482 | if let description {
483 | params["description"] = "\(description)"
484 | }
485 |
486 | if let githubURL {
487 | params["github_url"] = "\(githubURL)"
488 | }
489 |
490 | if let paperURL {
491 | params["paper_url"] = "\(paperURL)"
492 | }
493 |
494 | if let licenseURL {
495 | params["license_url"] = "\(licenseURL)"
496 | }
497 |
498 | if let coverImageURL {
499 | params["cover_image_url"] = "\(coverImageURL)"
500 | }
501 |
502 | return try await fetch(.post, "models", params: params)
503 | }
504 |
505 | // MARK: -
506 |
507 | /// List hardware available for running a model on Replicate.
508 | ///
509 | /// - Returns: An array of hardware.
510 | public func listHardware() async throws -> [Hardware] {
511 | return try await fetch(.get, "hardware")
512 | }
513 |
514 |
515 | // MARK: -
516 |
517 | @available(*, deprecated, renamed: "listModelVersions(_:cursor:)")
518 | public func getModelVersions(_ id: Model.ID,
519 | cursor: Pagination.Cursor? = nil)
520 | async throws -> Pagination.Page
521 | {
522 | return try await listModelVersions(id, cursor: cursor)
523 | }
524 |
525 | /// List model versions
526 | ///
527 | /// - Parameters:
528 | /// - id:
529 | /// The model identifier, comprising
530 | /// the name of the user or organization that owns the model and
531 | /// the name of the model.
532 | /// For example, "stability-ai/stable-diffusion-3".
533 | /// - cursor:
534 | /// A pointer to a page of results to fetch.
535 | /// - Returns: A page of model versions.
536 | public func listModelVersions(_ id: Model.ID,
537 | cursor: Pagination.Cursor? = nil)
538 | async throws -> Pagination.Page
539 | {
540 | return try await fetch(.get, "models/\(id)/versions", cursor: cursor)
541 | }
542 |
543 | /// Get a model version
544 | ///
545 | /// - Parameters:
546 | /// - id:
547 | /// The model identifier, comprising
548 | /// the name of the user or organization that owns the model and
549 | /// the name of the model.
550 | /// For example, "stability-ai/stable-diffusion-3".
551 | /// - version:
552 | /// The ID of the version.
553 | public func getModelVersion(_ id: Model.ID,
554 | version: Model.Version.ID)
555 | async throws -> Model.Version
556 | {
557 | return try await fetch(.get, "models/\(id)/versions/\(version)")
558 | }
559 |
560 | // MARK: -
561 |
562 | /// List collections of models
563 | /// - Parameters:
564 | /// - Parameter cursor: A pointer to a page of results to fetch.
565 | public func listModelCollections(cursor: Pagination.Cursor? = nil)
566 | async throws -> Pagination.Page
567 | {
568 | return try await fetch(.get, "collections")
569 | }
570 |
571 | /// Get a collection of models
572 | ///
573 | /// - Parameters:
574 | /// - slug:
575 | /// The slug of the collection,
576 | /// like super-resolution or image-restoration.
577 | ///
578 | /// See
579 | public func getModelCollection(_ slug: String)
580 | async throws -> Model.Collection
581 | {
582 | return try await fetch(.get, "collections/\(slug)")
583 | }
584 |
585 | // MARK: -
586 |
587 | /// Train a model on Replicate.
588 | ///
589 | /// To find out which models can be trained,
590 | /// check out the [trainable language models collection](https://replicate.com/collections/trainable-language-models).
591 | ///
592 | /// - Parameters:
593 | /// - model:
594 | /// The base model used to train a new version.
595 | /// - id:
596 | /// The ID of the base model version
597 | /// that you're using to train a new model version.
598 | ///
599 | /// You can get your model's versions using the API,
600 | /// or find them on the website by clicking
601 | /// the "Versions" tab on the Replicate model page,
602 | /// e.g. replicate.com/replicate/hello-world/versions,
603 | /// then copying the full SHA256 hash from the URL.
604 | ///
605 | /// The version ID is the same as the Docker image ID
606 | /// that's created when you build your model.
607 | /// - destination:
608 | /// The desired model to push to in the format `{owner}/{model_name}`.
609 | /// This should be an existing model owned by
610 | /// the user or organization making the API request.
611 | /// - input:
612 | /// An object containing inputs to the
613 | /// Cog model's `train()` function.
614 | /// - webhook:
615 | /// A webhook that is called when the training has completed.
616 | ///
617 | /// It will be a `POST` request where
618 | /// the request body is the same as
619 | /// the response body of the get training endpoint.
620 | /// If there are network problems,
621 | /// we will retry the webhook a few times,
622 | /// so make sure it can be safely called more than once.
623 | public func createTraining(
624 | _ type: Training.Type = AnyTraining.self,
625 | model: Model.ID,
626 | version: Model.Version.ID,
627 | destination: Model.ID,
628 | input: Input,
629 | webhook: Webhook? = nil
630 | ) async throws -> Training
631 | {
632 | var params: [String: Value] = [
633 | "destination": "\(destination)",
634 | "input": try Value(input)
635 | ]
636 |
637 | if let webhook {
638 | params["webhook"] = "\(webhook.url.absoluteString)"
639 | params["webhook_events_filter"] = .array(webhook.events.map { "\($0.rawValue)" })
640 | }
641 |
642 | return try await fetch(.post, "models/\(model)/versions/\(version)/trainings", params: params)
643 | }
644 |
645 | @available(*, deprecated, renamed: "listTrainings(_:cursor:)")
646 | public func getTrainings(
647 | _ type: Training.Type = AnyTraining.self,
648 | cursor: Pagination.Cursor? = nil
649 | ) async throws -> Pagination.Page>
650 | {
651 | return try await listTrainings(type, cursor: cursor)
652 | }
653 |
654 | /// List trainings
655 | ///
656 | /// - Parameters:
657 | /// - type: The type of the training. Defaults to `AnyTraining.self`.
658 | /// - cursor: A pointer to a page of results to fetch.
659 | /// - Returns: A page of trainings.
660 | public func listTrainings(
661 | _ type: Training.Type = AnyTraining.self,
662 | cursor: Pagination.Cursor? = nil
663 | ) async throws -> Pagination.Page>
664 | {
665 | return try await fetch(.get, "trainings", cursor: cursor)
666 | }
667 |
668 | /// Get a training
669 | ///
670 | /// - Parameters:
671 | /// - type: The type of the training. Defaults to `AnyTraining.self`.
672 | /// - id: The ID of the training you want to fetch.
673 | /// - Returns: The requested training.
674 | public func getTraining(
675 | _ type: Training.Type = AnyTraining.self,
676 | id: Training.ID
677 | ) async throws -> Training
678 | {
679 | return try await fetch(.get, "trainings/\(id)")
680 | }
681 |
682 | /// Cancel a training
683 | ///
684 | /// - Parameters:
685 | /// - type: The type of the training. Defaults to `AnyTraining.self`.
686 | /// - id: The ID of the training you want to cancel.
687 | /// - Returns: The canceled training.
688 | public func cancelTraining(
689 | _ type: Training.Type = AnyTraining.self,
690 | id: Training.ID
691 | ) async throws -> Training
692 | {
693 | return try await fetch(.post, "trainings/\(id)/cancel")
694 | }
695 |
696 | // MARK: -
697 |
698 | /// Get the current account
699 | ///
700 | /// - Returns: The current account.
701 | public func getCurrentAccount() async throws -> Account {
702 | return try await fetch(.get, "account")
703 | }
704 |
705 | // MARK: -
706 |
707 | /// Get a deployment
708 | ///
709 | /// - Parameters:
710 | /// - id: The deployment identifier, comprising
711 | /// the name of the user or organization that owns the deployment and
712 | /// the name of the deployment.
713 | /// For example, "replicate/my-app-image-generator".
714 | /// - Returns: The requested deployment.
715 | public func getDeployment(_ id: Deployment.ID)
716 | async throws -> Deployment
717 | {
718 | return try await fetch(.get, "deployments/\(id)")
719 | }
720 |
721 | // MARK: -
722 |
723 | private enum Method: String, Hashable {
724 | case get = "GET"
725 | case post = "POST"
726 | case query = "QUERY"
727 | }
728 |
729 | private func fetch(_ method: Method,
730 | _ path: String,
731 | cursor: Pagination.Cursor?)
732 | async throws -> Pagination.Page {
733 | var params: [String: Value]? = nil
734 | if let cursor {
735 | params = ["cursor": "\(cursor)"]
736 | }
737 |
738 | let request = try createRequest(method: method, path: path, params: params)
739 | return try await sendRequest(request)
740 | }
741 |
742 | private func fetch(_ method: Method,
743 | _ path: String,
744 | params: [String: Value]? = nil)
745 | async throws -> T {
746 | let request = try createRequest(method: method, path: path, params: params)
747 | return try await sendRequest(request)
748 | }
749 |
750 | private func createRequest(method: Method, path: String, params: [String: Value]? = nil) throws -> URLRequest {
751 | var urlComponents = URLComponents(string: self.baseURLString.appending(path))
752 | var httpBody: Data? = nil
753 |
754 | switch method {
755 | case .get:
756 | if let params {
757 | var queryItems: [URLQueryItem] = []
758 | for (key, value) in params {
759 | queryItems.append(URLQueryItem(name: key, value: value.description))
760 | }
761 | urlComponents?.queryItems = queryItems
762 | }
763 | case .post:
764 | if let params {
765 | let encoder = JSONEncoder()
766 | httpBody = try encoder.encode(params)
767 | }
768 | case .query:
769 | if let params, let queryString = params["query"] {
770 | httpBody = queryString.description.data(using: .utf8)
771 | }
772 | }
773 |
774 | guard let url = urlComponents?.url else {
775 | throw Error(detail: "invalid request \(method) \(path)")
776 | }
777 |
778 | var request = URLRequest(url: url)
779 | request.httpMethod = method.rawValue
780 |
781 | if !token.isEmpty {
782 | request.addValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
783 | }
784 | request.addValue("application/json", forHTTPHeaderField: "Accept")
785 |
786 | if let httpBody {
787 | request.httpBody = httpBody
788 | request.addValue("application/json", forHTTPHeaderField: "Content-Type")
789 | }
790 |
791 | if let userAgent {
792 | request.addValue(userAgent, forHTTPHeaderField: "User-Agent")
793 | }
794 |
795 | return request
796 | }
797 |
798 | private func sendRequest(_ request: URLRequest) async throws -> T {
799 | let (data, response) = try await session.data(for: request)
800 |
801 | let decoder = JSONDecoder()
802 | decoder.dateDecodingStrategy = .iso8601WithFractionalSeconds
803 |
804 | switch (response as? HTTPURLResponse)?.statusCode {
805 | case (200..<300)?:
806 | return try decoder.decode(T.self, from: data)
807 | default:
808 | if let error = try? decoder.decode(Error.self, from: data) {
809 | throw error
810 | }
811 |
812 | if let string = String(data: data, encoding: .utf8) {
813 | throw Error(detail: "invalid response: \(response) \n \(string)")
814 | }
815 |
816 | throw Error(detail: "invalid response: \(response)")
817 | }
818 | }
819 | }
820 |
821 | // MARK: - Decodable
822 |
823 | extension Client.Pagination.Page: Decodable where Result: Decodable {
824 | private enum CodingKeys: String, CodingKey {
825 | case results
826 | case previous
827 | case next
828 | }
829 |
830 | public init(from decoder: Decoder) throws {
831 | let container = try decoder.container(keyedBy: CodingKeys.self)
832 | self.previous = try? container.decode(Client.Pagination.Cursor.self, forKey: .previous)
833 | self.next = try? container.decode(Client.Pagination.Cursor.self, forKey: .next)
834 | self.results = try container.decode([Result].self, forKey: .results)
835 | }
836 | }
837 |
838 | extension Client.Pagination.Cursor: Decodable {
839 | public init(from decoder: Decoder) throws {
840 | let container = try decoder.singleValueContainer()
841 | let string = try container.decode(String.self)
842 | guard let urlComponents = URLComponents(string: string),
843 | let queryItem = urlComponents.queryItems?.first(where: { $0.name == "cursor" }),
844 | let value = queryItem.value
845 | else {
846 | let context = DecodingError.Context(codingPath: container.codingPath, debugDescription: "invalid cursor")
847 | throw DecodingError.dataCorrupted(context)
848 | }
849 |
850 | self.rawValue = value
851 | }
852 | }
853 |
854 | // MARK: -
855 |
856 | extension Client {
857 | /// A namespace for pagination cursor and page types.
858 | public enum Pagination {
859 | /// A paginated collection of results.
860 | public struct Page {
861 | /// A pointer to the previous page of results
862 | public let previous: Cursor?
863 |
864 | /// A pointer to the next page of results.
865 | public let next: Cursor?
866 |
867 | /// The results for this page.
868 | public let results: [Result]
869 | }
870 |
871 | /// A pointer to a page of results.
872 | public struct Cursor: RawRepresentable, Hashable {
873 | public var rawValue: String
874 |
875 | public init(rawValue: String) {
876 | self.rawValue = rawValue
877 | }
878 | }
879 | }
880 | }
881 |
882 | extension Client.Pagination.Cursor: CustomStringConvertible {
883 | public var description: String {
884 | return self.rawValue
885 | }
886 | }
887 |
888 | extension Client.Pagination.Cursor: ExpressibleByStringLiteral {
889 | public init(stringLiteral value: String) {
890 | self.init(rawValue: value)
891 | }
892 | }
893 |
894 | // MARK: -
895 |
896 | extension Client {
897 | /// A policy for how often a client should retry a request.
898 | public struct RetryPolicy: Equatable, Sequence {
899 | /// A strategy used to determine how long to wait between retries.
900 | public enum Strategy: Hashable {
901 | /// Wait for a constant interval.
902 | ///
903 | /// This strategy implements constant backoff with jitter
904 | /// as described by the equation:
905 | ///
906 | /// $$
907 | /// t = d + R([-j/2, j/2])
908 | /// $$
909 | ///
910 | /// - Parameters:
911 | /// - duration: The constant interval ($d$).
912 | /// - jitter: The amount of random jitter ($j$).
913 | case constant(duration: TimeInterval = 2.0,
914 | jitter: Double = 0.0)
915 |
916 | /// Wait for an exponentially increasing interval.
917 | ///
918 | /// This strategy implements exponential backoff with jitter
919 | /// as described by the equation:
920 | ///
921 | /// $$
922 | /// t = b^c + R([-j/2, j/2])
923 | /// $$
924 | ///
925 | /// - Parameters:
926 | /// - base: The power base ($b$).
927 | /// - multiplier: The power exponent ($c$).
928 | /// - jitter: The amount of random jitter ($j$).
929 | case exponential(base: TimeInterval = 2.0,
930 | multiplier: Double = 2.0,
931 | jitter: Double = 0.5)
932 | }
933 |
934 | /// The strategy used to determine how long to wait between retries.
935 | public let strategy: Strategy
936 |
937 | /// The total maximum amount of time to retry requests.
938 | public let timeout: TimeInterval?
939 |
940 | /// The maximum amount of time between requests.
941 | public let maximumInterval: TimeInterval?
942 |
943 | /// The maximum number of requests to make.
944 | public let maximumRetries: Int?
945 |
946 | /// The default retry policy.
947 | static let `default` = RetryPolicy(strategy: .exponential(),
948 | timeout: 300.0,
949 | maximumInterval: 30.0,
950 | maximumRetries: 10)
951 |
952 | /// Creates a new retry policy.
953 | ///
954 | /// - Parameters:
955 | /// - strategy: The strategy used to determine how long to wait between retries.
956 | /// - timeout: The total maximum amount of time to retry requests.
957 | /// Must be greater than zero, if specified.
958 | /// - maximumInterval: The maximum amount of time between requests.
959 | /// Must be greater than zero, if specified.
960 | /// - maximumRetries: The maximum number of requests to make.
961 | /// Must be greater than zero, if specified.
962 | public init(strategy: Strategy,
963 | timeout: TimeInterval?,
964 | maximumInterval: TimeInterval?,
965 | maximumRetries: Int?)
966 | {
967 | precondition(timeout ?? .greatestFiniteMagnitude > 0)
968 | precondition(maximumInterval ?? .greatestFiniteMagnitude > 0)
969 | precondition(maximumRetries ?? .max > 0)
970 |
971 | self.strategy = strategy
972 | self.timeout = timeout
973 | self.maximumInterval = maximumInterval
974 | self.maximumRetries = maximumRetries
975 | }
976 |
977 | /// An instantiation of a retry policy.
978 | ///
979 | /// This type satisfies a requirement for `RetryPolicy`
980 | /// to conform to the `Sequence` protocol.
981 | public struct Retrier: IteratorProtocol {
982 | /// The number of retry attempts made.
983 | public private(set) var retries: Int = 0
984 |
985 | /// The retry policy.
986 | public let policy: RetryPolicy
987 |
988 | /// The random number generator used to create random values.
989 | private var randomNumberGenerator: any RandomNumberGenerator
990 |
991 | /// A time after which no delay values are produced, if any.
992 | public let deadline: DispatchTime?
993 |
994 | /// Creates a new instantiation of a retry policy.
995 | ///
996 | /// - Parameters:
997 | /// - policy: The retry policy.
998 | /// - randomNumberGenerator: The random number generator used to create random values.
999 | /// - deadline: A time after which no delay values are produced, if any.
1000 | init(policy: RetryPolicy,
1001 | randomNumberGenerator: any RandomNumberGenerator = SystemRandomNumberGenerator(),
1002 | deadline: DispatchTime?)
1003 | {
1004 | self.policy = policy
1005 | self.randomNumberGenerator = randomNumberGenerator
1006 | self.deadline = deadline
1007 | }
1008 |
1009 | /// Returns the next delay amount, or `nil`.
1010 | public mutating func next() -> TimeInterval? {
1011 | guard policy.maximumRetries.flatMap({ $0 > retries }) ?? true else { return nil }
1012 | guard deadline.flatMap({ $0 > .now() }) ?? true else { return nil }
1013 |
1014 | defer { retries += 1 }
1015 |
1016 | let delay: TimeInterval
1017 | switch policy.strategy {
1018 | case .constant(let base, let jitter):
1019 | delay = base + Double.random(jitter: jitter, using: &randomNumberGenerator)
1020 | case .exponential(let base, let multiplier, let jitter):
1021 | delay = base * (pow(multiplier, Double(retries))) + Double.random(jitter: jitter, using: &randomNumberGenerator)
1022 | }
1023 |
1024 | return delay.clamped(to: 0...(policy.maximumInterval ?? .greatestFiniteMagnitude))
1025 | }
1026 | }
1027 |
1028 | // Returns a new instantiation of the retry policy.
1029 | public func makeIterator() -> Retrier {
1030 | return Retrier(policy: self,
1031 | deadline: timeout.flatMap {
1032 | .now().advanced(by: .nanoseconds(Int($0 * 1e+9)))
1033 | })
1034 | }
1035 | }
1036 | }
1037 |
1038 | // MARK: -
1039 |
1040 | extension JSONDecoder.DateDecodingStrategy {
1041 | static let iso8601WithFractionalSeconds = custom { decoder in
1042 | let container = try decoder.singleValueContainer()
1043 | let string = try container.decode(String.self)
1044 |
1045 | let formatter = ISO8601DateFormatter()
1046 | formatter.formatOptions = [.withInternetDateTime,
1047 | .withFractionalSeconds]
1048 |
1049 | if let date = formatter.date(from: string) {
1050 | return date
1051 | }
1052 |
1053 | // Try again without fractional seconds
1054 | formatter.formatOptions = [.withInternetDateTime]
1055 |
1056 | guard let date = formatter.date(from: string) else {
1057 | throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid date: \(string)")
1058 | }
1059 |
1060 | return date
1061 | }
1062 | }
1063 |
1064 | private extension Double {
1065 | static func random(jitter amount: Double,
1066 | using generator: inout T) -> Double
1067 | where T : RandomNumberGenerator
1068 | {
1069 | guard !amount.isZero else { return 0.0 }
1070 | return Double.random(in: (-amount / 2)...(amount / 2))
1071 | }
1072 | }
1073 |
1074 | private extension Comparable {
1075 | func clamped(to range: ClosedRange) -> Self {
1076 | return min(max(self, range.lowerBound), range.upperBound)
1077 | }
1078 | }
1079 |
--------------------------------------------------------------------------------