├── .github └── workflows │ └── run_tests.yml ├── .gitignore ├── .gitmodules ├── .scalafmt.conf ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── build.sbt ├── environment.yml ├── project ├── Build.scala ├── build.properties └── plugins.sbt ├── requirements-dev.txt ├── scala-torch ├── build.sbt ├── project │ └── build.properties └── src │ ├── main │ └── scala │ │ └── com │ │ └── microsoft │ │ └── scalatorch │ │ └── torch │ │ ├── Device.scala │ │ ├── Generator.scala │ │ ├── Model.scala │ │ ├── Reduction.scala │ │ ├── ReferenceManager.scala │ │ ├── Scalar.scala │ │ ├── Size.scala │ │ ├── Tensor.scala │ │ ├── Tensor.scala.in │ │ ├── TensorOptions.scala │ │ ├── dtype.scala │ │ ├── fft │ │ ├── package.scala │ │ └── package.scala.in │ │ ├── jit │ │ ├── CompilationUnit.scala │ │ ├── IValue.scala │ │ ├── Module.scala │ │ ├── OpSymbol.scala │ │ └── Type.scala │ │ ├── linalg │ │ ├── package.scala │ │ └── package.scala.in │ │ ├── nn │ │ ├── functional │ │ │ ├── package.scala │ │ │ └── package.scala.in │ │ └── init │ │ │ └── package.scala │ │ ├── optim │ │ └── Optimizer.scala │ │ ├── package.scala │ │ ├── package.scala.in │ │ ├── special │ │ ├── package.scala │ │ └── package.scala.in │ │ ├── syntax.scala │ │ └── util │ │ ├── Disposer.scala │ │ ├── Implicits.scala │ │ ├── NoGrad.scala │ │ └── Profiler.scala │ └── test │ ├── resources │ └── com │ │ └── microsoft │ │ └── scalatorch │ │ └── torch │ │ └── jit │ │ ├── simple_trace.py │ │ └── traced_model.pt │ └── scala │ └── com │ └── microsoft │ └── scalatorch │ └── torch │ ├── TensorTest.scala │ ├── TorchFunctionTest.scala │ ├── jit │ └── ModuleTest.scala │ └── tutorial │ ├── PyTorchOrgTensorTutorialTest.scala │ ├── PytorchOrgPolynomialNetwork.scala │ ├── PytorchOrgPolynomialNetworkOptim.scala │ └── TutorialTest.scala └── swig ├── build.sbt ├── project └── plugins.sbt └── src ├── main ├── java │ └── com │ │ └── microsoft │ │ └── scalatorch │ │ └── torch │ │ └── internal │ │ ├── Module2.java │ │ ├── Module3.java │ │ ├── Module4.java │ │ └── NativeLoader.java └── swig │ ├── bindgen.py │ ├── generated_bindings.i │ ├── generated_tensor_bindings.i │ ├── torch_array.i │ ├── torch_array_ref.i │ ├── torch_data.i │ ├── torch_equals_hashcode.i │ ├── torch_generator_swig.i │ ├── torch_indexing.i │ ├── torch_init.i │ ├── torch_ir.i │ ├── torch_jit_type.i │ ├── torch_list.h │ ├── torch_optim_swig.i │ ├── torch_optional.i │ ├── torch_primitives.i │ ├── torch_profiler.i │ ├── torch_reduction.i │ ├── torch_scalar.i │ ├── torch_script_swig.i │ ├── torch_serialize_swig.i │ ├── torch_std_array.i │ ├── torch_string_view.i │ ├── torch_swig.i │ ├── torch_tensor.i │ ├── torch_tensor_list.i │ ├── torch_variant_enum.i │ └── tuple.i └── native └── CMakeLists.txt /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Compile and Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | TORCH_VERSION: 1.10.2 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v2 17 | with: 18 | submodules: recursive 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: '3.10' 23 | - name: Setup JDK 24 | uses: actions/setup-java@v3 25 | with: 26 | distribution: temurin 27 | java-version: 17 28 | - name: Install Python dependencies 29 | run: pip install -r requirements-dev.txt 30 | - name: Dowload libtorch 31 | run: | 32 | set -e 33 | curl https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip -o libtorch.zip 34 | unzip libtorch.zip 35 | rm libtorch.zip 36 | - name: Generate Declarations.yaml 37 | run: cd pytorch && python -m tools.codegen.gen -s aten/src/ATen -d torch/share/ATen 38 | - name: Run tests 39 | run: LD_LIBRARY_PATH=$PWD/libtorch/lib:$LD_LIBRARY_PATH sbt test 40 | 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # cmake stuff 2 | Testing/ 3 | dynet/Testing/ 4 | dynet/tests.bin/ 5 | CTestTestfile.cmake 6 | config.h 7 | Makefile 8 | CMakeCache.txt 9 | CMakeFiles 10 | cmake_install.cmake 11 | python/dynet.cpp 12 | python/dist/ 13 | python/dyNET.egg-info/ 14 | 15 | # binaries 16 | build/ 17 | Debug/ 18 | *.model 19 | 20 | #data 21 | rnnlm/ptb-mikolov/ 22 | 23 | # Python temporary files 24 | *.pyc 25 | 26 | # Compiled Object files 27 | *.slo 28 | *.lo 29 | *.o 30 | *.obj 31 | 32 | # Precompiled Headers 33 | *.gch 34 | *.pch 35 | 36 | # Compiled Dynamic libraries 37 | *.so 38 | *.dylib 39 | *.dll 40 | *.jnilib 41 | 42 | # Fortran module files 43 | *.mod 44 | 45 | # Compiled Static libraries 46 | *.lai 47 | *.la 48 | *.a 49 | *.lib 50 | 51 | # Executables 52 | *.exe 53 | *.out 54 | *.app 55 | *.log 56 | 57 | # Editor stuff 58 | *.swp 59 | .vscode 60 | 61 | # Doc stuff 62 | doc/doxygen/xml 63 | doc/source/tutorials_notebooks 64 | 65 | # sbt stuff 66 | target 67 | .idea 68 | .bsp 69 | 70 | .RData 71 | 72 | # platform specific 73 | .DS_Store 74 | 75 | # ignore libtorch. we assume you've put it here. 76 | libtorch/ 77 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch"] 2 | path = pytorch 3 | url = https://github.com/pytorch/pytorch.git 4 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = "2.7.4" 2 | maxColumn = 120 3 | rewrite.rules = [PreferCurlyFors, SortModifiers, AvoidInfix, SortImports] 4 | trailingCommas = always 5 | assumeStandardLibraryStripMargin = true 6 | optIn.breakChainOnFirstMethodDot = false 7 | spaces.inImportCurlyBraces = true 8 | includeCurlyBraceInSelectChains = true 9 | spaces.beforeContextBoundColon = Always 10 | newlines.implicitParamListModifierPrefer = before 11 | newlines.beforeCurlyLambdaParams = multilineWithCaseOnly 12 | docstrings.style = SpaceAsterisk 13 | // These are copied from https://scalameta.org/scalafmt/docs/configuration.html#other 14 | rewrite.neverInfix.excludeFilters = [ 15 | until 16 | to 17 | by 18 | eq 19 | ne 20 | "should.*" 21 | "contain.*" 22 | "must.*" 23 | in 24 | ignore 25 | be 26 | taggedAs 27 | thrownBy 28 | synchronized 29 | have 30 | when 31 | size 32 | only 33 | noneOf 34 | oneElementOf 35 | noElementsOf 36 | atLeastOneElementOf 37 | atMostOneElementOf 38 | allElementsOf 39 | inOrderElementsOf 40 | theSameElementsAs 41 | // these are new 42 | throws 43 | returns 44 | satisfy 45 | ] 46 | 47 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 4 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 5 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 6 | 7 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 8 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 9 | provided by the bot. You will only need to do this once across all repos using our CLA. 10 | 11 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 12 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 13 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scala-torch 2 | JVM/Scala wrappers for LibTorch. 3 | 4 | ## State of this project 5 | 6 | This project is mature enough to be used regularly in production code. The API exposed is fairly clean 7 | and tries to follow PyTorch syntax as much as possible. The API is a mix of hand-written wrappings and a wrapper 8 | around most of `Declarations.yaml`. 9 | 10 | That said, some internal documentation is not quite ready for public consumption yet, though there is enough 11 | documentation that people who are already familiar with Scala and LibTorch can probably figure out what's going on. 12 | Code generation is accomplished through a combination of [Swig](https://www.swig.org) and a quick-and-dirty 13 | [Python script](swig/src/main/swig/bindgen.py) that reads in `Declarations.yaml`, which provides a language-independent 14 | API for a large part of LibTorch. This file is [deprecated](https://github.com/pytorch/pytorch/issues/69471) and in the 15 | future, we can hopefully replace `bindgen.py` using the forthcoming [torchgen](https://github.com/pytorch/pytorch/issues/69471#issuecomment-1273642655) 16 | tool provided by PyTorch. 17 | 18 | One major annoyance with Scala in particular is that you cannot define multiple overloads of a method that take default 19 | arguments. Currently, `bindgen.py` uses any defaults present in only the first overload found in `Declarations.yaml`. 20 | In some cases, clever use of Scala's implicit conversions can hide these headaches, but currently, you occasionaly have to write 21 | out the defaults where you would not have to in Python. One potential future option is to give overloads 22 | different names, but we elected not to do that (yet). 23 | 24 | We have not yet published JARs for this project. These are coming soon. 25 | 26 | ## Short tour 27 | 28 | Scala-torch exposes an API that tries to mirror PyTorch as much as Scala syntax 29 | allows. For example, taking some snippets from 30 | [this tutorial](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html): 31 | 32 | PyTorch: 33 | ```python 34 | import torch 35 | 36 | data = [[1, 2],[3, 4]] 37 | x_data = torch.tensor(data) 38 | ``` 39 | 40 | Scala-Torch: 41 | ```scala 42 | import com.microsoft.scalatorch.torch 43 | import com.microsoft.scalatorch.torch.syntax._ 44 | 45 | torch.ReferenceManager.forBlock { implicit rm => 46 | val data = $($(1, 2), $(3, 4)) 47 | val x_data = torch.tensor(data) 48 | } 49 | ``` 50 | 51 | 52 | PyTorch: 53 | ```python 54 | tensor = torch.ones(4, 4) 55 | print(f"First row: {tensor[0]}") 56 | print(f"First column: {tensor[:, 0]}") 57 | print(f"Last column: {tensor[..., -1]}") 58 | tensor[:,1] = 0 59 | print(tensor) 60 | ``` 61 | 62 | Scala-Torch: 63 | ```scala 64 | val tensor = torch.ones($(4, 4)) 65 | println(s"First row: ${tensor(0)}") 66 | println(s"First column: ${tensor(::, 0)}") 67 | println(s"Last column: ${tensor(---, -1)}") 68 | tensor(::, 1) = 0 69 | println(tensor) 70 | ``` 71 | 72 | See [this file](scala-torch/src/test/scala/com/microsoft/scalatorch/torch/tutorial/PyTorchOrgTensorTutorialTest.scala) for 73 | a complete translation of the PyTorch tutorial into Scala-Torch. 74 | 75 | ### Memory management 76 | 77 | One big difference between Scala-Torch and PyTorch is in memory management. Because Python and LibTorch both use 78 | reference counting, memory management is fairly transparent to users. However, since the JVM uses garbage collection 79 | and [finalizers are not guaranteed to run](https://docs.oracle.com/javase/9/docs/api/java/lang/Object.html#finalize--), 80 | it is not easy to make memory management transparent to the user. Scala-Torch elects to make memory management something 81 | the user must control by providing [ReferenceManager](scala-torch/src/main/scala/com/microsoft/scalatorch/torch/ReferenceManager.scala)s 82 | that define the lifetime of any LibTorch-allocated object 83 | that is added to it. All Scala-Torch methods that allocate objects from LibTorch take an `implicit` `ReferenceManager`, 84 | so it is the responsibility of the caller to make sure there is a `ReferenceManager` in `implicit` scope (or passed 85 | explicitly) and that that `ReferenceManager` will be `close()`ed when appropriate. See documentation and uses 86 | of `ReferenceManager` for more examples. 87 | 88 | ## Handling of native dependencies 89 | 90 | PyTorch provides pre-built binaries for the native code backing it [here](https://pytorch.org/get-started/locally/). 91 | We make use of the pre-built dynamic libraries by packaging them up in a jar, much like [TensorFlow Scala](http://platanios.org/tensorflow_scala/installation.html). 92 | Downstream 93 | projects have two options for handling the native dependencies: they can either 94 | 1. Declare a dependency on the packaged native dependencies wrapped up with a jar using 95 | ```scala 96 | val osClassifier = System.getProperty("os.name").toLowerCase match { 97 | case os if os.contains("mac") || os.contains("darwin") => "darwin" 98 | case os if os.contains("linux") => "linux" 99 | case os if os.contains("windows") => "windows" 100 | case os => throw new sbt.MessageOnlyException(s"The OS $os is not a supported platform.") 101 | } 102 | libraryDependencies += ("com.microsoft.scalatorch" % "libtorch-jar" % "1.10.0").classifier(osClassifier + "_cpu") 103 | ``` 104 | 2. Ensure that the libtorch dependencies are installed in the OS-dependent way, for example, in `/usr/lib` or in `LD_LIBRARY_PATH` on Linux, 105 | or in `PATH` on windows. Note that on recent version of MacOS, [System Integrity Protected](https://developer.apple.com/library/archive/documentation/Security/Conceptual/System_Integrity_Protection_Guide/RuntimeProtections/RuntimeProtections.html) 106 | resets `LD_LIBRARY_PATH` and `DYLD_LIBRARY_PATH` when working processes, so it is very hard to use that approach on MacOS. 107 | 108 | The native binaries for the JNI bindings for all three supported OSes are published in `scala-torch-swig.jar`, so there 109 | is no need for OS-specific treatment of those libraries. 110 | 111 | Approach 1 is convenient because sbt will handle the libtorch native dependency for you and users won't need install 112 | libtorch or set any environment variables. This is the ideal approach for local development. 113 | 114 | There are several downsides of approach 1: 115 | * it may unnecessarily duplicate installation of libtorch if, for example, pytorch is already installed 116 | * jars for GPU builds of libtorch are not provided, so approach 2 is the only option if GPU support is required 117 | * care must be taken when publishing any library that depends on Scala-Torch to not publish the dependency 118 | on the `libtorch-jar`, since that would force the consumer of that library to depend on whatever OS-specific 119 | version of the jar was used at building time. See the use of `pomPostProcess` in [build.sbt](build.sbt) for 120 | how we handle that. Note that another option is for downstream libraries to exclude the `libtorch-jar` 121 | using something like 122 | ```scala 123 | libraryDependencies += ("com.microsoft" % "scala-torch" % "0.1.0").exclude("com.microsoft.scalatorch", "libtorch-jar") 124 | ``` 125 | 126 | Approach 2 is the better option for CI, remote jobs, production, etc. 127 | 128 | ### Local Development (MacOS) 129 | 130 | You will need to have SWIG installed, which you can 131 | install using `brew install swig`. 132 | 133 | ``` 134 | git submodule update --init --recursive 135 | cd pytorch 136 | python3 -m tools.codegen.gen -s aten/src/ATen -d torch/share/ATen 137 | cd .. 138 | curl https://download.pytorch.org/libtorch/cpu/libtorch-macos-$(pytorchVersion).zip -o libtorch.zip 139 | unzip libtorch.zip 140 | rm -f libtorch.zip 141 | conda env create --name scala-torch --file environment.yml 142 | conda activate scala-torch 143 | export TORCH_DIR=$PWD/libtorch 144 | # This links to the JNI shared library to the absolute paths in the libtorch dir instead of 145 | # using an rpath. 146 | export LINK_TO_BUILD_LIB=true 147 | sbt test 148 | ``` 149 | 150 | A similar setup should work for Linux and Windows. 151 | 152 | #### Troubleshooting 153 | 154 | If you are using Clang 11.0.3 you may run into an error 155 | when compiling the `SobolEngineOps` file. This is most 156 | likely due to an issue with the compiler and it has already 157 | been reported [here](https://github.com/pytorch/pytorch/issues/35478). 158 | A temporary workaround is to install another version of 159 | Clang (e.g., by executing `brew install llvm`). Another option 160 | is to downgrade XCode to a version < 11.4. 161 | 162 | ### Upgrading the LibTorch version 163 | 164 | To upgrade the underlying version of LibTorch: 165 | * `cd pytorch; git checkout ` with the `` of the desired release version, 166 | best found [here](https://github.com/pytorch/pytorch/releases). 167 | * Rerun the steps under **Local Development**. 168 | * Change `TORCH_VERSION` in [run_tests.yml](.github/workflows/run_tests.yml). 169 | * Address compilation errors when running `sbt compile`. Changes to [bindgen.py](swig/src/main/swig/bindgen.py) may 170 | be necessary. 171 | 172 | # Contributors 173 | 174 | Thanks to the following contributors to this project: 175 | 176 | * [Adam Pauls](https://github.com/adampauls) 177 | * [David Hall](https://github.com/dlwh) 178 | * [Theo Lanman](https://github.com/theo-lanman) 179 | * [Alex Kyte](https://github.com/alexanderkyte) 180 | * [Hao Fang](https://github.com/hao-fang) 181 | * [Anthony Platanios](https://github.com/eaplatanios) 182 | * [Dmitrij Peters](https://github.com/Dpetters) 183 | 184 | # Trademarks 185 | 186 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 187 | trademarks or logos is subject to and must follow 188 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 189 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 190 | Any use of third-party trademarks or logos are subject to those third-party's policies. -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | lazy val Scala2_13Version = "2.13.10" 2 | lazy val Scala2_12Version = "2.12.17" 3 | val commonSettings = Seq( 4 | organization := "com.microsoft", 5 | scalaVersion := Scala2_13Version, 6 | version := "0.0.1-SM-05-SNAPSHOT", 7 | publishMavenStyle := true, 8 | ) 9 | lazy val swig = (project in file("swig")) 10 | .settings(commonSettings: _*) 11 | .enablePlugins(JniNative) 12 | .settings( 13 | name := "swig", 14 | crossScalaVersions := Seq(Scala2_12Version, Scala2_13Version), 15 | crossPaths := true, 16 | ) 17 | 18 | lazy val `scala-torch` = (project in file("scala-torch")) 19 | .settings(commonSettings: _*) 20 | .settings( 21 | name := "scala-torch", 22 | crossScalaVersions := Seq(Scala2_12Version, Scala2_13Version), 23 | ) 24 | .dependsOn(swig) 25 | 26 | lazy val root = (project in file(".")) 27 | .settings(commonSettings: _*) 28 | .settings( 29 | name := "scala-torch-parent", 30 | ) 31 | .dependsOn(swig, `scala-torch`) 32 | .aggregate(swig, `scala-torch`) 33 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - defaults 4 | dependencies: 5 | - python=3.7.3 6 | - ninja=1.10.0 7 | - pyyaml=5.3.1 8 | - setuptools=46.1.3 9 | - cmake=3.17.0 10 | - cffi=1.14.0 11 | - typing_extensions=3.7.4.3 12 | -------------------------------------------------------------------------------- /project/Build.scala: -------------------------------------------------------------------------------- 1 | import scala.util.Try 2 | 3 | import lmcoursier.CoursierConfiguration 4 | import lmcoursier.definitions.Authentication 5 | import sbt.{ File, Logger } 6 | 7 | object Util { 8 | def osCudaClassifier: String = { 9 | val osString = 10 | Option(System.getenv("LIBTORCH_TARGET_OS")).getOrElse(System.getProperty("os.name")).toLowerCase match { 11 | case os if os.contains("mac") || os.contains("darwin") => "darwin" 12 | case os if os.contains("linux") => "linux" 13 | case os if os.contains("windows") => "windows" 14 | case os => throw new sbt.MessageOnlyException(s"The OS $os is not a supported platform.") 15 | } 16 | val cudaString = Option(System.getenv("LIBTORCH_TARGET_CPU")).getOrElse("cpu") 17 | s"${osString}_$cudaString" 18 | } 19 | 20 | private val acceptHeader = "Accept" -> "application/octet-stream, application/json, application/xml, */*" 21 | } 22 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | # NOTE: keep in sync with docker/Dockerfile 2 | sbt.version=1.7.3 3 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.sbt" % "sbt-jni" % "1.5.3") 2 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.0.2") 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | typing_extensions==3.7.4.3 2 | pyyaml==5.3.1 3 | -------------------------------------------------------------------------------- /scala-torch/build.sbt: -------------------------------------------------------------------------------- 1 | scalacOptions += "-target:jvm-1.8" 2 | javacOptions ++= Seq("-target", "1.8", "-source", "1.8") 3 | 4 | libraryDependencies += "org.scalatest" %% "scalatest" % "3.1.4" % "test" 5 | libraryDependencies += "com.michaelpollmeier" %% "scala-arm" % "2.1" 6 | libraryDependencies += "com.lihaoyi" %% "sourcecode" % "0.1.9" 7 | libraryDependencies += "org.scala-lang.modules" %% "scala-collection-compat" % "2.5.0" 8 | 9 | unmanagedSourceDirectories in Compile += { 10 | val sourceDir = (sourceDirectory in Compile).value 11 | CrossVersion.partialVersion(scalaVersion.value) match { 12 | case Some((2, n)) if n >= 13 => 13 | sourceDir / "scala-2.13+" 14 | case _ => 15 | sourceDir / "scala-2.13-" 16 | } 17 | } 18 | 19 | fork := true 20 | -------------------------------------------------------------------------------- /scala-torch/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.3.3 2 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/Device.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.jit.CompilationUnit 5 | import com.microsoft.scalatorch.torch.util.Disposer 6 | import com.microsoft.scalatorch.torch.internal 7 | 8 | case class Device private[torch] (protected[torch] val underlying: internal.Device) 9 | extends TorchReference[internal.Device] { 10 | override protected def delete(): Unit = underlying.delete() 11 | 12 | def tpe: internal.DeviceType = underlyingChecked.`type`() 13 | def index: Short = underlying.index() 14 | def isCPU: Boolean = underlyingChecked.is_cpu() 15 | def isCUDA: Boolean = underlyingChecked.is_cuda() 16 | 17 | override def toString: String = underlyingChecked.str() 18 | } 19 | 20 | object Device { 21 | private[torch] def fromString(str: String): Device = { 22 | val underlying = new internal.Device(str) 23 | Disposer.add(new Device(underlying), () => underlying.delete()) 24 | } 25 | 26 | private[torch] def apply(`type`: internal.DeviceType, index: Int): Device = { 27 | val underlying = new internal.Device(`type`, index.toByte) 28 | Disposer.add(new Device(underlying), () => underlying.delete()) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/Generator.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.util.Disposer 5 | import com.microsoft.scalatorch.torch.internal 6 | 7 | /** @see https://pytorch.org/docs/stable/generated/torch.Generator.html 8 | */ 9 | class Generator private (private[torch] val underlying: internal.Generator) { 10 | def device: device = new Device(underlying.device()) 11 | def get_state(implicit rm: ReferenceManager): Tensor = Tensor(underlying.get_state()) 12 | def set_state(state: Tensor): Generator = { 13 | underlying.set_state(state.underlying) 14 | this 15 | } 16 | def initial_seed: Long = underlying.initial_seed() 17 | def manual_seed(seed: Long): Unit = underlying.manual_seed(seed) 18 | def seed(): Long = underlying.seed() 19 | } 20 | 21 | object Generator { 22 | private[torch] def apply(underlying: internal.Generator): Generator = { 23 | Disposer.add(new Generator(underlying), () => underlying.delete()) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/Model.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import java.io.File 4 | 5 | import scala.collection.JavaConverters._ 6 | 7 | import com.microsoft.scalatorch.torch.internal.{ torch_swig, Layout } 8 | import com.microsoft.scalatorch.torch.jit.{ ClassType, Module } 9 | import com.microsoft.scalatorch.torch.util.NoGrad 10 | import resource.ManagedResource 11 | import com.microsoft.scalatorch.torch.nn.init.ParameterInit 12 | import syntax._ 13 | 14 | /** A wrapper of a Torch Module representing the "root" of the [[Module]] tree. 15 | * 16 | * A note about memory ownership: this class contains a [[ReferenceManager]] for managing all parameters 17 | * stored internally. When you register a [[Parameter]] or [[Module]], you pass a factory that accepts a 18 | * [[ReferenceManager]] and returns a new [[Tensor]] or [[Module]]. The factory is responsible for adding 19 | * itself to the [[ReferenceManager]], but note that methods that make [[Tensor]]s and [[Module]]s typically 20 | * have a signature that takes an implicit [[ReferenceManager]] -- for example, [[Tensor.fromLongArray]], 21 | * and so you simply need to pass that method with the implicit parameter curried. 22 | * 23 | * Somewhat confusingly, the "get" methods like [[getParameter]] take an implicit [[ReferenceManager]] that 24 | * owns the wrapper created by [[getParameter]], but not the underlying storage. Typically, you can pass 25 | * a temporary manager created by [[ReferenceManager.forBlock]] to manage their storage, assuming 26 | * you only use the return [[Tensor]] temporarily of course. 27 | * 28 | * TODO it's unclear if we need this vs just using [[Module]]s directly, but for now it helps organize memory ownership. 29 | * 30 | * @param module underlying PyTorch [[Module]] 31 | * @see https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_module.html?highlight=module 32 | * @see 33 | */ 34 | class Model private[torch] ( 35 | private[torch] val module: jit.Module, 36 | ) extends java.io.Closeable { 37 | 38 | def save(filename: String): Unit = module.save(new File(filename)) 39 | 40 | private[torch] val owner: ReferenceManager = new ReferenceManager {} 41 | 42 | /** The [[ClassType]] of the inner [[Module]] */ 43 | lazy val classType: ClassType = module.`type` 44 | 45 | owner.addReference(module) 46 | 47 | /** Put the model into train mode. See https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train */ 48 | def train(on: Boolean = true): Unit = module.train(on) 49 | 50 | def isTraining(): Boolean = module.is_training() 51 | 52 | /** Put the model into eval mode. See https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.eval */ 53 | def eval(): Unit = module.eval() 54 | 55 | def addParameters( 56 | shape: Size, 57 | name: String, 58 | init: ParameterInit = com.microsoft.scalatorch.torch.nn.init.glorotUniform(false), 59 | ): Tensor = { 60 | module.addParameter(shape, name, init)(owner) 61 | } 62 | 63 | def registerParameter(name: String, tensor: ReferenceManager => Tensor): Tensor = { 64 | val newTensor = tensor(owner) 65 | module.register_parameter(name, newTensor) 66 | newTensor 67 | } 68 | 69 | def registerModule[M <: jit.Module](name: String, childModule: ReferenceManager => M): M = { 70 | val newModule = childModule(owner) 71 | module.register_module(name, newModule) 72 | newModule 73 | } 74 | 75 | def getParameter(name: String)(implicit manager: ReferenceManager): Option[Tensor] = { 76 | module.getParameter(name) 77 | } 78 | 79 | def getModule(name: String)(implicit manager: ReferenceManager): Option[jit.Module] = { 80 | module.getModule(name) 81 | } 82 | 83 | /** Should be efficient for the 2-norm on sparse gradients. No guarantees for other norms. */ 84 | def gradNorm(p: Float = 2f): Double = NoGrad.noGrad { 85 | ReferenceManager.forBlock { implicit rm => 86 | // TODO this might not be efficient for (sparse) Embeddings 87 | val parameters = module.parameters(recurse = true) 88 | parameters.map { param => 89 | val grad = param.grad 90 | if (!grad.underlying.defined()) { 91 | 0f 92 | } else if (p == 2f && grad.underlying.layout() == Layout.Sparse) { 93 | // TODO does this work? 94 | // TODO handle other sparse norms. If we can get access to the undelrying 95 | // sparse tensor we call norm on that directly. 96 | val twoNorm = torch_swig._sparse_sum((grad * grad).underlying) 97 | try Math.sqrt(twoNorm.toFloat) 98 | finally twoNorm.delete() 99 | } else { 100 | val tensor = torch_swig.norm(grad.underlying, Scalar.fromFloat(p).underlying) 101 | try tensor.toFloat 102 | finally tensor.delete() 103 | } 104 | }.sum 105 | } 106 | } 107 | 108 | def serializeToByteArray(): Array[Byte] = { 109 | module.serializeToByteArray() 110 | } 111 | 112 | override def close(): Unit = { 113 | owner.close() 114 | } 115 | } 116 | 117 | object Model { 118 | 119 | def managed(name: String): ManagedResource[Model] = { 120 | resource.makeManagedResource(new Model(Module(name)))(_.close())(List.empty) 121 | } 122 | 123 | private[torch] def apply(name: String): Model = { 124 | new Model(Module(name)) 125 | } 126 | 127 | def loadFromByteArray(array: Array[Byte]): Model = { 128 | new Model(jit.Module.load(array)) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/Reduction.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | /** mostly an export of [[Reduction]] 4 | */ 5 | object Reduction { 6 | val None = internal.Reduction.None 7 | val Mean = internal.Reduction.Mean 8 | val Sum = internal.Reduction.Sum 9 | } 10 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/ReferenceManager.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import java.lang.ref._ 5 | 6 | import resource.{ DefaultManagedResource, ManagedResource } 7 | 8 | /** LibTorch uses ref counting to do its memory management, which doesn't interact 9 | * the best with JVM's GC. We use this class to keep track of "generations" of 10 | * related Tensors and the like. 11 | * 12 | * There are three common uses: 13 | * * one for [[Model]]s, whose parameter tensors typically live a long time, 14 | * * one (conceptually) for a "computation graph" that represents 15 | * one coherent unit of execution (e.g. one minibatch of examples), and 16 | * * one for a short lived objects (see uses of [[ReferenceManager.forBlock]]). 17 | * 18 | * The general pattern in Scala-Torch is that any function that returns a torch object ([[Tensor]], [[Scalar]], 19 | * etc) takes an implicit [[ReferenceManager]] as an input and the callee adds its result to the provided 20 | * [[ReferenceManager]] to pass ownership out to the caller. 21 | */ 22 | class ReferenceManager extends AutoCloseable { 23 | private[torch] def addReference(reference: AnyRef): reference.type = { 24 | assertOpen(reference) 25 | references += reference 26 | reference 27 | } 28 | 29 | private def assertOpen(reference: AnyRef): Unit = { 30 | assert(!isClosed, s"attempt to register $reference against closed owner $this") 31 | } 32 | 33 | def close(): Unit = { 34 | references.foreach { 35 | case r: AutoCloseable => r.close() 36 | case _ => () 37 | } 38 | references.clear() 39 | closed = true 40 | } 41 | 42 | override protected def finalize(): Unit = { 43 | if (!isClosed && hasReferences) { 44 | System.err.println(s"Failed to close reference owner $this. garbage collecting") 45 | close() 46 | } 47 | } 48 | 49 | private def hasReferences = references.nonEmpty 50 | 51 | def isClosed: Boolean = closed 52 | 53 | // TODO: is this true in torch? 54 | // Tensors sometimes rely on things (e.g. wrapped C++ vectors) that get deleted when the JVM 55 | // garbage collector runs. By explicitly grabbing references to them, we can prevent this 56 | // premature garbage collection. 57 | private val references: ArrayBuffer[AnyRef] = ArrayBuffer() 58 | private var closed = false 59 | 60 | } 61 | 62 | object ReferenceManager { 63 | 64 | /** This function can be used to automatically delete the [[ReferenceManager]] after the block 65 | * 66 | * Use this function with: 67 | * {{{ 68 | * forBlock { implicit cg => do stuff } 69 | * }}} 70 | */ 71 | def forBlock[T](f: ReferenceManager => T): T = { 72 | managed.apply(f) 73 | } 74 | 75 | def managed[T]: ManagedResource[ReferenceManager] = { 76 | resource.makeManagedResource(new ReferenceManager)(_.close())(List.empty) 77 | } 78 | 79 | /** This is a global reference manager if you don't want to use scoped memory management. 80 | * If you're used to Dynet, using scoped memory management isn't *quite* as necessary as 81 | * with PyTorch: Torch doesn't use arenas to do memory management, just a bunch of 82 | * reference counted pointers, and the ReferenceManager only holds weak references 83 | * to tensors and the like, so collection can happen when it's still "in scope" 84 | * 85 | * Nevertheless, it's a good idea to use a scoped ReferenceManager when convenient: 86 | * The JVM doesn't "feel" the native memory pressure created by keeping a bunch 87 | * of "heavy" pointers. As a simple example, a basic MNIST experiment was using something 88 | * like 5 gigs of memory with just a global manager, while switching to scoped 89 | * memory management got it down to 1 gig or so (most of which was JVM heap). 90 | * 91 | * You can also call System.gc() though that's not guaranteed to do anything. 92 | */ 93 | val global = new ReferenceManager 94 | 95 | object Implicits { 96 | implicit val global: ReferenceManager = ReferenceManager.global 97 | } 98 | } 99 | 100 | /** A reference that can be managed by [[ReferenceManager]] */ 101 | trait TorchReference[+Underlying] extends AutoCloseable { 102 | 103 | /** Not thread safe. */ 104 | final def close(): Unit = { 105 | if (!deleted) { 106 | delete() 107 | deleted = true 108 | } 109 | } 110 | private var deleted = false 111 | def isClosed: Boolean = deleted 112 | 113 | /** Should actually free the resource */ 114 | protected def delete(): Unit 115 | 116 | protected[torch] def assertOpen[T](body: => T): T = { 117 | if (isClosed) throw new IllegalStateException("Attempt to access closed resource " + this.getClass.getName()) 118 | body 119 | } 120 | 121 | def underlyingChecked: Underlying = assertOpen(underlying) 122 | 123 | protected[torch] def underlying: Underlying 124 | } 125 | 126 | object TorchReference { 127 | protected[torch] def assertOpen[T](refs: TorchReference[_]*)(body: => T): T = { 128 | refs.foreach(ref => if (ref.isClosed) ref.assertOpen("blah")) 129 | body 130 | } 131 | 132 | // TODO we should add some helpers that make it easy to cleanup native objects 133 | // with syntax like {{{using(method.makeSwigObject())(x => {...})}}} 134 | // that expands to 135 | // {{{ 136 | // val x = method.makeSwigObject() 137 | // try {...} 138 | // finally x.delete() 139 | // }}} 140 | // Unfortunately this is hard without using reflection because swig objects don't inherit from a common 141 | // trait (like AutoCloseable that exposes close()). 142 | } 143 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/Scalar.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.internal 5 | 6 | /** Scalar represents a 0-dimensional tensor which contains a single element. 7 | * Wraps the C++ class of the same name. 8 | * 9 | * @see https://github.com/pytorch/pytorch/tree/master/c10/core/Scalar.h 10 | */ 11 | class Scalar private (protected[torch] val underlying: internal.Scalar) extends TorchReference[internal.Scalar] { 12 | def toFloat: Float = underlying.toFloat 13 | def toDouble: Double = underlying.toDouble 14 | def toInt: Int = underlying.toInt 15 | def toLong: Long = underlying.toLong 16 | def toBoolean: Boolean = underlying.toBoolean 17 | 18 | override def toString: String = underlying.toString 19 | 20 | def `type`()(implicit rm: ReferenceManager): dtype = dtype(underlying.`type`()) 21 | def unary_-(implicit rm: ReferenceManager): Scalar = Scalar(underlying.unary_minus()) 22 | def conj()(implicit rm: ReferenceManager): Scalar = Scalar(underlying.conj()) 23 | def log()(implicit rm: ReferenceManager): Scalar = Scalar(underlying.log()) 24 | 25 | def isFloatingPoint(): Boolean = underlying.isFloatingPoint() 26 | 27 | def isIntegral(includeBool: Boolean): Boolean = underlying.isIntegral(includeBool) 28 | 29 | def isComplex(): Boolean = underlying.isComplex() 30 | 31 | def isBoolean(): Boolean = underlying.isBoolean() 32 | 33 | override protected def delete(): Unit = underlying.delete() 34 | } 35 | 36 | object Scalar { 37 | private[torch] def apply(underlying: internal.Scalar)(implicit manager: ReferenceManager): Scalar = { 38 | manager.addReference(new Scalar(underlying)) 39 | } 40 | 41 | implicit def fromFloat(f: Float)(implicit manager: ReferenceManager): Scalar = Scalar(new internal.Scalar(f)) 42 | implicit def fromInt(f: Int)(implicit manager: ReferenceManager): Scalar = Scalar(new internal.Scalar(f)) 43 | implicit def fromBoolean(f: Boolean)(implicit manager: ReferenceManager): Scalar = Scalar(new internal.Scalar(f)) 44 | implicit def fromDouble(f: Double)(implicit manager: ReferenceManager): Scalar = Scalar(new internal.Scalar(f)) 45 | implicit def fromLong(f: Long)(implicit manager: ReferenceManager): Scalar = Scalar(new internal.Scalar(f)) 46 | } 47 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/Size.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import scala.collection.compat._ 4 | import scala.collection.JavaConverters._ 5 | 6 | import com.microsoft.scalatorch.torch.internal.LongVector 7 | 8 | /** Same as torch.Size. We don't use a simple alias for an Array[Long] because 9 | * we want a nice toString and structural equality. 10 | */ 11 | class Size(val underlying: immutable.ArraySeq.ofLong) extends AnyVal { 12 | def rank: Int = underlying.unsafeArray.length 13 | 14 | def sizes: Array[Long] = underlying.unsafeArray 15 | 16 | def numel(): Long = underlying.unsafeArray.product 17 | override def toString(): String = underlying.mkString("Size(", ", ", ")") 18 | 19 | def apply(i: Int): Long = underlying(i) 20 | 21 | } 22 | object Size { 23 | def apply(array: Array[Long]): Size = new Size(new immutable.ArraySeq.ofLong(array)) 24 | def apply(size: Long*): Size = apply(size.toArray) 25 | def apply(dims: LongVector): Size = apply(dims.asScala.map(_.toLong).toArray) 26 | 27 | implicit def unwrapSizeToArray(size: Size): Array[Long] = size.sizes 28 | implicit def unwrapSizeToSeq(size: Size): Seq[Long] = size.underlying 29 | } 30 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/TensorOptions.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.internal.{ Layout, TypeMeta } 5 | import com.microsoft.scalatorch.torch.internal 6 | import resource.ManagedResource 7 | import com.microsoft.scalatorch.torch.syntax._ 8 | 9 | // TODO: devices 10 | /** Class for configuring the underlying properties of a tensor. Most notably the [[dtype]] 11 | * and (eventually) device. 12 | * 13 | * @see https://pytorch.org/cppdocs/notes/tensor_creation.html#configuring-properties-of-the-tensor 14 | */ 15 | case class TensorOptions( 16 | dtype: Option[dtype] = None, 17 | device: Option[Device] = None, 18 | layout: Option[Layout] = None, 19 | requires_grad: Option[Boolean] = None, 20 | pinned_memory: Option[Boolean] = None, 21 | ) { 22 | private[torch] def toInternal: ManagedResource[internal.TensorOptions] = { 23 | def man(o: => internal.TensorOptions) = resource.makeManagedResource(o)(_.delete())(List.empty) 24 | // TODO will this be a perf problem? 25 | for { 26 | orig <- man(new internal.TensorOptions()) 27 | afterDtype <- man(orig.dtype(dtype.map(_.underlying))) 28 | afterDevice <- man(afterDtype.device(device.map(_.underlying))) 29 | afterLayout <- man(afterDevice.layout(layout)) 30 | afterRequiredsGrad <- man(afterLayout.requires_grad(requires_grad.map(java.lang.Boolean.valueOf))) 31 | afterPinnedMemory <- man(afterRequiredsGrad.pinned_memory(pinned_memory.map(java.lang.Boolean.valueOf))) 32 | } yield afterPinnedMemory 33 | } 34 | } 35 | 36 | object TensorOptions { 37 | // implicits to more or less mirror the torch api 38 | implicit def fromMeta(meta: TypeMeta): TensorOptions = TensorOptions(Some(new dtype(meta))) 39 | 40 | val default: TensorOptions = TensorOptions() 41 | } 42 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/dtype.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.internal.{ torch_swig, ScalarType, TypeMeta } 5 | import com.microsoft.scalatorch.torch.internal 6 | import com.microsoft.scalatorch.torch.internal.{ torch_swig, ScalarType, TypeMeta } 7 | 8 | /** @see https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype 9 | */ 10 | class dtype private[torch] (private[torch] val underlying: TypeMeta) { 11 | 12 | def is_complex: Boolean = { 13 | torch_swig.isComplexType(underlying.toScalarType) 14 | } 15 | 16 | def is_floating_point: Boolean = { 17 | torch_swig.isFloatingType(underlying.toScalarType) 18 | } 19 | 20 | def is_signed: Boolean = { 21 | internal.torch_swig.isSignedType(underlying.toScalarType) 22 | } 23 | 24 | override def toString: String = underlying.name 25 | 26 | override def equals(o: Any): Boolean = o match { 27 | case that: dtype => underlying.equalTo(that.underlying) 28 | } 29 | 30 | def toScalarType: ScalarType = underlying.toScalarType 31 | } 32 | 33 | object dtype { 34 | private[torch] def apply(underlying: ScalarType)(implicit cg: ReferenceManager): dtype = { 35 | cg.addReference(new dtype(TypeMeta.fromScalarType(underlying))) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/fft/package.scala: -------------------------------------------------------------------------------- 1 | // THIS FILE IS AUTO-GENERATED, DO NOT EDIT. Changes should be made to package.scala.in 2 | 3 | package com.microsoft.scalatorch.torch 4 | 5 | import com.microsoft.scalatorch.torch 6 | import com.microsoft.scalatorch.torch._ 7 | import com.microsoft.scalatorch.torch.util.Implicits._ 8 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 9 | import com.microsoft.scalatorch.torch.util.NoGrad 10 | 11 | package object fft { 12 | // THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT 13 | // See swig/src/main/swig/build.sbt for details 14 | def fft(self: Tensor, n: Option[Long] = None, dim: Long = -1, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_fft(self.underlying, n.asJavaLong, dim, norm)) 15 | def ifft(self: Tensor, n: Option[Long] = None, dim: Long = -1, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_ifft(self.underlying, n.asJavaLong, dim, norm)) 16 | def rfft(self: Tensor, n: Option[Long] = None, dim: Long = -1, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_rfft(self.underlying, n.asJavaLong, dim, norm)) 17 | def irfft(self: Tensor, n: Option[Long] = None, dim: Long = -1, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_irfft(self.underlying, n.asJavaLong, dim, norm)) 18 | def hfft(self: Tensor, n: Option[Long] = None, dim: Long = -1, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_hfft(self.underlying, n.asJavaLong, dim, norm)) 19 | def ihfft(self: Tensor, n: Option[Long] = None, dim: Long = -1, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_ihfft(self.underlying, n.asJavaLong, dim, norm)) 20 | def fft2(self: Tensor, s: Option[Array[Long]] = None, dim: Array[Long] = Array(-2,-1), norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_fft2(self.underlying, s, dim, norm)) 21 | def ifft2(self: Tensor, s: Option[Array[Long]] = None, dim: Array[Long] = Array(-2,-1), norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_ifft2(self.underlying, s, dim, norm)) 22 | def rfft2(self: Tensor, s: Option[Array[Long]] = None, dim: Array[Long] = Array(-2,-1), norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_rfft2(self.underlying, s, dim, norm)) 23 | def irfft2(self: Tensor, s: Option[Array[Long]] = None, dim: Array[Long] = Array(-2,-1), norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_irfft2(self.underlying, s, dim, norm)) 24 | def fftn(self: Tensor, s: Option[Array[Long]] = None, dim: Option[Array[Long]] = None, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_fftn(self.underlying, s, dim, norm)) 25 | def ifftn(self: Tensor, s: Option[Array[Long]] = None, dim: Option[Array[Long]] = None, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_ifftn(self.underlying, s, dim, norm)) 26 | def rfftn(self: Tensor, s: Option[Array[Long]] = None, dim: Option[Array[Long]] = None, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_rfftn(self.underlying, s, dim, norm)) 27 | def irfftn(self: Tensor, s: Option[Array[Long]] = None, dim: Option[Array[Long]] = None, norm: Option[String] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_irfftn(self.underlying, s, dim, norm)) 28 | def fftfreq(n: Long, d: Double = 1.0, dtype: Option[dtype] = None, layout: Option[Layout] = None, device: Option[Device] = None, pin_memory: Option[Boolean] = None)(implicit rm: ReferenceManager): Tensor = TensorOptions( 29 | dtype=dtype, 30 | device=device, 31 | layout=layout, 32 | pinned_memory=pin_memory, 33 | ).toInternal.apply { options => 34 | Tensor(swig.fft_fftfreq(n, d, options)) 35 | } 36 | 37 | def rfftfreq(n: Long, d: Double = 1.0, dtype: Option[dtype] = None, layout: Option[Layout] = None, device: Option[Device] = None, pin_memory: Option[Boolean] = None)(implicit rm: ReferenceManager): Tensor = TensorOptions( 38 | dtype=dtype, 39 | device=device, 40 | layout=layout, 41 | pinned_memory=pin_memory, 42 | ).toInternal.apply { options => 43 | Tensor(swig.fft_rfftfreq(n, d, options)) 44 | } 45 | 46 | def fftshift(self: Tensor, dim: Option[Array[Long]] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_fftshift(self.underlying, dim)) 47 | def ifftshift(self: Tensor, dim: Option[Array[Long]] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.fft_ifftshift(self.underlying, dim))} 48 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/fft/package.scala.in: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch._ 5 | import com.microsoft.scalatorch.torch.util.Implicits._ 6 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 7 | import com.microsoft.scalatorch.torch.util.NoGrad 8 | 9 | package object fft { 10 | // @@@ bindgen.py inserts generated bindings here @@@ 11 | } 12 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/jit/CompilationUnit.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.jit 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.util.Disposer 5 | import com.microsoft.scalatorch.torch.internal 6 | import com.microsoft.scalatorch.torch.util.Disposer 7 | import com.microsoft.scalatorch.torch._ 8 | import com.microsoft.scalatorch.torch.util.Disposer 9 | 10 | class CompilationUnit private (protected[torch] val underlying: internal.CompilationUnit) 11 | extends TorchReference[internal.CompilationUnit] { 12 | override protected def delete(): Unit = underlying.delete() 13 | } 14 | 15 | object CompilationUnit { 16 | 17 | private[torch] def apply(underlying: internal.CompilationUnit): CompilationUnit = { 18 | Disposer.add(new CompilationUnit(underlying), () => underlying.delete()) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/jit/IValue.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.jit 2 | 3 | import java.io.File 4 | import java.lang 5 | 6 | import com.microsoft.scalatorch.torch.internal 7 | import com.microsoft.scalatorch.torch.internal.torch_swig 8 | import com.microsoft.scalatorch.torch.internal 9 | import com.microsoft.scalatorch.torch._ 10 | import com.microsoft.scalatorch.torch.internal.{ torch_swig, IValueVector } 11 | 12 | /** An IValue in C++ land is a "generic tagged union used by the interpreter to hold 13 | * all value types." 14 | * 15 | * They're used by [[Module]], which are sort of necessarily dynamically typed. 16 | */ 17 | class IValue private[torch] (protected[torch] val underlying: internal.IValue)(rm: ReferenceManager) 18 | extends TorchReference[internal.IValue] { 19 | def scriptType: Type = Type(underlying.`type`()) 20 | 21 | rm.addReference(this) 22 | override protected def delete(): Unit = underlying.delete() 23 | 24 | def asTensor(implicit rm: ReferenceManager): Tensor = Tensor(underlyingChecked.toTensor) 25 | def asDouble: Double = underlyingChecked.toDouble 26 | def asLong: Long = underlyingChecked.toInt 27 | def asBoolean: Boolean = underlyingChecked.toBool 28 | def asString: String = underlyingChecked.toStringRef 29 | def asScalar(implicit rm: ReferenceManager): Scalar = Scalar(underlyingChecked.toScalar) 30 | } 31 | 32 | object IValue { 33 | def apply(underlying: internal.IValue)(implicit rm: ReferenceManager): IValue = new IValue(underlying)(rm) 34 | def fromPickle(f: File)(implicit rm: ReferenceManager): IValue = { 35 | new IValue(torch_swig.unpickle_from_file(f.toString))(rm) 36 | } 37 | 38 | def none(implicit rm: ReferenceManager): IValue = IValue(new internal.IValue())(rm) 39 | 40 | implicit def fromTensor(t: Tensor)(implicit rm: ReferenceManager): IValue = { 41 | new IValue(new internal.IValue(t.underlyingChecked))(rm) 42 | } 43 | 44 | implicit def fromModule(m: Module)(implicit rm: ReferenceManager): IValue = 45 | IValue(new internal.IValue(m.underlyingChecked))(rm) 46 | 47 | implicit def fromDouble(t: Double)(implicit rm: ReferenceManager): IValue = IValue(new internal.IValue(t))(rm) 48 | 49 | implicit def fromLong(t: Long)(implicit rm: ReferenceManager): IValue = IValue(new internal.IValue(t))(rm) 50 | 51 | implicit def fromBoolean(t: Boolean)(implicit rm: ReferenceManager): IValue = IValue(new internal.IValue(t))(rm) 52 | 53 | implicit def fromLongs(t: Array[Long])( 54 | implicit rm: ReferenceManager, 55 | ): IValue = IValue(new internal.IValue(t))(rm) 56 | 57 | implicit def fromString(t: String)(implicit rm: ReferenceManager): IValue = IValue(new internal.IValue(t))(rm) 58 | 59 | // TODO: fromDoubles 60 | // TODO: TensorList 61 | implicit def fromScalar(t: Scalar)(implicit rm: ReferenceManager): IValue = { 62 | IValue(new internal.IValue(t.underlyingChecked))(rm) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/jit/Module.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.jit 2 | 3 | import java.io.{ Closeable, File } 4 | 5 | import scala.collection.JavaConverters._ 6 | 7 | import com.microsoft.scalatorch.torch.internal 8 | import com.microsoft.scalatorch.torch.internal.{ torch_swig, QualifiedName, TypeMeta } 9 | import com.microsoft.scalatorch.torch.util.Disposer 10 | import com.microsoft.scalatorch.torch.internal 11 | import com.microsoft.scalatorch.torch.internal.{ torch_swig, QualifiedName, TypeMeta } 12 | import com.microsoft.scalatorch.torch.util.Disposer 13 | import resource.ManagedResource 14 | import com.microsoft.scalatorch.torch._ 15 | import com.microsoft.scalatorch.torch.nn.init.ParameterInit 16 | import com.microsoft.scalatorch.torch.util.Disposer 17 | import com.microsoft.scalatorch.torch.syntax._ 18 | 19 | /** A [[Module]] is a wrapper around the torch::jit::script::Module, which is a deserialized 20 | * torchscript or traced module (from Python or potentially C++/Scala, though the latter don't support 21 | * tracing yet). 22 | */ 23 | class Module private[torch] ( 24 | protected[torch] val underlying: internal.Module, 25 | ) extends TorchReference[internal.Module] { 26 | 27 | def getParameter(name: String)(implicit rm: ReferenceManager): Option[Tensor] = { 28 | val v = underlying.attr(name, IValue.none.underlying) 29 | try if (v.isNone) None else Some(Tensor(v.toTensor)) 30 | finally v.delete() 31 | } 32 | 33 | def getModule(name: String)(implicit rm: ReferenceManager): Option[Module] = { 34 | val v = underlying.attr(name, IValue.none.underlying) 35 | try if (v.isNone) None else Some(rm.addReference(Module(v.toModule))) 36 | finally v.delete() 37 | } 38 | 39 | def name: String = underlying.name() 40 | 41 | def register_module(name: String, module: Module): Unit = { 42 | underlying.register_module(name, module.underlyingChecked) 43 | } 44 | 45 | def register_parameter(name: String, v: Tensor, is_buffer: Boolean = false): Unit = { 46 | underlying.register_parameter(name, v.underlying, is_buffer) 47 | } 48 | 49 | def register_attribute( 50 | name: String, 51 | t: Type, 52 | v: IValue, 53 | is_param: Boolean = false, 54 | is_buffer: Boolean = false, 55 | ): Unit = { 56 | underlying.register_attribute(name, t.underlying, v.underlying, is_param, is_buffer) 57 | } 58 | 59 | /** Like [[register_parameter]], but initializes a [[Tensor]] for you. 60 | * 61 | * TODO we might want to consider taking a [[TensorInfo]] instead of [[Size]]/[[Device]]/[[TypeMeta]], 62 | * but this is more convenient for now. 63 | */ 64 | def addParameter( 65 | shape: Size, 66 | name: String, 67 | init: ParameterInit = com.microsoft.scalatorch.torch.nn.init.glorotUniform(false), 68 | device: Option[Device] = None, 69 | dtype: Option[dtype] = None, 70 | )( 71 | implicit rm: ReferenceManager, 72 | ): Tensor = { 73 | val tensor = Tensor.empty( 74 | shape, 75 | TensorOptions( 76 | requires_grad = Some(true), 77 | dtype = dtype, 78 | device = device, 79 | ), 80 | )(rm) 81 | init.initializeParams(tensor) 82 | 83 | // synchronized because addAttribute seems to be thread unsafe. 84 | underlying.synchronized { 85 | val tensorType = TensorType.create(shape) 86 | // Be extra cautious and use a global lock, just in case the underlying type might be shared across modules. 87 | // Not sure if this is actually happening, but we have observed some rare segfaults in 88 | // [libtorch_cpu.so] c10::ClassType::addAttribute(std::__cxx11::basic_string, std::allocator > const&, std::shared_ptr const&, bool, bool)+0x319 89 | // and this is our best guess. 90 | Module.synchronized { 91 | val t = underlying.`type`() 92 | try t.addAttribute(name, tensorType.underlying, /*is_parameter=*/ true) 93 | finally t.delete() 94 | } 95 | 96 | val newValue = IValue.fromTensor(tensor) 97 | underlying.setattr(name, newValue.underlying) 98 | } 99 | tensor 100 | } 101 | 102 | def eval(): Unit = underlyingChecked.eval() 103 | 104 | def is_training(): Boolean = underlyingChecked.is_training() 105 | 106 | def train(on: Boolean = true): Unit = underlyingChecked.train(on) 107 | 108 | lazy val `type`: ClassType = ClassType(underlying.`type`()) 109 | 110 | override protected def delete(): Unit = { 111 | underlyingChecked.delete() 112 | } 113 | 114 | def serializeToByteArray(): Array[Byte] = torch_swig.save_Module_to_byte_array(underlyingChecked) 115 | 116 | def forward(args: Seq[IValue])(implicit rm: ReferenceManager): IValue = { 117 | val iValueVector = new internal.IValueVector(args.map(_.underlying).asJava) 118 | try new IValue(underlyingChecked.forward(iValueVector))(rm) 119 | finally iValueVector.delete() 120 | } 121 | 122 | def invoke(methodName: String, args: IValue*)(implicit rm: ReferenceManager): IValue = { 123 | val iValueVector = new internal.IValueVector(args.map(_.underlying).asJava) 124 | try new IValue(underlying.run_method(methodName, iValueVector))(rm) 125 | finally iValueVector.delete() 126 | 127 | } 128 | 129 | def to(device: Device): Unit = underlyingChecked.to(device.underlyingChecked) 130 | 131 | def save(file: File): Unit = underlyingChecked.save(file.toString) 132 | def serialize(): Array[Byte] = internal.torch_swig.save_Module_to_byte_array(underlyingChecked) 133 | 134 | // TODO: there's no direct way to get a list of attributes out 135 | 136 | def attr(name: String)(implicit rm: ReferenceManager): Option[IValue] = { 137 | val uc = underlyingChecked 138 | if (uc.hasattr(name)) { 139 | Some(new IValue(uc.attr(name))(rm)) 140 | } else { 141 | None 142 | } 143 | } 144 | 145 | def hasattr(name: String): Boolean = underlyingChecked.hasattr(name) 146 | 147 | def parameters(recurse: Boolean = true)(implicit rm: ReferenceManager): Iterable[Tensor] = { 148 | named_parameters(recurse).values 149 | } 150 | 151 | def named_parameters(recurse: Boolean = true)(implicit rm: ReferenceManager): Map[String, Tensor] = { 152 | val np = underlying.named_parameters(recurse) 153 | try { 154 | np.asScala.map(t => (t.getName, Tensor(t.value()))).toMap 155 | } finally { 156 | np.delete() 157 | } 158 | } 159 | 160 | def named_children(implicit rm: ReferenceManager): Map[String, Module] = { 161 | val nc = underlying.named_children() 162 | try { 163 | nc.asScala.map(t => (t.getName, Module(t.value()))).toMap 164 | } finally { 165 | nc.delete() 166 | } 167 | } 168 | } 169 | 170 | object Module { 171 | 172 | private[torch] def apply(underlying: internal.Module): Module = { 173 | Disposer.add(new Module(underlying), () => underlying.delete()) 174 | } 175 | 176 | def apply(name: String): Module = { 177 | val qn = new QualifiedName(name) 178 | try { 179 | Module(new internal.Module(qn)) 180 | } finally { 181 | qn.delete() 182 | } 183 | } 184 | 185 | def load(file: File, device: Option[Device] = None): Module = { 186 | Module(internal.torch_swig.load_script_module(file.toString, device.map(_.underlyingChecked))) 187 | } 188 | 189 | def load(data: Array[Byte]): Module = { 190 | Module(internal.torch_swig.load_Module_from_byte_array(data)) 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/jit/OpSymbol.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.jit 2 | 3 | import com.microsoft.scalatorch.torch.internal 4 | import com.microsoft.scalatorch.torch.internal.Symbol 5 | import com.microsoft.scalatorch.torch.internal 6 | import com.microsoft.scalatorch.torch.internal 7 | 8 | /** Names a Torch op. */ 9 | case class OpSymbol(private[torch] val underlying: Symbol) { 10 | override def toString: String = underlying.toDisplayString 11 | } 12 | 13 | object OpSymbol { 14 | // TODO think about whether we need to do anything with the "qualified" part of QualifiedName. 15 | def apply(s: String) = new OpSymbol(internal.Symbol.fromQualString(s)) 16 | } 17 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/jit/Type.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.jit 2 | 3 | import com.microsoft.scalatorch.torch.{ internal, Size } 4 | import com.microsoft.scalatorch.torch.internal.{ DeviceType, TypeMeta, TypeVector } 5 | import com.microsoft.scalatorch.torch.util.Disposer 6 | import com.microsoft.scalatorch.torch.{ internal, Size } 7 | import com.microsoft.scalatorch.torch.util.Disposer 8 | import com.microsoft.scalatorch.torch.{ float, internal, Size } 9 | import com.microsoft.scalatorch.torch.util.Disposer 10 | import com.microsoft.scalatorch.torch.syntax._ 11 | import com.microsoft.scalatorch.torch.internal 12 | 13 | /** A wrapper around c10::Type at https://github.com/pytorch/pytorch/blob/v1.4.0/aten/src/ATen/core/jit_type.h#L65. */ 14 | class Type protected (protected[torch] val underlying: internal.Type) { 15 | override def toString = underlying.toString() 16 | override def equals(o: Any): Boolean = o match { 17 | case t: Type => underlying == t.underlying 18 | case _ => false 19 | } 20 | 21 | // TODO figure out a hash code that matches with equals if we ever use this is a key in a HashMap. 22 | // Currently we only support hashing on [[TensorType]]. 23 | override def hashCode: Int = ??? 24 | } 25 | 26 | object Type { 27 | private[torch] def apply(underlying: internal.Type): Type = { 28 | Disposer.add(new Type(underlying), () => underlying.delete()) 29 | } 30 | 31 | def createDict(keyType: Type, valueType: Type): Type = { 32 | Type(internal.Type.createDict(keyType.underlying, valueType.underlying)) 33 | } 34 | 35 | def createList(elementType: Type): Type = { 36 | Type(internal.Type.createList(elementType.underlying)) 37 | } 38 | 39 | val string: Type = Type(internal.Type.getString) 40 | val float: Type = Type(internal.Type.getFloat) 41 | val bool: Type = Type(internal.Type.getBool) 42 | val int: Type = Type(internal.Type.getInt) 43 | val tensor: TensorType = TensorType(internal.TensorType.get()) 44 | } 45 | 46 | // TODO more static types mirroring TorchScript's typesystem. It's inconsistent that we only have a static 47 | // wrapper for TupleType. 48 | class TupleType(override protected[torch] val underlying: internal.TupleType) extends Type(underlying) 49 | 50 | object TupleType { 51 | private[torch] def apply(underlying: internal.TupleType): TupleType = { 52 | Disposer.add(new TupleType(underlying), () => underlying.delete()) 53 | } 54 | 55 | def create(fieldTypes: Seq[Type]): TupleType = { 56 | val typeVector = new TypeVector(fieldTypes.map(_.underlying).toArray[internal.Type]) 57 | try TupleType(internal.TupleType.create(typeVector)) 58 | finally typeVector.delete() 59 | } 60 | } 61 | 62 | class ClassType private (override protected[torch] val underlying: internal.ClassType) extends Type(underlying) { 63 | lazy val qualifiedName: Option[String] = { 64 | underlying.name().map { name => 65 | try name.qualifiedName() 66 | finally name.delete() 67 | } 68 | } 69 | } 70 | 71 | object ClassType { 72 | private[torch] def apply(underlying: internal.ClassType): ClassType = { 73 | Disposer.add(new ClassType(underlying), () => underlying.delete()) 74 | } 75 | } 76 | 77 | class TensorType private ( 78 | override protected[torch] val underlying: internal.TensorType, 79 | ) extends Type(underlying) { 80 | def shape: Option[Size] = { 81 | val sizes = underlying.sizes() 82 | try sizes.map(Size(_)) 83 | finally sizes.foreach(_.delete()) 84 | } 85 | 86 | override def hashCode: Int = { 87 | (underlying.sizes(), underlying.dtype(), underlying.device()).hashCode 88 | } 89 | } 90 | 91 | object TensorType { 92 | private[torch] def apply(underlying: internal.TensorType): TensorType = { 93 | Disposer.add(new TensorType(underlying), () => underlying.delete()) 94 | } 95 | 96 | def create( 97 | shape: Size, 98 | typeMeta: TypeMeta = float.underlying, 99 | deviceType: DeviceType = DeviceType.CPU, 100 | ): TensorType = { 101 | TensorType(internal.TensorType.createContiguous(typeMeta, deviceType, shape.sizes)) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/linalg/package.scala: -------------------------------------------------------------------------------- 1 | // THIS FILE IS AUTO-GENERATED, DO NOT EDIT. Changes should be made to package.scala.in 2 | 3 | package com.microsoft.scalatorch.torch 4 | 5 | import com.microsoft.scalatorch.torch 6 | import com.microsoft.scalatorch.torch._ 7 | import com.microsoft.scalatorch.torch.util.Implicits._ 8 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 9 | import com.microsoft.scalatorch.torch.util.NoGrad 10 | 11 | package object linalg { 12 | // THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT 13 | // See swig/src/main/swig/build.sbt for details 14 | def cholesky_ex(self: Tensor, upper: Boolean = false, check_errors: Boolean = false)(implicit rm: ReferenceManager): (Tensor, Tensor) = wrapTensorTuple2(swig.linalg_cholesky_ex(self.underlying, upper, check_errors)) 15 | def cholesky(self: Tensor, upper: Boolean = false)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_cholesky(self.underlying, upper)) 16 | def det(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_det(self.underlying)) 17 | def lstsq(self: Tensor, b: Tensor, rcond: Option[Double] = None, driver: Option[String] = None)(implicit rm: ReferenceManager): (Tensor, Tensor, Tensor, Tensor) = wrapTensorTuple4(swig.linalg_lstsq(self.underlying, b.underlying, rcond.asJavaDouble, driver)) 18 | def matmul(self: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_matmul(self.underlying, other.underlying)) 19 | def slogdet(self: Tensor)(implicit rm: ReferenceManager): (Tensor, Tensor) = wrapTensorTuple2(swig.linalg_slogdet(self.underlying)) 20 | def eig(self: Tensor)(implicit rm: ReferenceManager): (Tensor, Tensor) = wrapTensorTuple2(swig.linalg_eig(self.underlying)) 21 | def eigvals(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_eigvals(self.underlying)) 22 | def eigh(self: Tensor, UPLO: String = "L")(implicit rm: ReferenceManager): (Tensor, Tensor) = wrapTensorTuple2(swig.linalg_eigh(self.underlying, UPLO)) 23 | def eigvalsh(self: Tensor, UPLO: String = "L")(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_eigvalsh(self.underlying, UPLO)) 24 | def householder_product(input: Tensor, tau: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_householder_product(input.underlying, tau.underlying)) 25 | def inv_ex(self: Tensor, check_errors: Boolean = false)(implicit rm: ReferenceManager): (Tensor, Tensor) = wrapTensorTuple2(swig.linalg_inv_ex(self.underlying, check_errors)) 26 | def inv(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_inv(self.underlying)) 27 | def norm(self: Tensor, ord: Option[Scalar] = None, dim: Option[Array[Long]] = None, keepdim: Boolean = false, dtype: Option[dtype] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_norm(self.underlying, ord.map(_.underlying), dim, keepdim, dtype.map(_.toScalarType))) 28 | def norm(self: Tensor, ord: String, dim: Option[Array[Long]], keepdim: Boolean, dtype: Option[dtype])(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_norm(self.underlying, ord, dim, keepdim, dtype.map(_.toScalarType))) 29 | def vector_norm(self: Tensor, ord: Double = 2, dim: Option[Array[Long]] = None, keepdim: Boolean = false, dtype: Option[dtype] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_vector_norm(self.underlying, ord.toInternalScalar, dim, keepdim, dtype.map(_.toScalarType))) 30 | def matrix_norm(self: Tensor, ord: Scalar, dim: Array[Long] = Array(-2,-1), keepdim: Boolean = false, dtype: Option[dtype] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_matrix_norm(self.underlying, ord.underlying, dim, keepdim, dtype.map(_.toScalarType))) 31 | def matrix_norm(self: Tensor, ord: String, dim: Array[Long], keepdim: Boolean, dtype: Option[dtype])(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_matrix_norm(self.underlying, ord, dim, keepdim, dtype.map(_.toScalarType))) 32 | def svd(self: Tensor, full_matrices: Boolean = true)(implicit rm: ReferenceManager): (Tensor, Tensor, Tensor) = wrapTensorTuple3(swig.linalg_svd(self.underlying, full_matrices)) 33 | def svdvals(input: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_svdvals(input.underlying)) 34 | def cond(self: Tensor, p: Option[Scalar] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_cond(self.underlying, p.map(_.underlying))) 35 | def cond(self: Tensor, p: String)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_cond(self.underlying, p)) 36 | def pinv(self: Tensor, rcond: Double = 1e-15, hermitian: Boolean = false)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_pinv(self.underlying, rcond, hermitian)) 37 | def pinv(self: Tensor, rcond: Tensor, hermitian: Boolean)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_pinv(self.underlying, rcond.underlying, hermitian)) 38 | def solve(input: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_solve(input.underlying, other.underlying)) 39 | def tensorinv(self: Tensor, ind: Long = 2)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_tensorinv(self.underlying, ind)) 40 | def tensorsolve(self: Tensor, other: Tensor, dims: Option[Array[Long]] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_tensorsolve(self.underlying, other.underlying, dims)) 41 | def qr(self: Tensor, mode: String = "reduced")(implicit rm: ReferenceManager): (Tensor, Tensor) = wrapTensorTuple2(swig.linalg_qr(self.underlying, mode)) 42 | def matrix_power(self: Tensor, n: Long)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_matrix_power(self.underlying, n)) 43 | def matrix_rank(self: Tensor, tol: Option[Double] = None, hermitian: Boolean = false)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_matrix_rank(self.underlying, tol.asJavaDouble, hermitian)) 44 | def matrix_rank(input: Tensor, tol: Tensor, hermitian: Boolean)(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_matrix_rank(input.underlying, tol.underlying, hermitian)) 45 | def multi_dot(tensors: Array[Tensor])(implicit rm: ReferenceManager): Tensor = Tensor(swig.linalg_multi_dot(tensors.map(_.underlyingChecked)))} 46 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/linalg/package.scala.in: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch._ 5 | import com.microsoft.scalatorch.torch.util.Implicits._ 6 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 7 | import com.microsoft.scalatorch.torch.util.NoGrad 8 | 9 | package object linalg { 10 | // @@@ bindgen.py inserts generated bindings here @@@ 11 | } 12 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/nn/functional/package.scala.in: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.nn 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch._ 5 | import com.microsoft.scalatorch.torch.util.Implicits._ 6 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 7 | import com.microsoft.scalatorch.torch.util.NoGrad 8 | 9 | package object functional { 10 | // This overload is only here because the defaults are missing from the auto-generated version. 11 | // Shouldn't be needed once a version of libtorch with https://github.com/pytorch/pytorch/pull/70156 is released. 12 | def poisson_nll_loss( 13 | input: Tensor, 14 | target: Tensor, 15 | log_input: Boolean = false, 16 | full: Boolean = false, 17 | eps: Double = 1e-8, 18 | reduction: Reduction = Reduction.Mean, 19 | )(implicit cg: ReferenceManager): Tensor = { 20 | torch.poisson_nll_loss( 21 | input, 22 | target, 23 | log_input, 24 | full, 25 | eps, 26 | reduction = reduction.swigValue(), 27 | ) 28 | } 29 | 30 | // @@@ bindgen.py inserts generated bindings here @@@ 31 | } 32 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/nn/init/package.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.nn 2 | 3 | import com.microsoft.scalatorch.torch.{ Scalar, Tensor, internal } 4 | import com.microsoft.scalatorch.torch.internal.torch_swig 5 | import com.microsoft.scalatorch.torch.util.NoGrad 6 | 7 | package object init { 8 | 9 | type Nonlinearity = internal.Nonlinearity 10 | object Nonlinearity { 11 | val Linear = internal.Nonlinearity.Linear 12 | val Conv1D = internal.Nonlinearity.Conv1D 13 | val Conv2D = internal.Nonlinearity.Conv2D 14 | val Conv3D = internal.Nonlinearity.Conv3D 15 | val ConvTranspose1D = internal.Nonlinearity.ConvTranspose1D 16 | val ConvTranspose2D = internal.Nonlinearity.ConvTranspose2D 17 | val ConvTranspose3D = internal.Nonlinearity.ConvTranspose3D 18 | val Sigmoid = internal.Nonlinearity.Sigmoid 19 | val Tanh = internal.Nonlinearity.Tanh 20 | val ReLU = internal.Nonlinearity.ReLU 21 | val LeakyReLU = internal.Nonlinearity.LeakyReLU 22 | } 23 | 24 | type FanMode = internal.FanMode 25 | object FanMode { 26 | val FanIn = internal.FanMode.FanIn 27 | val FanOut = internal.FanMode.FanOut 28 | } 29 | 30 | trait ParameterInit { 31 | def initializeParams(values: Tensor): values.type 32 | } 33 | 34 | private class ParameterInitImpl(initInPlace: internal.TorchTensor => Unit) extends ParameterInit { 35 | override def initializeParams(values: Tensor): values.type = NoGrad.noGrad { 36 | initInPlace(values.underlying) 37 | values 38 | } 39 | } 40 | 41 | /** Initialize a parameter with random normal values. Note that dynet v is variance, but pytorch is std. we're using std */ 42 | def normal(m: Float = 0.0f, std: Float = 1.0f): ParameterInit = 43 | new ParameterInitImpl(torch_swig.normal_(_, m, std)) 44 | 45 | def normal_(tensor: Tensor, m: Float = 0.0f, std: Float = 1.0f): tensor.type = { 46 | torch_swig.normal_(tensor.underlying, m, std) 47 | tensor 48 | } 49 | 50 | /** Initialize a parameter with random uniform [left, right] values */ 51 | def uniform(low: Float = 0f, high: Float = 1f): ParameterInit = 52 | new ParameterInitImpl(torch_swig.uniform_(_, low, high)) 53 | def uniform_(tensor: Tensor, low: Float = 0f, high: Float = 1f): tensor.type = { 54 | torch_swig.normal_(tensor.underlying, low, high) 55 | tensor 56 | } 57 | 58 | def kaiming_uniform( 59 | a: Float = 0f, 60 | mode: FanMode = FanMode.FanIn, 61 | nonlinearity: Nonlinearity = Nonlinearity.LeakyReLU, 62 | ): ParameterInit = new ParameterInitImpl( 63 | torch_swig.kaiming_uniform_(_, a, mode, nonlinearity), 64 | ) 65 | def kaiming_uniform_( 66 | tensor: Tensor, 67 | a: Float = 0f, 68 | mode: FanMode = FanMode.FanIn, 69 | nonlinearity: Nonlinearity = Nonlinearity.LeakyReLU, 70 | ): tensor.type = { 71 | torch_swig.kaiming_uniform_(tensor.underlying, a, mode, nonlinearity) 72 | tensor 73 | } 74 | 75 | def kaiming_normal( 76 | a: Float = 0f, 77 | mode: FanMode = FanMode.FanIn, 78 | nonlinearity: Nonlinearity = Nonlinearity.LeakyReLU, 79 | ): ParameterInit = new ParameterInitImpl(torch_swig.kaiming_normal_(_, a, mode, nonlinearity)) 80 | def kaiming_normal_( 81 | tensor: Tensor, 82 | a: Float = 0f, 83 | mode: FanMode = FanMode.FanIn, 84 | nonlinearity: Nonlinearity = Nonlinearity.LeakyReLU, 85 | ): tensor.type = { 86 | torch_swig.kaiming_normal_(tensor.underlying, a, mode, nonlinearity) 87 | tensor 88 | } 89 | 90 | /** Initialize a parameter with the constant value c */ 91 | def constant(c: Scalar): ParameterInit = 92 | new ParameterInitImpl(t => { 93 | torch_swig.constant_(t, c.underlying) 94 | }) 95 | def constant(d: Double): ParameterInit = { 96 | new ParameterInitImpl(t => { 97 | val c = new internal.Scalar(d) 98 | try torch_swig.constant_(t, c) 99 | finally c.delete() 100 | }) 101 | 102 | } 103 | 104 | def constant_(tensor: Tensor, c: Scalar): tensor.type = { 105 | torch_swig.constant_(tensor.underlying, c.underlying) 106 | tensor 107 | } 108 | 109 | def sparse_(tensor: Tensor, sparsity: Double, std: Double = 0.01): tensor.type = { 110 | torch_swig.sparse_(tensor.underlying, sparsity, std) 111 | tensor 112 | } 113 | 114 | def eye(): ParameterInit = new ParameterInitImpl(torch_swig.eye_) 115 | def eye_(tensor: Tensor): tensor.type = { torch_swig.eye_(tensor.underlying); tensor } 116 | def ones(): ParameterInit = new ParameterInitImpl(torch_swig.ones_) 117 | def ones_(tensor: Tensor): tensor.type = { torch_swig.ones_(tensor.underlying); tensor } 118 | def zeros(): ParameterInit = new ParameterInitImpl(torch_swig.zeros_) 119 | def zeros_(tensor: Tensor): tensor.type = { torch_swig.zeros_(tensor.underlying); tensor } 120 | def dirac(): ParameterInit = new ParameterInitImpl(torch_swig.zeros_) 121 | 122 | def xavier_uniform(gain: Double = 1.0): ParameterInit = new ParameterInitImpl(torch_swig.xavier_uniform_(_, gain)) 123 | def xavier_uniform_(tensor: Tensor, gain: Double = 1.0): tensor.type = { 124 | torch_swig.xavier_uniform_(tensor.underlying, gain) 125 | tensor 126 | } 127 | 128 | /** Initialize a parameter using Glorot (Xavier) uniform initialization */ 129 | def glorotUniform(isLookup: Boolean): ParameterInit = { 130 | new ParameterInitImpl({ t => 131 | // dynet's code reduces to this with a rank 0 or 1 tensor 132 | // (pytorch however gets mad) 133 | val dims = t.sizes() 134 | if (dims.length == 0) { 135 | val u = math.sqrt(3.0f) 136 | torch_swig.uniform_(t, -u, u) 137 | } else if (dims.length == 1 || isLookup) { 138 | require(!isLookup || dims.length == 2) 139 | val denom = if (isLookup) { 140 | // For lookup parameters, we use the same scale for an initialization as 141 | // an individual vector of the embedding size. 142 | dims(1) 143 | } else { 144 | dims.sum 145 | } 146 | val u = math.sqrt(3.0f / denom) 147 | torch_swig.uniform_(t, -u, u) 148 | } else { 149 | torch_swig.xavier_uniform_(t) 150 | } 151 | }) 152 | } 153 | 154 | def fromValues(values: Tensor): ParameterInit = NoGrad.noGrad { 155 | new ParameterInitImpl(_.copy_(values.underlying, false)) 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/special/package.scala: -------------------------------------------------------------------------------- 1 | // THIS FILE IS AUTO-GENERATED, DO NOT EDIT. Changes should be made to package.scala.in 2 | 3 | package com.microsoft.scalatorch.torch 4 | 5 | import com.microsoft.scalatorch.torch 6 | import com.microsoft.scalatorch.torch._ 7 | import com.microsoft.scalatorch.torch.util.Implicits._ 8 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 9 | import com.microsoft.scalatorch.torch.util.NoGrad 10 | 11 | package object special { 12 | // THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT 13 | // See swig/src/main/swig/build.sbt for details 14 | def entr(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_entr(self.underlying)) 15 | def ndtri(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_ndtri(self.underlying)) 16 | def expm1(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_expm1(self.underlying)) 17 | def exp2(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_exp2(self.underlying)) 18 | def psi(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_psi(self.underlying)) 19 | def digamma(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_digamma(self.underlying)) 20 | def gammaln(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_gammaln(self.underlying)) 21 | def erf(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_erf(self.underlying)) 22 | def erfc(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_erfc(self.underlying)) 23 | def erfcx(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_erfcx(self.underlying)) 24 | def erfinv(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_erfinv(self.underlying)) 25 | def ndtr(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_ndtr(self.underlying)) 26 | def xlog1py(self: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_xlog1py(self.underlying, other.underlying)) 27 | def xlog1py(self: Scalar, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_xlog1py(self.underlying, other.underlying)) 28 | def xlog1py(self: Tensor, other: Scalar)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_xlog1py(self.underlying, other.underlying)) 29 | def xlogy(self: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_xlogy(self.underlying, other.underlying)) 30 | def xlogy(self: Scalar, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_xlogy(self.underlying, other.underlying)) 31 | def xlogy(self: Tensor, other: Scalar)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_xlogy(self.underlying, other.underlying)) 32 | def zeta(self: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_zeta(self.underlying, other.underlying)) 33 | def zeta(self: Scalar, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_zeta(self.underlying, other.underlying)) 34 | def zeta(self: Tensor, other: Scalar)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_zeta(self.underlying, other.underlying)) 35 | def i0(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_i0(self.underlying)) 36 | def i0e(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_i0e(self.underlying)) 37 | def i1(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_i1(self.underlying)) 38 | def i1e(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_i1e(self.underlying)) 39 | def logit(self: Tensor, eps: Option[Double] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_logit(self.underlying, eps.asJavaDouble)) 40 | def polygamma(n: Long, self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_polygamma(n, self.underlying)) 41 | def logsumexp(self: Tensor, dim: Array[Long], keepdim: Boolean = false)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_logsumexp(self.underlying, dim, keepdim)) 42 | def expit(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_expit(self.underlying)) 43 | def sinc(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_sinc(self.underlying)) 44 | def round(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_round(self.underlying)) 45 | def log1p(self: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_log1p(self.underlying)) 46 | def log_softmax(self: Tensor, dim: Long, dtype: Option[dtype] = None)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_log_softmax(self.underlying, dim, dtype.map(_.toScalarType))) 47 | def gammainc(self: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_gammainc(self.underlying, other.underlying)) 48 | def gammaincc(self: Tensor, other: Tensor)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_gammaincc(self.underlying, other.underlying)) 49 | def multigammaln(self: Tensor, p: Long)(implicit rm: ReferenceManager): Tensor = Tensor(swig.special_multigammaln(self.underlying, p))} 50 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/special/package.scala.in: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch._ 5 | import com.microsoft.scalatorch.torch.util.Implicits._ 6 | import com.microsoft.scalatorch.torch.internal.{ TensorIndex, TensorVector, TorchTensor, LongVector, torch_swig => swig } 7 | import com.microsoft.scalatorch.torch.util.NoGrad 8 | 9 | package object special { 10 | // @@@ bindgen.py inserts generated bindings here @@@ 11 | } 12 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/syntax.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import scala.reflect.ClassTag 4 | 5 | /** Various implicits to provide syntactic sugar that mirrors the PyTorch API as much as one can do with Scala syntax. 6 | * Use with {{{import com.microsoft.scalatorch.torch.syntax._}}} */ 7 | object syntax { 8 | 9 | 10 | /** Since Python has such nice syntax for list literals with [], it's nice to have nearly as short syntax 11 | * in PyTorch. It's not very idiomatic Scala to do this though, you should really just write `Array`. 12 | * You are welcome to exclude this particular sugar with 13 | * {{{import com.microsoft.scalatorch.torch.syntax.{$ => _, _}}}}*/ 14 | def $[T: ClassTag](ts: T*): Array[T] = Array(ts: _*) 15 | 16 | // We could have a generic some[T], but we prefer to limit the use of implicits to just those necessary 17 | // for pytorch syntax. 18 | implicit def someTensor(tensor: Tensor): Option[Tensor] = Some(tensor) 19 | implicit def someBoolean(b: Boolean): Option[Boolean] = Some(b) 20 | implicit def someDouble(x: Double): Option[Double] = Some(x) 21 | implicit def someGenerator(x: Generator): Option[Generator] = Some(x) 22 | implicit def someLayout(x: Layout): Option[Layout] = Some(x) 23 | implicit def somedtype(x: dtype): Option[dtype] = Some(x) 24 | implicit def someDevice(x: Device): Option[Device] = Some(x) 25 | 26 | implicit def stringToDevice(s: String): Device = device(s) 27 | implicit def stringToOptionDevice(s: String): Option[Device] = Some(device(s)) 28 | 29 | implicit def anyToTensor(a: Any)(implicit rm: ReferenceManager): Tensor = tensor(a) 30 | 31 | implicit def intToArray(int: Int): Array[Long] = Array(int) 32 | implicit def intTupleToArray2(intTuple: (Int, Int)): Array[Long] = Array(intTuple._1, intTuple._2) 33 | implicit def intTupleToArray3(intTuple: (Int, Int, Int)): Array[Long] = Array(intTuple._1, intTuple._2, intTuple._3) 34 | implicit def intTupleToArray4(intTuple: (Int, Int, Int, Int)): Array[Long] = 35 | Array(intTuple._1, intTuple._2, intTuple._3, intTuple._4) 36 | implicit def intTupleToArray5(intTuple: (Int, Int, Int, Int, Int)): Array[Long] = 37 | Array(intTuple._1, intTuple._2, intTuple._3, intTuple._4, intTuple._5) 38 | 39 | implicit def boolTupleToArray2(boolTuple: (Boolean, Boolean)): Array[Boolean] = Array(boolTuple._1, boolTuple._2) 40 | implicit def boolTupleToArray3(boolTuple: (Boolean, Boolean, Boolean)): Array[Boolean] = 41 | Array(boolTuple._1, boolTuple._2, boolTuple._3) 42 | implicit def boolTupleToArray4(boolTuple: (Boolean, Boolean, Boolean, Boolean)): Array[Boolean] = 43 | Array(boolTuple._1, boolTuple._2, boolTuple._3, boolTuple._4) 44 | 45 | implicit def reductionToLong(reduction: Reduction): Long = reduction.swigValue() 46 | 47 | /** Supports the indexing documents in [[Tensor.apply(syntax.Indexer*)(ReferenceManager)]] 48 | */ 49 | sealed trait Indexer 50 | 51 | object Indexer { 52 | 53 | /** Used across varies indexing syntaxes, see use cases below. 54 | */ 55 | private[torch] case class RangeStepIndexer( 56 | bottom: java.util.OptionalLong, 57 | top: java.util.OptionalLong, 58 | step: java.util.OptionalLong, 59 | ) extends Indexer 60 | 61 | /** Allows for the syntax {{{x(1 -> 2)}}}. Python syntax ({{{x[1:2]}}}) is not possible, both because of 62 | * of the different meaning of square and round parens, and also because : is not available operator in Scala. 63 | * Note that {{{x(1::2)}}} and {{{x(::)}}} are both possible and match the meaning in Python. 64 | */ 65 | implicit def intPairToIndexer(pair: (Int, Int)): Indexer = RangeStepIndexer( 66 | java.util.OptionalLong.of(pair._1), 67 | java.util.OptionalLong.of(pair._2), 68 | java.util.OptionalLong.empty(), 69 | ) 70 | 71 | /** Allows for the syntax {{{x(1 -> 2 -> 3)}}}, matching Python's {{{x[1:2:3]}}} 72 | */ 73 | implicit def intTripleToIndexer(triple: ((Int, Int), Int)): Indexer = RangeStepIndexer( 74 | java.util.OptionalLong.of(triple._1._1), 75 | java.util.OptionalLong.of(triple._1._2), 76 | java.util.OptionalLong.of(triple._2), 77 | ) 78 | 79 | /** Allow for {{{foo(1)}}} 80 | */ 81 | implicit def intToElemIndexer(elem: Int): Indexer = ElemIndexer(elem) 82 | private[torch] case class ElemIndexer(elem: Int) extends Indexer 83 | 84 | /** Allow for {{{foo(true)}}} 85 | */ 86 | implicit def boolToBoolIndexer(bool: Boolean): Indexer = BoolIndexer(bool) 87 | private[torch] case class BoolIndexer(bool: Boolean) extends Indexer 88 | 89 | /** Allow for {{{foo(None)}}} 90 | */ 91 | implicit def noneToRangeIndexer(none: None.type): Indexer = 92 | RangeStepIndexer(java.util.OptionalLong.empty(), java.util.OptionalLong.empty(), java.util.OptionalLong.empty()) 93 | } 94 | 95 | /** Ellipsis (...) in python 96 | */ 97 | case object --- extends Indexer 98 | 99 | /** Range (colon) in python 100 | */ 101 | val :: : Indexer = None 102 | 103 | /** Allow for {{{foo(1.::)}}} and {{{foo(1.::(2)}}} 104 | */ 105 | implicit class RichInteger(val bottom: Int) extends AnyVal { 106 | def ::(step: Int): Indexer = { 107 | // Note that despite the names, :: reverses the operators, that is a :: b calls b.::(a) 108 | // So step and bottom are reversed here 109 | Indexer.RangeStepIndexer( 110 | java.util.OptionalLong.of(step), 111 | java.util.OptionalLong.empty(), 112 | java.util.OptionalLong.of(bottom), 113 | ) 114 | } 115 | 116 | def :: : Indexer = 117 | Indexer.RangeStepIndexer( 118 | java.util.OptionalLong.of(bottom), 119 | java.util.OptionalLong.empty(), 120 | java.util.OptionalLong.empty(), 121 | ) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/util/Disposer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * This class is largely copied from 3 | * https://github.com/eaplatanios/tensorflow_scala/blob/master/modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/Disposer.scala 4 | * with some minor modifications. 5 | * 6 | * Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 9 | * use this file except in compliance with the License. You may obtain a copy of 10 | * the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 17 | * License for the specific language governing permissions and limitations under 18 | * the License. 19 | */ 20 | 21 | package com.microsoft.scalatorch.torch.util 22 | 23 | import java.lang.Thread.currentThread 24 | import java.lang.ref.{ PhantomReference, ReferenceQueue } 25 | import java.security.{ AccessController, PrivilegedAction } 26 | import java.util 27 | import java.util.concurrent.ConcurrentHashMap 28 | 29 | import scala.annotation.tailrec 30 | 31 | /** This class is used for registering and disposing the native data associated with Scala objects. 32 | * It is almost identical to [[Cleaner]], except that it uses a [[ConcurrentHashMap]] to store references 33 | * to avoid the locking in [[Cleaner]]. 34 | * 35 | * The object can register itself by calling the [[Disposer.add]] method and providing a disposing function to it. This 36 | * function will be called in order to dispose the native data. It accepts no arguments and returns nothing. 37 | * 38 | * When the object becomes unreachable, the provided disposing function for that object will be called. 39 | * Note that because the garbage collector does not know how much memory the native objects take up, it may accumulate 40 | * many of them before triggering a GC. We make use of [[ReferenceManager]]s to manage the lifetime of frequently 41 | * allocated objects like [[Tensor]]s, but rely on the [[Disposer]] for objects that are unlikely to result in 42 | * memory pressure, but still should be cleaned up when possible. 43 | */ 44 | private[torch] object Disposer { 45 | 46 | private val queue: ReferenceQueue[Any] = new ReferenceQueue[Any] 47 | private val records: util.Map[PhantomReference[Any], () => Unit] = 48 | new ConcurrentHashMap[PhantomReference[Any], () => Unit] 49 | 50 | /** Performs the actual registration of the target object to be disposed. 51 | * Somewhat confusingly, the `disposer` argument *cannot* be retrieved out of `target` or reference it in any way 52 | * because by the time `disposer` runs, `target` will already have been garbage collected 53 | * (hence the slightly funny interface). 54 | * 55 | * The typical pattern is that every Swig-generated type internal.Foo will have a wrapper class called 56 | * Foo with an apply method that looks like 57 | * {{{ 58 | * object Foo { 59 | * def apply(arg1: Int, arg2: String): Foo = { 60 | * val underlying = new internal.Foo(arg1, arg2) 61 | * Disposer.add(new Foo(underlying), () => underlying.delete()) 62 | * } 63 | * } 64 | * }}} 65 | * 66 | * @param target Wrapper object that manages the lifetime of the underlying object 67 | * @param disposer Closure that will clean up any underlying memory. 68 | * @return target for easier chaining. 69 | */ 70 | def add(target: AnyRef, disposer: () => Unit): target.type = { 71 | val reference = new PhantomReference[Any](target, queue) 72 | records.put(reference, disposer) 73 | target 74 | } 75 | 76 | AccessController.doPrivileged(new PrivilegedAction[Unit] { 77 | override def run(): Unit = { 78 | // The thread must be a member of a thread group which will not get GCed before the VM exit. For this reason, we 79 | // make its parent the top-level thread group. 80 | @tailrec def rootThreadGroup(group: ThreadGroup = currentThread.getThreadGroup): ThreadGroup = { 81 | group.getParent match { 82 | case null => group 83 | case parent => rootThreadGroup(parent) 84 | } 85 | } 86 | 87 | new Thread(rootThreadGroup(), "Torch Disposer") { 88 | override def run(): Unit = while (true) { 89 | // Blocks until there is a reference in the queue. 90 | val referenceToDispose = queue.remove 91 | records.remove(referenceToDispose).apply() 92 | referenceToDispose.clear() 93 | } 94 | 95 | setContextClassLoader(null) 96 | setDaemon(true) 97 | // Let Cleaner, which runs at priority MAX_PRIORITY - 2, take precedence. 98 | setPriority(Thread.MAX_PRIORITY - 3) 99 | start() 100 | } 101 | } 102 | }) 103 | } 104 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/util/Implicits.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.util 2 | 3 | import java.util.{ OptionalDouble, OptionalLong } 4 | 5 | import com.microsoft.scalatorch.torch.internal 6 | import com.microsoft.scalatorch.torch.internal 7 | import com.microsoft.scalatorch.torch.{ internal, ReferenceManager, Scalar } 8 | 9 | private[torch] object Implicits { 10 | implicit class RichOptionDouble(private val option: Option[Double]) extends AnyVal { 11 | def asJavaDouble: OptionalDouble = { 12 | option.fold(OptionalDouble.empty())(d => OptionalDouble.of(d)) 13 | } 14 | } 15 | 16 | implicit class RichOptionLong(private val option: Option[Long]) extends AnyVal { 17 | def asJavaLong: OptionalLong = { 18 | option.fold(OptionalLong.empty())(l => OptionalLong.of(l)) 19 | } 20 | } 21 | 22 | implicit class RichDouble(private val d: Double) extends AnyVal { 23 | def toInternalScalar(implicit cg: ReferenceManager): internal.Scalar = { 24 | Scalar.fromDouble(d).underlying 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/util/NoGrad.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.util 2 | 3 | import com.microsoft.scalatorch.torch.internal.GradMode 4 | 5 | /** Like PyTorch's no_grad. Sets a threadlocal variable, then unsets at the end (unless the bit was 6 | * already enabled at invocation). 7 | * 8 | * Useful for directly manipulating parameters. Generally you should avoid this unless you know what you're doing. 9 | * https://datascience.stackexchange.com/questions/32651/what-is-the-use-of-torch-no-grad-in-pytorch 10 | */ 11 | object NoGrad { 12 | def noGrad[T](body: => T): T = { 13 | val wasEnabled = GradMode.is_enabled() 14 | GradMode.set_enabled(false) 15 | 16 | val x = 17 | try { 18 | body 19 | } finally { 20 | if (GradMode.is_enabled()) { 21 | throw new IllegalStateException("Inconsistent state with GradMode, someone turned it on that shouldn't have.") 22 | } 23 | GradMode.set_enabled(wasEnabled) 24 | } 25 | x 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /scala-torch/src/main/scala/com/microsoft/scalatorch/torch/util/Profiler.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.util 2 | 3 | import com.microsoft.scalatorch.torch.internal.RecordProfile 4 | 5 | /** Profiles `body` and dumps the output to `file`. File can be viewed in chrome://tracing. 6 | * 7 | * Copied from profiler.h: 8 | * NOTE: changing profiler modes is **NOT THREAD SAFE**. You should ensure that 9 | * there no autograd functions are being executed when these function are used. 10 | */ 11 | object Profiler { 12 | def profile[T](file: String)(body: => T): T = { 13 | resource.makeManagedResource(new RecordProfile(file))(_.delete())(List.empty).apply(_ => body) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /scala-torch/src/test/resources/com/microsoft/scalatorch/torch/jit/simple_trace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MyModule(torch.nn.Module): 4 | def __init__(self, N, M): 5 | super(MyModule, self).__init__() 6 | self.weight = torch.nn.Parameter(torch.ones(N, M)) 7 | 8 | def forward(self, input): 9 | return self.weight.mv(input) 10 | 11 | @torch.jit.export 12 | def foo(self, a: float): return a + 4 13 | 14 | my_module = MyModule(3,4) 15 | sm = torch.jit.script(my_module) 16 | sm.save("traced_model.pt") 17 | -------------------------------------------------------------------------------- /scala-torch/src/test/resources/com/microsoft/scalatorch/torch/jit/traced_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/scala_torch/9091473e3c19ccc542baff12523d024e440dfa32/scala-torch/src/test/resources/com/microsoft/scalatorch/torch/jit/traced_model.pt -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/TensorTest.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch.syntax.{ ---, _ } 5 | import org.scalatest._ 6 | 7 | class TensorTest extends FunSpec with BeforeAndAfterAll { 8 | import TensorTestUtils._ 9 | 10 | implicit val rm = new ReferenceManager 11 | 12 | override def afterAll(): Unit = { 13 | rm.close() 14 | } 15 | 16 | describe("create") { 17 | it("can create from different types of scalars") { 18 | val a = Tensor.fromLongs(3, 4, 5) 19 | assert(a.dtype == torch.long) 20 | 21 | val b = Tensor(3.0f, 4.0f, 5.0f) 22 | assert(b.dtype == torch.float) 23 | } 24 | 25 | } 26 | 27 | describe("tensor functions") { 28 | val e1 = Tensor(1) 29 | val e2 = Tensor(2) 30 | val e3 = Tensor(3) 31 | 32 | it("arithmetic") { 33 | assertIsClose((-e1).toFloat, -1f) 34 | assertIsClose((e1 + e2).toFloat, 3f) 35 | assertIsClose((e1 + 10).toFloat, 11f) 36 | assertIsClose((10 + e1).toFloat, 11f) 37 | assertIsClose((e2 - e1).toFloat, 1f) 38 | assertIsClose((e2 - 10).toFloat, -8f) 39 | //isClose(( 10 - e2).toFloat,8f) 40 | assertIsClose((e1 * e2).toFloat, 2f) 41 | assertIsClose((10 * e2).toFloat, 20f) 42 | assertIsClose((e2 * 10).toFloat, 20f) 43 | assertIsClose((e2 / 10).toFloat, 0.2f) 44 | } 45 | 46 | // affine transform 47 | describe("affine transform") { 48 | it("affine transform") { 49 | assertIsClose(nn.functional.linear(e1, e2.reshape(Size(1, 1)), e3).toFloat, 5) 50 | } 51 | } 52 | 53 | it("pow") { 54 | val sqrt2 = torch.sqrt(e2) 55 | assertIsClose((sqrt2 * sqrt2).toFloat, 2) 56 | 57 | assertIsClose(torch.pow(e2, e3).toFloat, 8) 58 | assertIsClose(torch.pow(e3, e2).toFloat, 9) 59 | } 60 | 61 | it("min/max") { 62 | assertIsClose(torch.min(e1, e3).toFloat, 1) 63 | assertIsClose(torch.max(e1, e3).toFloat, 3) 64 | } 65 | 66 | it("function that returns tuple") { 67 | val t = Tensor.fromFloatArray(Array(3f, 4f, 1f, 2f), Size(2, 2)) 68 | val (max1, index1) = torch.max(t, 1) 69 | assert(max1.toArray[Float].toSeq == Seq(4f, 2f)) 70 | assert(index1.toArray[Long].toSeq == Seq(1L, 1L)) 71 | assert(max1.shape == Size(2)) 72 | 73 | val (max2, index2) = torch.max(t, 0) 74 | assert(max2.toArray[Float].toSeq == Seq(3f, 4f)) 75 | assert(index2.toArray[Long].toSeq == Seq(0L, 0L)) 76 | 77 | val (max3, index3) = torch.max(t, 1, keepdim = true) 78 | assert(max3.toArray[Float].toSeq == Seq(4f, 2f)) 79 | assert(index3.toArray[Long].toSeq == Seq(1L, 1L)) 80 | assert(max3.shape == Size(2, 1)) 81 | } 82 | 83 | // TODO: write more tests 84 | } 85 | 86 | describe("concatenate") { 87 | it("fail gracefully with empty list") { 88 | ReferenceManager.forBlock { implicit rm => 89 | assertThrows[RuntimeException] { 90 | val foo = torch.cat(Array()) 91 | } 92 | } 93 | } 94 | } 95 | 96 | it("transpose should work") { 97 | ReferenceManager.forBlock { implicit rm => 98 | val e1 = Tensor.zeros(Size(10, 1)) 99 | val transposed = torch.transpose(e1, 0, 1) 100 | assert(transposed.shape == Size(1, 10)) 101 | 102 | val e2 = Tensor.zeros(Size(1, 10)) 103 | val transposed2 = torch.transpose(e2, 0, 1) 104 | assert(transposed2.shape == Size(10, 1)) 105 | } 106 | } 107 | 108 | it("lists of tensors should get converted to vectors") { 109 | def sum(tensors: Array[Tensor]): Tensor = { 110 | if (tensors.isEmpty) Tensor.zeros(Size()) else torch.stack(tensors).sum() 111 | } 112 | ReferenceManager.forBlock { implicit rm => 113 | val exprs = for (i <- 1 to 100) yield Tensor(i.toFloat) 114 | 115 | val sums = for (i <- 1 to 50) yield sum(exprs.toArray) 116 | val expected = (1 to 100).sum 117 | 118 | sums.foreach(s => assertIsClose(s.toFloat, expected)) 119 | 120 | val uberSum = sum((for { 121 | _ <- 1 to 1000 122 | i1 = scala.util.Random.nextInt(100) 123 | i2 = scala.util.Random.nextInt(100) 124 | i3 = scala.util.Random.nextInt(100) 125 | } yield sum(Array(exprs(i1), exprs(i2), exprs(i3)))).toArray) 126 | val value = uberSum.toFloat 127 | assert(value > 30f * 1000 * 3) 128 | assert(value < 70f * 1000 * 3) 129 | } 130 | } 131 | 132 | it("empty") { 133 | // This test fails if you the native code calls at::empty instead of torch::empty 134 | assert(Tensor.empty(Size(10), TensorOptions(requires_grad = true)).shape == Size(10)) 135 | assert(torch.empty(Size(10), layout = Layout.Sparse).layout == Layout.Sparse) 136 | } 137 | 138 | it("normal") { 139 | // This one requires special treatment in swig so we give it a special test 140 | assert(torch.normal(1, 1, Array(1, 1), dtype = torch.float16).dtype == torch.float16) 141 | } 142 | 143 | it("sum should basically function") { 144 | ReferenceManager.forBlock { implicit rm => 145 | val expr = Tensor(Array.range(0, 100).map(_.toFloat): _*) 146 | val total = expr.sum() 147 | 148 | assertIsClose(total.toFloat, (0 until 100).map(_.toFloat).sum) 149 | } 150 | } 151 | 152 | it("index") { 153 | // index is special because it does the sketchy stuff with c10::List> 154 | ReferenceManager.forBlock { implicit rm => 155 | val expr = Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), Size(10, 10)) 156 | val indexedOnce = expr.index(Array(Some(Tensor.fromLongArray(Array(0L, 3L), Size(2))), None)) 157 | assert( 158 | indexedOnce == Tensor 159 | .fromFloatArray(Array.range(0, 10).map(_.toFloat) ++ Array.range(30, 40).map(_.toFloat), Size(2, 10)), 160 | ) 161 | val indexedTwice = expr.index( 162 | Array(Some(Tensor.fromLongArray(Array(0L, 3L), Size(2))), Some(Tensor.fromLongArray(Array(1L, 4L), Size(2)))), 163 | ) 164 | assert(indexedTwice == Tensor(1f, 34f)) 165 | 166 | // This crashes, but I think it's a bug in Torch. 167 | // assert(expr.index(Array(None, None)) == expr) 168 | } 169 | } 170 | 171 | describe("indexing and slicing") { 172 | it("single coords") { 173 | ReferenceManager.forBlock { implicit rm => 174 | val expr = Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), Size(10, 10)) 175 | assert(expr(1, 1).toFloat == 11) 176 | expr(1, 1) = Tensor(-11) 177 | assert(expr(1, 1).toFloat == -11) 178 | expr(::) = torch.zeros(Array(10, 10)) 179 | assert(expr == torch.zeros(Array(10, 10))) 180 | } 181 | } 182 | it("other slicers") { 183 | ReferenceManager.forBlock { implicit rm => 184 | val expr = Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), Size(10, 10)) 185 | assert(expr(1, None) == Tensor.fromFloatArray(Array.range(10, 20).map(_.toFloat), Size(10))) 186 | assert(expr(None, 1) == Tensor.fromFloatArray(Array.range(0, 10).map(x => x * 10f + 1), Size(10))) 187 | assert(expr(::, 1) == Tensor.fromFloatArray(Array.range(0, 10).map(x => x * 10f + 1), Size(10))) 188 | assert(expr(1, 5.::) == Tensor.fromFloatArray(Array.range(15, 20).map(_.toFloat), Size(5))) 189 | // If you really want the python syntax 190 | import scala.language.postfixOps 191 | assert(expr(1, 5 ::) == Tensor.fromFloatArray(Array.range(15, 20).map(_.toFloat), Size(5))) 192 | // format: off 193 | assert(expr(1, 5::) == Tensor.fromFloatArray(Array.range(15, 20).map(_.toFloat), Size(5))) 194 | // format: on 195 | 196 | assert(expr(1, 5 -> 8) == Tensor(15f, 16f, 17f)) 197 | assert(expr(1, 5 -> -2) == Tensor(15f, 16f, 17f)) 198 | assert(expr(1, 5 -> 8 -> 2) == Tensor(15f, 17f)) 199 | assert(expr(1, 5 :: 2) == Tensor(15f, 17f, 19f)) 200 | 201 | val expr3 = Tensor.fromFloatArray(Array.range(0, 1000).map(_.toFloat), Size(10, 10, 10)) 202 | assert(expr3(1, ---) == Tensor.fromFloatArray(Array.range(100, 200).map(_.toFloat), Size(10, 10))) 203 | assert(expr3(---, 1) == Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat).map(_ * 10 + 1), Size(10, 10))) 204 | } 205 | } 206 | } 207 | 208 | it("torch.tensor") { 209 | assert(torch.tensor(1) == Tensor.fromIntArray(Array(1), Size())) 210 | assert(torch.tensor(Array(1)) == Tensor.fromIntArray(Array(1), Size(1))) 211 | assert(torch.tensor(Array(1, 2, 3)) == Tensor.fromIntArray(Array(1, 2, 3), Size(3))) 212 | assert(torch.tensor(Array(Array(1, 2), Array(3, 4))) == Tensor.fromIntArray(Array(1, 2, 3, 4), Size(2, 2))) 213 | assert( 214 | torch.tensor(Array(Array(Array(1, 2), Array(3, 4)), Array(Array(1, 2), Array(3, 4)))) == Tensor 215 | .fromIntArray(Array(1, 2, 3, 4, 1, 2, 3, 4), Size(2, 2, 2)), 216 | ) 217 | 218 | assert( 219 | torch.tensor(Array(Array(Array(1, 2, 3), Array(3, 4, 5)), Array(Array(1, 2, 3), Array(3, 4, 5)))) == Tensor 220 | .fromIntArray(Array(1, 2, 3, 3, 4, 5, 1, 2, 3, 3, 4, 5), Size(2, 2, 3)), 221 | ) 222 | 223 | assertThrows[IllegalArgumentException](torch.tensor(Array(Array(1, 2), Array(3)))) 224 | assertThrows[IllegalArgumentException](torch.tensor(Array.empty[Int])) 225 | 226 | assert(torch.tensor(Array(1f, 2f, 3f)) == Tensor.fromFloatArray(Array(1f, 2f, 3f), Size(3))) 227 | assert(torch.tensor(Array(1.0, 2.0, 3.0)) == Tensor.fromDoubleArray(Array(1.0, 2.0, 3.0), Size(3))) 228 | assert(torch.tensor(Array(1L, 2L, 3L)) == Tensor.fromLongArray(Array(1L, 2L, 3L), Size(3))) 229 | assert(torch.tensor(Array[Byte](1, 2, 3)) == Tensor.fromByteArray(Array[Byte](1, 2, 3), Size(3))) 230 | 231 | } 232 | 233 | } 234 | 235 | object TensorTestUtils extends Assertions { 236 | private val DefaultEpsilon = 1e-4f 237 | 238 | def assertIsClose(value: Float, expected: Float, epsilon: Float = DefaultEpsilon): Assertion = { 239 | assert((value - expected).abs <= epsilon) 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/TorchFunctionTest.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch 2 | 3 | import com.microsoft.scalatorch.torch.syntax._ 4 | import org.scalatest.FunSpec 5 | 6 | class TorchFunctionTest extends FunSpec { 7 | 8 | describe("loss functions") { 9 | it("should accept long as second arg") { 10 | ReferenceManager.forBlock { implicit rm => 11 | assert( 12 | nn.functional.nll_loss(Tensor(-1.0f, -3.0f, -4.0f).reshape(Size(1, 3)), Tensor.fromLongs(1)).toFloat == 3.0f, 13 | ) 14 | } 15 | } 16 | 17 | it("should accept Reduction") { 18 | ReferenceManager.forBlock { implicit rm => 19 | assert( 20 | nn.functional 21 | .poisson_nll_loss(Tensor(1.0f, 3.0f, 4.0f).reshape(Size(1, 3)), Tensor.fromLongs(1)) 22 | .toFloat == 1.8383645f, 23 | ) 24 | assert( 25 | nn.functional 26 | .binary_cross_entropy( 27 | Tensor(0f, 0.5f, 1f).reshape(Size(1, 3)), 28 | Tensor.ones(Size(1)), 29 | reduction = Reduction.Mean, 30 | ) 31 | .toFloat == 33.56438f, 32 | ) 33 | } 34 | } 35 | } 36 | 37 | describe("packages") { 38 | it("linear") { 39 | ReferenceManager.forBlock { implicit rm => 40 | TensorTestUtils.assertIsClose(linalg.norm(Tensor(1.0f, 3.0f, 4.0f)).toFloat, 5.09902f) 41 | } 42 | } 43 | it("special") { 44 | ReferenceManager.forBlock { implicit rm => 45 | assert(special.log1p(Tensor(1.0f, 3.0f, 4.0f)).toArray[Float].toSeq == Seq(0.6931472f, 1.3862944f, 1.609438f)) 46 | } 47 | } 48 | 49 | it("fft") { 50 | ReferenceManager.forBlock { implicit rm => 51 | assert(fft.fft(Tensor(1.0f, 3.0f, 4.0f)).apply(0).real.toFloat == 8.0000f) 52 | } 53 | } 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/jit/ModuleTest.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.jit 2 | 3 | import java.io.File 4 | 5 | import com.microsoft.scalatorch.torch._ 6 | import org.scalatest.FunSpec 7 | 8 | class ModuleTest extends FunSpec { 9 | describe("Module") { 10 | lazy val f = new File(getClass.getResource("traced_model.pt").toURI) 11 | it("should load") { 12 | ReferenceManager.forBlock { implicit rm => 13 | val m = Module.load(f) 14 | val t = Tensor(3.0f, 4.0f, 5.0f, 6.0f) 15 | val r = m.forward(Seq(t)) 16 | // the actual script generates a 3x4 ones matrix 17 | val o = Tensor.ones(Size(3, 4)) 18 | assert(r.asTensor == (o.matmul(t))) 19 | } 20 | 21 | } 22 | 23 | it("should allow us to invoke other methods") { 24 | ReferenceManager.forBlock { implicit rm => 25 | val m = Module.load(f) 26 | val r = m.invoke("foo", 4.0) 27 | assert(r.asDouble == 8) 28 | } 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/tutorial/PyTorchOrgTensorTutorialTest.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.tutorial 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch.ReferenceManager 5 | import com.microsoft.scalatorch.torch.syntax._ 6 | 7 | /** Tries to follow https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html 8 | * 9 | * See also [[TutorialTest]] in this same directory. 10 | */ 11 | object PyTorchOrgTensorTutorial { 12 | def main(args: Array[String]): Unit = { 13 | ReferenceManager.forBlock { implicit rm => 14 | val data = $($(1, 2), $(3, 4)) 15 | val x_data = torch.tensor(data) 16 | assert(x_data.sum() == torch.tensor(10L)) 17 | 18 | val x_ones = torch.ones_like(x_data) 19 | // retains the properties of x_data 20 | println(s"Ones Tensor: \n ${x_ones} \n") 21 | 22 | val x_rand = torch.rand_like(x_data, dtype = torch.float) 23 | // overrides the datatype of x_data 24 | println(s"Random Tensor: \n ${x_rand} \n") 25 | 26 | val shape = (2, 3) 27 | val rand_tensor = torch.rand(shape) 28 | val ones_tensor = torch.ones(shape) 29 | val zeros_tensor = torch.zeros(shape) 30 | 31 | println(s"Random Tensor: \n ${rand_tensor} \n") 32 | println(s"Ones Tensor: \n ${ones_tensor} \n") 33 | println(s"Zeros Tensor: \n ${zeros_tensor}") 34 | 35 | var tensor = torch.rand($(3, 4)) 36 | 37 | println(s"Shape of tensor: ${tensor.shape}") 38 | println(s"Datatype of tensor: ${tensor.dtype}") 39 | println(s"Device tensor is stored on: ${tensor.device}") 40 | 41 | if (torch.cuda.is_available()) 42 | tensor = tensor.to(device = "cuda") 43 | 44 | // Standard numpy-like indexing and slicing 45 | tensor = torch.ones($(4, 4)) 46 | println(s"First row: ${tensor(0)}") 47 | println(s"First column: ${tensor(::, 0)}") 48 | println(s"Last column: ${tensor(---, -1)}") 49 | tensor(::, 1) = 0 50 | println(tensor) 51 | 52 | val t1 = torch.cat($(tensor, tensor, tensor), dim = 1) 53 | println(t1) 54 | 55 | // Arithmetic operations 56 | 57 | // This computes the matrix multiplication between two tensors 58 | // y1, y2, y3 will have the same value 59 | val y1 = tensor *@* tensor.T 60 | val y2 = tensor.matmul(tensor.T) 61 | 62 | val y3 = torch.rand_like(y1) 63 | // out params not supported yet 64 | // torch.matmul(tensor, tensor.T, out = y3) 65 | val out = torch.matmul(tensor, tensor.T) 66 | 67 | // This computes the element -wise product 68 | // z1, z2, z3 will have the same value 69 | val z1 = tensor * tensor 70 | val z2 = tensor.mul(tensor) 71 | 72 | val z3 = torch.rand_like(tensor) 73 | // out params not supported yet 74 | // torch.mul(tensor, tensor, out = z3) 75 | val out2 = torch.mul(tensor, tensor) 76 | 77 | val agg = tensor.sum() 78 | val agg_item = agg.item() 79 | println(s"${agg_item}, ${agg_item.getClass}") 80 | 81 | // In-place operations 82 | println(s"${tensor} \n") 83 | tensor.add_(5) 84 | println(tensor) 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/tutorial/PytorchOrgPolynomialNetwork.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.tutorial 2 | import com.microsoft.scalatorch.torch 3 | import com.microsoft.scalatorch.torch.{ Reduction, ReferenceManager, Tensor } 4 | import com.microsoft.scalatorch.torch.syntax._ 5 | 6 | /** Follows the example at https://pytorch.org/tutorials/beginner/pytorch_with_examples.html 7 | * Same as [[PytorchOrgPolynomialNetworkOptimTutorial]], but uses manual updates instead of a 8 | * [[com.microsoft.scalatorch.torch.optim.Optimizer]]. 9 | */ 10 | object PytorchOrgPolynomialNetworkTutorial { 11 | def main(args: Array[String]): Unit = ReferenceManager.forBlock { implicit rm => 12 | // Create Tensors to hold input and outputs. 13 | val x = torch.linspace(-Math.PI, Math.PI, Some(2000)) 14 | val y = torch.sin(x) 15 | 16 | // For this example, the output y is a linear function of(x, x ^ 2, x ^ 3), so 17 | // we can consider it as a linear layer neural network 18 | // Let 's prepare the tensor(x, x ^ 2, x ^ 3). 19 | val p = torch.tensor($(1, 2, 3)) 20 | val xx = x.unsqueeze(-1).pow(p) 21 | 22 | // In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape 23 | // (3,), for this case, broadcasting semantics will apply to obtain a tensor 24 | // of shape (2000, 3) 25 | // 26 | // Here we depart from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html because 27 | // we have not (yet) made a wrapper for Modules. We manage the weight and bias parameters manually. 28 | val weight = torch.ones($(1, 3)) 29 | val bias = torch.ones($(1)) 30 | weight.requires_grad_(true) 31 | bias.requires_grad_(true) 32 | val parameters = Seq(weight, bias) 33 | val model = (x: torch.Tensor) => { 34 | torch.flatten(torch.nn.functional.linear(x, weight, Some(bias)), 0, 1) 35 | } 36 | 37 | // The nn package also contains definitions of popular loss functions; in this 38 | // case we will use Mean Squared Error (MSE) as our loss function. 39 | val loss_fn = torch.nn.functional.mse_loss(_, _, reduction = Reduction.Sum) 40 | 41 | val learning_rate = 1e-6 42 | for (t <- 0 until 2000) { 43 | 44 | // Forward pass: compute predicted y by passing x to the model. Module objects 45 | // override the __call__ operator so you can call them like functions. When 46 | // doing so you pass a Tensor of input data to the Module and it produces 47 | // a Tensor of output data. 48 | val y_pred = model(xx) 49 | // Compute and print loss. We pass Tensors containing the predicted and true 50 | // values of y, and the loss function returns a Tensor containing the 51 | // loss. 52 | val loss = loss_fn(y_pred, y) 53 | if (t % 100 == 99) 54 | println(s"iteration=$t loss=${loss.item()}") 55 | 56 | // Zero the gradients before running the backward pass. 57 | torch.no_grad { 58 | parameters.foreach(p => if (p.grad.defined()) p.grad.zero_()) 59 | } 60 | 61 | // Backward pass: compute gradient of the loss with respect to all the learnable 62 | // parameters of the model. Internally, the parameters of each Module are stored 63 | // in Tensors with requires_grad=True, so this call will compute gradients for 64 | // all learnable parameters in the model. 65 | loss.backward() 66 | // Update the weights using gradient descent. Each parameter is a Tensor, so 67 | // we can access its gradients like we did before. 68 | torch.no_grad { 69 | for (param <- parameters) { 70 | param -= learning_rate * param.grad 71 | } 72 | } 73 | 74 | println( 75 | s"Result: y = ${bias.item()} + ${weight(::, 0).item()} x + ${weight(::, 1).item()} x^2 + ${weight(::, 2).item()} x^3", 76 | ) 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/tutorial/PytorchOrgPolynomialNetworkOptim.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.tutorial 2 | 3 | import com.microsoft.scalatorch.torch 4 | import com.microsoft.scalatorch.torch.optim.SGD 5 | import com.microsoft.scalatorch.torch.{ Reduction, ReferenceManager, Tensor } 6 | import com.microsoft.scalatorch.torch.syntax._ 7 | 8 | /** Follows the example at https://pytorch.org/tutorials/beginner/pytorch_with_examples.html 9 | * Same as [[PytorchOrgPolynomialNetworkTutorial]], but uses an [[com.microsoft.scalatorch.torch.optim.Optimizer]] 10 | * instead of manual updates */ 11 | object PytorchOrgPolynomialNetworkOptimTutorial { 12 | def main(args: Array[String]): Unit = ReferenceManager.forBlock { implicit rm => 13 | // Create Tensors to hold input and outputs. 14 | val x = torch.linspace(-Math.PI, Math.PI, Some(2000)) 15 | val y = torch.sin(x) 16 | 17 | // For this example, the output y is a linear function of(x, x ^ 2, x ^ 3), so 18 | // we can consider it as a linear layer neural network 19 | // Let 's prepare the tensor(x, x ^ 2, x ^ 3). 20 | val p = torch.tensor($(1, 2, 3)) 21 | val xx = x.unsqueeze(-1).pow(p) 22 | 23 | // In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape 24 | // (3,), for this case, broadcasting semantics will apply to obtain a tensor 25 | // of shape (2000, 3) 26 | // 27 | // Here we depart from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html because 28 | // we have not (yet) made a wrapper for Modules. We manage the weight and bias parameters manually. 29 | val weight = torch.ones($(1, 3)) 30 | val bias = torch.ones($(1)) 31 | weight.requires_grad_(true) 32 | bias.requires_grad_(true) 33 | val parameters = Seq(weight, bias) 34 | val model = (x: torch.Tensor) => { 35 | torch.flatten(torch.nn.functional.linear(x, weight, Some(bias)), 0, 1) 36 | } 37 | 38 | // The nn package also contains definitions of popular loss functions; in this 39 | // case we will use Mean Squared Error (MSE) as our loss function. 40 | val loss_fn = torch.nn.functional.mse_loss(_, _, reduction = Reduction.Sum) 41 | val optim = SGD(parameters, SGD.Options(learningRate = 1e-6)) 42 | for (t <- 0 until 2000) { 43 | 44 | // Forward pass: compute predicted y by passing x to the model. Module objects 45 | // override the __call__ operator so you can call them like functions. When 46 | // doing so you pass a Tensor of input data to the Module and it produces 47 | // a Tensor of output data. 48 | val y_pred = model(xx) 49 | // Compute and print loss. We pass Tensors containing the predicted and true 50 | // values of y, and the loss function returns a Tensor containing the 51 | // loss. 52 | val loss = loss_fn(y_pred, y) 53 | if (t % 100 == 99) 54 | println(s"iteration=$t loss=${loss.item()}") 55 | 56 | optim.zeroGrad() 57 | 58 | // Backward pass: compute gradient of the loss with respect to all the learnable 59 | // parameters of the model. Internally, the parameters of each Module are stored 60 | // in Tensors with requires_grad=True, so this call will compute gradients for 61 | // all learnable parameters in the model. 62 | loss.backward() 63 | // Update the weights using gradient descent. Each parameter is a Tensor, so 64 | // we can access its gradients like we did before. 65 | optim.step() 66 | 67 | println( 68 | s"Result: y = ${bias.item()} + ${weight(::, 0).item()} x + ${weight(::, 1).item()} x^2 + ${weight(::, 2).item()} x^3", 69 | ) 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /scala-torch/src/test/scala/com/microsoft/scalatorch/torch/tutorial/TutorialTest.scala: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.tutorial 2 | 3 | import com.microsoft.scalatorch.torch.ReferenceManager 4 | 5 | class TutorialTest extends org.scalatest.funspec.AnyFunSpec { 6 | 7 | it("functions with simple calls ") { 8 | import com.microsoft.scalatorch.torch 9 | // There should always be a ReferenceManager in implicit scope for memory management. 10 | ReferenceManager.forBlock { implicit rm => 11 | val tensor: torch.Tensor = torch.eye(2) 12 | assert(tensor.numel() == 4) 13 | assert(tensor.size() == torch.Size(2, 2)) 14 | 15 | // OOP style 16 | assert(tensor.sum().toFloat == 2f) 17 | // Static function style 18 | assert(torch.sum(tensor).toFloat == 2f) 19 | 20 | // Unfortunately, Scala does not allow multiple overloads with default values, so for some overloads, 21 | // you must redundantly specify defaults. 22 | assert(tensor.sum(dim = Array(1L), false, None) == torch.ones(Array(2L))) 23 | } 24 | } 25 | 26 | describe("Python-like slicing syntax") { 27 | it("works for reads") { 28 | import com.microsoft.scalatorch.torch 29 | import com.microsoft.scalatorch.torch.syntax._ 30 | ReferenceManager.forBlock { implicit rm => 31 | val expr = torch.Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), torch.Size(10, 10)) 32 | assert(expr(1, 1).toFloat == 11) 33 | 34 | assert(expr(1, None) == torch.Tensor.fromFloatArray(Array.range(10, 20).map(_.toFloat), torch.Size(10))) 35 | assert(expr(None, 1) == torch.Tensor.fromFloatArray(Array.range(0, 10).map(x => x * 10f + 1), torch.Size(10))) 36 | assert(expr(::, 1) == torch.Tensor.fromFloatArray(Array.range(0, 10).map(x => x * 10f + 1), torch.Size(10))) 37 | 38 | assert(expr(1, 5 -> 8) == torch.Tensor(15f, 16f, 17f)) 39 | assert(expr(1, 5 -> -2) == torch.Tensor(15f, 16f, 17f)) 40 | assert(expr(1, 5 -> 8 -> 2) == torch.Tensor(15f, 17f)) 41 | 42 | assert(expr(1, 5 :: 2) == torch.Tensor(15f, 17f, 19f)) 43 | 44 | val expr2 = torch.Tensor.fromFloatArray(Array.range(0, 1000).map(_.toFloat), torch.Size(10, 10, 10)) 45 | assert(expr2(1, ---) == torch.Tensor.fromFloatArray(Array.range(100, 200).map(_.toFloat), torch.Size(10, 10))) 46 | assert( 47 | expr2(---, 1) == torch.Tensor 48 | .fromFloatArray(Array.range(0, 100).map(_.toFloat).map(_ * 10 + 1), torch.Size(10, 10)), 49 | ) 50 | } 51 | } 52 | 53 | it("works for writes") { 54 | import com.microsoft.scalatorch.torch 55 | import com.microsoft.scalatorch.torch.syntax._ 56 | ReferenceManager.forBlock { implicit rm => 57 | val expr = torch.Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), torch.Size(10, 10)) 58 | assert(expr(1, 1).toFloat == 11) 59 | expr(1, 1) = torch.Tensor(-11) 60 | assert(expr(1, 1).toFloat == -11) 61 | expr(::) = torch.zeros(Array(10, 10)) 62 | assert(expr == torch.zeros(Array(10, 10))) 63 | } 64 | } 65 | } 66 | it("is possible with dot postfix") { 67 | import com.microsoft.scalatorch.torch 68 | import com.microsoft.scalatorch.torch.syntax._ 69 | ReferenceManager.forBlock { implicit rm => 70 | val expr = torch.Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), torch.Size(10, 10)) 71 | // You can write `5 ::` instead of `5.::` if you enable postfixOps (see below) 72 | assert(expr(1, 5.::) == torch.Tensor.fromFloatArray(Array.range(15, 20).map(_.toFloat), torch.Size(5))) 73 | } 74 | } 75 | it("is possible with space postfix") { 76 | import com.microsoft.scalatorch.torch 77 | import com.microsoft.scalatorch.torch.syntax._ 78 | // If you really want the python syntax 79 | import scala.language.postfixOps 80 | ReferenceManager.forBlock { implicit rm => 81 | val expr = torch.Tensor.fromFloatArray(Array.range(0, 100).map(_.toFloat), torch.Size(10, 10)) 82 | // format: off 83 | assert(expr(1, 5 ::) == torch.Tensor.fromFloatArray(Array.range(15, 20).map(_.toFloat), torch.Size(5))) 84 | // format: on 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /swig/build.sbt: -------------------------------------------------------------------------------- 1 | name := "scala-torch-swig" 2 | 3 | libraryDependencies += "com.github.fommil" % "jniloader" % "1.1" 4 | 5 | // for some reason, SBT passes javacOptions to javadoc, but javadoc doesn't understand -target, so: 6 | javacOptions in (Compile, doc) := Seq.empty 7 | 8 | val packageName = "com.microsoft.scalatorch.torch.internal" 9 | 10 | val torchDir = baseDirectory { d => Option(System.getenv("TORCH_DIR")).map(file).getOrElse(d / "../libtorch") } 11 | 12 | val includeDirs = torchDir { d => Seq(d / "include", d / "include/torch/csrc/api/include/") } 13 | 14 | // Swig stuff: generate into managed sources, try to be a good SBT citizen 15 | val Swig = config("swig") 16 | 17 | val generate = TaskKey[Seq[File]]("generate") 18 | 19 | inConfig(Swig)(Defaults.paths) 20 | 21 | // Output for this config is version-independent 22 | Swig / crossTarget := (Swig / target).value 23 | 24 | Swig / sourceDirectory := (sourceDirectory in Compile).value / "swig" 25 | 26 | Compile / sourceManaged := (Swig / sourceManaged).value 27 | 28 | Swig / generate := { 29 | val tgt = (Swig / sourceManaged).value 30 | tgt.mkdirs() 31 | 32 | // the cxx file will go here 33 | val native = tgt / "native" 34 | val include = includeDirs.value 35 | native.mkdirs() 36 | 37 | val out = streams.value 38 | val bindingGenScript = (sourceDirectory in Swig).value / "bindgen.py" 39 | val declarationsFile = bindingGenScript 40 | .getParentFile() / ".." / ".." / ".." / ".." / "pytorch" / "torch" / "share" / "ATen" / "Declarations.yaml" 41 | 42 | val cachedGen = FileFunction.cached( 43 | out.cacheDirectory / "bindgen", 44 | inStyle = FilesInfo.lastModified, 45 | outStyle = FilesInfo.exists, 46 | ) { (in: Set[File]) => 47 | assert(in.contains(bindingGenScript), s"$bindingGenScript should be in ${in}") 48 | assert(in.contains(declarationsFile), s"$declarationsFile should be in ${in}") 49 | genSwigAndScala(bindingGenScript, declarationsFile, tgt, out.log) 50 | } 51 | val genned = cachedGen(((sourceDirectory in Swig).value ** ("*.py")).get.toSet + bindingGenScript + declarationsFile) 52 | 53 | val cachedRunSwig = FileFunction.cached( 54 | out.cacheDirectory / "swig", 55 | inStyle = FilesInfo.lastModified, 56 | outStyle = FilesInfo.exists, 57 | ) { (in: Set[File]) => 58 | val theMainFile = in.find(_.getName == "torch_swig.i").get 59 | runSwig(theMainFile, tgt, native, include, out.log) 60 | } 61 | cachedRunSwig(genned).toSeq 62 | } 63 | 64 | /** This is pretty convoluted right now. This code calls bindgen.py, which generates swig declarations, which are turned 65 | * by the `swig` command into both a .cxx file and Java bindings. This code also generates 66 | * Scala wrappers for those Java bindings. Unfortunately, those Scala wrappers are defined 67 | * in a downstream project (scala-torch). In an ideal world, we would separately generate 68 | * the scala bindings in the scala-torch project, and also put the generate the scala files 69 | * under the [[sourceManaged]] directory. For now, out of laziness, we directly 70 | * read in .scala.in files from the downstream project and produce .scala files in that same project. We 71 | * have checked in the generated Scala files for now so the API is easy to see, but we proably shouldn't do that 72 | * in the long-run. 73 | */ 74 | def genSwigAndScala( 75 | bindingGenScript: File, 76 | declarationsFile: File, 77 | target: File, 78 | logger: Logger, 79 | ): Set[File] = { 80 | import scala.sys.process._ 81 | 82 | val realTarget = target / packageName.replace(".", "/") 83 | realTarget.mkdirs() 84 | 85 | val scalaDir = 86 | bindingGenScript 87 | .getParentFile() / ".." / ".." / ".." / ".." / "scala-torch" / "src" / "main" / "scala" / "com" / "microsoft" / "scalatorch" / "torch" 88 | 89 | val bindingGenCmd = 90 | s"""python3 $bindingGenScript $declarationsFile ${bindingGenScript.getParentFile()} $scalaDir""" 91 | 92 | logger.info("Generating auto-generated swig bindings") 93 | logger.info(bindingGenCmd) 94 | 95 | val pyErrorCode = bindingGenCmd ! logger 96 | if (pyErrorCode != 0) { 97 | sys.error(s"aborting generation of swig files because $bindingGenScript failed") 98 | } 99 | 100 | (bindingGenScript.getParentFile() ** ("*.i")).get.toSet 101 | } 102 | 103 | def runSwig( 104 | swigFile: File, 105 | target: File, 106 | nativeTarget: File, 107 | includeDirs: Seq[File], 108 | logger: Logger, 109 | ): Set[File] = { 110 | import scala.sys.process._ 111 | def stripExtension(file: File) = ext.matcher(file.getName).replaceAll("") 112 | 113 | val realTarget = target / packageName.replace(".", "/") 114 | realTarget.mkdirs() 115 | 116 | val cxx = s"${nativeTarget}/${stripExtension(swigFile)}.cxx" 117 | val totalInclude = includeDirs 118 | val includeI = totalInclude.mkString("-I", " -I", "") 119 | 120 | val cmd = 121 | s"""swig -DSWIGWORDSIZE64 -v -c++ -java -package $packageName $includeI -o $cxx -outdir $realTarget $swigFile""" 122 | 123 | logger.info(s"generating SWIG: ${swigFile}") 124 | logger.info(cmd) 125 | 126 | val errorCode = cmd ! logger 127 | if (errorCode != 0) { 128 | sys.error(s"aborting generation for $swigFile because swig was unhappy") 129 | } 130 | 131 | (target ** "*.java").get.toSet 132 | } 133 | 134 | unmanagedSourceDirectories in Compile += (sourceDirectory in Swig).value 135 | sourceGenerators in Compile += (generate in Swig).taskValue 136 | cleanFiles += (target in Swig).value 137 | 138 | import java.util.regex.Pattern 139 | 140 | val ext = Pattern.compile("(?<=.)\\.[^.]+$") 141 | 142 | // native compilation should wait on swig, since swig makes the cxx file 143 | nativeCompile := nativeCompile.dependsOn(generate in Swig).value 144 | 145 | // this is a bit of a hack to make the naming consistent with JniLoader 146 | resourceGenerators in Compile += Def.task { 147 | val libraries: Seq[(File, String)] = (nativeLibraries in Compile).value 148 | val resources: Seq[File] = for ((file, _) <- libraries) yield { 149 | 150 | val newName = file.getParentFile.getParentFile.getName match { 151 | case "x86_64-darwin" => "libtorch_swig.pred-osx-x86_64.jnilib" 152 | case "x86_64-linux" => "libtorch_swig.pred-linux-x86_64.so" 153 | // TODO: flesh this out 154 | } 155 | val resource = (resourceManaged in Compile).value / newName 156 | 157 | // copy native library to a managed resource, so that it is always available 158 | // on the classpath, even when not packaged as a jar 159 | IO.copyFile(file, resource) 160 | resource 161 | } 162 | resources 163 | }.taskValue 164 | -------------------------------------------------------------------------------- /swig/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.sbt" % "sbt-jni" % "1.5.3") 2 | -------------------------------------------------------------------------------- /swig/src/main/java/com/microsoft/scalatorch/torch/internal/Module2.java: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.internal; 2 | 3 | public interface Module2 { 4 | Output forward(Input1 input1, Input2 input2); 5 | } 6 | -------------------------------------------------------------------------------- /swig/src/main/java/com/microsoft/scalatorch/torch/internal/Module3.java: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.internal; 2 | 3 | public interface Module3 { 4 | Output forward(Input1 input1, Input2 input2, Input3 input3); 5 | } 6 | -------------------------------------------------------------------------------- /swig/src/main/java/com/microsoft/scalatorch/torch/internal/Module4.java: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.internal; 2 | 3 | public interface Module4 { 4 | Output forward(Input1 input1, Input2 input2, Input3 input3, Input4 input4); 5 | } 6 | -------------------------------------------------------------------------------- /swig/src/main/java/com/microsoft/scalatorch/torch/internal/NativeLoader.java: -------------------------------------------------------------------------------- 1 | package com.microsoft.scalatorch.torch.internal; 2 | 3 | import java.io.File; 4 | import java.io.FileOutputStream; 5 | import java.io.IOException; 6 | import java.io.InputStream; 7 | import java.net.URL; 8 | import java.nio.channels.FileChannel; 9 | import java.nio.channels.Channels; 10 | import java.nio.channels.ReadableByteChannel; 11 | import java.nio.file.Path; 12 | 13 | /** Much like https://github.com/sbt/sbt-jni or https://github.com/scijava/native-lib-loader, 14 | * but both of them are a little too opinionated. */ 15 | public class NativeLoader { 16 | 17 | /** 18 | * Extracts the library given by the path elem +: elems. For example, 19 | * {{{extractAndLoadLibrary(Paths.get("tmp", true, "native", "c10"}}} will look for a resource 20 | * called `native/libc10.dylib` on a mac and `native/c10.dll` on windows. 21 | * 22 | * @param dir The directory to extract the native library to. 23 | * @param isLibName If true, the last element of the path is treated as a library name and [[System.mapLibraryName]]. 24 | * Otherwise, it is taken as is, so it must have the correct extensions (e.g. libfoo.so.1). 25 | * will be called on it first. 26 | * @return The extracted file, or null if no resource could be found. 27 | * @throws IOException Also might throw other RuntimeExceptions if library loading fails. 28 | */ 29 | public static File extractAndLoadLibrary(Path dir, boolean isLibName, String elem, String... elems) throws IOException { 30 | String maybeName = (elems.length == 0) ? elem : elems[elems.length - 1]; 31 | String libName = isLibName ? System.mapLibraryName(maybeName) : maybeName; 32 | String[] resourcePathElems = new String[elems.length + 1]; 33 | if (elems.length > 0) { 34 | resourcePathElems[0] = elem; 35 | System.arraycopy(elems, 0, resourcePathElems, 1, elems.length - 1); 36 | } 37 | resourcePathElems[elems.length] = libName; 38 | String resourcePath = String.join("/", resourcePathElems); 39 | File result = extract(dir, resourcePath); 40 | if (result == null) return null; 41 | System.load(result.toString()); 42 | return result; 43 | } 44 | 45 | private static File extract(Path dir, String resourcePath) throws IOException { 46 | URL url = NativeLoader.class.getResource("/" + resourcePath); 47 | if (url == null) return null; 48 | 49 | try(InputStream in = NativeLoader.class.getResourceAsStream("/" + resourcePath)) { 50 | File file = file(dir, resourcePath); 51 | file.deleteOnExit(); 52 | 53 | ReadableByteChannel src = Channels.newChannel(in); 54 | try (FileChannel dest = new FileOutputStream(file).getChannel()) { 55 | dest.transferFrom(src, 0, Long.MAX_VALUE); 56 | 57 | return file; 58 | } 59 | } 60 | } 61 | 62 | private static File file(Path dir, String path) throws IOException { 63 | String name = new File(path).getName(); 64 | 65 | File file = dir.resolve(name).toFile(); 66 | if (file.exists() && !file.isFile()) 67 | throw new IllegalArgumentException(file.getAbsolutePath() + " is not a file."); 68 | if (!file.exists()) file.createNewFile(); 69 | return file; 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_array.i: -------------------------------------------------------------------------------- 1 | // Conversion to/from jvm arrays and tensors. creates methods {to,from}_{float}_array that we can use 2 | %define ARRAY_INPUT_OUTPUT(ctype, jnitype, javatype, jniarraytype, capType, torchScalarType) 3 | %native (to_##javatype##_array) jniarraytype to_##javatype##_array(const torch::Tensor& t); 4 | %native (from_##javatype##_array) torch::Tensor from_##javatype##_array(jniarraytype data, jlongArray shape, const TensorOptions& options={}); 5 | 6 | %{ 7 | 8 | // the two jni methods in this block are implementations of the two functions above 9 | extern "C" { 10 | SWIGEXPORT jlong JNICALL Java_com_microsoft_scalatorch_torch_internal_torch_1swigJNI_from_1##javatype##_1array(JNIEnv * env, jclass, jniarraytype data, jlongArray shape, jlong pTensorOptions, jobject) { 11 | auto shapeElems = (env)->GetLongArrayElements(shape, nullptr); 12 | size_t len = (env)->GetArrayLength(shape); 13 | 14 | TensorOptions* opts = *(TensorOptions **)&pTensorOptions; 15 | TensorOptions withDType = opts->dtype(at::k##torchScalarType); 16 | 17 | auto tshape = IntArrayRef((int64_t*)shapeElems, len); 18 | 19 | auto result = torch::empty(tshape, withDType); 20 | 21 | auto tdata = (jnitype*)result.data_ptr(); 22 | (env)->Get##capType##ArrayRegion(data, 0, result.numel(), tdata); 23 | 24 | jlong jresult = 0; 25 | *(torch::Tensor **)&jresult = new torch::Tensor((const torch::Tensor &)result); 26 | (env)->ReleaseLongArrayElements(shape, shapeElems, 0); 27 | return jresult; 28 | } 29 | 30 | SWIGEXPORT jniarraytype JNICALL Java_com_microsoft_scalatorch_torch_internal_torch_1swigJNI_to_1##javatype##_1array(JNIEnv * env, jclass, jlong pT, jobject pT_) { 31 | torch::Tensor* v = *(torch::Tensor **)&pT; 32 | size_t size = v->numel(); 33 | jniarraytype result = env->New##capType##Array(size); 34 | env->Set##capType##ArrayRegion(result, 0, size, (const jnitype*)v->data_ptr()); 35 | return result; 36 | } 37 | } 38 | %} 39 | 40 | %enddef 41 | 42 | ARRAY_INPUT_OUTPUT(float, jfloat, float, jfloatArray, Float, Float) 43 | ARRAY_INPUT_OUTPUT(int64_t, jlong, long, jlongArray, Long, Long) 44 | ARRAY_INPUT_OUTPUT(double, jdouble, double, jdoubleArray, Double, Double) 45 | ARRAY_INPUT_OUTPUT(signed char, jbyte, byte, jbyteArray, Byte, Char) 46 | ARRAY_INPUT_OUTPUT(int32_t, jint, int, jintArray, Int, Int) 47 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_array_ref.i: -------------------------------------------------------------------------------- 1 | namespace c10 { 2 | 3 | template 4 | class ArrayRef final { 5 | private: 6 | constexpr ArrayRef(); 7 | }; 8 | } 9 | 10 | 11 | // This makes e.g. IntArrayRef look like long[] to the jvm 12 | %define ARRAY_REF(TYPE, SCALAR) 13 | %typemap(in) TYPE { 14 | auto elems$1 = (SCALAR*)(jenv)->Get##$typemap(jboxtype, SCALAR)##ArrayElements($input, nullptr); 15 | size_t len$1 = (jenv)->GetArrayLength($input); 16 | $1 = TYPE(elems$1, len$1); 17 | } 18 | 19 | %typemap(freearg) TYPE { 20 | (jenv)->Release##$typemap(jboxtype, SCALAR)##ArrayElements($input, ($typemap(jni, SCALAR)*)$1.data(), 0); 21 | } 22 | 23 | %typemap(out) TYPE { 24 | $result = (jenv)->New##$typemap(jboxtype, SCALAR)##Array($1.size()); 25 | (jenv)->Set##$typemap(jboxtype, SCALAR)##ArrayRegion($result, 0, $1.size(), (const $typemap(jni, SCALAR)*)$1.data()); 26 | } 27 | 28 | %typemap(jni) TYPE "$typemap(jni, SCALAR)""Array" 29 | %typemap(jtype) TYPE "$typemap(jtype, SCALAR)[]" 30 | %typemap(jstype) TYPE "$typemap(jtype, SCALAR)[]" 31 | %typemap(javain) TYPE "$javainput" 32 | %typemap(javaout) TYPE { 33 | return $jnicall; 34 | } 35 | 36 | %enddef 37 | 38 | ARRAY_REF(IntArrayRef, int64_t) 39 | ARRAY_REF(ArrayRef, double) 40 | 41 | DEFINE_OPTIONAL(OptDoubleArrayRef, c10::ArrayRef) 42 | DEFINE_OPTIONAL(OptIntArrayRef, IntArrayRef) 43 | 44 | %define ARRAY_REF_OF_OBJECT(ListT, T) 45 | 46 | %template(ListT) c10::ArrayRef< T >; 47 | %naturalvar c10::ArrayRef< T >; 48 | 49 | %typemap(jni) c10::ArrayRef< T > "jlongArray" // Use jlongArray for CPtrs, really these are objects 50 | %typemap(jstype) c10::ArrayRef< T > "$typemap(jboxtype, T)[]" 51 | %typemap(jtype) c10::ArrayRef< T > "long[]" 52 | %typemap(javain) c10::ArrayRef< T > "$javainput" 53 | 54 | %typemap(javain, 55 | pre=" long[] cptrs$javainput = new long[$javainput.length]; for (int i = 0; i < $javainput.length; ++i) { cptrs$javainput[i] = $typemap(jboxtype, T).getCPtr($javainput[i]); }", 56 | //post=" opt$javainput.delete();", 57 | pgcppname="cptrs$javainput") 58 | c10::ArrayRef< T > "cptrs$javainput" 59 | 60 | %typemap(javaout) c10::ArrayRef< T > { 61 | throw new java.lang.IllegalStateException("There should never be a need to return an ArrayRef of objects. "); 62 | } 63 | 64 | %typemap(in) c10::ArrayRef< T > { 65 | 66 | size_t len$1 = (jenv)->GetArrayLength($input); 67 | // https://stackoverflow.com/questions/4754763/object-array-initialization-without-default-constructor 68 | void *raw_memory = operator new[](len$1 * sizeof(T)); 69 | T* array$1 = static_cast(raw_memory); 70 | jlong* elems$1 = (jenv)->GetLongArrayElements($input, nullptr); 71 | for (size_t i = 0; i < len$1; ++i) { 72 | new(&array$1[i]) T(*(T*)elems$1[i]); 73 | } 74 | (jenv)->ReleaseLongArrayElements($input, elems$1, 0); 75 | $1 = c10::ArrayRef(array$1, len$1); 76 | } 77 | 78 | %typemap(freearg) c10::ArrayRef< T > { 79 | delete (T*)$1.data(); 80 | } 81 | 82 | %enddef 83 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_data.i: -------------------------------------------------------------------------------- 1 | %javaconst(1) Layout; 2 | enum class Layout : int8_t { Strided, Sparse, Mkldnn }; 3 | DEFINE_OPTIONAL(OptLayout, Layout) 4 | 5 | %javaconst(1) QScheme; 6 | enum class QScheme : uint8_t { 7 | PER_TENSOR_AFFINE = 0, 8 | PER_CHANNEL_AFFINE = 1, 9 | PER_TENSOR_SYMMETRIC = 2, 10 | PER_CHANNEL_SYMMETRIC = 3, 11 | PER_CHANNEL_AFFINE_FLOAT_QPARAMS = 4, 12 | COMPILE_TIME_NUM_QSCHEMES = 5, 13 | }; 14 | 15 | struct TypeMeta { 16 | TypeMeta() = delete; // this is actually available, but we don't need it. 17 | c10::string_view name() const; 18 | ScalarType toScalarType(); 19 | size_t itemsize() const; 20 | 21 | static inline TypeMeta fromScalarType(ScalarType scalar_type); 22 | 23 | }; 24 | EQUALS_FROM_EQ(TypeMeta) 25 | DEFINE_OPTIONAL(OptTypeMeta, TypeMeta) 26 | 27 | // this block defines a bunch of typemetas like kLongMeta to be used with TensorOptions 28 | // cribbed from c10/core/ScalarType.h, but we can't %import it because of swig's limitations 29 | #define AT_FORALL_SCALAR_TYPES(_) \ 30 | _(uint8_t, Byte) \ 31 | _(int8_t, Char) \ 32 | _(int16_t, Short) \ 33 | _(int, Int) \ 34 | _(int64_t, Long) \ 35 | _(float, Float) \ 36 | _(double, Double) 37 | 38 | #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ 39 | _(uint8_t, Byte) /* 0 */ \ 40 | _(int8_t, Char) /* 1 */ \ 41 | _(int16_t, Short) /* 2 */ \ 42 | _(int, Int) /* 3 */ \ 43 | _(int64_t, Long) /* 4 */ \ 44 | _(at::Half, Half) /* 5 */ \ 45 | _(float, Float) /* 6 */ \ 46 | _(double, Double) /* 7 */ \ 47 | _(at::ComplexHalf, ComplexHalf) /* 8 */ \ 48 | _(std::complex, ComplexFloat) /* 9 */ \ 49 | _(std::complex, ComplexDouble) /* 10 */ \ 50 | _(bool, Bool) /* 11 */ \ 51 | _(c10::qint8, QInt8) /* 12 */ \ 52 | _(c10::quint8, QUInt8) /* 13 */ \ 53 | _(c10::qint32, QInt32) /* 14 */ \ 54 | _(at::BFloat16, BFloat16) /* 15 */ 55 | 56 | %inline %{ 57 | #define TYPE_META_FOR(tpe, name) const TypeMeta k##name##Meta = caffe2::TypeMeta::Make(); 58 | 59 | AT_FORALL_SCALAR_TYPES(TYPE_META_FOR) 60 | 61 | #undef TYPE_META_FOR 62 | 63 | %} 64 | 65 | enum class ScalarType : int8_t { 66 | #define DEFINE_ENUM(_1, n) n, 67 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM) 68 | #undef DEFINE_ENUM 69 | Undefined, 70 | NumOptions 71 | }; 72 | DEFINE_OPTIONAL(OptScalarType, ScalarType) 73 | 74 | namespace c10 { 75 | bool isFloatingType(ScalarType t); 76 | bool isSignedType(ScalarType t); 77 | bool isComplexType(ScalarType t); 78 | } 79 | 80 | enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast }; 81 | DEFINE_OPTIONAL(OptMemoryFormat, MemoryFormat) 82 | 83 | %include 84 | %include 85 | namespace c10 { 86 | EQUALS_FROM_EQ(Device) 87 | HASHCODE_FROM_STD_HASH(Device) 88 | } 89 | 90 | DEFINE_OPTIONAL(OptDevice, Device) 91 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_equals_hashcode.i: -------------------------------------------------------------------------------- 1 | // make a nice toString method using the ostream stuff they use everywhere 2 | %define TO_STRING_FROM_OSTREAM(T) 3 | %extend T { 4 | std::string toString() const { 5 | std::ostringstream ss; 6 | ss << *($self); 7 | return ss.str(); 8 | } 9 | } 10 | %enddef 11 | 12 | %define EQUALS_FROM_EQ(T) 13 | 14 | %extend T { 15 | bool equalTo(const T& o) const { 16 | return (*$self) == o; 17 | } 18 | 19 | 20 | %proxycode %{ 21 | @Override public boolean equals(Object o) { 22 | if (o instanceof $javaclassname) { 23 | return equalTo(($javaclassname)o); 24 | } else { 25 | return false; 26 | } 27 | } 28 | %} 29 | } 30 | %enddef 31 | 32 | // Unlike other macros in this file, you should use this one inside the class declaration 33 | %define EQUALS_AND_HASH_CODE_FROM_PTR_EQUALITY(T) 34 | 35 | %proxycode %{ 36 | @Override public boolean equals(Object obj) { 37 | boolean equal = false; 38 | if (obj instanceof $javaclassname) { 39 | return ((($javaclassname)obj).swigCPtr == this.swigCPtr); 40 | } else { 41 | return false; 42 | } 43 | } 44 | @Override public int hashCode() { 45 | return (int)this.swigCPtr; 46 | } 47 | %} 48 | %enddef 49 | 50 | %define HASHCODE_FROM_STD_HASH(T) 51 | 52 | %extend T { 53 | size_t hash() const { 54 | return std::hash()(*($self)); 55 | } 56 | 57 | 58 | %proxycode %{ 59 | @Override public int hashCode() { 60 | return (int)hash(); 61 | } 62 | %} 63 | } 64 | %enddef 65 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_generator_swig.i: -------------------------------------------------------------------------------- 1 | %{ 2 | #include 3 | %} 4 | 5 | %ignore at::Generator::mutex_; 6 | 7 | struct TORCH_API Generator { 8 | Generator(); 9 | 10 | 11 | Device device() const; 12 | torch::Tensor get_state() const; 13 | void set_state(const torch::Tensor& new_state); 14 | int64_t seed(); 15 | 16 | %extend { 17 | void manual_seed(int64_t seed) { 18 | // See Note [Acquire lock when using random generators] 19 | std::lock_guard lock($self->mutex()); 20 | $self->set_current_seed(seed); 21 | } 22 | 23 | int64_t initial_seed() const { 24 | return $self->current_seed(); 25 | } 26 | } 27 | }; 28 | 29 | %extend Generator { 30 | explicit Generator(c10::Device device) { 31 | if (device.type() == at::kCPU) { 32 | return new Generator(c10::make_intrusive(device.index())); 33 | #ifdef USE_CUDA 34 | } else if (device.type() == at::kCUDA) { 35 | return new Generator(c10::make_intrusive(device.index())); 36 | #endif 37 | } else { 38 | AT_ERROR("Device type ", c10::DeviceTypeName(device.type()), 39 | " is not supported for torch.Generator() api."); 40 | } 41 | } 42 | } 43 | 44 | DEFINE_OPTIONAL(OptGenerator, Generator) 45 | 46 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_indexing.i: -------------------------------------------------------------------------------- 1 | namespace at { 2 | namespace indexing { 3 | struct TORCH_API EllipsisIndexType final { EllipsisIndexType(); }; 4 | 5 | struct TORCH_API Slice final { 6 | Slice( 7 | c10::optional start_index = c10::nullopt, 8 | c10::optional stop_index = c10::nullopt, 9 | c10::optional step_index = c10::nullopt); 10 | }; 11 | 12 | struct TORCH_API TensorIndex final { 13 | 14 | // Case 1: `at::indexing::None` 15 | //TensorIndex(c10::nullopt_t); 16 | 17 | // Case 2: "..." / `at::indexing::Ellipsis` 18 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.UninitializedObject) 19 | TensorIndex(EllipsisIndexType); 20 | 21 | // Case 3: Integer value 22 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.UninitializedObject) 23 | TensorIndex(int64_t integer); 24 | 25 | // Case 4: Boolean value 26 | TensorIndex(bool boolean); 27 | 28 | // Case 5: Slice represented in `at::indexing::Slice` form 29 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.UninitializedObject) 30 | TensorIndex(Slice slice); 31 | 32 | // Case 6: Tensor value 33 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.UninitializedObject) 34 | TensorIndex(Tensor tensor); 35 | }; 36 | } 37 | } 38 | 39 | ARRAY_REF_OF_OBJECT(TensorIndexArrayRef, at::indexing::TensorIndex) 40 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_init.i: -------------------------------------------------------------------------------- 1 | %javaconst(1) FanMode; 2 | %inline %{ 3 | namespace FanMode { 4 | enum FanMode { 5 | FanIn, 6 | FanOut 7 | }; 8 | } 9 | %} 10 | 11 | typedef c10::variant FanModeType; 12 | 13 | %{ 14 | FanModeType to_fan_mode_t(unsigned r) { 15 | switch(r) { 16 | case 0: return torch::enumtype::kFanIn(); break; 17 | case 1: return torch::enumtype::kFanOut(); break; 18 | default: throw std::invalid_argument("Bad argument for FanMode"); 19 | } 20 | } 21 | 22 | %} 23 | 24 | VARIANT_ENUM(FanModeType, FanMode, to_fan_mode_t) 25 | 26 | 27 | %javaconst(1) Nonlinearity; 28 | %inline %{ 29 | namespace Nonlinearity { 30 | enum Nonlinearity { 31 | Linear, 32 | Conv1D, 33 | Conv2D, 34 | Conv3D, 35 | ConvTranspose1D, 36 | ConvTranspose2D, 37 | ConvTranspose3D, 38 | Sigmoid, 39 | Tanh, 40 | ReLU, 41 | LeakyReLU 42 | }; 43 | } 44 | %} 45 | 46 | typedef c10::variant NonlinearityType; 57 | 58 | %{ 59 | NonlinearityType to_nonlinearity_t(unsigned r) { 60 | switch(r) { 61 | case 0: return torch::enumtype::kLinear(); break; 62 | case 1: return torch::enumtype::kConv1D(); break; 63 | case 2: return torch::enumtype::kConv2D(); break; 64 | case 3: return torch::enumtype::kConv3D(); break; 65 | case 4: return torch::enumtype::kConvTranspose1D(); break; 66 | case 5: return torch::enumtype::kConvTranspose2D(); break; 67 | case 6: return torch::enumtype::kConvTranspose3D(); break; 68 | case 7: return torch::enumtype::kSigmoid(); break; 69 | case 8: return torch::enumtype::kTanh(); break; 70 | case 9: return torch::enumtype::kReLU(); break; 71 | case 10: return torch::enumtype::kLeakyReLU(); break; 72 | default: throw std::invalid_argument("Bad argument for Nonlinearity"); 73 | } 74 | } 75 | 76 | %} 77 | 78 | VARIANT_ENUM(NonlinearityType, Nonlinearity, to_nonlinearity_t) 79 | 80 | 81 | 82 | namespace torch { 83 | namespace nn { 84 | namespace init { 85 | /// Return the recommended gain value for the given nonlinearity function. 86 | TORCH_API double calculate_gain(NonlinearityType nonlinearity, double param = 0.01); 87 | 88 | TORCH_API void constant_(torch::Tensor tensor, torch::Scalar value); 89 | 90 | TORCH_API void dirac_(torch::Tensor tensor); 91 | 92 | TORCH_API void eye_(torch::Tensor matrix); 93 | 94 | TORCH_API void normal_(torch::Tensor tensor, double mean = 0, double std = 1); 95 | 96 | TORCH_API void ones_(torch::Tensor tensor); 97 | 98 | TORCH_API void orthogonal_(torch::Tensor tensor, double gain = 1.0); 99 | 100 | TORCH_API void sparse_(torch::Tensor tensor, double sparsity, double std = 0.01); 101 | 102 | TORCH_API void uniform_(torch::Tensor tensor, double low = 0, double high = 1); 103 | 104 | TORCH_API void kaiming_normal_( 105 | torch::Tensor tensor, 106 | double a = 0, 107 | FanModeType mode = torch::kFanIn, 108 | NonlinearityType nonlinearity = torch::kLeakyReLU); 109 | 110 | TORCH_API void kaiming_uniform_( 111 | torch::Tensor tensor, 112 | double a = 0, 113 | FanModeType mode = torch::kFanIn, 114 | NonlinearityType nonlinearity = torch::kLeakyReLU); 115 | 116 | TORCH_API void xavier_normal_(torch::Tensor tensor, double gain = 1.0); 117 | 118 | TORCH_API void xavier_uniform_(torch::Tensor tensor, double gain = 1.0); 119 | 120 | TORCH_API void zeros_(torch::Tensor tensor); 121 | } // namespace init 122 | } // namespace nn 123 | } // namespace torch 124 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_ir.i: -------------------------------------------------------------------------------- 1 | %{ 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | using torch::jit::SourceRange; 8 | using torch::jit::Source; 9 | using torch::jit::Graph; 10 | using torch::jit::Node; 11 | using torch::jit::Value; 12 | using torch::jit::NamedValue; 13 | %} 14 | 15 | DEFINE_OPTIONAL(OptIValue, torch::jit::IValue) 16 | 17 | %shared_ptr(torch::jit::Source) 18 | 19 | namespace std { 20 | %template(NodeVector) vector; 21 | %template(JitValueVector) vector; 22 | } // namepsace std { 23 | 24 | namespace torch { 25 | namespace jit { 26 | 27 | struct Source { 28 | Source( 29 | std::string text, 30 | c10::optional filename, 31 | size_t starting_line_no 32 | ); 33 | 34 | private: 35 | Source(); 36 | }; 37 | 38 | struct SourceRange { 39 | SourceRange( 40 | std::shared_ptr source, 41 | size_t start, 42 | size_t end); 43 | 44 | private: 45 | SourceRange(); 46 | }; 47 | 48 | struct Value { 49 | 50 | EQUALS_AND_HASH_CODE_FROM_PTR_EQUALITY(Value) 51 | 52 | torch::jit::Node* node(); 53 | TORCH_API void replaceAllUsesWith(torch::jit::Value* newValue); 54 | Value* setType(std::shared_ptr tpe); 55 | const std::shared_ptr& type() const; 56 | std::string debugName() const; 57 | %extend { 58 | 59 | c10::optional maybeIValue() { 60 | return torch::jit::toIValue($self); 61 | } 62 | 63 | std::string toString() const { 64 | return $self->debugName(); 65 | } 66 | 67 | std::vector uses() { 68 | std::vector ret; 69 | for (auto i: $self->uses()) { 70 | ret.push_back(i.user); 71 | } 72 | return ret; 73 | } 74 | } 75 | private: 76 | Value(); 77 | }; 78 | 79 | struct NamedValue { 80 | NamedValue(Value* value); 81 | NamedValue(const SourceRange& loc, Value* value); 82 | 83 | const std::string& name() const; 84 | 85 | %extend { 86 | Value* value(std::shared_ptr g) const { 87 | return $self->value(*g); 88 | } 89 | } 90 | 91 | const SourceRange& loc() const; 92 | }; 93 | 94 | struct TORCH_API Operator { 95 | %extend { 96 | Symbol op() const { 97 | return Symbol::fromQualString($self->schema().name()); 98 | } 99 | } 100 | private: 101 | Operator(); 102 | 103 | }; 104 | 105 | struct TORCH_API Node { 106 | EQUALS_AND_HASH_CODE_FROM_PTR_EQUALITY(Value) 107 | SourceRange sourceRange(); const 108 | void setSourceRange(SourceRange r); 109 | torch::jit::Value* output(); 110 | TORCH_API void replaceInputWith(torch::jit::Value* from, torch::jit::Value* to); 111 | const Operator* maybeOperator() const; 112 | 113 | void moveBefore(Node* n); 114 | void moveAfter(Node* n); 115 | 116 | bool isBefore(const Node* n) const; 117 | 118 | bool isAfter(const Node* n) const; 119 | 120 | // Declaration has NodeKind but it's a typedef 121 | Symbol kind() const; 122 | 123 | %extend { 124 | 125 | size_t numOutputs() { 126 | return $self->outputs().size(); 127 | } 128 | 129 | std::vector outputs() { 130 | return std::vector($self->outputs().begin(), $self->outputs().end()); 131 | } 132 | 133 | std::vector inputs() { 134 | return std::vector($self->inputs().begin(), $self->inputs().end()); 135 | } 136 | } 137 | 138 | private: 139 | Node(); 140 | }; 141 | 142 | TO_STRING_FROM_OSTREAM(Node); 143 | 144 | struct Graph { 145 | Graph(); 146 | Value* addInput(std::string name = ""); 147 | TORCH_API const std::string toString(bool print_source_locations = true) const; 148 | TORCH_API Value* insertConstant( 149 | const IValue& val, 150 | c10::optional loc = c10::nullopt); 151 | TORCH_API Value* insertGetAttr(Value* obj, const std::string& field); 152 | 153 | %extend { 154 | std::vector inputs() { 155 | return std::vector($self->inputs().begin(), $self->inputs().end()); 156 | } 157 | 158 | std::vector nodes() { 159 | return std::vector($self-> nodes().begin(), $self->nodes().end()); 160 | } 161 | 162 | std::vector outputs() { 163 | return std::vector($self->outputs().begin(), $self->outputs().end()); 164 | } 165 | 166 | std::vector insertGraph( 167 | std::shared_ptr callee, 168 | const std::vector& inputs 169 | ) { 170 | return torch::jit::insertGraph(*($self), *callee, ArrayRef(inputs)); 171 | } 172 | 173 | Node* insertConstantChunk(Value* v, size_t size, int64_t dim) { 174 | auto* newNode = $self->create(prim::ConstantChunk, {v}, size); 175 | newNode->i_(attr::chunks, size); 176 | newNode->i_(attr::dim, dim); 177 | $self->insertNode(newNode); 178 | return newNode; 179 | } 180 | 181 | Value* insertObject(std::shared_ptr type) { 182 | return $self->createObject(type)->output(); 183 | } 184 | 185 | Value* insertList(const std::shared_ptr& elem_type, const std::vector& inputs) { 186 | auto* created = $self->createList(elem_type, ArrayRef(inputs)); 187 | return $self->insertNode(created)->output(); 188 | } 189 | 190 | // This is the same as torch::jit::Graph::insertMethodCall, but avoids 191 | // taking a MatchedSchema and takes arguments and a return type directly. 192 | Value* insertMethodCall( 193 | std::string method_name, 194 | const std::vector& arguments, 195 | const std::shared_ptr returnType 196 | ) { 197 | Value* result = $self->insertNode($self->create(prim::CallMethod, arguments)) 198 | ->s_(attr::name, std::move(method_name)) 199 | ->output() 200 | ->setType(returnType); 201 | return result; 202 | } 203 | } 204 | 205 | size_t registerOutput(Value* n); 206 | 207 | %extend { 208 | 209 | Value* insert( 210 | Symbol opname, 211 | const std::vector& args, 212 | const SourceRange& range) { 213 | return $self->insert(opname, args, {}, range); 214 | } 215 | } 216 | 217 | }; 218 | 219 | } // jit 220 | } // torch 221 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_jit_type.i: -------------------------------------------------------------------------------- 1 | %{ 2 | #include 3 | #include 4 | using torch::jit::CompilationUnit; 5 | %} 6 | 7 | namespace c10 { 8 | struct QualifiedName { 9 | QualifiedName(const std::string& name); 10 | explicit QualifiedName(std::vector atoms); 11 | explicit QualifiedName(const QualifiedName& prefix, std::string name); 12 | const std::string& qualifiedName() const; 13 | const std::string& name() const; 14 | const std::string& prefix(); 15 | }; 16 | EQUALS_FROM_EQ(QualifiedName) 17 | HASHCODE_FROM_STD_HASH(QualifiedName) 18 | } 19 | 20 | 21 | DEFINE_OPTIONAL(OptQualifiedName, c10::QualifiedName) 22 | DEFINE_OPTIONAL(OptSizes, std::vector) 23 | 24 | namespace std { 25 | %template(FunctionVector) vector; 26 | } // namepsace std { 27 | 28 | 29 | namespace torch { 30 | namespace jit { 31 | 32 | struct CompilationUnit { 33 | 34 | torch::jit::Function* create_function( 35 | c10::QualifiedName name, 36 | std::shared_ptr graph, 37 | bool shouldMangle = false); 38 | 39 | %extend { 40 | 41 | IValue run_method(const c10::QualifiedName& method_name, const std::vector& args) { 42 | return $self->get_function(method_name)(args); 43 | } 44 | 45 | IValue run_method(const std::string& method_name, const std::vector& args) { 46 | return $self->get_function(method_name)(args); 47 | } 48 | } 49 | }; 50 | } // namespace jit 51 | } // namespace torch 52 | 53 | namespace c10 { 54 | enum class TypeKind { 55 | AnyType, 56 | TensorType, 57 | TupleType, 58 | ListType, 59 | DictType, 60 | NumberType, 61 | FloatType, 62 | FutureType, 63 | IntType, 64 | NoneType, 65 | StringType, 66 | GeneratorType, 67 | BoolType, 68 | OptionalType, 69 | VarType, 70 | DeviceObjType, 71 | FunctionType, 72 | ClassType, 73 | CapsuleType, 74 | InterfaceType 75 | }; 76 | 77 | struct CAFFE2_API Type { 78 | TypeKind kind() const; 79 | 80 | %extend { 81 | 82 | std::shared_ptr expectTensor() { 83 | return $self->expect(); 84 | } 85 | 86 | // TODO out of laziness we don't expose any C++ types other than Type and TupleType. 87 | // We could expose them all. 88 | 89 | // TODO more types 90 | static std::shared_ptr createDict(std::shared_ptr keyType, std::shared_ptr valueType) { 91 | return DictType::create(keyType, valueType); 92 | } 93 | 94 | static std::shared_ptr createList(std::shared_ptr elementType) { 95 | return ListType::create(elementType); 96 | } 97 | 98 | static std::shared_ptr getString() { 99 | return StringType::get(); 100 | } 101 | 102 | static std::shared_ptr getFloat() { 103 | return FloatType::get(); 104 | } 105 | 106 | static std::shared_ptr getInt() { 107 | return IntType::get(); 108 | } 109 | 110 | static std::shared_ptr getBool() { 111 | return BoolType::get(); 112 | } 113 | } 114 | private: 115 | Type(); 116 | }; 117 | EQUALS_FROM_EQ(Type) 118 | 119 | TO_STRING_FROM_OSTREAM(Type); 120 | 121 | // TupleType actually inherits from NamedType (which inherits from Type) but swig is happy with this declaration. 122 | struct CAFFE2_API TupleType: public Type { 123 | static std::shared_ptr create(const std::vector>& types); 124 | }; 125 | 126 | // ClassType actually inherits from NamedType (which inherits from Type) but swig is happy with this declaration. 127 | struct CAFFE2_API ClassType: public Type { 128 | 129 | const c10::optional& name() const; 130 | 131 | const std::vector& methods() const; 132 | 133 | size_t addAttribute(const std::string& name, const std::shared_ptr& type, bool is_parameter = false); 134 | }; 135 | 136 | struct CAFFE2_API TensorType: public Type { 137 | // Dim/device/type unspecified 138 | static std::shared_ptr get(); 139 | c10::optional device(); 140 | %extend { 141 | 142 | static std::shared_ptr createContiguous(TypeMeta typeMeta, DeviceType deviceType, IntArrayRef dim) { 143 | return TensorType::createContiguous(c10::typeMetaToScalarType(typeMeta), deviceType, dim); 144 | } 145 | 146 | // It seems that ScalarType is quasi-deprecated (https://pytorch.org/cppdocs/notes/tensor_creation.html), 147 | // so convert to a type meta here. 148 | c10::optional dtype() const { 149 | if ($self->scalarType()) { 150 | return c10::scalarTypeToTypeMeta(*($self->scalarType())); 151 | } else { 152 | return nullopt; 153 | } 154 | } 155 | 156 | c10::optional> sizes() const { 157 | return $self->sizes().concrete_sizes(); 158 | } 159 | } 160 | }; 161 | 162 | } // namespace c10 163 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_list.h: -------------------------------------------------------------------------------- 1 | %{ 2 | #include 3 | %} 4 | 5 | namespace c10 { 6 | template 7 | class List final { 8 | public: 9 | //List(std::shared_ptr elementType); 10 | void push_back(const T& value) const; 11 | void reserve(size_t new_cap) const; 12 | %extend{ 13 | 14 | // This is present in master but not in the version (1.4) that we're pinned to. 15 | std::vector vec() const { 16 | std::vector result(($self)->begin(), ($self)->end()); 17 | return result; 18 | } 19 | } 20 | 21 | private: 22 | explicit List(); 23 | }; 24 | } 25 | 26 | %define DEFINE_LIST_OF_OPTIONAL(ListT, T) 27 | 28 | %template(ListT) c10::List< c10::optional< T > >; 29 | %naturalvar c10::List< T >; 30 | 31 | %typemap(jni) c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & "jlongArray" // Use jlongArray for CPtrs, really these are objects 32 | %typemap(jstype) c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & "scala.Option<$typemap(jboxtype, T)>[]" 33 | %typemap(jtype) c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & "long[]" 34 | %typemap(javain) c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & "$javainput" 35 | 36 | // Here, we pass in an array of long representing pointers to negative objects. We use -1 to indicate an optional 37 | // value not present. I am pretty sure that points must be positive and so -1 is safe, but we're in 38 | // for a vanishngly rare but really bad time if that ends up not being true. 39 | %typemap(javain, 40 | pre=" long[] cptrs$javainput = new long[$javainput.length]; for (int i = 0; i < $javainput.length; ++i) { if ($javainput[i].isEmpty()) { cptrs$javainput[i] = -1; } else { cptrs$javainput[i] = $typemap(jboxtype, T).getCPtr($javainput[i].get()); } }", 41 | pgcppname="cptrs$javainput") 42 | c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & "cptrs$javainput" 43 | 44 | %typemap(in) c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & { 45 | 46 | $1 = new c10::List< c10::optional< T > >(); 47 | auto elems$1 = (jenv)->GetLongArrayElements($input, nullptr); 48 | size_t len$1 = (jenv)->GetArrayLength($input); 49 | $1->reserve(len$1); 50 | for (size_t i = 0; i < len$1; ++i) { 51 | if (elems$1[i] == -1) { 52 | $1->push_back(c10::nullopt); 53 | } else { 54 | $1->push_back(*(T*)elems$1[i]); 55 | } 56 | } 57 | (jenv)->ReleaseLongArrayElements($input, elems$1, 0); 58 | } 59 | 60 | %typemap(freearg) c10::List< c10::optional< T > >, const c10::List< c10::optional< T > > & { 61 | delete $1; 62 | } 63 | 64 | %enddef 65 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_optim_swig.i: -------------------------------------------------------------------------------- 1 | // optimization include for torch_swig 2 | 3 | %{ 4 | using namespace torch::nn; 5 | #include 6 | %} 7 | 8 | namespace torch { 9 | namespace nn { 10 | namespace utils { 11 | inline double clip_grad_norm_( 12 | std::vector parameters, 13 | double max_norm, 14 | double norm_type = 2.0); 15 | } // namespace utils 16 | } // namespace nn 17 | } // namespace torch 18 | 19 | namespace torch { 20 | namespace optim { 21 | 22 | class TORCH_API OptimizerOptions { 23 | }; 24 | 25 | // TODO handle param groups better. 26 | // A note about options: Torch has switched to using parameters groups, each with their own options 27 | // instead of having one set of parameters per Optimizer. There is backwards compatibility of sorts 28 | // for the one parameter setup: there is still a constructor for each optimizer that takes a single param 29 | // list and uses the default options. So, each optimizer implementation exposes its options 30 | // by calling through to the defaults. 31 | 32 | class TORCH_API OptimizerParamGroup { 33 | public: 34 | OptimizerParamGroup(const OptimizerParamGroup& param_group); 35 | OptimizerParamGroup(std::vector params); 36 | 37 | bool has_options() const; 38 | OptimizerOptions& options(); 39 | std::vector& params(); 40 | }; 41 | 42 | struct Optimizer { 43 | Optimizer() = delete; 44 | virtual void step(); 45 | virtual void zero_grad(); 46 | // std::vector doesn't work out of the box, so we'll hack around it 47 | //const std::vector& param_groups() const noexcept; 48 | %extend { 49 | int num_param_groups() const { 50 | return $self->param_groups().size(); 51 | } 52 | OptimizerParamGroup& param_group(int i) { 53 | return $self->param_groups()[i]; 54 | } 55 | } 56 | 57 | %extend { 58 | // Same as Optimizer::parameters, which is going to be deleted in 1.6 59 | // Remove if we handel param groups. 60 | const std::vector& all_parameters() const noexcept { 61 | return $self->param_groups().at(0).params(); 62 | } 63 | } 64 | }; 65 | 66 | 67 | struct SGDOptions : public OptimizerOptions { 68 | /* implicit */ SGDOptions(double lr); 69 | TORCH_ARG(SGDOptions, double, lr); 70 | TORCH_ARG(SGDOptions, double, momentum); 71 | TORCH_ARG(SGDOptions, double, dampening); 72 | TORCH_ARG(SGDOptions, double, weight_decay); 73 | TORCH_ARG(SGDOptions, bool, nesterov); 74 | %extend { 75 | // TODO make a macro or find some otherway of generalizing this casting 76 | static SGDOptions& cast(OptimizerOptions& opts) { 77 | return static_cast(opts); 78 | } 79 | } 80 | }; 81 | 82 | struct SGD: Optimizer { 83 | SGD(std::vector parameters, const SGDOptions& options_); 84 | %extend { 85 | SGDOptions& getOptions() { return static_cast($self->defaults()); } 86 | } 87 | }; 88 | 89 | struct AdagradOptions : public OptimizerOptions { 90 | /* implicit */ AdagradOptions(double lr); 91 | TORCH_ARG(AdagradOptions, double, lr); 92 | TORCH_ARG(AdagradOptions, double, lr_decay); 93 | TORCH_ARG(AdagradOptions, double, weight_decay); 94 | %extend { 95 | // TODO make a macro or find some otherway of generalizing this casting 96 | static AdagradOptions& cast(OptimizerOptions& opts) { 97 | return static_cast(opts); 98 | } 99 | } 100 | }; 101 | 102 | struct Adagrad: Optimizer { 103 | Adagrad(std::vector parameters, const AdagradOptions& options_); 104 | %extend { 105 | AdagradOptions& getOptions() { return static_cast($self->defaults()); } 106 | } 107 | }; 108 | 109 | struct AdamOptions : public OptimizerOptions { 110 | /* implicit */ AdamOptions(double lr); 111 | TORCH_ARG(AdamOptions, double, lr); 112 | %extend { 113 | double beta1() const { return std::get<0>($self->betas()); } // 0.9 114 | AdamOptions beta1(double v) { return $self->betas(std::make_tuple(v, std::get<1>($self->betas()))); } 115 | double beta2() const { return std::get<1>($self->betas()); } // 0.999 116 | AdamOptions beta2(double v) { return $self->betas(std::make_tuple(std::get<0>($self->betas()), v)); } 117 | } 118 | TORCH_ARG(AdamOptions, double, weight_decay); // 0 119 | TORCH_ARG(AdamOptions, double, eps); // 1E-8 120 | TORCH_ARG(AdamOptions, bool, amsgrad) // false; 121 | %extend { 122 | // TODO make a macro or find some otherway of generalizing this casting 123 | static AdamOptions& cast(OptimizerOptions& opts) { 124 | return static_cast(opts); 125 | } 126 | } 127 | }; 128 | 129 | struct Adam: Optimizer { 130 | Adam(std::vector parameters, const AdamOptions& options_); 131 | %extend { 132 | AdamOptions& getOptions() { 133 | return static_cast($self->defaults()); 134 | } 135 | } 136 | }; 137 | 138 | struct RMSpropOptions : public OptimizerOptions { 139 | /* implicit */ RMSpropOptions(double lr); 140 | TORCH_ARG(RMSpropOptions, double, lr); 141 | TORCH_ARG(RMSpropOptions, double, alpha); // 0.99 142 | TORCH_ARG(RMSpropOptions, double, eps); // 1E-8 143 | TORCH_ARG(RMSpropOptions, double, weight_decay); // 0 144 | TORCH_ARG(RMSpropOptions, double, momentum); // 0 145 | TORCH_ARG(RMSpropOptions, bool, centered) // false; 146 | %extend { 147 | // TODO make a macro or find some otherway of generalizing this casting 148 | static RMSpropOptions& cast(OptimizerOptions& opts) { 149 | return static_cast(opts); 150 | } 151 | } 152 | }; 153 | 154 | struct RMSprop: Optimizer { 155 | RMSprop(std::vector parameters, const RMSpropOptions& options_); 156 | %extend { 157 | RMSpropOptions& getOptions() { 158 | return static_cast($self->defaults()); 159 | } 160 | } 161 | }; 162 | 163 | } // namespace optim 164 | } // namespace torch 165 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_optional.i: -------------------------------------------------------------------------------- 1 | // based on https://gist.github.com/vadz/7ba2bfb73d04483e2f6254b05feb7e1f 2 | 3 | // Typemap that makes the c10::optional type in swig appear as scala.Option 4 | 5 | namespace c10 { 6 | template 7 | struct optional { 8 | optional(); 9 | optional(T value); 10 | bool has_value() const noexcept; 11 | T value(); 12 | }; 13 | } 14 | 15 | %define DEFINE_OPTIONAL(OptT, T) 16 | 17 | // Use reference, not pointer, typemaps for member variables of this type. 18 | %naturalvar c10::optional< T >; 19 | 20 | %template(OptT) c10::optional< T >; 21 | 22 | // Note the use of "jboxtype" instead of just "jstype": for primitive types, 23 | // such as "double", they're different and we must use the former as 24 | // Optional can only be used with reference types in Java. 25 | %typemap(jstype) c10::optional< T >, const c10::optional< T >& "scala.Option<$typemap(jboxtype, T)>" 26 | 27 | // This typemap is used for function arguments of this type. 28 | %typemap(javain, 29 | pre=" OptT opt$javainput = $javainput.isDefined() ? new OptT($javainput.get()) : new OptT();", 30 | post=" opt$javainput.delete();", 31 | pgcppname="opt$javainput") 32 | c10::optional< T >, const c10::optional< T >& "$javaclassname.getCPtr(opt$javainput)" 33 | 34 | // This typemap is for functions returning objects of this type. 35 | %typemap(javaout) c10::optional< T >, const c10::optional< T >& { 36 | OptT optValue = new OptT($jnicall, $owner); 37 | if (optValue.has_value()) { 38 | scala.Option<$typemap(jboxtype, T)> someValue = new scala.Some<$typemap(jboxtype, T)>(optValue.value()); 39 | optValue.delete(); 40 | return someValue; 41 | } else { 42 | return scala.Option.apply(null); 43 | } 44 | } 45 | 46 | %enddef 47 | 48 | %define PRIMITIVE_OPTIONAL(OptT, T) 49 | 50 | // Use reference, not pointer, typemaps for member variables of this type. 51 | %naturalvar c10::optional< T >; 52 | 53 | %template(OptT) c10::optional< T >; 54 | 55 | // Note the use of "jboxtype" instead of just "jstype": for primitive types, 56 | // such as "double", they're different and we must use the former as 57 | // Optional can only be used with reference types in Java. 58 | %typemap(jstype) c10::optional< T >, const c10::optional< T >& "java.util.Optional""$typemap(jboxtype, T)" 59 | 60 | // This typemap is used for function arguments of this type. 61 | %typemap(javain, 62 | pre= " OptT opt$javainput = $javainput.isPresent() ? new OptT($javainput.getAs$typemap(jboxtype, T)()) : new OptT();", 63 | post=" opt$javainput.delete();", 64 | pgcppname="opt$javainput") 65 | c10::optional< T >, const c10::optional< T >& "$javaclassname.getCPtr(opt$javainput)" 66 | 67 | // This typemap is for functions returning objects of this type. 68 | %typemap(javaout) c10::optional< T >, const c10::optional< T >& { 69 | OptT optValue = new OptT($jnicall, $owner); 70 | if (optValue.has_value()) { 71 | java.util.Optional$typemap(jboxtype, T) someValue = java.util.Optional$typemap(jboxtype, T).of(optValue.value()); 72 | optValue.delete(); 73 | return someValue; 74 | } else { 75 | return java.util.Optional$typemap(jboxtype, T).empty(); 76 | } 77 | } 78 | 79 | %enddef 80 | 81 | DEFINE_OPTIONAL(OptBool, bool) 82 | DEFINE_OPTIONAL(OptFloat, float) 83 | PRIMITIVE_OPTIONAL(OptDouble, double) 84 | PRIMITIVE_OPTIONAL(OptLong, int64_t) 85 | PRIMITIVE_OPTIONAL(OptInt, int32_t) 86 | DEFINE_OPTIONAL(OptString, std::string) 87 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_primitives.i: -------------------------------------------------------------------------------- 1 | // this is garbage under linux gcc. In particular, it decides that int64_t is long, when it's 2 | // actually long long and gcc complains about various coercions. 3 | // see https://github.com/swig/swig/issues/568 which is closed but it's still busted. 4 | // So rather than use it, we use enough typemaps (below) to make swig happy 5 | //%include 6 | 7 | typedef jint int32_t; 8 | typedef jshort int16_t; 9 | 10 | %typemap(jboxtype) int64_t, const int64_t & "Long" 11 | %typemap(jni) int64_t, const int64_t & "jlong" 12 | %typemap(javaout) int64_t, const int64_t & { return $jnicall; } 13 | %typemap(javain) int64_t, const int64_t & "$javainput" 14 | %typemap(jstype) int64_t, const int64_t & "long" 15 | %typemap(jtype) int64_t, const int64_t & "long" 16 | 17 | %typemap(in) int64_t %{$1 = ($1_ltype)$input; %} 18 | %typemap(out) int64_t %{ $result = (jlong)$1; %} 19 | // Reference types get treated like pointers, so we need to take addresses and dereference. 20 | %typemap(in) const int64_t & %{$1 = ($1_ltype)(&$input); %} 21 | %typemap(out) const int64_t & %{ $result = (jlong)(*$1); %} 22 | 23 | %typemap(jboxtype) int8_t, const int8_t & "Byte" 24 | %typemap(jni) int8_t, const int8_t & "jbyte" 25 | %typemap(javaout) int8_t, const int8_t & { return $jnicall; } 26 | %typemap(javain) int8_t, const int8_t & "$javainput" 27 | %typemap(jstype) int8_t, const int8_t & "byte" 28 | %typemap(jtype) int8_t, const int8_t & "byte" 29 | 30 | %typemap(in) int8_t %{$1 = ($1_ltype)$input; %} 31 | %typemap(out) int8_t %{ $result = (jbyte)$1; %} 32 | // Reference types get treated like pointers, so we need to take addresses and dereference. 33 | %typemap(in) const int8_t & %{$1 = ($1_ltype)(&$input); %} 34 | %typemap(out) const int8_t & %{ $result = (jbyte)(*$1); %} 35 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_profiler.i: -------------------------------------------------------------------------------- 1 | #include 2 | namespace torch { 3 | namespace autograd { 4 | namespace profiler { 5 | 6 | struct TORCH_API RecordProfile { 7 | RecordProfile(const std::string& filename); 8 | 9 | ~RecordProfile(); 10 | }; 11 | } // namespace profiler 12 | } // namespace autograd 13 | } // namespace torch 14 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_reduction.i: -------------------------------------------------------------------------------- 1 | %javaconst(1) Reduction; 2 | namespace Reduction { 3 | enum Reduction { 4 | None, // Do not reduce 5 | Mean, // (Possibly weighted) mean of losses 6 | Sum//, // Sum losses 7 | //END 8 | }; 9 | } 10 | 11 | // torch is moving to this type, which is hard to map to Java, so we use the above enum in Java land and map it back 12 | %naturalvar default_reduction_t; 13 | %inline %{ 14 | typedef c10::variant default_reduction_t; 15 | %} 16 | 17 | 18 | %{ 19 | default_reduction_t to_default_reduction_t(unsigned r) { 20 | switch(r) { 21 | case Reduction::None: return torch::kNone; break; 22 | case Reduction::Mean: return torch::kMean; break; 23 | case Reduction::Sum: return torch::kSum; break; 24 | default: throw std::invalid_argument("Bad argument for Reduction"); 25 | } 26 | } 27 | 28 | %} 29 | 30 | VARIANT_ENUM(default_reduction_t, Reduction, to_default_reduction_t) 31 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_scalar.i: -------------------------------------------------------------------------------- 1 | namespace torch { 2 | 3 | struct Scalar { 4 | Scalar(float value); 5 | Scalar(double value); 6 | Scalar(int value); 7 | Scalar(short value); 8 | Scalar(int64_t value); 9 | Scalar(bool value); 10 | 11 | template T to() const; 12 | %template(toFloat) to; 13 | %template(toDouble) to; 14 | %template(toInt) to; 15 | %template(toLong) to; 16 | %template(toBoolean) to; 17 | 18 | ScalarType type() const; 19 | %rename(unary_minus) operator-; 20 | Scalar operator-() const; 21 | Scalar conj() const; 22 | Scalar log() const; 23 | 24 | bool isFloatingPoint() const; 25 | bool isIntegral(bool includeBool) const; 26 | bool isComplex() const; 27 | bool isBoolean() const; 28 | }; 29 | TO_STRING_FROM_OSTREAM(Scalar); 30 | } 31 | 32 | ARRAY_REF_OF_OBJECT(ScalarList, torch::Scalar) 33 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_script_swig.i: -------------------------------------------------------------------------------- 1 | 2 | %{ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using torch::jit::IValue; 18 | 19 | using torch::jit::Function; 20 | 21 | using torch::jit::ExtraFilesMap; 22 | #include 23 | #include 24 | %} 25 | 26 | namespace c10 { 27 | 28 | template 29 | class Dict final { 30 | public: 31 | explicit Dict(std::shared_ptr keyType, std::shared_ptr valueType); 32 | // TODO more methods 33 | 34 | // TODO this should return an iterator 35 | void insert_or_assign(Key&& key, Value&& value) const; 36 | 37 | private: 38 | explicit Dict(); 39 | }; 40 | 41 | } // namespace c10 42 | 43 | using ExtraFilesMap = std::unordered_map; 44 | 45 | namespace std { 46 | %template(ExtraFilesMap) unordered_map; 47 | %template(TensorTypeVector) vector>; 48 | } 49 | 50 | namespace c10 { 51 | struct CAFFE2_API Symbol { 52 | static Symbol fromQualString(const std::string& s); 53 | const char * toDisplayString() const; 54 | }; 55 | EQUALS_FROM_EQ(Symbol) 56 | HASHCODE_FROM_STD_HASH(Symbol) 57 | } 58 | 59 | // https://pytorch.org/tutorials/advanced/cpp_export.html 60 | struct IValue { 61 | // Constructs None 62 | IValue(); 63 | ~IValue(); 64 | 65 | std::shared_ptr type() const; 66 | 67 | // Tensor 68 | IValue(torch::Tensor t); 69 | bool isTensor() const; 70 | torch::Tensor toTensor() &&; 71 | torch::Tensor toTensor() const &; 72 | 73 | // TODO: figure out intrusive_ptrs 74 | // TODO? blobs 75 | // TODO? capsules 76 | 77 | 78 | IValue(c10::intrusive_ptr v); 79 | 80 | // Double 81 | IValue(double t); 82 | bool isDouble() const; 83 | double toDouble(); 84 | 85 | // TODO? futures 86 | 87 | // Int 88 | IValue(int64_t t); 89 | bool isInt() const; 90 | int64_t toInt(); 91 | 92 | // Bool 93 | IValue(bool t); 94 | bool isBool() const; 95 | bool toBool(); 96 | 97 | // IntList 98 | IValue(IntArrayRef v); 99 | bool isIntList() const { return Tag::IntList == tag; } 100 | // TODO: c10::List ? 101 | 102 | // ConstantString 103 | IValue(std::string v); 104 | bool isString() const; 105 | const std::string& toStringRef() const; 106 | 107 | 108 | // DoubleList 109 | // TODO c10::list 110 | IValue(std::vector v); 111 | bool isDoubleList() const; 112 | 113 | //TensorList 114 | // TODO c10::list 115 | IValue(const std::vector& v); 116 | bool isTensorList() const; 117 | // TODO: fix swig output generation code for htis 118 | // TensorList toTensorListRef() const; 119 | 120 | // GenericList 121 | IValue(c10::List v); 122 | 123 | // GenericDict 124 | IValue(c10::Dict v); 125 | bool isGenericDict() const; 126 | c10::Dict toGenericDict() const &; 127 | 128 | template 129 | IValue(c10::Dict v); 130 | 131 | // IValue(c10::intrusive_ptr v); 132 | %extend { 133 | IValue(torch::jit::Module* v) { 134 | return new IValue(v->_ivalue()); 135 | } 136 | } 137 | bool isModule() const; 138 | torch::jit::Module toModule() const; 139 | 140 | bool isNone() const; 141 | 142 | static IValue uninitialized(); 143 | 144 | IValue(torch::Scalar s); 145 | bool isScalar() const; 146 | torch::Scalar toScalar() const; 147 | 148 | // perhaps counterintuitively, an IValue can represent a device too. Basically anything 149 | // that can be an argument to a torch function can be an IValue 150 | // Device 151 | IValue(Device s); 152 | bool isDevice() const; 153 | Device toDevice() const; 154 | // TODO: ScalarType? 155 | // TODO: Layout? 156 | // TODO: MemoryFormat? 157 | // TODO: QScheme? 158 | std::string tagKind() const; 159 | bool isSameIdentity(const IValue& rhs) const; 160 | }; 161 | 162 | namespace c10 { 163 | 164 | %template(IValueList) List; 165 | %template(IValueDict) Dict; 166 | %template(IValueArrayRef) ArrayRef; 167 | 168 | } // namespace c10 169 | 170 | namespace std { 171 | %template(IValueVector) vector; 172 | %template(NamedValueVector) vector; 173 | %template(TypeVector) vector>; 174 | %template(NamedModuleVector) vector>; 175 | %template(NamedIValueVector) vector>; 176 | %template(NamedTensorVector) vector>; 177 | } 178 | 179 | namespace torch { 180 | namespace jit { 181 | struct TORCH_API Function { 182 | std::shared_ptr graph() const; 183 | const std::string& name() const; 184 | private: 185 | Function(); 186 | }; 187 | 188 | template 189 | struct Named { 190 | std::string name; 191 | %extend { 192 | // Exposing value like this instead of the raw member makes sure that swig 193 | // allocates a new T. Because all the types we use Named for are internally reference-counted pointers 194 | // (Module, Tensor, IValue), it's important that we do this to get reference-counting right. 195 | T value() { 196 | return $self->value; 197 | } 198 | } 199 | }; 200 | 201 | %template(NamedModule) Named; 202 | %template(NamedIValue) Named; 203 | %template(NamedTensor) Named; 204 | 205 | // avoid clash with java.lang.Object 206 | %rename(ScriptObject) Object; 207 | struct Object { 208 | void setattr(const std::string& name, IValue v); 209 | 210 | IValue attr(const std::string& name) const; 211 | IValue attr(const std::string& name, IValue or_else) const; 212 | bool hasattr(const std::string& name) const; 213 | 214 | %extend { 215 | 216 | std::shared_ptr compilation_unit() { 217 | return ($self)->_ivalue()->compilation_unit(); 218 | } 219 | 220 | std::string name() const { 221 | return ($self)->_ivalue()->name(); 222 | } 223 | 224 | std::shared_ptr slot_type(const std::string& name) { 225 | size_t slot = ($self)->_ivalue()->type()->getAttributeSlot(name); 226 | return ($self)->_ivalue()->type()->getAttribute(slot); 227 | } 228 | 229 | const std::vector get_method_functions() const { 230 | std::vector result; 231 | for (const auto& m : ($self)->get_methods()) { 232 | result.push_back(&(m.function())); 233 | } 234 | return result; 235 | } 236 | 237 | IValue run_method(const std::string& method_name, std::vector inputs) { 238 | return ($self)->get_method(method_name)(std::move(inputs)); 239 | } 240 | } 241 | }; 242 | 243 | struct Module: Object { 244 | 245 | explicit Module(c10::QualifiedName class_name); 246 | 247 | Module( 248 | c10::QualifiedName, 249 | std::shared_ptr cu, 250 | bool shouldMangle = false 251 | ); 252 | 253 | Module(std::shared_ptr cu, std::shared_ptr type); 254 | 255 | IValue forward(std::vector inputs); 256 | 257 | std::shared_ptr type() const; 258 | 259 | c10::IValue attr(const std::string& name) const; 260 | 261 | c10::IValue attr(const std::string& name, c10::IValue or_else) const; 262 | 263 | void setattr(const std::string& name, c10::IValue v); 264 | 265 | void register_parameter(const std::string& name, torch::Tensor v, bool is_buffer); 266 | void register_attribute( 267 | const std::string& name, 268 | const std::shared_ptr t, 269 | IValue v, 270 | bool is_param = false, 271 | bool is_buffer = false); 272 | 273 | void register_module(const std::string& name, const Module& m); 274 | 275 | %extend { 276 | std::vector> named_children() const { 277 | std::vector> ret; 278 | ret.reserve($self->named_children().size()); 279 | for (const auto& named_child: $self->named_children()) { 280 | ret.push_back(named_child); 281 | } 282 | return ret; 283 | } 284 | 285 | std::vector> named_parameters(bool recurse = true) const { 286 | const auto& params = $self->named_parameters(recurse); 287 | std::vector> ret; 288 | ret.reserve(params.size()); 289 | for (const auto& named_param: params) { 290 | ret.push_back(named_param); 291 | } 292 | return ret; 293 | } 294 | } 295 | 296 | 297 | void define(const std::string& src); 298 | 299 | %extend { 300 | // TODO I (@adampauls) don't know if there's a better way to define a TorchScript "method" 301 | // (a function that is a member of a class) directly. Best I could come up with is calling 302 | // CompilationUnit.create_function and then passing the result to this method, which I pieced together 303 | // by reading torch/csrc/jit/script/compiler.cpp. 304 | void define_method(Function* fn) { 305 | const auto selfRef = torch::jit::SimpleSelf($self->type()); 306 | selfRef.getClassType()->addMethod(fn); 307 | } 308 | 309 | bool has_method(const std::string& basename) const { 310 | return $self->find_method(basename).has_value(); 311 | } 312 | 313 | torch::jit::Function* find_function(const std::string& basename) const { 314 | return &($self->find_method(basename)->function()); 315 | } 316 | 317 | } 318 | 319 | void save( 320 | const std::string& filename, 321 | const ExtraFilesMap& extra_files = ExtraFilesMap()) const; 322 | 323 | Module clone() const; 324 | 325 | void train(bool on = true); 326 | void eval(); 327 | 328 | bool is_training(); 329 | 330 | void to(Device device, bool non_blocking = false); 331 | void to(Device device, ScalarType dtype, bool non_blocking = false); 332 | void to(ScalarType dtype, bool non_blocking = false); 333 | }; 334 | 335 | %rename(load_script_module) load; 336 | TORCH_API Module load( 337 | const std::string& filename, 338 | c10::optional device = c10::nullopt, 339 | ExtraFilesMap& extra_files = default_extra_files); 340 | 341 | } // namespace jit 342 | } // namespace torch 343 | 344 | namespace torch { 345 | namespace jit { 346 | std::shared_ptr compile(const std::string &source); 347 | TORCH_API void runRequiredPasses(const std::shared_ptr& g); 348 | } 349 | } 350 | 351 | %inline { 352 | // Copied verbatim from graph_executor.cpp. It is file-scoped unfortunately. 353 | // // TODO try to expose this properly in torch. 354 | void runOptimization(std::shared_ptr graph) { 355 | // Basic graph preprocessing to eliminate noise. 356 | EliminateDeadCode(graph); 357 | EliminateCommonSubexpression(graph); 358 | ConstantPooling(graph); 359 | 360 | PeepholeOptimize(graph); 361 | ConstantPropagation(graph); 362 | 363 | // Unroll small loops, and eliminate expressions that are the same at every 364 | // iteration. 365 | UnrollLoops(graph); 366 | EliminateCommonSubexpression(graph); 367 | 368 | CheckInplace(graph); 369 | } 370 | 371 | // Copied from a snippet ("Phase 2") inside compileSpec in graph_executor.cpp. 372 | // TODO this is buggy -- it does the wrong thing for aten::matmul at least. 373 | void runTensorShapePropagation(std::shared_ptr opt_graph) { 374 | ConstantPropagation(opt_graph); 375 | PropagateInputShapes(opt_graph); 376 | PropagateRequiresGrad(opt_graph); 377 | } 378 | 379 | } 380 | 381 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_serialize_swig.i: -------------------------------------------------------------------------------- 1 | %inline %{ 2 | #include 3 | #include 4 | %} 5 | 6 | // This set of macros defines the overloads for loading and saving using the torch serialization framework 7 | %define SERIALIZATION_SET(NAME, TPE, ARG) 8 | namespace torch { 9 | void load(TPE & ARG, const std::string& file_name, c10::optional device = c10::nullopt); 10 | void load(TPE & ARG, const char* data, size_t size, c10::optional device = c10::nullopt); 11 | void save(const TPE & ARG, const std::string& file_name); 12 | } 13 | 14 | /** uses the above to write to a java byte array */ 15 | %native (save_##NAME##_to_byte_array) jbyteArray save_##NAME##_to_byte_array(const TPE & ARG); 16 | /** uses the above to load from a java byte array */ 17 | %native (load_##NAME##_from_byte_array) void load_##NAME##_from_byte_array(TPE & ARG, jbyteArray arr); 18 | 19 | %{ 20 | extern "C" { 21 | // TODO: it would probably be good to do this in a way that allows for writing to an OutputStream 22 | SWIGEXPORT jbyteArray JNICALL Java_com_microsoft_scalatorch_torch_internal_torch_1swigJNI_save_1##NAME##_1to_1byte_1array(JNIEnv * env, jclass, jlong pA, jobject pA_) { 23 | std::ostringstream out; 24 | TPE* v = *(TPE **)&pA; 25 | torch::save(*v, out); 26 | // TODO: it would probably be better to do this in a way that allows for zero copy 27 | auto str = std::move(out.str()); 28 | size_t size = str.size(); 29 | jbyteArray result = env->NewByteArray(size); 30 | if (result == nullptr) { 31 | java_throw(env, "java/lang/OutOfMemoryError", "Unable to allocate new byte array"); 32 | return nullptr; 33 | } 34 | env->SetByteArrayRegion(result, 0, size, (const jbyte*)str.c_str()); 35 | return result; 36 | } 37 | 38 | SWIGEXPORT void JNICALL Java_com_microsoft_scalatorch_torch_internal_torch_1swigJNI_load_1##NAME##_1from_1byte_1array(JNIEnv * env, jclass, jlong pA, jobject pA_, jbyteArray arr) { 39 | TPE* v = *(TPE **)&pA; 40 | 41 | size_t len = env->GetArrayLength(arr); 42 | char* buf = static_cast(env->GetPrimitiveArrayCritical(arr, 0)); 43 | if (buf == nullptr) { 44 | java_throw(env, "java/lang/OutOfMemoryError", "Unable to get JNI array"); 45 | return; 46 | } 47 | std::istringstream in(std::string(buf, len)); 48 | env->ReleasePrimitiveArrayCritical(arr, buf, 0); 49 | 50 | torch::load(*v, in); 51 | return; 52 | } 53 | } 54 | 55 | %} 56 | 57 | 58 | %enddef 59 | 60 | 61 | SERIALIZATION_SET(TensorVector, std::vector, tensor_vec) 62 | SERIALIZATION_SET(Optimizer, torch::optim::Optimizer, opt) 63 | SERIALIZATION_SET(Tensor, torch::Tensor, t) 64 | 65 | %{ 66 | extern "C" { 67 | // TODO: it would probably be good to do this in a way that allows for writing to an OutputStream 68 | SWIGEXPORT jbyteArray JNICALL Java_com_microsoft_scalatorch_torch_internal_torch_1swigJNI_save_1Module_1to_1byte_1array(JNIEnv * env, jclass, jlong pA, jobject pA_) { 69 | std::ostringstream out; 70 | torch::jit::Module* v = *(torch::jit::Module **)&pA; 71 | v->save(out); 72 | // TODO: it would probably be better to do this in a way that allows for zero copy 73 | auto str = std::move(out.str()); 74 | size_t size = str.size(); 75 | jbyteArray result = env->NewByteArray(size); 76 | if (result == nullptr) { 77 | java_throw(env, "java/lang/OutOfMemoryError", "Unable to allocate new byte array"); 78 | return nullptr; 79 | } 80 | env->SetByteArrayRegion(result, 0, size, (const jbyte*)str.c_str()); 81 | return result; 82 | } 83 | 84 | SWIGEXPORT jlong JNICALL Java_com_microsoft_scalatorch_torch_internal_torch_1swigJNI_load_1Module_1from_1byte_1array(JNIEnv * env, jclass, jbyteArray arr) { 85 | jlong jresult = 0; 86 | 87 | size_t len = env->GetArrayLength(arr); 88 | char* buf = static_cast(env->GetPrimitiveArrayCritical(arr, 0)); 89 | if (buf == nullptr) { 90 | java_throw(env, "java/lang/OutOfMemoryError", "Unable to get JNI array"); 91 | return jresult; 92 | } 93 | std::istringstream in(std::string(buf, len)); 94 | env->ReleasePrimitiveArrayCritical(arr, buf, 0); 95 | 96 | *(torch::jit::Module **)&jresult = new torch::jit::Module(torch::jit::load(in)); 97 | return jresult; 98 | } 99 | } 100 | 101 | %} 102 | 103 | /** uses the above to write to a java byte array */ 104 | %native (save_Module_to_byte_array) jbyteArray save_Module_to_byte_array(const torch::jit::Module & module); 105 | /** uses the above to load from a java byte array */ 106 | %native (load_Module_from_byte_array) torch::jit::Module load_Module_from_byte_array(jbyteArray arr); 107 | 108 | // IValue/pickles 109 | namespace torch { 110 | namespace jit { 111 | std::vector pickle_save(const IValue &ivalue); 112 | } 113 | } 114 | 115 | %inline %{ 116 | IValue unpickle_from_file(const std::string& path) { 117 | std::vector vec; 118 | std::ifstream file(path, std::ios::binary); 119 | if (!file.eof() && !file.fail()) { 120 | file.seekg(0, std::ios_base::end); 121 | std::streampos fileSize = file.tellg(); 122 | vec.resize(fileSize); 123 | 124 | file.seekg(0, std::ios_base::beg); 125 | file.read(&vec[0], fileSize); 126 | } 127 | 128 | return torch::jit::unpickle(vec.data(), vec.size()); 129 | } 130 | 131 | void pickle_save_to_file(const std::string& path, const IValue& ivalue) { 132 | std::vector vec = torch::jit::pickle_save(ivalue); 133 | std::ofstream file(path, std::ios::binary); 134 | if (!file.eof() && !file.fail()) { 135 | file.write( &vec[0], vec.size() ); 136 | } 137 | } 138 | 139 | %} 140 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_std_array.i: -------------------------------------------------------------------------------- 1 | %define STD_ARRAY(SCALAR,SIZE) 2 | %typemap(in) std::array { 3 | auto elems$1 = (SCALAR*)(jenv)->Get##$typemap(jboxtype, SCALAR)##ArrayElements($input, nullptr); 4 | size_t len$1 = (jenv)->GetArrayLength($input); 5 | if (len$1 != SIZE) { 6 | throw std::invalid_argument("Wrong size for fixed size array"); 7 | } 8 | for (int i = 0; i < SIZE; ++i) { 9 | $1[i] = elems$1[i]; 10 | } 11 | (jenv)->Release##$typemap(jboxtype, SCALAR)##ArrayElements($input, ($typemap(jni, SCALAR)*)elems$1, 0); 12 | } 13 | 14 | 15 | %typemap(jni) std::array "$typemap(jni, SCALAR)""Array" 16 | %typemap(jtype) std::array "$typemap(jtype, SCALAR)[]" 17 | %typemap(jstype) std::array "$typemap(jtype, SCALAR)[]" 18 | %typemap(javain) std::array "$javainput" 19 | 20 | %enddef 21 | 22 | STD_ARRAY(bool, 2) 23 | STD_ARRAY(bool, 3) 24 | STD_ARRAY(bool, 4) 25 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_string_view.i: -------------------------------------------------------------------------------- 1 | %{ 2 | #include 3 | %} 4 | 5 | %typemap(jni) c10::string_view "jstring" 6 | %typemap(jtype) c10::string_view "String" 7 | %typemap(jstype) c10::string_view "String" 8 | %typemap(out) c10::string_view %{ $result = jenv->NewStringUTF($1.data()); %} 9 | 10 | %typemap(javain) c10::string_view "$javainput" 11 | %typemap(javaout) c10::string_view { return $jnicall; } -------------------------------------------------------------------------------- /swig/src/main/swig/torch_tensor.i: -------------------------------------------------------------------------------- 1 | namespace torch { 2 | 3 | %include "generated_tensor_bindings.i" 4 | %include "generated_bindings.i" 5 | TO_STRING_FROM_OSTREAM(Tensor); 6 | } 7 | 8 | // These need to be manually defined because of namespace collisions. 9 | // I(adpauls) don't totally understand what's going on, but there are torch:: and at:: versions 10 | // of all of the ATen functions. The torch:: ones should always take precedence, but in some cases 11 | // there's a need to fallback to at:: Mysteriously, there are some clashes just for normal we can't 12 | // invoke torch::normal explicitly 13 | torch::Tensor normal(const torch::Tensor & mean, double std, c10::optional generator); 14 | torch::Tensor normal(double mean, const torch::Tensor & std, c10::optional generator); 15 | torch::Tensor normal(const torch::Tensor & mean, const torch::Tensor & std, c10::optional generator); 16 | torch::Tensor normal(double mean, double std, IntArrayRef size, c10::optional generator, TensorOptions options); 17 | 18 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_tensor_list.i: -------------------------------------------------------------------------------- 1 | %typemap(in) TensorList { 2 | size_t len$1 = (jenv)->GetArrayLength($input); 3 | auto elems$1 = new torch::Tensor[len$1]; 4 | for (size_t i = 0; i < len$1; ++i) { 5 | auto obj = (jenv)->GetObjectArrayElement($input, i); 6 | jclass cls = (jenv)->GetObjectClass(obj); 7 | auto fid = (jenv)->GetFieldID(cls, "swigCPtr", "J"); 8 | (jenv)->DeleteLocalRef(cls); 9 | elems$1[i] = *(torch::Tensor *)(jenv)->GetLongField(obj, fid); 10 | (jenv)->DeleteLocalRef(obj); 11 | } 12 | $1 = TensorList(elems$1, len$1); 13 | } 14 | 15 | %typemap(freearg) TensorList { 16 | delete[] $1.data(); 17 | } 18 | 19 | %typemap(jni) TensorList "jobjectArray" 20 | %typemap(jtype) TensorList "TorchTensor[]" 21 | %typemap(jstype) TensorList "TorchTensor[]" 22 | %typemap(javain) TensorList "$javainput" 23 | -------------------------------------------------------------------------------- /swig/src/main/swig/torch_variant_enum.i: -------------------------------------------------------------------------------- 1 | %define VARIANT_ENUM(ENUM_T, ENUM_CLASS, CONVERT) 2 | 3 | %typemap(jni) const ENUM_T&, ENUM_T "jint" 4 | %typemap(jtype) const ENUM_T&, ENUM_T "int" 5 | %typemap(jstype) const ENUM_T&, ENUM_T "ENUM_CLASS" 6 | %typemap(javain) const ENUM_T&, ENUM_T "$javainput.swigValue()" 7 | %typemap(javaout) const ENUM_T&, ENUM_T { return ENUM_CLASS.swigToEnum($jnicall); } 8 | %typemap(in) const ENUM_T&, ENUM_T %{ 9 | $1 = CONVERT($input); 10 | %} 11 | 12 | %enddef 13 | -------------------------------------------------------------------------------- /swig/src/main/swig/tuple.i: -------------------------------------------------------------------------------- 1 | /* ----------------------------------------------------------------------------- 2 | * Like std_pair.i, but specialized for std::tuples with two arguments 3 | * Swig can't handle variadic templated types like std::tuple. So, instead, 4 | * we make an individual class for each arity. Importantly, each std::tuple 5 | * is implicitly convertible (https://en.cppreference.com/w/cpp/language/implicit_conversion) 6 | * to each tupleN declared here. This means that, somewhat sneakily, 7 | * if a C++ function has a signature like 8 | * std::tuple func(); 9 | * then you can write the following declaration in a swig file 10 | * tuple2 func(); 11 | * and the swig-generated code will still compile because the implicit conversion 12 | * from std::tuple -> tuple2. 13 | * 14 | * ----------------------------------------------------------------------------- */ 15 | 16 | %inline { 17 | template struct tuple2 { 18 | T first; 19 | U second; 20 | 21 | tuple2(const T& f, const U& s):first(f), second(s) {} 22 | 23 | // can be used for implicit conversion from std::tuple<,> 24 | tuple2(const std::tuple &other):tuple2(std::get<0>(other), std::get<1>(other)) {} 25 | }; 26 | } 27 | 28 | // For concrete instantation tuple2, you will need to call this macr. 29 | // e.g. DEFINE_TUPLE2(TensorBoolTuple, Tensor, bool) 30 | %define DEFINE_TUPLE_2(JTuple, T1, T2) 31 | 32 | %template(JTuple) tuple2< T1, T2 >; 33 | 34 | // Wrap with an actual Scala tuple 35 | %typemap(jstype) tuple2< T1, T2 >, const tuple2< T1, T2 >& "scala.Tuple2<$typemap(jboxtype, T1), $typemap(jboxtype, T2)>" 36 | %typemap(javaout) tuple2< T1, T2 >, const tuple2< T1, T2 >& { 37 | JTuple jTuple = new JTuple($jnicall, $owner); 38 | try { 39 | return new scala.Tuple2<$typemap(jboxtype, T1), $typemap(jboxtype, T2)>(jTuple.getFirst(), jTuple.getSecond()); 40 | } finally { 41 | jTuple.delete(); 42 | } 43 | } 44 | 45 | %enddef 46 | 47 | %inline { 48 | template struct tuple3 { 49 | T first; 50 | U second; 51 | V third; 52 | 53 | tuple3(const T& f, const U& s, const V& v):first(f), second(s), third(v) {} 54 | tuple3(const std::tuple &other):tuple3(std::get<0>(other), std::get<1>(other), std::get<2>(other)) {} 55 | }; 56 | } 57 | 58 | %define DEFINE_TUPLE_3(JTuple, T1, T2, T3) 59 | 60 | %template(JTuple) tuple3< T1, T2, T3 >; 61 | 62 | // Wrap with an actual Scala tuple 63 | %typemap(jstype) tuple3< T1, T2, T3 >, const tuple3< T1, T2, T3 >& "scala.Tuple3<$typemap(jboxtype, T1), $typemap(jboxtype, T2), $typemap(jboxtype, T3)>" 64 | %typemap(javaout) tuple3< T1, T2, T3 >, const tuple3< T1, T2, T3 >& { 65 | JTuple jTuple = new JTuple($jnicall, $owner); 66 | try { 67 | return new scala.Tuple3<$typemap(jboxtype, T1), $typemap(jboxtype, T2), $typemap(jboxtype, T3)>(jTuple.getFirst(), jTuple.getSecond(), jTuple.getThird()); 68 | } finally { 69 | jTuple.delete(); 70 | } 71 | } 72 | 73 | %enddef 74 | 75 | %inline { 76 | template struct tuple4 { 77 | T first; 78 | U second; 79 | V third; 80 | W fourth; 81 | 82 | tuple4(const T& f, const U& s, const V& v, const W& w):first(f), second(s), third(v), fourth(w) {} 83 | tuple4(const std::tuple &other):tuple4(std::get<0>(other), std::get<1>(other), std::get<2>(other), std::get<3>(other)) {} 84 | }; 85 | } 86 | 87 | %define DEFINE_TUPLE_4(JTuple, T1, T2, T3, T4) 88 | 89 | %template(JTuple) tuple4< T1, T2, T3, T4 >; 90 | 91 | // Wrap with an actual Scala tuple 92 | %typemap(jstype) tuple4< T1, T2, T3, T4 >, const tuple4< T1, T2, T3, T4 >& "scala.Tuple4<$typemap(jboxtype, T1), $typemap(jboxtype, T2), $typemap(jboxtype, T3), $typemap(jboxtype, T4)>" 93 | %typemap(javaout) tuple4< T1, T2, T3, T4 >, const tuple4< T1, T2, T3, T4 >& { 94 | JTuple jTuple = new JTuple($jnicall, $owner); 95 | try { 96 | return new scala.Tuple4<$typemap(jboxtype, T1), $typemap(jboxtype, T2), $typemap(jboxtype, T3), $typemap(jboxtype, T4)>(jTuple.getFirst(), jTuple.getSecond(), jTuple.getThird(), jTuple.getFourth()); 97 | } finally { 98 | jTuple.delete(); 99 | } 100 | } 101 | 102 | %enddef 103 | 104 | %inline { 105 | template struct tuple5 { 106 | T first; 107 | U second; 108 | V third; 109 | W fourth; 110 | X fifth; 111 | 112 | tuple5(const T& f, const U& s, const V& v, const W& w, const X&x ):first(f), second(s), third(v), fourth(w), fifth(x) {} 113 | tuple5(const std::tuple &other):tuple5(std::get<0>(other), std::get<1>(other), std::get<2>(other), std::get<3>(other), std::get<4>(other)) {} 114 | }; 115 | } 116 | 117 | %define DEFINE_TUPLE_5(JTuple, T1, T2, T3, T4, T5) 118 | 119 | %template(JTuple) tuple5< T1, T2, T3, T4, T5 >; 120 | 121 | // Wrap with an actual Scala tuple 122 | %typemap(jstype) tuple5< T1, T2, T3, T4, T5 >, const tuple5< T1, T2, T3, T4, T5 >& "scala.Tuple5<$typemap(jboxtype, T1), $typemap(jboxtype, T2), $typemap(jboxtype, T3), $typemap(jboxtype, T4), $typemap(jboxtype, T5)>" 123 | %typemap(javaout) tuple5< T1, T2, T3, T4, T5 >, const tuple5< T1, T2, T3, T4, T5 >& { 124 | JTuple jTuple = new JTuple($jnicall, $owner); 125 | try { 126 | return new scala.Tuple5<$typemap(jboxtype, T1), $typemap(jboxtype, T2), $typemap(jboxtype, T3), $typemap(jboxtype, T4), $typemap(jboxtype, T5)>(jTuple.getFirst(), jTuple.getSecond(), jTuple.getThird(), jTuple.getFourth(), jTuple.getFifth()); 127 | } finally { 128 | jTuple.delete(); 129 | } 130 | } 131 | 132 | %enddef 133 | -------------------------------------------------------------------------------- /swig/src/native/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # A minimal CMake file that is compatible with sbt-jni # 3 | # # 4 | # All settings required by sbt-jni have been marked so, please # 5 | # add/modify/remove settings to build your specific library. # 6 | ################################################################ 7 | 8 | cmake_minimum_required(VERSION 3.15.0) 9 | if (WIN32) 10 | # Explained here: 11 | # https://cmake.org/cmake/help/v3.15/prop_tgt/MSVC_RUNTIME_LIBRARY.html 12 | # The policy is needed to set the runtime library. This chosen runtime 13 | # library is described here: 14 | # https://docs.microsoft.com/en-us/cpp/build/reference/md-mt-ld-use-run-time-library?view=vs-2019 15 | # Selecting it fixes a crash observed when running libtorch with multiple threads 16 | # after compilation. 17 | set(CMAKE_MSVC_RUNTIME_LIBRARY MultiThreadedDLL) 18 | cmake_policy(SET CMP0091 NEW) 19 | endif() 20 | 21 | ## See https://discuss.pytorch.org/t/libtorch-static-library/73178/9 22 | 23 | ## https://stackoverflow.com/questions/51047978/cmake-could-not-find-jni 24 | set(JAVA_AWT_LIBRARY NotNeeded) 25 | set(JAVA_JVM_LIBRARY NotNeeded) 26 | 27 | # For local development, it's easiest to build a library that remembers absolute paths 28 | # to the torch libs instead of using an @rpath 29 | if(DEFINED ENV{LINK_TO_BUILD_LIB}) 30 | set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) 31 | endif() 32 | 33 | # Note that we use the value of the environment variable STATIC_PYTORCH if available 34 | # TODO Currently, STATIC_PYTORCH ON only works for MacOS. Windows compiles but 35 | # fails to load at runtime with "DLL initializer failed" 36 | option(STATIC_PYTORCH "Is Pytorch Statically linked?" OFF) 37 | 38 | if(DEFINED ENV{STATIC_PYTORCH}) 39 | set(STATIC_PYTORCH $ENV{STATIC_PYTORCH}) 40 | endif() 41 | if(DEFINED ENV{TORCH_DIR}) 42 | # Fix the path separator characters, so that cmake doesn't 43 | # terminate with "/s isn't a valid control character"-style errors 44 | file(TO_CMAKE_PATH "$ENV{TORCH_DIR}" TORCH_DIR) 45 | endif() 46 | message (STATUS "Static Pytorch: ${STATIC_PYTORCH}") 47 | message (STATUS "torch dir?: ${TORCH_DIR}") 48 | 49 | if (NOT TORCH_DIR) 50 | get_filename_component(TORCH_DIR 51 | "${CMAKE_CURRENT_LIST_DIR}/../../../libtorch" 52 | ABSOLUTE) 53 | endif() 54 | message (STATUS "final torch dir: ${TORCH_DIR}") 55 | 56 | list(APPEND CMAKE_MODULE_PATH "${TORCH_DIR}/../cmake/Modules") 57 | 58 | if (NOT WIN32) 59 | list(APPEND CMAKE_MODULE_PATH "/usr/local/cmake/Modules") 60 | endif() 61 | 62 | # If we're statically linking pytorch, we link in some dynamic libraries 63 | # from conda such as mkl and mkl-include 64 | if (STATIC_PYTORCH) 65 | 66 | if (NOT WIN32) 67 | if (NOT CONDA_PATH) 68 | execute_process(COMMAND which conda OUTPUT_VARIABLE CONDA_PATH) 69 | get_filename_component(CONDA_PATH 70 | "${CONDA_PATH}/../../" 71 | ABSOLUTE) 72 | endif() 73 | message (STATUS "Conda Path: ${CONDA_PATH}") 74 | 75 | list(APPEND CMAKE_PREFIX_PATH "${CONDA_PATH}/share/cmake-3.14") 76 | endif() 77 | endif() 78 | 79 | message (STATUS "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}") 80 | message (STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}/share/cmake-3.14") 81 | 82 | 83 | # Define project and related variables 84 | # (required by sbt-jni) please use semantic versioning 85 | # 86 | project (torch_swig) 87 | set(PROJECT_VERSION_MAJOR 0) 88 | set(PROJECT_VERSION_MINOR 1) 89 | set(PROJECT_VERSION_PATCH 0) 90 | 91 | # Taken from pytorch/CMakeLists.txt 92 | set(CMAKE_CXX_STANDARD 14) 93 | 94 | # Setup JNI 95 | find_package(JNI REQUIRED) 96 | if (JNI_FOUND) 97 | message (STATUS "JNI include directories: ${JNI_INCLUDE_DIRS}") 98 | endif() 99 | 100 | 101 | # Torch 102 | find_package(Torch REQUIRED PATHS "${TORCH_DIR}/share/cmake") 103 | if (TORCH_FOUND) 104 | message (STATUS "Torch include directories: ${TORCH_INCLUDE_DIRS}") 105 | endif() 106 | 107 | if (STATIC_PYTORCH) 108 | #Protobuf 109 | set(Protobuf_DEBUG ON) 110 | if (WIN32) 111 | find_package(Protobuf REQUIRED PATHS "${TORCH_DIR}/cmake") 112 | else() 113 | find_package(Protobuf REQUIRED PATHS "${TORCH_DIR}/lib/cmake") 114 | endif() 115 | message (STATUS "${Protobuf_FOUND}") 116 | message (STATUS "${Protobuf_VERSION}") 117 | # TODO: these aren't getting set??? 118 | message (STATUS "Protobuf include directories: ${Protobuf_INCLUDE_DIRS}") 119 | message (STATUS "Protobuf libraries: ${Protobuf_LIBRARIES}") 120 | message (STATUS "Protobuf libraries: ${PROTOBUF_LIBRARIES}") 121 | 122 | endif(STATIC_PYTORCH) 123 | 124 | # Include directories 125 | include_directories(.) 126 | include_directories(include) 127 | include_directories(${JNI_INCLUDE_DIRS}) 128 | include_directories(${TORCH_INCLUDE_DIRS}) 129 | include_directories(${Protobuf_INCLUDE_DIRS}) 130 | include_directories(${PROTOBUF_INCLUDE_DIRS}) 131 | if (NOT STATIC_PYTORCH) 132 | include_directories(${GFLAGS_INCLUDE_DIR}) 133 | endif (NOT STATIC_PYTORCH) 134 | 135 | # Sources 136 | file(GLOB LIB_SRC 137 | "../../target/src_managed/native/*.cxx" 138 | ) 139 | 140 | # compiler flags 141 | add_compile_options(-O2) 142 | if (MSVC) 143 | add_compile_options(/bigobj) 144 | else() 145 | # swig depends on type punning to do its work and you have to tell gcc you're going to do that. 146 | add_compile_options(-fno-strict-aliasing) 147 | endif() 148 | # -fpermissive: gcc (correctly, but pedantically) regards long and long long as different types, but 149 | # they're actually the same under all sane modern UNIXes. 150 | # (They are different under windows and we'll need to address that) 151 | # (adpauls): I'm not sure why -fpermissive is gone but I'll leave this comment 152 | # here anyways 153 | 154 | 155 | # TODO: this shouldn't be necessary, but FindTorch and FindProtobuf aren't working 156 | # and there are probably other libs (onnx, fbgemm) that aren't getting picked up either. 157 | # a bunch of things aren't getting found because pytorch really doesn't like static linking, but we will prevail 158 | if (STATIC_PYTORCH) 159 | set(ABS_LIB_DIR "${TORCH_DIR}/lib/") 160 | link_directories("${ABS_LIB_DIR}") 161 | endif (STATIC_PYTORCH) 162 | 163 | # Setup installation targets 164 | # (required by sbt-jni) major version should always be appended to library name 165 | # 166 | set (LIB_NAME ${PROJECT_NAME}${PROJECT_VERSION_MAJOR}) 167 | add_library(${LIB_NAME} SHARED ${LIB_SRC}) 168 | 169 | if (STATIC_PYTORCH) 170 | if (WIN32) 171 | target_link_libraries(${LIB_NAME} libprotobuf onnx onnx_proto caffe2_protos clog cpuinfo c10 foxi_loader Caffe2_perfkernels_avx Caffe2_perfkernels_avx2 Caffe2_perfkernels_avx512) 172 | else() 173 | # TODO "-framework Accelerate" is MacOS specific. Will have to figure out what to do for Linux. 174 | target_link_libraries(${LIB_NAME} -lprotobuf -lasmjit -lfbgemm -lcaffe2_protos -lonnx -lonnx_proto -lfoxi_loader -lqnnpack -lsleef "-framework Accelerate" -Wl,-rpath,${CONDA_PATH}/lib) 175 | endif() 176 | endif (STATIC_PYTORCH) 177 | 178 | # libs 179 | if (NOT STATIC_PYTORCH) 180 | target_link_libraries(${LIB_NAME} "${MKL_LIBRARIES}") 181 | target_link_libraries(${LIB_NAME} "${GFLAGS_LIBRARIES}") 182 | endif (NOT STATIC_PYTORCH) 183 | target_link_libraries(${LIB_NAME} "${TORCH_LIBRARIES}") 184 | if (NOT WIN32) 185 | target_link_libraries(${LIB_NAME} "-lpthread") 186 | endif() 187 | 188 | message (STATUS "torch libs : ${TORCH_LIBRARIES}") 189 | install(TARGETS ${LIB_NAME} LIBRARY DESTINATION .) 190 | 191 | --------------------------------------------------------------------------------