├── csrc ├── .gitignore ├── include │ └── lltm │ │ ├── lltm.h │ │ └── exports.h ├── src │ ├── lltm.def │ ├── exports.cpp │ └── lltm.cpp └── CMakeLists.txt ├── .github ├── .gitignore └── workflows │ └── R-CMD-check.yaml ├── src ├── .gitignore ├── lltm_types.h ├── Makevars ├── lltm.cpp ├── Makevars.win ├── exports.cpp └── RcppExports.cpp ├── LICENSE ├── tests ├── testthat.R └── testthat │ └── test-lltm.R ├── man └── figures │ ├── packaging.png │ └── high-level.png ├── .Rbuildignore ├── .gitignore ├── NAMESPACE ├── inst ├── include │ └── lltm │ │ ├── lltm.h │ │ └── exports.h └── def │ └── lltm.def ├── lltm.Rproj ├── R ├── RcppExports.R ├── lltm.R └── package.R ├── DESCRIPTION ├── LICENSE.md ├── README.Rmd └── README.md /csrc/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.so 3 | *.dll 4 | -------------------------------------------------------------------------------- /src/lltm_types.h: -------------------------------------------------------------------------------- 1 | #include 2 | -------------------------------------------------------------------------------- /src/Makevars: -------------------------------------------------------------------------------- 1 | PKG_CPPFLAGS = -I../inst/include/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2021 2 | COPYRIGHT HOLDER: lltm authors 3 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(lltm) 3 | 4 | test_check("lltm") 5 | -------------------------------------------------------------------------------- /man/figures/packaging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/lltm/HEAD/man/figures/packaging.png -------------------------------------------------------------------------------- /man/figures/high-level.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlverse/lltm/HEAD/man/figures/high-level.png -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^lltm\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^LICENSE\.md$ 4 | ^\.github$ 5 | ^csrc$ 6 | ^inst/lib$ 7 | ^README\.Rmd$ 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .Rdata 4 | .httr-oauth 5 | .DS_Store 6 | inst/lib 7 | inst/bin 8 | src/lltm.lib 9 | 10 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | importFrom(Rcpp,sourceCpp) 4 | importFrom(utils,download.file) 5 | importFrom(utils,packageDescription) 6 | importFrom(utils,unzip) 7 | -------------------------------------------------------------------------------- /csrc/include/lltm/lltm.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | LLTM_API int _raise_exception (); 4 | inline int raise_exception () { 5 | _raise_exception(); 6 | host_exception_handler(); 7 | return 1; 8 | } 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /inst/include/lltm/lltm.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | LLTM_API int _raise_exception (); 4 | inline int raise_exception () { 5 | _raise_exception(); 6 | host_exception_handler(); 7 | return 1; 8 | } 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /csrc/src/lltm.def: -------------------------------------------------------------------------------- 1 | LIBRARY LLTM 2 | EXPORTS 3 | 4 | ;------ autogenerated ------------------------- 5 | ; don't modify between the autogenerated lines 6 | _lltm_forward 7 | _lltm_backward 8 | lltm_last_error 9 | lltm_last_error_clear 10 | ;------ autogenerated ------------------------- 11 | 12 | _raise_exception 13 | -------------------------------------------------------------------------------- /inst/def/lltm.def: -------------------------------------------------------------------------------- 1 | LIBRARY LLTM 2 | EXPORTS 3 | 4 | ;------ autogenerated ------------------------- 5 | ; don't modify between the autogenerated lines 6 | _lltm_forward 7 | _lltm_backward 8 | lltm_last_error 9 | lltm_last_error_clear 10 | ;------ autogenerated ------------------------- 11 | 12 | _raise_exception 13 | -------------------------------------------------------------------------------- /src/lltm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #define LLTM_HEADERS_ONLY 3 | #include 4 | #define TORCH_IMPL 5 | #define IMPORT_TORCH 6 | #include 7 | 8 | void host_exception_handler () 9 | { 10 | if (lltm_last_error()) 11 | { 12 | auto msg = Rcpp::as(torch::string(lltm_last_error())); 13 | lltm_last_error_clear(); 14 | Rcpp::stop(msg); 15 | } 16 | } 17 | 18 | // [[Rcpp::export]] 19 | void lltm_raise_exception () 20 | { 21 | raise_exception(); 22 | } 23 | -------------------------------------------------------------------------------- /lltm.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | LineEndingConversion: Posix 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /src/Makevars.win: -------------------------------------------------------------------------------- 1 | PKG_CPPFLAGS = -I../inst/include/ 2 | 3 | PKG_LIBS=\ 4 | -L. \ 5 | -llltm 6 | 7 | .PHONY: all lltm 8 | all: clean lltm 9 | 10 | # Creates the import library from the .def file. 11 | # Keeping the def file in the source folder makes more sense because it's easy 12 | # to inspect it in version control suystems and it's easier to edit - since it's 13 | # a plain text file. 14 | # But MingGW can't take the .def file as an input to the linker, so we need to 15 | # create the import library. 16 | lltm: clean 17 | $(DLLTOOL) -d ../inst/def/lltm.def -l lltm.lib 18 | 19 | clean: 20 | rm -rf lltm.lib 21 | 22 | -------------------------------------------------------------------------------- /tests/testthat/test-lltm.R: -------------------------------------------------------------------------------- 1 | test_that("multiplication works", { 2 | batch_size = 16 3 | input_features = 32 4 | state_size = 128 5 | 6 | X = torch::torch_randn(batch_size, input_features) 7 | h = torch::torch_randn(batch_size, state_size) 8 | C = torch::torch_randn(batch_size, state_size) 9 | 10 | rnn = nn_lltm(input_features, state_size) 11 | 12 | 13 | 14 | out = rnn(X, list(h, C)) 15 | l <- out[[1]]$sum() + out[[2]]$sum() 16 | l$backward() 17 | 18 | expect_equal(rnn$weights$grad$shape, c(384, 160)) 19 | expect_equal(rnn$bias$grad$shape, c(384)) 20 | }) 21 | 22 | test_that("raise exceptions", { 23 | 24 | expect_error(lltm_raise_exception(), "Error from LLTM") 25 | 26 | }) 27 | -------------------------------------------------------------------------------- /R/RcppExports.R: -------------------------------------------------------------------------------- 1 | # Generated by using Rcpp::compileAttributes() -> do not edit by hand 2 | # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 3 | 4 | rcpp_lltm_forward <- function(input, weights, bias, old_h, old_cell) { 5 | .Call('_lltm_rcpp_lltm_forward', PACKAGE = 'lltm', input, weights, bias, old_h, old_cell) 6 | } 7 | 8 | rcpp_lltm_backward <- function(grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights) { 9 | .Call('_lltm_rcpp_lltm_backward', PACKAGE = 'lltm', grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights) 10 | } 11 | 12 | lltm_raise_exception <- function() { 13 | invisible(.Call('_lltm_lltm_raise_exception', PACKAGE = 'lltm')) 14 | } 15 | 16 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: lltm 2 | Title: Long Long Term Memory Neural Network Cells 3 | Version: 0.0.0.9000 4 | Authors@R: c( 5 | person("Daniel", "Falbel", email = "daniel@rstudio.com", role = c("aut", "cre", "cph")), 6 | person(family = "RStudio", role = c("cph")) 7 | ) 8 | Description: Implements the Long Long Term Memory Neural Network Cells ('LLTM'). 9 | It demonstrates how to implement and distribute 'C++' extensions for 'torch'. 10 | Also implements 'JIT' operators for 'torch'. 11 | License: MIT + file LICENSE 12 | Encoding: UTF-8 13 | Roxygen: list(markdown = TRUE) 14 | SystemRequirements: C++11 15 | RoxygenNote: 7.1.1 16 | LinkingTo: 17 | Rcpp, torch 18 | Imports: 19 | Rcpp, torch 20 | Remotes: 21 | mlverse/torch 22 | Suggests: 23 | testthat (>= 3.0.0) 24 | Config/testthat/edition: 3 25 | -------------------------------------------------------------------------------- /src/exports.cpp: -------------------------------------------------------------------------------- 1 | // Generated by using torchexport::export() -> do not edit by hand 2 | #include 3 | #include 4 | #define LLTM_HEADERS_ONLY 5 | #include 6 | 7 | // [[Rcpp::export]] 8 | torch::TensorList rcpp_lltm_forward (torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell) { 9 | return lltm_forward(input.get(), weights.get(), bias.get(), old_h.get(), old_cell.get()); 10 | } 11 | // [[Rcpp::export]] 12 | torch::TensorList rcpp_lltm_backward (torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, torch::Tensor input_gate, torch::Tensor output_gate, torch::Tensor candidate_cell, torch::Tensor X, torch::Tensor gate_weights, torch::Tensor weights) { 13 | return lltm_backward(grad_h.get(), grad_cell.get(), new_cell.get(), input_gate.get(), output_gate.get(), candidate_cell.get(), X.get(), gate_weights.get(), weights.get()); 14 | } 15 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2021 lltm authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /csrc/src/exports.cpp: -------------------------------------------------------------------------------- 1 | // Generated by using torchexport::export() -> do not edit by hand 2 | #include "lltm/exports.h" 3 | #include 4 | void * p_lltm_last_error = NULL; 5 | 6 | LLTM_API void* lltm_last_error() 7 | { 8 | return p_lltm_last_error; 9 | } 10 | 11 | LLTM_API void lltm_last_error_clear() 12 | { 13 | p_lltm_last_error = NULL; 14 | } 15 | 16 | std::vector lltm_forward (torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell); 17 | LLTM_API void* _lltm_forward (void* input, void* weights, void* bias, void* old_h, void* old_cell) { 18 | try { 19 | return make_raw::TensorList(lltm_forward(from_raw::Tensor(input), from_raw::Tensor(weights), from_raw::Tensor(bias), from_raw::Tensor(old_h), from_raw::Tensor(old_cell))); 20 | } LLTM_HANDLE_EXCEPTION 21 | return ( void* ) NULL; 22 | } 23 | std::vector lltm_backward (torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, torch::Tensor input_gate, torch::Tensor output_gate, torch::Tensor candidate_cell, torch::Tensor X, torch::Tensor gate_weights, torch::Tensor weights); 24 | LLTM_API void* _lltm_backward (void* grad_h, void* grad_cell, void* new_cell, void* input_gate, void* output_gate, void* candidate_cell, void* X, void* gate_weights, void* weights) { 25 | try { 26 | return make_raw::TensorList(lltm_backward(from_raw::Tensor(grad_h), from_raw::Tensor(grad_cell), from_raw::Tensor(new_cell), from_raw::Tensor(input_gate), from_raw::Tensor(output_gate), from_raw::Tensor(candidate_cell), from_raw::Tensor(X), from_raw::Tensor(gate_weights), from_raw::Tensor(weights))); 27 | } LLTM_HANDLE_EXCEPTION 28 | return ( void* ) NULL; 29 | } 30 | -------------------------------------------------------------------------------- /R/lltm.R: -------------------------------------------------------------------------------- 1 | lltm_function <- torch::autograd_function( 2 | forward = function(ctx, input, weights, bias, old_h, old_cell) { 3 | outputs <- rcpp_lltm_forward(input, weights, bias, old_h, old_cell) 4 | names(outputs) <- c("new_h", "new_cell", "input_gate", "output_gate", 5 | "candidate_cell", "X", "gate_weights") 6 | 7 | variables <- append(outputs, list(weights = weights)) 8 | ctx$save_for_backward(!!!variables) 9 | 10 | outputs[c("new_h", "new_cell")] 11 | }, 12 | backward = function(ctx, grad_h, grad_cell) { 13 | outputs <- rcpp_lltm_backward( 14 | grad_h = grad_h$contiguous(), 15 | grad_cell = grad_cell$contiguous(), 16 | new_cell = ctx$saved_variables$new_cell, 17 | input_gate = ctx$saved_variables$input_gate, 18 | output_gate = ctx$saved_variables$output_gate, 19 | candidate_cell = ctx$saved_variables$candidate_cell, 20 | X = ctx$saved_variables$X, 21 | gate_weights = ctx$saved_variables$gate_weights, 22 | weights = ctx$saved_variables$weights 23 | ) 24 | 25 | names(outputs) <- c("old_h", "input", "weights", "bias", "old_cell") 26 | outputs 27 | } 28 | ) 29 | 30 | nn_lltm <- torch::nn_module( 31 | initialize = function(input_features, state_size) { 32 | self$input_features <- input_features 33 | self$state_size <- state_size 34 | self$weights <- torch::nn_parameter( 35 | torch::torch_empty(3 * state_size, input_features + state_size)) 36 | self$bias <- torch::nn_parameter(torch::torch_empty(3 * state_size)) 37 | self$reset_parameters() 38 | }, 39 | reset_parameters = function() { 40 | stdv = 1.0 / sqrt(self$state_size) 41 | lapply(self$parameters, function(x) { 42 | torch::nn_init_uniform_(x, a = -stdv, b = stdv) 43 | }) 44 | }, 45 | forward = function(input, state) { 46 | lltm_function(input, self$weights, self$bias, state[[1]], state[[2]]) 47 | } 48 | ) 49 | -------------------------------------------------------------------------------- /csrc/include/lltm/exports.h: -------------------------------------------------------------------------------- 1 | // Generated by using torchexport::export() -> do not edit by hand 2 | #ifdef _WIN32 3 | #ifndef LLTM_HEADERS_ONLY 4 | #define LLTM_API extern "C" __declspec(dllexport) 5 | #else 6 | #define LLTM_API extern "C" __declspec(dllimport) 7 | #endif 8 | #else 9 | #define LLTM_API extern "C" 10 | #endif 11 | 12 | #ifndef LLTM_HANDLE_EXCEPTION 13 | #define LLTM_HANDLE_EXCEPTION \ 14 | catch(const std::exception& ex) { \ 15 | p_lltm_last_error = make_raw::string(ex.what()); \ 16 | } catch (std::string& ex) { \ 17 | p_lltm_last_error = make_raw::string(ex); \ 18 | } catch (...) { \ 19 | p_lltm_last_error = make_raw::string("Unknown error. "); \ 20 | } 21 | #endif 22 | 23 | void host_exception_handler (); 24 | extern void* p_lltm_last_error; 25 | LLTM_API void* lltm_last_error (); 26 | LLTM_API void lltm_last_error_clear(); 27 | 28 | LLTM_API void* _lltm_forward (void* input, void* weights, void* bias, void* old_h, void* old_cell); 29 | LLTM_API void* _lltm_backward (void* grad_h, void* grad_cell, void* new_cell, void* input_gate, void* output_gate, void* candidate_cell, void* X, void* gate_weights, void* weights); 30 | 31 | #ifdef RCPP_VERSION 32 | inline void* lltm_forward (void* input, void* weights, void* bias, void* old_h, void* old_cell) { 33 | auto ret = _lltm_forward(input, weights, bias, old_h, old_cell); 34 | host_exception_handler(); 35 | return ret; 36 | } 37 | inline void* lltm_backward (void* grad_h, void* grad_cell, void* new_cell, void* input_gate, void* output_gate, void* candidate_cell, void* X, void* gate_weights, void* weights) { 38 | auto ret = _lltm_backward(grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights); 39 | host_exception_handler(); 40 | return ret; 41 | } 42 | #endif // RCPP_VERSION 43 | -------------------------------------------------------------------------------- /inst/include/lltm/exports.h: -------------------------------------------------------------------------------- 1 | // Generated by using torchexport::export() -> do not edit by hand 2 | #ifdef _WIN32 3 | #ifndef LLTM_HEADERS_ONLY 4 | #define LLTM_API extern "C" __declspec(dllexport) 5 | #else 6 | #define LLTM_API extern "C" __declspec(dllimport) 7 | #endif 8 | #else 9 | #define LLTM_API extern "C" 10 | #endif 11 | 12 | #ifndef LLTM_HANDLE_EXCEPTION 13 | #define LLTM_HANDLE_EXCEPTION \ 14 | catch(const std::exception& ex) { \ 15 | p_lltm_last_error = make_raw::string(ex.what()); \ 16 | } catch (std::string& ex) { \ 17 | p_lltm_last_error = make_raw::string(ex); \ 18 | } catch (...) { \ 19 | p_lltm_last_error = make_raw::string("Unknown error. "); \ 20 | } 21 | #endif 22 | 23 | void host_exception_handler (); 24 | extern void* p_lltm_last_error; 25 | LLTM_API void* lltm_last_error (); 26 | LLTM_API void lltm_last_error_clear(); 27 | 28 | LLTM_API void* _lltm_forward (void* input, void* weights, void* bias, void* old_h, void* old_cell); 29 | LLTM_API void* _lltm_backward (void* grad_h, void* grad_cell, void* new_cell, void* input_gate, void* output_gate, void* candidate_cell, void* X, void* gate_weights, void* weights); 30 | 31 | #ifdef RCPP_VERSION 32 | inline void* lltm_forward (void* input, void* weights, void* bias, void* old_h, void* old_cell) { 33 | auto ret = _lltm_forward(input, weights, bias, old_h, old_cell); 34 | host_exception_handler(); 35 | return ret; 36 | } 37 | inline void* lltm_backward (void* grad_h, void* grad_cell, void* new_cell, void* input_gate, void* output_gate, void* candidate_cell, void* X, void* gate_weights, void* weights) { 38 | auto ret = _lltm_backward(grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights); 39 | host_exception_handler(); 40 | return ret; 41 | } 42 | #endif // RCPP_VERSION 43 | -------------------------------------------------------------------------------- /R/package.R: -------------------------------------------------------------------------------- 1 | ## usethis namespace: start 2 | #' @importFrom Rcpp sourceCpp 3 | #' @importFrom utils download.file packageDescription unzip 4 | ## usethis namespace: end 5 | NULL 6 | 7 | .onLoad <- function(lib, pkg) { 8 | if (torch::torch_is_installed()) { 9 | 10 | if (!lltm_is_installed()) 11 | install_lltm() 12 | 13 | if (!lltm_is_installed()) { 14 | if (interactive()) 15 | warning("liblltm is not installed. Run `intall_lltm()` before using the package.") 16 | } else { 17 | dyn.load(lib_path(), local = FALSE) 18 | 19 | # when using devtools::load_all() the library might be available in 20 | # `lib/pkg/src` 21 | pkgload <- file.path(lib, pkg, "src", paste0(pkg, .Platform$dynlib.ext)) 22 | if (file.exists(pkgload)) 23 | dyn.load(pkgload) 24 | else 25 | library.dynam("lltm", pkg, lib) 26 | } 27 | } 28 | } 29 | 30 | inst_path <- function() { 31 | install_path <- Sys.getenv("LLTM_HOME") 32 | if (nzchar(install_path)) return(install_path) 33 | 34 | system.file("", package = "lltm") 35 | } 36 | 37 | lib_path <- function() { 38 | install_path <- inst_path() 39 | 40 | if (.Platform$OS.type == "unix") { 41 | file.path(install_path, "lib", paste0("liblltm", lib_ext())) 42 | } else { 43 | file.path(install_path, "bin", paste0("lltm", lib_ext())) 44 | } 45 | } 46 | 47 | lib_ext <- function() { 48 | if (grepl("darwin", version$os)) 49 | ".dylib" 50 | else if (grepl("linux", version$os)) 51 | ".so" 52 | else 53 | ".dll" 54 | } 55 | 56 | lltm_is_installed <- function() { 57 | file.exists(lib_path()) 58 | } 59 | 60 | install_lltm <- function(url = Sys.getenv("LLTM_URL", unset = NA)) { 61 | 62 | if (!interactive() && Sys.getenv("TORCH_INSTALL", unset = 0) == "0") return() 63 | 64 | if (is.na(url)) { 65 | tmp <- tempfile(fileext = ".zip") 66 | version <- packageDescription("lltm")$Version 67 | os <- get_cmake_style_os() 68 | dev <- if (torch::cuda_is_available()) "cu" else "cpu" 69 | 70 | url <- sprintf("https://github.com/mlverse/lltm/releases/download/liblltm/lltm-%s+%s-%s.zip", 71 | version, dev, os) 72 | } 73 | 74 | if (is_url(url)) { 75 | file <- tempfile(fileext = ".zip") 76 | on.exit(unlink(file), add = TRUE) 77 | download.file(url = url, destfile = file) 78 | } else { 79 | message('Using file ', url) 80 | file <- url 81 | } 82 | 83 | tmp <- tempfile() 84 | on.exit(unlink(tmp), add = TRUE) 85 | unzip(file, exdir = tmp) 86 | 87 | file.copy( 88 | list.files(list.files(tmp, full.names = TRUE), full.names = TRUE), 89 | inst_path(), 90 | recursive = TRUE 91 | ) 92 | } 93 | 94 | get_cmake_style_os <- function() { 95 | os <- version$os 96 | if (grepl("darwin", os)) { 97 | "Darwin" 98 | } else if (grepl("linux", os)) { 99 | "Linux" 100 | } else { 101 | "win64" 102 | } 103 | } 104 | 105 | is_url <- function(x) { 106 | grepl("^https", x) || grepl("^http", x) 107 | } 108 | 109 | -------------------------------------------------------------------------------- /csrc/src/lltm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #define LANTERN_TYPES_IMPL // Should be defined only in a single file. 3 | #include 4 | #include 5 | #include 6 | #include "lltm/lltm.h" 7 | 8 | torch::Tensor d_sigmoid(torch::Tensor z) { 9 | auto s = torch::sigmoid(z); 10 | return (1 - s) * s; 11 | } 12 | 13 | // [[torch::export]] 14 | std::vector lltm_forward( 15 | torch::Tensor input, 16 | torch::Tensor weights, 17 | torch::Tensor bias, 18 | torch::Tensor old_h, 19 | torch::Tensor old_cell) { 20 | auto X = torch::cat({old_h, input}, /*dim=*/1); 21 | 22 | auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); 23 | auto gates = gate_weights.chunk(3, /*dim=*/1); 24 | 25 | auto input_gate = torch::sigmoid(gates[0]); 26 | auto output_gate = torch::sigmoid(gates[1]); 27 | auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); 28 | 29 | auto new_cell = old_cell + candidate_cell * input_gate; 30 | auto new_h = torch::tanh(new_cell) * output_gate; 31 | 32 | return {new_h, 33 | new_cell, 34 | input_gate, 35 | output_gate, 36 | candidate_cell, 37 | X, 38 | gate_weights}; 39 | } 40 | 41 | // tanh'(z) = 1 - tanh^2(z) 42 | torch::Tensor d_tanh(torch::Tensor z) { 43 | return 1 - z.tanh().pow(2); 44 | } 45 | 46 | // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} 47 | torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { 48 | auto e = z.exp(); 49 | auto mask = (alpha * (e - 1)) < 0; 50 | return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); 51 | } 52 | 53 | // [[torch::export]] 54 | std::vector lltm_backward( 55 | torch::Tensor grad_h, 56 | torch::Tensor grad_cell, 57 | torch::Tensor new_cell, 58 | torch::Tensor input_gate, 59 | torch::Tensor output_gate, 60 | torch::Tensor candidate_cell, 61 | torch::Tensor X, 62 | torch::Tensor gate_weights, 63 | torch::Tensor weights) { 64 | auto d_output_gate = torch::tanh(new_cell) * grad_h; 65 | auto d_tanh_new_cell = output_gate * grad_h; 66 | auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; 67 | 68 | auto d_old_cell = d_new_cell; 69 | auto d_candidate_cell = input_gate * d_new_cell; 70 | auto d_input_gate = candidate_cell * d_new_cell; 71 | 72 | auto gates = gate_weights.chunk(3, /*dim=*/1); 73 | d_input_gate *= d_sigmoid(gates[0]); 74 | d_output_gate *= d_sigmoid(gates[1]); 75 | d_candidate_cell *= d_elu(gates[2]); 76 | 77 | auto d_gates = 78 | torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); 79 | 80 | auto d_weights = d_gates.t().mm(X); 81 | auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); 82 | 83 | auto d_X = d_gates.mm(weights); 84 | const auto state_size = grad_h.size(1); 85 | auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); 86 | auto d_input = d_X.slice(/*dim=*/1, state_size); 87 | 88 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; 89 | } 90 | 91 | LLTM_API int _raise_exception () 92 | { 93 | try { 94 | throw std::runtime_error("Error from LLTM"); 95 | } LLTM_HANDLE_EXCEPTION 96 | return 1; 97 | } 98 | -------------------------------------------------------------------------------- /csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Project name and cmake minimum requirement. 2 | project(lltm) 3 | cmake_minimum_required(VERSION 3.16) 4 | 5 | # We find a LibTorch installation trough the torch package. 6 | # This is the best approach if we want to make sure we are 7 | # targetting the same LibTorch version as used by torch. 8 | execute_process ( 9 | COMMAND Rscript -e "cat(torch::torch_install_path())" 10 | OUTPUT_VARIABLE TORCH_HOME 11 | ) 12 | set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "${TORCH_HOME}") 13 | 14 | # Now that the prefix path is set we can tell cmake to go 15 | # and find Torch. 16 | find_package(Torch REQUIRED) 17 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 18 | 19 | # Here we tell CMake what are the source files of our package. 20 | # If you want to separate your implementation in multiple files 21 | # add their paths after `src/lltm.cpp`, the spearator is a simple 22 | # space. 23 | set(LLTM_SRC src/lltm.cpp src/exports.cpp) 24 | set(LLTM_HEADERS include/lltm/lltm.h include/lltm/exports.h) 25 | 26 | # On Windows we use module definition files to declare what are 27 | # the exported functions from the library. It's similar to the 28 | # Namespace file in the R ecossystem. 29 | # We need to fill it manually as its used to link the Rcpp interface 30 | # with the `csrc` library. 31 | if(WIN32) 32 | set(LLRM_SRC ${LLTM_SRC} lltm.def) 33 | endif() 34 | 35 | # Tell cmake to build the shared library. 36 | add_library(lltm SHARED ${LLTM_SRC}) 37 | add_library(lltm::library ALIAS lltm) 38 | 39 | # Tell cmake what are the include files. 40 | target_include_directories(lltm PUBLIC 41 | ${PROJECT_SOURCE_DIR}/include 42 | ) 43 | set_property(TARGET lltm 44 | PROPERTY PUBLIC_HEADER ${LLTM_HEADERS}) 45 | 46 | # Tell cmake the libraries we want to link to. 47 | message(STATUS "${TORCH_LIBRARIES}") 48 | target_link_libraries(lltm "${TORCH_LIBRARIES}") 49 | 50 | set_property(TARGET lltm PROPERTY CXX_STANDARD 17) 51 | 52 | add_custom_target(lltmExport 53 | COMMAND Rscript -e "torchexport::export()" 54 | COMMENT "Regenerating export code." 55 | VERBATIM 56 | ) 57 | add_dependencies(lltm lltmExport) 58 | 59 | # Syncronize the headers and the def file with the Rcpp 60 | # interface. 61 | add_custom_command(TARGET lltm POST_BUILD 62 | COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/src/lltm.def ${PROJECT_SOURCE_DIR}/../inst/def/lltm.def 63 | COMMENT "Copied def file to inst folder." 64 | ) 65 | 66 | # Set CPack related code to automatically generate installation bundles. 67 | # The bundle name will have the same version as defined in the R DESCRIPTION 68 | # file. 69 | # The cpack configuration is used by the CI/CD workflows to create the pre-built 70 | # binaries bundles and upload them to the GitHub Releases page. 71 | set(CPACK_GENERATOR ZIP) 72 | execute_process ( 73 | COMMAND Rscript -e "cat(desc::description$new(file = '../../DESCRIPTION')$get('Version'))" 74 | OUTPUT_VARIABLE CPACK_PACKAGE_VERSION 75 | ) 76 | 77 | if(DEFINED ${CUDA_VERSION_STRING}) 78 | set(CPACK_PACKAGE_VERSION ${CPACK_PACKAGE_VERSION}+cu${CUDA_VERSION_STRING}) 79 | else() 80 | set(CPACK_PACKAGE_VERSION ${CPACK_PACKAGE_VERSION}+cpu) 81 | endif() 82 | 83 | 84 | include(CPack) 85 | 86 | set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR}/../inst) 87 | install(TARGETS lltm LIBRARY PUBLIC_HEADER DESTINATION include/lltm) 88 | -------------------------------------------------------------------------------- /src/RcppExports.cpp: -------------------------------------------------------------------------------- 1 | // Generated by using Rcpp::compileAttributes() -> do not edit by hand 2 | // Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 3 | 4 | #include "lltm_types.h" 5 | #include 6 | 7 | using namespace Rcpp; 8 | 9 | #ifdef RCPP_USE_GLOBAL_ROSTREAM 10 | Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); 11 | Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); 12 | #endif 13 | 14 | // rcpp_lltm_forward 15 | torch::TensorList rcpp_lltm_forward(torch::Tensor input, torch::Tensor weights, torch::Tensor bias, torch::Tensor old_h, torch::Tensor old_cell); 16 | RcppExport SEXP _lltm_rcpp_lltm_forward(SEXP inputSEXP, SEXP weightsSEXP, SEXP biasSEXP, SEXP old_hSEXP, SEXP old_cellSEXP) { 17 | BEGIN_RCPP 18 | Rcpp::RObject rcpp_result_gen; 19 | Rcpp::RNGScope rcpp_rngScope_gen; 20 | Rcpp::traits::input_parameter< torch::Tensor >::type input(inputSEXP); 21 | Rcpp::traits::input_parameter< torch::Tensor >::type weights(weightsSEXP); 22 | Rcpp::traits::input_parameter< torch::Tensor >::type bias(biasSEXP); 23 | Rcpp::traits::input_parameter< torch::Tensor >::type old_h(old_hSEXP); 24 | Rcpp::traits::input_parameter< torch::Tensor >::type old_cell(old_cellSEXP); 25 | rcpp_result_gen = Rcpp::wrap(rcpp_lltm_forward(input, weights, bias, old_h, old_cell)); 26 | return rcpp_result_gen; 27 | END_RCPP 28 | } 29 | // rcpp_lltm_backward 30 | torch::TensorList rcpp_lltm_backward(torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, torch::Tensor input_gate, torch::Tensor output_gate, torch::Tensor candidate_cell, torch::Tensor X, torch::Tensor gate_weights, torch::Tensor weights); 31 | RcppExport SEXP _lltm_rcpp_lltm_backward(SEXP grad_hSEXP, SEXP grad_cellSEXP, SEXP new_cellSEXP, SEXP input_gateSEXP, SEXP output_gateSEXP, SEXP candidate_cellSEXP, SEXP XSEXP, SEXP gate_weightsSEXP, SEXP weightsSEXP) { 32 | BEGIN_RCPP 33 | Rcpp::RObject rcpp_result_gen; 34 | Rcpp::RNGScope rcpp_rngScope_gen; 35 | Rcpp::traits::input_parameter< torch::Tensor >::type grad_h(grad_hSEXP); 36 | Rcpp::traits::input_parameter< torch::Tensor >::type grad_cell(grad_cellSEXP); 37 | Rcpp::traits::input_parameter< torch::Tensor >::type new_cell(new_cellSEXP); 38 | Rcpp::traits::input_parameter< torch::Tensor >::type input_gate(input_gateSEXP); 39 | Rcpp::traits::input_parameter< torch::Tensor >::type output_gate(output_gateSEXP); 40 | Rcpp::traits::input_parameter< torch::Tensor >::type candidate_cell(candidate_cellSEXP); 41 | Rcpp::traits::input_parameter< torch::Tensor >::type X(XSEXP); 42 | Rcpp::traits::input_parameter< torch::Tensor >::type gate_weights(gate_weightsSEXP); 43 | Rcpp::traits::input_parameter< torch::Tensor >::type weights(weightsSEXP); 44 | rcpp_result_gen = Rcpp::wrap(rcpp_lltm_backward(grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights)); 45 | return rcpp_result_gen; 46 | END_RCPP 47 | } 48 | // lltm_raise_exception 49 | void lltm_raise_exception(); 50 | RcppExport SEXP _lltm_lltm_raise_exception() { 51 | BEGIN_RCPP 52 | Rcpp::RNGScope rcpp_rngScope_gen; 53 | lltm_raise_exception(); 54 | return R_NilValue; 55 | END_RCPP 56 | } 57 | 58 | static const R_CallMethodDef CallEntries[] = { 59 | {"_lltm_rcpp_lltm_forward", (DL_FUNC) &_lltm_rcpp_lltm_forward, 5}, 60 | {"_lltm_rcpp_lltm_backward", (DL_FUNC) &_lltm_rcpp_lltm_backward, 9}, 61 | {"_lltm_lltm_raise_exception", (DL_FUNC) &_lltm_lltm_raise_exception, 0}, 62 | {NULL, NULL, 0} 63 | }; 64 | 65 | RcppExport void R_init_lltm(DllInfo *dll) { 66 | R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); 67 | R_useDynamicSymbols(dll, FALSE); 68 | } 69 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/master/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | branches: [main, master] 12 | 13 | name: R-CMD-check 14 | 15 | jobs: 16 | 17 | Build-Libs: 18 | 19 | if: github.ref == 'refs/heads/main' 20 | runs-on: ${{ matrix.config.os }} 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | config: 26 | - {os: macOS-latest} 27 | - {os: windows-latest} 28 | - {os: ubuntu-18.04} 29 | 30 | env: 31 | TORCH_INSTALL: 1 32 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 33 | 34 | steps: 35 | - uses: actions/checkout@v2 36 | 37 | - uses: r-lib/actions/setup-r@v1 38 | with: 39 | r-version: ${{ matrix.config.r }} 40 | http-user-agent: ${{ matrix.config.http-user-agent }} 41 | use-public-rspm: true 42 | 43 | - name: Install dependencies 44 | run: | 45 | Rscript -e "install.packages(c('remotes', 'desc', 'rcmdcheck'))" -e "remotes::install_deps(dependencies = TRUE, INSTALL_opts='--no-multiarch')" 46 | Rscript -e "remotes::install_github('mlverse/torchexport')" 47 | 48 | - run: | 49 | cd csrc 50 | mkdir build && cd build 51 | cmake .. 52 | cmake --build . --target package --config Release 53 | 54 | - uses: svenstaro/upload-release-action@v2 55 | with: 56 | repo_token: ${{ secrets.GITHUB_TOKEN }} 57 | file: csrc/build/*.zip 58 | overwrite: true 59 | file_glob: true 60 | tag: liblltm 61 | 62 | 63 | R-CMD-check: 64 | runs-on: ${{ matrix.config.os }} 65 | 66 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 67 | 68 | strategy: 69 | fail-fast: false 70 | matrix: 71 | config: 72 | - {os: macOS-latest, r: 'release'} 73 | - {os: windows-latest, r: 'release'} 74 | - {os: ubuntu-18.04, r: 'release'} 75 | - {os: windows-latest, r: '3.6'} 76 | 77 | env: 78 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 79 | R_KEEP_PKG_SOURCE: yes 80 | TORCH_INSTALL: 1 81 | 82 | steps: 83 | - uses: actions/checkout@v2 84 | 85 | - uses: r-lib/actions/setup-pandoc@v1 86 | 87 | - uses: r-lib/actions/setup-r@v1 88 | with: 89 | r-version: ${{ matrix.config.r }} 90 | http-user-agent: ${{ matrix.config.http-user-agent }} 91 | use-public-rspm: true 92 | 93 | 94 | - run: | 95 | Rscript -e "install.packages(c('remotes', 'desc', 'rcmdcheck'))" -e "remotes::install_deps(dependencies = TRUE, INSTALL_opts='--no-multiarch')" 96 | Rscript -e "remotes::install_github('mlverse/torchexport')" 97 | 98 | - run: | 99 | cd csrc 100 | mkdir build && cd build 101 | cmake .. 102 | cmake --build . --target package --config Release 103 | Rscript -e "cat('ZIP file:', normalizePath(list.files(pattern='zip', full.names=TRUE)))" 104 | Rscript -e "cat('ENV file:', Sys.getenv('GITHUB_ENV'))" 105 | Rscript -e "writeLines(paste0('LLTM_URL=',normalizePath(list.files(pattern='zip', full.names=TRUE))), Sys.getenv('GITHUB_ENV'))" 106 | 107 | - uses: r-lib/actions/check-r-package@v1 108 | with: 109 | error-on: '"error"' 110 | args: 'c("--no-multiarch", "--no-manual")' 111 | 112 | - name: Show testthat output 113 | if: always() 114 | run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true 115 | shell: bash 116 | 117 | - name: Upload check results 118 | if: failure() 119 | uses: actions/upload-artifact@main 120 | with: 121 | name: ${{ runner.os }}-r${{ matrix.config.r }}-results 122 | path: check 123 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # lltm 17 | 18 | 19 | 20 | 21 | The goal of lltm is to be a minimal implementation of an extension for [torch](https://github.com/mlverse/torch) that interfaces with the underlying 22 | C++ interface, called LibTorch. 23 | 24 | In this pakage we provide an implementation of a new recurrent unit that is similar 25 | to a LSTM but it lacks a *forget gate* and uses an *Exponential Linear Unit* (ELU) as its internal activation function. Because this unit never forgets, we’ll call it LLTM, or **Long-Long-Term-Memory unit**. 26 | 27 | The example implemented here is a port of the official PyTorch [tutorial](https://pytorch.org/tutorials/advanced/cpp_extension.html) on custom 28 | C++ and CUDA extensions. 29 | 30 | ## High-Level overview 31 | 32 | Writing C++ extensions for torch requires us to coordinate the communication 33 | between multiple agents in the torch ecossytem. The following diagram is a high-level overview on how they communicate in this package. 34 | 35 | On the torch package side the agents that appear are: 36 | 37 | - **LibTorch**: The PyTorch's C++ interface. This is the library implementing all 38 | the heavy computations and the data structures like tensors. 39 | - **Lantern**: Is a C wrapper for LibTorch and is a part of the torch for R 40 | project. We had to develop Lantern because on Windows LibTorch can only be 41 | compiled with the MSVC compiler while R is compiled with MinGW. Because of 42 | the different compilers, only C interfaces (not C++) are compatible. 43 | - **torchpkg.so**: This is how we are referring to the C++ library, implemented 44 | with Rcpp that allows the R API to make calls to Lantern functions. Another 45 | important feature it provides is custom Rcpp types that allows users to easily 46 | manage memory life time of objects returned by Lantern. 47 | 48 | In the extension side the actors are: 49 | 50 | - **csrc**: What we are calling `csrc` here is the equivalent to Lantern in the 51 | torch project. It's a C interface for calling functions from LibTorch that 52 | implement the desidered extension functionality. The library produced here 53 | must also be compiled with MSVC on Windows thus the C interface is required. 54 | - **lltm.so**: This is the C++ library implemented using Rcpp that allows the R 55 | API to call the `csrc` functionality. Here, in general, we want to use the 56 | `torchpkg.so` features to manage memory instead of re-implementing that functionality. 57 | 58 | [![](man/figures/high-level.png)](https://excalidraw.com/#json=6114208240369664,J9vJ8KK7VOBqgn7Nex5Huw) 59 | 60 | ## Project structure 61 | 62 | - **csrc**: The directory containing library that will call efficient LibTorch code. See the section `csrc` for details. 63 | - **src**: Rcpp code that interfaces the `csrc` library and exports functionality 64 | to the R API. 65 | - **R/package.R**: Definitions for correctly downloading pre-built binaries, 66 | and dynamically loading the `csrc` library as well as the C++ library. 67 | 68 | ### csrc: Implementing the operators and their C wrappers. 69 | 70 | - **CMakeLists.txt**: The first important file that you should get familiar with in this directory is the [CMakeLists.txt](https://github.com/mlverse/lltm/blob/main/csrc/CMakeLists.txt) file. This is the [CMake](https://cmake.org/) configuration file defining how the 71 | project must be compiled and its dependencies. You can refer to comments in 72 | the [file](https://github.com/mlverse/lltm/blob/main/csrc/CMakeLists.txt) for almost line by line explanation of definitions. 73 | 74 | - **csrc/src/lltm.cpp**: In this file we define the LibTorch implementation of the 75 | operations we want to export. We can use as many functions as we want in the implementation 76 | and we mark the functions we want to make available in the R package with `// [[torch::export]]`, similar to what we do when exporting functions with Rcpp. For example 77 | we define the `lltm_forward` implementation with: (For details on the `lltm_forward` 78 | implementation refer to the [official guide](https://pytorch.org/tutorials/advanced/cpp_extension.html).) 79 | 80 | The `// [[torch::export]]` marks will allow [torchexport](https://github.com/mlverse/torchexport) that is called during when 81 | building with cmake to autogenerate C wrappers necessary to handle errors and 82 | to correctly pass data between this library and the R package. 83 | 84 | ```cpp 85 | // [[torch::export]] 86 | std::vector lltm_forward( 87 | torch::Tensor input, 88 | torch::Tensor weights, 89 | torch::Tensor bias, 90 | torch::Tensor old_h, 91 | torch::Tensor old_cell) { 92 | auto X = torch::cat({old_h, input}, /*dim=*/1); 93 | 94 | auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); 95 | auto gates = gate_weights.chunk(3, /*dim=*/1); 96 | 97 | auto input_gate = torch::sigmoid(gates[0]); 98 | auto output_gate = torch::sigmoid(gates[1]); 99 | auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); 100 | 101 | auto new_cell = old_cell + candidate_cell * input_gate; 102 | auto new_h = torch::tanh(new_cell) * output_gate; 103 | 104 | return {new_h, 105 | new_cell, 106 | input_gate, 107 | output_gate, 108 | candidate_cell, 109 | X, 110 | gate_weights}; 111 | } 112 | ``` 113 | 114 | - **csrc/src/exports.cpp**: This file is autogenerated by `torchexport` and 115 | should not be modified manually. It wrapps the function that uses LibTorch's 116 | API into C API functions that can be called in the R/Rcpp side. 117 | 118 | - **csrc/include/lltm/exports.h** This file includes declarations used by 119 | functions defined in `exports.cpp`. It should always be included in `lltm.h`. 120 | Note that this file is also autogenerated. 121 | 122 | - **csrc/src/lltm.def**: This file is automaticaly generated by a custom CMake command. 123 | It lists the functions from `lltm.cpp` that we 124 | want to export. This is only required for Windows, but it's a good practice to 125 | keep it up to date. See more information on Module definition files in this 126 | [link](https://docs.microsoft.com/en-us/cpp/build/reference/module-definition-dot-def-files?view=msvc-160) 127 | 128 | For example, the current definition is: 129 | 130 | ``` 131 | LIBRARY LLTM 132 | EXPORTS 133 | _lltm_forward 134 | _lltm_backward 135 | ``` 136 | 137 | - **csrc/include/lltm/lltm.h**: In a minimal setup this file only needs to 138 | include the `lltm/exports.h` headers that is auto-generated by 139 | [`torchexport`](https:://github.com/mlverse/torchexport). 140 | You might want add other function declarations here, if for some reason 141 | you had to bypass the code autogeneration. 142 | 143 | 144 | The library implemented in `csrc` can be compiled with CMake. We use the following 145 | commands to compile and install it locally: 146 | 147 | ``` 148 | cd csrc && mkdir build 149 | cmake .. && cmake --build . --target install --config Release 150 | ``` 151 | 152 | ### src: Wrapping the library with Rcpp 153 | 154 | Now that we implemented the operators that we wanted to call from R, we can now 155 | implement the Rcpp wrappers that will allow us to call those operators from R. 156 | 157 | - **src/exports.cpp**: This file is autogenerated and defines Rcpp wrappers for 158 | the functions that have been marked with `[[torch::export]]` in your library. 159 | The wrappers defined in this file take R objects 160 | and convert them to the correct C type that we need to pass to the C library. 161 | Remember that the C library return `void*` pointers and we need to make sure to 162 | free this objects when they are no longer in use, otherwise we will leak memory. 163 | The `torch.h` headers provides Rcpp extension types that act like *smart pointers* 164 | and make sure that the objects created in the C library are correctly freed when 165 | they are no longer in use. The types implemented in `torch.h` also implement 166 | convertion from and to `SEXP`s so we don't need to implement them on our own. 167 | 168 | You can find all the available types in the `torch` namespace available when 169 | you include ``. 170 | 171 | - **src/lltm.cpp**: In a minimal setup this file only needs to include the header 172 | files from the torch package as well as from your library and specify a few 173 | variables that make sure the implementations are included. It also must define 174 | a `host_exception_handler` that is used to correctly raise exceptions from your 175 | C library to the R runtime - in general you don't need to modify the one that's 176 | already defined in this template. 177 | 178 | ```cpp 179 | #include 180 | #define LLTM_HEADERS_ONLY // should only be defined in a single file 181 | #include 182 | #define TORCH_IMPL // should only be defined in a single file 183 | #define IMPORT_TORCH // should only be defined in a single file 184 | #include 185 | ``` 186 | 187 | - **src/Makevars.win**: On Windows, the normal compilation workflow wouldn't work 188 | as Windows wouldn't be able to find the implementations of `_lltm_forward` (as it 189 | only sees the headers), so we convert the `.def` file created in `csrc` to a `.lib` 190 | file and use this as an argument to the linker. That's what `Makevars.win` implements. 191 | In most cases you won't need to modify this file. 192 | 193 | 194 | ### R API 195 | 196 | Now the Rcpp wrappers are implemented and exported you have now access to `lltm_forward` 197 | in the R side. 198 | 199 | - **R/lltm.R**: In this package we wanted to provide a new autograd function and 200 | a `nn_module` that uses it and we implemented it in this file. This is normal 201 | R code and we won't discuss the actual implementation. 202 | 203 | ## Packaging 204 | 205 | It's not trivial to package torch extensions because they can't be entirely built 206 | on CRAN machines. We would need to include pre-built binaries in the package tarball 207 | but for security reasons that's not accepted on CRAN. 208 | 209 | In this package we implement a suggested way of packaging torch extensions that makes 210 | it really easy for users to install your package without having to use custom 211 | installation steps or building libraries from source. The diagram below shows an 212 | overview of the packaging process. 213 | 214 | ![](man/figures/packaging.png) 215 | 216 | - **R/package.R**: implements the suggested installation logic - including downloading 217 | from GitHub Releases and dynamically loading the shared libraries. 218 | 219 | - **.github/workflows/R-CMD-check.yaml**: the job called *Build-Libs* implements 220 | the logic for building the binaries from `csrc` for each operating system and 221 | uploading to GH Releases. 222 | 223 | ## Installation 224 | 225 | ~~You can install the released version of lltm from [CRAN](https://CRAN.R-project.org) with:~~ 226 | 227 | ``` r 228 | install.packages("lltm") 229 | ``` 230 | 231 | And the development version from [GitHub](https://github.com/) with: 232 | 233 | ``` r 234 | # install.packages("devtools") 235 | devtools::install_github("mlverse/lltm") 236 | ``` 237 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # lltm 5 | 6 | 7 | 8 | 9 | The goal of lltm is to be a minimal implementation of an extension for 10 | [torch](https://github.com/mlverse/torch) that interfaces with the 11 | underlying C++ interface, called LibTorch. 12 | 13 | In this pakage we provide an implementation of a new recurrent unit that 14 | is similar to a LSTM but it lacks a *forget gate* and uses an 15 | *Exponential Linear Unit* (ELU) as its internal activation function. 16 | Because this unit never forgets, we’ll call it LLTM, or 17 | **Long-Long-Term-Memory unit**. 18 | 19 | The example implemented here is a port of the official PyTorch 20 | [tutorial](https://pytorch.org/tutorials/advanced/cpp_extension.html) on 21 | custom C++ and CUDA extensions. 22 | 23 | ## High-Level overview 24 | 25 | Writing C++ extensions for torch requires us to coordinate the 26 | communication between multiple agents in the torch ecossytem. The 27 | following diagram is a high-level overview on how they communicate in 28 | this package. 29 | 30 | On the torch package side the agents that appear are: 31 | 32 | - **LibTorch**: The PyTorch’s C++ interface. This is the library 33 | implementing all the heavy computations and the data structures like 34 | tensors. 35 | - **Lantern**: Is a C wrapper for LibTorch and is a part of the torch 36 | for R project. We had to develop Lantern because on Windows LibTorch 37 | can only be 38 | compiled with the MSVC compiler while R is compiled with MinGW. 39 | Because of the different compilers, only C interfaces (not C++) are 40 | compatible. 41 | - **torchpkg.so**: This is how we are referring to the C++ library, 42 | implemented with Rcpp that allows the R API to make calls to Lantern 43 | functions. Another important feature it provides is custom Rcpp 44 | types that allows users to easily manage memory life time of objects 45 | returned by Lantern. 46 | 47 | In the extension side the actors are: 48 | 49 | - **csrc**: What we are calling `csrc` here is the equivalent to 50 | Lantern in the torch project. It’s a C interface for calling 51 | functions from LibTorch that implement the desidered extension 52 | functionality. The library produced here must also be compiled with 53 | MSVC on Windows thus the C interface is required. 54 | - **lltm.so**: This is the C++ library implemented using Rcpp that 55 | allows the R API to call the `csrc` functionality. Here, in general, 56 | we want to use the `torchpkg.so` features to manage memory instead 57 | of re-implementing that functionality. 58 | 59 | [![](man/figures/high-level.png)](https://excalidraw.com/#json=6114208240369664,J9vJ8KK7VOBqgn7Nex5Huw) 60 | 61 | ## Project structure 62 | 63 | - **csrc**: The directory containing library that will call efficient 64 | LibTorch code. See the section `csrc` for details. 65 | - **src**: Rcpp code that interfaces the `csrc` library and exports 66 | functionality to the R API. 67 | - **R/package.R**: Definitions for correctly downloading pre-built 68 | binaries, and dynamically loading the `csrc` library as well as the 69 | C++ library. 70 | 71 | ### csrc: Implementing the operators and their C wrappers. 72 | 73 | - **CMakeLists.txt**: The first important file that you should get 74 | familiar with in this directory is the 75 | [CMakeLists.txt](https://github.com/mlverse/lltm/blob/main/csrc/CMakeLists.txt) 76 | file. This is the [CMake](https://cmake.org/) configuration file 77 | defining how the project must be compiled and its dependencies. You 78 | can refer to comments in the 79 | [file](https://github.com/mlverse/lltm/blob/main/csrc/CMakeLists.txt) 80 | for almost line by line explanation of definitions. 81 | 82 | - **csrc/src/lltm.cpp**: In this file we define the LibTorch 83 | implementation of the operations we want to export. We can use as 84 | many functions as we want in the implementation and we mark the 85 | functions we want to make available in the R package with 86 | `// [[torch::export]]`, similar to what we do when exporting 87 | functions with Rcpp. For example we define the `lltm_forward` 88 | implementation with: (For details on the `lltm_forward` 89 | implementation refer to the [official 90 | guide](https://pytorch.org/tutorials/advanced/cpp_extension.html).) 91 | 92 | The `// [[torch::export]]` marks will allow 93 | [torchexport](https://github.com/mlverse/torchexport) that is called 94 | during when building with cmake to autogenerate C wrappers necessary 95 | to handle errors and to correctly pass data between this library and 96 | the R package. 97 | 98 | ``` cpp 99 | // [[torch::export]] 100 | std::vector lltm_forward( 101 | torch::Tensor input, 102 | torch::Tensor weights, 103 | torch::Tensor bias, 104 | torch::Tensor old_h, 105 | torch::Tensor old_cell) { 106 | auto X = torch::cat({old_h, input}, /*dim=*/1); 107 | 108 | auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); 109 | auto gates = gate_weights.chunk(3, /*dim=*/1); 110 | 111 | auto input_gate = torch::sigmoid(gates[0]); 112 | auto output_gate = torch::sigmoid(gates[1]); 113 | auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); 114 | 115 | auto new_cell = old_cell + candidate_cell * input_gate; 116 | auto new_h = torch::tanh(new_cell) * output_gate; 117 | 118 | return {new_h, 119 | new_cell, 120 | input_gate, 121 | output_gate, 122 | candidate_cell, 123 | X, 124 | gate_weights}; 125 | } 126 | ``` 127 | 128 | - **csrc/src/exports.cpp**: This file is autogenerated by 129 | `torchexport` and should not be modified manually. It wrapps the 130 | function that uses LibTorch’s API into C API functions that can be 131 | called in the R/Rcpp side. 132 | 133 | - **csrc/include/lltm/exports.h** This file includes declarations used 134 | by functions defined in `exports.cpp`. It should always be included 135 | in `lltm.h`. Note that this file is also autogenerated. 136 | 137 | - **csrc/src/lltm.def**: This file is automaticaly generated by a 138 | custom CMake command. It lists the functions from `lltm.cpp` that we 139 | want to export. This is only required for Windows, but it’s a good 140 | practice to keep it up to date. See more information on Module 141 | definition files in this 142 | [link](https://docs.microsoft.com/en-us/cpp/build/reference/module-definition-dot-def-files?view=msvc-160) 143 | 144 | For example, the current definition is: 145 | 146 | LIBRARY LLTM 147 | EXPORTS 148 | _lltm_forward 149 | _lltm_backward 150 | 151 | - **csrc/include/lltm/lltm.h**: In a minimal setup this file only 152 | needs to include the `lltm/exports.h` headers that is auto-generated 153 | by [`torchexport`](https:://github.com/mlverse/torchexport). You 154 | might want add other function declarations here, if for some reason 155 | you had to bypass the code autogeneration. 156 | 157 | The library implemented in `csrc` can be compiled with CMake. We use the 158 | following commands to compile and install it locally: 159 | 160 | cd csrc && mkdir build 161 | cmake .. && cmake --build . --target install --config Release 162 | 163 | ### src: Wrapping the library with Rcpp 164 | 165 | Now that we implemented the operators that we wanted to call from R, we 166 | can now implement the Rcpp wrappers that will allow us to call those 167 | operators from R. 168 | 169 | - **src/exports.cpp**: This file is autogenerated and defines Rcpp 170 | wrappers for the functions that have been marked with 171 | `[[torch::export]]` in your library. The wrappers defined in this 172 | file take R objects and convert them to the correct C type that we 173 | need to pass to the C library. Remember that the C library return 174 | `void*` pointers and we need to make sure to free this objects when 175 | they are no longer in use, otherwise we will leak memory. The 176 | `torch.h` headers provides Rcpp extension types that act like *smart 177 | pointers* and make sure that the objects created in the C library 178 | are correctly freed when they are no longer in use. The types 179 | implemented in `torch.h` also implement convertion from and to 180 | `SEXP`s so we don’t need to implement them on our own. 181 | 182 | You can find all the available types in the `torch` namespace 183 | available when you include ``. 184 | 185 | - **src/lltm.cpp**: In a minimal setup this file only needs to include 186 | the header files from the torch package as well as from your library 187 | and specify a few variables that make sure the implementations are 188 | included. It also must define a `host_exception_handler` that is 189 | used to correctly raise exceptions from your C library to the R 190 | runtime - in general you don’t need to modify the one that’s already 191 | defined in this template. 192 | 193 | ``` cpp 194 | #include 195 | #define LLTM_HEADERS_ONLY // should only be defined in a single file 196 | #include 197 | #define TORCH_IMPL // should only be defined in a single file 198 | #define IMPORT_TORCH // should only be defined in a single file 199 | #include 200 | ``` 201 | 202 | - **src/Makevars.win**: On Windows, the normal compilation workflow 203 | wouldn’t work as Windows wouldn’t be able to find the 204 | implementations of `_lltm_forward` (as it only sees the headers), so 205 | we convert the `.def` file created in `csrc` to a `.lib` file and 206 | use this as an argument to the linker. That’s what `Makevars.win` 207 | implements. In most cases you won’t need to modify this file. 208 | 209 | ### R API 210 | 211 | Now the Rcpp wrappers are implemented and exported you have now access 212 | to `lltm_forward` in the R side. 213 | 214 | - **R/lltm.R**: In this package we wanted to provide a new autograd 215 | function and a `nn_module` that uses it and we implemented it in 216 | this file. This is normal R code and we won’t discuss the actual 217 | implementation. 218 | 219 | ## Packaging 220 | 221 | It’s not trivial to package torch extensions because they can’t be 222 | entirely built on CRAN machines. We would need to include pre-built 223 | binaries in the package tarball but for security reasons that’s not 224 | accepted on CRAN. 225 | 226 | In this package we implement a suggested way of packaging torch 227 | extensions that makes it really easy for users to install your package 228 | without having to use custom installation steps or building libraries 229 | from source. The diagram below shows an overview of the packaging 230 | process. 231 | 232 | ![](man/figures/packaging.png) 233 | 234 | - **R/package.R**: implements the suggested installation logic - 235 | including downloading from GitHub Releases and dynamically loading 236 | the shared libraries. 237 | 238 | - **.github/workflows/R-CMD-check.yaml**: the job called *Build-Libs* 239 | implements the logic for building the binaries from `csrc` for each 240 | operating system and uploading to GH Releases. 241 | 242 | ## Installation 243 | 244 | ~~You can install the released version of lltm from 245 | [CRAN](https://CRAN.R-project.org) with:~~ 246 | 247 | ``` r 248 | install.packages("lltm") 249 | ``` 250 | 251 | And the development version from [GitHub](https://github.com/) with: 252 | 253 | ``` r 254 | # install.packages("devtools") 255 | devtools::install_github("mlverse/lltm") 256 | ``` 257 | --------------------------------------------------------------------------------