├── .envrc ├── .github └── workflows │ ├── ci.yml │ └── clean.yml ├── .gitignore ├── .scalafmt.conf ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── build.sbt ├── core └── src │ ├── main │ └── scala │ │ └── torch │ │ ├── DType.scala │ │ ├── Device.scala │ │ ├── Generator.scala │ │ ├── Layout.scala │ │ ├── MemoryFormat.scala │ │ ├── Tensor.scala │ │ ├── Types.scala │ │ ├── cuda │ │ └── package.scala │ │ ├── data │ │ ├── DataLoader.scala │ │ ├── Example.scala │ │ ├── TensorDataset.scala │ │ └── TensorSeq.scala │ │ ├── hub.scala │ │ ├── indexing.scala │ │ ├── internal │ │ └── NativeConverters.scala │ │ ├── nn │ │ ├── functional │ │ │ ├── Activations.scala │ │ │ ├── Convolution.scala │ │ │ ├── Dropout.scala │ │ │ ├── Linear.scala │ │ │ ├── Loss.scala │ │ │ ├── Pooling.scala │ │ │ ├── Sparse.scala │ │ │ └── package.scala │ │ ├── init.scala │ │ ├── loss │ │ │ └── CrossEntropyLoss.scala │ │ ├── modules │ │ │ ├── Module.scala │ │ │ ├── activation │ │ │ │ ├── LogSoftmax.scala │ │ │ │ ├── ReLU.scala │ │ │ │ ├── Softmax.scala │ │ │ │ └── Tanh.scala │ │ │ ├── batchnorm │ │ │ │ ├── BatchNorm1d.scala │ │ │ │ └── BatchNorm2d.scala │ │ │ ├── container │ │ │ │ ├── ModuleList.scala │ │ │ │ └── Sequential.scala │ │ │ ├── conv │ │ │ │ └── Conv2d.scala │ │ │ ├── flatten │ │ │ │ └── Flatten.scala │ │ │ ├── linear │ │ │ │ ├── Identity.scala │ │ │ │ └── Linear.scala │ │ │ ├── normalization │ │ │ │ ├── GroupNorm.scala │ │ │ │ └── LayerNorm.scala │ │ │ ├── pooling │ │ │ │ ├── AdaptiveAvgPool2d.scala │ │ │ │ └── MaxPool2d.scala │ │ │ ├── regularization │ │ │ │ └── Dropout.scala │ │ │ └── sparse │ │ │ │ └── Embedding.scala │ │ ├── package.scala │ │ └── utils.scala │ │ ├── ops │ │ ├── BLASOps.scala │ │ ├── ComparisonOps.scala │ │ ├── CreationOps.scala │ │ ├── IndexingSlicingJoiningOps.scala │ │ ├── OtherOps.scala │ │ ├── PointwiseOps.scala │ │ ├── RandomSamplingOps.scala │ │ ├── ReductionOps.scala │ │ └── package.scala │ │ ├── optim │ │ ├── Adam.scala │ │ ├── AdamW.scala │ │ ├── Optimizer.scala │ │ ├── SGD.scala │ │ └── lr_scheduler │ │ │ ├── LRScheduler.scala │ │ │ └── StepLR.scala │ │ ├── package.scala │ │ └── special │ │ └── package.scala │ └── test │ └── scala │ ├── TrainingSuite.scala │ └── torch │ ├── DeviceSuite.scala │ ├── Generators.scala │ ├── TensorCheckSuite.scala │ ├── TensorSuite.scala │ ├── nn │ ├── ConvolutionSuite.scala │ ├── PoolingSuite.scala │ ├── functional │ │ └── SparseSuite.scala │ └── modules │ │ ├── ActivationSuite.scala │ │ ├── BatchNormSuite.scala │ │ ├── EmbeddingSuite.scala │ │ ├── FlattenSuite.scala │ │ ├── LinearSuite.scala │ │ ├── NormalizationSuite.scala │ │ └── PoolingSuite.scala │ └── ops │ ├── ComparisonOpsSuite.scala │ ├── CreationOpsSuite.scala │ ├── IndexingSlicingJoiningOpsSuite.scala │ ├── OtherOpsSuite.scala │ ├── PointwiseOpsSuite.scala │ ├── RandomSamplingOpsSuite.scala │ └── ReductionOpsSuite.scala ├── devenv.lock ├── devenv.nix ├── devenv.yaml ├── docs ├── about.md ├── contributing.md ├── directory.conf ├── examples.md ├── faq.md ├── installation.md ├── modules.md ├── pre-trained-weights.md └── tutorial │ ├── autograd.md │ ├── buildmodel.md │ ├── directory.conf │ ├── img │ └── comp-graph.png │ └── tensors.md ├── examples └── src │ └── main │ └── scala │ ├── ImageClassifier.scala │ ├── LeNet.scala │ └── gpt │ ├── Utils.scala │ └── V2.scala ├── git-hooks └── pre-push-checks ├── project ├── SiteSettings.scala ├── build.properties └── plugins.sbt ├── scripts └── convert-weights │ ├── convert_weights.py │ └── requirements.txt ├── site └── src │ ├── css │ └── custom.css │ ├── img │ └── storch.svg │ ├── js │ └── render-katex.js │ └── landing-page.md └── vision └── src ├── main └── scala │ └── torchvision │ ├── datasets │ └── MNIST.scala │ ├── models │ └── resnet.scala │ └── transforms │ ├── functional.scala │ └── presets.scala └── test └── scala └── torchvision └── MNISTSuite.scala /.envrc: -------------------------------------------------------------------------------- 1 | watch_file devenv.nix 2 | watch_file devenv.yaml 3 | watch_file devenv.lock 4 | eval "$(devenv print-dev-env)" -------------------------------------------------------------------------------- /.github/workflows/clean.yml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by sbt-github-actions using the 2 | # githubWorkflowGenerate task. You should add and commit this file to 3 | # your git repository. It goes without saying that you shouldn't edit 4 | # this file by hand! Instead, if you wish to make changes, you should 5 | # change your sbt build configuration to revise the workflow description 6 | # to meet your needs, then regenerate this file. 7 | 8 | name: Clean 9 | 10 | on: push 11 | 12 | jobs: 13 | delete-artifacts: 14 | name: Delete Artifacts 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | steps: 19 | - name: Delete artifacts 20 | run: | 21 | # Customize those three lines with your repository and credentials: 22 | REPO=${GITHUB_API_URL}/repos/${{ github.repository }} 23 | 24 | # A shortcut to call GitHub API. 25 | ghapi() { curl --silent --location --user _:$GITHUB_TOKEN "$@"; } 26 | 27 | # A temporary file which receives HTTP response headers. 28 | TMPFILE=/tmp/tmp.$$ 29 | 30 | # An associative array, key: artifact name, value: number of artifacts of that name. 31 | declare -A ARTCOUNT 32 | 33 | # Process all artifacts on this repository, loop on returned "pages". 34 | URL=$REPO/actions/artifacts 35 | while [[ -n "$URL" ]]; do 36 | 37 | # Get current page, get response headers in a temporary file. 38 | JSON=$(ghapi --dump-header $TMPFILE "$URL") 39 | 40 | # Get URL of next page. Will be empty if we are at the last page. 41 | URL=$(grep '^Link:' "$TMPFILE" | tr ',' '\n' | grep 'rel="next"' | head -1 | sed -e 's/.*.*//') 42 | rm -f $TMPFILE 43 | 44 | # Number of artifacts on this page: 45 | COUNT=$(( $(jq <<<$JSON -r '.artifacts | length') )) 46 | 47 | # Loop on all artifacts on this page. 48 | for ((i=0; $i < $COUNT; i++)); do 49 | 50 | # Get name of artifact and count instances of this name. 51 | name=$(jq <<<$JSON -r ".artifacts[$i].name?") 52 | ARTCOUNT[$name]=$(( $(( ${ARTCOUNT[$name]} )) + 1)) 53 | 54 | id=$(jq <<<$JSON -r ".artifacts[$i].id?") 55 | size=$(( $(jq <<<$JSON -r ".artifacts[$i].size_in_bytes?") )) 56 | printf "Deleting '%s' #%d, %'d bytes\n" $name ${ARTCOUNT[$name]} $size 57 | ghapi -X DELETE $REPO/actions/artifacts/$id 58 | done 59 | done 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .bsp/ 2 | .idea/ 3 | .vscode/ 4 | target/ 5 | .bleep/ 6 | .bloop/ 7 | .metals 8 | metals.sbt 9 | *.worksheet.sc 10 | /data/ 11 | .scala-build/ 12 | 13 | # Devenv 14 | .devenv* 15 | devenv.local.nix 16 | 17 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version= "3.6.1" 2 | maxColumn = 100 3 | runner.dialect = scala3 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Storch 2 | 3 | Please have a look at the contributor docs on the [website](https://storch.dev/contribute.html) 4 | or here in the [repo](docs/contributing.md). 5 | 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Storch - GPU Accelerated Deep Learning for Scala 3 2 | 3 | Storch is a Scala library for fast tensor computations and deep learning, based on PyTorch. 4 | 5 | Like PyTorch, Storch provides 6 | * A NumPy like API for working with tensors 7 | * GPU support 8 | * Automatic differentiation 9 | * A neural network API for building and training neural networks. 10 | 11 | Storch aims to close to the Python API to make porting existing models and the life of people already familiar with PyTorch easier. 12 | 13 | For documentation, see https://storch.dev 14 | 15 | ## Example: 16 | 17 | ```scala 18 | val data = Seq(0,1,2,3) 19 | // data: Seq[Int] = List(0, 1, 2, 3) 20 | val t1 = torch.Tensor(data) 21 | // t1: Tensor[Int32] = dtype=int32, shape=[4], device=CPU 22 | // [0, 1, 2, 3] 23 | t1.equal(torch.arange(0,4)) 24 | // res0: Boolean = true 25 | val t2 = t1.to(dtype=float32) 26 | // t2: Tensor[Float32] = dtype=float32, shape=[4], device=CPU 27 | // [0,0000, 1,0000, 2,0000, 3,0000] 28 | val t3 = t1 + t2 29 | // t3: Tensor[Float32] = dtype=float32, shape=[4], device=CPU 30 | // [0,0000, 2,0000, 4,0000, 6,0000] 31 | 32 | val shape = Seq(2l,3l) 33 | // shape: Seq[Long] = List(2, 3) 34 | val randTensor = torch.rand(shape) 35 | // randTensor: Tensor[Float32] = dtype=float32, shape=[2, 3], device=CPU 36 | // [[0,4341, 0,9738, 0,9305], 37 | // [0,8987, 0,1122, 0,3912]] 38 | val zerosTensor = torch.zeros(shape, dtype=torch.int64) 39 | // zerosTensor: Tensor[Int64] = dtype=int64, shape=[2, 3], device=CPU 40 | // [[0, 0, 0], 41 | // [0, 0, 0]] 42 | 43 | val x = torch.ones(Seq(5)) 44 | // x: Tensor[Float32] = dtype=float32, shape=[5], device=CPU 45 | // [1,0000, 1,0000, 1,0000, 1,0000, 1,0000] 46 | val w = torch.randn(Seq(5, 3), requiresGrad=true) 47 | // w: Tensor[Float32] = dtype=float32, shape=[5, 3], device=CPU 48 | // [[0,8975, 0,5484, 0,2307], 49 | // [0,2689, 0,7430, 0,6446], 50 | // [0,9503, 0,6342, 0,7523], 51 | // [0,5332, 0,7497, 0,3665], 52 | // [0,3376, 0,6040, 0,5033]] 53 | val b = torch.randn(Seq(3), requiresGrad=true) 54 | // b: Tensor[Float32] = dtype=float32, shape=[3], device=CPU 55 | // [0,2638, 0,9697, 0,3664] 56 | val z = (x matmul w) + b 57 | // z: Tensor[Float32] = dtype=float32, shape=[3], device=CPU 58 | // [3,2513, 4,2490, 2,8640] 59 | ``` 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | 3 | import Keys._ 4 | import MdocPlugin.autoImport._ 5 | import LaikaPlugin.autoImport._ 6 | 7 | ThisBuild / tlBaseVersion := "0.0" // your current series x.y 8 | 9 | ThisBuild / organization := "dev.storch" 10 | ThisBuild / organizationName := "storch.dev" 11 | ThisBuild / startYear := Some(2022) 12 | ThisBuild / licenses := Seq(License.Apache2) 13 | ThisBuild / developers := List( 14 | // your GitHub handle and name 15 | tlGitHubDev("sbrunk", "Sören Brunk") 16 | ) 17 | 18 | // publish to s01.oss.sonatype.org (set to true to publish to oss.sonatype.org instead) 19 | ThisBuild / tlSonatypeUseLegacyHost := false 20 | 21 | // publish website from this branch 22 | ThisBuild / tlSitePublishBranch := Some("main") 23 | 24 | ThisBuild / apiURL := Some(new URL("https://storch.dev/api/")) 25 | 26 | val scrImageVersion = "4.0.34" 27 | val pytorchVersion = "2.1.2" 28 | val cudaVersion = "12.3-8.9" 29 | val openblasVersion = "0.3.26" 30 | val mklVersion = "2024.0" 31 | ThisBuild / scalaVersion := "3.3.1" 32 | ThisBuild / javaCppVersion := "1.5.10" 33 | ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots") 34 | 35 | ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11")) 36 | ThisBuild / githubWorkflowOSes := Seq("macos-latest", "ubuntu-latest", "windows-latest") 37 | 38 | val enableGPU = settingKey[Boolean]("enable or disable GPU support") 39 | 40 | ThisBuild / enableGPU := false 41 | 42 | val hasMKL = { 43 | val firstPlatform = org.bytedeco.sbt.javacpp.Platform.current.head 44 | firstPlatform == "linux-x86_64" || firstPlatform == "windows-x86_64" 45 | } 46 | 47 | lazy val commonSettings = Seq( 48 | Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"), 49 | javaCppVersion := (ThisBuild / javaCppVersion).value, 50 | javaCppPlatform := Seq(), 51 | // This is a hack to avoid depending on the native libs when publishing 52 | // but conveniently have them on the classpath during development. 53 | // There's probably a cleaner way to do this. 54 | tlJdkRelease := Some(11) 55 | ) ++ tlReplaceCommandAlias( 56 | "tlReleaseLocal", 57 | List( 58 | "reload", 59 | "project /", 60 | "set core / javaCppPlatform := Seq()", 61 | "set core / javaCppPresetLibs := Seq()", 62 | "+publishLocal" 63 | ).mkString("; ", "; ", "") 64 | ) ++ tlReplaceCommandAlias( 65 | "tlRelease", 66 | List( 67 | "reload", 68 | "project /", 69 | "set core / javaCppPlatform := Seq()", 70 | "set core / javaCppPresetLibs := Seq()", 71 | "+mimaReportBinaryIssues", 72 | "+publish", 73 | "tlSonatypeBundleReleaseIfRelevant" 74 | ).mkString("; ", "; ", "") 75 | ) 76 | 77 | lazy val core = project 78 | .in(file("core")) 79 | .settings(commonSettings) 80 | .settings( 81 | javaCppPresetLibs ++= Seq( 82 | (if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion, 83 | "openblas" -> openblasVersion 84 | ) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq()) 85 | ++ (if (hasMKL) Seq("mkl" -> mklVersion) else Seq()), 86 | javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current, 87 | fork := true, 88 | Test / fork := true, 89 | libraryDependencies ++= Seq( 90 | "org.bytedeco" % "pytorch" % s"$pytorchVersion-${javaCppVersion.value}", 91 | "org.typelevel" %% "spire" % "0.18.0", 92 | "org.typelevel" %% "shapeless3-typeable" % "3.3.0", 93 | "com.lihaoyi" %% "os-lib" % "0.9.1", 94 | "com.lihaoyi" %% "sourcecode" % "0.3.0", 95 | "dev.dirs" % "directories" % "26", 96 | "org.scalameta" %% "munit" % "0.7.29" % Test, 97 | "org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test 98 | ) 99 | ) 100 | 101 | lazy val vision = project 102 | .in(file("vision")) 103 | .settings(commonSettings) 104 | .settings( 105 | libraryDependencies ++= Seq( 106 | "com.sksamuel.scrimage" % "scrimage-core" % scrImageVersion, 107 | "com.sksamuel.scrimage" % "scrimage-webp" % scrImageVersion, 108 | "org.scalameta" %% "munit" % "0.7.29" % Test 109 | ) 110 | ) 111 | .dependsOn(core) 112 | 113 | lazy val examples = project 114 | .in(file("examples")) 115 | .enablePlugins(NoPublishPlugin) 116 | .settings( 117 | commonSettings, 118 | // disable discarded non-Unit value warnings in examples for now 119 | scalacOptions ~= (_.filterNot(Set("-Wvalue-discard"))) 120 | ) 121 | .settings( 122 | fork := true, 123 | libraryDependencies ++= Seq( 124 | "me.tongfei" % "progressbar" % "0.9.5", 125 | "com.github.alexarchambault" %% "case-app" % "2.1.0-M24", 126 | "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4" 127 | ) 128 | ) 129 | .dependsOn(vision) 130 | 131 | lazy val docs = project 132 | .in(file("site")) 133 | .enablePlugins(ScalaUnidocPlugin, TypelevelSitePlugin, StorchSitePlugin) 134 | .settings(commonSettings) 135 | .settings( 136 | mdocVariables ++= Map( 137 | "JAVACPP_VERSION" -> javaCppVersion.value, 138 | "PYTORCH_VERSION" -> pytorchVersion, 139 | "OPENBLAS_VERSION" -> openblasVersion, 140 | "MKL_VERSION" -> mklVersion, 141 | "CUDA_VERSION" -> cudaVersion 142 | ), 143 | ScalaUnidoc / unidoc / unidocProjectFilter := inAnyProject -- inProjects(examples), 144 | Laika / sourceDirectories ++= Seq(sourceDirectory.value), 145 | laikaIncludeAPI := true, 146 | laikaGenerateAPI / mappings := (ScalaUnidoc / packageDoc / mappings).value 147 | ) 148 | .dependsOn(vision) 149 | 150 | lazy val root = project 151 | .enablePlugins(NoPublishPlugin) 152 | .in(file(".")) 153 | .aggregate(core, vision, examples, docs) 154 | .settings( 155 | javaCppVersion := (ThisBuild / javaCppVersion).value 156 | ) 157 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/Device.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.bytedeco.pytorch 20 | import scala.collection.immutable.ArraySeq 21 | 22 | enum DeviceType: 23 | case CPU, CUDA, MKLDNN, OPENGL, OPENCL, IDEEP, HIP, FPGA, ORT, XLA, Vulkan, Metal, XPU, MPS, Meta, 24 | HPU, VE, Lazy, IPU, MTIA, PrivateUse1, COMPILE_TIME_MAX_DEVICE_TYPES 25 | 26 | object DeviceType: 27 | val deviceTypesLowerCase: Seq[String] = 28 | ArraySeq.unsafeWrapArray(DeviceType.values).map(_.toString.toLowerCase) 29 | def apply(v: String): DeviceType = 30 | val index = deviceTypesLowerCase.indexOf(v) 31 | if index == -1 then DeviceType.valueOf(v) 32 | else DeviceType.fromOrdinal(index) 33 | 34 | case class Device(device: DeviceType, index: Byte = -1): 35 | private[torch] def toNative: pytorch.Device = pytorch.Device(device.ordinal.toByte, index) 36 | object Device: 37 | def apply(device: String, index: Byte): Device = Device(DeviceType(device), index) 38 | def apply(device: String): Device = Device(device, -1: Byte) 39 | private[torch] def apply(native: pytorch.Device): Device = 40 | Device(DeviceType.fromOrdinal(native.`type`().value), native.index()) 41 | val CPU = Device(DeviceType.CPU) 42 | val CUDA = Device(DeviceType.CUDA) 43 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/Generator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.bytedeco.pytorch.global.torch.make_generator_cpu 20 | import org.bytedeco.pytorch.global.torch_cuda.make_generator_cuda 21 | import torch.internal.NativeConverters.fromNative 22 | 23 | /** Creates and returns a generator object that manages the state of the algorithm which produces 24 | * pseudo random numbers. 25 | */ 26 | class Generator(val device: Device = Device.CPU) { 27 | private[torch] val native = device.device match 28 | case DeviceType.CPU => make_generator_cpu 29 | case DeviceType.CUDA => make_generator_cuda 30 | case _ => throw new IllegalArgumentException("Unsupported generator device") 31 | 32 | /** Returns the Generator state as a [[torch.Tensor[UInt8]]. */ 33 | def getState: Tensor[UInt8] = fromNative(native.get_state()) 34 | 35 | /** Sets the Generator state. 36 | * 37 | * @param newState 38 | * The desired state. 39 | */ 40 | def setState(newState: Tensor[UInt8]) = native.set_state(newState.native) 41 | 42 | /** Returns the initial seed for generating random numbers. */ 43 | def initialSeed: Long = native.seed() 44 | 45 | /** Sets the seed for generating random numbers. Returns a torch.Generator object. 46 | * 47 | * It is recommended to set a large seed, i.e. a number that has a good balance of 0 and 1 bits. 48 | * Avoid having many 0 bits in the seed. 49 | * 50 | * @param seed 51 | * The desired seed. Value must be within the inclusive range 52 | * *[-0x8000_0000_0000_0000,0xffff_ffff_ffff_ffff]*. Otherwise, a RuntimeError is raised. 53 | * Negative inputs are remapped to positive values with the *formula 0xffff_ffff_ffff_ffff + 54 | * seed*. 55 | */ 56 | def manualSeed(seed: Long): Unit = native.set_current_seed(seed) 57 | 58 | /** Gets a non-deterministic random number from std::random_device or the current time and uses it 59 | * to seed a Generator. 60 | */ 61 | def seed: Long = native.current_seed() 62 | } 63 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/Layout.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | import org.bytedeco.pytorch.global.torch as torchNative 19 | 20 | /** A `torch.layout` is an object that represents the memory layout of a torch.Tensor. 21 | * 22 | * Currently, we support ``torch.strided`` (dense Tensors) and have beta support for 23 | * ``torch.sparse_coo`` (sparse COO Tensors). 24 | * 25 | * torch.strided represents dense Tensors and is the memory layout that is most commonly used. Each 26 | * strided tensor has an associated torch.Storage, which holds its data. These tensors provide 27 | * multi-dimensional, strided view of a storage. Strides are a list of integers: the k-th stride 28 | * represents the jump in the memory necessary to go from one element to the next one in the k-th 29 | * dimension of the Tensor. This concept makes it possible to perform many tensor operations 30 | * efficiently. 31 | */ 32 | enum Layout: 33 | case Strided, Sparse, SparseCsr, Mkldnn, NumOptions 34 | private[torch] def toNative: torchNative.Layout = torchNative.Layout.valueOf(this.toString) 35 | 36 | object Layout: 37 | private[torch] def fromNative(native: torchNative.Layout) = Layout.valueOf(native.toString) 38 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/MemoryFormat.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.bytedeco.pytorch 20 | import org.bytedeco.pytorch.global.torch as torchNative 21 | 22 | /** A memoryFormat is an object representing the memory format on which a torch.Tensor is or will be 23 | * allocated. 24 | */ 25 | enum MemoryFormat: 26 | /** Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values 27 | * in decreasing order. 28 | */ 29 | case Contiguous 30 | 31 | /** Used in functions like clone to preserve the memory format of the input tensor. If input 32 | * tensor is allocated in dense non-overlapping memory, the output tensor strides will be copied 33 | * from the input. Otherwise output strides will follow torch.contiguous_format 34 | */ 35 | case Preserve 36 | 37 | /** Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values 38 | * in `strides[0] > strides[2] > strides[3] > strides[1] == 1` aka NHWC order. 39 | */ 40 | case ChannelsLast 41 | case ChannelsLast3d 42 | 43 | private[torch] def toNative: torchNative.MemoryFormat = 44 | torchNative.MemoryFormat.valueOf(this.toString) 45 | private[torch] def toNativeOptional: pytorch.MemoryFormatOptional = 46 | pytorch.MemoryFormatOptional(torchNative.MemoryFormat.valueOf(this.toString)) 47 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/Types.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import shapeless3.typeable.{TypeCase, Typeable} 20 | import shapeless3.typeable.syntax.typeable.* 21 | import scala.util.NotGiven 22 | import spire.math.Complex 23 | 24 | /* Typeable instance for Array[T] 25 | * NOTE: It needs to iterate through the whole array to validate casteability 26 | */ 27 | given iterableTypeable[T](using tt: Typeable[T]): Typeable[Array[T]] with 28 | def castable(t: Any): Boolean = 29 | t match 30 | case (arr: Array[?]) => 31 | arr.forall(_.castable[T]) 32 | case _ => false 33 | def describe = s"Array[${tt.describe}]" 34 | 35 | /* TypeCase helpers to perform pattern matching on `Complex` higher kinded types */ 36 | val complexDoubleArray = TypeCase[Array[Complex[Double]]] 37 | val complexFloatArray = TypeCase[Array[Complex[Float]]] 38 | 39 | /* TypeCase helpers to perform pattern matching on `Seq` higher kinded types */ 40 | val singleSeq = TypeCase[Seq[?]] 41 | val doubleSeq = TypeCase[Seq[Seq[?]]] 42 | val tripleSeq = TypeCase[Seq[Seq[Seq[?]]]] 43 | 44 | /* Type helper to describe inputs that accept Tensor or Real scalars */ 45 | type TensorOrReal[D <: RealNN] = Tensor[D] | Real 46 | 47 | /* Evidence used in operations where Bool is accepted, but only on one of the two inputs, not both 48 | */ 49 | type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool] 50 | 51 | /* Evidence used in operations where at least one Float is required */ 52 | type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN 53 | 54 | /* Evidence used in operations where at least one Float or Complex is required */ 55 | type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | ComplexNN) | 56 | B <:< (FloatNN | ComplexNN) 57 | 58 | /* Evidence that two dtypes are not the same */ 59 | type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2] 60 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/cuda/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.bytedeco.pytorch.global.torch as torchNative 20 | 21 | /** This package adds support for CUDA tensor types, that implement the same function as CPU 22 | * tensors, but they utilize GPUs for computation. 23 | */ 24 | package object cuda { 25 | 26 | /** Returns a Boolean indicating if CUDA is currently available. */ 27 | def isAvailable: Boolean = torchNative.cuda_is_available() 28 | } 29 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/data/DataLoader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package data 19 | 20 | import scala.util.Random 21 | 22 | /** Provides an iterable over batches of a given dataset. */ 23 | class DataLoader[Input, Batch]( 24 | dataset: IndexedSeq[Input], 25 | batchSize: Int = 1, 26 | shuffle: Boolean = false, 27 | collateFn: Seq[Input] => Batch 28 | ) extends Iterable[Batch] { 29 | 30 | override def iterator = 31 | (if shuffle then Random.shuffle(dataset) else dataset) 32 | .grouped(batchSize) 33 | .map(collateFn) 34 | } 35 | 36 | class TupleDataLoader[D1 <: DType, D2 <: DType]( 37 | dataset: IndexedSeq[(Tensor[D1], Tensor[D2])], 38 | batchSize: Int = 1, 39 | shuffle: Boolean = false, 40 | collateFn: Seq[(Tensor[D1], Tensor[D2])] => (Tensor[D1], Tensor[D2]) = 41 | (examples: Seq[(Tensor[D1], Tensor[D2])]) => 42 | (torch.stack(examples.map(_._1)), torch.stack(examples.map(_._2))) 43 | ) extends DataLoader[(Tensor[D1], Tensor[D2]), (Tensor[D1], Tensor[D2])]( 44 | dataset, 45 | batchSize, 46 | shuffle, 47 | collateFn 48 | ) 49 | 50 | class ExampleDataLoader[D1 <: DType, D2 <: DType]( 51 | dataset: IndexedSeq[Example[D1, D2]], 52 | batchSize: Int = 1, 53 | shuffle: Boolean = false, 54 | collateFn: Seq[Example[D1, D2]] => (Tensor[D1], Tensor[D2]) = 55 | (examples: Seq[Example[D1, D2]]) => 56 | (torch.stack(examples.map(_.feature)), torch.stack(examples.map(_.target))) 57 | ) extends DataLoader[Example[D1, D2], (Tensor[D1], Tensor[D2])]( 58 | dataset, 59 | batchSize, 60 | shuffle, 61 | collateFn 62 | ) 63 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/data/Example.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package data 19 | 20 | final case class Example[D1 <: DType, D2 <: DType](feature: Tensor[D1], target: Tensor[D2]) 21 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/data/TensorDataset.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package data 19 | 20 | /** Wraps a pair of tensors as a Seq. 21 | * 22 | * Each sample will be retrieved by indexing tensors along the first dimension. 23 | */ 24 | // TODO can we generalize this to tuples of arbitrary size? 25 | trait TensorDataset[Input <: DType, Target <: DType] 26 | extends IndexedSeq[(Tensor[Input], Tensor[Target])] { 27 | def features: Tensor[Input] 28 | def targets: Tensor[Target] 29 | } 30 | 31 | object TensorDataset { 32 | def apply[Input <: DType, Target <: DType]( 33 | _features: Tensor[Input], 34 | _targets: Tensor[Target] 35 | ): TensorDataset[Input, Target] = new TensorDataset { 36 | val features = _features 37 | val targets = _targets 38 | 39 | require(features.size.length > 0) 40 | require(features.size.head == targets.size.head) 41 | 42 | override def apply(i: Int): (Tensor[Input], Tensor[Target]) = (features(i), targets(i)) 43 | 44 | override def length: Int = features.size.head 45 | 46 | override def toString(): String = 47 | s"TensorDataset(features=${features.info}, targets=${targets.info})" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/data/TensorSeq.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package data 19 | 20 | /** Wraps a tensor as a Seq. 21 | * 22 | * Each sample will be retrieved by indexing tensors along the first dimension. 23 | * 24 | * @param t 25 | * tensor to be wrapped as a seq 26 | */ 27 | class TensorSeq[D <: DType](t: Tensor[D]) extends IndexedSeq[Tensor[D]] { 28 | 29 | require(t.size.length > 0) 30 | require(t.size.head <= Int.MaxValue) 31 | 32 | override def apply(i: Int): Tensor[D] = t(i) 33 | 34 | override def length: Int = t.size.head.toInt 35 | 36 | } 37 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/hub.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import dev.dirs.BaseDirectories 20 | import scala.util.Using 21 | import java.nio.file.Files 22 | import java.net.URL 23 | 24 | /** Utilities to download and cache pre-trained model weights. */ 25 | object hub: 26 | private val storchDir = os.Path(BaseDirectories.get().cacheDir) / "storch" 27 | private val hubDir = storchDir / "hub" 28 | private val modelDir = hubDir / "checkpoints" 29 | 30 | def loadStateDictFromUrl(url: String): Map[String, Tensor[DType]] = 31 | os.makeDir.all(modelDir) 32 | val filename = os.Path(URL(url).getPath).last 33 | val cachedFile = (modelDir / filename) 34 | if !os.exists(cachedFile) then 35 | System.err.println(s"Downloading: $url to $cachedFile") 36 | Using.resource(URL(url).openStream()) { inputStream => 37 | val _ = Files.copy(inputStream, cachedFile.toNIO) 38 | } 39 | torch.pickleLoad(cachedFile.toNIO) 40 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/indexing.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | object indexing: 20 | 21 | case class Slice(start: Option[Int], end: Option[Int], step: Option[Int]) 22 | object Slice: 23 | private def extract(index: Option[Int] | Int) = index match 24 | case i: Option[Int] => i 25 | case i: Int => Option(i) 26 | def apply( 27 | start: Option[Int] | Int = None, 28 | end: Option[Int] | Int = None, 29 | step: Option[Int] | Int = None 30 | ): Slice = Slice(extract(start), extract(end), extract(step)) 31 | 32 | /** Ellipsis or ... in Python syntax. */ 33 | sealed class Ellipsis 34 | 35 | /** Ellipsis or ... in Python syntax. */ 36 | case object Ellipsis extends Ellipsis 37 | 38 | /** Ellipsis or ... in Python syntax. */ 39 | val --- = Ellipsis 40 | 41 | /** Range (colon / :) in python syntax. */ 42 | val :: = Slice() 43 | 44 | /** Allow for {{{t(1.::)}}} and {{{t(1.::(2)}}} */ 45 | extension (start: Int | Option[Int]) 46 | def ::(step: Int | Option[Int]): Slice = 47 | // Note that despite the names, :: reverses the operators, that is a :: b calls b.::(a) 48 | // So step and start are reversed here 49 | Slice(step, None, start) 50 | 51 | def :: : Slice = Slice(start, None, None) 52 | 53 | export indexing.* 54 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/functional/Activations.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package functional 20 | 21 | import Derive.derive 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.global.torch as torchNative 24 | import torch.internal.NativeConverters.fromNative 25 | import org.bytedeco.pytorch.ScalarTypeOptional 26 | 27 | private[torch] trait Activations { 28 | 29 | /** Applies a softmax followed by a logarithm. 30 | * 31 | * While mathematically equivalent to log(softmax(x)), doing these two operations separately is 32 | * slower and numerically unstable. This function uses an alternative formulation to compute the 33 | * output and gradient correctly. 34 | * 35 | * See `torch.nn.LogSoftmax` for more details. 36 | * 37 | * @group nn_activation 38 | */ 39 | def logSoftmax[In <: DType, Out <: FloatNN | Derive]( 40 | input: Tensor[In], 41 | dim: Long, 42 | dtype: Out = derive 43 | ): Tensor[DTypeOrDeriveFromTensor[In, Out]] = 44 | val derivedDType = dtype match 45 | case _: Derive => input.dtype 46 | case d: DType => d 47 | val nativeDType = 48 | if dtype == input.dtype then ScalarTypeOptional() 49 | else ScalarTypeOptional(derivedDType.toScalarType) 50 | fromNative(torchNative.log_softmax(input.native, dim, nativeDType)) 51 | 52 | /** Applies the rectified linear unit function element-wise. 53 | * 54 | * See [[torch.nn.ReLU]] for more details. 55 | * 56 | * @group nn_activation 57 | */ 58 | def relu[D <: DType](input: Tensor[D]): Tensor[D] = fromNative(torchNative.relu(input.native)) 59 | 60 | /** Applies the element-wise function $\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}$ 61 | * 62 | * See `torch.nn.Sigmoid` for more details. 63 | * 64 | * @group nn_activation 65 | */ 66 | def sigmoid[D <: DType](input: Tensor[D]): Tensor[D] = fromNative( 67 | torchNative.sigmoid(input.native) 68 | ) 69 | 70 | /** Applies the Sigmoid Linear Unit (SiLU) function, element-wise. The SiLU function is also known 71 | * as the swish function. 72 | * 73 | * @group nn_activation 74 | */ 75 | def silu[D <: DType](input: Tensor[D]): Tensor[D] = fromNative(torchNative.silu(input.native)) 76 | 77 | /** Applies a softmax function. 78 | * 79 | * @group nn_activation 80 | */ 81 | def softmax[In <: DType, Out <: FloatNN | Derive]( 82 | input: Tensor[In], 83 | dim: Long, 84 | dtype: Out = derive 85 | ): Tensor[DTypeOrDeriveFromTensor[In, Out]] = 86 | val derivedDType = dtype match 87 | case _: Derive => input.dtype 88 | case d: DType => d 89 | val nativeDType = 90 | if dtype == input.dtype then ScalarTypeOptional() 91 | else ScalarTypeOptional(derivedDType.toScalarType) 92 | fromNative(torchNative.softmax(input.native, dim, nativeDType)) 93 | } 94 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/functional/Convolution.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package functional 20 | 21 | import org.bytedeco.pytorch 22 | import org.bytedeco.pytorch.global.torch as torchNative 23 | import torch.internal.NativeConverters.* 24 | 25 | private[torch] trait Convolution { 26 | 27 | /** Applies a 1D convolution over an input signal composed of several input planes. 28 | * 29 | * @group nn_conv 30 | */ 31 | def conv1d[D <: FloatNN | ComplexNN]( 32 | input: Tensor[D], 33 | weight: Tensor[D], 34 | bias: Tensor[D] | Option[Tensor[D]] = None, 35 | stride: Int = 1, 36 | padding: Int = 0, 37 | dilation: Int = 1, 38 | groups: Int = 1 39 | ): Tensor[D] = 40 | fromNative( 41 | torchNative.conv1d( 42 | input.native, 43 | weight.native, 44 | toOptional(bias), 45 | Array(stride.toLong), 46 | Array(padding.toLong), 47 | Array(dilation.toLong), 48 | groups 49 | ) 50 | ) 51 | 52 | /** Applies a 2D convolution over an input signal composed of several input planes. 53 | * 54 | * @group nn_conv 55 | */ 56 | def conv2d[D <: FloatNN | ComplexNN]( 57 | input: Tensor[D], 58 | weight: Tensor[D], 59 | bias: Tensor[D] | Option[Tensor[D]] = None, 60 | stride: Int | (Int, Int) = 1, 61 | padding: Int | (Int, Int) = 0, 62 | dilation: Int | (Int, Int) = 1, 63 | groups: Int = 1 64 | ): Tensor[D] = 65 | fromNative( 66 | torchNative.conv2d( 67 | input.native, 68 | weight.native, 69 | toOptional(bias), 70 | toArray(stride), 71 | toArray(padding), 72 | toArray(dilation), 73 | groups 74 | ) 75 | ) 76 | 77 | /** Applies a 3D convolution over an input image composed of several input planes. 78 | * 79 | * @group nn_conv 80 | */ 81 | def conv3d[D <: FloatNN | ComplexNN]( 82 | input: Tensor[D], 83 | weight: Tensor[D], 84 | bias: Tensor[D] | Option[Tensor[D]] = None, 85 | stride: Int = 1, 86 | padding: Int = 0, 87 | dilation: Int = 1, 88 | groups: Int = 1 89 | ): Tensor[D] = 90 | fromNative( 91 | torchNative.conv3d( 92 | input.native, 93 | weight.native, 94 | toOptional(bias), 95 | Array(stride.toLong), 96 | Array(padding.toLong), 97 | Array(dilation.toLong), 98 | groups 99 | ) 100 | ) 101 | 102 | /** Applies a 1D transposed convolution operator over an input signal composed of several input 103 | * planes, sometimes also called “deconvolution”. 104 | * 105 | * @group nn_conv 106 | */ 107 | def convTranspose1d[D <: FloatNN | ComplexNN]( 108 | input: Tensor[D], 109 | weight: Tensor[D], 110 | bias: Tensor[D] | Option[Tensor[D]] = None, 111 | stride: Int | (Int, Int) = 1, 112 | padding: Int | (Int, Int) = 0, 113 | outputPadding: Int | (Int, Int) = 0, 114 | groups: Int = 1, 115 | dilation: Int | (Int, Int) = 1 116 | ): Tensor[D] = 117 | fromNative( 118 | torchNative.conv_transpose1d( 119 | input.native, 120 | weight.native, 121 | toOptional(bias), 122 | toArray(stride), 123 | toArray(padding), 124 | toArray(outputPadding), 125 | groups, 126 | toArray(dilation): _* 127 | ) 128 | ) 129 | 130 | /** Applies a 2D transposed convolution operator over an input image composed of several input 131 | * planes, sometimes also called “deconvolution”. 132 | * 133 | * @group nn_conv 134 | */ 135 | def convTranspose2d[D <: FloatNN | ComplexNN]( 136 | input: Tensor[D], 137 | weight: Tensor[D], 138 | bias: Tensor[D] | Option[Tensor[D]] = None, 139 | stride: Int | (Int, Int) = 1, 140 | padding: Int | (Int, Int) = 0, 141 | outputPadding: Int | (Int, Int) = 0, 142 | groups: Int = 1, 143 | dilation: Int | (Int, Int) = 1 144 | ): Tensor[D] = 145 | fromNative( 146 | torchNative.conv_transpose2d( 147 | input.native, 148 | weight.native, 149 | toOptional(bias), 150 | toArray(stride), 151 | toArray(padding), 152 | toArray(outputPadding), 153 | groups, 154 | toArray(dilation): _* 155 | ) 156 | ) 157 | 158 | /** Applies a 3D transposed convolution operator over an input image composed of several input 159 | * planes, sometimes also called “deconvolution”. 160 | * 161 | * @group nn_conv 162 | */ 163 | def convTranspose3d[D <: FloatNN | ComplexNN]( 164 | input: Tensor[D], 165 | weight: Tensor[D], 166 | bias: Tensor[D] | Option[Tensor[D]] = None, 167 | stride: Int | (Int, Int, Int) = 1, 168 | padding: Int | (Int, Int, Int) = 0, 169 | outputPadding: Int | (Int, Int, Int) = 0, 170 | groups: Int = 1, 171 | dilation: Int | (Int, Int) = 1 172 | ): Tensor[D] = 173 | fromNative( 174 | torchNative.conv_transpose3d( 175 | input.native, 176 | weight.native, 177 | toOptional(bias), 178 | toArray(stride), 179 | toArray(padding), 180 | toArray(outputPadding), 181 | groups, 182 | toArray(dilation): _* 183 | ) 184 | ) 185 | 186 | // TODO unfold 187 | // TODO fold 188 | } 189 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/functional/Dropout.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package functional 20 | 21 | import org.bytedeco.pytorch.global.torch as torchNative 22 | import torch.internal.NativeConverters.fromNative 23 | 24 | private[torch] trait Dropout { 25 | 26 | /** During training, randomly zeroes some of the elements of the input tensor with probability `p` 27 | * using samples from a Bernoulli distribution. 28 | * 29 | * @see 30 | * [[torch.nn.Dropout]] for details. 31 | * 32 | * @group nn_dropout 33 | */ 34 | def dropout[D <: DType](input: Tensor[D], p: Double = 0.5, training: Boolean = true): Tensor[D] = 35 | fromNative( 36 | torchNative.dropout(input.native, p, training) 37 | ) 38 | 39 | // TODO alpha_dropout Applies alpha dropout to the input. 40 | // TODO feature_alpha_dropout Randomly masks out entire channels (a channel is a feature map, e.g. 41 | // TODO dropout1d Randomly zero out entire channels (a channel is a 1D feature map, e.g., the jj-th channel of the ii-th sample in the batched input is a 1D tensor input[i,j]input[i,j]) of the input tensor). 42 | // TODO dropout2d Randomly zero out entire channels (a channel is a 2D feature map, e.g., the jj-th channel of the ii-th sample in the batched input is a 2D tensor input[i,j]input[i,j]) of the input tensor). 43 | // TODO dropout3d Randomly zero out entire channels (a channel is a 3D feature map, e.g., the jj-th channel of the ii-th sample in the batched input is a 3D tensor input[i,j]input[i,j]) of the input tensor). 44 | } 45 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/functional/Linear.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package functional 20 | 21 | import org.bytedeco.pytorch 22 | import org.bytedeco.pytorch.global.torch as torchNative 23 | import torch.internal.NativeConverters.{fromNative, toOptional} 24 | 25 | // Linear functions 26 | private[torch] trait Linear { 27 | 28 | /** Applies a linear transformation to the incoming data: $y = xA^T + b$. 29 | * 30 | * This operation supports 2-D `weight` with `sparse layout` 31 | * 32 | * Warning 33 | * 34 | * Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be 35 | * supported, or may not have autograd support. If you notice missing functionality please open a 36 | * feature request. 37 | * 38 | * This operator supports `TensorFloat32` 39 | * 40 | * Shape: 41 | * 42 | * - Input: $(*, in\_features)$ where [\*] means any number of additional dimensions, including 43 | * none 44 | * - Weight: $(out\_features, in\_features)$ or $(in\_features)$ 45 | * - Bias: $(out\_features)$ or $()$ 46 | * - Output: $(*, out\_features)$ or $(*)$, based on the shape of the weight 47 | * 48 | * @group nn_linear 49 | */ 50 | def linear[D <: DType]( 51 | input: Tensor[D], 52 | weight: Tensor[D], 53 | bias: Tensor[D] | Option[Tensor[D]] = None 54 | ): Tensor[D] = 55 | fromNative( 56 | torchNative.linear(input.native, weight.native, toOptional(bias)) 57 | ) 58 | 59 | /** Applies a bilinear transformation to the incoming data: $y = x_1^T A x_2 + b$ 60 | * 61 | * Shape: 62 | * 63 | * - input1: $(N, *, H_{in1})$ where $H_{in1}=\text{in1\_features}$ and $*$ means any number of 64 | * additional dimensions. All but the last dimension of the inputs should be the same. 65 | * - input2: $(N, *, H_{in2})$ where $H_{in2}=\text{in2\_features}$ 66 | * - weight: $(\text{out\_features}, \text{in1\_features}, \text{in2\_features})$ 67 | * - bias: $(\text{out\_features})$ 68 | * - output: $(N, *, H_{out})$ where $H_{out}=\text{out\_features}$ and all but the last 69 | * dimension are the same shape as the input. 70 | * 71 | * @group nn_linear 72 | */ 73 | def bilinear[D <: DType]( 74 | input1: Tensor[D], 75 | input2: Tensor[D], 76 | weight: Tensor[D], 77 | bias: Tensor[D] | Option[Tensor[D]] = None 78 | ): Tensor[D] = fromNative( 79 | torchNative.bilinear(input1.native, input2.native, weight.native, toOptional(bias)) 80 | ) 81 | 82 | } 83 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/functional/Sparse.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package functional 20 | 21 | import org.bytedeco.pytorch.global.torch as torchNative 22 | import torch.internal.NativeConverters.fromNative 23 | 24 | private[torch] trait Sparse { 25 | 26 | /** Takes LongTensor with index values of shape `(*)` and returns a tensor of shape `(*, 27 | * numClasses)` that have zeros everywhere except where the index of last dimension matches the 28 | * corresponding value of the input tensor, in which case it will be 1. 29 | * 30 | * @group nn_sparse 31 | */ 32 | def oneHot(input: Tensor[Int64], numClasses: Long = -1): Tensor[Int64] = 33 | fromNative(torchNative.one_hot(input.native, numClasses)) 34 | } 35 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/functional/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | 20 | /** @groupname nn_conv Convolution functions 21 | * @groupname nn_pooling Pooling functions 22 | * @groupname nn_attention Attention mechanisms 23 | * @groupname nn_activation Non-linear activation functions 24 | * @groupname nn_linear Linear functions 25 | * @groupname nn_dropout Dropout functions 26 | * @groupname nn_sparse Sparse functions 27 | * @groupname nn_distance Distance functions 28 | * @groupname nn_loss Loss functions 29 | * @groupname nn_vision Vision functions 30 | */ 31 | package object functional 32 | extends Activations 33 | with Convolution 34 | with Dropout 35 | with Linear 36 | with Loss 37 | with Pooling 38 | with Sparse 39 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/loss/CrossEntropyLoss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package loss 20 | 21 | import org.bytedeco.pytorch.CrossEntropyLossImpl 22 | import torch.nn.modules.Module 23 | import torch.internal.NativeConverters.fromNative 24 | 25 | /** This criterion computes the cross entropy loss between input and target. */ 26 | // TODO optional args 27 | final class CrossEntropyLoss extends Module { 28 | override private[torch] val nativeModule: CrossEntropyLossImpl = CrossEntropyLossImpl() 29 | 30 | override def hasBias(): Boolean = false 31 | 32 | def apply[D <: DType](input: Tensor[D], target: Tensor[?]): Tensor[D] = fromNative( 33 | nativeModule.forward(input.native, target.native) 34 | ) 35 | } 36 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/Module.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | import org.bytedeco.pytorch 22 | import org.bytedeco.pytorch.{InputArchive, OutputArchive} 23 | import Tensor.fromNative 24 | 25 | import scala.collection.immutable.{ArraySeq, SeqMap, TreeSeqMap} 26 | 27 | abstract class Module { 28 | 29 | protected[torch] var _nativeModule = pytorch.Module() 30 | private[torch] def nativeModule: pytorch.Module = _nativeModule // = pytorch.Module() 31 | private var childModules: TreeSeqMap[String, Module] = TreeSeqMap.empty 32 | 33 | def namedBuffers(recurse: Boolean = true): SeqMap[String, Tensor[?]] = 34 | val buffers = nativeModule.named_buffers(recurse) 35 | TreeSeqMap.from((0 until buffers.size().toInt).map { i => 36 | val item = buffers.get(i) 37 | (item.key().getString(), fromNative[DType](item.access())) 38 | }) 39 | 40 | def namedParameters(recurse: Boolean = true): SeqMap[String, Tensor[?]] = 41 | val params = nativeModule.named_parameters(recurse) 42 | TreeSeqMap.from((0 until params.size().toInt).map { i => 43 | val item = params.get(i) 44 | (item.key().getString(), fromNative[DType](item.access())) 45 | }) 46 | 47 | def parameters: Seq[Tensor[?]] = parameters(recurse = true) 48 | 49 | def parameters(recurse: Boolean): Seq[Tensor[?]] = 50 | ArraySeq.unsafeWrapArray(nativeModule.parameters().get).map(fromNative[DType]) 51 | 52 | // TODO make strict a parameter 53 | // TODO improve error handling 54 | def loadStateDict(stateDict: Map[String, Tensor[DType]]): Unit = 55 | val tensorsToLoad = namedParameters() ++ namedBuffers() 56 | // assert(stateDict.keySet -- tensorsToLoad.keySet == Set.empty, s"keys missing in state dict: ${tensorsToLoad.keySet -- stateDict.keySet}") 57 | for ((key, param) <- tensorsToLoad if stateDict.contains(key)) 58 | noGrad { 59 | param.copy_(stateDict(key)) 60 | } 61 | 62 | def modules(recurse: Boolean): Seq[Module] = 63 | childModules.values.flatMap(child => child +: child.modules).toSeq.distinct 64 | def modules: Seq[Module] = modules(recurse = true) 65 | 66 | def namedChildren: SeqMap[String, Module] = childModules 67 | def namedModules: SeqMap[String, Module] = 68 | namedChildren.flatMap((_, module) => module.namedModules) 69 | 70 | def apply(fn: Module => Unit): this.type = 71 | for (_, module) <- namedModules 72 | do module(fn) 73 | this 74 | 75 | def register[M <: Module](child: M, n: String = "")(using name: sourcecode.Name): M = 76 | val name_ = if n.trim().isEmpty() then name.value else n.trim() 77 | // println(s"registering ${name_}:$child") 78 | childModules = childModules.updated(name_, child) 79 | nativeModule.register_module(name_, child.nativeModule) 80 | child 81 | 82 | def registerModule[M <: Module](child: M, n: String = "")(using name: sourcecode.Name): M = 83 | register(child = child)(using name) 84 | 85 | def registerParameter[D <: DType](t: Tensor[D], requiresGrad: Boolean = true, n: String = "")( 86 | using name: sourcecode.Name 87 | ): Tensor[D] = 88 | val name_ = if n.trim().isEmpty() then name.value else n.trim() 89 | nativeModule.register_parameter(name_, t.native, requiresGrad) 90 | t 91 | 92 | def registerBuffer[D <: DType](t: Tensor[D], n: String = "")(using 93 | name: sourcecode.Name 94 | ): Tensor[D] = 95 | val name_ = if n.trim().isEmpty() then name.value else n.trim() 96 | nativeModule.register_buffer(name_, t.native) 97 | t 98 | 99 | /** Adds a buffer to the module. */ 100 | def registerBuffer[D <: DType](name: String, tensor: Tensor[D]): Tensor[D] = 101 | fromNative(nativeModule.register_buffer(name, tensor.native)) 102 | 103 | def hasBias(): Boolean = modules.exists(_.hasBias()) 104 | 105 | def eval(): Unit = nativeModule.eval() 106 | 107 | def isTraining: Boolean = nativeModule.is_training 108 | 109 | def train(on: Boolean = true): Unit = nativeModule.train(on) 110 | 111 | def to(device: Device): this.type = 112 | nativeModule.to(device.toNative, false) 113 | this 114 | 115 | def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive) 116 | 117 | def load(inputArchive: InputArchive) = nativeModule.load(inputArchive) 118 | 119 | override def toString(): String = getClass().getSimpleName() 120 | 121 | private def doSummarize(indent: Int): String = 122 | val thisModule = toString 123 | if modules.isEmpty then thisModule 124 | else 125 | thisModule + namedChildren 126 | .map((name, module) => s"${" " * (indent + 2)}($name): " + module.doSummarize(indent + 2)) 127 | .mkString("(\n", "\n", s"\n${" " * indent})") 128 | def summarize: String = 129 | doSummarize(0) 130 | } 131 | 132 | trait HasParams[ParamType <: FloatNN | ComplexNN: Default] extends Module: 133 | override def parameters(recurse: Boolean): Seq[Tensor[ParamType]] = 134 | nativeModule.parameters(recurse).get().toSeq.map(fromNative[ParamType]) 135 | override def parameters: Seq[Tensor[ParamType]] = parameters(recurse = true) 136 | transparent inline def paramType: DType = summon[Default[ParamType]].dtype 137 | 138 | trait HasWeight[ParamType <: FloatNN | ComplexNN]: 139 | def weight: Tensor[ParamType] 140 | 141 | /** Transforms a single tensor into another one of the same type. */ 142 | trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]): 143 | override def toString(): String = "TensorModule" 144 | 145 | trait TensorModuleBase[D <: DType, D2 <: DType] extends Module with (Tensor[D] => Tensor[D2]) { 146 | override def toString() = "TensorModuleBase" 147 | } 148 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package activation 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.LogSoftmaxImpl 24 | import org.bytedeco.pytorch.LogSoftmaxOptions 25 | import torch.internal.NativeConverters.fromNative 26 | 27 | /** Applies the log(Softmax(x)) function to an n-dimensional input Tensor. The LogSoftmax 28 | * formulation can be simplified as: 29 | * 30 | * TODO LaTeX 31 | * 32 | * Example: 33 | * 34 | * ```scala sc 35 | * import torch.* 36 | * val m = nn.LogSoftmax(dim = 1) 37 | * val input = torch.randn(Seq(2, 3)) 38 | * val output = m(input) 39 | * ``` 40 | */ 41 | final class LogSoftmax[D <: DType: Default](dim: Int) extends TensorModule[D]: 42 | private val options = new LogSoftmaxOptions(dim) 43 | options.dim().put(dim) 44 | 45 | override val nativeModule: LogSoftmaxImpl = LogSoftmaxImpl(options) 46 | 47 | override def hasBias(): Boolean = false 48 | 49 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 50 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/activation/ReLU.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package activation 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.{ReLUImpl, ReLUOptions} 24 | import torch.internal.NativeConverters.fromNative 25 | 26 | /** Applies the rectified linear unit function element-wise: 27 | * 28 | * $\text{ReLU}(x) = (x)^+ = \max(0, x)$ 29 | */ 30 | final class ReLU[D <: DType: Default](inplace: Boolean = false) extends TensorModule[D]: 31 | private val options = new ReLUOptions() 32 | options.inplace().put(inplace) 33 | 34 | override protected[torch] val nativeModule: ReLUImpl = ReLUImpl(options) 35 | 36 | override def hasBias(): Boolean = false 37 | 38 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 39 | 40 | override def toString = getClass().getSimpleName() 41 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/activation/Softmax.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package activation 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.SoftmaxImpl 24 | import org.bytedeco.pytorch.SoftmaxOptions 25 | import torch.internal.NativeConverters.fromNative 26 | 27 | /** Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the 28 | * elements of the n-dimensional output Tensor lie in the range [0,1] and sum to 1. 29 | * 30 | * Softmax is defined as: $$\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$ 31 | * 32 | * When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``. 33 | */ 34 | final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]: 35 | private val options = new SoftmaxOptions(dim) 36 | 37 | override val nativeModule: SoftmaxImpl = SoftmaxImpl(options) 38 | 39 | override def hasBias(): Boolean = false 40 | 41 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 42 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/activation/Tanh.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package activation 21 | 22 | import org.bytedeco.pytorch.TanhImpl 23 | import torch.internal.NativeConverters.fromNative 24 | 25 | /** Applies the Hyperbolic Tangent (Tanh) function element-wise. Tanh is defined as:: 26 | * 27 | * TODO LaTeX 28 | * 29 | * Example: 30 | * 31 | * ```scala sc 32 | * import torch.* 33 | * val m = nn.Tanh() 34 | * val input = torch.randn(Seq(2)) 35 | * val output = m(input) 36 | * ``` 37 | */ 38 | final class Tanh[D <: DType: Default] extends TensorModule[D]: 39 | 40 | override protected[torch] val nativeModule: TanhImpl = new TanhImpl() 41 | 42 | override def hasBias(): Boolean = false 43 | 44 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 45 | 46 | override def toString = getClass().getSimpleName() 47 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package batchnorm 21 | 22 | import org.bytedeco.pytorch.{BatchNorm1dImpl, BatchNormOptions} 23 | import org.bytedeco.pytorch 24 | import torch.internal.NativeConverters.fromNative 25 | 26 | /** Applies Batch Normalization over a 2D or 3D input as described in the paper [Batch 27 | * Normalization: Accelerating Deep Network Training by Reducing Internal Covariate 28 | * Shift](https://arxiv.org/abs/1502.03167) . 29 | * 30 | * $$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$ 31 | * 32 | * The mean and standard-deviation are calculated per-dimension over the mini-batches and $\gamma$ 33 | * and $\beta$ are learnable parameter vectors of size [C] (where [C] is the number of features or 34 | * channels of the input). By default, the elements of $\gamma$ are set to 1 and the elements of 35 | * $\beta$ are set to 0. The standard-deviation is calculated via the biased estimator, equivalent 36 | * to *[torch.var(input, unbiased=False)]*. 37 | * 38 | * Also by default, during training this layer keeps running estimates of its computed mean and 39 | * variance, which are then used for normalization during evaluation. The running estimates are 40 | * kept with a default `momentum` of 0.1. 41 | * 42 | * If `trackRunningStats` is set to `false`, this layer then does not keep running estimates, and 43 | * batch statistics are instead used during evaluation time as well. 44 | * 45 | * Example: 46 | * 47 | * ```scala sc 48 | * import torch.nn 49 | * // With Learnable Parameters 50 | * var m = nn.BatchNorm1d(numFeatures = 100) 51 | * // Without Learnable Parameters 52 | * m = nn.BatchNorm1d(100, affine = false) 53 | * val input = torch.randn(Seq(20, 100)) 54 | * val output = m(input) 55 | * ``` 56 | * 57 | * @note 58 | * This `momentum` argument is different from one used in optimizer classes and the conventional 59 | * notion of momentum. Mathematically, the update rule for running statistics here is 60 | * $\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$, 61 | * where $\hat{x}$ is the estimated statistic and $x_t$ is the new observed value. 62 | * 63 | * Because the Batch Normalization is done over the [C] dimension, computing statistics on [(N, L)] 64 | * slices, it\'s common terminology to call this Temporal Batch Normalization. 65 | * 66 | * Args: 67 | * 68 | * @param numFeatures 69 | * number of features or channels $C$ of the input 70 | * @param eps: 71 | * a value added to the denominator for numerical stability. Default: 1e-5 72 | * @param momentum 73 | * the value used for the runningVean and runningVar computation. Can be set to `None` for 74 | * cumulative moving average (i.e. simple average). Default: 0.1 75 | * @param affine: 76 | * a boolean value that when set to `true`, this module has learnable affine parameters. Default: 77 | * `True` 78 | * @param trackRunningStats: 79 | * a boolean value that when set to `true`, this module tracks the running mean and variance, and 80 | * when set to `false`, this module does not track such statistics, and initializes statistics 81 | * buffers `runningMean` and `runningVar` as `None`. When these buffers are `None`, this module 82 | * always uses batch statistics. in both training and eval modes. Default: `true` 83 | * 84 | * Shape: 85 | * 86 | * - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size, $C$ is the number of features 87 | * or channels, and $L$ is the sequence length 88 | * - Output: $(N, C)$ or $(N, C, L)$ (same shape as input) 89 | * 90 | * @group nn_conv 91 | * 92 | * TODO use dtype 93 | */ 94 | final class BatchNorm1d[ParamType <: FloatNN | ComplexNN: Default]( 95 | numFeatures: Int, 96 | eps: Double = 1e-05, 97 | momentum: Double = 0.1, 98 | affine: Boolean = true, 99 | trackRunningStats: Boolean = true 100 | ) extends HasParams[ParamType] 101 | with HasWeight[ParamType] 102 | with TensorModule[ParamType]: 103 | 104 | private val options = new BatchNormOptions(numFeatures) 105 | options.eps().put(eps) 106 | options.momentum().put(momentum) 107 | options.affine().put(affine) 108 | options.track_running_stats().put(trackRunningStats) 109 | 110 | override private[torch] val nativeModule: BatchNorm1dImpl = BatchNorm1dImpl(options) 111 | nativeModule.to(paramType.toScalarType, false) 112 | 113 | // TODO weight, bias etc. are undefined if affine = false. We need to take that into account 114 | val weight: Tensor[ParamType] = fromNative[ParamType](nativeModule.weight) 115 | val bias: Tensor[ParamType] = fromNative[ParamType](nativeModule.bias) 116 | // TODO running_mean, running_var, num_batches_tracked 117 | 118 | override def hasBias(): Boolean = true 119 | 120 | def apply(t: Tensor[ParamType]): Tensor[ParamType] = fromNative(nativeModule.forward(t.native)) 121 | 122 | override def toString(): String = s"${getClass().getSimpleName()}(numFeatures=$numFeatures)" 123 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package batchnorm 21 | 22 | import org.bytedeco.pytorch.{BatchNorm2dImpl, BatchNormOptions} 23 | import org.bytedeco.pytorch 24 | import torch.internal.NativeConverters.fromNative 25 | 26 | /** Applies Batch Normalization over a 4D input as described in the paper [Batch Normalization: 27 | * Accelerating Deep Network Training by Reducing Internal Covariate 28 | * Shift](https://arxiv.org/abs/1502.03167) . 29 | * 30 | * $$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$ 31 | * 32 | * The mean and standard-deviation are calculated per-dimension over the mini-batches and $\gamma$ 33 | * and $\beta$ are learnable parameter vectors of size [C] (where [C] is the number of features or 34 | * channels of the input). By default, the elements of $\gamma$ are set to 1 and the elements of 35 | * $\beta$ are set to 0. The standard-deviation is calculated via the biased estimator, equivalent 36 | * to *[torch.var(input, unbiased=False)]*. 37 | * 38 | * Also by default, during training this layer keeps running estimates of its computed mean and 39 | * variance, which are then used for normalization during evaluation. The running estimates are 40 | * kept with a default `momentum` of 0.1. 41 | * 42 | * If `trackRunningStats` is set to `false`, this layer then does not keep running estimates, and 43 | * batch statistics are instead used during evaluation time as well. 44 | * 45 | * Example: 46 | * 47 | * ```scala sc 48 | * import torch.nn 49 | * // With Learnable Parameters 50 | * var m = nn.BatchNorm2d(numFeatures = 100) 51 | * // Without Learnable Parameters 52 | * m = nn.BatchNorm2d(100, affine = false) 53 | * val input = torch.randn(Seq(20, 100, 35, 45)) 54 | * val output = m(input) 55 | * ``` 56 | * 57 | * @note 58 | * This `momentum` argument is different from one used in optimizer classes and the conventional 59 | * notion of momentum. Mathematically, the update rule for running statistics here is 60 | * $\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$, 61 | * where $\hat{x}$ is the estimated statistic and $x_t$ is the new observed value. 62 | * 63 | * Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) 64 | * slices, it’s common terminology to call this Spatial Batch Normalization. 65 | * 66 | * @param numFeatures 67 | * number of features or channels $C$ of the input 68 | * @param eps: 69 | * a value added to the denominator for numerical stability. Default: 1e-5 70 | * @param momentum 71 | * the value used for the runningVean and runningVar computation. Can be set to `None` for 72 | * cumulative moving average (i.e. simple average). Default: 0.1 73 | * @param affine: 74 | * a boolean value that when set to `true`, this module has learnable affine parameters. Default: 75 | * `True` 76 | * @param trackRunningStats: 77 | * a boolean value that when set to `true`, this module tracks the running mean and variance, and 78 | * when set to `false`, this module does not track such statistics, and initializes statistics 79 | * buffers `runningMean` and `runningVar` as `None`. When these buffers are `None`, this module 80 | * always uses batch statistics. in both training and eval modes. Default: `true` 81 | * 82 | * Shape: 83 | * 84 | * - Input: $(N, C, H, W)$ 85 | * - Output: $(N, C, H, W)$ (same shape as input) 86 | * 87 | * @group nn_conv 88 | * 89 | * TODO use dtype 90 | */ 91 | final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default]( 92 | numFeatures: Int, 93 | eps: Double = 1e-05, 94 | momentum: Double = 0.1, 95 | affine: Boolean = true, 96 | trackRunningStats: Boolean = true 97 | ) extends HasParams[ParamType] 98 | with HasWeight[ParamType] 99 | with TensorModule[ParamType]: 100 | 101 | private val options = new BatchNormOptions(numFeatures) 102 | options.eps().put(eps) 103 | options.momentum().put(momentum) 104 | options.affine().put(affine) 105 | options.track_running_stats().put(trackRunningStats) 106 | 107 | override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options) 108 | nativeModule.to(paramType.toScalarType, false) 109 | 110 | // TODO weight, bias etc. are undefined if affine = false. We need to take that into account 111 | val weight: Tensor[ParamType] = fromNative[ParamType](nativeModule.weight) 112 | val bias: Tensor[ParamType] = fromNative[ParamType](nativeModule.bias) 113 | // TODO running_mean, running_var, num_batches_tracked 114 | 115 | override def hasBias(): Boolean = true 116 | 117 | def apply(t: Tensor[ParamType]): Tensor[ParamType] = fromNative(nativeModule.forward(t.native)) 118 | 119 | override def toString(): String = s"${getClass().getSimpleName()}(numFeatures=$numFeatures)" 120 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/container/ModuleList.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package container 21 | 22 | import sourcecode.Name 23 | 24 | /** Holds submodules in a list. 25 | * 26 | * It can be indexed like a regular Python list, but the modules it contains are properly 27 | * registered, and will be visible by all [[torch.nn.Module]] methods. 28 | * 29 | * @example 30 | * ```scala 31 | * class MyModule extends nn.Module: 32 | * val linears = register( nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) ) 33 | * 34 | * // ModuleList can act as an iterable, or be indexed using ints 35 | * def forward(self, x) = 36 | * var x_ = x.copy_(x) 37 | * for l <- linears 38 | * x_ = x_ + l(x_) 39 | * x 40 | * ``` 41 | * 42 | * @see 43 | * [[torch.nn.ModuleList https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html?highlight=modulelist#torch.nn.ModuleList]] 44 | * @see 45 | * [[container ModuleList https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#ModuleList]] 46 | */ 47 | final class ModuleList[D <: DType](override val modules: TensorModule[D]*) 48 | extends Module 49 | // with TensorModule[D]: 50 | // TODO 51 | with TensorModule[D] 52 | with scala.collection.immutable.Iterable[TensorModule[D]]: 53 | 54 | modules.zipWithIndex.foreach((module, index) => 55 | this.register(module)(using Name(index.toString())) 56 | ) 57 | 58 | override def iterator: Iterator[TensorModule[D]] = modules.iterator 59 | 60 | override def apply(input: Tensor[D]): Tensor[D] = 61 | modules.foldLeft(input)((i, module) => module(i)) 62 | 63 | override def toString = getClass().getSimpleName() 64 | 65 | /** Insert a given module before a given index in the list. 66 | * 67 | * @param index 68 | * index to insert. 69 | * @param module 70 | * module to insert 71 | * @return 72 | * ModuleList[D] with new elements 73 | */ 74 | def insert(index: Int, module: TensorModule[D]): ModuleList[D] = 75 | val (before, after) = modules.splitAt(index) 76 | val all = before ++ (after.prepended(module)) 77 | // TODO: not in Python code. Note other modules retain index, so we have repeats 78 | this.register(module)(using Name(index.toString())) 79 | // TODO: make modules list mutable? 80 | ModuleList(all: _*) 81 | 82 | /** Appends a given module to the end of the list. 83 | * 84 | * @param module 85 | * module to append 86 | * @return 87 | * ModuleList[D] with new elements 88 | */ 89 | def append(module: TensorModule[D]): ModuleList[D] = 90 | // TODO: not in Module 91 | // self.add_module(str(len(self)), module) 92 | // TODO: not in Python code 93 | val index = modules.length 94 | this.register(module)(using Name(index.toString())) 95 | val all = modules.appended(module) 96 | // TODO: make modules list mutable? 97 | ModuleList(all: _*) 98 | 99 | /** Appends modules from a Python iterable to the end of the list. 100 | * 101 | * @param modules 102 | * iterable of modules to append 103 | * @return 104 | */ 105 | def extend(newModules: Iterable[TensorModule[D]]): ModuleList[D] = 106 | // TODO: not in Module 107 | // offset = len(self) 108 | // for i, module in enumerate(modules): 109 | // self.add_module(str(offset + i), module) 110 | // return self 111 | // val offset = modules.length 112 | val all = modules ++ newModules 113 | // Not in Python 114 | newModules.zipWithIndex.foreach((module, index) => 115 | this.register(module)(using Name(index.toString())) 116 | ) 117 | // TODO: make modules list mutable? 118 | ModuleList(all: _*) 119 | 120 | override def hasBias(): Boolean = modules.exists(_.hasBias()) 121 | 122 | def apply(i: Int): torch.nn.modules.TensorModule[D] = modules(i) 123 | 124 | def length: Int = modules.length 125 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/container/Sequential.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package container 21 | 22 | import sourcecode.Name 23 | 24 | final class Sequential[D <: DType](override val modules: TensorModule[D]*) 25 | extends Module 26 | with TensorModule[D]: 27 | modules.zipWithIndex.foreach((module, index) => 28 | this.register(module)(using Name(index.toString())) 29 | ) 30 | 31 | override def hasBias(): Boolean = modules.exists(_.hasBias()) 32 | 33 | override def apply(input: Tensor[D]): Tensor[D] = 34 | modules.foldLeft(input)((i, module) => module(i)) 35 | 36 | override def toString = getClass().getSimpleName() 37 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/conv/Conv2d.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package conv 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.{Conv2dImpl, Conv2dOptions, kZeros, kReflect, kReplicate, kCircular} 24 | import torch.internal.NativeConverters.{fromNative, toNative} 25 | import torch.nn.modules.conv.Conv2d.PaddingMode 26 | 27 | /** Applies a 2D convolution over an input signal composed of several input planes. 28 | * 29 | * @group nn_conv 30 | */ 31 | final class Conv2d[ParamType <: FloatNN | ComplexNN: Default]( 32 | inChannels: Long, 33 | outChannels: Long, 34 | kernelSize: Int | (Int, Int), 35 | stride: Int | (Int, Int) = 1, 36 | padding: Int | (Int, Int) = 0, 37 | dilation: Int | (Int, Int) = 1, 38 | groups: Int = 1, 39 | bias: Boolean = true, 40 | paddingMode: PaddingMode = PaddingMode.Zeros 41 | ) extends HasParams[ParamType] 42 | with TensorModule[ParamType]: 43 | 44 | private val options = new Conv2dOptions(inChannels, outChannels, toNative(kernelSize)) 45 | options.stride().put(toNative(stride)) 46 | options.padding().put(toNative(padding)) 47 | options.dilation().put(toNative(dilation)) 48 | options.groups().put(groups) 49 | options.bias().put(bias) 50 | private val paddingModeNative = paddingMode match 51 | case PaddingMode.Zeros => new kZeros 52 | case PaddingMode.Reflect => new kReflect 53 | case PaddingMode.Replicate => new kReplicate 54 | case PaddingMode.Circular => new kCircular 55 | options.padding_mode().put(paddingModeNative) 56 | 57 | override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options) 58 | nativeModule.to(paramType.toScalarType, false) 59 | 60 | def apply(t: Tensor[ParamType]): Tensor[ParamType] = fromNative(nativeModule.forward(t.native)) 61 | 62 | def weight: Tensor[ParamType] = fromNative(nativeModule.weight) 63 | 64 | override def hasBias(): Boolean = options.bias().get() 65 | 66 | override def toString = 67 | s"Conv2d($inChannels, $outChannels, kernelSize=$kernelSize, stride=$stride, padding=$padding, bias=$bias)" 68 | 69 | object Conv2d: 70 | enum PaddingMode: 71 | case Zeros, Reflect, Replicate, Circular 72 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/flatten/Flatten.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package flatten 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.{FlattenImpl, FlattenOptions} 24 | import torch.internal.NativeConverters.fromNative 25 | 26 | // format: off 27 | /** Flattens a contiguous range of dims into a tensor. For use with [[nn.Sequential]]. 28 | * 29 | * Shape: 30 | * \- Input: $(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)$,' where $S_{i}$ is the size 31 | * at dimension $i$ and $*$ means any number of dimensions including none. 32 | * \- Output: $(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)$. 33 | * 34 | * Example: 35 | * 36 | * ```scala 37 | * import torch.nn 38 | * 39 | * val input = torch.randn(Seq(32, 1, 5, 5)) 40 | * // With default parameters 41 | * val m1 = nn.Flatten() 42 | * // With non-default parameters 43 | * val m2 = nn.Flatten(0, 2) 44 | * ``` 45 | * 46 | * @group nn_flatten 47 | * 48 | * @param startDim 49 | * first dim to flatten 50 | * @param endDim 51 | * last dim to flatten 52 | */ 53 | // format: on 54 | final class Flatten[D <: DType: Default](startDim: Int = 1, endDim: Int = -1) 55 | extends TensorModule[D]: 56 | 57 | private val options = FlattenOptions() 58 | options.start_dim().put(startDim) 59 | options.end_dim().put(endDim) 60 | 61 | override val nativeModule: FlattenImpl = FlattenImpl(options) 62 | 63 | override def hasBias(): Boolean = false 64 | 65 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 66 | 67 | override def toString = getClass().getSimpleName() 68 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/linear/Identity.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package linear 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.IdentityImpl 24 | import torch.internal.NativeConverters.fromNative 25 | import scala.annotation.nowarn 26 | 27 | /** A placeholder identity operator that is argument-insensitive. 28 | * 29 | * @group nn_linear 30 | */ 31 | final class Identity[D <: DType: Default](@nowarn("msg=unused explicit parameter") args: Any*) 32 | extends TensorModule[D]: 33 | override val nativeModule: IdentityImpl = IdentityImpl() 34 | 35 | override def hasBias(): Boolean = false 36 | 37 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 38 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/linear/Linear.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package linear 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.{LinearImpl, LinearOptions} 24 | import internal.NativeConverters.fromNative 25 | 26 | /** Applies a linear transformation to the incoming data: $y = xA^T + b$ 27 | * 28 | * This module supports `TensorFloat32`. 29 | * 30 | * Example: 31 | * 32 | * ```scala sc:nocompile 33 | * import torch.* 34 | * 35 | * val linear = nn.Linear[Float32](20, 30) 36 | * val input = torch.rand(Seq(128, 20)) 37 | * println(linear(input).size) // ArraySeq(128, 30) 38 | * ``` 39 | * 40 | * @group nn_linear 41 | * 42 | * @param inFeatures 43 | * size of each input sample 44 | * @param outFeatures 45 | * size of each output sample 46 | * @param bias 47 | * If set to ``false``, the layer will not learn an additive bias. Default: ``true`` 48 | */ 49 | final class Linear[ParamType <: FloatNN: Default]( 50 | inFeatures: Long, 51 | outFeatures: Long, 52 | addBias: Boolean = true 53 | // dtype: ParamType = defaultDType[ParamType] 54 | ) extends HasParams[ParamType] 55 | with HasWeight[ParamType] 56 | with TensorModule[ParamType]: 57 | 58 | private val options = new LinearOptions(inFeatures, outFeatures) 59 | options.bias().put(addBias) 60 | 61 | override private[torch] val nativeModule: LinearImpl = new LinearImpl(options) 62 | nativeModule.to(paramType.toScalarType, false) 63 | 64 | override def hasBias(): Boolean = options.bias().get() 65 | 66 | def weight = fromNative[ParamType](nativeModule.weight()) 67 | def weight_=(t: Tensor[ParamType]): Tensor[ParamType] = 68 | nativeModule.weight(t.native) 69 | t 70 | 71 | def bias = fromNative[ParamType](nativeModule.bias()) 72 | def bias_=(t: Tensor[ParamType]): Tensor[ParamType] = 73 | nativeModule.bias(t.native) 74 | t 75 | 76 | def apply(input: Tensor[ParamType]): Tensor[ParamType] = fromNative( 77 | nativeModule.forward(input.native) 78 | ) 79 | 80 | override def toString = 81 | s"${getClass.getSimpleName}(inFeatures=$inFeatures, outFeatures=$outFeatures, bias=$addBias)" 82 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package normalization 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.{GroupNormImpl, GroupNormOptions} 24 | import torch.internal.NativeConverters.fromNative 25 | 26 | /** Applies Group Normalization over a mini-batch of inputs 27 | * 28 | * @param numGroups 29 | * number of groups to separate the channels into 30 | * @param numChannels 31 | * number of channels expected in input 32 | * @param eps 33 | * a value added to the denominator for numerical stability 34 | * @param affine 35 | * a boolean value that when set to `true`, this module has learnable per-channel affine 36 | * parameters initialized to ones (for weights) and zeros (for biases) 37 | */ 38 | final class GroupNorm[ParamType <: FloatNN | ComplexNN: Default]( 39 | numGroups: Int, 40 | numChannels: Int, 41 | eps: Double = 1e-05, 42 | affine: Boolean = true 43 | ) extends HasWeight[ParamType] 44 | with TensorModule[ParamType]: 45 | private val options: GroupNormOptions = GroupNormOptions(numGroups, numChannels) 46 | options.eps().put(eps) 47 | options.affine().put(affine) 48 | 49 | override private[torch] val nativeModule: GroupNormImpl = GroupNormImpl(options) 50 | 51 | val weight: Tensor[ParamType] = fromNative[ParamType](nativeModule.weight) 52 | val bias: Tensor[ParamType] = fromNative[ParamType](nativeModule.bias) 53 | 54 | override def hasBias(): Boolean = true 55 | 56 | def apply(t: Tensor[ParamType]): Tensor[ParamType] = 57 | fromNative[ParamType](nativeModule.forward(t.native)) 58 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // cSpell:ignore elementwise_affine, nn, shapenormalized, storch, NLP 18 | 19 | package torch 20 | package nn 21 | package modules 22 | package normalization 23 | 24 | import org.bytedeco.pytorch 25 | import org.bytedeco.pytorch.{LayerNormImpl, LayerNormOptions, LongVector} 26 | import internal.NativeConverters.fromNative 27 | 28 | // format: off 29 | /** Applies Layer Normalization over a mini-batch of inputs as described in the paper 30 | * [[Layer Normalization https://arxiv.org/abs/1607.06450]] 31 | * 32 | * TODO 33 | * $$ 34 | * y=x−E[x]Var[x]+ϵ∗γ+β 35 | * y=Var[x]+ϵ 36 | * ​x−E[x]​∗γ+β 37 | * $$ 38 | * 39 | * The mean and standard-deviation are calculated over the last D dimensions, where D is the 40 | * dimension of `normalized_shape`. For example, if `normalized_shape` is (3, 5) (a 2-dimensional 41 | * shape), the mean and standard-deviation are computed over the last 2 dimensions of the input 42 | * (i.e. input.mean((-2, -1))). γ and β are learnable affine transform parameters of 43 | * `normalized_shape` if `elementwise_affine` is `true`. The standard-deviation is calculated via 44 | * the biased estimator, equivalent to `torch.var(input, unbiased=False)`. 45 | * 46 | * @note 47 | * Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for 48 | * each entire channel/plane with the `affine` option, Layer Normalization applies per-element 49 | * scale and bias with `elementwise_affine`. 50 | * 51 | * @variable 52 | * weight – the learnable weights of the module of shape `normalized_shape` when 53 | * `elementwise_affine` is set to `true`. The values are initialized to 1. bias – the learnable 54 | * bias of the module of shape `normalized_shape` when `elementwise_affine` is set to `true`. The 55 | * values are initialized to 0. 56 | * 57 | * @example 58 | * TODO 59 | * ```scala 60 | * // NLP Example 61 | * val Seq(batch, sentence_length, embedding_dim) = Seq(20, 5, 10) 62 | * val embedding = torch.randn(batch, sentence_length, embedding_dim) 63 | * val layer_norm = nn.LayerNorm(embedding_dim) 64 | * // Activate module 65 | * val out = layer_norm(embedding) 66 | * 67 | * // Image Example 68 | * val Seq(N, C, H, W) = Seq(20, 5, 10, 10) 69 | * val input = torch.randn(N, C, H, W) 70 | * // Normalize over the last three dimensions (i.e. the channel and spatial dimensions) 71 | * val layer_norm = nn.LayerNorm([C, H, W]) 72 | * val output = layer_norm(input) 73 | * ``` 74 | * 75 | * @param `normalized_shape` 76 | * – input shape from an expected input of size 77 | * [∗×normalized_shape[0]×normalized_shape[1]×…×normalized_shape[−1]] 78 | * [∗×normalized_shape[0]×normalized_shape[1]×…×normalized_shape[−1]] If a single integer is 79 | * used, it is treated as a singleton list, and this module will normalize over the last 80 | * dimension which is expected to be of that specific size. 81 | * @param eps 82 | * – a value added to the denominator for numerical stability. Default: 1e-5 83 | * @param elementwise_affine 84 | * – a boolean value that when set to `true`, this module has learnable per-element affine 85 | * parameters initialized to ones (for weights) and zeros (for biases). Default: `true`. 86 | */ 87 | // format: on 88 | final class LayerNorm[ParamType <: FloatNN | ComplexNN: Default]( 89 | normalizedShape: Seq[Int], 90 | eps: Double = 1e-05, 91 | elementWiseAffine: Boolean = true 92 | ) extends HasWeight[ParamType] 93 | with TensorModule[ParamType]: 94 | 95 | private val shape: LongVector = LongVector(normalizedShape.map(_.toLong): _*) 96 | private val options: LayerNormOptions = LayerNormOptions(shape) 97 | options.eps().put(eps) 98 | options.elementwise_affine().put(elementWiseAffine) 99 | 100 | override private[torch] val nativeModule: LayerNormImpl = LayerNormImpl(options) 101 | 102 | val weight: Tensor[ParamType] = fromNative[ParamType](nativeModule.weight) 103 | val bias: Tensor[ParamType] = fromNative[ParamType](nativeModule.bias) 104 | 105 | override def hasBias(): Boolean = true 106 | 107 | def apply(t: Tensor[ParamType]): Tensor[ParamType] = 108 | fromNative[ParamType](nativeModule.forward(t.native)) 109 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package pooling 21 | 22 | import org.bytedeco.pytorch.AdaptiveAvgPool2dImpl 23 | import org.bytedeco.pytorch 24 | 25 | import torch.internal.NativeConverters.{fromNative, toOptional} 26 | import org.bytedeco.pytorch.LongOptionalVector 27 | import org.bytedeco.pytorch.LongOptional 28 | 29 | /** Applies a 2D adaptive average pooling over an input signal composed of several input planes. 30 | * 31 | * The output is of size H x W, for any input size. The number of output features is equal to the 32 | * number of input planes. 33 | */ 34 | final class AdaptiveAvgPool2d[D <: BFloat16 | Float32 | Float64: Default]( 35 | outputSize: Int | Option[Int] | (Option[Int], Option[Int]) | (Int, Int) 36 | ) extends Module { 37 | 38 | private def nativeOutputSize = outputSize match 39 | case (h: Int, w: Int) => new LongOptionalVector(new LongOptional(h), new LongOptional(w)) 40 | case x: Int => new LongOptionalVector(new LongOptional(x), new LongOptional(x)) 41 | // We know this can only be int so we can suppress the type test for Option[Int] cannot be checked at runtime warning 42 | case (h: Option[Int @unchecked], w: Option[Int @unchecked]) => 43 | new LongOptionalVector(h.toOptional, w.toOptional) 44 | case x: Option[Int] => 45 | new LongOptionalVector(x.toOptional, x.toOptional) 46 | 47 | override protected[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl( 48 | nativeOutputSize.get(0) 49 | ) 50 | 51 | override def hasBias(): Boolean = false 52 | 53 | def apply(t: Tensor[D]): Tensor[D] = fromNative( 54 | nativeModule.forward(t.native) 55 | ) 56 | } 57 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package pooling 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions} 24 | import torch.internal.NativeConverters.{fromNative, toNative} 25 | 26 | /** Applies a 2D max pooling over an input signal composed of several input planes. */ 27 | final class MaxPool2d[D <: BFloat16 | Float32 | Float64: Default]( 28 | kernelSize: Int | (Int, Int), 29 | stride: Option[Int | (Int, Int)] = None, 30 | padding: Int | (Int, Int) = 0, 31 | dilation: Int | (Int, Int) = 1, 32 | // returnIndices: Boolean = false, 33 | ceilMode: Boolean = false 34 | ) extends TensorModule[D]: 35 | 36 | private val options: MaxPool2dOptions = MaxPool2dOptions(toNative(kernelSize)) 37 | stride.foreach(s => options.stride().put(toNative(s))) 38 | options.padding().put(toNative(padding)) 39 | options.dilation().put(toNative(dilation)) 40 | options.ceil_mode().put(ceilMode) 41 | 42 | override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options) 43 | 44 | override def hasBias(): Boolean = false 45 | 46 | override def toString(): String = 47 | s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)" 48 | 49 | def apply(t: Tensor[D]): Tensor[D] = fromNative(nativeModule.forward(t.native)) 50 | // TODO forward_with_indices 51 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/regularization/Dropout.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | // cSpell:ignore nn, inplace 18 | 19 | package torch 20 | package nn 21 | package modules 22 | package regularization 23 | 24 | import org.bytedeco.pytorch 25 | import org.bytedeco.pytorch.DropoutImpl 26 | import org.bytedeco.pytorch.DropoutOptions 27 | import torch.internal.NativeConverters.fromNative 28 | 29 | // format: off 30 | /** During training, randomly zeroes some of the elements of the input tensor with probability `p` 31 | * using samples from a Bernoulli distribution. Each channel will be zeroed out independently on 32 | * every forward call. 33 | * 34 | * This has proven to be an effective technique for regularization and preventing the co-adaptation 35 | * of neurons as described in the paper [[https://arxiv.org/abs/1207.0580 Improving neural networks 36 | * by preventing co-adaptation of feature detectors]]. 37 | * 38 | * Furthermore, the outputs are scaled by a factor of $\frac{1}{1−p}​ during training. This means 39 | * that during evaluation the module simply computes an identity function. 40 | * 41 | * Shape: 42 | * - Input: $(∗)(∗)$. Input can be of any shape 43 | * - Output: $(∗)(∗)$. Output is of the same shape as input 44 | * 45 | * @example 46 | * 47 | * ```scala 48 | * import torch.nn 49 | * 50 | * val m = nn.Dropout(p=0.2) 51 | * val input = torch.randn(20, 16) 52 | * val output = m(input) 53 | * ``` 54 | * 55 | * @param p – probability of an element to be zeroed. Default: 0.5 56 | * @param inplace – If set to True, will do this operation in-place. Default: `false` 57 | * 58 | * @see See [[https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_embedding.html#class-embedding Pytorch C++ Embedding]] 59 | * @see See [[https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#dropout Pytorch Python Dropout]] 60 | * @see See [[https://pytorch.org/docs/master/nn.html#torch.nn.Dropout]] 61 | * @see See [[https://pytorch.org/docs/master/nn.html#torch.nn.Dropout2d]] 62 | * @see See [[https://pytorch.org/docs/master/nn.html#torch.nn.Dropout3d]] 63 | * @see See [[https://pytorch.org/docs/stable/generated/torch.nn.functional.dropout.html#torch-nn-functional-dropout]] 64 | * 65 | * TODO: https://pytorch.org/docs/master/nn.html#torch.nn.Dropout 66 | * Add 2D, 3D, Alpha and feature alpha versions 67 | */ 68 | // format: on 69 | final class Dropout[ParamType <: FloatNN | ComplexNN: Default]( 70 | p: Double = 0.5, 71 | inplace: Boolean = false 72 | ) extends HasParams[ParamType] 73 | with TensorModule[ParamType]: 74 | 75 | private val options: DropoutOptions = DropoutOptions(p) 76 | options.inplace().put(inplace) 77 | 78 | override private[torch] val nativeModule: DropoutImpl = DropoutImpl(options) 79 | nativeModule.to(paramType.toScalarType, false) 80 | 81 | def apply(t: Tensor[ParamType]): Tensor[ParamType] = fromNative(nativeModule.forward(t.native)) 82 | 83 | override def hasBias(): Boolean = false 84 | 85 | override def toString(): String = s"${getClass().getSimpleName()}(p=$p, inplace=$inplace)" 86 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/modules/sparse/Embedding.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | package sparse 21 | 22 | import org.bytedeco.pytorch 23 | import org.bytedeco.pytorch.EmbeddingImpl 24 | import org.bytedeco.pytorch.EmbeddingOptions 25 | import torch.internal.NativeConverters.{fromNative, toNative} 26 | 27 | // format: off 28 | /** A simple lookup table that stores embeddings of a fixed dictionary and size. 29 | * 30 | * This module is often used to store word embeddings and retrieve them using indices. The input to 31 | * the module is a list of indices, and the output is the corresponding word embeddings. 32 | * 33 | * @group nn_sparse 34 | * 35 | * @param numEmbeddings 36 | * Size of the dictionary of embeddings 37 | * @param embeddingDim 38 | * The size of each embedding vector 39 | * @param paddingIdx 40 | * If specified, the entries at `paddingIdx` do not contribute to the gradient; therefore, the 41 | * embedding vector at `paddingIdx` is not updated during training, i.e. it remains as a fixed 42 | * "pad". For a newly constructed Embedding, the embedding vector at `paddingIdx` will default to 43 | * all zeros, but can be updated to another value to be used as the padding vector. 44 | * @param maxNorm 45 | * If given, each embedding vector with norm larger than `maxNorm` is renormalized to have norm 46 | * `maxNorm`. 47 | * @param normType 48 | * The p of the p-norm to compute for the `maxNorm` option. Default `2`. 49 | * @param scaleGradByFreq 50 | * If given, this will scale gradients by the inverse of frequency of the words in the 51 | * mini-batch. Default `false`. 52 | * @param sparse 53 | * If ``True``, gradient w.r.t. `weight` matrix will be a sparse tensor. See Notes for more 54 | * details regarding sparse gradients. 55 | * 56 | * @see See [[https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_embedding.html#class-embedding Pytorch C++ Embedding]] 57 | * @see See [[https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html Pytorch Python Embedding]] 58 | */ 59 | // format: on 60 | final class Embedding[ParamType <: FloatNN | ComplexNN: Default]( 61 | numEmbeddings: Int, 62 | embeddingDim: Int, 63 | paddingIdx: Option[Int] = None, 64 | maxNorm: Option[Double] = None, 65 | normType: Option[Double] = Some(2.0), 66 | scaleGradByFreq: Boolean = false, 67 | sparse: Boolean = false 68 | ) extends HasParams[ParamType] 69 | with HasWeight[ParamType] 70 | with TensorModuleBase[Int64, ParamType]: 71 | 72 | private val options = new EmbeddingOptions(numEmbeddings.toLong, embeddingDim.toLong) 73 | paddingIdx.foreach(p => options.padding_idx().put(toNative(p))) 74 | maxNorm.foreach(m => options.max_norm().put(m)) 75 | normType.foreach(n => options.norm_type().put(n)) 76 | options.scale_grad_by_freq().put(scaleGradByFreq) 77 | options.sparse().put(sparse) 78 | 79 | override val nativeModule: EmbeddingImpl = EmbeddingImpl(options) 80 | nativeModule.to(paramType.toScalarType, false) 81 | 82 | override def hasBias(): Boolean = false 83 | 84 | def weight: Tensor[ParamType] = fromNative(nativeModule.weight) 85 | def weight_=(w: Tensor[ParamType]): Tensor[ParamType] = 86 | nativeModule.weight(w.native) 87 | w 88 | 89 | def apply(t: Tensor[Int64]): Tensor[ParamType] = fromNative(nativeModule.forward(t.native)) 90 | 91 | override def toString(): String = 92 | val numEmbed = s"numEmbeddings=$numEmbeddings" 93 | val dim = s"embeddingDim=$embeddingDim" 94 | val padding = s"paddingIdx=$paddingIdx" 95 | val maxN = s"maxNorm=$maxNorm" 96 | val normT = s"normType=$normType" 97 | val scale = s"scaleGradByFreq=$scaleGradByFreq" 98 | val s = s"sparse=$sparse" 99 | s"${getClass().getSimpleName()}($numEmbed, $dim, $padding, $maxN, $normT, $scale, $s )" 100 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | /** These are the basic building blocks for graphs. 20 | * 21 | * @groupname nn_conv Convolution Layers 22 | * @groupname nn_linear Linear Layers 23 | * @groupname nn_utilities Utilities 24 | */ 25 | package object nn { 26 | 27 | export modules.Module 28 | 29 | export modules.activation.Softmax 30 | export modules.activation.LogSoftmax 31 | export modules.activation.ReLU 32 | export modules.activation.Tanh 33 | export modules.batchnorm.BatchNorm1d 34 | export modules.batchnorm.BatchNorm2d 35 | export modules.container.Sequential 36 | export modules.container.ModuleList 37 | export modules.conv.Conv2d 38 | export modules.flatten.Flatten 39 | export modules.linear.Linear 40 | export modules.linear.Identity 41 | export modules.normalization.GroupNorm 42 | export modules.normalization.LayerNorm 43 | export modules.pooling.AdaptiveAvgPool2d 44 | export modules.pooling.MaxPool2d 45 | export modules.sparse.Embedding 46 | export modules.regularization.Dropout 47 | 48 | export loss.CrossEntropyLoss 49 | } 50 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/nn/utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | 20 | import org.bytedeco.pytorch.global.torch as torchNative 21 | import org.bytedeco.pytorch.TensorVector 22 | 23 | object utils: 24 | def clipGradNorm_( 25 | parameters: Seq[Tensor[?]], 26 | max_norm: Double, 27 | norm_type: Double = 2.0, 28 | error_if_nonfinite: Boolean = false 29 | ): Double = 30 | torchNative.clip_grad_norm_( 31 | TensorVector(parameters.map(_.native).toArray*), 32 | max_norm, 33 | norm_type, 34 | error_if_nonfinite 35 | ) 36 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/ops/BLASOps.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package ops 19 | 20 | /** BLAS and LAPACK Operations 21 | * 22 | * https://pytorch.org/docs/stable/torch.html#blas-and-lapack-operations 23 | */ 24 | private[torch] trait BLASOps { 25 | def matmul[D1 <: DType, D2 <: DType](t1: Tensor[D1], t2: Tensor[D2]): Tensor[Promoted[D1, D2]] = 26 | t1.matmul(t2) 27 | } 28 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/ops/ComparisonOps.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package ops 19 | 20 | import org.bytedeco.pytorch.global.torch as torchNative 21 | import internal.NativeConverters.fromNative 22 | 23 | /** Comparison Ops 24 | * 25 | * https://pytorch.org/docs/stable/torch.html#comparison-ops 26 | */ 27 | private[torch] trait ComparisonOps { 28 | 29 | def allclose( 30 | input: Tensor[?], 31 | other: Tensor[?], 32 | rtol: Double = 1e-05, 33 | atol: Double = 1e-08, 34 | equalNan: Boolean = false 35 | ): Boolean = 36 | torchNative.allclose(input.native, other.native, rtol, atol, equalNan) 37 | 38 | /** Returns the indices that sort a tensor along a given dimension in ascending order by value. 39 | * 40 | * This is the second value returned by `torch.sort`. See its documentation for the exact 41 | * semantics of this method. 42 | * 43 | * If `stable` is `True` then the sorting routine becomes stable, preserving the order of 44 | * equivalent elements. If `False`, the relative order of values which compare equal is not 45 | * guaranteed. `True` is slower. 46 | * 47 | * Args: {input} dim (int, optional): the dimension to sort along descending (bool, optional): 48 | * controls the sorting order (ascending or descending) stable (bool, optional): controls the 49 | * relative order of equivalent elements 50 | * 51 | * Example: 52 | * 53 | * ```scala sc 54 | * val a = torch.randn(Seq(4, 4)) 55 | * // tensor dtype=float32, shape=[4, 4], device=CPU 56 | * // [[ 0.0785, 1.5267, -0.8521, 0.4065], 57 | * // [ 0.1598, 0.0788, -0.0745, -1.2700], 58 | * // [ 1.2208, 1.0722, -0.7064, 1.2564], 59 | * // [ 0.0669, -0.2318, -0.8229, -0.9280]] 60 | * 61 | * torch.argsort(a, dim = 1) 62 | * // tensor dtype=int64, shape=[4, 4], device=CPU 63 | * // [[2, 0, 3, 1], 64 | * // [3, 2, 1, 0], 65 | * // [2, 1, 0, 3], 66 | * // [3, 2, 1, 0]] 67 | * ``` 68 | * 69 | * @group comparison_ops 70 | */ 71 | def argsort[D <: RealNN]( 72 | input: Tensor[D], 73 | dim: Int = -1, 74 | descending: Boolean = false 75 | // TODO implement stable, there are two boolean args in argsort and are not in order 76 | // stable: Boolean = false 77 | ): Tensor[Int64] = 78 | fromNative( 79 | torchNative.argsort(input.native, dim.toLong, descending) 80 | ) 81 | } 82 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/ops/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import internal.NativeConverters.{fromNative, tensorOptions} 20 | import org.bytedeco.pytorch 21 | import org.bytedeco.pytorch.MemoryFormatOptional 22 | 23 | package object ops { 24 | 25 | private[torch] def xLike[D <: DType, D2 <: DType | Derive]( 26 | input: Tensor[D], 27 | dtype: D2, 28 | layout: Layout | Derive, 29 | device: Device | Derive, 30 | requiresGrad: Boolean, 31 | memoryFormat: MemoryFormat, 32 | nativeFn: ( 33 | pytorch.Tensor, 34 | pytorch.TensorOptions, 35 | pytorch.MemoryFormatOptional 36 | ) => pytorch.Tensor 37 | ): Tensor[DTypeOrDeriveFromTensor[D, D2]] = { 38 | val derivedDType = dtype match 39 | case _: Derive => input.dtype 40 | case d: DType => d 41 | val derivedLayout = layout match 42 | case _: Derive => input.layout 43 | case l: Layout => l 44 | val derivedDevice = device match 45 | case _: Derive => input.device 46 | case d: Device => d 47 | fromNative( 48 | nativeFn( 49 | input.native, 50 | tensorOptions(derivedDType, derivedLayout, derivedDevice, requiresGrad), 51 | new MemoryFormatOptional(memoryFormat.toNative) 52 | ) 53 | ) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/optim/Adam.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package optim 19 | 20 | import org.bytedeco.pytorch 21 | import org.bytedeco.pytorch.{AdamOptions, TensorVector} 22 | 23 | import scala.collection.immutable.Iterable 24 | 25 | // format: off 26 | /** Implements the Adam algorithm. 27 | * 28 | * $$ 29 | * \begin{aligned} 30 | * &\rule{110mm}{0.4pt} \\ 31 | * &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 32 | * \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ 33 | * &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, 34 | * \:\textit{maximize} \\ 35 | * &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, 36 | * v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] 37 | * &\rule{110mm}{0.4pt} \\ 38 | * &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 39 | * 40 | * &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 41 | * &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 42 | * &\hspace{5mm}\textbf{else} \\ 43 | * &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 44 | * &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ 45 | * &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 46 | * &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 47 | * &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 48 | * &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ 49 | * &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 50 | * &\hspace{5mm}\textbf{if} \: amsgrad \\ 51 | * &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, 52 | * \widehat{v_t}) \\ 53 | * &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ 54 | * \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ 55 | * &\hspace{5mm}\textbf{else} \\ 56 | * &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ 57 | * \end{aligned} 58 | * $$ 59 | * 60 | * For further details regarding the algorithm we refer to 61 | * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). 62 | */ 63 | // format: on 64 | final class Adam( 65 | params: Iterable[Tensor[?]], 66 | lr: Double = 1e-3, 67 | betas: (Double, Double) = (0.9, 0.999), 68 | eps: Double = 1e-8, 69 | weightDecay: Double = 0, 70 | amsgrad: Boolean = false 71 | ) extends Optimizer { 72 | private val nativeParams: TensorVector = TensorVector(params.map(_.native).toArray*) 73 | private val options: AdamOptions = AdamOptions(lr) 74 | options.betas().put(Array(betas._1, betas._2)*) 75 | options.eps().put(eps) 76 | options.weight_decay().put(weightDecay) 77 | options.amsgrad().put(amsgrad) 78 | override private[torch] val native: pytorch.Adam = pytorch.Adam(nativeParams, options) 79 | } 80 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/optim/AdamW.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package optim 19 | 20 | import org.bytedeco.pytorch 21 | import org.bytedeco.pytorch.{AdamWOptions, TensorVector} 22 | 23 | import scala.collection.immutable.Iterable 24 | 25 | // format: off 26 | /** Implements the AdamW algorithm. 27 | * 28 | */ 29 | // format: on 30 | final class AdamW( 31 | params: Iterable[Tensor[?]], 32 | lr: Double = 1e-3, 33 | betas: (Double, Double) = (0.9, 0.999), 34 | eps: Double = 1e-8, 35 | weightDecay: Double = 0, 36 | amsgrad: Boolean = false 37 | ) extends Optimizer { 38 | private val nativeParams: TensorVector = TensorVector(params.map(_.native).toArray*) 39 | private val options: AdamWOptions = AdamWOptions(lr) 40 | options.betas().put(Array(betas._1, betas._2)*) 41 | options.eps().put(eps) 42 | options.weight_decay().put(weightDecay) 43 | options.amsgrad().put(amsgrad) 44 | override private[torch] val native: pytorch.AdamW = pytorch.AdamW(nativeParams, options) 45 | } 46 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/optim/Optimizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package optim 19 | 20 | import org.bytedeco.pytorch 21 | 22 | /** Base class for all optimizers. */ 23 | abstract class Optimizer { 24 | private[torch] def native: pytorch.Optimizer 25 | 26 | /** Performs a single optimization step (parameter update). 27 | * 28 | * @note 29 | * Unless otherwise specified, this function should not modify the ``.grad`` field of the 30 | * parameters. 31 | */ 32 | def step(): Unit = 33 | native.step() 34 | // TODO check what tensor is returned by step 35 | () 36 | 37 | /** Sets the gradients of all optimized `Tensor`s to zero. */ 38 | def zeroGrad(): Unit = native.zero_grad() 39 | def zeroGrad(setToNone: Boolean = true): Unit = native.zero_grad(setToNone) 40 | 41 | } 42 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/optim/SGD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package optim 19 | 20 | import org.bytedeco.pytorch 21 | import org.bytedeco.pytorch.{SGDOptions, TensorVector} 22 | 23 | import scala.collection.immutable.Iterable 24 | 25 | // format: off 26 | /** Implements stochastic gradient descent (optionally with momentum). 27 | * 28 | * $$ 29 | * \begin{aligned} 30 | * &\rule{110mm}{0.4pt} \\ 31 | * &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) 32 | * \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ 33 | * &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, 34 | * \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] 35 | * &\rule{110mm}{0.4pt} \\ 36 | * &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 37 | * &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 38 | * &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ 39 | * &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 40 | * &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ 41 | * &\hspace{10mm}\textbf{if} \: t > 1 \\ 42 | * &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ 43 | * &\hspace{10mm}\textbf{else} \\ 44 | * &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ 45 | * &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ 46 | * &\hspace{15mm} g_t \leftarrow g_{t-1} + \mu \textbf{b}_t \\ 47 | * &\hspace{10mm}\textbf{else} \\[-1.ex] 48 | * &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ 49 | * &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ 50 | * &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] 51 | * &\hspace{5mm}\textbf{else} \\[-1.ex] 52 | * &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] 53 | * &\rule{110mm}{0.4pt} \\[-1.ex] 54 | * &\bf{return} \: \theta_t \\[-1.ex] 55 | * &\rule{110mm}{0.4pt} \\[-1.ex] 56 | * \end{aligned} 57 | * $$ 58 | * 59 | * Nesterov momentum is based on the formula from 60 | * [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf) 61 | */ 62 | // format: on 63 | // TODO optionial parameters 64 | class SGD(params: Iterable[Tensor[?]], lr: Float) extends Optimizer { 65 | private val nativeParams = TensorVector(params.map(_.native).toArray*) 66 | private val options = SGDOptions(lr) 67 | override private[torch] val native: pytorch.SGD = pytorch.SGD(nativeParams, options) 68 | } 69 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/optim/lr_scheduler/LRScheduler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package optim 19 | package lr_scheduler 20 | 21 | trait LRScheduler: 22 | def step(): Unit 23 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/optim/lr_scheduler/StepLR.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package optim 19 | package lr_scheduler 20 | 21 | import org.bytedeco.pytorch 22 | 23 | /** Decays the learning rate of each parameter group by gamma every step_size epochs. 24 | * 25 | * Notice that such decay can happen simultaneously with other changes to the learning rate from 26 | * outside this scheduler. 27 | */ 28 | class StepLR(optimizer: Optimizer, step_size: Int, gamma: Float = 0.1) extends LRScheduler { 29 | private[torch] val native = pytorch.StepLR(optimizer.native, step_size, gamma) 30 | 31 | def step(): Unit = native.step() 32 | } 33 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | import scala.util.Using 18 | 19 | /** The torch package contains data structures for multi-dimensional tensors and defines 20 | * mathematical operations over these tensors. Additionally, it provides many utilities for 21 | * efficient serialization of Tensors and arbitrary types, and other useful utilities. 22 | * 23 | * It has a CUDA counterpart, that enables you to run your tensor computations on an NVIDIA GPU 24 | * with compute capability >= 3.0. 25 | * 26 | * @groupname creation_ops Creation Ops 27 | * @groupname pointwise_ops Pointwise Ops 28 | * @groupname reduction_ops Reduction Ops 29 | */ 30 | package object torch 31 | extends ops.BLASOps 32 | with ops.ComparisonOps 33 | with ops.CreationOps 34 | with ops.IndexingSlicingJoiningOps 35 | with ops.PointwiseOps 36 | with ops.RandomSamplingOps 37 | with ops.ReductionOps 38 | with ops.OtherOps { 39 | 40 | /** Disable gradient calculation for [[op]]. 41 | * 42 | * Disabling gradient calculation is useful for inference, when you are sure that you will not 43 | * call `Tensor.backward()`. It will reduce memory consumption for computations that would 44 | * otherwise have `requiresGrad=true`. 45 | * 46 | * In this mode, the result of every computation will have `requiresGrad=false`, even when the 47 | * inputs have `requiresGrad=true`. 48 | * 49 | * This context manager is thread local; it will not affect computation in other threads. 50 | * 51 | * @param op 52 | */ 53 | def noGrad[A](op: => A): A = { 54 | import org.bytedeco.pytorch.NoGradGuard 55 | Using.resource(NoGradGuard()) { _ => 56 | op 57 | } 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /core/src/main/scala/torch/special/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.bytedeco.pytorch.global.torch as torchNative 20 | 21 | import internal.NativeConverters.* 22 | 23 | package object special: 24 | /** Computes the logarithmic derivative of the gamma function on `input`. */ 25 | def digamma[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = 26 | fromNative(torchNative.digamma(input.native)) 27 | 28 | /** Computes the error function of `input`. */ 29 | def erf[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = 30 | fromNative(torchNative.erf(input.native)) 31 | 32 | /** Computes the complementary error function of `input`. */ 33 | def erfc[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = 34 | fromNative(torchNative.erfc(input.native)) 35 | 36 | /** Computes the inverse error function of `input`. The inverse error function is defined in the 37 | * range (−1,1) 38 | */ 39 | def erfinv[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = 40 | fromNative(torchNative.erfinv(input.native)) 41 | 42 | /** Computes the base two exponential function of `input`. */ 43 | def exp2[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = 44 | fromNative(torchNative.exp2(input.native)) 45 | 46 | /** Computes the exponential of the elements minus 1 of `input`. */ 47 | def expm1[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = 48 | fromNative(torchNative.expm1(input.native)) 49 | 50 | /** Computes the zeroth order modified Bessel function of the first kind for each element of 51 | * `input`. 52 | */ 53 | def i0[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = 54 | fromNative(torchNative.i0(input.native)) 55 | 56 | /** Computes the regularized lower incomplete gamma function */ 57 | // NOTE it is named `gammainc` in pytorch torch.special 58 | // TODO Change `D2 <: RealNN` once we fix property testing compilation 59 | def igamma[D <: RealNN, D2 <: FloatNN]( 60 | input: Tensor[D], 61 | other: Tensor[D2] 62 | )(using AtLeastOneFloat[D, D2]): Tensor[FloatPromoted[Promoted[D, D2]]] = 63 | fromNative(torchNative.igamma(input.native, other.native)) 64 | 65 | /** Computes the regularized upper incomplete gamma function */ 66 | // NOTE it is named `gamaincc` in pytorch torch.special 67 | // TODO Change `D2 <: RealNN` once we fix property testing compilation 68 | def igammac[D <: RealNN, D2 <: FloatNN]( 69 | input: Tensor[D], 70 | other: Tensor[D2] 71 | )(using AtLeastOneFloat[D, D2]): Tensor[FloatPromoted[Promoted[D, D2]]] = 72 | fromNative(torchNative.igammac(input.native, other.native)) 73 | 74 | /** Returns a new tensor with the logit of the elements of `input`. `input` is clamped to [eps, 1 75 | * \- eps] when eps is not None. When eps is None and input < 0 or input > 1, the function will 76 | * yields NaN. 77 | */ 78 | def logit[D <: RealNN](input: Tensor[D], eps: Option[Double]): Tensor[FloatPromoted[D]] = 79 | fromNative(torchNative.logit(input.native, toOptional(eps))) 80 | 81 | /** Computes the multivariate log-gamma function with dimension p element-wise */ 82 | // NOTE it is named `multigammaln` in pytorch torch.special 83 | def mvlgamma[D <: NumericRealNN](input: Tensor[D], p: Int): Tensor[FloatPromoted[D]] = 84 | fromNative(torchNative.mvlgamma(input.native, p)) 85 | 86 | /** Computes the nth derivative of the digamma function on `input`. n≥0 is called the order of the 87 | * polygamma function. 88 | */ 89 | def polygamma[D <: RealNN](n: Int, input: Tensor[D]): Tensor[FloatPromoted[D]] = 90 | fromNative(torchNative.polygamma(n, input.native)) 91 | 92 | /** Computes the expit (also known as the logistic sigmoid function) of the elements of `input`. 93 | */ 94 | // NOTE it is named `expit` in pytorch torch.special 95 | def sigmoid[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = 96 | fromNative(torchNative.sigmoid(input.native)) 97 | 98 | /** Returns a new tensor with the normalized sinc of the elements of `input`. */ 99 | def sinc[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = 100 | fromNative(torchNative.sinc(input.native)) 101 | 102 | /** Computes `input * log(other)` with the following cases. */ 103 | // TODO handle Scalar `input` 104 | def xlogy[D <: RealNN, D2 <: RealNN]( 105 | input: Tensor[D], 106 | other: TensorOrReal[D2] 107 | ): Tensor[FloatPromoted[D]] = 108 | fromNative( 109 | other match 110 | case other: Tensor[D2] => 111 | torchNative.xlogy(input.native, other.native) 112 | case other: Real => 113 | torchNative.xlogy(input.native, toScalar(other)) 114 | ) 115 | -------------------------------------------------------------------------------- /core/src/test/scala/TrainingSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | import torch.data.* 19 | 20 | class TraininSuite extends munit.FunSuite { 21 | test("training") { 22 | 23 | val xTrain = torch.arange(end = 10, dtype = float32) // .reshape(10, 1) 24 | val yTrain = Tensor(Seq(1.0f, 1.3f, 3.1f, 2.0f, 5.0f, 6.3f, 6.6f, 7.4f, 8.0f, 9.0f)) 25 | val xTrainNorm = ((xTrain - xTrain.mean) / xTrain.std) 26 | 27 | val ds = TensorSeq(xTrainNorm).zip(TensorSeq(yTrain)) 28 | 29 | torch.manualSeed(1) 30 | 31 | val weight = torch.randn(Seq(1), requiresGrad = true) 32 | val bias = torch.zeros(Seq(1), requiresGrad = true) 33 | 34 | def model(xb: Tensor[Float32]): Tensor[Float32] = (xb matmul weight) + bias 35 | 36 | def lossFn(input: Tensor[Float32], target: Tensor[Float32]) = (input - target).pow(2).mean 37 | 38 | val learningRate = 0.001f 39 | val numEpochs = 10 40 | val logEpochs = 1 41 | 42 | Range(0, xTrainNorm.size.head.toInt, 2) 43 | 44 | val dl = TupleDataLoader(ds, batchSize = 1, shuffle = true) 45 | 46 | val batch = dl.head 47 | val pred = model(batch._1) 48 | val loss = lossFn(pred, batch._2) 49 | loss.backward() 50 | 51 | for { 52 | epoch <- 0 to numEpochs 53 | loss = dl.map { (x, y) => 54 | val pred = model(x) 55 | val loss = lossFn(pred, y) 56 | loss.backward() 57 | noGrad { 58 | weight.grad.foreach { grad => 59 | weight -= grad * learningRate 60 | grad.zero_() 61 | } 62 | bias.grad.foreach { grad => 63 | weight -= grad * learningRate 64 | grad.zero_() 65 | } 66 | } 67 | loss 68 | }.last 69 | } { 70 | if (epoch % logEpochs == 0) println(s"Epoch ${epoch} Loss ${loss.item}") 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/DeviceSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import munit.ScalaCheckSuite 20 | import org.scalacheck.Prop.* 21 | import org.scalacheck._ 22 | import Generators.given 23 | 24 | class DeviceSuite extends ScalaCheckSuite { 25 | test("device native roundtrip") { 26 | val d = Device("cpu") 27 | assertEquals(d, Device(d.toNative)) 28 | } 29 | 30 | property("device native roundtrip for all") { 31 | forAll { (d: Device) => 32 | assertEquals(d, Device(d.toNative)) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/Generators.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.scalacheck.{Arbitrary, Gen} 20 | 21 | import scala.collection.immutable.ArraySeq 22 | 23 | object Generators: 24 | val genDeviceType: Gen[DeviceType] = Gen.oneOf(ArraySeq.unsafeWrapArray(DeviceType.values)) 25 | val genIndex: Gen[Byte] = Gen.chooseNum(-1, Byte.MaxValue) 26 | val genCpuIndex: Gen[Byte] = Gen.chooseNum[Byte](-1, 0) 27 | val genDevice: Gen[Device] = for 28 | deviceType <- genDeviceType // Arbitrary(genDeviceType).arbitrary 29 | i <- if deviceType == DeviceType.CPU then genCpuIndex else genIndex 30 | yield Device(deviceType, i) 31 | val genDimSize = Gen.choose(0, 30) 32 | val genTensorSize = Gen.choose(0, 5).flatMap(listSize => Gen.listOfN(listSize, genDimSize)) 33 | given Arbitrary[Device] = Arbitrary(genDevice) 34 | 35 | val allDTypes: List[DType] = List( 36 | int8, 37 | uint8, 38 | int16, 39 | int32, 40 | int64, 41 | float32, 42 | float64, 43 | // complex32, // NOTE: A lot of CPU operations do not support this dtype yet 44 | complex64, 45 | complex128, 46 | bool, 47 | // qint8, 48 | // quint8, 49 | // qint32, 50 | bfloat16 51 | // quint4x2, 52 | // float16, // NOTE: A lot of CPU operations do not support this dtype yet 53 | // undefined, 54 | // numoptions 55 | ) 56 | 57 | /* This method generates tensors of multiple DTypes, and it casts them to the given concrete subtype of DType, 58 | * so we can use them in operations that require a specific dtype at compile time but may fail with a runtime error. 59 | * It is being used for property testing, and complement-property testing of tensor operations. 60 | */ 61 | inline def genTensor[D <: DType]( 62 | filterDTypes: Boolean = false, 63 | tensorDimensions: Int = 2 64 | ): Gen[Tensor[D]] = 65 | Gen.oneOf(allDTypes.filter(_.isInstanceOf[D] || !filterDTypes)).map { dtype => 66 | ones(Seq.fill(tensorDimensions)(4), dtype = dtype.asInstanceOf[D]) 67 | } 68 | 69 | val genDType = Gen.oneOf(allDTypes) 70 | given Arbitrary[DType] = Arbitrary(genDType) 71 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/TensorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | class TensorSuite extends TensorCheckSuite { 20 | 21 | test("tensor properties") { 22 | val t = ones(Seq(2, 3), dtype = float32) 23 | assertEquals(t.size, Seq[Int](2, 3)) 24 | assertEquals(t.device, Device(DeviceType.CPU, -1: Byte)) 25 | assertEquals(t.numel, 2L * 3) 26 | } 27 | 28 | // property("tensor requiresGrad") { 29 | // forAll { (dtype: FloatNN | ComplexNN, requiresGrad: Boolean) => 30 | // val t = ones(Seq(2, 3), dtype, requiresGrad=requiresGrad) 31 | // assertEquals(t.dtype, dtype) 32 | // } 33 | // } 34 | 35 | test("exp and log") { 36 | val t = Tensor(Seq(1.0, 2.0, 3.0)) 37 | assertEquals(t.log(0), Tensor(0.0)) 38 | assert(torch.allclose(t.log.exp, t)) 39 | } 40 | 41 | test("toBuffer") { 42 | val content = Seq(1, 2, 3, 4) 43 | val t = Tensor(content) 44 | val b = t.toBuffer 45 | val a = new Array[Int](content.length) 46 | b.get(a) 47 | assertEquals(content, a.toSeq) 48 | } 49 | 50 | test("+") { 51 | assertEquals((Tensor(1) + 2).item, 3) 52 | } 53 | 54 | test("grad") { 55 | val t = torch.ones(Seq(3)) * 2 56 | t.requiresGrad = true 57 | val sum = t.sum 58 | assert(t.grad.isEmpty) 59 | sum.backward() 60 | assert(t.grad.isDefined) 61 | t.grad.map { grad => 62 | assertEquals(grad.dtype, float32) 63 | assert(grad.equal(torch.ones(Seq(3)))) 64 | } 65 | } 66 | 67 | test("indexing") { 68 | val tensor = torch.arange(0, 16).reshape(4, 4) 69 | // first row 70 | assertEquals(tensor(0), Tensor(Seq(0, 1, 2, 3))) 71 | // first column 72 | assertEquals(tensor(torch.Slice(), 0), Tensor(Seq(0, 4, 8, 12))) 73 | // last column 74 | assertEquals(tensor(---, -1), Tensor(Seq(3, 7, 11, 15))) 75 | } 76 | 77 | test("update/setter") { 78 | val tensor = torch.arange(0, 16).reshape(4, 4) 79 | tensor(Seq(0)) = 20 80 | assertEquals(tensor(0), torch.full(Seq(4), 20)) 81 | 82 | val updated = Tensor(30) 83 | tensor(Seq(1, 0)) = Tensor(30) 84 | assertEquals(tensor(1, 0), updated) 85 | 86 | // copy column 1 to column 0 87 | tensor(Seq(torch.Slice(), 1)) = tensor(torch.Slice(), 0) 88 | assertEquals(tensor(torch.Slice(), 1), tensor(torch.Slice(), 0)) 89 | } 90 | 91 | test("Tensor creation properly handling buffers") { 92 | val value = 100L 93 | val data = Seq.fill(10000)(value) 94 | val tensors = 1.to(1000).map { _ => 95 | Tensor(data) 96 | } 97 | assert( 98 | tensors.forall { t => 99 | t.min().item == value && 100 | t.max().item == value 101 | } 102 | ) 103 | } 104 | 105 | test("repeat") { 106 | val t = torch.Tensor(Seq(1, 2, 3)) 107 | val repeated = t.repeat(4, 2) 108 | 109 | val repeatCols = torch.cat(Seq(t, t)) 110 | val repeatRows = torch.stack(Seq.fill(4)(repeatCols)) 111 | 112 | assert(repeated equal repeatRows) 113 | 114 | assertEquals(t.repeat(4, 2, 1).size, Seq(4, 2, 3)) 115 | } 116 | 117 | test("trace") { 118 | val t = torch.eye(3) 119 | assertEquals(t.trace, Tensor(3f)) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/ConvolutionSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | import functional as F 20 | 21 | class ConvolutionSuite extends munit.FunSuite { 22 | 23 | test("mismatchShapeConv2d") { 24 | val dtypes = List[FloatNN | ComplexNN](torch.float32, torch.complex64) 25 | for (dtype <- dtypes) { 26 | val x = torch.randn(Seq(1, 10, 1, 28, 28), dtype) 27 | val w = torch.randn(Seq(6, 1, 5, 5), dtype) 28 | 29 | intercept[RuntimeException](F.conv2d(x, w)) 30 | // TODO find a way to run interceptMessage comparing only the first line/string prefix as we don't care about the c++ stacktrace here 31 | // interceptMessage[RuntimeException] { 32 | // """Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 10, 1, 28, 28]""" 33 | // } { 34 | // conv2d(x, w) 35 | // } 36 | } 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/PoolingSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | 20 | import functional as F 21 | import org.scalacheck.Gen 22 | import torch.Generators.allDTypes 23 | 24 | class PoolingSuite extends TensorCheckSuite { 25 | test("MaxPool2d output shapes") { 26 | val input = torch.randn(Seq(1, 3, 244, 244)) 27 | // pool of square window of size=3, stride=2 28 | val m1 = MaxPool2d[Float32](3, stride = Some(2)) 29 | assertEquals(m1(input).shape, Seq(1, 3, 121, 121)) 30 | // pool of non-square window 31 | val m2 = MaxPool2d[Float32]((3, 2), stride = Some(2, 1)) 32 | assertEquals(m2(input).shape, Seq(1, 3, 121, 243)) 33 | val m3 = MaxPool2d[Float32](3) 34 | assertEquals(m3(input).shape, Seq(1, 3, 81, 81)) 35 | } 36 | 37 | val shape3d = Seq(16, 50, 32) 38 | propertyTestUnaryOp(F.avgPool1d(_, 3), "avgPool1d", genRandTensor(shape3d)) 39 | propertyTestUnaryOp(F.maxPool1d(_, 3), "maxPool1d", genRandTensor(shape3d)) 40 | propertyTestUnaryOp(F.maxPool1dWithIndices(_, 3), "maxPool1dWithIndices", genRandTensor(shape3d)) 41 | 42 | inline def genRandTensor[D <: FloatNN | ComplexNN](shape: Seq[Int] = Seq(3, 3)): Gen[Tensor[D]] = 43 | Gen.oneOf(allDTypes.filter(_.isInstanceOf[D])).map { dtype => 44 | torch.rand(shape, dtype = dtype.asInstanceOf[D]) 45 | } 46 | 47 | val shape4d = Seq(8, 16, 50, 32) 48 | propertyTestUnaryOp(F.avgPool2d(_, 3), "avgPool2d", genRandTensor(shape4d)) 49 | propertyTestUnaryOp(F.maxPool2d(_, 3), "maxPool2d", genRandTensor(shape4d)) 50 | propertyTestUnaryOp(F.maxPool2dWithIndices(_, 3), "maxPool2dWithIndices", genRandTensor(shape4d)) 51 | 52 | val shape5d = Seq(2, 16, 50, 44, 31) 53 | propertyTestUnaryOp( 54 | F.avgPool3d(_, (3, 2, 2), stride = (2, 1, 2)), 55 | "avgPool3d", 56 | genRandTensor(shape5d) 57 | ) 58 | propertyTestUnaryOp( 59 | F.maxPool3d(_, (3, 2, 2), stride = (2, 1, 2)), 60 | "maxPool3d", 61 | genRandTensor(shape5d) 62 | ) 63 | propertyTestUnaryOp( 64 | F.maxPool3dWithIndices(_, (3, 2, 2), stride = (2, 1, 2)), 65 | "maxPool3dWithIndices", 66 | genRandTensor(shape5d) 67 | ) 68 | } 69 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/functional/SparseSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package functional 20 | 21 | import Generators.genTensor 22 | 23 | class SparseSuite extends TensorCheckSuite { 24 | 25 | // TODO Test multi-dimensional tensors 26 | testUnaryOp( 27 | op = nn.functional.oneHot(_, numClasses = 6), 28 | opName = "nn.functional.oneHot", 29 | inputTensor = Tensor(3L), 30 | expectedTensor = Tensor(Seq(0L, 0L, 0L, 1L, 0L, 0L)), 31 | // TODO Fix genTensor for cases where the tensor type is not a union, but a concrete one, such as Tensor[Int64] 32 | genTensor = genTensor[Int64](filterDTypes = true) 33 | ) 34 | 35 | } 36 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/ActivationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class ActivationSuite extends munit.FunSuite { 22 | test("LogSoftmax") { 23 | torch.manualSeed(0) 24 | val m = nn.LogSoftmax(dim = 1) 25 | val input = torch.randn(Seq(2, 3)) 26 | val output = m(input) 27 | assertEquals(output.shape, input.shape) 28 | val expectedOutput = Tensor( 29 | Seq( 30 | Seq(-0.1689f, -2.0033f, -3.8886f), 31 | Seq(-0.2862f, -1.9392f, -2.2532f) 32 | ) 33 | ) 34 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 35 | } 36 | 37 | // TODO ReLU 38 | // TODO Softmax 39 | 40 | test("Tanh") { 41 | torch.manualSeed(0) 42 | val m = nn.Tanh() 43 | val input = torch.randn(Seq(2)) 44 | val output = m(input) 45 | assertEquals(output.shape, input.shape) 46 | val expectedOutput = Tensor(Seq(0.9123f, -0.2853f)) 47 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/BatchNormSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class BatchNormSuite extends munit.FunSuite { 22 | 23 | test("BatchNorm1d") { 24 | torch.manualSeed(0) 25 | val m = nn.BatchNorm1d(numFeatures = 3) 26 | val input = torch.randn(Seq(3, 3)) 27 | val output = m(input) 28 | assertEquals(output.shape, input.shape) 29 | val expectedOutput = Tensor( 30 | Seq( 31 | Seq(1.4014f, -0.1438f, -1.2519f), 32 | Seq(-0.5362f, -1.1465f, 0.0564f), 33 | Seq(-0.8651f, 1.2903f, 1.1956f) 34 | ) 35 | ) 36 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 37 | } 38 | 39 | test("BatchNorm2d") { 40 | torch.manualSeed(0) 41 | val m = nn.BatchNorm2d(numFeatures = 3) 42 | val input = torch.randn(Seq(3, 3, 1, 1)) 43 | val output = m(input) 44 | assertEquals(output.shape, input.shape) 45 | val expectedOutput = Tensor( 46 | Seq( 47 | Seq(1.4014f, -0.1438f, -1.2519f), 48 | Seq(-0.5362f, -1.1465f, 0.0564f), 49 | Seq(-0.8651f, 1.2903f, 1.1956f) 50 | ) 51 | ) 52 | assert(torch.allclose(output.squeeze, expectedOutput, atol = 1e-4)) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/EmbeddingSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class EmbeddingSuite extends munit.FunSuite { 22 | 23 | test("Embedding") { 24 | { 25 | torch.manualSeed(0) 26 | val embedding = nn.Embedding(10, 3) 27 | // a batch of 2 samples of 4 indices each 28 | val input = torch.Tensor(Seq(Seq(1L, 2, 4, 5), Seq(4L, 3, 2, 9))) 29 | val output = embedding(input) 30 | val expectedOutput = Tensor( 31 | Seq( 32 | Seq( 33 | Seq(-0.4339f, 0.8487f, 0.6920f), 34 | Seq(-0.3160f, -2.1152f, 0.3223f), 35 | Seq(0.1198f, 1.2377f, -0.1435f), 36 | Seq(-0.1116f, -0.6136f, 0.0316f) 37 | ), 38 | Seq( 39 | Seq(0.1198f, 1.2377f, -0.1435f), 40 | Seq(-1.2633f, 0.3500f, 0.3081f), 41 | Seq(-0.3160f, -2.1152f, 0.3223f), 42 | Seq(0.0525f, 0.5229f, 2.3022f) 43 | ) 44 | ) 45 | ) 46 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 47 | } 48 | { 49 | torch.manualSeed(0) 50 | // example with padding_idx 51 | val embedding = nn.Embedding(5, 3, paddingIdx = Some(0)) 52 | embedding.weight = Tensor( 53 | Seq( 54 | Seq(0f, 0f, 0f), 55 | Seq(0.5684f, -1.0845f, -1.3986f), 56 | Seq(0.4033f, 0.8380f, -0.7193f), 57 | Seq(0.4033f, 0.8380f, -0.7193f), 58 | Seq(-0.8567f, 1.1006f, -1.0712f) 59 | ) 60 | ) 61 | val input = torch.Tensor(Seq(Seq(0L, 2, 0, 4))) 62 | val output = embedding(input) 63 | 64 | val expectedOutput = Tensor( 65 | Seq( 66 | Seq(0f, 0f, 0f), 67 | Seq(0.4033f, 0.8380f, -0.7193f), 68 | Seq(0f, 0f, 0f), 69 | Seq(-0.8567f, 1.1006f, -1.0712f) 70 | ) 71 | ).unsqueeze(0) 72 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 73 | } 74 | { 75 | torch.manualSeed(0) 76 | // example of changing `pad` vector 77 | val paddingIdx = 0 78 | val embedding = nn.Embedding(3, 3, paddingIdx = Some(paddingIdx)) 79 | noGrad { 80 | embedding.weight(Seq(paddingIdx)) = torch.ones(3) 81 | } 82 | val expectedOutput = Tensor( 83 | Seq( 84 | Seq(1f, 1f, 1f), 85 | Seq(0.5684f, -1.0845f, -1.3986f), 86 | Seq(0.4033f, 0.8380f, -0.7193f) 87 | ) 88 | ) 89 | assert(torch.allclose(embedding.weight, expectedOutput, atol = 1e-4)) 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/FlattenSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class FlattenSuite extends munit.FunSuite { 22 | test("Flatten") { 23 | val input = torch.randn(Seq(32, 1, 5, 5)) 24 | val m1 = nn.Flatten() 25 | val o1 = m1(input) 26 | assertEquals(o1.shape, Seq(32, 25)) 27 | assert(input.reshape(32, 25).equal(o1)) 28 | val m2 = nn.Flatten(0, 2) 29 | val o2 = m2(input) 30 | assertEquals(o2.shape, Seq(160, 5)) 31 | assert(input.reshape(160, 5).equal(o2)) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/LinearSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class LinearSuite extends munit.FunSuite { 22 | test("Linear shape") { 23 | val linear = Linear(20, 30) 24 | val input = randn(Seq(128, 20)) 25 | assertEquals(linear(input).shape, Seq(128, 30)) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/NormalizationSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class NormalizationSuite extends munit.FunSuite { 22 | 23 | test("LayerNorm") { 24 | { 25 | torch.manualSeed(0) 26 | val (batch, sentenceLength, embeddingDim) = (2, 2, 3) 27 | val embedding = torch.randn(Seq(batch, sentenceLength, embeddingDim)) 28 | val layerNorm = nn.LayerNorm(Seq(embeddingDim)) 29 | val output = layerNorm(embedding) 30 | assertEquals(output.shape, embedding.shape) 31 | val expectedOutput = Tensor( 32 | Seq( 33 | Seq( 34 | Seq(1.2191f, 0.0112f, -1.2303f), 35 | Seq(1.3985f, -0.5172f, -0.8813f) 36 | ), 37 | Seq( 38 | Seq(0.3495f, 1.0120f, -1.3615f), 39 | Seq(-0.3948f, -0.9786f, 1.3734f) 40 | ) 41 | ) 42 | ) 43 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 44 | } 45 | { 46 | torch.manualSeed(0) 47 | val (n, c, h, w) = (1, 2, 2, 2) 48 | val input = torch.randn(Seq(n, c, h, w)) 49 | // Normalize over the last three dimensions (i.e. the channel and spatial dimensions) 50 | val layerNorm = nn.LayerNorm(Seq(c, h, w)) 51 | val output = layerNorm(input) 52 | assertEquals(output.shape, (Seq(n, c, h, w))) 53 | val expectedOutput = Tensor( 54 | Seq( 55 | Seq( 56 | Seq(1.4715f, -0.0785f), 57 | Seq(-1.6714f, 0.6497f) 58 | ), 59 | Seq( 60 | Seq(-0.7469f, -1.0122f), 61 | Seq(0.5103f, 0.8775f) 62 | ) 63 | ) 64 | ).unsqueeze(0) 65 | assert(torch.allclose(output, expectedOutput, atol = 1e-4)) 66 | } 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/nn/modules/PoolingSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package nn 19 | package modules 20 | 21 | class AdapativeAvgPool2dSuite extends munit.FunSuite { 22 | test("AdapativeAvgPool2d output shapes") { 23 | val m1 = AdaptiveAvgPool2d((5, 7)) 24 | val input = torch.randn(Seq(1, 64, 8, 9)) 25 | assertEquals(m1(input).shape, Seq(1, 64, 5, 7)) 26 | val m2 = nn.AdaptiveAvgPool2d((1, 1)) 27 | assertEquals(m2(input).shape, Seq(1, 64, 1, 1)) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/ops/ComparisonOpsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | class ComparisonOpsSuite extends TensorCheckSuite { 20 | 21 | testUnaryOp( 22 | op = argsort(_), 23 | opName = "argsort", 24 | inputTensor = Tensor(Seq(1, 3, 2)), 25 | expectedTensor = Tensor(Seq(0L, 2L, 1L)) 26 | ) 27 | 28 | } 29 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/ops/CreationOpsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | import org.scalacheck.Prop.* 20 | import Generators.* 21 | 22 | class CreationOpsSuite extends TensorCheckSuite { 23 | 24 | test("arange.unit-test") { 25 | val t0 = arange(0, 10) 26 | assertEquals(t0.toSeq, Seq.range(0, 10)) 27 | val t1 = arange(0, 10, 2) 28 | assertEquals(t1.toSeq, Seq.range(0, 10, 2)) 29 | } 30 | 31 | property("ones.property-test") { 32 | forAll(genTensorSize, genDType) { (size, dtype) => 33 | val t = ones(size, dtype) 34 | assertEquals(t.dtype, dtype) 35 | assertEquals(t.size, size) 36 | assertEquals(t.numel, size.product.toLong) 37 | assertEquals(t.toSeq.length, size.product.toInt) 38 | } 39 | } 40 | 41 | test("ones.unit-test") { 42 | val t = ones[Float32](Seq(2, 3)) 43 | assertEquals(t.size, Seq(2, 3)) 44 | assertEquals(t.numel, 2L * 3) 45 | assertEquals(t.toSeq, Seq.fill[Float](2 * 3)(1f)) 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/ops/OtherOpsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | package ops 19 | 20 | class OtherOpsSuite extends TensorCheckSuite { 21 | test("einsum") { 22 | // trace 23 | torch.einsum("ii", torch.eye(5)).item == 5f 24 | val a = torch.arange(end = 25).reshape(5, 5) 25 | val b = torch.arange(end = 5) 26 | assert(torch.einsum("ii", a) equal torch.trace(a)) 27 | // diagonal 28 | assert(torch.einsum("ii->i", a) equal Tensor(Seq(0, 6, 12, 18, 24))) 29 | // inner product 30 | assert(torch.einsum("i,i", b, b) equal Tensor(30)) 31 | // matrix vector multiplication 32 | assert(torch.einsum("ij,j", a, b) equal Tensor(Seq(30, 80, 130, 180, 230))) 33 | } 34 | 35 | test("trace") { 36 | val t = torch.arange(1f, 10f).view(3, 3) 37 | assert(torch.trace(t) equal Tensor(15f)) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /core/src/test/scala/torch/ops/RandomSamplingOpsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torch 18 | 19 | class RandomSamplingOpsSuite extends TensorCheckSuite { 20 | 21 | testUnaryOp( 22 | op = multinomial(_, 2, true), 23 | opName = "multinomial", 24 | inputTensor = Tensor(Seq(0.0, 0.0, 0.0, 1.0)), 25 | expectedTensor = Tensor(Seq(3L, 3L)) 26 | ) 27 | 28 | test("randint.unit-test") { 29 | // randint generates uniform numbers in the range [min, max) 30 | val low = 0 31 | val high = 4 32 | val randintTensor = randint(low, high + 1, Seq(100000)).to(dtype = float32) 33 | val randintMean = randintTensor.mean 34 | val expectedMean = Tensor(high / 2).to(dtype = float32) 35 | 36 | assert(allclose(randintMean, expectedMean, atol = 1e-2)) 37 | 38 | val g1 = torch.Generator() 39 | g1.manualSeed(0) 40 | val t1 = torch.randint(high = 100, Seq(2, 2), generator = g1) 41 | val t2 = torch.randint(high = 100, Seq(2, 2), generator = g1) 42 | assertNotEquals(t1, t2) 43 | 44 | val g2 = torch.Generator() 45 | g2.manualSeed(0) 46 | val t3 = torch.randint(high = 100, Seq(2, 2), generator = g2) 47 | assertEquals(t1, t3) 48 | 49 | } 50 | 51 | test("randn.unit-test") { 52 | val randnTensor = randn(Seq(100000)) 53 | val randnMean = randnTensor.mean 54 | val expectedMean = Tensor(0.0).to(dtype = float32) 55 | val randnVariance = randnTensor.variance 56 | val expectedVariance = Tensor(1.0).to(dtype = float32) 57 | 58 | assert( 59 | allclose(randnMean, expectedMean, atol = 1e-2) && 60 | allclose(randnVariance, expectedVariance, atol = 1e-2) 61 | ) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /devenv.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "devenv": { 4 | "locked": { 5 | "dir": "src/modules", 6 | "lastModified": 1677225427, 7 | "narHash": "sha256-+M4LGzGVAhkM0HOy2lexDUBwlLnv8UUX32utlSFaMc4=", 8 | "owner": "cachix", 9 | "repo": "devenv", 10 | "rev": "1aa3dbbf745f50e77aadc016d124fed3ca2dc9be", 11 | "type": "github" 12 | }, 13 | "original": { 14 | "dir": "src/modules", 15 | "owner": "cachix", 16 | "repo": "devenv", 17 | "type": "github" 18 | } 19 | }, 20 | "flake-compat": { 21 | "flake": false, 22 | "locked": { 23 | "lastModified": 1673956053, 24 | "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", 25 | "owner": "edolstra", 26 | "repo": "flake-compat", 27 | "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "edolstra", 32 | "repo": "flake-compat", 33 | "type": "github" 34 | } 35 | }, 36 | "flake-utils": { 37 | "locked": { 38 | "lastModified": 1667395993, 39 | "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", 40 | "owner": "numtide", 41 | "repo": "flake-utils", 42 | "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", 43 | "type": "github" 44 | }, 45 | "original": { 46 | "owner": "numtide", 47 | "repo": "flake-utils", 48 | "type": "github" 49 | } 50 | }, 51 | "gitignore": { 52 | "inputs": { 53 | "nixpkgs": [ 54 | "pre-commit-hooks", 55 | "nixpkgs" 56 | ] 57 | }, 58 | "locked": { 59 | "lastModified": 1660459072, 60 | "narHash": "sha256-8DFJjXG8zqoONA1vXtgeKXy68KdJL5UaXR8NtVMUbx8=", 61 | "owner": "hercules-ci", 62 | "repo": "gitignore.nix", 63 | "rev": "a20de23b925fd8264fd7fad6454652e142fd7f73", 64 | "type": "github" 65 | }, 66 | "original": { 67 | "owner": "hercules-ci", 68 | "repo": "gitignore.nix", 69 | "type": "github" 70 | } 71 | }, 72 | "nixpkgs": { 73 | "locked": { 74 | "lastModified": 1677352614, 75 | "narHash": "sha256-VYo1cSiCHDXZrHO8pb0c9EGob7C75lCPx1jBMi9UAlU=", 76 | "owner": "NixOS", 77 | "repo": "nixpkgs", 78 | "rev": "bf592ea571b11dfee17a74d022f0b481ca5f1319", 79 | "type": "github" 80 | }, 81 | "original": { 82 | "owner": "NixOS", 83 | "ref": "nixpkgs-unstable", 84 | "repo": "nixpkgs", 85 | "type": "github" 86 | } 87 | }, 88 | "nixpkgs-stable": { 89 | "locked": { 90 | "lastModified": 1673800717, 91 | "narHash": "sha256-SFHraUqLSu5cC6IxTprex/nTsI81ZQAtDvlBvGDWfnA=", 92 | "owner": "NixOS", 93 | "repo": "nixpkgs", 94 | "rev": "2f9fd351ec37f5d479556cd48be4ca340da59b8f", 95 | "type": "github" 96 | }, 97 | "original": { 98 | "owner": "NixOS", 99 | "ref": "nixos-22.11", 100 | "repo": "nixpkgs", 101 | "type": "github" 102 | } 103 | }, 104 | "pre-commit-hooks": { 105 | "inputs": { 106 | "flake-compat": "flake-compat", 107 | "flake-utils": "flake-utils", 108 | "gitignore": "gitignore", 109 | "nixpkgs": [ 110 | "nixpkgs" 111 | ], 112 | "nixpkgs-stable": "nixpkgs-stable" 113 | }, 114 | "locked": { 115 | "lastModified": 1677160285, 116 | "narHash": "sha256-tBzpCjMP+P3Y3nKLYvdBkXBg3KvTMo3gvi8tLQaqXVY=", 117 | "owner": "cachix", 118 | "repo": "pre-commit-hooks.nix", 119 | "rev": "2bd861ab81469428d9c823ef72c4bb08372dd2c4", 120 | "type": "github" 121 | }, 122 | "original": { 123 | "owner": "cachix", 124 | "repo": "pre-commit-hooks.nix", 125 | "type": "github" 126 | } 127 | }, 128 | "root": { 129 | "inputs": { 130 | "devenv": "devenv", 131 | "nixpkgs": "nixpkgs", 132 | "pre-commit-hooks": "pre-commit-hooks" 133 | } 134 | } 135 | }, 136 | "root": "root", 137 | "version": 7 138 | } 139 | -------------------------------------------------------------------------------- /devenv.nix: -------------------------------------------------------------------------------- 1 | { pkgs, inputs, ... }: 2 | 3 | let 4 | packages = if pkgs.stdenv.isDarwin 5 | then inputs.nixpkgs.legacyPackages.x86_64-darwin 6 | else pkgs; 7 | in 8 | { 9 | packages = with packages; [ 10 | sbt 11 | ]; 12 | 13 | scripts.hello.exec = "echo ---STORCH---"; 14 | 15 | enterShell = '' 16 | hello 17 | ''; 18 | } 19 | -------------------------------------------------------------------------------- /devenv.yaml: -------------------------------------------------------------------------------- 1 | inputs: 2 | nixpkgs: 3 | url: github:NixOS/nixpkgs/nixpkgs-unstable 4 | # can't point to the local modules here as it's used as a template -------------------------------------------------------------------------------- /docs/about.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | Storch is a Scala library for fast tensor computations and deep learning, based on [PyTorch](https://pytorch.org/). 4 | 5 | Like PyTorch, Storch provides 6 | 7 | * A NumPy like API for working with tensors 8 | * GPU support 9 | * Automatic differentiation 10 | * A neural network API for building and training neural networks. 11 | 12 | Storch aims to stay close to the original PyTorch API to make porting existing models and the life of people already familiar with PyTorch easier. 13 | 14 | ```scala mdoc:invisible 15 | torch.manualSeed(0) 16 | ``` 17 | 18 | ```scala mdoc 19 | val data = Seq(0,1,2,3) 20 | val t1 = torch.Tensor(data) 21 | t1.equal(torch.arange(0,4)) 22 | val t2 = t1.to(dtype=torch.float32) 23 | val t3 = t1 + t2 24 | 25 | val shape = Seq(2,3) 26 | val randTensor = torch.rand(shape) 27 | val zerosTensor = torch.zeros(shape, dtype=torch.int64) 28 | 29 | val x = torch.ones(Seq(5)) 30 | val w = torch.randn(Seq(5, 3), requiresGrad=true) 31 | val b = torch.randn(Seq(3), requiresGrad=true) 32 | val z = (x matmul w) + b 33 | ``` 34 | 35 | One notable difference is that tensors in Storch are statically typed regarding the underlying `dtype`. 36 | So you'll see `Tensor[Float32]` or `Tensor[Int8]` instead of just `Tensor`. 37 | 38 | Tracking the data type at compile time enables us to catch certain errors earlier. For instance, `torch.rand` is only implemented for float types and the following will trigger a runtime error in PyTorch: 39 | ```python 40 | torch.rand([3,3], dtype=torch.int32) # RuntimeError: "check_uniform_bounds" not implemented for 'Int' 41 | ``` 42 | 43 | In Storch, the same code does not compile: 44 | ```scala mdoc:fail 45 | torch.rand(Seq(3,3), dtype=torch.int32) 46 | ``` 47 | 48 | Storch is powered by [LibTorch](https://pytorch.org/cppdocs/index.html), the C++ library underlying PyTorch and 49 | [JavaCPP](https://github.com/bytedeco/javacpp), which provides generated Java bindings for LibTorch as well as important utilities to integrate with native code on the JVM. -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | Contributions to Storch are most welcome. Please have a look at the existing 4 | [issues](https://github.com/sbrunk/storch/issues) or don't hesitate to open a new one 5 | to discuss ideas. 6 | 7 | ## Development setup 8 | 9 | You need to have [sbt](https://www.scala-sbt.org/) installed to build Storch from source. 10 | 11 | Then clone the git repo and you should be ready to go. 12 | ```bash 13 | git clone https://github.com/sbrunk/storch 14 | sbt 15 | ``` 16 | 17 | ## Enabling GPU support during development 18 | 19 | GPU support is disabled by default. To enable it, set the following parameter inside the sbt shell: 20 | 21 | ```scala 22 | set ThisBuild / enableGPU := true 23 | ``` 24 | 25 | You can verify if the GPU is working by running the LeNet example. 26 | If it's working, you should see an output like this: 27 | 28 | ``` 29 | sbt examples/runMain LeNetApp 30 | [info] running (fork) LeNetApp 31 | [info] Using device: Device(CUDA,-1) 32 | ... 33 | ``` 34 | 35 | ## Edit the documentation 36 | 37 | Documentation sources live in the *docs* directory. 38 | We use [mdoc](https://scalameta.org/mdoc/) for typechecked documenation and to embed code output. 39 | The website is rendered by [Laika](https://typelevel.org/Laika/). 40 | To build the documentation locally, you can run the following command: 41 | 42 | ```bash 43 | sbt ~tlSitePreview 44 | ``` 45 | 46 | Then open http://localhost:4242 in a browser enjoy a live preview while hacking on the docs. 47 | 48 | To just build Scaladoc for all modules, you can run 49 | 50 | ```bash 51 | sbt ~unidoc 52 | ``` 53 | 54 | ## Linting 55 | 56 | Manually run headerCrate + scalafmt on all files: 57 | 58 | ```bash 59 | sbt 'headerCreateAll ; scalafmtAll' 60 | ``` 61 | 62 | Add useful git pre-push linting checks: 63 | 64 | ```bash 65 | cp git-hooks/pre-push-checks .git/hooks/ && chmod +x git-hooks/pre-push-checks 66 | ``` 67 | 68 | ## Optional: Install dependencies via nix + devenv 69 | 70 | You can use nix and devenv to install your develepment environment, but it's not required. 71 | 72 | 1. Install [nix](https://nixos.org) package manager 73 | 74 | ```bash 75 | sh <(curl -L https://nixos.org/nix/install) --daemon 76 | ``` 77 | 78 | For more info, see https://nixos.org/download.html 79 | 80 | 2. Install [devenv](https://devenv.sh) 81 | 82 | ```bash 83 | nix profile install --accept-flake-config github:cachix/devenv/latest 84 | ``` 85 | 86 | For more info, see: https://devenv.sh/getting-started/#installation 87 | 88 | 3. (Optionally) Install [direnv](https://direnv.net/) 89 | 90 | This will load the specific environment variables upon `cd` into the storch folder 91 | 92 | ```bash 93 | nix profile install 'nixpkgs#direnv' 94 | ``` 95 | 96 | 4. Load environment 97 | 98 | If you did not install direnv, run the following in the `storch` root folder: 99 | 100 | ```bash 101 | devenv shell 102 | ``` 103 | 104 | If you installed direnv, just `cd` into storch -------------------------------------------------------------------------------- /docs/directory.conf: -------------------------------------------------------------------------------- 1 | laika.navigationOrder = [ 2 | about.md 3 | installation.md 4 | modules.md 5 | examples.md 6 | pre-trained-weights.md 7 | faq.md 8 | contributing.md 9 | ] -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | You can find runnable examples in the [examples](https://github.com/sbrunk/storch/tree/main/examples/src/main/scala) directory. 4 | We are planning to add more examples for different tasks. 5 | If you have an idea or want to contribute or improve an example, please don't hesitate to open an issue or create a PR. 6 | 7 | ## Running the examples 8 | 9 | You can clone the repository and run the examples as part of the build, for instance: 10 | 11 | ```bash 12 | sbt> examples/runMain LeNetApp 13 | ``` 14 | 15 | The examples are also scala-cli scripts so you can run them with [scala-cli](https://scala-cli.virtuslab.org/), either locally or directly from the repo: 16 | 17 | ```bash 18 | scala-cli https://raw.githubusercontent.com/sbrunk/storch/main/examples/src/main/scala/ImageClassifier.scala 19 | ``` 20 | 21 | ## Image classifier example 22 | 23 | Example script for training an image-classification model on your own images and running inference. 24 | It uses the [ResNet](https://github.com/sbrunk/storch/blob/main/vision/src/main/scala/torchvision/models/resnet.scala) model implementation. 25 | 26 | It will also automatically download converted pre-trained weights from the [releases](https://github.com/sbrunk/storch/releases/tag/pretrained-weights). See [converting pre-trained weights from PyTorch] for details. 27 | 28 | ### Training 29 | 30 | To train a new image classifier on your own images run: 31 | 32 | ```bash 33 | scala-cli ImageClassifier.scala -- train --dataset-dir 34 | ``` 35 | 36 | Where the expected dataset is a directory per class with examples, like this: 37 | ``` 38 | . 39 | ├── PetImages 40 | ├── Cat 41 | │ ├── 1.jpg 42 | │ ├── 2.jpg 43 | │ ├── ... 44 | └── Dog 45 | ├── 1.jpg 46 | ├── 2.jpg 47 | ├── ... 48 | ``` 49 | 50 | Using a smaller base model: 51 | ```bash 52 | scala-cli ImageClassifier.scala -- train --base-model ResNet18 --dataset-dir 53 | ``` 54 | 55 | To see all options, run: 56 | ```bash 57 | scala-cli ImageClassifier.scala -- train -h 58 | ``` 59 | 60 | #### Training on the GPU 61 | 62 | Right now, if you're using scala-cli you have to edit the directives at the top of the `ImageClassifier.scala` 63 | script to enable GPU support (see comments in the script). 64 | We're looking for ways to make this easier in the future. 65 | 66 | ### Inference 67 | 68 | Once you've trained a model, you can use it for predicitons: 69 | ```bash 70 | scala-cli ImageClassifier.scala -- predict --image-path 71 | ``` 72 | 73 | If you don't have your own images, you can use an example dataset, for instance, the [Cat VS Dog dataset](https://www.kaggle.com/datasets/karakaggle/kaggle-cat-vs-dog-dataset) (alternative [download](https://www.microsoft.com/en-us/download/details.aspx?id=54765)) without requiring a kaggle account) is already in the right format. 74 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | ## Q: I want to run operations on the GPU, but Storch seems to hang? 4 | 5 | Depending on your hardware, the CUDA version and capability settings, CUDA might need to do 6 | [just-in-time compilation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#just-in-time-compilation) 7 | of your kernels, which can take a few minutes. The result is cached, so it should load faster on subsequent runs. 8 | 9 | If you're unsure, you can watch the size of the cache: 10 | 11 | ```bash 12 | watch -d du -sm ~/.nv/ComputeCache 13 | ``` 14 | If it's still growing, it's very likely that CUDA is doing just-in-time compilation. 15 | 16 | You can also increase the cache size to up to 4GB, to avoid recomputation: 17 | 18 | ```bash 19 | export CUDA_CACHE_MAXSIZE=4294967296 20 | ``` 21 | 22 | 23 | ## Q: What about GPU support on my Mac? 24 | 25 | Recent PyTorch versions provide a new backend based on Apple’s Metal Performance Shaders (MPS). 26 | The MPS backend enables GPU-accelerated training on the M1/M2 architecture. 27 | While we have an ARM build of PyTorch in JavaCPP as of version `1.5.10`, MPS ist not enabled as the CI runners currently run on a macOS version that is too old. 28 | If you want to help getting this to work, check out [the corresponding issue](https://github.com/bytedeco/javacpp-presets/issues/1464). -------------------------------------------------------------------------------- /docs/modules.md: -------------------------------------------------------------------------------- 1 | # Modules 2 | 3 | Storch provides a neural network module API with building blocks for creating stateful neural network architectures. 4 | 5 | ## Simple custom module example 6 | 7 | ```scala mdoc:invisible 8 | torch.manualSeed(0) 9 | ``` 10 | 11 | ```scala mdoc 12 | import torch.* 13 | import torch.nn 14 | import torch.nn.functional as F 15 | 16 | class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module: 17 | val conv1 = register(nn.Conv2d(1, 6, 5)) 18 | val conv2 = register(nn.Conv2d(6, 16, 5)) 19 | val fc1 = register(nn.Linear(16 * 4 * 4, 120)) 20 | val fc2 = register(nn.Linear(120, 84)) 21 | val fc3 = register(nn.Linear(84, 10)) 22 | 23 | def apply(i: Tensor[D]): Tensor[D] = 24 | var x = F.maxPool2d(F.relu(conv1(i)), (2, 2)) 25 | x = F.maxPool2d(F.relu(conv2(x)), 2) 26 | x = x.view(-1, 16 * 4 * 4) 27 | x = F.relu(fc1(x)) 28 | x = F.relu(fc2(x)) 29 | x = fc3(x) 30 | x 31 | ``` 32 | 33 | ```scala mdoc 34 | val model = LeNet() 35 | val input = torch.rand(Seq(1, 1, 28, 28)) 36 | model(input) 37 | ``` -------------------------------------------------------------------------------- /docs/pre-trained-weights.md: -------------------------------------------------------------------------------- 1 | {% 2 | laika.title="Converting pre-trained weights" 3 | %} 4 | 5 | # Converting pre-trained weights from PyTorch 6 | 7 | Loading pre-trained weights from PyTorch into Storch is an important feature for transfer-learning, 8 | or for doing inference on models trained with PyTorch. 9 | 10 | Currently, this is not as simple as it should be because serialized weights cannot be loaded into Storch if they 11 | contain weights stored as `torch.nn.parameter.Parameter`, a subclass of `torch.Tensor`. 12 | We have to convert these parameters to regular tensors first to be able to load them in Storch. 13 | 14 | To help with this task, we provide a simple [conversion script](https://github.com/sbrunk/storch/blob/main/scripts/convert-weights/convert_weights.py). 15 | Currently the script only converts pre-trained ResNet weights but it shouldn't be too difficult to apply it to other weights as well. 16 | 17 | The [converted weights](https://github.com/sbrunk/storch/releases/tag/pretrained-weights) are also available for download 18 | from the Storch GitHub repository. 19 | 20 | We hope to improve the situation by creating our own reader that allows direct loading of PyTorch weights. 21 | -------------------------------------------------------------------------------- /docs/tutorial/directory.conf: -------------------------------------------------------------------------------- 1 | laika.navigationOrder = [ 2 | tensors.md 3 | audograd.md 4 | buildmodel.md 5 | ] -------------------------------------------------------------------------------- /docs/tutorial/img/comp-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbrunk/storch/2dfa3884b9f0f2d1e2566aad791f44535b48bb09/docs/tutorial/img/comp-graph.png -------------------------------------------------------------------------------- /docs/tutorial/tensors.md: -------------------------------------------------------------------------------- 1 | Learn the Basics || 2 | Quickstart || 3 | **Tensors** || 4 | Datasets & DataLoaders || 5 | Transforms || 6 | [Build Model](buildmodel.md) || 7 | [Autograd](autograd.md) || 8 | Optimization || 9 | Save & Load Model 10 | 11 | # Tensors 12 | 13 | ```scala mdoc:invisible 14 | torch.manualSeed(0) 15 | ``` 16 | 17 | Tensors are a specialized data structure that are very similar to arrays 18 | and matrices. In PyTorch, we use tensors to encode the inputs and 19 | outputs of a model, as well as the model’s parameters. 20 | 21 | Tensors are similar to [NumPy’s](https://numpy.org/) ndarrays, except 22 | that tensors can run on GPUs or other hardware accelerators. Tensors 23 | are also optimized for automatic differentiation (we'll see more about 24 | that later in the [Autograd](autograd.md) section). If 25 | you’re familiar with ndarrays, you’ll be right at home with the Tensor 26 | API. If not, follow along! 27 | 28 | ## Initializing a Tensor 29 | 30 | Tensors can be initialized in various ways. Take a look at the following 31 | examples: 32 | 33 | **Directly from data** 34 | 35 | Tensors can be created directly from data. The data type is 36 | automatically inferred. 37 | 38 | ```scala mdoc 39 | val data = Seq(1, 2, 3, 4) 40 | val xData = torch.Tensor(data).reshape(2,2) 41 | ``` 42 | 43 | **From another tensor:** 44 | 45 | The new tensor retains the properties (shape, datatype) of the argument 46 | tensor, unless explicitly overridden. 47 | 48 | 49 | ```scala mdoc 50 | // Ones Tensor: 51 | val xOnes = torch.onesLike(xData) // retains the properties of xData 52 | ``` 53 | 54 | ```scala mdoc 55 | // Random Tensor: 56 | val xRand = torch.randLike(xData, dtype=torch.float32) // overrides the datatype of xData 57 | ``` 58 | 59 | **With random or constant values:** 60 | 61 | `shape` is a tuple of tensor dimensions. In the functions below, it 62 | determines the dimensionality of the output tensor. 63 | 64 | ```scala mdoc 65 | val shape = Seq(2,3) 66 | 67 | // Random Tensor: 68 | val randTensor = torch.rand(shape) 69 | 70 | // Ones Tensor: 71 | val onesTensor = torch.ones(shape) 72 | 73 | // Zeros Tensor: 74 | val zerosTensor = torch.zeros(shape) 75 | ``` 76 | 77 | ## Attributes of a Tensor 78 | 79 | Tensor attributes describe their shape, datatype, and the device on 80 | which they are stored. 81 | 82 | ```scala mdoc 83 | var tensor = torch.rand(Seq(3,4)) 84 | 85 | println(s"Shape of tensor: ${tensor.shape}") 86 | println(s"Datatype of tensor: ${tensor.dtype}") 87 | println(s"Device tensor is stored on: {tensor.device}") 88 | ``` 89 | 90 | ------------------------------------------------------------------------ 91 | 92 | ## Operations on Tensors 93 | 94 | Over 100 tensor operations, including arithmetic, linear algebra, matrix 95 | manipulation (transposing, indexing, slicing), sampling and more are 96 | comprehensively described 97 | [here](https://pytorch.org/docs/stable/torch.html). 98 | 99 | Each of these operations can be run on the GPU (at typically higher 100 | speeds than on a CPU). If you’re using Colab, allocate a GPU by going to 101 | Runtime \> Change runtime type \> GPU. 102 | 103 | By default, tensors are created on the CPU. We need to explicitly move 104 | tensors to the GPU using `.to` method (after checking for GPU 105 | availability). Keep in mind that copying large tensors across devices 106 | can be expensive in terms of time and memory! We move our tensor to the 107 | GPU if available 108 | 109 | ```scala mdoc 110 | if torch.cuda.isAvailable then 111 | tensor = tensor.to(torch.Device.CUDA) 112 | ``` 113 | 114 | Try out some of the operations from the list. If you're familiar with 115 | the NumPy API, you'll find the Tensor API a breeze to use. 116 | 117 | **Standard numpy-like indexing and slicing:** 118 | 119 | ```scala mdoc 120 | import torch.{---, Slice} 121 | tensor = torch.ones(Seq(4, 4)) 122 | println(s"First row: ${tensor(0)}") 123 | println(s"First column: ${tensor(Slice(), 0)}") 124 | println(s"Last column: ${tensor(---, -1)}") 125 | //tensor(---,1) = 0 TODO update op 126 | println(tensor) 127 | ``` 128 | 129 | **Joining tensors** You can use `torch.cat` to concatenate a sequence of 130 | tensors along a given dimension. See also 131 | [torch.stack](https://pytorch.org/docs/stable/generated/torch.stack.html), 132 | another tensor joining op that is subtly different from `torch.cat`. 133 | 134 | ```scala mdoc 135 | val t1 = torch.cat(Seq(tensor, tensor, tensor), dim=1) 136 | println(t1) 137 | ``` 138 | 139 | **Arithmetic operations** 140 | 141 | ```scala mdoc 142 | // This computes the matrix multiplication between two tensors. y1, y2, y3 will 143 | // have the same value 144 | // `tensor.mT` returns the transpose of a tensor 145 | val y1 = tensor `@` tensor.mT 146 | val y2 = tensor.matmul(tensor.mT) 147 | 148 | //val y3 = torch.randLike(y1) 149 | //torch.matmul(tensor, tensor.mT, out=y3) 150 | 151 | // This computes the element-wise product. z1, z2, z3 will have the same value 152 | 153 | val z1 = tensor * tensor 154 | val z2 = tensor.mul(tensor) 155 | 156 | //val z3 = torch.randLike(tensor) 157 | //torch.mul(tensor, tensor, out=z3) 158 | ``` 159 | 160 | **Single-element tensors** If you have a one-element tensor, for example 161 | by aggregating all values of a tensor into one value, you can convert it 162 | to a Scala numerical value using `item()`: 163 | 164 | ```scala mdoc 165 | val agg = tensor.sum 166 | val aggItem = agg.item 167 | print(aggItem) 168 | println(aggItem.getClass) 169 | ``` 170 | 171 | **In-place operations** Operations that store the result into the 172 | operand are called in-place. They are denoted by a `_` suffix. For 173 | example: `x.copy_(y)`, `x.t_()`, will change `x`. 174 | 175 | ```scala mdoc 176 | println(s"$tensor") 177 | tensor -= 5 178 | println(tensor) 179 | ``` 180 | 181 | @:callout(info) 182 | 183 | In-place operations save some memory, but can be problematic when 184 | computing derivatives because of an immediate loss of history. Hence, 185 | their use is discouraged. 186 | 187 | @:@ -------------------------------------------------------------------------------- /examples/src/main/scala/LeNet.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | //> using scala "3.3" 18 | //> using repository "sonatype:snapshots" 19 | //> using repository "sonatype-s01:snapshots" 20 | //> using lib "dev.storch::vision:0.0-2fff591-SNAPSHOT" 21 | // replace with pytorch-platform-gpu if you have a CUDA capable GPU 22 | //> using lib "org.bytedeco:pytorch-platform:2.1.2-1.5.10" 23 | // enable for CUDA support 24 | ////> using lib "org.bytedeco:cuda-platform-redist:12.3-8.9-1.5.10" 25 | // enable for native Apple Silicon support 26 | // will not be needed with newer versions of pytorch-platform 27 | ////> using lib "org.bytedeco:pytorch:2.1.2-1.5.10,classifier=macosx-arm64" 28 | 29 | import torch.* 30 | import torch.nn.functional as F 31 | import torch.optim.Adam 32 | import org.bytedeco.pytorch.OutputArchive 33 | import torchvision.datasets.MNIST 34 | import scala.util.Random 35 | import java.nio.file.Paths 36 | import torch.Device.CUDA 37 | import scala.util.Using 38 | import org.bytedeco.javacpp.PointerScope 39 | import torch.Device.CPU 40 | import torch.nn.modules.HasParams 41 | 42 | // Define the model architecture 43 | class LeNet[D <: BFloat16 | Float32: Default] extends HasParams[D] { 44 | 45 | val conv1 = register(nn.Conv2d(1, 6, 5)) 46 | val conv2 = register(nn.Conv2d(6, 16, 5)) 47 | val fc1 = register(nn.Linear(16 * 4 * 4, 120)) 48 | val fc2 = register(nn.Linear(120, 84)) 49 | val fc3 = register(nn.Linear(84, 10)) 50 | 51 | def apply(i: Tensor[D]): Tensor[D] = 52 | var x = F.maxPool2d(F.relu(conv1(i)), (2, 2)) 53 | x = F.maxPool2d(F.relu(conv2(x)), 2) 54 | x = x.view(-1, 16 * 4 * 4) // all dimensions except the batch dimension 55 | x = F.relu(fc1(x)) 56 | x = F.relu(fc2(x)) 57 | x = fc3(x) 58 | x 59 | } 60 | 61 | /** Shows how to train a simple LeNet on the MNIST dataset */ 62 | object LeNetApp extends App { 63 | val device = if torch.cuda.isAvailable then CUDA else CPU 64 | println(s"Using device: $device") 65 | val model = LeNet().to(device) 66 | 67 | // prepare data 68 | val dataPath = Paths.get("data/mnist") 69 | val mnistTrain = MNIST(dataPath, train = true, download = true) 70 | val mnistEval = MNIST(dataPath, train = false) 71 | val evalFeatures = mnistEval.features.to(device) 72 | val evalTargets = mnistEval.targets.to(device) 73 | val r = Random(seed = 0) 74 | 75 | def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = 76 | r.shuffle(mnistTrain).grouped(32).map { batch => 77 | val (features, targets) = batch.unzip 78 | (torch.stack(features).to(device), torch.stack(targets).to(device)) 79 | } 80 | 81 | val lossFn = torch.nn.loss.CrossEntropyLoss() 82 | // enable AMSGrad to avoid convergence issues 83 | val optimizer = Adam(model.parameters, lr = 1e-3, amsgrad = true) 84 | 85 | // run training 86 | for (epoch <- 1 to 5) do 87 | for (batch <- dataLoader.zipWithIndex) do 88 | // make sure we deallocate intermediate tensors in time 89 | Using.resource(new PointerScope()) { p => 90 | val ((feature, target), batchIndex) = batch 91 | optimizer.zeroGrad() 92 | val prediction = model(feature) 93 | val loss = lossFn(prediction, target) 94 | loss.backward() 95 | optimizer.step() 96 | if batchIndex % 200 == 0 then 97 | // run evaluation 98 | val predictions = model(evalFeatures) 99 | val evalLoss = lossFn(predictions, evalTargets) 100 | val accuracy = 101 | (predictions.argmax(dim = 1).eq(evalTargets).sum / mnistEval.length).item 102 | println( 103 | f"Epoch: $epoch | Batch: $batchIndex%4d | Training loss: ${loss.item}%.4f | Eval loss: ${evalLoss.item}%.4f | Eval accuracy: $accuracy%.4f" 104 | ) 105 | } 106 | val archive = new OutputArchive 107 | model.save(archive) 108 | archive.save_to("net.pt") 109 | } 110 | -------------------------------------------------------------------------------- /git-hooks/pre-push-checks: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Go to the repository root 4 | cd "${GIT_DIR}/.." 5 | 6 | # Run the sbt linting checks 7 | sbt 'headerCheckAll ; scalafmtCheckAll ; scalafmtSbtCheck' 8 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.9.8 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") 2 | addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17") 3 | addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2") 4 | addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0") 5 | addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.6.5") 6 | addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.6.5") 7 | -------------------------------------------------------------------------------- /scripts/convert-weights/convert_weights.py: -------------------------------------------------------------------------------- 1 | # Script for converting weights of class `torch.nn.parameter.Parameter` to regular tensors which enables loading them from libtorch 2 | 3 | from pathlib import Path 4 | from urllib.parse import urlparse 5 | from torchvision.models.resnet import * 6 | import torch 7 | 8 | model_path = Path("models") 9 | model_path.mkdir(parents=True, exist_ok=True) 10 | 11 | for weights_enum in [ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights]: 12 | for weights in weights_enum: 13 | parts = urlparse(weights.url) 14 | filename = Path(parts.path).name 15 | print(f"Converting {filename}") 16 | state_dict = weights.get_state_dict(progress=True) 17 | # Convert from torch.nn.Parameter to torch.Tensor so we can load the state dict from libtorch 18 | converted_state_dict = {k: v.clone().detach() 19 | for k, v in state_dict.items()} 20 | torch.save(converted_state_dict, (model_path / filename)) 21 | -------------------------------------------------------------------------------- /scripts/convert-weights/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | 4 | 5 | -------------------------------------------------------------------------------- /site/src/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Adjust a few styles from landing-page.css */ 2 | 3 | #header { 4 | padding-top: 20px; 5 | } 6 | 7 | #header-left h1, #header-left h2 { 8 | color: var(--component-color); 9 | line-height: 1; 10 | margin-bottom: 5px; 11 | } 12 | 13 | #header-left h1 { 14 | font-size: 40px; 15 | } 16 | 17 | #header-left h2 { 18 | font-size: 22px; 19 | margin-top: 0.7em; 20 | } 21 | 22 | .teaser h2 { 23 | font-size: 20px; 24 | margin-bottom: 0.25em; 25 | margin-top: 0; 26 | } 27 | 28 | .teaser p { 29 | font-size: 15px; 30 | } 31 | 32 | .teasers { 33 | margin: 15px auto 0 auto; 34 | } 35 | -------------------------------------------------------------------------------- /site/src/js/render-katex.js: -------------------------------------------------------------------------------- 1 | document.addEventListener("DOMContentLoaded", function() { 2 | renderMathInElement(document.body, { 3 | // customised options 4 | // • auto-render specific keys, e.g.: 5 | delimiters: [ 6 | {left: '$$', right: '$$', display: true}, 7 | {left: '$', right: '$', display: false}, 8 | ], 9 | // • rendering keys, e.g.: 10 | throwOnError : false 11 | }); 12 | }); -------------------------------------------------------------------------------- /site/src/landing-page.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 |
6 | 7 |
8 | 9 |
10 | 11 |
12 | 13 |

Torch by Mailtoanton / CC BY-SA 3.0

-------------------------------------------------------------------------------- /vision/src/main/scala/torchvision/datasets/MNIST.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torchvision.datasets 18 | 19 | import torch.* 20 | import org.bytedeco.pytorch 21 | import torch.data.TensorDataset 22 | import java.nio.file.Path 23 | import scala.util.Using 24 | import java.net.URL 25 | import java.nio.file.Files 26 | import java.util.zip.GZIPInputStream 27 | import scala.util.Try 28 | import scala.util.Success 29 | import scala.util.Failure 30 | import torch.Tensor.fromNative 31 | 32 | trait MNISTBase( 33 | val mirrors: Seq[String], 34 | val resources: Seq[(String, String)], 35 | val classes: Seq[String], 36 | val root: Path, 37 | val train: Boolean, 38 | val download: Boolean 39 | ) extends TensorDataset[Float32, Int64] { 40 | 41 | private def downloadAndExtractArchive(url: URL, target: Path): Unit = 42 | println(s"downloading from $url") 43 | Using.resource(url.openStream()) { inputStream => 44 | val _ = Files.copy(GZIPInputStream(inputStream), target) 45 | } 46 | 47 | if download then { 48 | Files.createDirectories(root) 49 | for (filename, md5) <- resources do 50 | val finalPath = root.resolve(filename.stripSuffix(".gz")) 51 | if !Files.exists(finalPath) then 52 | println(s"$finalPath not found") 53 | val _ = mirrors.iterator 54 | .map { mirror => 55 | Try(downloadAndExtractArchive(URL(s"$mirror$filename"), finalPath)) 56 | } 57 | .tapEach { 58 | case Failure(exception) => println(exception) 59 | case Success(_) => 60 | } 61 | .collectFirst { case Success(_) => } 62 | } 63 | 64 | private val mode = 65 | if train then pytorch.MNIST.Mode.kTrain.intern().value 66 | else pytorch.MNIST.Mode.kTest.intern().value 67 | 68 | private val native = pytorch.MNIST(root.toString(), mode) 69 | 70 | private val ds = 71 | TensorDataset( 72 | fromNative[Float32](native.images().clone()), 73 | fromNative[Int64](native.targets().clone()) 74 | ) 75 | export ds.{apply, length, features, targets} 76 | 77 | override def toString(): String = ds.toString() 78 | } 79 | 80 | /** The [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. 81 | * 82 | * @param root 83 | * Root directory of dataset where `train-images-idx3-ubyte` `t10k-images-idx3-ubyte` exist. 84 | * @param train 85 | * If true, creates dataset from `train-images-idx3-ubyte`, otherwise from 86 | * `t10k-images-idx3-ubyte`. 87 | */ 88 | class MNIST(root: Path, train: Boolean = true, download: Boolean = false) 89 | extends MNISTBase( 90 | mirrors = List( 91 | "http://yann.lecun.com/exdb/mnist/", 92 | "https://ossci-datasets.s3.amazonaws.com/mnist/" 93 | ), 94 | resources = List( 95 | ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), 96 | ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), 97 | ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), 98 | ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") 99 | ), 100 | classes = Seq( 101 | "0 - zero", 102 | "1 - one", 103 | "2 - two", 104 | "3 - three", 105 | "4 - four", 106 | "5 - five", 107 | "6 - six", 108 | "7 - seven", 109 | "8 - eight", 110 | "9 - nine" 111 | ), 112 | root, 113 | train, 114 | download 115 | ) 116 | 117 | /** The [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) Dataset. 118 | * 119 | * @param root 120 | * Root directory of dataset where `train-images-idx3-ubyte` `t10k-images-idx3-ubyte` exist. 121 | * @param train 122 | * If true, creates dataset from `train-images-idx3-ubyte`, otherwise from 123 | * `t10k-images-idx3-ubyte`. 124 | */ 125 | class FashionMNIST(root: Path, train: Boolean = true, download: Boolean = false) 126 | extends MNISTBase( 127 | mirrors = List( 128 | "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" 129 | ), 130 | resources = List( 131 | ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), 132 | ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), 133 | ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), 134 | ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310") 135 | ), 136 | classes = Seq( 137 | "T-shirt/top", 138 | "Trouser", 139 | "Pullover", 140 | "Dress", 141 | "Coat", 142 | "Sandal", 143 | "Shirt", 144 | "Sneaker", 145 | "Bag", 146 | "Ankle boot" 147 | ), 148 | root, 149 | train, 150 | download 151 | ) 152 | -------------------------------------------------------------------------------- /vision/src/main/scala/torchvision/transforms/functional.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torchvision.transforms 18 | 19 | import torch.* 20 | import com.sksamuel.scrimage.ImmutableImage 21 | import scala.collection.immutable.ArraySeq 22 | 23 | object functional: 24 | 25 | private def isTensorATorchImage(x: Tensor[?]): Boolean = x.dim >= 2 26 | 27 | private def assertImageTensor(img: Tensor[?]): Unit = 28 | if !isTensorATorchImage(img) then 29 | throw new IllegalArgumentException("Tensor is not a torch image.") 30 | 31 | def normalize[D <: FloatNN](tensor: Tensor[D], mean: Seq[Float], std: Seq[Float]): Tensor[D] = 32 | assertImageTensor(tensor) 33 | 34 | if tensor.dim < 3 then 35 | throw new IllegalArgumentException( 36 | s"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ${tensor.size}" 37 | ) 38 | 39 | val dtype = tensor.dtype 40 | var _mean = Tensor(mean, device = tensor.device).to(dtype = dtype) 41 | var _std = Tensor(std, device = tensor.device).to(dtype = dtype) 42 | if (_std == 0).any.item then 43 | throw new IllegalArgumentException( 44 | f"std evaluated to zero after conversion to {dtype}, leading to division by zero." 45 | ) 46 | if _mean.dim == 1 then _mean = _mean.view(-1, 1, 1) 47 | if _std.dim == 1 then _std = _std.view(-1, 1, 1) 48 | (tensor - _mean) / _std 49 | 50 | /** Convert an [[ImmutableImage]] (H x W x C) to a [[Tensor[Float32]] of shape (C x H x W) in the 51 | * range `[0.0, 1.0]`. 52 | */ 53 | def toTensor(pic: ImmutableImage): Tensor[Float32] = 54 | val bytes = pic.rgb.flatten 55 | // transpose NxHxWxC to NxCxHxW because pytorch expects channels first 56 | Tensor(ArraySeq.unsafeWrapArray(bytes)) 57 | .reshape(pic.height, pic.width, 3) 58 | .permute(2, 0, 1) 59 | .to(dtype = float32) / 255 60 | 61 | def toImmutableImage[D <: FloatNN](pic: Tensor[D]): ImmutableImage = 62 | var _pic = pic 63 | if !Seq(2, 3).contains(pic.dim) then 64 | throw new IllegalArgumentException( 65 | s"pic should be 2/3 dimensional. Got ${pic.dim} dimensions." 66 | ) 67 | else if pic.dim == 2 then 68 | // if 2D image, add channel dimension (CHW) 69 | _pic = pic.unsqueeze(0) 70 | // check number of channels 71 | if pic.shape(-3) > 4 then 72 | throw new IllegalArgumentException( 73 | s"pic should not have > 4 channels. Got ${pic.shape(-3)} channels." 74 | ) 75 | val intImage = (_pic.permute(1, 2, 0) * 255).to(dtype = int8) 76 | val bytes = intImage.toArray 77 | ImmutableImage.loader().fromBytes(bytes) 78 | -------------------------------------------------------------------------------- /vision/src/main/scala/torchvision/transforms/presets.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torchvision 18 | package transforms 19 | 20 | import com.sksamuel.scrimage.ImmutableImage 21 | import com.sksamuel.scrimage.ScaleMethod 22 | import torch.Tensor 23 | import torch.Float32 24 | import torchvision.transforms.functional.toTensor 25 | 26 | object Presets: 27 | 28 | class ImageClassification( 29 | cropSize: Int, 30 | resizeSize: Int = 256, 31 | mean: Seq[Float] = Seq(0.485f, 0.456f, 0.406f), 32 | std: Seq[Float] = Seq(0.229f, 0.224f, 0.225f), 33 | interpolation: ScaleMethod = ScaleMethod.Bilinear 34 | ): 35 | def transforms(image: ImmutableImage): Tensor[Float32] = 36 | val scaledImage = 37 | if image.height < image.width then 38 | image.scaleTo( 39 | (resizeSize * (image.width / image.height.toDouble)).toInt, 40 | resizeSize, 41 | interpolation 42 | ) 43 | else 44 | image.scaleTo( 45 | resizeSize, 46 | (resizeSize * (image.height / image.width.toDouble)).toInt, 47 | interpolation 48 | ) 49 | val croppedImage = scaledImage.resizeTo(cropSize, cropSize) 50 | toTensor(croppedImage) 51 | 52 | def batchTransforms(input: Tensor[Float32]): Tensor[Float32] = 53 | torchvision.transforms.functional.normalize( 54 | input, 55 | mean = mean, 56 | std = std 57 | ) 58 | -------------------------------------------------------------------------------- /vision/src/test/scala/torchvision/MNISTSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2022 storch.dev 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package torchvision 18 | package datasets 19 | 20 | import java.nio.file.Paths 21 | 22 | class MNISTSuite extends munit.FunSuite { 23 | test("MNIST download") { 24 | val mnistTrain = MNIST(Paths.get("data/mnist"), download = true) 25 | assertEquals(mnistTrain.features.shape, Seq(60000, 1, 28, 28)) 26 | val mnistTest = MNIST(Paths.get("data/mnist"), train = false, download = true) 27 | assertEquals(mnistTest.features.shape, Seq(10000, 1, 28, 28)) 28 | } 29 | 30 | test("FashionMNIST download") { 31 | val fashionMNISTTrain = FashionMNIST(Paths.get("data/fashion-mnist"), download = true) 32 | assertEquals(fashionMNISTTrain.features.shape, Seq(60000, 1, 28, 28)) 33 | val fashionMNISTTest = 34 | FashionMNIST(Paths.get("data/fashion-mnist"), train = false, download = true) 35 | assertEquals(fashionMNISTTest.features.shape, Seq(10000, 1, 28, 28)) 36 | } 37 | 38 | } 39 | --------------------------------------------------------------------------------