├── .editorconfig ├── .gitattributes ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── .yarn └── sdks │ ├── integrations.yml │ └── typescript │ ├── bin │ ├── tsc │ └── tsserver │ ├── lib │ ├── tsc.js │ ├── tsserver.js │ ├── tsserverlibrary.js │ └── typescript.js │ └── package.json ├── .yarnrc.yml ├── LICENSE ├── README.md ├── bundle.ts ├── package.json ├── src ├── index.ts └── nodes │ ├── GetOllamaModelNode.ts │ ├── ListOllamaModelsNode.ts │ ├── OllamaChatNode.ts │ ├── OllamaEmbeddingNode.ts │ ├── OllamaGenerateNode.ts │ └── PullModelToOllamaNode.ts ├── tsconfig.json └── yarn.lock /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | 7 | [*.{js,json,yml}] 8 | charset = utf-8 9 | indent_style = space 10 | indent_size = 2 11 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | /.yarn/** linguist-vendored 2 | /.yarn/releases/* binary 3 | /.yarn/plugins/**/* binary 4 | /.pnp.* binary linguist-generated 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .yarn/* 2 | !.yarn/patches 3 | !.yarn/plugins 4 | !.yarn/releases 5 | !.yarn/sdks 6 | !.yarn/versions 7 | 8 | # Swap the comments on the following lines if you don't wish to use zero-installs 9 | # Documentation here: https://yarnpkg.com/features/zero-installs 10 | # !.yarn/cache 11 | .pnp.* 12 | node_modules 13 | dist 14 | 15 | # src/**/*.js 16 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "arcanis.vscode-zipfs" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "search.exclude": { 3 | "**/.yarn": true, 4 | "**/.pnp.*": true 5 | }, 6 | "files.exclude": { 7 | "**/.git": true, 8 | "**/.svn": true, 9 | "**/.hg": true, 10 | "**/CVS": true, 11 | "**/.DS_Store": true, 12 | "**/Thumbs.db": true 13 | }, 14 | "typescript.tsdk": ".yarn/sdks/typescript/lib", 15 | "typescript.enablePromptUseWorkspaceTsdk": true 16 | } 17 | -------------------------------------------------------------------------------- /.yarn/sdks/integrations.yml: -------------------------------------------------------------------------------- 1 | # This file is automatically generated by @yarnpkg/sdks. 2 | # Manual changes might be lost! 3 | 4 | integrations: 5 | - vscode 6 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/bin/tsc: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const {existsSync} = require(`fs`); 4 | const {createRequire} = require(`module`); 5 | const {resolve} = require(`path`); 6 | 7 | const relPnpApiPath = "../../../../.pnp.cjs"; 8 | 9 | const absPnpApiPath = resolve(__dirname, relPnpApiPath); 10 | const absRequire = createRequire(absPnpApiPath); 11 | 12 | if (existsSync(absPnpApiPath)) { 13 | if (!process.versions.pnp) { 14 | // Setup the environment to be able to require typescript/bin/tsc 15 | require(absPnpApiPath).setup(); 16 | } 17 | } 18 | 19 | // Defer to the real typescript/bin/tsc your application uses 20 | module.exports = absRequire(`typescript/bin/tsc`); 21 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/bin/tsserver: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const {existsSync} = require(`fs`); 4 | const {createRequire} = require(`module`); 5 | const {resolve} = require(`path`); 6 | 7 | const relPnpApiPath = "../../../../.pnp.cjs"; 8 | 9 | const absPnpApiPath = resolve(__dirname, relPnpApiPath); 10 | const absRequire = createRequire(absPnpApiPath); 11 | 12 | if (existsSync(absPnpApiPath)) { 13 | if (!process.versions.pnp) { 14 | // Setup the environment to be able to require typescript/bin/tsserver 15 | require(absPnpApiPath).setup(); 16 | } 17 | } 18 | 19 | // Defer to the real typescript/bin/tsserver your application uses 20 | module.exports = absRequire(`typescript/bin/tsserver`); 21 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/lib/tsc.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const {existsSync} = require(`fs`); 4 | const {createRequire} = require(`module`); 5 | const {resolve} = require(`path`); 6 | 7 | const relPnpApiPath = "../../../../.pnp.cjs"; 8 | 9 | const absPnpApiPath = resolve(__dirname, relPnpApiPath); 10 | const absRequire = createRequire(absPnpApiPath); 11 | 12 | if (existsSync(absPnpApiPath)) { 13 | if (!process.versions.pnp) { 14 | // Setup the environment to be able to require typescript/lib/tsc.js 15 | require(absPnpApiPath).setup(); 16 | } 17 | } 18 | 19 | // Defer to the real typescript/lib/tsc.js your application uses 20 | module.exports = absRequire(`typescript/lib/tsc.js`); 21 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/lib/tsserver.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const {existsSync} = require(`fs`); 4 | const {createRequire} = require(`module`); 5 | const {resolve} = require(`path`); 6 | 7 | const relPnpApiPath = "../../../../.pnp.cjs"; 8 | 9 | const absPnpApiPath = resolve(__dirname, relPnpApiPath); 10 | const absRequire = createRequire(absPnpApiPath); 11 | 12 | const moduleWrapper = tsserver => { 13 | if (!process.versions.pnp) { 14 | return tsserver; 15 | } 16 | 17 | const {isAbsolute} = require(`path`); 18 | const pnpApi = require(`pnpapi`); 19 | 20 | const isVirtual = str => str.match(/\/(\$\$virtual|__virtual__)\//); 21 | const isPortal = str => str.startsWith("portal:/"); 22 | const normalize = str => str.replace(/\\/g, `/`).replace(/^\/?/, `/`); 23 | 24 | const dependencyTreeRoots = new Set(pnpApi.getDependencyTreeRoots().map(locator => { 25 | return `${locator.name}@${locator.reference}`; 26 | })); 27 | 28 | // VSCode sends the zip paths to TS using the "zip://" prefix, that TS 29 | // doesn't understand. This layer makes sure to remove the protocol 30 | // before forwarding it to TS, and to add it back on all returned paths. 31 | 32 | function toEditorPath(str) { 33 | // We add the `zip:` prefix to both `.zip/` paths and virtual paths 34 | if (isAbsolute(str) && !str.match(/^\^?(zip:|\/zip\/)/) && (str.match(/\.zip\//) || isVirtual(str))) { 35 | // We also take the opportunity to turn virtual paths into physical ones; 36 | // this makes it much easier to work with workspaces that list peer 37 | // dependencies, since otherwise Ctrl+Click would bring us to the virtual 38 | // file instances instead of the real ones. 39 | // 40 | // We only do this to modules owned by the the dependency tree roots. 41 | // This avoids breaking the resolution when jumping inside a vendor 42 | // with peer dep (otherwise jumping into react-dom would show resolution 43 | // errors on react). 44 | // 45 | const resolved = isVirtual(str) ? pnpApi.resolveVirtual(str) : str; 46 | if (resolved) { 47 | const locator = pnpApi.findPackageLocator(resolved); 48 | if (locator && (dependencyTreeRoots.has(`${locator.name}@${locator.reference}`) || isPortal(locator.reference))) { 49 | str = resolved; 50 | } 51 | } 52 | 53 | str = normalize(str); 54 | 55 | if (str.match(/\.zip\//)) { 56 | switch (hostInfo) { 57 | // Absolute VSCode `Uri.fsPath`s need to start with a slash. 58 | // VSCode only adds it automatically for supported schemes, 59 | // so we have to do it manually for the `zip` scheme. 60 | // The path needs to start with a caret otherwise VSCode doesn't handle the protocol 61 | // 62 | // Ref: https://github.com/microsoft/vscode/issues/105014#issuecomment-686760910 63 | // 64 | // 2021-10-08: VSCode changed the format in 1.61. 65 | // Before | ^zip:/c:/foo/bar.zip/package.json 66 | // After | ^/zip//c:/foo/bar.zip/package.json 67 | // 68 | // 2022-04-06: VSCode changed the format in 1.66. 69 | // Before | ^/zip//c:/foo/bar.zip/package.json 70 | // After | ^/zip/c:/foo/bar.zip/package.json 71 | // 72 | // 2022-05-06: VSCode changed the format in 1.68 73 | // Before | ^/zip/c:/foo/bar.zip/package.json 74 | // After | ^/zip//c:/foo/bar.zip/package.json 75 | // 76 | case `vscode <1.61`: { 77 | str = `^zip:${str}`; 78 | } break; 79 | 80 | case `vscode <1.66`: { 81 | str = `^/zip/${str}`; 82 | } break; 83 | 84 | case `vscode <1.68`: { 85 | str = `^/zip${str}`; 86 | } break; 87 | 88 | case `vscode`: { 89 | str = `^/zip/${str}`; 90 | } break; 91 | 92 | // To make "go to definition" work, 93 | // We have to resolve the actual file system path from virtual path 94 | // and convert scheme to supported by [vim-rzip](https://github.com/lbrayner/vim-rzip) 95 | case `coc-nvim`: { 96 | str = normalize(resolved).replace(/\.zip\//, `.zip::`); 97 | str = resolve(`zipfile:${str}`); 98 | } break; 99 | 100 | // Support neovim native LSP and [typescript-language-server](https://github.com/theia-ide/typescript-language-server) 101 | // We have to resolve the actual file system path from virtual path, 102 | // everything else is up to neovim 103 | case `neovim`: { 104 | str = normalize(resolved).replace(/\.zip\//, `.zip::`); 105 | str = `zipfile://${str}`; 106 | } break; 107 | 108 | default: { 109 | str = `zip:${str}`; 110 | } break; 111 | } 112 | } else { 113 | str = str.replace(/^\/?/, process.platform === `win32` ? `` : `/`); 114 | } 115 | } 116 | 117 | return str; 118 | } 119 | 120 | function fromEditorPath(str) { 121 | switch (hostInfo) { 122 | case `coc-nvim`: { 123 | str = str.replace(/\.zip::/, `.zip/`); 124 | // The path for coc-nvim is in format of //zipfile://.yarn/... 125 | // So in order to convert it back, we use .* to match all the thing 126 | // before `zipfile:` 127 | return process.platform === `win32` 128 | ? str.replace(/^.*zipfile:\//, ``) 129 | : str.replace(/^.*zipfile:/, ``); 130 | } break; 131 | 132 | case `neovim`: { 133 | str = str.replace(/\.zip::/, `.zip/`); 134 | // The path for neovim is in format of zipfile:////.yarn/... 135 | return str.replace(/^zipfile:\/\//, ``); 136 | } break; 137 | 138 | case `vscode`: 139 | default: { 140 | return str.replace(/^\^?(zip:|\/zip(\/ts-nul-authority)?)\/+/, process.platform === `win32` ? `` : `/`) 141 | } break; 142 | } 143 | } 144 | 145 | // Force enable 'allowLocalPluginLoads' 146 | // TypeScript tries to resolve plugins using a path relative to itself 147 | // which doesn't work when using the global cache 148 | // https://github.com/microsoft/TypeScript/blob/1b57a0395e0bff191581c9606aab92832001de62/src/server/project.ts#L2238 149 | // VSCode doesn't want to enable 'allowLocalPluginLoads' due to security concerns but 150 | // TypeScript already does local loads and if this code is running the user trusts the workspace 151 | // https://github.com/microsoft/vscode/issues/45856 152 | const ConfiguredProject = tsserver.server.ConfiguredProject; 153 | const {enablePluginsWithOptions: originalEnablePluginsWithOptions} = ConfiguredProject.prototype; 154 | ConfiguredProject.prototype.enablePluginsWithOptions = function() { 155 | this.projectService.allowLocalPluginLoads = true; 156 | return originalEnablePluginsWithOptions.apply(this, arguments); 157 | }; 158 | 159 | // And here is the point where we hijack the VSCode <-> TS communications 160 | // by adding ourselves in the middle. We locate everything that looks 161 | // like an absolute path of ours and normalize it. 162 | 163 | const Session = tsserver.server.Session; 164 | const {onMessage: originalOnMessage, send: originalSend} = Session.prototype; 165 | let hostInfo = `unknown`; 166 | 167 | Object.assign(Session.prototype, { 168 | onMessage(/** @type {string | object} */ message) { 169 | const isStringMessage = typeof message === 'string'; 170 | const parsedMessage = isStringMessage ? JSON.parse(message) : message; 171 | 172 | if ( 173 | parsedMessage != null && 174 | typeof parsedMessage === `object` && 175 | parsedMessage.arguments && 176 | typeof parsedMessage.arguments.hostInfo === `string` 177 | ) { 178 | hostInfo = parsedMessage.arguments.hostInfo; 179 | if (hostInfo === `vscode` && process.env.VSCODE_IPC_HOOK) { 180 | const [, major, minor] = (process.env.VSCODE_IPC_HOOK.match( 181 | // The RegExp from https://semver.org/ but without the caret at the start 182 | /(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$/ 183 | ) ?? []).map(Number) 184 | 185 | if (major === 1) { 186 | if (minor < 61) { 187 | hostInfo += ` <1.61`; 188 | } else if (minor < 66) { 189 | hostInfo += ` <1.66`; 190 | } else if (minor < 68) { 191 | hostInfo += ` <1.68`; 192 | } 193 | } 194 | } 195 | } 196 | 197 | const processedMessageJSON = JSON.stringify(parsedMessage, (key, value) => { 198 | return typeof value === 'string' ? fromEditorPath(value) : value; 199 | }); 200 | 201 | return originalOnMessage.call( 202 | this, 203 | isStringMessage ? processedMessageJSON : JSON.parse(processedMessageJSON) 204 | ); 205 | }, 206 | 207 | send(/** @type {any} */ msg) { 208 | return originalSend.call(this, JSON.parse(JSON.stringify(msg, (key, value) => { 209 | return typeof value === `string` ? toEditorPath(value) : value; 210 | }))); 211 | } 212 | }); 213 | 214 | return tsserver; 215 | }; 216 | 217 | if (existsSync(absPnpApiPath)) { 218 | if (!process.versions.pnp) { 219 | // Setup the environment to be able to require typescript/lib/tsserver.js 220 | require(absPnpApiPath).setup(); 221 | } 222 | } 223 | 224 | // Defer to the real typescript/lib/tsserver.js your application uses 225 | module.exports = moduleWrapper(absRequire(`typescript/lib/tsserver.js`)); 226 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/lib/tsserverlibrary.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const {existsSync} = require(`fs`); 4 | const {createRequire} = require(`module`); 5 | const {resolve} = require(`path`); 6 | 7 | const relPnpApiPath = "../../../../.pnp.cjs"; 8 | 9 | const absPnpApiPath = resolve(__dirname, relPnpApiPath); 10 | const absRequire = createRequire(absPnpApiPath); 11 | 12 | const moduleWrapper = tsserver => { 13 | if (!process.versions.pnp) { 14 | return tsserver; 15 | } 16 | 17 | const {isAbsolute} = require(`path`); 18 | const pnpApi = require(`pnpapi`); 19 | 20 | const isVirtual = str => str.match(/\/(\$\$virtual|__virtual__)\//); 21 | const isPortal = str => str.startsWith("portal:/"); 22 | const normalize = str => str.replace(/\\/g, `/`).replace(/^\/?/, `/`); 23 | 24 | const dependencyTreeRoots = new Set(pnpApi.getDependencyTreeRoots().map(locator => { 25 | return `${locator.name}@${locator.reference}`; 26 | })); 27 | 28 | // VSCode sends the zip paths to TS using the "zip://" prefix, that TS 29 | // doesn't understand. This layer makes sure to remove the protocol 30 | // before forwarding it to TS, and to add it back on all returned paths. 31 | 32 | function toEditorPath(str) { 33 | // We add the `zip:` prefix to both `.zip/` paths and virtual paths 34 | if (isAbsolute(str) && !str.match(/^\^?(zip:|\/zip\/)/) && (str.match(/\.zip\//) || isVirtual(str))) { 35 | // We also take the opportunity to turn virtual paths into physical ones; 36 | // this makes it much easier to work with workspaces that list peer 37 | // dependencies, since otherwise Ctrl+Click would bring us to the virtual 38 | // file instances instead of the real ones. 39 | // 40 | // We only do this to modules owned by the the dependency tree roots. 41 | // This avoids breaking the resolution when jumping inside a vendor 42 | // with peer dep (otherwise jumping into react-dom would show resolution 43 | // errors on react). 44 | // 45 | const resolved = isVirtual(str) ? pnpApi.resolveVirtual(str) : str; 46 | if (resolved) { 47 | const locator = pnpApi.findPackageLocator(resolved); 48 | if (locator && (dependencyTreeRoots.has(`${locator.name}@${locator.reference}`) || isPortal(locator.reference))) { 49 | str = resolved; 50 | } 51 | } 52 | 53 | str = normalize(str); 54 | 55 | if (str.match(/\.zip\//)) { 56 | switch (hostInfo) { 57 | // Absolute VSCode `Uri.fsPath`s need to start with a slash. 58 | // VSCode only adds it automatically for supported schemes, 59 | // so we have to do it manually for the `zip` scheme. 60 | // The path needs to start with a caret otherwise VSCode doesn't handle the protocol 61 | // 62 | // Ref: https://github.com/microsoft/vscode/issues/105014#issuecomment-686760910 63 | // 64 | // 2021-10-08: VSCode changed the format in 1.61. 65 | // Before | ^zip:/c:/foo/bar.zip/package.json 66 | // After | ^/zip//c:/foo/bar.zip/package.json 67 | // 68 | // 2022-04-06: VSCode changed the format in 1.66. 69 | // Before | ^/zip//c:/foo/bar.zip/package.json 70 | // After | ^/zip/c:/foo/bar.zip/package.json 71 | // 72 | // 2022-05-06: VSCode changed the format in 1.68 73 | // Before | ^/zip/c:/foo/bar.zip/package.json 74 | // After | ^/zip//c:/foo/bar.zip/package.json 75 | // 76 | case `vscode <1.61`: { 77 | str = `^zip:${str}`; 78 | } break; 79 | 80 | case `vscode <1.66`: { 81 | str = `^/zip/${str}`; 82 | } break; 83 | 84 | case `vscode <1.68`: { 85 | str = `^/zip${str}`; 86 | } break; 87 | 88 | case `vscode`: { 89 | str = `^/zip/${str}`; 90 | } break; 91 | 92 | // To make "go to definition" work, 93 | // We have to resolve the actual file system path from virtual path 94 | // and convert scheme to supported by [vim-rzip](https://github.com/lbrayner/vim-rzip) 95 | case `coc-nvim`: { 96 | str = normalize(resolved).replace(/\.zip\//, `.zip::`); 97 | str = resolve(`zipfile:${str}`); 98 | } break; 99 | 100 | // Support neovim native LSP and [typescript-language-server](https://github.com/theia-ide/typescript-language-server) 101 | // We have to resolve the actual file system path from virtual path, 102 | // everything else is up to neovim 103 | case `neovim`: { 104 | str = normalize(resolved).replace(/\.zip\//, `.zip::`); 105 | str = `zipfile://${str}`; 106 | } break; 107 | 108 | default: { 109 | str = `zip:${str}`; 110 | } break; 111 | } 112 | } else { 113 | str = str.replace(/^\/?/, process.platform === `win32` ? `` : `/`); 114 | } 115 | } 116 | 117 | return str; 118 | } 119 | 120 | function fromEditorPath(str) { 121 | switch (hostInfo) { 122 | case `coc-nvim`: { 123 | str = str.replace(/\.zip::/, `.zip/`); 124 | // The path for coc-nvim is in format of //zipfile://.yarn/... 125 | // So in order to convert it back, we use .* to match all the thing 126 | // before `zipfile:` 127 | return process.platform === `win32` 128 | ? str.replace(/^.*zipfile:\//, ``) 129 | : str.replace(/^.*zipfile:/, ``); 130 | } break; 131 | 132 | case `neovim`: { 133 | str = str.replace(/\.zip::/, `.zip/`); 134 | // The path for neovim is in format of zipfile:////.yarn/... 135 | return str.replace(/^zipfile:\/\//, ``); 136 | } break; 137 | 138 | case `vscode`: 139 | default: { 140 | return str.replace(/^\^?(zip:|\/zip(\/ts-nul-authority)?)\/+/, process.platform === `win32` ? `` : `/`) 141 | } break; 142 | } 143 | } 144 | 145 | // Force enable 'allowLocalPluginLoads' 146 | // TypeScript tries to resolve plugins using a path relative to itself 147 | // which doesn't work when using the global cache 148 | // https://github.com/microsoft/TypeScript/blob/1b57a0395e0bff191581c9606aab92832001de62/src/server/project.ts#L2238 149 | // VSCode doesn't want to enable 'allowLocalPluginLoads' due to security concerns but 150 | // TypeScript already does local loads and if this code is running the user trusts the workspace 151 | // https://github.com/microsoft/vscode/issues/45856 152 | const ConfiguredProject = tsserver.server.ConfiguredProject; 153 | const {enablePluginsWithOptions: originalEnablePluginsWithOptions} = ConfiguredProject.prototype; 154 | ConfiguredProject.prototype.enablePluginsWithOptions = function() { 155 | this.projectService.allowLocalPluginLoads = true; 156 | return originalEnablePluginsWithOptions.apply(this, arguments); 157 | }; 158 | 159 | // And here is the point where we hijack the VSCode <-> TS communications 160 | // by adding ourselves in the middle. We locate everything that looks 161 | // like an absolute path of ours and normalize it. 162 | 163 | const Session = tsserver.server.Session; 164 | const {onMessage: originalOnMessage, send: originalSend} = Session.prototype; 165 | let hostInfo = `unknown`; 166 | 167 | Object.assign(Session.prototype, { 168 | onMessage(/** @type {string | object} */ message) { 169 | const isStringMessage = typeof message === 'string'; 170 | const parsedMessage = isStringMessage ? JSON.parse(message) : message; 171 | 172 | if ( 173 | parsedMessage != null && 174 | typeof parsedMessage === `object` && 175 | parsedMessage.arguments && 176 | typeof parsedMessage.arguments.hostInfo === `string` 177 | ) { 178 | hostInfo = parsedMessage.arguments.hostInfo; 179 | if (hostInfo === `vscode` && process.env.VSCODE_IPC_HOOK) { 180 | const [, major, minor] = (process.env.VSCODE_IPC_HOOK.match( 181 | // The RegExp from https://semver.org/ but without the caret at the start 182 | /(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$/ 183 | ) ?? []).map(Number) 184 | 185 | if (major === 1) { 186 | if (minor < 61) { 187 | hostInfo += ` <1.61`; 188 | } else if (minor < 66) { 189 | hostInfo += ` <1.66`; 190 | } else if (minor < 68) { 191 | hostInfo += ` <1.68`; 192 | } 193 | } 194 | } 195 | } 196 | 197 | const processedMessageJSON = JSON.stringify(parsedMessage, (key, value) => { 198 | return typeof value === 'string' ? fromEditorPath(value) : value; 199 | }); 200 | 201 | return originalOnMessage.call( 202 | this, 203 | isStringMessage ? processedMessageJSON : JSON.parse(processedMessageJSON) 204 | ); 205 | }, 206 | 207 | send(/** @type {any} */ msg) { 208 | return originalSend.call(this, JSON.parse(JSON.stringify(msg, (key, value) => { 209 | return typeof value === `string` ? toEditorPath(value) : value; 210 | }))); 211 | } 212 | }); 213 | 214 | return tsserver; 215 | }; 216 | 217 | if (existsSync(absPnpApiPath)) { 218 | if (!process.versions.pnp) { 219 | // Setup the environment to be able to require typescript/lib/tsserverlibrary.js 220 | require(absPnpApiPath).setup(); 221 | } 222 | } 223 | 224 | // Defer to the real typescript/lib/tsserverlibrary.js your application uses 225 | module.exports = moduleWrapper(absRequire(`typescript/lib/tsserverlibrary.js`)); 226 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/lib/typescript.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const {existsSync} = require(`fs`); 4 | const {createRequire} = require(`module`); 5 | const {resolve} = require(`path`); 6 | 7 | const relPnpApiPath = "../../../../.pnp.cjs"; 8 | 9 | const absPnpApiPath = resolve(__dirname, relPnpApiPath); 10 | const absRequire = createRequire(absPnpApiPath); 11 | 12 | if (existsSync(absPnpApiPath)) { 13 | if (!process.versions.pnp) { 14 | // Setup the environment to be able to require typescript/lib/typescript.js 15 | require(absPnpApiPath).setup(); 16 | } 17 | } 18 | 19 | // Defer to the real typescript/lib/typescript.js your application uses 20 | module.exports = absRequire(`typescript/lib/typescript.js`); 21 | -------------------------------------------------------------------------------- /.yarn/sdks/typescript/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "typescript", 3 | "version": "5.2.2-sdk", 4 | "main": "./lib/typescript.js", 5 | "type": "commonjs" 6 | } 7 | -------------------------------------------------------------------------------- /.yarnrc.yml: -------------------------------------------------------------------------------- 1 | compressionLevel: mixed 2 | 3 | enableGlobalCache: false 4 | 5 | nodeLinker: node-modules 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Ironclad 2 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 3 | 4 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 5 | 6 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Rivet Logo

2 | 3 |
4 | 5 | 6 | logo 7 | 8 |
9 | 10 | # Rivet Ollama Plugin 11 | 12 | The Rivet Ollama Plugin is a plugin for [Rivet](https://rivet.ironcladapp.com) to allow you to use [Ollama](https://ollama.ai/) to run and chat with LLMs locally and easily. It adds the following nodes: 13 | 14 | - Ollama Chat 15 | - Ollama Embedding 16 | - Get Ollama Model 17 | - List Ollama Models 18 | - Pull Model to Ollama 19 | 20 | **Table of Contents** 21 | 22 | - [Running Ollama](#running-ollama) 23 | - [Using the plugin](#using-the-plugin) 24 | - [In Rivet](#in-rivet) 25 | - [In the SDK](#in-the-sdk) 26 | - [Configuration](#configuration) 27 | - [In Rivet](#in-rivet-1) 28 | - [In the SDK](#in-the-sdk-1) 29 | - [Nodes](#nodes) 30 | - [Ollama Chat](#ollama-generate) 31 | - [Inputs](#inputs) 32 | - [Outputs](#outputs) 33 | - [Editor Settings](#editor-settings) 34 | - [Ollama Embedding](#ollama-embedding) 35 | - [Inputs](#inputs) 36 | - [Outputs](#outputs) 37 | - [Editor Settings](#editor-settings) 38 | - [Ollama Generate](#ollama-generate) 39 | - [Inputs](#inputs) 40 | - [Outputs](#outputs) 41 | - [Editor Settings](#editor-settings) 42 | - [List Ollama Models](#list-ollama-models) 43 | - [Inputs](#inputs-1) 44 | - [Outputs](#outputs-1) 45 | - [Editor Settings](#editor-settings-1) 46 | - [Get Ollama Model](#get-ollama-model) 47 | - [Inputs](#inputs-2) 48 | - [Outputs](#outputs-2) 49 | - [Editor Settings](#editor-settings-2) 50 | - [Pull Model to Ollama](#pull-model-to-ollama) 51 | - [Inputs](#inputs-3) 52 | - [Outputs](#outputs-3) 53 | - [Editor Settings](#editor-settings-3) 54 | - [Local Development](#local-development) 55 | 56 | ## Running Ollama 57 | 58 | To run Ollama so that Rivet's default [browser executor](https://rivet.ironcladapp.com/docs/user-guide/executors#browser) can communicate with it, you will want to start it with the following command: 59 | 60 | ```bash 61 | OLLAMA_ORIGINS=* ollama serve 62 | ``` 63 | 64 | If you are using the [node executor](https://rivet.ironcladapp.com/docs/user-guide/executors#node), you can omit the `OLLAMA_ORIGINS` environment variable. 65 | 66 | ## Using the plugin 67 | 68 | ### In Rivet 69 | 70 | To use this plugin in Rivet: 71 | 72 | 1. Open the plugins overlay at the top of the screen. 73 | 2. Search for "rivet-plugin-ollama" 74 | 3. Click the "Add" button to install the plugin into your current project. 75 | 76 | ### In the SDK 77 | 78 | 1. Import the plugin and Rivet into your project: 79 | 80 | ```ts 81 | import * as Rivet from "@ironclad/rivet-node"; 82 | import RivetPluginOllama from "rivet-plugin-ollama"; 83 | ``` 84 | 85 | 2. Initialize the plugin and register the nodes with the `globalRivetNodeRegistry`: 86 | 87 | ```ts 88 | Rivet.globalRivetNodeRegistry.registerPlugin(RivetPluginOllama(Rivet)); 89 | ``` 90 | 91 | (You may also use your own node registry if you wish, instead of the global one.) 92 | 93 | 3. The nodes will now work when ran with `runGraphInFile` or `createProcessor`. 94 | 95 | ## Configuration 96 | 97 | ### In Rivet 98 | 99 | By default, the plugin will attempt to connect to Ollama at `http://localhost:11434`. If you would like you change this, you can open the Settings window, navigate to the Plugins area, and you will see a `Host` setting for Ollama. You can change this to the URL of your Ollama instance. For some users it works using `http://127.0.0.1:11434` instead. 100 | 101 | ### In the SDK 102 | 103 | When using the SDK, you can pass a `host` option to the plugin to configure the host: 104 | 105 | Using `createProcessor` or `runGraphInFile`, pass in via `pluginSettings` in `RunGraphOptions`: 106 | 107 | ```ts 108 | await createProcessor(project, { 109 | ...etc, 110 | pluginSettings: { 111 | ollama: { 112 | host: "http://localhost:11434", 113 | }, 114 | }, 115 | }); 116 | ``` 117 | 118 | ## Nodes 119 | 120 | ### Ollama Chat 121 | 122 | The main node of the plugin. Functions similarly to the [Chat Node](https://rivet.ironcladapp.com/docs/node-reference/chat) built in to Rivet. Uses /api/chat route 123 | 124 | #### Inputs 125 | 126 | | Title | Data Type | Description | Default Value | Notes | 127 | | ------------- | ---------------- | --------------------------------------------------- | ------------- | ---------------------------------------------------------------------------------------- | 128 | | System Prompt | `string` | The system prompt to prepend to the messages list. | (none) | Optional. | 129 | | Messages | 'chat-message[]' | The chat messages to use as the prompt for the LLM. | (none) | Chat messages are converted to the OpenAI message format using "role" and "content" keys | 130 | 131 | #### Outputs 132 | 133 | | Title | Data Type | Description | Notes | 134 | | ------------- | ---------------- | ----------------------------------------------- | ----- | 135 | | Output | `string` | The response text from the LLM. | | 136 | | Messages Sent | `chat-message[]` | The messages that were sent to Ollama. | | 137 | | All Messages | `chat-message[]` | All messages, including the reply from the LLM. | | 138 | 139 | #### Editor Settings 140 | 141 | | Setting | Description | Default Value | Use Input Toggle | Input Data Type | 142 | | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------- | ---------------- | --------------- | 143 | | Model | The name of the LLM model in to use in Ollama. | (Empty) | Yes | `string` | 144 | | Prompt Format | The way to format chat messages for the prompt being sent to the ollama model. Raw means no formatting is applied. Llama 2 Instruct follows the [Llama 2 prompt format](https://gpus.llm-utils.org/llama-2-prompt-template/). | Llama 2 Instruct | No | N/A | 145 | | JSON Mode | Activates JSON output mode | false | Yes | `boolean` | 146 | | Parameters Group | | | | | 147 | | Mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | (unset) | Yes | `number` | 148 | | Mirostat Eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | (unset) | Yes | `number` | 149 | | Mirostat Tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | (unset) | Yes | `number` | 150 | | Num Ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | (unset) | Yes | `number` | 151 | | Num GQA | The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b | (unset) | Yes | `number` | 152 | | Num GPUs | The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. | (unset) | Yes | `number` | 153 | | Num Threads | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | (unset) | Yes | `number` | 154 | | Repeat Last N | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | (unset) | Yes | `number` | 155 | | Repeat Penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | (unset) | Yes | `number` | 156 | | Temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | (unset) | Yes | `number` | 157 | | Seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | (unset) | Yes | `number` | 158 | | Stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. | (unset) | Yes | `string` | 159 | | TFS Z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | (unset) | Yes | `number` | 160 | | Num Predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | (unset) | Yes | `number` | 161 | | Top K | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | (unset) | Yes | `number` | 162 | | Top P | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | (unset) | Yes | `number` | 163 | | Additional Parameters | Additional parameters to pass to Ollama. Numbers will be parsed and sent as numbers, otherwise they will be sent as strings. [See all supported parameters in Ollama](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md) | (none) | Yes | `object` | 164 | 165 | ### Ollama Embedding 166 | 167 | Embedding models are models that are trained specifically to generate vector embeddings: long arrays of numbers that represent semantic meaning for a given sequence of text. The resulting vector embedding arrays can then be stored in a database, which will compare them as a way to search for data that is similar in meaning. 168 | 169 | #### Inputs 170 | 171 | See Editor Settings for all possible inputs. 172 | 173 | #### Outputs 174 | 175 | | Title | Data Type | Description | Notes | 176 | | --------- | --------- | ------------------------------------------------------------------------------ | ----- | 177 | | Embedding | `vector` | Array of numbers that represent semantic meaning for a given sequence of text. | | 178 | 179 | #### Editor Settings 180 | 181 | | Setting | Description | Default Value | Use Input Toggle | Input Data Type | 182 | | ---------- | ----------------------------- | ------------- | ----------------- | --------------- | 183 | | Model Name | The name of the model to get. | (Empty) | Yes (default off) | `string` | 184 | | Text | The text to embed. | (Empty) | Yes (default off) | `string` | 185 | 186 | ### Ollama Generate 187 | 188 | Previously the main node of the plugin. Allows you to send prompts to Ollama and receive responses from the LLMs installed with deep customization options even including custom prompt formats. Uses /api/generate route 189 | 190 | #### Inputs 191 | 192 | | Title | Data Type | Description | Default Value | Notes | 193 | | ------------- | ---------------- | --------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 194 | | System Prompt | `string` | The system prompt to prepend to the messages list. | (none) | Optional. | 195 | | Messages | 'chat-message[]' | The chat messages to use as the prompt for the LLM. | (none) | Chat messages are converted to a prompt in Ollama based on the "Prompt Format" editor setting. If "Raw" is selected, no formatting is performed on the chat messages, and you are expected to have already formatted them in your Rivet graphs. | 196 | 197 | Additional inputs available with toggles in the editor. 198 | 199 | #### Outputs 200 | 201 | | Title | Data Type | Description | Notes | 202 | | -------------------- | ---------------- | ---------------------------------------------------------- | ----------------------------------------------------------- | 203 | | Output | `string` | The response text from the LLM. | | 204 | | Prompt | `string` | The full prompt, with formatting, that was sent to Ollama. | | 205 | | Messages Sent | `chat-message[]` | The messages that were sent to Ollama. | | 206 | | All Messages | `chat-message[]` | All messages, including the reply from the LLM. | | 207 | | Total Duration | `number` | Time spent generating the response. | Only available if the "Advanced Outputs" toggle is enabled. | 208 | | Load Duration | `number` | Time spent in nanoseconds loading the model. | Only available if the "Advanced Outputs" toggle is enabled. | 209 | | Sample Count | `number` | Number of samples generated. | Only available if the "Advanced Outputs" toggle is enabled. | 210 | | Sample Duration | `number` | Time spent in nanoseconds generating samples. | Only available if the "Advanced Outputs" toggle is enabled. | 211 | | Prompt Eval Count | `number` | Number of tokens in the prompt. | Only available if the "Advanced Outputs" toggle is enabled. | 212 | | Prompt Eval Duration | `number` | Time spent in nanoseconds evaluating the prompt. | Only available if the "Advanced Outputs" toggle is enabled. | 213 | | Eval Count | `number` | Number of tokens in the response. | Only available if the "Advanced Outputs" toggle is enabled. | 214 | | Eval Duration | `number` | Time spent in nanoseconds evaluating the response. | Only available if the "Advanced Outputs" toggle is enabled. | 215 | | Tokens Per Second | `number` | Number of tokens generated per second. | Only available if the "Advanced Outputs" toggle is enabled. | 216 | | Parameters | `object` | The parameters used to generate the response. | Only available if the "Advanced Outputs" toggle is enabled. | 217 | 218 | #### Editor Settings 219 | 220 | | Setting | Description | Default Value | Use Input Toggle | Input Data Type | 221 | | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------- | ---------------- | --------------- | 222 | | Model | The name of the LLM model in to use in Ollama. | (Empty) | Yes | `string` | 223 | | Prompt Format | The way to format chat messages for the prompt being sent to the ollama model. Raw means no formatting is applied. Llama 2 Instruct follows the [Llama 2 prompt format](https://gpus.llm-utils.org/llama-2-prompt-template/). | Llama 2 Instruct | No | N/A | 224 | | JSON Mode | Activates JSON output mode | false | Yes | `boolean` | 225 | | Advanced Outputs | Add additional outputs with detailed information about the Ollama execution. | No | No | N/A | 226 | | Parameters Group | | | | | 227 | | Mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | (unset) | Yes | `number` | 228 | | Mirostat Eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | (unset) | Yes | `number` | 229 | | Mirostat Tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | (unset) | Yes | `number` | 230 | | Num Ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | (unset) | Yes | `number` | 231 | | Num GQA | The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b | (unset) | Yes | `number` | 232 | | Num GPUs | The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. | (unset) | Yes | `number` | 233 | | Num Threads | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | (unset) | Yes | `number` | 234 | | Repeat Last N | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | (unset) | Yes | `number` | 235 | | Repeat Penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | (unset) | Yes | `number` | 236 | | Temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | (unset) | Yes | `number` | 237 | | Seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | (unset) | Yes | `number` | 238 | | Stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. | (unset) | Yes | `string` | 239 | | TFS Z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | (unset) | Yes | `number` | 240 | | Num Predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | (unset) | Yes | `number` | 241 | | Top K | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | (unset) | Yes | `number` | 242 | | Top P | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | (unset) | Yes | `number` | 243 | | Additional Parameters | Additional parameters to pass to Ollama. Numbers will be parsed and sent as numbers, otherwise they will be sent as strings. [See all supported parameters in Ollama](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md) | (none) | Yes | `object` | 244 | 245 | ### List Ollama Models 246 | 247 | Lists the models installed in Ollama. 248 | 249 | #### Inputs 250 | 251 | This node has no inputs. 252 | 253 | #### Outputs 254 | 255 | | Title | Data Type | Description | Notes | 256 | | ----------- | ---------- | -------------------------------------------- | ----- | 257 | | Model Names | `string[]` | The names of the models installed in Ollama. | | 258 | 259 | #### Editor Settings 260 | 261 | This node has no editor settings. 262 | 263 | ### Get Ollama Model 264 | 265 | Gets the model with the given name from Ollama. 266 | 267 | #### Inputs 268 | 269 | See Editor Settings for all possible inputs. 270 | 271 | #### Outputs 272 | 273 | | Title | Data Type | Description | Notes | 274 | | ---------- | --------- | ------------------------------------------- | ----- | 275 | | License | `string` | Contents of the license block of the model. | | 276 | | Modelfile | `string` | The Ollama modelfile for the model" | | 277 | | Parameters | `string` | The parameters for the model. | | 278 | | Template | `string` | The template for the model. | | 279 | 280 | #### Editor Settings 281 | 282 | | Setting | Description | Default Value | Use Input Toggle | Input Data Type | 283 | | ---------- | ----------------------------- | ------------- | ---------------- | --------------- | 284 | | Model Name | The name of the model to get. | (Empty) | Yes (default on) | `string` | 285 | 286 | ### Pull Model to Ollama 287 | 288 | Downloads a model from the Ollama library to the Ollama server. 289 | 290 | #### Inputs 291 | 292 | See Editor Settings for all possible inputs. 293 | 294 | #### Outputs 295 | 296 | | Title | Data Type | Description | Notes | 297 | | ---------- | --------- | -------------------------------------- | ----- | 298 | | Model Name | `string` | The name of the model that was pulled. | | 299 | 300 | #### Editor Settings 301 | 302 | | Setting | Description | Default Value | Use Input Toggle | Input Data Type | 303 | | ---------- | --------------------------------------------------------------------------------------------------------------------- | ------------- | ---------------- | --------------- | 304 | | Model Name | The name of the model to pull. | (Empty) | Yes (default on) | `string` | 305 | | Insecure | Allow insecure connections to the library. Only use this if you are pulling from your own library during development. | No | No | N/A | 306 | 307 | ## Local Development 308 | 309 | 1. Run `yarn dev` to start the compiler and bundler in watch mode. This will automatically recombine and rebundle your changes into the `dist` folder. This will also copy the bundled files into the plugin install directory. 310 | 2. After each change, you must restart Rivet to see the changes. 311 | -------------------------------------------------------------------------------- /bundle.ts: -------------------------------------------------------------------------------- 1 | import * as esbuild from "esbuild"; 2 | import { match } from "ts-pattern"; 3 | import { join, dirname } from "node:path"; 4 | import copy from "recursive-copy"; 5 | import { platform, homedir } from "node:os"; 6 | import { readFile, rm, mkdir, copyFile } from "node:fs/promises"; 7 | import { fileURLToPath } from "node:url"; 8 | 9 | const __dirname = dirname(fileURLToPath(import.meta.url)); 10 | 11 | // Roughly https://github.com/demurgos/appdata-path/blob/master/lib/index.js but appdata local and .local/share, try to match `dirs` from rust 12 | function getAppDataLocalPath() { 13 | const identifier = "com.ironcladapp.rivet"; 14 | return match(platform()) 15 | .with("win32", () => join(homedir(), "AppData", "Local", identifier)) 16 | .with("darwin", () => 17 | join(homedir(), "Library", "Application Support", identifier) 18 | ) 19 | .with("linux", () => join(homedir(), ".local", "share", identifier)) 20 | .otherwise(() => { 21 | if (platform().startsWith("win")) { 22 | return join(homedir(), "AppData", "Local", identifier); 23 | } else { 24 | return join(homedir(), ".local", "share", identifier); 25 | } 26 | }); 27 | } 28 | 29 | const syncPlugin: esbuild.Plugin = { 30 | name: "onBuild", 31 | setup(build) { 32 | build.onEnd(async () => { 33 | const packageJson = JSON.parse( 34 | await readFile(join(__dirname, "package.json"), "utf-8") 35 | ); 36 | const pluginName = packageJson.name; 37 | 38 | const rivetPluginsDirectory = join(getAppDataLocalPath(), "plugins"); 39 | const thisPluginDirectory = join( 40 | rivetPluginsDirectory, 41 | `${pluginName}-latest` 42 | ); 43 | 44 | await rm(join(thisPluginDirectory, "package"), { 45 | recursive: true, 46 | force: true, 47 | }); 48 | await mkdir(join(thisPluginDirectory, "package"), { recursive: true }); 49 | 50 | await copy( 51 | join(__dirname, "dist"), 52 | join(thisPluginDirectory, "package", "dist") 53 | ); 54 | await copyFile( 55 | join(__dirname, "package.json"), 56 | join(thisPluginDirectory, "package", "package.json") 57 | ); 58 | 59 | // Copy .git to mark as locally installed plugin 60 | await copy( 61 | join(__dirname, ".git"), 62 | join(thisPluginDirectory, "package", ".git") 63 | ); 64 | 65 | console.log( 66 | `Synced ${pluginName} to Rivet at ${thisPluginDirectory}. Refresh or restart Rivet to see changes.` 67 | ); 68 | }); 69 | }, 70 | }; 71 | 72 | const options = { 73 | entryPoints: ["src/index.ts"], 74 | bundle: true, 75 | platform: "neutral", 76 | target: "es2020", 77 | outfile: "dist/bundle.js", 78 | format: "esm", 79 | logLevel: "info", 80 | plugins: [] as esbuild.Plugin[], 81 | } satisfies esbuild.BuildOptions; 82 | 83 | if (process.argv.includes("--sync")) { 84 | options.plugins.push(syncPlugin); 85 | } 86 | 87 | if (process.argv.includes("--watch")) { 88 | const context = await esbuild.context(options); 89 | 90 | await context.watch(); 91 | 92 | console.log("Watching for changes..."); 93 | } else { 94 | await esbuild.build(options); 95 | } 96 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "rivet-plugin-ollama", 3 | "packageManager": "yarn@3.5.0", 4 | "version": "0.4.0", 5 | "type": "module", 6 | "main": "dist/bundle.js", 7 | "license": "MIT", 8 | "files": [ 9 | "dist" 10 | ], 11 | "scripts": { 12 | "build": "tsc -b && tsx bundle.ts", 13 | "dev": "run-p watch:tsc watch:esbuild:sync", 14 | "watch:tsc": "tsc -b -w --preserveWatchOutput", 15 | "watch:esbuild": "tsx bundle.ts --watch", 16 | "watch:esbuild:sync": "tsx bundle.ts --watch --sync", 17 | "publish": "yarn npm publish --access public" 18 | }, 19 | "dependencies": { 20 | "@ironclad/rivet-core": "^1.13.2", 21 | "ts-pattern": "^5.0.5" 22 | }, 23 | "devDependencies": { 24 | "esbuild": "^0.19.2", 25 | "npm-run-all": "^4.1.5", 26 | "recursive-copy": "^2.0.14", 27 | "tsx": "^3.12.10", 28 | "typescript": "^5.2.2" 29 | }, 30 | "volta": { 31 | "node": "20.6.1" 32 | }, 33 | "rivet": { 34 | "skipInstall": true 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | import type { RivetPlugin, RivetPluginInitializer } from "@ironclad/rivet-core"; 2 | import { ollamaChat } from "./nodes/OllamaGenerateNode"; 3 | import { ollamaChat2 } from "./nodes/OllamaChatNode"; 4 | import { ollamaEmbed } from "./nodes/OllamaEmbeddingNode"; 5 | import { getOllamaModel } from "./nodes/GetOllamaModelNode"; 6 | import { listOllamaModels } from "./nodes/ListOllamaModelsNode"; 7 | import { pullModelToOllama } from "./nodes/PullModelToOllamaNode"; 8 | 9 | const plugin: RivetPluginInitializer = (rivet) => { 10 | const examplePlugin: RivetPlugin = { 11 | id: "ollama", 12 | name: "Ollama Plugin", 13 | 14 | configSpec: { 15 | host: { 16 | label: "Host", 17 | type: "string", 18 | default: "http://localhost:11434", 19 | description: 20 | "The host to use for the Ollama API. Defaults to http://localhost:11434.", 21 | helperText: 22 | "The host to use for the Ollama API. Defaults to http://localhost:11434.", 23 | }, 24 | apiKey: { 25 | label: "API Key", 26 | type: "secret", 27 | default: "", 28 | description: 29 | "Optional API key for authentication with Ollama instances that require it.", 30 | helperText: 31 | "Leave empty if your Ollama instance doesn't require authentication.", 32 | }, 33 | }, 34 | 35 | contextMenuGroups: [ 36 | { 37 | id: "ollama", 38 | label: "Ollama", 39 | }, 40 | ], 41 | 42 | register: (register) => { 43 | register(ollamaChat(rivet)); 44 | register(ollamaChat2(rivet)); 45 | register(ollamaEmbed(rivet)); 46 | register(getOllamaModel(rivet)); 47 | register(listOllamaModels(rivet)); 48 | register(pullModelToOllama(rivet)); 49 | }, 50 | }; 51 | 52 | return examplePlugin; 53 | }; 54 | 55 | export default plugin; 56 | -------------------------------------------------------------------------------- /src/nodes/GetOllamaModelNode.ts: -------------------------------------------------------------------------------- 1 | import type { 2 | ChartNode, 3 | NodeId, 4 | NodeInputDefinition, 5 | NodeUIData, 6 | PluginNodeImpl, 7 | PortId, 8 | Rivet, 9 | pluginNodeDefinition, 10 | } from "@ironclad/rivet-core"; 11 | 12 | export type GetOllamaModelNode = ChartNode< 13 | "getOllamaModel", 14 | { 15 | modelName: string; 16 | useModelNameInput?: boolean; 17 | 18 | host?: string; 19 | useHostInput?: boolean; 20 | 21 | apiKey?: string; 22 | useApiKeyInput?: boolean; 23 | 24 | headers?: { key: string; value: string }[]; 25 | useHeadersInput?: boolean; 26 | } 27 | >; 28 | 29 | export const getOllamaModel = (rivet: typeof Rivet) => { 30 | const impl: PluginNodeImpl = { 31 | create() { 32 | return { 33 | id: rivet.newId(), 34 | data: { 35 | modelName: "", 36 | useModelNameInput: true, 37 | }, 38 | title: "Get Ollama Model", 39 | type: "getOllamaModel", 40 | visualData: { 41 | x: 0, 42 | y: 0, 43 | width: 250, 44 | }, 45 | } satisfies GetOllamaModelNode; 46 | }, 47 | 48 | getInputDefinitions(data) { 49 | const inputs: NodeInputDefinition[] = []; 50 | 51 | if (data.useModelNameInput) { 52 | inputs.push({ 53 | id: "modelName" as PortId, 54 | dataType: "string", 55 | title: "Model Name", 56 | description: "The name of the model to get.", 57 | }); 58 | } 59 | 60 | if (data.useHostInput) { 61 | inputs.push({ 62 | dataType: "string", 63 | id: "host" as PortId, 64 | title: "Host", 65 | description: 66 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 67 | }); 68 | } 69 | 70 | if (data.useApiKeyInput) { 71 | inputs.push({ 72 | dataType: "string", 73 | id: "apiKey" as PortId, 74 | title: "API Key", 75 | description: 76 | "Optional API key for authentication with Ollama instances that require it.", 77 | }); 78 | } 79 | 80 | if (data.useHeadersInput) { 81 | inputs.push({ 82 | dataType: 'object', 83 | id: 'headers' as PortId, 84 | title: 'Headers', 85 | description: 'Additional headers to send to the API.', 86 | }); 87 | } 88 | 89 | return inputs; 90 | }, 91 | 92 | getOutputDefinitions() { 93 | return [ 94 | { 95 | id: "license" as PortId, 96 | dataType: "string", 97 | title: "License", 98 | description: "Contents of the license block of the model", 99 | }, 100 | { 101 | id: "modelfile" as PortId, 102 | dataType: "string", 103 | title: "Modelfile", 104 | description: "The Ollama modelfile for the model", 105 | }, 106 | { 107 | id: "parameters" as PortId, 108 | dataType: "string", 109 | title: "Parameters", 110 | description: "The parameters for the model", 111 | }, 112 | { 113 | id: "template" as PortId, 114 | dataType: "string", 115 | title: "Template", 116 | description: "The template for the model", 117 | }, 118 | ]; 119 | }, 120 | 121 | getEditors() { 122 | return [ 123 | { 124 | type: "string", 125 | dataKey: "modelName", 126 | useInputToggleDataKey: "useModelNameInput", 127 | label: "Model Name", 128 | helperMessage: "The name of the model to get.", 129 | placeholder: "Model Name", 130 | }, 131 | { 132 | type: "group", 133 | label: "Advanced", 134 | editors: [ 135 | { 136 | type: "string", 137 | label: "Host", 138 | dataKey: "host", 139 | useInputToggleDataKey: "useHostInput", 140 | helperMessage: 141 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 142 | }, 143 | { 144 | type: "string", 145 | label: "API Key", 146 | dataKey: "apiKey", 147 | useInputToggleDataKey: "useApiKeyInput", 148 | helperMessage: 149 | "Optional API key for authentication with Ollama instances that require it. Will be sent as Authorization Bearer token.", 150 | }, 151 | { 152 | type: "keyValuePair", 153 | label: "Headers", 154 | dataKey: "headers", 155 | useInputToggleDataKey: "useHeadersInput", 156 | keyPlaceholder: "Header Name", 157 | valuePlaceholder: "Header Value", 158 | helperMessage: 159 | "Additional headers to send to the API.", 160 | }, 161 | ], 162 | }, 163 | ]; 164 | }, 165 | 166 | getBody(data) { 167 | return rivet.dedent` 168 | Model: ${ 169 | data.useModelNameInput ? "(From Input)" : data.modelName || "Unset!" 170 | } 171 | `; 172 | }, 173 | 174 | getUIData(): NodeUIData { 175 | return { 176 | contextMenuTitle: "Get Ollama Model", 177 | group: "Ollama", 178 | infoBoxTitle: "Get Ollama Model Node", 179 | infoBoxBody: "Gets information about a model from Ollama.", 180 | }; 181 | }, 182 | 183 | async process(data, inputData, context) { 184 | const hostInput = rivet.getInputOrData(data, inputData, "host", "string"); 185 | const host = 186 | hostInput || 187 | context.getPluginConfig("host") || 188 | "http://localhost:11434"; 189 | 190 | if (!host.trim()) { 191 | throw new Error("No host set!"); 192 | } 193 | 194 | const apiKeyInput = rivet.getInputOrData( 195 | data, 196 | inputData, 197 | "apiKey", 198 | "string", 199 | ); 200 | const apiKey = apiKeyInput || context.getPluginConfig("apiKey"); 201 | 202 | const modelName = rivet.getInputOrData(data, inputData, "modelName"); 203 | 204 | const headers: Record = { 205 | "Content-Type": "application/json", 206 | }; 207 | 208 | if (apiKey && apiKey.trim()) { 209 | headers["Authorization"] = `Bearer ${apiKey}`; 210 | } 211 | 212 | // Add headers from data or input 213 | let additionalHeaders: Record = {}; 214 | if (data.useHeadersInput) { 215 | const headersInput = rivet.coerceTypeOptional( 216 | inputData["headers" as PortId], 217 | "object", 218 | ) as Record | undefined; 219 | if (headersInput) { 220 | additionalHeaders = headersInput; 221 | } 222 | } else if (data.headers) { 223 | additionalHeaders = data.headers.reduce( 224 | (acc, { key, value }) => { 225 | acc[key] = value; 226 | return acc; 227 | }, 228 | {} as Record, 229 | ); 230 | } 231 | 232 | Object.assign(headers, additionalHeaders); 233 | 234 | const response = await fetch(`${host}/api/show`, { 235 | method: "POST", 236 | headers, 237 | body: JSON.stringify({ 238 | name: modelName, 239 | }), 240 | }); 241 | 242 | if (response.status === 404) { 243 | return { 244 | ["license" as PortId]: { 245 | type: "control-flow-excluded", 246 | value: undefined, 247 | }, 248 | ["modelfile" as PortId]: { 249 | type: "control-flow-excluded", 250 | value: undefined, 251 | }, 252 | ["parameters" as PortId]: { 253 | type: "control-flow-excluded", 254 | value: undefined, 255 | }, 256 | ["template" as PortId]: { 257 | type: "control-flow-excluded", 258 | value: undefined, 259 | }, 260 | }; 261 | } 262 | 263 | if (!response.ok) { 264 | try { 265 | const body = await response.text(); 266 | throw new Error(`Error from Ollama: ${body}`); 267 | } catch (err) { 268 | throw new Error( 269 | `Error ${response.status} from Ollama: ${ 270 | rivet.getError(err).message 271 | }` 272 | ); 273 | } 274 | } 275 | 276 | const { license, modelfile, parameters, template } = 277 | (await response.json()) as { 278 | license: string; 279 | modelfile: string; 280 | parameters: string; 281 | template: string; 282 | }; 283 | 284 | return { 285 | ["license" as PortId]: { 286 | type: "string", 287 | value: license, 288 | }, 289 | ["modelfile" as PortId]: { 290 | type: "string", 291 | value: modelfile, 292 | }, 293 | ["parameters" as PortId]: { 294 | type: "string", 295 | value: parameters, 296 | }, 297 | ["template" as PortId]: { 298 | type: "string", 299 | value: template, 300 | }, 301 | }; 302 | }, 303 | }; 304 | 305 | return rivet.pluginNodeDefinition(impl, "List Ollama Models"); 306 | }; 307 | -------------------------------------------------------------------------------- /src/nodes/ListOllamaModelsNode.ts: -------------------------------------------------------------------------------- 1 | import type { 2 | ChartNode, 3 | NodeId, 4 | NodeInputDefinition, 5 | NodeUIData, 6 | PluginNodeImpl, 7 | PortId, 8 | Rivet, 9 | } from "@ironclad/rivet-core"; 10 | 11 | export type ListOllamaModelsNode = ChartNode< 12 | "listOllamaModels", 13 | { 14 | host?: string; 15 | useHostInput?: boolean; 16 | 17 | apiKey?: string; 18 | useApiKeyInput?: boolean; 19 | 20 | headers?: { key: string; value: string }[]; 21 | useHeadersInput?: boolean; 22 | } 23 | >; 24 | 25 | export const listOllamaModels = (rivet: typeof Rivet) => { 26 | const impl: PluginNodeImpl = { 27 | create() { 28 | return { 29 | id: rivet.newId(), 30 | data: {}, 31 | title: "List Ollama Models", 32 | type: "listOllamaModels", 33 | visualData: { 34 | x: 0, 35 | y: 0, 36 | width: 300, 37 | }, 38 | } satisfies ListOllamaModelsNode; 39 | }, 40 | 41 | getInputDefinitions(data) { 42 | const inputs: NodeInputDefinition[] = []; 43 | 44 | if (data.useHostInput) { 45 | inputs.push({ 46 | dataType: "string", 47 | id: "host" as PortId, 48 | title: "Host", 49 | description: 50 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 51 | }); 52 | } 53 | 54 | if (data.useApiKeyInput) { 55 | inputs.push({ 56 | dataType: "string", 57 | id: "apiKey" as PortId, 58 | title: "API Key", 59 | description: 60 | "Optional API key for authentication with Ollama instances that require it.", 61 | }); 62 | } 63 | 64 | if (data.useHeadersInput) { 65 | inputs.push({ 66 | dataType: 'object', 67 | id: 'headers' as PortId, 68 | title: 'Headers', 69 | description: 'Additional headers to send to the API.', 70 | }); 71 | } 72 | 73 | return inputs; 74 | }, 75 | 76 | getOutputDefinitions() { 77 | return [ 78 | { 79 | id: "modelNames" as PortId, 80 | dataType: "string[]", 81 | title: "Model Names", 82 | }, 83 | ]; 84 | }, 85 | 86 | getEditors() { 87 | return [ 88 | { 89 | type: "group", 90 | label: "Advanced", 91 | editors: [ 92 | { 93 | type: "string", 94 | label: "Host", 95 | dataKey: "host", 96 | useInputToggleDataKey: "useHostInput", 97 | helperMessage: 98 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 99 | }, 100 | { 101 | type: "string", 102 | label: "API Key", 103 | dataKey: "apiKey", 104 | useInputToggleDataKey: "useApiKeyInput", 105 | helperMessage: 106 | "Optional API key for authentication with Ollama instances that require it. Will be sent as Authorization Bearer token.", 107 | }, 108 | { 109 | type: "keyValuePair", 110 | label: "Headers", 111 | dataKey: "headers", 112 | useInputToggleDataKey: "useHeadersInput", 113 | keyPlaceholder: "Header Name", 114 | valuePlaceholder: "Header Value", 115 | helperMessage: 116 | "Additional headers to send to the API.", 117 | }, 118 | ], 119 | }, 120 | ]; 121 | }, 122 | 123 | getBody() { 124 | return ""; 125 | }, 126 | 127 | getUIData(): NodeUIData { 128 | return { 129 | contextMenuTitle: "List Ollama Models", 130 | group: "Ollama", 131 | infoBoxTitle: "List Ollama Models Node", 132 | infoBoxBody: "Lists all models that are available in Ollama.", 133 | }; 134 | }, 135 | 136 | async process(data, inputData, context) { 137 | const hostInput = rivet.getInputOrData(data, inputData, "host", "string"); 138 | const host = 139 | hostInput || 140 | context.getPluginConfig("host") || 141 | "http://localhost:11434"; 142 | 143 | if (!host.trim()) { 144 | throw new Error("No host set!"); 145 | } 146 | 147 | const apiKeyInput = rivet.getInputOrData( 148 | data, 149 | inputData, 150 | "apiKey", 151 | "string", 152 | ); 153 | const apiKey = apiKeyInput || context.getPluginConfig("apiKey"); 154 | 155 | const headers: Record = {}; 156 | 157 | if (apiKey && apiKey.trim()) { 158 | headers["Authorization"] = `Bearer ${apiKey}`; 159 | } 160 | 161 | // Add headers from data or input 162 | let additionalHeaders: Record = {}; 163 | if (data.useHeadersInput) { 164 | const headersInput = rivet.coerceTypeOptional( 165 | inputData["headers" as PortId], 166 | "object", 167 | ) as Record | undefined; 168 | if (headersInput) { 169 | additionalHeaders = headersInput; 170 | } 171 | } else if (data.headers) { 172 | additionalHeaders = data.headers.reduce( 173 | (acc, { key, value }) => { 174 | acc[key] = value; 175 | return acc; 176 | }, 177 | {} as Record, 178 | ); 179 | } 180 | 181 | Object.assign(headers, additionalHeaders); 182 | 183 | const response = await fetch(`${host}/api/tags`, { 184 | method: "GET", 185 | headers, 186 | }); 187 | 188 | if (!response.ok) { 189 | try { 190 | const body = await response.text(); 191 | throw new Error(`Error from Ollama: ${body}`); 192 | } catch (err) { 193 | throw new Error( 194 | `Error ${response.status} from Ollama: ${ 195 | rivet.getError(err).message 196 | }` 197 | ); 198 | } 199 | } 200 | 201 | const { models } = (await response.json()) as { 202 | models: { 203 | name: string; 204 | modified_at: string; 205 | size: number; 206 | }[]; 207 | }; 208 | 209 | return { 210 | ["modelNames" as PortId]: { 211 | type: "string[]", 212 | value: models.map((model) => model.name), 213 | }, 214 | }; 215 | }, 216 | }; 217 | 218 | return rivet.pluginNodeDefinition(impl, "List Ollama Models"); 219 | }; 220 | -------------------------------------------------------------------------------- /src/nodes/OllamaChatNode.ts: -------------------------------------------------------------------------------- 1 | import type { 2 | ChartNode, 3 | ChatMessage, 4 | EditorDefinition, 5 | NodeId, 6 | NodeInputDefinition, 7 | NodeOutputDefinition, 8 | NodeUIData, 9 | Outputs, 10 | PluginNodeImpl, 11 | PortId, 12 | Rivet, 13 | } from "@ironclad/rivet-core"; 14 | 15 | export type OllamaChatNodeData = { 16 | model: string; 17 | useModelInput?: boolean; 18 | 19 | jsonMode: boolean; 20 | 21 | advancedOutputs: boolean; 22 | 23 | numPredict?: number; 24 | useNumPredictInput?: boolean; 25 | 26 | temperature?: number; 27 | useTemperatureInput?: boolean; 28 | 29 | // PARAMETERS 30 | 31 | mirostat?: number; 32 | useMirostatInput?: boolean; 33 | 34 | mirostatEta?: number; 35 | useMirostatEtaInput?: boolean; 36 | 37 | mirostatTau?: number; 38 | useMirostatTauInput?: boolean; 39 | 40 | numCtx?: number; 41 | useNumCtxInput?: boolean; 42 | 43 | numGqa?: number; 44 | useNumGqaInput?: boolean; 45 | 46 | numGpu?: number; 47 | useNumGpuInput?: boolean; 48 | 49 | numThread?: number; 50 | useNumThreadInput?: boolean; 51 | 52 | repeatLastN?: number; 53 | useRepeatLastNInput?: boolean; 54 | 55 | repeatPenalty?: number; 56 | useRepeatPenaltyInput?: boolean; 57 | 58 | seed?: number; 59 | useSeedInput?: boolean; 60 | 61 | stop: string; 62 | useStopInput?: boolean; 63 | 64 | tfsZ?: number; 65 | useTfsZInput?: boolean; 66 | 67 | topK?: number; 68 | useTopKInput?: boolean; 69 | 70 | topP?: number; 71 | useTopPInput?: boolean; 72 | 73 | additionalParameters?: { key: string; value: string }[]; 74 | useAdditionalParametersInput?: boolean; 75 | 76 | host?: string; 77 | useHostInput?: boolean; 78 | 79 | apiKey?: string; 80 | useApiKeyInput?: boolean; 81 | 82 | headers?: { key: string; value: string }[]; 83 | useHeadersInput?: boolean; 84 | }; 85 | 86 | export type OllamaChatNode = ChartNode<"ollamaChat2", OllamaChatNodeData>; 87 | 88 | type OllamaStreamingContentResponse = { 89 | model: string; 90 | created_at: string; 91 | done: false; 92 | message: { 93 | role: string; 94 | content: string; 95 | }; 96 | }; 97 | 98 | type OllamaStreamingFinalResponse = { 99 | model: string; 100 | created_at: string; 101 | message: { 102 | role: string; 103 | content: string; 104 | }; 105 | done: true; 106 | total_duration: number; 107 | load_duration: number; 108 | prompt_eval_count: number; 109 | prompt_eval_duration: number; 110 | eval_count: number; 111 | eval_duration: number; 112 | }; 113 | 114 | type OllamaStreamingGenerateResponse = 115 | | OllamaStreamingContentResponse 116 | | OllamaStreamingFinalResponse; 117 | 118 | export const ollamaChat2 = (rivet: typeof Rivet) => { 119 | const impl: PluginNodeImpl = { 120 | create(): OllamaChatNode { 121 | const node: OllamaChatNode = { 122 | id: rivet.newId(), 123 | data: { 124 | model: "", 125 | useModelInput: false, 126 | numPredict: 1024, 127 | jsonMode: false, 128 | advancedOutputs: false, 129 | stop: "", 130 | }, 131 | title: "Ollama Chat", 132 | type: "ollamaChat2", 133 | visualData: { 134 | x: 0, 135 | y: 0, 136 | width: 250, 137 | }, 138 | }; 139 | return node; 140 | }, 141 | 142 | getInputDefinitions(data): NodeInputDefinition[] { 143 | const inputs: NodeInputDefinition[] = []; 144 | 145 | inputs.push({ 146 | id: "system-prompt" as PortId, 147 | dataType: "string", 148 | title: "System Prompt", 149 | description: "The system prompt to prepend to the messages list.", 150 | required: false, 151 | coerced: true, 152 | }); 153 | 154 | inputs.push({ 155 | id: "messages" as PortId, 156 | dataType: ["chat-message[]", "chat-message"], 157 | title: "Messages", 158 | description: "The chat messages to use as the prompt.", 159 | }); 160 | 161 | if (data.useModelInput) { 162 | inputs.push({ 163 | id: "model" as PortId, 164 | dataType: "string", 165 | title: "Model", 166 | }); 167 | } 168 | 169 | if (data.useMirostatInput) { 170 | inputs.push({ 171 | id: "mirostat" as PortId, 172 | dataType: "number", 173 | title: "Mirostat", 174 | description: 'The "mirostat" parameter.', 175 | }); 176 | } 177 | 178 | if (data.useMirostatEtaInput) { 179 | inputs.push({ 180 | id: "mirostatEta" as PortId, 181 | dataType: "number", 182 | title: "Mirostat Eta", 183 | description: 'The "mirostat_eta" parameter.', 184 | }); 185 | } 186 | 187 | if (data.useMirostatTauInput) { 188 | inputs.push({ 189 | id: "mirostatTau" as PortId, 190 | dataType: "number", 191 | title: "Mirostat Tau", 192 | description: 'The "mirostat_tau" parameter.', 193 | }); 194 | } 195 | 196 | if (data.useNumCtxInput) { 197 | inputs.push({ 198 | id: "numCtx" as PortId, 199 | dataType: "number", 200 | title: "Num Ctx", 201 | description: 'The "num_ctx" parameter.', 202 | }); 203 | } 204 | 205 | if (data.useNumGqaInput) { 206 | inputs.push({ 207 | id: "numGqa" as PortId, 208 | dataType: "number", 209 | title: "Num GQA", 210 | description: 'The "num_gqa" parameter.', 211 | }); 212 | } 213 | 214 | if (data.useNumGpuInput) { 215 | inputs.push({ 216 | id: "numGpu" as PortId, 217 | dataType: "number", 218 | title: "Num GPUs", 219 | description: 'The "num_gpu" parameter.', 220 | }); 221 | } 222 | 223 | if (data.useNumThreadInput) { 224 | inputs.push({ 225 | id: "numThread" as PortId, 226 | dataType: "number", 227 | title: "Num Threads", 228 | description: 'The "num_thread" parameter.', 229 | }); 230 | } 231 | 232 | if (data.useRepeatLastNInput) { 233 | inputs.push({ 234 | id: "repeatLastN" as PortId, 235 | dataType: "number", 236 | title: "Repeat Last N", 237 | description: 'The "repeat_last_n" parameter.', 238 | }); 239 | } 240 | 241 | if (data.useRepeatPenaltyInput) { 242 | inputs.push({ 243 | id: "repeatPenalty" as PortId, 244 | dataType: "number", 245 | title: "Repeat Penalty", 246 | description: 'The "repeat_penalty" parameter.', 247 | }); 248 | } 249 | 250 | if (data.useTemperatureInput) { 251 | inputs.push({ 252 | id: "temperature" as PortId, 253 | dataType: "number", 254 | title: "Temperature", 255 | description: 'The "temperature" parameter.', 256 | }); 257 | } 258 | 259 | if (data.useSeedInput) { 260 | inputs.push({ 261 | id: "seed" as PortId, 262 | dataType: "number", 263 | title: "Seed", 264 | description: 'The "seed" parameter.', 265 | }); 266 | } 267 | 268 | if (data.useStopInput) { 269 | inputs.push({ 270 | id: "stop" as PortId, 271 | dataType: "string[]", 272 | title: "Stop", 273 | description: 'The "stop" parameter.', 274 | }); 275 | } 276 | 277 | if (data.useTfsZInput) { 278 | inputs.push({ 279 | id: "tfsZ" as PortId, 280 | dataType: "number", 281 | title: "TFS Z", 282 | description: 'The "tfs_z" parameter.', 283 | }); 284 | } 285 | 286 | if (data.useNumPredictInput) { 287 | inputs.push({ 288 | id: "numPredict" as PortId, 289 | dataType: "number", 290 | title: "Num Predict", 291 | description: 'The "num_predict" parameter.', 292 | }); 293 | } 294 | 295 | if (data.useTopKInput) { 296 | inputs.push({ 297 | id: "topK" as PortId, 298 | dataType: "number", 299 | title: "Top K", 300 | description: 'The "top_k" parameter.', 301 | }); 302 | } 303 | 304 | if (data.useTopPInput) { 305 | inputs.push({ 306 | id: "topP" as PortId, 307 | dataType: "number", 308 | title: "Top P", 309 | description: 'The "top_p" parameter.', 310 | }); 311 | } 312 | 313 | if (data.useHostInput) { 314 | inputs.push({ 315 | dataType: "string", 316 | id: "host" as PortId, 317 | title: "Host", 318 | description: 319 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 320 | }); 321 | } 322 | 323 | if (data.useApiKeyInput) { 324 | inputs.push({ 325 | dataType: "string", 326 | id: "apiKey" as PortId, 327 | title: "API Key", 328 | description: 329 | "Optional API key for authentication with Ollama instances that require it.", 330 | }); 331 | } 332 | 333 | if (data.useAdditionalParametersInput) { 334 | inputs.push({ 335 | id: "additionalParameters" as PortId, 336 | dataType: "object", 337 | title: "Additional Parameters", 338 | description: "Additional parameters to pass to Ollama.", 339 | }); 340 | } 341 | 342 | if (data.useHeadersInput) { 343 | inputs.push({ 344 | dataType: 'object', 345 | id: 'headers' as PortId, 346 | title: 'Headers', 347 | description: 'Additional headers to send to the API.', 348 | }); 349 | } 350 | 351 | return inputs; 352 | }, 353 | 354 | getOutputDefinitions(data): NodeOutputDefinition[] { 355 | let outputs: NodeOutputDefinition[] = [ 356 | { 357 | id: "output" as PortId, 358 | dataType: "string", 359 | title: "Output", 360 | description: "The output from Ollama.", 361 | }, 362 | { 363 | id: "messages-sent" as PortId, 364 | dataType: "chat-message[]", 365 | title: "Messages Sent", 366 | description: 367 | "The messages sent to Ollama, including the system prompt.", 368 | }, 369 | { 370 | id: "all-messages" as PortId, 371 | dataType: "chat-message[]", 372 | title: "All Messages", 373 | description: "All messages, including the reply from Ollama.", 374 | }, 375 | ]; 376 | 377 | return outputs; 378 | }, 379 | 380 | getEditors(): EditorDefinition[] { 381 | return [ 382 | { 383 | type: "string", 384 | dataKey: "model", 385 | label: "Model", 386 | useInputToggleDataKey: "useModelInput", 387 | helperMessage: "The LLM model to use in Ollama.", 388 | }, 389 | { 390 | type: "toggle", 391 | dataKey: "jsonMode", 392 | label: "JSON mode", 393 | helperMessage: 394 | "Activates Ollamas JSON mode. Make sure to also instruct the model to return JSON", 395 | }, 396 | { 397 | type: "number", 398 | dataKey: "numPredict", 399 | useInputToggleDataKey: "useNumPredictInput", 400 | label: "maxTokens (num Predict)", 401 | helperMessage: 402 | "The maximum number of tokens to generate in the chat completion.", 403 | allowEmpty: false, 404 | defaultValue: 1024, 405 | }, 406 | { 407 | type: "number", 408 | dataKey: "temperature", 409 | useInputToggleDataKey: "useTemperatureInput", 410 | label: "Temperature", 411 | helperMessage: 412 | "The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)", 413 | allowEmpty: true, 414 | }, 415 | { 416 | type: "group", 417 | label: "Parameters", 418 | editors: [ 419 | { 420 | type: "number", 421 | dataKey: "mirostat", 422 | useInputToggleDataKey: "useMirostatInput", 423 | label: "Mirostat", 424 | helperMessage: 425 | "Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", 426 | min: 0, 427 | max: 1, 428 | step: 1, 429 | allowEmpty: true, 430 | }, 431 | { 432 | type: "number", 433 | dataKey: "mirostatEta", 434 | useInputToggleDataKey: "useMirostatEtaInput", 435 | label: "Mirostat Eta", 436 | helperMessage: 437 | "Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1)", 438 | allowEmpty: true, 439 | }, 440 | { 441 | type: "number", 442 | dataKey: "mirostatTau", 443 | useInputToggleDataKey: "useMirostatTauInput", 444 | label: "Mirostat Tau", 445 | helperMessage: 446 | "Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0)", 447 | allowEmpty: true, 448 | }, 449 | { 450 | type: "number", 451 | dataKey: "numCtx", 452 | useInputToggleDataKey: "useNumCtxInput", 453 | label: "Num Ctx", 454 | helperMessage: 455 | "Sets the size of the context window used to generate the next token. (Default: 2048)", 456 | 457 | allowEmpty: true, 458 | }, 459 | { 460 | type: "number", 461 | dataKey: "numGqa", 462 | useInputToggleDataKey: "useNumGqaInput", 463 | label: "Num GQA", 464 | helperMessage: 465 | "The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b", 466 | allowEmpty: true, 467 | }, 468 | { 469 | type: "number", 470 | dataKey: "numGpu", 471 | useInputToggleDataKey: "useNumGpuInput", 472 | label: "Num GPUs", 473 | helperMessage: 474 | "The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.", 475 | allowEmpty: true, 476 | }, 477 | { 478 | type: "number", 479 | dataKey: "numThread", 480 | useInputToggleDataKey: "useNumThreadInput", 481 | label: "Num Threads", 482 | helperMessage: 483 | "Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).", 484 | allowEmpty: true, 485 | }, 486 | { 487 | type: "number", 488 | dataKey: "repeatLastN", 489 | useInputToggleDataKey: "useRepeatLastNInput", 490 | label: "Repeat Last N", 491 | helperMessage: 492 | "Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)", 493 | allowEmpty: true, 494 | }, 495 | { 496 | type: "number", 497 | dataKey: "repeatPenalty", 498 | useInputToggleDataKey: "useRepeatPenaltyInput", 499 | label: "Repeat Penalty", 500 | helperMessage: 501 | "Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)", 502 | allowEmpty: true, 503 | }, 504 | { 505 | type: "number", 506 | dataKey: "seed", 507 | useInputToggleDataKey: "useSeedInput", 508 | label: "Seed", 509 | helperMessage: 510 | "Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0)", 511 | allowEmpty: true, 512 | }, 513 | { 514 | type: "string", 515 | dataKey: "stop", 516 | useInputToggleDataKey: "useStopInput", 517 | label: "Stop", 518 | helperMessage: 519 | "Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return.", 520 | }, 521 | { 522 | type: "number", 523 | dataKey: "tfsZ", 524 | useInputToggleDataKey: "useTfsZInput", 525 | label: "TFS Z", 526 | helperMessage: 527 | "Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)", 528 | allowEmpty: true, 529 | }, 530 | { 531 | type: "number", 532 | dataKey: "numPredict", 533 | useInputToggleDataKey: "useNumPredictInput", 534 | label: "Num Predict", 535 | helperMessage: 536 | "Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)", 537 | allowEmpty: true, 538 | }, 539 | { 540 | type: "number", 541 | dataKey: "topK", 542 | useInputToggleDataKey: "useTopKInput", 543 | label: "Top K", 544 | helperMessage: 545 | "Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)", 546 | allowEmpty: true, 547 | }, 548 | { 549 | type: "number", 550 | dataKey: "topP", 551 | useInputToggleDataKey: "useTopPInput", 552 | label: "Top P", 553 | helperMessage: 554 | "Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)", 555 | allowEmpty: true, 556 | }, 557 | { 558 | type: "keyValuePair", 559 | dataKey: "additionalParameters", 560 | useInputToggleDataKey: "useAdditionalParametersInput", 561 | label: "Additional Parameters", 562 | keyPlaceholder: "Parameter", 563 | valuePlaceholder: "Value", 564 | helperMessage: 565 | "Additional parameters to pass to Ollama. Numbers will be parsed and sent as numbers, otherwise they will be sent as strings.", 566 | }, 567 | ], 568 | }, 569 | { 570 | type: "group", 571 | label: "Advanced", 572 | editors: [ 573 | { 574 | type: "string", 575 | label: "Host", 576 | dataKey: "host", 577 | useInputToggleDataKey: "useHostInput", 578 | helperMessage: 579 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 580 | }, 581 | { 582 | type: "string", 583 | label: "API Key", 584 | dataKey: "apiKey", 585 | useInputToggleDataKey: "useApiKeyInput", 586 | helperMessage: 587 | "Optional API key for authentication with Ollama instances that require it. Will be sent as Authorization Bearer token.", 588 | }, 589 | { 590 | type: "keyValuePair", 591 | label: "Headers", 592 | dataKey: "headers", 593 | useInputToggleDataKey: "useHeadersInput", 594 | keyPlaceholder: "Header Name", 595 | valuePlaceholder: "Header Value", 596 | helperMessage: 597 | "Additional headers to send to the API.", 598 | }, 599 | ], 600 | }, 601 | ]; 602 | }, 603 | 604 | getBody(data) { 605 | return rivet.dedent` 606 | Model: ${data.useModelInput ? "(From Input)" : data.model || "Unset!"} 607 | Max tokens: ${data.numPredict || 1024} 608 | JSON Mode: ${data.jsonMode || false} 609 | `; 610 | }, 611 | 612 | getUIData(): NodeUIData { 613 | return { 614 | contextMenuTitle: "Ollama Chat", 615 | group: "Ollama", 616 | infoBoxBody: "This is an Ollama Chat node using /api/chat.", 617 | infoBoxTitle: "Ollama Chat Node", 618 | }; 619 | }, 620 | 621 | async process(data, inputData, context) { 622 | let outputs: Outputs = {}; 623 | 624 | const hostInput = rivet.getInputOrData(data, inputData, "host", "string"); 625 | const host = 626 | hostInput || 627 | context.getPluginConfig("host") || 628 | "http://localhost:11434"; 629 | 630 | if (!host.trim()) { 631 | throw new Error("No host set!"); 632 | } 633 | 634 | const apiKeyInput = rivet.getInputOrData( 635 | data, 636 | inputData, 637 | "apiKey", 638 | "string", 639 | ); 640 | const apiKey = apiKeyInput || context.getPluginConfig("apiKey"); 641 | 642 | const model = rivet.getInputOrData(data, inputData, "model", "string"); 643 | if (!model) { 644 | throw new Error("No model set!"); 645 | } 646 | 647 | const systemPrompt = rivet.coerceTypeOptional( 648 | inputData["system-prompt" as PortId], 649 | "string", 650 | ); 651 | 652 | const chatMessages = 653 | rivet.coerceTypeOptional( 654 | inputData["messages" as PortId], 655 | "chat-message[]", 656 | ) ?? []; 657 | const allMessages: ChatMessage[] = systemPrompt 658 | ? [{ type: "system", message: systemPrompt }, ...chatMessages] 659 | : chatMessages; 660 | 661 | const inputMessages: InputMessage[] = allMessages.map((message) => { 662 | if (typeof message.message === "string") { 663 | return { type: message.type, message: message.message }; 664 | } else { 665 | return { 666 | type: message.type, 667 | message: JSON.stringify(message.message), 668 | }; 669 | } 670 | }); 671 | 672 | let additionalParameters: Record = ( 673 | data.additionalParameters ?? [] 674 | ).reduce( 675 | (acc, { key, value }) => { 676 | const parsedValue = Number(value); 677 | acc[key] = isNaN(parsedValue) ? value : parsedValue; 678 | return acc; 679 | }, 680 | {} as Record, 681 | ); 682 | 683 | if (data.useAdditionalParametersInput) { 684 | additionalParameters = (rivet.coerceTypeOptional( 685 | inputData["additionalParameters" as PortId], 686 | "object", 687 | ) ?? {}) as Record; 688 | } 689 | 690 | let stop: string[] | undefined = undefined; 691 | if (data.useStopInput) { 692 | stop = rivet.coerceTypeOptional( 693 | inputData["stop" as PortId], 694 | "string[]", 695 | ); 696 | } else { 697 | stop = data.stop ? [data.stop] : undefined; 698 | } 699 | 700 | const openAiMessages = formatChatMessages(inputMessages); 701 | 702 | const parameters = { 703 | mirostat: rivet.getInputOrData(data, inputData, "mirostat", "number"), 704 | mirostat_eta: rivet.getInputOrData( 705 | data, 706 | inputData, 707 | "mirostatEta", 708 | "number", 709 | ), 710 | mirostat_tau: rivet.getInputOrData( 711 | data, 712 | inputData, 713 | "mirostatTau", 714 | "number", 715 | ), 716 | num_ctx: rivet.getInputOrData(data, inputData, "numCtx", "number"), 717 | num_gqa: rivet.getInputOrData(data, inputData, "numGqa", "number"), 718 | num_gpu: rivet.getInputOrData(data, inputData, "numGpu", "number"), 719 | num_thread: rivet.getInputOrData( 720 | data, 721 | inputData, 722 | "numThread", 723 | "number", 724 | ), 725 | repeat_last_n: rivet.getInputOrData( 726 | data, 727 | inputData, 728 | "repeatLastN", 729 | "number", 730 | ), 731 | repeat_penalty: rivet.getInputOrData( 732 | data, 733 | inputData, 734 | "repeatPenalty", 735 | "number", 736 | ), 737 | temperature: rivet.getInputOrData( 738 | data, 739 | inputData, 740 | "temperature", 741 | "number", 742 | ), 743 | seed: rivet.getInputOrData(data, inputData, "seed", "number"), 744 | stop, 745 | tfs_z: rivet.getInputOrData(data, inputData, "tfsZ", "number"), 746 | num_predict: rivet.getInputOrData( 747 | data, 748 | inputData, 749 | "numPredict", 750 | "number", 751 | ), 752 | top_k: rivet.getInputOrData(data, inputData, "topK", "number"), 753 | top_p: rivet.getInputOrData(data, inputData, "topP", "number"), 754 | ...additionalParameters, 755 | }; 756 | 757 | let apiResponse: Response; 758 | 759 | type RequestBodyType = { 760 | model: string; 761 | messages: OutputMessage[]; 762 | format?: string; 763 | options: any; 764 | stream: boolean; 765 | }; 766 | 767 | const requestBody: RequestBodyType = { 768 | model, 769 | messages: openAiMessages, 770 | stream: true, 771 | options: parameters, 772 | }; 773 | 774 | if (data.jsonMode === true) { 775 | requestBody.format = "json"; 776 | } 777 | 778 | try { 779 | const headers: Record = { 780 | "Content-Type": "application/json", 781 | }; 782 | 783 | if (apiKey && apiKey.trim()) { 784 | headers["Authorization"] = `Bearer ${apiKey}`; 785 | } 786 | 787 | // Add headers from data or input 788 | let additionalHeaders: Record = {}; 789 | if (data.useHeadersInput) { 790 | const headersInput = rivet.coerceTypeOptional( 791 | inputData["headers" as PortId], 792 | "object", 793 | ) as Record | undefined; 794 | if (headersInput) { 795 | additionalHeaders = headersInput; 796 | } 797 | } else if (data.headers) { 798 | additionalHeaders = data.headers.reduce( 799 | (acc, { key, value }) => { 800 | acc[key] = value; 801 | return acc; 802 | }, 803 | {} as Record, 804 | ); 805 | } 806 | 807 | Object.assign(headers, additionalHeaders); 808 | 809 | apiResponse = await fetch(`${host}/api/chat`, { 810 | method: "POST", 811 | headers, 812 | body: JSON.stringify(requestBody), 813 | }); 814 | } catch (err) { 815 | throw new Error(`Error from Ollama: ${rivet.getError(err).message}`); 816 | } 817 | 818 | if (!apiResponse.ok) { 819 | try { 820 | const error = await apiResponse.json(); 821 | throw new Error(`Error from Ollama: ${error.message}`); 822 | } catch (err) { 823 | throw new Error(`Error from Ollama: ${apiResponse.statusText}`); 824 | } 825 | } 826 | 827 | const reader = apiResponse.body?.getReader(); 828 | 829 | if (!reader) { 830 | throw new Error("No response body!"); 831 | } 832 | 833 | let streamingResponseText = ""; 834 | let llmResponseText = ""; 835 | 836 | let finalResponse: OllamaStreamingFinalResponse | undefined; 837 | 838 | while (true) { 839 | const { value, done } = await reader.read(); 840 | if (done) { 841 | break; 842 | } 843 | 844 | if (value) { 845 | const chunk = new TextDecoder().decode(value); 846 | 847 | streamingResponseText += chunk; 848 | 849 | const lines = streamingResponseText.split("\n"); 850 | streamingResponseText = lines.pop() ?? ""; 851 | 852 | for (const line of lines) { 853 | try { 854 | const json = JSON.parse(line) as OllamaStreamingGenerateResponse; 855 | 856 | if (!("done" in json)) { 857 | throw new Error(`Invalid response from Ollama: ${line}`); 858 | } 859 | 860 | if (!json.done) { 861 | if (llmResponseText === "") { 862 | llmResponseText += ( 863 | json.message.content as string 864 | ).trimStart(); 865 | } else { 866 | llmResponseText += json.message.content; 867 | } 868 | } else { 869 | finalResponse = json; 870 | } 871 | } catch (err) { 872 | throw new Error( 873 | `Error parsing line from Ollama streaming response: ${line}`, 874 | ); 875 | } 876 | } 877 | 878 | outputs["output" as PortId] = { 879 | type: "string", 880 | value: llmResponseText, 881 | }; 882 | 883 | context.onPartialOutputs?.(outputs); 884 | } 885 | } 886 | 887 | if (!finalResponse) { 888 | throw new Error("No final response from Ollama!"); 889 | } 890 | 891 | outputs["messages-sent" as PortId] = { 892 | type: "chat-message[]", 893 | value: allMessages, 894 | }; 895 | 896 | outputs["all-messages" as PortId] = { 897 | type: "chat-message[]", 898 | value: [ 899 | ...allMessages, 900 | { 901 | type: "assistant", 902 | message: llmResponseText, 903 | function_call: undefined, 904 | }, 905 | ], 906 | }; 907 | 908 | return outputs; 909 | }, 910 | }; 911 | 912 | return rivet.pluginNodeDefinition(impl, "Ollama Chat"); 913 | }; 914 | 915 | type InputMessage = { 916 | type: string; 917 | message: string; 918 | }; 919 | 920 | type OutputMessage = { 921 | role: string; 922 | content: string; 923 | }; 924 | 925 | function formatChatMessages(messages: InputMessage[]): OutputMessage[] { 926 | return messages.map((message) => ({ 927 | role: message.type, 928 | content: message.message, 929 | })); 930 | } 931 | -------------------------------------------------------------------------------- /src/nodes/OllamaEmbeddingNode.ts: -------------------------------------------------------------------------------- 1 | import type { 2 | ChartNode, 3 | ChatMessage, 4 | ChatMessageMessagePart, 5 | EditorDefinition, 6 | NodeId, 7 | NodeInputDefinition, 8 | NodeOutputDefinition, 9 | NodeUIData, 10 | Outputs, 11 | PluginNodeImpl, 12 | PortId, 13 | Rivet, 14 | } from "@ironclad/rivet-core"; 15 | import { match } from "ts-pattern"; 16 | 17 | export type OllamaEmbeddingNodeData = { 18 | model: string; 19 | useModelInput?: boolean; 20 | embedding: number[]; 21 | text: string; 22 | useTextInput?: boolean; 23 | 24 | host?: string; 25 | useHostInput?: boolean; 26 | 27 | apiKey?: string; 28 | useApiKeyInput?: boolean; 29 | 30 | headers?: { key: string; value: string }[]; 31 | useHeadersInput?: boolean; 32 | }; 33 | 34 | export type OllamaEmbeddingNode = ChartNode< 35 | "ollamaEmbed", 36 | OllamaEmbeddingNodeData 37 | >; 38 | 39 | type OllamaEmbeddingResponse = { 40 | embedding: number[]; 41 | }; 42 | 43 | export const ollamaEmbed = (rivet: typeof Rivet) => { 44 | const impl: PluginNodeImpl = { 45 | create(): OllamaEmbeddingNode { 46 | const node: OllamaEmbeddingNode = { 47 | id: rivet.newId(), 48 | data: { 49 | model: "", 50 | useModelInput: false, 51 | embedding: [], 52 | text: "", 53 | useTextInput: false, 54 | }, 55 | title: "Ollama Embedding", 56 | type: "ollamaEmbed", 57 | visualData: { 58 | x: 0, 59 | y: 0, 60 | width: 250, 61 | }, 62 | }; 63 | return node; 64 | }, 65 | 66 | getInputDefinitions(data): NodeInputDefinition[] { 67 | const inputs: NodeInputDefinition[] = []; 68 | 69 | if (data.useModelInput) { 70 | inputs.push({ 71 | id: "model" as PortId, 72 | dataType: "string", 73 | title: "Model", 74 | }); 75 | } 76 | 77 | if (data.useTextInput) { 78 | inputs.push({ 79 | id: "text" as PortId, 80 | dataType: "string", 81 | title: "Text", 82 | }); 83 | } 84 | 85 | if (data.useHostInput) { 86 | inputs.push({ 87 | dataType: "string", 88 | id: "host" as PortId, 89 | title: "Host", 90 | description: 91 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 92 | }); 93 | } 94 | 95 | if (data.useApiKeyInput) { 96 | inputs.push({ 97 | dataType: "string", 98 | id: "apiKey" as PortId, 99 | title: "API Key", 100 | description: 101 | "Optional API key for authentication with Ollama instances that require it.", 102 | }); 103 | } 104 | 105 | if (data.useHeadersInput) { 106 | inputs.push({ 107 | dataType: 'object', 108 | id: 'headers' as PortId, 109 | title: 'Headers', 110 | description: 'Additional headers to send to the API.', 111 | }); 112 | } 113 | 114 | return inputs; 115 | }, 116 | 117 | getOutputDefinitions(data): NodeOutputDefinition[] { 118 | let outputs: NodeOutputDefinition[] = [ 119 | { 120 | id: "embedding" as PortId, 121 | dataType: "vector", 122 | title: "Embedding", 123 | description: "The embedding output from Ollama.", 124 | }, 125 | ]; 126 | 127 | return outputs; 128 | }, 129 | 130 | getEditors(): EditorDefinition[] { 131 | return [ 132 | { 133 | type: "string", 134 | dataKey: "model", 135 | useInputToggleDataKey: "useModelInput", 136 | label: "Model", 137 | }, 138 | { 139 | type: "string", 140 | dataKey: "text", 141 | useInputToggleDataKey: "useTextInput", 142 | label: "Text", 143 | }, 144 | { 145 | type: "group", 146 | label: "Advanced", 147 | editors: [ 148 | { 149 | type: "string", 150 | label: "Host", 151 | dataKey: "host", 152 | useInputToggleDataKey: "useHostInput", 153 | helperMessage: 154 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 155 | }, 156 | { 157 | type: "string", 158 | label: "API Key", 159 | dataKey: "apiKey", 160 | useInputToggleDataKey: "useApiKeyInput", 161 | helperMessage: 162 | "Optional API key for authentication with Ollama instances that require it. Will be sent as Authorization Bearer token.", 163 | }, 164 | { 165 | type: "keyValuePair", 166 | label: "Headers", 167 | dataKey: "headers", 168 | useInputToggleDataKey: "useHeadersInput", 169 | keyPlaceholder: "Header Name", 170 | valuePlaceholder: "Header Value", 171 | helperMessage: 172 | "Additional headers to send to the API.", 173 | }, 174 | ], 175 | }, 176 | ]; 177 | }, 178 | 179 | getBody(data) { 180 | return rivet.dedent` 181 | Model: ${data.useModelInput ? "(From Input)" : data.model || "Unset!"} 182 | Text: ${data.useTextInput ? "(From Input)" : data.text || "Unset!"} 183 | `; 184 | }, 185 | 186 | getUIData(): NodeUIData { 187 | return { 188 | contextMenuTitle: "Ollama Embedding", 189 | group: "Ollama", 190 | infoBoxBody: "This is an Ollama Embedding node using /api/embeddings.", 191 | infoBoxTitle: "Ollama Embedding Node", 192 | }; 193 | }, 194 | 195 | async process(data, inputData, context) { 196 | let outputs: Outputs = {}; 197 | 198 | const hostInput = rivet.getInputOrData(data, inputData, "host", "string"); 199 | const host = 200 | hostInput || 201 | context.getPluginConfig("host") || 202 | "http://localhost:11434"; 203 | 204 | if (!host.trim()) { 205 | throw new Error("No host set!"); 206 | } 207 | 208 | const apiKeyInput = rivet.getInputOrData( 209 | data, 210 | inputData, 211 | "apiKey", 212 | "string", 213 | ); 214 | const apiKey = apiKeyInput || context.getPluginConfig("apiKey"); 215 | 216 | const model = rivet.getInputOrData(data, inputData, "model", "string"); 217 | if (!model) { 218 | throw new Error("No model set!"); 219 | } 220 | 221 | const prompt = rivet.getInputOrData(data, inputData, "text", "string"); 222 | let apiResponse: Response; 223 | 224 | type RequestBodyType = { 225 | model: string; 226 | prompt: string; 227 | }; 228 | 229 | const requestBody: RequestBodyType = { 230 | model, 231 | prompt, 232 | }; 233 | 234 | try { 235 | const headers: Record = { 236 | "Content-Type": "application/json", 237 | }; 238 | 239 | if (apiKey && apiKey.trim()) { 240 | headers["Authorization"] = `Bearer ${apiKey}`; 241 | } 242 | 243 | // Add headers from data or input 244 | let additionalHeaders: Record = {}; 245 | if (data.useHeadersInput) { 246 | const headersInput = rivet.coerceTypeOptional( 247 | inputData["headers" as PortId], 248 | "object", 249 | ) as Record | undefined; 250 | if (headersInput) { 251 | additionalHeaders = headersInput; 252 | } 253 | } else if (data.headers) { 254 | additionalHeaders = data.headers.reduce( 255 | (acc, { key, value }) => { 256 | acc[key] = value; 257 | return acc; 258 | }, 259 | {} as Record, 260 | ); 261 | } 262 | 263 | Object.assign(headers, additionalHeaders); 264 | 265 | apiResponse = await fetch(`${host}/api/embeddings`, { 266 | method: "POST", 267 | headers, 268 | body: JSON.stringify(requestBody), 269 | }); 270 | } catch (err) { 271 | throw new Error( 272 | `Error from Ollama {POST}: ${rivet.getError(err).message}`, 273 | ); 274 | } 275 | 276 | if (!apiResponse.ok) { 277 | try { 278 | const error = await apiResponse.json(); 279 | throw new Error(`Error from Ollama {JSON}: ${error.message}`); 280 | } catch (err) { 281 | throw new Error(`Error from Ollama {RAW}: ${apiResponse.statusText}`); 282 | } 283 | } 284 | 285 | const reader = apiResponse.body?.getReader(); 286 | 287 | if (!reader) { 288 | throw new Error("No response body!"); 289 | } 290 | 291 | let streamingResponseText = ""; 292 | let llmResponseText = ""; 293 | const { value, done } = await reader.read(); 294 | const line = new TextDecoder().decode(value); 295 | const response = JSON.parse(line) as OllamaEmbeddingResponse; 296 | 297 | outputs["embedding" as PortId] = { 298 | type: "vector", 299 | value: response.embedding, 300 | }; 301 | 302 | return outputs; 303 | }, 304 | }; 305 | 306 | return rivet.pluginNodeDefinition(impl, "Ollama Embedding"); 307 | }; 308 | -------------------------------------------------------------------------------- /src/nodes/OllamaGenerateNode.ts: -------------------------------------------------------------------------------- 1 | import type { 2 | ChartNode, 3 | ChatMessage, 4 | ChatMessageMessagePart, 5 | EditorDefinition, 6 | NodeId, 7 | NodeInputDefinition, 8 | NodeOutputDefinition, 9 | NodeUIData, 10 | Outputs, 11 | PluginNodeImpl, 12 | PortId, 13 | Rivet, 14 | } from "@ironclad/rivet-core"; 15 | import { match } from "ts-pattern"; 16 | 17 | export type OllamaGenerateNodeData = { 18 | model: string; 19 | useModelInput?: boolean; 20 | 21 | promptFormat: string; 22 | 23 | jsonMode: boolean; 24 | 25 | outputFormat: string; 26 | 27 | advancedOutputs: boolean; 28 | 29 | // PARAMETERS 30 | 31 | mirostat?: number; 32 | useMirostatInput?: boolean; 33 | 34 | mirostatEta?: number; 35 | useMirostatEtaInput?: boolean; 36 | 37 | mirostatTau?: number; 38 | useMirostatTauInput?: boolean; 39 | 40 | numCtx?: number; 41 | useNumCtxInput?: boolean; 42 | 43 | numGqa?: number; 44 | useNumGqaInput?: boolean; 45 | 46 | numGpu?: number; 47 | useNumGpuInput?: boolean; 48 | 49 | numThread?: number; 50 | useNumThreadInput?: boolean; 51 | 52 | repeatLastN?: number; 53 | useRepeatLastNInput?: boolean; 54 | 55 | repeatPenalty?: number; 56 | useRepeatPenaltyInput?: boolean; 57 | 58 | temperature?: number; 59 | useTemperatureInput?: boolean; 60 | 61 | seed?: number; 62 | useSeedInput?: boolean; 63 | 64 | stop: string; 65 | useStopInput?: boolean; 66 | 67 | tfsZ?: number; 68 | useTfsZInput?: boolean; 69 | 70 | numPredict?: number; 71 | useNumPredictInput?: boolean; 72 | 73 | topK?: number; 74 | useTopKInput?: boolean; 75 | 76 | topP?: number; 77 | useTopPInput?: boolean; 78 | 79 | additionalParameters?: { key: string; value: string }[]; 80 | useAdditionalParametersInput?: boolean; 81 | 82 | host?: string; 83 | useHostInput?: boolean; 84 | 85 | apiKey?: string; 86 | useApiKeyInput?: boolean; 87 | 88 | headers?: { key: string; value: string }[]; 89 | useHeadersInput?: boolean; 90 | }; 91 | 92 | export type OllamaGenerateNode = ChartNode<"ollamaChat", OllamaGenerateNodeData>; 93 | 94 | type OllamaStreamingContentResponse = { 95 | model: string; 96 | created_at: string; 97 | done: false; 98 | response: string; 99 | }; 100 | 101 | type OllamaStreamingFinalResponse = { 102 | model: string; 103 | created_at: string; 104 | response: ""; 105 | done: true; 106 | total_duration: number; 107 | load_duration: number; 108 | sample_count?: number; 109 | sample_duration?: number; 110 | prompt_eval_count: number; 111 | prompt_eval_duration: number; 112 | eval_count: number; 113 | eval_duration: number; 114 | context: number[]; 115 | }; 116 | 117 | type OllamaStreamingGenerateResponse = 118 | | OllamaStreamingContentResponse 119 | | OllamaStreamingFinalResponse; 120 | 121 | export const ollamaChat = (rivet: typeof Rivet) => { 122 | const impl: PluginNodeImpl = { 123 | create(): OllamaGenerateNode { 124 | const node: OllamaGenerateNode = { 125 | id: rivet.newId(), 126 | data: { 127 | model: "", 128 | useModelInput: false, 129 | promptFormat: "auto", 130 | jsonMode: false, 131 | outputFormat: "", 132 | numPredict: 1024, 133 | advancedOutputs: false, 134 | stop: "", 135 | }, 136 | title: "Ollama Generate", 137 | type: "ollamaChat", 138 | visualData: { 139 | x: 0, 140 | y: 0, 141 | width: 250, 142 | }, 143 | }; 144 | return node; 145 | }, 146 | 147 | getInputDefinitions(data): NodeInputDefinition[] { 148 | const inputs: NodeInputDefinition[] = []; 149 | 150 | inputs.push({ 151 | id: "system-prompt" as PortId, 152 | dataType: "string", 153 | title: "System Prompt", 154 | description: "The system prompt to prepend to the messages list.", 155 | required: false, 156 | coerced: true, 157 | }); 158 | 159 | inputs.push({ 160 | id: "messages" as PortId, 161 | dataType: ["chat-message[]", "chat-message"], 162 | title: "Messages", 163 | description: "The chat messages to use as the prompt.", 164 | }); 165 | 166 | if (data.useModelInput) { 167 | inputs.push({ 168 | id: "model" as PortId, 169 | dataType: "string", 170 | title: "Model", 171 | }); 172 | } 173 | 174 | if (data.useMirostatInput) { 175 | inputs.push({ 176 | id: "mirostat" as PortId, 177 | dataType: "number", 178 | title: "Mirostat", 179 | description: 'The "mirostat" parameter.', 180 | }); 181 | } 182 | 183 | if (data.useMirostatEtaInput) { 184 | inputs.push({ 185 | id: "mirostatEta" as PortId, 186 | dataType: "number", 187 | title: "Mirostat Eta", 188 | description: 'The "mirostat_eta" parameter.', 189 | }); 190 | } 191 | 192 | if (data.useMirostatTauInput) { 193 | inputs.push({ 194 | id: "mirostatTau" as PortId, 195 | dataType: "number", 196 | title: "Mirostat Tau", 197 | description: 'The "mirostat_tau" parameter.', 198 | }); 199 | } 200 | 201 | if (data.useNumCtxInput) { 202 | inputs.push({ 203 | id: "numCtx" as PortId, 204 | dataType: "number", 205 | title: "Num Ctx", 206 | description: 'The "num_ctx" parameter.', 207 | }); 208 | } 209 | 210 | if (data.useNumGqaInput) { 211 | inputs.push({ 212 | id: "numGqa" as PortId, 213 | dataType: "number", 214 | title: "Num GQA", 215 | description: 'The "num_gqa" parameter.', 216 | }); 217 | } 218 | 219 | if (data.useNumGpuInput) { 220 | inputs.push({ 221 | id: "numGpu" as PortId, 222 | dataType: "number", 223 | title: "Num GPUs", 224 | description: 'The "num_gpu" parameter.', 225 | }); 226 | } 227 | 228 | if (data.useNumThreadInput) { 229 | inputs.push({ 230 | id: "numThread" as PortId, 231 | dataType: "number", 232 | title: "Num Threads", 233 | description: 'The "num_thread" parameter.', 234 | }); 235 | } 236 | 237 | if (data.useRepeatLastNInput) { 238 | inputs.push({ 239 | id: "repeatLastN" as PortId, 240 | dataType: "number", 241 | title: "Repeat Last N", 242 | description: 'The "repeat_last_n" parameter.', 243 | }); 244 | } 245 | 246 | if (data.useRepeatPenaltyInput) { 247 | inputs.push({ 248 | id: "repeatPenalty" as PortId, 249 | dataType: "number", 250 | title: "Repeat Penalty", 251 | description: 'The "repeat_penalty" parameter.', 252 | }); 253 | } 254 | 255 | if (data.useTemperatureInput) { 256 | inputs.push({ 257 | id: "temperature" as PortId, 258 | dataType: "number", 259 | title: "Temperature", 260 | description: 'The "temperature" parameter.', 261 | }); 262 | } 263 | 264 | if (data.useSeedInput) { 265 | inputs.push({ 266 | id: "seed" as PortId, 267 | dataType: "number", 268 | title: "Seed", 269 | description: 'The "seed" parameter.', 270 | }); 271 | } 272 | 273 | if (data.useStopInput) { 274 | inputs.push({ 275 | id: "stop" as PortId, 276 | dataType: "string[]", 277 | title: "Stop", 278 | description: 'The "stop" parameter.', 279 | }); 280 | } 281 | 282 | if (data.useTfsZInput) { 283 | inputs.push({ 284 | id: "tfsZ" as PortId, 285 | dataType: "number", 286 | title: "TFS Z", 287 | description: 'The "tfs_z" parameter.', 288 | }); 289 | } 290 | 291 | if (data.useNumPredictInput) { 292 | inputs.push({ 293 | id: "numPredict" as PortId, 294 | dataType: "number", 295 | title: "Num Predict", 296 | description: 'The "num_predict" parameter.', 297 | }); 298 | } 299 | 300 | if (data.useTopKInput) { 301 | inputs.push({ 302 | id: "topK" as PortId, 303 | dataType: "number", 304 | title: "Top K", 305 | description: 'The "top_k" parameter.', 306 | }); 307 | } 308 | 309 | if (data.useTopPInput) { 310 | inputs.push({ 311 | id: "topP" as PortId, 312 | dataType: "number", 313 | title: "Top P", 314 | description: 'The "top_p" parameter.', 315 | }); 316 | } 317 | 318 | if (data.useHostInput) { 319 | inputs.push({ 320 | dataType: "string", 321 | id: "host" as PortId, 322 | title: "Host", 323 | description: 324 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 325 | }); 326 | } 327 | 328 | if (data.useApiKeyInput) { 329 | inputs.push({ 330 | dataType: "string", 331 | id: "apiKey" as PortId, 332 | title: "API Key", 333 | description: 334 | "Optional API key for authentication with Ollama instances that require it.", 335 | }); 336 | } 337 | 338 | if (data.useAdditionalParametersInput) { 339 | inputs.push({ 340 | id: "additionalParameters" as PortId, 341 | dataType: "object", 342 | title: "Additional Parameters", 343 | description: "Additional parameters to pass to Ollama.", 344 | }); 345 | } 346 | 347 | if (data.useHeadersInput) { 348 | inputs.push({ 349 | dataType: 'object', 350 | id: 'headers' as PortId, 351 | title: 'Headers', 352 | description: 'Additional headers to send to the API.', 353 | }); 354 | } 355 | 356 | return inputs; 357 | }, 358 | 359 | getOutputDefinitions(data): NodeOutputDefinition[] { 360 | let outputs: NodeOutputDefinition[] = [ 361 | { 362 | id: "output" as PortId, 363 | dataType: "string", 364 | title: "Output", 365 | description: "The output from Ollama.", 366 | }, 367 | { 368 | id: "prompt" as PortId, 369 | dataType: "string", 370 | title: "Prompt", 371 | description: 372 | "The full prompt, with formattting, that was sent to Ollama.", 373 | }, 374 | { 375 | id: "messages-sent" as PortId, 376 | dataType: "chat-message[]", 377 | title: "Messages Sent", 378 | description: 379 | "The messages sent to Ollama, including the system prompt.", 380 | }, 381 | { 382 | id: "all-messages" as PortId, 383 | dataType: "chat-message[]", 384 | title: "All Messages", 385 | description: "All messages, including the reply from Ollama.", 386 | }, 387 | ]; 388 | 389 | if (data.advancedOutputs) { 390 | outputs = [ 391 | ...outputs, 392 | { 393 | id: "total-duration" as PortId, 394 | dataType: "number", 395 | title: "Total Duration", 396 | description: "Time spent generating the response", 397 | }, 398 | { 399 | id: "load-duration" as PortId, 400 | dataType: "number", 401 | title: "Load Duration", 402 | description: "Time spent in nanoseconds loading the model", 403 | }, 404 | { 405 | id: "sample-count" as PortId, 406 | dataType: "number", 407 | title: "Sample Count", 408 | description: "Number of samples generated", 409 | }, 410 | { 411 | id: "sample-duration" as PortId, 412 | dataType: "number", 413 | title: "Sample Duration", 414 | description: "Time spent generating samples", 415 | }, 416 | { 417 | id: "prompt-eval-count" as PortId, 418 | dataType: "number", 419 | title: "Prompt Eval Count", 420 | description: "Number of tokens in the prompt", 421 | }, 422 | { 423 | id: "prompt-eval-duration" as PortId, 424 | dataType: "number", 425 | title: "Prompt Eval Duration", 426 | description: "Time spend in nanoseconds evaluating the prompt", 427 | }, 428 | { 429 | id: "eval-count" as PortId, 430 | dataType: "number", 431 | title: "Eval Count", 432 | description: "Number of tokens in the response", 433 | }, 434 | { 435 | id: "eval-duration" as PortId, 436 | dataType: "number", 437 | title: "Eval Duration", 438 | description: "Time in nanoseconds spent generating the response", 439 | }, 440 | { 441 | id: "tokens-per-second" as PortId, 442 | dataType: "number", 443 | title: "Tokens Per Second", 444 | description: "Tokens generated per second", 445 | }, 446 | { 447 | id: "parameters" as PortId, 448 | dataType: "object", 449 | title: "Parameters", 450 | description: "The parameters sent to Ollama", 451 | }, 452 | ]; 453 | } 454 | 455 | return outputs; 456 | }, 457 | 458 | getEditors(): EditorDefinition[] { 459 | return [ 460 | { 461 | type: "string", 462 | dataKey: "model", 463 | label: "Model", 464 | useInputToggleDataKey: "useModelInput", 465 | helperMessage: "The LLM model to use in Ollama.", 466 | }, 467 | { 468 | type: "dropdown", 469 | dataKey: "promptFormat", 470 | label: "Prompt Format", 471 | options: [ 472 | { value: "auto", label: "Auto"}, 473 | { value: "", label: "Raw" }, 474 | { value: "llama2", label: "Llama 2 Instruct" }, 475 | ], 476 | defaultValue: "", 477 | helperMessage: "The way to format chat messages for the prompt being sent to the ollama model. Raw means no formatting is applied. Auto means ollama will take care of it." 478 | }, 479 | { 480 | type: "toggle", 481 | dataKey: "jsonMode", 482 | label: "JSON mode", 483 | helperMessage: "Activates Ollamas JSON mode. Make sure to also instruct the model to return JSON" 484 | }, 485 | { 486 | type: "number", 487 | dataKey: "numPredict", 488 | useInputToggleDataKey: "useNumPredictInput", 489 | label: "maxTokens (num Predict)", 490 | helperMessage: 491 | "The maximum number of tokens to generate in the chat completion.", 492 | allowEmpty: false, 493 | defaultValue: 1024, 494 | }, 495 | { 496 | type: "toggle", 497 | dataKey: "advancedOutputs", 498 | label: "Advanced Outputs", 499 | helperMessage: 500 | "Add additional outputs with detailed information about the Ollama execution.", 501 | }, 502 | { 503 | type: "group", 504 | label: "Parameters", 505 | editors: [ 506 | { 507 | type: "number", 508 | dataKey: "mirostat", 509 | useInputToggleDataKey: "useMirostatInput", 510 | label: "Mirostat", 511 | helperMessage: 512 | "Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", 513 | min: 0, 514 | max: 1, 515 | step: 1, 516 | allowEmpty: true, 517 | }, 518 | { 519 | type: "number", 520 | dataKey: "mirostatEta", 521 | useInputToggleDataKey: "useMirostatEtaInput", 522 | label: "Mirostat Eta", 523 | helperMessage: 524 | "Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1)", 525 | allowEmpty: true, 526 | }, 527 | { 528 | type: "number", 529 | dataKey: "mirostatTau", 530 | useInputToggleDataKey: "useMirostatTauInput", 531 | label: "Mirostat Tau", 532 | helperMessage: 533 | "Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0)", 534 | allowEmpty: true, 535 | }, 536 | { 537 | type: "number", 538 | dataKey: "numCtx", 539 | useInputToggleDataKey: "useNumCtxInput", 540 | label: "Num Ctx", 541 | helperMessage: 542 | "Sets the size of the context window used to generate the next token. (Default: 2048)", 543 | 544 | allowEmpty: true, 545 | }, 546 | { 547 | type: "number", 548 | dataKey: "numGqa", 549 | useInputToggleDataKey: "useNumGqaInput", 550 | label: "Num GQA", 551 | helperMessage: 552 | "The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b", 553 | allowEmpty: true, 554 | }, 555 | { 556 | type: "number", 557 | dataKey: "numGpu", 558 | useInputToggleDataKey: "useNumGpuInput", 559 | label: "Num GPUs", 560 | helperMessage: 561 | "The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.", 562 | allowEmpty: true, 563 | }, 564 | { 565 | type: "number", 566 | dataKey: "numThread", 567 | useInputToggleDataKey: "useNumThreadInput", 568 | label: "Num Threads", 569 | helperMessage: 570 | "Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).", 571 | allowEmpty: true, 572 | }, 573 | { 574 | type: "number", 575 | dataKey: "repeatLastN", 576 | useInputToggleDataKey: "useRepeatLastNInput", 577 | label: "Repeat Last N", 578 | helperMessage: 579 | "Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)", 580 | allowEmpty: true, 581 | }, 582 | { 583 | type: "number", 584 | dataKey: "repeatPenalty", 585 | useInputToggleDataKey: "useRepeatPenaltyInput", 586 | label: "Repeat Penalty", 587 | helperMessage: 588 | "Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)", 589 | allowEmpty: true, 590 | }, 591 | { 592 | type: "number", 593 | dataKey: "temperature", 594 | useInputToggleDataKey: "useTemperatureInput", 595 | label: "Temperature", 596 | helperMessage: 597 | "The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)", 598 | allowEmpty: true, 599 | }, 600 | { 601 | type: "number", 602 | dataKey: "seed", 603 | useInputToggleDataKey: "useSeedInput", 604 | label: "Seed", 605 | helperMessage: 606 | "Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0)", 607 | allowEmpty: true, 608 | }, 609 | { 610 | type: "string", 611 | dataKey: "stop", 612 | useInputToggleDataKey: "useStopInput", 613 | label: "Stop", 614 | helperMessage: 615 | "Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return.", 616 | }, 617 | { 618 | type: "number", 619 | dataKey: "tfsZ", 620 | useInputToggleDataKey: "useTfsZInput", 621 | label: "TFS Z", 622 | helperMessage: 623 | "Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)", 624 | allowEmpty: true, 625 | }, 626 | { 627 | type: "number", 628 | dataKey: "numPredict", 629 | useInputToggleDataKey: "useNumPredictInput", 630 | label: "Num Predict", 631 | helperMessage: 632 | "Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)", 633 | allowEmpty: true, 634 | }, 635 | { 636 | type: "number", 637 | dataKey: "topK", 638 | useInputToggleDataKey: "useTopKInput", 639 | label: "Top K", 640 | helperMessage: 641 | "Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)", 642 | allowEmpty: true, 643 | }, 644 | { 645 | type: "number", 646 | dataKey: "topP", 647 | useInputToggleDataKey: "useTopPInput", 648 | label: "Top P", 649 | helperMessage: 650 | "Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)", 651 | allowEmpty: true, 652 | }, 653 | { 654 | type: "keyValuePair", 655 | dataKey: "additionalParameters", 656 | useInputToggleDataKey: "useAdditionalParametersInput", 657 | label: "Additional Parameters", 658 | keyPlaceholder: "Parameter", 659 | valuePlaceholder: "Value", 660 | helperMessage: 661 | "Additional parameters to pass to Ollama. Numbers will be parsed and sent as numbers, otherwise they will be sent as strings.", 662 | }, 663 | ], 664 | }, 665 | { 666 | type: "group", 667 | label: "Advanced", 668 | editors: [ 669 | { 670 | type: "string", 671 | label: "Host", 672 | dataKey: "host", 673 | useInputToggleDataKey: "useHostInput", 674 | helperMessage: 675 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 676 | }, 677 | { 678 | type: "string", 679 | label: "API Key", 680 | dataKey: "apiKey", 681 | useInputToggleDataKey: "useApiKeyInput", 682 | helperMessage: 683 | "Optional API key for authentication with Ollama instances that require it. Will be sent as Authorization Bearer token.", 684 | }, 685 | { 686 | type: "keyValuePair", 687 | label: "Headers", 688 | dataKey: "headers", 689 | useInputToggleDataKey: "useHeadersInput", 690 | keyPlaceholder: "Header Name", 691 | valuePlaceholder: "Header Value", 692 | helperMessage: 693 | "Additional headers to send to the API.", 694 | }, 695 | ], 696 | }, 697 | ]; 698 | }, 699 | 700 | getBody(data) { 701 | return rivet.dedent` 702 | Model: ${data.useModelInput ? "(From Input)" : data.model || "Unset!"} 703 | Format: ${data.promptFormat || "Raw"} 704 | Max tokens: ${data.numPredict || 1024} 705 | `; 706 | }, 707 | 708 | getUIData(): NodeUIData { 709 | return { 710 | contextMenuTitle: "Ollama Generate", 711 | group: "Ollama", 712 | infoBoxBody: "This is an Ollama Generate node using /api/generate.", 713 | infoBoxTitle: "Ollama Generate Node", 714 | }; 715 | }, 716 | 717 | async process(data, inputData, context) { 718 | let outputs: Outputs = {}; 719 | 720 | const hostInput = rivet.getInputOrData(data, inputData, "host", "string"); 721 | const host = 722 | hostInput || 723 | context.getPluginConfig("host") || 724 | "http://localhost:11434"; 725 | 726 | if (!host.trim()) { 727 | throw new Error("No host set!"); 728 | } 729 | 730 | const apiKeyInput = rivet.getInputOrData( 731 | data, 732 | inputData, 733 | "apiKey", 734 | "string", 735 | ); 736 | const apiKey = apiKeyInput || context.getPluginConfig("apiKey"); 737 | 738 | const model = rivet.getInputOrData(data, inputData, "model", "string"); 739 | if (!model) { 740 | throw new Error("No model set!"); 741 | } 742 | 743 | const systemPrompt = rivet.coerceTypeOptional( 744 | inputData["system-prompt" as PortId], 745 | "string" 746 | ); 747 | const messages = 748 | rivet.coerceTypeOptional( 749 | inputData["messages" as PortId], 750 | "chat-message[]" 751 | ) ?? []; 752 | const allMessages: ChatMessage[] = systemPrompt 753 | ? [{ type: "system", message: systemPrompt }, ...messages] 754 | : messages; 755 | 756 | const prompt = formatChatMessages(allMessages, data.promptFormat); 757 | 758 | let additionalParameters: Record = ( 759 | data.additionalParameters ?? [] 760 | ).reduce((acc, { key, value }) => { 761 | const parsedValue = Number(value); 762 | acc[key] = isNaN(parsedValue) ? value : parsedValue; 763 | return acc; 764 | }, {} as Record); 765 | 766 | if (data.useAdditionalParametersInput) { 767 | additionalParameters = (rivet.coerceTypeOptional( 768 | inputData["additionalParameters" as PortId], 769 | "object" 770 | ) ?? {}) as Record; 771 | } 772 | 773 | let stop: string[] | undefined = undefined; 774 | if (data.useStopInput) { 775 | stop = rivet.coerceTypeOptional( 776 | inputData["stop" as PortId], 777 | "string[]" 778 | ); 779 | } else { 780 | stop = data.stop ? [data.stop] : undefined; 781 | } 782 | 783 | const parameters = { 784 | mirostat: rivet.getInputOrData(data, inputData, "mirostat", "number"), 785 | mirostat_eta: rivet.getInputOrData( 786 | data, 787 | inputData, 788 | "mirostatEta", 789 | "number" 790 | ), 791 | mirostat_tau: rivet.getInputOrData( 792 | data, 793 | inputData, 794 | "mirostatTau", 795 | "number" 796 | ), 797 | num_ctx: rivet.getInputOrData(data, inputData, "numCtx", "number"), 798 | num_gqa: rivet.getInputOrData(data, inputData, "numGqa", "number"), 799 | num_gpu: rivet.getInputOrData(data, inputData, "numGpu", "number"), 800 | num_thread: rivet.getInputOrData( 801 | data, 802 | inputData, 803 | "numThread", 804 | "number" 805 | ), 806 | repeat_last_n: rivet.getInputOrData( 807 | data, 808 | inputData, 809 | "repeatLastN", 810 | "number" 811 | ), 812 | repeat_penalty: rivet.getInputOrData( 813 | data, 814 | inputData, 815 | "repeatPenalty", 816 | "number" 817 | ), 818 | temperature: rivet.getInputOrData( 819 | data, 820 | inputData, 821 | "temperature", 822 | "number" 823 | ), 824 | seed: rivet.getInputOrData(data, inputData, "seed", "number"), 825 | stop, 826 | tfs_z: rivet.getInputOrData(data, inputData, "tfsZ", "number"), 827 | num_predict: rivet.getInputOrData( 828 | data, 829 | inputData, 830 | "numPredict", 831 | "number" 832 | ), 833 | top_k: rivet.getInputOrData(data, inputData, "topK", "number"), 834 | top_p: rivet.getInputOrData(data, inputData, "topP", "number"), 835 | ...additionalParameters, 836 | }; 837 | 838 | let apiResponse: Response; 839 | 840 | type RequestBodyType = { 841 | model: string; 842 | prompt: string; 843 | raw: boolean; 844 | stream: boolean; 845 | options: any; 846 | format?: string; 847 | }; 848 | 849 | const requestBody: RequestBodyType = { 850 | model, 851 | prompt, 852 | raw: data.promptFormat === "auto" ? false : true, 853 | stream: true, 854 | options: parameters 855 | }; 856 | if (data.jsonMode === true) { 857 | requestBody.format = "json"; 858 | } // test 859 | 860 | try { 861 | const headers: Record = { 862 | "Content-Type": "application/json", 863 | }; 864 | 865 | if (apiKey && apiKey.trim()) { 866 | headers["Authorization"] = `Bearer ${apiKey}`; 867 | } 868 | 869 | // Add headers from data or input 870 | let additionalHeaders: Record = {}; 871 | if (data.useHeadersInput) { 872 | const headersInput = rivet.coerceTypeOptional( 873 | inputData["headers" as PortId], 874 | "object", 875 | ) as Record | undefined; 876 | if (headersInput) { 877 | additionalHeaders = headersInput; 878 | } 879 | } else if (data.headers) { 880 | additionalHeaders = data.headers.reduce( 881 | (acc, { key, value }) => { 882 | acc[key] = value; 883 | return acc; 884 | }, 885 | {} as Record, 886 | ); 887 | } 888 | 889 | Object.assign(headers, additionalHeaders); 890 | 891 | apiResponse = await fetch(`${host}/api/generate`, { 892 | method: "POST", 893 | headers, 894 | body: JSON.stringify(requestBody) 895 | }); 896 | } catch (err) { 897 | throw new Error(`Error from Ollama: ${rivet.getError(err).message}`); 898 | } 899 | 900 | if (!apiResponse.ok) { 901 | try { 902 | const error = await apiResponse.json(); 903 | throw new Error(`Error from Ollama: ${error.message}`); 904 | } catch (err) { 905 | throw new Error(`Error from Ollama: ${apiResponse.statusText}`); 906 | } 907 | } 908 | 909 | const reader = apiResponse.body?.getReader(); 910 | 911 | if (!reader) { 912 | throw new Error("No response body!"); 913 | } 914 | 915 | let streamingResponseText = ""; 916 | let llmResponseText = ""; 917 | 918 | let finalResponse: OllamaStreamingFinalResponse | undefined; 919 | 920 | while (true) { 921 | const { value, done } = await reader.read(); 922 | if (done) { 923 | break; 924 | } 925 | 926 | if (value) { 927 | const chunk = new TextDecoder().decode(value); 928 | 929 | streamingResponseText += chunk; 930 | 931 | const lines = streamingResponseText.split("\n"); 932 | streamingResponseText = lines.pop() ?? ""; 933 | 934 | for (const line of lines) { 935 | try { 936 | const json = JSON.parse(line) as OllamaStreamingGenerateResponse; 937 | 938 | if (!("done" in json)) { 939 | throw new Error(`Invalid response from Ollama: ${line}`); 940 | } 941 | 942 | if (!json.done) { 943 | if (llmResponseText === "") { 944 | llmResponseText += (json.response as string).trimStart(); 945 | } else { 946 | llmResponseText += json.response; 947 | } 948 | } else { 949 | finalResponse = json; 950 | } 951 | } catch (err) { 952 | throw new Error( 953 | `Error parsing line from Ollama streaming response: ${line}` 954 | ); 955 | } 956 | } 957 | 958 | outputs["output" as PortId] = { 959 | type: "string", 960 | value: llmResponseText, 961 | }; 962 | 963 | context.onPartialOutputs?.(outputs); 964 | } 965 | } 966 | 967 | if (!finalResponse) { 968 | throw new Error("No final response from Ollama!"); 969 | } 970 | 971 | outputs["prompt" as PortId] = { 972 | type: "string", 973 | value: prompt, 974 | }; 975 | 976 | outputs["messages-sent" as PortId] = { 977 | type: "chat-message[]", 978 | value: allMessages, 979 | }; 980 | 981 | outputs["all-messages" as PortId] = { 982 | type: "chat-message[]", 983 | value: [ 984 | ...allMessages, 985 | { 986 | type: "assistant", 987 | message: llmResponseText, 988 | function_call: undefined, 989 | }, 990 | ], 991 | }; 992 | 993 | if (data.advancedOutputs) { 994 | outputs["total-duration" as PortId] = { 995 | type: "number", 996 | value: finalResponse.total_duration, 997 | }; 998 | 999 | outputs["load-duration" as PortId] = { 1000 | type: "number", 1001 | value: finalResponse.load_duration, 1002 | }; 1003 | 1004 | outputs["sample-count" as PortId] = { 1005 | type: "number", 1006 | value: finalResponse.sample_count ?? 0, 1007 | }; 1008 | 1009 | outputs["sample-duration" as PortId] = { 1010 | type: "number", 1011 | value: finalResponse.sample_duration ?? 0, 1012 | }; 1013 | 1014 | outputs["prompt-eval-count" as PortId] = { 1015 | type: "number", 1016 | value: finalResponse.prompt_eval_count, 1017 | }; 1018 | 1019 | outputs["prompt-eval-duration" as PortId] = { 1020 | type: "number", 1021 | value: finalResponse.prompt_eval_duration, 1022 | }; 1023 | 1024 | outputs["eval-count" as PortId] = { 1025 | type: "number", 1026 | value: finalResponse.eval_count, 1027 | }; 1028 | 1029 | outputs["eval-duration" as PortId] = { 1030 | type: "number", 1031 | value: finalResponse.eval_duration, 1032 | }; 1033 | 1034 | outputs["tokens-per-second" as PortId] = { 1035 | type: "number", 1036 | value: finalResponse.eval_count / (finalResponse.eval_duration / 1e9), 1037 | }; 1038 | 1039 | outputs["parameters" as PortId] = { 1040 | type: "object", 1041 | value: parameters, 1042 | }; 1043 | } 1044 | 1045 | return outputs; 1046 | }, 1047 | }; 1048 | 1049 | return rivet.pluginNodeDefinition(impl, "Ollama Chat"); 1050 | }; 1051 | 1052 | function formatChatMessages(messages: ChatMessage[], format: string): string { 1053 | return match(format) 1054 | .with( 1055 | "", 1056 | () => 1057 | messages.map((message) => formatChatMessage(message, format)).join("\n") // Hopefully \n is okay? Instead of joining with empty string? 1058 | ) 1059 | .with( 1060 | "auto", 1061 | () => 1062 | messages.map((message) => formatChatMessage(message, format)).join("\n") 1063 | ) 1064 | .with("llama2", () => formatLlama2Instruct(messages)) 1065 | .otherwise(() => { 1066 | throw new Error(`Unsupported format: ${format}`); 1067 | }); 1068 | } 1069 | 1070 | function formatLlama2Instruct(messages: ChatMessage[]): string { 1071 | let inMessage = false; 1072 | let inInstruction = false; 1073 | let prompt = ""; 1074 | 1075 | for (const message of messages) { 1076 | if (!inMessage) { 1077 | prompt += ""; 1078 | inMessage = true; 1079 | } 1080 | 1081 | if (message.type === "system" || message.type === "user") { 1082 | if (inInstruction) { 1083 | prompt += "\n\n"; 1084 | } else { 1085 | prompt += "[INST] "; 1086 | inInstruction = true; 1087 | } 1088 | 1089 | prompt += formatChatMessage(message, "llama2"); 1090 | } else if (message.type === "assistant") { 1091 | if (inInstruction) { 1092 | prompt += " [/INST] "; 1093 | inInstruction = false; 1094 | } 1095 | 1096 | prompt += formatChatMessage(message, "llama2"); 1097 | prompt += " "; 1098 | inMessage = false; 1099 | } else { 1100 | throw new Error(`Unsupported message type: ${message.type}`); 1101 | } 1102 | } 1103 | 1104 | if (inInstruction) { 1105 | prompt += "[/INST] "; 1106 | inInstruction = false; 1107 | } 1108 | 1109 | // Make sure there's always an unterminated for the LLM to fill in itself 1110 | if (!inMessage) { 1111 | prompt += ""; 1112 | inMessage = true; 1113 | } 1114 | 1115 | return prompt; 1116 | } 1117 | 1118 | function formatChatMessage(message: ChatMessage, format: string): string { 1119 | return match(format) 1120 | .with("", (): string => chatMessageToString(message.message)) 1121 | .with("llama2", (): string => 1122 | match(message) 1123 | .with({ type: "user" }, (message) => 1124 | chatMessageToString(message.message) 1125 | ) 1126 | .with( 1127 | { type: "system" }, 1128 | (message) => 1129 | `<>\n${chatMessageToString(message.message)}\n<>\n` // Two more \n added by formatLlama2Instruct to make 3 total 1130 | ) 1131 | .with({ type: "assistant" }, (message) => 1132 | chatMessageToString(message.message) 1133 | ) 1134 | .otherwise(() => "") 1135 | ) 1136 | .otherwise(() => chatMessageToString(message.message)); 1137 | } 1138 | 1139 | function chatMessageToString( 1140 | messageParts: ChatMessageMessagePart[] | ChatMessageMessagePart 1141 | ): string { 1142 | const parts = Array.isArray(messageParts) ? messageParts : [messageParts]; 1143 | 1144 | const stringMessage = parts 1145 | .map((part): string => { 1146 | if (typeof part === "string") { 1147 | return part; 1148 | } else if (part.type === "url") { 1149 | return `(Image at ${part.url})`; 1150 | } else if (part.type === "image") { 1151 | return `(Embedded Image)`; 1152 | } else { 1153 | return `(Unknown Message Part)`; 1154 | } 1155 | }) 1156 | .join("\n\n"); 1157 | 1158 | return stringMessage; 1159 | } -------------------------------------------------------------------------------- /src/nodes/PullModelToOllamaNode.ts: -------------------------------------------------------------------------------- 1 | import type { 2 | ChartNode, 3 | NodeId, 4 | NodeInputDefinition, 5 | NodeUIData, 6 | PluginNodeImpl, 7 | PortId, 8 | Rivet, 9 | } from "@ironclad/rivet-core"; 10 | 11 | export type PullModelToOllamaNode = ChartNode< 12 | "pullModelToOllama", 13 | { 14 | modelName: string; 15 | useModelNameInput?: boolean; 16 | 17 | insecure: boolean; 18 | 19 | host?: string; 20 | useHostInput?: boolean; 21 | 22 | apiKey?: string; 23 | useApiKeyInput?: boolean; 24 | 25 | headers?: { key: string; value: string }[]; 26 | useHeadersInput?: boolean; 27 | } 28 | >; 29 | 30 | export const pullModelToOllama = (rivet: typeof Rivet) => { 31 | const impl: PluginNodeImpl = { 32 | create() { 33 | return { 34 | id: rivet.newId(), 35 | data: { 36 | modelName: "", 37 | useModelNameInput: true, 38 | insecure: false, 39 | }, 40 | title: "Pull Model to Ollama", 41 | type: "pullModelToOllama", 42 | visualData: { 43 | x: 0, 44 | y: 0, 45 | width: 250, 46 | }, 47 | } satisfies PullModelToOllamaNode; 48 | }, 49 | 50 | getInputDefinitions(data) { 51 | const inputs: NodeInputDefinition[] = []; 52 | 53 | if (data.useModelNameInput) { 54 | inputs.push({ 55 | id: "modelName" as PortId, 56 | dataType: "string", 57 | title: "Model Name", 58 | description: "The name of the model to pull from the ollama library.", 59 | }); 60 | } 61 | 62 | if (data.useHostInput) { 63 | inputs.push({ 64 | dataType: "string", 65 | id: "host" as PortId, 66 | title: "Host", 67 | description: 68 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 69 | }); 70 | } 71 | 72 | if (data.useApiKeyInput) { 73 | inputs.push({ 74 | dataType: "string", 75 | id: "apiKey" as PortId, 76 | title: "API Key", 77 | description: 78 | "Optional API key for authentication with Ollama instances that require it.", 79 | }); 80 | } 81 | 82 | if (data.useHeadersInput) { 83 | inputs.push({ 84 | dataType: 'object', 85 | id: 'headers' as PortId, 86 | title: 'Headers', 87 | description: 'Additional headers to send to the API.', 88 | }); 89 | } 90 | 91 | return inputs; 92 | }, 93 | 94 | getOutputDefinitions() { 95 | return [ 96 | { 97 | id: "modelName" as PortId, 98 | dataType: "string", 99 | title: "Model Name", 100 | description: "The name of the model that was pulled.", 101 | }, 102 | ]; 103 | }, 104 | 105 | getEditors() { 106 | return [ 107 | { 108 | type: "string", 109 | dataKey: "modelName", 110 | useInputToggleDataKey: "useModelNameInput", 111 | label: "Model Name", 112 | helperMessage: "The name of the model to get.", 113 | placeholder: "Model Name", 114 | }, 115 | { 116 | type: "toggle", 117 | dataKey: "insecure", 118 | label: "Insecure", 119 | helperMessage: 120 | "Allow insecure connections to the library. Only use this if you are pulling from your own library during development.", 121 | }, 122 | { 123 | type: "group", 124 | label: "Advanced", 125 | editors: [ 126 | { 127 | type: "string", 128 | label: "Host", 129 | dataKey: "host", 130 | useInputToggleDataKey: "useHostInput", 131 | helperMessage: 132 | "The host to use for the Ollama API. You can use this to replace with any Ollama-compatible API. Leave blank for the default: http://localhost:11434", 133 | }, 134 | { 135 | type: "string", 136 | label: "API Key", 137 | dataKey: "apiKey", 138 | useInputToggleDataKey: "useApiKeyInput", 139 | helperMessage: 140 | "Optional API key for authentication with Ollama instances that require it. Will be sent as Authorization Bearer token.", 141 | }, 142 | { 143 | type: "keyValuePair", 144 | label: "Headers", 145 | dataKey: "headers", 146 | useInputToggleDataKey: "useHeadersInput", 147 | keyPlaceholder: "Header Name", 148 | valuePlaceholder: "Header Value", 149 | helperMessage: 150 | "Additional headers to send to the API.", 151 | }, 152 | ], 153 | }, 154 | ]; 155 | }, 156 | 157 | getBody(data) { 158 | return rivet.dedent` 159 | Model: ${ 160 | data.useModelNameInput ? "(From Input)" : data.modelName || "Unset!" 161 | } 162 | `; 163 | }, 164 | 165 | getUIData(): NodeUIData { 166 | return { 167 | contextMenuTitle: "Pull Model to Ollama", 168 | group: "Ollama", 169 | infoBoxTitle: "Pull Model to Ollama Node", 170 | infoBoxBody: 171 | "Downloads a model from the Ollama library to the Ollama server.", 172 | }; 173 | }, 174 | 175 | async process(data, inputData, context) { 176 | const hostInput = rivet.getInputOrData(data, inputData, "host", "string"); 177 | const host = 178 | hostInput || 179 | context.getPluginConfig("host") || 180 | "http://localhost:11434"; 181 | 182 | if (!host.trim()) { 183 | throw new Error("No host set!"); 184 | } 185 | 186 | const apiKeyInput = rivet.getInputOrData( 187 | data, 188 | inputData, 189 | "apiKey", 190 | "string", 191 | ); 192 | const apiKey = apiKeyInput || context.getPluginConfig("apiKey"); 193 | 194 | const modelName = rivet.getInputOrData(data, inputData, "modelName"); 195 | 196 | const headers: Record = { 197 | "Content-Type": "application/json", 198 | }; 199 | 200 | if (apiKey && apiKey.trim()) { 201 | headers["Authorization"] = `Bearer ${apiKey}`; 202 | } 203 | 204 | // Add headers from data or input 205 | let additionalHeaders: Record = {}; 206 | if (data.useHeadersInput) { 207 | const headersInput = rivet.coerceTypeOptional( 208 | inputData["headers" as PortId], 209 | "object", 210 | ) as Record | undefined; 211 | if (headersInput) { 212 | additionalHeaders = headersInput; 213 | } 214 | } else if (data.headers) { 215 | additionalHeaders = data.headers.reduce( 216 | (acc, { key, value }) => { 217 | acc[key] = value; 218 | return acc; 219 | }, 220 | {} as Record, 221 | ); 222 | } 223 | 224 | Object.assign(headers, additionalHeaders); 225 | 226 | const response = await fetch(`${host}/api/pull`, { 227 | method: "POST", 228 | headers, 229 | body: JSON.stringify({ 230 | name: modelName, 231 | insecure: data.insecure, 232 | stream: true, 233 | }), 234 | }); 235 | 236 | if (!response.ok) { 237 | try { 238 | const body = await response.text(); 239 | throw new Error(`Error from Ollama: ${body}`); 240 | } catch (err) { 241 | throw new Error( 242 | `Error ${response.status} from Ollama: ${ 243 | rivet.getError(err).message 244 | }` 245 | ); 246 | } 247 | } 248 | 249 | // Stream the response to avoid fetch timeout 250 | const reader = response.body?.getReader(); 251 | 252 | if (!reader) { 253 | throw new Error("Response body was not readable."); 254 | } 255 | 256 | while (true) { 257 | const { done } = await reader.read(); 258 | 259 | if (done) { 260 | break; 261 | } 262 | 263 | // Nothing to do with the value right now 264 | } 265 | 266 | return { 267 | ["modelName" as PortId]: { 268 | type: "string", 269 | value: modelName, 270 | }, 271 | }; 272 | }, 273 | }; 274 | 275 | return rivet.pluginNodeDefinition(impl, "List Ollama Models"); 276 | }; 277 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ESNext", 4 | "module": "ESNext", 5 | "moduleResolution": "bundler", 6 | "esModuleInterop": true, 7 | "forceConsistentCasingInFileNames": true, 8 | "strict": true, 9 | "skipLibCheck": true, 10 | "noEmit": true 11 | } 12 | } 13 | --------------------------------------------------------------------------------