├── .gitignore ├── .npmignore ├── .husky └── pre-commit ├── tests ├── resources │ ├── test_model.pt │ └── torch_module.py ├── tensor.test.js ├── rand.test.js └── scriptmodule.test.js ├── .prettierrc.json ├── .editorconfig ├── tsconfig.json ├── .prettierignore ├── .eslintignore ├── examples ├── Operations.md ├── ScriptModule.md └── Tensor.md ├── lib ├── bindings.d.ts ├── torch.js ├── torch.d.ts └── index.ts ├── src ├── ATen.h ├── addon.cc ├── constants.cc ├── utils.h ├── constants.h ├── ScriptModule.h ├── Tensor.h ├── FunctionWorker.h ├── utils.cc ├── ATen.cc ├── Tensor.cc └── ScriptModule.cc ├── .vscode ├── settings.json └── c_cpp_properties.json ├── .github └── workflows │ ├── publish-npm.yml │ ├── ci.yml │ ├── publish-prebuild-test.yml │ ├── publish-prebuild.yml │ └── scripts │ ├── install_cuda_ubuntu.sh │ └── install_cuda_windows.ps1 ├── LICENSE ├── .eslintrc.json ├── package.json ├── README.md ├── patches └── 0001-Remove-native-image-support.patch └── CMakeLists.txt /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | build 3 | dist 4 | libtorch* 5 | prebuilds -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | build 3 | .vscode 4 | .github 5 | libtorch* 6 | prebuilds -------------------------------------------------------------------------------- /.husky/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | . "$(dirname "$0")/_/husky.sh" 3 | 4 | yarn lint-staged 5 | -------------------------------------------------------------------------------- /tests/resources/test_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arition/torch-js/HEAD/tests/resources/test_model.pt -------------------------------------------------------------------------------- /.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "trailingComma": "es5", 3 | "tabWidth": 2, 4 | "semi": true, 5 | "singleQuote": true 6 | } -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | [*.{js,ts}] 2 | indent_style = space 3 | indent_size = 2 4 | trim_trailing_whitespace = true 5 | insert_final_newline = true 6 | end_of_line = lf -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es6", 4 | "module": "commonjs", 5 | "declaration": true, 6 | "outDir": "./dist", 7 | "strict": true 8 | }, 9 | "include": [ 10 | "lib/**/*" 11 | ] 12 | } -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | # Ignore artifacts: 2 | build 3 | coverage 4 | dist 5 | patches 6 | src 7 | 8 | # Ignore config files: 9 | .* 10 | *.lock 11 | 12 | # Ignore JSON: 13 | *.json 14 | 15 | # Other stuff: 16 | node_modules/ 17 | .github/ 18 | *.log 19 | *.zip 20 | *.txt -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | # Ignore artifacts: 2 | build 3 | coverage 4 | dist 5 | patches 6 | src 7 | 8 | # Ignore config files: 9 | .* 10 | *.lock 11 | 12 | # Ignore JSON: 13 | *.json 14 | 15 | # Other stuff: 16 | node_modules/ 17 | .github/ 18 | *.log 19 | *.md 20 | CMakeLists.txt -------------------------------------------------------------------------------- /examples/Operations.md: -------------------------------------------------------------------------------- 1 | # Sample Operation Usage 2 | 3 | ## forward 4 | 5 | The script_module module allows for a forward call, that recieves two tensors and returns a Promise as a result. 6 | 7 | ```js 8 | const a = torch.rand(1, 5); 9 | const b = torch.rand(1, 5); 10 | const res = await script_module.forward(a, b); 11 | ``` 12 | -------------------------------------------------------------------------------- /lib/bindings.d.ts: -------------------------------------------------------------------------------- 1 | type BindingFunction = (mod: string) => any; 2 | interface BindingInterface extends BindingFunction { 3 | getRoot: (file: string) => string; 4 | getFileName: (callingFile?: string) => string; 5 | } 6 | declare let bindings: BindingInterface; 7 | 8 | declare module 'bindings' { 9 | export = bindings; 10 | } 11 | -------------------------------------------------------------------------------- /lib/torch.js: -------------------------------------------------------------------------------- 1 | const torch = require('bindings')('torch-js'); 2 | 3 | const bindings = require('bindings'); 4 | const path = require('path'); 5 | 6 | const moduleRoot = bindings.getRoot(bindings.getFileName()); 7 | const buildFolder = path.join(moduleRoot, 'build', 'Release'); 8 | torch.initenv(`${buildFolder};${process.env.path}`); 9 | 10 | module.exports = torch; 11 | -------------------------------------------------------------------------------- /src/ATen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchjs 6 | { 7 | namespace aten 8 | { 9 | 10 | Napi::Object Init(Napi::Env env, Napi::Object exports); 11 | 12 | Napi::Value rand(const Napi::CallbackInfo &info); 13 | 14 | Napi::Value initenv(const Napi::CallbackInfo &info); 15 | 16 | } // namespace aten 17 | } // namespace torchjs -------------------------------------------------------------------------------- /src/addon.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "ATen.h" 4 | #include "ScriptModule.h" 5 | #include "Tensor.h" 6 | #include "constants.h" 7 | 8 | using namespace torchjs; 9 | 10 | Napi::Object Init(Napi::Env env, Napi::Object exports) 11 | { 12 | aten::Init(env, exports); 13 | constants::Init(env, exports); 14 | ScriptModule::Init(env, exports); 15 | Tensor::Init(env, exports); 16 | return exports; 17 | } 18 | 19 | NODE_API_MODULE(torchjs, Init); -------------------------------------------------------------------------------- /examples/ScriptModule.md: -------------------------------------------------------------------------------- 1 | # Sample ScriptModule Usage 2 | 3 | ## forward 4 | 5 | The script_module may be used to load up any model in TorchScript, and perform operations on it 6 | 7 | ```js 8 | const torch = require("torch-js"); 9 | const modelPath = `test_model.pt`; 10 | const model = new torch.ScriptModule(testModelPath); 11 | const inputA = torch.rand([1, 5]); 12 | const inputB = torch.rand([1, 5]); 13 | const res = await model.forward(inputA, inputB); 14 | ``` 15 | -------------------------------------------------------------------------------- /src/constants.cc: -------------------------------------------------------------------------------- 1 | #include "constants.h" 2 | 3 | #include 4 | 5 | namespace torchjs 6 | { 7 | namespace constants 8 | { 9 | Napi::Object Init(Napi::Env env, Napi::Object exports) 10 | { 11 | exports.Set(kFloat32, static_cast(torch::kFloat32)); 12 | exports.Set(kFloat64, static_cast(torch::kFloat64)); 13 | exports.Set(kInt32, static_cast(torch::kInt32)); 14 | return exports; 15 | } 16 | } // namespace constants 17 | } // namespace torchjs -------------------------------------------------------------------------------- /tests/resources/torch_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TestModule(torch.nn.Module): 5 | def __init__(self): 6 | super(TestModule, self).__init__() 7 | 8 | def forward(self, input1, input2): 9 | return input1 + input2 10 | 11 | 12 | test_module = TestModule() 13 | 14 | example = (torch.rand(1, 4), torch.rand(1, 4)) 15 | traced_script_module = torch.jit.trace(test_module, example) 16 | 17 | traced_script_module.save("test_model.pt") 18 | 19 | loaded = torch.jit.load("test_model.pt") 20 | 21 | loaded(*example) 22 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace torchjs 7 | { 8 | 9 | // TODO: Switch to Napi::BigIntArray when it's more stable 10 | using ShapeArrayType = Napi::Array; 11 | 12 | torch::TensorOptions parseTensorOptions(const Napi::Object &); 13 | 14 | ShapeArrayType tensorShapeToArray(Napi::Env, const torch::Tensor &); 15 | std::vector shapeArrayToVector(const ShapeArrayType &); 16 | 17 | template 18 | torch::ScalarType scalarType(); 19 | 20 | } // namespace torchjs -------------------------------------------------------------------------------- /src/constants.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace torchjs 8 | { 9 | namespace constants 10 | { 11 | const static std::string kData = "data"; 12 | const static std::string kDtype = "dtype"; 13 | const static std::string kShape = "shape"; 14 | 15 | const static std::string kFloat32 = "float32"; 16 | const static std::string kFloat64 = "float64"; 17 | const static std::string kInt32 = "int32"; 18 | 19 | Napi::Object Init(Napi::Env env, Napi::Object exports); 20 | } // namespace constants 21 | } // namespace torchjs 22 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "__config": "cpp", 4 | "__nullptr": "cpp", 5 | "exception": "cpp", 6 | "initializer_list": "cpp", 7 | "new": "cpp", 8 | "stdexcept": "cpp", 9 | "typeinfo": "cpp", 10 | "algorithm": "cpp", 11 | "type_traits": "cpp", 12 | "vector": "cpp", 13 | "__bit_reference": "cpp", 14 | "__debug": "cpp", 15 | "__functional_03": "cpp", 16 | "__functional_base": "cpp", 17 | "__functional_base_03": "cpp", 18 | "__split_buffer": "cpp", 19 | "__string": "cpp", 20 | "__tuple": "cpp", 21 | "atomic": "cpp", 22 | "cstddef": "cpp", 23 | "functional": "cpp", 24 | "iosfwd": "cpp", 25 | "iterator": "cpp", 26 | "limits": "cpp", 27 | "memory": "cpp", 28 | "string": "cpp", 29 | "string_view": "cpp", 30 | "tuple": "cpp", 31 | "utility": "cpp", 32 | "hashtable": "cpp", 33 | "unordered_map": "cpp", 34 | "unordered_set": "cpp" 35 | } 36 | } -------------------------------------------------------------------------------- /src/ScriptModule.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace torchjs 9 | { 10 | class ScriptModule : public Napi::ObjectWrap 11 | { 12 | public: 13 | static Napi::Object Init(Napi::Env env, Napi::Object exports); 14 | ScriptModule(const Napi::CallbackInfo &info); 15 | 16 | private: 17 | static Napi::FunctionReference constructor; 18 | torch::jit::script::Module module_; 19 | std::string path_; 20 | 21 | Napi::Value forward(const Napi::CallbackInfo &info); 22 | Napi::Value toString(const Napi::CallbackInfo &info); 23 | Napi::Value cpu(const Napi::CallbackInfo &info); 24 | Napi::Value cuda(const Napi::CallbackInfo &info); 25 | static Napi::Value isCudaAvailable(const Napi::CallbackInfo &info); 26 | static Napi::Value IValueToJSType(Napi::Env env, const c10::IValue &iValue); 27 | static c10::IValue JSTypeToIValue(Napi::Env env, const Napi::Value &jsValue); 28 | }; 29 | 30 | } // namespace torchjs -------------------------------------------------------------------------------- /.github/workflows/publish-npm.yml: -------------------------------------------------------------------------------- 1 | name: publish-npm 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 9 | jobs: 10 | # This workflow contains a single job called "build" 11 | build: 12 | # The type of runner that the job will run on 13 | runs-on: ubuntu-latest 14 | if: "!contains(github.event.head_commit.message, 'publish-npm skip')" 15 | 16 | # Steps represent a sequence of tasks that will be executed as part of the job 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Install node 22 | uses: actions/setup-node@v1 23 | with: 24 | node-version: "16" 25 | registry-url: "https://registry.npmjs.org" 26 | 27 | - name: Install packages 28 | run: yarn upgrade --dev 29 | 30 | - name: Build 31 | run: yarn build 32 | 33 | - name: Publish 34 | run: yarn publish 35 | env: 36 | NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} 37 | -------------------------------------------------------------------------------- /src/Tensor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace torchjs 9 | { 10 | 11 | class Tensor : public Napi::ObjectWrap 12 | { 13 | public: 14 | static Napi::Object Init(Napi::Env, Napi::Object exports); 15 | static Napi::Object FromTensor(Napi::Env, const torch::Tensor &); 16 | static bool IsInstance(Napi::Object &); 17 | Tensor(const Napi::CallbackInfo &); 18 | virtual void Finalize(Napi::Env env) override; 19 | 20 | torch::Tensor tensor(); 21 | 22 | private: 23 | torch::Tensor tensor_; 24 | static Napi::FunctionReference constructor; 25 | Napi::Value toString(const Napi::CallbackInfo &); 26 | Napi::Value toObject(const Napi::CallbackInfo &); 27 | static Napi::Value fromObject(const Napi::CallbackInfo &); 28 | Napi::Value cpu(const Napi::CallbackInfo &info); 29 | Napi::Value cuda(const Napi::CallbackInfo &info); 30 | Napi::Value clone(const Napi::CallbackInfo &info); 31 | void free(const Napi::CallbackInfo &info); 32 | bool is_free; 33 | }; 34 | 35 | } // namespace torchjs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Kittipat Virochsiri, arition 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /examples/Tensor.md: -------------------------------------------------------------------------------- 1 | # Sample Tensor Creation 2 | 3 | ## Rand 4 | 5 | The rand function may be used to generate a tensor of a given shape and configuration, populated with random values 6 | 7 | ### Calling rand with variable number of arguments 8 | 9 | ```js 10 | const a = torch.rand(1, 5); 11 | ``` 12 | 13 | ### Calling rand with shape array 14 | 15 | ```js 16 | const a = torch.rand([1, 5]); 17 | ``` 18 | 19 | ### Calling rand using option parsing 20 | 21 | ```js 22 | const a = torch.rand([1, 5], { 23 | dtype: torch.float64, 24 | }); 25 | ``` 26 | 27 | ## Tensor 28 | 29 | The tensor function may be used to create a tensor given an array of values that the tensor would consist of. 30 | 31 | ### Calling tensor given a nested array of tensor values 32 | 33 | ```js 34 | const a = torch.tensor([ 35 | [0.1, 0.2, 0.3], 36 | [0.4, 0.5, 0.6], 37 | ]); 38 | ``` 39 | 40 | ### Calling tensor given a nested array of tensor values and a specified option type 41 | 42 | ```js 43 | const a = torch.tensor( 44 | [ 45 | [0.1, 0.2, 0.3], 46 | [0.4, 0.5, 0.6], 47 | ], 48 | { 49 | dtype: torch.float64, 50 | } 51 | ); 52 | ``` 53 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Mac", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "../libtorch/**" 8 | ], 9 | "defines": [], 10 | "macFrameworkPath": [ 11 | "/System/Library/Frameworks", 12 | "/Library/Frameworks" 13 | ], 14 | "compilerPath": "/usr/bin/clang", 15 | "cStandard": "c11", 16 | "cppStandard": "c++11", 17 | "intelliSenseMode": "clang-x64" 18 | }, 19 | { 20 | "name": "Win32", 21 | "includePath": [ 22 | "${workspaceFolder}/**" 23 | ], 24 | "defines": [ 25 | "_DEBUG", 26 | "UNICODE", 27 | "_UNICODE" 28 | ], 29 | "windowsSdkVersion": "10.0.18362.0", 30 | "compilerPath": "C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.27.29110/bin/Hostx64/x64/cl.exe", 31 | "cStandard": "c11", 32 | "cppStandard": "c++17", 33 | "intelliSenseMode": "msvc-x64", 34 | "configurationProvider": "ms-vscode.cmake-tools" 35 | } 36 | ], 37 | "version": 4 38 | } -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | 8 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 9 | jobs: 10 | # This workflow contains a single job called "build" 11 | build: 12 | # The type of runner that the job will run on 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: [macos-latest, ubuntu-latest, windows-latest] 17 | if: "!contains(github.event.head_commit.message, 'ci skip')" 18 | 19 | # Steps represent a sequence of tasks that will be executed as part of the job 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v2 23 | 24 | - name: Install node 25 | uses: actions/setup-node@v1 26 | with: 27 | node-version: "16" 28 | 29 | - name: Install Python 30 | uses: actions/setup-python@v2 31 | with: 32 | python-version: "3.10" 33 | 34 | - name: Install packages 35 | run: yarn upgrade --dev 36 | 37 | - name: Prebuild 38 | run: yarn build-prebuild 39 | 40 | - name: Build 41 | run: yarn build 42 | 43 | - name: Run test suite 44 | run: yarn test 45 | 46 | - name: Upload artifact 47 | uses: actions/upload-artifact@v2 48 | with: 49 | name: artifact-${{ matrix.os }} 50 | path: prebuilds/**/*.tar.gz 51 | -------------------------------------------------------------------------------- /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "node": true 4 | }, 5 | "extends": [ 6 | "airbnb-base", 7 | "plugin:prettier/recommended" 8 | ], 9 | "overrides": [ 10 | { 11 | "files": [ 12 | "**/*.test.js" 13 | ], 14 | "env": { 15 | "node": true, 16 | "jest": true 17 | } 18 | }, 19 | { 20 | "files": [ 21 | "**/*.ts", 22 | "**/*.tsx" 23 | ], 24 | "env": { 25 | "node": true 26 | }, 27 | "extends": [ 28 | "airbnb-base", 29 | "airbnb-typescript/base", 30 | "plugin:@typescript-eslint/eslint-recommended", 31 | "plugin:@typescript-eslint/recommended", 32 | "plugin:prettier/recommended" 33 | ], 34 | "globals": { 35 | "Atomics": "readonly", 36 | "SharedArrayBuffer": "readonly" 37 | }, 38 | "parser": "@typescript-eslint/parser", 39 | "parserOptions": { 40 | "sourceType": "module", 41 | "project": "./tsconfig.json" 42 | }, 43 | "plugins": [ 44 | "@typescript-eslint" 45 | ], 46 | "rules": { 47 | "@typescript-eslint/no-explicit-any": 0, 48 | "@typescript-eslint/no-useless-constructor": 0, 49 | "no-param-reassign": 0, 50 | "max-classes-per-file": [ 51 | "error", 52 | 3 53 | ] 54 | } 55 | } 56 | ], 57 | "rules": { 58 | "no-param-reassign": 0, 59 | "max-classes-per-file": [ 60 | "error", 61 | 3 62 | ] 63 | } 64 | } -------------------------------------------------------------------------------- /src/FunctionWorker.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace torchjs 7 | { 8 | template 9 | class FunctionWorker : public Napi::AsyncWorker 10 | { 11 | public: 12 | FunctionWorker(Napi::Env env, std::function _workFunction, std::function _postWorkFunction); 13 | ~FunctionWorker() {} 14 | void Execute() override; 15 | void OnOK() override; 16 | void OnError(const Napi::Error &e) override; 17 | Napi::Promise GetPromise(); 18 | 19 | private: 20 | Napi::Promise::Deferred promise; 21 | std::function workFunction; 22 | std::function postWorkFunction; 23 | T value; 24 | }; 25 | 26 | template 27 | FunctionWorker::FunctionWorker(Napi::Env env, std::function _workFunction, std::function _postWorkFunction) 28 | : promise(Napi::Promise::Deferred::New(env)), workFunction(_workFunction), postWorkFunction(_postWorkFunction), AsyncWorker(env) {} 29 | 30 | template 31 | void FunctionWorker::Execute() 32 | { 33 | value = workFunction(); 34 | } 35 | 36 | template 37 | void FunctionWorker::OnOK() 38 | { 39 | auto result = postWorkFunction(Env(), value); 40 | promise.Resolve(result); 41 | } 42 | 43 | template 44 | void FunctionWorker::OnError(const Napi::Error &e) 45 | { 46 | promise.Reject(e.Value()); 47 | } 48 | 49 | template 50 | Napi::Promise FunctionWorker::GetPromise() 51 | { 52 | return promise.Promise(); 53 | } 54 | } // namespace torchjs -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@arition/torch-js", 3 | "version": "0.14.0", 4 | "main": "dist/index.js", 5 | "types": "dist/index.d.ts", 6 | "author": "Kittipat Virochsiri, arition, raghavmecheri", 7 | "license": "MIT", 8 | "repository": { 9 | "type": "git", 10 | "url": "git+https://github.com/arition/torch-js.git" 11 | }, 12 | "scripts": { 13 | "install": "prebuild-install -r napi || cmake-js compile", 14 | "build": "node build.js", 15 | "compile": "cmake-js compile", 16 | "rebuild": "cmake-js rebuild", 17 | "test": "jest --forceExit --verbose", 18 | "lint": "eslint ./", 19 | "build-prebuild": "prebuild -t 3 -r napi --include-regex \"\\.[nsd][oyl][dl]?[ie]?b?$\" --backend cmake-js", 20 | "prepare": "husky install" 21 | }, 22 | "lint-staged": { 23 | "**/*": [ 24 | "yarn lint", 25 | "prettier --write" 26 | ] 27 | }, 28 | "devDependencies": { 29 | "@arition/prebuild": "^11.0.3", 30 | "@jest/globals": "^27.5.1", 31 | "@types/node": "^16", 32 | "@typescript-eslint/eslint-plugin": "^5.20.0", 33 | "@typescript-eslint/parser": "^5.20.0", 34 | "eslint": "^8.14.0", 35 | "eslint-config-airbnb-base": "^15.0.0", 36 | "eslint-config-airbnb-typescript": "^17.0.0", 37 | "eslint-config-prettier": "^8.5.0", 38 | "eslint-plugin-import": "^2.26.0", 39 | "eslint-plugin-prettier": "^4.0.0", 40 | "husky": "^7.0.4", 41 | "jest": "^27.5.1", 42 | "lint-staged": "^12.4.0", 43 | "prettier": "^2.6.2" 44 | }, 45 | "dependencies": { 46 | "@arition/prebuild-install": "^7.1.0", 47 | "bindings": "^1.5.0", 48 | "cmake-js": "^6.3.0", 49 | "node-addon-api": "^4.3.0", 50 | "typescript": "^4.6.3" 51 | }, 52 | "binary": { 53 | "napi_versions": [ 54 | 3 55 | ] 56 | } 57 | } -------------------------------------------------------------------------------- /lib/torch.d.ts: -------------------------------------------------------------------------------- 1 | // Possible values for dtype 2 | export const float32: number; 3 | export const float64: number; 4 | export const int32: number; 5 | 6 | export interface ObjectTensor { 7 | data: Float32Array | Float64Array | Int32Array; 8 | shape: number[]; 9 | } 10 | 11 | export class Tensor { 12 | static fromObject({ data, shape }: ObjectTensor): Tensor; 13 | 14 | toObject(): ObjectTensor; 15 | 16 | toString(): string; 17 | 18 | cpu(): Tensor; 19 | 20 | cuda(): Tensor; 21 | 22 | clone(): Tensor; 23 | 24 | /** 25 | * Free the underlying tensor resources. 26 | * 27 | * DO NOT use the object again after calling the method. 28 | */ 29 | free(): void; 30 | } 31 | 32 | // raghavmecheri - FIXTHIS: Get rid of circular dependancies if possible 33 | 34 | // eslint-disable-next-line @typescript-eslint/no-empty-interface, no-use-before-define 35 | interface TorchTypesArray extends Array {} 36 | interface TorchTypesRecord // eslint-disable-line @typescript-eslint/no-empty-interface 37 | extends Record {} // eslint-disable-line no-use-before-define 38 | 39 | type TorchTypes = 40 | | TorchTypesArray 41 | | TorchTypesRecord 42 | | string 43 | | number 44 | | boolean 45 | | Tensor; 46 | 47 | export class ScriptModule { 48 | constructor(path: string); 49 | 50 | forward(input: TorchTypes): Promise; 51 | 52 | forward(...inputs: TorchTypes[]): Promise; 53 | 54 | toString(): string; 55 | 56 | cpu(): ScriptModule; 57 | 58 | cuda(): ScriptModule; 59 | 60 | static isCudaAvailable(): boolean; 61 | } 62 | 63 | export function rand(shape: number[], options?: { dtype?: number }): Tensor; 64 | // The function actually support options at the end but TS can't express that 65 | export function rand( 66 | ...shape: number[] /** , options?: {dtype?: number} */ 67 | ): Tensor; 68 | -------------------------------------------------------------------------------- /src/utils.cc: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | #include "constants.h" 4 | 5 | namespace torchjs 6 | { 7 | 8 | using namespace constants; 9 | 10 | torch::TensorOptions parseTensorOptions(const Napi::Object &obj) 11 | { 12 | torch::TensorOptions options; 13 | if (obj.Has(kDtype)) 14 | { 15 | options = options.dtype(static_cast(obj.Get(kDtype).As().Int32Value())) 16 | .requires_grad(false); 17 | } 18 | return options; 19 | } 20 | 21 | ShapeArrayType tensorShapeToArray(Napi::Env env, const torch::Tensor &tensor) 22 | { 23 | Napi::EscapableHandleScope scope(env); 24 | auto shape = tensor.sizes(); 25 | auto shape_array = ShapeArrayType::New(env, shape.size()); 26 | // Int64 is not well supported in N-API so this can't be memcpy'ed 27 | for (size_t i = 0; i < shape.size(); ++i) 28 | { 29 | shape_array[i] = shape[i]; 30 | } 31 | return scope.Escape(shape_array).As(); 32 | } 33 | 34 | std::vector shapeArrayToVector(const ShapeArrayType &size_array) 35 | { 36 | auto array_len = size_array.Length(); 37 | std::vector sizes; 38 | for (decltype(array_len) i = 0; i < array_len; ++i) 39 | { 40 | Napi::HandleScope scope(size_array.Env()); 41 | sizes.push_back(static_cast(size_array[i]) 42 | .As() 43 | .Int64Value()); 44 | } 45 | return sizes; 46 | } 47 | 48 | template <> 49 | torch::ScalarType scalarType() { return torch::kFloat32; } 50 | template <> 51 | torch::ScalarType scalarType() { return torch::kFloat64; } 52 | template <> 53 | torch::ScalarType scalarType() { return torch::kInt32; } 54 | template <> 55 | torch::ScalarType scalarType() { return torch::kInt64; } 56 | 57 | } // namespace torchjs -------------------------------------------------------------------------------- /lib/index.ts: -------------------------------------------------------------------------------- 1 | import assert = require('assert'); 2 | import torch = require('./torch'); 3 | 4 | export * from './torch'; 5 | 6 | // eslint-disable-next-line no-use-before-define 7 | type Matrix = number[] | NestedArray; 8 | type NestedArray = Array; 9 | 10 | function flattenArray(arr: Matrix) { 11 | const shape = [arr.length]; 12 | while (arr.length > 0 && Array.isArray(arr[0])) { 13 | shape.push((arr[0] as Matrix).length); 14 | // Forcing Array is required to make TS happy; otherwise, TS2349 15 | arr = (arr as Array).reduce((acc, val) => acc.concat(val), []); 16 | 17 | // Running assert in every step to make sure that shape is regular 18 | const numel = shape.reduce((acc, cur) => acc * cur); 19 | assert.strictEqual(arr.length, numel); 20 | } 21 | return { 22 | data: arr as number[], 23 | shape, 24 | }; 25 | } 26 | 27 | function typedArrayType(dtype: number) { 28 | switch (dtype) { 29 | case torch.float32: 30 | return Float32Array; 31 | case torch.float64: 32 | return Float64Array; 33 | case torch.int32: 34 | return Int32Array; 35 | default: 36 | throw new TypeError('Unsupported dtype'); 37 | } 38 | } 39 | 40 | type TensorDataCompatible = Float32Array | Float64Array | Int32Array | Matrix; 41 | 42 | export function tensor( 43 | data: TensorDataCompatible, 44 | { dtype = torch.float32, shape }: { dtype?: number; shape?: number[] } = {} 45 | ): torch.Tensor { 46 | if (Array.isArray(data)) { 47 | const arrayAndShape = flattenArray(data); 48 | const typedArray = new (typedArrayType(dtype))(arrayAndShape.data); 49 | return torch.Tensor.fromObject({ 50 | data: typedArray, 51 | shape: arrayAndShape.shape, 52 | }); 53 | } 54 | if (shape === undefined) { 55 | shape = [data.length]; 56 | } 57 | return torch.Tensor.fromObject({ 58 | data, 59 | shape, 60 | }); 61 | } 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deprecated 2 | 3 | The whole repo will not be regularly updated. If someone is still interested in using the torch-js, feel free to open an issue and wait for the author to update. PRs are more welcome. 4 | 5 | # TorchJS 6 | 7 | [![npm version](https://badge.fury.io/js/%40arition%2Ftorch-js.svg)](https://badge.fury.io/js/%40arition%2Ftorch-js) 8 | 9 | TorchJS is a JS binding for PyTorch. Its primary objective is to allow running [Torch Script](https://pytorch.org/docs/master/jit.html) inside Node.js program. Complete binding of libtorch is possible but is out-of-scope at the moment. 10 | 11 | ## Changes after fork 12 | 13 | - Add support for `List` (Javascript `Array`), `Dict` (Javascript `Object`), `String`, `float` (Javascript `number`) as inputs and outputs. 14 | 15 | - Add CUDA support. 16 | 17 | - Add ops from torchvision. 18 | 19 | - Add async support for `forward` function. 20 | 21 | ## Install 22 | 23 | To install the forked version, you can install it from npm: 24 | 25 | ```bash 26 | yarn add torch-js@npm:@arition/torch-js 27 | ``` 28 | 29 | ## Example 30 | 31 | In `tests/resources/torch_module.py`, you will find the defination of our test module and the code to generate the trace file. 32 | 33 | ```python 34 | class TestModule(torch.nn.Module): 35 | def __init__(self): 36 | super(TestModule, self).__init__() 37 | 38 | def forward(self, input1, input2): 39 | return input1 + input2 40 | ``` 41 | 42 | Once you have the trace file, it may be loaded into NodeJS like this 43 | 44 | ```javascript 45 | const torch = require("torch-js"); 46 | const modelPath = `test_model.pt`; 47 | const model = new torch.ScriptModule(testModelPath); 48 | const inputA = torch.rand([1, 5]); 49 | const inputB = torch.rand([1, 5]); 50 | const res = await model.forward(inputA, inputB); 51 | ``` 52 | 53 | More examples regarding tensor creation, ScriptModule operations, and loading models can be found in our [examples](./examples) folder. 54 | -------------------------------------------------------------------------------- /src/ATen.cc: -------------------------------------------------------------------------------- 1 | #include "ATen.h" 2 | 3 | #include 4 | 5 | #include "Tensor.h" 6 | #include "constants.h" 7 | #include "utils.h" 8 | #include 9 | 10 | namespace torchjs 11 | { 12 | 13 | using namespace constants; 14 | 15 | namespace aten 16 | { 17 | 18 | Napi::Object Init(Napi::Env env, Napi::Object exports) 19 | { 20 | exports.Set("rand", Napi::Function::New(env, rand)); 21 | exports.Set("initenv", Napi::Function::New(env, initenv)); 22 | return exports; 23 | } 24 | 25 | Napi::Value initenv(const Napi::CallbackInfo &info) 26 | { 27 | if (info.Length() != 1 || !info[0].IsString()) 28 | throw Napi::Error::New(info.Env(), "Only support one string as parameter"); 29 | auto path = info[0].ToString().Utf8Value(); 30 | #ifdef _WIN32 31 | return Napi::Boolean::New(info.Env(), putenv(std::string("PATH=" + path).c_str()) == 0); 32 | #else 33 | return Napi::Boolean::New(info.Env(), setenv("PATH", path.c_str(), 1) == 0); 34 | #endif 35 | } 36 | 37 | Napi::Value rand(const Napi::CallbackInfo &info) 38 | { 39 | Napi::EscapableHandleScope scope(info.Env()); 40 | auto len = info.Length(); 41 | 42 | torch::TensorOptions options; 43 | if (len > 1 && info[len - 1].IsObject()) 44 | { 45 | auto option_obj = info[len - 1].As(); 46 | options = parseTensorOptions(option_obj); 47 | --len; 48 | } 49 | 50 | std::vector sizes; 51 | 52 | // TODO: Support int64_t sizes 53 | if (len == 1 && info[0].IsArray()) 54 | { 55 | auto size_array = info[0].As(); 56 | sizes = shapeArrayToVector(size_array); 57 | } 58 | else 59 | { 60 | for (size_t i = 0; i < len; ++i) 61 | { 62 | sizes.push_back(info[i].As().Int64Value()); 63 | } 64 | } 65 | auto rand_tensor = torch::rand(sizes, options); 66 | return scope.Escape(Tensor::FromTensor(info.Env(), rand_tensor)); 67 | } 68 | 69 | } // namespace aten 70 | } // namespace torchjs 71 | -------------------------------------------------------------------------------- /tests/tensor.test.js: -------------------------------------------------------------------------------- 1 | const torch = require('../dist'); 2 | 3 | describe('Tensor creation', () => { 4 | test('Tensor creation using valid array (1-D)', () => { 5 | const a = torch.tensor([[1, 2, 3, 4, 5, 6]]).toObject(); 6 | expect(a.data.length).toBe(6); 7 | expect(a.shape).toMatchObject([1, 6]); 8 | 9 | const b = torch.tensor([[1, 2, 3]]).toObject(); 10 | expect(b.data.length).toBe(3); 11 | expect(b.shape).toMatchObject([1, 3]); 12 | }); 13 | 14 | test('Tensor creation using valid array (multi-dimensional)', () => { 15 | const a = torch 16 | .tensor([ 17 | [0.1, 0.2, 0.3], 18 | [0.4, 0.5, 0.6], 19 | ]) 20 | .toObject(); 21 | expect(a.data.length).toBe(6); 22 | expect(a.shape).toMatchObject([2, 3]); 23 | 24 | const b = torch 25 | .tensor([ 26 | [0.1, 0.2, 0.3], 27 | [0.4, 0.5, 0.6], 28 | [0.7, 0.8, 0.9], 29 | ]) 30 | .toObject(); 31 | expect(b.data.length).toBe(9); 32 | expect(b.shape).toMatchObject([3, 3]); 33 | }); 34 | 35 | test('Tensor creation using valid object', () => { 36 | const a = torch 37 | .tensor(new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]), { 38 | shape: [1, 5], 39 | }) 40 | .toObject(); 41 | expect(a.data.length).toBe(5); 42 | expect(a.shape).toMatchObject([1, 5]); 43 | 44 | const b = torch 45 | .tensor(new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), { 46 | shape: [3, 3], 47 | }) 48 | .toObject(); 49 | expect(b.data.length).toBe(9); 50 | expect(b.shape).toMatchObject([3, 3]); 51 | }); 52 | 53 | test('Tensor creation using invalid params', () => { 54 | const t = () => torch.tensor(); 55 | const t2 = () => torch.tensor(123); 56 | const t3 = () => torch.tensor(true); 57 | 58 | expect(t).toThrow("Cannot read properties of undefined (reading 'length')"); 59 | expect(t2).toThrow('Invalid argument'); 60 | expect(t3).toThrow('Invalid argument'); 61 | }); 62 | 63 | test('Tensor clone', () => { 64 | const t = torch.rand(2, 3); 65 | const tObject = t.toObject(); 66 | const tCloneObject = t.clone().toObject(); 67 | expect(tObject).toEqual(tCloneObject); 68 | }); 69 | }); 70 | -------------------------------------------------------------------------------- /tests/rand.test.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Test module for Random tensor creation 3 | * TODO: Check why a.shape is of type object, instead of type Array 4 | * TODO: Implement toBeInstanceOf checks for data children of objects 5 | */ 6 | const torch = require('../dist'); 7 | 8 | describe('Random tensor creation', () => { 9 | test('Random tensor creation using variable number of arguements', () => { 10 | const a = torch.rand(1, 5).toObject(); 11 | expect(a.data.length).toBe(5); 12 | expect(a.shape).toMatchObject([1, 5]); 13 | 14 | const b = torch.rand(2, 5).toObject(); 15 | expect(b.data.length).toBe(10); 16 | expect(b.shape).toMatchObject([2, 5]); 17 | 18 | const c = torch.rand(2, 3).toObject(); 19 | expect(c.data.length).toBe(6); 20 | expect(c.shape).toMatchObject([2, 3]); 21 | }); 22 | 23 | test('Random tensor creation using shape array', () => { 24 | const a = torch.rand([1, 5]).toObject(); 25 | expect(a.data.length).toBe(5); 26 | expect(a.shape).toMatchObject([1, 5]); 27 | 28 | const b = torch.rand([2, 5]).toObject(); 29 | expect(b.data.length).toBe(10); 30 | expect(b.shape).toMatchObject([2, 5]); 31 | 32 | const c = torch.rand([2, 3]).toObject(); 33 | expect(c.data.length).toBe(6); 34 | expect(c.shape).toMatchObject([2, 3]); 35 | }); 36 | 37 | test('Random tensor creation using option parsing', () => { 38 | const a = torch 39 | .rand([1, 5], { 40 | dtype: torch.float64, 41 | }) 42 | .toObject(); 43 | expect(a.data.length).toBe(5); 44 | expect(a.shape).toMatchObject([1, 5]); 45 | }); 46 | 47 | test('Random tensor creation using single param', () => { 48 | const a = torch.rand(1).toObject(); 49 | expect(a.data.length).toBe(1); 50 | expect(a.shape.length).toBe(1); 51 | }); 52 | 53 | test('Random tensor creation using no params', () => { 54 | const a = torch.rand().toObject(); 55 | expect(a.data.length).toBe(1); 56 | expect(a.shape.length).toBe(0); 57 | }); 58 | 59 | test('Random tensor creation using invalid params', () => { 60 | const t = () => torch.rand('Hello World'); 61 | const t2 = () => torch.rand(true); 62 | expect(t).toThrow(new Error('A number was expected')); 63 | expect(t2).toThrow(new Error('A number was expected')); 64 | }); 65 | }); 66 | -------------------------------------------------------------------------------- /patches/0001-Remove-native-image-support.patch: -------------------------------------------------------------------------------- 1 | From cb1e7910b785fdb14a6deb52631c9415c49b5089 Mon Sep 17 00:00:00 2001 2 | From: arition 3 | Date: Tue, 1 Dec 2020 00:17:44 -0800 4 | Subject: [PATCH] Remove native image support 5 | 6 | --- 7 | CMakeLists.txt | 16 ++++++++++------ 8 | 1 file changed, 10 insertions(+), 6 deletions(-) 9 | 10 | diff --git a/CMakeLists.txt b/CMakeLists.txt 11 | index 4e222dbfa..404f5d4a7 100644 12 | --- a/CMakeLists.txt 13 | +++ b/CMakeLists.txt 14 | @@ -24,10 +24,6 @@ if (USE_PYTHON) 15 | endif() 16 | 17 | find_package(Torch REQUIRED) 18 | -find_package(PNG REQUIRED) 19 | -find_package(JPEG REQUIRED) 20 | -add_definitions(-DJPEG_FOUND) 21 | -add_definitions(-DPNG_FOUND) 22 | 23 | function(CUDA_CONVERT_FLAGS EXISTING_TARGET) 24 | get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS) 25 | @@ -39,6 +35,14 @@ function(CUDA_CONVERT_FLAGS EXISTING_TARGET) 26 | endif() 27 | endfunction() 28 | 29 | +file(GLOB HEADERS torchvision/csrc/*.h) 30 | +file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp ${HEADERS} torchvision/csrc/*.cpp) 31 | +if(WITH_CUDA) 32 | + file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} torchvision/csrc/cuda/*.h torchvision/csrc/cuda/*.cu) 33 | +endif() 34 | +file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h) 35 | +file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) 36 | + 37 | if(MSVC) 38 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819") 39 | if(WITH_CUDA) 40 | @@ -82,7 +86,7 @@ FOREACH(DIR ${ALLOW_LISTED}) 41 | ENDFOREACH() 42 | 43 | add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES}) 44 | -target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${PNG_LIBRARY} ${JPEG_LIBRARIES}) 45 | +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) 46 | 47 | if (USE_PYTHON) 48 | target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python) 49 | @@ -92,7 +96,7 @@ set_target_properties(${PROJECT_NAME} PROPERTIES 50 | EXPORT_NAME TorchVision 51 | INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib) 52 | 53 | -include_directories(torchvision/csrc ${JPEG_INCLUDE_DIRS} ${PNG_INCLUDE_DIRS}) 54 | +include_directories(torchvision/csrc) 55 | 56 | set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake") 57 | 58 | -- 59 | 2.37.1.windows.1 60 | 61 | -------------------------------------------------------------------------------- /tests/scriptmodule.test.js: -------------------------------------------------------------------------------- 1 | const torch = require('../dist'); 2 | 3 | const testModelPath = `${__dirname}/resources/test_model.pt`; 4 | 5 | describe('Constructor', () => { 6 | test('Call constructor using valid model path', () => { 7 | const scriptModule = new torch.ScriptModule(testModelPath); 8 | expect(scriptModule.toString()).toMatch(/ScriptModule.*\.pt/); 9 | }); 10 | 11 | // TODO -- Fine tune error message to remove stacktrace for error thrown on invalid model file 12 | test('Call constructor from invalid model path', () => { 13 | const t = () => new torch.ScriptModule('/resources/no_model.pt'); 14 | expect(t).toThrow(/Unrecognized data format*/); 15 | expect(true).toEqual(true); 16 | }); 17 | 18 | test('Call constructor with missing params', () => { 19 | const t = () => new torch.ScriptModule(); 20 | expect(t).toThrow(new Error('A string was expected')); 21 | }); 22 | 23 | test('Call constructor with invalid params', () => { 24 | const t = () => new torch.ScriptModule(true); 25 | const t2 = () => new torch.ScriptModule(123); 26 | const t3 = () => new torch.ScriptModule(122.3); 27 | expect(t).toThrow(new Error('A string was expected')); 28 | expect(t2).toThrow(new Error('A string was expected')); 29 | expect(t3).toThrow(new Error('A string was expected')); 30 | }); 31 | }); 32 | 33 | describe('toString', () => { 34 | test('Call toString using valid model path', () => { 35 | const scriptModule = new torch.ScriptModule(testModelPath); 36 | expect(scriptModule.toString()).toMatch(/ScriptModule.*\.pt/); 37 | }); 38 | 39 | test("Call toString using valid model path with variable params (shouldn't make a difference)", () => { 40 | const scriptModule = new torch.ScriptModule(testModelPath); 41 | expect(scriptModule.toString(3)).toMatch(/ScriptModule.*\.pt/); 42 | expect(scriptModule.toString(false)).toMatch(/ScriptModule.*\.pt/); 43 | expect(scriptModule.toString(3.233)).toMatch(/ScriptModule.*\.pt/); 44 | expect(scriptModule.toString(['Hello world'])).toMatch( 45 | /ScriptModule.*\.pt/ 46 | ); 47 | }); 48 | }); 49 | 50 | describe('Forward function', () => { 51 | test('Call to forward using valid tensor params', async () => { 52 | const scriptModule = new torch.ScriptModule(testModelPath); 53 | const a = torch.tensor([ 54 | [0.1, 0.2, 0.3], 55 | [0.4, 0.5, 0.6], 56 | ]); 57 | const b = torch.tensor([ 58 | [0.1, 0.2, 0.3], 59 | [0.4, 0.5, 0.6], 60 | ]); 61 | const res = await scriptModule.forward(a, b); 62 | expect(res.toObject().data.length).toEqual(6); 63 | expect(res.toObject().shape).toMatchObject([2, 3]); 64 | }); 65 | 66 | // TODO -- Fine tune error message to remove stacktrace for missing arguement value -- at the moment, the multiple lines make it difficult to test for an error message 67 | test('Call to forward using missing second tensor param', async () => { 68 | const scriptModule = new torch.ScriptModule(testModelPath); 69 | const a = torch.tensor([ 70 | [0.1, 0.2, 0.3], 71 | [0.4, 0.5, 0.6], 72 | ]); 73 | expect(scriptModule.forward(a)).rejects.toThrow(/[\s\S]/); 74 | }); 75 | }); 76 | -------------------------------------------------------------------------------- /.github/workflows/publish-prebuild-test.yml: -------------------------------------------------------------------------------- 1 | name: publish-prebuild-test 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'pytorch/**' 7 | 8 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 9 | jobs: 10 | # This workflow contains a single job called "build" 11 | build: 12 | # The type of runner that the job will run on 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | os: [ubuntu-latest, windows-latest] 18 | if: "!contains(github.event.head_commit.message, 'publish-prebuild skip')" 19 | 20 | # Steps represent a sequence of tasks that will be executed as part of the job 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v2 24 | 25 | - name: Install node 26 | uses: actions/setup-node@v1 27 | with: 28 | node-version: '16' 29 | 30 | - name: Install Python 31 | uses: actions/setup-python@v2 32 | with: 33 | python-version: '3.10' 34 | 35 | - name: Setup Conda to install cudnn on windows-latest 36 | uses: s-weigand/setup-conda@v1 37 | if: ${{ matrix.os == 'windows-latest' }} 38 | with: 39 | activate-conda: false 40 | 41 | - name: Install CUDA on windows-latest 42 | if: ${{ matrix.os == 'windows-latest' }} 43 | shell: powershell 44 | env: 45 | cuda: "11.6.2" 46 | cudnn: "8.4.0.27" 47 | run: | 48 | # Install CUDA via a powershell script 49 | .\.github\workflows\scripts\install_cuda_windows.ps1 50 | if ($?) { 51 | # Set paths for subsequent steps, using $env:CUDA_PATH 52 | echo "Adding CUDA to CUDA_PATH, CUDA_PATH_X_Y and PATH" 53 | echo "CUDA_PATH=$($env:CUDA_PATH)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append 54 | echo "$($env:CUDA_PATH_VX_Y)=$($env:CUDA_PATH)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append 55 | echo "$($env:CUDA_PATH)/bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append 56 | } 57 | 58 | - name: Install CUDA on ubuntu-latest 59 | if: ${{ matrix.os == 'ubuntu-latest' }} 60 | shell: bash 61 | env: 62 | cuda: "11.6.2" 63 | cudnn: "8.4.0.27" 64 | run: | 65 | source ./.github/workflows/scripts/install_cuda_ubuntu.sh 66 | if [[ $? -eq 0 ]]; then 67 | # Set paths for subsequent steps, using ${CUDA_PATH} 68 | echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH" 69 | echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV 70 | echo "${CUDA_PATH}/bin" >> $GITHUB_PATH 71 | echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV 72 | fi 73 | 74 | - name: Install packages 75 | run: yarn upgrade --dev 76 | 77 | - name: Prebuild 78 | run: yarn build-prebuild 79 | 80 | - name: Check prebuild bundle size 81 | if: ${{ matrix.os == 'windows-latest' }} 82 | shell: powershell 83 | run: | 84 | ls prebuilds/@arition 85 | ls build/Release 86 | -------------------------------------------------------------------------------- /.github/workflows/publish-prebuild.yml: -------------------------------------------------------------------------------- 1 | name: publish-prebuild 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 9 | jobs: 10 | # This workflow contains a single job called "build" 11 | build: 12 | # The type of runner that the job will run on 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | os: [macos-latest, ubuntu-latest, windows-latest] 18 | if: "!contains(github.event.head_commit.message, 'publish-prebuild skip')" 19 | 20 | # Steps represent a sequence of tasks that will be executed as part of the job 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v2 24 | 25 | - name: Install node 26 | uses: actions/setup-node@v1 27 | with: 28 | node-version: '16' 29 | 30 | - name: Install Python 31 | uses: actions/setup-python@v2 32 | with: 33 | python-version: '3.10' 34 | 35 | - name: Setup Conda to install cudnn on windows-latest 36 | uses: s-weigand/setup-conda@v1 37 | if: ${{ matrix.os == 'windows-latest' }} 38 | with: 39 | activate-conda: false 40 | 41 | - name: Install CUDA on windows-latest 42 | if: ${{ matrix.os == 'windows-latest' }} 43 | shell: powershell 44 | env: 45 | cuda: "11.6.2" 46 | cudnn: "8.4.0.27" 47 | run: | 48 | # Install CUDA via a powershell script 49 | .\.github\workflows\scripts\install_cuda_windows.ps1 50 | if ($?) { 51 | # Set paths for subsequent steps, using $env:CUDA_PATH 52 | echo "Adding CUDA to CUDA_PATH, CUDA_PATH_X_Y and PATH" 53 | echo "CUDA_PATH=$($env:CUDA_PATH)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append 54 | echo "$($env:CUDA_PATH_VX_Y)=$($env:CUDA_PATH)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append 55 | echo "$($env:CUDA_PATH)/bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append 56 | } 57 | 58 | - name: Install CUDA on ubuntu-latest 59 | if: ${{ matrix.os == 'ubuntu-latest' }} 60 | shell: bash 61 | env: 62 | cuda: "11.6.2" 63 | cudnn: "8.4.0.27" 64 | run: | 65 | source ./.github/workflows/scripts/install_cuda_ubuntu.sh 66 | if [[ $? -eq 0 ]]; then 67 | # Set paths for subsequent steps, using ${CUDA_PATH} 68 | echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH" 69 | echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV 70 | echo "${CUDA_PATH}/bin" >> $GITHUB_PATH 71 | echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV 72 | fi 73 | 74 | - name: Install packages 75 | run: yarn upgrade --dev 76 | 77 | - name: Prebuild 78 | run: yarn build-prebuild 79 | 80 | - name: Check prebuild bundle size 81 | if: ${{ matrix.os == 'windows-latest' }} 82 | shell: powershell 83 | run: | 84 | ls prebuilds/@arition 85 | ls build/Release 86 | 87 | - name: Publish prebuild 88 | run: yarn build-prebuild -u ${{ secrets.GITHUB_TOKEN }} 89 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | project (torch-js) 3 | 4 | set(PYTORCH_VERSION 1.12.0) 5 | set(PYTORCH_VISION_VERSION 0.13.0) 6 | find_package(CUDA) 7 | if(CUDA_FOUND) 8 | if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch") 9 | if(WIN32) 10 | file(DOWNLOAD https://download.pytorch.org/libtorch/cu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}/libtorch-win-shared-with-deps-${PYTORCH_VERSION}%2Bcu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}.zip ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}.zip SHOW_PROGRESS) 11 | else() 12 | file(DOWNLOAD https://download.pytorch.org/libtorch/cu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}.zip ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}.zip SHOW_PROGRESS) 13 | endif() 14 | execute_process( 15 | COMMAND unzip -qq -o ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cu${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}.zip 16 | ) 17 | endif() 18 | else() 19 | if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/build/libtorch") 20 | if(WIN32) 21 | file(DOWNLOAD https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cpu.zip) 22 | elseif(APPLE) 23 | file(DOWNLOAD https://download.pytorch.org/libtorch/cpu/libtorch-macos-${PYTORCH_VERSION}.zip ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cpu.zip) 24 | else() 25 | file(DOWNLOAD https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cpu.zip) 26 | endif() 27 | execute_process( 28 | COMMAND unzip -qq -o ${CMAKE_CURRENT_SOURCE_DIR}/libtorch-cpu.zip 29 | ) 30 | endif() 31 | endif(CUDA_FOUND) 32 | 33 | # Use C++14 34 | set(CMAKE_CXX_STANDARD 14) 35 | 36 | if(NOT CMAKE_BUILD_TYPE) 37 | set(CMAKE_BUILD_TYPE Release) 38 | endif() 39 | 40 | set(CMAKE_PREFIX_PATH ${CMAKE_CURRENT_BINARY_DIR}/libtorch/) 41 | find_package(Torch REQUIRED) 42 | 43 | if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/torchvision") 44 | execute_process( 45 | COMMAND git clone --depth 1 --branch v${PYTORCH_VISION_VERSION} https://github.com/pytorch/vision 46 | ) 47 | execute_process( 48 | COMMAND git -C "${CMAKE_CURRENT_BINARY_DIR}/vision/" apply "${CMAKE_CURRENT_SOURCE_DIR}/patches/0001-Remove-native-image-support.patch" 49 | ) 50 | execute_process( 51 | COMMAND ${CMAKE_COMMAND} -DCMAKE_BUILD_TYPE=Release -DUSE_PYTHON=ON -DWITH_CUDA=${CUDA_FOUND} -DCMAKE_PREFIX_PATH=${CMAKE_CURRENT_BINARY_DIR}/libtorch -S "${CMAKE_CURRENT_BINARY_DIR}/vision/" -B "${CMAKE_CURRENT_BINARY_DIR}/vision/build/" 52 | ) 53 | execute_process( 54 | COMMAND ${CMAKE_COMMAND} --build "${CMAKE_CURRENT_BINARY_DIR}/vision/build" --config Release 55 | ) 56 | execute_process( 57 | COMMAND ${CMAKE_COMMAND} --install "${CMAKE_CURRENT_BINARY_DIR}/vision/build" --prefix "${CMAKE_CURRENT_BINARY_DIR}/torchvision" 58 | ) 59 | endif() 60 | 61 | set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${CMAKE_CURRENT_BINARY_DIR}/torchvision/) 62 | find_package(TorchVision) 63 | 64 | if(WIN32) 65 | elseif(APPLE) 66 | set(CMAKE_BUILD_RPATH @loader_path;@executable_path) 67 | else() 68 | set(CMAKE_BUILD_RPATH $ORIGIN) 69 | endif() 70 | 71 | include_directories(${CMAKE_JS_INC}) 72 | 73 | file(GLOB SOURCE_FILES "src/*.cc" "src/*.h") 74 | add_library(${PROJECT_NAME} SHARED ${SOURCE_FILES} ${CMAKE_JS_SRC}) 75 | set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" SUFFIX ".node") 76 | 77 | # Include N-API wrappers 78 | execute_process(COMMAND node -p "require('node-addon-api').include" 79 | WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} 80 | OUTPUT_VARIABLE NODE_ADDON_API_DIR 81 | ) 82 | string(REPLACE "\n" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR}) 83 | string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR}) 84 | target_include_directories(${PROJECT_NAME} PRIVATE ${NODE_ADDON_API_DIR}) 85 | add_definitions(-DNAPI_VERSION=3) 86 | 87 | target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES}) 88 | 89 | target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/torchvision/include") 90 | target_link_libraries(${PROJECT_NAME} TorchVision::TorchVision) 91 | 92 | target_link_libraries(${PROJECT_NAME} ${CMAKE_JS_LIB}) 93 | 94 | if(WIN32) 95 | # use move for windows since the libs are too large to be copied on github actions vm 96 | add_custom_command(TARGET torch-js POST_BUILD 97 | COMMAND powershell mv 98 | "\"${CMAKE_CURRENT_BINARY_DIR}/libtorch/lib/*.dll\"" 99 | \"$\") 100 | add_custom_command(TARGET torch-js POST_BUILD 101 | COMMAND ${CMAKE_COMMAND} -E rm -rf 102 | "$/cudnn_adv_train64_8.dll" 103 | "$/cudnn_cnn_train64_8.dll" 104 | "$/cudnn_ops_train64_8.dll") 105 | else() 106 | add_custom_command(TARGET torch-js POST_BUILD 107 | COMMAND ${CMAKE_COMMAND} -E copy_directory 108 | "${CMAKE_CURRENT_BINARY_DIR}/libtorch/lib" 109 | $) 110 | endif() 111 | 112 | if(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/torchvision/bin") 113 | add_custom_command(TARGET torch-js POST_BUILD 114 | COMMAND ${CMAKE_COMMAND} -E copy_directory 115 | "${CMAKE_CURRENT_BINARY_DIR}/torchvision/bin" 116 | $) 117 | elseif(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/torchvision/lib") 118 | add_custom_command(TARGET torch-js POST_BUILD 119 | COMMAND ${CMAKE_COMMAND} -E copy_directory 120 | "${CMAKE_CURRENT_BINARY_DIR}/torchvision/lib" 121 | $) 122 | elseif(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/torchvision/lib64") 123 | add_custom_command(TARGET torch-js POST_BUILD 124 | COMMAND ${CMAKE_COMMAND} -E copy_directory 125 | "${CMAKE_CURRENT_BINARY_DIR}/torchvision/lib64" 126 | $) 127 | endif() 128 | -------------------------------------------------------------------------------- /.github/workflows/scripts/install_cuda_ubuntu.sh: -------------------------------------------------------------------------------- 1 | # @todo - better / more robust parsing of inputs from env vars. 2 | ## ------------------- 3 | ## Constants 4 | ## ------------------- 5 | 6 | # @todo - apt repos/known supported versions? 7 | 8 | # @todo - GCC support matrix? 9 | 10 | # List of sub-packages to install. 11 | # @todo - pass this in from outside the script? 12 | # @todo - check the specified subpackages exist via apt pre-install? apt-rdepends cuda-9-0 | grep "^cuda-"? 13 | 14 | # Ideally choose from the list of meta-packages to minimise variance between cuda versions (although it does change too) 15 | CUDA_PACKAGES_IN=( 16 | "command-line-tools" 17 | "libraries-dev" 18 | ) 19 | 20 | ## ------------------- 21 | ## Bash functions 22 | ## ------------------- 23 | # returns 0 (true) if a >= b 24 | function version_ge() { 25 | [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 26 | [ "$(printf '%s\n' "$@" | sort -V | head -n 1)" == "$2" ] 27 | } 28 | # returns 0 (true) if a > b 29 | function version_gt() { 30 | [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 31 | [ "$1" = "$2" ] && return 1 || version_ge $1 $2 32 | } 33 | # returns 0 (true) if a <= b 34 | function version_le() { 35 | [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 36 | [ "$(printf '%s\n' "$@" | sort -V | head -n 1)" == "$1" ] 37 | } 38 | # returns 0 (true) if a < b 39 | function version_lt() { 40 | [ "$#" != "2" ] && echo "${FUNCNAME[0]} requires exactly 2 arguments." && exit 1 41 | [ "$1" = "$2" ] && return 1 || version_le $1 $2 42 | } 43 | 44 | ## ------------------- 45 | ## Select CUDA version 46 | ## ------------------- 47 | 48 | # Get the cuda version from the environment as $cuda. 49 | CUDA_VERSION_MAJOR_MINOR=${cuda} 50 | CUDNN_VERSION=${cudnn} 51 | 52 | # Split the version. 53 | # We (might/probably) don't know PATCH at this point - it depends which version gets installed. 54 | CUDA_MAJOR=$(echo "${CUDA_VERSION_MAJOR_MINOR}" | cut -d. -f1) 55 | CUDA_MINOR=$(echo "${CUDA_VERSION_MAJOR_MINOR}" | cut -d. -f2) 56 | CUDA_PATCH=$(echo "${CUDA_VERSION_MAJOR_MINOR}" | cut -d. -f3) 57 | # use lsb_release to find the OS. 58 | UBUNTU_VERSION=$(lsb_release -sr) 59 | UBUNTU_VERSION="${UBUNTU_VERSION//.}" 60 | 61 | echo "CUDA_MAJOR: ${CUDA_MAJOR}" 62 | echo "CUDA_MINOR: ${CUDA_MINOR}" 63 | echo "CUDA_PATCH: ${CUDA_PATCH}" 64 | # echo "UBUNTU_NAME: ${UBUNTU_NAME}" 65 | echo "UBUNTU_VERSION: ${UBUNTU_VERSION}" 66 | 67 | # If we don't know the CUDA_MAJOR or MINOR, error. 68 | if [ -z "${CUDA_MAJOR}" ] ; then 69 | echo "Error: Unknown CUDA Major version. Aborting." 70 | exit 1 71 | fi 72 | if [ -z "${CUDA_MINOR}" ] ; then 73 | echo "Error: Unknown CUDA Minor version. Aborting." 74 | exit 1 75 | fi 76 | # If we don't know the Ubuntu version, error. 77 | if [ -z ${UBUNTU_VERSION} ]; then 78 | echo "Error: Unknown Ubuntu version. Aborting." 79 | exit 1 80 | fi 81 | 82 | 83 | ## --------------------------- 84 | ## GCC studio support check? 85 | ## --------------------------- 86 | 87 | # @todo 88 | 89 | ## ------------------------------- 90 | ## Select CUDA packages to install 91 | ## ------------------------------- 92 | CUDA_PACKAGES="" 93 | for package in "${CUDA_PACKAGES_IN[@]}" 94 | do : 95 | # @todo This is not perfect. Should probably provide a separate list for diff versions 96 | # cuda-compiler-X-Y if CUDA >= 9.1 else cuda-nvcc-X-Y 97 | if [[ "${package}" == "nvcc" ]] && version_ge "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then 98 | package="compiler" 99 | elif [[ "${package}" == "compiler" ]] && version_lt "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then 100 | package="nvcc" 101 | fi 102 | # Build the full package name and append to the string. 103 | CUDA_PACKAGES+=" cuda-${package}-${CUDA_MAJOR}-${CUDA_MINOR}" 104 | done 105 | echo "CUDA_PACKAGES ${CUDA_PACKAGES}" 106 | 107 | ## ----------------- 108 | ## Prepare to install 109 | ## ----------------- 110 | 111 | PIN_FILENAME="cuda-ubuntu${UBUNTU_VERSION}.pin" 112 | PIN_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/${PIN_FILENAME}" 113 | APT_KEY_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/7fa2af80.pub" 114 | APT_KEY_URL2="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/3bf863cc.pub" 115 | REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64" 116 | CUDNN_REPO_URL="https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu${UBUNTU_VERSION}/x86_64" 117 | 118 | echo "PIN_FILENAME ${PIN_FILENAME}" 119 | echo "PIN_URL ${PIN_URL}" 120 | echo "APT_KEY_URL ${APT_KEY_URL}" 121 | echo "APT_KEY_URL2 ${APT_KEY_URL2}" 122 | 123 | ## ----------------- 124 | ## Install 125 | ## ----------------- 126 | echo "Adding CUDA Repository" 127 | # wget ${PIN_URL} 128 | # sudo mv ${PIN_FILENAME} /etc/apt/preferences.d/cuda-repository-pin-600 129 | sudo apt-key adv --fetch-keys ${APT_KEY_URL} 130 | sudo apt-key adv --fetch-keys ${APT_KEY_URL2} 131 | sudo add-apt-repository "deb ${REPO_URL} /" 132 | sudo add-apt-repository "deb ${CUDNN_REPO_URL} /" 133 | sudo apt-get update 134 | 135 | echo "Installing CUDA packages ${CUDA_PACKAGES}" 136 | sudo apt-get -y install ${CUDA_PACKAGES} 137 | 138 | if [[ $? -ne 0 ]]; then 139 | echo "CUDA Installation Error." 140 | exit 1 141 | fi 142 | ## ----------------- 143 | ## Set environment vars / vars to be propagated 144 | ## ----------------- 145 | 146 | CUDA_PATH=/usr/local/cuda-${CUDA_MAJOR}.${CUDA_MINOR} 147 | echo "CUDA_PATH=${CUDA_PATH}" 148 | export CUDA_PATH=${CUDA_PATH} 149 | 150 | 151 | # Quick test. @temp 152 | export PATH="$CUDA_PATH/bin:$PATH" 153 | export LD_LIBRARY_PATH="$CUDA_PATH/lib:$LD_LIBRARY_PATH" 154 | nvcc -V 155 | 156 | sudo apt-get install -y --no-install-recommends \ 157 | libcudnn8=${CUDNN_VERSION}-1+cuda${CUDA_MAJOR}.${CUDA_MINOR} \ 158 | libcudnn8-dev=${CUDNN_VERSION}-1+cuda${CUDA_MAJOR}.${CUDA_MINOR} 159 | -------------------------------------------------------------------------------- /.github/workflows/scripts/install_cuda_windows.ps1: -------------------------------------------------------------------------------- 1 | ## ------------------- 2 | ## Constants 3 | ## ------------------- 4 | 5 | # Dictionary of known cuda versions and thier download URLS, which do not follow a consistent pattern :( 6 | $CUDA_KNOWN_URLS = @{ 7 | "10.1.243" = "https://developer.download.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.243_win10_network.exe"; 8 | "10.2.89" = "https://developer.download.nvidia.com/compute/cuda/10.2/Prod/network_installers/cuda_10.2.89_win10_network.exe"; 9 | "11.0.2" = "https://developer.download.nvidia.com/compute/cuda/11.0.2/network_installers/cuda_11.0.2_win10_network.exe"; 10 | "11.1.1" = "https://developer.download.nvidia.com/compute/cuda/11.1.1/network_installers/cuda_11.1.1_win10_network.exe"; 11 | "11.3.1" = "https://developer.download.nvidia.com/compute/cuda/11.3.1/network_installers/cuda_11.3.1_win10_network.exe"; 12 | "11.6.2" = "https://developer.download.nvidia.com/compute/cuda/11.6.2/network_installers/cuda_11.6.2_windows_network.exe"; 13 | } 14 | 15 | # cuda_runtime.h is in nvcc <= 10.2, but cudart >= 11.0 16 | # Package list: https://docs.nvidia.com/cuda/pdf/CUDA_Installation_Guide_Windows.pdf 17 | # @todo - make this easier to vary per CUDA version. 18 | $CUDA_PACKAGES_IN = @( 19 | "nvcc"; 20 | "visual_studio_integration"; 21 | "curand_dev"; 22 | "cublas_dev"; 23 | "cufft_dev"; 24 | "cusolver_dev"; 25 | "cusparse_dev"; 26 | "nvrtc_dev"; 27 | "cudart"; 28 | "nsight_nvtx"; 29 | "thrust"; 30 | ) 31 | 32 | 33 | ## ------------------- 34 | ## Select CUDA version 35 | ## ------------------- 36 | 37 | # Get the cuda version from the environment as env:cuda. 38 | $CUDA_VERSION_FULL = $env:cuda 39 | $CUDNN_VERSION_FULL = $env:cudnn 40 | 41 | # Validate CUDA version, extracting components via regex 42 | $cuda_ver_matched = $CUDA_VERSION_FULL -match "^(?[1-9][0-9]*)\.(?[0-9]+)\.(?[0-9]+)$" 43 | if(-not $cuda_ver_matched){ 44 | Write-Output "Invalid CUDA version specified, .. required. '$CUDA_VERSION_FULL'." 45 | exit 1 46 | } 47 | $CUDA_MAJOR=$Matches.major 48 | $CUDA_MINOR=$Matches.minor 49 | $CUDA_PATCH=$Matches.patch 50 | 51 | # Validate CUDNN version, extracting components via regex 52 | $cudnn_ver_matched = $CUDNN_VERSION_FULL -match "^(?[0-9]+)\.(?[0-9]+)\.(?[0-9]+)\.(?[0-9]+)$" 53 | if(-not $cuda_ver_matched){ 54 | Write-Output "Invalid CUDNN version specified, ... required. '$CUDNN_VERSION_FULL'." 55 | exit 1 56 | } 57 | $CUDNN_MAJOR=$Matches.major 58 | $CUDNN_MINOR=$Matches.minor 59 | $CUDNN_PATCH=$Matches.patch 60 | $CUDNN_BUILD=$Matches.build 61 | 62 | ## ------------------------------------------------ 63 | ## Select CUDA packages to install from environment 64 | ## ------------------------------------------------ 65 | 66 | $CUDA_PACKAGES = "" 67 | 68 | # for CUDA >= 11 cudart is a required package. 69 | # if([version]$CUDA_VERSION_FULL -ge [version]"11.0") { 70 | # if(-not $CUDA_PACKAGES_IN -contains "cudart") { 71 | # $CUDA_PACKAGES_IN += 'cudart' 72 | # } 73 | # } 74 | 75 | Foreach ($package in $CUDA_PACKAGES_IN) { 76 | # Make sure the correct package name is used for nvcc. 77 | if($package -eq "nvcc" -and [version]$CUDA_VERSION_FULL -lt [version]"9.1"){ 78 | $package="compiler" 79 | } elseif($package -eq "compiler" -and [version]$CUDA_VERSION_FULL -ge [version]"9.1") { 80 | $package="nvcc" 81 | } 82 | $CUDA_PACKAGES += " $($package)_$($CUDA_MAJOR).$($CUDA_MINOR)" 83 | 84 | } 85 | echo "$($CUDA_PACKAGES)" 86 | ## ----------------- 87 | ## Prepare download 88 | ## ----------------- 89 | 90 | # Select the download link if known, otherwise have a guess. 91 | $CUDA_REPO_PKG_REMOTE="" 92 | if($CUDA_KNOWN_URLS.containsKey($CUDA_VERSION_FULL)){ 93 | $CUDA_REPO_PKG_REMOTE=$CUDA_KNOWN_URLS[$CUDA_VERSION_FULL] 94 | } else{ 95 | # Guess what the url is given the most recent pattern (at the time of writing, 11) 96 | Write-Output "note: URL for CUDA ${$CUDA_VERSION_FULL} not known, estimating." 97 | $CUDA_REPO_PKG_REMOTE="https://developer.download.nvidia.com/compute/cuda/$($CUDA_VERSION_FULL)/network_installers/cuda_$($CUDA_VERSION_FULL)_windows_network.exe" 98 | } 99 | $TEMP_PATH = [System.IO.Path]::GetTempPath() 100 | $CUDA_REPO_PKG_LOCAL = Join-Path $TEMP_PATH "cuda_$($CUDA_VERSION_FULL)_windows_network.exe" 101 | 102 | 103 | ## ------------ 104 | ## Install CUDA 105 | ## ------------ 106 | 107 | # Get CUDA network installer 108 | Write-Output "Downloading CUDA Network Installer for $($CUDA_VERSION_FULL) from: $($CUDA_REPO_PKG_REMOTE)" 109 | Invoke-WebRequest $CUDA_REPO_PKG_REMOTE -OutFile $CUDA_REPO_PKG_LOCAL | Out-Null 110 | if(Test-Path -Path $CUDA_REPO_PKG_LOCAL){ 111 | Write-Output "Downloading Complete" 112 | } else { 113 | Write-Output "Error: Failed to download $($CUDA_REPO_PKG_LOCAL) from $($CUDA_REPO_PKG_REMOTE)" 114 | exit 1 115 | } 116 | 117 | # Invoke silent install of CUDA (via network installer) 118 | Write-Output "Installing CUDA $($CUDA_VERSION_FULL). Subpackages $($CUDA_PACKAGES)" 119 | Start-Process -Wait -FilePath "$($CUDA_REPO_PKG_LOCAL)" -ArgumentList "-s $($CUDA_PACKAGES)" 120 | 121 | # Check the return status of the CUDA installer. 122 | if (!$?) { 123 | Write-Output "Error: CUDA installer reported error. $($LASTEXITCODE)" 124 | exit 1 125 | } 126 | 127 | # Store the CUDA_PATH in the environment for the current session, to be forwarded in the action. 128 | $CUDA_PATH = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$($CUDA_MAJOR).$($CUDA_MINOR)" 129 | $CUDA_PATH_VX_Y = "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" 130 | 131 | # PATH needs updating elsewhere, anything in here won't persist. 132 | # Append $CUDA_PATH/bin to path. 133 | # Set CUDA_PATH as an environmental variable 134 | 135 | conda install -y -c conda-forge cudnn=$($CUDNN_VERSION_FULL) 136 | 137 | # Set environmental variables in this session 138 | $env:CUDA_PATH = "$($CUDA_PATH)" 139 | $env:CUDA_PATH_VX_Y = "$($CUDA_PATH_VX_Y)" 140 | Write-Output "CUDA_PATH $($CUDA_PATH)" 141 | Write-Output "CUDA_PATH_VX_Y $($CUDA_PATH_VX_Y)" 142 | -------------------------------------------------------------------------------- /src/Tensor.cc: -------------------------------------------------------------------------------- 1 | #include "Tensor.h" 2 | 3 | #include 4 | #include "constants.h" 5 | #include "utils.h" 6 | 7 | namespace torchjs 8 | { 9 | 10 | using namespace constants; 11 | 12 | namespace 13 | { 14 | template 15 | Napi::Value tensorToArray(Napi::Env env, const torch::Tensor &tensor) 16 | { 17 | try 18 | { 19 | Napi::EscapableHandleScope scope(env); 20 | assert(tensor.is_contiguous()); 21 | auto typed_array = Napi::TypedArrayOf::New(env, tensor.numel()); 22 | memcpy(typed_array.Data(), tensor.data_ptr(), sizeof(T) * tensor.numel()); 23 | auto shape_array = tensorShapeToArray(env, tensor); 24 | auto obj = Napi::Object::New(env); 25 | obj.Set(kData, typed_array); 26 | obj.Set(kShape, shape_array); 27 | return scope.Escape(obj); 28 | } 29 | catch (const std::exception &e) 30 | { 31 | throw Napi::Error::New(env, e.what()); 32 | } 33 | } 34 | 35 | template 36 | Napi::Value arrayToTensor(Napi::Env env, const Napi::TypedArray &data, 37 | const ShapeArrayType &shape_array) 38 | { 39 | try 40 | { 41 | Napi::EscapableHandleScope scope(env); 42 | auto *data_ptr = data.As>().Data(); 43 | auto shape = shapeArrayToVector(shape_array); 44 | torch::TensorOptions options(scalarType()); 45 | options = options.requires_grad(false); 46 | auto torch_tensor = torch::empty(shape, options); 47 | memcpy(torch_tensor.data_ptr(), data_ptr, sizeof(T) * torch_tensor.numel()); 48 | return scope.Escape(Tensor::FromTensor(env, torch_tensor)); 49 | } 50 | catch (const std::exception &e) 51 | { 52 | throw Napi::Error::New(env, e.what()); 53 | } 54 | } 55 | } // namespace 56 | 57 | Napi::Object Tensor::Init(Napi::Env env, Napi::Object exports) 58 | { 59 | Napi::Function func = DefineClass(env, "Tensor", 60 | { 61 | InstanceMethod("toString", &Tensor::toString), 62 | InstanceMethod("toObject", &Tensor::toObject), 63 | StaticMethod("fromObject", &Tensor::fromObject), 64 | InstanceMethod("cpu", &Tensor::cpu), 65 | InstanceMethod("cuda", &Tensor::cuda), 66 | InstanceMethod("free", &Tensor::free), 67 | InstanceMethod("clone", &Tensor::clone), 68 | }); 69 | 70 | constructor = Napi::Persistent(func); 71 | constructor.SuppressDestruct(); 72 | exports.Set("Tensor", func); 73 | return exports; 74 | } 75 | 76 | Napi::Object Tensor::FromTensor(Napi::Env env, const torch::Tensor &tensor) 77 | { 78 | try 79 | { 80 | Napi::EscapableHandleScope scope(env); 81 | auto obj = constructor.New({}); 82 | Napi::ObjectWrap::Unwrap(obj)->tensor_ = tensor; 83 | return scope.Escape(obj).ToObject(); 84 | } 85 | catch (const std::exception &e) 86 | { 87 | throw Napi::Error::New(env, e.what()); 88 | } 89 | } 90 | 91 | Tensor::Tensor(const Napi::CallbackInfo &info) : ObjectWrap(info), is_free(false) {} 92 | 93 | Napi::FunctionReference Tensor::constructor; 94 | 95 | torch::Tensor Tensor::tensor() 96 | { 97 | return tensor_; 98 | } 99 | 100 | bool Tensor::IsInstance(Napi::Object &obj) 101 | { 102 | return obj.InstanceOf(constructor.Value()); 103 | } 104 | 105 | Napi::Value Tensor::toString(const Napi::CallbackInfo &info) 106 | { 107 | try 108 | { 109 | return Napi::String::New(info.Env(), tensor_.toString()); 110 | } 111 | catch (const std::exception &e) 112 | { 113 | throw Napi::Error::New(info.Env(), e.what()); 114 | } 115 | } 116 | 117 | Napi::Value Tensor::toObject(const Napi::CallbackInfo &info) 118 | { 119 | try 120 | { 121 | auto env = info.Env(); 122 | auto st = tensor_.scalar_type(); 123 | switch (st) 124 | { 125 | case torch::ScalarType::Float: 126 | return tensorToArray(env, tensor_); 127 | case torch::ScalarType::Double: 128 | return tensorToArray(env, tensor_); 129 | case torch::ScalarType::Int: 130 | return tensorToArray(env, tensor_); 131 | default: 132 | throw Napi::TypeError::New(env, "Unsupported type"); 133 | } 134 | } 135 | catch (const std::exception &e) 136 | { 137 | throw Napi::Error::New(info.Env(), e.what()); 138 | } 139 | } 140 | 141 | Napi::Value Tensor::fromObject(const Napi::CallbackInfo &info) 142 | { 143 | try 144 | { 145 | auto env = info.Env(); 146 | Napi::HandleScope scope(env); 147 | auto obj = info[0].As(); 148 | auto data = obj.Get(kData).As(); 149 | auto shape = obj.Get(kShape).As(); 150 | auto data_type = data.TypedArrayType(); 151 | switch (data_type) 152 | { 153 | case napi_float32_array: 154 | return arrayToTensor(env, data, shape); 155 | case napi_float64_array: 156 | return arrayToTensor(env, data, shape); 157 | case napi_int32_array: 158 | return arrayToTensor(env, data, shape); 159 | default: 160 | throw Napi::TypeError::New(env, "Unsupported type"); 161 | } 162 | } 163 | catch (const std::exception &e) 164 | { 165 | throw Napi::Error::New(info.Env(), e.what()); 166 | } 167 | } 168 | 169 | Napi::Value Tensor::clone(const Napi::CallbackInfo &info) 170 | { 171 | auto newTensor = tensor_.clone(); 172 | return Tensor::FromTensor(info.Env(), newTensor); 173 | } 174 | 175 | Napi::Value Tensor::cpu(const Napi::CallbackInfo &info) 176 | { 177 | try 178 | { 179 | if (tensor_.is_cuda()) 180 | { 181 | return FromTensor(info.Env(), tensor_.cpu()); 182 | } 183 | else 184 | { 185 | return info.This(); 186 | } 187 | } 188 | catch (const std::exception &e) 189 | { 190 | throw Napi::Error::New(info.Env(), e.what()); 191 | } 192 | } 193 | 194 | Napi::Value Tensor::cuda(const Napi::CallbackInfo &info) 195 | { 196 | try 197 | { 198 | if (torch::cuda::is_available()) 199 | { 200 | if (!tensor_.is_cuda()) 201 | { 202 | return FromTensor(info.Env(), tensor_.cuda()); 203 | } 204 | else 205 | { 206 | return info.This(); 207 | } 208 | } 209 | else 210 | { 211 | throw Napi::Error::New(info.Env(), "CUDA is not aviliable"); 212 | } 213 | return info.This(); 214 | } 215 | catch (const std::exception &e) 216 | { 217 | throw Napi::Error::New(info.Env(), e.what()); 218 | } 219 | } 220 | 221 | void Tensor::Finalize(Napi::Env env) 222 | { 223 | if (!is_free) 224 | { 225 | tensor_.getIntrusivePtr()->release_resources(); 226 | is_free = true; 227 | } 228 | } 229 | 230 | void Tensor::free(const Napi::CallbackInfo &info) 231 | { 232 | if (!is_free) 233 | { 234 | tensor_.getIntrusivePtr()->release_resources(); 235 | is_free = true; 236 | } 237 | } 238 | 239 | } // namespace torchjs -------------------------------------------------------------------------------- /src/ScriptModule.cc: -------------------------------------------------------------------------------- 1 | #include "ScriptModule.h" 2 | 3 | #include 4 | #include "Tensor.h" 5 | #include "utils.h" 6 | #include "FunctionWorker.h" 7 | 8 | namespace torchjs 9 | { 10 | Napi::Object ScriptModule::Init(Napi::Env env, Napi::Object exports) 11 | { 12 | Napi::Function func = DefineClass(env, "ScriptModule", 13 | {InstanceMethod("forward", &ScriptModule::forward), 14 | InstanceMethod("toString", &ScriptModule::toString), 15 | InstanceMethod("cpu", &ScriptModule::cpu), 16 | InstanceMethod("cuda", &ScriptModule::cuda), 17 | StaticMethod("isCudaAvailable", &ScriptModule::isCudaAvailable)}); 18 | 19 | constructor = Napi::Persistent(func); 20 | constructor.SuppressDestruct(); 21 | exports.Set("ScriptModule", func); 22 | return exports; 23 | } 24 | 25 | ScriptModule::ScriptModule(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) 26 | { 27 | try 28 | { 29 | torch::NoGradGuard no_grad; 30 | Napi::HandleScope scope(info.Env()); 31 | Napi::String value = info[0].As(); 32 | path_ = value; 33 | module_ = torch::jit::load(value); 34 | } 35 | catch (const std::exception &e) 36 | { 37 | throw Napi::Error::New(info.Env(), e.what()); 38 | } 39 | } 40 | 41 | Napi::FunctionReference ScriptModule::constructor; 42 | 43 | Napi::Value ScriptModule::forward(const Napi::CallbackInfo &info) 44 | { 45 | try 46 | { 47 | torch::NoGradGuard no_grad; 48 | module_.eval(); 49 | 50 | auto len = info.Length(); 51 | auto inputs = std::vector(); 52 | for (size_t i = 0; i < len; ++i) 53 | { 54 | inputs.push_back(JSTypeToIValue(info.Env(), info[i])); 55 | } 56 | 57 | auto worker = new FunctionWorker( 58 | info.Env(), 59 | [=]() -> c10::IValue { 60 | torch::NoGradGuard no_grad; 61 | return module_.forward(inputs); 62 | }, 63 | [=](Napi::Env env, c10::IValue value) -> Napi::Value { 64 | return IValueToJSType(env, value); 65 | }); 66 | 67 | worker->Queue(); 68 | return worker->GetPromise(); 69 | } 70 | catch (const std::exception &e) 71 | { 72 | throw Napi::Error::New(info.Env(), e.what()); 73 | } 74 | } 75 | 76 | c10::IValue ScriptModule::JSTypeToIValue(Napi::Env env, const Napi::Value &jsValue) 77 | { 78 | Napi::HandleScope scope(env); 79 | if (jsValue.IsArray()) 80 | { 81 | auto jsList = jsValue.As(); 82 | auto len = jsList.Length(); 83 | if (len == 0) 84 | { 85 | throw Napi::Error::New(env, "Empty array is not supported"); 86 | } 87 | auto firstElement = JSTypeToIValue(env, jsList[(uint32_t)0]); 88 | c10::List list(firstElement.type()); 89 | for (uint32_t i = 1; i < len; ++i) 90 | { 91 | list.push_back(JSTypeToIValue(env, jsList[i])); 92 | } 93 | return list; 94 | } 95 | else if (jsValue.IsObject()) 96 | { 97 | auto jsObject = jsValue.As(); 98 | if (Tensor::IsInstance(jsObject)) 99 | { 100 | auto tensor = Napi::ObjectWrap::Unwrap(jsObject); 101 | return c10::IValue(tensor->tensor()); 102 | } 103 | throw Napi::Error::New(env, "Object/Dict input is not implemented"); 104 | } 105 | else if (jsValue.IsNumber()) 106 | { 107 | auto jsNumber = jsValue.As().DoubleValue(); 108 | return c10::IValue(jsNumber); 109 | } 110 | else if (jsValue.IsBoolean()) 111 | { 112 | auto jsBool = jsValue.As().Value(); 113 | return c10::IValue(jsBool); 114 | } 115 | else if (jsValue.IsString()) 116 | { 117 | auto jsString = jsValue.As().Utf8Value(); 118 | return c10::IValue(jsString); 119 | } 120 | throw Napi::Error::New(env, "Unsupported javascript input type"); 121 | } 122 | 123 | Napi::Value ScriptModule::IValueToJSType(Napi::Env env, const c10::IValue &iValue) 124 | { 125 | Napi::EscapableHandleScope scope(env); 126 | if (iValue.isTensor()) 127 | { 128 | return scope.Escape(Tensor::FromTensor(env, iValue.toTensor())); 129 | } 130 | else if (iValue.isList()) 131 | { 132 | auto list = iValue.toList(); 133 | auto jsList = Napi::Array::New(env); 134 | for (auto i = 0; i < list.size(); i++) 135 | { 136 | jsList[i] = IValueToJSType(env, list[i]); 137 | } 138 | return scope.Escape(jsList); 139 | } 140 | else if (iValue.isGenericDict()) 141 | { 142 | auto dict = iValue.toGenericDict(); 143 | auto jsDict = Napi::Object::New(env); 144 | for (auto iter = dict.begin(); iter != dict.end(); iter++) 145 | { 146 | auto key = IValueToJSType(env, iter->key()); 147 | auto value = IValueToJSType(env, iter->value()); 148 | jsDict.Set(key, value); 149 | } 150 | return scope.Escape(jsDict); 151 | } 152 | else if (iValue.isInt()) 153 | { 154 | return scope.Escape(Napi::Number::New(env, iValue.toInt())); 155 | } 156 | else if (iValue.isDouble()) 157 | { 158 | return scope.Escape(Napi::Number::New(env, iValue.toDouble())); 159 | } 160 | else if (iValue.isBool()) 161 | { 162 | return scope.Escape(Napi::Boolean::New(env, iValue.toBool())); 163 | } 164 | else if (iValue.isString()) 165 | { 166 | return scope.Escape(Napi::String::New(env, iValue.toString().get()->string())); 167 | } 168 | else if (iValue.isTuple()) 169 | { 170 | auto list = iValue.toTuple()->elements(); 171 | auto jsList = Napi::Array::New(env); 172 | for (auto i = 0; i < list.size(); i++) 173 | { 174 | jsList[i] = IValueToJSType(env, list[i]); 175 | } 176 | return scope.Escape(jsList); 177 | } 178 | throw Napi::Error::New(env, "Unsupported output type from ScriptModule"); 179 | } 180 | 181 | Napi::Value ScriptModule::toString(const Napi::CallbackInfo &info) 182 | { 183 | return Napi::String::New(info.Env(), "ScriptModule(\"" + path_ + "\")"); 184 | } 185 | 186 | Napi::Value ScriptModule::cpu(const Napi::CallbackInfo &info) 187 | { 188 | try 189 | { 190 | module_.to(at::kCPU); 191 | return info.This(); 192 | } 193 | catch (const std::exception &e) 194 | { 195 | throw Napi::Error::New(info.Env(), e.what()); 196 | } 197 | } 198 | 199 | Napi::Value ScriptModule::cuda(const Napi::CallbackInfo &info) 200 | { 201 | try 202 | { 203 | if (torch::cuda::is_available()) 204 | { 205 | module_.to(at::kCUDA); 206 | } 207 | else 208 | { 209 | throw Napi::Error::New(info.Env(), "CUDA is not aviliable"); 210 | } 211 | return info.This(); 212 | } 213 | catch (const std::exception &e) 214 | { 215 | throw Napi::Error::New(info.Env(), e.what()); 216 | } 217 | } 218 | 219 | Napi::Value ScriptModule::isCudaAvailable(const Napi::CallbackInfo &info) 220 | { 221 | try 222 | { 223 | return Napi::Boolean::New(info.Env(), torch::cuda::is_available()); 224 | } 225 | catch (const std::exception &e) 226 | { 227 | throw Napi::Error::New(info.Env(), e.what()); 228 | } 229 | } 230 | 231 | } // namespace torchjs --------------------------------------------------------------------------------