├── .eslintignore ├── .eslintrc ├── .github └── workflows │ ├── ci.yml │ └── deploy.yml ├── .gitignore ├── LICENSE ├── README.MD ├── RELEASE.MD ├── build.config.js ├── jest.config.cjs ├── package-lock.json ├── package.json ├── pnpm-lock.yaml ├── src ├── __mocks__ │ └── fs.ts ├── clients │ ├── deep-infra.ts │ └── index.ts ├── index.ts └── lib │ ├── constants │ └── client.ts │ ├── models │ └── base │ │ ├── automatic-speech-recognition.ts │ │ ├── base-model.ts │ │ ├── cog-model.ts │ │ ├── custom-model.ts │ │ ├── embeddings.ts │ │ ├── fill-mask.ts │ │ ├── image-base-model.ts │ │ ├── image-classification.ts │ │ ├── index.ts │ │ ├── object-detection.ts │ │ ├── question-answering.ts │ │ ├── sdxl.ts │ │ ├── text-classification.ts │ │ ├── text-generation.ts │ │ ├── text-to-image.ts │ │ ├── token-classification.ts │ │ └── zero-shot-image-classification.ts │ ├── types │ ├── automatic-speech-recognition │ │ ├── request.ts │ │ └── response.ts │ ├── cog │ │ ├── request.ts │ │ ├── response.ts │ │ └── sdxl │ │ │ ├── request.ts │ │ │ └── response.ts │ ├── common │ │ ├── client-config.ts │ │ ├── image-item.ts │ │ ├── image-request.ts │ │ ├── single-text-input-request.ts │ │ └── status.ts │ ├── embeddings │ │ ├── request.ts │ │ └── response.ts │ ├── fill-mask │ │ ├── request.ts │ │ └── response.ts │ ├── image-classification │ │ ├── request.ts │ │ └── response.ts │ ├── object-detection │ │ ├── request.ts │ │ └── response.ts │ ├── questions-answering │ │ ├── request.ts │ │ └── response.ts │ ├── text-classification │ │ ├── request.ts │ │ └── response.ts │ ├── text-generation │ │ ├── request.ts │ │ └── response.ts │ ├── text-to-image │ │ ├── request.ts │ │ └── response.ts │ ├── token-classification │ │ ├── request.ts │ │ └── response.ts │ └── zero-shot-image-classification │ │ ├── request.ts │ │ └── response.ts │ └── utils │ ├── form-data.ts │ ├── read-stream.ts │ └── url.ts ├── test └── base │ ├── automatic-speech-recognition.spec.ts │ ├── embeddings.spec.ts │ ├── fill-mask.spec.ts │ ├── image-classification.spec.ts │ ├── object-detection.spec.ts │ ├── question-answering.spec.ts │ ├── sdxl.spec.ts │ ├── text-classification.spec.ts │ ├── text-generation.spec.ts │ ├── text-to-image.spec.ts │ ├── token-classification.spec.ts │ └── zero-shot-image-classification.spec.ts └── tsconfig.json /.eslintignore: -------------------------------------------------------------------------------- 1 | **/misc.ts 2 | -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "es6": true, 5 | "node": true 6 | }, 7 | "extends": [ 8 | "eslint:recommended", 9 | "plugin:@typescript-eslint/recommended", 10 | "plugin:unicorn/recommended" 11 | ], 12 | "globals": { 13 | "Atomics": "readonly", 14 | "SharedArrayBuffer": "readonly" 15 | }, 16 | "parserOptions": { 17 | "ecmaVersion": 2020, 18 | "sourceType": "module" 19 | }, 20 | "plugins": [ 21 | "unicorn" 22 | ], 23 | "rules": { 24 | "unicorn/prevent-abbreviations": "off", 25 | "unicorn/filename-case": [ 26 | "error", 27 | { 28 | "case": "kebabCase" 29 | } 30 | ], 31 | "unicorn/prefer-query-selector": "error", 32 | "unicorn/no-null": "off", 33 | "indent": [ 34 | "error", 35 | 2 36 | ], 37 | "linebreak-style": [ 38 | "error", 39 | "unix" 40 | ], 41 | "quotes": [ 42 | "error", 43 | "single" 44 | ], 45 | "semi": [ 46 | "error", 47 | "always" 48 | ] 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | pull_request: 8 | branches: 9 | - '**' 10 | 11 | 12 | 13 | jobs: 14 | unit-tests: 15 | if: github.event.pull_request.draft == false 16 | runs-on: ubuntu-latest 17 | 18 | strategy: 19 | matrix: 20 | node-version: ['16.x', '18.x', '20.x'] 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | 25 | - uses: actions/setup-node@v2 26 | with: 27 | node-version: ${{ matrix.node-version }} 28 | registry-url: 'https://registry.npmjs.org' 29 | 30 | - run: npm install 31 | - run: npm run test 32 | - run: npm run build 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Publish NPM Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | concurrency: 7 | group: ${{ github.ref }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | publish-npm: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - uses: actions/setup-node@v2 17 | with: 18 | node-version: '18.x' 19 | registry-url: 'https://registry.npmjs.org' 20 | 21 | - run: npm install 22 | - run: npm run test 23 | - run: npm run build 24 | - run: npm publish --access=public 25 | env: 26 | NODE_AUTH_TOKEN: ${{secrets.NPM_TOKEN}} 27 | 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/github/gitignore/blob/main/Node.gitignore 2 | 3 | # Logs 4 | logs 5 | *.log 6 | npm-debug.log* 7 | yarn-debug.log* 8 | yarn-error.log* 9 | lerna-debug.log* 10 | .pnpm-debug.log* 11 | 12 | # Diagnostic reports (https://nodejs.org/api/report.html) 13 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 14 | 15 | # Runtime data 16 | pids 17 | *.pid 18 | *.seed 19 | *.pid.lock 20 | 21 | # Directory for instrumented libs generated by jscoverage/JSCover 22 | lib-cov 23 | 24 | # Coverage directory used by tools like istanbul 25 | coverage 26 | *.lcov 27 | 28 | # nyc test coverage 29 | .nyc_output 30 | 31 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 32 | .grunt 33 | 34 | # Bower dependency directory (https://bower.io/) 35 | bower_components 36 | 37 | # node-waf configuration 38 | .lock-wscript 39 | 40 | # Compiled binary addons (https://nodejs.org/api/addons.html) 41 | build/Release 42 | 43 | # Dependency directories 44 | node_modules/ 45 | jspm_packages/ 46 | 47 | # Snowpack dependency directory (https://snowpack.dev/) 48 | web_modules/ 49 | 50 | # TypeScript cache 51 | *.tsbuildinfo 52 | 53 | # Optional npm cache directory 54 | .npm 55 | 56 | # Optional eslint cache 57 | .eslintcache 58 | 59 | # Optional stylelint cache 60 | .stylelintcache 61 | 62 | # Microbundle cache 63 | .rpt2_cache/ 64 | .rts2_cache_cjs/ 65 | .rts2_cache_es/ 66 | .rts2_cache_umd/ 67 | 68 | # Optional REPL history 69 | .node_repl_history 70 | 71 | # Output of 'npm pack' 72 | *.tgz 73 | 74 | # Yarn Integrity file 75 | .yarn-integrity 76 | 77 | # dotenv environment variable files 78 | .env 79 | .env.development.local 80 | .env.test.local 81 | .env.production.local 82 | .env.local 83 | 84 | # parcel-bundler cache (https://parceljs.org/) 85 | .cache 86 | .parcel-cache 87 | 88 | # Next.js build output 89 | .next 90 | out 91 | 92 | # Nuxt.js build / generate output 93 | .nuxt 94 | dist 95 | 96 | # Gatsby files 97 | .cache/ 98 | # Comment in the public line in if your project uses Gatsby and not Next.js 99 | # https://nextjs.org/blog/next-9-1#public-directory-support 100 | # public 101 | 102 | # vuepress build output 103 | .vuepress/dist 104 | 105 | # vuepress v2.x temp and cache directory 106 | .temp 107 | .cache 108 | 109 | # Docusaurus cache and generated files 110 | .docusaurus 111 | 112 | # Serverless directories 113 | .serverless/ 114 | 115 | # FuseBox cache 116 | .fusebox/ 117 | 118 | # DynamoDB Local files 119 | .dynamodb/ 120 | 121 | # TernJS port file 122 | .tern-port 123 | 124 | # Stores VSCode versions used for testing VSCode extensions 125 | .vscode-test 126 | 127 | # yarn v2 128 | .yarn/cache 129 | .yarn/unplugged 130 | .yarn/build-state.yml 131 | .yarn/install-state.gz 132 | .pnp.* 133 | 134 | # IDE files 135 | .idea 136 | *.suo 137 | *.ntvs* 138 | .vscode 139 | *.code-workspace 140 | *.njsproj 141 | *.sln 142 | 143 | docs 144 | dist 145 | **misc.ts 146 | docs 147 | .husky 148 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # DeepInfra Node API Library 2 | ![npm](https://img.shields.io/npm/v/deepinfra) 3 | ![npm](https://img.shields.io/npm/dt/deepinfra)
4 | 5 | 6 | This library provides a simple way to interact with the DeepInfra API. 7 |
8 |
9 | Check out our docs [here.](https://deepinfra.github.io/deepinfra-node/) 10 | 11 | ## Installation 12 | 13 | ```bash 14 | npm install deepinfra 15 | ``` 16 | 17 | ## Usage 18 | 19 | ### Use [text generation models](https://deepinfra.com/models/text-generation) 20 | 21 | The Mixtral mixture of expert model, developed by Mistral AI, is an innovative experimental machine learning model that 22 | leverages a mixture of 8 experts (MoE) within 7b models. Its release was facilitated via a torrent, and the model's 23 | implementation remains in the experimental phase.\_ 24 | 25 | ```typescript 26 | import {TextGeneration} from "deepinfra"; 27 | 28 | const modelName = "mistralai/Mixtral-8x22B-Instruct-v0.1"; 29 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 30 | const main = async () => { 31 | const mixtral = new TextGeneration(modelName, apiKey); 32 | const body = { 33 | input: "What is the capital of France?", 34 | }; 35 | const output = await mixtral.generate(body); 36 | const text = output.results[0].generated_text; 37 | console.log(text); 38 | }; 39 | 40 | main(); 41 | ``` 42 | 43 | ### Use [text embedding models](https://deepinfra.com/models/embeddings) 44 | 45 | Gte Base is an text embedding model that generates embeddings for the input text. The model is trained by Alibaba DAMO Academy. 46 | 47 | ```typescript 48 | import { GteBase } from "deepinfra"; 49 | 50 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 51 | const modelName = "thenlper/gte-base"; 52 | const main = async () => { 53 | const gteBase = new Embeddings(modelName, apiKey); 54 | const body = { 55 | inputs: [ 56 | "What is the capital of France?", 57 | "What is the capital of Germany?", 58 | "What is the capital of Italy?", 59 | ], 60 | }; 61 | const output = await gteBase.generate(body); 62 | const embeddings = output.embeddings[0]; 63 | console.log(embeddings); 64 | }; 65 | 66 | main(); 67 | ``` 68 | 69 | ### Use [SDXL](https://deepinfra.com/stability-ai/sdxl) to generate images 70 | 71 | SDXL requires unique parameters, therefore it requires different initialization. 72 | 73 | ```typescript 74 | import { Sdxl } from "deepinfra"; 75 | import axios from "axios"; 76 | import fs from "fs"; 77 | 78 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 79 | 80 | const main = async () => { 81 | const model = new Sdxl(apiKey); 82 | 83 | const input = { 84 | prompt: "The quick brown fox jumps over the lazy dog with", 85 | }; 86 | const response = await model.generate({ input }); 87 | const { output } = response; 88 | const image = output[0]; 89 | 90 | await axios.get(image, { responseType: "arraybuffer" }).then((response) => { 91 | fs.writeFileSync("image.png", response.data); 92 | }); 93 | }; 94 | 95 | main(); 96 | ``` 97 | 98 | ### Use [other text to image models](https://deepinfra.com/models/text-to-image) 99 | 100 | ```typescript 101 | import { TextToImage } from "deepinfra"; 102 | import axios from "axios"; 103 | import fs from "fs"; 104 | 105 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 106 | const modelName = "stabilityai/stable-diffusion-2-1"; 107 | const main = async () => { 108 | const model = new TextToImage(modelName, apiKey); 109 | const input = { 110 | prompt: "The quick brown fox jumps over the lazy dog with", 111 | }; 112 | 113 | const response = await model.generate(input); 114 | const { output } = response; 115 | const image = output[0]; 116 | 117 | await axios.get(image, { responseType: "arraybuffer" }).then((response) => { 118 | fs.writeFileSync("image.png", response.data); 119 | }); 120 | }; 121 | main(); 122 | ``` 123 | 124 | ### Use [automatic speech recognition models](https://deepinfra.com/models/automatic-speech-recognition) 125 | 126 | ```typescript 127 | import { AutomaticSpeechRecognition } from "deepinfra"; 128 | 129 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 130 | const modelName = "openai/whisper-base"; 131 | 132 | const main = async () => { 133 | const model = new AutomaticSpeechRecognition(modelName, apiKey); 134 | 135 | const input = { 136 | audio: path.join(__dirname, "audio.mp3"), 137 | }; 138 | const response = await model.generate(input); 139 | const { text } = response; 140 | console.log(text); 141 | }; 142 | 143 | main(); 144 | ``` 145 | 146 | ### Use [object detection models](https://deepinfra.com/models/object-detection) 147 | 148 | ```typescript 149 | import { ObjectDetection } from "deepinfra"; 150 | 151 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 152 | const modelName = "hustvl/yolos-tiny"; 153 | const main = async () => { 154 | const model = new ObjectDetection(modelName, apiKey); 155 | 156 | const input = { 157 | image: path.join(__dirname, "image.jpg"), 158 | }; 159 | const response = await model.generate(input); 160 | const { results } = response; 161 | console.log(results); 162 | }; 163 | ``` 164 | 165 | ### Use [token classification models](https://deepinfra.com/models/token-classification) 166 | 167 | ```typescript 168 | import { TokenClassification } from "deepinfra"; 169 | 170 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 171 | const modelName = "Davlan/bert-base-multilingual-cased-ner-hrl"; 172 | 173 | const main = async () => { 174 | const model = new TokenClassification(modelName, apiKey); 175 | 176 | const input = { 177 | text: "The quick brown fox jumps over the lazy dog", 178 | }; 179 | const response = await model.generate(input); 180 | const { results } = response; 181 | console.log(results); 182 | }; 183 | ``` 184 | 185 | ### Use [fill mask models](https://deepinfra.com/models/fill-mask) 186 | 187 | ```typescript 188 | import { FillMask } from "deepinfra"; 189 | 190 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 191 | const modelName = "GroNLP/bert-base-dutch-cased"; 192 | 193 | const main = async () => { 194 | const model = new FillMask(modelName, apiKey); 195 | 196 | const body = { 197 | input: "Ik heb een [MASK] gekocht.", 198 | }; 199 | 200 | const { results } = await model.generate(body); 201 | console.log(results); 202 | }; 203 | ``` 204 | 205 | ### Use [image classification models](https://deepinfra.com/models/image-classification) 206 | 207 | ```typescript 208 | import { ImageClassification } from "deepinfra"; 209 | 210 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 211 | const modelName = "google/vit-base-patch16-224"; 212 | 213 | const main = async () => { 214 | const model = new ImageClassification(modelName, apiKey); 215 | 216 | const body = { 217 | image: path.join(__dirname, "image.jpg"), 218 | }; 219 | 220 | const { results } = await model.generate(body); 221 | console.log(results); 222 | }; 223 | ``` 224 | 225 | ### Use [zero-shot image classification models](https://deepinfra.com/models/zero-shot-image-classification) 226 | 227 | ```typescript 228 | import { ZeroShotImageClassification } from "deepinfra"; 229 | 230 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 231 | const modelName = "openai/clip-vit-base-patch32"; 232 | 233 | const main = async () => { 234 | const model = new ZeroShotImageClassification(modelName, apiKey); 235 | 236 | const body = { 237 | image: path.join(__dirname, "image.jpg"), 238 | candidate_labels: ["dog", "cat", "car"], 239 | }; 240 | 241 | const { results } = await model.generate(body); 242 | console.log(results); 243 | }; 244 | ``` 245 | 246 | ### Use [text classification models](https://deepinfra.com/models/text-classification) 247 | 248 | ```typescript 249 | import { TextClassification } from "deepinfra"; 250 | 251 | const apiKey = "YOUR_DEEPINFRA_API_KEY"; 252 | const modelName = "ProsusAI/finbert"; 253 | 254 | const misc = async () => { 255 | const model = new TextClassification(apiKey); 256 | 257 | const body = { 258 | input: 259 | "DeepInfra emerges from stealth with $8M to make running AI inferences more affordable", 260 | }; 261 | 262 | const { results } = await model.generate(body); 263 | console.log(results); 264 | }; 265 | ``` 266 | 267 | ## Contributors 268 | 269 | [Oguz Vuruskaner](https://github.com/ovuruska) 270 | 271 | [Iskren Ivov Chernev](https://github.com/ichernev) 272 | 273 | ## Contributing 274 | 275 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull 276 | requests to us. 277 | 278 | ## License 279 | 280 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 281 | -------------------------------------------------------------------------------- /RELEASE.MD: -------------------------------------------------------------------------------- 1 | # Release 2 | This document describes the release process for the `deepinfra` package. 3 | 4 | 5 | ## Checklist 6 | Before release, make sure to do the following steps: 7 | 8 | - [ ] Update the `package.json` version number. You can use `npm version --allow-same-version --no-git-tag-version version_number` to update the version number. 9 | - [ ] Create a new release on GitHub with the same version number. 10 | 11 | 12 | 13 | ## Authors 14 | - [Oguz Vuruskaner](https://github.com/ovuruska) 15 | - [Iskren Ivov Chernev](https://github.com/ichernev) 16 | 17 | -------------------------------------------------------------------------------- /build.config.js: -------------------------------------------------------------------------------- 1 | import { defineBuildConfig } from 'unbuild'; 2 | 3 | export default defineBuildConfig({ 4 | entries: ['./src/index'], 5 | outDir: 'dist', 6 | declaration: true, 7 | clean: true, 8 | failOnWarn: true, 9 | rollup: { 10 | emitCJS: true, 11 | esbuild: { 12 | minify: process.env.NODE_ENV === 'production' 13 | }, 14 | }, 15 | }); 16 | -------------------------------------------------------------------------------- /jest.config.cjs: -------------------------------------------------------------------------------- 1 | /** @type {import('ts-jest').JstConfigWithTsJest} */ 2 | module.exports = { 3 | testEnvironment: 'node', 4 | extensionsToTreatAsEsm: [".ts"], 5 | testTimeout: 10000, 6 | coveragePathIgnorePatterns: [ 7 | "/node_modules/", 8 | "/examples/", 9 | "/test/" 10 | ], 11 | moduleNameMapper: { 12 | "^@/(.*)$": "/src/$1" 13 | }, 14 | testMatch: [ 15 | "/test/**/*.spec.ts" 16 | ], 17 | rootDir: ".", 18 | transform: { 19 | "^.+\\.tsx?$": "ts-jest" 20 | }, 21 | 22 | }; 23 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "deepinfra", 3 | "version": "2.0.2", 4 | "description": "Official API wrapper for DeepInfra", 5 | "main": "dist/index.js", 6 | "types": "dist/index.d.ts", 7 | "files": [ 8 | "dist" 9 | ], 10 | "repository": { 11 | "type": "git", 12 | "url": "git+https://github.com/deepinfra/deepinfra-node.git" 13 | }, 14 | "scripts": { 15 | "build": "tsc && tsc-alias -p tsconfig.json", 16 | "misc": "npx ts-node -r tsconfig-paths/register src/misc.ts", 17 | "prepare": "husky", 18 | "test": "jest --passWithNoTests", 19 | "lint": "eslint . --ext .ts --fix", 20 | "prettier": "prettier --write ./src ./test && prettier --write README.MD", 21 | "build-docs": "typedoc --out docs src", 22 | "predeploy-docs": "npm run build-docs", 23 | "deploy-docs": "npx gh-pages -d docs" 24 | }, 25 | "config": { 26 | "commitizen": { 27 | "path": "./node_modules/cz-conventional-changelog" 28 | } 29 | }, 30 | "keywords": [ 31 | "llm", 32 | "deepinfra", 33 | "api", 34 | "wrapper", 35 | "typesript", 36 | "large language model", 37 | "deep learning", 38 | "machine learning", 39 | "artificial intelligence", 40 | "ai" 41 | ], 42 | "author": "Oguz Vuruskaner (https://www.oguzvuruskaner.com)", 43 | "license": "MIT", 44 | "dependencies": { 45 | "@swc/core": "^1.4.6", 46 | "@swc/wasm": "^1.4.6", 47 | "axios": "^1.6.7", 48 | "form-data": "^4.0.0" 49 | }, 50 | "devDependencies": { 51 | "@types/jest": "^29.5.12", 52 | "@types/node": "^20.11.26", 53 | "@typescript-eslint/eslint-plugin": "^7.2.0", 54 | "@typescript-eslint/parser": "^7.2.0", 55 | "cz-conventional-changelog": "^3.3.0", 56 | "dotenv": "^16.4.5", 57 | "eslint": "^8.57.0", 58 | "eslint-plugin-unicorn": "^51.0.1", 59 | "husky": "^9.0.11", 60 | "jest": "^29.7.0", 61 | "prettier": "^3.2.5", 62 | "ts-jest": "^29.1.2", 63 | "ts-node": "^10.9.2", 64 | "tsc-alias": "^1.8.8", 65 | "tsconfig-paths": "^4.2.0", 66 | "typedoc": "^0.25.12", 67 | "typescript": "^5.4.2" 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/__mocks__/fs.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * This file is used in unit tests to mock the fs module. 3 | * 4 | */ 5 | 6 | import * as originalFs from "fs"; 7 | 8 | type MockFiles = Record; 9 | 10 | interface CustomFs 11 | extends Omit { 12 | __setMockFiles: (newMockFiles: MockFiles) => void; 13 | readFileSync: (filePath: string, options?: object) => string; 14 | createReadStream: (filePath: string, options?: object) => any; 15 | } 16 | 17 | const fs: CustomFs = jest.createMockFromModule("fs"); 18 | jest.mock("node:fs", () => fs); 19 | 20 | let mockFiles: MockFiles = Object.create(null); 21 | 22 | function __setMockFiles(newMockFiles: MockFiles): void { 23 | mockFiles = newMockFiles; 24 | } 25 | 26 | function readFileSync(filePath: string): string { 27 | return mockFiles[filePath] || ""; 28 | } 29 | 30 | function createReadStream(filePath: string): any { 31 | return filePath; 32 | } 33 | 34 | fs.__setMockFiles = __setMockFiles; 35 | fs.readFileSync = readFileSync; 36 | fs.createReadStream = createReadStream; 37 | 38 | module.exports = fs; 39 | -------------------------------------------------------------------------------- /src/clients/deep-infra.ts: -------------------------------------------------------------------------------- 1 | import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from "axios"; 2 | import { USER_AGENT } from "@/lib/constants/client"; 3 | import { ClientConfig, IClientConfig } from "@/lib/types/common/client-config"; 4 | 5 | export class DeepInfraClient { 6 | private axiosClient: AxiosInstance; 7 | private readonly clientConfig: ClientConfig; 8 | 9 | constructor( 10 | private readonly url: string, 11 | private readonly authToken: string, 12 | config?: Partial, 13 | ) { 14 | this.axiosClient = axios.create({ 15 | baseURL: this.url, 16 | }); 17 | this.clientConfig = new ClientConfig(config); 18 | } 19 | 20 | private async backoffDelay(attempt: number): Promise { 21 | const delay = 22 | attempt === 1 23 | ? this.clientConfig.initialBackoff 24 | : this.clientConfig.subsequentBackoff; 25 | return new Promise((resolve) => setTimeout(resolve, delay)); 26 | } 27 | 28 | public async post( 29 | data: object, 30 | config?: AxiosRequestConfig, 31 | ): Promise> { 32 | const headers = { 33 | "content-type": "application/json", 34 | ...config?.headers, 35 | "User-Agent": USER_AGENT, 36 | Authorization: `Bearer ${this.authToken}`, 37 | }; 38 | 39 | for (let attempt = 0; attempt <= this.clientConfig.maxRetries; attempt++) { 40 | try { 41 | return await this.axiosClient.post(this.url, data, { 42 | ...config, 43 | headers, 44 | }); 45 | } catch (error) { 46 | if (attempt < this.clientConfig.maxRetries) { 47 | await this.backoffDelay(attempt); 48 | } else { 49 | throw error; 50 | } 51 | } 52 | } 53 | throw new Error("Maximum retries exceeded"); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/clients/index.ts: -------------------------------------------------------------------------------- 1 | export { DeepInfraClient } from "@/clients/deep-infra"; 2 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | export * from "@/lib/models/base"; 2 | -------------------------------------------------------------------------------- /src/lib/constants/client.ts: -------------------------------------------------------------------------------- 1 | export const MAX_RETRIES = 5; 2 | export const INITIAL_BACKOFF = 5000; 3 | export const SUBSEQUENT_BACKOFF = 2000; 4 | export const USER_AGENT = "DeepInfra TypeScript API Client"; 5 | 6 | export const ROOT_URL = "https://api.deepinfra.com/v1/inference/"; 7 | -------------------------------------------------------------------------------- /src/lib/models/base/automatic-speech-recognition.ts: -------------------------------------------------------------------------------- 1 | import { AutomaticSpeechRecognitionRequest } from "@/lib/types/automatic-speech-recognition/request"; 2 | import { BaseModel } from "@/lib/models/base/base-model"; 3 | import { IClientConfig } from "@/lib/types/common/client-config"; 4 | import { AutomaticSpeechRecognitionResponse } from "@/lib/types/automatic-speech-recognition/response"; 5 | import { FormDataUtils } from "@/lib/utils/form-data"; 6 | 7 | export class AutomaticSpeechRecognition extends BaseModel { 8 | constructor( 9 | modelName: string, 10 | authToken?: string, 11 | config?: Partial, 12 | ) { 13 | super(modelName, authToken, config); 14 | } 15 | 16 | async generate( 17 | body: AutomaticSpeechRecognitionRequest, 18 | ): Promise { 19 | const formData = 20 | await FormDataUtils.prepareFormData( 21 | body, 22 | ["audio"], 23 | ); 24 | const response = await this.client.post( 25 | formData, 26 | { 27 | headers: { 28 | ...formData.getHeaders(), 29 | }, 30 | }, 31 | ); 32 | return response.data; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/lib/models/base/base-model.ts: -------------------------------------------------------------------------------- 1 | import { DeepInfraClient } from "@/clients"; 2 | import { IClientConfig } from "@/lib/types/common/client-config"; 3 | 4 | import { ROOT_URL } from "@/lib/constants/client"; 5 | import { URLUtils } from "@/lib/utils/url"; 6 | 7 | export class BaseModel { 8 | protected client: DeepInfraClient; 9 | protected readonly endpoint: string; 10 | protected authToken: string; 11 | constructor( 12 | readonly modelName: string, 13 | authToken?: string, 14 | config?: Partial, 15 | ) { 16 | this.endpoint = URLUtils.isValidUrl(modelName) 17 | ? modelName 18 | : ROOT_URL + modelName; 19 | this.authToken = 20 | authToken || this.getAuthTokenFromEnv() || this.warnAboutMissingApiKey(); 21 | this.client = new DeepInfraClient(this.endpoint, this.authToken, config); 22 | } 23 | 24 | private warnAboutMissingApiKey() { 25 | console.warn( 26 | "API key is not provided. Please provide an API key as an argument or set DEEPINFRA_API_KEY environment variable.", 27 | ); 28 | return ""; 29 | } 30 | 31 | private getAuthTokenFromEnv() { 32 | const apiKey = process.env.DEEPINFRA_API_KEY; 33 | return apiKey || ""; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/lib/models/base/cog-model.ts: -------------------------------------------------------------------------------- 1 | import { AxiosError } from "axios"; 2 | import { BaseModel } from "@/lib/models/base"; 3 | import { CogResponse } from "@/lib/types/cog/response"; 4 | import { CogRequest } from "@/lib/types/cog/request"; 5 | import { IClientConfig } from "@/lib/types/common/client-config"; 6 | 7 | export class CogBaseModel extends BaseModel { 8 | constructor( 9 | protected endpoint: string, 10 | authToken?: string, 11 | config?: Partial, 12 | ) { 13 | super(endpoint, authToken, config); 14 | } 15 | 16 | public async generate( 17 | body: CogRequest, 18 | ): Promise> { 19 | try { 20 | const response = await this.client.post>(body); 21 | return response.data; 22 | } catch (error) { 23 | this.handleError(error); 24 | throw new Error("Failed to generate text"); 25 | } 26 | } 27 | 28 | private handleError(error: unknown) { 29 | if (error instanceof AxiosError) { 30 | if (error.response) { 31 | console.error("Response data:", error.response.data); 32 | console.error("Status:", error.response.status); 33 | console.error("Headers:", error.response.headers); 34 | } else if (error.request) { 35 | console.error("No response received:", error.request); 36 | } else { 37 | console.error("Error message:", error.message); 38 | } 39 | } else { 40 | console.error("An unexpected error occurred:", error); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/lib/models/base/custom-model.ts: -------------------------------------------------------------------------------- 1 | import { BaseModel } from "@/lib/models/base/base-model"; 2 | import { IClientConfig } from "@/lib/types/common/client-config"; 3 | 4 | export abstract class CustomModel< 5 | Request extends object, 6 | Response extends object, 7 | > extends BaseModel { 8 | protected constructor( 9 | modelName: string, 10 | authToken?: string, 11 | config?: Partial, 12 | ) { 13 | super(modelName, authToken, config); 14 | } 15 | 16 | public async generate(body: Request): Promise { 17 | const response = await this.client.post(body); 18 | return response.data; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/lib/models/base/embeddings.ts: -------------------------------------------------------------------------------- 1 | import { BaseModel } from "@/lib/models/base"; 2 | import { EmbeddingsRequest } from "@/lib/types/embeddings/request"; 3 | import { EmbeddingsResponse } from "@/lib/types/embeddings/response"; 4 | import { IClientConfig } from "@/lib/types/common/client-config"; 5 | 6 | export class Embeddings extends BaseModel { 7 | constructor( 8 | modelName: string, 9 | authToken?: string, 10 | config?: Partial, 11 | ) { 12 | super(modelName, authToken, config); 13 | } 14 | 15 | public async generate(body: EmbeddingsRequest): Promise { 16 | try { 17 | const response = await this.client.post(body); 18 | const { data, status } = response; 19 | if (status !== 200) { 20 | throw new Error(`HTTP error! status: ${status}`); 21 | } 22 | return data; 23 | } catch (error) { 24 | console.error("Error generating embeddings:", error); 25 | throw error; 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/lib/models/base/fill-mask.ts: -------------------------------------------------------------------------------- 1 | import { CustomModel } from "@/lib/models/base/custom-model"; 2 | import { FillMaskRequest } from "@/lib/types/token-classification/request"; 3 | import { TokenClassificationResponse } from "@/lib/types/token-classification/response"; 4 | import { IClientConfig } from "@/lib/types/common/client-config"; 5 | 6 | export class FillMask extends CustomModel< 7 | FillMaskRequest, 8 | TokenClassificationResponse 9 | > { 10 | constructor( 11 | modelName: string, 12 | authToken?: string, 13 | config?: Partial, 14 | ) { 15 | super(modelName, authToken, config); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/lib/models/base/image-base-model.ts: -------------------------------------------------------------------------------- 1 | import { BaseModel } from "@/lib/models/base/base-model"; 2 | import { IClientConfig } from "@/lib/types/common/client-config"; 3 | import { ImageRequest } from "@/lib/types/common/image-request"; 4 | import { FormDataUtils } from "@/lib/utils/form-data"; 5 | 6 | export class ImageBaseModel< 7 | RequestType extends ImageRequest, 8 | ResponseType, 9 | > extends BaseModel { 10 | constructor( 11 | modelName: string, 12 | authToken?: string, 13 | config?: Partial, 14 | ) { 15 | super(modelName, authToken, config); 16 | } 17 | 18 | async generate(body: RequestType): Promise { 19 | const formData = await FormDataUtils.prepareFormData(body, [ 20 | "image", 21 | ]); 22 | const response = await this.client.post(formData, { 23 | headers: { 24 | ...formData.getHeaders(), 25 | }, 26 | }); 27 | return response.data; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/lib/models/base/image-classification.ts: -------------------------------------------------------------------------------- 1 | import { ImageClassificationRequest } from "@/lib/types/image-classification/request"; 2 | import { ImageClassificationResponse } from "@/lib/types/image-classification/response"; 3 | import { ImageBaseModel } from "@/lib/models/base/image-base-model"; 4 | 5 | export class ImageClassification extends ImageBaseModel< 6 | ImageClassificationRequest, 7 | ImageClassificationResponse 8 | > {} 9 | -------------------------------------------------------------------------------- /src/lib/models/base/index.ts: -------------------------------------------------------------------------------- 1 | export { BaseModel } from "@/lib/models/base/base-model"; 2 | export { CustomModel } from "@/lib/models/base/custom-model"; 3 | export { TextToImage } from "@/lib/models/base/text-to-image"; 4 | export { TextGeneration } from "@/lib/models/base/text-generation"; 5 | export { Embeddings } from "@/lib/models/base/embeddings"; 6 | export { ObjectDetection } from "@/lib/models/base/object-detection"; 7 | export { AutomaticSpeechRecognition } from "@/lib/models/base/automatic-speech-recognition"; 8 | export { TokenClassification } from "@/lib/models/base/token-classification"; 9 | export { FillMask } from "@/lib/models/base/fill-mask"; 10 | export { TextClassification } from "@/lib/models/base/text-classification"; 11 | export { QuestionAnswering } from "@/lib/models/base/question-answering"; 12 | export { Sdxl } from "@/lib/models/base/sdxl"; 13 | export { ImageClassification } from "@/lib/models/base/image-classification"; 14 | export { ZeroShotImageClassification } from "@/lib/models/base/zero-shot-image-classification"; 15 | -------------------------------------------------------------------------------- /src/lib/models/base/object-detection.ts: -------------------------------------------------------------------------------- 1 | import { ObjectDetectionRequest } from "@/lib/types/object-detection/request"; 2 | import { ObjectDetectionResponse } from "@/lib/types/object-detection/response"; 3 | import { ImageBaseModel } from "@/lib/models/base/image-base-model"; 4 | 5 | export class ObjectDetection extends ImageBaseModel< 6 | ObjectDetectionRequest, 7 | ObjectDetectionResponse 8 | > {} 9 | -------------------------------------------------------------------------------- /src/lib/models/base/question-answering.ts: -------------------------------------------------------------------------------- 1 | import { QuestionAnsweringRequest } from "@/lib/types/questions-answering/request"; 2 | import { QuestionAnsweringResponse } from "@/lib/types/questions-answering/response"; 3 | import { CustomModel } from "@/lib/models/base/custom-model"; 4 | import { IClientConfig } from "@/lib/types/common/client-config"; 5 | 6 | export class QuestionAnswering extends CustomModel< 7 | QuestionAnsweringRequest, 8 | QuestionAnsweringResponse 9 | > { 10 | constructor( 11 | modelName: string, 12 | authToken?: string, 13 | config?: Partial, 14 | ) { 15 | super(modelName, authToken, config); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/lib/models/base/sdxl.ts: -------------------------------------------------------------------------------- 1 | import { CogBaseModel } from "@/lib/models/base/cog-model"; 2 | import { SdxlIn } from "@/lib/types/cog/sdxl/request"; 3 | import { SdxlOut } from "@/lib/types/cog/sdxl/response"; 4 | import { IClientConfig } from "@/lib/types/common/client-config"; 5 | 6 | export class Sdxl extends CogBaseModel { 7 | public static readonly modelName: string = "stability-ai/sdxl"; 8 | constructor(authToken?: string, config?: Partial) { 9 | super(Sdxl.modelName, authToken, config); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/lib/models/base/text-classification.ts: -------------------------------------------------------------------------------- 1 | import { TextClassificationRequest } from "@/lib/types/text-classification/request"; 2 | import { TextClassificationResponse } from "@/lib/types/text-classification/response"; 3 | import { CustomModel } from "@/lib/models/base/custom-model"; 4 | import { IClientConfig } from "@/lib/types/common/client-config"; 5 | 6 | export class TextClassification extends CustomModel< 7 | TextClassificationRequest, 8 | TextClassificationResponse 9 | > { 10 | constructor( 11 | modelName: string, 12 | authToken?: string, 13 | config?: Partial, 14 | ) { 15 | super(modelName, authToken, config); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/lib/models/base/text-generation.ts: -------------------------------------------------------------------------------- 1 | import { AxiosError } from "axios"; 2 | import { BaseModel } from "@/lib/models/base"; 3 | import { TextGenerationResponse } from "@/lib/types/text-generation/response"; 4 | import { TextGenerationRequest } from "@/lib/types/text-generation/request"; 5 | import { IClientConfig } from "@/lib/types/common/client-config"; 6 | 7 | export class TextGeneration extends BaseModel { 8 | constructor( 9 | modelName: string, 10 | authToken?: string, 11 | config?: Partial, 12 | ) { 13 | super(modelName, authToken, config); 14 | } 15 | 16 | public async generate( 17 | body: TextGenerationRequest, 18 | ): Promise { 19 | try { 20 | const response = await this.client.post(body); 21 | return response.data; 22 | } catch (error) { 23 | this.handleError(error); 24 | throw new Error("Failed to generate text"); 25 | } 26 | } 27 | 28 | private handleError(error: unknown) { 29 | if (error instanceof AxiosError) { 30 | if (error.response) { 31 | console.error("Response data:", error.response.data); 32 | console.error("Status:", error.response.status); 33 | console.error("Headers:", error.response.headers); 34 | } else if (error.request) { 35 | console.error("No response received:", error.request); 36 | } else { 37 | console.error("Error message:", error.message); 38 | } 39 | } else { 40 | console.error("An unexpected error occurred:", error); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/lib/models/base/text-to-image.ts: -------------------------------------------------------------------------------- 1 | import { BaseModel } from "@/lib/models/base"; 2 | import { AxiosResponse } from "axios"; 3 | import { TextToImageResponse } from "@/lib/types/text-to-image/response"; 4 | import { TextToImageRequest } from "@/lib/types/text-to-image/request"; 5 | import { IClientConfig } from "@/lib/types/common/client-config"; 6 | 7 | export class TextToImage extends BaseModel { 8 | constructor( 9 | modelName: string, 10 | authToken?: string, 11 | config?: Partial, 12 | ) { 13 | super(modelName, authToken, config); 14 | } 15 | 16 | public async generate( 17 | body: TextToImageRequest, 18 | ): Promise { 19 | try { 20 | const response: AxiosResponse = 21 | await this.client.post(body); 22 | return response.data; 23 | } catch (error) { 24 | console.error("Error generating image:", error); 25 | throw error; 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/lib/models/base/token-classification.ts: -------------------------------------------------------------------------------- 1 | import { CustomModel } from "@/lib/models/base/custom-model"; 2 | import { FillMaskRequest } from "@/lib/types/token-classification/request"; 3 | import { TokenClassificationResponse } from "@/lib/types/token-classification/response"; 4 | import { IClientConfig } from "@/lib/types/common/client-config"; 5 | 6 | export class TokenClassification extends CustomModel< 7 | FillMaskRequest, 8 | TokenClassificationResponse 9 | > { 10 | constructor( 11 | modelName: string, 12 | authToken?: string, 13 | config?: Partial, 14 | ) { 15 | super(modelName, authToken, config); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/lib/models/base/zero-shot-image-classification.ts: -------------------------------------------------------------------------------- 1 | import { ZeroShotImageClassificationRequest } from "@/lib/types/zero-shot-image-classification/request"; 2 | import { ZeroShotImageClassificationResponse } from "@/lib/types/zero-shot-image-classification/response"; 3 | import { ImageBaseModel } from "@/lib/models/base/image-base-model"; 4 | 5 | export class ZeroShotImageClassification extends ImageBaseModel< 6 | ZeroShotImageClassificationRequest, 7 | ZeroShotImageClassificationResponse 8 | > {} 9 | -------------------------------------------------------------------------------- /src/lib/types/automatic-speech-recognition/request.ts: -------------------------------------------------------------------------------- 1 | import { ReadStreamInput } from "@/lib/utils/read-stream"; 2 | 3 | export interface AutomaticSpeechRecognitionRequest { 4 | audio: ReadStreamInput; 5 | task?: "transcribe" | "translate"; 6 | language?: string; 7 | temperature?: number; 8 | patience?: number; 9 | suppress_tokens?: string; 10 | initial_prompt?: string; 11 | condition_on_previous_text?: boolean; 12 | temperature_increment_on_fallback?: number; 13 | compression_ratio_threshold?: number; 14 | logprob_threshold?: number; 15 | no_speech_threshold?: number; 16 | webhook?: string; 17 | } 18 | -------------------------------------------------------------------------------- /src/lib/types/automatic-speech-recognition/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | interface AutomaticSpeechRecognitionWord { 4 | text: string; 5 | start: number; 6 | end: number; 7 | confidence: number; 8 | } 9 | 10 | interface AutomaticSpeechRecognitionSegment { 11 | id: number; 12 | seek: number; 13 | start: number; 14 | end: number; 15 | text: string; 16 | tokens: number[]; 17 | temperature?: number; 18 | avg_logprob?: number; 19 | compression_ratio?: number; 20 | no_speech_prob?: number; 21 | confidence?: number; 22 | words?: AutomaticSpeechRecognitionWord[]; 23 | } 24 | 25 | export interface AutomaticSpeechRecognitionResponse { 26 | text: string; 27 | segments: AutomaticSpeechRecognitionSegment[]; 28 | language: string; 29 | input_length_ms?: number; 30 | request_id?: string; 31 | inference_status: Status; 32 | } 33 | -------------------------------------------------------------------------------- /src/lib/types/cog/request.ts: -------------------------------------------------------------------------------- 1 | export type WebhookEventType = "start" | "output" | "logs" | "completed"; 2 | 3 | export interface CogRequest { 4 | id?: string; 5 | version?: string | null; 6 | input: Req; 7 | webhook?: string | null; 8 | webhook_events_filter?: WebhookEventType[]; 9 | } 10 | -------------------------------------------------------------------------------- /src/lib/types/cog/response.ts: -------------------------------------------------------------------------------- 1 | import { CogRequest } from "./request"; 2 | 3 | export type CogStatus = "succeeded" | "failed"; 4 | export interface CogResponse extends CogRequest { 5 | status: CogStatus; 6 | created_at: string | null; 7 | started_at: string; 8 | complated_at: string; 9 | output: Res; 10 | error: string | null; 11 | logs: string; 12 | inference_status: { 13 | status: CogStatus; 14 | runtime_ms: number; 15 | cost: number; 16 | tokens_generated: number | null; 17 | tokens_input: number | null; 18 | }; 19 | metrics: { 20 | predict_time: number; 21 | }; 22 | } 23 | -------------------------------------------------------------------------------- /src/lib/types/cog/sdxl/request.ts: -------------------------------------------------------------------------------- 1 | export interface SdxlIn { 2 | prompt: string; 3 | negative_prompt?: string; 4 | image?: string; // ? base64 5 | mask?: string; // ? base64 6 | width?: number; 7 | height?: number; 8 | num_outputs?: number; 9 | scheduler?: 10 | | "DDIM" 11 | | "DPMSolverMultistep" 12 | | "HeunDiscrete" 13 | | "KarrasDPM" 14 | | "K_EULER_ANCESTRAL" 15 | | "K_EULER" 16 | | "PNDM"; 17 | num_inference_steps?: number; 18 | guidance_scale?: number; 19 | prompt_strength?: number; 20 | seed?: number; 21 | refine?: "no_refiner" | "expert_ensemble_refiner" | "base_image_refiner"; 22 | high_noise_frac?: number; 23 | refine_steps?: number; 24 | apply_watermark?: boolean; 25 | } 26 | -------------------------------------------------------------------------------- /src/lib/types/cog/sdxl/response.ts: -------------------------------------------------------------------------------- 1 | export type SdxlOut = string[]; 2 | -------------------------------------------------------------------------------- /src/lib/types/common/client-config.ts: -------------------------------------------------------------------------------- 1 | import { 2 | INITIAL_BACKOFF, 3 | MAX_RETRIES, 4 | SUBSEQUENT_BACKOFF, 5 | } from "@/lib/constants/client"; 6 | 7 | export interface IClientConfig { 8 | maxRetries: number; 9 | initialBackoff: number; 10 | subsequentBackoff: number; 11 | } 12 | 13 | export class ClientConfig implements IClientConfig { 14 | maxRetries: number; 15 | initialBackoff: number; 16 | subsequentBackoff: number; 17 | 18 | constructor(config?: Partial) { 19 | this.maxRetries = config?.maxRetries ?? MAX_RETRIES; 20 | this.initialBackoff = config?.initialBackoff ?? INITIAL_BACKOFF; 21 | this.subsequentBackoff = config?.subsequentBackoff ?? SUBSEQUENT_BACKOFF; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/lib/types/common/image-item.ts: -------------------------------------------------------------------------------- 1 | export interface ImageItem { 2 | label: string; 3 | score: number; 4 | } 5 | -------------------------------------------------------------------------------- /src/lib/types/common/image-request.ts: -------------------------------------------------------------------------------- 1 | import { ReadStreamInput } from "@/lib/utils/read-stream"; 2 | 3 | export interface ImageRequest { 4 | image: ReadStreamInput; 5 | webhook?: string; 6 | } 7 | -------------------------------------------------------------------------------- /src/lib/types/common/single-text-input-request.ts: -------------------------------------------------------------------------------- 1 | export interface SingleTextInputRequest { 2 | input: string; 3 | webhook?: string; 4 | } 5 | -------------------------------------------------------------------------------- /src/lib/types/common/status.ts: -------------------------------------------------------------------------------- 1 | export interface Status { 2 | status: "unknown" | "queued" | "running" | "succeeded" | "failed"; 3 | runtime_ms: number; 4 | cost: number; 5 | tokens_generated: number; 6 | tokens_input: number; 7 | } 8 | -------------------------------------------------------------------------------- /src/lib/types/embeddings/request.ts: -------------------------------------------------------------------------------- 1 | export interface EmbeddingsRequest { 2 | inputs: string[]; 3 | normalize?: boolean; 4 | image?: string; 5 | webhook?: string; 6 | } 7 | -------------------------------------------------------------------------------- /src/lib/types/embeddings/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface EmbeddingsResponse { 4 | embeddings: number[][]; 5 | input_tokens: number; 6 | request_id?: string; 7 | inference_status: Status; 8 | } 9 | -------------------------------------------------------------------------------- /src/lib/types/fill-mask/request.ts: -------------------------------------------------------------------------------- 1 | import { SingleTextInputRequest } from "@/lib/types/common/single-text-input-request"; 2 | 3 | export interface FillMaskRequest extends SingleTextInputRequest {} 4 | -------------------------------------------------------------------------------- /src/lib/types/fill-mask/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface FillMaskItem { 4 | sequence: string; 5 | score: number; 6 | token: number; 7 | token_str: string; 8 | } 9 | 10 | export interface FillMaskResponse { 11 | results: FillMaskItem[]; 12 | request_id?: string; 13 | inference_status: Status; 14 | } 15 | -------------------------------------------------------------------------------- /src/lib/types/image-classification/request.ts: -------------------------------------------------------------------------------- 1 | import { ImageRequest } from "@/lib/types/common/image-request"; 2 | 3 | export interface ImageClassificationRequest extends ImageRequest {} 4 | -------------------------------------------------------------------------------- /src/lib/types/image-classification/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | import { ImageItem } from "@/lib/types/common/image-item"; 3 | 4 | export interface ImageClassificationResponse { 5 | results: ImageItem[]; 6 | request_id?: string; 7 | inference_status: Status; 8 | } 9 | -------------------------------------------------------------------------------- /src/lib/types/object-detection/request.ts: -------------------------------------------------------------------------------- 1 | import { ImageRequest } from "@/lib/types/common/image-request"; 2 | 3 | export interface ObjectDetectionRequest extends ImageRequest {} 4 | -------------------------------------------------------------------------------- /src/lib/types/object-detection/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | import { ImageItem } from "@/lib/types/common/image-item"; 3 | 4 | export interface ObjectDetectionBox { 5 | xmin: number; 6 | ymin: number; 7 | xmax: number; 8 | ymax: number; 9 | } 10 | 11 | export interface ObjectDetectionItem extends ImageItem { 12 | box: ObjectDetectionBox; 13 | } 14 | 15 | export interface ObjectDetectionResponse { 16 | results: ObjectDetectionItem[]; 17 | request_id?: string; 18 | inference_status: Status; 19 | } 20 | -------------------------------------------------------------------------------- /src/lib/types/questions-answering/request.ts: -------------------------------------------------------------------------------- 1 | export interface QuestionAnsweringRequest { 2 | question: string; 3 | context: string; 4 | webhook?: string; 5 | } 6 | -------------------------------------------------------------------------------- /src/lib/types/questions-answering/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface QuestionAnsweringResponse { 4 | answer: string; 5 | score: number; 6 | start: number; 7 | end: number; 8 | request_id: string | null; 9 | inference_status: Status; 10 | } 11 | -------------------------------------------------------------------------------- /src/lib/types/text-classification/request.ts: -------------------------------------------------------------------------------- 1 | import { SingleTextInputRequest } from "@/lib/types/common/single-text-input-request"; 2 | 3 | export interface TextClassificationRequest extends SingleTextInputRequest {} 4 | -------------------------------------------------------------------------------- /src/lib/types/text-classification/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface TextClassificationResponse { 4 | results: TextClassificationItem[]; 5 | request_id: string | null; 6 | inference_status: Status; 7 | } 8 | 9 | export interface TextClassificationItem { 10 | label: string; 11 | score: number; 12 | } 13 | -------------------------------------------------------------------------------- /src/lib/types/text-generation/request.ts: -------------------------------------------------------------------------------- 1 | export interface TextGenerationRequest { 2 | input: string; 3 | stream?: boolean; 4 | max_new_tokens?: number; 5 | temperature?: number; 6 | top_p?: number; 7 | top_k?: number; 8 | repetition_penalty?: number; 9 | stop?: string[]; 10 | num_responses?: number; 11 | response_format?: { 12 | type: string; 13 | }; 14 | presence_penalty?: number; 15 | frequency_penalty?: number; 16 | webhook?: string; 17 | } 18 | -------------------------------------------------------------------------------- /src/lib/types/text-generation/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface TextGenerationResponse { 4 | request_id: string; 5 | inference_status: Status; 6 | results: GeneratedText[]; 7 | num_tokens: number; 8 | num_input_tokens: number; 9 | } 10 | 11 | export interface GeneratedText { 12 | generated_text: string; 13 | } 14 | -------------------------------------------------------------------------------- /src/lib/types/text-to-image/request.ts: -------------------------------------------------------------------------------- 1 | export interface TextToImageRequest { 2 | prompt: string; 3 | negative_prompt?: string; 4 | image?: string; 5 | num_images?: number; 6 | num_inference_steps?: number; 7 | guidance_scale?: number; 8 | strength?: number; 9 | width?: 10 | | 128 11 | | 256 12 | | 384 13 | | 448 14 | | 512 15 | | 576 16 | | 640 17 | | 704 18 | | 768 19 | | 832 20 | | 896 21 | | 960 22 | | 1024; 23 | height?: 24 | | 128 25 | | 256 26 | | 384 27 | | 448 28 | | 512 29 | | 576 30 | | 640 31 | | 704 32 | | 768 33 | | 832 34 | | 896 35 | | 960 36 | | 1024; 37 | seed?: number; 38 | use_compel?: boolean; 39 | webhook?: string; 40 | } 41 | -------------------------------------------------------------------------------- /src/lib/types/text-to-image/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface TextToImageResponse { 4 | images: string[]; 5 | nsfw_content_detected: boolean[]; 6 | seed: number; 7 | 8 | request_id: string; 9 | inference_status: Status; 10 | } 11 | -------------------------------------------------------------------------------- /src/lib/types/token-classification/request.ts: -------------------------------------------------------------------------------- 1 | import { SingleTextInputRequest } from "@/lib/types/common/single-text-input-request"; 2 | 3 | export interface FillMaskRequest extends SingleTextInputRequest {} 4 | -------------------------------------------------------------------------------- /src/lib/types/token-classification/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | 3 | export interface TokenClassificationItem { 4 | entity_group: string; 5 | score: number; 6 | word: string; 7 | start: number; 8 | end: number; 9 | } 10 | 11 | export interface TokenClassificationResponse { 12 | results: TokenClassificationItem[]; 13 | request_id?: string; 14 | inference_status: Status; 15 | } 16 | -------------------------------------------------------------------------------- /src/lib/types/zero-shot-image-classification/request.ts: -------------------------------------------------------------------------------- 1 | import { ImageRequest } from "@/lib/types/common/image-request"; 2 | 3 | export interface ZeroShotImageClassificationRequest extends ImageRequest { 4 | candidate_labels: string[]; 5 | } 6 | -------------------------------------------------------------------------------- /src/lib/types/zero-shot-image-classification/response.ts: -------------------------------------------------------------------------------- 1 | import { Status } from "@/lib/types/common/status"; 2 | import { ImageItem } from "@/lib/types/common/image-item"; 3 | 4 | export interface ZeroShotImageClassificationResponse { 5 | results: ImageItem[]; 6 | request_id?: string; 7 | inference_status: Status; 8 | } 9 | -------------------------------------------------------------------------------- /src/lib/utils/form-data.ts: -------------------------------------------------------------------------------- 1 | import FormData from "form-data"; 2 | import { ReadStreamInput, ReadStreamUtils } from "@/lib/utils/read-stream"; 3 | 4 | export const FormDataUtils = { 5 | /** 6 | * Prepare form data from the given data object 7 | * @param data - The data object to be converted to form data. 8 | * @param blobKeys - The keys of the data object that contain binary data. 9 | * @returns A FormData object. 10 | * @throws {Error} If the binary data is invalid. 11 | * 12 | */ 13 | async prepareFormData( 14 | data: T, 15 | blobKeys: string[] = [], 16 | ): Promise { 17 | const formData = new FormData(); 18 | for (const [key, value] of Object.entries(data)) { 19 | if (blobKeys.includes(key)) { 20 | const readStream = await ReadStreamUtils.getReadStream( 21 | value as ReadStreamInput, 22 | ); 23 | formData.append(key, readStream); 24 | } else { 25 | formData.append(key, JSON.stringify(value)); 26 | } 27 | } 28 | return formData; 29 | }, 30 | }; 31 | -------------------------------------------------------------------------------- /src/lib/utils/read-stream.ts: -------------------------------------------------------------------------------- 1 | import fs from "node:fs"; 2 | import axios from "axios"; 3 | import { Readable } from "stream"; 4 | 5 | /** 6 | * The input types that can be converted to a ReadableStream. 7 | * @alias ReadStreamInput 8 | */ 9 | export type ReadStreamInput = Buffer | string; 10 | 11 | /** 12 | * Utility class for working with images. 13 | * @class 14 | * @category Utils 15 | * @hideconstructor 16 | */ 17 | 18 | export const ReadStreamUtils = { 19 | /** 20 | * Creates a ReadableStream from a Buffer. 21 | * @param buffer - The Buffer to be streamed. 22 | * @returns A ReadableStream. 23 | */ 24 | bufferToStream(buffer: Buffer): Readable { 25 | const stream = new Readable(); 26 | stream.push(buffer); 27 | stream.push(null); // Signifies the end of the stream. 28 | return stream; 29 | }, 30 | 31 | /** 32 | * Creates a ReadableStream from a file path. 33 | * @param filePath - The path to the image file. 34 | * @returns A ReadableStream of the file contents. 35 | */ 36 | fileToStream(filePath: string): Readable { 37 | return fs.createReadStream(filePath); 38 | }, 39 | 40 | /** 41 | * Downloads an image from a URL and returns it as a ReadableStream. 42 | * @param url - The URL of the image. 43 | * @returns A ReadableStream containing the image data. 44 | */ 45 | async urlToStream(url: string): Promise { 46 | const response = await axios.get(url, { responseType: "stream" }); 47 | return response.data; 48 | }, 49 | 50 | /** 51 | * Converts a Base64 string to a ReadableStream. 52 | * @param base64 - The Base64 string to be converted. 53 | * @returns A ReadableStream of the image data. 54 | * @throws {Error} If the Base64 string is invalid. 55 | */ 56 | base64ToStream(base64: string): Readable { 57 | const buffer = Buffer.from(base64, "base64"); 58 | return this.bufferToStream(buffer); 59 | }, 60 | 61 | /** 62 | * Returns a ReadableStream from an object. 63 | * The object can be a Buffer, a file path, or a URL. 64 | * @param input 65 | * @returns A ReadableStream of the image data. 66 | * @throws {Error} If the input type is invalid. 67 | */ 68 | async getReadStream(input: ReadStreamInput): Promise { 69 | if (Buffer.isBuffer(input)) { 70 | return this.bufferToStream(input); 71 | } else if (typeof input === "string") { 72 | if (input.startsWith("http")) { 73 | return this.urlToStream(input); 74 | } else if (input.startsWith("data:")) { 75 | const base64Data = input.split(",")[1]; 76 | return this.base64ToStream(base64Data); 77 | } else { 78 | return this.fileToStream(input); 79 | } 80 | } else { 81 | throw new Error("Invalid input type"); 82 | } 83 | }, 84 | }; 85 | -------------------------------------------------------------------------------- /src/lib/utils/url.ts: -------------------------------------------------------------------------------- 1 | export const URLUtils = { 2 | isValidUrl: (urlString: string): boolean => { 3 | try { 4 | const url = new URL(urlString); 5 | return ["http:", "https:"].includes(url.protocol); 6 | } catch { 7 | return false; 8 | } 9 | }, 10 | }; 11 | -------------------------------------------------------------------------------- /test/base/automatic-speech-recognition.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | import { ROOT_URL } from "@/lib/constants/client"; 6 | import { AutomaticSpeechRecognition } from "@/index"; 7 | import FormData from "form-data"; 8 | 9 | jest.mock("axios", () => { 10 | const mockAxiosInstance = { 11 | post: postMock, 12 | get: jest.fn().mockResolvedValue({ data: {} }), 13 | }; 14 | return { 15 | get: jest.fn(() => mockAxiosInstance), 16 | create: jest.fn(() => mockAxiosInstance), 17 | }; 18 | }); 19 | 20 | describe("AutomaticSpeechRecognition", () => { 21 | const modelName = "openai/whisper-large"; 22 | const apiKey = "your-api-key"; 23 | let model: AutomaticSpeechRecognition; 24 | 25 | beforeAll(() => { 26 | model = new AutomaticSpeechRecognition(modelName, apiKey); 27 | }); 28 | 29 | it("should create a new instance", () => { 30 | expect(model).toBeInstanceOf(AutomaticSpeechRecognition); 31 | }); 32 | 33 | it("should send a request to correct URL", async () => { 34 | const response = await model.generate({ 35 | audio: "test/assets/audio.mp3", 36 | }); 37 | 38 | expect(response).toBeDefined(); 39 | expect(postMock).toHaveBeenCalledWith( 40 | `${ROOT_URL}${modelName}`, 41 | expect.any(FormData), 42 | expect.objectContaining({ 43 | headers: expect.objectContaining({ 44 | "content-type": expect.stringMatching(/multipart\/form-data/), 45 | }), 46 | }), 47 | ); 48 | }); 49 | }); 50 | -------------------------------------------------------------------------------- /test/base/embeddings.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ status: 200, data: { transcription: "example text" } }); 4 | jest.mock("axios", () => { 5 | const mockAxiosInstance = { 6 | post: postMock, 7 | }; 8 | return { 9 | create: jest.fn(() => mockAxiosInstance), 10 | }; 11 | }); 12 | import { ROOT_URL } from "@/lib/constants/client"; 13 | import { Embeddings } from "@/index"; 14 | 15 | describe("Embeddings", () => { 16 | const modelName = "BAAI/bge-large-en-v1.5"; 17 | const apiKey = "your-api-key"; 18 | let model: Embeddings; 19 | 20 | beforeAll(() => { 21 | model = new Embeddings(modelName, apiKey); 22 | }); 23 | 24 | it("should create a new instance", () => { 25 | expect(model).toBeInstanceOf(Embeddings); 26 | }); 27 | 28 | it("should send a request to correct URL", async () => { 29 | const response = await model.generate({ 30 | inputs: ["Hello world", "Hallo Wereld"], 31 | }); 32 | 33 | expect(response).toBeDefined(); 34 | expect(postMock).toHaveBeenCalledWith( 35 | `${ROOT_URL}${modelName}`, 36 | expect.any(Object), 37 | expect.any(Object), 38 | ); 39 | }); 40 | }); 41 | -------------------------------------------------------------------------------- /test/base/fill-mask.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ status: 200, data: { transcription: "example text" } }); 4 | jest.mock("axios", () => { 5 | const mockAxiosInstance = { 6 | post: postMock, 7 | }; 8 | return { 9 | create: jest.fn(() => mockAxiosInstance), 10 | }; 11 | }); 12 | import { ROOT_URL } from "@/lib/constants/client"; 13 | import { FillMask } from "@/index"; 14 | 15 | describe("FillMask", () => { 16 | const modelName = "GroNLP/bert-base-dutch-cased"; 17 | const apiKey = "your-api-key"; 18 | let model: FillMask; 19 | 20 | beforeAll(() => { 21 | model = new FillMask(modelName, apiKey); 22 | }); 23 | 24 | it("should create a new instance", () => { 25 | expect(model).toBeInstanceOf(FillMask); 26 | }); 27 | 28 | it("should send a request to correct URL", async () => { 29 | const response = await model.generate({ 30 | input: "Hello [MASK]", 31 | }); 32 | 33 | expect(response).toBeDefined(); 34 | expect(postMock).toHaveBeenCalledWith( 35 | `${ROOT_URL}${modelName}`, 36 | expect.any(Object), 37 | expect.any(Object), 38 | ); 39 | }); 40 | }); 41 | -------------------------------------------------------------------------------- /test/base/image-classification.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | jest.mock("axios", () => { 5 | const mockAxiosInstance = { 6 | post: postMock, 7 | }; 8 | return { 9 | create: jest.fn(() => mockAxiosInstance), 10 | }; 11 | }); 12 | import { ROOT_URL } from "@/lib/constants/client"; 13 | import { ImageClassification } from "@/index"; 14 | import FormData from "form-data"; 15 | 16 | describe("ImageClassification", () => { 17 | const modelName = "google/vit-base-patch16-224"; 18 | const apiKey = "your-api-key"; 19 | let model: ImageClassification; 20 | 21 | beforeAll(() => { 22 | model = new ImageClassification(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(ImageClassification); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | image: "test/assets/image.jpg", 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(FormData), 38 | expect.objectContaining({ 39 | headers: expect.objectContaining({ 40 | "content-type": expect.stringMatching(/multipart\/form-data/), 41 | }), 42 | }), 43 | ); 44 | }); 45 | }); 46 | -------------------------------------------------------------------------------- /test/base/object-detection.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | jest.mock("axios", () => { 5 | const mockAxiosInstance = { 6 | post: postMock, 7 | }; 8 | return { 9 | create: jest.fn(() => mockAxiosInstance), 10 | }; 11 | }); 12 | import { ROOT_URL } from "@/lib/constants/client"; 13 | import { ObjectDetection } from "@/index"; 14 | import FormData from "form-data"; 15 | 16 | describe("ObjectDetection", () => { 17 | const modelName = "hustvl/yolos-base"; 18 | const apiKey = "your-api-key"; 19 | let model: ObjectDetection; 20 | 21 | beforeAll(() => { 22 | model = new ObjectDetection(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(ObjectDetection); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | image: "test/assets/image.jpg", 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(FormData), 38 | expect.objectContaining({ 39 | headers: expect.objectContaining({ 40 | "content-type": expect.stringMatching(/multipart\/form-data/), 41 | }), 42 | }), 43 | ); 44 | }); 45 | }); 46 | -------------------------------------------------------------------------------- /test/base/question-answering.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | jest.mock("axios", () => { 6 | const mockAxiosInstance = { 7 | post: postMock, 8 | }; 9 | return { 10 | create: jest.fn(() => mockAxiosInstance), 11 | }; 12 | }); 13 | import { ROOT_URL } from "@/lib/constants/client"; 14 | import { QuestionAnswering } from "@/index"; 15 | 16 | describe("QuestionAnswering", () => { 17 | const modelName = "bert-large-uncased-whole-word-masking-finetuned-squad"; 18 | const apiKey = "your-api-key"; 19 | let model: QuestionAnswering; 20 | 21 | beforeAll(() => { 22 | model = new QuestionAnswering(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(QuestionAnswering); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | question: "What is the capital of France?", 32 | context: "France is a country in Europe.", 33 | }); 34 | 35 | expect(response).toBeDefined(); 36 | expect(postMock).toHaveBeenCalledWith( 37 | `${ROOT_URL}${modelName}`, 38 | expect.any(Object), 39 | expect.any(Object), 40 | ); 41 | }); 42 | }); 43 | -------------------------------------------------------------------------------- /test/base/sdxl.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | jest.mock("axios", () => { 6 | const mockAxiosInstance = { 7 | post: postMock, 8 | }; 9 | return { 10 | create: jest.fn(() => mockAxiosInstance), 11 | }; 12 | }); 13 | import { ROOT_URL } from "@/lib/constants/client"; 14 | import { Sdxl } from "@/index"; 15 | 16 | describe("Sdxl", () => { 17 | const modelName = "stability-ai/sdxl"; 18 | const apiKey = "your-api-key"; 19 | let model: Sdxl; 20 | 21 | beforeAll(() => { 22 | model = new Sdxl(apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(Sdxl); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | input: { prompt: "The quick brown fox jumps over the lazy dog" }, 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(Object), 38 | expect.any(Object), 39 | ); 40 | }); 41 | }); 42 | -------------------------------------------------------------------------------- /test/base/text-classification.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | jest.mock("axios", () => { 6 | const mockAxiosInstance = { 7 | post: postMock, 8 | }; 9 | return { 10 | create: jest.fn(() => mockAxiosInstance), 11 | }; 12 | }); 13 | import { ROOT_URL } from "@/lib/constants/client"; 14 | import { TextClassification } from "@/index"; 15 | 16 | describe("TextClassification", () => { 17 | const modelName = "ProsusAI/finbert"; 18 | const apiKey = "your-api-key"; 19 | let model: TextClassification; 20 | 21 | beforeAll(() => { 22 | model = new TextClassification(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(TextClassification); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | input: "The quick brown fox jumps over the lazy dog", 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(Object), 38 | expect.any(Object), 39 | ); 40 | }); 41 | }); 42 | -------------------------------------------------------------------------------- /test/base/text-generation.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | jest.mock("axios", () => { 6 | const mockAxiosInstance = { 7 | post: postMock, 8 | }; 9 | return { 10 | create: jest.fn(() => mockAxiosInstance), 11 | }; 12 | }); 13 | import { ROOT_URL } from "@/lib/constants/client"; 14 | import { TextGeneration } from "@/index"; 15 | 16 | describe("TextGeneration", () => { 17 | const modelName = "microsoft/WizardLM-2-8x22B"; 18 | const apiKey = "your-api-key"; 19 | let model: TextGeneration; 20 | 21 | beforeAll(() => { 22 | model = new TextGeneration(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(TextGeneration); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | input: "The quick brown fox jumps over the lazy dog", 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(Object), 38 | expect.any(Object), 39 | ); 40 | }); 41 | }); 42 | -------------------------------------------------------------------------------- /test/base/text-to-image.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | jest.mock("axios", () => { 6 | const mockAxiosInstance = { 7 | post: postMock, 8 | }; 9 | return { 10 | create: jest.fn(() => mockAxiosInstance), 11 | }; 12 | }); 13 | import { ROOT_URL } from "@/lib/constants/client"; 14 | import { TextToImage } from "@/index"; 15 | 16 | describe("TextToImage", () => { 17 | const modelName = "runwayml/stable-diffusion-v1-5"; 18 | const apiKey = "your-api-key"; 19 | let model: TextToImage; 20 | 21 | beforeAll(() => { 22 | model = new TextToImage(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(TextToImage); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | prompt: "The quick brown fox jumps over the lazy dog", 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(Object), 38 | expect.any(Object), 39 | ); 40 | }); 41 | }); 42 | -------------------------------------------------------------------------------- /test/base/token-classification.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | 5 | jest.mock("axios", () => { 6 | const mockAxiosInstance = { 7 | post: postMock, 8 | }; 9 | return { 10 | create: jest.fn(() => mockAxiosInstance), 11 | }; 12 | }); 13 | import { ROOT_URL } from "@/lib/constants/client"; 14 | import { TokenClassification } from "@/index"; 15 | 16 | describe("TokenClassification", () => { 17 | const modelName = "Davlan/bert-base-multilingual-cased-ner-hrl"; 18 | const apiKey = "your-api-key"; 19 | let model: TokenClassification; 20 | 21 | beforeAll(() => { 22 | model = new TokenClassification(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(TokenClassification); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | input: "The quick brown fox jumps over the lazy dog", 32 | }); 33 | 34 | expect(response).toBeDefined(); 35 | expect(postMock).toHaveBeenCalledWith( 36 | `${ROOT_URL}${modelName}`, 37 | expect.any(Object), 38 | expect.any(Object), 39 | ); 40 | }); 41 | 42 | it("should write to console if DEEPINFRA_API_KEY is not set", () => { 43 | const consoleSpy = jest.spyOn(console, "warn"); 44 | const model = new TokenClassification(modelName); 45 | expect(model).toBeDefined(); 46 | expect(consoleSpy).toHaveBeenCalled(); 47 | }); 48 | 49 | it("should be constructed with an API key", () => { 50 | process.env.DEEPINFRA_API_KEY = apiKey; 51 | const model = new TokenClassification(modelName); 52 | expect(model).toBeDefined(); 53 | process.env.DEEPINFRA_API_KEY = ""; 54 | }); 55 | }); 56 | -------------------------------------------------------------------------------- /test/base/zero-shot-image-classification.spec.ts: -------------------------------------------------------------------------------- 1 | const postMock = jest 2 | .fn() 3 | .mockResolvedValue({ data: { transcription: "example text" } }); 4 | jest.mock("axios", () => { 5 | const mockAxiosInstance = { 6 | post: postMock, 7 | }; 8 | return { 9 | create: jest.fn(() => mockAxiosInstance), 10 | }; 11 | }); 12 | import { ROOT_URL } from "@/lib/constants/client"; 13 | import { ZeroShotImageClassification } from "@/index"; 14 | import FormData from "form-data"; 15 | 16 | describe("ZeroShotImageClassification", () => { 17 | const modelName = "openai/clip-vit-base-patch32"; 18 | const apiKey = "your-api-key"; 19 | let model: ZeroShotImageClassification; 20 | 21 | beforeAll(() => { 22 | model = new ZeroShotImageClassification(modelName, apiKey); 23 | }); 24 | 25 | it("should create a new instance", () => { 26 | expect(model).toBeInstanceOf(ZeroShotImageClassification); 27 | }); 28 | 29 | it("should send a request to correct URL", async () => { 30 | const response = await model.generate({ 31 | image: "test/assets/image.jpg", 32 | candidate_labels: ["dog", "cat"], 33 | }); 34 | 35 | expect(response).toBeDefined(); 36 | expect(postMock).toHaveBeenCalledWith( 37 | `${ROOT_URL}${modelName}`, 38 | expect.any(FormData), 39 | expect.objectContaining({ 40 | headers: expect.objectContaining({ 41 | "content-type": expect.stringMatching(/multipart\/form-data/), 42 | }), 43 | }), 44 | ); 45 | }); 46 | }); 47 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "esnext", 4 | "declaration": true, 5 | "module": "commonjs", 6 | "moduleResolution": "node", 7 | "allowSyntheticDefaultImports": true, 8 | "experimentalDecorators": true, 9 | "emitDecoratorMetadata": true, 10 | "allowJs": true, 11 | "esModuleInterop": true, 12 | "forceConsistentCasingInFileNames": true, 13 | "strict": true, 14 | "resolveJsonModule": true, 15 | "skipLibCheck": true, 16 | "outDir": "dist", 17 | "rootDir": "src", 18 | "types": [ 19 | "node", 20 | "jest" 21 | ], 22 | "paths": { 23 | "@/*": [ 24 | "src/*" 25 | ], 26 | }, 27 | "baseUrl": "." 28 | }, 29 | "include": [ 30 | "src", 31 | "src/**/*.json" 32 | ], 33 | "exclude": [ 34 | "**/misc.ts", 35 | "node_modules", 36 | "dist", 37 | "**/*.spec.ts", 38 | ] 39 | } 40 | --------------------------------------------------------------------------------