├── .editorconfig ├── .gitattributes ├── .github └── workflows │ └── npmpublish.yml ├── .gitignore ├── .npmignore ├── .prettierrc.json ├── CMakeLists.txt ├── LICENSE ├── NodeJS.cmake ├── README.md ├── binding.gyp ├── csrc ├── jit.cc ├── jit.h ├── script_module.cc ├── script_module.h ├── tchjs.cc ├── tensor.cc ├── tensor.h ├── utils.cc └── utils.h ├── package-lock.json ├── package.json ├── src ├── binding.ts ├── index.ts └── promisify.ts ├── test ├── binding.test.ts ├── data │ ├── matmul.pt │ └── trace.py └── tensor.test.ts ├── tsconfig.json └── types ├── node-pre-gyp.d.ts └── tch.d.ts /.editorconfig: -------------------------------------------------------------------------------- 1 | # editorconfig.org 2 | root = true 3 | 4 | [*] 5 | indent_size = 2 6 | indent_style = space 7 | end_of_line = lf 8 | charset = utf-8 9 | trim_trailing_whitespace = true 10 | insert_final_newline = true 11 | 12 | [*.md] 13 | trim_trailing_whitespace = false 14 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.cmake linguist-detectable=false 2 | *.ts linguist-detectable=true 3 | *.py linguist-detectable=false 4 | -------------------------------------------------------------------------------- /.github/workflows/npmpublish.yml: -------------------------------------------------------------------------------- 1 | name: Node.js Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | build-linux: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-node@v1 13 | with: 14 | node-version: 10 15 | - run: sudo apt-get update 16 | - run: sudo apt-get install -y cmake build-essential unzip curl 17 | - run: curl -o libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.7.1%2Bcpu.zip 18 | - run: unzip libtorch.zip 19 | - run: npm i --ignore-scripts 20 | - run: npm run pre-build 21 | - run: npm test 22 | 23 | publish: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v2 27 | - uses: actions/setup-node@v1 28 | with: 29 | node-version: 10 30 | registry-url: https://registry.npmjs.org/ 31 | - run: npm publish 32 | env: 33 | NODE_AUTH_TOKEN: ${{secrets.npm_token}} 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | dist 3 | libtorch 4 | *.tgz 5 | *.log 6 | build* 7 | cmake-build* 8 | lib/binding 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | * 2 | !README.md 3 | !package.json 4 | !package-lock.json 5 | !LICENSE 6 | !lib/*.js 7 | -------------------------------------------------------------------------------- /.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "tabWidth": 2, 3 | "semi": true, 4 | "singleQuote": true, 5 | "useTabs": false, 6 | "printWidth": 100 7 | } 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1 FATAL_ERROR) 2 | 3 | project(tchjs) 4 | 5 | # Node-cmake 6 | include(NodeJS.cmake) 7 | nodejs_init() 8 | 9 | file(GLOB SRC_FILES "${PROJECT_SOURCE_DIR}/csrc/*.cc" "${PROJECT_SOURCE_DIR}/csrc/*.h") 10 | add_nodejs_module(${PROJECT_NAME} ${SRC_FILES}) 11 | 12 | # Look for shared libs in the same directory 13 | IF (UNIX) 14 | set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -Wl,-rpath,$ORIGIN") 15 | set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath,$ORIGIN") 16 | ENDIF(UNIX) 17 | 18 | # Include N-API wrappers 19 | target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_SOURCE_DIR}/node_modules/node-addon-api") 20 | 21 | find_package(Torch PATHS "./libtorch/" REQUIRED) 22 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 23 | set_target_properties(${PROJECT_NAME} PROPERTIES CXX_STANDARD 14) 24 | target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}") 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Libtorch copyrights: 2 | https://github.com/pytorch/pytorch/blob/master/LICENSE 3 | 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 2. Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 3. Neither the name of Eugene Ware nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY EUGENE WARE ''AS IS'' AND ANY 19 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL EUGENE WARE BE LIABLE FOR ANY 22 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 25 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /NodeJS.cmake: -------------------------------------------------------------------------------- 1 | # Defaults for standard Node.js builds 2 | set(NODEJS_DEFAULT_URL https://nodejs.org/download/release) 3 | set(NODEJS_DEFAULT_VERSION installed) 4 | set(NODEJS_VERSION_FALLBACK latest) 5 | set(NODEJS_DEFAULT_NAME node) 6 | set(NODEJS_DEFAULT_CHECKSUM SHASUMS256.txt) 7 | set(NODEJS_DEFAULT_CHECKTYPE SHA256) 8 | 9 | include(CMakeParseArguments) 10 | 11 | # Find a path by walking upward from a base directory until the path is 12 | # found. Sets the variable ${PATH} to False if the path can't 13 | # be determined 14 | function(find_path_parent NAME BASE PATH) 15 | set(ROOT ${BASE}) 16 | set(${PATH} ${ROOT}/${NAME} PARENT_SCOPE) 17 | set(DRIVE "^[A-Za-z]?:?/$") 18 | while(NOT ROOT MATCHES ${DRIVE} AND NOT EXISTS ${ROOT}/${NAME}) 19 | get_filename_component(ROOT ${ROOT} DIRECTORY) 20 | set(${PATH} ${ROOT}/${NAME} PARENT_SCOPE) 21 | endwhile() 22 | if(ROOT MATCHES ${DRIVE}) 23 | set(${PATH} False PARENT_SCOPE) 24 | endif() 25 | endfunction() 26 | 27 | # Shortcut for finding standard node module locations 28 | macro(find_nodejs_module NAME BASE PATH) 29 | find_path_parent(node_modules/${NAME} ${BASE} ${PATH}) 30 | endmacro() 31 | 32 | # Download with a bit of nice output (without spewing progress) 33 | function(download_file URL) 34 | message(STATUS "Downloading: ${URL}") 35 | file(APPEND ${TEMP}/download.log "Downloading: ${URL}\n") 36 | file(APPEND ${TEMP}/download.log "----------------------------------------\n") 37 | file(DOWNLOAD 38 | ${URL} 39 | ${ARGN} 40 | LOG DOWNLOAD_LOG 41 | ) 42 | file(APPEND ${TEMP}/download.log ${DOWNLOAD_LOG}) 43 | file(APPEND ${TEMP}/download.log "----------------------------------------\n") 44 | endfunction() 45 | 46 | # Embedded win_delay_load_hook file so that this file can be copied 47 | # into projects directly (recommended practice) 48 | function(nodejs_generate_delayload_hook OUTPUT) 49 | file(WRITE ${OUTPUT} "") 50 | file(APPEND ${OUTPUT} "/*\n") 51 | file(APPEND ${OUTPUT} " * When this file is linked to a DLL, it sets up a delay-load hook that\n") 52 | file(APPEND ${OUTPUT} " * intervenes when the DLL is trying to load the main node binary\n") 53 | file(APPEND ${OUTPUT} " * dynamically. Instead of trying to locate the .exe file it'll just return\n") 54 | file(APPEND ${OUTPUT} " * a handle to the process image.\n") 55 | file(APPEND ${OUTPUT} " *\n") 56 | file(APPEND ${OUTPUT} " * This allows compiled addons to work when node.exe or iojs.exe is renamed.\n") 57 | file(APPEND ${OUTPUT} " */\n") 58 | file(APPEND ${OUTPUT} "\n") 59 | file(APPEND ${OUTPUT} "#ifdef _MSC_VER\n") 60 | file(APPEND ${OUTPUT} "\n") 61 | file(APPEND ${OUTPUT} "#ifndef DELAYIMP_INSECURE_WRITABLE_HOOKS\n") 62 | file(APPEND ${OUTPUT} "#define DELAYIMP_INSECURE_WRITABLE_HOOKS\n") 63 | file(APPEND ${OUTPUT} "#endif\n") 64 | file(APPEND ${OUTPUT} "\n") 65 | file(APPEND ${OUTPUT} "#ifndef WIN32_LEAN_AND_MEAN\n") 66 | file(APPEND ${OUTPUT} "#define WIN32_LEAN_AND_MEAN\n") 67 | file(APPEND ${OUTPUT} "#endif\n") 68 | file(APPEND ${OUTPUT} "\n") 69 | file(APPEND ${OUTPUT} "#include \n") 70 | file(APPEND ${OUTPUT} "#include \n") 71 | file(APPEND ${OUTPUT} "#include \n") 72 | file(APPEND ${OUTPUT} "#include \n") 73 | file(APPEND ${OUTPUT} "#include \n") 74 | file(APPEND ${OUTPUT} "\n") 75 | file(APPEND ${OUTPUT} "static FARPROC WINAPI load_exe_hook(unsigned int event, DelayLoadInfo* info) {\n") 76 | file(APPEND ${OUTPUT} " if (event != dliNotePreLoadLibrary) return NULL;\n") 77 | file(APPEND ${OUTPUT} "\n") 78 | file(APPEND ${OUTPUT} " if (_stricmp(info->szDll, \"iojs.exe\") != 0 &&\n") 79 | file(APPEND ${OUTPUT} " _stricmp(info->szDll, \"node.exe\") != 0 &&\n") 80 | file(APPEND ${OUTPUT} " _stricmp(info->szDll, \"node.dll\") != 0)\n") 81 | file(APPEND ${OUTPUT} " return NULL;\n") 82 | file(APPEND ${OUTPUT} "\n") 83 | file(APPEND ${OUTPUT} " // Get a handle to the current process executable.\n") 84 | file(APPEND ${OUTPUT} " HMODULE processModule = GetModuleHandle(NULL);\n") 85 | file(APPEND ${OUTPUT} "\n") 86 | file(APPEND ${OUTPUT} " // Get the path to the executable.\n") 87 | file(APPEND ${OUTPUT} " TCHAR processPath[_MAX_PATH];\n") 88 | file(APPEND ${OUTPUT} " GetModuleFileName(processModule, processPath, _MAX_PATH);\n") 89 | file(APPEND ${OUTPUT} "\n") 90 | file(APPEND ${OUTPUT} " // Get the name of the current executable.\n") 91 | file(APPEND ${OUTPUT} " LPTSTR processName = PathFindFileName(processPath);\n") 92 | file(APPEND ${OUTPUT} "\n") 93 | file(APPEND ${OUTPUT} " // If the current process is node or iojs, then just return the proccess \n") 94 | file(APPEND ${OUTPUT} " // module.\n") 95 | file(APPEND ${OUTPUT} " if (_tcsicmp(processName, TEXT(\"node.exe\")) == 0 ||\n") 96 | file(APPEND ${OUTPUT} " _tcsicmp(processName, TEXT(\"iojs.exe\")) == 0) {\n") 97 | file(APPEND ${OUTPUT} " return (FARPROC) processModule;\n") 98 | file(APPEND ${OUTPUT} " }\n") 99 | file(APPEND ${OUTPUT} "\n") 100 | file(APPEND ${OUTPUT} " // If it is another process, attempt to load 'node.dll' from the same \n") 101 | file(APPEND ${OUTPUT} " // directory.\n") 102 | file(APPEND ${OUTPUT} " PathRemoveFileSpec(processPath);\n") 103 | file(APPEND ${OUTPUT} " PathAppend(processPath, TEXT(\"node.dll\"));\n") 104 | file(APPEND ${OUTPUT} "\n") 105 | file(APPEND ${OUTPUT} " HMODULE nodeDllModule = GetModuleHandle(processPath);\n") 106 | file(APPEND ${OUTPUT} " if(nodeDllModule != NULL) {\n") 107 | file(APPEND ${OUTPUT} " // This application has a node.dll in the same directory as the executable,\n") 108 | file(APPEND ${OUTPUT} " // use that.\n") 109 | file(APPEND ${OUTPUT} " return (FARPROC) nodeDllModule;\n") 110 | file(APPEND ${OUTPUT} " }\n") 111 | file(APPEND ${OUTPUT} "\n") 112 | file(APPEND ${OUTPUT} " // Fallback to the current executable, which must statically link to \n") 113 | file(APPEND ${OUTPUT} " // node.lib\n") 114 | file(APPEND ${OUTPUT} " return (FARPROC) processModule;\n") 115 | file(APPEND ${OUTPUT} "}\n") 116 | file(APPEND ${OUTPUT} "\n") 117 | file(APPEND ${OUTPUT} "PfnDliHook __pfnDliNotifyHook2 = load_exe_hook;\n") 118 | file(APPEND ${OUTPUT} "\n") 119 | file(APPEND ${OUTPUT} "#endif\n") 120 | endfunction() 121 | 122 | # Sets up a project to build Node.js native modules 123 | # - Downloads required dependencies and unpacks them to the build directory. 124 | # Internet access is required the first invocation but not after ( 125 | # provided the download is successful) 126 | # - Sets up several variables for building against the downloaded 127 | # dependencies 128 | # - Guarded to prevent multiple executions, so a single project hierarchy 129 | # will only call this once 130 | function(nodejs_init) 131 | # Prevents this function from executing more than once 132 | if(NODEJS_INIT) 133 | return() 134 | endif() 135 | 136 | # Regex patterns used by the init function for component extraction 137 | set(HEADERS_MATCH "^([A-Fa-f0-9]+)[ \t]+([^-]+)-(headers|v?[0-9.]+)-(headers|v?[0-9.]+)([.]tar[.]gz)$") 138 | set(LIB32_MATCH "(^[0-9A-Fa-f]+)[\t ]+(win-x86)?(/)?([^/]*)(.lib)$") 139 | set(LIB64_MATCH "(^[0-9A-Fa-f]+)[\t ]+(win-)?(x64/)(.*)(.lib)$") 140 | 141 | # Parse function arguments 142 | cmake_parse_arguments(nodejs_init 143 | "" "URL;NAME;VERSION;CHECKSUM;CHECKTYPE" "" ${ARGN} 144 | ) 145 | 146 | # Allow the download URL to be overridden by command line argument 147 | # NODEJS_URL 148 | if(NODEJS_URL) 149 | set(URL ${NODEJS_URL}) 150 | else() 151 | # Use the argument if specified, falling back to the default 152 | set(URL ${NODEJS_DEFAULT_URL}) 153 | if(nodejs_init_URL) 154 | set(URL ${nodejs_init_URL}) 155 | endif() 156 | endif() 157 | 158 | # Allow name to be overridden by command line argument NODEJS_NAME 159 | if(NODEJS_NAME) 160 | set(NAME ${NODEJS_NAME}) 161 | else() 162 | # Use the argument if specified, falling back to the default 163 | set(NAME ${NODEJS_DEFAULT_NAME}) 164 | if(nodejs_init_NAME) 165 | set(NAME ${nodejs_init_NAME}) 166 | endif() 167 | endif() 168 | 169 | # Allow the checksum file to be overridden by command line argument 170 | # NODEJS_CHECKSUM 171 | if(NODEJS_CHECKSUM) 172 | set(CHECKSUM ${NODEJS_CHECKSUM}) 173 | else() 174 | # Use the argument if specified, falling back to the default 175 | set(CHECKSUM ${NODEJS_DEFAULT_CHECKSUM}) 176 | if(nodejs_init_CHECKSUM) 177 | set(CHECKSUM ${nodejs_init_CHECKSUM}) 178 | endif() 179 | endif() 180 | 181 | # Allow the checksum type to be overriden by the command line argument 182 | # NODEJS_CHECKTYPE 183 | if(NODEJS_CHECKTYPE) 184 | set(CHECKTYPE ${NODEJS_CHECKTYPE}) 185 | else() 186 | # Use the argument if specified, falling back to the default 187 | set(CHECKTYPE ${NODEJS_DEFAULT_CHECKTYPE}) 188 | if(nodejs_init_CHECKTYPE) 189 | set(CHECKTYPE ${nodejs_init_CHECKTYPE}) 190 | endif() 191 | endif() 192 | 193 | # Allow the version to be overridden by the command line argument 194 | # NODEJS_VERSION 195 | if(NODEJS_VERSION) 196 | set(VERSION ${NODEJS_VERSION}) 197 | else() 198 | # Use the argument if specified, falling back to the default 199 | set(VERSION ${NODEJS_DEFAULT_VERSION}) 200 | if(nodejs_init_VERSION) 201 | set(VERSION ${nodejs_init_VERSION}) 202 | endif() 203 | endif() 204 | 205 | # "installed" is a special version that tries to use the currently 206 | # installed version (determined by running node) 207 | set(NODEJS_INSTALLED False CACHE BOOL "Node.js install status" FORCE) 208 | if(VERSION STREQUAL "installed") 209 | if(NOT NAME STREQUAL ${NODEJS_DEFAULT_NAME}) 210 | message(FATAL_ERROR 211 | "'Installed' version identifier can only be used with" 212 | "the core Node.js library" 213 | ) 214 | endif() 215 | # Fall back to the "latest" version if node isn't installed 216 | set(VERSION ${NODEJS_VERSION_FALLBACK}) 217 | # This has all of the implications of why the binary is called nodejs in the first place 218 | # https://lists.debian.org/debian-devel-announce/2012/07/msg00002.html 219 | # However, with nvm/n, its nearly standard to have a proper 'node' binary now (since the 220 | # apt-based one is so out of date), so for now just assume that this rare binary conflict 221 | # case is the degenerate case. May need a more complicated solution later. 222 | find_program(NODEJS_BINARY NAMES node nodejs) 223 | if(NODEJS_BINARY) 224 | execute_process( 225 | COMMAND ${NODEJS_BINARY} --version 226 | RESULT_VARIABLE INSTALLED_VERSION_RESULT 227 | OUTPUT_VARIABLE INSTALLED_VERSION 228 | OUTPUT_STRIP_TRAILING_WHITESPACE 229 | ) 230 | if(INSTALLED_VERSION_RESULT STREQUAL "0") 231 | set(NODEJS_INSTALLED True CACHE BOOL 232 | "Node.js install status" FORCE 233 | ) 234 | set(VERSION ${INSTALLED_VERSION}) 235 | endif() 236 | endif() 237 | endif() 238 | 239 | # Create a temporary download directory 240 | set(TEMP ${CMAKE_CURRENT_BINARY_DIR}/temp) 241 | if(EXISTS ${TEMP}) 242 | file(REMOVE_RECURSE ${TEMP}) 243 | endif() 244 | file(MAKE_DIRECTORY ${TEMP}) 245 | 246 | # Unless the target is special version "latest", the parameters 247 | # necessary to construct the root path are known 248 | if(NOT VERSION STREQUAL "latest") 249 | set(ROOT ${CMAKE_CURRENT_BINARY_DIR}/${NAME}/${VERSION}) 250 | # Extract checksums from the existing checksum file 251 | set(CHECKSUM_TARGET ${ROOT}/CHECKSUM) 252 | endif() 253 | 254 | # If we're trying to determine the version or we haven't saved the 255 | # checksum file for this version, download it from the specified server 256 | if(VERSION STREQUAL "latest" OR 257 | (DEFINED ROOT AND NOT EXISTS ${ROOT}/CHECKSUM)) 258 | if(DEFINED ROOT) 259 | # Clear away the old checksum in case the new one is different 260 | # and/or it fails to download 261 | file(REMOVE ${ROOT}/CHECKSUM) 262 | endif() 263 | file(REMOVE ${TEMP}/CHECKSUM) 264 | download_file( 265 | ${URL}/${VERSION}/${CHECKSUM} 266 | ${TEMP}/CHECKSUM 267 | INACTIVITY_TIMEOUT 10 268 | STATUS CHECKSUM_STATUS 269 | ) 270 | list(GET CHECKSUM_STATUS 0 CHECKSUM_STATUS) 271 | if(CHECKSUM_STATUS GREATER 0) 272 | file(REMOVE ${TEMP}/CHECKSUM) 273 | message(FATAL_ERROR 274 | "Unable to download checksum file" 275 | ) 276 | endif() 277 | # Extract checksums from the temporary file 278 | set(CHECKSUM_TARGET ${TEMP}/CHECKSUM) 279 | endif() 280 | 281 | # Extract the version, name, header archive and archive checksum 282 | # from the file. This first extract is what defines / specifies the 283 | # actual version number and name. 284 | file(STRINGS 285 | ${CHECKSUM_TARGET} HEADERS_CHECKSUM 286 | REGEX ${HEADERS_MATCH} 287 | LIMIT_COUNT 1 288 | ) 289 | if(NOT HEADERS_CHECKSUM) 290 | file(REMOVE ${TEMP}/CHECKSUM) 291 | if(DEFINED ROOT) 292 | file(REMOVE ${ROOT}/CHECKSUM) 293 | endif() 294 | message(FATAL_ERROR "Unable to extract header archive checksum") 295 | endif() 296 | string(REGEX MATCH ${HEADERS_MATCH} HEADERS_CHECKSUM ${HEADERS_CHECKSUM}) 297 | set(HEADERS_CHECKSUM ${CMAKE_MATCH_1}) 298 | set(NAME ${CMAKE_MATCH_2}) 299 | if(CMAKE_MATCH_3 STREQUAL "headers") 300 | set(VERSION ${CMAKE_MATCH_4}) 301 | else() 302 | set(VERSION ${CMAKE_MATCH_3}) 303 | endif() 304 | set(HEADERS_ARCHIVE 305 | ${CMAKE_MATCH_2}-${CMAKE_MATCH_3}-${CMAKE_MATCH_4}${CMAKE_MATCH_5} 306 | ) 307 | # Make sure that the root directory exists, and that the checksum 308 | # file has been moved over from temp 309 | if(DEFINED ROOT) 310 | set(OLD_ROOT ${ROOT}) 311 | endif() 312 | set(ROOT ${CMAKE_CURRENT_BINARY_DIR}/${NAME}/${VERSION}) 313 | if(DEFINED OLD_ROOT AND NOT ROOT STREQUAL "${OLD_ROOT}") 314 | file(REMOVE ${TEMP}/CHECKSUM) 315 | file(REMOVE ${ROOT}/CHECKSUM) 316 | message(FATAL_ERROR "Version/Name mismatch") 317 | endif() 318 | file(MAKE_DIRECTORY ${ROOT}) 319 | if(EXISTS ${TEMP}/CHECKSUM) 320 | file(REMOVE ${ROOT}/CHECKSUM) 321 | file(RENAME ${TEMP}/CHECKSUM ${ROOT}/CHECKSUM) 322 | endif() 323 | 324 | # Now that its fully resolved, report the name and version of Node.js being 325 | # used 326 | message(STATUS "NodeJS: Using ${NAME}, version ${VERSION}") 327 | 328 | # Download the headers for the version being used 329 | # Theoretically, these could be found by searching the installed 330 | # system, but in practice, this can be error prone. They're provided 331 | # on the download servers, so just use the ones there. 332 | if(NOT EXISTS ${ROOT}/include) 333 | file(REMOVE ${TEMP}/${HEADERS_ARCHIVE}) 334 | download_file( 335 | ${URL}/${VERSION}/${HEADERS_ARCHIVE} 336 | ${TEMP}/${HEADERS_ARCHIVE} 337 | INACTIVITY_TIMEOUT 10 338 | EXPECTED_HASH ${CHECKTYPE}=${HEADERS_CHECKSUM} 339 | STATUS HEADERS_STATUS 340 | ) 341 | list(GET HEADERS_STATUS 0 HEADERS_STATUS) 342 | if(HEADER_STATUS GREATER 0) 343 | file(REMOVE ${TEMP}/${HEADERS_ARCHIVE}) 344 | message(FATAL_ERROR "Unable to download Node.js headers") 345 | endif() 346 | execute_process( 347 | COMMAND ${CMAKE_COMMAND} -E tar xfz ${TEMP}/${HEADERS_ARCHIVE} 348 | WORKING_DIRECTORY ${TEMP} 349 | ) 350 | 351 | # This adapts the header extraction to support a number of different 352 | # header archive contents in addition to the one used by the 353 | # default Node.js library 354 | unset(NODEJS_HEADERS_PATH CACHE) 355 | find_path(NODEJS_HEADERS_PATH 356 | NAMES src include 357 | PATHS 358 | ${TEMP}/${NAME}-${VERSION}-headers 359 | ${TEMP}/${NAME}-${VERSION} 360 | ${TEMP}/${NODEJS_DEFAULT_NAME}-${VERSION}-headers 361 | ${TEMP}/${NODEJS_DEFAULT_NAME}-${VERSION} 362 | ${TEMP}/${NODEJS_DEFAULT_NAME} 363 | ${TEMP} 364 | NO_DEFAULT_PATH 365 | ) 366 | if(NOT NODEJS_HEADERS_PATH) 367 | message(FATAL_ERROR "Unable to find extracted headers directory") 368 | endif() 369 | 370 | # Move the headers into a standard location with a standard layout 371 | file(REMOVE ${TEMP}/${HEADERS_ARCHIVE}) 372 | file(REMOVE_RECURSE ${ROOT}/include) 373 | if(EXISTS ${NODEJS_HEADERS_PATH}/include/node) 374 | file(RENAME ${NODEJS_HEADERS_PATH}/include/node ${ROOT}/include) 375 | elseif(EXISTS ${NODEJS_HEADERS_PATH}/src) 376 | file(MAKE_DIRECTORY ${ROOT}/include) 377 | if(NOT EXISTS ${NODEJS_HEADERS_PATH}/src) 378 | file(REMOVE_RECURSE ${ROOT}/include) 379 | message(FATAL_ERROR "Unable to find core headers") 380 | endif() 381 | file(COPY ${NODEJS_HEADERS_PATH}/src/ 382 | DESTINATION ${ROOT}/include 383 | ) 384 | if(NOT EXISTS ${NODEJS_HEADERS_PATH}/deps/uv/include) 385 | file(REMOVE_RECURSE ${ROOT}/include) 386 | message(FATAL_ERROR "Unable to find libuv headers") 387 | endif() 388 | file(COPY ${NODEJS_HEADERS_PATH}/deps/uv/include/ 389 | DESTINATION ${ROOT}/include 390 | ) 391 | if(NOT EXISTS ${NODEJS_HEADERS_PATH}/deps/v8/include) 392 | file(REMOVE_RECURSE ${ROOT}/include) 393 | message(FATAL_ERROR "Unable to find v8 headers") 394 | endif() 395 | file(COPY ${NODEJS_HEADERS_PATH}/deps/v8/include/ 396 | DESTINATION ${ROOT}/include 397 | ) 398 | if(NOT EXISTS ${NODEJS_HEADERS_PATH}/deps/zlib) 399 | file(REMOVE_RECURSE ${ROOT}/include) 400 | message(FATAL_ERROR "Unable to find zlib headers") 401 | endif() 402 | file(COPY ${NODEJS_HEADERS_PATH}/deps/zlib/ 403 | DESTINATION ${ROOT}/include 404 | ) 405 | endif() 406 | file(REMOVE_RECURSE ${NODEJS_HEADERS_PATH}) 407 | unset(NODEJS_HEADERS_PATH CACHE) 408 | endif() 409 | 410 | # Only download the libraries on windows, since its the only place 411 | # its necessary. Note, this requires rerunning CMake if moving 412 | # a module from one platform to another (should happen automatically 413 | # with most generators) 414 | if(WIN32) 415 | # Download the win32 library for linking 416 | file(STRINGS 417 | ${ROOT}/CHECKSUM LIB32_CHECKSUM 418 | LIMIT_COUNT 1 419 | REGEX ${LIB32_MATCH} 420 | ) 421 | if(NOT LIB32_CHECKSUM) 422 | message(FATAL_ERROR "Unable to extract x86 library checksum") 423 | endif() 424 | string(REGEX MATCH ${LIB32_MATCH} LIB32_CHECKSUM ${LIB32_CHECKSUM}) 425 | set(LIB32_CHECKSUM ${CMAKE_MATCH_1}) 426 | set(LIB32_PATH win-x86) 427 | set(LIB32_NAME ${CMAKE_MATCH_4}${CMAKE_MATCH_5}) 428 | set(LIB32_TARGET ${CMAKE_MATCH_2}${CMAKE_MATCH_3}${LIB32_NAME}) 429 | if(NOT EXISTS ${ROOT}/${LIB32_PATH}) 430 | file(REMOVE_RECURSE ${TEMP}/${LIB32_PATH}) 431 | download_file( 432 | ${URL}/${VERSION}/${LIB32_TARGET} 433 | ${TEMP}/${LIB32_PATH}/${LIB32_NAME} 434 | INACTIVITY_TIMEOUT 10 435 | EXPECTED_HASH ${CHECKTYPE}=${LIB32_CHECKSUM} 436 | STATUS LIB32_STATUS 437 | ) 438 | list(GET LIB32_STATUS 0 LIB32_STATUS) 439 | if(LIB32_STATUS GREATER 0) 440 | message(FATAL_ERROR 441 | "Unable to download Node.js windows library (32-bit)" 442 | ) 443 | endif() 444 | file(REMOVE_RECURSE ${ROOT}/${LIB32_PATH}) 445 | file(MAKE_DIRECTORY ${ROOT}/${LIB32_PATH}) 446 | file(RENAME 447 | ${TEMP}/${LIB32_PATH}/${LIB32_NAME} 448 | ${ROOT}/${LIB32_PATH}/${LIB32_NAME} 449 | ) 450 | file(REMOVE_RECURSE ${TEMP}/${LIB32_PATH}) 451 | endif() 452 | 453 | # Download the win64 library for linking 454 | file(STRINGS 455 | ${ROOT}/CHECKSUM LIB64_CHECKSUM 456 | LIMIT_COUNT 1 457 | REGEX ${LIB64_MATCH} 458 | ) 459 | if(NOT LIB64_CHECKSUM) 460 | message(FATAL_ERROR "Unable to extract x64 library checksum") 461 | endif() 462 | string(REGEX MATCH ${LIB64_MATCH} LIB64_CHECKSUM ${LIB64_CHECKSUM}) 463 | set(LIB64_CHECKSUM ${CMAKE_MATCH_1}) 464 | set(LIB64_PATH win-x64) 465 | set(LIB64_NAME ${CMAKE_MATCH_4}${CMAKE_MATCH_5}) 466 | set(LIB64_TARGET ${CMAKE_MATCH_2}${CMAKE_MATCH_3}${LIB64_NAME}) 467 | if(NOT EXISTS ${ROOT}/${LIB64_PATH}) 468 | file(REMOVE_RECURSE ${TEMP}/${LIB64_PATH}) 469 | download_file( 470 | ${URL}/${VERSION}/${LIB64_TARGET} 471 | ${TEMP}/${LIB64_PATH}/${LIB64_NAME} 472 | INACTIVITY_TIMEOUT 10 473 | EXPECTED_HASH ${CHECKTYPE}=${LIB64_CHECKSUM} 474 | STATUS LIB64_STATUS 475 | ) 476 | list(GET LIB64_STATUS 0 LIB64_STATUS) 477 | if(LIB64_STATUS GREATER 0) 478 | message(FATAL_ERROR 479 | "Unable to download Node.js windows library (64-bit)" 480 | ) 481 | endif() 482 | file(REMOVE_RECURSE ${ROOT}/${LIB64_PATH}) 483 | file(MAKE_DIRECTORY ${ROOT}/${LIB64_PATH}) 484 | file(RENAME 485 | ${TEMP}/${LIB64_PATH}/${LIB64_NAME} 486 | ${ROOT}/${LIB64_PATH}/${LIB64_NAME} 487 | ) 488 | file(REMOVE_RECURSE ${TEMP}/${LIB64_PATH}) 489 | endif() 490 | endif() 491 | 492 | # The downloaded headers should always be set for inclusion 493 | list(APPEND INCLUDE_DIRS ${ROOT}/include) 494 | 495 | # Look for the NAN module, and add it to the includes 496 | find_nodejs_module( 497 | nan 498 | ${CMAKE_CURRENT_SOURCE_DIR} 499 | NODEJS_NAN_DIR 500 | ) 501 | if(NODEJS_NAN_DIR) 502 | list(APPEND INCLUDE_DIRS ${NODEJS_NAN_DIR}) 503 | endif() 504 | 505 | # Under windows, we need a bunch of libraries (due to the way 506 | # dynamic linking works) 507 | if(WIN32) 508 | # Generate and use a delay load hook to allow the node binary 509 | # name to be changed while still loading native modules 510 | set(DELAY_LOAD_HOOK ${CMAKE_CURRENT_BINARY_DIR}/win_delay_load_hook.c) 511 | nodejs_generate_delayload_hook(${DELAY_LOAD_HOOK}) 512 | set(SOURCES ${DELAY_LOAD_HOOK}) 513 | 514 | # Necessary flags to get delayload working correctly 515 | list(APPEND LINK_FLAGS 516 | "-IGNORE:4199" 517 | "-DELAYLOAD:iojs.exe" 518 | "-DELAYLOAD:node.exe" 519 | "-DELAYLOAD:node.dll" 520 | ) 521 | 522 | # Core system libraries used by node 523 | list(APPEND LIBRARIES 524 | kernel32.lib user32.lib gdi32.lib winspool.lib comdlg32.lib 525 | advapi32.lib shell32.lib ole32.lib oleaut32.lib uuid.lib 526 | odbc32.lib Shlwapi.lib DelayImp.lib 527 | ) 528 | 529 | # Also link to the node stub itself (downloaded above) 530 | if(CMAKE_CL_64) 531 | list(APPEND LIBRARIES ${ROOT}/${LIB64_PATH}/${LIB64_NAME}) 532 | else() 533 | list(APPEND LIBRARIES ${ROOT}/${LIB32_PATH}/${LIB32_NAME}) 534 | endif() 535 | else() 536 | # Non-windows platforms should use these flags 537 | list(APPEND DEFINITIONS _LARGEFILE_SOURCE _FILE_OFFSET_BITS=64) 538 | endif() 539 | 540 | # Special handling for OSX / clang to allow undefined symbols 541 | # Define is required by node on OSX 542 | if(APPLE) 543 | list(APPEND LINK_FLAGS "-undefined dynamic_lookup") 544 | list(APPEND DEFINITIONS _DARWIN_USE_64_BIT_INODE=1) 545 | endif() 546 | 547 | # Export all settings for use as arguments in the rest of the build 548 | set(NODEJS_VERSION ${VERSION} PARENT_SCOPE) 549 | set(NODEJS_SOURCES ${SOURCES} PARENT_SCOPE) 550 | set(NODEJS_INCLUDE_DIRS ${INCLUDE_DIRS} PARENT_SCOPE) 551 | set(NODEJS_LIBRARIES ${LIBRARIES} PARENT_SCOPE) 552 | set(NODEJS_LINK_FLAGS ${LINK_FLAGS} PARENT_SCOPE) 553 | set(NODEJS_DEFINITIONS ${DEFINITIONS} PARENT_SCOPE) 554 | 555 | # Prevents this function from executing more than once 556 | set(NODEJS_INIT TRUE PARENT_SCOPE) 557 | endfunction() 558 | 559 | # Helper function for defining a node module 560 | # After nodejs_init, all of the settings and dependencies necessary to do 561 | # this yourself are defined, but this helps make sure everything is configured 562 | # correctly. Feel free to use it as a model to do this by hand (or to 563 | # tweak this configuration if you need something custom). 564 | function(add_nodejs_module NAME) 565 | # Validate name parameter (must be a valid C identifier) 566 | string(MAKE_C_IDENTIFIER ${NAME} ${NAME}_SYMBOL_CHECK) 567 | if(NOT "${NAME}" STREQUAL "${${NAME}_SYMBOL_CHECK}") 568 | message(FATAL_ERROR 569 | "Module name must be a valid C identifier. " 570 | "Suggested alternative: '${${NAME}_SYMBOL_CHECK}'" 571 | ) 572 | endif() 573 | # Make sure node is initialized (variables set) before defining the module 574 | if(NOT NODEJS_INIT) 575 | message(FATAL_ERROR 576 | "Node.js has not been initialized. " 577 | "Call nodejs_init before adding any modules" 578 | ) 579 | endif() 580 | # In order to match node-gyp, we need to build into type specific directories 581 | # ncmake takes care of this, but be sure to set CMAKE_BUILD_TYPE yourself 582 | # if invoking CMake directly 583 | if(NOT CMAKE_CONFIGURATION_TYPES AND NOT CMAKE_BUILD_TYPE) 584 | message(FATAL_ERROR 585 | "Configuration type must be specified. " 586 | "Set CMAKE_BUILD_TYPE or use a different generator" 587 | ) 588 | endif() 589 | 590 | # A node module is a shared library 591 | add_library(${NAME} SHARED ${NODEJS_SOURCES} ${ARGN}) 592 | # Add compiler defines for the module 593 | # Two helpful ones: 594 | # MODULE_NAME must match the name of the build library, define that here 595 | # ${NAME}_BUILD is for symbol visibility under windows 596 | string(TOUPPER "${NAME}_BUILD" ${NAME}_BUILD_DEF) 597 | target_compile_definitions(${NAME} 598 | PRIVATE MODULE_NAME=${NAME} 599 | PRIVATE ${${NAME}_BUILD_DEF} 600 | PUBLIC ${NODEJS_DEFINITIONS} 601 | ) 602 | # This properly defines includes for the module 603 | target_include_directories(${NAME} PUBLIC ${NODEJS_INCLUDE_DIRS}) 604 | 605 | # Add link flags to the module 606 | target_link_libraries(${NAME} ${NODEJS_LIBRARIES}) 607 | 608 | # Set required properties for the module to build properly 609 | # Correct naming, symbol visiblity and C++ standard 610 | set_target_properties(${NAME} PROPERTIES 611 | OUTPUT_NAME ${NAME} 612 | PREFIX "" 613 | SUFFIX ".node" 614 | MACOSX_RPATH ON 615 | C_VISIBILITY_PRESET hidden 616 | CXX_VISIBILITY_PRESET hidden 617 | POSITION_INDEPENDENT_CODE TRUE 618 | CMAKE_CXX_STANDARD_REQUIRED TRUE 619 | CXX_STANDARD 11 620 | ) 621 | 622 | # Handle link flag cases properly 623 | # When there are link flags, they should be appended to LINK_FLAGS with space separation 624 | # If the list is emtpy (true for most *NIX platforms), this is a no-op 625 | foreach(NODEJS_LINK_FLAG IN LISTS NODEJS_LINK_FLAGS) 626 | set_property(TARGET ${NAME} APPEND_STRING PROPERTY LINK_FLAGS " ${NODEJS_LINK_FLAG}") 627 | endforeach() 628 | 629 | # Make sure we're buiilding in a build specific output directory 630 | # Only necessary on single-target generators (Make, Ninja) 631 | # Multi-target generators do this automatically 632 | # This (luckily) mirrors node-gyp conventions 633 | if(NOT CMAKE_CONFIGURATION_TYPES) 634 | set_property(TARGET ${NAME} PROPERTY 635 | LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BUILD_TYPE} 636 | ) 637 | endif() 638 | endfunction() 639 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tch-js 2 | 3 | An unofficial JavaScript and TypeScript port of PyTorch C++ library (libtorch). 4 | 5 | ## Install 6 | 7 | Package publishing to NPM is still work-in-progress. In the meantime, you can 8 | install the package directly from GitHub. 9 | 10 | ```sh 11 | $ npm i git+https://github.com/cedrickchee/tch-js.git 12 | ``` 13 | 14 | The package will download the pre-built binary during installation. You don't 15 | have to download PyTorch/libtorch or install tools for compiling from source. 16 | 17 | ## Versions Supported 18 | 19 | - Node.js: 20 | - 10 (tested) 21 | - 12 (tested) 22 | - 14 (WIP) 23 | - PyTorch (CPU): 24 | - 1.4.X (tested) 25 | - 1.5.X 26 | - 1.6.X 27 | - 1.7.X (tested) 28 | 29 | ## Code Examples 30 | 31 | ```javascript 32 | // This is a real example from an audio source separation model. 33 | const { tch, load, Tensor } = require('tch-js'); 34 | const fs = require('fs'); 35 | const wav = require('node-wav'); 36 | 37 | // WAV samples of length 269973 38 | const monoAudioChan = new Float32Array([ 39 | 1.100000023841858, 1.590000033378601, 40 | 2.049999952316284, 0.18000000715255737 41 | ]); 42 | // Flat tensor 43 | let audio = tch.tensor(monoAudioChan); // tensor of size [269973] 44 | 45 | // Reshape to 1xSample Length to match model input 46 | audio = audio.view([1, monoAudioChan.length]); // tensor of size [1, 269973] 47 | 48 | // Load PyTorch traced model async from file and return resulting ScripModule. 49 | const model = await load('sound-model.pt'); 50 | // Forward tensor async and return resulting Tensor. 51 | model.forward(audio, getResult); 52 | const getResult = (err: Error, result: Tensor) => { 53 | if (err) return; 54 | 55 | // result is a tensor of size [1, 1, 269973] 56 | const out = result.toFloat32Array(); // convert Tensor to JS TypedArray 57 | 58 | // Encode output to 16-bit float WAV and write to file. 59 | const buf = wav.encode([out], { sampleRate: 44100, float: false, bitDepth: 16}); 60 | fs.writeFileSync("out.wav", Buffer.from(buf)); 61 | }; 62 | ``` 63 | 64 | ## Build It Yourself 65 | 66 | Currently, only Linux builds are available. 67 | 68 | If you want to build the package youself, below are the steps to reproduce the 69 | build. 70 | 71 | Installing on Linux: 72 | - Ubuntu 18.04 (tested) 73 | - Ubuntu 20.04 (tested) 74 | 75 | 1. Install build tools 76 | 77 | ```sh 78 | $ apt install -y cmake make gcc-c++ unzip 79 | ``` 80 | 81 | 2. Install Node.js 10 82 | 3. Download libtorch 83 | Download [libtorch pre-cxx11 ABI and CPU version](https://pytorch.org/get-started/locally/#start-locally). 84 | 85 | ```sh 86 | $ curl -o libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.7.1%2Bcpu.zip 87 | $ unzip libtorch.zip 88 | ``` 89 | 90 | 4. Install Node.js packages 91 | 92 | ```sh 93 | $ npm i --ignore-scripts 94 | ``` 95 | 96 | 5. Build 97 | 98 | ```sh 99 | $ npm run pre-build 100 | ``` 101 | 102 | 6. Test 103 | 104 | ```sh 105 | $ npm run test 106 | ``` 107 | 108 | ## The Plan 109 | 110 | The library should: 111 | - expose more libtorch types and APIs for inference 112 | - supports Windows 113 | - auto build binaries using CI 114 | - TypeScript types 115 | 116 | ## Research 117 | 118 | - [PyTorch 1.0 tracing JIT and LibTorch C++ API to integrate PyTorch into NodeJS](https://blog.christianperone.com/2018/10/pytorch-1-0-tracing-jit-and-libtorch-c-api-to-integrate-pytorch-into-nodejs/) 119 | - [tch-rs](https://github.com/LaurentMazare/tch-rs) - Rust bindings for the C++ API of PyTorch 120 | - [torch-js](https://github.com/arition/torch-js) - Node.js binding for PyTorch 121 | - [pytorchjs](https://github.com/raghavmecheri/pytorchjs) - Torch and TorchVision, but for NodeJS 122 | - [Using the PyTorch C++ Frontend tutorial](https://pytorch.org/tutorials/advanced/cpp_frontend.html) 123 | - [Loading a TorchScript Model in C++](https://pytorch.org/tutorials/advanced/cpp_export.html) 124 | - [Torchaudio.load() in C++](https://discuss.pytorch.org/t/torchaudio-load-in-c/62400) 125 | - [Minimal PyTorch 1.0 -> C++ full example demo](https://gist.github.com/zeryx/526dbc05479e166ca7d512a670e6b82d) 126 | - [Convert libtorch tensor into an array](https://discuss.pytorch.org/t/convert-tensor-into-an-array/56721) 127 | - [JavaScript TypedArray](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypedArray) 128 | - [Napi::TypedArray](https://github.com/nodejs/node-addon-api/blob/master/doc/typed_array_of.md) - Node.js N-API doc 129 | - [How to use WebAssembly from node.js?](https://stackoverflow.com/questions/51403326/how-to-use-webassembly-from-node-js) - C++ to WebAssembly to Node.js 130 | - [node-pre-gyp](https://github.com/mapbox/node-pre-gyp) - Node.js tool for easy binary deployment of C++ addons 131 | -------------------------------------------------------------------------------- /binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "<(module_name)", 5 | "product_dir": "<(module_path)", 6 | "type": "none", 7 | "actions": [ 8 | { 9 | "action_name": "ncmake", 10 | "inputs": [""], 11 | "outputs": [""], 12 | "conditions": [ 13 | [ "OS=='linux'", 14 | {"action": ["npm", "run", "cmake-rebuild", "-DNAPI_VERSION=<(napi_build_version)"]} 15 | ] 16 | ] 17 | } 18 | ] 19 | }, 20 | { 21 | "target_name": "action_after_build", 22 | "type": "none", 23 | "dependencies": [ "<(module_name)" ], 24 | "copies": [ 25 | { 26 | "files": [ "<(PRODUCT_DIR)/<(module_name).node" ], 27 | "destination": "<(module_path)" 28 | }, 29 | { 30 | "files": [ ], 31 | "conditions": [ 32 | [ "OS=='linux'", 33 | { "files+": [ 34 | "libtorch/lib/libc10.so", 35 | "libtorch/lib/libcaffe2_detectron_ops.so", 36 | "libtorch/lib/libcaffe2_module_test_dynamic.so", 37 | "libtorch/lib/libfbjni.so", 38 | "libtorch/lib/libgomp-7c85b1e2.so.1", 39 | "libtorch/lib/libpytorch_jni.so", 40 | "libtorch/lib/libtorch.so" 41 | ] 42 | } 43 | ] 44 | ], 45 | "destination": "<(module_path)" 46 | } 47 | ] 48 | } 49 | ] 50 | } 51 | -------------------------------------------------------------------------------- /csrc/jit.cc: -------------------------------------------------------------------------------- 1 | #include "jit.h" 2 | #include "script_module.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tchjs { 9 | class LoadWorker : public Napi::AsyncWorker { 10 | public: 11 | static void load(const Napi::CallbackInfo &info) { 12 | std::string filename = info[0].As().Utf8Value(); 13 | Napi::Function cb = info[1].As(); 14 | auto *worker = new LoadWorker(cb, filename); 15 | worker->Queue(); 16 | } 17 | 18 | protected: 19 | void Execute() override { 20 | module = torch::jit::load(filename); 21 | } 22 | 23 | void OnOK() override { 24 | Napi::HandleScope scope(Env()); 25 | auto scriptModule = ScriptModule::NewInstance(); 26 | Napi::ObjectWrap::Unwrap(scriptModule)->setModule(module); 27 | Callback().Call({Env().Undefined(), scriptModule}); 28 | } 29 | 30 | private: 31 | LoadWorker(Napi::Function cb, std::string &filename) 32 | : Napi::AsyncWorker(cb), filename(filename) {} 33 | 34 | torch::jit::script::Module module; 35 | std::string filename; 36 | }; 37 | 38 | Napi::Object JitInit(Napi::Env env, Napi::Object exports) { 39 | exports.Set(Napi::String::New(env, "load"), Napi::Function::New(env, LoadWorker::load)); 40 | return exports; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /csrc/jit.h: -------------------------------------------------------------------------------- 1 | #ifndef TCHJS_JIT_H 2 | #define TCHJS_JIT_H 3 | 4 | #include 5 | 6 | namespace tchjs { 7 | Napi::Object JitInit(Napi::Env env, Napi::Object exports); 8 | } 9 | 10 | #endif 11 | -------------------------------------------------------------------------------- /csrc/script_module.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "script_module.h" 4 | #include "tensor.h" 5 | #include 6 | #include 7 | 8 | namespace tchjs { 9 | class ScriptModuleForwardAsyncWorker : public Napi::AsyncWorker { 10 | public: 11 | static void forward(torch::jit::script::Module module, 12 | torch::Tensor tensor, Napi::Function cb) { 13 | auto *worker = new ScriptModuleForwardAsyncWorker(module, tensor, cb); 14 | worker->Queue(); 15 | } 16 | 17 | protected: 18 | void Execute() override { 19 | outTensor = module.forward({inTensor}).toTensor(); 20 | } 21 | 22 | void OnOK() override { 23 | Napi::HandleScope scope(Env()); 24 | auto napiTensor = Tensor::NewInstance(); 25 | Napi::ObjectWrap::Unwrap(napiTensor)->setTensor(outTensor); 26 | Callback().Call({Env().Undefined(), napiTensor}); 27 | } 28 | 29 | private: 30 | ScriptModuleForwardAsyncWorker(torch::jit::script::Module module, 31 | torch::Tensor tensor, Napi::Function cb) 32 | : Napi::AsyncWorker(cb), inTensor(tensor), module(module) {} 33 | 34 | torch::jit::script::Module module; 35 | torch::Tensor inTensor; 36 | torch::Tensor outTensor; 37 | }; 38 | 39 | 40 | Napi::FunctionReference ScriptModule::constructor; 41 | 42 | Napi::Object ScriptModule::Init(Napi::Env env, Napi::Object exports) { 43 | Napi::HandleScope scope(env); 44 | 45 | Napi::Function func = DefineClass(env, "ScriptModule", { 46 | InstanceMethod("forward", &ScriptModule::forward), 47 | }); 48 | 49 | constructor = Napi::Persistent(func); 50 | constructor.SuppressDestruct(); 51 | 52 | exports.Set("ScriptModule", func); 53 | return exports; 54 | } 55 | 56 | ScriptModule::ScriptModule(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) { 57 | Napi::Env env = info.Env(); 58 | Napi::HandleScope scope(env); 59 | } 60 | 61 | Napi::Object ScriptModule::NewInstance() { 62 | return constructor.New({}); 63 | } 64 | 65 | void ScriptModule::forward(const Napi::CallbackInfo &info) { 66 | Napi::Env env = info.Env(); 67 | Napi::HandleScope scope(env); 68 | Tensor* napiTensor = Napi::ObjectWrap::Unwrap(info[0].As()); 69 | Napi::Function cb = info[1].As(); 70 | torch::Tensor torchTensor = napiTensor->getTensor(); 71 | ScriptModuleForwardAsyncWorker::forward(this->module, torchTensor, cb); 72 | } 73 | 74 | void ScriptModule::setModule(torch::jit::script::Module module) { 75 | this->module = module; 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /csrc/script_module.h: -------------------------------------------------------------------------------- 1 | #ifndef TCHJS_JIT_SCRIPT_MODULE_H 2 | #define TCHJS_JIT_SCRIPT_MODULE_H 3 | 4 | #include 5 | #include 6 | 7 | namespace tchjs { 8 | class ScriptModule : public Napi::ObjectWrap { 9 | public: 10 | static Napi::Object Init(Napi::Env env, Napi::Object exports); 11 | 12 | explicit ScriptModule(const Napi::CallbackInfo &info); 13 | 14 | void setModule(torch::jit::script::Module module); 15 | 16 | static Napi::Object NewInstance(); 17 | 18 | private: 19 | static Napi::FunctionReference constructor; 20 | 21 | void forward(const Napi::CallbackInfo &info); 22 | 23 | torch::jit::script::Module module; 24 | }; 25 | } 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /csrc/tchjs.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "jit.h" 3 | #include "script_module.h" 4 | #include "tensor.h" 5 | #include "utils.h" 6 | #include 7 | 8 | namespace tchjs { 9 | Napi::Object randn(const Napi::CallbackInfo &info) { 10 | Napi::Array shape = info[0].As(); 11 | at::Tensor tensor = torch::randn(napiArrayToVector(shape), torch::requires_grad(false)); 12 | auto napiTensor = Tensor::NewInstance(); 13 | Napi::ObjectWrap::Unwrap(napiTensor)->setTensor(tensor); 14 | return napiTensor; 15 | } 16 | 17 | Napi::Object ones(const Napi::CallbackInfo &info) { 18 | Napi::Array shape = info[0].As(); 19 | at::Tensor tensor = torch::ones(napiArrayToVector(shape), torch::requires_grad(false)); 20 | auto napiTensor = Tensor::NewInstance(); 21 | Napi::ObjectWrap::Unwrap(napiTensor)->setTensor(tensor); 22 | return napiTensor; 23 | } 24 | 25 | Napi::Object tensor(const Napi::CallbackInfo &info) { 26 | Napi::Float32Array arr = info[0].As(); 27 | size_t elements = arr.ElementLength(); 28 | torch::TensorOptions options; 29 | at::Tensor tensor = torch::tensor(at::ArrayRef(arr.Data(), elements), options); 30 | auto napiTensor = Tensor::NewInstance(); 31 | Napi::ObjectWrap::Unwrap(napiTensor)->setTensor(tensor); 32 | return napiTensor; 33 | } 34 | 35 | Napi::Object InitAll(Napi::Env env, Napi::Object exports) { 36 | JitInit(env, exports); 37 | ScriptModule::Init(env, exports); 38 | Tensor::Init(env, exports); 39 | exports.Set(Napi::String::New(env, "randn"), Napi::Function::New(env, randn)); 40 | exports.Set(Napi::String::New(env, "ones"), Napi::Function::New(env, ones)); 41 | exports.Set(Napi::String::New(env, "tensor"), Napi::Function::New(env, tensor)); 42 | return exports; 43 | } 44 | 45 | NODE_API_MODULE(NODE_GYP_MODULE_NAME, InitAll) 46 | } 47 | -------------------------------------------------------------------------------- /csrc/tensor.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "tensor.h" 3 | 4 | namespace tchjs { 5 | Napi::FunctionReference Tensor::constructor; 6 | 7 | Napi::Object Tensor::Init(Napi::Env env, Napi::Object exports) { 8 | Napi::HandleScope scope(env); 9 | 10 | Napi::Function func = DefineClass(env, "Tensor", { 11 | InstanceMethod("toString", &Tensor::toString), 12 | InstanceMethod("toUint8Array", &Tensor::toUint8Array), 13 | InstanceMethod("toFloat32Array", &Tensor::toFloat32Array), 14 | InstanceMethod("view", &Tensor::view) 15 | }); 16 | 17 | constructor = Napi::Persistent(func); 18 | constructor.SuppressDestruct(); 19 | 20 | exports.Set("Tensor", func); 21 | return exports; 22 | } 23 | 24 | Tensor::Tensor(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) { 25 | Napi::Env env = info.Env(); 26 | Napi::HandleScope scope(env); 27 | } 28 | 29 | Napi::Object Tensor::NewInstance() { 30 | return constructor.New({}); 31 | } 32 | 33 | void Tensor::setTensor(at::Tensor tensor) { 34 | this->tensor = tensor; 35 | } 36 | 37 | torch::Tensor Tensor::getTensor() { 38 | return this->tensor; 39 | } 40 | 41 | Napi::Value Tensor::toString(const Napi::CallbackInfo &info) { 42 | Napi::Env env = info.Env(); 43 | Napi::HandleScope scope(env); 44 | 45 | std::stringstream ss; 46 | ss << "Type=" << this->tensor.options() << std::endl; 47 | ss << ", Size=" << this->tensor.sizes() << std::endl; 48 | return Napi::String::New(env, ss.str()); 49 | } 50 | 51 | Napi::Value Tensor::toFloat32Array(const Napi::CallbackInfo &info) { 52 | Napi::Env env = info.Env(); 53 | Napi::HandleScope scope(env); 54 | 55 | uint64_t size = this->tensor.numel(); 56 | // Make float32 type tensor 57 | auto floatData = this->tensor.contiguous().data_ptr(); 58 | // Wrap in NAPI float32 array 59 | auto arr = Napi::Float32Array::New(env, size); 60 | for (uint64_t i = 0; i < size; i++) { 61 | arr[i] = floatData[i]; 62 | } 63 | return arr; 64 | } 65 | 66 | Napi::Value Tensor::toUint8Array(const Napi::CallbackInfo &info) { 67 | Napi::Env env = info.Env(); 68 | Napi::HandleScope scope(env); 69 | 70 | uint64_t size = this->tensor.numel(); 71 | // Make uint8 type tensor 72 | auto byteTensor = this->tensor.clamp(0, 255).to(at::ScalarType::Byte); 73 | auto byteData = byteTensor.contiguous().data_ptr(); 74 | // Wrap in NAPI uint8 array 75 | auto arr = Napi::Uint8Array::New(env, size); 76 | for (uint64_t i = 0; i < size; i++) { 77 | arr[i] = byteData[i]; 78 | } 79 | return arr; 80 | } 81 | 82 | Napi::Value Tensor::view(const Napi::CallbackInfo &info) { 83 | Napi::Env env = info.Env(); 84 | Napi::HandleScope scope(env); 85 | Napi::Array shape = info[0].As(); 86 | std::vector vshape; 87 | uint32_t len = shape.Length(); 88 | for (uint32_t i = 0; i < len; i++) { 89 | vshape.push_back(shape.Get(i).ToNumber()); 90 | } 91 | torch::Tensor tensor = this->tensor.view(vshape); 92 | auto napiTensor = Tensor::NewInstance(); 93 | Napi::ObjectWrap::Unwrap(napiTensor)->setTensor(tensor); 94 | return napiTensor; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /csrc/tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef TCHJS_TENSOR_H 2 | #define TCHJS_TENSOR_H 3 | 4 | #include 5 | #include 6 | 7 | namespace tchjs { 8 | class Tensor : public Napi::ObjectWrap { 9 | public: 10 | static Napi::Object Init(Napi::Env env, Napi::Object exports); 11 | 12 | explicit Tensor(const Napi::CallbackInfo &info); 13 | 14 | void setTensor(at::Tensor tensor); 15 | 16 | Napi::Value toString(const Napi::CallbackInfo &info); 17 | 18 | Napi::Value toFloat32Array(const Napi::CallbackInfo &info); 19 | 20 | Napi::Value toUint8Array(const Napi::CallbackInfo &info); 21 | 22 | Napi::Value view(const Napi::CallbackInfo &info); 23 | 24 | torch::Tensor getTensor(); 25 | 26 | static Napi::Object NewInstance(); 27 | 28 | private: 29 | static Napi::FunctionReference constructor; 30 | 31 | torch::Tensor tensor; 32 | }; 33 | } 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /csrc/utils.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "utils.h" 4 | 5 | namespace tchjs { 6 | std::vector napiArrayToVector(Napi::Array arr) { 7 | std::vector vec; 8 | size_t len = arr.Length(); 9 | for (size_t i = 0; i < len; i++) { 10 | vec.push_back(arr.Get(i).ToNumber().Int64Value()); 11 | } 12 | return vec; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /csrc/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef TCHJS_UTILS_H 2 | #define TCHJS_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | namespace tchjs { 8 | std::vector napiArrayToVector(Napi::Array arr); 9 | } 10 | 11 | #endif 12 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tch-js", 3 | "version": "0.0.1-alpha.1", 4 | "description": "A JavaScript port of PyTorch C++ frontend (libtorch)", 5 | "main": "dist/index.js", 6 | "typings": "dist/index.d.ts", 7 | "files": [ 8 | "dist", 9 | "src" 10 | ], 11 | "author": "Cedric Chee ", 12 | "license": "MIT", 13 | "repository": { 14 | "type": "git", 15 | "url": "https://github.com/cedrickchee/tch-js.git" 16 | }, 17 | "engines": { 18 | "node": "^10 || >=12" 19 | }, 20 | "os": [ 21 | "linux", 22 | "win32" 23 | ], 24 | "cpu": [ 25 | "x64" 26 | ], 27 | "keywords": [ 28 | "n-api", 29 | "pytorch", 30 | "libtorch", 31 | "tensor", 32 | "ScriptModule", 33 | "neural-network", 34 | "deep-learning" 35 | ], 36 | "scripts": { 37 | "install": "node-pre-gyp install --fallback-to-build=false", 38 | "pre-build": "node-pre-gyp rebuild package", 39 | "pre-publish": "node-pre-gyp-github publish", 40 | "pre-unpublish": "node-pre-gyp unpublish", 41 | "cmake-rebuild": "ncmake rebuild", 42 | "start": "tsdx watch", 43 | "build": "tsdx build", 44 | "test": "tsdx test", 45 | "lint": "tsdx lint", 46 | "prepare": "tsdx build", 47 | "size": "size-limit", 48 | "analyze": "size-limit --why" 49 | }, 50 | "dependencies": { 51 | "node-pre-gyp": "^0.14.0", 52 | "bindings": "^1.3.1" 53 | }, 54 | "devDependencies": { 55 | "@size-limit/preset-small-lib": "^4.9.1", 56 | "husky": "^4.3.8", 57 | "size-limit": "^4.9.1", 58 | "tsdx": "^0.14.1", 59 | "tslib": "^2.1.0", 60 | "typescript": "^4.1.3", 61 | "node-addon-api": "^1.6.2", 62 | "node-cmake": "^2.5.1" 63 | }, 64 | "peerDependencies": {}, 65 | "binary": { 66 | "module_name": "tchjs", 67 | "module_path": "./dist/binding/napi-v{napi_build_version}", 68 | "remote_path": "{version}", 69 | "package_name": "{platform}-{arch}-napi-v{napi_build_version}.tar.gz", 70 | "host": "https://github.com/cedrickchee/tch-js/releases/download/", 71 | "napi_versions": [ 72 | 3 73 | ] 74 | }, 75 | "husky": { 76 | "hooks": { 77 | "pre-commit": "tsdx lint" 78 | } 79 | }, 80 | "prettier": { 81 | "printWidth": 80, 82 | "semi": true, 83 | "singleQuote": true, 84 | "trailingComma": "es5" 85 | }, 86 | "module": "dist/tch-js.esm.js", 87 | "size-limit": [ 88 | { 89 | "path": "dist/tch-js.cjs.production.min.js", 90 | "limit": "10 KB" 91 | }, 92 | { 93 | "path": "dist/tch-js.esm.js", 94 | "limit": "10 KB" 95 | } 96 | ] 97 | } 98 | -------------------------------------------------------------------------------- /src/binding.ts: -------------------------------------------------------------------------------- 1 | import path from 'path'; 2 | import binary from 'node-pre-gyp'; 3 | 4 | const binding_path = binary.find( 5 | path.resolve(path.join(__dirname, '../package.json')) 6 | ); 7 | const binding = require(binding_path); 8 | 9 | export default binding; 10 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | import tch from './binding'; 2 | // import * as types from 'types'; 3 | import { promisify } from './promisify'; 4 | 5 | // const load = (filePath: string) => { 6 | // return new Promise((resolve, reject) => { 7 | // tch.load(filePath, (err: Error | null, model: Model) => { 8 | // if (err) reject(err); 9 | // return resolve(model); 10 | // }) 11 | // }); 12 | // }; 13 | 14 | // [path: string] 15 | let loadAsPromise = promisify(tch.load); 16 | type tsr = Tensor; 17 | 18 | export { tch, loadAsPromise as load, tsr as Tensor }; 19 | -------------------------------------------------------------------------------- /src/promisify.ts: -------------------------------------------------------------------------------- 1 | type CallbackErr = Error | string | null; 2 | type Callback = (err: CallbackErr, result: R) => any; 3 | 4 | type FunctionWithCallback = ( 5 | ...args: [Callback, ...T[]] 6 | ) => any; 7 | 8 | const promisifyCallback = ( 9 | func: FunctionWithCallback 10 | ) => (...args: I): Promise => { 11 | return new Promise((resolve, reject) => { 12 | func((err: CallbackErr, result: R) => { 13 | return err ? reject(err) : resolve(result); 14 | }, ...args); 15 | }); 16 | }; 17 | 18 | export { promisifyCallback as promisify }; 19 | -------------------------------------------------------------------------------- /test/binding.test.ts: -------------------------------------------------------------------------------- 1 | import binding from '../src/binding'; 2 | 3 | describe('binding', () => { 4 | it('module should load', () => { 5 | expect(binding).not.toBeNull(); 6 | console.log('binding:', binding); 7 | }); 8 | }); 9 | -------------------------------------------------------------------------------- /test/data/matmul.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedrickchee/tch-js/1d4e0c40b28f8a2244c1263581be755fba851ed1/test/data/matmul.pt -------------------------------------------------------------------------------- /test/data/trace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def matmul(x): 4 | return x * 2 5 | 6 | torch.jit.trace(matmul, torch.randn(3, 3)).save("matmul.pt") -------------------------------------------------------------------------------- /test/tensor.test.ts: -------------------------------------------------------------------------------- 1 | import { tch, load, Tensor } from '../src'; 2 | import path from 'path'; 3 | 4 | describe('tch-js', () => { 5 | it('ones', () => { 6 | const ones = tch.ones([3, 2]); 7 | 8 | expect([...ones.toUint8Array()]).toEqual([1, 1, 1, 1, 1, 1]); 9 | }); 10 | 11 | describe('tensor', () => { 12 | it('toFloat32Array', () => { 13 | const arr = new Float32Array([ 14 | 1.100000023841858, 15 | 1.590000033378601, 16 | 2.049999952316284, 17 | 0.18000000715255737, 18 | ]); 19 | const tensor = tch.tensor(arr); 20 | 21 | expect([...tensor.toFloat32Array()]).toEqual([ 22 | 1.100000023841858, 23 | 1.590000033378601, 24 | 2.049999952316284, 25 | 0.18000000715255737, 26 | ]); 27 | }); 28 | 29 | it('toUint8Array', () => { 30 | const arr = new Float32Array([1.1, 2.0, 3.0, 4.1]); 31 | const tensor = tch.tensor(arr); 32 | 33 | expect([...tensor.toUint8Array()]).toEqual([1, 2, 3, 4]); 34 | }); 35 | }); 36 | 37 | it('load', async done => { 38 | const input = tch.tensor(new Float32Array([1.5, 5.5])); 39 | 40 | const model = await load(path.join(__dirname, 'data', 'matmul.pt')); 41 | const getResult = (err: Error, result: Tensor) => { 42 | if (err) return; 43 | 44 | const output = result.toUint8Array(); 45 | 46 | expect([...output]).toEqual([3, 11]); 47 | 48 | done(); 49 | }; 50 | model.forward(input, getResult); 51 | }); 52 | }); 53 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | // see https://www.typescriptlang.org/tsconfig to better understand tsconfigs 3 | "include": ["src", "types"], 4 | "compilerOptions": { 5 | "module": "esnext", 6 | "lib": ["dom", "esnext"], 7 | "importHelpers": true, 8 | // output .d.ts declaration files for consumers 9 | "declaration": true, 10 | // output .js.map sourcemap files for consumers 11 | "sourceMap": true, 12 | // match output dir to input dir. e.g. dist/index instead of dist/src/index 13 | "rootDir": "./src", 14 | // stricter type-checking for stronger correctness. Recommended by TS 15 | "strict": true, 16 | // linter checks for common issues 17 | "noImplicitReturns": true, 18 | "noFallthroughCasesInSwitch": true, 19 | // noUnused* overlap with @typescript-eslint/no-unused-vars, can disable if duplicative 20 | "noUnusedLocals": true, 21 | "noUnusedParameters": true, 22 | // use Node's module resolution algorithm, instead of the legacy TS one 23 | "moduleResolution": "node", 24 | // transpile JSX to React.createElement 25 | "jsx": "react", 26 | // interop between ESM and CJS modules. Recommended by TS 27 | "esModuleInterop": true, 28 | // significant perf increase by skipping checking .d.ts files, particularly those in node_modules. Recommended by TS 29 | "skipLibCheck": true, 30 | // error out if import and file system have a casing mismatch. Recommended by TS 31 | "forceConsistentCasingInFileNames": true, 32 | // `tsdx build` ignores this option, but it is commonly used when type-checking separately with `tsc` 33 | "noEmit": true, 34 | "downlevelIteration": true, 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /types/node-pre-gyp.d.ts: -------------------------------------------------------------------------------- 1 | declare module 'node-pre-gyp'; 2 | -------------------------------------------------------------------------------- /types/tch.d.ts: -------------------------------------------------------------------------------- 1 | interface ScriptModule { 2 | forward(input: any, callback: Result): void; 3 | } 4 | interface Tensor { 5 | toUint8Array(): Uint8Array; 6 | toFloat32Array(): Float32Array; 7 | } 8 | type Result = (err: Error, result: Tensor) => void; 9 | --------------------------------------------------------------------------------