├── .eslintrc.json ├── .gitignore ├── .gitpod.yml ├── .husky └── pre-commit ├── .lintstagedrc.json ├── .prettierignore ├── .prettierrc.json ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── flake.lock ├── flake.nix ├── jest.config.json ├── package.json ├── src ├── cancelPrediction.ts ├── getModel.ts ├── getPrediction.ts ├── helpers │ ├── convertPrediction.ts │ ├── convertShallowPrediction.ts │ ├── extractModelAndOwner.ts │ ├── loadFile.ts │ ├── makeApiRequest.ts │ └── uploadFile.ts ├── index.ts ├── listPredictions.ts ├── listVersions.ts ├── pollPrediction.ts ├── predict.ts ├── processWebhook.ts └── tests │ ├── predict.test.ts │ └── upload.test.ts ├── testaudio.mp3 ├── tsconfig.build.json ├── tsconfig.json └── yarn.lock /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["eslint:recommended", "plugin:import/recommended", "plugin:jest/recommended"], 3 | "plugins": ["import"], 4 | "rules": { 5 | "no-undef": "off", 6 | "no-unused-vars": [ 7 | "off", 8 | { 9 | "ignoreRestSiblings": true 10 | } 11 | ], 12 | "import/no-duplicates": [ 13 | "error", 14 | { 15 | "considerQueryString": true 16 | } 17 | ], 18 | "import/named": "off", 19 | "import/first": "error", 20 | "import/no-namespace": "error", 21 | "import/extensions": [ 22 | "error", 23 | "always", 24 | { 25 | "js": "never", 26 | "jsx": "never", 27 | "ts": "never", 28 | "tsx": "never" 29 | } 30 | ], 31 | "import/order": [ 32 | "error", 33 | { 34 | "groups": [ 35 | ["internal", "external", "builtin"], 36 | ["parent", "sibling", "index"] 37 | ], 38 | "alphabetize": { 39 | "order": "asc", 40 | "caseInsensitive": true 41 | } 42 | } 43 | ], 44 | "import/newline-after-import": "error", 45 | "import/no-anonymous-default-export": "error", 46 | "import/no-dynamic-require": "error", 47 | "import/no-self-import": "error", 48 | "import/no-useless-path-segments": [ 49 | "error", 50 | { 51 | "noUselessIndex": true 52 | } 53 | ], 54 | "import/no-relative-packages": "error", 55 | "import/no-unused-modules": "error", 56 | "import/no-deprecated": "error", 57 | "import/no-commonjs": "error", 58 | "import/no-amd": "error", 59 | "import/no-mutable-exports": "error", 60 | "import/no-unassigned-import": "error", 61 | "jest/expect-expect": "off" 62 | }, 63 | "settings": { 64 | "import/external-module-folders": ["node_modules", "node_modules/@types"], 65 | "import/parsers": { 66 | "@typescript-eslint/parser": [".ts", ".tsx"] 67 | }, 68 | "import/resolver": { 69 | "typescript": { 70 | "alwaysTryTypes": true, 71 | "project": "/tsconfig.json" 72 | }, 73 | "node": true 74 | } 75 | }, 76 | "overrides": [ 77 | { 78 | "files": ["src/**/*.ts", "src/**/*.tsx"], 79 | "extends": [ 80 | "eslint:recommended", 81 | "plugin:@typescript-eslint/eslint-recommended", 82 | "plugin:@typescript-eslint/recommended", 83 | "plugin:import/recommended", 84 | "plugin:jest/recommended" 85 | ], 86 | "plugins": ["import", "@typescript-eslint"], 87 | "rules": { 88 | "@typescript-eslint/explicit-module-boundary-types": "off", 89 | "@typescript-eslint/no-unused-vars": [ 90 | "error", 91 | { 92 | "ignoreRestSiblings": true 93 | } 94 | ], 95 | "@typescript-eslint/switch-exhaustiveness-check": "error", 96 | "@typescript-eslint/no-empty-function": "off" 97 | }, 98 | "parserOptions": { 99 | "parser": "@typescript-eslint/parser", 100 | "ecmaVersion": 2018, 101 | "sourceType": "module", 102 | "project": "./tsconfig.json" 103 | } 104 | } 105 | ] 106 | } 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # node 2 | node_modules 3 | logs 4 | *.log 5 | npm-debug.log* 6 | yarn-debug.log* 7 | yarn-error.log* 8 | lerna-debug.log* 9 | .pnpm-debug.log* 10 | /.pnp 11 | .pnp.js 12 | # misc 13 | .DS_Store 14 | *.pem 15 | # typescript 16 | *.tsbuildinfo 17 | dist/ 18 | # eslint 19 | .eslintcache 20 | # replicate token 21 | src/tests/token.ts -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | tasks: 2 | - init: yarn install 3 | -------------------------------------------------------------------------------- /.husky/pre-commit: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | . "$(dirname -- "$0")/_/husky.sh" 3 | 4 | FORCE_COLOR=1 "$(dirname -- "$0")/../node_modules/.bin/lint-staged" -c .lintstagedrc.json -------------------------------------------------------------------------------- /.lintstagedrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "*.+(ts|tsx)": ["prettier --write", "eslint --cache --fix", "tsc-files --noEmit"], 3 | "*.+(js|jsx)": ["prettier --write", "eslint --cache --fix"], 4 | "*.+(json|css|md|yml|yaml|scss)": ["prettier --write"] 5 | } 6 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | # node 2 | node_modules 3 | logs 4 | *.log 5 | npm-debug.log* 6 | yarn-debug.log* 7 | yarn-error.log* 8 | lerna-debug.log* 9 | .pnpm-debug.log* 10 | /.pnp 11 | .pnp.js 12 | # misc 13 | .DS_Store 14 | *.pem 15 | # typescript 16 | *.tsbuildinfo 17 | dist/ 18 | # eslint 19 | .eslintcache -------------------------------------------------------------------------------- /.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "arrowParens": "avoid", 3 | "bracketSpacing": true, 4 | "embeddedLanguageFormatting": "auto", 5 | "htmlWhitespaceSensitivity": "css", 6 | "insertPragma": false, 7 | "bracketSameLine": false, 8 | "jsxSingleQuote": false, 9 | "printWidth": 120, 10 | "proseWrap": "always", 11 | "quoteProps": "consistent", 12 | "requirePragma": false, 13 | "semi": false, 14 | "singleQuote": false, 15 | "tabWidth": 2, 16 | "trailingComma": "es5", 17 | "useTabs": false, 18 | "vueIndentScriptAndStyle": false, 19 | "plugins": ["prettier-plugin-organize-imports"], 20 | "overrides": [ 21 | { 22 | "files": ["flake.lock"], 23 | "options": { 24 | "parser": "json" 25 | } 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "dbaeumer.vscode-eslint", 4 | "eamodio.gitlens", 5 | "wix.vscode-import-cost", 6 | "orta.vscode-jest", 7 | "esbenp.prettier-vscode" 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.defaultFormatter": "esbenp.prettier-vscode", 3 | "[typescriptreact]": { 4 | "editor.defaultFormatter": "esbenp.prettier-vscode" 5 | }, 6 | "[typescript]": { 7 | "editor.defaultFormatter": "esbenp.prettier-vscode" 8 | }, 9 | "[javascriptreact]": { 10 | "editor.defaultFormatter": "esbenp.prettier-vscode" 11 | }, 12 | "[javascript]": { 13 | "editor.defaultFormatter": "esbenp.prettier-vscode" 14 | }, 15 | "[jsonc]": { 16 | "editor.defaultFormatter": "esbenp.prettier-vscode" 17 | }, 18 | "[json]": { 19 | "editor.defaultFormatter": "esbenp.prettier-vscode" 20 | }, 21 | "editor.formatOnSave": true, 22 | "editor.codeActionsOnSave": { 23 | "source.fixAll.eslint": "explicit" 24 | }, 25 | "gitlens.hovers.currentLine.over": "line", 26 | "editor.tabCompletion": "on", 27 | "editor.inlineSuggest.enabled": true, 28 | "typescript.updateImportsOnFileMove.enabled": "always", 29 | "eslint.packageManager": "yarn", 30 | "eslint.validate": ["javascript", "typescript", "html", "javascriptreact", "typescriptreact"], 31 | "editor.quickSuggestions": { 32 | "strings": true 33 | }, 34 | "typescript.preferences.importModuleSpecifier": "non-relative", 35 | "jest.nodeEnv": { 36 | "NODE_OPTIONS": "--experimental-vm-modules" 37 | }, 38 | "jest.autoRun": "off", 39 | "jest.showCoverageOnLoad": false 40 | } 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zebreus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # replicate-api 2 | 3 | A typed client library for the [replicate.com](https://replicate.com/) API. 4 | 5 | You can use this to access the prediction API in a type-safe and convenient way. 6 | 7 | ## Install 8 | 9 | Just install it with your favorite package manager: 10 | 11 | ```bash 12 | yarn add replicate-api 13 | pnpm add replicate-api 14 | npm install replicate-api 15 | ``` 16 | 17 | The package should work in the browser and in Node.js [versions 18 and up](#older-node-versions). 18 | 19 | ## Obtain an API token 20 | 21 | You need an API token for nearly all operations. You can find the token in your 22 | [account settings](https://replicate.com/account). 23 | 24 | ## Examples 25 | 26 | ### Generate an image with stable-diffusion 27 | 28 | You can create a new prediction using the 29 | [`stability-ai/stable-diffusion`](https://replicate.com/stability-ai/stable-diffusion) model and wait for the result 30 | with: 31 | 32 | ```typescript 33 | const prediction = await predict({ 34 | model: "stability-ai/stable-diffusion", // The model name 35 | input: { prompt: "multicolor hyperspace" }, // The model specific input 36 | token: "...", // You need a token from replicate.com 37 | poll: true, // Wait for the model to finish 38 | }) 39 | 40 | console.log(prediction.output[0]) 41 | // https://replicate.com/api/models/stability-ai/stable-diffusion/files/58a1dcfc-3d5d-4297-bac2-5395294fe463/out-0.png 42 | ``` 43 | 44 | This does some things for you like resolving the model name to a model version and polling until the prediction is 45 | completed. 46 | 47 | ### Create a new prediction 48 | 49 | ```typescript 50 | const result = await predict({ model: "replicate/hello-world", input: { prompt: "..." }, token: "..." }) 51 | ``` 52 | 53 | Then you can check `result.status` to see if it's `"starting"`, `"processing"` or `succeeded`. If it's `"succeeded"` you 54 | can get the outputs with `result.outputs`. If not you can check back later with `getPrediction()` and the id from 55 | `result` (`result.id`). 56 | 57 | You can also set `poll: true` in the options of `predict()` to wait until it has finished. If you don't do that, you can 58 | still use `.poll()` to poll until the prediction is done. 59 | 60 | ### Wait until a prediction is finished 61 | 62 | ```typescript 63 | // If you have a PredictionState: 64 | const finishedPrediction = prediction.poll() 65 | 66 | // If you only have the prediction ID: 67 | const finishedPrediction = await pollPrediction({ id, token: "..." }) 68 | 69 | // If you are creating a new prediction anyways: 70 | const finishedPrediction = await predict({ ...otherOptions, poll: true }) 71 | ``` 72 | 73 | ### Retrieve the current state of a prediction 74 | 75 | ```typescript 76 | // If you have a PredictionState: 77 | const currentPrediction = prediction.get() 78 | 79 | // If you only have the prediction ID: 80 | const currentPrediction = await getPrediction({ id, token: "..." }) 81 | ``` 82 | 83 | ### Cancel a running prediction 84 | 85 | ```typescript 86 | // If you have a PredictionState: 87 | const currentPrediction = result.cancel() 88 | 89 | // If you only have the prediction ID: 90 | const currentPrediction = await cancelPrediction({ id, token: "..." }) 91 | ``` 92 | 93 | Canceling the prediction also returns the state of the prediction after canceling. 94 | 95 | ### Get information about a model 96 | 97 | ```typescript 98 | const info = await getModel({ model: "replicate/hello-world", token: "..." }) 99 | ``` 100 | 101 | ### Get a list of all versions of a model 102 | 103 | ```typescript 104 | const info = await listVersions({ model: "replicate/hello-world", token: "..." }) 105 | ``` 106 | 107 | ### Generate a prediction without using the convenience functions 108 | 109 | The first example used a few convenience functions to make it easier to use the API. You can also use the lower-level 110 | functions that map the API calls more directly. 111 | 112 | ```typescript 113 | const model = await getModel({ model: "stability-ai/stable-diffusion", token: "..." }) 114 | 115 | let prediction = await predict({ 116 | version: model.version, 117 | input: { prompt: "multicolor hyperspace" }, 118 | token: "...", 119 | }) 120 | 121 | // pollPrediction does this a bit smarter, with increasing backoff 122 | while (prediction.status === "starting" || prediction.status === "processing") { 123 | await new Promise(resolve => setTimeout(resolve, 1000)) 124 | prediction = await getPrediction({ id: prediction.id, token: "..." }) 125 | } 126 | 127 | console.log(prediction.outputs[0]) 128 | // https://replicate.com/api/models/stability-ai/stable-diffusion/files/58a1dcfc-3d5d-4297-bac2-5395294fe463/out-0.png 129 | ``` 130 | 131 | ### List your past predictions 132 | 133 | ```typescript 134 | const result = await listPredictions({ 135 | token: "...", 136 | }) 137 | ``` 138 | 139 | Returns up to 100 predictions. To get more, use the `next` function: 140 | 141 | ```typescript 142 | const moreResults = await result.next() 143 | ``` 144 | 145 | You can also set `all: true` to get all predictions. 146 | 147 | ### Use files in your inputs 148 | 149 | To use file inputs you need to pass them as URLs. You can use the `loadFile` function to convert local files to base64 150 | data URLs: 151 | 152 | ```typescript 153 | const testaudioURL = await loadFile("./testaudio.mp3") 154 | // 155 | ``` 156 | 157 | You can also use an HTTPS URL to load files from the web. 158 | 159 | ### Transcribe audio with whisper 160 | 161 | You can create a new prediction for the [`openai/whisper`](https://replicate.com/openai/whisper) model and wait for the 162 | result with: 163 | 164 | ```typescript 165 | const prediction = await predict({ 166 | model: "openai/whisper", // The model name 167 | input: { 168 | audio: await loadFile("./testaudio.mp3"), // Load local file as base64 dataurl 169 | // audio: "https://raw.githubusercontent.com/zebreus/replicate-api/master/testaudio.mp3", // Load from a URL 170 | model: "base", 171 | }, // The model specific input 172 | token: "...", // You need a token from replicate.com 173 | poll: true, // Wait for the model to finish 174 | }) 175 | 176 | console.log(prediction.output.transcription) 177 | // Transcribed text 178 | ``` 179 | 180 | ## Related projects 181 | 182 | - [replicate-js](https://github.com/nicholascelestin/replicate-js) - A js object-oriented client for replicate 183 | 184 | ## Older node versions 185 | 186 | This package uses the `fetch` API which is only supported in Node.js 18 and up. If you need to use an older version of 187 | Node.js, you can use `node-fetch`. It will be detected and used automatically if your node does not provide a native 188 | fetch. The Options object supports passing a custom fetch function, you can also try to pass `node-fetch` there. 189 | 190 | ## Building and testing this package 191 | 192 | To run the tests for this package you need an API token from . Then you create a `src/tests/token.ts` 193 | file that exports the token as a string like `export const token = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"`. Now you 194 | can run `yarn test` to run the tests. 195 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "locked": { 5 | "lastModified": 1659877975, 6 | "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", 7 | "owner": "numtide", 8 | "repo": "flake-utils", 9 | "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "numtide", 14 | "repo": "flake-utils", 15 | "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", 16 | "type": "github" 17 | } 18 | }, 19 | "nixpkgs": { 20 | "locked": { 21 | "lastModified": 1664383307, 22 | "narHash": "sha256-yvw3b8VOfcZtzoP5OKh0mVvoHglbEQhes6RSERtlxrE=", 23 | "owner": "nixos", 24 | "repo": "nixpkgs", 25 | "rev": "07b207c5e9a47b640fe30861c9eedf419c38dce0", 26 | "type": "github" 27 | }, 28 | "original": { 29 | "owner": "nixos", 30 | "repo": "nixpkgs", 31 | "rev": "07b207c5e9a47b640fe30861c9eedf419c38dce0", 32 | "type": "github" 33 | } 34 | }, 35 | "root": { 36 | "inputs": { 37 | "flake-utils": "flake-utils", 38 | "nixpkgs": "nixpkgs" 39 | } 40 | } 41 | }, 42 | "root": "root", 43 | "version": 7 44 | } 45 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | inputs = { 3 | nixpkgs.url = 4 | "github:nixos/nixpkgs?rev=07b207c5e9a47b640fe30861c9eedf419c38dce0"; 5 | flake-utils.url = 6 | "github:numtide/flake-utils?rev=c0e246b9b83f637f4681389ecabcb2681b4f3af0"; 7 | }; 8 | 9 | outputs = { self, nixpkgs, flake-utils }: 10 | flake-utils.lib.simpleFlake { 11 | inherit self nixpkgs; 12 | name = "Package name"; 13 | shell = { pkgs }: 14 | pkgs.mkShell { 15 | buildInputs = with pkgs; [ nodejs yarn ]; 16 | shellHook = '' 17 | export PATH="$(pwd)/node_modules/.bin:$PATH" 18 | ''; 19 | }; 20 | 21 | }; 22 | } -------------------------------------------------------------------------------- /jest.config.json: -------------------------------------------------------------------------------- 1 | { 2 | "extensionsToTreatAsEsm": [".ts"], 3 | "transform": { 4 | "^.+.m?tsx?$": "@zebreus/resolve-tspaths/jest" 5 | }, 6 | "coverageThreshold": { 7 | "global": { 8 | "statements": 95 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "replicate-api", 3 | "version": "0.4.5", 4 | "description": "A typed client library for the replicate.com API", 5 | "author": { 6 | "name": "Zebreus", 7 | "email": "lennarteichhorn@googlemail.com" 8 | }, 9 | "repository": { 10 | "type": "git", 11 | "url": "https://github.com/Zebreus/replicate-api" 12 | }, 13 | "license": "MIT", 14 | "type": "module", 15 | "devDependencies": { 16 | "@types/eslint": "^8.4.6", 17 | "@types/jest": "^29.0.3", 18 | "@types/node": "^18.7.18", 19 | "@typescript-eslint/eslint-plugin": "^5.37.0", 20 | "@typescript-eslint/parser": "^5.37.0", 21 | "@zebreus/resolve-tspaths": "^0.8.10", 22 | "eslint": "^8.23.1", 23 | "eslint-import-resolver-typescript": "^3.5.1", 24 | "eslint-plugin-import": "^2.26.0", 25 | "eslint-plugin-jest": "^27.0.4", 26 | "husky": "^8.0.1", 27 | "jest": "^29.0.3", 28 | "lint-staged": "^13.0.3", 29 | "prettier": "^2.7.1", 30 | "prettier-plugin-organize-imports": "^3.1.1", 31 | "ts-jest": "^29.0.1", 32 | "ts-node": "^10.9.1", 33 | "tsc-files": "^1.1.3", 34 | "typescript": "^4.8.3" 35 | }, 36 | "scripts": { 37 | "lint": "tsc --noEmit && prettier . --check && eslint --cache --ignore-path .gitignore --ext ts,js,tsx,jsx .", 38 | "build": "rm -rf dist && tsc -p tsconfig.build.json && resolve-tspaths -p tsconfig.build.json", 39 | "prepack": "rm -rf dist && tsc -p tsconfig.build.json && resolve-tspaths -p tsconfig.build.json", 40 | "format": "prettier --write .", 41 | "test": "NODE_OPTIONS='--experimental-vm-modules' jest", 42 | "prepare": "husky install" 43 | }, 44 | "files": [ 45 | "dist/**" 46 | ], 47 | "keywords": [ 48 | "library", 49 | "replicate", 50 | "api", 51 | "stable-diffusion", 52 | "ai" 53 | ], 54 | "main": "dist/index.js", 55 | "engines": { 56 | "node": ">=16.0.0" 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/cancelPrediction.ts: -------------------------------------------------------------------------------- 1 | import { getPrediction } from "getPrediction" 2 | import { makeApiRequest, ReplicateRequestOptions } from "helpers/makeApiRequest" 3 | 4 | export type CancelPredictionOptions = { 5 | /** The id of a prediction */ 6 | id: string 7 | } & ReplicateRequestOptions 8 | 9 | /** Cancel a running prediction. 10 | * ```typescript 11 | * const result = await cancelPrediction({ 12 | * id: "ID of your prediction", 13 | * token: "Get your token at https://replicate.com/account" 14 | * }) 15 | * ``` 16 | * If you have a `PredictionState`, you don't have to use this function, just call `.cancel()` on that object. 17 | * 18 | * @returns A `PredictionState` representing the new state. 19 | */ 20 | export const cancelPrediction = async (options: CancelPredictionOptions) => { 21 | await makeApiRequest(options, "POST", `predictions/${options.id}/cancel`) 22 | return await getPrediction(options) 23 | } 24 | -------------------------------------------------------------------------------- /src/getModel.ts: -------------------------------------------------------------------------------- 1 | import { extractModelAndOwner } from "helpers/extractModelAndOwner" 2 | import { makeApiRequest, ReplicateRequestOptions } from "helpers/makeApiRequest" 3 | import { ModelNameOptions } from "predict" 4 | 5 | export type ResolveModelOptions = ModelNameOptions & ReplicateRequestOptions 6 | 7 | export type ModelVersionResponse = { 8 | id: string 9 | created_at: string 10 | cog_version: string 11 | openapi_schema: unknown 12 | } 13 | 14 | type ModelResponse = { 15 | url: string 16 | owner: string 17 | name: string 18 | description: null | string 19 | visibility: "public" | "private" 20 | github_url: null | string 21 | paper_url: null | string 22 | license_url: null | string 23 | latest_version: ModelVersionResponse 24 | } 25 | 26 | /** Get information about a model 27 | * ```typescript 28 | * const model = await getModel({ 29 | * model: "stability-ai/stable-diffusion", 30 | * token: "Get your token at https://replicate.com/account" 31 | * }) 32 | * 33 | * const version = model.version 34 | * ``` 35 | */ 36 | export const getModel = async (options: ResolveModelOptions) => { 37 | const { owner, model } = extractModelAndOwner(options.model) 38 | const response = await makeApiRequest(options, "GET", `models/${owner}/${model}`) 39 | 40 | const result = { 41 | url: response.url, 42 | owner: response.owner, 43 | name: response.name, 44 | description: response.description ?? undefined, 45 | visibility: response.visibility, 46 | github: response.github_url ?? undefined, 47 | paper: response.paper_url ?? undefined, 48 | license: response.license_url ?? undefined, 49 | version: response.latest_version.id, 50 | } 51 | 52 | return result 53 | } 54 | -------------------------------------------------------------------------------- /src/getPrediction.ts: -------------------------------------------------------------------------------- 1 | import { convertPrediction, PredictionResponse } from "helpers/convertPrediction" 2 | import { makeApiRequest, ReplicateRequestOptions } from "helpers/makeApiRequest" 3 | 4 | export type GetPredictionOptions = { 5 | /** The ID of a prediction */ 6 | id: string 7 | } & ReplicateRequestOptions 8 | 9 | /** Get the `PredictionState` for a given ID. 10 | * 11 | * ```typescript 12 | * const result = await getPrediction({ 13 | * id: "ID of your prediction", 14 | * token: "Get your token at https://replicate.com/account" 15 | * }) 16 | * ``` 17 | * 18 | * @returns A new `PredictionState`. 19 | */ 20 | export const getPrediction = async (options: GetPredictionOptions) => { 21 | const response = await makeApiRequest(options, "GET", `predictions/${options.id}`) 22 | return convertPrediction(options, response) 23 | } 24 | -------------------------------------------------------------------------------- /src/helpers/convertPrediction.ts: -------------------------------------------------------------------------------- 1 | import { cancelPrediction } from "cancelPrediction" 2 | import { getPrediction } from "getPrediction" 3 | import { ReplicateRequestOptions } from "helpers/makeApiRequest" 4 | import { pollPrediction } from "pollPrediction" 5 | 6 | export type PredictionStatus = "starting" | "processing" | "succeeded" | "failed" | "canceled" 7 | 8 | export type PredictionResponse = { 9 | id: string 10 | version: string 11 | urls: { 12 | get: string 13 | cancel: string 14 | } 15 | created_at: string 16 | started_at: string | null 17 | completed_at: string | null 18 | status: PredictionStatus 19 | input: Record 20 | output: unknown 21 | error: null 22 | logs: null | string 23 | metrics: { 24 | /** In seconds */ 25 | predict_time?: number 26 | } 27 | } 28 | 29 | /** Status of a prediction 30 | * 31 | * The status is not automatically updated. You can use the `.get()` function on this object to get a new object with the current state on `replicate.com`. You can use `.poll()` to wait for the prediction to finish. 32 | */ 33 | export type PredictionState = { 34 | /** The id of this prediction */ 35 | id: string 36 | /** The version of the model used to generate this prediction */ 37 | version: string 38 | /** Get the updated state of this prediction from replicate 39 | * 40 | * @returns A new `PredictionState` representing the current state. 41 | */ 42 | get: () => Promise 43 | /** Cancel a running prediction. 44 | * 45 | * @returns A new `PredictionState` representing the updated state. 46 | */ 47 | cancel: () => Promise 48 | /** Poll until the prediction is completed or failed 49 | * 50 | * If the timeout occurs an error is thrown. 51 | * 52 | * @param timeout The timeout in milliseconds 53 | * @returns The `PredictionState` for the finished prediction. It has a status of either "succeeded", "failed" or "canceled". 54 | */ 55 | poll: (timeout?: number) => Promise 56 | /** When the prediction was created */ 57 | createdAt?: Date 58 | /** When execution of the prediction was started */ 59 | startedAt?: Date 60 | /** When execution of the prediction was completed (or cancelled) */ 61 | completedAt?: Date 62 | /** The status of the prediction */ 63 | status: PredictionStatus 64 | /** The input parameters */ 65 | input: Record 66 | /** The output parameters */ 67 | output: unknown 68 | error: null 69 | /** The logs of the prediction. A string seperated by newlines */ 70 | logs?: string 71 | /** Metrics about the prediction */ 72 | metrics: { 73 | /** In seconds */ 74 | predictTime?: number 75 | } 76 | } 77 | 78 | /** Convert the result that we get from replicate to a more idiomatic TypeScript object. 79 | * 80 | * Also adds `get()`, `cancel()` and `poll()` methods. 81 | */ 82 | export const convertPrediction = ( 83 | options: ReplicateRequestOptions, 84 | prediction: PredictionResponse 85 | ): PredictionState => { 86 | const PredictionState: PredictionState = { 87 | id: prediction.id, 88 | version: prediction.version, 89 | get: async () => await getPrediction({ ...options, id: prediction.id }), 90 | cancel: async () => await cancelPrediction({ ...options, id: prediction.id }), 91 | poll: async timeout => await pollPrediction({ ...options, id: prediction.id, timeout: timeout }, PredictionState), 92 | createdAt: prediction.created_at ? new Date(prediction.created_at) : undefined, 93 | startedAt: prediction.started_at ? new Date(prediction.started_at) : undefined, 94 | completedAt: prediction.completed_at ? new Date(prediction.completed_at) : undefined, 95 | status: prediction.status, 96 | input: prediction.input, 97 | output: prediction.output, 98 | error: prediction.error, 99 | logs: prediction.logs ?? undefined, 100 | metrics: { 101 | predictTime: prediction.metrics?.predict_time, 102 | }, 103 | } 104 | 105 | return PredictionState 106 | } 107 | -------------------------------------------------------------------------------- /src/helpers/convertShallowPrediction.ts: -------------------------------------------------------------------------------- 1 | import { cancelPrediction } from "cancelPrediction" 2 | import { getPrediction } from "getPrediction" 3 | import { PredictionResponse, PredictionState } from "helpers/convertPrediction" 4 | import { ReplicateRequestOptions } from "helpers/makeApiRequest" 5 | import { pollPrediction } from "pollPrediction" 6 | 7 | /** List does not return full predictions. This is the type for those response elements */ 8 | export type ShallowPredictionResponse = Pick< 9 | PredictionResponse, 10 | "id" | "version" | "created_at" | "started_at" | "completed_at" | "status" 11 | > 12 | 13 | /** Status of a prediction without the actual results 14 | * 15 | * You can use `.get()` to get the full `PredictionState`. 16 | */ 17 | export type ShallowPredictionState = Pick< 18 | PredictionState, 19 | "id" | "version" | "createdAt" | "startedAt" | "completedAt" | "status" | "get" | "cancel" | "poll" 20 | > 21 | 22 | /** Convert prediction list entries from replicate to a more idiomatic TypeScript object. 23 | * 24 | * Also adds `get()`, `cancel()` and `poll()` methods. 25 | */ 26 | export const convertShallowPrediction = ( 27 | options: ReplicateRequestOptions, 28 | prediction: ShallowPredictionResponse 29 | ): ShallowPredictionState => { 30 | const PredictionState: ShallowPredictionState = { 31 | id: prediction.id, 32 | version: prediction.version, 33 | get: async () => await getPrediction({ ...options, id: prediction.id }), 34 | cancel: async () => await cancelPrediction({ ...options, id: prediction.id }), 35 | poll: async timeout => await pollPrediction({ ...options, id: prediction.id, timeout: timeout }), 36 | createdAt: prediction.created_at ? new Date(prediction.created_at) : undefined, 37 | startedAt: prediction.started_at ? new Date(prediction.started_at) : undefined, 38 | completedAt: prediction.completed_at ? new Date(prediction.completed_at) : undefined, 39 | status: prediction.status, 40 | } 41 | 42 | return PredictionState 43 | } 44 | -------------------------------------------------------------------------------- /src/helpers/extractModelAndOwner.ts: -------------------------------------------------------------------------------- 1 | /** Converts a single string in the form of `owner/model` to an object with `owner` and `model` properties. 2 | * 3 | * @param ownerModel A string like `stability-ai/stable-diffusion` 4 | */ 5 | export const extractModelAndOwner = (ownerModel: string) => { 6 | if (!ownerModel.includes("/")) { 7 | throw new Error("model must be in the form owner/model") 8 | } 9 | 10 | const owner = ownerModel.split("/")[0] 11 | const model = ownerModel.split("/")[1] 12 | 13 | if (!owner) { 14 | throw new Error("The model name must contain the owner before the slash") 15 | } 16 | 17 | if (!model) { 18 | throw new Error("The model name must contain the model after the slash") 19 | } 20 | 21 | return { 22 | owner, 23 | model, 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/helpers/loadFile.ts: -------------------------------------------------------------------------------- 1 | import { readFile } from "fs/promises" 2 | 3 | /** Loads a file from the filesystem and returns a base64 encoded data URL. */ 4 | export const loadFile = async (path: string) => { 5 | const mimeType = guessMimeType(path) 6 | const content = await readFile(path, "base64") 7 | return `data:${mimeType};base64,${content}` 8 | } 9 | 10 | export const guessMimeType = (path: string) => { 11 | const extension = path.split(".").pop() 12 | const mimeTypesByExtension = { 13 | "aac": "audio/aac", 14 | "abw": "application/x-abiword", 15 | "arc": "application/octet-stream", 16 | "avi": "video/x-msvideo", 17 | "azw": "application/vnd.amazon.ebook", 18 | "bin": "application/octet-stream", 19 | "bz": "application/x-bzip", 20 | "bz2": "application/x-bzip2", 21 | "csh": "application/x-csh", 22 | "css": "text/css", 23 | "csv": "text/csv", 24 | "doc": "application/msword", 25 | "epub": "application/epub+zip", 26 | "gif": "image/gif", 27 | "htm": "text/html", 28 | "html": "text/html", 29 | "ico": "image/x-icon", 30 | "ics": "text/calendar", 31 | "jar": "application/java-archive", 32 | "jpeg": "image/jpeg", 33 | "jpg": "image/jpeg", 34 | "js": "application/javascript", 35 | "json": "application/json", 36 | "mid": "audio/midi", 37 | "midi": "audio/midi", 38 | "mpeg": "video/mpeg", 39 | "mp3": "audio/mpeg", 40 | "mp4": "video/mp4", 41 | "mpkg": "application/vnd.apple.installer+xml", 42 | "odp": "application/vnd.oasis.opendocument.presentation", 43 | "ods": "application/vnd.oasis.opendocument.spreadsheet", 44 | "odt": "application/vnd.oasis.opendocument.text", 45 | "oga": "audio/ogg", 46 | "ogv": "video/ogg", 47 | "ogx": "application/ogg", 48 | "pdf": "application/pdf", 49 | "ppt": "application/vnd.ms-powerpoint", 50 | "rar": "application/x-rar-compressed", 51 | "rtf": "application/rtf", 52 | "sh": "application/x-sh", 53 | "svg": "image/svg+xml", 54 | "swf": "application/x-shockwave-flash", 55 | "tar": "application/x-tar", 56 | "tif": "image/tiff", 57 | "tiff": "image/tiff", 58 | "ttf": "font/ttf", 59 | "vsd": "application/vnd.visio", 60 | "wav": "audio/x-wav", 61 | "weba": "audio/webm", 62 | "webm": "video/webm", 63 | "webp": "image/webp", 64 | "woff": "font/woff", 65 | "woff2": "font/woff2", 66 | "xhtml": "application/xhtml+xml", 67 | "xls": "application/vnd.ms-excel", 68 | "xml": "application/xml", 69 | "xul": "application/vnd.mozilla.xul+xml", 70 | "zip": "application/zip", 71 | "3gp": "video/3gpp", 72 | "3g2": "video/3gpp2", 73 | "7z": "application/x-7z-compressed", 74 | } as Record 75 | 76 | const typeByExtension = mimeTypesByExtension[extension ?? ""] 77 | 78 | if (typeByExtension) { 79 | return typeByExtension 80 | } 81 | 82 | return "application/octet-stream" 83 | } 84 | -------------------------------------------------------------------------------- /src/helpers/makeApiRequest.ts: -------------------------------------------------------------------------------- 1 | /** Interface of a fetch function. Compatible with the [fetch API](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch) */ 2 | export type FetchFunction = ( 3 | url: string, 4 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 5 | config: Record 6 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 7 | ) => Promise<{ json: () => Promise; ok: boolean; status: number }> 8 | 9 | /** Basic options for every request */ 10 | export type ReplicateRequestOptions = { 11 | /** Use a custom fetch function. Defaults to the native fetch or `node-fetch` */ 12 | fetch?: FetchFunction 13 | /** You need an https://replicate.com API token for nearly all operations. You can generate the token in your [account settings](https://replicate.com/account). */ 14 | token: string 15 | /** The actual API endpoint. Defaults to "https://api.replicate.com/v1/" */ 16 | apiUrl?: string 17 | } 18 | 19 | // webpackIgnore: true 20 | const nodeFetch = "node-fetch" 21 | 22 | const defaultFetch: FetchFunction | undefined | Promise = 23 | typeof fetch !== "undefined" 24 | ? fetch 25 | : typeof self === "undefined" 26 | ? // eslint-disable-next-line import/no-unresolved 27 | import(/* webpackIgnore: true */ nodeFetch).then(module => module.default).catch(() => undefined) 28 | : undefined 29 | 30 | export async function makeApiRequest( 31 | { fetch: passedFetchFunction, token, apiUrl = "https://api.replicate.com/v1/" }: ReplicateRequestOptions, 32 | method: "POST" | "GET", 33 | endpoint: string, 34 | content?: object 35 | ) { 36 | const url = `${apiUrl}${endpoint}` 37 | const body = method === "POST" && content ? JSON.stringify(content) : null 38 | const fetchFunctionOrPromise = passedFetchFunction || defaultFetch 39 | const fetchFunction = await fetchFunctionOrPromise 40 | 41 | if (!fetchFunction) { 42 | throw new Error("fetch is not available. Use node >= 18 or install node-fetch") 43 | } 44 | 45 | const response = await fetchFunction(url, { 46 | method, 47 | body, 48 | headers: { 49 | "Authorization": `Token ${token}`, 50 | "Content-Type": "application/json", 51 | }, 52 | }) 53 | const responseJson = await response.json() 54 | if (!response.ok) { 55 | const detail = responseJson.detail 56 | if (typeof detail === "string") { 57 | throw new Error(detail) 58 | } 59 | throw new Error(`Request failed (${response.status}): ${JSON.stringify(responseJson)}`) 60 | } 61 | return responseJson as ExpectedResponse 62 | } 63 | -------------------------------------------------------------------------------- /src/helpers/uploadFile.ts: -------------------------------------------------------------------------------- 1 | import { createReadStream, statSync } from "fs" 2 | import { guessMimeType } from "helpers/loadFile" 3 | import { basename } from "path" 4 | 5 | /** Upload a file to replicate.com and return the serving URL. 6 | * 7 | * For now files are uploaded to replicate.com, using an endpoint that is probably not intended for us. 8 | * If this breaks, future versions of this function may upload to other hosters. 9 | * 10 | * The endpoint is behind cloudflare, so I am not sure if this even works without a browser or captchas. 11 | * Feel free to open issues for all the problems you encounter with this function https://github.com/zebreus/replicate-api/issues 12 | * 13 | * @deprecated This is highly experimental and depends on undocumented endpoints of the replicate.com website. It may break at any time. Please open an issue if it does not work for you. 14 | * @param path Path to a local file 15 | * @returns A URL where the file can be downloaded from 16 | */ 17 | export const uploadFile = async (path: string) => { 18 | const { uploadUrl, servingUrl } = await getFileUrls(path) 19 | 20 | const fileSizeInBytes = statSync(path).size 21 | const mimeType = await guessMimeType(path) 22 | const fileStream = createReadStream(path) 23 | 24 | const uploadRequest = await fetch(uploadUrl, { 25 | method: "PUT", 26 | headers: { 27 | "Content-length": fileSizeInBytes + "", 28 | "Content-type": mimeType, 29 | }, 30 | body: fileStream as unknown as ReadableStream, 31 | }) 32 | if (!uploadRequest.ok) { 33 | console.error(uploadRequest) 34 | throw new Error("Failed to upload file") 35 | } 36 | return servingUrl 37 | } 38 | 39 | let csrfToken: string | undefined = undefined 40 | 41 | export const getCsrfToken = async () => { 42 | if (csrfToken) { 43 | return csrfToken 44 | } 45 | const csrfTokenRequest = await fetch( 46 | "https://replicate.com/openai/whisper/versions/23241e5731b44fcb5de68da8ebddae1ad97c5094d24f94ccb11f7c1d33d661e2", 47 | { method: "GET" } 48 | ) 49 | const setCookieHeader = csrfTokenRequest.headers.get("set-cookie") 50 | if (!setCookieHeader) { 51 | throw new Error("Failed to get CSRF token, no set-cookie header") 52 | } 53 | const receivedToken = setCookieHeader.match(/csrftoken=([^;]+)/)?.[1] 54 | if (!receivedToken) { 55 | throw new Error("Failed to get CSRF token, no token in set-cookie header") 56 | } 57 | csrfToken = receivedToken 58 | return csrfToken 59 | } 60 | 61 | export const getFileUrls = async (path: string) => { 62 | const mimeType = guessMimeType(path) 63 | const filename = basename(path) 64 | const url = `https://replicate.com/api/upload/${filename}?content_type=${encodeURIComponent(mimeType)}` 65 | const csrfToken = await getCsrfToken() 66 | const request = await fetch(url, { 67 | method: "POST", 68 | headers: { 69 | "cookie": `csrftoken=${csrfToken};`, 70 | "origin": "https://replicate.com", 71 | "x-csrftoken": csrfToken, 72 | }, 73 | }) 74 | const response = await request.json() 75 | if (typeof response !== "object") { 76 | throw new Error("Failed to get file URLs") 77 | } 78 | const uploadUrl = response.upload_url ?? "" 79 | const servingUrl = response.serving_url ?? "" 80 | if (typeof uploadUrl !== "string") { 81 | throw new Error("Failed to get file URLs, got no upload url") 82 | } 83 | if (typeof servingUrl !== "string") { 84 | throw new Error("Failed to get file URLs, got no serving url") 85 | } 86 | return { uploadUrl, servingUrl } 87 | } 88 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | export { cancelPrediction } from "cancelPrediction" 2 | export type { CancelPredictionOptions } from "cancelPrediction" 3 | export { getModel } from "getModel" 4 | export type { ResolveModelOptions } from "getModel" 5 | export { getPrediction } from "getPrediction" 6 | export type { GetPredictionOptions } from "getPrediction" 7 | export type { PredictionState, PredictionStatus } from "helpers/convertPrediction" 8 | export type { ShallowPredictionState } from "helpers/convertShallowPrediction" 9 | export { loadFile } from "helpers/loadFile" 10 | export type { FetchFunction, ReplicateRequestOptions } from "helpers/makeApiRequest" 11 | export { listPredictions } from "listPredictions" 12 | export type { ListOfPredictions, ListPredictionsOptions } from "listPredictions" 13 | export { listVersions } from "listVersions" 14 | export type { ListOfVersions, ListVersionsOptions } from "listVersions" 15 | export { pollPrediction } from "pollPrediction" 16 | export type { PollPredictionOptions } from "pollPrediction" 17 | export { predict } from "predict" 18 | export type { ModelNameOptions, ModelVersionOptions, PredictOptions } from "predict" 19 | export { processWebhook } from "processWebhook" 20 | export type { ProcessWebhookOptions } from "processWebhook" 21 | -------------------------------------------------------------------------------- /src/listPredictions.ts: -------------------------------------------------------------------------------- 1 | import { 2 | convertShallowPrediction, 3 | ShallowPredictionResponse, 4 | ShallowPredictionState, 5 | } from "helpers/convertShallowPrediction" 6 | import { makeApiRequest, ReplicateRequestOptions } from "helpers/makeApiRequest" 7 | 8 | export type PagedRequestOptions = { 9 | // TODO: Maybe replace with a maxResults option? 10 | /** Set to true to get all results */ 11 | all?: boolean 12 | /** Request data at this location. You should probably use the `.next()` method instead */ 13 | cursor?: string 14 | } 15 | 16 | export type ListPredictionsOptions = PagedRequestOptions & ReplicateRequestOptions 17 | 18 | type ListPredictionsResponse = { 19 | previous?: string 20 | next?: string //"https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", 21 | results: ShallowPredictionResponse[] 22 | } 23 | 24 | export type ListOfPredictions = { 25 | /** Up to 100 predictions */ 26 | predictions: ShallowPredictionState[] 27 | /** Get the next predictions */ 28 | next: () => Promise 29 | /** Cursor to get the next predictions manually. You should probably use the `.next()` method instead */ 30 | nextCursor?: string 31 | } 32 | 33 | const getEmptyResult = async (): Promise => ({ 34 | predictions: [], 35 | next: () => getEmptyResult(), 36 | }) 37 | 38 | /** List your past predictions. 39 | * 40 | * ```typescript 41 | * const result = await listPredictions({ 42 | * token: "Get your token at https://replicate.com/account" 43 | * }) 44 | * ``` 45 | * 46 | * Returns up to 100 predictions. To get more, use the `next` function: 47 | * ```typescript 48 | * const moreResults = await result.next() 49 | * ``` 50 | * 51 | * @returns A new `ShallowPredictionState`. 52 | */ 53 | export const listPredictions = async (options: ListPredictionsOptions): Promise => { 54 | const response = await makeApiRequest( 55 | options, 56 | "GET", 57 | `predictions${options.cursor ? "?cursor=" + options.cursor : ""}` 58 | ) 59 | 60 | const predictions = response.results.map(prediction => convertShallowPrediction(options, prediction)) 61 | const nextCursor = response.next?.split("cursor=").pop() 62 | const result = { 63 | predictions: predictions, 64 | next: () => 65 | predictions.length && nextCursor ? listPredictions({ ...options, cursor: nextCursor }) : getEmptyResult(), 66 | ...(nextCursor ? { nextCursor } : {}), 67 | } 68 | 69 | if (!options.all) { 70 | return result 71 | } 72 | 73 | const nextResult = await result.next() 74 | const allResults = { 75 | predictions: [...result.predictions, ...nextResult.predictions], 76 | next: () => getEmptyResult(), 77 | } 78 | return allResults 79 | } 80 | -------------------------------------------------------------------------------- /src/listVersions.ts: -------------------------------------------------------------------------------- 1 | import { ModelVersionResponse } from "getModel" 2 | import { extractModelAndOwner } from "helpers/extractModelAndOwner" 3 | import { makeApiRequest, ReplicateRequestOptions } from "helpers/makeApiRequest" 4 | import { PagedRequestOptions } from "listPredictions" 5 | import { ModelNameOptions } from "predict" 6 | 7 | /** Options for `listVersions` */ 8 | export type ListVersionsOptions = PagedRequestOptions & ModelNameOptions & ReplicateRequestOptions 9 | 10 | type ModelVersionsResponse = { 11 | previous?: string 12 | next?: string 13 | results: Array 14 | } 15 | 16 | export type ModelVersion = { 17 | id: string 18 | createdAt: Date 19 | cogVersion: string 20 | schema: unknown 21 | } 22 | 23 | export type ListOfVersions = { 24 | /** The id of the latest version */ 25 | version: string | undefined 26 | /** Up to 100 versions */ 27 | versions: ModelVersion[] 28 | /** Get the next versions */ 29 | next: () => Promise 30 | /** Cursor to get the next versions manually. You should probably use the `.next()` method instead */ 31 | nextCursor?: string 32 | } 33 | 34 | const getEmptyResult = async (): Promise => ({ 35 | version: undefined, 36 | versions: [], 37 | next: () => getEmptyResult(), 38 | }) 39 | 40 | /** List all versions that are availabe for a model 41 | * ```typescript 42 | * const {versions, version} = await listVersions({ 43 | * model: "stability-ai/stable-diffusion", 44 | * token: "Get your token at https://replicate.com/account" 45 | * }) 46 | * ``` 47 | */ 48 | export const listVersions = async (options: ListVersionsOptions): Promise => { 49 | const { owner, model } = extractModelAndOwner(options.model) 50 | const response = await makeApiRequest( 51 | options, 52 | "GET", 53 | `models/${owner}/${model}/versions${options.cursor ? "?cursor=" + options.cursor : ""}` 54 | ) 55 | 56 | const versions = response.results.map(version => ({ 57 | id: version.id, 58 | createdAt: new Date(version.created_at), 59 | cogVersion: version.cog_version, 60 | schema: version.openapi_schema, 61 | })) 62 | const nextCursor = response.next?.split("cursor=").pop() 63 | const result = { 64 | version: versions[0]?.id, 65 | versions: versions, 66 | next: () => (versions.length && nextCursor ? listVersions({ ...options, cursor: nextCursor }) : getEmptyResult()), 67 | ...(nextCursor ? { nextCursor } : {}), 68 | } 69 | 70 | if (!options.all) { 71 | return result 72 | } 73 | 74 | const nextResult = await result.next() 75 | const allResults = { 76 | version: result.version, 77 | versions: [...result.versions, ...nextResult.versions], 78 | next: () => getEmptyResult(), 79 | } 80 | return allResults 81 | } 82 | -------------------------------------------------------------------------------- /src/pollPrediction.ts: -------------------------------------------------------------------------------- 1 | import { getPrediction } from "getPrediction" 2 | import { PredictionState } from "helpers/convertPrediction" 3 | import { ReplicateRequestOptions } from "helpers/makeApiRequest" 4 | 5 | export type PollPredictionOptions = { 6 | /** The id of a prediction */ 7 | id: string 8 | /** Timeout in milliseconds 9 | * @default 3600000 10 | */ 11 | timeout?: number 12 | } & ReplicateRequestOptions 13 | 14 | const getSleepDuration = (elapsedTimeMillis: number) => { 15 | if (elapsedTimeMillis < 10000) { 16 | return 1000 17 | } 18 | if (elapsedTimeMillis < 60000) { 19 | return 5000 20 | } 21 | return 10000 22 | } 23 | 24 | /** Poll a prediction by ID. 25 | * 26 | * ```typescript 27 | * const result = await pollPrediction({ 28 | * id: "ID of your prediction", 29 | * token: "Get your token at https://replicate.com/account" 30 | * }) 31 | * ``` 32 | * If you have a `PredictionState`, you don't have to use this function, just call `.poll()` on that object. 33 | * 34 | * If the timeout occurs an error is thrown. 35 | * 36 | * @returns A new `PredictionState`. It has a status of either "succeeded", "failed" or "canceled". 37 | */ 38 | export const pollPrediction = async (options: PollPredictionOptions, initialResult?: PredictionState) => { 39 | let newPrediction = initialResult || (await getPrediction({ ...options, id: options.id })) 40 | 41 | const endAt = Date.now() + (options.timeout ?? 3600000) 42 | 43 | while (Date.now() < endAt) { 44 | if ( 45 | newPrediction.status === "succeeded" || 46 | newPrediction.status === "failed" || 47 | newPrediction.status === "canceled" 48 | ) { 49 | return newPrediction 50 | } 51 | 52 | const elapsedTime = newPrediction.startedAt ? Date.now() - newPrediction.startedAt.getTime() : 0 53 | const sleepDuration = getSleepDuration(elapsedTime) 54 | if (newPrediction !== initialResult) { 55 | await new Promise(resolve => setTimeout(resolve, sleepDuration)) 56 | } 57 | newPrediction = await getPrediction({ ...options, id: options.id }) 58 | } 59 | throw new Error("Prediction timed out") 60 | } 61 | -------------------------------------------------------------------------------- /src/predict.ts: -------------------------------------------------------------------------------- 1 | import { getModel } from "getModel" 2 | import { convertPrediction, PredictionResponse } from "helpers/convertPrediction" 3 | import { makeApiRequest, ReplicateRequestOptions } from "helpers/makeApiRequest" 4 | 5 | /** Option for the model name; e.g. `stability-ai/stable-diffusion` */ 6 | export type ModelNameOptions = { 7 | /** The name of the model; e.g. `stability-ai/stable-diffusion` */ 8 | model: string 9 | } 10 | 11 | /** Option for the model version */ 12 | export type ModelVersionOptions = { 13 | /** The ID of the model version that you want to run. */ 14 | version: string 15 | } 16 | 17 | /** Select which events trigger webhook request 18 | * 19 | * - `"start"`: immediately on prediction start 20 | * - `"output"`: each time a prediction generates an output (note that predictions can generate multiple outputs) 21 | * - `"logs"`: each time log output is generated by a prediction 22 | * - `"completed"`: when the prediction reaches a terminal state (succeeded/canceled/failed) 23 | * 24 | * See https://replicate.com/docs/reference/http#create-prediction--webhook_events_filter for more information. 25 | */ 26 | export type WebhookEventType = "start" | "output" | "logs" | "completed" 27 | 28 | /** Options for creating a new prediction */ 29 | export type PredictOptions = { 30 | /** The model's input as a JSON object. This differs for each model */ 31 | input: Record 32 | /** Set to true to poll until the prediction is completed */ 33 | poll?: boolean 34 | /** A webhook that is called when the prediction has completed. */ 35 | webhook?: string 36 | /** Select which events trigger webhook request 37 | * 38 | * @default ["completed"] 39 | */ 40 | webhookEvents?: WebhookEventType[] 41 | } & (ModelVersionOptions | ModelNameOptions) & 42 | ReplicateRequestOptions 43 | 44 | /** Create a new prediction 45 | * 46 | * ```typescript 47 | * const result = await predict({ 48 | * model: "stability-ai/stable-diffusion", 49 | * input: { prompt: "multicolor hyperspace" }, 50 | * token: "Get your token at https://replicate.com/account", 51 | * poll: true, 52 | * }) 53 | * ``` 54 | * 55 | * Then you can check `result.status` to see if it's `"starting"`, `"processing"` or `succeeded`. If it's `"succeeded"` you can get the outputs with `result.outputs`. If not you can check back later with `getPrediction` and the id from result (`result.id`). 56 | * 57 | * If you set the `poll` option this function will return a promise that waits until the prediction is completed. 58 | */ 59 | export const predict = async (options: PredictOptions) => { 60 | const version = "version" in options ? options.version : (await getModel(options)).version 61 | const response = await makeApiRequest(options, "POST", "predictions", { 62 | version: version, 63 | input: options.input, 64 | webhook: options.webhook, 65 | webhook_events_filter: options.webhook ? options.webhookEvents || ["completed"] : undefined, 66 | }) 67 | 68 | const prediction = convertPrediction(options, response) 69 | return options.poll ? await prediction.poll() : prediction 70 | } 71 | -------------------------------------------------------------------------------- /src/processWebhook.ts: -------------------------------------------------------------------------------- 1 | import { convertPrediction, PredictionResponse } from "helpers/convertPrediction" 2 | import { ReplicateRequestOptions } from "helpers/makeApiRequest" 3 | 4 | export type ProcessWebhookOptions = { 5 | /** The webhook body as an object */ 6 | body: unknown 7 | } & ReplicateRequestOptions 8 | 9 | /** Convert the body of a replicate callback to a `PredictionState`. 10 | * 11 | * When creating a prediction you can set a URL that will be called by replicate once the prediction is completed. This function can take the body of that request and converts it to a `PredictionState`. 12 | */ 13 | export const processWebhook = (options: ProcessWebhookOptions) => { 14 | const { body } = options 15 | 16 | if ( 17 | typeof body !== "object" || 18 | !body || 19 | typeof (body as Record)["id"] !== "string" || 20 | typeof (body as Record)["created_at"] !== "string" 21 | ) { 22 | throw new Error("You need to pass a valid PredictionResponse") 23 | } 24 | 25 | return convertPrediction(options, body as PredictionResponse) 26 | } 27 | -------------------------------------------------------------------------------- /src/tests/predict.test.ts: -------------------------------------------------------------------------------- 1 | import { cancelPrediction } from "cancelPrediction" 2 | import { log } from "console" 3 | import { getModel } from "getModel" 4 | import { getPrediction } from "getPrediction" 5 | import { loadFile } from "helpers/loadFile" 6 | import { listPredictions } from "listPredictions" 7 | import { listVersions } from "listVersions" 8 | import { pollPrediction } from "pollPrediction" 9 | import { predict } from "predict" 10 | import { token } from "tests/token" 11 | 12 | test("Call predict", async () => { 13 | const prediction = predict({ 14 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 15 | token, 16 | input: { 17 | prompt: "The quick brown fox jumps over the lazy dog", 18 | }, 19 | }) 20 | await expect(prediction).resolves.toBeTruthy() 21 | const result = await prediction 22 | result 23 | }) 24 | 25 | test("Polling a prediction works", async () => { 26 | const prediction = await predict({ 27 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 28 | token, 29 | input: { 30 | prompt: "The quick brown fox jumps over the lazy dog", 31 | }, 32 | }) 33 | expect(prediction.status).not.toBe("succeeded") 34 | const pollResult = await prediction.poll() 35 | expect(pollResult.status).toBe("succeeded") 36 | }, 20000) 37 | 38 | test("Polling with the poll function works", async () => { 39 | const prediction = await predict({ 40 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 41 | token, 42 | input: { 43 | prompt: "The quick brown fox jumps over the lazy dog", 44 | }, 45 | }) 46 | expect(prediction.status).not.toBe("succeeded") 47 | const pollR = await pollPrediction({ token, id: prediction.id }) 48 | expect(pollR.status).toBe("succeeded") 49 | }, 20000) 50 | 51 | test("Polling with the poll option", async () => { 52 | const prediction = await predict({ 53 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 54 | token, 55 | input: { 56 | prompt: "The quick brown fox jumps over the lazy dog", 57 | }, 58 | poll: true, 59 | }) 60 | expect(prediction.status).toBe("succeeded") 61 | }, 20000) 62 | 63 | test("Call predict with a model id instead of a version", async () => { 64 | const prediction = predict({ 65 | model: "stability-ai/stable-diffusion", 66 | token, 67 | input: { 68 | prompt: "The quick brown fox jumps over the lazy dog", 69 | }, 70 | }) 71 | await expect(prediction).resolves.toBeTruthy() 72 | const result = await prediction 73 | result 74 | }) 75 | 76 | test("Can retrieve an existing prediction predict", async () => { 77 | const prediction = await getPrediction({ 78 | id: "uhi3fggr6fgzbnjl5ccbzp3tme", 79 | token, 80 | }) 81 | expect(prediction.version).toBe("a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef") 82 | expect(prediction.status).toBe("succeeded") 83 | log(prediction) 84 | }) 85 | 86 | test("Canceling a existing prediction works", async () => { 87 | const prediction = await cancelPrediction({ 88 | id: "uhi3fggr6fgzbnjl5ccbzp3tme", 89 | token, 90 | }) 91 | await expect(["succeeded", "canceled", "failed"].includes(prediction.status)).toBeTruthy() 92 | }) 93 | 94 | test("Canceling a running prediction works", async () => { 95 | const prediction = await predict({ 96 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 97 | token, 98 | input: { 99 | prompt: "The quick brown fox jumps over the lazy dog", 100 | }, 101 | }) 102 | const canceledPrediction = await cancelPrediction({ id: prediction.id, token }) 103 | expect(canceledPrediction.status).toBe("canceled") 104 | }, 20000) 105 | 106 | test("Fails to cancel a nonexistent prediction", async () => { 107 | const prediction = cancelPrediction({ 108 | id: "uhi3fggr6fgzbnjl5ccbzpaaaa", 109 | token, 110 | }) 111 | await expect(prediction).rejects.toBeTruthy() 112 | }) 113 | 114 | test("Fails with invalid token", async () => { 115 | const prediction = predict({ 116 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 117 | token: "sjsakjfsladjf", 118 | input: {}, 119 | }) 120 | await expect(prediction).rejects.toBeTruthy() 121 | }) 122 | 123 | test("Resolving a model works", async () => { 124 | const prediction = await getModel({ 125 | model: "stability-ai/stable-diffusion", 126 | token, 127 | }) 128 | expect(prediction.version).toBeDefined() 129 | }) 130 | 131 | test("Resolving model versions works", async () => { 132 | const prediction = await listVersions({ 133 | model: "stability-ai/stable-diffusion", 134 | token, 135 | }) 136 | expect(prediction.versions.map(version => version.id)).toContain( 137 | "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef" 138 | ) 139 | }) 140 | 141 | // These tests require a token that has more than 200 past predictions 142 | test("Listing past predictions return 100 results", async () => { 143 | const { predictions } = await listPredictions({ token }) 144 | expect(predictions.length).toBe(100) 145 | }) 146 | 147 | test("Listing all past predictions returns a lot of results", async () => { 148 | const { predictions: allPastPredictions } = await listPredictions({ token, all: true }) 149 | expect(allPastPredictions.length).toBeGreaterThan(200) 150 | }, 120000) 151 | 152 | test("Listing the next predictions returns different results than the first call", async () => { 153 | const { predictions, next } = await listPredictions({ token }) 154 | expect(predictions.length).toBe(100) 155 | const { predictions: nextPredictions } = await next() 156 | expect(nextPredictions.length).toBe(100) 157 | expect(nextPredictions[0]?.id).not.toBe(predictions[0]?.id) 158 | }) 159 | 160 | test("Using a model with a file input works", async () => { 161 | const prediction = predict({ 162 | model: "openai/whisper", 163 | token, 164 | input: { 165 | audio: await loadFile("./testaudio.mp3"), 166 | model: "base", 167 | }, 168 | poll: true, 169 | }) 170 | const result = await prediction 171 | const { transcription } = (result?.output ?? {}) as { transcription?: unknown } 172 | expect(transcription).toBeTruthy() 173 | if (typeof transcription !== "string") { 174 | throw new Error("Transcription is not a string") 175 | } 176 | 177 | expect(transcription).toContain( 178 | "This is the Cal NEH American English Dialect Recordings Collection, produced with funding from the National Endowment for the Humanities and the Center for Applied Linguistics" 179 | ) 180 | }, 240000) 181 | 182 | test("Calling predict with a webhook does not fail", async () => { 183 | const prediction = predict({ 184 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 185 | token, 186 | input: { 187 | prompt: "The quick brown fox jumps over the lazy dog", 188 | }, 189 | webhook: "https://google.com", 190 | }) 191 | await expect(prediction).resolves.toBeTruthy() 192 | const awaitedPrediction = await prediction 193 | await expect(awaitedPrediction.cancel()).resolves.toBeTruthy() 194 | }) 195 | 196 | test("Calling predict with a webhook and events does not fail", async () => { 197 | const prediction = predict({ 198 | version: "a9758cbfbd5f3c2094457d996681af52552901775aa2d6dd0b17fd15df959bef", 199 | token, 200 | input: { 201 | prompt: "The quick brown fox jumps over the lazy dog", 202 | }, 203 | webhook: "https://google.com", 204 | webhookEvents: ["completed", "start"], 205 | }) 206 | await expect(prediction).resolves.toBeTruthy() 207 | const awaitedPrediction = await prediction 208 | await expect(awaitedPrediction.cancel()).resolves.toBeTruthy() 209 | }) 210 | -------------------------------------------------------------------------------- /src/tests/upload.test.ts: -------------------------------------------------------------------------------- 1 | // eslint-disable-next-line import/no-deprecated 2 | import { getCsrfToken, getFileUrls, uploadFile } from "helpers/uploadFile" 3 | 4 | test("Should be able to obtain CSRF token", async () => { 5 | const csrfTokenRequest = await fetch( 6 | "https://replicate.com/openai/whisper/versions/23241e5731b44fcb5de68da8ebddae1ad97c5094d24f94ccb11f7c1d33d661e2", 7 | { method: "GET" } 8 | ) 9 | const setCookieHeader = csrfTokenRequest.headers.get("set-cookie") 10 | expect(setCookieHeader).toBeTruthy() 11 | if (!setCookieHeader) { 12 | throw new Error("No set-cookie header") 13 | } 14 | console.log(setCookieHeader) 15 | const csrfToken = setCookieHeader.match(/csrftoken=([^;]+)/)?.[1] 16 | console.log(csrfToken) 17 | expect(csrfToken).toBeTruthy() 18 | }) 19 | 20 | test("Function for getting csrf token seems to work", async () => { 21 | const csrfToken = await getCsrfToken() 22 | 23 | const uploadRequest = await fetch("https://replicate.com/api/upload/hai.wav?content_type=audio%2Fwav", { 24 | method: "POST", 25 | headers: { 26 | "cookie": `csrftoken=${csrfToken};`, 27 | "origin": "https://replicate.com", 28 | "x-csrftoken": csrfToken, 29 | }, 30 | }) 31 | 32 | const uploadResponse = await uploadRequest.json() 33 | expect(uploadResponse).toBeTruthy() 34 | }) 35 | 36 | test("Function for getting upload URLs works", async () => { 37 | const urls = await getFileUrls("./testaudio.mp3") 38 | expect(urls.servingUrl).toBeTruthy() 39 | expect(urls.uploadUrl).toBeTruthy() 40 | }) 41 | 42 | test("Function for uploading files works", async () => { 43 | // eslint-disable-next-line import/no-deprecated 44 | const servingUrl = await uploadFile("./testaudio.mp3") 45 | expect(servingUrl).toBeTruthy() 46 | }) 47 | -------------------------------------------------------------------------------- /testaudio.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zebreus/replicate-api/7e10f059ea4630ccc9745ede582542039969d514/testaudio.mp3 -------------------------------------------------------------------------------- /tsconfig.build.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "compilerOptions": { 4 | "noEmit": false, 5 | "sourceMap": false, 6 | "declarationMap": false 7 | }, 8 | "exclude": ["src/tests", "src/**/*.test.ts"] 9 | } 10 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2022", 4 | "lib": ["ES2022", "DOM"], 5 | "module": "ES2022", 6 | "moduleResolution": "Node", 7 | "declaration": true, 8 | "declarationMap": true, 9 | "sourceMap": true, 10 | "outDir": "dist", 11 | "noEmit": true, 12 | "isolatedModules": true, 13 | "esModuleInterop": true, 14 | "forceConsistentCasingInFileNames": true, 15 | "strict": true, 16 | "skipLibCheck": true, 17 | "allowUnusedLabels": false, 18 | "allowUnreachableCode": false, 19 | "noFallthroughCasesInSwitch": true, 20 | "noPropertyAccessFromIndexSignature": true, 21 | "noUncheckedIndexedAccess": true, 22 | "resolveJsonModule": true, 23 | "rootDir": "src", 24 | "baseUrl": ".", 25 | "paths": { 26 | "*": ["./src/*"] 27 | } 28 | }, 29 | "include": ["src/**/*"], 30 | "exclude": ["node_modules"] 31 | } 32 | --------------------------------------------------------------------------------