├── .circleci └── config.yml ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── Setup.hs ├── cabal.project ├── codegen ├── codegen.cabal ├── exe │ └── Main.hs ├── run-ghci.sh ├── src │ ├── ParseClass.hs │ ├── ParseDeclarations.hs │ ├── ParseDerivatives.hs │ ├── ParseFunctionSig.hs │ ├── ParseTuples.hs │ ├── RenderClass.hs │ ├── RenderCommon.hs │ ├── RenderDeclarations.hs │ ├── RenderPure.hs │ └── RenderTuples.hs └── test │ ├── Spec.hs │ └── doctests.hs ├── deps ├── .gitignore └── get-deps.sh ├── examples ├── cnn │ └── Main.hs ├── elman │ ├── Elman.hs │ ├── GRU.hs │ ├── LSTM.hs │ ├── Main.hs │ └── RecurrentLayer.hs ├── examples.cabal ├── gaussian_process │ └── Main.hs ├── regression │ └── Main.hs ├── vae │ └── Main.hs └── xor_mlp │ └── Main.hs ├── ffi ├── ffi.cabal ├── src │ ├── ATen │ │ ├── Cast.hs │ │ ├── Class.hs │ │ ├── Const.hs │ │ ├── GC.hs │ │ ├── Managed │ │ │ ├── Cast.hs │ │ │ ├── NN.hs │ │ │ ├── Native.hs │ │ │ ├── TH.hs │ │ │ └── Type │ │ │ │ ├── Context.hs │ │ │ │ ├── Extra.hs │ │ │ │ ├── Generator.hs │ │ │ │ ├── IntArray.hs │ │ │ │ ├── Scalar.hs │ │ │ │ ├── SparseTensorRef.hs │ │ │ │ ├── StdArray.hs │ │ │ │ ├── StdString.hs │ │ │ │ ├── Storage.hs │ │ │ │ ├── Tensor.hs │ │ │ │ ├── TensorList.hs │ │ │ │ ├── TensorOptions.hs │ │ │ │ └── Tuple.hs │ │ ├── Type.hs │ │ └── Unmanaged │ │ │ ├── NN.hs │ │ │ ├── Native.hs │ │ │ ├── TH.hs │ │ │ └── Type │ │ │ ├── Context.hs │ │ │ ├── Extra.hs │ │ │ ├── Generator.hs │ │ │ ├── IntArray.hs │ │ │ ├── Scalar.hs │ │ │ ├── SparseTensorRef.hs │ │ │ ├── StdArray.hs │ │ │ ├── StdString.hs │ │ │ ├── Storage.hs │ │ │ ├── Tensor.hs │ │ │ ├── TensorList.hs │ │ │ ├── TensorOptions.hs │ │ │ └── Tuple.hs │ └── Torch │ │ ├── Managed │ │ ├── Autograd.hs │ │ ├── NN.hs │ │ ├── Native.hs │ │ └── TH.hs │ │ └── Unmanaged │ │ ├── Autograd.hs │ │ ├── NN.hs │ │ ├── Native.hs │ │ └── TH.hs └── test │ ├── BackwardSpec.hs │ ├── BasicSpec.hs │ ├── CudaSpec.hs │ ├── MemorySpec.hs │ └── Spec.hs ├── hasktorch ├── hasktorch.cabal ├── src │ ├── Torch.hs │ └── Torch │ │ ├── Autograd.hs │ │ ├── Backend.hs │ │ ├── Cast.hs │ │ ├── DType.hs │ │ ├── Functions.hs │ │ ├── Functions │ │ └── Native.hs │ │ ├── Layout.hs │ │ ├── NN.hs │ │ ├── Scalar.hs │ │ ├── Static.hs │ │ ├── Tensor.hs │ │ ├── TensorFactories.hs │ │ └── TensorOptions.hs └── test │ ├── FactorySpec.hs │ ├── FunctionsSpec.hs │ ├── GradSpec.hs │ ├── NNSpec.hs │ ├── SparseSpec.hs │ ├── Spec.hs │ └── TensorSpec.hs ├── setenv ├── setup-cabal.sh ├── spec ├── Declarations.yaml ├── README.md ├── bindings.yaml └── cppclass │ ├── array.yaml │ ├── context.yaml │ ├── gen.sh │ ├── generator.yaml │ ├── intarray.yaml │ ├── scalar.yaml │ ├── sparsetensorref.yaml │ ├── storage.yaml │ ├── tensor.yaml │ ├── tensorlist.yaml │ └── tensoroptions.yaml └── stack.yaml /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | stack-build: 5 | docker: 6 | - image: "ubuntu:18.04" 7 | environment: 8 | LD_LIBRARY_PATH: /root/project/deps/libtorch/lib:/root/project/deps/mklml/lib 9 | steps: 10 | - run: echo $LD_LIBRARY_PATH 11 | - run: apt update -qq && apt install -y cmake wget unzip git libtinfo-dev python3 python3-yaml 12 | - run: update-alternatives --install /usr/bin/python python /usr/bin/python3 1 13 | - checkout 14 | - run: git submodule init && git submodule update 15 | - run: wget -qO- https://get.haskellstack.org/ | sed -e 's/^STACK_VERSION=.*/STACK_VERSION="1.9.3"/g' | sh 16 | - run: gcc --version 17 | - run: stack --version 18 | - run: cd deps/ ; ./get-deps.sh -a cpu -c 19 | - run: 20 | name: stack build 21 | command: stack build --jobs 2 22 | no_output_timeout: 15m 23 | - run: stack test --jobs 2 24 | - run: stack exec codegen-exe 25 | - run: stack test --jobs 2 26 | - run: stack exec xor_mlp 27 | cabal-build: 28 | docker: 29 | - image: "ubuntu:18.04" 30 | environment: 31 | LD_LIBRARY_PATH: /root/project/deps/libtorch/lib:/root/project/deps/mklml/lib 32 | PATH: /opt/ghc/bin:/bin:/usr/bin:/usr/local/bin:/sbin:/usr/sbin 33 | steps: 34 | - run: echo $LD_LIBRARY_PATH 35 | - run: apt update -qq && apt install -y cmake curl wget unzip git libtinfo-dev python3 python3-yaml 36 | - run: apt -y --allow-downgrades --allow-remove-essential --allow-change-held-packages install locales software-properties-common apt-transport-https 37 | - run: add-apt-repository ppa:hvr/ghc 38 | - run: apt-get update -qq && apt-get -y --allow-downgrades --allow-remove-essential --allow-change-held-packages install build-essential zlib1g-dev liblapack-dev libblas-dev ghc-8.6.4 cabal-install-head devscripts debhelper python3-pip 39 | - run: update-alternatives --install /usr/bin/python python /usr/bin/python3 1 40 | - checkout 41 | - run: git submodule init && git submodule update 42 | - run: gcc --version 43 | - run: cabal --version 44 | - run: cd deps/ ; ./get-deps.sh -a cpu -c 45 | - run: ./setup-cabal.sh 46 | - run: cabal new-update 47 | - run: cabal new-install hspec-discover 48 | - run: 49 | name: cabal new-build all 50 | command: cabal new-build all --jobs=2 --write-ghc-environment-files=always 51 | no_output_timeout: 15m 52 | - run: cabal new-test all --jobs=2 --write-ghc-environment-files=always 53 | - run: cabal new-exec codegen-exe 54 | - run: cabal new-test all --jobs=2 --write-ghc-environment-files=always 55 | - run: cabal exec xor_mlp 56 | osx-stack-build: 57 | macos: 58 | xcode: "10.2.1" 59 | environment: 60 | DYLD_LIBRARY_PATH: /Users/distiller/project/deps/libtorch/lib:/Users/distiller/project/deps/mklml/lib 61 | PATH: /Users/distiller/project/.local/bin:/usr/local/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin 62 | steps: 63 | - run: echo $LD_LIBRARY_PATH 64 | - checkout 65 | - run: git submodule init && git submodule update 66 | - run: /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 67 | - run: brew install wget cmake libomp 68 | - run: pip3 install pyyaml 69 | - run: mkdir -p .local/bin 70 | - run: rm /usr/local/bin/python 71 | - run: ln -s /usr/local/bin/python3 .local/bin/python 72 | - run: wget -qO- https://get.haskellstack.org/ | sed -e 's/^STACK_VERSION=.*/STACK_VERSION="1.9.3"/g' | sh 73 | - run: clang --version 74 | - run: stack --version 75 | - run: cd deps/ ; ./get-deps.sh -a cpu -c 76 | - run: cp -a deps/libtorch/lib/*.dylib deps/mklml/lib/*.dylib /usr/local/lib/ 77 | - run: 78 | name: stack build 79 | command: stack build --jobs 2 80 | no_output_timeout: 15m 81 | - run: stack test --jobs 2 82 | - run: stack exec codegen-exe 83 | - run: stack test --jobs 2 84 | - run: stack exec xor_mlp 85 | - run: stack exec regression 86 | - run: stack exec gaussian_process 87 | - run: stack exec vae 88 | 89 | workflows: 90 | version: 2 91 | build: 92 | jobs: 93 | - "stack-build" 94 | - "cabal-build" 95 | - "osx-stack-build" 96 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | .stack-work/ 3 | dist-newstyle/ 4 | libtorch-test/build 5 | codegen/build/ 6 | .ghc.environment* 7 | deps/mklml 8 | cabal.project.local 9 | cabal.project.freeze 10 | venv/ 11 | 12 | *.swp 13 | *.lock 14 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "deps/pytorch"] 2 | path = deps/pytorch 3 | url = https://github.com/pytorch/pytorch.git 4 | [submodule "inline-c"] 5 | path = inline-c 6 | url = https://github.com/hasktorch/inline-c.git 7 | [submodule "fficxx"] 8 | path = fficxx 9 | url = git@github.com:hasktorch/fficxx.git 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Copyright Austin Huang (c) 2017 31 | 32 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 33 | 34 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 35 | 36 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 37 | 38 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 39 | 40 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (ARCHIVED) Hasktorch 0.2 Libtorch FFI 2 | 3 | __*This repository contains all of the work leading up to Haskell support of PyTorch's C++ API (part of PyTorch's 1.0 4 | release). It is now merged into [hasktorch](https://github.com/hasktorch/hasktorch), where development on this work continues.*__ 5 | 6 | Work on ffi bindings into the c++ libtorch library in preparation for 0.2 which targets the pytorch's post 1.0libtorch backend. 7 | 8 | General approach is to use generated `Declarations.yaml` spec instead of header parsing for code generation. 9 | 10 | ## Project Structure 11 | 12 | - `codegen/` - code generation, parses `Declarations.yaml` spec from pytorch and produces `ffi/` contents 13 | - `deps/` - submodules for dependencies - libtorch, mklml, pytorch 14 | - `examples/` - high level example models (xor mlp, typed cnn) 15 | - `ffi/`- low level FFI bindings to libtorch 16 | - `hasktorch/` - higher level user-facing library, calls into `ffi/`, used by `examples/` 17 | - `inline-c/` - submodule to inline-cpp fork used for C++ FFI 18 | - `spec/` - specification files used for `codegen/` 19 | 20 | ## Getting dependencies 21 | 22 | `deps/` holds several external dependencies that are retrieved using the `deps/get-deps.sh` script. 23 | 24 | This should be run prior to building 25 | 26 | ## XOR MLP Example 27 | 28 | The following steps should run the xor mlp example: 29 | 30 | ``` 31 | # Download libtorch-binary and other shared library dependencies 32 | pushd deps 33 | # For CPU 34 | ./get-deps.sh 35 | # For CUDA-9 36 | # ./get-deps.sh -a cu90 37 | # For CUDA-10 38 | # ./get-deps.sh -a cu100 39 | popd 40 | 41 | # Set shared library environment variables 42 | source setenv 43 | 44 | stack build examples 45 | 46 | stack exec xor_mlp 47 | ``` 48 | 49 | ## Running code generation 50 | 51 | Code generation is used to build low-level FFI functions. 52 | 53 | Note that the code is already generated in this repo under `ffi`, running this is only needed if changes are being made to the code generation process. 54 | 55 | To run: 56 | 57 | ``` 58 | stack build codegen 59 | stack exec codegen-exe 60 | ``` 61 | 62 | To get CLI options: 63 | 64 | ``` 65 | stack exec codegen-exe -- --help 66 | ``` 67 | 68 | ## Additional Information 69 | 70 | See [the wiki](https://github.com/hasktorch/ffi-experimental/wiki) for developer information. 71 | 72 | ## Contributions 73 | 74 | Contributions/PRs are welcome. 75 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /cabal.project: -------------------------------------------------------------------------------- 1 | packages: 2 | codegen/*.cabal 3 | ffi/*.cabal 4 | hasktorch/*.cabal 5 | examples/*.cabal 6 | inline-c/inline-c/*.cabal 7 | inline-c/inline-c-cpp/*.cabal 8 | -------------------------------------------------------------------------------- /codegen/codegen.cabal: -------------------------------------------------------------------------------- 1 | name: codegen 2 | version: 0.1.0.0 3 | synopsis: parse torch yaml spec files, generate code 4 | -- description: 5 | homepage: https://github.com/githubuser/ffi-experimental#readme 6 | license: BSD3 7 | author: Austin Huang 8 | maintainer: hasktorch@gmail.com 9 | copyright: 2018 Austin Huang 10 | category: Codegen 11 | build-type: Simple 12 | cabal-version: >=1.10 13 | 14 | library 15 | exposed-modules: 16 | ParseDeclarations 17 | ParseDerivatives 18 | ParseFunctionSig 19 | ParseTuples 20 | ParseClass 21 | RenderDeclarations 22 | RenderCommon 23 | RenderTuples 24 | RenderClass 25 | RenderPure 26 | hs-source-dirs: src 27 | default-language: Haskell2010 28 | build-depends: 29 | base >= 4.7 && < 5 30 | , aeson >= 1.4.2.0 31 | , inline-c-cpp >= 0.3.0.1 32 | , megaparsec >= 7.0.4 33 | , show-prettyprint >= 0.2.2 34 | , yaml >= 0.11.0.0 35 | , shakespeare 36 | , text 37 | , string-conversions 38 | , directory 39 | , unordered-containers 40 | extra-libraries: stdc++ 41 | ghc-options: -Wall 42 | 43 | executable codegen-exe 44 | hs-source-dirs: exe 45 | main-is: Main.hs 46 | default-language: Haskell2010 47 | build-depends: 48 | base >= 4.7 && < 5 49 | , codegen 50 | , optparse-applicative >= 0.14.3.0 51 | extra-libraries: stdc++ 52 | 53 | test-suite doctests 54 | default-language: Haskell2010 55 | type: exitcode-stdio-1.0 56 | hs-source-dirs: test 57 | main-is: doctests.hs 58 | ghc-options: -Wall -threaded 59 | build-depends: 60 | base 61 | , doctest 62 | , megaparsec 63 | 64 | test-suite spec 65 | default-language: Haskell2010 66 | type: exitcode-stdio-1.0 67 | hs-source-dirs: test 68 | main-is: Spec.hs 69 | ghc-options: -Wall -threaded 70 | build-depends: base 71 | , codegen 72 | , hspec 73 | , yaml 74 | , safe-exceptions 75 | , directory 76 | , megaparsec 77 | 78 | -------------------------------------------------------------------------------- /codegen/exe/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | 5 | module Main where 6 | 7 | import qualified Options.Applicative as O 8 | import qualified ParseFunctionSig as F 9 | import qualified RenderDeclarations as RD 10 | import qualified RenderTuples as RTL 11 | import qualified RenderClass as RC 12 | import qualified RenderPure as RP 13 | 14 | {- CLI options -} 15 | 16 | data Options = Options 17 | { specFileDL :: !String 18 | , outputDir :: !String 19 | } deriving Show 20 | 21 | optsParser :: O.ParserInfo Options 22 | optsParser = O.info 23 | (O.helper <*> versionOption <*> programOptions) 24 | ( O.fullDesc <> O.progDesc "ffi codegen" <> O.header 25 | "codegen for hasktorch 0.0.2" 26 | ) 27 | 28 | versionOption :: O.Parser (a -> a) 29 | versionOption = 30 | O.infoOption "0.0.2" (O.long "version" <> O.help "Show version") 31 | 32 | programOptions :: O.Parser Options 33 | programOptions = 34 | Options 35 | <$> O.strOption 36 | ( O.long "declaration-spec" 37 | <> O.short 'd' 38 | <> O.metavar "FILENAME" 39 | <> O.value "spec/Declarations.yaml" 40 | <> O.help "Specification file of Declarations" 41 | ) 42 | <*> O.strOption 43 | ( O.long "output-dir" 44 | <> O.short 'o' 45 | <> O.metavar "DIRNAME" 46 | <> O.value "output" 47 | <> O.help "Output-directory" 48 | ) 49 | 50 | main = do 51 | opts <- O.execParser optsParser 52 | -- RT.tensorBuilder 53 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/tensor.yaml" 54 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/intarray.yaml" 55 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/tensoroptions.yaml" 56 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/generator.yaml" 57 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/scalar.yaml" 58 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/sparsetensorref.yaml" 59 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/storage.yaml" 60 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/tensorlist.yaml" 61 | RC.decodeAndCodeGen (outputDir opts) "spec/cppclass/context.yaml" 62 | RTL.decodeAndCodeGen (outputDir opts) (specFileDL opts) 63 | RD.decodeAndCodeGen (outputDir opts) (specFileDL opts) 64 | RP.decodeAndCodeGen (outputDir opts) (specFileDL opts) "spec/bindings.yaml" 65 | pure () 66 | 67 | -------------------------------------------------------------------------------- /codegen/run-ghci.sh: -------------------------------------------------------------------------------- 1 | stack ghci --ghc-options='-fobject-code' --main-is ffi-experimental:exe:cpp-test -------------------------------------------------------------------------------- /codegen/src/ParseClass.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE OverloadedStrings #-} 5 | 6 | module ParseClass where 7 | 8 | import GHC.Generics 9 | import Data.Yaml 10 | 11 | import qualified Data.Yaml as Y 12 | import Text.Show.Prettyprint (prettyPrint) 13 | import qualified ParseFunctionSig as S 14 | 15 | data CppClassSpec = CppClassSpec 16 | { signature :: String 17 | , cppname :: String 18 | , hsname :: String 19 | , constructors :: [S.Function] 20 | , methods :: [S.Function] 21 | , functions :: [S.Function] 22 | } deriving (Show, Eq, Generic) 23 | 24 | instance FromJSON CppClassSpec 25 | 26 | 27 | decodeAndPrint :: String -> IO () 28 | decodeAndPrint fileName = do 29 | file <- Y.decodeFileEither fileName :: IO (Either ParseException CppClassSpec) 30 | prettyPrint file 31 | 32 | trimSpace :: String -> String 33 | trimSpace [] = [] 34 | trimSpace (' ':xs) = trimSpace xs 35 | trimSpace (x:xs) = x:trimSpace xs 36 | 37 | hasSpace :: String -> Bool 38 | hasSpace [] = False 39 | hasSpace (' ':_) = True 40 | hasSpace (_:xs) = hasSpace xs 41 | 42 | hsnameWithoutSpace :: CppClassSpec -> String 43 | hsnameWithoutSpace typ_ = trimSpace $ hsname typ_ 44 | 45 | hsnameWithParens :: CppClassSpec -> String 46 | hsnameWithParens typ_ = if hasSpace name then "(" <> name <> ")" else name 47 | where 48 | name = hsname typ_ 49 | -------------------------------------------------------------------------------- /codegen/src/ParseDeclarations.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE OverloadedStrings #-} 5 | 6 | module ParseDeclarations where 7 | 8 | import GHC.Generics 9 | import Data.Yaml 10 | import qualified Data.Yaml as Y 11 | import Data.Aeson.Types (defaultOptions, fieldLabelModifier, genericParseJSON) 12 | import Text.Show.Prettyprint (prettyPrint) 13 | import qualified ParseFunctionSig as S 14 | 15 | {- Declarations.yaml -} 16 | {- --A example-- 17 | - name: _th_set_ 18 | matches_jit_signature: false 19 | schema_string: '' 20 | method_prefix_derived: '' 21 | arguments: 22 | - dynamic_type: Tensor 23 | name: self 24 | type: Tensor & 25 | - dynamic_type: Storage 26 | name: source 27 | type: Storage 28 | method_of: 29 | - Type 30 | - namespace 31 | mode: TH 32 | python_module: '' 33 | buffers: [] 34 | returns: 35 | - dynamic_type: Tensor 36 | name: self 37 | type: Tensor & 38 | inplace: true 39 | is_factory_method: false 40 | abstract: true 41 | requires_tensor: false 42 | device_guard: false 43 | with_gil: false 44 | deprecated: false 45 | -} 46 | 47 | data Type = Type 48 | { name' :: String 49 | , dynamic_type' :: S.Parsable 50 | , type' :: String 51 | , size' :: Maybe Int 52 | } deriving (Show, Eq, Generic) 53 | 54 | type2type :: Type -> S.Parsable 55 | type2type typ = 56 | case dynamic_type' typ of 57 | S.TenType S.Scalar -> if type' typ == "Tensor" then S.TenType S.Tensor else S.TenType S.Scalar 58 | S.TenType (S.IntList s) -> 59 | case size' typ of 60 | Nothing -> S.TenType (S.IntList {S.dim = s}) 61 | Just s' -> S.TenType (S.IntList {S.dim = Just [s']}) 62 | a -> a 63 | 64 | data Mode 65 | = TH 66 | | THC 67 | | NN 68 | | Native 69 | deriving (Show, Eq, Generic) 70 | 71 | data Declaration = Declaration 72 | { name :: String 73 | , matches_jit_signature :: Bool 74 | , schema_string :: String 75 | , method_prefix_derived :: String 76 | , arguments :: [Type] 77 | , method_of :: [String] 78 | , mode :: Mode 79 | , python_module :: String 80 | -- , buffers :: [String] 81 | , returns :: [Type] 82 | , inplace :: Bool 83 | , is_factory_method :: Maybe Bool 84 | , abstract :: Bool 85 | , requires_tensor :: Bool 86 | , device_guard :: Maybe Bool 87 | , with_gil :: Maybe Bool 88 | , deprecated :: Maybe Bool 89 | } deriving (Show, Eq, Generic) 90 | 91 | 92 | instance FromJSON Type where 93 | parseJSON = genericParseJSON defaultOptions{ fieldLabelModifier = reverse.(drop 1).reverse } 94 | 95 | instance FromJSON Mode where 96 | parseJSON (String "TH") = pure TH 97 | parseJSON (String "THC") = pure THC 98 | parseJSON (String "NN") = pure NN 99 | parseJSON (String "native") = pure Native 100 | parseJSON v = fail $ show v <> " is not a string of Mode." 101 | 102 | instance FromJSON Declaration 103 | 104 | 105 | decodeAndPrint :: String -> IO () 106 | decodeAndPrint fileName = do 107 | file <- Y.decodeFileEither fileName :: IO (Either ParseException [Declaration]) 108 | prettyPrint file 109 | -------------------------------------------------------------------------------- /codegen/src/ParseDerivatives.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE OverloadedStrings #-} 5 | 6 | module ParseDerivatives where 7 | 8 | import GHC.Generics 9 | import Data.Yaml 10 | 11 | import qualified Data.Yaml as Y 12 | import Text.Show.Prettyprint (prettyPrint) 13 | 14 | {- derivatives.yaml -} 15 | 16 | data Derivative = Derivative { 17 | name :: String 18 | , self :: Maybe String 19 | , other :: Maybe String 20 | , tensor1 :: Maybe String 21 | , tensor2 :: Maybe String 22 | , tensors :: Maybe String 23 | , mat1 :: Maybe String 24 | , mat2 :: Maybe String 25 | , vec :: Maybe String 26 | , batch1 :: Maybe String 27 | , batch2 :: Maybe String 28 | , output_differentiability :: Maybe [Bool] 29 | , value :: Maybe String 30 | , exponent :: Maybe String 31 | , src :: Maybe String 32 | , grad_output :: Maybe String 33 | , weight :: Maybe String 34 | , bias :: Maybe String 35 | , input :: Maybe String 36 | , input2 :: Maybe String 37 | , input3 :: Maybe String 38 | , input_gates :: Maybe String 39 | , input_bias :: Maybe String 40 | , hidden_gates :: Maybe String 41 | , hidden_bias :: Maybe String 42 | , cx :: Maybe String 43 | , hx :: Maybe String 44 | , save_mean :: Maybe String 45 | , save_var :: Maybe String 46 | , grid :: Maybe String 47 | , i1 :: Maybe String 48 | , i2 :: Maybe String 49 | , i3 :: Maybe String 50 | } deriving (Show, Generic) 51 | 52 | instance FromJSON Derivative 53 | 54 | decodeAndPrint :: String -> IO () 55 | decodeAndPrint fileName = do 56 | file <- Y.decodeFileEither fileName :: IO (Either ParseException [Derivative]) 57 | prettyPrint file 58 | -------------------------------------------------------------------------------- /codegen/src/ParseTuples.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | 5 | module ParseTuples where 6 | 7 | import GHC.Generics 8 | import Data.Yaml 9 | 10 | import qualified ParseFunctionSig as S 11 | 12 | {- spec/tuples.yaml -} 13 | 14 | data Tuple = Tuple { 15 | types :: [S.Parsable] 16 | } deriving (Show, Eq, Generic) 17 | 18 | instance FromJSON Tuple 19 | 20 | -------------------------------------------------------------------------------- /codegen/src/RenderClass.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE QuasiQuotes #-} 7 | module RenderClass where 8 | 9 | import Data.Yaml (ParseException) 10 | import qualified Data.Yaml as Y 11 | import Text.Shakespeare.Text (st) 12 | import Data.Text (Text) 13 | import Data.String (fromString) 14 | import qualified Data.Text.IO as T 15 | import System.Directory (createDirectoryIfMissing) 16 | 17 | import qualified ParseClass as PC 18 | import RenderCommon 19 | 20 | renderImport :: Bool -> PC.CppClassSpec -> Text 21 | renderImport is_managed typ_ = if is_managed then [st| 22 | import Foreign.C.String 23 | import Foreign.C.Types 24 | import Foreign hiding (newForeignPtr) 25 | import Foreign.Concurrent 26 | import ATen.Type 27 | import ATen.Class 28 | import ATen.Cast 29 | import ATen.Unmanaged.Type.Generator 30 | import ATen.Unmanaged.Type.IntArray 31 | import ATen.Unmanaged.Type.Scalar 32 | import ATen.Unmanaged.Type.SparseTensorRef 33 | import ATen.Unmanaged.Type.Storage 34 | import ATen.Unmanaged.Type.Tensor 35 | import ATen.Unmanaged.Type.TensorList 36 | import ATen.Unmanaged.Type.TensorOptions 37 | import ATen.Unmanaged.Type.Tuple 38 | 39 | import qualified #{"ATen.Unmanaged.Type." <> (PC.hsnameWithoutSpace typ_)} as Unmanaged 40 | |] else [st| 41 | import qualified Language.C.Inline.Cpp as C 42 | import qualified Language.C.Inline.Cpp.Exceptions as C 43 | import qualified Language.C.Inline.Context as C 44 | import qualified Language.C.Types as C 45 | import qualified Data.Map as Map 46 | import Foreign.C.String 47 | import Foreign.C.Types 48 | import Foreign hiding (newForeignPtr) 49 | import Foreign.Concurrent 50 | import ATen.Type 51 | import ATen.Class 52 | 53 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 54 | 55 | C.include "" 56 | C.include "" 57 | |] 58 | 59 | 60 | renderConstructors :: Bool -> PC.CppClassSpec -> Text 61 | renderConstructors is_managed typ_ = mconcat $ map (methodToCpp typ_ True is_managed True "" "") (PC.constructors typ_) 62 | 63 | renderDestructor :: Bool -> PC.CppClassSpec -> Text 64 | renderDestructor is_managed typ_ = if is_managed then "" else [st| 65 | delete#{PC.hsnameWithoutSpace typ_} :: Ptr #{PC.hsnameWithParens typ_} -> IO () 66 | delete#{PC.hsnameWithoutSpace typ_} object = #{bra}C.throwBlock| void { delete $(#{PC.cppname typ_}* object);}|#{cket} 67 | 68 | instance CppObject #{PC.hsnameWithParens typ_} where 69 | fromPtr ptr = newForeignPtr ptr (delete#{PC.hsnameWithoutSpace typ_} ptr) 70 | |] 71 | 72 | 73 | renderMethods :: Bool -> PC.CppClassSpec -> Text 74 | renderMethods is_managed typ_ = mconcat $ map (methodToCpp typ_ False is_managed True "" "") (PC.methods typ_) 75 | 76 | renderFunctions :: Bool -> PC.CppClassSpec -> Text 77 | renderFunctions is_managed typ_ = mconcat $ map (functionToCpp is_managed True "at::" "") (PC.functions typ_) 78 | 79 | decodeAndCodeGen :: String -> String -> IO () 80 | decodeAndCodeGen basedir fileName = do 81 | funcs <- Y.decodeFileEither fileName :: IO (Either ParseException PC.CppClassSpec) 82 | case funcs of 83 | Left err' -> print err' 84 | Right fns -> do 85 | createDirectoryIfMissing True (basedir <> "/ATen/Unmanaged/Type") 86 | T.writeFile (basedir <> "/ATen/Unmanaged/Type/" <> PC.hsnameWithoutSpace fns <> ".hs") $ 87 | template False ("ATen.Unmanaged.Type." <> fromString (PC.hsnameWithoutSpace fns)) fns 88 | createDirectoryIfMissing True (basedir <> "/ATen/Managed/Type") 89 | T.writeFile (basedir <> "/ATen/Managed/Type/" <> PC.hsnameWithoutSpace fns <> ".hs") $ 90 | template True ("ATen.Managed.Type." <> fromString (PC.hsnameWithoutSpace fns)) fns 91 | 92 | 93 | template :: Bool -> Text -> PC.CppClassSpec -> Text 94 | template is_managed module_name types = [st| 95 | {-# LANGUAGE DataKinds #-} 96 | {-# LANGUAGE PolyKinds #-} 97 | {-# LANGUAGE TemplateHaskell #-} 98 | {-# LANGUAGE QuasiQuotes #-} 99 | {-# LANGUAGE ScopedTypeVariables #-} 100 | {-# LANGUAGE OverloadedStrings #-} 101 | {-# LANGUAGE TypeFamilies #-} 102 | {-# LANGUAGE FlexibleInstances #-} 103 | 104 | module #{module_name} where 105 | 106 | #{renderImport is_managed types} 107 | 108 | #{renderConstructors is_managed types} 109 | 110 | #{renderDestructor is_managed types} 111 | 112 | #{renderMethods is_managed types} 113 | 114 | #{renderFunctions is_managed types} 115 | |] 116 | 117 | -------------------------------------------------------------------------------- /codegen/src/RenderPure.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE QuasiQuotes #-} 7 | module RenderPure where 8 | 9 | import Control.Monad (forM_) 10 | import GHC.Generics 11 | import Data.Yaml (ParseException,FromJSON(..)) 12 | import qualified Data.Yaml as Y 13 | import Text.Shakespeare.Text (st) 14 | import Data.Text (Text) 15 | import Data.List (isPrefixOf, isSuffixOf, sort) 16 | import Data.Maybe (isJust) 17 | import qualified Data.Text.IO as T 18 | import System.Directory (createDirectoryIfMissing) 19 | import Data.Aeson.Types -- (defaultOptions, genericParseJSON, constructorTagModifier, sumEncoding(..)) 20 | 21 | import qualified ParseDeclarations as D 22 | import ParseFunctionSig as P 23 | import RenderCommon 24 | 25 | 26 | data Binding 27 | = BindRename { src :: String, dst :: String } 28 | | Bind { src :: String } 29 | | BindRemove { src :: String } 30 | deriving (Show, Eq, Generic) 31 | 32 | instance FromJSON Binding where 33 | parseJSON = genericParseJSON defaultOptions{ 34 | sumEncoding = ObjectWithSingleField, 35 | allNullaryToStringTag = True, 36 | constructorTagModifier = \tag -> 37 | case tag of 38 | "BindRename" -> "rename" 39 | "Bind" -> "bind" 40 | "BindRemove" -> "remove" 41 | a -> a 42 | } 43 | 44 | 45 | 46 | toFunction :: D.Declaration -> P.Function 47 | toFunction dl = P.Function 48 | { P.name = D.name dl 49 | , P.parameters = map (\a -> P.Parameter (D.type2type a) (D.name' a) Nothing) $ D.arguments dl 50 | , P.retType = case D.returns dl of 51 | [a] -> D.type2type a 52 | ax -> P.Tuple $ map D.type2type ax 53 | , P.variant = P.VFunction 54 | } 55 | 56 | renderFunctions :: [(String, D.Declaration)] -> Text 57 | renderFunctions nfs = mconcat $ flip map nfs $ \(n,nf) -> pureFunction n (toFunction nf) 58 | 59 | isRemove :: Binding -> Bool 60 | isRemove (BindRemove _) = True 61 | isRemove _ = False 62 | 63 | isRename :: Binding -> Bool 64 | isRename (BindRename _ _) = True 65 | isRename _ = False 66 | 67 | removeBinding :: Binding -> (String, D.Declaration) -> Bool 68 | removeBinding (BindRemove n) (hsName, _) = n == hsName 69 | removeBinding _ _ = False 70 | 71 | removeBinding' :: [Binding] -> (String, D.Declaration) -> Bool 72 | removeBinding' bindings decl = any (\b -> removeBinding b decl) bindings 73 | 74 | removeFilter :: [Binding] -> [(String, D.Declaration)] -> [(String, D.Declaration)] 75 | removeFilter bindings fns = filter (\v -> not (removeBinding' bindings' v)) fns 76 | where 77 | bindings' = filter isRemove bindings 78 | 79 | renameBinding :: Binding -> (String, D.Declaration) -> Maybe (String, D.Declaration) 80 | renameBinding (BindRename n new_name) (hsName, decl) = 81 | if n == hsName then Just (new_name,decl) else Nothing 82 | renameBinding _ _ = Nothing 83 | 84 | renameBinding' :: [Binding] -> (String, D.Declaration) -> (String, D.Declaration) 85 | renameBinding' bindings decl@(_,d) = 86 | case foldl (\i b -> let v = (renameBinding b decl) in if isJust v then v else i) Nothing bindings of 87 | Just v -> v 88 | Nothing -> (D.name d,d) 89 | 90 | renameFilter :: [Binding] -> [(String, D.Declaration)] -> [(String, D.Declaration)] 91 | renameFilter bindings fns = map (renameBinding' bindings') fns 92 | where 93 | bindings' = filter isRename bindings 94 | 95 | nativeFunctionsFilter :: [D.Declaration] -> [Binding] -> [(String, D.Declaration)] 96 | nativeFunctionsFilter fns bindings = 97 | filter (\(_,a) -> 98 | D.mode a == D.Native && 99 | "namespace" `elem` (D.method_of a) && 100 | D.is_factory_method a == Nothing && 101 | not (isPrefixOf "_" (D.name a)) && 102 | not (isSuffixOf "_" (D.name a)) && 103 | not (isSuffixOf "_out" (D.name a)) && 104 | all (/= P.Ptr P.GeneratorType) (map D.dynamic_type' (D.arguments a)) 105 | ) $ 106 | renameFilter bindings $ 107 | removeFilter bindings $ 108 | map (\f -> (getSignatures (toFunction f),f)) fns 109 | 110 | notUniqList :: [String] -> [String] 111 | notUniqList lst = notUniq (sort lst) [] 112 | where 113 | notUniq [] a = a 114 | notUniq (x:y:xs) ys = if x == y then notUniq xs (y:ys) else (notUniq (y:xs) ys) 115 | notUniq _ b = b 116 | 117 | decodeAndCodeGen :: String -> String -> String -> IO () 118 | decodeAndCodeGen basedir yamlSpecFileName bindingsFileName = do 119 | funcs <- Y.decodeFileEither yamlSpecFileName :: IO (Either ParseException [D.Declaration]) 120 | bindings <- Y.decodeFileEither bindingsFileName :: IO (Either ParseException [Binding]) 121 | case (funcs,bindings) of 122 | (Left err', _) -> print err' 123 | (Right _ , Left err') -> print err' 124 | (Right fns, Right bnd) -> do 125 | createDirectoryIfMissing True (basedir <> "/Torch/Functions/") 126 | let l = nativeFunctionsFilter fns bnd 127 | 128 | case notUniqList (map fst l) of 129 | [] -> do 130 | T.writeFile (basedir <> "/Torch/Functions/Native.hs") $ 131 | template "Torch.Functions.Native" $ 132 | renderFunctions l 133 | xs -> do 134 | putStrLn "---Duplicated functions are as follows. ----" 135 | forM_ xs $ \x -> do 136 | putStrLn x 137 | putStrLn "---To generate functions, add following commands in spec/bindings.yaml ----" 138 | forM_ (filter (\(i,_) -> i `elem` xs) l) $ \(_,x) -> do 139 | putStrLn $ "- remove: {src: "<> getSignatures (toFunction x) <>"}" 140 | 141 | 142 | renderImport :: Text 143 | renderImport = [st| 144 | import System.IO.Unsafe 145 | import Foreign.ForeignPtr 146 | 147 | import qualified ATen.Managed.Native as ATen 148 | import qualified ATen.Managed.Type.Tensor as ATen 149 | import qualified ATen.Managed.Type.Scalar as ATen 150 | import qualified ATen.Managed.Type.Tuple as ATen 151 | import qualified ATen.Const as ATen 152 | import qualified ATen.Type as ATen 153 | import qualified ATen.Managed.Cast 154 | import ATen.Cast 155 | 156 | import Torch.Tensor 157 | import Torch.Scalar 158 | |] 159 | 160 | template :: Text -> Text -> Text 161 | template module_name functions = [st| 162 | -- generated by using spec/Declarations.yaml 163 | 164 | {-# LANGUAGE DataKinds #-} 165 | {-# LANGUAGE PolyKinds #-} 166 | {-# LANGUAGE TemplateHaskell #-} 167 | {-# LANGUAGE QuasiQuotes #-} 168 | {-# LANGUAGE ScopedTypeVariables #-} 169 | {-# LANGUAGE OverloadedStrings #-} 170 | 171 | module #{module_name} where 172 | 173 | #{renderImport} 174 | #{functions} 175 | |] 176 | -------------------------------------------------------------------------------- /codegen/src/RenderTuples.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE QuasiQuotes #-} 7 | module RenderTuples where 8 | 9 | import Data.Yaml (ParseException) 10 | import qualified Data.Yaml as Y 11 | import Text.Shakespeare.Text (st) 12 | import Data.Text (Text) 13 | import Data.Maybe (mapMaybe) 14 | import Data.List (nubBy) 15 | import qualified Data.Text.IO as T 16 | import qualified Data.Text as T 17 | import System.Directory (createDirectoryIfMissing) 18 | 19 | import qualified ParseTuples as PT 20 | import qualified ParseDeclarations as D 21 | import ParseFunctionSig as P 22 | import RenderCommon 23 | 24 | 25 | renderImport :: Bool -> Text 26 | renderImport is_managed = if is_managed then [st| 27 | import Foreign.C.String 28 | import Foreign.C.Types 29 | import Foreign hiding (newForeignPtr) 30 | import Foreign.Concurrent 31 | import ATen.Type 32 | import ATen.Class 33 | import ATen.Cast 34 | 35 | import qualified ATen.Unmanaged.Type.Tuple as Unmanaged 36 | import ATen.Unmanaged.Type.Generator 37 | import ATen.Unmanaged.Type.IntArray 38 | import ATen.Unmanaged.Type.Scalar 39 | import ATen.Unmanaged.Type.SparseTensorRef 40 | import ATen.Unmanaged.Type.Storage 41 | import ATen.Unmanaged.Type.Tensor 42 | import ATen.Unmanaged.Type.TensorList 43 | import ATen.Unmanaged.Type.TensorOptions 44 | import ATen.Unmanaged.Type.Tuple 45 | |] else [st| 46 | import Foreign.C.String 47 | import Foreign.C.Types 48 | import Foreign hiding (newForeignPtr) 49 | import Foreign.Concurrent 50 | import ATen.Type 51 | import ATen.Class 52 | 53 | import qualified Language.C.Inline.Cpp as C 54 | import qualified Language.C.Inline.Cpp.Exceptions as C 55 | import qualified Language.C.Inline.Context as C 56 | import qualified Language.C.Types as C 57 | import qualified Data.Map as Map 58 | 59 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 60 | 61 | C.include "" 62 | |] 63 | 64 | tupleToCpp :: PT.Tuple -> Text 65 | tupleToCpp (PT.Tuple parsables) = [st|std::tuple<#{T.intercalate "," (map parsableToCppType parsables)}>|] 66 | 67 | tupleToHs :: PT.Tuple -> Text 68 | tupleToHs (PT.Tuple parsables) = [st|(#{T.intercalate "," (map parsableToHsType parsables)})|] 69 | 70 | tupleToHs' :: PT.Tuple -> Text 71 | tupleToHs' (PT.Tuple parsables) = [st|#{T.intercalate "" (map parsableToHsType parsables)}|] 72 | 73 | toHs :: P.Parsable -> Text 74 | toHs typ_ = 75 | if isCType typ_ 76 | then [st|#{parsableToHsType typ_}|] 77 | else [st|Ptr #{parsableToHsType typ_}|] 78 | 79 | toManagedHs :: P.Parsable -> Text 80 | toManagedHs typ_ = 81 | if isCType typ_ 82 | then [st|#{parsableToHsType typ_}|] 83 | else [st|ForeignPtr #{parsableToHsType typ_}|] 84 | 85 | 86 | toCpp :: P.Parsable -> Text 87 | toCpp typ_ = 88 | if isCType typ_ 89 | then [st|#{parsableToCppType typ_}|] 90 | else [st|#{parsableToCppType typ_}*|] 91 | 92 | toCpp' :: P.Parsable -> Text 93 | toCpp' typ_ = 94 | if isCType typ_ 95 | then [st||] 96 | else [st|new #{parsableToCppType typ_}|] 97 | 98 | renderCppObject :: PT.Tuple -> Text 99 | renderCppObject typ_ = [st| 100 | 101 | -----------------#{tupleToHs typ_}--------------------- 102 | 103 | delete#{tupleToHs' typ_} :: Ptr #{tupleToHs typ_} -> IO () 104 | delete#{tupleToHs' typ_} ptr = #{bra}C.throwBlock| void { delete $(#{tupleToCpp typ_}* ptr); return; }|#{cket} 105 | 106 | instance CppObject #{tupleToHs typ_} where 107 | fromPtr ptr = newForeignPtr ptr (delete#{tupleToHs' typ_} ptr) 108 | |] 109 | 110 | renderCppTuple2 :: PT.Tuple -> Text 111 | renderCppTuple2 typ_@(PT.Tuple (a:b:_)) = [st| 112 | instance CppTuple2 (Ptr #{tupleToHs typ_}) where 113 | type A (Ptr #{tupleToHs typ_}) = #{toHs a} 114 | type B (Ptr #{tupleToHs typ_}) = #{toHs b} 115 | get0 v = #{bra}C.throwBlock| #{toCpp a} { return #{toCpp' a}(std::get<0>(*$(#{tupleToCpp typ_}* v)));}|#{cket} 116 | get1 v = #{bra}C.throwBlock| #{toCpp b} { return #{toCpp' b}(std::get<1>(*$(#{tupleToCpp typ_}* v)));}|#{cket} 117 | |] 118 | renderCppTuple2 _ = "" 119 | 120 | renderCppTuple3 :: PT.Tuple -> Text 121 | renderCppTuple3 typ_@(PT.Tuple (_:_:c:_)) = [st| 122 | instance CppTuple3 (Ptr #{tupleToHs typ_}) where 123 | type C (Ptr #{tupleToHs typ_}) = #{toHs c} 124 | get2 v = #{bra}C.throwBlock| #{toCpp c} { return #{toCpp' c}(std::get<2>(*$(#{tupleToCpp typ_}* v)));}|#{cket} 125 | |] 126 | renderCppTuple3 _ = "" 127 | 128 | renderCppTuple4 :: PT.Tuple -> Text 129 | renderCppTuple4 typ_@(PT.Tuple (_:_:_:d:_)) = [st| 130 | instance CppTuple4 (Ptr #{tupleToHs typ_}) where 131 | type D (Ptr #{tupleToHs typ_}) = #{toHs d} 132 | get3 v = #{bra}C.throwBlock| #{toCpp d} { return #{toCpp' d}(std::get<3>(*$(#{tupleToCpp typ_}* v)));}|#{cket} 133 | |] 134 | renderCppTuple4 _ = "" 135 | 136 | 137 | renderCppTuple5 :: PT.Tuple -> Text 138 | renderCppTuple5 typ_@(PT.Tuple (_:_:_:_:e:_)) = [st| 139 | instance CppTuple5 (Ptr #{tupleToHs typ_}) where 140 | type E (Ptr #{tupleToHs typ_}) = #{toHs e} 141 | get4 v = #{bra}C.throwBlock| #{toCpp e} { return #{toCpp' e}(std::get<4>(*$(#{tupleToCpp typ_}* v)));}|#{cket} 142 | |] 143 | renderCppTuple5 _ = "" 144 | 145 | renderManagedCppTuple2 :: PT.Tuple -> Text 146 | renderManagedCppTuple2 typ_@(PT.Tuple (a:b:_)) = [st| 147 | instance CppTuple2 (ForeignPtr #{tupleToHs typ_}) where 148 | type A (ForeignPtr #{tupleToHs typ_}) = #{toManagedHs a} 149 | type B (ForeignPtr #{tupleToHs typ_}) = #{toManagedHs b} 150 | get0 v = cast1 (get0 :: Ptr #{tupleToHs typ_} -> IO (#{toHs a})) v 151 | get1 v = cast1 (get1 :: Ptr #{tupleToHs typ_} -> IO (#{toHs b})) v 152 | |] 153 | renderManagedCppTuple2 _ = "" 154 | 155 | renderManagedCppTuple3 :: PT.Tuple -> Text 156 | renderManagedCppTuple3 typ_@(PT.Tuple (_:_:c:_)) = [st| 157 | instance CppTuple3 (ForeignPtr #{tupleToHs typ_}) where 158 | type C (ForeignPtr #{tupleToHs typ_}) = #{toManagedHs c} 159 | get2 v = cast1 (get2 :: Ptr #{tupleToHs typ_} -> IO (#{toHs c})) v 160 | |] 161 | renderManagedCppTuple3 _ = "" 162 | 163 | renderManagedCppTuple4 :: PT.Tuple -> Text 164 | renderManagedCppTuple4 typ_@(PT.Tuple (_:_:_:d:_)) = [st| 165 | instance CppTuple4 (ForeignPtr #{tupleToHs typ_}) where 166 | type D (ForeignPtr #{tupleToHs typ_}) = #{toManagedHs d} 167 | get3 v = cast1 (get3 :: Ptr #{tupleToHs typ_} -> IO (#{toHs d})) v 168 | |] 169 | renderManagedCppTuple4 _ = "" 170 | 171 | 172 | renderManagedCppTuple5 :: PT.Tuple -> Text 173 | renderManagedCppTuple5 typ_@(PT.Tuple (_:_:_:_:e:_)) = [st| 174 | instance CppTuple5 (ForeignPtr #{tupleToHs typ_}) where 175 | type E (ForeignPtr #{tupleToHs typ_}) = #{toManagedHs e} 176 | get4 v = cast1 (get4 :: Ptr #{tupleToHs typ_} -> IO (#{toHs e})) v 177 | |] 178 | renderManagedCppTuple5 _ = "" 179 | 180 | 181 | 182 | renderTuples :: Bool -> [PT.Tuple] -> Text 183 | renderTuples True [] = "" 184 | renderTuples True (x:xs) = 185 | renderManagedCppTuple2 x <> 186 | renderManagedCppTuple3 x <> 187 | renderManagedCppTuple4 x <> 188 | renderManagedCppTuple5 x <> 189 | renderTuples True xs 190 | renderTuples False [] = "" 191 | renderTuples False (x:xs) = 192 | renderCppObject x <> 193 | renderCppTuple2 x <> 194 | renderCppTuple3 x <> 195 | renderCppTuple4 x <> 196 | renderCppTuple5 x <> 197 | renderTuples False xs 198 | 199 | 200 | decodeAndCodeGen :: String -> String -> IO () 201 | decodeAndCodeGen basedir fileName = do 202 | maybe_decls <- Y.decodeFileEither fileName :: IO (Either ParseException [D.Declaration]) 203 | --funcs <- Y.decodeFileEither fileName :: IO (Either ParseException [PT.Tuple]) 204 | case maybe_decls of 205 | Left err' -> print err' 206 | Right decls -> do 207 | let tuples = nubBy tupleHsTypeEq $ mapMaybe (getTupleType . D.returns) decls 208 | createDirectoryIfMissing True (basedir <> "/ATen/Unmanaged/Type") 209 | T.writeFile (basedir <> "/ATen/Unmanaged/Type/Tuple.hs") $ 210 | template False "ATen.Unmanaged.Type.Tuple" tuples 211 | createDirectoryIfMissing True (basedir <> "/ATen/Managed/Type") 212 | T.writeFile (basedir <> "/ATen/Managed/Type/Tuple.hs") $ 213 | template True "ATen.Managed.Type.Tuple" tuples 214 | where 215 | getTupleType :: [D.Type] -> Maybe PT.Tuple 216 | getTupleType [] = Nothing 217 | getTupleType [_] = Nothing 218 | getTupleType rets = Just $ PT.Tuple $ map D.type2type rets 219 | 220 | tupleHsTypeEq :: PT.Tuple -> PT.Tuple -> Bool 221 | tupleHsTypeEq a b = (fmap parsableToHsType (PT.types a)) == (fmap parsableToHsType (PT.types b)) 222 | 223 | 224 | template :: Bool -> Text -> [PT.Tuple] -> Text 225 | template is_managed module_name types = [st| 226 | -- generated by using spec/tuples.yaml 227 | 228 | {-# LANGUAGE DataKinds #-} 229 | {-# LANGUAGE PolyKinds #-} 230 | {-# LANGUAGE TemplateHaskell #-} 231 | {-# LANGUAGE QuasiQuotes #-} 232 | {-# LANGUAGE ScopedTypeVariables #-} 233 | {-# LANGUAGE OverloadedStrings #-} 234 | {-# LANGUAGE TypeFamilies #-} 235 | {-# LANGUAGE FlexibleInstances #-} 236 | 237 | module #{module_name} where 238 | 239 | #{renderImport is_managed} 240 | 241 | #{renderTuples is_managed types} 242 | |] 243 | 244 | -------------------------------------------------------------------------------- /codegen/test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE LambdaCase #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE TypeApplications #-} 4 | module Main where 5 | 6 | import Control.Exception.Safe (throwString, throw) 7 | import Data.Proxy 8 | import Text.Megaparsec (parse, errorBundlePretty) 9 | import ParseDerivatives (Derivative) 10 | import ParseDeclarations (Declaration) 11 | import System.Directory (doesFileExist) 12 | import Test.Hspec 13 | import qualified Data.Yaml as Y 14 | import ParseFunctionSig 15 | 16 | main :: IO () 17 | main = hspec $ do 18 | describe "parsing derivatives.yaml" $ do 19 | describe "Derivatives Spec" derivativesSpec 20 | describe "parsing Declarations.yaml" $ do 21 | describe "Declarations Spec" declarationsSpec 22 | 23 | 24 | derivativesPath :: FilePath 25 | derivativesPath = "../deps/pytorch/tools/autograd/derivatives.yaml" 26 | 27 | derivativesSpec :: Spec 28 | derivativesSpec = do 29 | xs <- runIO $ vanillaParse derivativesPath 30 | 31 | it "parses the same number of stringy functions as a vanilla parsing" $ do 32 | fs <- parseWith (Proxy @ Derivative) 33 | (length fs) `shouldBe` (length xs) 34 | 35 | where 36 | parseWith :: forall funtype . Y.FromJSON funtype => Proxy funtype -> IO [funtype] 37 | parseWith _ = do 38 | Y.decodeFileEither derivativesPath >>= \case 39 | Left exception -> throw exception 40 | Right (fs::[funtype]) -> pure fs 41 | 42 | vanillaParse :: FilePath -> IO [Y.Value] 43 | vanillaParse fp = do 44 | doesFileExist fp >>= \case 45 | False -> throwString $ "Spec " ++ fp ++ " doesn't exist! Review README to get spec yaml" 46 | True -> Y.decodeFileThrow fp 47 | 48 | 49 | declarationsPath :: FilePath 50 | declarationsPath = "../spec/Declarations.yaml" 51 | 52 | declarationsSpec :: Spec 53 | declarationsSpec = do 54 | xs <- runIO $ vanillaParse declarationsPath 55 | 56 | it "parses the same number of stringy functions as a vanilla parsing" $ do 57 | fs <- parseWith (Proxy @ Declaration) 58 | (length fs) `shouldBe` (length xs) 59 | 60 | where 61 | parseWith :: forall funtype . Y.FromJSON funtype => Proxy funtype -> IO [funtype] 62 | parseWith _ = do 63 | Y.decodeFileEither declarationsPath >>= \case 64 | Left exception -> throw exception 65 | Right (fs::[funtype]) -> pure fs 66 | 67 | -------------------------------------------------------------------------------- /codegen/test/doctests.hs: -------------------------------------------------------------------------------- 1 | module Main where 2 | 3 | import Test.DocTest 4 | 5 | main :: IO () 6 | main = do 7 | doctest $ 8 | [ 9 | "-XOverloadedStrings", 10 | "src/ParseFunctionSig.hs" 11 | ] 12 | -------------------------------------------------------------------------------- /deps/.gitignore: -------------------------------------------------------------------------------- 1 | libtorch/ 2 | mklml_mac_2019.0.1.20181227/ 3 | -------------------------------------------------------------------------------- /deps/get-deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # - gets submodules recursively to get pytorch and inline-c fork repo dependencies 4 | # - Retrieves a prebuilt libtorch binary per https://pytorch.org/cppdocs/installing.html 5 | # - Retrieves a release binary for mkl https://github.com/intel/mkl-dnn/releases 6 | # which is a runtime dependency that is not package w/ libtorch 7 | 8 | set -eu 9 | 10 | usage_exit() { 11 | echo "Usage: $0 [-n] [-c] [-a "cpu" or "cu90" or "cu100"] [-s]" 1>&2 12 | echo " -n # Use nightly libtorch w/ -l" 1>&2 13 | echo " # Use libtorch-1.1.0 w/o -l" 1>&2 14 | echo "" 1>&2 15 | echo " -c # Download libtorch from hasktorch's site w/ -c" 1>&2 16 | echo " # Download libtorch from pytorch's site w/o -c" 1>&2 17 | echo "" 1>&2 18 | echo " -a cpu # Use CPU without CUDA" 1>&2 19 | echo " -a cu90 # Use CUDA-9" 1>&2 20 | echo " -a cu100 # Use CUDA-10" 1>&2 21 | echo "" 1>&2 22 | echo " -s # Skip download" 1>&2 23 | echo "" 1>&2 24 | echo " -h # Show this help" 1>&2 25 | exit 1 26 | } 27 | 28 | USE_NIGHTLY=0 29 | USE_BINARY_FOR_CI=0 30 | COMPUTE_ARCH=cpu 31 | SKIP_DOWNLOAD=0 32 | 33 | while getopts nca:sh OPT 34 | do 35 | case $OPT in 36 | n) USE_NIGHTLY=1 37 | ;; 38 | c) USE_BINARY_FOR_CI=1 39 | ;; 40 | a) COMPUTE_ARCH=$OPTARG 41 | ;; 42 | s) SKIP_DOWNLOAD=1 43 | ;; 44 | h) usage_exit 45 | ;; 46 | \?) usage_exit 47 | ;; 48 | esac 49 | done 50 | 51 | if [ "$SKIP_DOWNLOAD" = 0 ] ; then 52 | git submodule update --init --recursive 53 | 54 | case "$(uname)" in 55 | "Darwin") 56 | if [ "$USE_NIGHTLY" = 1 ] ; then 57 | wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-macos-latest.zip 58 | unzip libtorch-macos-latest.zip 59 | rm libtorch-macos-latest.zip 60 | elif [ "$USE_BINARY_FOR_CI" = 1 ] ; then 61 | wget https://github.com/hasktorch/libtorch-binary-for-ci/releases/download/1.1.0/cpu-libtorch-macos-latest.zip 62 | unzip cpu-libtorch-macos-latest.zip 63 | rm cpu-libtorch-macos-latest.zip 64 | else 65 | wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-1.1.0.zip 66 | unzip libtorch-macos-1.1.0.zip 67 | rm libtorch-macos-1.1.0.zip 68 | fi 69 | wget https://github.com/intel/mkl-dnn/releases/download/v0.17.2/mklml_mac_2019.0.1.20181227.tgz 70 | tar -xzf mklml_mac_2019.0.1.20181227.tgz 71 | rm -f mklml_mac_2019.0.1.20181227.tgz 72 | rm -f mklml_mac_2019.0.1.20181227.tgz.1 73 | rm -rf mklml 74 | mv mklml_mac_2019.0.1.20181227 mklml 75 | ;; 76 | "Linux") 77 | if [ "$USE_NIGHTLY" = 1 ] ; then 78 | wget https://download.pytorch.org/libtorch/nightly/${COMPUTE_ARCH}/libtorch-shared-with-deps-latest.zip 79 | unzip libtorch-shared-with-deps-latest.zip 80 | rm libtorch-shared-with-deps-latest.zip 81 | elif [ "$USE_BINARY_FOR_CI" = 1 ] ; then 82 | wget https://github.com/hasktorch/libtorch-binary-for-ci/releases/download/1.1.0/${COMPUTE_ARCH}-libtorch-shared-with-deps-latest.zip 83 | unzip ${COMPUTE_ARCH}-libtorch-shared-with-deps-latest.zip 84 | rm ${COMPUTE_ARCH}-libtorch-shared-with-deps-latest.zip 85 | else 86 | wget https://download.pytorch.org/libtorch/${COMPUTE_ARCH}/libtorch-shared-with-deps-1.1.0.zip 87 | unzip libtorch-shared-with-deps-1.1.0.zip 88 | rm libtorch-shared-with-deps-1.1.0.zip 89 | fi 90 | wget https://github.com/intel/mkl-dnn/releases/download/v0.17.2/mklml_lnx_2019.0.1.20181227.tgz 91 | tar -xzf mklml_lnx_2019.0.1.20181227.tgz 92 | rm -f mklml_lnx_2019.0.1.20181227.tgz 93 | rm -f mklml_lnx_2019.0.1.20181227.tgz.1 94 | rm -rf mklml 95 | mv mklml_lnx_2019.0.1.20181227 mklml 96 | ln -s libmklml_intel.so mklml/lib/libmklml.so 97 | ;; 98 | esac 99 | fi 100 | 101 | # Following codes are copied from pytorch/tools/run-clang-tidy-in-ci.sh. 102 | # Generate ATen files. 103 | 104 | echo "Generate ATen files." 105 | pushd pytorch 106 | 107 | if [[ ! -d build ]]; then 108 | mkdir build 109 | fi 110 | 111 | python aten/src/ATen/gen.py \ 112 | -s aten/src/ATen \ 113 | -d build/aten/src/ATen \ 114 | aten/src/ATen/Declarations.cwrap \ 115 | aten/src/THNN/generic/THNN.h \ 116 | aten/src/THCUNN/generic/THCUNN.h \ 117 | aten/src/ATen/nn.yaml \ 118 | aten/src/ATen/native/native_functions.yaml 119 | 120 | # Sanitize "name: n" fields to be strings rather than booleans in Declarations.yaml 121 | 122 | case "$(uname)" in 123 | "Darwin") 124 | sed -i '' -e "s/ name: n$/ name: 'n'/g" -e "s/ name: N$/ name: 'N'/g" build/aten/src/ATen/Declarations.yaml 125 | ;; 126 | "Linux") 127 | sed -i -e "s/ name: n$/ name: 'n'/g" -e "s/ name: N$/ name: 'N'/g" build/aten/src/ATen/Declarations.yaml 128 | ;; 129 | esac 130 | 131 | popd 132 | -------------------------------------------------------------------------------- /examples/cnn/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE KindSignatures #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE TypeApplications #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE PolyKinds #-} 7 | {-# LANGUAGE NoStarIsType #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE RecordWildCards #-} 11 | {-# LANGUAGE PartialTypeSignatures #-} 12 | {-# LANGUAGE DuplicateRecordFields #-} 13 | 14 | module Main where 15 | 16 | import GHC.TypeLits 17 | import Data.Proxy 18 | 19 | import qualified Torch.Tensor as D 20 | import qualified Torch.Autograd as A 21 | import qualified Torch.DType as DType 22 | import Torch.Static 23 | 24 | -------------------------------------------------------------------------------- 25 | 26 | data Conv2d dtype (in_features :: Nat) (out_features :: Nat) 27 | (kernel_size :: (Nat, Nat)) 28 | (stride :: (Nat, Nat)) 29 | (padding :: (Nat, Nat)) = 30 | Conv2d { weight :: Tensor dtype '[out_features, in_features, Fst kernel_size, Snd kernel_size] 31 | , bias :: Tensor dtype '[out_features] } 32 | 33 | 34 | -- The constraints on this one are _very_ involved, so the partial signatures 35 | -- make the code significantly cleaner. 36 | conv2d :: forall stride padding. 37 | _ => Conv2d _ _ _ _ stride padding -> Tensor _ _ -> Tensor _ _ 38 | conv2d Conv2d{..} input = conv2dBias @stride @padding input weight bias 39 | 40 | -------------------------------------------------------------------------------- 41 | 42 | data Linear dtype (in_features :: Nat) (out_features :: Nat) = 43 | Linear { weight :: Tensor dtype '[in_features, out_features] 44 | , bias :: Tensor dtype '[out_features] 45 | } 46 | 47 | linear :: Linear dtype in_features out_features -> 48 | Tensor dtype [n, in_features] -> 49 | Tensor dtype [n, out_features] 50 | linear Linear{..} input = add (mm input weight) bias 51 | 52 | -------------------------------------------------------------------------------- 53 | 54 | type NoPadding = '(0, 0) 55 | type NoStrides = '(1, 1) 56 | 57 | data Model dtype = Model { conv1 :: Conv2d dtype 1 20 '(5, 5) NoStrides NoPadding 58 | , conv2 :: Conv2d dtype 20 50 '(5, 5) NoStrides NoPadding 59 | , fc1 :: Linear dtype (4*4*50) 500 60 | , fc2 :: Linear dtype 500 10 61 | } 62 | 63 | model :: forall dtype n. _ => Model dtype -> Tensor dtype [n, 1, 28, 28] -> Tensor dtype [n, 10] 64 | model Model{..} x = output 65 | where 66 | c1 = relu $ conv2d conv1 x 67 | p1 = maxPool2d @'(2, 2) @'(2, 2) @NoPadding c1 68 | c2 = relu $ conv2d conv2 p1 69 | p2 = maxPool2d @'(2, 2) @'(2, 2) @NoPadding c2 70 | flat = reshape @'[n, 4*4*50] p2 71 | f1 = relu $ linear fc1 flat 72 | logits = linear fc2 f1 73 | output = logSoftmax logits 1 74 | 75 | 76 | main = undefined 77 | -------------------------------------------------------------------------------- /examples/elman/Elman.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | 6 | module Elman where 7 | 8 | import Torch.Tensor 9 | import Torch.DType 10 | import Torch.TensorFactories 11 | import Torch.Functions 12 | import Torch.TensorOptions 13 | import Torch.Autograd 14 | import Torch.NN 15 | 16 | import Control.Monad.State.Strict 17 | import Data.List (foldl', scanl', intersperse) 18 | 19 | import RecurrentLayer 20 | 21 | 22 | data ElmanSpec = ElmanSpec { in_features :: Int, hidden_features :: Int } 23 | 24 | data ElmanCell = ElmanCell { 25 | input_weight :: Parameter, 26 | hidden_weight :: Parameter, 27 | bias :: Parameter 28 | } 29 | 30 | 31 | instance RecurrentCell ElmanCell where 32 | 33 | nextState ElmanCell{..} input hidden = 34 | gate input hidden Torch.Functions.tanh input_weight hidden_weight bias 35 | 36 | 37 | instance Randomizable ElmanSpec ElmanCell where 38 | sample ElmanSpec{..} = do 39 | w_ih <- makeIndependent =<< randn' [in_features, hidden_features] 40 | w_hh <- makeIndependent =<< randn' [hidden_features, hidden_features] 41 | b <- makeIndependent =<< randn' [1, hidden_features] 42 | return $ ElmanCell w_ih w_hh b 43 | 44 | 45 | instance Parameterized ElmanCell where 46 | flattenParameters ElmanCell{..} = [input_weight, hidden_weight, bias] 47 | replaceOwnParameters _ = do 48 | input_weight <- nextParameter 49 | hidden_weight <- nextParameter 50 | bias <- nextParameter 51 | return $ ElmanCell{..} 52 | 53 | 54 | instance Show ElmanCell where 55 | show ElmanCell{..} = 56 | (show input_weight) ++ "\n" ++ 57 | (show hidden_weight) ++ "\n" ++ 58 | (show bias) ++ "\n" 59 | 60 | -------------------------------------------------------------------------------- /examples/elman/GRU.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | 6 | module GRU where 7 | 8 | import Torch.Tensor 9 | import Torch.DType 10 | import Torch.TensorFactories 11 | import Torch.Functions 12 | import Torch.TensorOptions 13 | import Torch.Autograd 14 | import Torch.NN 15 | 16 | import Control.Monad.State.Strict 17 | import Data.List (foldl', scanl', intersperse) 18 | 19 | import RecurrentLayer 20 | 21 | 22 | -- Specifying the shape of the recurrent layer 23 | data GRUSpec = GRUSpec { in_f :: Int, h_f :: Int} 24 | 25 | data GRUCell = GRUCell { 26 | reset_gate :: [Parameter], 27 | update_gate :: [Parameter], 28 | gru_hidden_gate :: [Parameter] 29 | } 30 | 31 | 32 | instance RecurrentCell GRUCell where 33 | nextState GRUCell{..} input hidden = 34 | (ug * hidden) + ((1 - ug) * h') 35 | where 36 | rg = gate input hidden Torch.Functions.sigmoid 37 | (reset_gate !! 0) 38 | (reset_gate !! 1) 39 | (reset_gate !! 2) 40 | ug = gate input hidden Torch.Functions.sigmoid 41 | (update_gate !! 0) 42 | (update_gate !! 1) 43 | (update_gate !! 2) 44 | h' = gate input (rg * hidden) Torch.Functions.tanh 45 | (gru_hidden_gate !! 0) 46 | (gru_hidden_gate !! 1) 47 | (gru_hidden_gate !! 2) 48 | 49 | 50 | instance Randomizable GRUSpec GRUCell where 51 | sample GRUSpec{..} = do 52 | rg_ih <- makeIndependent =<< randn' [in_f, h_f] 53 | rg_hh <- makeIndependent =<< randn' [h_f, h_f] 54 | rg_b <- makeIndependent =<< randn' [1, h_f] 55 | ug_ih <- makeIndependent =<< randn' [in_f, h_f] 56 | ug_hh <- makeIndependent =<< randn' [h_f, h_f] 57 | ug_b <- makeIndependent =<< randn' [1, h_f] 58 | hg_ih <- makeIndependent =<< randn' [in_f, h_f] 59 | hg_hh <- makeIndependent =<< randn' [h_f, h_f] 60 | hg_b <- makeIndependent =<< randn' [1, h_f] 61 | let rg = [rg_ih, rg_hh, rg_b] 62 | let ug = [ug_ih, ug_hh, ug_b] 63 | let hg = [hg_ih, hg_hh, hg_b] 64 | return $ GRUCell rg ug hg 65 | 66 | 67 | instance Parameterized GRUCell where 68 | flattenParameters GRUCell{..} = 69 | reset_gate ++ update_gate ++ gru_hidden_gate 70 | replaceOwnParameters _ = do 71 | rg_ih <- nextParameter 72 | rg_hh <- nextParameter 73 | rg_b <- nextParameter 74 | ug_ih <- nextParameter 75 | ug_hh <- nextParameter 76 | ug_b <- nextParameter 77 | hg_ih <- nextParameter 78 | hg_hh <- nextParameter 79 | hg_b <- nextParameter 80 | let reset_gate = [rg_ih, rg_hh, rg_b] 81 | let update_gate = [ug_ih, ug_hh, ug_b] 82 | let gru_hidden_gate = [hg_ih, hg_hh, hg_b] 83 | return $ GRUCell{..} 84 | 85 | 86 | instance Show GRUCell where 87 | show GRUCell{..} = 88 | (show $ reset_gate) ++ "\n" ++ 89 | (show $ update_gate) ++ "\n" ++ 90 | (show $ gru_hidden_gate) 91 | -------------------------------------------------------------------------------- /examples/elman/LSTM.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | 6 | module LSTM where 7 | 8 | import Torch.Tensor 9 | import Torch.DType 10 | import Torch.TensorFactories 11 | import Torch.Functions 12 | import Torch.TensorOptions 13 | import Torch.Autograd 14 | import Torch.NN 15 | 16 | import Control.Monad.State.Strict 17 | import Data.List (foldl', scanl', intersperse) 18 | 19 | import RecurrentLayer 20 | 21 | 22 | data LSTMSpec = LSTMSpec { inf :: Int, hf :: Int} 23 | 24 | data LSTMCell = LSTMCell { 25 | input_gate :: [Parameter], 26 | forget_gate :: [Parameter], 27 | output_gate :: [Parameter], 28 | hidden_gate :: [Parameter], 29 | cell_state :: Parameter 30 | } 31 | 32 | 33 | newCellState :: LSTMCell -> Tensor -> Tensor -> Tensor 34 | newCellState LSTMCell{..} input hidden = 35 | (fg * (toDependent cell_state)) + (ig * c') 36 | where 37 | ig = gate input hidden Torch.Functions.sigmoid 38 | (input_gate !! 0) 39 | (input_gate !! 1) 40 | (input_gate !! 2) 41 | fg = gate input hidden Torch.Functions.sigmoid 42 | (forget_gate !! 0) 43 | (forget_gate !! 1) 44 | (forget_gate !! 2) 45 | c' = gate input hidden Torch.Functions.sigmoid 46 | (hidden_gate !! 0) 47 | (hidden_gate !! 1) 48 | (hidden_gate !! 2) 49 | 50 | 51 | instance RecurrentCell LSTMCell where 52 | nextState cell input hidden = 53 | matmul og (Torch.Functions.tanh cNew) 54 | where 55 | og' = output_gate cell 56 | og = gate input hidden Torch.Functions.sigmoid 57 | (og' !! 0) 58 | (og' !! 1) 59 | (og' !! 2) 60 | cNew = newCellState cell input hidden 61 | 62 | 63 | instance Randomizable LSTMSpec LSTMCell where 64 | sample LSTMSpec{..} = do 65 | ig_ih <- makeIndependent =<< randn' [inf, hf] 66 | ig_hh <- makeIndependent =<< randn' [hf, hf] 67 | ig_b <- makeIndependent =<< randn' [1, hf] 68 | fg_ih <- makeIndependent =<< randn' [inf, hf] 69 | fg_hh <- makeIndependent =<< randn' [hf, hf] 70 | fg_b <- makeIndependent =<< randn' [1, hf] 71 | og_ih <- makeIndependent =<< randn' [inf, hf] 72 | og_hh <- makeIndependent =<< randn' [hf, hf] 73 | og_b <- makeIndependent =<< randn' [1, hf] 74 | hg_ih <- makeIndependent =<< randn' [inf, hf] 75 | hg_hh <- makeIndependent =<< randn' [hf, hf] 76 | hg_b <- makeIndependent =<< randn' [1, hf] 77 | let ig = [ig_ih, ig_hh, ig_b] 78 | let fg = [fg_ih, fg_hh, fg_b] 79 | let og = [og_ih, og_hh, og_b] 80 | let hg = [hg_ih, hg_hh, hg_b] 81 | c <- makeIndependent =<< randn' [hf, hf] 82 | return $ LSTMCell ig fg og hg c 83 | 84 | 85 | -- Typeclass that allows us to manipulate and update the layer weights 86 | instance Parameterized LSTMCell where 87 | flattenParameters LSTMCell{..} = 88 | input_gate ++ forget_gate ++ hidden_gate ++ 89 | output_gate ++ [cell_state] 90 | replaceOwnParameters _ = do 91 | ig_ih <- nextParameter 92 | ig_hh <- nextParameter 93 | ig_b <- nextParameter 94 | fg_ih <- nextParameter 95 | fg_hh <- nextParameter 96 | fg_b <- nextParameter 97 | hg_ih <- nextParameter 98 | hg_hh <- nextParameter 99 | hg_b <- nextParameter 100 | og_ih <- nextParameter 101 | og_hh <- nextParameter 102 | og_b <- nextParameter 103 | cell_state <- nextParameter 104 | let input_gate = [ig_ih, ig_hh, ig_b] 105 | let forget_gate = [fg_ih, fg_hh, fg_b] 106 | let hidden_gate = [hg_ih, hg_hh, hg_b] 107 | let output_gate = [og_ih, og_hh, og_b] 108 | return $ LSTMCell{..} 109 | 110 | instance Show LSTMCell where 111 | show LSTMCell{..} = 112 | (show $ input_gate) ++ "\n" ++ 113 | (show $ forget_gate) ++ "\n" ++ 114 | (show $ output_gate) ++ "\n" ++ 115 | (show $ hidden_gate) ++ "\n" ++ 116 | (show $ cell_state) 117 | -------------------------------------------------------------------------------- /examples/elman/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE DeriveGeneric #-} 5 | 6 | module Main where 7 | 8 | import Torch.Tensor 9 | import Torch.DType 10 | import Torch.TensorFactories 11 | import Torch.Functions 12 | import Torch.TensorOptions 13 | import Torch.Autograd 14 | import Torch.NN 15 | import GHC.Generics 16 | 17 | import Control.Monad.State.Strict 18 | import Data.List (foldl', scanl', intersperse) 19 | 20 | import RecurrentLayer 21 | import Elman 22 | import LSTM 23 | import GRU 24 | 25 | 26 | num_iters = 10 27 | num_timesteps = 3 28 | 29 | run :: (RecurrentCell a, Parameterized a) 30 | => Tensor 31 | -> Tensor 32 | -> Tensor 33 | -> a 34 | -> Int 35 | -> IO (a) 36 | run input_tensor init_hidden expected_output model i = do 37 | 38 | let output = finalState model input_tensor init_hidden 39 | let loss = mse_loss output expected_output 40 | 41 | print loss 42 | 43 | let flat_parameters = flattenParameters model 44 | let gradients = grad loss flat_parameters 45 | 46 | 47 | -- new parameters returned by the SGD update functions 48 | new_flat_parameters <- mapM makeIndependent $ sgd 5e-2 flat_parameters gradients 49 | 50 | -- return the new model state "to" the next iteration of foldLoop 51 | return $ replaceParameters model new_flat_parameters 52 | 53 | 54 | main :: IO () 55 | main = do 56 | 57 | let foldLoop x count block = foldM block x [1..count] 58 | 59 | -- randomly initializing training values 60 | input_tensor <- randn' [num_timesteps, 2] 61 | init_hidden <- randn' [1, 2] 62 | expected_output <- randn' [1, 2] 63 | 64 | -- randomly initialize a gate 65 | rnnLayer <- sample $ ElmanSpec { in_features = 2, hidden_features = 2 } 66 | lstmLayer <- sample $ LSTMSpec 2 2 67 | gruLayer <- sample $ GRUSpec 2 2 68 | 69 | putStrLn "\nElman Cell Training Loop" 70 | -- training loop for elman cell 71 | foldLoop rnnLayer num_iters (run input_tensor init_hidden expected_output) 72 | 73 | putStrLn "\nLSTM Training Loop" 74 | -- training loop for LSTM cell 75 | foldLoop lstmLayer num_iters (run input_tensor init_hidden expected_output) 76 | 77 | putStrLn "\nGRU Training Loop" 78 | -- training loop for GRU cell 79 | foldLoop gruLayer num_iters (run input_tensor init_hidden expected_output) 80 | 81 | return () 82 | -------------------------------------------------------------------------------- /examples/elman/RecurrentLayer.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | 6 | module RecurrentLayer where 7 | 8 | import Torch.Tensor 9 | import Torch.DType 10 | import Torch.TensorFactories 11 | import Torch.Functions 12 | import Torch.TensorOptions 13 | import Torch.Autograd 14 | import Torch.NN 15 | 16 | import Control.Monad.State.Strict 17 | import Data.List (foldl', scanl', intersperse) 18 | 19 | 20 | class RecurrentCell a where 21 | -- get the hidden state of the cell at the next timestep 22 | nextState :: a -> Tensor -> Tensor -> Tensor 23 | -- function to run the cell over multiple timesteps and get 24 | -- final hidden state 25 | finalState :: a -> Tensor -> Tensor -> Tensor 26 | finalState layer input hidden = 27 | let 28 | -- converting matrix into a list of tensors 29 | -- this hack stays until I can write a Foldable instance 30 | -- for a tensor 31 | inputAsList = [reshape (input @@ x) [1, 2] | x <- [0.. ((size input 0) - 1)]] 32 | in 33 | foldl (nextState layer) hidden inputAsList 34 | 35 | 36 | {- 37 | TODO: there should also be a `forward` function here 38 | that uses the rnn forward functions from ATen 39 | but I'll implement that when I can make sense 40 | of the ATen function arguments -} 41 | 42 | 43 | gate :: Tensor 44 | -> Tensor 45 | -> (Tensor -> Tensor) 46 | -> Parameter 47 | -> Parameter 48 | -> Parameter 49 | -> Tensor 50 | gate input hidden nonLinearity inputWt hiddenWt biasWt = 51 | nonLinearity $ (mul input inputWt) + (mul hidden hiddenWt) + (toDependent biasWt) 52 | where 53 | mul features wts = transpose2D $ matmul (toDependent wts) (transpose2D features) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /examples/examples.cabal: -------------------------------------------------------------------------------- 1 | name: examples 2 | version: 0.2.0.0 3 | synopsis: examples for the new version of hasktorch 4 | -- description: 5 | homepage: https://github.com/githubuser/ffi-experimental#readme 6 | license: BSD3 7 | author: Austin Huang 8 | maintainer: hasktorch@gmail.com 9 | copyright: 2019 Austin Huang 10 | category: Codegen 11 | build-type: Simple 12 | cabal-version: >=1.10 13 | 14 | executable xor_mlp 15 | hs-source-dirs: xor_mlp 16 | main-is: Main.hs 17 | default-language: Haskell2010 18 | build-depends: base >= 4.7 && < 5 19 | , hasktorch 20 | , mtl 21 | 22 | executable elman 23 | hs-source-dirs: elman 24 | main-is: Main.hs 25 | other-modules: RecurrentLayer, 26 | Elman, 27 | LSTM, 28 | GRU 29 | default-language: Haskell2010 30 | build-depends: base >= 4.7 && < 5 31 | , hasktorch 32 | , mtl 33 | , ffi 34 | 35 | 36 | executable regression 37 | hs-source-dirs: regression 38 | main-is: Main.hs 39 | default-language: Haskell2010 40 | build-depends: base >= 4.7 && < 5 41 | , hasktorch 42 | , mtl 43 | , ffi 44 | 45 | executable cnn 46 | hs-source-dirs: cnn 47 | main-is: Main.hs 48 | default-language: Haskell2010 49 | ghc-options: -fno-warn-partial-type-signatures 50 | build-depends: base >= 4.7 && < 5 51 | , hasktorch 52 | , mtl 53 | 54 | executable gaussian_process 55 | hs-source-dirs: gaussian_process 56 | main-is: Main.hs 57 | default-language: Haskell2010 58 | ghc-options: -fno-warn-partial-type-signatures 59 | build-depends: base >= 4.7 && < 5 60 | , hasktorch 61 | , mtl 62 | 63 | executable vae 64 | hs-source-dirs: vae 65 | main-is: Main.hs 66 | default-language: Haskell2010 67 | ghc-options: -fno-warn-partial-type-signatures 68 | build-depends: base >= 4.7 && < 5 69 | , hasktorch 70 | , mtl 71 | -------------------------------------------------------------------------------- /examples/gaussian_process/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | 4 | module Main where 5 | 6 | import Control.Monad (foldM) 7 | 8 | import Prelude hiding (exp) 9 | 10 | import Torch.Tensor 11 | import Torch.DType (DType (Float)) 12 | import Torch.TensorFactories (eye', ones', rand', randn', zeros') 13 | import Torch.Functions 14 | import Torch.Autograd 15 | import Torch.NN 16 | 17 | -- | construct pairs of points on the axis 18 | makeAxis :: [Float] -> [Float] -> (Tensor, Tensor) 19 | makeAxis axis1 axis2 = (t, t') 20 | where 21 | t = asTensor (fst <$> rngPairs) 22 | t' = asTensor (snd <$> rngPairs) 23 | pairs axis1' axis2' = [(t, t') | t <- axis1', t' <- axis2'] 24 | rngPairs = pairs axis1 axis2 25 | 26 | -- | 1-dimensional radial basis function kernel 27 | kernel1d_rbf :: Double -> Double -> Tensor -> Tensor -> Tensor 28 | kernel1d_rbf sigma length t t' = (sigma'^2) * exp eterm 29 | where 30 | sigma' = asTensor sigma 31 | eterm = cmul (- (pow (t - t') (2 :: Int))) (1 / (2 * length^2) ) 32 | 33 | -- | derive a covariance matrix from the kernel for points on the axis 34 | makeCovmatrix :: [Float] -> [Float] -> Tensor 35 | makeCovmatrix axis1 axis2 = 36 | reshape (kernel1d_rbf 1.0 1.0 t t') [length axis1, length axis2] 37 | where 38 | (t, t') = makeAxis axis1 axis2 39 | 40 | -- | Multivariate 0-mean normal via cholesky decomposition 41 | mvnCholesky :: Tensor -> Int -> Int -> IO Tensor 42 | mvnCholesky cov axisDim n = do 43 | samples <- randn' [axisDim, n] 44 | pure $ matmul l samples 45 | where 46 | l = cholesky cov Upper 47 | 48 | -- | Compute posterior mean and covariance parameters based on observed data y 49 | condition :: Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> (Tensor, Tensor) 50 | condition muX muY covXX covXY covYY y = 51 | (postMu, postCov) 52 | where 53 | covYX = transpose2D covXY 54 | invY = inverse covYY 55 | postMu = muX + (matmul covXY (matmul invY (y - muY))) 56 | postCov = covXX - (matmul covXY (matmul invY covYX)) 57 | 58 | 59 | -- | Given observations + points of interest derive covariance terms and condition on observation 60 | computePosterior :: [Float] -> [Float] -> [Float] -> IO (Tensor, Tensor) 61 | computePosterior dataPredictors dataValues tRange = do 62 | let dataDim = length dataPredictors 63 | let axisDim = length tRange 64 | 65 | -- multivariate normal parameters for axis locations 66 | let priorMuAxis = zeros' [axisDim, 1] 67 | let priorCov = makeCovmatrix tRange tRange 68 | 69 | -- multivariate normal parameters for observation locations 70 | let priorMuData = zeros' [dataDim, 1] 71 | let obsCov = makeCovmatrix dataPredictors dataPredictors 72 | putStrLn $ "\nObservation coordinates covariance\n" ++ show obsCov 73 | 74 | -- cross-covariance terms 75 | let crossCov = makeCovmatrix tRange dataPredictors 76 | putStrLn $ "\nCross covariance\n" ++ show crossCov 77 | 78 | -- conditional distribution 79 | let obsVals = reshape (asTensor dataValues) [dataDim, 1] 80 | let (postMu, postCov) = 81 | condition 82 | priorMuAxis priorMuData 83 | priorCov crossCov obsCov 84 | obsVals 85 | pure $ (postMu, postCov) 86 | 87 | main :: IO () 88 | main = do 89 | -- Setup prediction axis 90 | let cov = makeCovmatrix tRange tRange 91 | putStrLn $ "Predictor values\n" ++ show tRange 92 | putStrLn $ "\nCovariance based on radial basis function\n" ++ show cov 93 | 94 | -- Prior GP, take 4 example samples 95 | putStrLn "prior" 96 | let reg = 0.01 * (eye' axisDim axisDim) -- regularization 97 | mvnSampPrior <- mvnCholesky (cov + reg) axisDim 4 98 | putStrLn $ "\nGP Samples (prior, rows = values, cols = realizations)\n" 99 | ++ show mvnSampPrior 100 | 101 | -- Observations 102 | putStrLn $ "\nObservations: predictor coordinates\n" ++ show dataPredictors 103 | putStrLn $ "\nObservations: values\n" ++ show dataValues 104 | (postMu, postCov) <- computePosterior dataPredictors dataValues tRange 105 | 106 | -- Conditional GP 107 | putStrLn $ "\nConditional mu (posterior)\n" ++ show postMu 108 | putStrLn $ "\nConditional covariance (posterior)\n" ++ show postCov 109 | mvnSampPost <- mvnCholesky (postCov + reg) (length tRange) 1 110 | putStrLn "\nGP Conditional Samples (posterior, rows = values, cols = realizations)" 111 | print (postMu + mvnSampPost) 112 | 113 | where 114 | -- Axis points 115 | scale = 0.1 116 | axisDim = 7 117 | tRange = (*) scale <$> (fromIntegral <$> [0 .. (axisDim - 1)]) 118 | -- Observed data points 119 | dataPredictors = [0.1, 0.3, 0.6] :: [Float] 120 | dataValues = [-2.3, 1.5, -4] :: [Float] 121 | -------------------------------------------------------------------------------- /examples/regression/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | 4 | module Main where 5 | 6 | import Control.Monad (foldM) 7 | 8 | import Torch.Tensor 9 | import Torch.DType (DType (Float)) 10 | import Torch.TensorFactories (ones', rand', randn') 11 | import Torch.Functions 12 | import Torch.Autograd 13 | import Torch.NN 14 | 15 | batch_size = 64 16 | num_iters = 2000 17 | num_features = 3 18 | 19 | model :: Linear -> Tensor -> Tensor 20 | model Linear{..} input = squeezeAll $ matmul input depWeight + depBias 21 | where 22 | (depWeight, depBias) = (toDependent weight, toDependent bias) 23 | 24 | groundTruth :: Tensor -> Tensor 25 | groundTruth t = squeezeAll $ matmul t weight + bias 26 | where 27 | weight = 42.0 * ones' [num_features, 1] 28 | bias = 3.14 * ones' [1] 29 | 30 | printParams :: Linear -> IO () 31 | printParams trained = do 32 | putStrLn "Parameters:" 33 | print $ toDependent $ weight trained 34 | putStrLn "Bias:" 35 | print $ toDependent $ bias trained 36 | 37 | main :: IO () 38 | main = do 39 | init <- sample $ LinearSpec { in_features = num_features, out_features = 1 } 40 | trained <- foldLoop init num_iters $ \state i -> do 41 | input <- randn' [batch_size, num_features] 42 | let expected_output = groundTruth input 43 | output = model state input 44 | loss = mse_loss output expected_output 45 | flat_parameters = flattenParameters state 46 | gradients = grad loss flat_parameters 47 | if i `mod` 100 == 0 then 48 | putStrLn $ "Loss: " ++ show loss 49 | else 50 | pure () 51 | new_flat_parameters <- mapM makeIndependent $ sgd 5e-3 flat_parameters gradients 52 | return $ replaceParameters state $ new_flat_parameters 53 | printParams trained 54 | pure () 55 | where 56 | foldLoop x count block = foldM block x [1..count] 57 | -------------------------------------------------------------------------------- /examples/vae/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | {-# LANGUAGE FunctionalDependencies #-} 4 | 5 | module Main where 6 | 7 | import Control.Monad (foldM) 8 | import Data.List (foldl', scanl', intersperse) 9 | import GHC.Generics 10 | import Prelude hiding (exp) 11 | 12 | import Torch.Tensor 13 | import Torch.DType (DType (Float)) 14 | import Torch.TensorFactories (ones', rand', randn', randn_like) 15 | import Torch.Functions 16 | import Torch.Autograd 17 | import Torch.NN 18 | 19 | -- Model Specification 20 | 21 | data VAESpec = VAESpec { 22 | encoderSpec :: [LinearSpec], 23 | muSpec :: LinearSpec, 24 | logvarSpec :: LinearSpec, 25 | decoderSpec :: [LinearSpec], 26 | nonlinearitySpec :: Tensor -> Tensor 27 | } deriving (Generic) 28 | 29 | -- Model State 30 | 31 | data VAEState = VAEState { 32 | encoderState :: [Linear], 33 | muFC :: Linear, 34 | logvarFC :: Linear, 35 | decoderState :: [Linear], 36 | nonlinearity :: Tensor -> Tensor 37 | } deriving (Generic) 38 | 39 | instance Randomizable VAESpec VAEState where 40 | sample VAESpec{..} = do 41 | encoderState <- mapM sample encoderSpec 42 | muFC <- sample muSpec 43 | logvarFC <- sample logvarSpec 44 | decoderState <- mapM sample decoderSpec 45 | let nonlinearity = nonlinearitySpec 46 | pure $ VAEState{..} 47 | 48 | instance Parameterized VAEState 49 | 50 | -- Output including latent mu and logvar used for VAE loss 51 | 52 | data ModelOutput = ModelOutput { 53 | recon :: Tensor, 54 | mu :: Tensor, 55 | logvar :: Tensor 56 | } deriving (Show) 57 | 58 | -- Recon Error + KL Divergence VAE Loss 59 | vaeLoss :: Tensor -> Tensor -> Tensor -> Tensor -> Tensor 60 | vaeLoss recon_x x mu logvar = reconLoss + kld 61 | where 62 | -- reconLoss = binary_cross_entropy_loss recon_x x undefined ReduceSum 63 | reconLoss = mse_loss recon_x x 64 | kld = -0.5 * (sumAll (1 + logvar - pow mu (2 :: Int) - exp logvar)) 65 | 66 | -- | End-to-end function for VAE model 67 | model :: VAEState -> Tensor -> IO ModelOutput 68 | model VAEState{..} input = do 69 | let encoded = mlp encoderState nonlinearity input 70 | mu = (linear muFC) encoded 71 | logvar = (linear logvarFC) encoded 72 | z <- reparamaterize mu logvar 73 | let output = mlp decoderState nonlinearity z -- TODO - try sampling output 74 | pure $ ModelOutput output mu logvar 75 | 76 | -- | MLP helper function for model used by both encoder & decoder 77 | mlp :: [Linear] -> (Tensor -> Tensor) -> Tensor -> Tensor 78 | mlp mlpState nonlin input = foldl' revApply input layerFunctionsList 79 | where 80 | layerFunctionsList = intersperse nonlin $ (map linear mlpState) 81 | revApply x f = f x 82 | 83 | -- | Reparamaterization trick to sample from latent space while allowing differentiation 84 | reparamaterize :: Tensor -> Tensor -> IO Tensor 85 | reparamaterize mu logvar = do 86 | eps <- randn_like mu 87 | pure $ mu + eps * exp (0.5 * logvar) 88 | 89 | -- | Given weights, apply linear layer to an input 90 | linear :: Linear -> Tensor -> Tensor 91 | linear Linear{..} input = squeezeAll $ matmul input depWeight + depBias 92 | where (depWeight, depBias) = (toDependent weight, toDependent bias) 93 | 94 | -- | Multivariate 0-mean normal via cholesky decomposition 95 | mvnCholesky :: Tensor -> Int -> Int -> IO Tensor 96 | mvnCholesky cov n axisDim = do 97 | samples <- randn' [axisDim, n] 98 | pure $ matmul l samples 99 | where 100 | l = cholesky cov Upper 101 | 102 | main :: IO () 103 | main = 104 | let nSamples = 32768 105 | -- TODO - use higher dimensions once functionality works 106 | dataDim = 4 107 | hDim = 2 108 | zDim = 2 109 | batchSize = 256 -- TODO - crashes for case where any batch is of size n=1 110 | numIters = 8000 111 | in do 112 | init <- sample $ VAESpec { 113 | encoderSpec = [LinearSpec dataDim hDim], 114 | muSpec = LinearSpec hDim zDim, 115 | logvarSpec = LinearSpec hDim zDim, 116 | decoderSpec = [LinearSpec zDim hDim, LinearSpec hDim dataDim], 117 | nonlinearitySpec = relu } 118 | 119 | dat <- transpose2D <$> 120 | mvnCholesky (asTensor ([[1.0, 0.3, 0.1, 0.0], 121 | [0.3, 1.0, 0.3, 0.1], 122 | [0.1, 0.3, 1.0, 0.3], 123 | [0.0, 0.1, 0.3, 1.0]] :: [[Float]])) 124 | nSamples dataDim 125 | trained <- foldLoop init numIters $ \vaeState i -> do 126 | let startIndex = mod (batchSize * i) nSamples 127 | endIndex = Prelude.min (startIndex + batchSize) nSamples 128 | input = slice dat 0 startIndex endIndex 1 -- size should be [batchSize, dataDim] 129 | output <- model vaeState input 130 | let (reconX, muVal, logvarVal) = (squeezeAll $ recon output, mu output, logvar output ) 131 | let loss = vaeLoss reconX input muVal logvarVal 132 | let flat_parameters = flattenParameters vaeState 133 | let gradients = grad loss flat_parameters 134 | if i `mod` 100 == 0 135 | then do putStrLn $ show loss 136 | else return () 137 | 138 | new_flat_parameters <- mapM makeIndependent $ sgd 1e-6 flat_parameters gradients 139 | pure $ replaceParameters vaeState $ new_flat_parameters 140 | putStrLn "Done" 141 | where 142 | foldLoop x count block = foldM block x [0..count] 143 | -------------------------------------------------------------------------------- /examples/xor_mlp/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | 5 | module Main where 6 | 7 | import Torch.Tensor 8 | import Torch.DType 9 | import Torch.TensorFactories 10 | import Torch.Functions hiding (linear) 11 | import Torch.TensorOptions 12 | import Torch.Autograd 13 | import Torch.NN 14 | import GHC.Generics 15 | 16 | import Control.Monad (foldM) 17 | import Data.List (foldl', scanl', intersperse) 18 | 19 | -------------------------------------------------------------------------------- 20 | -- MLP 21 | -------------------------------------------------------------------------------- 22 | 23 | linear :: Linear -> Tensor -> Tensor 24 | linear Linear{..} input = (matmul input (toDependent weight)) + (toDependent bias) 25 | 26 | data MLPSpec = MLPSpec { feature_counts :: [Int], nonlinearitySpec :: Tensor -> Tensor } 27 | 28 | data MLP = MLP { layers :: [Linear], nonlinearity :: Tensor -> Tensor } deriving (Generic) 29 | 30 | instance Randomizable MLPSpec MLP where 31 | sample MLPSpec{..} = do 32 | let layer_sizes = mkLayerSizes feature_counts 33 | linears <- mapM sample $ map (uncurry LinearSpec) layer_sizes 34 | return $ MLP { layers = linears, nonlinearity = nonlinearitySpec } 35 | where 36 | mkLayerSizes (a : (b : t)) = 37 | scanl shift (a, b) t 38 | where 39 | shift (a, b) c = (b, c) 40 | 41 | instance Parameterized MLP 42 | -- This instance generates following codes. 43 | -- 44 | --------------------------------------------------- 45 | -- instance Parameterized MLP where 46 | -- flattenParameters MLP{..} = concat $ map flattenParameters layers 47 | -- replaceOwnParameters mlp = do 48 | -- new_layers <- mapM replaceOwnParameters (layers mlp) 49 | -- return $ mlp { layers = new_layers } 50 | 51 | mlp :: MLP -> Tensor -> Tensor 52 | mlp MLP{..} input = foldl' revApply input $ intersperse nonlinearity $ map linear layers 53 | where revApply x f = f x 54 | 55 | -------------------------------------------------------------------------------- 56 | -- Training code 57 | -------------------------------------------------------------------------------- 58 | 59 | batch_size = 32 60 | num_iters = 10000 61 | 62 | model :: MLP -> Tensor -> Tensor 63 | model params t = sigmoid (mlp params t) 64 | 65 | main :: IO () 66 | main = do 67 | init <- sample $ MLPSpec { feature_counts = [2, 20, 20, 1], nonlinearitySpec = Torch.Functions.tanh } 68 | trained <- foldLoop init num_iters $ \state i -> do 69 | input <- rand' [batch_size, 2] >>= return . (toDType Float) . (gt 0.5) 70 | let expected_output = tensorXOR input 71 | 72 | let output = squeezeAll $ model state input 73 | let loss = mse_loss output expected_output 74 | 75 | let flat_parameters = flattenParameters state 76 | let gradients = grad loss flat_parameters 77 | 78 | if i `mod` 100 == 0 79 | then do putStrLn $ show loss 80 | else return () 81 | 82 | new_flat_parameters <- mapM makeIndependent $ sgd 5e-4 flat_parameters gradients 83 | return $ replaceParameters state $ new_flat_parameters 84 | return () 85 | where 86 | foldLoop x count block = foldM block x [1..count] 87 | 88 | tensorXOR :: Tensor -> Tensor 89 | tensorXOR t = (1 - (1 - a) * (1 - b)) * (1 - (a * b)) 90 | where 91 | a = select t 1 0 92 | b = select t 1 1 93 | -------------------------------------------------------------------------------- /ffi/ffi.cabal: -------------------------------------------------------------------------------- 1 | name: ffi 2 | version: 0.1.0.0 3 | synopsis: test out alternative options for ffi interface to libtorch 1.0 4 | -- description: 5 | homepage: https://github.com/githubuser/ffi-experimental#readme 6 | license: BSD3 7 | author: Austin Huang 8 | maintainer: hasktorch@gmail.com 9 | copyright: 2018 Austin Huang 10 | category: Codegen 11 | build-type: Simple 12 | cabal-version: >=1.10 13 | 14 | library 15 | exposed-modules: ATen.Type 16 | , ATen.Const 17 | , ATen.Cast 18 | , ATen.Class 19 | , ATen.GC 20 | , ATen.Unmanaged.NN 21 | , ATen.Unmanaged.TH 22 | , ATen.Unmanaged.Native 23 | , ATen.Unmanaged.Type.Tuple 24 | , ATen.Unmanaged.Type.Generator 25 | , ATen.Unmanaged.Type.IntArray 26 | , ATen.Unmanaged.Type.Scalar 27 | , ATen.Unmanaged.Type.SparseTensorRef 28 | , ATen.Unmanaged.Type.Storage 29 | , ATen.Unmanaged.Type.Tensor 30 | , ATen.Unmanaged.Type.TensorList 31 | , ATen.Unmanaged.Type.TensorOptions 32 | , ATen.Unmanaged.Type.StdString 33 | , ATen.Unmanaged.Type.StdArray 34 | , ATen.Unmanaged.Type.Context 35 | , ATen.Unmanaged.Type.Extra 36 | , ATen.Managed.NN 37 | , ATen.Managed.TH 38 | , ATen.Managed.Cast 39 | , ATen.Managed.Native 40 | , ATen.Managed.Type.Tuple 41 | , ATen.Managed.Type.Generator 42 | , ATen.Managed.Type.IntArray 43 | , ATen.Managed.Type.Scalar 44 | , ATen.Managed.Type.SparseTensorRef 45 | , ATen.Managed.Type.Storage 46 | , ATen.Managed.Type.Tensor 47 | , ATen.Managed.Type.TensorList 48 | , ATen.Managed.Type.TensorOptions 49 | , ATen.Managed.Type.StdString 50 | , ATen.Managed.Type.StdArray 51 | , ATen.Managed.Type.Context 52 | , ATen.Managed.Type.Extra 53 | , Torch.Unmanaged.Autograd 54 | , Torch.Unmanaged.NN 55 | , Torch.Unmanaged.TH 56 | , Torch.Unmanaged.Native 57 | , Torch.Managed.Autograd 58 | , Torch.Managed.NN 59 | , Torch.Managed.TH 60 | , Torch.Managed.Native 61 | hs-source-dirs: src 62 | default-language: Haskell2010 63 | build-depends: base >= 4.7 && < 5 64 | , inline-c-cpp >= 0.3.0.1 65 | , inline-c 66 | , optparse-applicative >= 0.14.3.0 67 | , containers 68 | , template-haskell 69 | , bytestring 70 | , safe-exceptions 71 | , sysinfo 72 | , async 73 | extra-libraries: stdc++ 74 | , c10 75 | , iomp5 76 | , mklml 77 | , caffe2 78 | , torch 79 | extra-ghci-libraries: stdc++ 80 | if os(darwin) 81 | ld-options: -Wl,-keep_dwarf_unwind 82 | ghc-options: -optc-D_GLIBCXX_USE_CXX11_ABI=0 -optc-std=c++11 -optc-xc++ 83 | else 84 | ghc-options: -optc-D_GLIBCXX_USE_CXX11_ABI=0 -optc-std=c++11 85 | cc-options: -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 86 | cxx-options: -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 87 | default-extensions: Strict 88 | , StrictData 89 | 90 | 91 | test-suite spec 92 | type: exitcode-stdio-1.0 93 | hs-source-dirs: test 94 | main-is: Spec.hs 95 | other-modules: BasicSpec 96 | , MemorySpec 97 | , BackwardSpec 98 | , CudaSpec 99 | default-language: Haskell2010 100 | build-depends: base >= 4.7 && < 5 101 | , inline-c-cpp >= 0.3.0.1 102 | , inline-c 103 | , optparse-applicative >= 0.14.3.0 104 | , containers 105 | , ffi 106 | , hspec 107 | , hspec-discover 108 | , safe-exceptions 109 | if os(darwin) 110 | ld-options: -Wl,-keep_dwarf_unwind 111 | ghc-options: -optc-D_GLIBCXX_USE_CXX11_ABI=0 -optc-std=c++11 -optc-xc++ 112 | else 113 | ghc-options: -optc-D_GLIBCXX_USE_CXX11_ABI=0 -optc-std=c++11 114 | cc-options: -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 115 | cxx-options: -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++11 116 | default-extensions: Strict 117 | , StrictData 118 | -------------------------------------------------------------------------------- /ffi/src/ATen/Class.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE PolyKinds #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE QuasiQuotes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE OverloadedStrings #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE MultiParamTypeClasses #-} 9 | 10 | module ATen.Class where 11 | 12 | import Foreign (Ptr, ForeignPtr) 13 | 14 | class Castable a b where 15 | cast :: a -> (b -> IO r) -> IO r 16 | uncast :: b -> (a -> IO r) -> IO r 17 | 18 | class CppObject a where 19 | fromPtr :: Ptr a -> IO (ForeignPtr a) 20 | 21 | class CppTuple2 m where 22 | type A m 23 | type B m 24 | get0 :: m -> IO (A m) 25 | get1 :: m -> IO (B m) 26 | 27 | class CppTuple2 m => CppTuple3 m where 28 | type C m 29 | get2 :: m -> IO (C m) 30 | 31 | class CppTuple3 m => CppTuple4 m where 32 | type D m 33 | get3 :: m -> IO (D m) 34 | 35 | class CppTuple4 m => CppTuple5 m where 36 | type E m 37 | get4 :: m -> IO (E m) 38 | 39 | class CppTuple5 m => CppTuple6 m where 40 | type F m 41 | get5 :: m -> IO (F m) 42 | -------------------------------------------------------------------------------- /ffi/src/ATen/Const.hs: -------------------------------------------------------------------------------- 1 | 2 | -- generated by using spec/Declarations.yaml 3 | 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE PolyKinds #-} 6 | {-# LANGUAGE TemplateHaskell #-} 7 | {-# LANGUAGE QuasiQuotes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE OverloadedStrings #-} 10 | 11 | module ATen.Const where 12 | 13 | import qualified Language.C.Inline.Cpp as C 14 | import qualified Language.C.Inline.Cpp.Exceptions as C 15 | import qualified Language.C.Inline.Context as C 16 | import qualified Language.C.Types as C 17 | import qualified Data.Map as Map 18 | 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign 22 | import ATen.Type 23 | 24 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 25 | 26 | C.include "" 27 | 28 | kByte :: ScalarType 29 | kByte = [C.pure| int8_t { (int8_t) at::ScalarType::Byte } |] 30 | 31 | kChar :: ScalarType 32 | kChar = [C.pure| int8_t { (int8_t) at::ScalarType::Char } |] 33 | 34 | kDouble :: ScalarType 35 | kDouble = [C.pure| int8_t { (int8_t) at::ScalarType::Double } |] 36 | 37 | kFloat :: ScalarType 38 | kFloat = [C.pure| int8_t { (int8_t) at::ScalarType::Float } |] 39 | 40 | kInt :: ScalarType 41 | kInt = [C.pure| int8_t { (int8_t) at::ScalarType::Int } |] 42 | 43 | kLong :: ScalarType 44 | kLong = [C.pure| int8_t { (int8_t) at::ScalarType::Long } |] 45 | 46 | kShort :: ScalarType 47 | kShort = [C.pure| int8_t { (int8_t) at::ScalarType::Short } |] 48 | 49 | kHalf :: ScalarType 50 | kHalf = [C.pure| int8_t { (int8_t) at::ScalarType::Half } |] 51 | 52 | kBool :: ScalarType 53 | kBool = [C.pure| int8_t { (int8_t) at::ScalarType::Bool } |] 54 | 55 | kComplexHalf :: ScalarType 56 | kComplexHalf = [C.pure| int8_t { (int8_t) at::ScalarType::ComplexHalf } |] 57 | 58 | kComplexFloat :: ScalarType 59 | kComplexFloat = [C.pure| int8_t { (int8_t) at::ScalarType::ComplexFloat } |] 60 | 61 | kComplexDouble :: ScalarType 62 | kComplexDouble = [C.pure| int8_t { (int8_t) at::ScalarType::ComplexDouble } |] 63 | 64 | kUndefined :: ScalarType 65 | kUndefined = [C.pure| int8_t { (int8_t) at::ScalarType::Undefined } |] 66 | 67 | kCPU :: DeviceType 68 | kCPU = [C.pure| int16_t { (int16_t) at::DeviceType::CPU } |] 69 | 70 | kCUDA :: DeviceType 71 | kCUDA = [C.pure| int16_t { (int16_t) at::DeviceType::CUDA } |] 72 | 73 | kMKLDNN :: DeviceType 74 | kMKLDNN = [C.pure| int16_t { (int16_t) at::DeviceType::MKLDNN } |] 75 | 76 | kOPENGL :: DeviceType 77 | kOPENGL = [C.pure| int16_t { (int16_t) at::DeviceType::OPENGL } |] 78 | 79 | kOPENCL :: DeviceType 80 | kOPENCL = [C.pure| int16_t { (int16_t) at::DeviceType::OPENCL } |] 81 | 82 | kIDEEP :: DeviceType 83 | kIDEEP = [C.pure| int16_t { (int16_t) at::DeviceType::IDEEP } |] 84 | 85 | kHIP :: DeviceType 86 | kHIP = [C.pure| int16_t { (int16_t) at::DeviceType::HIP } |] 87 | 88 | kFPGA :: DeviceType 89 | kFPGA = [C.pure| int16_t { (int16_t) at::DeviceType::FPGA } |] 90 | 91 | kMSNPU :: DeviceType 92 | kMSNPU = [C.pure| int16_t { (int16_t) at::DeviceType::MSNPU } |] 93 | 94 | kXLA :: DeviceType 95 | kXLA = [C.pure| int16_t { (int16_t) at::DeviceType::XLA } |] 96 | 97 | kCOMPILE_TIME_MAX_DEVICE_TYPES :: DeviceType 98 | kCOMPILE_TIME_MAX_DEVICE_TYPES = [C.pure| int16_t { (int16_t) at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES } |] 99 | 100 | kONLY_FOR_TEST :: DeviceType 101 | kONLY_FOR_TEST = [C.pure| int16_t { (int16_t) at::DeviceType::ONLY_FOR_TEST } |] 102 | 103 | -- TODO: add all values for at::Reduction 104 | 105 | kMean :: Int64 106 | kMean = [C.pure| int64_t { (int64_t) Reduction::Mean } |] 107 | 108 | bCPU :: Backend 109 | bCPU = [C.pure| int { (int) at::Backend::CPU } |] 110 | 111 | bCUDA :: Backend 112 | bCUDA = [C.pure| int { (int) at::Backend::CUDA } |] 113 | 114 | bHIP :: Backend 115 | bHIP = [C.pure| int { (int) at::Backend::HIP } |] 116 | 117 | bSparseCPU :: Backend 118 | bSparseCPU = [C.pure| int { (int) at::Backend::SparseCPU } |] 119 | 120 | bSparseCUDA :: Backend 121 | bSparseCUDA = [C.pure| int { (int) at::Backend::SparseCUDA } |] 122 | 123 | bSparseHIP :: Backend 124 | bSparseHIP = [C.pure| int { (int) at::Backend::SparseHIP } |] 125 | 126 | bMSNPU :: Backend 127 | bMSNPU = [C.pure| int { (int) at::Backend::MSNPU } |] 128 | 129 | bXLA :: Backend 130 | bXLA = [C.pure| int { (int) at::Backend::XLA } |] 131 | 132 | bUndefined :: Backend 133 | bUndefined = [C.pure| int { (int) at::Backend::Undefined } |] 134 | 135 | bNumOptions :: Backend 136 | bNumOptions = [C.pure| int { (int) at::Backend::NumOptions } |] 137 | 138 | kStrided :: Layout 139 | kStrided = [C.pure| int8_t { (int8_t) at::kStrided } |] 140 | 141 | kSparse :: Layout 142 | kSparse = [C.pure| int8_t { (int8_t) at::kSparse } |] 143 | 144 | kMkldnn :: Layout 145 | kMkldnn = [C.pure| int8_t { (int8_t) at::kMkldnn } |] 146 | -------------------------------------------------------------------------------- /ffi/src/ATen/GC.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE EmptyDataDecls #-} 2 | {-# LANGUAGE ExistentialQuantification #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE ForeignFunctionInterface #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE TypeSynonymInstances #-} 10 | 11 | module ATen.GC where 12 | 13 | import Control.Exception.Safe (catch,throwIO) 14 | import Data.List (isPrefixOf) 15 | import Language.C.Inline.Cpp.Exceptions (CppException(..)) 16 | import System.Mem (performGC) 17 | import Control.Concurrent (threadDelay) 18 | import Control.Concurrent.Async 19 | import System.SysInfo 20 | 21 | retryWithGC' :: Int -> IO a -> IO a 22 | retryWithGC' count func = 23 | func `catch` \a@(CppStdException message) -> 24 | if isPrefixOf msgOutOfMemory message 25 | then 26 | if count <= 0 27 | then throwIO $ CppStdException $ "Too many calls to performGC, " ++ message 28 | else do 29 | performGC 30 | threadDelay 1000 -- We need delta delay(1ms) to wait GC. 31 | retryWithGC' (count-1) func 32 | else throwIO a 33 | where 34 | msgOutOfMemory :: String 35 | msgOutOfMemory = "Exception: CUDA out of memory." 36 | 37 | retryWithGC :: IO a -> IO a 38 | retryWithGC = retryWithGC' 10 39 | 40 | checkOSMemoryWithGC :: IO () 41 | checkOSMemoryWithGC = do 42 | v <- sysInfo 43 | case v of 44 | Right stat -> do 45 | let rate = (fromIntegral (freeram stat) / fromIntegral (totalram stat)) 46 | if rate <= 0.5 47 | then performGC 48 | else return () 49 | Left _ -> return () 50 | threadDelay (500*1000) -- wait 500msec 51 | checkOSMemoryWithGC 52 | 53 | monitorMemory :: IO () -> IO () 54 | monitorMemory func = do 55 | func `race` checkOSMemoryWithGC 56 | return () 57 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Cast.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE MultiParamTypeClasses #-} 3 | 4 | module ATen.Managed.Cast where 5 | 6 | import Foreign.ForeignPtr 7 | import Control.Monad 8 | 9 | import ATen.Class 10 | import ATen.Cast 11 | import ATen.Type 12 | import ATen.Managed.Type.IntArray 13 | import ATen.Managed.Type.TensorList 14 | 15 | instance Castable Int (ForeignPtr IntArray) where 16 | cast xs f = do 17 | arr <- newIntArray 18 | intArray_push_back_l arr $ fromIntegral xs 19 | f arr 20 | uncast xs f = do 21 | v <- intArray_at_s xs 0 22 | f (fromIntegral v) 23 | 24 | instance Castable [Int] (ForeignPtr IntArray) where 25 | cast xs f = do 26 | arr <- newIntArray 27 | forM_ xs $ (intArray_push_back_l arr) . fromIntegral 28 | f arr 29 | uncast xs f = do 30 | len <- intArray_size xs 31 | -- NB: This check is necessary, because len is unsigned and it will wrap around if 32 | -- we subtract 1 when it's 0. 33 | if len == 0 34 | then f [] 35 | else f =<< mapM (\i -> intArray_at_s xs i >>= return . fromIntegral) [0..(len - 1)] 36 | 37 | instance Castable [ForeignPtr Tensor] (ForeignPtr TensorList) where 38 | cast xs f = do 39 | l <- newTensorList 40 | forM_ xs $ (tensorList_push_back_t l) 41 | f l 42 | uncast xs f = do 43 | len <- tensorList_size xs 44 | f =<< mapM (tensorList_at_s xs) [0..(len - 1)] 45 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/Context.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.Context where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.Context as Unmanaged 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | init 42 | :: IO (()) 43 | init = cast0 Unmanaged.init 44 | 45 | hasCUDA 46 | :: IO (CBool) 47 | hasCUDA = cast0 Unmanaged.hasCUDA 48 | 49 | hasHIP 50 | :: IO (CBool) 51 | hasHIP = cast0 Unmanaged.hasHIP 52 | 53 | hasXLA 54 | :: IO (CBool) 55 | hasXLA = cast0 Unmanaged.hasXLA 56 | 57 | getNumGPUs 58 | :: IO (CSize) 59 | getNumGPUs = cast0 Unmanaged.getNumGPUs 60 | 61 | hasOpenMP 62 | :: IO (CBool) 63 | hasOpenMP = cast0 Unmanaged.hasOpenMP 64 | 65 | hasMKL 66 | :: IO (CBool) 67 | hasMKL = cast0 Unmanaged.hasMKL 68 | 69 | hasLAPACK 70 | :: IO (CBool) 71 | hasLAPACK = cast0 Unmanaged.hasLAPACK 72 | 73 | hasMAGMA 74 | :: IO (CBool) 75 | hasMAGMA = cast0 Unmanaged.hasMAGMA 76 | 77 | hasMKLDNN 78 | :: IO (CBool) 79 | hasMKLDNN = cast0 Unmanaged.hasMKLDNN 80 | 81 | manual_seed_L 82 | :: Word64 83 | -> IO (()) 84 | manual_seed_L = cast1 Unmanaged.manual_seed_L 85 | 86 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/Extra.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.Extra where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.Extra as Unmanaged 32 | 33 | tensor_assign1_l 34 | :: ForeignPtr Tensor 35 | -> Int64 36 | -> Int64 37 | -> IO () 38 | tensor_assign1_l = cast3 Unmanaged.tensor_assign1_l 39 | 40 | tensor_assign2_l 41 | :: ForeignPtr Tensor 42 | -> Int64 43 | -> Int64 44 | -> Int64 45 | -> IO () 46 | tensor_assign2_l = cast4 Unmanaged.tensor_assign2_l 47 | 48 | tensor_assign1_d 49 | :: ForeignPtr Tensor 50 | -> Int64 51 | -> CDouble 52 | -> IO () 53 | tensor_assign1_d = cast3 Unmanaged.tensor_assign1_d 54 | 55 | tensor_assign2_d 56 | :: ForeignPtr Tensor 57 | -> Int64 58 | -> Int64 59 | -> CDouble 60 | -> IO () 61 | tensor_assign2_d = cast4 Unmanaged.tensor_assign2_d 62 | 63 | tensor_assign1_t 64 | :: ForeignPtr Tensor 65 | -> Int64 66 | -> ForeignPtr Tensor 67 | -> IO () 68 | tensor_assign1_t = cast3 Unmanaged.tensor_assign1_t 69 | 70 | tensor_assign2_t 71 | :: ForeignPtr Tensor 72 | -> Int64 73 | -> Int64 74 | -> ForeignPtr Tensor 75 | -> IO () 76 | tensor_assign2_t = cast4 Unmanaged.tensor_assign2_t 77 | 78 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/Generator.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.Generator where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.Generator as Unmanaged 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/IntArray.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.IntArray where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.IntArray as Unmanaged 32 | 33 | 34 | 35 | newIntArray 36 | :: IO (ForeignPtr IntArray) 37 | newIntArray = cast0 Unmanaged.newIntArray 38 | 39 | 40 | 41 | 42 | 43 | intArray_empty 44 | :: ForeignPtr IntArray 45 | -> IO (CBool) 46 | intArray_empty = cast1 Unmanaged.intArray_empty 47 | 48 | intArray_size 49 | :: ForeignPtr IntArray 50 | -> IO (CSize) 51 | intArray_size = cast1 Unmanaged.intArray_size 52 | 53 | intArray_at_s 54 | :: ForeignPtr IntArray 55 | -> CSize 56 | -> IO (Int64) 57 | intArray_at_s = cast2 Unmanaged.intArray_at_s 58 | 59 | intArray_push_back_l 60 | :: ForeignPtr IntArray 61 | -> Int64 62 | -> IO (()) 63 | intArray_push_back_l = cast2 Unmanaged.intArray_push_back_l 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/Scalar.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.Scalar where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.Scalar as Unmanaged 32 | 33 | 34 | 35 | newScalar 36 | :: IO (ForeignPtr Scalar) 37 | newScalar = cast0 Unmanaged.newScalar 38 | 39 | newScalar_i 40 | :: CInt 41 | -> IO (ForeignPtr Scalar) 42 | newScalar_i = cast1 Unmanaged.newScalar_i 43 | 44 | newScalar_d 45 | :: CDouble 46 | -> IO (ForeignPtr Scalar) 47 | newScalar_d = cast1 Unmanaged.newScalar_d 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/SparseTensorRef.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.SparseTensorRef where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.SparseTensorRef as Unmanaged 32 | 33 | 34 | 35 | newSparseTensorRef_t 36 | :: ForeignPtr Tensor 37 | -> IO (ForeignPtr SparseTensorRef) 38 | newSparseTensorRef_t = cast1 Unmanaged.newSparseTensorRef_t 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/StdArray.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.StdArray where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import qualified ATen.Unmanaged.Type.StdArray as Unmanaged 22 | 23 | 24 | 25 | newStdArrayBool2 26 | :: IO (ForeignPtr (StdArray CBool 2)) 27 | newStdArrayBool2 = cast0 Unmanaged.newStdArrayBool2 28 | 29 | newStdArrayBool2_bb 30 | :: CBool 31 | -> CBool 32 | -> IO (ForeignPtr (StdArray CBool 2)) 33 | newStdArrayBool2_bb = cast2 Unmanaged.newStdArrayBool2_bb 34 | 35 | instance CppTuple2 (ForeignPtr (StdArray CBool 2)) where 36 | type A (ForeignPtr (StdArray CBool 2)) = CBool 37 | type B (ForeignPtr (StdArray CBool 2)) = CBool 38 | get0 v = cast1 (get0 :: Ptr (StdArray CBool 2) -> IO CBool) v 39 | get1 v = cast1 (get1 :: Ptr (StdArray CBool 2) -> IO CBool) v 40 | 41 | newStdArrayBool3 42 | :: IO (ForeignPtr (StdArray CBool 3)) 43 | newStdArrayBool3 = cast0 Unmanaged.newStdArrayBool3 44 | 45 | newStdArrayBool3_bbb 46 | :: CBool 47 | -> CBool 48 | -> CBool 49 | -> IO (ForeignPtr (StdArray CBool 3)) 50 | newStdArrayBool3_bbb = cast3 Unmanaged.newStdArrayBool3_bbb 51 | 52 | instance CppTuple2 (ForeignPtr (StdArray CBool 3)) where 53 | type A (ForeignPtr (StdArray CBool 3)) = CBool 54 | type B (ForeignPtr (StdArray CBool 3)) = CBool 55 | get0 v = cast1 (get0 :: Ptr (StdArray CBool 3) -> IO CBool) v 56 | get1 v = cast1 (get1 :: Ptr (StdArray CBool 3) -> IO CBool) v 57 | 58 | instance CppTuple3 (ForeignPtr (StdArray CBool 3)) where 59 | type C (ForeignPtr (StdArray CBool 3)) = CBool 60 | get2 v = cast1 (get2 :: Ptr (StdArray CBool 3) -> IO CBool) v 61 | 62 | newStdArrayBool4 63 | :: IO (ForeignPtr (StdArray CBool 4)) 64 | newStdArrayBool4 = cast0 Unmanaged.newStdArrayBool4 65 | 66 | newStdArrayBool4_bbbb 67 | :: CBool 68 | -> CBool 69 | -> CBool 70 | -> CBool 71 | -> IO (ForeignPtr (StdArray CBool 4)) 72 | newStdArrayBool4_bbbb = cast4 Unmanaged.newStdArrayBool4_bbbb 73 | 74 | instance CppTuple2 (ForeignPtr (StdArray CBool 4)) where 75 | type A (ForeignPtr (StdArray CBool 4)) = CBool 76 | type B (ForeignPtr (StdArray CBool 4)) = CBool 77 | get0 v = cast1 (get0 :: Ptr (StdArray CBool 4) -> IO CBool) v 78 | get1 v = cast1 (get1 :: Ptr (StdArray CBool 4) -> IO CBool) v 79 | 80 | instance CppTuple3 (ForeignPtr (StdArray CBool 4)) where 81 | type C (ForeignPtr (StdArray CBool 4)) = CBool 82 | get2 v = cast1 (get2 :: Ptr (StdArray CBool 4) -> IO CBool) v 83 | 84 | instance CppTuple4 (ForeignPtr (StdArray CBool 4)) where 85 | type D (ForeignPtr (StdArray CBool 4)) = CBool 86 | get3 v = cast1 (get3 :: Ptr (StdArray CBool 4) -> IO CBool) v 87 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/StdString.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE PolyKinds #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE QuasiQuotes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE OverloadedStrings #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE FlexibleInstances #-} 9 | {-# LANGUAGE MultiParamTypeClasses #-} 10 | 11 | module ATen.Managed.Type.StdString where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import qualified ATen.Unmanaged.Type.StdString as Unmanaged 22 | 23 | 24 | 25 | newStdString 26 | :: IO (ForeignPtr StdString) 27 | newStdString = cast0 Unmanaged.newStdString 28 | 29 | newStdString_s 30 | :: String 31 | -> IO (ForeignPtr StdString) 32 | newStdString_s str = cast1 Unmanaged.newStdString_s str 33 | 34 | string_c_str 35 | :: ForeignPtr StdString 36 | -> IO String 37 | string_c_str str = cast1 Unmanaged.string_c_str str 38 | 39 | instance Castable String (ForeignPtr StdString) where 40 | cast str f = newStdString_s str >>= f 41 | uncast xs f = string_c_str xs >>= f 42 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/Storage.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.Storage where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.Storage as Unmanaged 32 | 33 | 34 | 35 | newStorage 36 | :: IO (ForeignPtr Storage) 37 | newStorage = cast0 Unmanaged.newStorage 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/TensorList.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.TensorList where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.TensorList as Unmanaged 32 | 33 | 34 | 35 | newTensorList 36 | :: IO (ForeignPtr TensorList) 37 | newTensorList = cast0 Unmanaged.newTensorList 38 | 39 | 40 | 41 | 42 | 43 | tensorList_empty 44 | :: ForeignPtr TensorList 45 | -> IO (CBool) 46 | tensorList_empty = cast1 Unmanaged.tensorList_empty 47 | 48 | tensorList_size 49 | :: ForeignPtr TensorList 50 | -> IO (CSize) 51 | tensorList_size = cast1 Unmanaged.tensorList_size 52 | 53 | tensorList_at_s 54 | :: ForeignPtr TensorList 55 | -> CSize 56 | -> IO (ForeignPtr Tensor) 57 | tensorList_at_s = cast2 Unmanaged.tensorList_at_s 58 | 59 | tensorList_push_back_t 60 | :: ForeignPtr TensorList 61 | -> ForeignPtr Tensor 62 | -> IO (()) 63 | tensorList_push_back_t = cast2 Unmanaged.tensorList_push_back_t 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/TensorOptions.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Managed.Type.TensorOptions where 12 | 13 | 14 | import Foreign.C.String 15 | import Foreign.C.Types 16 | import Foreign hiding (newForeignPtr) 17 | import Foreign.Concurrent 18 | import ATen.Type 19 | import ATen.Class 20 | import ATen.Cast 21 | import ATen.Unmanaged.Type.Generator 22 | import ATen.Unmanaged.Type.IntArray 23 | import ATen.Unmanaged.Type.Scalar 24 | import ATen.Unmanaged.Type.SparseTensorRef 25 | import ATen.Unmanaged.Type.Storage 26 | import ATen.Unmanaged.Type.Tensor 27 | import ATen.Unmanaged.Type.TensorList 28 | import ATen.Unmanaged.Type.TensorOptions 29 | import ATen.Unmanaged.Type.Tuple 30 | 31 | import qualified ATen.Unmanaged.Type.TensorOptions as Unmanaged 32 | 33 | 34 | 35 | newTensorOptions_s 36 | :: ScalarType 37 | -> IO (ForeignPtr TensorOptions) 38 | newTensorOptions_s = cast1 Unmanaged.newTensorOptions_s 39 | 40 | 41 | 42 | 43 | 44 | tensorOptions_device_D 45 | :: ForeignPtr TensorOptions 46 | -> DeviceType 47 | -> IO (ForeignPtr TensorOptions) 48 | tensorOptions_device_D = cast2 Unmanaged.tensorOptions_device_D 49 | 50 | tensorOptions_device_index_s 51 | :: ForeignPtr TensorOptions 52 | -> Int16 53 | -> IO (ForeignPtr TensorOptions) 54 | tensorOptions_device_index_s = cast2 Unmanaged.tensorOptions_device_index_s 55 | 56 | tensorOptions_dtype_s 57 | :: ForeignPtr TensorOptions 58 | -> ScalarType 59 | -> IO (ForeignPtr TensorOptions) 60 | tensorOptions_dtype_s = cast2 Unmanaged.tensorOptions_dtype_s 61 | 62 | tensorOptions_dtype 63 | :: ForeignPtr TensorOptions 64 | -> IO (ForeignPtr TensorOptions) 65 | tensorOptions_dtype = cast1 Unmanaged.tensorOptions_dtype 66 | 67 | tensorOptions_layout_L 68 | :: ForeignPtr TensorOptions 69 | -> Layout 70 | -> IO (ForeignPtr TensorOptions) 71 | tensorOptions_layout_L = cast2 Unmanaged.tensorOptions_layout_L 72 | 73 | tensorOptions_requires_grad_b 74 | :: ForeignPtr TensorOptions 75 | -> CBool 76 | -> IO (ForeignPtr TensorOptions) 77 | tensorOptions_requires_grad_b = cast2 Unmanaged.tensorOptions_requires_grad_b 78 | 79 | tensorOptions_is_variable_b 80 | :: ForeignPtr TensorOptions 81 | -> CBool 82 | -> IO (ForeignPtr TensorOptions) 83 | tensorOptions_is_variable_b = cast2 Unmanaged.tensorOptions_is_variable_b 84 | 85 | tensorOptions_has_device 86 | :: ForeignPtr TensorOptions 87 | -> IO (CBool) 88 | tensorOptions_has_device = cast1 Unmanaged.tensorOptions_has_device 89 | 90 | tensorOptions_device_index 91 | :: ForeignPtr TensorOptions 92 | -> IO (Int32) 93 | tensorOptions_device_index = cast1 Unmanaged.tensorOptions_device_index 94 | 95 | tensorOptions_has_dtype 96 | :: ForeignPtr TensorOptions 97 | -> IO (CBool) 98 | tensorOptions_has_dtype = cast1 Unmanaged.tensorOptions_has_dtype 99 | 100 | tensorOptions_layout 101 | :: ForeignPtr TensorOptions 102 | -> IO (Layout) 103 | tensorOptions_layout = cast1 Unmanaged.tensorOptions_layout 104 | 105 | tensorOptions_has_layout 106 | :: ForeignPtr TensorOptions 107 | -> IO (CBool) 108 | tensorOptions_has_layout = cast1 Unmanaged.tensorOptions_has_layout 109 | 110 | tensorOptions_requires_grad 111 | :: ForeignPtr TensorOptions 112 | -> IO (CBool) 113 | tensorOptions_requires_grad = cast1 Unmanaged.tensorOptions_requires_grad 114 | 115 | tensorOptions_has_requires_grad 116 | :: ForeignPtr TensorOptions 117 | -> IO (CBool) 118 | tensorOptions_has_requires_grad = cast1 Unmanaged.tensorOptions_has_requires_grad 119 | 120 | tensorOptions_is_variable 121 | :: ForeignPtr TensorOptions 122 | -> IO (CBool) 123 | tensorOptions_is_variable = cast1 Unmanaged.tensorOptions_is_variable 124 | 125 | tensorOptions_has_is_variable 126 | :: ForeignPtr TensorOptions 127 | -> IO (CBool) 128 | tensorOptions_has_is_variable = cast1 Unmanaged.tensorOptions_has_is_variable 129 | 130 | tensorOptions_backend 131 | :: ForeignPtr TensorOptions 132 | -> IO (Backend) 133 | tensorOptions_backend = cast1 Unmanaged.tensorOptions_backend 134 | 135 | 136 | 137 | dtype_s 138 | :: ScalarType 139 | -> IO (ForeignPtr TensorOptions) 140 | dtype_s = cast1 Unmanaged.dtype_s 141 | 142 | layout_L 143 | :: Layout 144 | -> IO (ForeignPtr TensorOptions) 145 | layout_L = cast1 Unmanaged.layout_L 146 | 147 | device_D 148 | :: DeviceType 149 | -> IO (ForeignPtr TensorOptions) 150 | device_D = cast1 Unmanaged.device_D 151 | 152 | device_index_s 153 | :: Int16 154 | -> IO (ForeignPtr TensorOptions) 155 | device_index_s = cast1 Unmanaged.device_index_s 156 | 157 | requires_grad_b 158 | :: CBool 159 | -> IO (ForeignPtr TensorOptions) 160 | requires_grad_b = cast1 Unmanaged.requires_grad_b 161 | 162 | -------------------------------------------------------------------------------- /ffi/src/ATen/Managed/Type/Tuple.hs: -------------------------------------------------------------------------------- 1 | 2 | -- generated by using spec/tuples.yaml 3 | 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE PolyKinds #-} 6 | {-# LANGUAGE TemplateHaskell #-} 7 | {-# LANGUAGE QuasiQuotes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE OverloadedStrings #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE FlexibleInstances #-} 12 | 13 | module ATen.Managed.Type.Tuple where 14 | 15 | 16 | import Foreign.C.String 17 | import Foreign.C.Types 18 | import Foreign hiding (newForeignPtr) 19 | import Foreign.Concurrent 20 | import ATen.Type 21 | import ATen.Class 22 | import ATen.Cast 23 | 24 | import qualified ATen.Unmanaged.Type.Tuple as Unmanaged 25 | import ATen.Unmanaged.Type.Generator 26 | import ATen.Unmanaged.Type.IntArray 27 | import ATen.Unmanaged.Type.Scalar 28 | import ATen.Unmanaged.Type.SparseTensorRef 29 | import ATen.Unmanaged.Type.Storage 30 | import ATen.Unmanaged.Type.Tensor 31 | import ATen.Unmanaged.Type.TensorList 32 | import ATen.Unmanaged.Type.TensorOptions 33 | import ATen.Unmanaged.Type.Tuple 34 | 35 | 36 | 37 | instance CppTuple2 (ForeignPtr (Tensor,Tensor)) where 38 | type A (ForeignPtr (Tensor,Tensor)) = ForeignPtr Tensor 39 | type B (ForeignPtr (Tensor,Tensor)) = ForeignPtr Tensor 40 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor) -> IO (Ptr Tensor)) v 41 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor) -> IO (Ptr Tensor)) v 42 | 43 | instance CppTuple2 (ForeignPtr (Tensor,Tensor,Tensor)) where 44 | type A (ForeignPtr (Tensor,Tensor,Tensor)) = ForeignPtr Tensor 45 | type B (ForeignPtr (Tensor,Tensor,Tensor)) = ForeignPtr Tensor 46 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 47 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 48 | 49 | instance CppTuple3 (ForeignPtr (Tensor,Tensor,Tensor)) where 50 | type C (ForeignPtr (Tensor,Tensor,Tensor)) = ForeignPtr Tensor 51 | get2 v = cast1 (get2 :: Ptr (Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 52 | 53 | instance CppTuple2 (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) where 54 | type A (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 55 | type B (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 56 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 57 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 58 | 59 | instance CppTuple3 (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) where 60 | type C (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 61 | get2 v = cast1 (get2 :: Ptr (Tensor,Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 62 | 63 | instance CppTuple4 (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) where 64 | type D (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 65 | get3 v = cast1 (get3 :: Ptr (Tensor,Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 66 | 67 | instance CppTuple5 (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) where 68 | type E (ForeignPtr (Tensor,Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 69 | get4 v = cast1 (get4 :: Ptr (Tensor,Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 70 | 71 | instance CppTuple2 (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) where 72 | type A (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) = ForeignPtr Tensor 73 | type B (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) = ForeignPtr Tensor 74 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor,Tensor,TensorList) -> IO (Ptr Tensor)) v 75 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor,Tensor,TensorList) -> IO (Ptr Tensor)) v 76 | 77 | instance CppTuple3 (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) where 78 | type C (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) = ForeignPtr Tensor 79 | get2 v = cast1 (get2 :: Ptr (Tensor,Tensor,Tensor,TensorList) -> IO (Ptr Tensor)) v 80 | 81 | instance CppTuple4 (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) where 82 | type D (ForeignPtr (Tensor,Tensor,Tensor,TensorList)) = ForeignPtr TensorList 83 | get3 v = cast1 (get3 :: Ptr (Tensor,Tensor,Tensor,TensorList) -> IO (Ptr TensorList)) v 84 | 85 | instance CppTuple2 (ForeignPtr (Tensor,Tensor,Tensor,Int64)) where 86 | type A (ForeignPtr (Tensor,Tensor,Tensor,Int64)) = ForeignPtr Tensor 87 | type B (ForeignPtr (Tensor,Tensor,Tensor,Int64)) = ForeignPtr Tensor 88 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor,Tensor,Int64) -> IO (Ptr Tensor)) v 89 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor,Tensor,Int64) -> IO (Ptr Tensor)) v 90 | 91 | instance CppTuple3 (ForeignPtr (Tensor,Tensor,Tensor,Int64)) where 92 | type C (ForeignPtr (Tensor,Tensor,Tensor,Int64)) = ForeignPtr Tensor 93 | get2 v = cast1 (get2 :: Ptr (Tensor,Tensor,Tensor,Int64) -> IO (Ptr Tensor)) v 94 | 95 | instance CppTuple4 (ForeignPtr (Tensor,Tensor,Tensor,Int64)) where 96 | type D (ForeignPtr (Tensor,Tensor,Tensor,Int64)) = Int64 97 | get3 v = cast1 (get3 :: Ptr (Tensor,Tensor,Tensor,Int64) -> IO (Int64)) v 98 | 99 | instance CppTuple2 (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) where 100 | type A (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 101 | type B (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 102 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 103 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 104 | 105 | instance CppTuple3 (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) where 106 | type C (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 107 | get2 v = cast1 (get2 :: Ptr (Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 108 | 109 | instance CppTuple4 (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) where 110 | type D (ForeignPtr (Tensor,Tensor,Tensor,Tensor)) = ForeignPtr Tensor 111 | get3 v = cast1 (get3 :: Ptr (Tensor,Tensor,Tensor,Tensor) -> IO (Ptr Tensor)) v 112 | 113 | instance CppTuple2 (ForeignPtr (Tensor,Tensor,CDouble,Int64)) where 114 | type A (ForeignPtr (Tensor,Tensor,CDouble,Int64)) = ForeignPtr Tensor 115 | type B (ForeignPtr (Tensor,Tensor,CDouble,Int64)) = ForeignPtr Tensor 116 | get0 v = cast1 (get0 :: Ptr (Tensor,Tensor,CDouble,Int64) -> IO (Ptr Tensor)) v 117 | get1 v = cast1 (get1 :: Ptr (Tensor,Tensor,CDouble,Int64) -> IO (Ptr Tensor)) v 118 | 119 | instance CppTuple3 (ForeignPtr (Tensor,Tensor,CDouble,Int64)) where 120 | type C (ForeignPtr (Tensor,Tensor,CDouble,Int64)) = CDouble 121 | get2 v = cast1 (get2 :: Ptr (Tensor,Tensor,CDouble,Int64) -> IO (CDouble)) v 122 | 123 | instance CppTuple4 (ForeignPtr (Tensor,Tensor,CDouble,Int64)) where 124 | type D (ForeignPtr (Tensor,Tensor,CDouble,Int64)) = Int64 125 | get3 v = cast1 (get3 :: Ptr (Tensor,Tensor,CDouble,Int64) -> IO (Int64)) v 126 | 127 | -------------------------------------------------------------------------------- /ffi/src/ATen/Type.hs: -------------------------------------------------------------------------------- 1 | 2 | -- generated by using spec/Declarations.yaml 3 | 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE PolyKinds #-} 6 | {-# LANGUAGE TemplateHaskell #-} 7 | {-# LANGUAGE QuasiQuotes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE OverloadedStrings #-} 10 | 11 | module ATen.Type where 12 | 13 | import qualified Language.C.Inline.Cpp as C 14 | import qualified Language.C.Inline.Cpp.Exceptions as C 15 | import qualified Language.C.Inline.Context as C 16 | import qualified Language.C.Types as C 17 | import qualified Data.Map as Map 18 | 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign 22 | 23 | type ScalarType = Int8 24 | type DeviceType = Int16 25 | type Backend = CInt 26 | type Layout = Int8 27 | 28 | data Tensor 29 | data Scalar 30 | data TensorOptions 31 | data TensorList 32 | data IntArrayRef 33 | data IntArray 34 | data TensorAVector 35 | data SparseTensorRef 36 | data Storage 37 | data StdArray a b 38 | data StdString 39 | data Generator 40 | data Device 41 | data Context 42 | 43 | typeTable = Map.fromList [ 44 | (C.TypeName "at::Scalar", [t|Scalar|]) 45 | , (C.TypeName "at::Tensor", [t|Tensor|]) 46 | , (C.TypeName "at::TensorOptions", [t|TensorOptions|]) 47 | , (C.TypeName "std::vector", [t|TensorList|]) 48 | , (C.TypeName "at::IntArrayRef", [t|IntArrayRef|]) 49 | , (C.TypeName "std::vector", [t|IntArray|]) 50 | , (C.TypeName "at::ScalarType", [t|ScalarType|]) 51 | , (C.TypeName "at::DeviceType", [t|DeviceType|]) 52 | , (C.TypeName "at::SparseTensorRef", [t|SparseTensorRef|]) 53 | , (C.TypeName "at::Storage", [t|Storage|]) 54 | , (C.TypeName "at::Device", [t|Device|]) 55 | , (C.TypeName "at::Generator", [t|Generator|]) 56 | , (C.TypeName "std::string", [t|StdString|]) 57 | , (C.TypeName "std::array", [t|StdArray CBool 2|]) 58 | , (C.TypeName "std::array", [t|StdArray CBool 3|]) 59 | , (C.TypeName "std::array", [t|StdArray CBool 4|]) 60 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor)|]) 61 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor,Tensor)|]) 62 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor,Tensor,Tensor)|]) 63 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor,Tensor,Tensor,Tensor)|]) 64 | , (C.TypeName "std::tuple>", [t|(Tensor,Tensor,Tensor,TensorList)|]) 65 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor,CDouble,Int64)|]) 66 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor,CFloat,CInt)|]) 67 | , (C.TypeName "std::tuple", [t|(Tensor,Tensor,Tensor,Int64)|]) 68 | , (C.TypeName "at::Backend", [t|Backend|]) 69 | , (C.TypeName "at::Layout", [t|Layout|]) 70 | , (C.TypeName "at::Context", [t|Context|]) 71 | ] 72 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/Context.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.Context where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | 34 | 35 | deleteContext :: Ptr Context -> IO () 36 | deleteContext object = [C.throwBlock| void { delete $(at::Context* object);}|] 37 | 38 | instance CppObject Context where 39 | fromPtr ptr = newForeignPtr ptr (deleteContext ptr) 40 | 41 | 42 | 43 | 44 | 45 | init 46 | :: IO (()) 47 | init = 48 | [C.throwBlock| void { (at::init( 49 | )); 50 | }|] 51 | 52 | hasCUDA 53 | :: IO (CBool) 54 | hasCUDA = 55 | [C.throwBlock| bool { return (at::hasCUDA( 56 | )); 57 | }|] 58 | 59 | hasHIP 60 | :: IO (CBool) 61 | hasHIP = 62 | [C.throwBlock| bool { return (at::hasHIP( 63 | )); 64 | }|] 65 | 66 | hasXLA 67 | :: IO (CBool) 68 | hasXLA = 69 | [C.throwBlock| bool { return (at::hasXLA( 70 | )); 71 | }|] 72 | 73 | getNumGPUs 74 | :: IO (CSize) 75 | getNumGPUs = 76 | [C.throwBlock| size_t { return (at::getNumGPUs( 77 | )); 78 | }|] 79 | 80 | hasOpenMP 81 | :: IO (CBool) 82 | hasOpenMP = 83 | [C.throwBlock| bool { return (at::hasOpenMP( 84 | )); 85 | }|] 86 | 87 | hasMKL 88 | :: IO (CBool) 89 | hasMKL = 90 | [C.throwBlock| bool { return (at::hasMKL( 91 | )); 92 | }|] 93 | 94 | hasLAPACK 95 | :: IO (CBool) 96 | hasLAPACK = 97 | [C.throwBlock| bool { return (at::hasLAPACK( 98 | )); 99 | }|] 100 | 101 | hasMAGMA 102 | :: IO (CBool) 103 | hasMAGMA = 104 | [C.throwBlock| bool { return (at::hasMAGMA( 105 | )); 106 | }|] 107 | 108 | hasMKLDNN 109 | :: IO (CBool) 110 | hasMKLDNN = 111 | [C.throwBlock| bool { return (at::hasMKLDNN( 112 | )); 113 | }|] 114 | 115 | manual_seed_L 116 | :: Word64 117 | -> IO (()) 118 | manual_seed_L _seed = 119 | [C.throwBlock| void { (at::manual_seed( 120 | $(uint64_t _seed))); 121 | }|] 122 | 123 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/Extra.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.Extra where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | tensor_assign1_l 33 | :: Ptr Tensor 34 | -> Int64 35 | -> Int64 36 | -> IO () 37 | tensor_assign1_l _obj _idx0 _val = 38 | [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)] = $(int64_t _val); }|] 39 | 40 | tensor_assign2_l 41 | :: Ptr Tensor 42 | -> Int64 43 | -> Int64 44 | -> Int64 45 | -> IO () 46 | tensor_assign2_l _obj _idx0 _idx1 _val = 47 | [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)][$(int64_t _idx1)] = $(int64_t _val); }|] 48 | 49 | tensor_assign1_d 50 | :: Ptr Tensor 51 | -> Int64 52 | -> CDouble 53 | -> IO () 54 | tensor_assign1_d _obj _idx0 _val = 55 | [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)] = $(double _val); }|] 56 | 57 | tensor_assign2_d 58 | :: Ptr Tensor 59 | -> Int64 60 | -> Int64 61 | -> CDouble 62 | -> IO () 63 | tensor_assign2_d _obj _idx0 _idx1 _val = 64 | [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)][$(int64_t _idx1)] = $(double _val); }|] 65 | 66 | 67 | tensor_assign1_t 68 | :: Ptr Tensor 69 | -> Int64 70 | -> Ptr Tensor 71 | -> IO () 72 | tensor_assign1_t _obj _idx0 _val = 73 | [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)] = *$(at::Tensor* _val); }|] 74 | 75 | tensor_assign2_t 76 | :: Ptr Tensor 77 | -> Int64 78 | -> Int64 79 | -> Ptr Tensor 80 | -> IO () 81 | tensor_assign2_t _obj _idx0 _idx1 _val = 82 | [C.throwBlock| void { (*$(at::Tensor* _obj))[$(int64_t _idx0)][$(int64_t _idx1)] = *$(at::Tensor* _val); }|] 83 | 84 | 85 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/Generator.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.Generator where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | 34 | 35 | deleteGenerator :: Ptr Generator -> IO () 36 | deleteGenerator object = [C.throwBlock| void { delete $(at::Generator* object);}|] 37 | 38 | instance CppObject Generator where 39 | fromPtr ptr = newForeignPtr ptr (deleteGenerator ptr) 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/IntArray.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.IntArray where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newIntArray 34 | :: IO (Ptr IntArray) 35 | newIntArray = 36 | [C.throwBlock| std::vector* { return new std::vector( 37 | ); 38 | }|] 39 | 40 | 41 | 42 | deleteIntArray :: Ptr IntArray -> IO () 43 | deleteIntArray object = [C.throwBlock| void { delete $(std::vector* object);}|] 44 | 45 | instance CppObject IntArray where 46 | fromPtr ptr = newForeignPtr ptr (deleteIntArray ptr) 47 | 48 | 49 | 50 | intArray_empty 51 | :: Ptr IntArray 52 | -> IO (CBool) 53 | intArray_empty _obj = 54 | [C.throwBlock| bool { return (*$(std::vector* _obj)).empty( 55 | ); 56 | }|] 57 | 58 | intArray_size 59 | :: Ptr IntArray 60 | -> IO (CSize) 61 | intArray_size _obj = 62 | [C.throwBlock| size_t { return (*$(std::vector* _obj)).size( 63 | ); 64 | }|] 65 | 66 | intArray_at_s 67 | :: Ptr IntArray 68 | -> CSize 69 | -> IO (Int64) 70 | intArray_at_s _obj _s = 71 | [C.throwBlock| int64_t { return (*$(std::vector* _obj)).at( 72 | $(size_t _s)); 73 | }|] 74 | 75 | intArray_push_back_l 76 | :: Ptr IntArray 77 | -> Int64 78 | -> IO (()) 79 | intArray_push_back_l _obj _v = 80 | [C.throwBlock| void { (*$(std::vector* _obj)).push_back( 81 | $(int64_t _v)); 82 | }|] 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/Scalar.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.Scalar where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newScalar 34 | :: IO (Ptr Scalar) 35 | newScalar = 36 | [C.throwBlock| at::Scalar* { return new at::Scalar( 37 | ); 38 | }|] 39 | 40 | newScalar_i 41 | :: CInt 42 | -> IO (Ptr Scalar) 43 | newScalar_i _a = 44 | [C.throwBlock| at::Scalar* { return new at::Scalar( 45 | $(int _a)); 46 | }|] 47 | 48 | newScalar_d 49 | :: CDouble 50 | -> IO (Ptr Scalar) 51 | newScalar_d _a = 52 | [C.throwBlock| at::Scalar* { return new at::Scalar( 53 | $(double _a)); 54 | }|] 55 | 56 | 57 | 58 | deleteScalar :: Ptr Scalar -> IO () 59 | deleteScalar object = [C.throwBlock| void { delete $(at::Scalar* object);}|] 60 | 61 | instance CppObject Scalar where 62 | fromPtr ptr = newForeignPtr ptr (deleteScalar ptr) 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/SparseTensorRef.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.SparseTensorRef where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newSparseTensorRef_t 34 | :: Ptr Tensor 35 | -> IO (Ptr SparseTensorRef) 36 | newSparseTensorRef_t _x = 37 | [C.throwBlock| at::SparseTensorRef* { return new at::SparseTensorRef( 38 | *$(at::Tensor* _x)); 39 | }|] 40 | 41 | 42 | 43 | deleteSparseTensorRef :: Ptr SparseTensorRef -> IO () 44 | deleteSparseTensorRef object = [C.throwBlock| void { delete $(at::SparseTensorRef* object);}|] 45 | 46 | instance CppObject SparseTensorRef where 47 | fromPtr ptr = newForeignPtr ptr (deleteSparseTensorRef ptr) 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/StdArray.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.StdArray where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newStdArrayBool2 34 | :: IO (Ptr (StdArray CBool 2)) 35 | newStdArrayBool2 = 36 | [C.throwBlock| std::array* { return new std::array( 37 | ); 38 | }|] 39 | 40 | newStdArrayBool2_bb 41 | :: CBool 42 | -> CBool 43 | -> IO (Ptr (StdArray CBool 2)) 44 | newStdArrayBool2_bb b0 b1 = 45 | [C.throwBlock| std::array* { return new std::array({$(bool b0),$(bool b1)}); }|] 46 | 47 | instance CppTuple2 (Ptr (StdArray CBool 2)) where 48 | type A (Ptr (StdArray CBool 2)) = CBool 49 | type B (Ptr (StdArray CBool 2)) = CBool 50 | get0 v = [C.throwBlock| bool { return std::get<0>(*$(std::array* v));}|] 51 | get1 v = [C.throwBlock| bool { return std::get<1>(*$(std::array* v));}|] 52 | 53 | newStdArrayBool3 54 | :: IO (Ptr (StdArray CBool 3)) 55 | newStdArrayBool3 = 56 | [C.throwBlock| std::array* { return new std::array( 57 | ); 58 | }|] 59 | 60 | newStdArrayBool3_bbb 61 | :: CBool 62 | -> CBool 63 | -> CBool 64 | -> IO (Ptr (StdArray CBool 3)) 65 | newStdArrayBool3_bbb b0 b1 b2 = 66 | [C.throwBlock| std::array* { return new std::array({$(bool b0),$(bool b1),$(bool b2)}); }|] 67 | 68 | instance CppTuple2 (Ptr (StdArray CBool 3)) where 69 | type A (Ptr (StdArray CBool 3)) = CBool 70 | type B (Ptr (StdArray CBool 3)) = CBool 71 | get0 v = [C.throwBlock| bool { return std::get<0>(*$(std::array* v));}|] 72 | get1 v = [C.throwBlock| bool { return std::get<1>(*$(std::array* v));}|] 73 | 74 | instance CppTuple3 (Ptr (StdArray CBool 3)) where 75 | type C (Ptr (StdArray CBool 3)) = CBool 76 | get2 v = [C.throwBlock| bool { return std::get<2>(*$(std::array* v));}|] 77 | 78 | newStdArrayBool4 79 | :: IO (Ptr (StdArray CBool 4)) 80 | newStdArrayBool4 = 81 | [C.throwBlock| std::array* { return new std::array( 82 | ); 83 | }|] 84 | 85 | newStdArrayBool4_bbbb 86 | :: CBool 87 | -> CBool 88 | -> CBool 89 | -> CBool 90 | -> IO (Ptr (StdArray CBool 4)) 91 | newStdArrayBool4_bbbb b0 b1 b2 b3 = 92 | [C.throwBlock| std::array* { return new std::array({$(bool b0),$(bool b1),$(bool b2),$(bool b3)}); }|] 93 | 94 | instance CppTuple2 (Ptr (StdArray CBool 4)) where 95 | type A (Ptr (StdArray CBool 4)) = CBool 96 | type B (Ptr (StdArray CBool 4)) = CBool 97 | get0 v = [C.throwBlock| bool { return std::get<0>(*$(std::array* v));}|] 98 | get1 v = [C.throwBlock| bool { return std::get<1>(*$(std::array* v));}|] 99 | 100 | instance CppTuple3 (Ptr (StdArray CBool 4)) where 101 | type C (Ptr (StdArray CBool 4)) = CBool 102 | get2 v = [C.throwBlock| bool { return std::get<2>(*$(std::array* v));}|] 103 | 104 | instance CppTuple4 (Ptr (StdArray CBool 4)) where 105 | type D (Ptr (StdArray CBool 4)) = CBool 106 | get3 v = [C.throwBlock| bool { return std::get<3>(*$(std::array* v));}|] 107 | 108 | deleteStdArrayBool2 :: Ptr (StdArray CBool 2) -> IO () 109 | deleteStdArrayBool2 object = [C.throwBlock| void { delete $(std::array* object);}|] 110 | 111 | deleteStdArrayBool3 :: Ptr (StdArray CBool 3) -> IO () 112 | deleteStdArrayBool3 object = [C.throwBlock| void { delete $(std::array* object);}|] 113 | 114 | deleteStdArrayBool4 :: Ptr (StdArray CBool 4) -> IO () 115 | deleteStdArrayBool4 object = [C.throwBlock| void { delete $(std::array* object);}|] 116 | 117 | 118 | instance CppObject (StdArray CBool 2) where 119 | fromPtr ptr = newForeignPtr ptr (deleteStdArrayBool2 ptr) 120 | 121 | instance CppObject (StdArray CBool 3) where 122 | fromPtr ptr = newForeignPtr ptr (deleteStdArrayBool3 ptr) 123 | 124 | instance CppObject (StdArray CBool 4) where 125 | fromPtr ptr = newForeignPtr ptr (deleteStdArrayBool4 ptr) 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/StdString.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.StdString where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newStdString 34 | :: IO (Ptr StdString) 35 | newStdString = 36 | [C.throwBlock| std::string* { return new std::string( 37 | ); 38 | }|] 39 | 40 | newStdString_s 41 | :: String 42 | -> IO (Ptr StdString) 43 | newStdString_s str = 44 | withCString str $ \cstr -> [C.throwBlock| std::string* { return new std::string($(char* cstr));}|] 45 | 46 | deleteStdString :: Ptr StdString -> IO () 47 | deleteStdString object = [C.throwBlock| void { delete $(std::string* object);}|] 48 | 49 | instance CppObject StdString where 50 | fromPtr ptr = newForeignPtr ptr (deleteStdString ptr) 51 | 52 | string_c_str 53 | :: Ptr StdString 54 | -> IO String 55 | string_c_str str = [C.throwBlock| const char* { return (*$(std::string* str)).c_str();}|] >>= peekCString 56 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/Storage.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.Storage where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newStorage 34 | :: IO (Ptr Storage) 35 | newStorage = 36 | [C.throwBlock| at::Storage* { return new at::Storage( 37 | ); 38 | }|] 39 | 40 | 41 | 42 | deleteStorage :: Ptr Storage -> IO () 43 | deleteStorage object = [C.throwBlock| void { delete $(at::Storage* object);}|] 44 | 45 | instance CppObject Storage where 46 | fromPtr ptr = newForeignPtr ptr (deleteStorage ptr) 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/TensorList.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.TensorList where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newTensorList 34 | :: IO (Ptr TensorList) 35 | newTensorList = 36 | [C.throwBlock| std::vector* { return new std::vector( 37 | ); 38 | }|] 39 | 40 | 41 | 42 | deleteTensorList :: Ptr TensorList -> IO () 43 | deleteTensorList object = [C.throwBlock| void { delete $(std::vector* object);}|] 44 | 45 | instance CppObject TensorList where 46 | fromPtr ptr = newForeignPtr ptr (deleteTensorList ptr) 47 | 48 | 49 | 50 | tensorList_empty 51 | :: Ptr TensorList 52 | -> IO (CBool) 53 | tensorList_empty _obj = 54 | [C.throwBlock| bool { return (*$(std::vector* _obj)).empty( 55 | ); 56 | }|] 57 | 58 | tensorList_size 59 | :: Ptr TensorList 60 | -> IO (CSize) 61 | tensorList_size _obj = 62 | [C.throwBlock| size_t { return (*$(std::vector* _obj)).size( 63 | ); 64 | }|] 65 | 66 | tensorList_at_s 67 | :: Ptr TensorList 68 | -> CSize 69 | -> IO (Ptr Tensor) 70 | tensorList_at_s _obj _s = 71 | [C.throwBlock| at::Tensor* { return new at::Tensor((*$(std::vector* _obj)).at( 72 | $(size_t _s))); 73 | }|] 74 | 75 | tensorList_push_back_t 76 | :: Ptr TensorList 77 | -> Ptr Tensor 78 | -> IO (()) 79 | tensorList_push_back_t _obj _v = 80 | [C.throwBlock| void { (*$(std::vector* _obj)).push_back( 81 | *$(at::Tensor* _v)); 82 | }|] 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /ffi/src/ATen/Unmanaged/Type/TensorOptions.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PolyKinds #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE QuasiQuotes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | 11 | module ATen.Unmanaged.Type.TensorOptions where 12 | 13 | 14 | import qualified Language.C.Inline.Cpp as C 15 | import qualified Language.C.Inline.Cpp.Exceptions as C 16 | import qualified Language.C.Inline.Context as C 17 | import qualified Language.C.Types as C 18 | import qualified Data.Map as Map 19 | import Foreign.C.String 20 | import Foreign.C.Types 21 | import Foreign hiding (newForeignPtr) 22 | import Foreign.Concurrent 23 | import ATen.Type 24 | import ATen.Class 25 | 26 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 27 | 28 | C.include "" 29 | C.include "" 30 | 31 | 32 | 33 | newTensorOptions_s 34 | :: ScalarType 35 | -> IO (Ptr TensorOptions) 36 | newTensorOptions_s _d = 37 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions( 38 | $(at::ScalarType _d)); 39 | }|] 40 | 41 | 42 | 43 | deleteTensorOptions :: Ptr TensorOptions -> IO () 44 | deleteTensorOptions object = [C.throwBlock| void { delete $(at::TensorOptions* object);}|] 45 | 46 | instance CppObject TensorOptions where 47 | fromPtr ptr = newForeignPtr ptr (deleteTensorOptions ptr) 48 | 49 | 50 | 51 | tensorOptions_device_D 52 | :: Ptr TensorOptions 53 | -> DeviceType 54 | -> IO (Ptr TensorOptions) 55 | tensorOptions_device_D _obj _device = 56 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).device( 57 | $(at::DeviceType _device))); 58 | }|] 59 | 60 | tensorOptions_device_index_s 61 | :: Ptr TensorOptions 62 | -> Int16 63 | -> IO (Ptr TensorOptions) 64 | tensorOptions_device_index_s _obj _device_index = 65 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).device_index( 66 | $(int16_t _device_index))); 67 | }|] 68 | 69 | tensorOptions_dtype_s 70 | :: Ptr TensorOptions 71 | -> ScalarType 72 | -> IO (Ptr TensorOptions) 73 | tensorOptions_dtype_s _obj _dtype = 74 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).dtype( 75 | $(at::ScalarType _dtype))); 76 | }|] 77 | 78 | tensorOptions_dtype 79 | :: Ptr TensorOptions 80 | -> IO (Ptr TensorOptions) 81 | tensorOptions_dtype _obj = 82 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).dtype( 83 | )); 84 | }|] 85 | 86 | tensorOptions_layout_L 87 | :: Ptr TensorOptions 88 | -> Layout 89 | -> IO (Ptr TensorOptions) 90 | tensorOptions_layout_L _obj _layout = 91 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).layout( 92 | $(at::Layout _layout))); 93 | }|] 94 | 95 | tensorOptions_requires_grad_b 96 | :: Ptr TensorOptions 97 | -> CBool 98 | -> IO (Ptr TensorOptions) 99 | tensorOptions_requires_grad_b _obj _requires_grad = 100 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).requires_grad( 101 | $(bool _requires_grad))); 102 | }|] 103 | 104 | tensorOptions_is_variable_b 105 | :: Ptr TensorOptions 106 | -> CBool 107 | -> IO (Ptr TensorOptions) 108 | tensorOptions_is_variable_b _obj _is_variable = 109 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions((*$(at::TensorOptions* _obj)).is_variable( 110 | $(bool _is_variable))); 111 | }|] 112 | 113 | tensorOptions_has_device 114 | :: Ptr TensorOptions 115 | -> IO (CBool) 116 | tensorOptions_has_device _obj = 117 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).has_device( 118 | ); 119 | }|] 120 | 121 | tensorOptions_device_index 122 | :: Ptr TensorOptions 123 | -> IO (Int32) 124 | tensorOptions_device_index _obj = 125 | [C.throwBlock| int32_t { return (*$(at::TensorOptions* _obj)).device_index( 126 | ); 127 | }|] 128 | 129 | tensorOptions_has_dtype 130 | :: Ptr TensorOptions 131 | -> IO (CBool) 132 | tensorOptions_has_dtype _obj = 133 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).has_dtype( 134 | ); 135 | }|] 136 | 137 | tensorOptions_layout 138 | :: Ptr TensorOptions 139 | -> IO (Layout) 140 | tensorOptions_layout _obj = 141 | [C.throwBlock| at::Layout { return (*$(at::TensorOptions* _obj)).layout( 142 | ); 143 | }|] 144 | 145 | tensorOptions_has_layout 146 | :: Ptr TensorOptions 147 | -> IO (CBool) 148 | tensorOptions_has_layout _obj = 149 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).has_layout( 150 | ); 151 | }|] 152 | 153 | tensorOptions_requires_grad 154 | :: Ptr TensorOptions 155 | -> IO (CBool) 156 | tensorOptions_requires_grad _obj = 157 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).requires_grad( 158 | ); 159 | }|] 160 | 161 | tensorOptions_has_requires_grad 162 | :: Ptr TensorOptions 163 | -> IO (CBool) 164 | tensorOptions_has_requires_grad _obj = 165 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).has_requires_grad( 166 | ); 167 | }|] 168 | 169 | tensorOptions_is_variable 170 | :: Ptr TensorOptions 171 | -> IO (CBool) 172 | tensorOptions_is_variable _obj = 173 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).is_variable( 174 | ); 175 | }|] 176 | 177 | tensorOptions_has_is_variable 178 | :: Ptr TensorOptions 179 | -> IO (CBool) 180 | tensorOptions_has_is_variable _obj = 181 | [C.throwBlock| bool { return (*$(at::TensorOptions* _obj)).has_is_variable( 182 | ); 183 | }|] 184 | 185 | tensorOptions_backend 186 | :: Ptr TensorOptions 187 | -> IO (Backend) 188 | tensorOptions_backend _obj = 189 | [C.throwBlock| at::Backend { return (*$(at::TensorOptions* _obj)).backend( 190 | ); 191 | }|] 192 | 193 | 194 | 195 | dtype_s 196 | :: ScalarType 197 | -> IO (Ptr TensorOptions) 198 | dtype_s _dtype = 199 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions(at::dtype( 200 | $(at::ScalarType _dtype))); 201 | }|] 202 | 203 | layout_L 204 | :: Layout 205 | -> IO (Ptr TensorOptions) 206 | layout_L _layout = 207 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions(at::layout( 208 | $(at::Layout _layout))); 209 | }|] 210 | 211 | device_D 212 | :: DeviceType 213 | -> IO (Ptr TensorOptions) 214 | device_D _device = 215 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions(at::device( 216 | $(at::DeviceType _device))); 217 | }|] 218 | 219 | device_index_s 220 | :: Int16 221 | -> IO (Ptr TensorOptions) 222 | device_index_s _device_index = 223 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions(at::device_index( 224 | $(int16_t _device_index))); 225 | }|] 226 | 227 | requires_grad_b 228 | :: CBool 229 | -> IO (Ptr TensorOptions) 230 | requires_grad_b _requires_grad = 231 | [C.throwBlock| at::TensorOptions* { return new at::TensorOptions(at::requires_grad( 232 | $(bool _requires_grad))); 233 | }|] 234 | 235 | -------------------------------------------------------------------------------- /ffi/src/Torch/Managed/Autograd.hs: -------------------------------------------------------------------------------- 1 | 2 | module Torch.Managed.Autograd where 3 | 4 | import Foreign.ForeignPtr 5 | 6 | import qualified Torch.Unmanaged.Autograd as Unmanaged 7 | import qualified ATen.Unmanaged.Type.Tensor 8 | import qualified ATen.Unmanaged.Type.TensorList 9 | import ATen.Type 10 | import ATen.Class 11 | import ATen.Cast 12 | 13 | 14 | grad :: ForeignPtr Tensor -> ForeignPtr TensorList -> IO (ForeignPtr TensorList) 15 | grad = cast2 Unmanaged.grad 16 | 17 | 18 | makeIndependent :: ForeignPtr Tensor -> IO (ForeignPtr Tensor) 19 | makeIndependent = cast1 Unmanaged.makeIndependent 20 | 21 | dropVariable :: ForeignPtr Tensor -> IO (ForeignPtr Tensor) 22 | dropVariable = cast1 Unmanaged.dropVariable 23 | -------------------------------------------------------------------------------- /ffi/src/Torch/Unmanaged/Autograd.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE PolyKinds #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE QuasiQuotes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE OverloadedStrings #-} 7 | 8 | module Torch.Unmanaged.Autograd where 9 | 10 | import Foreign.Ptr 11 | import qualified Language.C.Inline.Cpp as C 12 | import qualified Language.C.Inline.Cpp.Exceptions as C 13 | import qualified Language.C.Inline.Context as C 14 | import qualified Language.C.Types as C 15 | 16 | import ATen.Type 17 | 18 | C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } 19 | 20 | C.include "" 21 | C.include "" 22 | C.include "" 23 | C.include "" 24 | C.include "" 25 | 26 | grad :: Ptr Tensor -> Ptr TensorList -> IO (Ptr TensorList) 27 | grad y inputs = [C.throwBlock| std::vector* { 28 | torch::autograd::Variable y = *$(at::Tensor* y); 29 | const auto & inputs = *$(std::vector* inputs); 30 | 31 | torch::autograd::edge_list roots { y.gradient_edge() }; 32 | if (!roots[0].function) { 33 | throw std::runtime_error("Differentiated tensor not require grad"); 34 | } 35 | 36 | if (y.numel() != 1) { 37 | throw std::runtime_error("Differentiated tensor has more than a single element"); 38 | } 39 | torch::autograd::variable_list grads { torch::ones_like(y) }; 40 | 41 | torch::autograd::edge_list output_edges; 42 | output_edges.reserve(inputs.size()); 43 | for (torch::autograd::Variable input : inputs) { 44 | const auto output_nr = input.output_nr(); 45 | auto grad_fn = input.grad_fn(); 46 | if (!grad_fn) { 47 | grad_fn = input.try_get_grad_accumulator(); 48 | } 49 | if (!input.requires_grad()) { 50 | throw std::runtime_error("One of the differentiated Tensors does not require grad"); 51 | } 52 | if (!grad_fn) { 53 | output_edges.emplace_back(); 54 | } else { 55 | output_edges.emplace_back(grad_fn, output_nr); 56 | } 57 | } 58 | 59 | auto & engine = torch::autograd::Engine::get_default_engine(); 60 | auto outputs = engine.execute(roots, grads, 61 | /*keep_graph=*/true, 62 | /*create_graph=*/false, 63 | output_edges); 64 | 65 | return new std::vector(at::fmap(outputs)); 66 | }|] 67 | 68 | makeIndependent :: Ptr Tensor -> IO (Ptr Tensor) 69 | makeIndependent t = [C.throwBlock| at::Tensor* { 70 | return new at::Tensor($(at::Tensor* t)->detach().set_requires_grad(true)); 71 | }|] 72 | 73 | dropVariable :: Ptr Tensor -> IO (Ptr Tensor) 74 | dropVariable t = [C.throwBlock| at::Tensor* { 75 | auto ret = $(at::Tensor* t)->detach(); 76 | ret.unsafeGetTensorImpl()->set_autograd_meta(nullptr); 77 | return new at::Tensor(ret); 78 | }|] 79 | -------------------------------------------------------------------------------- /ffi/test/BackwardSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE BangPatterns #-} 3 | {-# LANGUAGE DataKinds #-} 4 | 5 | -- This test does the same test as deps/pytorch/aten/src/ATen/test/basic.cpp 6 | 7 | module BackwardSpec (spec) where 8 | 9 | import Test.Hspec 10 | import Control.Exception.Safe 11 | import qualified Language.C.Inline.Cpp.Exceptions as C 12 | import Control.Monad (forM_,forM,join) 13 | import Data.Int 14 | import Foreign 15 | import ATen.Const 16 | import ATen.Type 17 | import ATen.Class 18 | import ATen.Managed.Type.TensorOptions 19 | import ATen.Managed.Type.Tensor 20 | import ATen.Managed.Type.TensorList 21 | import ATen.Managed.Type.Extra 22 | import ATen.Managed.Type.IntArray 23 | import ATen.Managed.Type.Scalar 24 | import ATen.Managed.Type.Tuple 25 | import qualified ATen.Managed.Native as A 26 | import Torch.Managed.Native 27 | 28 | intArray :: [Int64] -> IO (ForeignPtr IntArray) 29 | intArray dims = do 30 | ary <- newIntArray 31 | forM_ dims $ intArray_push_back_l ary 32 | return ary 33 | 34 | tensorList :: [ForeignPtr Tensor] -> IO (ForeignPtr TensorList) 35 | tensorList dims = do 36 | ary <- newTensorList 37 | forM_ dims $ tensorList_push_back_t ary 38 | return ary 39 | 40 | ap1 fn a0 = join $ fn <$> a0 41 | ap2 fn a0 a1 = join $ fn <$> a0 <*> a1 42 | ap3 fn a0 a1 a2 = join $ fn <$> a0 <*> a1 <*> a2 43 | ap4 fn a0 a1 a2 a3 = join $ fn <$> a0 <*> a1 <*> a2 <*> a3 44 | 45 | at1 tensor i0 = tensor__at__l tensor i0 46 | at2 tensor i0 i1 = ap2 tensor__at__l (at1 tensor i0) (pure i1) 47 | at3 tensor i0 i1 i2 = ap2 tensor__at__l (at2 tensor i0 i1) (pure i2) 48 | 49 | new' fn dsize dtype = ap2 fn (intArray dsize) (options kCPU dtype) 50 | add' a b = join $ A.add_tts <$> pure a <*> pure b <*> newScalar_d 1 51 | addM' a b = join $ A.add_tts <$> a <*> b <*> newScalar_d 1 52 | add_s' a b = join $ A.add_tss <$> pure a <*> pure b <*> newScalar_d 1 53 | addM_s' a b = join $ A.add_tss <$> a <*> b <*> newScalar_d 1 54 | 55 | 56 | options :: DeviceType -> ScalarType -> IO (ForeignPtr TensorOptions) 57 | options dtype stype = ap2 tensorOptions_requires_grad_b (ap2 tensorOptions_dtype_s (device_D dtype) (pure stype)) (pure 1) 58 | 59 | spec :: Spec 60 | spec = forM_ [ 61 | (kFloat,"float"), 62 | (kDouble,"double") 63 | ] $ \(dtype,dtype_str) -> describe ("BasicSpec:" <> dtype_str) $ do 64 | -- torch::Tensor a = torch::ones({2, 2}, torch::requires_grad()); 65 | -- torch::Tensor b = torch::randn({2, 2}); 66 | -- auto c = a + b; 67 | -- c.backward(); 68 | -- std::cout << a << std::endl << b << std::endl << c << std::endl; 69 | it "Backward" $ do 70 | a <- new' ones_lo [2,2] dtype 71 | print "--a--" 72 | forM_ [0..1] $ \i -> 73 | forM_ [0..1] $ \j -> 74 | at2 a i j >>= tensor_item_double >>= print 75 | b <- new' randn_lo [2,2] dtype 76 | print "--b--" 77 | forM_ [0..1] $ \i -> 78 | forM_ [0..1] $ \j -> 79 | at2 b i j >>= tensor_item_double >>= print 80 | print "--c--" 81 | c <- add' a b 82 | forM_ [0..1] $ \i -> 83 | forM_ [0..1] $ \j -> 84 | at2 c i j >>= tensor_item_double >>= print 85 | tensor_print c 86 | tensor_backward c 87 | a' <- tensor_grad a 88 | print "--a'--" 89 | forM_ [0..1] $ \i -> 90 | forM_ [0..1] $ \j -> 91 | (at2 a' i j >>= tensor_item_double) `shouldReturn` 1 92 | 93 | 94 | -------------------------------------------------------------------------------- /ffi/test/CudaSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE BangPatterns #-} 3 | {-# LANGUAGE DataKinds #-} 4 | module CudaSpec (main, spec) where 5 | 6 | import Test.Hspec 7 | import Control.Exception.Safe (bracket,catch,throwIO) 8 | import Control.Monad (forM_,forM) 9 | import Data.Int 10 | import Foreign 11 | import ATen.Const 12 | import ATen.Type 13 | import ATen.Managed.Type.TensorOptions 14 | import ATen.Managed.Type.Tensor 15 | import ATen.Managed.Type.IntArray 16 | import ATen.Managed.Type.Context 17 | import ATen.Managed.Native 18 | import ATen.GC 19 | 20 | main :: IO () 21 | main = hspec spec 22 | 23 | spec :: Spec 24 | spec = do 25 | describe "CudaSpec" $ do 26 | it "When CUDA is out of memory, do GC and retry" $ do 27 | flag <- hasCUDA 28 | monitorMemory $ do 29 | forM_ [0..1000] $ \i -> do -- 80MByte x 1000 = 80GByte 30 | dims <- fromList [1000,1000,10] -- 8 byte x 10M = 80MByte 31 | to <- device_D $ if flag == 0 then kCPU else kCUDA 32 | tod <- tensorOptions_dtype_s to kDouble 33 | zeros_lo dims tod 34 | return () 35 | 36 | fromList :: [Int64] -> IO (ForeignPtr IntArray) 37 | fromList dims = do 38 | ary <- newIntArray 39 | forM_ dims $ intArray_push_back_l ary 40 | return ary 41 | -------------------------------------------------------------------------------- /ffi/test/MemorySpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | {-# LANGUAGE BangPatterns #-} 3 | {-# LANGUAGE DataKinds #-} 4 | module MemorySpec (main, spec) where 5 | 6 | import Test.Hspec 7 | import Control.Exception (bracket) 8 | import Control.Monad (forM_,forM) 9 | import Data.Int 10 | import Foreign 11 | import ATen.Const 12 | import ATen.Type 13 | import ATen.Managed.Type.TensorOptions 14 | import ATen.Managed.Type.Tensor 15 | import ATen.Managed.Type.IntArray 16 | import ATen.Managed.Type.Context 17 | import ATen.Managed.Native 18 | 19 | import System.Mem () 20 | 21 | -- |Confirm that memory is deallocated (works) 22 | main :: IO () 23 | main = hspec spec 24 | 25 | --type SomeDims = IntArray 26 | 27 | spec :: Spec 28 | spec = do 29 | describe "MemorySpec" $ do 30 | it "scenario: memoryTestMinimal" memoryTestMinimal 31 | 32 | 33 | fromList :: [Int64] -> IO (ForeignPtr IntArray) 34 | fromList dims = do 35 | ary <- newIntArray 36 | forM_ dims $ intArray_push_back_l ary 37 | return ary 38 | 39 | newTensor_zeros :: (ForeignPtr IntArray) -> IO (ForeignPtr Tensor) 40 | newTensor_zeros dims = do 41 | flag <- hasCUDA 42 | to <- device_D $ if flag == 0 then kCPU else kCUDA 43 | tod <- tensorOptions_dtype_s to kByte 44 | zeros_lo dims tod 45 | 46 | totalDim :: (ForeignPtr IntArray) -> IO Int64 47 | totalDim dims = do 48 | size <- intArray_size dims 49 | dims' <- forM [0..(size-1)] $ \i -> intArray_at_s dims i 50 | return $ sum dims' 51 | 52 | iterator :: (ForeignPtr IntArray) -> Int -> IO () 53 | iterator = iteratorBracket 54 | 55 | -- |Leaks memory 56 | iteratorAssign :: (ForeignPtr IntArray) -> Int -> IO () 57 | iteratorAssign d niter = do 58 | size <- memSizeGB d 59 | putStrLn $ show size ++ " GB per allocation x " ++ show niter 60 | forM_ [1..niter] $ \iter -> do 61 | putStr ("Iteration : " ++ show iter ++ " / ") 62 | x <- newTensor_zeros d 63 | v <- tensor_dim x 64 | putStr $ "Printing dummy value: " ++ show v ++ "\r" 65 | putStrLn "Done" 66 | 67 | -- |Releases memory on OSX (but not consistently on linux) 68 | iteratorMonadic :: (ForeignPtr IntArray) -> Int -> IO () 69 | iteratorMonadic d niter = do 70 | size <- memSizeGB d 71 | putStrLn $ show size ++ " GB per allocation x " ++ show niter 72 | forM_ [1..niter] $ \iter -> do 73 | putStr ("Iteration : " ++ show iter ++ " / ") 74 | x <- newTensor_zeros d 75 | v <- tensor_dim x 76 | putStr $ "Printing dummy value: " ++ show v ++ "\r" 77 | putStrLn "Done" 78 | 79 | -- |Releases memory 80 | iteratorBracket :: (ForeignPtr IntArray) -> Int -> IO () 81 | iteratorBracket d niter = do 82 | size <- memSizeGB d 83 | putStrLn $ show size ++ " GB per allocation x " ++ show niter 84 | forM_ [1..niter] $ \iter -> 85 | bracket (pure iter) 86 | (\iter -> do 87 | putStr ("Iteration : " ++ show iter ++ " / ") 88 | x <- newTensor_zeros d 89 | v <- tensor_dim x 90 | putStr $ "Printing dummy value: " ++ show v ++ "\r" 91 | ) 92 | (const (pure ())) 93 | putStrLn "Done" 94 | 95 | 96 | -- |getDim' size per allocation 97 | memSizeGB :: (ForeignPtr IntArray) -> IO Double 98 | memSizeGB d = do 99 | td <- totalDim d 100 | return $ (fromIntegral td * 8) / 1000000000.0 101 | 102 | memoryTestLarge :: IO () 103 | memoryTestLarge = do 104 | dims <- fromList [200, 200, 200, 200] 105 | iterator dims 1000000 -- 12.8 GB x 1M = 12M GB 106 | 107 | memoryTestSmall :: IO () 108 | memoryTestSmall = do 109 | dims <- fromList [100, 100, 100, 7] 110 | iterator dims 300 -- 50 MB x 300 = 15 GB 111 | 112 | memoryTestFast :: IO () 113 | memoryTestFast = do 114 | dims <- fromList [50, 50, 50, 5] 115 | iterator dims 10000 -- 5 MB x 1000 = 5 GB 116 | 117 | memoryTestMinimal :: IO () 118 | memoryTestMinimal = do 119 | dims <- fromList [50, 50, 50, 5] 120 | iterator dims 100 -- 5 MB x 100 = 500 MB 121 | -------------------------------------------------------------------------------- /ffi/test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | -------------------------------------------------------------------------------- /hasktorch/hasktorch.cabal: -------------------------------------------------------------------------------- 1 | name: hasktorch 2 | version: 0.2.0.0 3 | synopsis: initial implementation for hasktorch based on libtorch 4 | -- description: 5 | homepage: https://github.com/githubuser/ffi-experimental#readme 6 | license: BSD3 7 | author: Austin Huang 8 | maintainer: hasktorch@gmail.com 9 | copyright: 2019 Austin Huang 10 | category: Codegen 11 | build-type: Simple 12 | cabal-version: >=1.10 13 | 14 | library 15 | exposed-modules: Torch 16 | Torch.Tensor 17 | , Torch.TensorOptions 18 | , Torch.DType 19 | , Torch.TensorFactories 20 | , Torch.Functions 21 | , Torch.Functions.Native 22 | , Torch.Autograd 23 | , Torch.Static 24 | , Torch.NN 25 | , Torch.Scalar 26 | , Torch.Backend 27 | , Torch.Layout 28 | , Torch.Cast 29 | hs-source-dirs: src 30 | default-language: Haskell2010 31 | ghc-options: -fplugin GHC.TypeLits.KnownNat.Solver -fconstraint-solver-iterations=0 32 | build-depends: base >= 4.7 && < 5 33 | , ffi 34 | , finite-typelits 35 | , ghc-typelits-knownnat 36 | , mtl 37 | , safe-exceptions 38 | , reflection 39 | 40 | test-suite spec 41 | type: exitcode-stdio-1.0 42 | hs-source-dirs: test 43 | main-is: Spec.hs 44 | other-modules: FactorySpec 45 | , FunctionsSpec 46 | , GradSpec 47 | , SparseSpec 48 | , TensorSpec 49 | , NNSpec 50 | default-language: Haskell2010 51 | build-depends: base >= 4.7 && < 5 52 | , hasktorch 53 | , hspec 54 | , hspec-discover 55 | , safe-exceptions 56 | , QuickCheck 57 | , mtl 58 | -------------------------------------------------------------------------------- /hasktorch/src/Torch.hs: -------------------------------------------------------------------------------- 1 | module Torch where 2 | 3 | import Torch.Tensor 4 | import Torch.TensorOptions 5 | import Torch.TensorFactories 6 | import Torch.Functions 7 | import Torch.Autograd 8 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/Autograd.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE MultiParamTypeClasses #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | 5 | module Torch.Autograd where 6 | 7 | import System.IO.Unsafe 8 | import Foreign.ForeignPtr 9 | 10 | import qualified Torch.Managed.Autograd 11 | import qualified ATen.Managed.Type.Tensor as ATen 12 | import qualified ATen.Type as ATen 13 | import ATen.Class 14 | import ATen.Cast 15 | 16 | import Torch.Tensor 17 | 18 | newtype IndependentTensor = IndependentTensor { toDependent :: Tensor } 19 | deriving (Show) 20 | 21 | grad :: Tensor -> [IndependentTensor] -> [Tensor] 22 | grad y inputs = unsafePerformIO $ (cast2 Torch.Managed.Autograd.grad) y (map toDependent inputs) 23 | 24 | requiresGrad :: Tensor -> Bool 25 | requiresGrad t = unsafePerformIO $ (cast1 ATen.tensor_requires_grad) t 26 | 27 | makeIndependent :: Tensor -> IO IndependentTensor 28 | makeIndependent t = (cast1 Torch.Managed.Autograd.makeIndependent) t >>= return . IndependentTensor 29 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/Backend.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | 5 | module Torch.Backend where 6 | 7 | import ATen.Class (Castable(..)) 8 | import qualified ATen.Const as ATen 9 | import qualified ATen.Type as ATen 10 | 11 | data Backend = CPU | CUDA | HIP | SparseCPU | SparseCUDA | MSNPU | XLA 12 | deriving (Eq, Show) 13 | 14 | instance Castable Backend ATen.Backend where 15 | cast CPU f = f ATen.bCPU 16 | cast CUDA f = f ATen.bCUDA 17 | cast HIP f = f ATen.bHIP 18 | cast SparseCPU f = f ATen.bSparseCPU 19 | cast SparseCUDA f = f ATen.bSparseCUDA 20 | cast MSNPU f = f ATen.bMSNPU 21 | cast XLA f = f ATen.bXLA 22 | 23 | uncast x f 24 | | x == ATen.bCPU = f CPU 25 | | x == ATen.bCUDA = f CUDA 26 | | x == ATen.bHIP = f HIP 27 | | x == ATen.bSparseCPU = f SparseCPU 28 | | x == ATen.bSparseCUDA = f SparseCUDA 29 | | x == ATen.bMSNPU = f MSNPU 30 | | x == ATen.bXLA = f XLA 31 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/Cast.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE TypeFamilies #-} 5 | 6 | module Torch.Cast where 7 | 8 | import Foreign.ForeignPtr 9 | 10 | import ATen.Managed.Type.IntArray 11 | import ATen.Type 12 | import ATen.Class 13 | import ATen.Cast 14 | 15 | -- define useful casts 16 | 17 | instance CppTuple2 (ForeignPtr IntArray) where 18 | type A (ForeignPtr IntArray) = Int 19 | type B (ForeignPtr IntArray) = Int 20 | get0 v = cast1 (flip intArray_at_s 0) v 21 | get1 v = cast1 (flip intArray_at_s 1) v 22 | 23 | instance CppTuple3 (ForeignPtr IntArray) where 24 | type C (ForeignPtr IntArray) = Int 25 | get2 v = cast1 (flip intArray_at_s 2) v 26 | 27 | instance CppTuple4 (ForeignPtr IntArray) where 28 | type D (ForeignPtr IntArray) = Int 29 | get3 v = cast1 (flip intArray_at_s 3) v 30 | 31 | instance CppTuple5 (ForeignPtr IntArray) where 32 | type E (ForeignPtr IntArray) = Int 33 | get4 v = cast1 (flip intArray_at_s 4) v 34 | 35 | instance CppTuple6 (ForeignPtr IntArray) where 36 | type F (ForeignPtr IntArray) = Int 37 | get5 v = cast1 (flip intArray_at_s 5) v 38 | 39 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/DType.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE AllowAmbiguousTypes #-} 5 | 6 | module Torch.DType where 7 | 8 | import ATen.Class (Castable(..)) 9 | import qualified ATen.Const as ATen 10 | import qualified ATen.Type as ATen 11 | import Data.Int 12 | import Data.Word 13 | import Data.Reflection 14 | 15 | data DType = UInt8 | Int8 | Int16 | Int32 | Int64 | Half | Float | Double 16 | deriving (Eq, Show) 17 | 18 | instance Reifies Word8 DType where 19 | reflect _ = UInt8 20 | 21 | instance Reifies Int8 DType where 22 | reflect _ = Int8 23 | 24 | instance Reifies Int16 DType where 25 | reflect _ = Int16 26 | 27 | instance Reifies Int32 DType where 28 | reflect _ = Int32 29 | 30 | instance Reifies Int DType where 31 | reflect _ = Int64 32 | 33 | instance Reifies Int64 DType where 34 | reflect _ = Int64 35 | 36 | instance Reifies Float DType where 37 | reflect _ = Float 38 | 39 | instance Reifies Double DType where 40 | reflect _ = Double 41 | 42 | instance Castable DType ATen.ScalarType where 43 | cast UInt8 f = f ATen.kByte 44 | cast Int8 f = f ATen.kChar 45 | cast Int16 f = f ATen.kShort 46 | cast Int32 f = f ATen.kInt 47 | cast Int64 f = f ATen.kLong 48 | cast Half f = f ATen.kHalf 49 | cast Float f = f ATen.kFloat 50 | cast Double f = f ATen.kDouble 51 | 52 | uncast x f 53 | | x == ATen.kByte = f UInt8 54 | | x == ATen.kChar = f Int8 55 | | x == ATen.kShort = f Int16 56 | | x == ATen.kInt = f Int32 57 | | x == ATen.kLong = f Int64 58 | | x == ATen.kHalf = f Half 59 | | x == ATen.kFloat = f Float 60 | | x == ATen.kDouble = f Double 61 | 62 | 63 | isIntegral :: DType -> Bool 64 | isIntegral UInt8 = True 65 | isIntegral Int8 = True 66 | isIntegral Int16 = True 67 | isIntegral Int32 = True 68 | isIntegral Int64 = True 69 | isIntegral Half = False 70 | isIntegral Float = False 71 | isIntegral Double = False 72 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/Layout.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | 5 | module Torch.Layout where 6 | 7 | import ATen.Class (Castable(..)) 8 | import qualified ATen.Const as ATen 9 | import qualified ATen.Type as ATen 10 | 11 | data Layout = Strided | Sparse | Mkldnn 12 | deriving (Eq, Show) 13 | 14 | instance Castable Layout ATen.Layout where 15 | cast Strided f = f ATen.kStrided 16 | cast Sparse f = f ATen.kSparse 17 | cast Mkldnn f = f ATen.kMkldnn 18 | 19 | uncast x f 20 | | x == ATen.kStrided = f Strided 21 | | x == ATen.kSparse = f Sparse 22 | | x == ATen.kMkldnn = f Mkldnn 23 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/NN.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE RecordWildCards #-} 4 | {-# LANGUAGE DefaultSignatures #-} 5 | {-# LANGUAGE TypeSynonymInstances #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE FlexibleInstances #-} 9 | {-# LANGUAGE DeriveGeneric #-} 10 | 11 | module Torch.NN where 12 | 13 | import Control.Monad.State.Strict 14 | 15 | import Torch.Autograd 16 | import Torch.Tensor 17 | import Torch.TensorFactories (ones', rand', randn') 18 | import Torch.Functions 19 | import GHC.Generics 20 | 21 | type Parameter = IndependentTensor 22 | type ParamStream a = State [Parameter] a 23 | 24 | nextParameter :: ParamStream Parameter 25 | nextParameter = do 26 | params <- get 27 | case params of 28 | [] -> error "Not enough parameters supplied to replaceParameters" 29 | (p : t) -> do put t; return p 30 | 31 | class Parameterized f where 32 | flattenParameters :: f -> [Parameter] 33 | default flattenParameters :: (Generic f, Parameterized' (Rep f)) => f -> [Parameter] 34 | flattenParameters f = flattenParameters' (from f) 35 | 36 | replaceOwnParameters :: f -> ParamStream f 37 | default replaceOwnParameters :: (Generic f, Parameterized' (Rep f)) => f -> ParamStream f 38 | replaceOwnParameters f = fmap to $ replaceOwnParameters' (from f) 39 | 40 | instance Parameterized Parameter where 41 | flattenParameters x = [x] 42 | replaceOwnParameters _ = nextParameter 43 | 44 | instance Parameterized [Int] where 45 | flattenParameters x = [] 46 | replaceOwnParameters x = return x 47 | 48 | instance Parameterized (Tensor -> Tensor) where 49 | flattenParameters x = [] 50 | replaceOwnParameters x = return x 51 | 52 | class Parameterized' f where 53 | flattenParameters' :: f a -> [Parameter] 54 | replaceOwnParameters' :: f a -> ParamStream (f a) 55 | 56 | instance Parameterized' U1 where 57 | flattenParameters' U1 = [] 58 | replaceOwnParameters' U1 = return U1 59 | 60 | instance (Parameterized' f, Parameterized' g) => Parameterized' (f :+: g) where 61 | flattenParameters' (L1 x) = flattenParameters' x 62 | flattenParameters' (R1 x) = flattenParameters' x 63 | replaceOwnParameters' (L1 x) = do 64 | x' <- replaceOwnParameters' x 65 | return $ L1 x' 66 | replaceOwnParameters' (R1 x) = do 67 | x' <- replaceOwnParameters' x 68 | return $ R1 x' 69 | 70 | instance (Parameterized' f, Parameterized' g) => Parameterized' (f :*: g) where 71 | flattenParameters' (x :*: y) = flattenParameters' x ++ flattenParameters' y 72 | replaceOwnParameters' (x :*: y) = do 73 | x' <- replaceOwnParameters' x 74 | y' <- replaceOwnParameters' y 75 | return $ x' :*: y' 76 | 77 | instance (Parameterized c) => Parameterized' (K1 i c) where 78 | flattenParameters' (K1 x) = flattenParameters x 79 | replaceOwnParameters' (K1 x) = do 80 | x' <- replaceOwnParameters x 81 | return $ K1 x' 82 | 83 | instance (Parameterized' f) => Parameterized' (M1 i t f) where 84 | flattenParameters' (M1 x) = flattenParameters' x 85 | replaceOwnParameters' (M1 x) = do 86 | x' <- replaceOwnParameters' x 87 | return $ M1 x' 88 | 89 | replaceParameters :: Parameterized f => f -> [Parameter] -> f 90 | replaceParameters f params = 91 | let (f', remaining) = runState (replaceOwnParameters f) params in 92 | if null remaining 93 | then f' 94 | else error "Some parameters in a call to replaceParameters haven't been consumed!" 95 | 96 | class Randomizable spec f | spec -> f where 97 | sample :: spec -> IO f 98 | 99 | class (Randomizable spec f, Parameterized f) => Module spec f 100 | 101 | data LinearSpec = LinearSpec { in_features :: Int, out_features :: Int } 102 | deriving (Show, Eq) 103 | 104 | data Linear = Linear { weight :: Parameter, bias :: Parameter } deriving (Show, Generic) 105 | 106 | instance Randomizable LinearSpec Linear where 107 | sample LinearSpec{..} = do 108 | w <- makeIndependent =<< randn' [in_features, out_features] 109 | b <- makeIndependent =<< randn' [out_features] 110 | return $ Linear w b 111 | 112 | instance Parameterized Linear 113 | -- This instance generates following codes. 114 | -- 115 | --------------------------------------------------- 116 | -- instance Parameterized Linear where 117 | -- flattenParameters Linear{..} = [weight, bias] 118 | -- replaceOwnParameters _ = do 119 | -- weight <- nextParameter 120 | -- bias <- nextParameter 121 | -- return $ Linear{..} 122 | 123 | instance Parameterized [Linear] 124 | 125 | sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor] 126 | sgd lr parameters gradients = zipWith step depParameters gradients 127 | where 128 | step p dp = p - (lr * dp) 129 | depParameters = (map toDependent parameters) 130 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/Scalar.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | 5 | module Torch.Scalar where 6 | 7 | import Foreign.ForeignPtr 8 | 9 | import qualified ATen.Const as ATen 10 | import qualified ATen.Managed.Type.Scalar as ATen 11 | import qualified ATen.Type as ATen 12 | import ATen.Managed.Cast 13 | import ATen.Class (Castable(..)) 14 | import ATen.Cast 15 | 16 | instance Castable Float (ForeignPtr ATen.Scalar) where 17 | cast x f = ATen.newScalar_d (realToFrac x) >>= f 18 | uncast x f = undefined 19 | 20 | instance Castable Double (ForeignPtr ATen.Scalar) where 21 | cast x f = ATen.newScalar_d (realToFrac x) >>= f 22 | uncast x f = undefined 23 | 24 | instance Castable Int (ForeignPtr ATen.Scalar) where 25 | cast x f = ATen.newScalar_i (fromIntegral x) >>= f 26 | uncast x f = undefined 27 | 28 | class (Castable a (ForeignPtr ATen.Scalar)) => Scalar a 29 | instance Scalar Float 30 | instance Scalar Double 31 | instance Scalar Int 32 | 33 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/TensorFactories.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | 3 | module Torch.TensorFactories where 4 | 5 | import System.IO.Unsafe 6 | import Foreign.ForeignPtr 7 | 8 | import qualified ATen.Const as ATen 9 | import qualified ATen.Managed.Native as ATen 10 | import qualified ATen.Managed.Type.Tensor as ATen 11 | import qualified ATen.Managed.Type.TensorOptions as ATen 12 | import qualified ATen.Type as ATen 13 | import qualified Torch.Managed.Native as LibTorch 14 | import qualified Torch.Managed.Autograd as LibTorch 15 | import ATen.Managed.Cast 16 | import ATen.Class (Castable(..)) 17 | import ATen.Cast 18 | 19 | import Torch.Tensor 20 | import Torch.TensorOptions 21 | import Torch.Scalar 22 | 23 | -- XXX: We use the torch:: constructors, not at:: constructures, because 24 | -- otherwise we cannot use libtorch's AD. 25 | 26 | type FactoryType = ForeignPtr ATen.IntArray 27 | -> ForeignPtr ATen.TensorOptions 28 | -> IO (ForeignPtr ATen.Tensor) 29 | 30 | mkFactory :: FactoryType -> [Int] -> TensorOptions -> IO Tensor 31 | mkFactory aten_impl shape opts = (cast2 aten_impl) shape opts 32 | 33 | mkFactoryUnsafe :: FactoryType -> [Int] -> TensorOptions -> Tensor 34 | mkFactoryUnsafe f shape opts = unsafePerformIO $ mkFactory f shape opts 35 | 36 | mkDefaultFactory :: ([Int] -> TensorOptions -> a) -> [Int] -> a 37 | mkDefaultFactory non_default shape = non_default shape defaultOpts 38 | 39 | -------------------- Factories -------------------- 40 | 41 | ones :: [Int] -> TensorOptions -> Tensor 42 | ones = mkFactoryUnsafe LibTorch.ones_lo 43 | 44 | zeros :: [Int] -> TensorOptions -> Tensor 45 | zeros = mkFactoryUnsafe LibTorch.zeros_lo 46 | 47 | rand :: [Int] -> TensorOptions -> IO Tensor 48 | rand = mkFactory LibTorch.rand_lo 49 | 50 | randn :: [Int] -> TensorOptions -> IO Tensor 51 | randn = mkFactory LibTorch.randn_lo 52 | 53 | randn_like :: Tensor -> IO Tensor 54 | randn_like = cast1 ATen.randn_like_t 55 | 56 | linspace :: (Scalar a, Scalar b) => a -> b -> Int -> TensorOptions -> Tensor 57 | linspace start end steps opts = unsafePerformIO $ (cast4 LibTorch.linspace_sslo) start end steps opts 58 | 59 | logspace :: (Scalar a, Scalar b) => a -> b -> Int -> Double -> TensorOptions -> Tensor 60 | logspace start end steps base opts = unsafePerformIO $ (cast5 LibTorch.logspace_ssldo) start end steps base opts 61 | 62 | -- https://github.com/hasktorch/ffi-experimental/pull/57#discussion_r301062033 63 | -- empty :: [Int] -> TensorOptions -> Tensor 64 | -- empty = mkFactoryUnsafe LibTorch.empty_lo 65 | 66 | eyeSquare :: Int -> TensorOptions -> Tensor 67 | eyeSquare dim opts = unsafePerformIO $ (cast2 LibTorch.eye_lo) dim opts 68 | 69 | eye :: Int -> Int -> TensorOptions -> Tensor 70 | eye nrows ncols opts = unsafePerformIO $ (cast3 LibTorch.eye_llo) nrows ncols opts 71 | 72 | full :: Scalar a => [Int] -> a -> TensorOptions -> Tensor 73 | full shape value opts = unsafePerformIO $ (cast3 LibTorch.full_lso) shape value opts 74 | 75 | sparseCooTensor :: Tensor -> Tensor -> [Int] -> TensorOptions -> Tensor 76 | sparseCooTensor indices values size opts = unsafePerformIO $ (cast4 sparse_coo_tensor_ttlo) indices values size opts 77 | where 78 | sparse_coo_tensor_ttlo indices' values' size' opts' = do 79 | i' <- LibTorch.dropVariable indices' 80 | v' <- LibTorch.dropVariable values' 81 | LibTorch.sparse_coo_tensor_ttlo i' v' size' opts' 82 | 83 | -------------------- Factories with default type -------------------- 84 | 85 | ones' :: [Int] -> Tensor 86 | ones' = mkDefaultFactory ones 87 | 88 | zeros' :: [Int] -> Tensor 89 | zeros' = mkDefaultFactory zeros 90 | 91 | rand' :: [Int] -> IO Tensor 92 | rand' = mkDefaultFactory rand 93 | 94 | randn' :: [Int] -> IO Tensor 95 | randn' = mkDefaultFactory randn 96 | 97 | linspace' :: (Scalar a, Scalar b) => a -> b -> Int -> Tensor 98 | linspace' start end steps = linspace start end steps defaultOpts 99 | 100 | logspace' :: (Scalar a, Scalar b) => a -> b -> Int -> Double -> Tensor 101 | logspace' start end steps base = logspace start end steps base defaultOpts 102 | 103 | eyeSquare' :: Int -> Tensor 104 | eyeSquare' dim = eyeSquare dim defaultOpts 105 | 106 | eye' :: Int -> Int -> Tensor 107 | eye' nrows ncols = eye nrows ncols defaultOpts 108 | 109 | full' :: Scalar a => [Int] -> a -> Tensor 110 | full' shape value = full shape value defaultOpts 111 | 112 | sparseCooTensor' :: Tensor -> Tensor -> [Int] -> Tensor 113 | sparseCooTensor' indices values size = sparseCooTensor indices values size defaultOpts 114 | -------------------------------------------------------------------------------- /hasktorch/src/Torch/TensorOptions.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeSynonymInstances #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | 5 | module Torch.TensorOptions where 6 | 7 | import Foreign.ForeignPtr 8 | import System.IO.Unsafe 9 | 10 | import ATen.Cast 11 | import ATen.Class (Castable(..)) 12 | import qualified ATen.Type as ATen 13 | import qualified ATen.Const as ATen 14 | import qualified ATen.Managed.Type.TensorOptions as ATen 15 | 16 | import Torch.DType 17 | import Torch.Layout 18 | 19 | type ATenTensorOptions = ForeignPtr ATen.TensorOptions 20 | 21 | data TensorOptions = TensorOptions ATenTensorOptions 22 | 23 | instance Castable TensorOptions ATenTensorOptions where 24 | cast (TensorOptions aten_opts) f = f aten_opts 25 | uncast aten_opts f = f $ TensorOptions aten_opts 26 | 27 | defaultOpts :: TensorOptions 28 | defaultOpts = TensorOptions $ unsafePerformIO $ ATen.newTensorOptions_s ATen.kFloat 29 | 30 | withDType :: DType -> TensorOptions -> TensorOptions 31 | withDType dtype opts = unsafePerformIO $ (cast2 ATen.tensorOptions_dtype_s) opts dtype 32 | 33 | withLayout :: Layout -> TensorOptions -> TensorOptions 34 | withLayout layout opts = unsafePerformIO $ (cast2 ATen.tensorOptions_layout_L) opts layout 35 | -------------------------------------------------------------------------------- /hasktorch/test/FactorySpec.hs: -------------------------------------------------------------------------------- 1 | module FactorySpec (spec) where 2 | 3 | import Test.Hspec 4 | import Control.Exception.Safe 5 | 6 | import Torch.Tensor 7 | import Torch.DType 8 | import Torch.TensorFactories 9 | import Torch.Functions 10 | import Torch.TensorOptions 11 | 12 | spec :: Spec 13 | spec = do 14 | it "ones factory" $ do 15 | let x = ones' [50] 16 | shape x `shouldBe` [50] 17 | it "zeros factory" $ do 18 | let x = zeros' [50] 19 | shape x `shouldBe` [50] 20 | it "rand factory" $ do 21 | x <- rand' [50] 22 | shape x `shouldBe` [50] 23 | it "randn factory" $ do 24 | x <- randn' [50] 25 | shape x `shouldBe` [50] 26 | it "linspace factory" $ do 27 | let start = 5.0 :: Double 28 | let end = 25.0 :: Double 29 | let x = linspace start end 50 defaultOpts 30 | (toDouble $ select x 0 49) `shouldBe` 25.0 31 | it "logspace factory" $ do 32 | let start = 5.0 :: Double 33 | let end = 25.0 :: Double 34 | let x = logspace start end 50 2.0 defaultOpts 35 | (toDouble $ select x 0 0) `shouldBe` 32.0 36 | it "eyeSquare factory" $ do 37 | let x = eyeSquare' 7 38 | shape x `shouldBe` [7, 7] 39 | (toDouble $ select (select x 0 0) 0 0) `shouldBe` 1.0 40 | (toDouble $ select (select x 0 0) 0 1) `shouldBe` 0.0 41 | it "eye factory" $ do 42 | let x = eye' 7 3 43 | shape x `shouldBe` [7, 3] 44 | (toDouble $ select (select x 0 0) 0 0) `shouldBe` 1.0 45 | (toDouble $ select (select x 0 0) 0 1) `shouldBe` 0.0 46 | it "full factory" $ do 47 | let x = full' [5, 2] (15.0 :: Double) 48 | shape x `shouldBe` [5, 2] 49 | (toDouble $ select (select x 0 0) 0 0) `shouldBe` 15.0 50 | -------------------------------------------------------------------------------- /hasktorch/test/FunctionsSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoMonomorphismRestriction #-} 2 | 3 | module FunctionsSpec(spec) where 4 | 5 | import Prelude hiding (all, abs, exp, floor, log, min, max) 6 | 7 | import Test.Hspec 8 | import Control.Exception.Safe 9 | 10 | import Torch.Tensor 11 | import Torch.DType 12 | import Torch.TensorFactories 13 | import Torch.Functions 14 | import Torch.TensorOptions 15 | 16 | spec :: Spec 17 | spec = do 18 | it "scales and adds" $ do 19 | let x = 2 * ones' [10] + 3 * ones' [10] 20 | (toDouble $ select x 0 4) `shouldBe` 5.0 21 | it "sumAll" $ do 22 | let x = sumAll (2 * ones' [5]) 23 | toDouble x `shouldBe` 10.0 24 | it "abs" $ do 25 | let x = abs $ (-2) * ones' [5] 26 | (toDouble $ select x 0 0) `shouldBe` 2.0 27 | it "add" $ do 28 | let x = (-2) * ones' [5] 29 | let y = abs x 30 | let z = add x y 31 | (toDouble $ select z 0 0) `shouldBe` 0.0 32 | it "sub" $ do 33 | let x = (-2) * ones' [5] 34 | let y = abs x 35 | let z = sub x y 36 | (toDouble $ select z 0 0) `shouldBe` -4.0 37 | it "ceil" $ do 38 | x <- rand' [5] 39 | let y = ceil x 40 | (toDouble $ select y 0 0) `shouldBe` 1.0 41 | it "floor" $ do 42 | x <- rand' [5] 43 | let y = floor x 44 | (toDouble $ select y 0 0) `shouldBe` 0.0 45 | it "takes the minimum of a linspace" $ do 46 | let x = linspace (5.0 :: Double) (25.0 :: Double) 50 defaultOpts 47 | let m = min x 48 | toDouble m `shouldBe` 5.0 49 | it "takes the maximum of a linspace" $ do 50 | let x = linspace (5.0 :: Double) (25.0 :: Double) 50 defaultOpts 51 | let m = max x 52 | toDouble m `shouldBe` 25.0 53 | it "takes the median of a linspace" $ do 54 | let x = linspace (5.0 :: Double) (10.0 :: Double) 5 defaultOpts 55 | let m = median x 56 | toDouble m `shouldBe` 7.5 57 | it "performs matrix vector multiplication" $ do 58 | let m = 3 * ones' [5, 5] 59 | let v = 2 * ones' [5, 1] 60 | let x = matmul m v 61 | (toDouble $ select x 0 0) `shouldBe` 30.0 62 | it "erf" $ do 63 | let x = erf $ zeros' [4] 64 | (toDouble $ select x 0 0) `shouldBe` 0.0 65 | it "exp" $ do 66 | let x = exp $ zeros' [4] 67 | (toDouble $ select x 0 0) `shouldBe` 1.0 68 | it "log1p" $ do 69 | let x = log1p $ zeros' [4] 70 | (toDouble $ select x 0 0) `shouldBe` 0.0 71 | it "log2" $ do 72 | let x = log2 $ 4 * ones' [4] 73 | (toDouble $ select x 0 0) `shouldBe` 2.0 74 | it "log10" $ do 75 | let x = log10 $ 1000 * ones' [4] 76 | (toDouble $ select x 0 0) `shouldBe` 3.0 77 | it "relu (pos)" $ do 78 | let x = relu $ 5 * ones' [4] 79 | (toDouble $ select x 0 0) `shouldBe` 5.0 80 | it "relu (neg)" $ do 81 | let x = relu $ -5 * ones' [4] 82 | (toDouble $ select x 0 0) `shouldBe` 0.0 83 | it "gels" $ do 84 | let (x,qr) = gels (ones' [5,2]) (ones' [5,3]) 85 | shape x `shouldBe` [5,2] 86 | shape qr `shouldBe` [5,3] 87 | it "diag" $ do 88 | let x = ones' [3] 89 | let y = diag x 2 90 | shape y `shouldBe` [5, 5] 91 | 92 | -- decomposition / solvers 93 | 94 | it "cholesky decomposes" $ do 95 | let x = asTensor ([[4.0, 12.0, -16.0], [12.0, 37.0, -43.0], [-16.0, -43.0, 98.0]] :: [[Double]]) 96 | c = cholesky x Upper 97 | c' = asTensor ([[2.0, 6.0, -8.0], [0.0, 1.0, 5.0], [0.0, 0.0, 3.0]] :: [[Double]]) 98 | all (c ==. c') `shouldBe` True 99 | it "inverse of an identity matrix is an identity matrix" $ do 100 | let soln = eq (inverse $ eye' 3 3) (eye' 3 3) 101 | all soln `shouldBe` True 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /hasktorch/test/GradSpec.hs: -------------------------------------------------------------------------------- 1 | module GradSpec (spec) where 2 | 3 | import Test.Hspec 4 | import Control.Exception.Safe 5 | 6 | import Torch.Tensor 7 | import Torch.DType 8 | import Torch.TensorFactories 9 | import Torch.Functions 10 | import Torch.TensorOptions 11 | import Torch.Autograd 12 | 13 | spec :: Spec 14 | spec = do 15 | it "grad with ones" $ do 16 | xi <- makeIndependent $ ones' [] 17 | let x = toDependent xi 18 | y = x * x + 5 * x + 3 19 | fmap toDouble (grad y [xi]) `shouldBe` [7.0] 20 | it "grad with ones" $ do 21 | xi <- makeIndependent $ ones' [] 22 | yi <- makeIndependent $ ones' [] 23 | let x = toDependent xi 24 | y = toDependent yi 25 | z = x * x * y 26 | fmap toDouble (grad z [xi]) `shouldBe` [2.0] 27 | fmap toDouble (grad z [yi]) `shouldBe` [1.0] 28 | -------------------------------------------------------------------------------- /hasktorch/test/NNSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE FunctionalDependencies #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE NoMonomorphismRestriction #-} 5 | 6 | module NNSpec(spec) where 7 | 8 | import Test.Hspec 9 | import Control.Exception.Safe 10 | import Control.Monad.State.Strict 11 | 12 | import Torch.Tensor 13 | import Torch.NN 14 | import GHC.Generics 15 | 16 | spec :: Spec 17 | spec = do 18 | it "create flatten-parameters of Linear" $ do 19 | init <- sample $ LinearSpec { in_features = 3, out_features = 1 } 20 | init2 <- sample $ LinearSpec { in_features = 3, out_features = 1 } 21 | length (flattenParameters init) `shouldBe` 2 22 | length (flattenParameters (fst (flip runState (flattenParameters init2) (replaceOwnParameters init)))) `shouldBe` 2 23 | it "create flatten-parameters of [Linear]" $ do 24 | i0 <- sample $ LinearSpec { in_features = 3, out_features = 1 } 25 | i1 <- sample $ LinearSpec { in_features = 3, out_features = 1 } 26 | i2 <- sample $ LinearSpec { in_features = 3, out_features = 1 } 27 | i3 <- sample $ LinearSpec { in_features = 3, out_features = 1 } 28 | let init = [i0,i1] 29 | init2 = [i2,i3] 30 | length (flattenParameters init) `shouldBe` 4 31 | length (flattenParameters (fst (flip runState (flattenParameters init2) (replaceOwnParameters init)))) `shouldBe` 4 32 | -------------------------------------------------------------------------------- /hasktorch/test/SparseSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoMonomorphismRestriction #-} 2 | 3 | module SparseSpec(spec) where 4 | 5 | import Prelude hiding (abs, exp, floor, log, min, max) 6 | 7 | import Test.Hspec 8 | import Control.Exception.Safe 9 | 10 | import Torch.Tensor 11 | import Torch.DType 12 | import Torch.Layout 13 | import Torch.TensorFactories 14 | import Torch.Functions 15 | import Torch.TensorOptions 16 | 17 | spec :: Spec 18 | spec = do 19 | it "create sparse tensor" $ do 20 | let i = [[0, 1, 1], 21 | [2, 0, 2]] :: [[Int]] 22 | v = [3, 4, 5] :: [Float] 23 | let x = sparseCooTensor' (asTensor i) (asTensor v) [2, 3] 24 | (shape (asTensor i)) `shouldBe` [2,3] 25 | (shape (asTensor v)) `shouldBe` [3] 26 | print (toDense x) 27 | -- When we call print for sparse tensor, it throws a exception. 28 | (print x) `shouldThrow` anyException 29 | (asValue (toDense x) :: [[Float]]) `shouldBe` [[0.0,0.0,3.0],[4.0,0.0,5.0]] 30 | (asValue (toDense (x+x)) :: [[Float]]) `shouldBe` [[0.0,0.0,6.0],[8.0,0.0,10.0]] 31 | (asValue (toDense (toSparse (toDense (x+x)))) :: [[Float]]) `shouldBe` [[0.0,0.0,6.0],[8.0,0.0,10.0]] 32 | it "zeros sparse tensor" $ do 33 | let x = zeros [2, 3] $ withLayout Sparse defaultOpts 34 | (print x) `shouldThrow` anyException 35 | (asValue (toDense x) :: [[Float]]) `shouldBe` [[0.0,0.0,0.0],[0.0,0.0,0.0]] 36 | it "large sparse tensor" $ do 37 | let x = zeros [1000,1000,1000] $ withLayout Sparse defaultOpts 38 | shape x `shouldBe` [1000,1000,1000] 39 | -------------------------------------------------------------------------------- /hasktorch/test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | -------------------------------------------------------------------------------- /hasktorch/test/TensorSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables #-} 2 | 3 | module TensorSpec (spec) where 4 | 5 | import Test.Hspec 6 | import Test.QuickCheck 7 | import Control.Exception.Safe 8 | 9 | import Torch.Tensor 10 | import Torch.DType 11 | import Torch.TensorFactories 12 | import Torch.Functions 13 | import Torch.TensorOptions 14 | import Data.Word 15 | import Data.Int 16 | 17 | spec :: Spec 18 | spec = do 19 | it "TensorLike Word8" $ property $ 20 | \x -> asValue (asTensor x) `shouldBe` (x :: Word8) 21 | it "TensorLike Int8" $ property $ 22 | \x -> asValue (asTensor x) `shouldBe` (x :: Int8) 23 | it "TensorLike Int16" $ property $ 24 | \x -> asValue (asTensor x) `shouldBe` (x :: Int16) 25 | it "TensorLike Int32" $ property $ 26 | \x -> asValue (asTensor x) `shouldBe` (x :: Int32) 27 | it "TensorLike Int" $ property $ 28 | \x -> asValue (asTensor x) `shouldBe` (x :: Int) 29 | it "TensorLike Int64" $ property $ 30 | \x -> asValue (asTensor x) `shouldBe` (x :: Int64) 31 | it "TensorLike Float" $ property $ 32 | \x -> asValue (asTensor x) `shouldBe` (x :: Float) 33 | it "TensorLike Double" $ property $ 34 | \x -> asValue (asTensor x) `shouldBe` (x :: Double) 35 | 36 | it "TensorLike [Word8]" $ property $ 37 | \(NonEmpty (x :: [Word8])) -> do 38 | asValue (asTensor x) `shouldBe` x 39 | toDouble (select (asTensor x) 0 0) `shouldBe` fromIntegral (head x) 40 | shape (asTensor x) `shouldBe` [length x] 41 | let xx = replicate 5 x 42 | asValue (asTensor xx) `shouldBe` xx 43 | let xxx = replicate 3 xx 44 | asValue (asTensor xxx) `shouldBe` xxx 45 | it "TensorLike [Int8]" $ property $ 46 | \(NonEmpty (x :: [Int8])) -> do 47 | asValue (asTensor x) `shouldBe` x 48 | toDouble (select (asTensor x) 0 0) `shouldBe` fromIntegral (head x) 49 | shape (asTensor x) `shouldBe` [length x] 50 | let xx = replicate 5 x 51 | asValue (asTensor xx) `shouldBe` xx 52 | let xxx = replicate 3 xx 53 | asValue (asTensor xxx) `shouldBe` xxx 54 | it "TensorLike [Int16]" $ property $ 55 | \(NonEmpty (x :: [Int16])) -> do 56 | asValue (asTensor x) `shouldBe` x 57 | toDouble (select (asTensor x) 0 0) `shouldBe` fromIntegral (head x) 58 | shape (asTensor x) `shouldBe` [length x] 59 | let xx = replicate 5 x 60 | asValue (asTensor xx) `shouldBe` xx 61 | let xxx = replicate 3 xx 62 | asValue (asTensor xxx) `shouldBe` xxx 63 | it "TensorLike [Int32]" $ property $ 64 | \(NonEmpty (x :: [Int32])) -> do 65 | asValue (asTensor x) `shouldBe` x 66 | toDouble (select (asTensor x) 0 0) `shouldBe` fromIntegral (head x) 67 | shape (asTensor x) `shouldBe` [length x] 68 | let xx = replicate 5 x 69 | asValue (asTensor xx) `shouldBe` xx 70 | let xxx = replicate 3 xx 71 | asValue (asTensor xxx) `shouldBe` xxx 72 | it "TensorLike [Int]" $ property $ 73 | \(NonEmpty (x :: [Int])) -> do 74 | asValue (asTensor x) `shouldBe` x 75 | toDouble (select (asTensor x) 0 0) `shouldBe` fromIntegral (head x) 76 | shape (asTensor x) `shouldBe` [length x] 77 | let xx = replicate 5 x 78 | asValue (asTensor xx) `shouldBe` xx 79 | let xxx = replicate 3 xx 80 | asValue (asTensor xxx) `shouldBe` xxx 81 | it "TensorLike [Int64]" $ property $ 82 | \(NonEmpty (x :: [Int64])) -> do 83 | asValue (asTensor x) `shouldBe` x 84 | toDouble (select (asTensor x) 0 0) `shouldBe` fromIntegral (head x) 85 | shape (asTensor x) `shouldBe` [length x] 86 | let xx = replicate 5 x 87 | asValue (asTensor xx) `shouldBe` xx 88 | let xxx = replicate 3 xx 89 | asValue (asTensor xxx) `shouldBe` xxx 90 | it "TensorLike [Float]" $ property $ 91 | \(NonEmpty (x :: [Float])) -> do 92 | asValue (asTensor x) `shouldBe` x 93 | toDouble (select (asTensor x) 0 0) `shouldBe` realToFrac (head x) 94 | shape (asTensor x) `shouldBe` [length x] 95 | let xx = replicate 5 x 96 | asValue (asTensor xx) `shouldBe` xx 97 | let xxx = replicate 3 xx 98 | asValue (asTensor xxx) `shouldBe` xxx 99 | it "TensorLike [Double]" $ property $ 100 | \(NonEmpty (x :: [Double])) -> do 101 | asValue (asTensor x) `shouldBe` x 102 | toDouble (select (asTensor x) 0 0) `shouldBe` realToFrac (head x) 103 | shape (asTensor x) `shouldBe` [length x] 104 | let xx = replicate 5 x 105 | asValue (asTensor xx) `shouldBe` xx 106 | let xxx = replicate 3 xx 107 | asValue (asTensor xxx) `shouldBe` xxx 108 | 109 | it "invalid cast of TensorLike a" $ do 110 | let x = asTensor (10 :: Int) 111 | (dtype x) `shouldBe` Int64 112 | (print (asValue x :: Double)) `shouldThrow` anyException 113 | it "invalid cast of TensorLike [a]" $ do 114 | let x = asTensor ([0..10] :: [Int]) 115 | (print (asValue x :: [Double])) `shouldThrow` anyException 116 | 117 | it "lists having different length" $ do 118 | (print (asTensor ([[1],[1,2]] :: [[Double]]))) `shouldThrow` anyException 119 | -------------------------------------------------------------------------------- /setenv: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # execute this command if needed when building on OSX if there are linker errors. 4 | # dylib files in extra-lib-dirs don't get forwarded to ghc 5 | # in some versions of OSX. See https://github.com/commercialhaskell/stack/issues/1826 6 | HASKTORCH_LIB_PATH="$(pwd)/deps/libtorch/lib/:$(pwd)/deps/mklml/lib/" 7 | 8 | function add_vendor_lib_path { 9 | case "$(uname)" in 10 | "Darwin") 11 | DYLD_LIBRARY_PATH=$HASKTORCH_LIB_PATH:$DYLD_LIBRARY_PATH 12 | export DYLD_LIBRARY_PATH 13 | ;; 14 | "Linux"|"FreeBSD") 15 | LD_LIBRARY_PATH=$HASKTORCH_LIB_PATH:$LD_LIBRARY_PATH 16 | export LD_LIBRARY_PATH 17 | ;; 18 | *) 19 | echo "OS doesn't have known environment variable hacks to set" 20 | ;; 21 | esac 22 | } 23 | 24 | if ! type git &> /dev/null; then 25 | echo "git is not installed, setenv cannot reliably perform checks to set your system's library path" 26 | fi 27 | 28 | if [[ "$(basename "$(git rev-parse --show-toplevel)")" == "ffi-experimental" ]] &> /dev/null; then 29 | echo "updating library path..." 30 | add_vendor_lib_path 31 | echo "...done!" 32 | else 33 | echo "couldn't update library path. Please file an issue or adjust this script for your system and submit a pull request" 34 | fi 35 | -------------------------------------------------------------------------------- /setup-cabal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | curl https://www.stackage.org/lts-13.23/cabal.config | grep -v inline-c > cabal.project.freeze 6 | 7 | cat < cabal.project.local 8 | 9 | extra-include-dirs: 10 | $(pwd)/deps/libtorch/include/torch/csrc/api/include 11 | , $(pwd)/deps/libtorch/include 12 | 13 | extra-lib-dirs: 14 | $(pwd)/deps/libtorch/lib 15 | , $(pwd)/deps/mklml/lib 16 | 17 | EOF 18 | -------------------------------------------------------------------------------- /spec/Declarations.yaml: -------------------------------------------------------------------------------- 1 | ../deps/pytorch/build/aten/src/ATen/Declarations.yaml -------------------------------------------------------------------------------- /spec/README.md: -------------------------------------------------------------------------------- 1 | # Note on `native_functions.yaml` 2 | 3 | If for some reason you'd want to use that file instead of Declarations.yaml, be aware that two of the functions listed in there(_dimI, _dimV) are problematic. This is because their dispatch field, which is usually a record, is just a string. Fortunately those are deprecated, so it is ok to filter them out. 4 | -------------------------------------------------------------------------------- /spec/bindings.yaml: -------------------------------------------------------------------------------- 1 | # This file controls functions in hasktorch/src/Torch/Functions/Native.hs. 2 | # The functions are pure-functions not using IO-monad. 3 | - rename: {src: where_ttt, dst: where'} 4 | - rename: {src: min_t, dst: minAll} 5 | - rename: {src: max_t, dst: maxAll} 6 | - rename: {src: median_t, dst: medianAll} 7 | 8 | # detach needs IO-monad to create new-tensor. 9 | - remove: {src: detach_t} 10 | 11 | # remove already registered functions 12 | - remove: {src: transpose_tll} 13 | - remove: {src: select_tll} 14 | - remove: {src: reshape_tl} 15 | - remove: {src: erf_t} 16 | - remove: {src: floor_t} 17 | - remove: {src: cos_t} 18 | - remove: {src: exp_t} 19 | - remove: {src: sin_t} 20 | - remove: {src: sinh_t} 21 | - remove: {src: sqrt_t} 22 | - remove: {src: tanh_t} 23 | - remove: {src: abs_t} 24 | - remove: {src: log1p_t} 25 | - remove: {src: ceil_t} 26 | - remove: {src: conv2d_tttllll} 27 | - remove: {src: diag_tl} 28 | - remove: {src: gels_tt} 29 | - remove: {src: log10_t} 30 | - remove: {src: log2_t} 31 | - remove: {src: matmul_tt} 32 | - remove: {src: mse_loss_ttl} 33 | - remove: {src: relu_t} 34 | - remove: {src: selu_t} 35 | - remove: {src: sigmoid_t} 36 | - remove: {src: size_tl} 37 | - remove: {src: all_t} 38 | - remove: {src: any_t} 39 | - remove: {src: qr_t} 40 | - remove: {src: geqrf_t} 41 | - remove: {src: orgqr_tt} 42 | - remove: {src: sign_t} 43 | - remove: {src: inverse_t} 44 | - remove: {src: cholesky_tb} 45 | - remove: {src: cholesky_inverse_tb} 46 | - remove: {src: cholesky_solve_ttb} 47 | - remove: {src: eig_tb} 48 | - remove: {src: solve_tt} 49 | - remove: {src: svd_tbb} 50 | - remove: {src: symeig_tbb} 51 | 52 | # remove overloaded functions 53 | - remove: {src: add_tts} 54 | - remove: {src: add_tss} 55 | - remove: {src: all_tlb} 56 | - remove: {src: any_tlb} 57 | - remove: {src: cumsum_tls} 58 | - remove: {src: cumsum_tl} 59 | - remove: {src: cumprod_tls} 60 | - remove: {src: cumprod_tl} 61 | - remove: {src: ctc_loss_ttllllb} 62 | - remove: {src: ctc_loss_ttttllb} 63 | - remove: {src: div_tt} 64 | - remove: {src: div_ts} 65 | - remove: {src: log_softmax_tls} 66 | - remove: {src: log_softmax_tl} 67 | - remove: {src: matrix_rank_tdb} 68 | - remove: {src: matrix_rank_tb} 69 | - remove: {src: max_tlb} 70 | - remove: {src: mean_ts} 71 | - remove: {src: mean_t} 72 | - remove: {src: mean_tlbs} 73 | - remove: {src: mean_tlb} 74 | - remove: {src: mean_tls} 75 | - remove: {src: median_tlb} 76 | - remove: {src: min_tlb} 77 | - remove: {src: mul_tt} 78 | - remove: {src: mul_ts} 79 | - remove: {src: randint_like_tl} 80 | - remove: {src: randint_like_tll} 81 | - remove: {src: repeat_interleave_t} 82 | - remove: {src: repeat_interleave_ttl} 83 | - remove: {src: repeat_interleave_tll} 84 | - remove: {src: softmax_tls} 85 | - remove: {src: softmax_tl} 86 | - remove: {src: squeeze_t} 87 | - remove: {src: squeeze_tl} 88 | - remove: {src: sum_t} 89 | - remove: {src: sum_ts} 90 | - remove: {src: sum_tlbs} 91 | - remove: {src: sum_tlb} 92 | - remove: {src: sum_tls} 93 | - remove: {src: std_tb} 94 | - remove: {src: std_tlbb} 95 | - remove: {src: prod_ts} 96 | - remove: {src: prod_t} 97 | - remove: {src: prod_tlbs} 98 | - remove: {src: prod_tlb} 99 | - remove: {src: prod_tls} 100 | - remove: {src: var_tb} 101 | - remove: {src: var_tlbb} 102 | - remove: {src: norm_tss} 103 | - remove: {src: norm_ts} 104 | - remove: {src: norm_tslbs} 105 | - remove: {src: norm_tslb} 106 | - remove: {src: frobenius_norm_t} 107 | - remove: {src: frobenius_norm_tlb} 108 | - remove: {src: pow_ts} 109 | - remove: {src: sub_tts} 110 | - remove: {src: sub_tss} 111 | - remove: {src: rsub_tts} 112 | - remove: {src: rsub_tss} 113 | - remove: {src: lstm_tllbldbbb} 114 | - remove: {src: lstm_ttllbldbb} 115 | - remove: {src: gru_ttlbldbbb} 116 | - remove: {src: gru_tttlbldbb} 117 | - remove: {src: rnn_tanh_ttlbldbbb} 118 | - remove: {src: rnn_tanh_tttlbldbb} 119 | - remove: {src: rnn_relu_ttlbldbbb} 120 | - remove: {src: rnn_relu_tttlbldbb} 121 | - remove: {src: masked_fill_tts} 122 | - remove: {src: masked_fill_ttt} 123 | - remove: {src: index_fill_tlts} 124 | - remove: {src: index_fill_tltt} 125 | - remove: {src: scatter_tltt} 126 | - remove: {src: scatter_tlts} 127 | - remove: {src: ne_ts} 128 | - remove: {src: ne_tt} 129 | - remove: {src: eq_ts} 130 | - remove: {src: eq_tt} 131 | - remove: {src: ge_ts} 132 | - remove: {src: ge_tt} 133 | - remove: {src: le_ts} 134 | - remove: {src: le_tt} 135 | - remove: {src: gt_ts} 136 | - remove: {src: gt_tt} 137 | - remove: {src: lt_ts} 138 | - remove: {src: lt_tt} 139 | - remove: {src: lerp_tts} 140 | - remove: {src: lerp_ttt} 141 | - remove: {src: fmod_ts} 142 | - remove: {src: fmod_tt} 143 | - remove: {src: remainder_ts} 144 | - remove: {src: remainder_tt} 145 | - remove: {src: min_tt} 146 | - remove: {src: max_tt} 147 | - remove: {src: pow_tt} 148 | - remove: {src: pow_st} 149 | -------------------------------------------------------------------------------- /spec/cppclass/array.yaml: -------------------------------------------------------------------------------- 1 | name: std::array 2 | hstype: std::array 3 | functions: [] 4 | constructors: 5 | - new(bool,bool,bool,bool) -> std::array 6 | methods: 7 | - empty() -> bool 8 | - size() -> size_t 9 | - at(size_t) -> bool 10 | -------------------------------------------------------------------------------- /spec/cppclass/context.yaml: -------------------------------------------------------------------------------- 1 | signature: Context 2 | cppname: at::Context 3 | hsname: Context 4 | functions: 5 | - init() -> void 6 | - hasCUDA() -> bool 7 | - hasHIP() -> bool 8 | - hasXLA() -> bool 9 | - getNumGPUs() -> size_t 10 | - hasOpenMP() -> bool 11 | - hasMKL() -> bool 12 | - hasLAPACK() -> bool 13 | - hasMAGMA() -> bool 14 | - hasMKLDNN() -> bool 15 | - manual_seed(uint64_t seed) -> void 16 | constructors: [] 17 | methods: [] 18 | -------------------------------------------------------------------------------- /spec/cppclass/gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cat < spec/cppclass/tensor.yaml 4 | class: Tensor 5 | constructors: 6 | - new() -> Tensor 7 | - new(Tensor) -> Tensor 8 | methods: 9 | EOF 10 | 11 | cat deps/libtorch/include/ATen/core/Tensor.h \ 12 | | grep -vi c10 \ 13 | | grep -vi TensorImpl \ 14 | | perl -ne 'BEGIN{$f=0;}; {if (/class Tensor/){$f=1;}; if (/^};/ || /Tensor alias/){print $_;$f=0;}; if($f==1){print $_;}};' \ 15 | | grep '^ [^ /].*(.*)' \ 16 | | sed -e 's/).*/)/g' \ 17 | | sed -e 's/const //g' \ 18 | | sed -e 's/inline //g' \ 19 | | sed -e 's/&//g' \ 20 | | sed -e 's/^ *//g' \ 21 | | sed -e 's/^char \*/char*/g' \ 22 | | sed -e 's/^T \*/T*/g' \ 23 | | sed -e 's/ *\([^ ]*\) \(.*\)$/\2 -> \1/g' \ 24 | | sed -e 's/^ *//g' \ 25 | | sed -e 's/^/- /g' \ 26 | >> spec/cppclass/tensor.yaml 27 | 28 | 29 | cat < spec/cppclass/tensor.yaml 30 | class: Tensor 31 | constructors: 32 | - new() -> Tensor 33 | - new(Tensor) -> Tensor 34 | methods: 35 | EOF 36 | 37 | cat deps/libtorch/include/ATen/core/Tensor.h \ 38 | | grep -vi c10 \ 39 | | grep -vi TensorImpl \ 40 | | perl -ne 'BEGIN{$f=0;}; {if (/class Tensor/){$f=1;}; if (/^};/ || /Tensor alias/){print $_;$f=0;}; if($f==1){print $_;}};' \ 41 | | grep '^ [^ /].*(.*)' \ 42 | | sed -e 's/).*/)/g' \ 43 | | sed -e 's/const //g' \ 44 | | sed -e 's/inline //g' \ 45 | | sed -e 's/&//g' \ 46 | | sed -e 's/^ *//g' \ 47 | | sed -e 's/^char \*/char*/g' \ 48 | | sed -e 's/^T \*/T*/g' \ 49 | | sed -e 's/ *\([^ ]*\) \(.*\)$/\2 -> \1/g' \ 50 | | sed -e 's/^ *//g' \ 51 | | sed -e 's/^/- /g' \ 52 | >> spec/cppclass/tensor.yaml 53 | -------------------------------------------------------------------------------- /spec/cppclass/generator.yaml: -------------------------------------------------------------------------------- 1 | signature: Generator 2 | cppname: at::Generator 3 | hsname: Generator 4 | functions: [] 5 | constructors: [] 6 | methods: [] 7 | -------------------------------------------------------------------------------- /spec/cppclass/intarray.yaml: -------------------------------------------------------------------------------- 1 | signature: IntArray 2 | cppname: std::vector 3 | hsname: IntArray 4 | functions: [] 5 | constructors: 6 | - new() -> IntArray 7 | methods: 8 | - empty() -> bool 9 | - size() -> size_t 10 | - at(size_t s) -> int64_t 11 | - push_back(int64_t v) -> void 12 | -------------------------------------------------------------------------------- /spec/cppclass/scalar.yaml: -------------------------------------------------------------------------------- 1 | signature: Scalar 2 | cppname: at::Scalar 3 | hsname: Scalar 4 | functions: [] 5 | constructors: 6 | - new() -> Scalar 7 | - new(int a) -> Scalar 8 | - new(double a) -> Scalar 9 | methods: [] 10 | -------------------------------------------------------------------------------- /spec/cppclass/sparsetensorref.yaml: -------------------------------------------------------------------------------- 1 | signature: SparseTensorRef 2 | cppname: at::SparseTensorRef 3 | hsname: SparseTensorRef 4 | functions: [] 5 | constructors: 6 | - new(Tensor x) -> SparseTensorRef 7 | methods: [] 8 | -------------------------------------------------------------------------------- /spec/cppclass/storage.yaml: -------------------------------------------------------------------------------- 1 | signature: Storage 2 | cppname: at::Storage 3 | hsname: Storage 4 | functions: [] 5 | constructors: 6 | - new() -> Storage 7 | methods: [] 8 | -------------------------------------------------------------------------------- /spec/cppclass/tensorlist.yaml: -------------------------------------------------------------------------------- 1 | signature: TensorList 2 | cppname: std::vector 3 | hsname: TensorList 4 | functions: [] 5 | constructors: 6 | - new() -> TensorList 7 | methods: 8 | - empty() -> bool 9 | - size() -> size_t 10 | - at(size_t s) -> Tensor 11 | - push_back(Tensor v) -> void 12 | -------------------------------------------------------------------------------- /spec/cppclass/tensoroptions.yaml: -------------------------------------------------------------------------------- 1 | signature: TensorOptions 2 | cppname: at::TensorOptions 3 | hsname: TensorOptions 4 | functions: 5 | - dtype(ScalarType dtype) -> TensorOptions 6 | - layout(Layout layout) -> TensorOptions 7 | - device(Device device) -> TensorOptions 8 | - device_index(int16_t device_index) -> TensorOptions 9 | - requires_grad(bool requires_grad) -> TensorOptions 10 | constructors: 11 | #- new(Device d) -> TensorOptions 12 | #- new(Backend d) -> TensorOptions 13 | - new(ScalarType d) -> TensorOptions 14 | #- new(Layout d) -> TensorOptions 15 | methods: 16 | #- operator==(TensorOptions other) -> bool 17 | #- operator!=(TensorOptions other) -> bool 18 | #- device(c10::optional device) -> TensorOptions 19 | - device(Device device) -> TensorOptions 20 | #- device(Args... args) -> TensorOptions 21 | - device_index(int16_t device_index) -> TensorOptions 22 | #- dtype(c10::optional dtype) -> TensorOptions 23 | #- dtype(c10::optional dtype) -> TensorOptions 24 | - dtype(ScalarType dtype) -> TensorOptions 25 | - dtype() -> TensorOptions 26 | #- layout(c10::optional layout) -> TensorOptions 27 | - layout(Layout layout) -> TensorOptions 28 | #- requires_grad(c10::optional requires_grad) -> TensorOptions 29 | - requires_grad(bool requires_grad) -> TensorOptions 30 | #- is_variable(c10::optional is_variable) -> TensorOptions 31 | - is_variable(bool is_variable) -> TensorOptions 32 | #- device() -> Device 33 | - has_device() -> bool 34 | #- device_opt() -> c10::optional 35 | - device_index() -> int32_t 36 | #- dtype() -> caffe2::TypeMeta 37 | - has_dtype() -> bool 38 | #- dtype_opt() -> c10::optional 39 | - layout() -> Layout 40 | - has_layout() -> bool 41 | #- layout_opt() -> c10::optional 42 | - requires_grad() -> bool 43 | - has_requires_grad() -> bool 44 | #- requires_grad_opt() -> c10::optional 45 | - is_variable() -> bool 46 | - has_is_variable() -> bool 47 | #- is_variable_opt() -> c10::optional 48 | - backend() -> Backend 49 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-13.1 2 | 3 | packages: 4 | - codegen 5 | - ffi 6 | - inline-c/inline-c 7 | - inline-c/inline-c-cpp 8 | - hasktorch 9 | - examples 10 | 11 | extra-include-dirs: 12 | - deps/libtorch/include/torch/csrc/api/include 13 | - deps/libtorch/include 14 | 15 | extra-lib-dirs: 16 | - deps/libtorch/lib 17 | - deps/mklml/lib 18 | 19 | # see https://github.com/commercialhaskell/stack/issues/4073 20 | # with-gcc: /usr/local/bin/gcc-7 21 | 22 | extra-deps: 23 | - template-0.2.0.10@sha256:f822de4d34c45bc84b33a61bc112c15fedee6fa6dc414c62b10456395a868f85 24 | --------------------------------------------------------------------------------