├── .Rbuildignore ├── .github └── workflows │ └── R-CMD-check.yml ├── .gitignore ├── CRAN-SUBMISSION ├── DESCRIPTION ├── NAMESPACE ├── NEWS.md ├── R ├── columns.R ├── install.R ├── layer.R ├── load.R ├── package.R ├── recipe.R └── reexports.R ├── README.md ├── cran-comments.md ├── man ├── bake.step_pretrained_text_embedding.Rd ├── hub_image_embedding_column.Rd ├── hub_load.Rd ├── hub_sparse_text_embedding_column.Rd ├── hub_text_embedding_column.Rd ├── install_tfhub.Rd ├── layer_hub.Rd ├── pipe.Rd ├── prep.step_pretrained_text_embedding.Rd ├── reexports.Rd └── step_pretrained_text_embedding.Rd ├── tests ├── testthat.R └── testthat │ ├── test-columns.R │ ├── test-layer-hub.R │ ├── test-load.R │ ├── test-recipe.R │ └── utils.R ├── tfhub.Rproj └── vignettes ├── .gitignore ├── archived ├── feature_column.R └── feature_column.Rmd ├── examples ├── .gitignore ├── biggan_image_generation.R ├── biggan_image_generation.Rmd ├── image_classification.R ├── image_classification.Rmd ├── index.Rmd ├── recipes.R ├── recipes.Rmd ├── text_classification.R ├── text_classification.Rmd └── using_bert_tfhub.R ├── hub-with-keras.Rmd ├── intro.Rmd └── key-concepts.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^examples$ 4 | 5 | 6 | ^\.travis\.yml$ 7 | ^README\.Rmd$ 8 | ^appveyor\.yml$ 9 | 10 | ^.*saved_model.pb 11 | ^.*train.csv.zip 12 | ^doc$ 13 | ^Meta$ 14 | ^.*flower_photos.*$ 15 | ^.*train_images.*$ 16 | ^.*train\.zip$ 17 | ^.*train\.csv$ 18 | ^\.github.*$ 19 | ^cran-comments\.md$ 20 | ^.*\.zip$ 21 | ^CRAN-SUBMISSION$ 22 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yml: -------------------------------------------------------------------------------- 1 | 2 | name: R-CMD-check 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | schedule: 10 | - cron: '0 1 * * *' 11 | 12 | 13 | defaults: 14 | run: 15 | shell: Rscript {0} 16 | 17 | 18 | jobs: 19 | R-CMD-check: 20 | runs-on: ${{ matrix.os }} 21 | 22 | name: ${{ matrix.os }} (TF ${{ matrix.tf }}) (TFHUB ${{ matrix.tfhub }}) 23 | 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | include: 28 | - { os: windows-latest, tf: 'release', tfhub: 'release', r: 'release'} 29 | - { os: macOS-latest , tf: 'release', tfhub: 'release', r: 'release'} 30 | - { os: ubuntu-latest , tf: 'release', tfhub: 'release', r: 'release'} 31 | - { os: ubuntu-latest , tf: 'nightly', tfhub: 'release', r: 'release', allow_failure: true} 32 | 33 | env: 34 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: 'true' 35 | R_COMPILE_AND_INSTALL_PACKAGES: 'never' 36 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 37 | 38 | steps: 39 | - uses: actions/checkout@v2 40 | 41 | - uses: r-lib/actions/setup-r@v2 42 | with: 43 | r-version: ${{ matrix.r }} 44 | use-public-rspm: true 45 | Ncpus: '2' 46 | 47 | - uses: r-lib/actions/setup-pandoc@v2 48 | 49 | - uses: r-lib/actions/setup-r-dependencies@v2 50 | with: 51 | extra-packages: rcmdcheck remotes 52 | 53 | - name: Install dev r-tensorflow 54 | run: remotes::install_github(paste0("rstudio/", c("tensorflow", "keras"))) 55 | 56 | - name: Install Miniconda 57 | run: reticulate::install_miniconda() 58 | 59 | - name: Install Tensorflow 60 | # run: Rscript -e 'keras::install_keras(version = "${{ matrix.tf }}-cpu")' 61 | run: | 62 | tensorflow::install_tensorflow( 63 | version = "${{ matrix.tf }}-cpu", 64 | extra_packages = c("Pillow", "scipy")) 65 | 66 | - name: Install tfhub r-pkg 67 | run: remotes::install_local() 68 | 69 | - name: Install tfhub py module 70 | run: tfhub::install_tfhub("${{ matrix.tfhub }}") 71 | 72 | - name: Check 73 | continue-on-error: ${{ matrix.allow_failure }} 74 | run: rcmdcheck::rcmdcheck(args = '--no-manual', error_on = 'warning', check_dir = 'check') 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | examples/text_classification/aclimdb 6 | inst/doc 7 | .DS_Store 8 | my_module 9 | train.csv.zip 10 | doc 11 | Meta 12 | flower_photos 13 | jigsaw-toxic-comment-classification-challenge.zip 14 | *.zip 15 | -------------------------------------------------------------------------------- /CRAN-SUBMISSION: -------------------------------------------------------------------------------- 1 | Version: 0.8.1 2 | Date: 2021-12-15 14:25:15 UTC 3 | SHA: 3dfe200753fe741eaaf893d9e403e9789b0fa30d 4 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: tfhub 2 | Type: Package 3 | Title: Interface to 'TensorFlow' Hub 4 | Version: 0.8.1.9000 5 | Authors@R: c( 6 | person("Tomasz", "Kalinowski", role = c("aut", "cre"), 7 | email = "tomasz.kalinowski@rstudio.com"), 8 | person("Daniel", "Falbel", role = c("aut"), 9 | email = "daniel@rstudio.com"), 10 | person("JJ", "Allaire", role = c("aut"), 11 | email = "jj@rstudio.com"), 12 | person("RStudio", role = c("cph", "fnd")), 13 | person(family = "Google Inc.", role = c("cph")) 14 | ) 15 | Description: 'TensorFlow' Hub is a library for the publication, discovery, and 16 | consumption of reusable parts of machine learning models. A module is a 17 | self-contained piece of a 'TensorFlow' graph, along with its weights and 18 | assets, that can be reused across different tasks in a process known as 19 | transfer learning. Transfer learning train a model with a smaller dataset, 20 | improve generalization, and speed up training. 21 | License: Apache License 2.0 22 | URL: https://github.com/rstudio/tfhub 23 | BugReports: https://github.com/rstudio/tfhub/issues 24 | SystemRequirements: TensorFlow >= 2.0 (https://www.tensorflow.org/) 25 | Encoding: UTF-8 26 | LazyData: true 27 | RoxygenNote: 7.1.2 28 | Imports: 29 | reticulate (>= 1.9.0.9002), 30 | tensorflow (>= 1.8.0.9006), 31 | magrittr, 32 | rstudioapi (>= 0.7), 33 | vctrs 34 | Suggests: 35 | testthat (>= 2.1.0), 36 | knitr, 37 | tfestimators, 38 | keras, 39 | rmarkdown, 40 | callr, 41 | recipes, 42 | tibble, 43 | abind, 44 | fs, 45 | VignetteBuilder: knitr 46 | Config/reticulate: 47 | list( 48 | packages = list( 49 | list(package = "tensorflow_hub", pip = TRUE) 50 | ) 51 | ) 52 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export("%>%") 4 | export(hub_image_embedding_column) 5 | export(hub_load) 6 | export(hub_sparse_text_embedding_column) 7 | export(hub_text_embedding_column) 8 | export(install_tensorflow) 9 | export(install_tfhub) 10 | export(layer_hub) 11 | export(shape) 12 | export(step_pretrained_text_embedding) 13 | export(tf) 14 | importFrom(magrittr,"%>%") 15 | importFrom(tensorflow,install_tensorflow) 16 | importFrom(tensorflow,shape) 17 | importFrom(tensorflow,tf) 18 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # tfhub (development version) 2 | 3 | # tfhub 0.8.1 4 | 5 | - `install_tfhub()` now defaults to the release version of tensorflow_hub. 6 | - Removed {pins} dependency 7 | 8 | # tfhub 0.8.0 9 | 10 | * Added a `NEWS.md` file to track changes to the package. 11 | -------------------------------------------------------------------------------- /R/columns.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | #' Module to construct a dense representation from a text feature. 4 | #' 5 | #' This feature column can be used on an input feature whose values are strings of 6 | #' arbitrary size. 7 | #' 8 | #' @inheritParams hub_sparse_text_embedding_column 9 | #' 10 | #' @export 11 | hub_text_embedding_column <- function(key, module_spec, trainable = FALSE) { 12 | tfhub$text_embedding_column( 13 | key = key, 14 | module_spec = module_spec, 15 | trainable = trainable 16 | ) 17 | } 18 | 19 | #' Module to construct dense representations from sparse text features. 20 | #' 21 | #' The input to this feature column is a batch of multiple strings with 22 | #' arbitrary size, assuming the input is a SparseTensor. 23 | #' 24 | #' This type of feature column is typically suited for modules that operate 25 | #' on pre-tokenized text to produce token level embeddings which are combined 26 | #' with the combiner into a text embedding. The combiner always treats the tokens 27 | #' as a bag of words rather than a sequence. 28 | #' 29 | #' The output (i.e., transformed input layer) is a DenseTensor, with 30 | #' shape [batch_size, num_embedding_dim]. 31 | #' 32 | #' @param key A string or [feature_column](https://tensorflow.rstudio.com/tfestimators/articles/feature_columns.html) 33 | #' identifying the text feature. 34 | #' @param module_spec A string handle or a _ModuleSpec identifying the module. 35 | #' @param combiner a string specifying reducing op for embeddings in the same Example. 36 | #' Currently, 'mean', 'sqrtn', 'sum' are supported. Using `combiner = NULL` is 37 | #' undefined. 38 | #' @param default_value default value for Examples where the text feature is empty. 39 | #' Note, it's recommended to have default_value consistent OOV tokens, in case 40 | #' there was special handling of OOV in the text module. If `NULL`, the text 41 | #' feature is assumed be non-empty for each Example. 42 | #' @param trainable Whether or not the Module is trainable. `FALSE` by default, 43 | #' meaning the pre-trained weights are frozen. This is different from the ordinary 44 | #' `tf.feature_column.embedding_column()`, but that one is intended for training 45 | #' from scratch. 46 | #' 47 | #' @export 48 | hub_sparse_text_embedding_column <- function(key, module_spec, combiner, 49 | default_value, trainable = FALSE) { 50 | tfhub$sparse_text_embedding_column( 51 | key = key, 52 | module_spec = module_spec, 53 | combiner = combiner, 54 | default_value = default_value, 55 | trainable = trainable 56 | ) 57 | } 58 | 59 | 60 | #' Module to construct a dense 1-D representation from the pixels of images. 61 | #' 62 | #' @inheritParams hub_sparse_text_embedding_column 63 | #' 64 | #' @details 65 | #' This feature column can be used on images, represented as float32 tensors of RGB pixel 66 | #' data in the range [0,1]. 67 | #' 68 | #' @export 69 | hub_image_embedding_column <- function(key, module_spec) { 70 | tfhub$image_embedding_column( 71 | key = key, 72 | module_spec = module_spec 73 | ) 74 | } 75 | 76 | 77 | -------------------------------------------------------------------------------- /R/install.R: -------------------------------------------------------------------------------- 1 | 2 | #' Install TensorFlow Hub 3 | #' 4 | #' This function is used to install the TensorFlow Hub python module. 5 | #' 6 | #' @param version version of TensorFlow Hub to be installed. 7 | #' @param ... other arguments passed to [reticulate::py_install()]. 8 | #' @param restart_session Restart R session after installing (note this will 9 | #' only occur within RStudio). 10 | #' 11 | #' @export 12 | install_tfhub <- function(version = "release", ..., restart_session = TRUE) { 13 | 14 | if (version == "nightly") 15 | module_string <- "tf-hub-nightly" 16 | else if (is.null(version) || version %in% c("release", "default", "")) 17 | module_string <- "tensorflow_hub" 18 | else 19 | module_string <- paste0("tensorflow_hub==", version) 20 | 21 | reticulate::py_install(packages = module_string, pip = TRUE, ...) 22 | 23 | if (restart_session && rstudioapi::hasFun("restartSession")) 24 | rstudioapi::restartSession() 25 | } 26 | -------------------------------------------------------------------------------- /R/layer.R: -------------------------------------------------------------------------------- 1 | #' Hub Layer 2 | #' 3 | #' Wraps a Hub module (or a similar callable) for TF2 as a Keras Layer. 4 | #' 5 | #' This layer wraps a callable object for use as a Keras layer. The callable 6 | #' object can be passed directly, or be specified by a string with a handle 7 | #' that gets passed to `hub_load()`. 8 | #' 9 | #' The callable object is expected to follow the conventions detailed below. 10 | #' (These are met by TF2-compatible modules loaded from TensorFlow Hub.) 11 | #' 12 | #' The callable is invoked with a single positional argument set to one tensor or 13 | #' a list of tensors containing the inputs to the layer. If the callable accepts 14 | #' a training argument, a boolean is passed for it. It is `TRUE` if this layer 15 | #' is marked trainable and called for training. 16 | #' 17 | #' If present, the following attributes of callable are understood to have special 18 | #' meanings: variables: a list of all tf.Variable objects that the callable depends on. 19 | #' trainable_variables: those elements of variables that are reported as trainable 20 | #' variables of this Keras Layer when the layer is trainable. regularization_losses: 21 | #' a list of callables to be added as losses of this Keras Layer when the layer is 22 | #' trainable. Each one must accept zero arguments and return a scalar tensor. 23 | #' 24 | #' @param object Model or layer object 25 | #' @param handle a callable object (subject to the conventions above), or a string 26 | #' for which `hub_load()` returns such a callable. A string is required to save 27 | #' the Keras config of this Layer. 28 | #' @param trainable Boolean controlling whether this layer is trainable. 29 | #' @param arguments optionally, a list with additional keyword arguments passed to 30 | #' the callable. These must be JSON-serializable to save the Keras config of 31 | #' this layer. 32 | #' @param ... Other arguments that are passed to the TensorFlow Hub module. 33 | #' 34 | #' @examples 35 | #' 36 | #' \dontrun{ 37 | #' 38 | #' library(keras) 39 | #' 40 | #' model <- keras_model_sequential() %>% 41 | #' layer_hub( 42 | #' handle = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", 43 | #' input_shape = c(224, 224, 3) 44 | #' ) %>% 45 | #' layer_dense(1) 46 | #' 47 | #' } 48 | #' 49 | #' @export 50 | layer_hub <- function(object, handle, trainable = FALSE, arguments = NULL, ...) { 51 | 52 | args <- list(...) 53 | 54 | if (!is.null(args$input_shape)) 55 | args$input_shape <- lapply(args$input_shape, as_nullable_integer) 56 | 57 | keras::create_layer( 58 | tfhub$KerasLayer, 59 | object, 60 | append( 61 | list( 62 | handle = handle, 63 | trainable = trainable, 64 | arguments = arguments 65 | ), 66 | args 67 | ) 68 | ) 69 | } 70 | 71 | as_nullable_integer <- function(x) { 72 | if (is.null(x)) 73 | x 74 | else 75 | as.integer(x) 76 | } 77 | -------------------------------------------------------------------------------- /R/load.R: -------------------------------------------------------------------------------- 1 | #' Hub Load 2 | #' 3 | #' Loads a module from a handle. 4 | #' 5 | #' Currently this method is fully supported only with Tensorflow 2.x and with 6 | #' modules created by calling `export_savedmodel`. The method works in 7 | #' both eager and graph modes. 8 | #' 9 | #' Depending on the type of handle used, the call may involve downloading a 10 | #' TensorFlow Hub module to a local cache location specified by the 11 | #' `TFHUB_CACHE_DIR` environment variable. If a copy of the module is already 12 | #' present in the TFHUB_CACHE_DIR, the download step is skipped. 13 | #' 14 | #' Currently, three types of module handles are supported: 1) Smart URL resolvers 15 | #' such as tfhub.dev, e.g.: https://tfhub.dev/google/nnlm-en-dim128/1. 2) A directory 16 | #' on a file system supported by Tensorflow containing module files. This may include 17 | #' a local directory (e.g. /usr/local/mymodule) or a Google Cloud Storage bucket 18 | #' (gs://mymodule). 3) A URL pointing to a TGZ archive of a module, e.g. 19 | #' https://example.com/mymodule.tar.gz. 20 | #' 21 | #' @param handle (string) the Module handle to resolve. 22 | #' @param tags A set of strings specifying the graph variant to use, if loading 23 | #' from a v1 module. 24 | #' 25 | #' @examples 26 | #' 27 | #' \dontrun{ 28 | #' 29 | #' model <- hub_load('https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4') 30 | #' 31 | #' } 32 | #' 33 | #' @export 34 | hub_load <- function(handle, tags = NULL) { 35 | tfhub$load(handle, tags) 36 | } 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /R/package.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | # main tfhub module 4 | tfhub <- NULL 5 | 6 | .onLoad <- function(libname, pkgname) { 7 | tfhub <<- reticulate::import("tensorflow_hub", delay_load = list( 8 | priority = 10, 9 | environment = "r-tensorflow" 10 | )) 11 | 12 | vctrs::s3_register("recipes::prep", "step_pretrained_text_embedding") 13 | vctrs::s3_register("recipes::bake", "step_pretrained_text_embedding") 14 | } 15 | -------------------------------------------------------------------------------- /R/recipe.R: -------------------------------------------------------------------------------- 1 | #' Pretrained text-embeddings 2 | #' 3 | #' `step_pretrained_text_embedding` creates a *specification* of a 4 | #' recipe step that will transform text data into its numerical 5 | #' transformation based on a pretrained model. 6 | #' 7 | #' @param recipe A recipe object. The step will be added to the 8 | #' sequence of operations for this recipe. 9 | #' @param ... One or more selector functions to choose variables. 10 | #' @param role Role for the created variables 11 | #' @param trained A logical to indicate if the quantities for 12 | #' preprocessing have been estimated. 13 | #' @param skip A logical. Should the step be skipped when the 14 | #' recipe is baked by [recipes::bake.recipe()]? While all operations are baked 15 | #' when [recipes::prep.recipe()] is run, some operations may not be able to be 16 | #' conducted on new data (e.g. processing the outcome variable(s)). 17 | #' Care should be taken when using `skip = TRUE` as it may affect 18 | #' the computations for subsequent operations 19 | #' @param handle the Module handle to resolve. 20 | #' @param args other arguments passed to [hub_load()]. 21 | #' @param id A character string that is unique to this step to identify it. 22 | #' 23 | #' @examples 24 | #' 25 | #' \dontrun{ 26 | #' library(tibble) 27 | #' library(recipes) 28 | #' df <- tibble(text = c('hi', "heello", "goodbye"), y = 0) 29 | #' 30 | #' rec <- recipe(y ~ text, df) 31 | #' rec <- rec %>% step_pretrained_text_embedding( 32 | #' text, 33 | #' handle = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1" 34 | #' ) 35 | #' 36 | #' } 37 | #' 38 | #' @export 39 | step_pretrained_text_embedding <- function( 40 | recipe, ..., 41 | role = "predictor", 42 | trained = FALSE, 43 | handle, 44 | args = NULL, 45 | skip = FALSE, 46 | id = recipes::rand_id("pretrained_text_embedding") 47 | ) { 48 | 49 | terms <- recipes::ellipse_check(...) 50 | 51 | recipes::add_step( 52 | recipe, 53 | step_pretrained_text_embedding_new( 54 | terms = terms, 55 | trained = trained, 56 | role = role, 57 | vars = NULL, 58 | handle = handle, 59 | args = args, 60 | skip = skip, 61 | id = id 62 | ) 63 | ) 64 | } 65 | 66 | step_pretrained_text_embedding_new <- function(terms, role, trained, vars, 67 | handle, args, skip, id) { 68 | recipes::step( 69 | subclass = "pretrained_text_embedding", 70 | terms = terms, 71 | role = role, 72 | trained = trained, 73 | vars = vars, 74 | handle = handle, 75 | args = args, 76 | skip = skip, 77 | id = id 78 | ) 79 | } 80 | 81 | #' Prep method for step_pretrained_text_embedding 82 | #' 83 | #' @param x object 84 | #' @param info variables state 85 | #' @param training wether or not it's training 86 | #' 87 | #' @inheritParams step_pretrained_text_embedding 88 | #' 89 | prep.step_pretrained_text_embedding <- function(x, training, info = NULL, ...) { 90 | col_names <- recipes::terms_select(terms = x$terms, info = info) 91 | 92 | step_pretrained_text_embedding_new( 93 | terms = x$terms, 94 | trained = TRUE, 95 | role = x$role, 96 | vars = col_names, 97 | handle = x$handle, 98 | args = x$args, 99 | skip = x$skip, 100 | id = x$id 101 | ) 102 | } 103 | 104 | get_embedding <- function(column, module) { 105 | out <- module(as.character(column)) 106 | 107 | if (!tensorflow::tf$executing_eagerly()) { 108 | sess <- tensorflow::tf$compat$v1$Session() 109 | sess$run(tensorflow::tf$compat$v1$global_variables_initializer()) 110 | sess$run(tensorflow::tf$compat$v1$tables_initializer()) 111 | out <- sess$run(out) 112 | sess$close() 113 | } else { 114 | out <- as.matrix(out) 115 | } 116 | 117 | out 118 | } 119 | 120 | #' Bake method for step_pretrained_text_embedding 121 | #' 122 | #' @param object object 123 | #' @param new_data new data to apply transformations 124 | #' 125 | #' @inheritParams step_pretrained_text_embedding 126 | #' 127 | bake.step_pretrained_text_embedding <- function(object, new_data, ...) { 128 | 129 | module <- do.call(hub_load, append(list(handle = object$handle), object$args)) 130 | 131 | embeddings <- lapply(object$vars, function(x) { 132 | embedding <- get_embedding(new_data[[x]], module) 133 | colnames(embedding) <- sprintf("%s_txt_emb_%04d", x, 1:ncol(embedding)) 134 | tibble::as_tibble(embedding) 135 | }) 136 | 137 | out <- do.call(cbind, append(list(new_data), embeddings)) 138 | 139 | # remove text columns 140 | for (i in object$vars) { 141 | out[[i]] <- NULL 142 | } 143 | 144 | out 145 | } 146 | 147 | 148 | -------------------------------------------------------------------------------- /R/reexports.R: -------------------------------------------------------------------------------- 1 | #' Pipe operator 2 | #' 3 | #' See \code{\link[magrittr]{\%>\%}} for more details. 4 | #' 5 | #' @name %>% 6 | #' @rdname pipe 7 | #' @keywords internal 8 | #' @export 9 | #' @importFrom magrittr %>% 10 | #' @usage lhs \%>\% rhs 11 | NULL 12 | 13 | 14 | #' @importFrom tensorflow install_tensorflow 15 | #' @export 16 | tensorflow::install_tensorflow 17 | 18 | #' @importFrom tensorflow tf 19 | #' @export 20 | tensorflow::tf 21 | 22 | #' @importFrom tensorflow shape 23 | #' @export 24 | tensorflow::shape 25 | 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tfhub 2 | 3 | 4 | [](https://github.com/rstudio/tfhub/) 5 | 6 | 7 | The tfhub package provides R wrappers to [TensorFlow Hub](https://www.tensorflow.org/hub). 8 | 9 | [TensorFlow Hub](https://www.tensorflow.org/hub) is a library for reusable machine learning modules. 10 | 11 | TensorFlow Hub is a library for the publication, discovery, and consumption of reusable parts of machine learning models. A module is a self-contained piece of a TensorFlow graph, along with its weights and assets, that can be reused across different tasks in a process known as transfer learning. Transfer learning can: 12 | 13 | * Train a model with a smaller dataset, 14 | * Improve generalization, and 15 | * Speed up training. 16 | 17 | ## Installation 18 | 19 | You can install the development version from [GitHub](https://github.com/) with: 20 | 21 | ``` r 22 | # install.packages("devtools") 23 | devtools::install_github("rstudio/tfhub") 24 | ``` 25 | 26 | After installing the tfhub package you need to install the TensorFlow Hub python 27 | module: 28 | 29 | ``` r 30 | library(tfhub) 31 | install_tfhub() 32 | ``` 33 | 34 | Go to [**the website**](https://tensorflow.rstudio.com/guide/tfhub/intro/) for more information. 35 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | Minor release, bugfixes and updates. 2 | 3 | Details in NEWS.md 4 | -------------------------------------------------------------------------------- /man/bake.step_pretrained_text_embedding.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe.R 3 | \name{bake.step_pretrained_text_embedding} 4 | \alias{bake.step_pretrained_text_embedding} 5 | \title{Bake method for step_pretrained_text_embedding} 6 | \usage{ 7 | bake.step_pretrained_text_embedding(object, new_data, ...) 8 | } 9 | \arguments{ 10 | \item{object}{object} 11 | 12 | \item{new_data}{new data to apply transformations} 13 | 14 | \item{...}{One or more selector functions to choose variables.} 15 | } 16 | \description{ 17 | Bake method for step_pretrained_text_embedding 18 | } 19 | -------------------------------------------------------------------------------- /man/hub_image_embedding_column.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/columns.R 3 | \name{hub_image_embedding_column} 4 | \alias{hub_image_embedding_column} 5 | \title{Module to construct a dense 1-D representation from the pixels of images.} 6 | \usage{ 7 | hub_image_embedding_column(key, module_spec) 8 | } 9 | \arguments{ 10 | \item{key}{A string or [feature_column](https://tensorflow.rstudio.com/tfestimators/articles/feature_columns.html) 11 | identifying the text feature.} 12 | 13 | \item{module_spec}{A string handle or a _ModuleSpec identifying the module.} 14 | } 15 | \description{ 16 | Module to construct a dense 1-D representation from the pixels of images. 17 | } 18 | \details{ 19 | This feature column can be used on images, represented as float32 tensors of RGB pixel 20 | data in the range [0,1]. 21 | } 22 | -------------------------------------------------------------------------------- /man/hub_load.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/load.R 3 | \name{hub_load} 4 | \alias{hub_load} 5 | \title{Hub Load} 6 | \usage{ 7 | hub_load(handle, tags = NULL) 8 | } 9 | \arguments{ 10 | \item{handle}{(string) the Module handle to resolve.} 11 | 12 | \item{tags}{A set of strings specifying the graph variant to use, if loading 13 | from a v1 module.} 14 | } 15 | \description{ 16 | Loads a module from a handle. 17 | } 18 | \details{ 19 | Currently this method is fully supported only with Tensorflow 2.x and with 20 | modules created by calling `export_savedmodel`. The method works in 21 | both eager and graph modes. 22 | 23 | Depending on the type of handle used, the call may involve downloading a 24 | TensorFlow Hub module to a local cache location specified by the 25 | `TFHUB_CACHE_DIR` environment variable. If a copy of the module is already 26 | present in the TFHUB_CACHE_DIR, the download step is skipped. 27 | 28 | Currently, three types of module handles are supported: 1) Smart URL resolvers 29 | such as tfhub.dev, e.g.: https://tfhub.dev/google/nnlm-en-dim128/1. 2) A directory 30 | on a file system supported by Tensorflow containing module files. This may include 31 | a local directory (e.g. /usr/local/mymodule) or a Google Cloud Storage bucket 32 | (gs://mymodule). 3) A URL pointing to a TGZ archive of a module, e.g. 33 | https://example.com/mymodule.tar.gz. 34 | } 35 | \examples{ 36 | 37 | \dontrun{ 38 | 39 | model <- hub_load('https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4') 40 | 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /man/hub_sparse_text_embedding_column.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/columns.R 3 | \name{hub_sparse_text_embedding_column} 4 | \alias{hub_sparse_text_embedding_column} 5 | \title{Module to construct dense representations from sparse text features.} 6 | \usage{ 7 | hub_sparse_text_embedding_column( 8 | key, 9 | module_spec, 10 | combiner, 11 | default_value, 12 | trainable = FALSE 13 | ) 14 | } 15 | \arguments{ 16 | \item{key}{A string or [feature_column](https://tensorflow.rstudio.com/tfestimators/articles/feature_columns.html) 17 | identifying the text feature.} 18 | 19 | \item{module_spec}{A string handle or a _ModuleSpec identifying the module.} 20 | 21 | \item{combiner}{a string specifying reducing op for embeddings in the same Example. 22 | Currently, 'mean', 'sqrtn', 'sum' are supported. Using `combiner = NULL` is 23 | undefined.} 24 | 25 | \item{default_value}{default value for Examples where the text feature is empty. 26 | Note, it's recommended to have default_value consistent OOV tokens, in case 27 | there was special handling of OOV in the text module. If `NULL`, the text 28 | feature is assumed be non-empty for each Example.} 29 | 30 | \item{trainable}{Whether or not the Module is trainable. `FALSE` by default, 31 | meaning the pre-trained weights are frozen. This is different from the ordinary 32 | `tf.feature_column.embedding_column()`, but that one is intended for training 33 | from scratch.} 34 | } 35 | \description{ 36 | The input to this feature column is a batch of multiple strings with 37 | arbitrary size, assuming the input is a SparseTensor. 38 | } 39 | \details{ 40 | This type of feature column is typically suited for modules that operate 41 | on pre-tokenized text to produce token level embeddings which are combined 42 | with the combiner into a text embedding. The combiner always treats the tokens 43 | as a bag of words rather than a sequence. 44 | 45 | The output (i.e., transformed input layer) is a DenseTensor, with 46 | shape [batch_size, num_embedding_dim]. 47 | } 48 | -------------------------------------------------------------------------------- /man/hub_text_embedding_column.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/columns.R 3 | \name{hub_text_embedding_column} 4 | \alias{hub_text_embedding_column} 5 | \title{Module to construct a dense representation from a text feature.} 6 | \usage{ 7 | hub_text_embedding_column(key, module_spec, trainable = FALSE) 8 | } 9 | \arguments{ 10 | \item{key}{A string or [feature_column](https://tensorflow.rstudio.com/tfestimators/articles/feature_columns.html) 11 | identifying the text feature.} 12 | 13 | \item{module_spec}{A string handle or a _ModuleSpec identifying the module.} 14 | 15 | \item{trainable}{Whether or not the Module is trainable. `FALSE` by default, 16 | meaning the pre-trained weights are frozen. This is different from the ordinary 17 | `tf.feature_column.embedding_column()`, but that one is intended for training 18 | from scratch.} 19 | } 20 | \description{ 21 | This feature column can be used on an input feature whose values are strings of 22 | arbitrary size. 23 | } 24 | -------------------------------------------------------------------------------- /man/install_tfhub.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/install.R 3 | \name{install_tfhub} 4 | \alias{install_tfhub} 5 | \title{Install TensorFlow Hub} 6 | \usage{ 7 | install_tfhub(version = "release", ..., restart_session = TRUE) 8 | } 9 | \arguments{ 10 | \item{version}{version of TensorFlow Hub to be installed.} 11 | 12 | \item{...}{other arguments passed to [reticulate::py_install()].} 13 | 14 | \item{restart_session}{Restart R session after installing (note this will 15 | only occur within RStudio).} 16 | } 17 | \description{ 18 | This function is used to install the TensorFlow Hub python module. 19 | } 20 | -------------------------------------------------------------------------------- /man/layer_hub.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/layer.R 3 | \name{layer_hub} 4 | \alias{layer_hub} 5 | \title{Hub Layer} 6 | \usage{ 7 | layer_hub(object, handle, trainable = FALSE, arguments = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{object}{Model or layer object} 11 | 12 | \item{handle}{a callable object (subject to the conventions above), or a string 13 | for which `hub_load()` returns such a callable. A string is required to save 14 | the Keras config of this Layer.} 15 | 16 | \item{trainable}{Boolean controlling whether this layer is trainable.} 17 | 18 | \item{arguments}{optionally, a list with additional keyword arguments passed to 19 | the callable. These must be JSON-serializable to save the Keras config of 20 | this layer.} 21 | 22 | \item{...}{Other arguments that are passed to the TensorFlow Hub module.} 23 | } 24 | \description{ 25 | Wraps a Hub module (or a similar callable) for TF2 as a Keras Layer. 26 | } 27 | \details{ 28 | This layer wraps a callable object for use as a Keras layer. The callable 29 | object can be passed directly, or be specified by a string with a handle 30 | that gets passed to `hub_load()`. 31 | 32 | The callable object is expected to follow the conventions detailed below. 33 | (These are met by TF2-compatible modules loaded from TensorFlow Hub.) 34 | 35 | The callable is invoked with a single positional argument set to one tensor or 36 | a list of tensors containing the inputs to the layer. If the callable accepts 37 | a training argument, a boolean is passed for it. It is `TRUE` if this layer 38 | is marked trainable and called for training. 39 | 40 | If present, the following attributes of callable are understood to have special 41 | meanings: variables: a list of all tf.Variable objects that the callable depends on. 42 | trainable_variables: those elements of variables that are reported as trainable 43 | variables of this Keras Layer when the layer is trainable. regularization_losses: 44 | a list of callables to be added as losses of this Keras Layer when the layer is 45 | trainable. Each one must accept zero arguments and return a scalar tensor. 46 | } 47 | \examples{ 48 | 49 | \dontrun{ 50 | 51 | library(keras) 52 | 53 | model <- keras_model_sequential() \%>\% 54 | layer_hub( 55 | handle = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", 56 | input_shape = c(224, 224, 3) 57 | ) \%>\% 58 | layer_dense(1) 59 | 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /man/pipe.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/reexports.R 3 | \name{\%>\%} 4 | \alias{\%>\%} 5 | \title{Pipe operator} 6 | \usage{ 7 | lhs \%>\% rhs 8 | } 9 | \description{ 10 | See \code{\link[magrittr]{\%>\%}} for more details. 11 | } 12 | \keyword{internal} 13 | -------------------------------------------------------------------------------- /man/prep.step_pretrained_text_embedding.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe.R 3 | \name{prep.step_pretrained_text_embedding} 4 | \alias{prep.step_pretrained_text_embedding} 5 | \title{Prep method for step_pretrained_text_embedding} 6 | \usage{ 7 | prep.step_pretrained_text_embedding(x, training, info = NULL, ...) 8 | } 9 | \arguments{ 10 | \item{x}{object} 11 | 12 | \item{training}{wether or not it's training} 13 | 14 | \item{info}{variables state} 15 | 16 | \item{...}{One or more selector functions to choose variables.} 17 | } 18 | \description{ 19 | Prep method for step_pretrained_text_embedding 20 | } 21 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/reexports.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{install_tensorflow} 7 | \alias{tf} 8 | \alias{shape} 9 | \title{Objects exported from other packages} 10 | \keyword{internal} 11 | \description{ 12 | These objects are imported from other packages. Follow the links 13 | below to see their documentation. 14 | 15 | \describe{ 16 | \item{tensorflow}{\code{\link[tensorflow]{install_tensorflow}}, \code{\link[tensorflow]{shape}}, \code{\link[tensorflow]{tf}}} 17 | }} 18 | 19 | -------------------------------------------------------------------------------- /man/step_pretrained_text_embedding.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/recipe.R 3 | \name{step_pretrained_text_embedding} 4 | \alias{step_pretrained_text_embedding} 5 | \title{Pretrained text-embeddings} 6 | \usage{ 7 | step_pretrained_text_embedding( 8 | recipe, 9 | ..., 10 | role = "predictor", 11 | trained = FALSE, 12 | handle, 13 | args = NULL, 14 | skip = FALSE, 15 | id = recipes::rand_id("pretrained_text_embedding") 16 | ) 17 | } 18 | \arguments{ 19 | \item{recipe}{A recipe object. The step will be added to the 20 | sequence of operations for this recipe.} 21 | 22 | \item{...}{One or more selector functions to choose variables.} 23 | 24 | \item{role}{Role for the created variables} 25 | 26 | \item{trained}{A logical to indicate if the quantities for 27 | preprocessing have been estimated.} 28 | 29 | \item{handle}{the Module handle to resolve.} 30 | 31 | \item{args}{other arguments passed to [hub_load()].} 32 | 33 | \item{skip}{A logical. Should the step be skipped when the 34 | recipe is baked by [recipes::bake.recipe()]? While all operations are baked 35 | when [recipes::prep.recipe()] is run, some operations may not be able to be 36 | conducted on new data (e.g. processing the outcome variable(s)). 37 | Care should be taken when using `skip = TRUE` as it may affect 38 | the computations for subsequent operations} 39 | 40 | \item{id}{A character string that is unique to this step to identify it.} 41 | } 42 | \description{ 43 | `step_pretrained_text_embedding` creates a *specification* of a 44 | recipe step that will transform text data into its numerical 45 | transformation based on a pretrained model. 46 | } 47 | \examples{ 48 | 49 | \dontrun{ 50 | library(tibble) 51 | library(recipes) 52 | df <- tibble(text = c('hi', "heello", "goodbye"), y = 0) 53 | 54 | rec <- recipe(y ~ text, df) 55 | rec <- rec \%>\% step_pretrained_text_embedding( 56 | text, 57 | handle = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1" 58 | ) 59 | 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | 2 | library(testthat) 3 | library(tensorflow) 4 | library(tfhub) 5 | 6 | test_check("tfhub") 7 | 8 | -------------------------------------------------------------------------------- /tests/testthat/test-columns.R: -------------------------------------------------------------------------------- 1 | context("columns") 2 | 3 | source("utils.R") 4 | 5 | test_succeeds("hub_text_embedding_column can load a module", { 6 | 7 | column <- hub_text_embedding_column("sentence", 8 | "https://tfhub.dev/google/nnlm-en-dim128/1") 9 | expect_true(reticulate::py_has_attr(column, "name")) 10 | }) 11 | 12 | 13 | -------------------------------------------------------------------------------- /tests/testthat/test-layer-hub.R: -------------------------------------------------------------------------------- 1 | source("utils.R") 2 | 3 | test_succeeds("layer_hub works with sequential models", { 4 | 5 | library(keras) 6 | 7 | model <- keras_model_sequential() %>% 8 | layer_hub( 9 | handle = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", 10 | input_shape = c(224, 224, 3) 11 | ) %>% 12 | layer_dense(1) 13 | 14 | a <- tf$constant(array(0, dim = as.integer(c(1, 224, 224, 3))), dtype = "float32") 15 | 16 | res <- as.numeric(model(a)) 17 | 18 | expect_is(res, "numeric") 19 | }) 20 | 21 | test_succeeds("layer_hub works with functional API", { 22 | 23 | input <- layer_input(shape = c(224, 224, 3)) 24 | 25 | output <- input %>% 26 | layer_hub( 27 | handle = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4" 28 | ) %>% 29 | layer_dense(1) 30 | 31 | model <- keras_model(input, output) 32 | 33 | a <- tf$constant(array(0, dim = c(1, 224, 224, 3)), dtype = "float32") 34 | 35 | res <- as.numeric(model(a)) 36 | 37 | expect_is(res, "numeric") 38 | }) 39 | 40 | test_succeeds("can initialiaze the layer_hub", { 41 | 42 | features <- layer_hub( 43 | handle = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4" 44 | ) 45 | 46 | input <- layer_input(shape = c(224, 224, 3)) 47 | 48 | output <- input %>% 49 | features() %>% 50 | layer_dense(1) 51 | 52 | model <- keras_model(input, output) 53 | 54 | a <- tf$constant(array(0, dim = c(1, 224, 224, 3)), dtype = "float32") 55 | 56 | res <- as.numeric(model(a)) 57 | 58 | expect_is(res, "numeric") 59 | }) 60 | 61 | 62 | -------------------------------------------------------------------------------- /tests/testthat/test-load.R: -------------------------------------------------------------------------------- 1 | context("load") 2 | 3 | test_succeeds("Can load module from URL", { 4 | module <- hub_load("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4") 5 | expect_s3_class(module, "tensorflow.python.saved_model.load._UserObject") 6 | }) 7 | 8 | test_succeeds("Can load module from file path", { 9 | 10 | skip("Currently skipping due to bug exporting models on Windows") 11 | 12 | library(keras) 13 | 14 | input <- layer_input(shape = shape(1)) 15 | input2 <- layer_input(shape = shape(1)) 16 | output <- layer_add(list(input, input2)) 17 | 18 | model <- keras_model(list(input, input2), output) 19 | 20 | tmp <- tempfile() 21 | dir.create(tmp) 22 | 23 | export_savedmodel(model, tmp, remove_learning_phase = FALSE) 24 | 25 | module <- hub_load(tmp) 26 | expect_s3_class(module, "tensorflow.python.saved_model.load._UserObject") 27 | 28 | expect_equal( 29 | as.numeric(module(list(tf$ones(shape = c(1,1)), tf$ones(shape = c(1,1))))), 30 | 2 31 | ) 32 | }) 33 | 34 | test_succeeds("hub_load correctly uses the env var", { 35 | 36 | tmp <- tempfile() 37 | 38 | x <- callr::r( 39 | function() { 40 | tfhub::hub_load('https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4') 41 | }, 42 | env = c(TFHUB_CACHE_DIR = tmp) 43 | ) 44 | 45 | expect_length(list.files(tmp), 2) 46 | }) 47 | 48 | 49 | -------------------------------------------------------------------------------- /tests/testthat/test-recipe.R: -------------------------------------------------------------------------------- 1 | context("recipe") 2 | 3 | test_succeeds("Can use with recipes", { 4 | library(tibble) 5 | library(recipes) 6 | df <- tibble(text = c('hi', "heello", "goodbye"), y = 0) 7 | 8 | rec <- recipe(y ~ text, df) 9 | rec <- rec %>% step_pretrained_text_embedding( 10 | text, 11 | handle = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim-with-oov/1" 12 | ) 13 | 14 | rec <- prep(rec) 15 | 16 | x <- bake(rec, df) 17 | 18 | expect_s3_class(x, "data.frame") 19 | }) 20 | 21 | -------------------------------------------------------------------------------- /tests/testthat/utils.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | skip_if_no_tfhub <- function(required_version = NULL) { 4 | if (!reticulate::py_module_available("tensorflow_hub")) 5 | skip("TensorFlow Hub not available for testing") 6 | } 7 | 8 | skip_if_no_tf_version <- function(required_version) { 9 | if (!reticulate::py_module_available("tensorflow")) 10 | skip("TensorFlow is not available.") 11 | 12 | if (tensorflow::tf_version() < required_version) 13 | skip(paste0("Needs TF version >= ", required_version)) 14 | } 15 | 16 | test_succeeds <- function(desc, expr) { 17 | test_that(desc, { 18 | skip_if_no_tfhub() 19 | skip_if_no_tf_version("2.0") 20 | expect_error(force(expr), NA) 21 | }) 22 | } 23 | -------------------------------------------------------------------------------- /tfhub.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | hub-with-keras_files 3 | flower_photos 4 | hub-with-keras.R 5 | .build.timestamp 6 | -------------------------------------------------------------------------------- /vignettes/archived/feature_column.R: -------------------------------------------------------------------------------- 1 | #' In this example we will use the PetFinder dataset to demonstrate the 2 | #' feature_spec functionality with TensorFlow Hub. 3 | #' 4 | #' Currently, we need TensorFlow 2.0 nightly and disable eager execution 5 | #' in order for this example to work. 6 | #' 7 | #' Waiting for https://github.com/tensorflow/hub/issues/333 8 | #' 9 | #' 10 | #' 11 | 12 | # Notes about why this was archived: https://github.com/rstudio/tfdatasets/issues/81 13 | # 14 | # Snippets to download the data if we want to restore this example: 15 | # 16 | # pip install kaggle 17 | # login to kaggle.com, download API 'kaggle.json' file 18 | # mkdir ~/.kaggle 19 | # mv ~/kaggle.json ~/.kaggle/ 20 | # chmod 600 ~/.kaggle/kaggle.json 21 | # kaggle competitions download -c petfinder-adoption-prediction 22 | 23 | # unzip("petfinder-adoption-prediction.zip", exdir = "petfinder") 24 | 25 | 26 | library(keras) 27 | library(tfhub) 28 | library(tfdatasets) 29 | library(readr) 30 | library(dplyr) 31 | 32 | tf$compat$v1$disable_eager_execution() 33 | 34 | # Read data --------------------------------------------------------------- 35 | 36 | dataset <- read_csv("petfinder/train/train.csv") %>% 37 | filter(PhotoAmt > 0) %>% 38 | mutate(img_path = path.expand(paste0("petfinder/train_images/", PetID, "-1.jpg"))) %>% 39 | mutate_at(vars(Breed1:Health, State), as.character) %>% 40 | sample_n(size = nrow(.)) # shuffle 41 | 42 | dataset_tf <- dataset %>% 43 | tensor_slices_dataset() %>% 44 | dataset_map(function(x) { 45 | img <- tf$io$read_file(filename = x$img_path) %>% 46 | tf$image$decode_jpeg(channels = 3L) %>% 47 | tf$image$resize(size = c(224L, 224L)) 48 | x[["img"]] <- img/255 49 | x 50 | }) 51 | 52 | dataset_test <- dataset_tf %>% 53 | dataset_take(nrow(dataset)*0.2) %>% 54 | dataset_batch(512) 55 | 56 | dataset_train <- dataset_tf %>% 57 | dataset_skip(nrow(dataset)*0.2) %>% 58 | dataset_batch(32) 59 | 60 | # Build the feature spec -------------------------------------------------- 61 | 62 | spec <- dataset_train %>% 63 | feature_spec(AdoptionSpeed ~ .) %>% 64 | step_text_embedding_column( 65 | Description, 66 | module_spec = "https://tfhub.dev/google/universal-sentence-encoder/2" 67 | ) %>% 68 | step_image_embedding_column( 69 | img, 70 | module_spec = "https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/3" 71 | ) %>% 72 | # step_pretrained_text_embedding( 73 | # Description, 74 | # handle = "https://tfhub.dev/google/universal-sentence-encoder/2" 75 | # ) %>% 76 | # step_pretrained_text_embedding( 77 | # img, 78 | # handle = "https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/3" 79 | # ) %>% 80 | step_numeric_column(Age, Fee, Quantity, normalizer_fn = scaler_standard()) %>% 81 | step_categorical_column_with_vocabulary_list( 82 | has_type("string"), -Description, -RescuerID, -img_path, -PetID, -Name 83 | ) %>% 84 | step_embedding_column(Breed1:Health, State) 85 | 86 | spec <- fit(spec) 87 | 88 | # Build the model --------------------------------------------------------- 89 | 90 | inputs <- layer_input_from_dataset(dataset_train) %>% reticulate::py_to_r() 91 | inputs <- inputs[-which(names(inputs) == "AdoptionSpeed")] 92 | 93 | output <- inputs %>% 94 | layer_dense_features(spec$dense_features()) %>% 95 | layer_dropout(0.25) %>% 96 | layer_dense(units = 32, activation = "relu") %>% 97 | layer_dense(units = 5, activation = "softmax") 98 | 99 | model <- keras_model(inputs, output) 100 | 101 | model %>% 102 | compile( 103 | loss = "sparse_categorical_crossentropy", 104 | optimizer = "adam", 105 | metrics = "accuracy" 106 | ) 107 | 108 | # Fit the model ----------------------------------------------------------- 109 | 110 | sess <- k_get_session() 111 | sess$run(tf$compat$v1$initialize_all_variables()) 112 | sess$run(tf$compat$v1$initialize_all_tables()) 113 | 114 | model %>% 115 | fit( 116 | x = dataset_use_spec(dataset_train, spec), 117 | validation_data = dataset_use_spec(dataset_test, spec), 118 | epochs = 5 119 | ) 120 | 121 | 122 | -------------------------------------------------------------------------------- /vignettes/archived/feature_column.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: feature_column 3 | type: docs 4 | repo: https://github.com/rstudio/tfhub 5 | menu: 6 | main: 7 | parent: tfhub-examples 8 | --- 9 | 10 |