├── .eslintignore
├── .eslintrc
├── .github
└── workflows
│ ├── ci.yml
│ └── deploy.yml
├── .gitignore
├── LICENSE
├── README.MD
├── RELEASE.MD
├── build.config.js
├── jest.config.cjs
├── package-lock.json
├── package.json
├── pnpm-lock.yaml
├── src
├── __mocks__
│ └── fs.ts
├── clients
│ ├── deep-infra.ts
│ └── index.ts
├── index.ts
└── lib
│ ├── constants
│ └── client.ts
│ ├── models
│ └── base
│ │ ├── automatic-speech-recognition.ts
│ │ ├── base-model.ts
│ │ ├── cog-model.ts
│ │ ├── custom-model.ts
│ │ ├── embeddings.ts
│ │ ├── fill-mask.ts
│ │ ├── image-base-model.ts
│ │ ├── image-classification.ts
│ │ ├── index.ts
│ │ ├── object-detection.ts
│ │ ├── question-answering.ts
│ │ ├── sdxl.ts
│ │ ├── text-classification.ts
│ │ ├── text-generation.ts
│ │ ├── text-to-image.ts
│ │ ├── token-classification.ts
│ │ └── zero-shot-image-classification.ts
│ ├── types
│ ├── automatic-speech-recognition
│ │ ├── request.ts
│ │ └── response.ts
│ ├── cog
│ │ ├── request.ts
│ │ ├── response.ts
│ │ └── sdxl
│ │ │ ├── request.ts
│ │ │ └── response.ts
│ ├── common
│ │ ├── client-config.ts
│ │ ├── image-item.ts
│ │ ├── image-request.ts
│ │ ├── single-text-input-request.ts
│ │ └── status.ts
│ ├── embeddings
│ │ ├── request.ts
│ │ └── response.ts
│ ├── fill-mask
│ │ ├── request.ts
│ │ └── response.ts
│ ├── image-classification
│ │ ├── request.ts
│ │ └── response.ts
│ ├── object-detection
│ │ ├── request.ts
│ │ └── response.ts
│ ├── questions-answering
│ │ ├── request.ts
│ │ └── response.ts
│ ├── text-classification
│ │ ├── request.ts
│ │ └── response.ts
│ ├── text-generation
│ │ ├── request.ts
│ │ └── response.ts
│ ├── text-to-image
│ │ ├── request.ts
│ │ └── response.ts
│ ├── token-classification
│ │ ├── request.ts
│ │ └── response.ts
│ └── zero-shot-image-classification
│ │ ├── request.ts
│ │ └── response.ts
│ └── utils
│ ├── form-data.ts
│ ├── read-stream.ts
│ └── url.ts
├── test
└── base
│ ├── automatic-speech-recognition.spec.ts
│ ├── embeddings.spec.ts
│ ├── fill-mask.spec.ts
│ ├── image-classification.spec.ts
│ ├── object-detection.spec.ts
│ ├── question-answering.spec.ts
│ ├── sdxl.spec.ts
│ ├── text-classification.spec.ts
│ ├── text-generation.spec.ts
│ ├── text-to-image.spec.ts
│ ├── token-classification.spec.ts
│ └── zero-shot-image-classification.spec.ts
└── tsconfig.json
/.eslintignore:
--------------------------------------------------------------------------------
1 | **/misc.ts
2 |
--------------------------------------------------------------------------------
/.eslintrc:
--------------------------------------------------------------------------------
1 | {
2 | "env": {
3 | "browser": true,
4 | "es6": true,
5 | "node": true
6 | },
7 | "extends": [
8 | "eslint:recommended",
9 | "plugin:@typescript-eslint/recommended",
10 | "plugin:unicorn/recommended"
11 | ],
12 | "globals": {
13 | "Atomics": "readonly",
14 | "SharedArrayBuffer": "readonly"
15 | },
16 | "parserOptions": {
17 | "ecmaVersion": 2020,
18 | "sourceType": "module"
19 | },
20 | "plugins": [
21 | "unicorn"
22 | ],
23 | "rules": {
24 | "unicorn/prevent-abbreviations": "off",
25 | "unicorn/filename-case": [
26 | "error",
27 | {
28 | "case": "kebabCase"
29 | }
30 | ],
31 | "unicorn/prefer-query-selector": "error",
32 | "unicorn/no-null": "off",
33 | "indent": [
34 | "error",
35 | 2
36 | ],
37 | "linebreak-style": [
38 | "error",
39 | "unix"
40 | ],
41 | "quotes": [
42 | "error",
43 | "single"
44 | ],
45 | "semi": [
46 | "error",
47 | "always"
48 | ]
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Run Unit Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - '**'
7 | pull_request:
8 | branches:
9 | - '**'
10 |
11 |
12 |
13 | jobs:
14 | unit-tests:
15 | if: github.event.pull_request.draft == false
16 | runs-on: ubuntu-latest
17 |
18 | strategy:
19 | matrix:
20 | node-version: ['16.x', '18.x', '20.x']
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 |
25 | - uses: actions/setup-node@v2
26 | with:
27 | node-version: ${{ matrix.node-version }}
28 | registry-url: 'https://registry.npmjs.org'
29 |
30 | - run: npm install
31 | - run: npm run test
32 | - run: npm run build
33 |
34 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: Publish NPM Package
2 |
3 | on:
4 | release:
5 | types: [created]
6 | concurrency:
7 | group: ${{ github.ref }}
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | publish-npm:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v2
15 |
16 | - uses: actions/setup-node@v2
17 | with:
18 | node-version: '18.x'
19 | registry-url: 'https://registry.npmjs.org'
20 |
21 | - run: npm install
22 | - run: npm run test
23 | - run: npm run build
24 | - run: npm publish --access=public
25 | env:
26 | NODE_AUTH_TOKEN: ${{secrets.NPM_TOKEN}}
27 |
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Taken from https://github.com/github/gitignore/blob/main/Node.gitignore
2 |
3 | # Logs
4 | logs
5 | *.log
6 | npm-debug.log*
7 | yarn-debug.log*
8 | yarn-error.log*
9 | lerna-debug.log*
10 | .pnpm-debug.log*
11 |
12 | # Diagnostic reports (https://nodejs.org/api/report.html)
13 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
14 |
15 | # Runtime data
16 | pids
17 | *.pid
18 | *.seed
19 | *.pid.lock
20 |
21 | # Directory for instrumented libs generated by jscoverage/JSCover
22 | lib-cov
23 |
24 | # Coverage directory used by tools like istanbul
25 | coverage
26 | *.lcov
27 |
28 | # nyc test coverage
29 | .nyc_output
30 |
31 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
32 | .grunt
33 |
34 | # Bower dependency directory (https://bower.io/)
35 | bower_components
36 |
37 | # node-waf configuration
38 | .lock-wscript
39 |
40 | # Compiled binary addons (https://nodejs.org/api/addons.html)
41 | build/Release
42 |
43 | # Dependency directories
44 | node_modules/
45 | jspm_packages/
46 |
47 | # Snowpack dependency directory (https://snowpack.dev/)
48 | web_modules/
49 |
50 | # TypeScript cache
51 | *.tsbuildinfo
52 |
53 | # Optional npm cache directory
54 | .npm
55 |
56 | # Optional eslint cache
57 | .eslintcache
58 |
59 | # Optional stylelint cache
60 | .stylelintcache
61 |
62 | # Microbundle cache
63 | .rpt2_cache/
64 | .rts2_cache_cjs/
65 | .rts2_cache_es/
66 | .rts2_cache_umd/
67 |
68 | # Optional REPL history
69 | .node_repl_history
70 |
71 | # Output of 'npm pack'
72 | *.tgz
73 |
74 | # Yarn Integrity file
75 | .yarn-integrity
76 |
77 | # dotenv environment variable files
78 | .env
79 | .env.development.local
80 | .env.test.local
81 | .env.production.local
82 | .env.local
83 |
84 | # parcel-bundler cache (https://parceljs.org/)
85 | .cache
86 | .parcel-cache
87 |
88 | # Next.js build output
89 | .next
90 | out
91 |
92 | # Nuxt.js build / generate output
93 | .nuxt
94 | dist
95 |
96 | # Gatsby files
97 | .cache/
98 | # Comment in the public line in if your project uses Gatsby and not Next.js
99 | # https://nextjs.org/blog/next-9-1#public-directory-support
100 | # public
101 |
102 | # vuepress build output
103 | .vuepress/dist
104 |
105 | # vuepress v2.x temp and cache directory
106 | .temp
107 | .cache
108 |
109 | # Docusaurus cache and generated files
110 | .docusaurus
111 |
112 | # Serverless directories
113 | .serverless/
114 |
115 | # FuseBox cache
116 | .fusebox/
117 |
118 | # DynamoDB Local files
119 | .dynamodb/
120 |
121 | # TernJS port file
122 | .tern-port
123 |
124 | # Stores VSCode versions used for testing VSCode extensions
125 | .vscode-test
126 |
127 | # yarn v2
128 | .yarn/cache
129 | .yarn/unplugged
130 | .yarn/build-state.yml
131 | .yarn/install-state.gz
132 | .pnp.*
133 |
134 | # IDE files
135 | .idea
136 | *.suo
137 | *.ntvs*
138 | .vscode
139 | *.code-workspace
140 | *.njsproj
141 | *.sln
142 |
143 | docs
144 | dist
145 | **misc.ts
146 | docs
147 | .husky
148 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2024 The MIT License (MIT)
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # DeepInfra Node API Library
2 | 
3 | 
4 |
5 |
6 | This library provides a simple way to interact with the DeepInfra API.
7 |
8 |
9 | Check out our docs [here.](https://deepinfra.github.io/deepinfra-node/)
10 |
11 | ## Installation
12 |
13 | ```bash
14 | npm install deepinfra
15 | ```
16 |
17 | ## Usage
18 |
19 | ### Use [text generation models](https://deepinfra.com/models/text-generation)
20 |
21 | The Mixtral mixture of expert model, developed by Mistral AI, is an innovative experimental machine learning model that
22 | leverages a mixture of 8 experts (MoE) within 7b models. Its release was facilitated via a torrent, and the model's
23 | implementation remains in the experimental phase.\_
24 |
25 | ```typescript
26 | import {TextGeneration} from "deepinfra";
27 |
28 | const modelName = "mistralai/Mixtral-8x22B-Instruct-v0.1";
29 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
30 | const main = async () => {
31 | const mixtral = new TextGeneration(modelName, apiKey);
32 | const body = {
33 | input: "What is the capital of France?",
34 | };
35 | const output = await mixtral.generate(body);
36 | const text = output.results[0].generated_text;
37 | console.log(text);
38 | };
39 |
40 | main();
41 | ```
42 |
43 | ### Use [text embedding models](https://deepinfra.com/models/embeddings)
44 |
45 | Gte Base is an text embedding model that generates embeddings for the input text. The model is trained by Alibaba DAMO Academy.
46 |
47 | ```typescript
48 | import { GteBase } from "deepinfra";
49 |
50 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
51 | const modelName = "thenlper/gte-base";
52 | const main = async () => {
53 | const gteBase = new Embeddings(modelName, apiKey);
54 | const body = {
55 | inputs: [
56 | "What is the capital of France?",
57 | "What is the capital of Germany?",
58 | "What is the capital of Italy?",
59 | ],
60 | };
61 | const output = await gteBase.generate(body);
62 | const embeddings = output.embeddings[0];
63 | console.log(embeddings);
64 | };
65 |
66 | main();
67 | ```
68 |
69 | ### Use [SDXL](https://deepinfra.com/stability-ai/sdxl) to generate images
70 |
71 | SDXL requires unique parameters, therefore it requires different initialization.
72 |
73 | ```typescript
74 | import { Sdxl } from "deepinfra";
75 | import axios from "axios";
76 | import fs from "fs";
77 |
78 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
79 |
80 | const main = async () => {
81 | const model = new Sdxl(apiKey);
82 |
83 | const input = {
84 | prompt: "The quick brown fox jumps over the lazy dog with",
85 | };
86 | const response = await model.generate({ input });
87 | const { output } = response;
88 | const image = output[0];
89 |
90 | await axios.get(image, { responseType: "arraybuffer" }).then((response) => {
91 | fs.writeFileSync("image.png", response.data);
92 | });
93 | };
94 |
95 | main();
96 | ```
97 |
98 | ### Use [other text to image models](https://deepinfra.com/models/text-to-image)
99 |
100 | ```typescript
101 | import { TextToImage } from "deepinfra";
102 | import axios from "axios";
103 | import fs from "fs";
104 |
105 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
106 | const modelName = "stabilityai/stable-diffusion-2-1";
107 | const main = async () => {
108 | const model = new TextToImage(modelName, apiKey);
109 | const input = {
110 | prompt: "The quick brown fox jumps over the lazy dog with",
111 | };
112 |
113 | const response = await model.generate(input);
114 | const { output } = response;
115 | const image = output[0];
116 |
117 | await axios.get(image, { responseType: "arraybuffer" }).then((response) => {
118 | fs.writeFileSync("image.png", response.data);
119 | });
120 | };
121 | main();
122 | ```
123 |
124 | ### Use [automatic speech recognition models](https://deepinfra.com/models/automatic-speech-recognition)
125 |
126 | ```typescript
127 | import { AutomaticSpeechRecognition } from "deepinfra";
128 |
129 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
130 | const modelName = "openai/whisper-base";
131 |
132 | const main = async () => {
133 | const model = new AutomaticSpeechRecognition(modelName, apiKey);
134 |
135 | const input = {
136 | audio: path.join(__dirname, "audio.mp3"),
137 | };
138 | const response = await model.generate(input);
139 | const { text } = response;
140 | console.log(text);
141 | };
142 |
143 | main();
144 | ```
145 |
146 | ### Use [object detection models](https://deepinfra.com/models/object-detection)
147 |
148 | ```typescript
149 | import { ObjectDetection } from "deepinfra";
150 |
151 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
152 | const modelName = "hustvl/yolos-tiny";
153 | const main = async () => {
154 | const model = new ObjectDetection(modelName, apiKey);
155 |
156 | const input = {
157 | image: path.join(__dirname, "image.jpg"),
158 | };
159 | const response = await model.generate(input);
160 | const { results } = response;
161 | console.log(results);
162 | };
163 | ```
164 |
165 | ### Use [token classification models](https://deepinfra.com/models/token-classification)
166 |
167 | ```typescript
168 | import { TokenClassification } from "deepinfra";
169 |
170 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
171 | const modelName = "Davlan/bert-base-multilingual-cased-ner-hrl";
172 |
173 | const main = async () => {
174 | const model = new TokenClassification(modelName, apiKey);
175 |
176 | const input = {
177 | text: "The quick brown fox jumps over the lazy dog",
178 | };
179 | const response = await model.generate(input);
180 | const { results } = response;
181 | console.log(results);
182 | };
183 | ```
184 |
185 | ### Use [fill mask models](https://deepinfra.com/models/fill-mask)
186 |
187 | ```typescript
188 | import { FillMask } from "deepinfra";
189 |
190 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
191 | const modelName = "GroNLP/bert-base-dutch-cased";
192 |
193 | const main = async () => {
194 | const model = new FillMask(modelName, apiKey);
195 |
196 | const body = {
197 | input: "Ik heb een [MASK] gekocht.",
198 | };
199 |
200 | const { results } = await model.generate(body);
201 | console.log(results);
202 | };
203 | ```
204 |
205 | ### Use [image classification models](https://deepinfra.com/models/image-classification)
206 |
207 | ```typescript
208 | import { ImageClassification } from "deepinfra";
209 |
210 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
211 | const modelName = "google/vit-base-patch16-224";
212 |
213 | const main = async () => {
214 | const model = new ImageClassification(modelName, apiKey);
215 |
216 | const body = {
217 | image: path.join(__dirname, "image.jpg"),
218 | };
219 |
220 | const { results } = await model.generate(body);
221 | console.log(results);
222 | };
223 | ```
224 |
225 | ### Use [zero-shot image classification models](https://deepinfra.com/models/zero-shot-image-classification)
226 |
227 | ```typescript
228 | import { ZeroShotImageClassification } from "deepinfra";
229 |
230 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
231 | const modelName = "openai/clip-vit-base-patch32";
232 |
233 | const main = async () => {
234 | const model = new ZeroShotImageClassification(modelName, apiKey);
235 |
236 | const body = {
237 | image: path.join(__dirname, "image.jpg"),
238 | candidate_labels: ["dog", "cat", "car"],
239 | };
240 |
241 | const { results } = await model.generate(body);
242 | console.log(results);
243 | };
244 | ```
245 |
246 | ### Use [text classification models](https://deepinfra.com/models/text-classification)
247 |
248 | ```typescript
249 | import { TextClassification } from "deepinfra";
250 |
251 | const apiKey = "YOUR_DEEPINFRA_API_KEY";
252 | const modelName = "ProsusAI/finbert";
253 |
254 | const misc = async () => {
255 | const model = new TextClassification(apiKey);
256 |
257 | const body = {
258 | input:
259 | "DeepInfra emerges from stealth with $8M to make running AI inferences more affordable",
260 | };
261 |
262 | const { results } = await model.generate(body);
263 | console.log(results);
264 | };
265 | ```
266 |
267 | ## Contributors
268 |
269 | [Oguz Vuruskaner](https://github.com/ovuruska)
270 |
271 | [Iskren Ivov Chernev](https://github.com/ichernev)
272 |
273 | ## Contributing
274 |
275 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull
276 | requests to us.
277 |
278 | ## License
279 |
280 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
281 |
--------------------------------------------------------------------------------
/RELEASE.MD:
--------------------------------------------------------------------------------
1 | # Release
2 | This document describes the release process for the `deepinfra` package.
3 |
4 |
5 | ## Checklist
6 | Before release, make sure to do the following steps:
7 |
8 | - [ ] Update the `package.json` version number. You can use `npm version --allow-same-version --no-git-tag-version version_number` to update the version number.
9 | - [ ] Create a new release on GitHub with the same version number.
10 |
11 |
12 |
13 | ## Authors
14 | - [Oguz Vuruskaner](https://github.com/ovuruska)
15 | - [Iskren Ivov Chernev](https://github.com/ichernev)
16 |
17 |
--------------------------------------------------------------------------------
/build.config.js:
--------------------------------------------------------------------------------
1 | import { defineBuildConfig } from 'unbuild';
2 |
3 | export default defineBuildConfig({
4 | entries: ['./src/index'],
5 | outDir: 'dist',
6 | declaration: true,
7 | clean: true,
8 | failOnWarn: true,
9 | rollup: {
10 | emitCJS: true,
11 | esbuild: {
12 | minify: process.env.NODE_ENV === 'production'
13 | },
14 | },
15 | });
16 |
--------------------------------------------------------------------------------
/jest.config.cjs:
--------------------------------------------------------------------------------
1 | /** @type {import('ts-jest').JstConfigWithTsJest} */
2 | module.exports = {
3 | testEnvironment: 'node',
4 | extensionsToTreatAsEsm: [".ts"],
5 | testTimeout: 10000,
6 | coveragePathIgnorePatterns: [
7 | "/node_modules/",
8 | "/examples/",
9 | "/test/"
10 | ],
11 | moduleNameMapper: {
12 | "^@/(.*)$": "/src/$1"
13 | },
14 | testMatch: [
15 | "/test/**/*.spec.ts"
16 | ],
17 | rootDir: ".",
18 | transform: {
19 | "^.+\\.tsx?$": "ts-jest"
20 | },
21 |
22 | };
23 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "deepinfra",
3 | "version": "2.0.2",
4 | "description": "Official API wrapper for DeepInfra",
5 | "main": "dist/index.js",
6 | "types": "dist/index.d.ts",
7 | "files": [
8 | "dist"
9 | ],
10 | "repository": {
11 | "type": "git",
12 | "url": "git+https://github.com/deepinfra/deepinfra-node.git"
13 | },
14 | "scripts": {
15 | "build": "tsc && tsc-alias -p tsconfig.json",
16 | "misc": "npx ts-node -r tsconfig-paths/register src/misc.ts",
17 | "prepare": "husky",
18 | "test": "jest --passWithNoTests",
19 | "lint": "eslint . --ext .ts --fix",
20 | "prettier": "prettier --write ./src ./test && prettier --write README.MD",
21 | "build-docs": "typedoc --out docs src",
22 | "predeploy-docs": "npm run build-docs",
23 | "deploy-docs": "npx gh-pages -d docs"
24 | },
25 | "config": {
26 | "commitizen": {
27 | "path": "./node_modules/cz-conventional-changelog"
28 | }
29 | },
30 | "keywords": [
31 | "llm",
32 | "deepinfra",
33 | "api",
34 | "wrapper",
35 | "typesript",
36 | "large language model",
37 | "deep learning",
38 | "machine learning",
39 | "artificial intelligence",
40 | "ai"
41 | ],
42 | "author": "Oguz Vuruskaner (https://www.oguzvuruskaner.com)",
43 | "license": "MIT",
44 | "dependencies": {
45 | "@swc/core": "^1.4.6",
46 | "@swc/wasm": "^1.4.6",
47 | "axios": "^1.6.7",
48 | "form-data": "^4.0.0"
49 | },
50 | "devDependencies": {
51 | "@types/jest": "^29.5.12",
52 | "@types/node": "^20.11.26",
53 | "@typescript-eslint/eslint-plugin": "^7.2.0",
54 | "@typescript-eslint/parser": "^7.2.0",
55 | "cz-conventional-changelog": "^3.3.0",
56 | "dotenv": "^16.4.5",
57 | "eslint": "^8.57.0",
58 | "eslint-plugin-unicorn": "^51.0.1",
59 | "husky": "^9.0.11",
60 | "jest": "^29.7.0",
61 | "prettier": "^3.2.5",
62 | "ts-jest": "^29.1.2",
63 | "ts-node": "^10.9.2",
64 | "tsc-alias": "^1.8.8",
65 | "tsconfig-paths": "^4.2.0",
66 | "typedoc": "^0.25.12",
67 | "typescript": "^5.4.2"
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/src/__mocks__/fs.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * This file is used in unit tests to mock the fs module.
3 | *
4 | */
5 |
6 | import * as originalFs from "fs";
7 |
8 | type MockFiles = Record;
9 |
10 | interface CustomFs
11 | extends Omit {
12 | __setMockFiles: (newMockFiles: MockFiles) => void;
13 | readFileSync: (filePath: string, options?: object) => string;
14 | createReadStream: (filePath: string, options?: object) => any;
15 | }
16 |
17 | const fs: CustomFs = jest.createMockFromModule("fs");
18 | jest.mock("node:fs", () => fs);
19 |
20 | let mockFiles: MockFiles = Object.create(null);
21 |
22 | function __setMockFiles(newMockFiles: MockFiles): void {
23 | mockFiles = newMockFiles;
24 | }
25 |
26 | function readFileSync(filePath: string): string {
27 | return mockFiles[filePath] || "";
28 | }
29 |
30 | function createReadStream(filePath: string): any {
31 | return filePath;
32 | }
33 |
34 | fs.__setMockFiles = __setMockFiles;
35 | fs.readFileSync = readFileSync;
36 | fs.createReadStream = createReadStream;
37 |
38 | module.exports = fs;
39 |
--------------------------------------------------------------------------------
/src/clients/deep-infra.ts:
--------------------------------------------------------------------------------
1 | import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from "axios";
2 | import { USER_AGENT } from "@/lib/constants/client";
3 | import { ClientConfig, IClientConfig } from "@/lib/types/common/client-config";
4 |
5 | export class DeepInfraClient {
6 | private axiosClient: AxiosInstance;
7 | private readonly clientConfig: ClientConfig;
8 |
9 | constructor(
10 | private readonly url: string,
11 | private readonly authToken: string,
12 | config?: Partial,
13 | ) {
14 | this.axiosClient = axios.create({
15 | baseURL: this.url,
16 | });
17 | this.clientConfig = new ClientConfig(config);
18 | }
19 |
20 | private async backoffDelay(attempt: number): Promise {
21 | const delay =
22 | attempt === 1
23 | ? this.clientConfig.initialBackoff
24 | : this.clientConfig.subsequentBackoff;
25 | return new Promise((resolve) => setTimeout(resolve, delay));
26 | }
27 |
28 | public async post(
29 | data: object,
30 | config?: AxiosRequestConfig,
31 | ): Promise> {
32 | const headers = {
33 | "content-type": "application/json",
34 | ...config?.headers,
35 | "User-Agent": USER_AGENT,
36 | Authorization: `Bearer ${this.authToken}`,
37 | };
38 |
39 | for (let attempt = 0; attempt <= this.clientConfig.maxRetries; attempt++) {
40 | try {
41 | return await this.axiosClient.post(this.url, data, {
42 | ...config,
43 | headers,
44 | });
45 | } catch (error) {
46 | if (attempt < this.clientConfig.maxRetries) {
47 | await this.backoffDelay(attempt);
48 | } else {
49 | throw error;
50 | }
51 | }
52 | }
53 | throw new Error("Maximum retries exceeded");
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/clients/index.ts:
--------------------------------------------------------------------------------
1 | export { DeepInfraClient } from "@/clients/deep-infra";
2 |
--------------------------------------------------------------------------------
/src/index.ts:
--------------------------------------------------------------------------------
1 | export * from "@/lib/models/base";
2 |
--------------------------------------------------------------------------------
/src/lib/constants/client.ts:
--------------------------------------------------------------------------------
1 | export const MAX_RETRIES = 5;
2 | export const INITIAL_BACKOFF = 5000;
3 | export const SUBSEQUENT_BACKOFF = 2000;
4 | export const USER_AGENT = "DeepInfra TypeScript API Client";
5 |
6 | export const ROOT_URL = "https://api.deepinfra.com/v1/inference/";
7 |
--------------------------------------------------------------------------------
/src/lib/models/base/automatic-speech-recognition.ts:
--------------------------------------------------------------------------------
1 | import { AutomaticSpeechRecognitionRequest } from "@/lib/types/automatic-speech-recognition/request";
2 | import { BaseModel } from "@/lib/models/base/base-model";
3 | import { IClientConfig } from "@/lib/types/common/client-config";
4 | import { AutomaticSpeechRecognitionResponse } from "@/lib/types/automatic-speech-recognition/response";
5 | import { FormDataUtils } from "@/lib/utils/form-data";
6 |
7 | export class AutomaticSpeechRecognition extends BaseModel {
8 | constructor(
9 | modelName: string,
10 | authToken?: string,
11 | config?: Partial,
12 | ) {
13 | super(modelName, authToken, config);
14 | }
15 |
16 | async generate(
17 | body: AutomaticSpeechRecognitionRequest,
18 | ): Promise {
19 | const formData =
20 | await FormDataUtils.prepareFormData(
21 | body,
22 | ["audio"],
23 | );
24 | const response = await this.client.post(
25 | formData,
26 | {
27 | headers: {
28 | ...formData.getHeaders(),
29 | },
30 | },
31 | );
32 | return response.data;
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/src/lib/models/base/base-model.ts:
--------------------------------------------------------------------------------
1 | import { DeepInfraClient } from "@/clients";
2 | import { IClientConfig } from "@/lib/types/common/client-config";
3 |
4 | import { ROOT_URL } from "@/lib/constants/client";
5 | import { URLUtils } from "@/lib/utils/url";
6 |
7 | export class BaseModel {
8 | protected client: DeepInfraClient;
9 | protected readonly endpoint: string;
10 | protected authToken: string;
11 | constructor(
12 | readonly modelName: string,
13 | authToken?: string,
14 | config?: Partial,
15 | ) {
16 | this.endpoint = URLUtils.isValidUrl(modelName)
17 | ? modelName
18 | : ROOT_URL + modelName;
19 | this.authToken =
20 | authToken || this.getAuthTokenFromEnv() || this.warnAboutMissingApiKey();
21 | this.client = new DeepInfraClient(this.endpoint, this.authToken, config);
22 | }
23 |
24 | private warnAboutMissingApiKey() {
25 | console.warn(
26 | "API key is not provided. Please provide an API key as an argument or set DEEPINFRA_API_KEY environment variable.",
27 | );
28 | return "";
29 | }
30 |
31 | private getAuthTokenFromEnv() {
32 | const apiKey = process.env.DEEPINFRA_API_KEY;
33 | return apiKey || "";
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/lib/models/base/cog-model.ts:
--------------------------------------------------------------------------------
1 | import { AxiosError } from "axios";
2 | import { BaseModel } from "@/lib/models/base";
3 | import { CogResponse } from "@/lib/types/cog/response";
4 | import { CogRequest } from "@/lib/types/cog/request";
5 | import { IClientConfig } from "@/lib/types/common/client-config";
6 |
7 | export class CogBaseModel extends BaseModel {
8 | constructor(
9 | protected endpoint: string,
10 | authToken?: string,
11 | config?: Partial,
12 | ) {
13 | super(endpoint, authToken, config);
14 | }
15 |
16 | public async generate(
17 | body: CogRequest,
18 | ): Promise> {
19 | try {
20 | const response = await this.client.post>(body);
21 | return response.data;
22 | } catch (error) {
23 | this.handleError(error);
24 | throw new Error("Failed to generate text");
25 | }
26 | }
27 |
28 | private handleError(error: unknown) {
29 | if (error instanceof AxiosError) {
30 | if (error.response) {
31 | console.error("Response data:", error.response.data);
32 | console.error("Status:", error.response.status);
33 | console.error("Headers:", error.response.headers);
34 | } else if (error.request) {
35 | console.error("No response received:", error.request);
36 | } else {
37 | console.error("Error message:", error.message);
38 | }
39 | } else {
40 | console.error("An unexpected error occurred:", error);
41 | }
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/lib/models/base/custom-model.ts:
--------------------------------------------------------------------------------
1 | import { BaseModel } from "@/lib/models/base/base-model";
2 | import { IClientConfig } from "@/lib/types/common/client-config";
3 |
4 | export abstract class CustomModel<
5 | Request extends object,
6 | Response extends object,
7 | > extends BaseModel {
8 | protected constructor(
9 | modelName: string,
10 | authToken?: string,
11 | config?: Partial,
12 | ) {
13 | super(modelName, authToken, config);
14 | }
15 |
16 | public async generate(body: Request): Promise {
17 | const response = await this.client.post(body);
18 | return response.data;
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/src/lib/models/base/embeddings.ts:
--------------------------------------------------------------------------------
1 | import { BaseModel } from "@/lib/models/base";
2 | import { EmbeddingsRequest } from "@/lib/types/embeddings/request";
3 | import { EmbeddingsResponse } from "@/lib/types/embeddings/response";
4 | import { IClientConfig } from "@/lib/types/common/client-config";
5 |
6 | export class Embeddings extends BaseModel {
7 | constructor(
8 | modelName: string,
9 | authToken?: string,
10 | config?: Partial,
11 | ) {
12 | super(modelName, authToken, config);
13 | }
14 |
15 | public async generate(body: EmbeddingsRequest): Promise {
16 | try {
17 | const response = await this.client.post(body);
18 | const { data, status } = response;
19 | if (status !== 200) {
20 | throw new Error(`HTTP error! status: ${status}`);
21 | }
22 | return data;
23 | } catch (error) {
24 | console.error("Error generating embeddings:", error);
25 | throw error;
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/lib/models/base/fill-mask.ts:
--------------------------------------------------------------------------------
1 | import { CustomModel } from "@/lib/models/base/custom-model";
2 | import { FillMaskRequest } from "@/lib/types/token-classification/request";
3 | import { TokenClassificationResponse } from "@/lib/types/token-classification/response";
4 | import { IClientConfig } from "@/lib/types/common/client-config";
5 |
6 | export class FillMask extends CustomModel<
7 | FillMaskRequest,
8 | TokenClassificationResponse
9 | > {
10 | constructor(
11 | modelName: string,
12 | authToken?: string,
13 | config?: Partial,
14 | ) {
15 | super(modelName, authToken, config);
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/lib/models/base/image-base-model.ts:
--------------------------------------------------------------------------------
1 | import { BaseModel } from "@/lib/models/base/base-model";
2 | import { IClientConfig } from "@/lib/types/common/client-config";
3 | import { ImageRequest } from "@/lib/types/common/image-request";
4 | import { FormDataUtils } from "@/lib/utils/form-data";
5 |
6 | export class ImageBaseModel<
7 | RequestType extends ImageRequest,
8 | ResponseType,
9 | > extends BaseModel {
10 | constructor(
11 | modelName: string,
12 | authToken?: string,
13 | config?: Partial,
14 | ) {
15 | super(modelName, authToken, config);
16 | }
17 |
18 | async generate(body: RequestType): Promise {
19 | const formData = await FormDataUtils.prepareFormData(body, [
20 | "image",
21 | ]);
22 | const response = await this.client.post(formData, {
23 | headers: {
24 | ...formData.getHeaders(),
25 | },
26 | });
27 | return response.data;
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/lib/models/base/image-classification.ts:
--------------------------------------------------------------------------------
1 | import { ImageClassificationRequest } from "@/lib/types/image-classification/request";
2 | import { ImageClassificationResponse } from "@/lib/types/image-classification/response";
3 | import { ImageBaseModel } from "@/lib/models/base/image-base-model";
4 |
5 | export class ImageClassification extends ImageBaseModel<
6 | ImageClassificationRequest,
7 | ImageClassificationResponse
8 | > {}
9 |
--------------------------------------------------------------------------------
/src/lib/models/base/index.ts:
--------------------------------------------------------------------------------
1 | export { BaseModel } from "@/lib/models/base/base-model";
2 | export { CustomModel } from "@/lib/models/base/custom-model";
3 | export { TextToImage } from "@/lib/models/base/text-to-image";
4 | export { TextGeneration } from "@/lib/models/base/text-generation";
5 | export { Embeddings } from "@/lib/models/base/embeddings";
6 | export { ObjectDetection } from "@/lib/models/base/object-detection";
7 | export { AutomaticSpeechRecognition } from "@/lib/models/base/automatic-speech-recognition";
8 | export { TokenClassification } from "@/lib/models/base/token-classification";
9 | export { FillMask } from "@/lib/models/base/fill-mask";
10 | export { TextClassification } from "@/lib/models/base/text-classification";
11 | export { QuestionAnswering } from "@/lib/models/base/question-answering";
12 | export { Sdxl } from "@/lib/models/base/sdxl";
13 | export { ImageClassification } from "@/lib/models/base/image-classification";
14 | export { ZeroShotImageClassification } from "@/lib/models/base/zero-shot-image-classification";
15 |
--------------------------------------------------------------------------------
/src/lib/models/base/object-detection.ts:
--------------------------------------------------------------------------------
1 | import { ObjectDetectionRequest } from "@/lib/types/object-detection/request";
2 | import { ObjectDetectionResponse } from "@/lib/types/object-detection/response";
3 | import { ImageBaseModel } from "@/lib/models/base/image-base-model";
4 |
5 | export class ObjectDetection extends ImageBaseModel<
6 | ObjectDetectionRequest,
7 | ObjectDetectionResponse
8 | > {}
9 |
--------------------------------------------------------------------------------
/src/lib/models/base/question-answering.ts:
--------------------------------------------------------------------------------
1 | import { QuestionAnsweringRequest } from "@/lib/types/questions-answering/request";
2 | import { QuestionAnsweringResponse } from "@/lib/types/questions-answering/response";
3 | import { CustomModel } from "@/lib/models/base/custom-model";
4 | import { IClientConfig } from "@/lib/types/common/client-config";
5 |
6 | export class QuestionAnswering extends CustomModel<
7 | QuestionAnsweringRequest,
8 | QuestionAnsweringResponse
9 | > {
10 | constructor(
11 | modelName: string,
12 | authToken?: string,
13 | config?: Partial,
14 | ) {
15 | super(modelName, authToken, config);
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/lib/models/base/sdxl.ts:
--------------------------------------------------------------------------------
1 | import { CogBaseModel } from "@/lib/models/base/cog-model";
2 | import { SdxlIn } from "@/lib/types/cog/sdxl/request";
3 | import { SdxlOut } from "@/lib/types/cog/sdxl/response";
4 | import { IClientConfig } from "@/lib/types/common/client-config";
5 |
6 | export class Sdxl extends CogBaseModel {
7 | public static readonly modelName: string = "stability-ai/sdxl";
8 | constructor(authToken?: string, config?: Partial) {
9 | super(Sdxl.modelName, authToken, config);
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/src/lib/models/base/text-classification.ts:
--------------------------------------------------------------------------------
1 | import { TextClassificationRequest } from "@/lib/types/text-classification/request";
2 | import { TextClassificationResponse } from "@/lib/types/text-classification/response";
3 | import { CustomModel } from "@/lib/models/base/custom-model";
4 | import { IClientConfig } from "@/lib/types/common/client-config";
5 |
6 | export class TextClassification extends CustomModel<
7 | TextClassificationRequest,
8 | TextClassificationResponse
9 | > {
10 | constructor(
11 | modelName: string,
12 | authToken?: string,
13 | config?: Partial,
14 | ) {
15 | super(modelName, authToken, config);
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/lib/models/base/text-generation.ts:
--------------------------------------------------------------------------------
1 | import { AxiosError } from "axios";
2 | import { BaseModel } from "@/lib/models/base";
3 | import { TextGenerationResponse } from "@/lib/types/text-generation/response";
4 | import { TextGenerationRequest } from "@/lib/types/text-generation/request";
5 | import { IClientConfig } from "@/lib/types/common/client-config";
6 |
7 | export class TextGeneration extends BaseModel {
8 | constructor(
9 | modelName: string,
10 | authToken?: string,
11 | config?: Partial,
12 | ) {
13 | super(modelName, authToken, config);
14 | }
15 |
16 | public async generate(
17 | body: TextGenerationRequest,
18 | ): Promise {
19 | try {
20 | const response = await this.client.post(body);
21 | return response.data;
22 | } catch (error) {
23 | this.handleError(error);
24 | throw new Error("Failed to generate text");
25 | }
26 | }
27 |
28 | private handleError(error: unknown) {
29 | if (error instanceof AxiosError) {
30 | if (error.response) {
31 | console.error("Response data:", error.response.data);
32 | console.error("Status:", error.response.status);
33 | console.error("Headers:", error.response.headers);
34 | } else if (error.request) {
35 | console.error("No response received:", error.request);
36 | } else {
37 | console.error("Error message:", error.message);
38 | }
39 | } else {
40 | console.error("An unexpected error occurred:", error);
41 | }
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/lib/models/base/text-to-image.ts:
--------------------------------------------------------------------------------
1 | import { BaseModel } from "@/lib/models/base";
2 | import { AxiosResponse } from "axios";
3 | import { TextToImageResponse } from "@/lib/types/text-to-image/response";
4 | import { TextToImageRequest } from "@/lib/types/text-to-image/request";
5 | import { IClientConfig } from "@/lib/types/common/client-config";
6 |
7 | export class TextToImage extends BaseModel {
8 | constructor(
9 | modelName: string,
10 | authToken?: string,
11 | config?: Partial,
12 | ) {
13 | super(modelName, authToken, config);
14 | }
15 |
16 | public async generate(
17 | body: TextToImageRequest,
18 | ): Promise {
19 | try {
20 | const response: AxiosResponse =
21 | await this.client.post(body);
22 | return response.data;
23 | } catch (error) {
24 | console.error("Error generating image:", error);
25 | throw error;
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/lib/models/base/token-classification.ts:
--------------------------------------------------------------------------------
1 | import { CustomModel } from "@/lib/models/base/custom-model";
2 | import { FillMaskRequest } from "@/lib/types/token-classification/request";
3 | import { TokenClassificationResponse } from "@/lib/types/token-classification/response";
4 | import { IClientConfig } from "@/lib/types/common/client-config";
5 |
6 | export class TokenClassification extends CustomModel<
7 | FillMaskRequest,
8 | TokenClassificationResponse
9 | > {
10 | constructor(
11 | modelName: string,
12 | authToken?: string,
13 | config?: Partial,
14 | ) {
15 | super(modelName, authToken, config);
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/lib/models/base/zero-shot-image-classification.ts:
--------------------------------------------------------------------------------
1 | import { ZeroShotImageClassificationRequest } from "@/lib/types/zero-shot-image-classification/request";
2 | import { ZeroShotImageClassificationResponse } from "@/lib/types/zero-shot-image-classification/response";
3 | import { ImageBaseModel } from "@/lib/models/base/image-base-model";
4 |
5 | export class ZeroShotImageClassification extends ImageBaseModel<
6 | ZeroShotImageClassificationRequest,
7 | ZeroShotImageClassificationResponse
8 | > {}
9 |
--------------------------------------------------------------------------------
/src/lib/types/automatic-speech-recognition/request.ts:
--------------------------------------------------------------------------------
1 | import { ReadStreamInput } from "@/lib/utils/read-stream";
2 |
3 | export interface AutomaticSpeechRecognitionRequest {
4 | audio: ReadStreamInput;
5 | task?: "transcribe" | "translate";
6 | language?: string;
7 | temperature?: number;
8 | patience?: number;
9 | suppress_tokens?: string;
10 | initial_prompt?: string;
11 | condition_on_previous_text?: boolean;
12 | temperature_increment_on_fallback?: number;
13 | compression_ratio_threshold?: number;
14 | logprob_threshold?: number;
15 | no_speech_threshold?: number;
16 | webhook?: string;
17 | }
18 |
--------------------------------------------------------------------------------
/src/lib/types/automatic-speech-recognition/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | interface AutomaticSpeechRecognitionWord {
4 | text: string;
5 | start: number;
6 | end: number;
7 | confidence: number;
8 | }
9 |
10 | interface AutomaticSpeechRecognitionSegment {
11 | id: number;
12 | seek: number;
13 | start: number;
14 | end: number;
15 | text: string;
16 | tokens: number[];
17 | temperature?: number;
18 | avg_logprob?: number;
19 | compression_ratio?: number;
20 | no_speech_prob?: number;
21 | confidence?: number;
22 | words?: AutomaticSpeechRecognitionWord[];
23 | }
24 |
25 | export interface AutomaticSpeechRecognitionResponse {
26 | text: string;
27 | segments: AutomaticSpeechRecognitionSegment[];
28 | language: string;
29 | input_length_ms?: number;
30 | request_id?: string;
31 | inference_status: Status;
32 | }
33 |
--------------------------------------------------------------------------------
/src/lib/types/cog/request.ts:
--------------------------------------------------------------------------------
1 | export type WebhookEventType = "start" | "output" | "logs" | "completed";
2 |
3 | export interface CogRequest {
4 | id?: string;
5 | version?: string | null;
6 | input: Req;
7 | webhook?: string | null;
8 | webhook_events_filter?: WebhookEventType[];
9 | }
10 |
--------------------------------------------------------------------------------
/src/lib/types/cog/response.ts:
--------------------------------------------------------------------------------
1 | import { CogRequest } from "./request";
2 |
3 | export type CogStatus = "succeeded" | "failed";
4 | export interface CogResponse extends CogRequest {
5 | status: CogStatus;
6 | created_at: string | null;
7 | started_at: string;
8 | complated_at: string;
9 | output: Res;
10 | error: string | null;
11 | logs: string;
12 | inference_status: {
13 | status: CogStatus;
14 | runtime_ms: number;
15 | cost: number;
16 | tokens_generated: number | null;
17 | tokens_input: number | null;
18 | };
19 | metrics: {
20 | predict_time: number;
21 | };
22 | }
23 |
--------------------------------------------------------------------------------
/src/lib/types/cog/sdxl/request.ts:
--------------------------------------------------------------------------------
1 | export interface SdxlIn {
2 | prompt: string;
3 | negative_prompt?: string;
4 | image?: string; // ? base64
5 | mask?: string; // ? base64
6 | width?: number;
7 | height?: number;
8 | num_outputs?: number;
9 | scheduler?:
10 | | "DDIM"
11 | | "DPMSolverMultistep"
12 | | "HeunDiscrete"
13 | | "KarrasDPM"
14 | | "K_EULER_ANCESTRAL"
15 | | "K_EULER"
16 | | "PNDM";
17 | num_inference_steps?: number;
18 | guidance_scale?: number;
19 | prompt_strength?: number;
20 | seed?: number;
21 | refine?: "no_refiner" | "expert_ensemble_refiner" | "base_image_refiner";
22 | high_noise_frac?: number;
23 | refine_steps?: number;
24 | apply_watermark?: boolean;
25 | }
26 |
--------------------------------------------------------------------------------
/src/lib/types/cog/sdxl/response.ts:
--------------------------------------------------------------------------------
1 | export type SdxlOut = string[];
2 |
--------------------------------------------------------------------------------
/src/lib/types/common/client-config.ts:
--------------------------------------------------------------------------------
1 | import {
2 | INITIAL_BACKOFF,
3 | MAX_RETRIES,
4 | SUBSEQUENT_BACKOFF,
5 | } from "@/lib/constants/client";
6 |
7 | export interface IClientConfig {
8 | maxRetries: number;
9 | initialBackoff: number;
10 | subsequentBackoff: number;
11 | }
12 |
13 | export class ClientConfig implements IClientConfig {
14 | maxRetries: number;
15 | initialBackoff: number;
16 | subsequentBackoff: number;
17 |
18 | constructor(config?: Partial) {
19 | this.maxRetries = config?.maxRetries ?? MAX_RETRIES;
20 | this.initialBackoff = config?.initialBackoff ?? INITIAL_BACKOFF;
21 | this.subsequentBackoff = config?.subsequentBackoff ?? SUBSEQUENT_BACKOFF;
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/src/lib/types/common/image-item.ts:
--------------------------------------------------------------------------------
1 | export interface ImageItem {
2 | label: string;
3 | score: number;
4 | }
5 |
--------------------------------------------------------------------------------
/src/lib/types/common/image-request.ts:
--------------------------------------------------------------------------------
1 | import { ReadStreamInput } from "@/lib/utils/read-stream";
2 |
3 | export interface ImageRequest {
4 | image: ReadStreamInput;
5 | webhook?: string;
6 | }
7 |
--------------------------------------------------------------------------------
/src/lib/types/common/single-text-input-request.ts:
--------------------------------------------------------------------------------
1 | export interface SingleTextInputRequest {
2 | input: string;
3 | webhook?: string;
4 | }
5 |
--------------------------------------------------------------------------------
/src/lib/types/common/status.ts:
--------------------------------------------------------------------------------
1 | export interface Status {
2 | status: "unknown" | "queued" | "running" | "succeeded" | "failed";
3 | runtime_ms: number;
4 | cost: number;
5 | tokens_generated: number;
6 | tokens_input: number;
7 | }
8 |
--------------------------------------------------------------------------------
/src/lib/types/embeddings/request.ts:
--------------------------------------------------------------------------------
1 | export interface EmbeddingsRequest {
2 | inputs: string[];
3 | normalize?: boolean;
4 | image?: string;
5 | webhook?: string;
6 | }
7 |
--------------------------------------------------------------------------------
/src/lib/types/embeddings/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface EmbeddingsResponse {
4 | embeddings: number[][];
5 | input_tokens: number;
6 | request_id?: string;
7 | inference_status: Status;
8 | }
9 |
--------------------------------------------------------------------------------
/src/lib/types/fill-mask/request.ts:
--------------------------------------------------------------------------------
1 | import { SingleTextInputRequest } from "@/lib/types/common/single-text-input-request";
2 |
3 | export interface FillMaskRequest extends SingleTextInputRequest {}
4 |
--------------------------------------------------------------------------------
/src/lib/types/fill-mask/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface FillMaskItem {
4 | sequence: string;
5 | score: number;
6 | token: number;
7 | token_str: string;
8 | }
9 |
10 | export interface FillMaskResponse {
11 | results: FillMaskItem[];
12 | request_id?: string;
13 | inference_status: Status;
14 | }
15 |
--------------------------------------------------------------------------------
/src/lib/types/image-classification/request.ts:
--------------------------------------------------------------------------------
1 | import { ImageRequest } from "@/lib/types/common/image-request";
2 |
3 | export interface ImageClassificationRequest extends ImageRequest {}
4 |
--------------------------------------------------------------------------------
/src/lib/types/image-classification/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 | import { ImageItem } from "@/lib/types/common/image-item";
3 |
4 | export interface ImageClassificationResponse {
5 | results: ImageItem[];
6 | request_id?: string;
7 | inference_status: Status;
8 | }
9 |
--------------------------------------------------------------------------------
/src/lib/types/object-detection/request.ts:
--------------------------------------------------------------------------------
1 | import { ImageRequest } from "@/lib/types/common/image-request";
2 |
3 | export interface ObjectDetectionRequest extends ImageRequest {}
4 |
--------------------------------------------------------------------------------
/src/lib/types/object-detection/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 | import { ImageItem } from "@/lib/types/common/image-item";
3 |
4 | export interface ObjectDetectionBox {
5 | xmin: number;
6 | ymin: number;
7 | xmax: number;
8 | ymax: number;
9 | }
10 |
11 | export interface ObjectDetectionItem extends ImageItem {
12 | box: ObjectDetectionBox;
13 | }
14 |
15 | export interface ObjectDetectionResponse {
16 | results: ObjectDetectionItem[];
17 | request_id?: string;
18 | inference_status: Status;
19 | }
20 |
--------------------------------------------------------------------------------
/src/lib/types/questions-answering/request.ts:
--------------------------------------------------------------------------------
1 | export interface QuestionAnsweringRequest {
2 | question: string;
3 | context: string;
4 | webhook?: string;
5 | }
6 |
--------------------------------------------------------------------------------
/src/lib/types/questions-answering/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface QuestionAnsweringResponse {
4 | answer: string;
5 | score: number;
6 | start: number;
7 | end: number;
8 | request_id: string | null;
9 | inference_status: Status;
10 | }
11 |
--------------------------------------------------------------------------------
/src/lib/types/text-classification/request.ts:
--------------------------------------------------------------------------------
1 | import { SingleTextInputRequest } from "@/lib/types/common/single-text-input-request";
2 |
3 | export interface TextClassificationRequest extends SingleTextInputRequest {}
4 |
--------------------------------------------------------------------------------
/src/lib/types/text-classification/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface TextClassificationResponse {
4 | results: TextClassificationItem[];
5 | request_id: string | null;
6 | inference_status: Status;
7 | }
8 |
9 | export interface TextClassificationItem {
10 | label: string;
11 | score: number;
12 | }
13 |
--------------------------------------------------------------------------------
/src/lib/types/text-generation/request.ts:
--------------------------------------------------------------------------------
1 | export interface TextGenerationRequest {
2 | input: string;
3 | stream?: boolean;
4 | max_new_tokens?: number;
5 | temperature?: number;
6 | top_p?: number;
7 | top_k?: number;
8 | repetition_penalty?: number;
9 | stop?: string[];
10 | num_responses?: number;
11 | response_format?: {
12 | type: string;
13 | };
14 | presence_penalty?: number;
15 | frequency_penalty?: number;
16 | webhook?: string;
17 | }
18 |
--------------------------------------------------------------------------------
/src/lib/types/text-generation/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface TextGenerationResponse {
4 | request_id: string;
5 | inference_status: Status;
6 | results: GeneratedText[];
7 | num_tokens: number;
8 | num_input_tokens: number;
9 | }
10 |
11 | export interface GeneratedText {
12 | generated_text: string;
13 | }
14 |
--------------------------------------------------------------------------------
/src/lib/types/text-to-image/request.ts:
--------------------------------------------------------------------------------
1 | export interface TextToImageRequest {
2 | prompt: string;
3 | negative_prompt?: string;
4 | image?: string;
5 | num_images?: number;
6 | num_inference_steps?: number;
7 | guidance_scale?: number;
8 | strength?: number;
9 | width?:
10 | | 128
11 | | 256
12 | | 384
13 | | 448
14 | | 512
15 | | 576
16 | | 640
17 | | 704
18 | | 768
19 | | 832
20 | | 896
21 | | 960
22 | | 1024;
23 | height?:
24 | | 128
25 | | 256
26 | | 384
27 | | 448
28 | | 512
29 | | 576
30 | | 640
31 | | 704
32 | | 768
33 | | 832
34 | | 896
35 | | 960
36 | | 1024;
37 | seed?: number;
38 | use_compel?: boolean;
39 | webhook?: string;
40 | }
41 |
--------------------------------------------------------------------------------
/src/lib/types/text-to-image/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface TextToImageResponse {
4 | images: string[];
5 | nsfw_content_detected: boolean[];
6 | seed: number;
7 |
8 | request_id: string;
9 | inference_status: Status;
10 | }
11 |
--------------------------------------------------------------------------------
/src/lib/types/token-classification/request.ts:
--------------------------------------------------------------------------------
1 | import { SingleTextInputRequest } from "@/lib/types/common/single-text-input-request";
2 |
3 | export interface FillMaskRequest extends SingleTextInputRequest {}
4 |
--------------------------------------------------------------------------------
/src/lib/types/token-classification/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 |
3 | export interface TokenClassificationItem {
4 | entity_group: string;
5 | score: number;
6 | word: string;
7 | start: number;
8 | end: number;
9 | }
10 |
11 | export interface TokenClassificationResponse {
12 | results: TokenClassificationItem[];
13 | request_id?: string;
14 | inference_status: Status;
15 | }
16 |
--------------------------------------------------------------------------------
/src/lib/types/zero-shot-image-classification/request.ts:
--------------------------------------------------------------------------------
1 | import { ImageRequest } from "@/lib/types/common/image-request";
2 |
3 | export interface ZeroShotImageClassificationRequest extends ImageRequest {
4 | candidate_labels: string[];
5 | }
6 |
--------------------------------------------------------------------------------
/src/lib/types/zero-shot-image-classification/response.ts:
--------------------------------------------------------------------------------
1 | import { Status } from "@/lib/types/common/status";
2 | import { ImageItem } from "@/lib/types/common/image-item";
3 |
4 | export interface ZeroShotImageClassificationResponse {
5 | results: ImageItem[];
6 | request_id?: string;
7 | inference_status: Status;
8 | }
9 |
--------------------------------------------------------------------------------
/src/lib/utils/form-data.ts:
--------------------------------------------------------------------------------
1 | import FormData from "form-data";
2 | import { ReadStreamInput, ReadStreamUtils } from "@/lib/utils/read-stream";
3 |
4 | export const FormDataUtils = {
5 | /**
6 | * Prepare form data from the given data object
7 | * @param data - The data object to be converted to form data.
8 | * @param blobKeys - The keys of the data object that contain binary data.
9 | * @returns A FormData object.
10 | * @throws {Error} If the binary data is invalid.
11 | *
12 | */
13 | async prepareFormData(
14 | data: T,
15 | blobKeys: string[] = [],
16 | ): Promise {
17 | const formData = new FormData();
18 | for (const [key, value] of Object.entries(data)) {
19 | if (blobKeys.includes(key)) {
20 | const readStream = await ReadStreamUtils.getReadStream(
21 | value as ReadStreamInput,
22 | );
23 | formData.append(key, readStream);
24 | } else {
25 | formData.append(key, JSON.stringify(value));
26 | }
27 | }
28 | return formData;
29 | },
30 | };
31 |
--------------------------------------------------------------------------------
/src/lib/utils/read-stream.ts:
--------------------------------------------------------------------------------
1 | import fs from "node:fs";
2 | import axios from "axios";
3 | import { Readable } from "stream";
4 |
5 | /**
6 | * The input types that can be converted to a ReadableStream.
7 | * @alias ReadStreamInput
8 | */
9 | export type ReadStreamInput = Buffer | string;
10 |
11 | /**
12 | * Utility class for working with images.
13 | * @class
14 | * @category Utils
15 | * @hideconstructor
16 | */
17 |
18 | export const ReadStreamUtils = {
19 | /**
20 | * Creates a ReadableStream from a Buffer.
21 | * @param buffer - The Buffer to be streamed.
22 | * @returns A ReadableStream.
23 | */
24 | bufferToStream(buffer: Buffer): Readable {
25 | const stream = new Readable();
26 | stream.push(buffer);
27 | stream.push(null); // Signifies the end of the stream.
28 | return stream;
29 | },
30 |
31 | /**
32 | * Creates a ReadableStream from a file path.
33 | * @param filePath - The path to the image file.
34 | * @returns A ReadableStream of the file contents.
35 | */
36 | fileToStream(filePath: string): Readable {
37 | return fs.createReadStream(filePath);
38 | },
39 |
40 | /**
41 | * Downloads an image from a URL and returns it as a ReadableStream.
42 | * @param url - The URL of the image.
43 | * @returns A ReadableStream containing the image data.
44 | */
45 | async urlToStream(url: string): Promise {
46 | const response = await axios.get(url, { responseType: "stream" });
47 | return response.data;
48 | },
49 |
50 | /**
51 | * Converts a Base64 string to a ReadableStream.
52 | * @param base64 - The Base64 string to be converted.
53 | * @returns A ReadableStream of the image data.
54 | * @throws {Error} If the Base64 string is invalid.
55 | */
56 | base64ToStream(base64: string): Readable {
57 | const buffer = Buffer.from(base64, "base64");
58 | return this.bufferToStream(buffer);
59 | },
60 |
61 | /**
62 | * Returns a ReadableStream from an object.
63 | * The object can be a Buffer, a file path, or a URL.
64 | * @param input
65 | * @returns A ReadableStream of the image data.
66 | * @throws {Error} If the input type is invalid.
67 | */
68 | async getReadStream(input: ReadStreamInput): Promise {
69 | if (Buffer.isBuffer(input)) {
70 | return this.bufferToStream(input);
71 | } else if (typeof input === "string") {
72 | if (input.startsWith("http")) {
73 | return this.urlToStream(input);
74 | } else if (input.startsWith("data:")) {
75 | const base64Data = input.split(",")[1];
76 | return this.base64ToStream(base64Data);
77 | } else {
78 | return this.fileToStream(input);
79 | }
80 | } else {
81 | throw new Error("Invalid input type");
82 | }
83 | },
84 | };
85 |
--------------------------------------------------------------------------------
/src/lib/utils/url.ts:
--------------------------------------------------------------------------------
1 | export const URLUtils = {
2 | isValidUrl: (urlString: string): boolean => {
3 | try {
4 | const url = new URL(urlString);
5 | return ["http:", "https:"].includes(url.protocol);
6 | } catch {
7 | return false;
8 | }
9 | },
10 | };
11 |
--------------------------------------------------------------------------------
/test/base/automatic-speech-recognition.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | import { ROOT_URL } from "@/lib/constants/client";
6 | import { AutomaticSpeechRecognition } from "@/index";
7 | import FormData from "form-data";
8 |
9 | jest.mock("axios", () => {
10 | const mockAxiosInstance = {
11 | post: postMock,
12 | get: jest.fn().mockResolvedValue({ data: {} }),
13 | };
14 | return {
15 | get: jest.fn(() => mockAxiosInstance),
16 | create: jest.fn(() => mockAxiosInstance),
17 | };
18 | });
19 |
20 | describe("AutomaticSpeechRecognition", () => {
21 | const modelName = "openai/whisper-large";
22 | const apiKey = "your-api-key";
23 | let model: AutomaticSpeechRecognition;
24 |
25 | beforeAll(() => {
26 | model = new AutomaticSpeechRecognition(modelName, apiKey);
27 | });
28 |
29 | it("should create a new instance", () => {
30 | expect(model).toBeInstanceOf(AutomaticSpeechRecognition);
31 | });
32 |
33 | it("should send a request to correct URL", async () => {
34 | const response = await model.generate({
35 | audio: "test/assets/audio.mp3",
36 | });
37 |
38 | expect(response).toBeDefined();
39 | expect(postMock).toHaveBeenCalledWith(
40 | `${ROOT_URL}${modelName}`,
41 | expect.any(FormData),
42 | expect.objectContaining({
43 | headers: expect.objectContaining({
44 | "content-type": expect.stringMatching(/multipart\/form-data/),
45 | }),
46 | }),
47 | );
48 | });
49 | });
50 |
--------------------------------------------------------------------------------
/test/base/embeddings.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ status: 200, data: { transcription: "example text" } });
4 | jest.mock("axios", () => {
5 | const mockAxiosInstance = {
6 | post: postMock,
7 | };
8 | return {
9 | create: jest.fn(() => mockAxiosInstance),
10 | };
11 | });
12 | import { ROOT_URL } from "@/lib/constants/client";
13 | import { Embeddings } from "@/index";
14 |
15 | describe("Embeddings", () => {
16 | const modelName = "BAAI/bge-large-en-v1.5";
17 | const apiKey = "your-api-key";
18 | let model: Embeddings;
19 |
20 | beforeAll(() => {
21 | model = new Embeddings(modelName, apiKey);
22 | });
23 |
24 | it("should create a new instance", () => {
25 | expect(model).toBeInstanceOf(Embeddings);
26 | });
27 |
28 | it("should send a request to correct URL", async () => {
29 | const response = await model.generate({
30 | inputs: ["Hello world", "Hallo Wereld"],
31 | });
32 |
33 | expect(response).toBeDefined();
34 | expect(postMock).toHaveBeenCalledWith(
35 | `${ROOT_URL}${modelName}`,
36 | expect.any(Object),
37 | expect.any(Object),
38 | );
39 | });
40 | });
41 |
--------------------------------------------------------------------------------
/test/base/fill-mask.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ status: 200, data: { transcription: "example text" } });
4 | jest.mock("axios", () => {
5 | const mockAxiosInstance = {
6 | post: postMock,
7 | };
8 | return {
9 | create: jest.fn(() => mockAxiosInstance),
10 | };
11 | });
12 | import { ROOT_URL } from "@/lib/constants/client";
13 | import { FillMask } from "@/index";
14 |
15 | describe("FillMask", () => {
16 | const modelName = "GroNLP/bert-base-dutch-cased";
17 | const apiKey = "your-api-key";
18 | let model: FillMask;
19 |
20 | beforeAll(() => {
21 | model = new FillMask(modelName, apiKey);
22 | });
23 |
24 | it("should create a new instance", () => {
25 | expect(model).toBeInstanceOf(FillMask);
26 | });
27 |
28 | it("should send a request to correct URL", async () => {
29 | const response = await model.generate({
30 | input: "Hello [MASK]",
31 | });
32 |
33 | expect(response).toBeDefined();
34 | expect(postMock).toHaveBeenCalledWith(
35 | `${ROOT_URL}${modelName}`,
36 | expect.any(Object),
37 | expect.any(Object),
38 | );
39 | });
40 | });
41 |
--------------------------------------------------------------------------------
/test/base/image-classification.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 | jest.mock("axios", () => {
5 | const mockAxiosInstance = {
6 | post: postMock,
7 | };
8 | return {
9 | create: jest.fn(() => mockAxiosInstance),
10 | };
11 | });
12 | import { ROOT_URL } from "@/lib/constants/client";
13 | import { ImageClassification } from "@/index";
14 | import FormData from "form-data";
15 |
16 | describe("ImageClassification", () => {
17 | const modelName = "google/vit-base-patch16-224";
18 | const apiKey = "your-api-key";
19 | let model: ImageClassification;
20 |
21 | beforeAll(() => {
22 | model = new ImageClassification(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(ImageClassification);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | image: "test/assets/image.jpg",
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(FormData),
38 | expect.objectContaining({
39 | headers: expect.objectContaining({
40 | "content-type": expect.stringMatching(/multipart\/form-data/),
41 | }),
42 | }),
43 | );
44 | });
45 | });
46 |
--------------------------------------------------------------------------------
/test/base/object-detection.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 | jest.mock("axios", () => {
5 | const mockAxiosInstance = {
6 | post: postMock,
7 | };
8 | return {
9 | create: jest.fn(() => mockAxiosInstance),
10 | };
11 | });
12 | import { ROOT_URL } from "@/lib/constants/client";
13 | import { ObjectDetection } from "@/index";
14 | import FormData from "form-data";
15 |
16 | describe("ObjectDetection", () => {
17 | const modelName = "hustvl/yolos-base";
18 | const apiKey = "your-api-key";
19 | let model: ObjectDetection;
20 |
21 | beforeAll(() => {
22 | model = new ObjectDetection(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(ObjectDetection);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | image: "test/assets/image.jpg",
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(FormData),
38 | expect.objectContaining({
39 | headers: expect.objectContaining({
40 | "content-type": expect.stringMatching(/multipart\/form-data/),
41 | }),
42 | }),
43 | );
44 | });
45 | });
46 |
--------------------------------------------------------------------------------
/test/base/question-answering.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | jest.mock("axios", () => {
6 | const mockAxiosInstance = {
7 | post: postMock,
8 | };
9 | return {
10 | create: jest.fn(() => mockAxiosInstance),
11 | };
12 | });
13 | import { ROOT_URL } from "@/lib/constants/client";
14 | import { QuestionAnswering } from "@/index";
15 |
16 | describe("QuestionAnswering", () => {
17 | const modelName = "bert-large-uncased-whole-word-masking-finetuned-squad";
18 | const apiKey = "your-api-key";
19 | let model: QuestionAnswering;
20 |
21 | beforeAll(() => {
22 | model = new QuestionAnswering(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(QuestionAnswering);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | question: "What is the capital of France?",
32 | context: "France is a country in Europe.",
33 | });
34 |
35 | expect(response).toBeDefined();
36 | expect(postMock).toHaveBeenCalledWith(
37 | `${ROOT_URL}${modelName}`,
38 | expect.any(Object),
39 | expect.any(Object),
40 | );
41 | });
42 | });
43 |
--------------------------------------------------------------------------------
/test/base/sdxl.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | jest.mock("axios", () => {
6 | const mockAxiosInstance = {
7 | post: postMock,
8 | };
9 | return {
10 | create: jest.fn(() => mockAxiosInstance),
11 | };
12 | });
13 | import { ROOT_URL } from "@/lib/constants/client";
14 | import { Sdxl } from "@/index";
15 |
16 | describe("Sdxl", () => {
17 | const modelName = "stability-ai/sdxl";
18 | const apiKey = "your-api-key";
19 | let model: Sdxl;
20 |
21 | beforeAll(() => {
22 | model = new Sdxl(apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(Sdxl);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | input: { prompt: "The quick brown fox jumps over the lazy dog" },
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(Object),
38 | expect.any(Object),
39 | );
40 | });
41 | });
42 |
--------------------------------------------------------------------------------
/test/base/text-classification.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | jest.mock("axios", () => {
6 | const mockAxiosInstance = {
7 | post: postMock,
8 | };
9 | return {
10 | create: jest.fn(() => mockAxiosInstance),
11 | };
12 | });
13 | import { ROOT_URL } from "@/lib/constants/client";
14 | import { TextClassification } from "@/index";
15 |
16 | describe("TextClassification", () => {
17 | const modelName = "ProsusAI/finbert";
18 | const apiKey = "your-api-key";
19 | let model: TextClassification;
20 |
21 | beforeAll(() => {
22 | model = new TextClassification(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(TextClassification);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | input: "The quick brown fox jumps over the lazy dog",
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(Object),
38 | expect.any(Object),
39 | );
40 | });
41 | });
42 |
--------------------------------------------------------------------------------
/test/base/text-generation.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | jest.mock("axios", () => {
6 | const mockAxiosInstance = {
7 | post: postMock,
8 | };
9 | return {
10 | create: jest.fn(() => mockAxiosInstance),
11 | };
12 | });
13 | import { ROOT_URL } from "@/lib/constants/client";
14 | import { TextGeneration } from "@/index";
15 |
16 | describe("TextGeneration", () => {
17 | const modelName = "microsoft/WizardLM-2-8x22B";
18 | const apiKey = "your-api-key";
19 | let model: TextGeneration;
20 |
21 | beforeAll(() => {
22 | model = new TextGeneration(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(TextGeneration);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | input: "The quick brown fox jumps over the lazy dog",
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(Object),
38 | expect.any(Object),
39 | );
40 | });
41 | });
42 |
--------------------------------------------------------------------------------
/test/base/text-to-image.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | jest.mock("axios", () => {
6 | const mockAxiosInstance = {
7 | post: postMock,
8 | };
9 | return {
10 | create: jest.fn(() => mockAxiosInstance),
11 | };
12 | });
13 | import { ROOT_URL } from "@/lib/constants/client";
14 | import { TextToImage } from "@/index";
15 |
16 | describe("TextToImage", () => {
17 | const modelName = "runwayml/stable-diffusion-v1-5";
18 | const apiKey = "your-api-key";
19 | let model: TextToImage;
20 |
21 | beforeAll(() => {
22 | model = new TextToImage(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(TextToImage);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | prompt: "The quick brown fox jumps over the lazy dog",
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(Object),
38 | expect.any(Object),
39 | );
40 | });
41 | });
42 |
--------------------------------------------------------------------------------
/test/base/token-classification.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 |
5 | jest.mock("axios", () => {
6 | const mockAxiosInstance = {
7 | post: postMock,
8 | };
9 | return {
10 | create: jest.fn(() => mockAxiosInstance),
11 | };
12 | });
13 | import { ROOT_URL } from "@/lib/constants/client";
14 | import { TokenClassification } from "@/index";
15 |
16 | describe("TokenClassification", () => {
17 | const modelName = "Davlan/bert-base-multilingual-cased-ner-hrl";
18 | const apiKey = "your-api-key";
19 | let model: TokenClassification;
20 |
21 | beforeAll(() => {
22 | model = new TokenClassification(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(TokenClassification);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | input: "The quick brown fox jumps over the lazy dog",
32 | });
33 |
34 | expect(response).toBeDefined();
35 | expect(postMock).toHaveBeenCalledWith(
36 | `${ROOT_URL}${modelName}`,
37 | expect.any(Object),
38 | expect.any(Object),
39 | );
40 | });
41 |
42 | it("should write to console if DEEPINFRA_API_KEY is not set", () => {
43 | const consoleSpy = jest.spyOn(console, "warn");
44 | const model = new TokenClassification(modelName);
45 | expect(model).toBeDefined();
46 | expect(consoleSpy).toHaveBeenCalled();
47 | });
48 |
49 | it("should be constructed with an API key", () => {
50 | process.env.DEEPINFRA_API_KEY = apiKey;
51 | const model = new TokenClassification(modelName);
52 | expect(model).toBeDefined();
53 | process.env.DEEPINFRA_API_KEY = "";
54 | });
55 | });
56 |
--------------------------------------------------------------------------------
/test/base/zero-shot-image-classification.spec.ts:
--------------------------------------------------------------------------------
1 | const postMock = jest
2 | .fn()
3 | .mockResolvedValue({ data: { transcription: "example text" } });
4 | jest.mock("axios", () => {
5 | const mockAxiosInstance = {
6 | post: postMock,
7 | };
8 | return {
9 | create: jest.fn(() => mockAxiosInstance),
10 | };
11 | });
12 | import { ROOT_URL } from "@/lib/constants/client";
13 | import { ZeroShotImageClassification } from "@/index";
14 | import FormData from "form-data";
15 |
16 | describe("ZeroShotImageClassification", () => {
17 | const modelName = "openai/clip-vit-base-patch32";
18 | const apiKey = "your-api-key";
19 | let model: ZeroShotImageClassification;
20 |
21 | beforeAll(() => {
22 | model = new ZeroShotImageClassification(modelName, apiKey);
23 | });
24 |
25 | it("should create a new instance", () => {
26 | expect(model).toBeInstanceOf(ZeroShotImageClassification);
27 | });
28 |
29 | it("should send a request to correct URL", async () => {
30 | const response = await model.generate({
31 | image: "test/assets/image.jpg",
32 | candidate_labels: ["dog", "cat"],
33 | });
34 |
35 | expect(response).toBeDefined();
36 | expect(postMock).toHaveBeenCalledWith(
37 | `${ROOT_URL}${modelName}`,
38 | expect.any(FormData),
39 | expect.objectContaining({
40 | headers: expect.objectContaining({
41 | "content-type": expect.stringMatching(/multipart\/form-data/),
42 | }),
43 | }),
44 | );
45 | });
46 | });
47 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "esnext",
4 | "declaration": true,
5 | "module": "commonjs",
6 | "moduleResolution": "node",
7 | "allowSyntheticDefaultImports": true,
8 | "experimentalDecorators": true,
9 | "emitDecoratorMetadata": true,
10 | "allowJs": true,
11 | "esModuleInterop": true,
12 | "forceConsistentCasingInFileNames": true,
13 | "strict": true,
14 | "resolveJsonModule": true,
15 | "skipLibCheck": true,
16 | "outDir": "dist",
17 | "rootDir": "src",
18 | "types": [
19 | "node",
20 | "jest"
21 | ],
22 | "paths": {
23 | "@/*": [
24 | "src/*"
25 | ],
26 | },
27 | "baseUrl": "."
28 | },
29 | "include": [
30 | "src",
31 | "src/**/*.json"
32 | ],
33 | "exclude": [
34 | "**/misc.ts",
35 | "node_modules",
36 | "dist",
37 | "**/*.spec.ts",
38 | ]
39 | }
40 |
--------------------------------------------------------------------------------