├── .Rbuildignore ├── .github ├── ISSUE_TEMPLATE │ └── installation-issue.md └── workflows │ ├── R-CMD-check.yaml │ ├── awaiting-response-close.yml │ └── awaiting-response-remove-label.yml ├── .gitignore ├── DESCRIPTION ├── LICENSE ├── NAMESPACE ├── NEWS.md ├── R ├── arguments.R ├── compat.R ├── deprecated.R ├── eager.R ├── estimator-generics.R ├── extract.R ├── flags.R ├── generics.R ├── help.R ├── install.R ├── modules.R ├── package.R ├── probability.R ├── reexports.R ├── save.R ├── seed.R ├── shape.R ├── tensorboard.R └── utils.R ├── README.md ├── cran-comments.md ├── docs └── index.html ├── man ├── all_dims.Rd ├── as_tensor.Rd ├── evaluate.Rd ├── export_savedmodel.Rd ├── figures │ ├── lifecycle-archived.svg │ ├── lifecycle-defunct.svg │ ├── lifecycle-deprecated.svg │ ├── lifecycle-experimental.svg │ ├── lifecycle-maturing.svg │ ├── lifecycle-questioning.svg │ ├── lifecycle-soft-deprecated.svg │ ├── lifecycle-stable.svg │ └── lifecycle-superseded.svg ├── install_tensorflow.Rd ├── parse_arguments.Rd ├── parse_flags.Rd ├── reexports.Rd ├── set_random_seed.Rd ├── shape.Rd ├── sub-.tensorflow.tensor.Rd ├── tensorboard.Rd ├── tensorflow.Rd ├── tf.Rd ├── tf_config.Rd ├── tf_extract_opts.Rd ├── tf_function.Rd ├── tf_gpu_configured.Rd ├── tf_probability.Rd ├── train.Rd ├── train_and_evaluate.Rd ├── use_compat.Rd ├── use_session_with_seed.Rd └── view_savedmodel.Rd ├── pkgdown └── _pkgdown.yml ├── tensorflow.Rproj └── tests ├── testthat.R └── testthat ├── .gitignore ├── helper-utils.R ├── setup.R ├── test-arguments.R ├── test-as_tensor.R ├── test-data-structures.R ├── test-examples.R ├── test-export-savedmodel.R ├── test-extract-syntax.R ├── test-generic-methods.R ├── test-seed.R ├── test-shape.R └── test-types.R /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^docs$ 4 | ^issues$ 5 | ^pkgdown$ 6 | ^inst/examples/mnist/MNIST-data$ 7 | ^tests/testthat/MNIST-data$ 8 | README.md 9 | README.html 10 | .Renviron 11 | LICENSE 12 | .travis.yml 13 | travis_install.sh 14 | run_tests.sh 15 | testenv 16 | MNIST-data 17 | ^appveyor\.yml$ 18 | ^CRAN-RELEASE$ 19 | ^\.github.*$ 20 | ^cran-comments\.md$ 21 | ^CRAN-SUBMISSION$ 22 | ^revdep$ 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/installation-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Installation Issue 3 | about: Trouble installing Tensorflow 4 | title: Installation Issue 5 | labels: 'Installation' 6 | assignees: '' 7 | --- 8 | 9 | Many installation issues are resolved by running the following in a **fresh R session** (you can restart R in Rstudio with Ctrl+Shift+F10) : 10 | ```R 11 | # install the development version of packages, in case the 12 | # issue is already fixed but not on CRAN yet. 13 | install.packages("pak") 14 | pak::pak(sprintf("rstudio/%s", c("reticulate", "tensorflow", "keras"))) 15 | if (is.null(reticulate::virtualenv_starter())) 16 | reticulate::install_python() 17 | tensorflow::install_tensorflow() 18 | ``` 19 | 20 | Test to see if installation was successful. 21 | ```R 22 | tensorflow::as_tensor("Hello World") 23 | ``` 24 | 25 | If the above snippet succeeded and you saw something like `tf.Tensor(b'Hello World', shape=(), dtype=string)`, then :tada:, you've successfully installed Tensorflow. 26 | 27 | If the above installation failed, please gather some diagnostic info: 28 | ```R 29 | reticulate::py_config() 30 | tensorflow::tf_config() 31 | reticulate::import("tensorflow") 32 | reticulate::py_last_error() 33 | sessionInfo() 34 | ``` 35 | 36 | Please copy and paste the FULL OUTPUT of running all three snippets, and be sure to enclose the output lines with three backticks (```) for monospace formatting. 37 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_dispatch: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | schedule: 8 | - cron: '21 3 * * Fri' 9 | 10 | name: R-CMD-check 11 | 12 | defaults: 13 | run: 14 | shell: Rscript {0} 15 | 16 | jobs: 17 | R-CMD-check: 18 | name: ${{ matrix.os }}, tf-${{ matrix.tf }}, R-${{ matrix.r}} 19 | timeout-minutes: 30 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | include: 24 | 25 | ## the happy path - default everything 26 | - {os: 'ubuntu-latest' , tf: 'default', r: 'release'} 27 | - {os: 'windows-latest', tf: 'default', r: 'release'} 28 | - {os: 'macOS-latest' , tf: 'default', r: 'release'} 29 | 30 | ## old R versions 31 | - {os: 'windows-latest', tf: 'default', r: 'oldrel'} 32 | - {os: 'macOS-latest' , tf: 'default', r: 'oldrel'} 33 | - {os: 'ubuntu-latest' , tf: 'default', r: 'oldrel'} 34 | - {os: 'ubuntu-latest' , tf: 'default', r: 'oldrel-1'} 35 | - {os: 'ubuntu-latest' , tf: 'default', r: 'oldrel-2'} 36 | - {os: 'ubuntu-latest' , tf: 'default', r: 'oldrel-3'} 37 | 38 | ## release keras/tf version (if different from 'default') 39 | # - {os: 'ubuntu-latest' , tf: 'release', r: 'release'} 40 | # - {os: 'windows-latest', tf: 'release', r: 'release'} 41 | # - {os: 'macOS-latest' , tf: 'release', r: 'release'} 42 | 43 | ## old keras/tf versions 44 | # - {os: 'ubuntu-latest', tf: '2.18', r: 'release'} 45 | - {os: 'ubuntu-latest', tf: '2.17', r: 'release'} 46 | - {os: 'ubuntu-latest', tf: '2.16', r: 'release'} 47 | - {os: 'ubuntu-latest', tf: '2.15', r: 'release'} 48 | - {os: 'ubuntu-latest', tf: '2.14', r: 'release'} 49 | - {os: 'ubuntu-latest', tf: '2.13', r: 'release'} 50 | - {os: 'ubuntu-latest', tf: '2.12', r: 'release'} 51 | 52 | # these are allowed to fail 53 | # - {os: 'ubuntu-latest', tf: '2.14.0rc1', r: 'release'} 54 | # - {os: 'ubuntu-20.04', tf: 'default', r: 'devel'} 55 | # - {os: 'ubuntu-20.04', tf: 'nightly' , r: 'release'} 56 | 57 | runs-on: ${{ matrix.os }} 58 | continue-on-error: ${{ matrix.tf == 'nightly' || contains(matrix.tf, 'rc') || matrix.r == 'devel' }} 59 | env: 60 | R_KEEP_PKG_SOURCE: yes 61 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 62 | # R_REMOTES_NO_ERRORS_FROM_WARNINGS: 'true' 63 | # R_COMPILE_AND_INSTALL_PACKAGES: 'never' 64 | 65 | steps: 66 | 67 | - uses: actions/checkout@v3 68 | 69 | - uses: r-lib/actions/setup-pandoc@v2 70 | 71 | - uses: r-lib/actions/setup-r@v2 72 | id: setup-r 73 | with: 74 | r-version: ${{ matrix.r }} 75 | use-public-rspm: true 76 | Ncpus: '2L' 77 | 78 | # - name: Get Date 79 | # id: get-date 80 | # shell: bash 81 | # run: | 82 | # echo "::set-output name=year-week::$(date -u "+%Y-%U")" 83 | # echo "::set-output name=date::$(date -u "+%F")" 84 | # 85 | # - name: Restore R package cache 86 | # uses: actions/cache@v2 87 | # id: r-package-cache 88 | # with: 89 | # path: ${{ env.R_LIBS_USER }} 90 | # key: ${{ matrix.os }}-${{ steps.setup-r.outputs.installed-r-version }}-${{ steps.get-date.outputs.year-week }}-2 91 | 92 | - uses: r-lib/actions/setup-r-dependencies@v2 93 | with: 94 | extra-packages: any::rcmdcheck local::. rstudio/reticulate 95 | cache-version: 4 96 | upgrade: 'TRUE' 97 | 98 | - name: Install TensorFlow 99 | run: | 100 | print(sessionInfo()) 101 | print(Sys.info()) 102 | version <- '${{ matrix.tf }}' 103 | if (version != "default") 104 | tensorflow::install_tensorflow(version = '${{ matrix.tf }}') 105 | tensorflow::tf_config() 106 | 107 | - uses: r-lib/actions/check-r-package@v2 108 | with: 109 | upload-snapshots: true 110 | 111 | -------------------------------------------------------------------------------- /.github/workflows/awaiting-response-close.yml: -------------------------------------------------------------------------------- 1 | name: Close 'awaiting response' labeled issues 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: '0 6 * * *' 7 | 8 | jobs: 9 | close-issues-with-no-response: 10 | 11 | runs-on: ubuntu-latest 12 | permissions: 13 | issues: write 14 | pull-requests: write 15 | 16 | steps: 17 | - uses: actions/stale@v4 18 | with: 19 | repo-token: ${{ secrets.GITHUB_TOKEN }} 20 | 21 | stale-issue-label: 'awaiting response' 22 | stale-pr-label: 'awaiting response' 23 | days-before-stale: -1 24 | 25 | days-before-close: 30 26 | close-issue-message: > 27 | Automatically closed because there has not been a response for 30 days. 28 | When you're ready to work on this further, please comment here and the 29 | issue will automatically reopen. 30 | -------------------------------------------------------------------------------- /.github/workflows/awaiting-response-remove-label.yml: -------------------------------------------------------------------------------- 1 | name: Remove 'awaiting response' label 2 | 3 | on: 4 | issue_comment: 5 | types: [created, edited] 6 | 7 | jobs: 8 | remove-awaiting-response-label: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Remove 'awaiting response' label and maybe reopen 12 | if: contains(github.event.issue.labels.*.name, 'awaiting response') 13 | run: | 14 | gh issue edit $ISSUE --remove-label 'awaiting response' 15 | gh issue reopen $ISSUE 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | ISSUE: ${{ github.event.issue.html_url }} 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/config.log 3 | **/config.status 4 | .Rproj.user 5 | .Rhistory 6 | .RData 7 | .Ruserdata 8 | src/*.o 9 | src/*.so 10 | src/*.dll 11 | README.html 12 | .Renviron 13 | MNIST-data 14 | issues/ 15 | logs 16 | revdep 17 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: tensorflow 2 | Type: Package 3 | Title: R Interface to 'TensorFlow' 4 | Version: 2.16.0.9000 5 | Authors@R: c( 6 | person("JJ", "Allaire", role = c("aut", "cph")), 7 | person("Tomasz", "Kalinowski", role = c("ctb", "cph", "cre"), 8 | email = "tomasz.kalinowski@posit.co"), 9 | person("Daniel", "Falbel", role = c("ctb", "cph"), email = "daniel@posit.co"), 10 | person("Dirk", "Eddelbuettel", role = c("ctb", "cph"), 11 | email = "edd@debian.org"), 12 | person("Yuan", "Tang", role = c("aut", "cph"), 13 | email = "terrytangyuan@gmail.com", 14 | comment = c(ORCID = "0000-0001-5243-233X")), 15 | person("Nick", "Golding", role = c("ctb", "cph"), 16 | email = "nick.golding.research@gmail.com"), 17 | person("Google Inc.", role = c("ctb", "cph"), 18 | comment = "Examples and Tutorials"), 19 | person("Posit, PBC", role = c("cph", "fnd")) 20 | ) 21 | Description: Interface to 'TensorFlow' , 22 | an open source software library for numerical computation using data 23 | flow graphs. Nodes in the graph represent mathematical operations, 24 | while the graph edges represent the multidimensional data arrays 25 | (tensors) communicated between them. The flexible architecture allows 26 | you to deploy computation to one or more 'CPUs' or 'GPUs' in a desktop, 27 | server, or mobile device with a single 'API'. 'TensorFlow' was originally 28 | developed by researchers and engineers working on the Google Brain Team 29 | within Google's Machine Intelligence research organization for the 30 | purposes of conducting machine learning and deep neural networks research, 31 | but the system is general enough to be applicable in a wide variety 32 | of other domains as well. 33 | License: Apache License 2.0 34 | URL: https://github.com/rstudio/tensorflow 35 | BugReports: https://github.com/rstudio/tensorflow/issues 36 | SystemRequirements: TensorFlow (https://www.tensorflow.org/) 37 | Encoding: UTF-8 38 | Depends: R (>= 3.6) 39 | Imports: 40 | config, 41 | processx, 42 | reticulate (>= 1.41.0), 43 | tfruns (>= 1.0), 44 | utils, 45 | yaml, 46 | grDevices, 47 | tfautograph (>= 0.3.1), 48 | rstudioapi (>= 0.7), 49 | lifecycle 50 | Roxygen: list(markdown = TRUE) 51 | Suggests: 52 | testthat (>= 2.1.0), 53 | keras3, 54 | pillar, 55 | withr, 56 | callr 57 | RoxygenNote: 7.3.2 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2015 The TensorFlow Authors. All rights reserved. 2 | Copyright 2016 RStudio, Inc. All rights reserved. 3 | 4 | Apache License 5 | Version 2.0, January 2004 6 | http://www.apache.org/licenses/ 7 | 8 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 9 | 10 | 1. Definitions. 11 | 12 | "License" shall mean the terms and conditions for use, reproduction, 13 | and distribution as defined by Sections 1 through 9 of this document. 14 | 15 | "Licensor" shall mean the copyright owner or entity authorized by 16 | the copyright owner that is granting the License. 17 | 18 | "Legal Entity" shall mean the union of the acting entity and all 19 | other entities that control, are controlled by, or are under common 20 | control with that entity. For the purposes of this definition, 21 | "control" means (i) the power, direct or indirect, to cause the 22 | direction or management of such entity, whether by contract or 23 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 24 | outstanding shares, or (iii) beneficial ownership of such entity. 25 | 26 | "You" (or "Your") shall mean an individual or Legal Entity 27 | exercising permissions granted by this License. 28 | 29 | "Source" form shall mean the preferred form for making modifications, 30 | including but not limited to software source code, documentation 31 | source, and configuration files. 32 | 33 | "Object" form shall mean any form resulting from mechanical 34 | transformation or translation of a Source form, including but 35 | not limited to compiled object code, generated documentation, 36 | and conversions to other media types. 37 | 38 | "Work" shall mean the work of authorship, whether in Source or 39 | Object form, made available under the License, as indicated by a 40 | copyright notice that is included in or attached to the work 41 | (an example is provided in the Appendix below). 42 | 43 | "Derivative Works" shall mean any work, whether in Source or Object 44 | form, that is based on (or derived from) the Work and for which the 45 | editorial revisions, annotations, elaborations, or other modifications 46 | represent, as a whole, an original work of authorship. For the purposes 47 | of this License, Derivative Works shall not include works that remain 48 | separable from, or merely link (or bind by name) to the interfaces of, 49 | the Work and Derivative Works thereof. 50 | 51 | "Contribution" shall mean any work of authorship, including 52 | the original version of the Work and any modifications or additions 53 | to that Work or Derivative Works thereof, that is intentionally 54 | submitted to Licensor for inclusion in the Work by the copyright owner 55 | or by an individual or Legal Entity authorized to submit on behalf of 56 | the copyright owner. For the purposes of this definition, "submitted" 57 | means any form of electronic, verbal, or written communication sent 58 | to the Licensor or its representatives, including but not limited to 59 | communication on electronic mailing lists, source code control systems, 60 | and issue tracking systems that are managed by, or on behalf of, the 61 | Licensor for the purpose of discussing and improving the Work, but 62 | excluding communication that is conspicuously marked or otherwise 63 | designated in writing by the copyright owner as "Not a Contribution." 64 | 65 | "Contributor" shall mean Licensor and any individual or Legal Entity 66 | on behalf of whom a Contribution has been received by Licensor and 67 | subsequently incorporated within the Work. 68 | 69 | 2. Grant of Copyright License. Subject to the terms and conditions of 70 | this License, each Contributor hereby grants to You a perpetual, 71 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 72 | copyright license to reproduce, prepare Derivative Works of, 73 | publicly display, publicly perform, sublicense, and distribute the 74 | Work and such Derivative Works in Source or Object form. 75 | 76 | 3. Grant of Patent License. Subject to the terms and conditions of 77 | this License, each Contributor hereby grants to You a perpetual, 78 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 79 | (except as stated in this section) patent license to make, have made, 80 | use, offer to sell, sell, import, and otherwise transfer the Work, 81 | where such license applies only to those patent claims licensable 82 | by such Contributor that are necessarily infringed by their 83 | Contribution(s) alone or by combination of their Contribution(s) 84 | with the Work to which such Contribution(s) was submitted. If You 85 | institute patent litigation against any entity (including a 86 | cross-claim or counterclaim in a lawsuit) alleging that the Work 87 | or a Contribution incorporated within the Work constitutes direct 88 | or contributory patent infringement, then any patent licenses 89 | granted to You under this License for that Work shall terminate 90 | as of the date such litigation is filed. 91 | 92 | 4. Redistribution. You may reproduce and distribute copies of the 93 | Work or Derivative Works thereof in any medium, with or without 94 | modifications, and in Source or Object form, provided that You 95 | meet the following conditions: 96 | 97 | (a) You must give any other recipients of the Work or 98 | Derivative Works a copy of this License; and 99 | 100 | (b) You must cause any modified files to carry prominent notices 101 | stating that You changed the files; and 102 | 103 | (c) You must retain, in the Source form of any Derivative Works 104 | that You distribute, all copyright, patent, trademark, and 105 | attribution notices from the Source form of the Work, 106 | excluding those notices that do not pertain to any part of 107 | the Derivative Works; and 108 | 109 | (d) If the Work includes a "NOTICE" text file as part of its 110 | distribution, then any Derivative Works that You distribute must 111 | include a readable copy of the attribution notices contained 112 | within such NOTICE file, excluding those notices that do not 113 | pertain to any part of the Derivative Works, in at least one 114 | of the following places: within a NOTICE text file distributed 115 | as part of the Derivative Works; within the Source form or 116 | documentation, if provided along with the Derivative Works; or, 117 | within a display generated by the Derivative Works, if and 118 | wherever such third-party notices normally appear. The contents 119 | of the NOTICE file are for informational purposes only and 120 | do not modify the License. You may add Your own attribution 121 | notices within Derivative Works that You distribute, alongside 122 | or as an addendum to the NOTICE text from the Work, provided 123 | that such additional attribution notices cannot be construed 124 | as modifying the License. 125 | 126 | You may add Your own copyright statement to Your modifications and 127 | may provide additional or different license terms and conditions 128 | for use, reproduction, or distribution of Your modifications, or 129 | for any such Derivative Works as a whole, provided Your use, 130 | reproduction, and distribution of the Work otherwise complies with 131 | the conditions stated in this License. 132 | 133 | 5. Submission of Contributions. Unless You explicitly state otherwise, 134 | any Contribution intentionally submitted for inclusion in the Work 135 | by You to the Licensor shall be under the terms and conditions of 136 | this License, without any additional terms or conditions. 137 | Notwithstanding the above, nothing herein shall supersede or modify 138 | the terms of any separate license agreement you may have executed 139 | with Licensor regarding such Contributions. 140 | 141 | 6. Trademarks. This License does not grant permission to use the trade 142 | names, trademarks, service marks, or product names of the Licensor, 143 | except as required for reasonable and customary use in describing the 144 | origin of the Work and reproducing the content of the NOTICE file. 145 | 146 | 7. Disclaimer of Warranty. Unless required by applicable law or 147 | agreed to in writing, Licensor provides the Work (and each 148 | Contributor provides its Contributions) on an "AS IS" BASIS, 149 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 150 | implied, including, without limitation, any warranties or conditions 151 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 152 | PARTICULAR PURPOSE. You are solely responsible for determining the 153 | appropriateness of using or redistributing the Work and assume any 154 | risks associated with Your exercise of permissions under this License. 155 | 156 | 8. Limitation of Liability. In no event and under no legal theory, 157 | whether in tort (including negligence), contract, or otherwise, 158 | unless required by applicable law (such as deliberate and grossly 159 | negligent acts) or agreed to in writing, shall any Contributor be 160 | liable to You for damages, including any direct, indirect, special, 161 | incidental, or consequential damages of any character arising as a 162 | result of this License or out of the use or inability to use the 163 | Work (including but not limited to damages for loss of goodwill, 164 | work stoppage, computer failure or malfunction, or any and all 165 | other commercial damages or losses), even if such Contributor 166 | has been advised of the possibility of such damages. 167 | 168 | 9. Accepting Warranty or Additional Liability. While redistributing 169 | the Work or Derivative Works thereof, You may choose to offer, 170 | and charge a fee for, acceptance of support, warranty, indemnity, 171 | or other liability obligations and/or rights consistent with this 172 | License. However, in accepting such obligations, You may act only 173 | on Your own behalf and on Your sole responsibility, not on behalf 174 | of any other Contributor, and only if You agree to indemnify, 175 | defend, and hold each Contributor harmless for any liability 176 | incurred by, or claims asserted against, such Contributor by reason 177 | of your accepting any such warranty or additional liability. 178 | 179 | END OF TERMS AND CONDITIONS 180 | 181 | APPENDIX: How to apply the Apache License to your work. 182 | 183 | To apply the Apache License to your work, attach the following 184 | boilerplate notice, with the fields enclosed by brackets "[]" 185 | replaced with your own identifying information. (Don't include 186 | the brackets!) The text should be enclosed in the appropriate 187 | comment syntax for the file format. We also recommend that a 188 | file or class name and description of purpose be included on the 189 | same "printed page" as the copyright notice for easier 190 | identification within third-party archives. 191 | 192 | Copyright 2015, The TensorFlow Authors. 193 | 194 | Licensed under the Apache License, Version 2.0 (the "License"); 195 | you may not use this file except in compliance with the License. 196 | You may obtain a copy of the License at 197 | 198 | http://www.apache.org/licenses/LICENSE-2.0 199 | 200 | Unless required by applicable law or agreed to in writing, software 201 | distributed under the License is distributed on an "AS IS" BASIS, 202 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 203 | See the License for the specific language governing permissions and 204 | limitations under the License. 205 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method("!",tensorflow.tensor) 4 | S3method("!=",tensorflow.python.framework.tensor_shape.TensorShape) 5 | S3method("!=",tensorflow.tensor) 6 | S3method("%%",tensorflow.tensor) 7 | S3method("%/%",tensorflow.tensor) 8 | S3method("&",tensorflow.tensor) 9 | S3method("*",tensorflow.tensor) 10 | S3method("+",tensorflow.tensor) 11 | S3method("-",tensorflow.tensor) 12 | S3method("/",tensorflow.tensor) 13 | S3method("<",tensorflow.tensor) 14 | S3method("<=",tensorflow.tensor) 15 | S3method("==",tensorflow.python.framework.tensor_shape.TensorShape) 16 | S3method("==",tensorflow.tensor) 17 | S3method(">",tensorflow.tensor) 18 | S3method(">=",tensorflow.tensor) 19 | S3method("[",tensorflow.python.framework.tensor_shape.TensorShape) 20 | S3method("[",tensorflow.tensor) 21 | S3method("[<-",tensorflow.python.framework.tensor_shape.TensorShape) 22 | S3method("[<-",tensorflow.tensor) 23 | S3method("[[",tensorflow.python.framework.tensor_shape.TensorShape) 24 | S3method("[[<-",tensorflow.python.framework.tensor_shape.TensorShape) 25 | S3method("^",tensorflow.tensor) 26 | S3method("|",tensorflow.tensor) 27 | S3method(.DollarNames,tensorflow.python.platform.flags._FlagValues) 28 | S3method(Arg,tensorflow.tensor) 29 | S3method(Conj,tensorflow.tensor) 30 | S3method(Im,tensorflow.tensor) 31 | S3method(Mod,tensorflow.tensor) 32 | S3method(Re,tensorflow.tensor) 33 | S3method(abs,tensorflow.tensor) 34 | S3method(acos,tensorflow.tensor) 35 | S3method(all,tensorflow.tensor) 36 | S3method(any,tensorflow.tensor) 37 | S3method(aperm,tensorflow.tensor) 38 | S3method(as.array,python.builtin.EagerTensor) 39 | S3method(as.array,tensorflow.python.framework.ops.EagerTensor) 40 | S3method(as.array,tensorflow.python.ops.variables.Variable) 41 | S3method(as.character,python.builtin.EagerTensor) 42 | S3method(as.character,tensorflow.python.framework.ops.EagerTensor) 43 | S3method(as.character,tensorflow.python.ops.variables.Variable) 44 | S3method(as.double,python.builtin.EagerTensor) 45 | S3method(as.double,tensorflow.python.framework.ops.EagerTensor) 46 | S3method(as.double,tensorflow.python.framework.tensor_shape.TensorShape) 47 | S3method(as.double,tensorflow.python.ops.variables.Variable) 48 | S3method(as.integer,python.builtin.EagerTensor) 49 | S3method(as.integer,tensorflow.python.framework.ops.EagerTensor) 50 | S3method(as.integer,tensorflow.python.framework.tensor_shape.TensorShape) 51 | S3method(as.integer,tensorflow.python.ops.variables.Variable) 52 | S3method(as.list,tensorflow.python.framework.tensor_shape.TensorShape) 53 | S3method(as.logical,python.builtin.EagerTensor) 54 | S3method(as.logical,tensorflow.python.framework.ops.EagerTensor) 55 | S3method(as.logical,tensorflow.python.ops.variables.Variable) 56 | S3method(as.matrix,python.builtin.EagerTensor) 57 | S3method(as.matrix,tensorflow.python.framework.ops.EagerTensor) 58 | S3method(as.matrix,tensorflow.python.ops.variables.Variable) 59 | S3method(as.numeric,python.builtin.EagerTensor) 60 | S3method(as.numeric,tensorflow.python.framework.ops.EagerTensor) 61 | S3method(as.numeric,tensorflow.python.framework.tensor_shape.TensorShape) 62 | S3method(as.numeric,tensorflow.python.ops.variables.Variable) 63 | S3method(as.raster,python.builtin.EagerTensor) 64 | S3method(as.raster,tensorflow.python.framework.ops.EagerTensor) 65 | S3method(as.raster,tensorflow.python.ops.variables.Variable) 66 | S3method(as.vector,python.builtin.EagerTensor) 67 | S3method(as.vector,tensorflow.python.framework.ops.EagerTensor) 68 | S3method(as.vector,tensorflow.python.ops.variables.Variable) 69 | S3method(as_tensor,default) 70 | S3method(as_tensor,double) 71 | S3method(as_tensor,tensorflow.python.framework.tensor_shape.TensorShape) 72 | S3method(asin,tensorflow.tensor) 73 | S3method(atan,tensorflow.tensor) 74 | S3method(c,tensorflow.python.framework.tensor_shape.TensorShape) 75 | S3method(cbind,tensorflow.tensor) 76 | S3method(ceiling,tensorflow.tensor) 77 | S3method(cos,tensorflow.tensor) 78 | S3method(cospi,tensorflow.tensor) 79 | S3method(digamma,tensorflow.tensor) 80 | S3method(dim,tensorflow.tensor) 81 | S3method(exp,tensorflow.tensor) 82 | S3method(expm1,tensorflow.tensor) 83 | S3method(export_savedmodel,tensorflow.python.client.session.Session) 84 | S3method(floor,tensorflow.tensor) 85 | S3method(format,tensorflow.python.framework.tensor_shape.TensorShape) 86 | S3method(is.finite,tensorflow.tensor) 87 | S3method(is.infinite,tensorflow.tensor) 88 | S3method(is.nan,tensorflow.tensor) 89 | S3method(length,tensorflow.python.framework.tensor_shape.TensorShape) 90 | S3method(length,tensorflow.tensor) 91 | S3method(lgamma,tensorflow.tensor) 92 | S3method(log,tensorflow.tensor) 93 | S3method(log10,tensorflow.tensor) 94 | S3method(log1p,tensorflow.tensor) 95 | S3method(log2,tensorflow.tensor) 96 | S3method(max,tensorflow.tensor) 97 | S3method(mean,tensorflow.tensor) 98 | S3method(merge,tensorflow.python.framework.tensor_shape.TensorShape) 99 | S3method(min,tensorflow.tensor) 100 | S3method(pillar::type_sum,tensorflow.tensor) 101 | S3method(print,tensorflow.python.framework.tensor_shape.TensorShape) 102 | S3method(print,tensorflow.tensor) 103 | S3method(print,tensorflow_config) 104 | S3method(prod,tensorflow.tensor) 105 | S3method(py_str,tensorflow.python.framework.tensor_shape.TensorShape) 106 | S3method(py_to_r,keras.src.utils.tracking.TrackedDict) 107 | S3method(py_to_r,keras.src.utils.tracking.TrackedList) 108 | S3method(py_to_r,keras.src.utils.tracking.TrackedSet) 109 | S3method(py_to_r,tensorflow.python.trackable.data_structures.ListWrapper) 110 | S3method(py_to_r,tensorflow.python.trackable.data_structures._DictWrapper) 111 | S3method(py_to_r,tensorflow.python.training.tracking.data_structures.ListWrapper) 112 | S3method(py_to_r,tensorflow.python.training.tracking.data_structures._DictWrapper) 113 | S3method(range,tensorflow.tensor) 114 | S3method(rbind,tensorflow.tensor) 115 | S3method(rep,tensorflow.tensor) 116 | S3method(round,tensorflow.tensor) 117 | S3method(sign,tensorflow.tensor) 118 | S3method(sin,tensorflow.tensor) 119 | S3method(sinpi,tensorflow.tensor) 120 | S3method(sort,tensorflow.tensor) 121 | S3method(sqrt,tensorflow.tensor) 122 | S3method(str,tensorflow.tensor) 123 | S3method(sum,tensorflow.tensor) 124 | S3method(t,tensorflow.tensor) 125 | S3method(tan,tensorflow.tensor) 126 | S3method(tanpi,tensorflow.tensor) 127 | export("%as%") 128 | export(all_dims) 129 | export(array_reshape) 130 | export(as_tensor) 131 | export(dict) 132 | export(evaluate) 133 | export(export_savedmodel) 134 | export(flag_boolean) 135 | export(flag_integer) 136 | export(flag_numeric) 137 | export(flag_string) 138 | export(flags) 139 | export(import) 140 | export(install_tensorflow) 141 | export(iterate) 142 | export(np_array) 143 | export(parse_arguments) 144 | export(parse_flags) 145 | export(run_dir) 146 | export(set_random_seed) 147 | export(shape) 148 | export(tensorboard) 149 | export(tf) 150 | export(tf_config) 151 | export(tf_extract_opts) 152 | export(tf_function) 153 | export(tf_gpu_configured) 154 | export(tf_probability) 155 | export(tf_version) 156 | export(train) 157 | export(train_and_evaluate) 158 | export(tuple) 159 | export(use_compat) 160 | export(use_condaenv) 161 | export(use_python) 162 | export(use_session_with_seed) 163 | export(use_virtualenv) 164 | export(view_savedmodel) 165 | import(reticulate) 166 | importFrom(grDevices,as.raster) 167 | importFrom(lifecycle,deprecated) 168 | importFrom(reticulate,"%as%") 169 | importFrom(reticulate,array_reshape) 170 | importFrom(reticulate,dict) 171 | importFrom(reticulate,import) 172 | importFrom(reticulate,iterate) 173 | importFrom(reticulate,np_array) 174 | importFrom(reticulate,tuple) 175 | importFrom(reticulate,use_condaenv) 176 | importFrom(reticulate,use_python) 177 | importFrom(reticulate,use_virtualenv) 178 | importFrom(tfruns,flag_boolean) 179 | importFrom(tfruns,flag_integer) 180 | importFrom(tfruns,flag_numeric) 181 | importFrom(tfruns,flag_string) 182 | importFrom(tfruns,flags) 183 | importFrom(tfruns,run_dir) 184 | importFrom(utils,.DollarNames) 185 | importFrom(utils,str) 186 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # tensorflow (development version) 2 | 3 | - Updates for reticulate 1.41. The tensorflow R package now calls 4 | `reticuate::py_require()` when it is loaded. Calling `install_tensorflow()` 5 | in most circumstances is no longer necessary. 6 | - `install_tensorflow()` installs TensorFlow v2.18 by default. 7 | - Fixed an issue where GPUs would not be found when running on Windows 8 | WSL Linux (reported in rstudio/keras3#1456, fixed in #599) 9 | - Fixes for NumPy 2.0 (#601) 10 | - Fixes for R-devel (4.5) 11 | 12 | # tensorflow 2.16.0 13 | 14 | - The package now Suggest 'keras3' instead of 'keras' 15 | - `install_tensorflow()` installs TensorFlow v2.16 by default. 16 | - If `install_tensorflow()` detects a GPU on Linux, it will automatically 17 | install the cuda package and configure required symlinks for cudnn and ptxax. 18 | 19 | # tensorflow 2.15.0 20 | 21 | - `install_tensorflow()` installs TensorFlow v2.15 by default 22 | - Added compatibility with the latest release of reticulate (> 1.34). 23 | 24 | # tensorflow 2.14.0 25 | 26 | - `install_tensorflow()` changes: 27 | - Installs TensorFlow v2.14 by default. 28 | - Now will automatically install the required Nvidia CUDA runtime as a pip 29 | package if on Linux and a GPU is detected. You can opt-out by passing 30 | `install_tensorflow(cuda = FALSE)`. Aside from the Nvidia driver, no other 31 | pre-existing Nvidia CUDA packages are now necessary. 32 | - The `configure_cudnn` argument is now superseded by the new argument `cuda`. 33 | - New argument `metal`, for specifying if the `tensorflow-metal` pip package 34 | should be installed on Arm Macs. Defaults to `TRUE` on Arm Macs. 35 | 36 | - Fixed an issue where `as.array()` and other methods might fail if the tensor 37 | had conversion disabled via `r_to_py()` or `convert = FALSE`. 38 | - Fixed an issue where Ops group generic dispatch would error one object was a tensor 39 | and the other was a non-tensor Python object (e.g., a numpy array). 40 | - Removed long deprecated symbols: 41 | `install_tensorflow_extras()`, `tfe_enable_eager_execution()` 42 | - tfestimator generics `train()` and `train_and_evaluate()` now warn about 43 | their deprecation status when called. The will be removed in a future release. 44 | 45 | # tensorflow 2.13.0 46 | 47 | - `install_tensorflow()` changes: 48 | - Installs TensorFlow v2.13 by default now. 49 | - The `envname` argument new default is `"r-tensorflow"`. This means that 50 | unless the `envname` argument supplied, `install_tensorflow()` will now 51 | install into the `"r-tensorflow"` environment, bootstrapping a venv of 52 | that name if necessary. 53 | - gains a `new_env` argument. If `TRUE`, any existing environment 54 | specified by `envname` is deleted and created anew. Defaults to `TRUE` if 55 | envname is `"r-tensorflow"`, `FALSE` otherwise. 56 | - If running on Linux, now detects if NVIDIA GPUs on Linux are installed, 57 | and if so, and installs cuDNN (via pip), configures symlinks for tensorflow 58 | to find cuDNN, and emits additional instructions for how to install the necessary CUDA 59 | drivers to enable GPU usage. Set new arg `configure_cudnn = FALSE` to disable. 60 | - `pip_ignore_installed` default is now `FALSE` again. 61 | - On Arm Macs (M1/M2), the default tensorflow package is once again installed, 62 | rather than `tensorflow-macos` and `tensorflow-metal`. 63 | 64 | - New `pillar:type_sum()` method for Tensors, giving a 65 | more informative printout of Tensors in R tracebacks and tibbles. 66 | 67 | # tensorflow 2.11.0 68 | 69 | - `install_tensorflow()` now installs TF v2.11 by default. 70 | 71 | - `as_tensor()` now coerces bare R atomic vectors to R arrays before conversion. 72 | As a consequence, by default, R atomic double vectors now coerce to 73 | 'float64' dtype tensors instead of 'float32'. 74 | 75 | - `shape()` gains the ability to accept vectors of length > 1 in `...`, 76 | including other `tf.TensorShape`s. Shapes are automatically flattened. 77 | 78 | - Fixed an issue where a `ListWrapper` object of trackable keras layers 79 | (e.g., as part of a keras model) would not convert to an R list. 80 | 81 | # tensorflow 2.9.0 82 | 83 | - Generic method updates: 84 | - New methods: 85 | all(), any(), sum(), prod(), min(), max(), mean(), range(), 86 | cbind(), rbind(), t(), aperm(), sort(), 87 | as.vector(), as.character(), as.raster(), 88 | is.infinite(), is.finite(), is.nan() 89 | - `^` will now invoke `tf.square()` or `tf.sqrt()` directly when appropriate 90 | - `|`, `&`, and `!` now cast arguments to 'bool' dtype. 91 | - `print()` now shows 1d shapes without a trailing commas. 92 | - `str()` method for tensors now returns only a single compact line; 93 | `str()` on a list of tensors now does something sensible. 94 | 95 | - `install_tensorflow()` now install TensorFlow 2.9 by default. 96 | 97 | - `install_tensorflow()` no longer requires conda on Windows, now works in a regular venv. 98 | 99 | - Comparing two partially-defined `TensorShape` now returns TRUE if each dimension matches. 100 | e.g.: `shape(NA, 4) == shape(NA, 4)` now returns TRUE, previously FALSE. 101 | 102 | - Tensors with dtype 'string' now convert to R character vectors by methods 103 | `as.array()` and `as.matrix()`. (previously they converted to python.builtin.bytes, 104 | or an R list of python.builtin.bytes objects) 105 | 106 | - `as_tensor()`: 107 | - atomic R integer vectors now convert to 'int32', not 'int64' 108 | - casting between integer and floating dtypes is now done via 109 | `tf$dtypes$saturate_cast()` instead of `tf$cast()`. 110 | - `shape` argument now accepts a tensor. 111 | - fixed issue where expanding a scalar tensor to an nd-array with 112 | `shape` provided as a tensor would raise an error. 113 | 114 | - `tf.SparseTensor` objects now inherit from `"tensorflow.tensor"`. 115 | 116 | # tensorflow 2.8.0 117 | 118 | - Updated default Tensorflow version installed by `install_tensorflow()` to 2.8. 119 | 120 | - `as_tensor()` gains a `shape` argument, can be used to fill or reshape tensors. 121 | Scalars can be recycled to a tensor of arbitrary `shape`, otherwise 122 | supplied objects are reshaped using row-major (C-style) semantics. 123 | 124 | - `install_tensorflow()` now provides experimental support for Arm Macs, 125 | with the following restrictions: 126 | - "conda" is the only supported installation method. 127 | - requests for non-default or older tensorflow versions are not supported. 128 | 129 | - `install_tensorflow()` default conda_python_version changes from 3.7 to NULL. 130 | 131 | - `tf.TensorShape()`'s gain `format()` and `print()` S3 methods. 132 | 133 | - `[` method for slicing tensors now accepts `NA` as a synonym for a missing or `NULL` spec. 134 | For example `x[NA:3]` is now valid, equivalent to `x[:3]` in Python. 135 | 136 | # tensorflow 2.7.0 137 | 138 | - Default Tensorflow version installed by `install_tensorflow()` updated to 2.7 139 | 140 | - Breaking changes: 141 | - `shape()` now returns a `tf.TensorShape()` object 142 | (Previously an R-list of `NULL`s or integers). 143 | - `[` method for `tf.TensorShape()` objects also now returns a `tf.TensorShape()`. 144 | Use `[[`, `as.numeric`, `as.integer`, and/or `as.list` to convert to R objects. 145 | - `length()` method for `tensorflow.tensor` now returns `NA_integer_` for 146 | tensors with not fully defined shapes. (previously a zero length integer vector). 147 | - `dim()` method for `tensorflow.tensor` now returns an R integer vector 148 | with `NA` for dimensions that are undefined. 149 | (previously an R list with `NULL` for undefined dimension) 150 | 151 | - New S3 generics for `tf.TensorShape()`'s: 152 | `c`, `length`, `[<-`, `[[<-`, `merge`, `==`, `!=`, `as_tensor()`, 153 | `as.list`, `as.integer`, `as.numeric`, `as.double`, `py_str` 154 | (joining previous generics `[` and `[[`). 155 | See `?shape` for extended examples. 156 | 157 | - Ops S3 generics for `tensorflow.tensor`s that take two arguments now 158 | automatically cast a supplied non-tensor to the dtype of the supplied tensor 159 | that triggered the S3 dispatch. Casting is done via `as_tensor()`. 160 | e.g., this now works: 161 | ``` 162 | as_tensor(5L) - 2 # now returns tf.Tensor(3, shape=(), dtype=int32) 163 | ``` 164 | previously it would raise an error: 165 | ``` 166 | TypeError: `x` and `y` must have the same dtype, got tf.int32 != tf.float32 167 | ``` 168 | Generics that now do autocasting: 169 | +, -, *, /, %/%, %%, ^, &, |, ==, !=, <, <=, >, >= 170 | 171 | - `install_tensorflow()`: new argument with default `pip_ignore_installed = TRUE`. 172 | This ensures that all Tensorflow dependencies like Numpy are installed by pip 173 | rather than conda. 174 | 175 | - A message with the Tensorflow version is now shown when the 176 | python module is loaded, e.g: "Loaded Tensorflow version 2.6.0" 177 | 178 | # tensorflow 2.6.0 179 | 180 | - Updated default Tensorflow version to 2.6. 181 | 182 | - Changed default in `tf_function()` to `autograph=TRUE`. 183 | 184 | - Added S3 generic `as_tensor()`. 185 | 186 | - tfautograph added to Imports 187 | 188 | - jsonlite removed from Imports, tfestimators removed from Suggests 189 | 190 | - Refactored `install_tensorflow()`. 191 | - Potentially breaking change: numeric versions supplied without a patchlevel now automatically pull the latest patch release. 192 | (e.g. `install_tensorflow(version="2.4")` will install `"2.4.2"`. Previously it would install "2.4.0") 193 | 194 | - Removed "Config/reticulate" declaration from DESCRIPTION. 195 | - Setting `RETICULATE_AUTOCONFIGURE=FALSE` environment variable when using non-default tensorflow installations (e.g., 'tensorflow-cpu') no longer required. 196 | - Users will have to call `install_tensorflow()` for automatic installation. 197 | 198 | - Refactored automated tests to closer match the default installation procedure 199 | and compute environment of most user. 200 | 201 | - Expanded CI test coverage to include R devel, oldrel and 3.6. 202 | 203 | - Fixed an issue where extra packages with version constraints like 204 | `install_tensorflow(extra_packages = "Pillow<8.3")` were not quoted properly. 205 | 206 | - Fixed an issue where valid tensor-like objects supplied to 207 | `log(x, base)`, `cospi()`, `tanpi()`, and `sinpi()` would raise an error. 208 | 209 | 210 | # tensorflow 2.5.0 211 | 212 | - Updated default Tensorflow version to 2.5. 213 | - Added support for additional arguments in `tf_function()` (e.g., `jit_compile`) 214 | - Added support for `expm1` S3 generic. 215 | - `tfe_enable_eager_execution` is deprecated. Eager mode has been the default since TF version 2.0. 216 | - Improved error message in `tf_config()` on unsuccessful installation. 217 | 218 | # tensorflow 2.4.0 219 | 220 | - Fixed error with `use_session_with_seed` (#428) 221 | - Added a new `set_random_seed` function that makes more sense for TensorFlow >= 2.0 (#442) 222 | - Updated the default version of TensorFlow to 2.4 as well as the default Python to 3.7 (#454) 223 | 224 | # TensorFlow 2.2.0 (CRAN) 225 | 226 | - Bugfix with `all_dims` (#398) 227 | 228 | - Indexing for TensorShape & `py_to_r` conversion (#379, #388) 229 | 230 | # TensorFlow 2.0.0 (CRAN) 231 | 232 | - Upgraded default installed version to 2.0.0. 233 | 234 | - Tensorboard log directory path fixes (#360). 235 | 236 | - Allow for `v1` and `v2` compat (#358). 237 | 238 | - `install_tensorflow` now does not installs `tfprobability`, `tfhub` and other 239 | related packages. 240 | 241 | # TensorFlow 1.14.1 (CRAN) 242 | 243 | - Upgraded default installed version to 1.14.0 244 | 245 | - Refactored the `install_tensorflow` code delegating to `reticulate` (#333, #341): We completely delegate to installation to `reticulate::py_install`, the main difference is that now the default environment name to install is `r-reticulate` and not `r-tensorflow`. 246 | 247 | # TensorFlow 1.13.1 (CRAN) 248 | 249 | - added option to silence TF CPP info output 250 | 251 | - `tf_gpu_configured` function to check if GPU was correctly 252 | -------------------------------------------------------------------------------- /R/arguments.R: -------------------------------------------------------------------------------- 1 | #' Parse Command Line Arguments 2 | #' 3 | #' Parse command line arguments of the form `--key=value` and 4 | #' `--key value`. The values are assumed to be valid `yaml` and 5 | #' will be converted using [yaml::yaml.load()]. 6 | #' 7 | #' @param arguments A vector of command line arguments. When 8 | #' `NULL` (the default), the command line arguments received 9 | #' by the current \R process are used. 10 | #' 11 | #' @export 12 | parse_arguments <- function(arguments = NULL) { 13 | arguments <- arguments %||% commandArgs(TRUE) 14 | 15 | # initialize some state 16 | values <- list() 17 | 18 | i <- 0; n <- length(arguments) 19 | while (i < n) { 20 | i <- i + 1 21 | argument <- arguments[[i]] 22 | 23 | # skip any command line arguments without a '--' prefix 24 | if (!grepl("^--", argument)) 25 | next 26 | 27 | # check to see if an '=' was specified for this argument 28 | equals_idx <- regexpr("=", argument) 29 | if (identical(c(equals_idx), -1L)) { 30 | # no '='; the next argument is the value for this key 31 | key <- substring(argument, 3) 32 | val <- arguments[[i + 1]] 33 | i <- i + 1 34 | } else { 35 | # found a '='; the next argument is all the text following 36 | # that character 37 | key <- substring(argument, 3, equals_idx - 1) 38 | val <- substring(argument, equals_idx + 1) 39 | } 40 | 41 | # convert '-' to '_' in key 42 | key <- gsub("-", "_", key) 43 | 44 | # update our map of argument values 45 | values[[key]] <- yaml::yaml.load(val) 46 | } 47 | 48 | values 49 | 50 | } 51 | -------------------------------------------------------------------------------- /R/compat.R: -------------------------------------------------------------------------------- 1 | 2 | #' Use Compatibility 3 | #' 4 | #' Enables TensorFlow to run under a different API version for compatibility 5 | #' with previous versions. For instance, this is useful to run TensorFlow 1.x 6 | #' code when using TensorFlow 2.x. 7 | #' 8 | #' @param version The version to activate. Must be `"v1"` or `"v2"` 9 | #' 10 | #' @examples 11 | #' \dontrun{ 12 | #' library(tensorflow) 13 | #' use_compat("v1") 14 | #' } 15 | #' 16 | #' @export 17 | use_compat <- function(version = c("v1", "v2")) { 18 | if (identical(tf_version(), NULL)) return() 19 | 20 | version <- match.arg(version) 21 | 22 | tf2 <- tf 23 | 24 | unlock <- get("unlockBinding") 25 | lock <- get("lockBinding") 26 | 27 | unlock("tf", as.environment("package:tensorflow")) 28 | on.exit(lock("tf", as.environment("package:tensorflow"))) 29 | 30 | assign("tf", tf$compat[[version]], as.environment("package:tensorflow")) 31 | 32 | invisible(tf2) 33 | } 34 | -------------------------------------------------------------------------------- /R/deprecated.R: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rstudio/tensorflow/94bccf65824a51cefb6c55d2c6a99d4e4dbaa6bb/R/deprecated.R -------------------------------------------------------------------------------- /R/eager.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | #' @export 4 | as.array.python.builtin.EagerTensor <- function(x, ...) { 5 | if (py_is_null_xptr(x)) 6 | return(NULL) 7 | if(as_r_value(x$dtype$name) == "string") 8 | array(as.character(x, ...), 9 | dim = if(length(dx <- dim(x))) dx else 1L) 10 | else 11 | as_r_value(x$numpy()) 12 | } 13 | 14 | #' @export 15 | as.array.tensorflow.python.framework.ops.EagerTensor <- as.array.python.builtin.EagerTensor 16 | 17 | #' @export 18 | as.array.tensorflow.python.ops.variables.Variable <- as.array.python.builtin.EagerTensor 19 | 20 | 21 | 22 | #' @export 23 | as.vector.python.builtin.EagerTensor <- function(x, ...) { 24 | as.vector(as.array(x, ...)) 25 | } 26 | 27 | #' @export 28 | as.vector.tensorflow.python.framework.ops.EagerTensor <- as.vector.python.builtin.EagerTensor 29 | 30 | #' @export 31 | as.vector.tensorflow.python.ops.variables.Variable <- as.vector.python.builtin.EagerTensor 32 | 33 | 34 | #' @export 35 | as.matrix.python.builtin.EagerTensor <- function(x, ...) { 36 | if (py_is_null_xptr(x)) 37 | return(NULL) 38 | 39 | as.matrix(as.array(x, ...)) 40 | } 41 | 42 | #' @export 43 | as.matrix.tensorflow.python.framework.ops.EagerTensor <- as.matrix.python.builtin.EagerTensor 44 | 45 | #' @export 46 | as.matrix.tensorflow.python.ops.variables.Variable <- as.matrix.python.builtin.EagerTensor 47 | 48 | 49 | #' @export 50 | as.integer.python.builtin.EagerTensor <- function(x, ...) { 51 | if (py_is_null_xptr(x)) 52 | NULL 53 | else 54 | as.integer(as.array(x)) 55 | } 56 | 57 | #' @export 58 | as.integer.tensorflow.python.framework.ops.EagerTensor <- as.integer.python.builtin.EagerTensor 59 | 60 | #' @export 61 | as.integer.tensorflow.python.ops.variables.Variable <- as.integer.python.builtin.EagerTensor 62 | 63 | 64 | #' @export 65 | as.numeric.python.builtin.EagerTensor <- function(x, ...) { 66 | if (py_is_null_xptr(x)) 67 | NULL 68 | else 69 | as.numeric(as.array(x)) 70 | } 71 | 72 | #' @export 73 | as.numeric.tensorflow.python.framework.ops.EagerTensor <- as.numeric.python.builtin.EagerTensor 74 | 75 | #' @export 76 | as.numeric.tensorflow.python.ops.variables.Variable <- as.numeric.python.builtin.EagerTensor 77 | 78 | 79 | #' @export 80 | as.double.python.builtin.EagerTensor <- function(x, ...) { 81 | if (py_is_null_xptr(x)) 82 | NULL 83 | else 84 | as.double(as.array(x)) 85 | } 86 | 87 | #' @export 88 | as.double.tensorflow.python.framework.ops.EagerTensor <- as.double.python.builtin.EagerTensor 89 | 90 | #' @export 91 | as.double.tensorflow.python.ops.variables.Variable <- as.double.python.builtin.EagerTensor 92 | 93 | 94 | #' @export 95 | as.logical.python.builtin.EagerTensor <- function(x, ...) { 96 | if (py_is_null_xptr(x)) 97 | NULL 98 | else 99 | as.logical(as.array(x)) 100 | } 101 | 102 | #' @export 103 | as.logical.tensorflow.python.framework.ops.EagerTensor <- as.logical.python.builtin.EagerTensor 104 | 105 | #' @export 106 | as.logical.tensorflow.python.ops.variables.Variable <- as.logical.python.builtin.EagerTensor 107 | 108 | #' @export 109 | as.character.python.builtin.EagerTensor <- function(x, ...) { 110 | out <- as_r_value(x$numpy()) 111 | # as.character() on python bytes dispatches to 112 | # reticulate:::as.character.python.builtin.bytes, which calls 113 | # x$decode(encoding = "utf-8", errors = "strict") 114 | if(is.list(out)) 115 | vapply(out, as.character, "", ..., USE.NAMES = FALSE) 116 | else 117 | as.character(out, ...) 118 | } 119 | 120 | #' @export 121 | as.character.tensorflow.python.framework.ops.EagerTensor <- 122 | as.character.python.builtin.EagerTensor 123 | 124 | #' @export 125 | as.character.tensorflow.python.ops.variables.Variable <- 126 | as.character.python.builtin.EagerTensor 127 | 128 | ## @exportS3Method grDevices::as.raster A delayed registration like this requires R>=3.6 129 | 130 | #' @importFrom grDevices as.raster 131 | #' @export 132 | as.raster.python.builtin.EagerTensor <- 133 | function(x, max = if(as_r_value(x$dtype$is_integer)) as_r_value(x$dtype$max) else 1, ...) 134 | as.raster(as.array(x), max = max, ...) 135 | 136 | #' @export 137 | as.raster.tensorflow.python.framework.ops.EagerTensor <- 138 | as.raster.python.builtin.EagerTensor 139 | 140 | #' @export 141 | as.raster.tensorflow.python.ops.variables.Variable <- 142 | as.raster.python.builtin.EagerTensor 143 | 144 | 145 | 146 | #' Creates a callable TensorFlow graph from an R function. 147 | #' 148 | #' `tf_function` constructs a callable that executes a TensorFlow graph created 149 | #' by tracing the TensorFlow operations in `f`. This allows the TensorFlow 150 | #' runtime to apply optimizations and exploit parallelism in the computation 151 | #' defined by `f`. 152 | #' 153 | #' A guide to getting started with 154 | #' [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) can 155 | #' be found [here](https://www.tensorflow.org/guide/function). 156 | #' 157 | #' @param f the function to be compiled 158 | #' @param input_signature A possibly nested sequence of `tf$TensorSpec` objects 159 | #' specifying the shapes and dtypes of the tensors that will be supplied to 160 | #' this function. If `NULL`, a separate function is instantiated for each 161 | #' inferred input signature. If `input_signature` is specified, every input to 162 | #' `f` must be a tensor. 163 | #' @param autograph TRUE or FALSE. If TRUE (the default), you can use tensors in 164 | #' R control flow expressions `if`, `while`, `for` and `break` and they will 165 | #' be traced into the tensorflow graph. A guide to getting started and 166 | #' additional details can be found: 167 | #' [here](https://t-kalinowski.github.io/tfautograph/) 168 | #' @param ... additional arguments passed on to `tf.function` (vary based on 169 | #' Tensorflow version). See 170 | #' [here](https://www.tensorflow.org/api_docs/python/tf/function#args_1) for 171 | #' details. 172 | #' 173 | #' @export 174 | tf_function <- function(f, 175 | input_signature = NULL, 176 | autograph = TRUE, 177 | ...) { 178 | if (!is.function(f)) 179 | stop("`f` must be an R function") 180 | 181 | if (!(isTRUE(autograph) || isFALSE(autograph))) 182 | stop("`autograph` must be TRUE or FALSE") 183 | 184 | if (autograph) { 185 | # Can't register tfautograph in Imports yet due to circular dependency 186 | if(!requireNamespace("tfautograph", quietly=TRUE)) 187 | stop('"tfautograph" package required if autograph=TRUE. Please run install.packages("tfautograph")') 188 | f <- tfautograph::autograph(f) 189 | } 190 | 191 | args <- list(py_func(f), input_signature, FALSE, ...) 192 | do.call(tf$`function`, args) 193 | } 194 | 195 | # TODO: calling tf_function() with `f` missing should return 196 | # a decorator with args partially pre-specified 197 | -------------------------------------------------------------------------------- /R/estimator-generics.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | #' (Deprecated) Train a Model 4 | #' 5 | #' Train a model object. See implementation in the 6 | #' `tfestimators::train.tf_estimator()` package. 7 | #' 8 | #' @param object A trainable \R object. 9 | #' @param ... Optional arguments passed on to implementing methods. 10 | #' 11 | #' `r lifecycle::badge('deprecated')` 12 | #' 13 | #' @export 14 | #' @keywords internal 15 | train <- function(object, ...) { 16 | lifecycle::deprecate_warn("2.9", "train()", "fit()") 17 | UseMethod("train") 18 | } 19 | 20 | #' Evaluate a Model 21 | #' 22 | #' Evaluate a model object. See implementations in the 23 | #' [keras3][keras3::evaluate.keras.src.models.model.Model()] package. 24 | #' 25 | #' @param object An evaluatable \R object. 26 | #' @param ... Optional arguments passed on to implementing methods. 27 | #' 28 | #' @section Implementations: 29 | #' 30 | #' - [keras3][keras3::evaluate.keras.src.models.model.Model()] 31 | #' 32 | #' @export 33 | evaluate <- function(object, ...) { 34 | UseMethod("evaluate") 35 | } 36 | 37 | #' (Deprecated) Simultaneously Train and Evaluate a Model 38 | #' 39 | #' Train and evaluate a model object. See implementation in the 40 | #' `tfestimators::train_and_evaluate.tf_estimator()` package. 41 | #' 42 | #' @param object An \R object. 43 | #' @param ... Optional arguments passed on to implementing methods. 44 | #' 45 | #'`r lifecycle::badge('deprecated')` 46 | #' 47 | #' @keywords internal 48 | #' @export 49 | train_and_evaluate <- function(object, ...) { 50 | lifecycle::deprecate_warn("2.9", "train_and_evaluate()") 51 | UseMethod("train_and_evaluate") 52 | } 53 | 54 | 55 | #' Export a Saved Model 56 | #' 57 | #' Serialize a model to disk. See implementations in the 58 | #' [keras3][keras3::export_savedmodel.keras.src.models.model.Model()] 59 | #' package. 60 | #' 61 | #' @param object An \R object. 62 | #' @param export_dir_base A string containing a directory in which to export the 63 | #' SavedModel. 64 | #' @param ... Optional arguments passed on to implementing methods. 65 | #' 66 | #' @return The path to the exported directory, as a string. 67 | #' 68 | #' @section Implementations: 69 | #' 70 | #' - [keras3][keras3::export_savedmodel.keras.src.models.model.Model] 71 | #' 72 | #' @keywords internal 73 | #' @export 74 | export_savedmodel <- function( 75 | object, 76 | export_dir_base, 77 | ...) { 78 | UseMethod("export_savedmodel") 79 | } 80 | 81 | -------------------------------------------------------------------------------- /R/flags.R: -------------------------------------------------------------------------------- 1 | #' Parse Configuration Flags for a TensorFlow Application 2 | #' 3 | #' Parse configuration flags for a TensorFlow application. Use 4 | #' this to parse and unify the configuration(s) specified through 5 | #' a `flags.yml` configuration file, alongside other arguments 6 | #' set through the command line. 7 | #' 8 | #' @param config The configuration to use. Defaults to the 9 | #' active configuration for the current environment (as 10 | #' specified by the `R_CONFIG_ACTIVE` environment 11 | #' variable), or `default` when unset. 12 | #' @param file The configuration file to read. 13 | #' @param arguments The command line arguments (as a 14 | #' character vector) to be parsed. 15 | #' 16 | #' @return A named \R list, mapping configuration keys to values. 17 | 18 | #' @examples 19 | #' \dontrun{ 20 | #' # examine an example configuration file provided by tensorflow 21 | #' file <- system.file("examples/config/flags.yml", package = "tensorflow") 22 | #' cat(readLines(file), sep = "\n") 23 | #' 24 | #' # read the default configuration 25 | #' FLAGS <- tensorflow::parse_flags("default", file = file) 26 | #' str(FLAGS) 27 | #' 28 | #' # read the alternate configuration: note that 29 | #' # the default configuration is inherited, but 30 | #' # we override the 'string' configuration here 31 | #' FLAGS <- tensorflow::parse_flags("alternate", file = file) 32 | #' str(FLAGS) 33 | #' 34 | #' # override configuration values using command 35 | #' # line arguments (normally, these would be 36 | #' # passed in through the command line invocation 37 | #' # used to start the process) 38 | #' FLAGS <- tensorflow::parse_flags( 39 | #' "alternate", 40 | #' file = file, 41 | #' arguments = c("--foo=1") 42 | #' ) 43 | #' str(FLAGS) 44 | #' 45 | #' } 46 | #' @export 47 | parse_flags <- 48 | function(config = Sys.getenv("R_CONFIG_ACTIVE", unset = "default"), 49 | file = "flags.yml", 50 | arguments = commandArgs(TRUE)) 51 | { 52 | flags <- list() 53 | 54 | # warn if the user has supplied a 'file' argument but no such file exists 55 | if (!missing(file) && !file.exists(file)) 56 | warning(sprintf("configuration file '%s' does not exist", file)) 57 | 58 | # read configuration file if it does exist 59 | if (file.exists(file)) 60 | flags <- config::get(config = config, file = file) 61 | 62 | # backwards compatibility -- if the user is using the 63 | # TensorFlow FLAGS system for handling their 64 | # configuration, then explicitly read the configuration 65 | # from that; otherwise run our own parser 66 | actions <- tf$app$flags$`_global_parser`$`_actions` 67 | if (length(actions) > 1) { 68 | flags <- config::merge(flags, parse_tensorflow_flags(arguments)) 69 | } else { 70 | flags <- config::merge(flags, parse_arguments(arguments)) 71 | } 72 | 73 | # return generated config 74 | flags 75 | } 76 | 77 | parse_tensorflow_flags <- function(args = commandArgs(TRUE)) { 78 | 79 | # parse known arguments using the global parser 80 | parser <- tf$app$flags$`_global_parser` 81 | result <- tryCatch( 82 | parser$parse_known_args(as.list(args)), 83 | error = function(e) NULL 84 | ) 85 | 86 | # check for error (means user invoked --help) 87 | if (is.null(result)) { 88 | if (interactive()) 89 | return(NULL) 90 | else 91 | quit(save = "no") 92 | } 93 | 94 | # return parsed flags as named R list 95 | result[[1]]$`__dict__` 96 | } 97 | 98 | -------------------------------------------------------------------------------- /R/help.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | register_tf_help_handler <- function() { 4 | reticulate::register_module_help_handler("tensorflow", function(name, subtopic = NULL) { 5 | 6 | # # Version specific URLs are disabled because 7 | # # upstream TF is missing public docs for later version 8 | # # https://github.com/tensorflow/tensorflow/issues/89084 9 | # # get the base tensorflow help url 10 | # version <- tf$`__version__` 11 | # version <- strsplit(version, ".", fixed = TRUE)[[1]] 12 | # help_url <- paste0("https://www.tensorflow.org/versions/r", 13 | # version[1], ".", version[2], "/api_docs/python/") 14 | 15 | help_url <- "https://www.tensorflow.org/api_docs/python/" 16 | 17 | # some adjustments 18 | name <- sub("^tensorflow\\._api\\.v2\\.", "tensorflow.", name) 19 | name <- sub("^tensorflow", "tf", name) 20 | name <- sub("python.client.session.", "", name, fixed = TRUE) 21 | name <- sub("python.ops.", "", name, fixed = TRUE) 22 | if (grepl("tf.contrib.opt", name)) { 23 | components <- strsplit(name, ".", fixed = TRUE)[[1]] 24 | class_name <- components[[length(components)]] 25 | name <- paste0("tf.contrib.opt", ".", class_name) 26 | } 27 | 28 | # form topic url 29 | topic_url <- gsub(".", "/", name, fixed = TRUE) 30 | if (!is.null(subtopic)) 31 | topic_url <- paste0(topic_url, "#", subtopic) 32 | 33 | # return the full url 34 | paste0(help_url, topic_url) 35 | }) 36 | } 37 | 38 | -------------------------------------------------------------------------------- /R/modules.R: -------------------------------------------------------------------------------- 1 | 2 | #' Main TensorFlow module 3 | #' 4 | #' Interface to main TensorFlow module. Provides access to top level classes 5 | #' and functions as well as sub-modules (e.g. \code{tf$nn}, 6 | #' \code{tf$contrib$learn}, etc.). 7 | #' 8 | #' @format TensorFlow module 9 | #' 10 | #' @examples 11 | #' \dontrun{ 12 | #' library(tensorflow) 13 | #' 14 | #' hello <- tf$constant('Hello, TensorFlow!') 15 | #' zeros <- tf$Variable(tf$zeros(shape(1L))) 16 | #' 17 | #' tf$print(hello) 18 | #' tf$print(zeros) 19 | #' } 20 | #' @export 21 | tf <- NULL 22 | -------------------------------------------------------------------------------- /R/package.R: -------------------------------------------------------------------------------- 1 | 2 | #' TensorFlow for R 3 | #' 4 | #' \href{https://www.tensorflow.org}{TensorFlow} is an open source software library 5 | #' for numerical computation using data flow graphs. Nodes in the graph 6 | #' represent mathematical operations, while the graph edges represent the 7 | #' multidimensional data arrays (tensors) communicated between them. The 8 | #' flexible architecture allows you to deploy computation to one or more CPUs or 9 | #' GPUs in a desktop, server, or mobile device with a single API. 10 | #' 11 | #' The \href{https://www.tensorflow.org/api_docs/python/tf/all_symbols}{TensorFlow 12 | #' API} is composed of a set of Python modules that enable constructing and 13 | #' executing TensorFlow graphs. The tensorflow package provides access to the 14 | #' complete TensorFlow API from within R. 15 | #' 16 | #' For additional documentation on the tensorflow package see 17 | #' \href{https://tensorflow.rstudio.com}{https://tensorflow.rstudio.com} 18 | #' 19 | #' @import reticulate 20 | #' 21 | #' @docType package 22 | #' @aliases tensorflow-package 23 | #' @name tensorflow 24 | "_PACKAGE" 25 | 26 | ## usethis namespace: start 27 | #' @importFrom lifecycle deprecated 28 | ## usethis namespace: end 29 | 30 | 31 | tf_v2 <- function() { 32 | # 1.14 already shows deprecation warnings. 33 | package_version(tf_version()) >= "2.0" 34 | } 35 | 36 | # globals 37 | .globals <- new.env(parent = emptyenv()) 38 | .globals$tensorboard <- NULL 39 | 40 | 41 | .onLoad <- function(libname, pkgname) { 42 | 43 | # if (is.na(Sys.getenv("TF_CPP_MIN_LOG_LEVEL", NA))) { 44 | # ## Doesn't seem to make a difference 45 | # Sys.setenv("TF_CPP_MIN_LOG_LEVEL" = "3") 46 | # } 47 | # 0 = all messages are logged (default behavior) 48 | # 1 = INFO messages are not printed 49 | # 2 = INFO and WARNING messages are not printed 50 | # 3 = INFO, WARNING, and ERROR messages are not printed 51 | 52 | 53 | # if TENSORFLOW_PYTHON is defined then forward it to RETICULATE_PYTHON 54 | tensorflow_python <- Sys.getenv("TENSORFLOW_PYTHON", unset = NA) 55 | if (!is.na(tensorflow_python)) 56 | Sys.setenv(RETICULATE_PYTHON = tensorflow_python) 57 | 58 | # honor option to silence cpp startup logs (INFO, level 1), 59 | # but insist on printing warnings (level 2) and errors (level 3) 60 | cpp_log_opt <- getOption("tensorflow.core.cpp_min_log_level") 61 | if (!is.null(cpp_log_opt)) 62 | Sys.setenv(TF_CPP_MIN_LOG_LEVEL = max(min(cpp_log_opt, 3), 0)) 63 | 64 | # register requirements with py_require() 65 | reqs <- get_py_requirements() 66 | reticulate::py_require(reqs$packages, reqs$python_version) 67 | 68 | # delay load tensorflow 69 | tryCatch({ 70 | 71 | tf <<- import("tensorflow", delay_load = list( 72 | 73 | priority = 5, # keras sets priority = 10 74 | 75 | environment = c( 76 | "r-tensorflow", 77 | if (as.package_version(getNamespaceVersion("reticulate")) >= "1.36.0") 78 | "r-keras" 79 | ), 80 | 81 | # before_load = function() { 82 | # 83 | # }, 84 | 85 | on_load = function() { 86 | 87 | # register warning suppression handler 88 | register_suppress_warnings_handler(list( 89 | suppress = function() { 90 | if (tf_v2()) { 91 | tf_logger <- tf$get_logger() 92 | logging <- reticulate::import("logging") 93 | 94 | old_verbosity <- tf_logger$level 95 | tf_logger$setLevel(logging$ERROR) 96 | old_verbosity 97 | } else { 98 | old_verbosity <- tf$logging$get_verbosity() 99 | tf$logging$set_verbosity(tf$logging$ERROR) 100 | old_verbosity 101 | } 102 | }, 103 | restore = function(context) { 104 | if (tf_v2()) { 105 | tf_logger <- tf$get_logger() 106 | tf_logger$setLevel(context) 107 | } else { 108 | tf$logging$set_verbosity(context) 109 | } 110 | } 111 | )) 112 | 113 | # if we loaded tensorflow then register tf help handler 114 | register_tf_help_handler() 115 | 116 | # workaround to silence crash-causing deprecation warnings 117 | tryCatch(tf$python$util$deprecation$silence()$`__enter__`(), 118 | error = function(e) NULL) 119 | 120 | # TODO: move this into .onAttach, where you either emit immediately if 121 | # already loaded otherwise register emit hook for reticulate 122 | # emit <- get("packageStartupMessage") # R CMD check 123 | # emit("Loaded TensorFlow version ", tf$version$VERSION) 124 | } 125 | , 126 | 127 | on_error = function(e) { 128 | stop(tf_config_error_message(), call. = FALSE) 129 | } 130 | )) 131 | }, 132 | python.builtin.ModuleNotFoundError = function(e) { 133 | warning(e$message, "\n", 134 | "Restart the R session and load the tensorflow R package before ", 135 | "reticulate has initialized Python, or ensure reticulate initialized ", 136 | "a Python installation where the tensorflow module is installed.", call. = FALSE) 137 | }) 138 | 139 | 140 | # provide a common base S3 class for tensors 141 | reticulate::register_class_filter(function(classes) { 142 | if (any(c("tensorflow.python.ops.variables.Variable", 143 | "tensorflow.python.types.core.Tensor", # 2.14 144 | "tensorflow.python.framework.tensor.Tensor", # 2.14 145 | "tensorflow.python.framework.ops.Tensor", 146 | "tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor", 147 | "tensorflow.python.framework.sparse_tensor.SparseTensor") 148 | %in% 149 | classes)) { 150 | c("tensorflow.tensor", classes) 151 | } else { 152 | classes 153 | } 154 | }) 155 | 156 | } 157 | 158 | 159 | is_string <- function(x) { 160 | is.character(x) && length(x) == 1L && !is.na(x) 161 | } 162 | 163 | #' TensorFlow configuration information 164 | #' 165 | #' @return List with information on the current configuration of TensorFlow. 166 | #' You can determine whether TensorFlow was found using the `available` 167 | #' member (other members vary depending on whether `available` is `TRUE` 168 | #' or `FALSE`) 169 | #' 170 | #' @keywords internal 171 | #' @export 172 | tf_config <- function() { 173 | 174 | # first check if we found tensorflow 175 | have_tensorflow <- py_module_available("tensorflow") 176 | 177 | # get py config 178 | config <- py_config() 179 | 180 | # found it! 181 | if (have_tensorflow) { 182 | 183 | # get version 184 | if (reticulate::py_has_attr(tf, "version")) 185 | version_raw <- tf$version$VERSION 186 | else 187 | version_raw <- tf$VERSION 188 | 189 | tfv <- strsplit(version_raw, ".", fixed = TRUE)[[1]] 190 | version <- package_version(paste(tfv[[1]], tfv[[2]], sep = ".")) 191 | 192 | structure(class = "tensorflow_config", list( 193 | available = TRUE, 194 | version = version, 195 | version_str = version_raw, 196 | location = config$required_module_path, 197 | python = config$python, 198 | python_version = config$version 199 | )) 200 | 201 | # didn't find it 202 | } else { 203 | structure(class = "tensorflow_config", list( 204 | available = FALSE, 205 | python_versions = config$python_versions, 206 | error_message = tf_config_error_message() 207 | )) 208 | } 209 | } 210 | 211 | 212 | #' @rdname tf_config 213 | #' @keywords internal 214 | #' @export 215 | tf_version <- function() { 216 | config <- tf_config() 217 | if (config$available) 218 | config$version 219 | else 220 | NULL 221 | } 222 | 223 | #' @export 224 | print.tensorflow_config <- function(x, ...) { 225 | if (x$available) { 226 | aliased <- function(path) sub(Sys.getenv("HOME"), "~", path) 227 | cat("TensorFlow v", x$version_str, " (", aliased(x$location), ")\n", sep = "") 228 | cat("Python v", as.character(x$python_version), " (", aliased(x$python), ")\n", sep = "") 229 | } else { 230 | cat(x$error_message, "\n") 231 | } 232 | } 233 | 234 | #' TensorFlow GPU configuration information 235 | #' 236 | #' @return A bool, whether GPU is configured or not, or NA if could not be 237 | #' determined. 238 | #' 239 | #' @keywords internal 240 | #' @param verbose boolean. Whether to show extra GPU info. 241 | #' @export 242 | tf_gpu_configured <- function(verbose=TRUE) { 243 | res <- tryCatch({ 244 | tf$test$is_gpu_available() 245 | }, error = function(e) { 246 | warning("Can not determine if GPU is configured.", call. = FALSE); 247 | NA 248 | }) 249 | 250 | if (!is.na(verbose) && is.logical(verbose) &&verbose) { 251 | tryCatch({ 252 | cat(paste("TensorFlow built with CUDA: ", tf$test$is_built_with_cuda()), 253 | "\n"); 254 | cat(paste("GPU device name: ", tf$test$gpu_device_name(), 255 | collapse = "\n")) 256 | }, error = function(e) {}) 257 | } 258 | res 259 | } 260 | 261 | 262 | # Build error message for TensorFlow configuration errors 263 | tf_config_error_message <- function() { 264 | message <- "Valid installation of TensorFlow not found." 265 | config <- py_config() 266 | if (!is.null(config)) { 267 | if (length(config$python_versions) > 0) { 268 | message <- paste0(message, 269 | "\n\nPython environments searched for 'tensorflow' package:\n") 270 | python_versions <- paste0(" ", normalizePath(config$python_versions, mustWork = FALSE), 271 | collapse = "\n") 272 | message <- paste0(message, python_versions, sep = "\n") 273 | } 274 | } 275 | 276 | python_error <- tryCatch({ 277 | import("tensorflow") 278 | list(message = NULL) 279 | }, 280 | error = function(e) { 281 | on.exit(py_clear_last_error()) 282 | py_last_error() 283 | }) 284 | 285 | message <- paste0(message, 286 | "\nPython exception encountered:\n ", 287 | python_error$message, "\n") 288 | 289 | message <- paste0(message, 290 | "\nYou can install TensorFlow using the install_tensorflow() function.\n") 291 | message 292 | } 293 | -------------------------------------------------------------------------------- /R/probability.R: -------------------------------------------------------------------------------- 1 | 2 | #' TensorFlow Probability Module 3 | #' 4 | #' @return Reference to [TensorFlow Probability](https://www.tensorflow.org/probability) 5 | #' functions and classes 6 | #' 7 | #' @examples \dontrun{ 8 | #' library(tensorflow) 9 | #' ## one time setup: 10 | #' # reticulate::py_install("tensorflow_probability") 11 | #' tfp <- tf_probability() 12 | #' tfp$distributions$Normal(loc = 0, scale = 1) 13 | #' } 14 | #' 15 | #' @export 16 | tf_probability <- function() { 17 | 18 | # ensure that tensorflow is loaded 19 | ensure_loaded() 20 | 21 | # return module 22 | import("tensorflow_probability") 23 | } 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /R/reexports.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | #' @export 4 | reticulate::import 5 | 6 | #' @export 7 | reticulate::dict 8 | 9 | #' @export 10 | reticulate::tuple 11 | 12 | #' @export 13 | reticulate::np_array 14 | 15 | #' @export 16 | reticulate::array_reshape 17 | 18 | #' @export 19 | reticulate::iterate 20 | 21 | #' @export 22 | reticulate::`%as%` 23 | 24 | #' @export 25 | reticulate::use_python 26 | 27 | #' @export 28 | reticulate::use_virtualenv 29 | 30 | #' @export 31 | reticulate::use_condaenv 32 | 33 | #' @importFrom tfruns flags 34 | #' @export 35 | tfruns::flags 36 | 37 | #' @importFrom tfruns flag_numeric 38 | #' @export 39 | tfruns::flag_numeric 40 | 41 | #' @importFrom tfruns flag_integer 42 | #' @export 43 | tfruns::flag_integer 44 | 45 | #' @importFrom tfruns flag_string 46 | #' @export 47 | tfruns::flag_string 48 | 49 | #' @importFrom tfruns flag_boolean 50 | #' @export 51 | tfruns::flag_boolean 52 | 53 | #' @importFrom tfruns run_dir 54 | #' @export 55 | tfruns::run_dir 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /R/save.R: -------------------------------------------------------------------------------- 1 | #' View a Saved Model 2 | #' 3 | #' View a serialized model from disk. 4 | #' 5 | #' @param model_dir The path to the exported model, as a string. 6 | #' 7 | #' @return URL for browsing TensorBoard (invisibly). 8 | #' 9 | #' @export 10 | view_savedmodel <- function( 11 | model_dir 12 | ) { 13 | log_dir <- tempfile() 14 | export_files <- dir(model_dir, full.names = TRUE, recursive = TRUE) 15 | export_pb <- export_files[grepl("\\.pb$", export_files)] 16 | if (length(export_pb) != 1) stop("Failed to find 'pb' file under ", model_dir) 17 | 18 | gfile <- tf$python$platform$gfile 19 | compat <- tf$python$util$compat 20 | saved_model_pb2 <- tf$core$protobuf$saved_model_pb2 21 | 22 | # open session and file (close both on exit) 23 | 24 | if (tf_version() >= "1.14") 25 | sess <- tf$compat$v1$Session() 26 | else 27 | sess <- tf$Session() 28 | 29 | f <- gfile$FastGFile(export_pb, "rb") 30 | on.exit({ 31 | f$close() 32 | sess$close() 33 | }, add = TRUE) 34 | 35 | # write graph 36 | 37 | graph_def <- tf$GraphDef() 38 | 39 | data <- compat$as_bytes(f$read()) 40 | sm <- saved_model_pb2$SavedModel() 41 | sm$ParseFromString(data) 42 | 43 | if (sm$meta_graphs$`__len__`() > 1) stop("Saved model contains more than one graph") 44 | 45 | g_in <- tf$import_graph_def(sm$meta_graphs[[0]]$graph_def) 46 | 47 | 48 | train_writer <- tf$summary$FileWriter(log_dir) 49 | train_writer$add_graph(sess$graph) 50 | train_writer$close() 51 | 52 | tensorboard(log_dir = log_dir) 53 | } 54 | 55 | #' @export 56 | export_savedmodel.tensorflow.python.client.session.Session <- function( 57 | object, 58 | export_dir_base, 59 | inputs, 60 | outputs, 61 | overwrite = TRUE, 62 | versioned = !overwrite, 63 | as_text = FALSE, 64 | ...) { 65 | 66 | if (versioned) { 67 | export_dir_base <- file.path(export_dir_base, format(Sys.time(), "%Y%m%d%H%M%OS", tz = "GMT")) 68 | } 69 | 70 | sess <- object 71 | if (!overwrite && !versioned && dir.exists(export_dir_base)) 72 | stop("Directory ", export_dir_base, " already exists.") 73 | 74 | tensor_inputs_info <- lapply(inputs, function(i) tf$saved_model$utils$build_tensor_info(i)) 75 | tensor_outputs_info <- lapply(outputs, function(o) tf$saved_model$utils$build_tensor_info(o)) 76 | 77 | build_signature_def <- if (tf_version() >= "1.14") { 78 | tf$compat$v1$saved_model$signature_def_utils$build_signature_def 79 | } else { 80 | tf$saved_model$signature_def_utils$build_signature_def 81 | } 82 | 83 | signature_constants <- if (tf_version() >= "1.14") { 84 | tf$saved_model 85 | } else { 86 | tf$saved_model$signature_constants 87 | } 88 | 89 | prediction_signature <- build_signature_def( 90 | inputs = tensor_inputs_info, 91 | outputs = tensor_outputs_info, 92 | method_name = signature_constants$PREDICT_METHOD_NAME) 93 | 94 | signature_def_map_class_dig <- signature_constants$DEFAULT_SERVING_SIGNATURE_DEF_KEY 95 | signature <- list() 96 | signature[[signature_def_map_class_dig]] <- prediction_signature 97 | 98 | if (overwrite && dir.exists(export_dir_base)) 99 | unlink(export_dir_base, recursive = TRUE) 100 | 101 | 102 | if (tf_version() >= "1.14") 103 | builder <- tf$compat$v1$saved_model$builder$SavedModelBuilder(export_dir_base) 104 | else 105 | builder <- tf$saved_model$builder$SavedModelBuilder(export_dir_base) 106 | 107 | builder$add_meta_graph_and_variables( 108 | sess, 109 | list( 110 | tf$python$saved_model$tag_constants$SERVING 111 | ), 112 | signature_def_map = signature 113 | ) 114 | 115 | invisible(builder$save(as_text = as_text)) 116 | } 117 | 118 | -------------------------------------------------------------------------------- /R/seed.R: -------------------------------------------------------------------------------- 1 | 2 | #' Use a session with a random seed 3 | #' 4 | #' Set various random seeds required to ensure reproducible results. The 5 | #' provided `seed` value will establish a new random seed for R, Python, NumPy, 6 | #' and TensorFlow. GPU computations and CPU parallelism will also be disabled by 7 | #' default. 8 | #' 9 | #' 10 | #' @param seed A single value, interpreted as an integer 11 | #' @param disable_gpu `TRUE` to disable GPU execution (see *Parallelism* below). 12 | #' @param disable_parallel_cpu `TRUE` to disable CPU parallelism (see 13 | #' *Parallelism* below). 14 | #' @param quiet `TRUE` to suppress printing of messages. 15 | #' 16 | #' @details This function must be called at the very top of your script (i.e. 17 | #' immediately after `library(tensorflow)`, `library(keras)`, etc.). Any 18 | #' existing TensorFlow session is torn down via `tf$reset_default_graph()`. 19 | #' 20 | #' This function takes all measures known to promote reproducible results from 21 | #' TensorFlow sessions, however it's possible that various individual 22 | #' TensorFlow features or dependent libraries escape its effects. If you 23 | #' encounter non-reproducible results please investigate the possible sources 24 | #' of the problem, contributions via pull request are very welcome! 25 | #' 26 | #' @section Parallelism: By default the `use_session_with_seed()` function 27 | #' disables GPU and CPU parallelism, since both can result in 28 | #' non-deterministic execution patterns (see 29 | #' ). You can optionally enable 30 | #' GPU or CPU parallelism by setting the `disable_gpu` and/or 31 | #' `disable_parallel_cpu` parameters to `FALSE`. 32 | #' 33 | #' @return TensorFlow session object, invisibly 34 | #' 35 | #' @details Packages which need to be notified before and after the seed is set 36 | #' can register for the "tensorflow.on_before_use_session" and 37 | #' "tensorflow.on_use_session" hooks (see [setHook()]) for additional 38 | #' details on hooks). 39 | #' 40 | #' @examples 41 | #' \dontrun{ 42 | #' library(tensorflow) 43 | #' use_session_with_seed(42) 44 | #' } 45 | #' 46 | #' @export 47 | use_session_with_seed <- function(seed, 48 | disable_gpu = TRUE, 49 | disable_parallel_cpu = TRUE, 50 | quiet = FALSE) { 51 | 52 | 53 | msg <- "use_session_with_seed will be deprecated in the future. use tensorflow::set_random_seed instead." 54 | if (tf_version() >= "2.0") { 55 | tf <- tf$compat$v1 56 | warning(msg) 57 | } 58 | 59 | if (tf_version() >= "2.3") 60 | stop(msg) 61 | 62 | # cast seed to integer 63 | seed <- as.integer(seed) 64 | 65 | # call hook (returns TRUE if TF seed should be set, this allows users to 66 | # call this function even when using front-end packages like keras that 67 | # may not use TF as their backend) 68 | using_tf <- call_hook("tensorflow.on_before_use_session", quiet) 69 | 70 | # destroy existing session call before hook 71 | if (using_tf) 72 | tf$reset_default_graph() 73 | 74 | # note what has been disabled 75 | disabled <- character() 76 | 77 | # disable CUDA if requested 78 | if (disable_gpu) { 79 | Sys.setenv(CUDA_VISIBLE_DEVICES = "-1") 80 | disabled <- c(disabled, "GPU") 81 | } 82 | 83 | # set R random seed 84 | set.seed(seed) 85 | 86 | # set Python/NumPy random seed 87 | py_set_seed(seed) 88 | 89 | # TF if we are using tf 90 | if (using_tf) { 91 | # Force TensorFlow to use single thread as multiple threads are a potential 92 | # source of non-reproducible results. For further details, see: 93 | # https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res 94 | 95 | # disable parallelism if requested 96 | config <- list() 97 | if (disable_gpu) { 98 | config$device_count <- list(gpu = 0L) 99 | } 100 | if (disable_parallel_cpu) { 101 | config$intra_op_parallelism_threads <- 1L 102 | config$inter_op_parallelism_threads <- 1L 103 | disabled <- c(disabled, "CPU parallelism") 104 | } 105 | session_conf <- do.call(tf$ConfigProto, config) 106 | 107 | # The below tf$set_random_seed() will make random number generation in the 108 | # TensorFlow backend have a well-defined initial state. For further details, 109 | # see: https://www.tensorflow.org/api_docs/python/tf/set_random_seed 110 | tf$set_random_seed(seed) 111 | 112 | # create session 113 | sess <- tf$Session(graph = tf$get_default_graph(), config = session_conf) 114 | 115 | } else { 116 | sess <- NULL 117 | } 118 | 119 | # show message 120 | msg <- paste("Set session seed to", seed) 121 | if (length(disabled) > 0) 122 | msg <- paste0(msg, " (disabled ", paste(disabled, collapse = ", "), ")") 123 | if (!quiet) 124 | message(msg) 125 | 126 | # call after hook 127 | call_hook("tensorflow.on_use_session", sess, quiet) 128 | 129 | # return session invisibly 130 | invisible(sess) 131 | } 132 | 133 | #' Set random seed for TensorFlow 134 | #' 135 | #' Sets all random seeds needed to make TensorFlow code reproducible. 136 | #' 137 | #' @details 138 | #' 139 | #' This function should be used instead of [use_session_with_seed()] if 140 | #' you are using TensorFlow >= 2.0, as the concept of `session` doesn't 141 | #' really make sense anymore. 142 | #' 143 | #' This functions sets: 144 | #' 145 | #' - The R random seed with [set.seed()]. 146 | #' - The python and Numpy seeds via ([reticulate::py_set_seed()]). 147 | #' - The TensorFlow seed with (`tf$random$set_seed()`) 148 | #' 149 | #' It also optionally disables the GPU execution as this is a potential 150 | #' source of non-reproducibility. 151 | #' 152 | #' @param seed A single value, interpreted as an integer 153 | #' @param disable_gpu `TRUE` to disable GPU execution (see *Parallelism* below). 154 | #' 155 | #' @export 156 | set_random_seed <- function(seed, disable_gpu = TRUE) { 157 | 158 | if(is.null(tf_ver <- tf_version())) 159 | stop("TensorFlow not installed, please run `tensorflow::install_tensorflow()`") 160 | else if (tf_ver < "2.0") 161 | stop("set_random_seed() only works for TF >= 2.0") 162 | 163 | # cast seed to integer 164 | seed <- as.integer(seed) 165 | 166 | # set R random seed 167 | set.seed(seed) 168 | 169 | # set Python/NumPy random seed 170 | py_set_seed(seed) 171 | 172 | # set tensorflow random seed 173 | tensorflow::tf$random$set_seed(seed) 174 | 175 | if (disable_gpu) { 176 | Sys.setenv(CUDA_VISIBLE_DEVICES = "-1") 177 | } 178 | 179 | invisible(NULL) 180 | } 181 | -------------------------------------------------------------------------------- /R/shape.R: -------------------------------------------------------------------------------- 1 | #' Create a `tf.TensorShape` object 2 | #' 3 | #' @param ... Tensor dimensions as integers or `NULL` for an unknown 4 | #' dimensions. `NA` and `-1` are synonyms for `NULL`. 5 | #' @param dims Tensor dimensions as a vector. 6 | #' 7 | #' @seealso 8 | #' 9 | #' @export 10 | #' @examples 11 | #' \dontrun{ 12 | #' 13 | #' # --- construct --- 14 | #' shape() # tf.TensorShape() # scalar 15 | #' shape(NULL) # tf.TensorShape([None]) # 1-D array of unknown length 16 | #' shape(NA) # tf.TensorShape([None]) # 1-D array of unknown length, NA is a synonym for NULL 17 | #' 18 | #' shape(dims = NULL) # TensorShape(None) # Unknown rank, unknown size 19 | #' shape(3, 4) # TensorShape([3, 4]) # 2-D array (matrix) with 3 rows, 4 columns 20 | #' shape(NA, 4) # TensorShape([None, 4]) # 2-D array (matrix) with unknown rows, 4 columns 21 | #' shape(dims = c(NA, 4)) # TensorShape([None, 4]) # same as above; bypass ... and pass dims directly 22 | #' 23 | #' # --- inspect --- 24 | #' length(shape(dims = NULL)) # NA_integer_ 25 | #' length(shape(1,2,3,NA)) # 4L 26 | #' 27 | #' # ---convert --- 28 | #' x <- shape(dims = list(3L, 5L)) 29 | #' as.list(x) # list(3L, 5L) 30 | #' as.integer(x) # c(3L, 5L) 31 | #' as.numeric(x) # c(3, 5) 32 | #' as.double(x) # c(3, 5) # alias for as.numeric 33 | #' as_tensor(x) # tf.Tensor([3 5], shape=(2,), dtype=int32) 34 | #' 35 | #' # convert partially undefined shapes 36 | #' x <- shape(NA, 3) 37 | #' as.list(x) # list(NULL, 3L) 38 | #' as.integer(x) # c(NA, 3L) 39 | #' as_tensor(x) # tf.Tensor([-1 3], shape=(2,), dtype=int32) # unspecified dims default is -1 40 | #' 41 | #' # as_tensor() converts undefined dimensions to -1, which is useful for 42 | #' # tf functions that only accept tensors for shapes, e.g, 43 | #' tf$reshape(tf$zeros(shape(8)), 44 | #' as_tensor(shape(NA, 4))) 45 | #' # tf.Tensor([[0. 0. 0. 0.] 46 | #' # [0. 0. 0. 0.]], shape=(2, 4), dtype=float32) 47 | #' 48 | #' # converting fully unknown shapes raises an error 49 | #' try(as.list(shape(dims = NULL))) # ValueError: as_list() is not defined on an unknown TensorShape. 50 | #' # test for rank first if this a concern: 51 | #' as.list_or_null <- function(x) if(is.na(length(x))) NULL else as.list(x) 52 | #' as.list_or_null(shape(dims = NULL)) 53 | #' 54 | #' 55 | #' # --- compare --- 56 | #' # Fully known shapes return TRUE if and only if each element is equal 57 | #' shape(3, 4) == shape(3, 4) # TRUE 58 | #' shape(3, 4) == shape(4, 4) # FALSE 59 | #' 60 | #' # two unknown dimensions are treated as equal 61 | #' shape(NA, 4) == shape(NA, 4) # TRUE 62 | #' shape(NA, 4) == shape(3, 4) # FALSE 63 | #' 64 | #' # Two unknown shapes, return TRUE 65 | #' shape(dims = NULL) == shape(dims = NULL) # TRUE 66 | #' 67 | #' # Comparing an unknown shape to a partially or fully defined shape returns FALSE 68 | #' shape(dims = NULL) == shape(NULL) # FALSE 69 | #' shape(dims = NULL) == shape(4) # FALSE 70 | #' 71 | #' 72 | #' values of length greater than one supplied to `...` are automatically flattened 73 | #' shape(1, c(2, 3), 4) # shape(1, 2, 3, 4) 74 | #' shape(1, shape(2, 3), 4) # shape(1, 2, 3, 4) 75 | #' shape(1, as_tensor(2, 3), 4) # shape(1, 2, 3, 4) 76 | #' 77 | #' # --- extract or replace --- 78 | #' # regular R-list semantics for `[`, `[[`, `[<-`, `[[<-` 79 | #' x <- shape(1, 2, 3) 80 | #' x[1] # TensorShape([1]) 81 | #' x[[1]] # 1L 82 | #' x[2:3] # TensorShape([2, 3]) 83 | #' x[-1] # TensorShape([2, 3]) 84 | #' 85 | #' x[1] <- 11 ; x # TensorShape([11, 2, 3]) 86 | #' x[1] <- shape(11) ; x # TensorShape([11, 2, 3]) 87 | #' x[1] <- list(11) ; x # TensorShape([11, 2, 3]) 88 | #' 89 | #' x[[1]] <- 22 ; x # TensorShape([22, 2, 3]) 90 | #' x[1:2] <- c(NA, 99) ; x # TensorShape([None, 99, 3]) 91 | #' x[1:2] <- shape(33, 44) ; x # TensorShape([33, 44, 3]) 92 | #' 93 | #' # --- concatenate --- 94 | #' c(shape(1), shape(2, 3), shape(4, NA)) # TensorShape([1, 2, 3, 4, None]) 95 | #' 96 | #' # --- merge --- 97 | #' merge(shape(NA, 2), 98 | #' shape(1 , 2)) # TensorShape([1, 2]) 99 | #' 100 | #' try(merge(shape(2, 2), 101 | #' shape(1, 2))) # ValueError: Shapes (2, 2) and (1, 2) are not compatible 102 | #' 103 | #' rm(x) # cleanup 104 | #' } 105 | shape <- function(..., dims = list(...)) { 106 | if (is.null(dims)) 107 | return(tf$TensorShape(NULL)) 108 | 109 | if(inherits(dims, "tensorflow.tensor") && tf$executing_eagerly()) 110 | dims <- as_r_value(dims$numpy()) 111 | 112 | names(dims) <- NULL 113 | dims <- lapply(dims, function(d) { 114 | d <- as_r_value(d) 115 | if (is.null(d) || 116 | is.atomic(d) && isTRUE(is.na(d)) || 117 | (is.numeric(d) && isTRUE(d == -1L))) 118 | list(NULL) 119 | else 120 | as.integer(d) 121 | }) 122 | if(length(dims)) 123 | dims <- unlist(dims, recursive = FALSE, use.names = FALSE) 124 | 125 | tf$TensorShape(dims) 126 | } 127 | 128 | 129 | as_shape <- function(x) { 130 | if(inherits(x, "tensorflow.python.framework.tensor_shape.TensorShape")) 131 | x 132 | else 133 | shape(dims = x) 134 | } 135 | 136 | as_r_value <- function (x) { 137 | if (inherits(x, "python.builtin.object")) 138 | py_to_r(x) 139 | else 140 | x 141 | } 142 | 143 | 144 | 145 | #' @export 146 | as.list.tensorflow.python.framework.tensor_shape.TensorShape <- function(x, ...) { 147 | as.list(as_r_value(x$as_list())) # raises an exception for unknown rank 148 | } 149 | 150 | #' @export 151 | #' @method as.integer tensorflow.python.framework.tensor_shape.TensorShape 152 | as.integer.tensorflow.python.framework.tensor_shape.TensorShape <- function(x, ...) { 153 | vapply(as.list(as_r_value(x$as_list())), 154 | function(e) e %||% NA_integer_, 155 | 0L) 156 | } 157 | 158 | #' @export 159 | #' @method as.numeric tensorflow.python.framework.tensor_shape.TensorShape 160 | as.numeric.tensorflow.python.framework.tensor_shape.TensorShape <- function(x, ...) 161 | as.numeric(as.integer.tensorflow.python.framework.tensor_shape.TensorShape(x), ...) 162 | 163 | #' @export 164 | #' @method as.double tensorflow.python.framework.tensor_shape.TensorShape 165 | as.double.tensorflow.python.framework.tensor_shape.TensorShape <- 166 | as.numeric.tensorflow.python.framework.tensor_shape.TensorShape 167 | 168 | #' @export 169 | as_tensor.tensorflow.python.framework.tensor_shape.TensorShape <- 170 | function(x, dtype = NULL, ..., name = NULL) { 171 | if(x$is_fully_defined()) 172 | return(NextMethod()) 173 | 174 | x <- as.integer.tensorflow.python.framework.tensor_shape.TensorShape(x) 175 | x[is.na(x)] <- -1L 176 | as_tensor.default(x, dtype, ..., name = name) 177 | } 178 | 179 | 180 | #' @export 181 | `[.tensorflow.python.framework.tensor_shape.TensorShape` <- function(x, i) { 182 | x <- as.list.tensorflow.python.framework.tensor_shape.TensorShape(x) 183 | as_shape(x[i]) 184 | } 185 | 186 | #' @export 187 | `[[.tensorflow.python.framework.tensor_shape.TensorShape` <- function(x, i) { 188 | x <- as.list.tensorflow.python.framework.tensor_shape.TensorShape(x) 189 | x[[i]] 190 | } 191 | 192 | #' @export 193 | `[<-.tensorflow.python.framework.tensor_shape.TensorShape` <- function(x, ..., value) { 194 | x <- as.list.tensorflow.python.framework.tensor_shape.TensorShape(x) 195 | x[...] <- as.list(value) 196 | shape(dims = x) 197 | } 198 | 199 | #' @export 200 | `[[<-.tensorflow.python.framework.tensor_shape.TensorShape` <- function(x, ..., value) { 201 | x <- as.list.tensorflow.python.framework.tensor_shape.TensorShape(x) 202 | x[[...]] <- value 203 | shape(dims = x) 204 | } 205 | 206 | 207 | #' @export 208 | `c.tensorflow.python.framework.tensor_shape.TensorShape` <- function(...) { 209 | x <- ..1 210 | for(other in list(...)[-1]) 211 | x <- x$concatenate(as_shape(other)) 212 | x 213 | } 214 | 215 | # `c.tensorflow.python.framework.tensor_shape.TensorShape` <- function(...) { 216 | # x <- do.call(c, lapply(unname(list(...)), as.list)) 217 | # shape(dims = x) 218 | # } 219 | 220 | 221 | #' @export 222 | merge.tensorflow.python.framework.tensor_shape.TensorShape <- function(x, y, ...) 223 | x$merge_with(as_shape(y)) 224 | 225 | 226 | #' @export 227 | length.tensorflow.python.framework.tensor_shape.TensorShape <- function(x) { 228 | # x$rank returns NULL on tensor of unknown rank 229 | # x$`__len__`()) raises ValueError on tensor of unknown rank 230 | # dim(tensor) returns NULL on tensor of unknown rank (for reference) 231 | as_r_value(x$rank) %||% NA_integer_ 232 | } 233 | 234 | #' @export 235 | format.tensorflow.python.framework.tensor_shape.TensorShape <- 236 | function(x, ...) { 237 | if (identical(as_r_value(x$rank), NULL)) 238 | "()" 239 | else 240 | sprintf("(%s)", paste0(as.integer(x), collapse = ", ")) 241 | } 242 | 243 | #' @export 244 | print.tensorflow.python.framework.tensor_shape.TensorShape <- 245 | function(x, ...) { 246 | writeLines(import_builtins()$repr(x)) 247 | invisible(x) 248 | } 249 | 250 | 251 | #' @export 252 | py_str.tensorflow.python.framework.tensor_shape.TensorShape <- 253 | function(object, ...) as_r_value(object$`__repr__`()) 254 | 255 | ## reticulate already dispatches to __eq__, but we need to do 256 | ## additional coercion on a and b 257 | #' @export 258 | `==.tensorflow.python.framework.tensor_shape.TensorShape` <- function(a, b) { 259 | a <- as_shape(a) 260 | b <- as_shape(b) 261 | as_r_value(a$`__eq__`(b)) 262 | } 263 | 264 | # != is not defined as the negation of == in python, tricky! 265 | #' @export 266 | `!=.tensorflow.python.framework.tensor_shape.TensorShape` <- function(a, b) { 267 | a <- as_shape(a) 268 | b <- as_shape(b) 269 | as_r_value(a$`__ne__`(b)) 270 | } 271 | 272 | 273 | 274 | 275 | # old shape def, retained in namespace in case it's needed for easy back compat 276 | shape_v1 <- function(...) { 277 | values <- list(...) 278 | lapply(values, function(value) { 279 | if (!is.null(value)) 280 | as.integer(value) 281 | else 282 | NULL 283 | }) 284 | } 285 | -------------------------------------------------------------------------------- /R/tensorboard.R: -------------------------------------------------------------------------------- 1 | #' TensorBoard Visualization Tool 2 | #' 3 | #' TensorBoard is a tool inspecting and understanding your TensorFlow runs and 4 | #' graphs. 5 | #' 6 | #' @param log_dir Directories to scan for training logs. If this is a named 7 | #' character vector then the specified names will be used as aliases within 8 | #' TensorBoard. 9 | #' @param action Specify whether to start or stop TensorBoard (TensorBoard will 10 | #' be stopped automatically when the R session from which it is launched is 11 | #' terminated). 12 | #' @param host Host for serving TensorBoard 13 | #' @param port Port for serving TensorBoard. If "auto" is specified (the 14 | #' default) then an unused port will be chosen automatically. 15 | #' @param launch_browser Open a web browser for TensorBoard after launching. 16 | #' Defaults to `TRUE` in interactive sessions. When running under RStudio uses 17 | #' an RStudio window by default (pass a function e.g. [utils::browseURL()] to 18 | #' open in an external browser). Use the `tensorflow.tensorboard.browser` 19 | #' option to establish a global default behavior. 20 | #' @param reload_interval How often the backend should load more data. 21 | #' @param purge_orphaned_data Whether to purge data that may have been orphaned 22 | #' due to TensorBoard restarts. Disabling purge_orphaned_data can be used to 23 | #' debug data disappearance. 24 | #' 25 | #' @return URL for browsing TensorBoard (invisibly). 26 | #' 27 | #' @details When TensorBoard is passed a logdir at startup, it recursively walks 28 | #' the directory tree rooted at logdir looking for subdirectories that contain 29 | #' tfevents data. Every time it encounters such a subdirectory, it loads it as 30 | #' a new run, and the frontend will organize the data accordingly. 31 | #' 32 | #' The TensorBoard process will be automatically destroyed when the R session 33 | #' in which it is launched exits. You can pass `action = "stop"` to manually 34 | #' terminate TensorBoard. 35 | #' 36 | #' @export 37 | tensorboard <- function(log_dir, action = c("start", "stop"), 38 | host = "127.0.0.1", port = "auto", 39 | launch_browser = getOption("tensorflow.tensorboard.browser", 40 | interactive()), 41 | reload_interval = 5, 42 | purge_orphaned_data = TRUE 43 | ) { 44 | 45 | # ensure that tensorflow initializes (so we get tensorboard on our path) 46 | ensure_loaded() 47 | 48 | # verify we can find tensorboard 49 | if (!nzchar(Sys.which("tensorboard"))) 50 | stop("Unable to find tensorboard on PATH") 51 | 52 | # if log_dir is missing try to find a "latest run" 53 | if (missing(log_dir)) { 54 | latest <- tfruns::latest_run() 55 | if (!is.null(latest)) 56 | log_dir <- latest$run_dir 57 | else 58 | stop("A log_dir must be specified for tensorboard") 59 | } 60 | 61 | # convert input to run_dir 62 | log_dir <- tfruns::as_run_dir(log_dir) 63 | 64 | # expand log dir path 65 | log_dir <- path.expand(log_dir) 66 | 67 | # create log_dir(s) if necessary 68 | log_dir <- as.character(lapply(log_dir, function(dir) { 69 | if (!utils::file_test("-d", dir)) 70 | dir.create(dir, recursive = TRUE) 71 | dir 72 | })) 73 | 74 | # if we already have a tensorboard for this session then kill it and re-use it's port 75 | if (!is.null(.globals$tensorboard)) { 76 | p <- .globals$tensorboard$process 77 | if (p$is_alive()) { 78 | p$kill() 79 | p$wait(1000L) 80 | } 81 | if (identical(port, "auto")) 82 | port <- .globals$tensorboard$port 83 | .globals$tensorboard <- NULL 84 | } 85 | 86 | # exit if this was action = "stop" 87 | action <- match.arg(action) 88 | if (identical(action, "stop")) { 89 | cat("TensorBoard stopped.\n") 90 | return(invisible(NULL)) 91 | } 92 | 93 | 94 | # for port = "auto", attempt to find a port up to 20 times 95 | if (identical(port, "auto")) { 96 | 97 | for (i in 1:20) { 98 | 99 | # determine the port (exclude those considered unsafe by Chrome) 100 | while(TRUE) { 101 | port <- 3000 + sample(5000, 1) 102 | if (!port %in% c(3659, 4045, 6000, 6665:6669)) 103 | break 104 | } 105 | 106 | # attempt to launch 107 | p <- launch_tensorboard(log_dir, host, port, FALSE, reload_interval, purge_orphaned_data) 108 | if (p$is_alive()) 109 | break 110 | } 111 | 112 | } else { 113 | p <- launch_tensorboard(log_dir, host, port, TRUE, reload_interval, purge_orphaned_data) 114 | } 115 | 116 | if (p$is_alive()) { 117 | 118 | # close connections 119 | close(p$get_output_connection()) 120 | close(p$get_error_connection()) 121 | 122 | # save as global tensorboard 123 | .globals$tensorboard <- list(process = p, port = port) 124 | 125 | # browse the url if requested 126 | url <- paste0("http://", host, ":", port) 127 | cat("Started TensorBoard at", url, "\n") 128 | if (isTRUE(launch_browser)) { 129 | getOption("browser")(url) 130 | } else if (is.function(launch_browser)) { 131 | launch_browser(url) 132 | } 133 | 134 | # return the url invisibly 135 | invisible(url) 136 | 137 | } else { 138 | stop("Unable to launch tensorboard") 139 | } 140 | } 141 | 142 | 143 | tensorboard_version <- function() { 144 | if (is.null(ver <- .globals$tensorboard_version)) { 145 | ver <- package_version(system("tensorboard --version_tb", intern = TRUE, ignore.stderr = TRUE)) 146 | .globals$tensorboard_version <- ver 147 | } 148 | ver 149 | } 150 | 151 | 152 | launch_tensorboard <- function(log_dir, host, port, explicit_port, reload_interval, purge_orphaned_data) { 153 | 154 | if (tensorboard_version() < "2.0") { 155 | # check for names and provide defaults 156 | names <- names(log_dir) 157 | if (is.null(names)) 158 | names <- basename(log_dir) 159 | 160 | # concatenate names if we have them 161 | if (!is.null(names)) 162 | log_dir <- paste0(names, ":", log_dir) 163 | 164 | # build log_dir 165 | log_dir <- paste(log_dir, collapse = ",") 166 | } 167 | 168 | # start the process 169 | p <- processx::process$new("tensorboard", 170 | c("--logdir", log_dir, 171 | "--host", host, 172 | "--port", as.character(port), 173 | "--reload_interval", as.integer(reload_interval), 174 | "--purge_orphaned_data", purge_orphaned_data), 175 | stdout = "|", stderr = "|") 176 | 177 | # poll for availability of the http server (continue as long as the 178 | # process is still alive). note that we used to poll for stdout however 179 | # tensorflow v1.3 stopped writing a newline after printing the host:port 180 | # and caused us to haning in p$read_output_lines() 181 | started <- FALSE 182 | Sys.sleep(0.25) 183 | conn <- url(paste0("http://", host, ":", as.character(port))) 184 | on.exit(close(conn), add = TRUE) 185 | while(!started && p$is_alive()) { 186 | Sys.sleep(0.25) 187 | tryCatch({ 188 | suppressWarnings(readLines(conn, n = -1)) 189 | started = TRUE 190 | }, 191 | error = function(e) {} 192 | ) 193 | } 194 | 195 | # poll for error messages 196 | res <- p$poll_io(100L) 197 | 198 | # see if we have stderr 199 | if (identical(res[["error"]], "ready")) { 200 | 201 | # capture error output 202 | err <- p$read_error_lines() 203 | 204 | # write it unless it's a port in use error when we are auto-binding 205 | if (explicit_port || !any(grepl(paste0("^.*", port, ".*already in use.*$"), err))) 206 | write(err, stderr()) 207 | } 208 | 209 | # return the process 210 | p 211 | } 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | "%||%" <- function(x, y) if (is.null(x)) y else x 2 | 3 | is_windows <- function() { 4 | identical(.Platform$OS.type, "windows") 5 | } 6 | 7 | is_unix <- function() { 8 | identical(.Platform$OS.type, "unix") 9 | } 10 | 11 | is_osx <- function() { 12 | Sys.info()["sysname"] == "Darwin" 13 | } 14 | 15 | is_linux <- function() { 16 | identical(tolower(Sys.info()[["sysname"]]), "linux") 17 | } 18 | 19 | is_ubuntu <- function() { 20 | # check /etc/lsb-release 21 | if (is_unix() && file.exists("/etc/lsb-release")) { 22 | lsbRelease <- readLines("/etc/lsb-release") 23 | any(grepl("Ubuntu", lsbRelease)) 24 | } else { 25 | FALSE 26 | } 27 | } 28 | 29 | 30 | is_mac_arm64 <- function() { 31 | sys_info <- Sys.info() 32 | sys_info[["sysname"]] == "Darwin" && 33 | sys_info[["machine"]] == "arm64" 34 | } 35 | 36 | dir_exists <- function(x) { 37 | utils::file_test('-d', x) 38 | } 39 | 40 | ensure_loaded <- function() { 41 | invisible(tf$`__version__`) 42 | } 43 | 44 | aliased <- function(path) { 45 | sub(Sys.getenv("HOME"), "~", path) 46 | } 47 | 48 | 49 | call_hook <- function(name, ...) { 50 | hooks <- getHook(name) 51 | if (!is.list(hooks)) 52 | hooks <- list(hooks) 53 | response <- FALSE 54 | lapply(hooks, function(hook) { 55 | if (isTRUE(hook(...))) 56 | response <<- TRUE 57 | }) 58 | response 59 | } 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## TensorFlow for R 3 | [![R build status](https://github.com/rstudio/tensorflow/workflows/R-CMD-check/badge.svg)](https://github.com/rstudio/tensorflow/actions?workflow=R-CMD-check) [![CRAN\_Status\_Badge](https://www.r-pkg.org/badges/version/tensorflow)](https://cran.r-project.org/package=tensorflow) 4 | 5 | [TensorFlow™](https://www.tensorflow.org) is an open source software library for numerical computation using data flow graphs. Nodes in the graph represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) communicated between them. The flexible architecture allows you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device with a single API. 6 | 7 | The [TensorFlow API](https://www.tensorflow.org/api_docs/python/tf/all_symbols) is composed of a set of Python modules that enable constructing and executing TensorFlow graphs. The tensorflow package provides access to the complete TensorFlow API from within R. 8 | 9 | ## Installation 10 | 11 | To get started, install the tensorflow R package from GitHub as follows: 12 | 13 | ```r 14 | devtools::install_github("rstudio/tensorflow") 15 | ``` 16 | 17 | Then, use the `install_tensorflow()` function to install TensorFlow: 18 | 19 | ```r 20 | library(tensorflow) 21 | install_tensorflow() 22 | ``` 23 | 24 | You can confirm that the installation succeeded with: 25 | 26 | ```r 27 | hello <- tf$constant("Hello") 28 | print(hello) 29 | ``` 30 | 31 | This will provide you with a default installation of TensorFlow suitable for getting started with the tensorflow R package. See the [article on installation](https://tensorflow.rstudio.com/install/) to learn about more advanced options, including installing a version of TensorFlow that takes advantage of Nvidia GPUs if you have the correct CUDA libraries installed. 32 | 33 | ## Documentation 34 | 35 | See the package website for additional details on using the TensorFlow API from R: 36 | 37 | See the TensorFlow API reference for details on all of the modules, classes, and functions within the API: 38 | 39 | The tensorflow package provides code completion and inline help for the TensorFlow API when running within the RStudio IDE. In order to take advantage of these features you should also install the [Current Release](https://posit.co/download/rstudio-desktop/) of RStudio. 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | New release, bugfixes and updates. 2 | 3 | Details in NEWS.md 4 | 5 | ## revdepcheck results 6 | 7 | We checked 64 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 8 | 9 | * We saw 0 new problems 10 | * We failed to check 0 packages 11 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | HTML Meta Tag 5 | 6 | 7 | 8 |

The TensorFlow for R website has been moved to here.

9 | 10 | 11 | -------------------------------------------------------------------------------- /man/all_dims.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/extract.R 3 | \name{all_dims} 4 | \alias{all_dims} 5 | \title{All dims} 6 | \usage{ 7 | all_dims() 8 | } 9 | \description{ 10 | This function returns an object that can be used when subsetting tensors with 11 | \code{[}. If you are familiar with python,, this is equivalent to the python Ellipsis 12 | \code{...}, (not to be confused with \code{...} in \code{R}). 13 | } 14 | \examples{ 15 | \dontrun{ 16 | # in python, if x is a numpy array or tensorflow tensor 17 | x[..., i] 18 | # the ellipsis means "expand to match number of dimension of x". 19 | # to translate the above python expression to R, write: 20 | x[all_dims(), i] 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /man/as_tensor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/generics.R 3 | \name{as_tensor} 4 | \alias{as_tensor} 5 | \alias{as_tensor.default} 6 | \alias{as_tensor.double} 7 | \title{as_tensor} 8 | \usage{ 9 | as_tensor(x, dtype = NULL, ..., name = NULL) 10 | 11 | \method{as_tensor}{default}(x, dtype = NULL, ..., shape = NULL, name = NULL) 12 | 13 | \method{as_tensor}{double}(x, dtype = NULL, ..., name = NULL) 14 | } 15 | \arguments{ 16 | \item{x}{object to convert} 17 | 18 | \item{dtype}{\code{NULL}, a tensorflow dtype (\code{tf$int32}), or something coercible 19 | to one (e.g. a string \code{"int32"})} 20 | 21 | \item{..., }{ignored} 22 | 23 | \item{name}{\code{NULL} or a string. Useful for debugging in graph mode, ignored 24 | while in eager mode.} 25 | 26 | \item{shape}{an integer vector, tensor, or \code{tf.TensorShape}. Can contain up 27 | to 1 unspecified dimension, encoded as a \code{-1} or \code{NA}. This will reshape 28 | \code{x} using row-major (C-style) semantics. It will prefer reshaping using 29 | non-graph operations if possible, but will otherwise invoke \code{tf$reshape()}. 30 | If \code{x} is a scalar and the requested \code{shape} is fully defined or a tensor, 31 | the value of \code{x} will be recycled to fill a tensor of the requested shape 32 | (it will dispatch to \code{tf$fill()}).} 33 | } 34 | \value{ 35 | a tensorflow tensor 36 | } 37 | \description{ 38 | Coerce objects to tensorflow tensors (potentially of a specific dtype or shape). The 39 | provided default methods will call 40 | \href{https://www.tensorflow.org/api_docs/python/tf/convert_to_tensor}{\code{tf$convert_to_tensor}}. Depending on arguments supplied it may also call some combination of 41 | \itemize{ 42 | \item \href{https://www.tensorflow.org/api_docs/python/tf/dtypes/saturate_cast}{\code{tf$saturate_cast}} or 43 | \href{https://www.tensorflow.org/api_docs/python/tf/cast}{\code{tf$cast}} 44 | \item \href{https://www.tensorflow.org/api_docs/python/tf/fill}{\code{tf$fill}} or 45 | \href{https://www.tensorflow.org/api_docs/python/tf/reshape}{\code{tf$reshape}} 46 | } 47 | } 48 | \examples{ 49 | \dontrun{ 50 | as_tensor(42, "int32") 51 | as_tensor(as_tensor(42)) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /man/evaluate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/estimator-generics.R 3 | \name{evaluate} 4 | \alias{evaluate} 5 | \title{Evaluate a Model} 6 | \usage{ 7 | evaluate(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An evaluatable \R object.} 11 | 12 | \item{...}{Optional arguments passed on to implementing methods.} 13 | } 14 | \description{ 15 | Evaluate a model object. See implementations in the 16 | \link[keras3:evaluate.keras.src.models.model.Model]{keras3} package. 17 | } 18 | \section{Implementations}{ 19 | 20 | \itemize{ 21 | \item \link[keras3:evaluate.keras.src.models.model.Model]{keras3} 22 | } 23 | } 24 | 25 | -------------------------------------------------------------------------------- /man/export_savedmodel.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/estimator-generics.R 3 | \name{export_savedmodel} 4 | \alias{export_savedmodel} 5 | \title{Export a Saved Model} 6 | \usage{ 7 | export_savedmodel(object, export_dir_base, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An \R object.} 11 | 12 | \item{export_dir_base}{A string containing a directory in which to export the 13 | SavedModel.} 14 | 15 | \item{...}{Optional arguments passed on to implementing methods.} 16 | } 17 | \value{ 18 | The path to the exported directory, as a string. 19 | } 20 | \description{ 21 | Serialize a model to disk. See implementations in the 22 | \link[keras3:export_savedmodel.keras.src.models.model.Model]{keras3} 23 | package. 24 | } 25 | \section{Implementations}{ 26 | 27 | \itemize{ 28 | \item \link[keras3:export_savedmodel.keras.src.models.model.Model]{keras3} 29 | } 30 | } 31 | 32 | \keyword{internal} 33 | -------------------------------------------------------------------------------- /man/figures/lifecycle-archived.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: archived 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | archived 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-defunct.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: defunct 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | defunct 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-deprecated.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: deprecated 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | deprecated 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-experimental.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: experimental 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | experimental 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-maturing.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: maturing 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | maturing 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-questioning.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: questioning 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | questioning 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-soft-deprecated.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: soft-deprecated 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | soft-deprecated 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/figures/lifecycle-stable.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: stable 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 19 | 20 | lifecycle 21 | 22 | 25 | 26 | stable 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /man/figures/lifecycle-superseded.svg: -------------------------------------------------------------------------------- 1 | 2 | lifecycle: superseded 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | lifecycle 18 | 19 | superseded 20 | 21 | 22 | -------------------------------------------------------------------------------- /man/install_tensorflow.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/install.R 3 | \name{install_tensorflow} 4 | \alias{install_tensorflow} 5 | \title{Install TensorFlow and its dependencies} 6 | \usage{ 7 | install_tensorflow( 8 | method = c("auto", "virtualenv", "conda"), 9 | conda = "auto", 10 | version = "default", 11 | envname = "r-tensorflow", 12 | extra_packages = NULL, 13 | restart_session = TRUE, 14 | conda_python_version = NULL, 15 | ..., 16 | cuda = NULL, 17 | metal = FALSE, 18 | pip_ignore_installed = FALSE, 19 | new_env = identical(envname, "r-tensorflow"), 20 | python_version = NULL 21 | ) 22 | } 23 | \arguments{ 24 | \item{method}{Installation method. By default, "auto" automatically finds a 25 | method that will work in the local environment. Change the default to force 26 | a specific installation method. Note that the "virtualenv" method is not 27 | available on Windows.} 28 | 29 | \item{conda}{The path to a \code{conda} executable. Use \code{"auto"} to allow 30 | \code{reticulate} to automatically find an appropriate \code{conda} binary. 31 | See \strong{Finding Conda} and \code{\link[reticulate:conda_binary]{conda_binary()}} for more details.} 32 | 33 | \item{version}{TensorFlow version to install. Valid values include: 34 | \itemize{ 35 | \item \code{"default"} installs 2.18 36 | \item \code{"release"} installs the latest release version of tensorflow (which may 37 | be incompatible with the current version of the R package) 38 | \item A version specification like \code{"2.4"} or \code{"2.4.0"}. Note that if the patch 39 | version is not supplied, the latest patch release is installed (e.g., 40 | \code{"2.4"} today installs version "2.4.2") 41 | \item \code{nightly} for the latest available nightly build. 42 | \item To any specification, you can append "-cpu" to install the cpu version 43 | only of the package (e.g., \code{"2.4-cpu"}) 44 | \item The full URL or path to a installer binary or python *.whl file. 45 | }} 46 | 47 | \item{envname}{The name, or full path, of the environment in which Python 48 | packages are to be installed. When \code{NULL} (the default), the active 49 | environment as set by the \code{RETICULATE_PYTHON_ENV} variable will be used; 50 | if that is unset, then the \code{r-reticulate} environment will be used.} 51 | 52 | \item{extra_packages}{Additional Python packages to install along with 53 | TensorFlow.} 54 | 55 | \item{restart_session}{Restart R session after installing (note this will 56 | only occur within RStudio).} 57 | 58 | \item{conda_python_version}{Passed to conda (only applicable if \code{method = "conda"})} 59 | 60 | \item{...}{other arguments passed to \code{\link[reticulate:conda-tools]{reticulate::conda_install()}} or 61 | \code{\link[reticulate:virtualenv-tools]{reticulate::virtualenv_install()}}, depending on the \code{method} used.} 62 | 63 | \item{cuda}{logical \code{TRUE} or \code{FALSE}. If \code{install_tensorflow()} detects the 64 | platform is Linux, an Nvidia GPU is available, and the TensorFlow version 65 | is 2.14 (the default), it will install also install the required CUDA 66 | libraries through pip.} 67 | 68 | \item{metal}{Whether to install \code{tensorflow-metal} pip package on Arm Macs. 69 | This enables tensorflow to use the GPU. Pass a string to install a specific 70 | version like \verb{"tensorflow-metal==0.7.*}.} 71 | 72 | \item{pip_ignore_installed}{Whether pip should ignore installed python 73 | packages and reinstall all already installed python packages.} 74 | 75 | \item{new_env}{If \code{TRUE}, any existing Python virtual environment and/or 76 | conda environment specified by \code{envname} is deleted first.} 77 | 78 | \item{python_version}{Select the Python that will be used to create the 79 | virtualenv. Pass a string with version constraints like \code{"3.8"}, or 80 | \code{">=3.9,<=3.11"} or a file path to a \code{python} executable like 81 | \code{"/path/to/bin/python3"}. The supplied value is passed on to 82 | \code{reticulate::virtualenv_starter()}. Note that the Python version must be 83 | compatible with the requested TensorFlow version, documented here: 84 | \url{https://www.tensorflow.org/install/pip#system-requirements}} 85 | } 86 | \description{ 87 | Beginning with reticulate version 1.41, in most circumstances, calling the 88 | \code{install_tensorflow()} function is no longer necessary, because reticulate 89 | automatically registers python requirements with \code{reticulate::py_require()} 90 | when tensorflow is loaded. 91 | 92 | The Python packages registered with \code{py_require()} by the tensorflow R 93 | package: 94 | \itemize{ 95 | \item On Linux: if a GPU is detected: \code{"tensorflow[and-cuda]"}, otherwise, 96 | \code{"tensorflow-cpu"}. 97 | \item On macOS: \code{"tensorflow"} is declared. The default package is not capable 98 | of using the GPU. To enable TensorFlow usage of the GPU, call 99 | \code{reticulate::py_require("tensorflow-metal")} before reticulate has 100 | initialized Python. Note that not all features of TensorFlow work correctly 101 | if \code{tensorflow-metal} is installed. There are known issues with random number 102 | generators like \code{tf$random$stateless_uniform()}, likely others as well. 103 | \item On Windows: \code{"tensorflow"} and \code{"numpy<2"} are declared. Note that 104 | TensorFlow GPU usage on Windows is no longer supported (Since TensorFlow 105 | 2.10). To use a GPU on windows, use TensorFlow via WSL. \code{"numpy<2"} is 106 | declared because at the time of this publishing, the pre-built binaries of 107 | \code{tensorflow} for Windows are not compatible with \code{numpy>2}. 108 | } 109 | 110 | \code{install_tensorflow()} creates a new virtual environment containing the 111 | \code{tensorflow} python package and it's direct dependencies. For creating a 112 | virtual environment with more complete set packages that includes additional 113 | optional dependencies, use \code{\link[keras3:install_keras]{keras3::install_keras()}}. 114 | } 115 | \details{ 116 | You may be prompted to download and install miniconda if reticulate 117 | did not find a non-system installation of python. Miniconda is the 118 | recommended installation method for most users, as it ensures that the R 119 | python installation is isolated from other python installations. All python 120 | packages will by default be installed into a self-contained conda or venv 121 | environment named "r-reticulate". Note that "conda" is the only supported 122 | method on M1 Mac. 123 | 124 | If you initially declined the miniconda installation prompt, you can later 125 | manually install miniconda by running \code{\link[reticulate:install_miniconda]{reticulate::install_miniconda()}}. 126 | } 127 | \section{Custom Installation}{ 128 | \code{install_tensorflow()} or 129 | \code{keras3::install_keras()} isn't required to use tensorflow with the 130 | package. If you manually configure a python environment with the required 131 | dependencies, you can tell R to use it by pointing reticulate at it, 132 | commonly by setting an environment variable: 133 | 134 | \if{html}{\out{
}}\preformatted{Sys.setenv("RETICULATE_PYTHON" = "~/path/to/python-env/bin/python") 135 | }\if{html}{\out{
}} 136 | } 137 | 138 | \section{Apple Silicon}{ 139 | Beginning with Tensorflow version 2.13, the default 140 | tensorflow package now works on Apple Silicon. See 141 | \url{https://developer.apple.com/metal/tensorflow-plugin/} for instructions 142 | on how to install older versions of Tensorflow on macOS. Please note that 143 | not all operations are supported on Arm Mac GPUs. You can work around the 144 | missing operations by pinning operations to CPU. For example: 145 | 146 | \if{html}{\out{
}}\preformatted{x <- array(runif(64*64), c(1, 64, 64)) 147 | keras3::layer_random_rotation(x, .5) # Error: 148 | # No registered 'RngReadAndSkip' OpKernel for 'GPU' devices 149 | # Pin the operation to the CPU to avoid the error 150 | with(tf$device("CPU"), keras3::layer_random_rotation(x, .5) ) # No Error 151 | }\if{html}{\out{
}} 152 | } 153 | 154 | \section{Additional Packages}{ 155 | 156 | 157 | If you wish to add additional PyPI packages to your Keras / TensorFlow 158 | environment you can either specify the packages in the \code{extra_packages} 159 | argument of \code{install_tensorflow()} or \code{install_keras()}, or alternatively 160 | install them into an existing environment using the 161 | \code{\link[reticulate:py_install]{reticulate::py_install()}} function. Note that \code{install_keras()} includes a 162 | set of additional python packages by default, see \code{?keras3::install_keras} 163 | for details. 164 | } 165 | 166 | \seealso{ 167 | \itemize{ 168 | \item \code{\link[keras3:install_keras]{keras3::install_keras()}} 169 | \item \url{https://tensorflow.rstudio.com/reference/tensorflow/install_tensorflow} 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /man/parse_arguments.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/arguments.R 3 | \name{parse_arguments} 4 | \alias{parse_arguments} 5 | \title{Parse Command Line Arguments} 6 | \usage{ 7 | parse_arguments(arguments = NULL) 8 | } 9 | \arguments{ 10 | \item{arguments}{A vector of command line arguments. When 11 | \code{NULL} (the default), the command line arguments received 12 | by the current \R process are used.} 13 | } 14 | \description{ 15 | Parse command line arguments of the form \code{--key=value} and 16 | \verb{--key value}. The values are assumed to be valid \code{yaml} and 17 | will be converted using \code{\link[yaml:yaml.load]{yaml::yaml.load()}}. 18 | } 19 | -------------------------------------------------------------------------------- /man/parse_flags.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/flags.R 3 | \name{parse_flags} 4 | \alias{parse_flags} 5 | \title{Parse Configuration Flags for a TensorFlow Application} 6 | \usage{ 7 | parse_flags( 8 | config = Sys.getenv("R_CONFIG_ACTIVE", unset = "default"), 9 | file = "flags.yml", 10 | arguments = commandArgs(TRUE) 11 | ) 12 | } 13 | \arguments{ 14 | \item{config}{The configuration to use. Defaults to the 15 | active configuration for the current environment (as 16 | specified by the \code{R_CONFIG_ACTIVE} environment 17 | variable), or \code{default} when unset.} 18 | 19 | \item{file}{The configuration file to read.} 20 | 21 | \item{arguments}{The command line arguments (as a 22 | character vector) to be parsed.} 23 | } 24 | \value{ 25 | A named \R list, mapping configuration keys to values. 26 | } 27 | \description{ 28 | Parse configuration flags for a TensorFlow application. Use 29 | this to parse and unify the configuration(s) specified through 30 | a \code{flags.yml} configuration file, alongside other arguments 31 | set through the command line. 32 | } 33 | \examples{ 34 | \dontrun{ 35 | # examine an example configuration file provided by tensorflow 36 | file <- system.file("examples/config/flags.yml", package = "tensorflow") 37 | cat(readLines(file), sep = "\n") 38 | 39 | # read the default configuration 40 | FLAGS <- tensorflow::parse_flags("default", file = file) 41 | str(FLAGS) 42 | 43 | # read the alternate configuration: note that 44 | # the default configuration is inherited, but 45 | # we override the 'string' configuration here 46 | FLAGS <- tensorflow::parse_flags("alternate", file = file) 47 | str(FLAGS) 48 | 49 | # override configuration values using command 50 | # line arguments (normally, these would be 51 | # passed in through the command line invocation 52 | # used to start the process) 53 | FLAGS <- tensorflow::parse_flags( 54 | "alternate", 55 | file = file, 56 | arguments = c("--foo=1") 57 | ) 58 | str(FLAGS) 59 | 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/reexports.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{import} 7 | \alias{dict} 8 | \alias{tuple} 9 | \alias{np_array} 10 | \alias{array_reshape} 11 | \alias{iterate} 12 | \alias{\%as\%} 13 | \alias{use_python} 14 | \alias{use_virtualenv} 15 | \alias{use_condaenv} 16 | \alias{flags} 17 | \alias{flag_numeric} 18 | \alias{flag_integer} 19 | \alias{flag_string} 20 | \alias{flag_boolean} 21 | \alias{run_dir} 22 | \title{Objects exported from other packages} 23 | \keyword{internal} 24 | \description{ 25 | These objects are imported from other packages. Follow the links 26 | below to see their documentation. 27 | 28 | \describe{ 29 | \item{reticulate}{\code{\link[reticulate:with-as-operator]{\%as\%}}, \code{\link[reticulate]{array_reshape}}, \code{\link[reticulate]{dict}}, \code{\link[reticulate]{import}}, \code{\link[reticulate]{iterate}}, \code{\link[reticulate]{np_array}}, \code{\link[reticulate]{tuple}}, \code{\link[reticulate:use_python]{use_condaenv}}, \code{\link[reticulate]{use_python}}, \code{\link[reticulate:use_python]{use_virtualenv}}} 30 | 31 | \item{tfruns}{\code{\link[tfruns:flags]{flag_boolean}}, \code{\link[tfruns:flags]{flag_integer}}, \code{\link[tfruns:flags]{flag_numeric}}, \code{\link[tfruns:flags]{flag_string}}, \code{\link[tfruns]{flags}}, \code{\link[tfruns]{run_dir}}} 32 | }} 33 | 34 | -------------------------------------------------------------------------------- /man/set_random_seed.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/seed.R 3 | \name{set_random_seed} 4 | \alias{set_random_seed} 5 | \title{Set random seed for TensorFlow} 6 | \usage{ 7 | set_random_seed(seed, disable_gpu = TRUE) 8 | } 9 | \arguments{ 10 | \item{seed}{A single value, interpreted as an integer} 11 | 12 | \item{disable_gpu}{\code{TRUE} to disable GPU execution (see \emph{Parallelism} below).} 13 | } 14 | \description{ 15 | Sets all random seeds needed to make TensorFlow code reproducible. 16 | } 17 | \details{ 18 | This function should be used instead of \code{\link[=use_session_with_seed]{use_session_with_seed()}} if 19 | you are using TensorFlow >= 2.0, as the concept of \code{session} doesn't 20 | really make sense anymore. 21 | 22 | This functions sets: 23 | \itemize{ 24 | \item The R random seed with \code{\link[=set.seed]{set.seed()}}. 25 | \item The python and Numpy seeds via (\code{\link[reticulate:py_set_seed]{reticulate::py_set_seed()}}). 26 | \item The TensorFlow seed with (\code{tf$random$set_seed()}) 27 | } 28 | 29 | It also optionally disables the GPU execution as this is a potential 30 | source of non-reproducibility. 31 | } 32 | -------------------------------------------------------------------------------- /man/shape.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/shape.R 3 | \name{shape} 4 | \alias{shape} 5 | \title{Create a \code{tf.TensorShape} object} 6 | \usage{ 7 | shape(..., dims = list(...)) 8 | } 9 | \arguments{ 10 | \item{...}{Tensor dimensions as integers or \code{NULL} for an unknown 11 | dimensions. \code{NA} and \code{-1} are synonyms for \code{NULL}.} 12 | 13 | \item{dims}{Tensor dimensions as a vector.} 14 | } 15 | \description{ 16 | Create a \code{tf.TensorShape} object 17 | } 18 | \examples{ 19 | \dontrun{ 20 | 21 | # --- construct --- 22 | shape() # tf.TensorShape() # scalar 23 | shape(NULL) # tf.TensorShape([None]) # 1-D array of unknown length 24 | shape(NA) # tf.TensorShape([None]) # 1-D array of unknown length, NA is a synonym for NULL 25 | 26 | shape(dims = NULL) # TensorShape(None) # Unknown rank, unknown size 27 | shape(3, 4) # TensorShape([3, 4]) # 2-D array (matrix) with 3 rows, 4 columns 28 | shape(NA, 4) # TensorShape([None, 4]) # 2-D array (matrix) with unknown rows, 4 columns 29 | shape(dims = c(NA, 4)) # TensorShape([None, 4]) # same as above; bypass ... and pass dims directly 30 | 31 | # --- inspect --- 32 | length(shape(dims = NULL)) # NA_integer_ 33 | length(shape(1,2,3,NA)) # 4L 34 | 35 | # ---convert --- 36 | x <- shape(dims = list(3L, 5L)) 37 | as.list(x) # list(3L, 5L) 38 | as.integer(x) # c(3L, 5L) 39 | as.numeric(x) # c(3, 5) 40 | as.double(x) # c(3, 5) # alias for as.numeric 41 | as_tensor(x) # tf.Tensor([3 5], shape=(2,), dtype=int32) 42 | 43 | # convert partially undefined shapes 44 | x <- shape(NA, 3) 45 | as.list(x) # list(NULL, 3L) 46 | as.integer(x) # c(NA, 3L) 47 | as_tensor(x) # tf.Tensor([-1 3], shape=(2,), dtype=int32) # unspecified dims default is -1 48 | 49 | # as_tensor() converts undefined dimensions to -1, which is useful for 50 | # tf functions that only accept tensors for shapes, e.g, 51 | tf$reshape(tf$zeros(shape(8)), 52 | as_tensor(shape(NA, 4))) 53 | # tf.Tensor([[0. 0. 0. 0.] 54 | # [0. 0. 0. 0.]], shape=(2, 4), dtype=float32) 55 | 56 | # converting fully unknown shapes raises an error 57 | try(as.list(shape(dims = NULL))) # ValueError: as_list() is not defined on an unknown TensorShape. 58 | # test for rank first if this a concern: 59 | as.list_or_null <- function(x) if(is.na(length(x))) NULL else as.list(x) 60 | as.list_or_null(shape(dims = NULL)) 61 | 62 | 63 | # --- compare --- 64 | # Fully known shapes return TRUE if and only if each element is equal 65 | shape(3, 4) == shape(3, 4) # TRUE 66 | shape(3, 4) == shape(4, 4) # FALSE 67 | 68 | # two unknown dimensions are treated as equal 69 | shape(NA, 4) == shape(NA, 4) # TRUE 70 | shape(NA, 4) == shape(3, 4) # FALSE 71 | 72 | # Two unknown shapes, return TRUE 73 | shape(dims = NULL) == shape(dims = NULL) # TRUE 74 | 75 | # Comparing an unknown shape to a partially or fully defined shape returns FALSE 76 | shape(dims = NULL) == shape(NULL) # FALSE 77 | shape(dims = NULL) == shape(4) # FALSE 78 | 79 | 80 | values of length greater than one supplied to `...` are automatically flattened 81 | shape(1, c(2, 3), 4) # shape(1, 2, 3, 4) 82 | shape(1, shape(2, 3), 4) # shape(1, 2, 3, 4) 83 | shape(1, as_tensor(2, 3), 4) # shape(1, 2, 3, 4) 84 | 85 | # --- extract or replace --- 86 | # regular R-list semantics for `[`, `[[`, `[<-`, `[[<-` 87 | x <- shape(1, 2, 3) 88 | x[1] # TensorShape([1]) 89 | x[[1]] # 1L 90 | x[2:3] # TensorShape([2, 3]) 91 | x[-1] # TensorShape([2, 3]) 92 | 93 | x[1] <- 11 ; x # TensorShape([11, 2, 3]) 94 | x[1] <- shape(11) ; x # TensorShape([11, 2, 3]) 95 | x[1] <- list(11) ; x # TensorShape([11, 2, 3]) 96 | 97 | x[[1]] <- 22 ; x # TensorShape([22, 2, 3]) 98 | x[1:2] <- c(NA, 99) ; x # TensorShape([None, 99, 3]) 99 | x[1:2] <- shape(33, 44) ; x # TensorShape([33, 44, 3]) 100 | 101 | # --- concatenate --- 102 | c(shape(1), shape(2, 3), shape(4, NA)) # TensorShape([1, 2, 3, 4, None]) 103 | 104 | # --- merge --- 105 | merge(shape(NA, 2), 106 | shape(1 , 2)) # TensorShape([1, 2]) 107 | 108 | try(merge(shape(2, 2), 109 | shape(1, 2))) # ValueError: Shapes (2, 2) and (1, 2) are not compatible 110 | 111 | rm(x) # cleanup 112 | } 113 | } 114 | \seealso{ 115 | \url{https://www.tensorflow.org/api_docs/python/tf/TensorShape} 116 | } 117 | -------------------------------------------------------------------------------- /man/sub-.tensorflow.tensor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/extract.R 3 | \name{[.tensorflow.tensor} 4 | \alias{[.tensorflow.tensor} 5 | \title{Subset tensors with \code{[}} 6 | \usage{ 7 | \method{[}{tensorflow.tensor}( 8 | x, 9 | ..., 10 | drop = TRUE, 11 | style = getOption("tensorflow.extract.style"), 12 | options = tf_extract_opts(style) 13 | ) 14 | } 15 | \arguments{ 16 | \item{x}{Tensorflow tensor} 17 | 18 | \item{...}{slicing specs. See examples and details.} 19 | 20 | \item{drop}{whether to drop scalar dimensions} 21 | 22 | \item{style}{One of \code{"python"} or \code{"R"}.} 23 | 24 | \item{options}{An object returned by \code{tf_extract_opts()}} 25 | } 26 | \description{ 27 | Subset tensors with \code{[} 28 | } 29 | \examples{ 30 | \dontrun{ 31 | 32 | x <- as_tensor(array(1:15, dim = c(3, 5))) 33 | x 34 | # by default, numerics supplied to [...] are interpreted R style 35 | x[,1] # first column 36 | x[1:2,] # first two rows 37 | x[,1, drop = FALSE] # 1 column matrix 38 | 39 | # strided steps can be specified in R syntax or python syntax 40 | x[, seq(1, 5, by = 2)] 41 | x[, 1:5:2] 42 | # if you are unfamiliar with python-style strided steps, see: 43 | # https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing 44 | 45 | # missing arguments for python syntax are valid, but they must by backticked 46 | # or supplied as NULL 47 | x[, `::2`] 48 | x[, NULL:NULL:2] 49 | x[, `2:`] 50 | 51 | 52 | # all_dims() expands to the shape of the tensor 53 | # (equivalent to a python ellipsis `...`) 54 | # (not to be confused with R dots `...`) 55 | y <- as_tensor(array(1:(3^5), dim = c(3,3,3,3,3))) 56 | all.equal(y[all_dims(), 1], 57 | y[, , , , 1]) 58 | 59 | # tf$newaxis are valid (equivalent to a NULL) 60 | x[,, tf$newaxis] 61 | x[,, NULL] 62 | 63 | 64 | # negative numbers are always interpreted python style 65 | # The first time a negative number is supplied to `[`, a warning is issued 66 | # about the non-standard behavior. 67 | x[-1,] # last row, with a warning 68 | x[-1,] # the warning is only issued once 69 | 70 | # specifying `style = 'python'` changes the following: 71 | # + zero-based indexing is used 72 | # + slice sequences in the form of `start:stop` do not include `stop` 73 | # in the returned value 74 | # + out-of-bounds indices in a slice are valid 75 | 76 | # The style argument can be supplied to individual calls of `[` or set 77 | # as a global option 78 | 79 | # example of zero based indexing 80 | x[0, , style = 'python'] # first row 81 | x[1, , style = 'python'] # second row 82 | 83 | # example of slices with exclusive stop 84 | options(tensorflow.extract.style = 'python') 85 | x[, 0:1] # just the first column 86 | x[, 0:2] # first and second column 87 | 88 | # example of out-of-bounds index 89 | x[, 0:10] 90 | options(tensorflow.extract.style = NULL) 91 | 92 | # slicing with tensors is valid too, but note, tensors are never 93 | # translated and are always interpreted python-style. 94 | # A warning is issued the first time a tensor is passed to `[` 95 | x[, tf$constant(0L):tf$constant(2L)] 96 | # just as in python, only scalar tensors are valid 97 | # https://www.tensorflow.org/api_docs/python/tf/Tensor#__getitem__ 98 | 99 | # To silence the warnings about tensors being passed as-is and negative numbers 100 | # being interpreted python-style, set 101 | options(tensorflow.extract.style = 'R') 102 | 103 | # clean up from examples 104 | options(tensorflow.extract.style = NULL) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /man/tensorboard.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/tensorboard.R 3 | \name{tensorboard} 4 | \alias{tensorboard} 5 | \title{TensorBoard Visualization Tool} 6 | \usage{ 7 | tensorboard( 8 | log_dir, 9 | action = c("start", "stop"), 10 | host = "127.0.0.1", 11 | port = "auto", 12 | launch_browser = getOption("tensorflow.tensorboard.browser", interactive()), 13 | reload_interval = 5, 14 | purge_orphaned_data = TRUE 15 | ) 16 | } 17 | \arguments{ 18 | \item{log_dir}{Directories to scan for training logs. If this is a named 19 | character vector then the specified names will be used as aliases within 20 | TensorBoard.} 21 | 22 | \item{action}{Specify whether to start or stop TensorBoard (TensorBoard will 23 | be stopped automatically when the R session from which it is launched is 24 | terminated).} 25 | 26 | \item{host}{Host for serving TensorBoard} 27 | 28 | \item{port}{Port for serving TensorBoard. If "auto" is specified (the 29 | default) then an unused port will be chosen automatically.} 30 | 31 | \item{launch_browser}{Open a web browser for TensorBoard after launching. 32 | Defaults to \code{TRUE} in interactive sessions. When running under RStudio uses 33 | an RStudio window by default (pass a function e.g. \code{\link[utils:browseURL]{utils::browseURL()}} to 34 | open in an external browser). Use the \code{tensorflow.tensorboard.browser} 35 | option to establish a global default behavior.} 36 | 37 | \item{reload_interval}{How often the backend should load more data.} 38 | 39 | \item{purge_orphaned_data}{Whether to purge data that may have been orphaned 40 | due to TensorBoard restarts. Disabling purge_orphaned_data can be used to 41 | debug data disappearance.} 42 | } 43 | \value{ 44 | URL for browsing TensorBoard (invisibly). 45 | } 46 | \description{ 47 | TensorBoard is a tool inspecting and understanding your TensorFlow runs and 48 | graphs. 49 | } 50 | \details{ 51 | When TensorBoard is passed a logdir at startup, it recursively walks 52 | the directory tree rooted at logdir looking for subdirectories that contain 53 | tfevents data. Every time it encounters such a subdirectory, it loads it as 54 | a new run, and the frontend will organize the data accordingly. 55 | 56 | The TensorBoard process will be automatically destroyed when the R session 57 | in which it is launched exits. You can pass \code{action = "stop"} to manually 58 | terminate TensorBoard. 59 | } 60 | -------------------------------------------------------------------------------- /man/tensorflow.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/package.R 3 | \docType{package} 4 | \name{tensorflow} 5 | \alias{tensorflow} 6 | \alias{tensorflow-package} 7 | \title{TensorFlow for R} 8 | \description{ 9 | \href{https://www.tensorflow.org}{TensorFlow} is an open source software library 10 | for numerical computation using data flow graphs. Nodes in the graph 11 | represent mathematical operations, while the graph edges represent the 12 | multidimensional data arrays (tensors) communicated between them. The 13 | flexible architecture allows you to deploy computation to one or more CPUs or 14 | GPUs in a desktop, server, or mobile device with a single API. 15 | } 16 | \details{ 17 | The \href{https://www.tensorflow.org/api_docs/python/tf/all_symbols}{TensorFlow 18 | API} is composed of a set of Python modules that enable constructing and 19 | executing TensorFlow graphs. The tensorflow package provides access to the 20 | complete TensorFlow API from within R. 21 | 22 | For additional documentation on the tensorflow package see 23 | \href{https://tensorflow.rstudio.com}{https://tensorflow.rstudio.com} 24 | } 25 | \seealso{ 26 | Useful links: 27 | \itemize{ 28 | \item \url{https://github.com/rstudio/tensorflow} 29 | \item Report bugs at \url{https://github.com/rstudio/tensorflow/issues} 30 | } 31 | 32 | } 33 | \author{ 34 | \strong{Maintainer}: Tomasz Kalinowski \email{tomasz.kalinowski@posit.co} [contributor, copyright holder] 35 | 36 | Authors: 37 | \itemize{ 38 | \item JJ Allaire [copyright holder] 39 | \item Yuan Tang \email{terrytangyuan@gmail.com} (\href{https://orcid.org/0000-0001-5243-233X}{ORCID}) [copyright holder] 40 | } 41 | 42 | Other contributors: 43 | \itemize{ 44 | \item Daniel Falbel \email{daniel@posit.co} [contributor, copyright holder] 45 | \item Dirk Eddelbuettel \email{edd@debian.org} [contributor, copyright holder] 46 | \item Nick Golding \email{nick.golding.research@gmail.com} [contributor, copyright holder] 47 | \item Google Inc. (Examples and Tutorials) [contributor, copyright holder] 48 | \item Posit, PBC [copyright holder, funder] 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /man/tf.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/modules.R 3 | \docType{data} 4 | \name{tf} 5 | \alias{tf} 6 | \title{Main TensorFlow module} 7 | \format{ 8 | TensorFlow module 9 | } 10 | \usage{ 11 | tf 12 | } 13 | \description{ 14 | Interface to main TensorFlow module. Provides access to top level classes 15 | and functions as well as sub-modules (e.g. \code{tf$nn}, 16 | \code{tf$contrib$learn}, etc.). 17 | } 18 | \examples{ 19 | \dontrun{ 20 | library(tensorflow) 21 | 22 | hello <- tf$constant('Hello, TensorFlow!') 23 | zeros <- tf$Variable(tf$zeros(shape(1L))) 24 | 25 | tf$print(hello) 26 | tf$print(zeros) 27 | } 28 | } 29 | \keyword{datasets} 30 | -------------------------------------------------------------------------------- /man/tf_config.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/package.R 3 | \name{tf_config} 4 | \alias{tf_config} 5 | \alias{tf_version} 6 | \title{TensorFlow configuration information} 7 | \usage{ 8 | tf_config() 9 | 10 | tf_version() 11 | } 12 | \value{ 13 | List with information on the current configuration of TensorFlow. 14 | You can determine whether TensorFlow was found using the \code{available} 15 | member (other members vary depending on whether \code{available} is \code{TRUE} 16 | or \code{FALSE}) 17 | } 18 | \description{ 19 | TensorFlow configuration information 20 | } 21 | \keyword{internal} 22 | -------------------------------------------------------------------------------- /man/tf_extract_opts.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/extract.R 3 | \name{tf_extract_opts} 4 | \alias{tf_extract_opts} 5 | \title{Tensor extract options} 6 | \usage{ 7 | tf_extract_opts( 8 | style = getOption("tensorflow.extract.style"), 9 | ..., 10 | one_based = getOption("tensorflow.extract.one_based", TRUE), 11 | inclusive_stop = getOption("tensorflow.extract.inclusive_stop", TRUE), 12 | disallow_out_of_bounds = getOption("tensorflow.extract.dissallow_out_of_bounds", TRUE), 13 | warn_tensors_passed_asis = getOption("tensorflow.extract.warn_tensors_passed_asis", 14 | TRUE), 15 | warn_negatives_pythonic = getOption("tensorflow.extract.warn_negatives_pythonic", TRUE) 16 | ) 17 | } 18 | \arguments{ 19 | \item{style}{one of \code{NULL} (the default) \code{"R"} or \code{"python"}. If supplied, 20 | this overrides all other options. \code{"python"} is equivalent to all the other 21 | arguments being \code{FALSE}. \code{"R"} is equivalent to 22 | \code{warn_tensors_passed_asis} and \code{warn_negatives_pythonic} 23 | set to \code{FALSE}} 24 | 25 | \item{...}{ignored} 26 | 27 | \item{one_based}{TRUE or FALSE, if one-based indexing should be used} 28 | 29 | \item{inclusive_stop}{TRUE or FALSE, if slices like \code{start:stop} should be 30 | inclusive of \code{stop}} 31 | 32 | \item{disallow_out_of_bounds}{TRUE or FALSE, whether checks are performed on 33 | the slicing index to ensure it is within bounds.} 34 | 35 | \item{warn_tensors_passed_asis}{TRUE or FALSE, whether to emit a warning the 36 | first time a tensor is supplied to \code{[} that tensors are passed as-is, with 37 | no R to python translation} 38 | 39 | \item{warn_negatives_pythonic}{TRUE or FALSE, whether to emit 40 | a warning the first time a negative number is supplied to \code{[} about the 41 | non-standard (python-style) interpretation} 42 | } 43 | \value{ 44 | an object with class "tf_extract_opts", suitable for passing to 45 | \verb{[.tensorflow.tensor()} 46 | } 47 | \description{ 48 | Tensor extract options 49 | } 50 | \examples{ 51 | \dontrun{ 52 | x <- tf$constant(1:10) 53 | 54 | opts <- tf_extract_opts("R") 55 | x[1, options = opts] 56 | 57 | # or for more fine-grained control 58 | opts <- tf_extract_opts( 59 | one_based = FALSE, 60 | warn_tensors_passed_asis = FALSE, 61 | warn_negatives_pythonic = FALSE 62 | ) 63 | x[0:2, options = opts] 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /man/tf_function.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/eager.R 3 | \name{tf_function} 4 | \alias{tf_function} 5 | \title{Creates a callable TensorFlow graph from an R function.} 6 | \usage{ 7 | tf_function(f, input_signature = NULL, autograph = TRUE, ...) 8 | } 9 | \arguments{ 10 | \item{f}{the function to be compiled} 11 | 12 | \item{input_signature}{A possibly nested sequence of \code{tf$TensorSpec} objects 13 | specifying the shapes and dtypes of the tensors that will be supplied to 14 | this function. If \code{NULL}, a separate function is instantiated for each 15 | inferred input signature. If \code{input_signature} is specified, every input to 16 | \code{f} must be a tensor.} 17 | 18 | \item{autograph}{TRUE or FALSE. If TRUE (the default), you can use tensors in 19 | R control flow expressions \code{if}, \code{while}, \code{for} and \code{break} and they will 20 | be traced into the tensorflow graph. A guide to getting started and 21 | additional details can be found: 22 | \href{https://t-kalinowski.github.io/tfautograph/}{here}} 23 | 24 | \item{...}{additional arguments passed on to \code{tf.function} (vary based on 25 | Tensorflow version). See 26 | \href{https://www.tensorflow.org/api_docs/python/tf/function#args_1}{here} for 27 | details.} 28 | } 29 | \description{ 30 | \code{tf_function} constructs a callable that executes a TensorFlow graph created 31 | by tracing the TensorFlow operations in \code{f}. This allows the TensorFlow 32 | runtime to apply optimizations and exploit parallelism in the computation 33 | defined by \code{f}. 34 | } 35 | \details{ 36 | A guide to getting started with 37 | \href{https://www.tensorflow.org/api_docs/python/tf/function}{\code{tf.function}} can 38 | be found \href{https://www.tensorflow.org/guide/function}{here}. 39 | } 40 | -------------------------------------------------------------------------------- /man/tf_gpu_configured.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/package.R 3 | \name{tf_gpu_configured} 4 | \alias{tf_gpu_configured} 5 | \title{TensorFlow GPU configuration information} 6 | \usage{ 7 | tf_gpu_configured(verbose = TRUE) 8 | } 9 | \arguments{ 10 | \item{verbose}{boolean. Whether to show extra GPU info.} 11 | } 12 | \value{ 13 | A bool, whether GPU is configured or not, or NA if could not be 14 | determined. 15 | } 16 | \description{ 17 | TensorFlow GPU configuration information 18 | } 19 | \keyword{internal} 20 | -------------------------------------------------------------------------------- /man/tf_probability.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/probability.R 3 | \name{tf_probability} 4 | \alias{tf_probability} 5 | \title{TensorFlow Probability Module} 6 | \usage{ 7 | tf_probability() 8 | } 9 | \value{ 10 | Reference to \href{https://www.tensorflow.org/probability}{TensorFlow Probability} 11 | functions and classes 12 | } 13 | \description{ 14 | TensorFlow Probability Module 15 | } 16 | \examples{ 17 | \dontrun{ 18 | library(tensorflow) 19 | ## one time setup: 20 | # reticulate::py_install("tensorflow_probability") 21 | tfp <- tf_probability() 22 | tfp$distributions$Normal(loc = 0, scale = 1) 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /man/train.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/estimator-generics.R 3 | \name{train} 4 | \alias{train} 5 | \title{(Deprecated) Train a Model} 6 | \usage{ 7 | train(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{A trainable \R object.} 11 | 12 | \item{...}{Optional arguments passed on to implementing methods. 13 | 14 | \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}}} 15 | } 16 | \description{ 17 | Train a model object. See implementation in the 18 | \code{tfestimators::train.tf_estimator()} package. 19 | } 20 | \keyword{internal} 21 | -------------------------------------------------------------------------------- /man/train_and_evaluate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/estimator-generics.R 3 | \name{train_and_evaluate} 4 | \alias{train_and_evaluate} 5 | \title{(Deprecated) Simultaneously Train and Evaluate a Model} 6 | \usage{ 7 | train_and_evaluate(object, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An \R object.} 11 | 12 | \item{...}{Optional arguments passed on to implementing methods. 13 | 14 | \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}}} 15 | } 16 | \description{ 17 | Train and evaluate a model object. See implementation in the 18 | \code{tfestimators::train_and_evaluate.tf_estimator()} package. 19 | } 20 | \keyword{internal} 21 | -------------------------------------------------------------------------------- /man/use_compat.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/compat.R 3 | \name{use_compat} 4 | \alias{use_compat} 5 | \title{Use Compatibility} 6 | \usage{ 7 | use_compat(version = c("v1", "v2")) 8 | } 9 | \arguments{ 10 | \item{version}{The version to activate. Must be \code{"v1"} or \code{"v2"}} 11 | } 12 | \description{ 13 | Enables TensorFlow to run under a different API version for compatibility 14 | with previous versions. For instance, this is useful to run TensorFlow 1.x 15 | code when using TensorFlow 2.x. 16 | } 17 | \examples{ 18 | \dontrun{ 19 | library(tensorflow) 20 | use_compat("v1") 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /man/use_session_with_seed.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/seed.R 3 | \name{use_session_with_seed} 4 | \alias{use_session_with_seed} 5 | \title{Use a session with a random seed} 6 | \usage{ 7 | use_session_with_seed( 8 | seed, 9 | disable_gpu = TRUE, 10 | disable_parallel_cpu = TRUE, 11 | quiet = FALSE 12 | ) 13 | } 14 | \arguments{ 15 | \item{seed}{A single value, interpreted as an integer} 16 | 17 | \item{disable_gpu}{\code{TRUE} to disable GPU execution (see \emph{Parallelism} below).} 18 | 19 | \item{disable_parallel_cpu}{\code{TRUE} to disable CPU parallelism (see 20 | \emph{Parallelism} below).} 21 | 22 | \item{quiet}{\code{TRUE} to suppress printing of messages.} 23 | } 24 | \value{ 25 | TensorFlow session object, invisibly 26 | } 27 | \description{ 28 | Set various random seeds required to ensure reproducible results. The 29 | provided \code{seed} value will establish a new random seed for R, Python, NumPy, 30 | and TensorFlow. GPU computations and CPU parallelism will also be disabled by 31 | default. 32 | } 33 | \details{ 34 | This function must be called at the very top of your script (i.e. 35 | immediately after \code{library(tensorflow)}, \code{library(keras)}, etc.). Any 36 | existing TensorFlow session is torn down via \code{tf$reset_default_graph()}. 37 | 38 | This function takes all measures known to promote reproducible results from 39 | TensorFlow sessions, however it's possible that various individual 40 | TensorFlow features or dependent libraries escape its effects. If you 41 | encounter non-reproducible results please investigate the possible sources 42 | of the problem, contributions via pull request are very welcome! 43 | 44 | Packages which need to be notified before and after the seed is set 45 | can register for the "tensorflow.on_before_use_session" and 46 | "tensorflow.on_use_session" hooks (see \code{\link[=setHook]{setHook()}}) for additional 47 | details on hooks). 48 | } 49 | \section{Parallelism}{ 50 | By default the \code{use_session_with_seed()} function 51 | disables GPU and CPU parallelism, since both can result in 52 | non-deterministic execution patterns (see 53 | \url{https://stackoverflow.com/questions/42022950/}). You can optionally enable 54 | GPU or CPU parallelism by setting the \code{disable_gpu} and/or 55 | \code{disable_parallel_cpu} parameters to \code{FALSE}. 56 | } 57 | 58 | \examples{ 59 | \dontrun{ 60 | library(tensorflow) 61 | use_session_with_seed(42) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /man/view_savedmodel.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/save.R 3 | \name{view_savedmodel} 4 | \alias{view_savedmodel} 5 | \title{View a Saved Model} 6 | \usage{ 7 | view_savedmodel(model_dir) 8 | } 9 | \arguments{ 10 | \item{model_dir}{The path to the exported model, as a string.} 11 | } 12 | \value{ 13 | URL for browsing TensorBoard (invisibly). 14 | } 15 | \description{ 16 | View a serialized model from disk. 17 | } 18 | -------------------------------------------------------------------------------- /pkgdown/_pkgdown.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | reference: 4 | - title: "Installation" 5 | contents: 6 | - install_tensorflow 7 | - tf_version 8 | - tf_config 9 | 10 | - title: "Configuration" 11 | contents: 12 | - parse_flags 13 | - parse_arguments 14 | 15 | - title: "TensorFlow API" 16 | contents: 17 | - tf 18 | - shape 19 | 20 | - title: "Utilities" 21 | contents: 22 | - tensorboard 23 | - use_session_with_seed 24 | 25 | -------------------------------------------------------------------------------- /tensorflow.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | ProjectId: 9b7991ac-25ac-456a-aa6e-b7af13061228 3 | 4 | RestoreWorkspace: No 5 | SaveWorkspace: No 6 | AlwaysSaveHistory: Yes 7 | 8 | EnableCodeIndexing: Yes 9 | UseSpacesForTab: Yes 10 | NumSpacesForTab: 2 11 | Encoding: UTF-8 12 | 13 | RnwWeave: Sweave 14 | LaTeX: pdfLaTeX 15 | 16 | AutoAppendNewline: Yes 17 | StripTrailingWhitespace: Yes 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(tensorflow) 3 | 4 | if (identical(Sys.getenv("NOT_CRAN"), "true")) 5 | test_check("tensorflow") 6 | -------------------------------------------------------------------------------- /tests/testthat/.gitignore: -------------------------------------------------------------------------------- 1 | MNIST-data 2 | -------------------------------------------------------------------------------- /tests/testthat/helper-utils.R: -------------------------------------------------------------------------------- 1 | Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 1) 2 | options(warn = 1L) 3 | 4 | if(py_module_available("tensorflow")) 5 | tf$abs(1) # initialize on load_all() 6 | 7 | .SESS <- NULL 8 | grab <- function(x) { 9 | if(!inherits(x, "tensorflow.tensor")) 10 | return(x) 11 | 12 | if(tf$executing_eagerly()) 13 | return(as.array(x)) 14 | 15 | if (is.null(.SESS)) { 16 | if (tf_version() >= "1.14") 17 | .SESS <<- tf$compat$v1$Session() 18 | else 19 | .SESS <<- tf$Session() 20 | } 21 | 22 | .SESS$run(x) 23 | } 24 | 25 | skip_if_no_tensorflow <- function() { 26 | if (!reticulate::py_module_available("tensorflow")) 27 | skip("TensorFlow not available for testing") 28 | } 29 | 30 | 31 | arr <- function(..., mode = "double", gen = seq_len) 32 | array(as.vector(gen(prod(unlist(c(...)))), mode = mode), unlist(c(...))) 33 | 34 | set.seed(42) 35 | rarr <- function(...) arr(..., gen=runif) 36 | 37 | expect_near <- function(..., tol = 1e-5) expect_equal(..., tolerance = tol) 38 | 39 | 40 | suppress_warning_NaNs_produced <- function(expr) { 41 | withCallingHandlers( 42 | expr, 43 | warning = function(w) { 44 | if(inherits(w, "warning") && grepl("NaNs produced", w$message)) 45 | invokeRestart("muffleWarning") 46 | }) 47 | } 48 | -------------------------------------------------------------------------------- /tests/testthat/setup.R: -------------------------------------------------------------------------------- 1 | 2 | clean_python_tmp_dir <- function() { 3 | if (!reticulate::py_available()) 4 | return() 5 | 6 | tryCatch({ 7 | python_temp_dir <- dirname( 8 | reticulate::py_run_string( 9 | "import tempfile; x=tempfile.NamedTemporaryFile().name", 10 | local = TRUE 11 | )$x 12 | ) 13 | detritus <- list.files(path = python_temp_dir, 14 | pattern = "__autograph_generated_file|__pycache__", 15 | all.files = TRUE, include.dirs = TRUE, no.. = TRUE, 16 | full.names = TRUE) 17 | 18 | if(length(detritus)) { 19 | # cat("Unlinking:\n", 20 | # paste("-", detritus, "\n"), sep = "") 21 | unlink(detritus, TRUE, TRUE) 22 | } 23 | }, 24 | 25 | error = function(e) { 26 | warning(e) 27 | }) 28 | } 29 | 30 | withr::defer(clean_python_tmp_dir(), teardown_env()) 31 | -------------------------------------------------------------------------------- /tests/testthat/test-arguments.R: -------------------------------------------------------------------------------- 1 | context("Arguments") 2 | 3 | expect_arguments <- function(arguments, expected) { 4 | splat <- strsplit(arguments, "[[:space:]]+")[[1]] 5 | parsed <- parse_arguments(splat) 6 | expect_equal(parsed, expected) 7 | } 8 | 9 | test_that("sample command line arguments are parsed as expected", { 10 | 11 | expect_arguments( 12 | "--alpha=beta --gamma=delta", 13 | list(alpha = "beta", gamma = "delta") 14 | ) 15 | 16 | expect_arguments( 17 | "--alpha beta --gamma delta", 18 | list(alpha = "beta", gamma = "delta") 19 | ) 20 | 21 | expect_arguments( 22 | "--nested=a=b", 23 | list(nested = "a=b") 24 | ) 25 | 26 | expect_arguments( 27 | "--array=[1,2,3]", 28 | list(array = c(1, 2, 3)) 29 | ) 30 | 31 | expect_arguments( 32 | "--number=1000", 33 | list(number = 1000) 34 | ) 35 | 36 | expect_arguments( 37 | "--nested-dashes=1000", 38 | list(nested_dashes = 1000) 39 | ) 40 | 41 | }) 42 | -------------------------------------------------------------------------------- /tests/testthat/test-as_tensor.R: -------------------------------------------------------------------------------- 1 | test_that("as_tensor works", { 2 | 3 | skip_if_no_tensorflow() 4 | 5 | test_is_tensor <- function(x, dtype=NULL) { 6 | expect("tensorflow.tensor" %in% class(x), 7 | paste("Wrong S3 class, expected 'tensorflow.tensor', actual", class(x))) 8 | 9 | failure_message <- sprintf( 10 | "wrong type attributes. expected '%s', encountered '%s'", dtype, x$dtype) 11 | if(is.character(dtype)) 12 | expect(x$dtype[[dtype]], failure_message) 13 | else if(!is.null(dtype)) 14 | expect(x$dtype == tf$as_dtype(dtype), failure_message) 15 | } 16 | 17 | test_is_tensor(as_tensor(3), 'is_floating') 18 | test_is_tensor(as_tensor(3L), tf$int32) 19 | test_is_tensor(as_tensor("foo"), tf$string) 20 | test_is_tensor(as_tensor(TRUE), tf$bool) 21 | test_is_tensor(as_tensor(1+1i), 'is_complex') 22 | 23 | test_is_tensor(as_tensor(3L, tf$int32) , tf$int32) 24 | test_is_tensor(as_tensor(3L, tf$int64) , tf$int64) 25 | test_is_tensor(as_tensor(3L, tf$float32) , tf$float32) 26 | test_is_tensor(as_tensor(3L, tf$float64) , tf$float64) 27 | test_is_tensor(as_tensor(3L, tf$int8) , tf$int8) 28 | 29 | test_is_tensor(as_tensor(3.0, tf$float32), tf$float32) 30 | test_is_tensor(as_tensor(3.0, tf$float64), tf$float64) 31 | test_is_tensor(as_tensor(3.0, tf$int32) , tf$int32) 32 | test_is_tensor(as_tensor(3.0, tf$int64) , tf$int64) 33 | test_is_tensor(as_tensor(3.0, tf$int8) , tf$int8) 34 | 35 | # currently scalars -> float32; arrays -> float64 36 | test_is_tensor(as_tensor(arr(3)) , 'is_floating') 37 | test_is_tensor(as_tensor(arr(3, 3)) , 'is_floating') 38 | test_is_tensor(as_tensor(arr(3, 3, 3)) , 'is_floating') 39 | 40 | x <- tf$constant(3) 41 | test_is_tensor(as_tensor(x, tf$int32), tf$int32) 42 | test_is_tensor(as_tensor(x, tf$int64), tf$int64) 43 | 44 | shps <- list(c(-1, 4), 45 | shape(-1, 4), 46 | c(NA, 4), 47 | list(NULL, 4), 48 | list(3, 4), 49 | as_tensor(c(-1L, 4L))) 50 | for (shp in shps) { 51 | x <- as_tensor(1:12, shape = shp) 52 | expect_identical(dim(x), c(3L, 4L)) 53 | } 54 | 55 | 56 | # can call tf$fill() to expand scalars 57 | expect_identical( 58 | tf$convert_to_tensor(array(0, c(3,4)))$numpy(), 59 | as_tensor(0, shape = c(3, 4))$numpy() 60 | ) 61 | 62 | expect_identical( 63 | tf$zeros(shape(3,4))$numpy(), 64 | as_tensor(0, shape = c(3, 4))$numpy() 65 | ) 66 | 67 | # atomic vectors are converted to python as R arrays 68 | x <- as_tensor(1:6, shape = c(2, 3)) 69 | expect_equal(x$dtype$name, "int32") # a list of ints would convert to int64 70 | 71 | 72 | i <- as_tensor(seq(4)) 73 | j <- as_tensor(rep(3L, 4)) 74 | r <- as_tensor(i >= j, dtype = "int32") 75 | expect_equal(r$dtype$name, "int32") # casting bool to int doesn't do a saturate cast 76 | 77 | r <- as_tensor(i-1, "bool") 78 | expect_equal(as.logical(r), c(FALSE, TRUE, TRUE, TRUE)) 79 | 80 | x <- as_tensor(c(-1, 0, 1, 254, 255, 256), "uint8") 81 | expect_equal(as.integer(x), as.integer(c(0,0, 1, 254, 255, 255))) 82 | 83 | x <- as_tensor(c(-1, 0, 1, 254, 255, 256)) 84 | x <- as_tensor(x, "uint8") 85 | expect_equal(as.integer(x), as.integer(c(0,0, 1, 254, 255, 255))) 86 | 87 | # tf.fill works with a scalar tensor 88 | x <- as_tensor(as_tensor(3), shape = 32) 89 | expect_equal(dim(x), 32L) 90 | expect_equal(as.numeric(x), rep(3, 32)) 91 | 92 | # supplied shape can be a tensor 93 | x <- as_tensor(as_tensor(3), shape = as_tensor(32L)) 94 | expect_equal(dim(x), 32L) 95 | expect_equal(as.numeric(x), rep(3, 32)) 96 | 97 | }) 98 | 99 | 100 | test_that("conversion of tf.string dtype tensors", { 101 | 102 | xx <- list(array("foo"), 103 | array(c("foo", "bar", "baz")), 104 | array(as.character(1:12), c(3, 4)), 105 | array(as.character(1:12), c(2, 3, 2))) 106 | 107 | for (x in xx) 108 | expect_identical(x, as.array(as_tensor(x))) 109 | 110 | for (x in xx) 111 | expect_identical(as.character(x), 112 | as.character(as_tensor(x))) 113 | 114 | }) 115 | -------------------------------------------------------------------------------- /tests/testthat/test-data-structures.R: -------------------------------------------------------------------------------- 1 | test_that("cast List wrapers", { 2 | model <- tf$keras$models$Sequential(list()) 3 | model$denses <- list(tf$keras$layers$Dense(10L), 4 | tf$keras$layers$Dense(10L)) 5 | 6 | model$denses_dict <- list(abc = tf$keras$layers$Dense(10L), 7 | def = tf$keras$layers$Dense(10L)) 8 | 9 | expect_true(is.list(model$denses)) 10 | expect_true(length(model$denses) == 2) 11 | 12 | expect_true(is.list(model$denses_dict)) 13 | expect_true(length(model$denses_dict) == 2) 14 | expect_named(model$denses_dict, c("abc", "def")) 15 | }) 16 | -------------------------------------------------------------------------------- /tests/testthat/test-examples.R: -------------------------------------------------------------------------------- 1 | context("examples") 2 | 3 | 4 | # some helpers 5 | run_example <- function(example) { 6 | env <- new.env() 7 | capture.output({ 8 | example_path <- system.file("examples", example, package = "tensorflow") 9 | old_wd <- setwd(dirname(example_path)) 10 | on.exit(setwd(old_wd), add = TRUE) 11 | source(basename(example_path), local = env) 12 | }, type = "output") 13 | rm(list = ls(env), envir = env) 14 | gc() 15 | } 16 | 17 | examples <- if (nzchar(Sys.getenv("TENSORFLOW_TEST_EXAMPLES"))) { 18 | examples <- c("hello.R", 19 | "introduction.R", 20 | "mnist/mnist_softmax.R", 21 | "mnist/fully_connected_feed.R", 22 | "regression/tensorflow_linear_regression.R") 23 | 24 | if (tf_version() >= "2.0") { 25 | # disable examples since tf_compat() requires session restart 26 | examples <- NULL 27 | } 28 | 29 | examples 30 | } 31 | 32 | for (example in examples) { 33 | test_that(paste(example, "example runs successfully"), { 34 | skip_if_no_tensorflow() 35 | expect_error(run_example(example), NA) 36 | }) 37 | } 38 | 39 | -------------------------------------------------------------------------------- /tests/testthat/test-export-savedmodel.R: -------------------------------------------------------------------------------- 1 | context("Save") 2 | 3 | train_mnist_graph <- function(sess) { 4 | 5 | IPython <- IPython <- reticulate::import("IPython") 6 | py_capture_output <- IPython$utils$capture$capture_output 7 | 8 | with(py_capture_output(), { 9 | datasets <- tf$contrib$learn$datasets 10 | mnist <- datasets$mnist$read_data_sets("MNIST-data", one_hot = TRUE) 11 | }) 12 | 13 | if (tf_version() >= "1.14") 14 | placeholder <- tf$compat$v1$placeholder 15 | else 16 | placeholder <- tf$placeholder 17 | 18 | x <- placeholder(tf$float32, shape(NULL, 784L)) 19 | 20 | W <- tf$Variable(tf$zeros(shape(784L, 10L))) 21 | b <- tf$Variable(tf$zeros(shape(10L))) 22 | 23 | y <- tf$nn$softmax(tf$matmul(x, W) + b) 24 | 25 | y_ <- placeholder(tf$float32, shape(NULL, 10L)) 26 | cross_entropy <- tf$reduce_mean(-tf$reduce_sum(y_ * log(y), reduction_indices = 1L)) 27 | 28 | if (tf_version() >= "1.14") 29 | optimizer <- tf$compat$v1$train$GradientDescentOptimizer(0.5) 30 | else 31 | optimizer <- tf$train$GradientDescentOptimizer(0.5) 32 | 33 | train_step <- optimizer$minimize(cross_entropy) 34 | 35 | if (tf_version() >= "1.14") 36 | init <- tf$compat$v1$global_variables_initializer() 37 | else 38 | init <- tf$global_variables_initializer() 39 | 40 | sess$run(init) 41 | 42 | for (i in 1:1000) { 43 | batches <- mnist$train$next_batch(100L) 44 | batch_xs <- batches[[1]] 45 | batch_ys <- batches[[2]] 46 | sess$run(train_step, 47 | feed_dict = dict(x = batch_xs, y_ = batch_ys)) 48 | } 49 | 50 | correct_prediction <- tf$equal(tf$argmax(y, 1L), tf$argmax(y_, 1L)) 51 | accuracy <- tf$reduce_mean(tf$cast(correct_prediction, tf$float32)) 52 | 53 | sess$run(accuracy, feed_dict = dict(x = mnist$test$images, y_ = mnist$test$labels)) 54 | 55 | list(input = x, output = y) 56 | } 57 | 58 | # TODO: consider testing in a new R session with tf$compat$v1$disable_eager_execution() 59 | # skip("Don't have sessions to export when running eager.") 60 | if (!tf$executing_eagerly()) 61 | test_that("export_savedmodel() works with MNIST", { 62 | skip_if_no_tensorflow() 63 | 64 | temp_path <- tempfile() 65 | 66 | if (tf_version() >= "1.14") 67 | sess <- tf$compat$v1$Session() 68 | else 69 | sess <- tf$Session() 70 | 71 | tensors <- train_mnist_graph(sess) 72 | 73 | export_savedmodel( 74 | sess, 75 | temp_path, 76 | inputs = list(images = tensors$input), 77 | outputs = list(scores = tensors$output) 78 | ) 79 | 80 | expect_true(file.exists(file.path(temp_path, "saved_model.pb"))) 81 | 82 | }) 83 | -------------------------------------------------------------------------------- /tests/testthat/test-extract-syntax.R: -------------------------------------------------------------------------------- 1 | context("extract syntax") 2 | 3 | 4 | null_out_all_extract_opts <- function() { 5 | opts <- options() 6 | opts[grepl("^tensorflow[.]extract", names(opts))] <- list(NULL) 7 | options(opts) 8 | } 9 | 10 | arr <- function (...) { 11 | # create an array with the specified dimensions, and fill it with consecutive 12 | # increasing integers 13 | dims <- unlist(list(...)) 14 | array(1:prod(dims), dim = dims) 15 | } 16 | 17 | randn <- function (...) { 18 | dim <- c(...) 19 | array(rnorm(prod(dim)), dim = dim) 20 | } 21 | 22 | # check a simple (one-object) expression produces the same result when done on 23 | # an R array, and when done on a tensor, with results ported back to R 24 | # e.g. check_expr(a[1:3], swap = "a") 25 | check_expr <- function (expr, name = "x") { 26 | 27 | call <- substitute(expr) 28 | r_out <- as.array(eval(expr)) 29 | 30 | # swap the array for a constant, run, and convert back to an array 31 | obj <- get(name, parent.frame()) 32 | swapsies <- list(tf$constant(obj)) 33 | names(swapsies) <- name 34 | tf_out <- with(swapsies, grab(eval(call))) 35 | 36 | # check it's very very similar 37 | expect_identical(r_out, tf_out) 38 | 39 | } 40 | 41 | reset_warnings <- function() { 42 | e <- tensorflow:::warned_about 43 | e$negative_indices <- FALSE 44 | e$tensors_passed_asis <- FALSE 45 | } 46 | 47 | 48 | # capture previous r-like extraction method, set to default, and return later 49 | # old_extract_method <- options("tensorflow.extract.one_based") 50 | # options(tensorflow.extract.one_based = NULL) 51 | # options(tensorflow.extract.style = 'R') 52 | 53 | 54 | # test indexing for unknown dimensions 55 | 56 | test_that('extract works for unknown dimensions', { 57 | 58 | skip_if_no_tensorflow() 59 | 60 | oopt <- options(tensorflow.extract.style = "R") 61 | 62 | # expected values with 5 rows 63 | x_vals <- matrix(rnorm(50), 5, 10) 64 | y1_exp <- as.array(x_vals[, 1]) 65 | y2_exp <- as.array(x_vals[, 1, drop = FALSE]) 66 | 67 | 68 | if(tf$executing_eagerly()) { 69 | 70 | t <- tf$convert_to_tensor(x_vals) 71 | y1_obs <- t[, 1] %>% as.array() 72 | y2_obs <- t[, 1, drop = FALSE] %>% as.array() 73 | 74 | } else { 75 | 76 | if (tf_version() >= "1.14") 77 | placeholder <- tf$compat$v1$placeholder 78 | else 79 | placeholder <- tf$placeholder 80 | 81 | # the output should retain the missing dimension 82 | x <- placeholder(tf$float64, shape(NULL, 10)) 83 | y1 <- x[, 1] 84 | y2 <- x[, 1, drop = FALSE] 85 | expect_identical(dim(y1), list(NULL)) 86 | expect_identical(dim(y2), list(NULL, 1L)) 87 | 88 | # get observed in values for these 89 | if (tf_version() >= "1.14") 90 | sess <- tf$compat$v1$Session() 91 | else 92 | sess <- tf$Session() 93 | 94 | y1_obs <- sess$run(y1, 95 | feed_dict = dict(x = x_vals)) 96 | y2_obs <- sess$run(y2, 97 | feed_dict = dict(x = x_vals)) 98 | 99 | } 100 | 101 | expect_identical(y1_obs, y1_exp) 102 | expect_identical(y2_obs, y2_exp) 103 | 104 | options(oopt) 105 | }) 106 | 107 | test_that("scalar indexing works", { 108 | 109 | skip_if_no_tensorflow() 110 | oopt <- options(tensorflow.extract.style = "R") 111 | # set up arrays 112 | x1_ <- arr(3) 113 | x2_ <- arr(3, 3) 114 | x3_ <- arr(3, 3, 3) 115 | 116 | # cast to Tensors 117 | x1 <- tf$constant(x1_) 118 | x2 <- tf$constant(x2_) 119 | x3 <- tf$constant(x3_) 120 | 121 | # extract as arrays 122 | y1_ <- x1_[1] 123 | y2_ <- x2_[1, 2] 124 | y3_ <- x3_[1, 2, 3] 125 | 126 | # extract as Tensors 127 | y1 <- x1[1] 128 | y2 <- x2[1, 2] 129 | y3 <- x3[1, 2, 3] 130 | 131 | # they should be equivalent 132 | expect_equal(y1_, grab(y1)) 133 | expect_equal(y2_, grab(y2)) 134 | expect_equal(y3_, grab(y3)) 135 | 136 | options(oopt) 137 | }) 138 | 139 | # tests for 0-based indexing 140 | 141 | # options(tensorflow.extract.one_based = FALSE) 142 | 143 | test_that("vector indexing works", { 144 | skip_if_no_tensorflow() 145 | 146 | oopt <- options(tensorflow.extract.one_based = FALSE) 147 | # set up arrays 148 | x1_ <- arr(3) 149 | x2_ <- arr(3, 3) 150 | 151 | # cast to Tensors 152 | x1 <- tf$constant(x1_) 153 | x2 <- tf$constant(x2_) 154 | 155 | # extract as arrays 156 | y1_ <- x1_[2:3] 157 | y2_ <- x2_[2:3, 1] 158 | 159 | # extract as Tensors 160 | y1 <- x1[1:2] 161 | y2 <- x2[1:2, 0] 162 | 163 | # these should be equivalent (need to coerce R version back to arrays) 164 | expect_equal(y1_, grab(y1)) 165 | expect_equal(array(y2_), grab(y2)) 166 | 167 | options(oopt) 168 | }) 169 | 170 | test_that("blank indices retain all elements", { 171 | skip_if_no_tensorflow() 172 | 173 | oopt <- options(tensorflow.extract.one_based = FALSE) 174 | 175 | # set up arrays 176 | x1_ <- arr(3) 177 | x2_ <- arr(3, 3) 178 | x3_ <- arr(3, 3, 3) 179 | x4_ <- arr(3, 3, 3, 3) 180 | 181 | # cast to Tensors 182 | x1 <- tf$constant(x1_) 183 | x2 <- tf$constant(x2_) 184 | x3 <- tf$constant(x3_) 185 | x4 <- tf$constant(x4_) 186 | 187 | # extract as arrays 188 | y1_ <- x1_[] 189 | y2_a <- x2_[2:3, ] 190 | y2_b <- x2_[, 1:2] 191 | y3_a <- x3_[2:3, 1, ] 192 | y3_b <- x3_[2:3, , 1] 193 | y4_ <- x4_[2:3, 1, , 2:3] 194 | 195 | # extract as Tensors 196 | y1 <- x1[] 197 | y2a <- x2[1:2, ] # j missing 198 | y2b <- x2[, 0:1] 199 | y3a <- x3[1:2, 0, ] 200 | y3b <- x3[1:2, , 0] 201 | y4 <- x4[1:2, 0, , 1:2] 202 | 203 | # these should be equivalent 204 | expect_equal(y1_, grab(y1)) 205 | expect_equal(y2_a, grab(y2a)) 206 | expect_equal(y2_b, grab(y2b)) # 207 | expect_equal(y3_a, grab(y3a)) 208 | expect_equal(y3_b, grab(y3b)) # 209 | expect_equal(y4_, grab(y4)) 210 | 211 | options(oopt) 212 | }) 213 | 214 | test_that("indexing works within functions", { 215 | skip_if_no_tensorflow() 216 | 217 | # tensorflow.extract.style = "python", 218 | oopt <- options(tensorflow.extract.one_based = FALSE) 219 | 220 | # set up arrays 221 | x1_ <- arr(3) 222 | x2_ <- arr(3, 3) 223 | x3_ <- arr(3, 3, 3) 224 | 225 | # cast to Tensors 226 | x1 <- tf$constant(x1_) 227 | x2 <- tf$constant(x2_) 228 | x3 <- tf$constant(x3_) 229 | 230 | # set up functions 231 | sub1 <- function (x, a) 232 | x[a - 1] 233 | sub2 <- function (x, a, b) 234 | x[a - 1, b - 1] 235 | sub3 <- function (x, b, c) 236 | x[, b - 1, c - 1] # skip first element 237 | 238 | # extract as arrays 239 | y1_ <- x1_[1:3] 240 | y2_ <- x2_[, 1:2] 241 | y3_a <- x3_[, 1:2, ] 242 | y3_b <- x3_[, , 1] 243 | 244 | # extract as Tensors 245 | y1 <- sub1(x1, 1:3) 246 | y2 <- sub2(x2, 1:3, 1:2) 247 | y3a <- sub3(x3, 1:2, 1:3) 248 | y3b <- sub3(x3, 1:3, 1) 249 | 250 | # these should be equivalent 251 | expect_equal(y1_, grab(y1)) 252 | expect_equal(y2_, grab(y2)) 253 | expect_equal(y3_a, grab(y3a)) 254 | expect_equal(y3_b, grab(y3b)) 255 | 256 | options(oopt) 257 | }) 258 | 259 | 260 | test_that("indexing works with variables", { 261 | skip_if_no_tensorflow() 262 | 263 | expect_ok <- function (expr) { 264 | expect_is(expr, "tensorflow.tensor") 265 | } 266 | 267 | # set up tensors 268 | x1 <- tf$constant(arr(3)) 269 | x2 <- tf$constant(arr(3, 3)) 270 | x3 <- tf$constant(arr(3, 3, 3)) 271 | 272 | # extract with index (these shouldn't error) 273 | index <- 2 274 | expect_ok(x1[index]) # i 275 | expect_ok(x2[, index]) # j 276 | expect_ok(x3[, , index]) # dots 277 | 278 | }) 279 | 280 | test_that("indexing with negative sequences errors", { 281 | skip_if_no_tensorflow() 282 | 283 | oopt <- options(tensorflow.extract.style = "R") 284 | # set up Tensors 285 | x1 <- tf$constant(arr(3)) 286 | x2 <- tf$constant(arr(3, 3)) 287 | 288 | # extract with negative indices (where : is not the top level call) 289 | expect_error(x1[-(1:2)], 'positive') 290 | expect_error(x2[-(1:2), ], 'positive') 291 | 292 | options(oopt) 293 | }) 294 | 295 | test_that("incorrect number of indices errors", { 296 | skip_if_no_tensorflow() 297 | 298 | # set up Tensor 299 | x <- tf$constant(arr(3, 3, 3)) 300 | # options(tensorflow.extract.one_based = TRUE) 301 | # too many 302 | expect_error(x[1:2, 2, 1:2, 3], 303 | 'Incorrect number of dimensions') 304 | expect_error(x[1:2, 2, 1:2, 3, , ], 305 | 'Incorrect number of dimensions') 306 | expect_error(x[1:2, 2, 1:2, 3, , drop = TRUE], 307 | 'Incorrect number of dimensions') 308 | # too few 309 | expect_warning(x[], 310 | 'Incorrect number of dimensions') 311 | expect_warning(x[1:2, ], 312 | 'Incorrect number of dimensions') 313 | expect_warning(x[1:2, 2], 314 | 'Incorrect number of dimensions') 315 | 316 | }) 317 | 318 | test_that("silly indices error", { 319 | skip_if_no_tensorflow() 320 | 321 | # set up Tensor 322 | x <- tf$constant(arr(3, 3, 3)) 323 | 324 | # these should all error and notify the user of the failing index 325 | expect_error(x[1:2, NA, 2], 'NA') 326 | expect_error(x[1:2, Inf, 2], 'Inf') 327 | expect_error(x[1:2, 'apple', 2], 'character') 328 | expect_error(x[1:2, mean, 2], 'function') 329 | }) 330 | 331 | test_that("passing non-vector indices errors", { 332 | skip_if_no_tensorflow() 333 | 334 | # set up Tensor 335 | x1 <- tf$constant(arr(3, 3)) 336 | x2 <- tf$constant(arr(3, 3, 3)) 337 | 338 | # block indices 339 | block_idx_1 <- rbind(c(1, 2), c(0, 1)) 340 | block_idx_2 <- rbind(c(1, 2, 1), c(0, 1, 2)) 341 | 342 | # indexing with matrices should fail 343 | expect_error(x1[block_idx_1], 344 | 'not currently supported') 345 | expect_error(x2[block_idx_2], 346 | 'not currently supported') 347 | 348 | }) 349 | 350 | # thanks to @dfalbel https://github.com/rstudio/tensorflow/issues/139 351 | # also check it returns the correct dimensions to R 352 | test_that("undefined extensions extract", { 353 | 354 | skip_if_no_tensorflow() 355 | oopt <- options(tensorflow.extract.style = 'python') 356 | 357 | x_ <- matrix(seq_len(3), ncol = 1) 358 | 359 | if(tf$executing_eagerly()) { 360 | 361 | t <- tf$convert_to_tensor(x_) 362 | result <- t[, 0L] %>% as.array() 363 | 364 | } else { 365 | 366 | if (tf_version() >= "1.14") 367 | placeholder <- tf$compat$v1$placeholder 368 | else 369 | placeholder <- tf$placeholder 370 | 371 | x <- placeholder(tf$int16, shape = list(NULL, 1L)) 372 | sub <- x[, 0L] 373 | 374 | if (tf_version() >= "1.14") 375 | sess <- tf$compat$v1$Session() 376 | else 377 | sess <- tf$Session() 378 | 379 | result <- sess$run(sub, dict(x = x_)) 380 | 381 | } 382 | 383 | expectation <- array(x_[, 1, drop = TRUE]) 384 | expect_equal(result, expectation) 385 | 386 | options(oopt) 387 | 388 | }) 389 | 390 | 391 | test_that("dim(), length(), nrow(), and ncol() work on tensors", { 392 | 393 | skip_if_no_tensorflow() 394 | 395 | a_matrix <- matrix(rnorm(100), ncol = 2) 396 | a_tensor <- tf$constant(a_matrix) 397 | expect_equal(dim(a_matrix), dim(a_tensor)) 398 | expect_equal(length(a_matrix), length(a_tensor)) 399 | expect_equal(nrow(a_matrix), nrow(a_tensor)) 400 | expect_equal(ncol(a_matrix), ncol(a_tensor)) 401 | 402 | }) 403 | 404 | 405 | 406 | test_that("all_dims()", { 407 | 408 | skip_if_no_tensorflow() 409 | 410 | x1.r <- arr(3) 411 | x2.r <- arr(3, 3) 412 | x3.r <- arr(3, 3, 3) 413 | x4.r <- arr(3, 3, 3, 3) 414 | 415 | x1.t <- tf$constant(x1.r) 416 | x2.t <- tf$constant(x2.r) 417 | x3.t <- tf$constant(x3.r) 418 | x4.t <- tf$constant(x4.r) 419 | 420 | options(tensorflow.extract.one_based = TRUE) 421 | 422 | expect_equal(grab( x1.t[all_dims()] ), x1.r[] ) 423 | expect_equal(grab( x1.t[1, all_dims()] ), x1.r[1] ) 424 | expect_equal(grab( x1.t[all_dims(), 1] ), x1.r[1] ) 425 | 426 | # as.array() because tf returns 1d arrays, not bare atomic vectors 427 | expect_equal(grab( x2.t[all_dims()] ), as.array( x2.r[,] )) 428 | expect_equal(grab( x2.t[1, all_dims()] ), as.array( x2.r[1,] )) 429 | expect_equal(grab( x2.t[ all_dims(), 1] ), as.array( x2.r[,1] )) 430 | 431 | expect_equal(grab( x3.t[all_dims()] ), as.array( x3.r[,,] )) 432 | expect_equal(grab( x3.t[1, all_dims()] ), as.array( x3.r[1,,] )) 433 | expect_equal(grab( x3.t[1, 1, all_dims()] ), as.array( x3.r[1,1,] )) 434 | expect_equal(grab( x3.t[1, all_dims(), 1] ), as.array( x3.r[1,,1] )) 435 | expect_equal(grab( x3.t[all_dims(), 1] ), as.array( x3.r[,,1] )) 436 | expect_equal(grab( x3.t[all_dims(), 1, 1] ), as.array( x3.r[,1,1] )) 437 | 438 | expect_equal(grab( x4.t[all_dims()] ), as.array( x4.r[,,,] )) 439 | expect_equal(grab( x4.t[1, all_dims()] ), as.array( x4.r[1,,,] )) 440 | expect_equal(grab( x4.t[1, 1, all_dims()] ), as.array( x4.r[1,1,,] )) 441 | expect_equal(grab( x4.t[1, all_dims(), 1] ), as.array( x4.r[1,,,1] )) 442 | expect_equal(grab( x4.t[all_dims(), 1] ), as.array( x4.r[,,,1] )) 443 | expect_equal(grab( x4.t[all_dims(), 1, 1] ), as.array( x4.r[,,1,1] )) 444 | 445 | }) 446 | 447 | 448 | test_that("negative-integers work python style", { 449 | 450 | skip_if_no_tensorflow() 451 | options(tensorflow.extract.warn_negatives_pythonic = FALSE) 452 | # options(tensorflow.warn_negative_extract_is_python_style = FALSE) 453 | 454 | x1.r <- arr(4) 455 | x2.r <- arr(4, 4) 456 | 457 | x1.t <- tf$constant(x1.r) 458 | x2.t <- tf$constant(x2.r) 459 | 460 | options(tensorflow.extract.one_based = TRUE) 461 | expect_equal(grab( x1.t[-1] ), x1.r[4] ) 462 | expect_equal(grab( x1.t[-2] ), x1.r[3] ) 463 | expect_equal(grab( x2.t[-2, -2] ), x2.r[3, 3] ) 464 | expect_equal(grab( x2.t[-1, ] ), as.array( x2.r[4,] )) 465 | 466 | options(tensorflow.extract.one_based = FALSE) 467 | # same as above 468 | expect_equal(grab( x1.t[-1] ), x1.r[4] ) 469 | expect_equal(grab( x1.t[-2] ), x1.r[3] ) 470 | 471 | expect_equal(grab( x1.t[NULL:-2] ), x1.r[1:3] ) 472 | expect_equal(grab( x1.t[NULL:-1] ), x1.r[] ) 473 | 474 | expect_equal(grab( x2.t[-2, -2] ), x2.r[3, 3] ) 475 | expect_equal(grab( x2.t[-1, ] ), as.array( x2.r[4,] )) 476 | 477 | null_out_all_extract_opts() 478 | }) 479 | 480 | 481 | test_that("python-style strided slice", { 482 | 483 | skip_if_no_tensorflow() 484 | oopts <- options() 485 | options(tensorflow.extract.warn_negatives_pythonic = FALSE) 486 | 487 | x.r <- arr(20, 2) # 2nd dim to keep R from dropping (since tf always returns 1d array) 488 | x.t <- tf$constant(x.r) 489 | 490 | options(tensorflow.extract.style = "R") 491 | 492 | expect_equal(grab( x.t[ `5:` ,] ), x.r[ 5:20,]) 493 | expect_equal(grab( x.t[ `5:NULL` ,] ), x.r[ 5:20,]) 494 | expect_equal(grab( x.t[ 5:NULL ,] ), x.r[ 5:20,]) 495 | expect_equal(grab( x.t[ 5:NA ,] ), x.r[ 5:20,]) 496 | expect_equal(grab( x.t[ `5:NULL:` ,] ), x.r[ 5:20,]) 497 | expect_equal(grab( x.t[ 5:NULL:NULL ,] ), x.r[ 5:20,]) 498 | expect_equal(grab( x.t[ 5:NA:NA ,] ), x.r[ 5:20,]) 499 | expect_equal(grab( x.t[ 5:NA:NA_integer_ ,] ), x.r[ 5:20,]) 500 | expect_equal(grab( x.t[ 5:NA_real_:NA ,] ), x.r[ 5:20,]) 501 | expect_equal(grab( x.t[ `5:NULL:NULL` ,] ), x.r[ 5:20,]) 502 | 503 | expect_equal(grab( x.t[ `5::` ,] ), x.r[ 5:20,]) 504 | expect_equal(grab( x.t[ `:5:` ,] ), x.r[ 1:5,]) 505 | expect_equal(grab( x.t[ `:5` ,] ), x.r[ 1:5,]) 506 | expect_equal(grab( x.t[ `2:5` ,] ), x.r[ 2:5,]) 507 | expect_equal(grab( x.t[ 2:5 ,] ), x.r[ 2:5,]) 508 | 509 | expect_equal(grab( x.t[ `::2` ,] ), x.r[ seq.int(1, 20, by = 2) ,]) 510 | expect_equal(grab( x.t[ NULL:NULL:2 ,] ), x.r[ seq.int(1, 20, by = 2) ,]) 511 | 512 | # non syntantic names or function calls can work too 513 | `_idx` <- 1 514 | expect_equal(grab( x.t[ `_idx`:(identity(5)+1L),]), x.r[ 1:6, ] ) 515 | 516 | 517 | expect_equal(grab( x.t[ `2:6:2`,]), x.r[ seq.int(2, 6, 2) ,]) 518 | expect_equal(grab( x.t[ 2:6:2 ,]), x.r[ seq.int(2, 6, 2) ,]) 519 | 520 | 521 | # decreasing indexes work 522 | expect_equal(grab( x.t[ `6:2:-2`,]), x.r[ seq.int(6, 2, -2) ,]) 523 | expect_equal(grab( x.t[ 6:2:-2 ,]), x.r[ seq.int(6, 2, -2) ,]) 524 | 525 | # sign of step gets automatically inverted on decreasing indexes 526 | expect_equal(grab( x.t[ `6:2:2` ,]), x.r[ seq.int(6, 2, -2) ,]) 527 | expect_equal(grab( x.t[ 6:2:2 ,]), x.r[ seq.int(6, 2, -2) ,]) 528 | expect_equal(grab( x.t[ 6:2 ,]), x.r[ 6:2 ,]) 529 | expect_equal(grab( x.t[ 6:2:1 ,]), x.r[ 6:2 ,]) 530 | expect_equal(grab( x.t[ 6:2:-1 ,]), x.r[ 6:2 ,]) 531 | 532 | 533 | options(tensorflow.extract.style = "python") 534 | # options set to match python 535 | # helper to actually test in python 536 | test_in_python <- (function() { 537 | # main <- reticulate::import_main() 538 | reticulate::py_run_string(paste( 539 | "import numpy as np", 540 | "x = np.array(range(1, 41))", 541 | "x.shape = (2, 20)", 542 | "x = x.transpose()", sep = "\n")) 543 | function(chr) { 544 | reticulate::py_eval(chr) 545 | } 546 | })() 547 | 548 | 549 | expect_equal(grab( x.t[ 2:5,] ), test_in_python("x[2:5,]")) 550 | expect_equal(grab( x.t[ 2:-5 ,] ), test_in_python("x[ 2:-5 ,]")) 551 | expect_equal(grab( x.t[ 2:5:2 ,] ), test_in_python("x[ 2:5:2 ,]")) 552 | expect_equal(grab( x.t[ -2:-5:-1 ,] ), test_in_python("x[ -2:-5:-1 ,]")) 553 | expect_equal(grab( x.t[ 5:2:-1 ,] ), test_in_python("x[ 5:2:-1 ,]")) 554 | expect_equal(grab( x.t[ 5:2:-2 ,] ), test_in_python("x[ 5:2:-2 ,]")) 555 | 556 | 557 | # indexing with tensors 558 | expect_equal(grab( x.t[tf$constant(2L),] ), as.array(x.r[3,])) 559 | expect_equal(grab( x.t[tf$constant(2L):tf$constant(5L),] ), x.r[3:5,]) 560 | 561 | # expect warning that no translation on tensors performed 562 | null_out_all_extract_opts() 563 | expect_warning(grab( x.t[tf$constant(2L),] ), "ignored") 564 | 565 | # warn only once 566 | expect_silent(grab( x.t[tf$constant(2L),] )) 567 | 568 | # warn in slice syntax too 569 | reset_warnings() 570 | null_out_all_extract_opts() 571 | expect_warning(grab( x.t[tf$constant(2L):tf$constant(5L),] ), "ignored") 572 | 573 | reset_warnings() 574 | options(tensorflow.extract.warn_tensors_passed_asis = FALSE) 575 | expect_silent(grab( x.t[tf$constant(2L):tf$constant(5L),] )) 576 | 577 | 578 | null_out_all_extract_opts() 579 | }) 580 | 581 | 582 | 583 | # test warnings for extraction that looks like it might be 0-based 584 | 585 | test_that('extract warns when indices look 0-based', { 586 | 587 | skip_if_no_tensorflow() 588 | oopts <- options() 589 | 590 | x <- tf$constant(matrix(0, 2, 2)) 591 | i0 <- 0:1 592 | i1 <- 1:2 593 | 594 | # explicit 0-indexing shouldn't warn 595 | options(tensorflow.extract.one_based = FALSE) 596 | expect_silent(x[i0, i0]) 597 | 598 | # explicit 1-indexing shouldn't warn 599 | options(tensorflow.extract.one_based = TRUE) 600 | # expect_silent(x[i0, i0]) # expect error 601 | 602 | # default 1-indexing should warn only if there's a zero in there 603 | options(tensorflow.extract.one_based = NULL) 604 | expect_silent(x[i1, i1]) 605 | # expect_warning(x[i0, i0], # expect error 606 | # "It looks like you might be using 0-based indexing") 607 | 608 | options(oopts) 609 | }) 610 | 611 | test_that('extract errors when indices have missing elements at variable steps', { 612 | 613 | skip_if_no_tensorflow() 614 | 615 | x <- tf$constant(array(0, dim = c(2, 4, 2))) 616 | 617 | # indexing with sequential values shouldn't error 618 | expect_silent(x[1, c(1, 2, 3), ]) 619 | expect_error( x[1, c(1, 3, 4),]) 620 | 621 | }) 622 | 623 | 624 | # reset user's extract method 625 | # options(tensorflow.extract.one_based = old_extract_method) 626 | -------------------------------------------------------------------------------- /tests/testthat/test-generic-methods.R: -------------------------------------------------------------------------------- 1 | context("generic methods") 2 | 3 | 4 | test_that("log with supplied base works", { 5 | 6 | skip_if_no_tensorflow() 7 | 8 | r <- array(as.double(1:20)) 9 | t <- as_tensor(r, dtype = tf$float32) 10 | 11 | expect_near(r, grab(log(as_tensor(exp(r))))) 12 | expect_near(r, grab(log2(as_tensor(2 ^ r)))) 13 | expect_near(r, grab(log10(as_tensor(10 ^ r)))) 14 | 15 | expect_near(r, grab(log(exp(t)))) 16 | expect_near(r, grab(log2(2 ^ t))) 17 | expect_near(r, grab(log10(10 ^ t))) 18 | 19 | # log() dispatches correctly without trying to change base 20 | expect_identical(grab(tf$math$log(t)), grab(log(t))) 21 | 22 | expect_near(log(r), grab(log(t))) 23 | expect_near(log(r, base = 3), grab(log(t, base = 3))) 24 | 25 | }) 26 | 27 | 28 | 29 | test_generic <- function(fn, ..., namespace = "base") { 30 | name <- gsub("[\"']", "", deparse(substitute(fn))) 31 | if(!is.function(fn)) 32 | name <- fn 33 | test_that(paste("Generic", name, "works"), { 34 | skip_if_no_tensorflow() 35 | if(!is.function(fn)) 36 | fn <- get(name, envir = asNamespace(namespace), 37 | mode = 'function') 38 | 39 | suppress_warning_NaNs_produced({ 40 | out_r <- do.call(fn, list(...)) 41 | }) 42 | 43 | if(length(list(...)) == 1) { 44 | out_tf <- grab(fn(tf$constant(..1))) 45 | expect_equal(out_tf, out_r) 46 | return() 47 | } 48 | 49 | if(length(list(...)) == 2) { 50 | expect_equal(out_r, grab(fn(tf$constant(..1), ..2))) 51 | expect_equal(out_r, grab(fn(..1, tf$constant(..2)))) 52 | expect_equal(out_r, grab(fn(tf$constant(..1), tf$constant(..2)))) 53 | return() 54 | } 55 | 56 | stop("bad test call, only unary and binary S3 generics supported") 57 | 58 | }) 59 | } 60 | 61 | # --------- binary operators ---------------- 62 | 63 | binary_arith_generics <- c("+", "-", "*", "/", "^", "%%", "%/%") 64 | binary_compr_generics <- c("==", "!=", "<", "<=", ">", ">=") 65 | 66 | 67 | for (fn in c(binary_arith_generics, binary_compr_generics)) { 68 | test_generic(fn, rarr(1,2,3), rarr(1,2,3)) 69 | 70 | # test automatic type casting 71 | fn <- get(fn, envir = asNamespace("base")) 72 | expect_equal(fn(5L, 3), grab(fn(as_tensor(5L), 3))) 73 | expect_equal(fn(5L, 3), grab(fn(5L, as_tensor(3, "float64")))) 74 | 75 | expect_equal(fn(5, 3L), grab(fn(5, as_tensor(3L)))) 76 | expect_equal(fn(5, 3L), grab(fn(as_tensor(5, "float64"), 3L))) 77 | } 78 | 79 | 80 | if(getRversion() >= "4.3.0") { 81 | test_generic("%*%", rarr(3, 3), rarr(3, 3)) 82 | } 83 | 84 | expect_equal(as.numeric(as_tensor(3) ^ 2), 3^2) 85 | expect_equal(as.numeric(as_tensor(3, "float64") ^ .5), 3^.5) 86 | 87 | binary_logic_generics <- c("&", "|") 88 | 89 | x <- lapply(expand.grid(e1 = c(TRUE, FALSE), e2 = c(TRUE, FALSE)), 90 | as.array) 91 | x$e1.num <- x$e2.num <- 1:4 92 | x$e1.num[!x$e1] <- 0L 93 | x$e2.num[!x$e2] <- 0 94 | 95 | for (fn in binary_logic_generics) { 96 | test_generic(fn, x$e1, x$e2) 97 | 98 | # test automatic type casting 99 | fn <- get(fn, envir = asNamespace("base")) 100 | expect_equal(fn(x$e1.num, x$e2), grab(fn(x$e1.num, as_tensor(x$e2)))) 101 | expect_equal(fn(x$e1, x$e2.num), grab(fn(as_tensor(x$e1), x$e2.num))) 102 | } 103 | 104 | 105 | # ---------- unary operators --------------- 106 | unary_logic_generics <- c("!") 107 | 108 | for (fn in unary_logic_generics) 109 | test_generic(fn, x$e1) 110 | 111 | 112 | unary_shape_generics <- c("dim", "length") 113 | 114 | for (fn in unary_shape_generics) { 115 | test_generic(fn, arr(1)) 116 | test_generic(fn, arr(1, 2)) 117 | test_generic(fn, arr(1, 2, 3)) 118 | test_generic(fn, arr(3)) 119 | } 120 | 121 | expect_identical(dim(as_tensor(arr(3, 3))), c(3L, 3L)) 122 | 123 | f <- tf_function(function(x) { 124 | expect_identical(dim(x), NA_integer_) 125 | expect_identical(length(x), NA_integer_) 126 | x+1 127 | }, input_signature = list(tf$TensorSpec(shape(NA)))) 128 | f(as_tensor(array(3), "float32")) 129 | 130 | f <- tf_function(function(x) { 131 | expect_identical(dim(x), c(NA_integer_, 1L, NA_integer_)) 132 | expect_identical(length(x), NA_integer_) 133 | x+1 134 | }, input_signature = list(tf$TensorSpec(shape(NA, 1, NA)))) 135 | f(as_tensor(array(3, dim = c(1,1,1)), "float32")) 136 | 137 | 138 | f <- tf_function(function(x) { 139 | expect_identical(dim(x), NULL) 140 | expect_identical(length(x), NA_integer_) 141 | x+1 142 | }, input_signature = list(tf$TensorSpec(shape(dims = NULL)))) 143 | f(as_tensor(array(3, dim = c(1,1,1)), "float32")) 144 | 145 | 146 | unary_math_generics <- c( 147 | 148 | "-", 149 | "+", 150 | 151 | "abs", 152 | "sign", 153 | "sqrt", 154 | "floor", 155 | "ceiling", 156 | "round", 157 | 158 | "log", 159 | "log1p", 160 | "log2", 161 | "log10", 162 | 163 | "exp", 164 | "expm1", 165 | 166 | "cos", 167 | "sin", 168 | "tan", 169 | 170 | "sinpi", 171 | "cospi", 172 | "tanpi", 173 | 174 | "acos", 175 | "asin", 176 | "atan", 177 | 178 | "lgamma", 179 | "digamma" 180 | ) 181 | 182 | for (fn in c(unary_math_generics)) { 183 | test_generic(fn, arr(20)) 184 | test_generic(fn, rarr(20)) 185 | } 186 | 187 | 188 | unary_complex_generics <- c("Re", "Im", "Conj", "Arg", "Mod") 189 | 190 | for (fn in unary_complex_generics) 191 | test_generic(fn, 1 + 2i) 192 | 193 | 194 | 195 | numeric_reduce_generics <- 196 | list(sum, prod, min, max, mean, range) 197 | 198 | 199 | x <- arr(3, 4) 200 | xt <- as_tensor(x) 201 | 202 | for(fn in numeric_reduce_generics) 203 | expect_equal(fn(x), as.numeric(fn(as_tensor(x)))) 204 | 205 | for(fn in list(sum, prod, min, max, range)) # not mean 206 | expect_equal(fn(x, x), as.numeric(fn(as_tensor(x), as_tensor(x)))) 207 | 208 | for(fn in list(sum, prod, min, max, mean)) { # not range 209 | expect_equal(dim(fn(xt, axis = 1)), 4L) 210 | expect_equal(dim(fn(xt, axis = 2)), 3L) 211 | expect_equal(dim(fn(xt, axis = 1, keepdims = TRUE)), c(1L, 4L)) 212 | expect_equal(dim(fn(xt, axis = 2, keepdims = TRUE)), c(3L, 1L)) 213 | } 214 | 215 | 216 | bool_reduce_generics <- list(all, any) 217 | for (fn in bool_reduce_generics) { 218 | tt <- rep(TRUE, 5) 219 | ff <- rep(FALSE, 5) 220 | mx <- rep(c(TRUE, FALSE), 4) 221 | for (x in list(tt, ff, mx)) { 222 | expect_equal(fn(x), as.logical(fn(as_tensor(x)))) 223 | expect_equal(fn(x, x), as.logical(fn(as_tensor(x), as_tensor(x)))) 224 | expect_equal(fn(x, x), as.logical(fn(as_tensor(x), x))) 225 | } 226 | } 227 | 228 | expect_equivalent_bind_generic <- function(fn, ...) { 229 | res1 <- fn(...) 230 | dimnames(res1) <- NULL 231 | res2 <- as.array(do.call(fn, lapply(list(...), as_tensor))) 232 | if(is_windows()) # https://github.com/rstudio/reticulate/issues/1071 233 | storage.mode(res2) <- "integer" 234 | expect_identical(res1, res2) 235 | 236 | dots <- list(...) 237 | dots[[1L]] <- as_tensor(..1) 238 | res3 <- as.array(do.call(fn, dots)) 239 | if(is_windows()) # https://github.com/rstudio/reticulate/issues/1071 240 | storage.mode(res3) <- "integer" 241 | expect_identical(res1, res3) 242 | 243 | dots <- list(...) 244 | dots[[2L]] <- as_tensor(..2) 245 | res4 <- as.array(do.call(fn, dots)) 246 | if(is_windows()) # https://github.com/rstudio/reticulate/issues/1071 247 | storage.mode(res4) <- "integer" 248 | expect_identical(res1, res4) 249 | } 250 | 251 | m <- matrix(1:9, nrow = 3) 252 | v <- 1:3 253 | v1 <- as.array(1:3) 254 | for (fn in list(cbind, rbind)) { 255 | expect_equivalent_bind_generic(fn, v,v,v) 256 | expect_equivalent_bind_generic(fn, v1,v,m) 257 | expect_equivalent_bind_generic(fn, m,v,v) 258 | expect_equivalent_bind_generic(fn, m, m) 259 | expect_equivalent_bind_generic(fn, m, v) 260 | expect_equivalent_bind_generic(fn, 1L, 1L) 261 | expect_equivalent_bind_generic(fn, 1L, as.matrix(1L)) 262 | expect_equivalent_bind_generic(fn, as.array(1L), 1L) 263 | expect_equal(fn(as_tensor(1:3), 1:3, dtype = "int64")$dtype$name, "int64") 264 | expect_equal(fn(as_tensor(1:3), 1:3, dtype = "int16")$dtype$name, "int16") 265 | expect_equal(fn(as_tensor(1:3), 1:3, dtype = "float32")$dtype$name, "float32") 266 | } 267 | 268 | 269 | test_generic("t", 1) 270 | test_generic("t", array(1)) 271 | test_generic("t", matrix(1)) 272 | test_generic("t", 1:3) 273 | test_generic("t", array(1:3)) 274 | test_generic("t", matrix(1:3)) 275 | test_generic("t", m) 276 | 277 | test_generic("aperm", array(1)) 278 | test_generic("aperm", matrix(1)) 279 | test_generic("aperm", array(1:3)) 280 | test_generic("aperm", matrix(1:3)) 281 | test_generic("aperm", m) 282 | 283 | a <- arr(3, 4, 5) 284 | r1 <- aperm(a, c(2, 1, 3)) 285 | r2 <- as.array(aperm(as_tensor(a), c(2, 1, 3))) 286 | expect_identical(r1, r2) 287 | 288 | 289 | x <- array(c(0, 1, Inf, NaN)) 290 | test_generic("is.finite", x) 291 | test_generic("is.infinite", x) 292 | test_generic("is.nan", x) 293 | 294 | x <- array(c(2L, 10L, 3L, 1L, 7L, 4L, 6L, 8L, 9L, 5L)) 295 | test_generic("sort", x) 296 | decreasing_sort <- function(x) sort(x, decreasing = TRUE) 297 | test_generic(decreasing_sort, x) 298 | 299 | 300 | xx <- list(array(1:3), 301 | 1) 302 | 303 | for (x in xx) { 304 | test_generic(function(a) as.array(rep(a, 3)), x) 305 | test_generic(function(a) as.array(rep(as_tensor(a), as_tensor(3L))), x) 306 | 307 | test_generic("as.vector", x) 308 | } 309 | 310 | 311 | test_that("generics can handle tensors w/ convert=FALSE", { 312 | 313 | skip_if_no_tensorflow() 314 | 315 | # this tests that `*` dispatches correctly even of both x and y provide Ops methods 316 | if(getRversion() >= "4.3.0") { 317 | x <- tf$ones(shape(5, 5)) * r_to_py(array(1, dim = c(5, 5))) 318 | expect_true(as.logical(all(x == 1))) 319 | } 320 | 321 | # test that as.array / as.raster can work even if convert=FALSE 322 | img <- tf$cast(tf$random$uniform(shape(256, 256, 4), maxval = 256), 323 | "uint8") 324 | x <- np_array(array(c(2, 1, 1, 1), dim = c(1, 1, 4)), dtype = "uint8") # convert=FALSE 325 | 326 | expect_no_error(as.raster(img)) 327 | expect_no_error(as.raster(r_to_py(img))) 328 | 329 | if (getRversion() >= "4.3.0") { 330 | expect_no_error(as.raster(img %/% x)) 331 | expect_no_error(as.raster(r_to_py(img %/% x))) 332 | expect_no_error(as.raster(r_to_py(r_to_py(img) %/% x))) 333 | expect_no_error(as.raster(r_to_py(img %/% r_to_py(x)))) 334 | } 335 | 336 | }) 337 | 338 | 339 | -------------------------------------------------------------------------------- /tests/testthat/test-seed.R: -------------------------------------------------------------------------------- 1 | # skip("use_session_with seed doesn't work with TF >= 2.3") 2 | if (tf_version() < "2.3") 3 | test_that("use_session_with_seed works", { 4 | skip_if_no_tensorflow() 5 | 6 | 7 | f <- function() { 8 | library(keras3) 9 | use_session_with_seed(seed = 1) 10 | model <- keras_model_sequential() %>% 11 | layer_dense(units = 1) 12 | predict(model, matrix(1, ncol = 1)) 13 | } 14 | 15 | run1 <- callr::r(f) 16 | run2 <- callr::r(f) 17 | 18 | expect_equal(run1, run2) 19 | }) 20 | 21 | test_that("set_random_seed", { 22 | 23 | skip_if_no_tensorflow() 24 | 25 | if (tf_version() < "2.0") 26 | skip("set_random_seed only works for TF >= 2.0") 27 | 28 | f <- function() { 29 | library(keras3) 30 | tensorflow::set_random_seed(seed = 1) 31 | model <- keras_model_sequential() %>% 32 | layer_dense(units = 1) 33 | predict(model, matrix(1, ncol = 1)) 34 | } 35 | 36 | run1 <- callr::r(f) 37 | run2 <- callr::r(f) 38 | 39 | expect_equal(run1, run2) 40 | }) 41 | -------------------------------------------------------------------------------- /tests/testthat/test-shape.R: -------------------------------------------------------------------------------- 1 | test_that("shape() works", { 2 | 3 | 4 | skip_if_no_tensorflow() 5 | 6 | 7 | expect_tensor_shape <- function(x, dims) { 8 | expect_s3_class(x, "tensorflow.python.framework.tensor_shape.TensorShape") 9 | if (missing(dims)) 10 | return() 11 | 12 | if (is.null(dims)) { 13 | expect_null(x$rank) 14 | expect_null(x$dims) 15 | } else { 16 | expect_identical(as.list(x$as_list()), dims) 17 | } 18 | } 19 | 20 | x <- shape(1, NA, 2, NULL, 3) 21 | expect_tensor_shape(x, list(1L, NULL, 2L, NULL, 3L)) 22 | 23 | expect_identical(as.list(x), list(1L, NULL, 2L, NULL, 3L)) 24 | 25 | expect_true(shape() == tf$TensorShape(list())) 26 | expect_true(shape(dims = NULL) == tf$TensorShape(NULL)) 27 | 28 | 29 | # --- construct --- 30 | expect_tensor_shape(shape() , list()) 31 | expect_tensor_shape(shape(NULL) , list(NULL)) 32 | expect_tensor_shape(shape(dims = NULL) , NULL) 33 | expect_tensor_shape(shape(3, 4) , list(3L, 4L)) 34 | expect_tensor_shape(shape(NA, 4) , list(NULL, 4L)) 35 | expect_tensor_shape(shape(dims = c(NA, 4)) , list(NULL, 4L)) 36 | 37 | # --- inspect --- 38 | expect_identical(length(shape(dims = NULL)), NA_integer_, ) 39 | expect_identical(length(shape(1, 2, 3, NA)), 4L) 40 | 41 | 42 | # ---convert --- 43 | x <- shape(dims = list(3L, 5L)) 44 | expect_identical(as.list(x) , list(3L, 5L)) 45 | expect_identical(as.integer(x), c(3L, 5L)) 46 | expect_identical(as.numeric(x), c(3, 5)) 47 | expect_identical(as.double(x) , c(3, 5)) 48 | 49 | x <- shape(NA, 3) 50 | expect_identical(as.list(x), list(NULL, 3L)) 51 | expect_identical(as.integer(x), c(NA, 3L)) 52 | expect_identical(as.double(x), c(NA, 3)) 53 | 54 | x2 <- as_tensor(shape(NA, 3)) 55 | expect_equal(x2$numpy(), array(c(-1L, 3L))) 56 | expect_identical(x2$dtype$name, "int32") 57 | expect_identical(as.list(x2$shape), list(2L)) 58 | 59 | x <- shape(dims = NULL) 60 | expect_error(as.list(x)) 61 | expect_error(as.numeric(x)) 62 | expect_error(as_tensor(x)) 63 | 64 | x <- shape(NA, 3) 65 | # as_tensor() converts undefined dims to -1 66 | expect_identical(as.integer(as_tensor(x)), c(-1L, 3L)) 67 | # can round trips shape -> tensor -> shape 68 | expect_tensor_shape(shape(dims = as_tensor(x)), list(NULL, 3L)) 69 | 70 | 71 | # --- compare --- 72 | # Fully known shapes return TRUE if and only if each element is equal 73 | expect_true(shape(3, 4) == shape(3, 4)) # TRUE 74 | expect_false(shape(3, 4) == shape(4, 4)) # FALSE 75 | 76 | # Partially-known shapes always return FALSE 77 | if (tf_version() >= "2.9") 78 | expect_true(shape(NA, 4) == shape(NA, 4)) 79 | else 80 | expect_false(shape(NA, 4) == shape(NA, 4)) 81 | 82 | expect_false(shape(NA, 4) == shape(3, 4)) 83 | 84 | # Two unknown shapes, return TRUE 85 | expect_true(shape(dims = NULL) == shape(dims = NULL)) 86 | 87 | # Comparing an unknown shape to a partially or fully defined shape returns FALSE 88 | expect_false(shape(dims = NULL) == shape(NULL)) 89 | expect_false(shape(dims = NULL) == shape(4)) 90 | 91 | if(tf_version() < "2.9") { 92 | # in 2.9, != is just negation of == 93 | # prior versions: != is mostly the inverse of ==, with one difference: 94 | # it raises an error when comparing a fully unknown shapes 95 | expect_error(shape(dims = NULL) != shape(dims = NULL)) # ValueError: The inequality of unknown TensorShapes is undefined. 96 | expect_error(shape(dims = NULL) != shape()) # ValueError: The inequality of unknown TensorShapes is undefined. 97 | } 98 | 99 | 100 | # --- extract or replace --- 101 | # regular R-list semantics for `[`, `[[`, `[<-`, `[[<-` 102 | x <- shape(1, 2, 3) 103 | expect_tensor_shape(x[1], list(1L)) 104 | expect_identical(x[[1]], 1L) 105 | 106 | x_slice <- x[2:3] 107 | expect_tensor_shape(x_slice, list(2L, 3L)) 108 | expect_true(x_slice == c(2, 3)) 109 | expect_true(x_slice == x[-1]) 110 | 111 | x <- shape(1, 2, 3) 112 | x[1] <- 11 113 | expect_tensor_shape(x, list(11L, 2L, 3L)) 114 | expect_true(x == c(11, 2, 3)) 115 | 116 | x[1] <- shape(22) 117 | expect_tensor_shape(x, list(22L, 2L, 3L)) 118 | expect_true(x == c(22, 2, 3)) 119 | 120 | x[1] <- list(33) 121 | expect_tensor_shape(x, list(33L, 2L, 3L)) 122 | expect_true(x == c(33, 2, 3)) 123 | 124 | x[[1]] <- 44 125 | expect_true(x == c(44, 2, 3)) 126 | x[1:2] <- c(NA, 99) 127 | expect_identical(as.numeric(x), c(NA, 99, 3)) 128 | x[1:2] <- shape(33, 44) 129 | expect_tensor_shape(x, list(33L, 44L, 3L)) 130 | expect_identical(as.numeric(x), c(33, 44, 3)) 131 | 132 | # --- concatenate --- 133 | x <- 134 | c(shape(1), shape(2, 3), shape(4, NA)) # TensorShape([1, 2, 3, 4, None]) 135 | expect_identical(as.list(x), list(1L, 2L, 3L, 4L, NULL)) 136 | 137 | # --- merge --- 138 | x <- merge(shape(NA, 2), 139 | shape(1 , 2)) # TensorShape([1, 2]) 140 | expect_tensor_shape(x, list(1L, 2L)) 141 | expect_true(x == c(1, 2)) 142 | 143 | expect_error(merge(shape(2, 2), 144 | shape(1, 2))) # ValueError: Shapes (2, 2) and (1, 2) are not compatible 145 | 146 | 147 | expect_output(print(shape(3)), "TensorShape([3])", fixed = TRUE) 148 | expect_output(print(shape(3, NA)), "TensorShape([3, None])", fixed = TRUE) 149 | expect_output(print(shape(3, NULL)), "TensorShape([3, None])", fixed = TRUE) 150 | 151 | expect_equal(format(shape(3)), "(3)") 152 | expect_equal(format(shape(3, NA)), "(3, NA)") 153 | expect_equal(format(shape(3, NULL)), "(3, NA)") 154 | 155 | # shape() can accept tf.TensorShapes, and flatten them 156 | expect_equal(as.list(shape(shape(3))), list(3L)) 157 | expect_equal(as.list(shape(shape(3, 4))), list(3L, 4L)) 158 | expect_equal(as.list(shape(shape(3, 4), 5)), list(3L, 4L, 5L)) 159 | expect_equal(as.list(shape(NA, shape(3, 4), 5)), list(NULL, 3L, 4L, 5L)) 160 | 161 | }) 162 | -------------------------------------------------------------------------------- /tests/testthat/test-types.R: -------------------------------------------------------------------------------- 1 | 2 | test_that("TensorShapes are not converted to lists", { 3 | 4 | skip_if_no_tensorflow() 5 | 6 | x <- tf$constant(10, shape = shape(5,10)) 7 | expect_true(x$shape[2] == shape(10)) 8 | expect_identical(x$shape[[2]], 10L) 9 | 10 | y <- tf$TensorShape(shape(5L, 10L)) 11 | expect_s3_class(y, "tensorflow.python.framework.tensor_shape.TensorShape") 12 | expect_true(y[2] == shape(10)) 13 | expect_identical(y[[2]], 10L) 14 | }) 15 | 16 | 17 | test_that("tf.random works", { 18 | # Installing tensorflow-metal 1.2.0 makes this error. 19 | x <- tf$random$stateless_uniform( 20 | shape = tuple(10L), 21 | seed = tuple(2L, 3L), 22 | minval = 0L, 23 | maxval = 10L, 24 | dtype = tf$dtypes$int32) 25 | expect_s3_class(x, "tensorflow.tensor") 26 | x <- as.array(x) 27 | if (!is_windows()) 28 | expect_type(x, "integer") 29 | expect_identical(dim(x), 10L) 30 | }) 31 | --------------------------------------------------------------------------------