├── dataframe ├── README.md ├── Setup.hs ├── src │ ├── Barbies │ │ ├── TH │ │ │ └── Config.hs │ │ └── FieldName.hs │ └── Generics │ │ └── SOP │ │ └── Record │ │ ├── Combination.hs │ │ └── SubTyping.hs ├── package.yaml └── dataframe.cabal ├── default-type-plugin ├── README.md ├── Setup.hs ├── default-type-plugin.cabal ├── LICENSE └── src │ └── Plugin │ └── DefaultType.hs ├── haskell-torch-models ├── README.md ├── Setup.hs ├── package.yaml ├── haskell-torch-models.cabal └── src │ └── Torch │ └── Models │ └── Vision │ └── AlexNet.hs ├── haskell-torch-datasets ├── README.md ├── Setup.hs ├── src │ └── Torch │ │ ├── Datasets.hs │ │ └── Datasets │ │ └── Vision │ │ └── CIFAR.hs ├── package.yaml └── haskell-torch-datasets.cabal ├── logo.png ├── shell-with-jupyter.nix ├── simplify-nat-algebra-plugin ├── README.md ├── tests │ └── Main.hs ├── LICENSE └── simplify-nat-algebra-plugin.cabal ├── interpolateIO ├── test │ ├── Spec.hs │ ├── doctests.hs │ └── Data │ │ └── String │ │ ├── InterpolateIO │ │ ├── IsStringSpec.hs │ │ ├── ParseSpec.hs │ │ ├── UtilSpec.hs │ │ └── Internal │ │ │ └── UtilSpec.hs │ │ └── InterpolateIOSpec.hs ├── .ghci ├── Setup.lhs ├── src │ └── Data │ │ └── String │ │ ├── InterpolateIO │ │ ├── Parse.hs │ │ ├── IsString.hs │ │ ├── Compat.hs │ │ └── Util.hs │ │ ├── InterpolateIO.hs │ │ └── ShowIO.hs ├── package.yaml ├── LICENSE ├── README.markdown └── interpolateIO.cabal ├── .ghci ├── haskell-torch-tools ├── Setup.hs ├── README.md ├── package.yaml └── haskell-torch-tools.cabal ├── logo-with-text.png ├── haskell-torch ├── test │ ├── Doctest.hs │ ├── doctest.json │ └── Spec.hs ├── Setup.hs ├── README.md ├── src │ └── Torch.hs ├── package.yaml └── haskell-torch.cabal ├── haskell-torch-matio ├── Setup.hs ├── README.md ├── package.yaml └── haskell-torch-matio.cabal ├── haskell-torch-cbindings ├── Setup.hs ├── README.md ├── src │ └── Torch │ │ └── C │ │ ├── Language.hs │ │ ├── CUDA.hs │ │ ├── Generator.hs │ │ ├── Scalar.hs │ │ └── Types.hs ├── package.yaml └── haskell-torch-cbindings.cabal ├── haskell-torch-imagemagick ├── Setup.hs ├── README.md ├── package.yaml └── haskell-torch-imagemagick.cabal ├── haskell-torch-examples ├── Setup.hs ├── README.md ├── src │ └── Torch │ │ └── Tutorial │ │ ├── RL │ │ └── Simple.hs │ │ ├── README.md │ │ ├── Intro │ │ ├── T02_LinearRegression.hs │ │ ├── T03_LogisticRegression.hs │ │ ├── T04_FeedforwardNN.hs │ │ ├── T07_RNN.hs │ │ ├── T08_BiRNN.hs │ │ ├── T05_CNN.hs │ │ └── T11_VAE.hs │ │ └── Tensorboard.hs ├── package.yaml └── haskell-torch-examples.cabal ├── ihaskell-matplotlib ├── Setup.hs ├── IHaskell │ └── Display │ │ └── Matplotlib.hs ├── package.yaml └── ihaskell-matplotlib.cabal ├── haskell-torch-tensorboard-proto ├── Setup.hs ├── README.md ├── src │ └── Tensorboard │ │ └── Proto │ │ ├── Event.hs │ │ ├── Attributes.hs │ │ ├── Tensor.hs │ │ ├── Summary.hs │ │ └── Graph.hs ├── proto │ └── tensorboard │ │ └── src │ │ ├── resource_handle.proto │ │ ├── versions.proto │ │ ├── tensor_shape.proto │ │ ├── types.proto │ │ ├── graph.proto │ │ ├── event.proto │ │ ├── attr_value.proto │ │ ├── node_def.proto │ │ ├── tensor.proto │ │ └── summary.proto ├── package.yaml └── haskell-torch-tensorboard-proto.cabal ├── common-paths.yaml ├── ChangeLog.md ├── cabal.project ├── .gitignore ├── ihaskell-dynamic.diff ├── makefile ├── patches ├── jupyter-nbconvert-fix-theme-6.1.patch ├── jupytext-add-haskell.patch ├── ihaskell-fixup-set-master-nov-01-2021.patch └── ihaskell-fixup-set-0.10.2.1.patch ├── docs ├── index.rst ├── Makefile └── conf.py ├── hie.yaml ├── stack.yaml ├── LICENSE ├── .hlint.yaml ├── setup.sh ├── stack.yaml.lock ├── .circleci └── config.yml ├── nix └── sources.json └── Plugins.ipynb /dataframe/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch datasets 2 | -------------------------------------------------------------------------------- /default-type-plugin/README.md: -------------------------------------------------------------------------------- 1 | # Type defaulting plugin 2 | -------------------------------------------------------------------------------- /haskell-torch-models/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch models 2 | -------------------------------------------------------------------------------- /haskell-torch-datasets/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch datasets 2 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abarbu/haskell-torch/HEAD/logo.png -------------------------------------------------------------------------------- /shell-with-jupyter.nix: -------------------------------------------------------------------------------- 1 | (import ./shell.nix) { withJupyter = true; } 2 | -------------------------------------------------------------------------------- /simplify-nat-algebra-plugin/README.md: -------------------------------------------------------------------------------- 1 | # Simplify type-level algebra 2 | -------------------------------------------------------------------------------- /interpolateIO/test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | -------------------------------------------------------------------------------- /.ghci: -------------------------------------------------------------------------------- 1 | :set -fobject-code 2 | :set -O0 3 | :set prompt-cont | 4 | :set prompt > 5 | -------------------------------------------------------------------------------- /default-type-plugin/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /haskell-torch-tools/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /logo-with-text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abarbu/haskell-torch/HEAD/logo-with-text.png -------------------------------------------------------------------------------- /interpolateIO/.ghci: -------------------------------------------------------------------------------- 1 | :set -isrc -itest -optP-include -optPdist/build/autogen/cabal_macros.h 2 | -------------------------------------------------------------------------------- /haskell-torch/test/Doctest.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF doctest-discover -optF test/doctest.json #-} 2 | -------------------------------------------------------------------------------- /interpolateIO/Setup.lhs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env runhaskell 2 | > import Distribution.Simple 3 | > main = defaultMain 4 | -------------------------------------------------------------------------------- /haskell-torch-matio/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | 4 | main = defaultMain 5 | 6 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | 4 | main = defaultMain 5 | 6 | -------------------------------------------------------------------------------- /haskell-torch-imagemagick/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | 4 | main = defaultMain 5 | 6 | -------------------------------------------------------------------------------- /haskell-torch/test/doctest.json: -------------------------------------------------------------------------------- 1 | {"doctestOptions": ["-XScopedTypeVariables","-XScopedTypeVariables","-XTypeApplications","-XDataKinds"] } 2 | -------------------------------------------------------------------------------- /dataframe/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMain 6 | 7 | -------------------------------------------------------------------------------- /haskell-torch/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMain 6 | 7 | -------------------------------------------------------------------------------- /haskell-torch-datasets/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMain 6 | 7 | -------------------------------------------------------------------------------- /haskell-torch-examples/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMain 6 | 7 | -------------------------------------------------------------------------------- /haskell-torch-models/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMain 6 | 7 | -------------------------------------------------------------------------------- /ihaskell-matplotlib/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMain 6 | 7 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | import System.Environment 3 | import Data.ProtoLens.Setup 4 | 5 | main = defaultMainGeneratingProtos "proto" 6 | -------------------------------------------------------------------------------- /haskell-torch-matio/README.md: -------------------------------------------------------------------------------- 1 | # Haskell matio bindings to read matlab files 2 | 3 | This package is part of the Haskell-Torch ecosystem. Check out the documentation 4 | [there](https://github.com/abarbu/haskell-torch). 5 | -------------------------------------------------------------------------------- /haskell-torch/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch package 2 | 3 | This is the user-facing part of Haskell-Torch. [Look at the root of the 4 | repository for information on how to use it](https://github.com/abarbu/haskell-torch). 5 | -------------------------------------------------------------------------------- /common-paths.yaml: -------------------------------------------------------------------------------- 1 | include-dirs: 2 | - /nix/store/crrwcbzgbbcjivf36kmminsmv9a0jv7v-imagemagick-7.1.0-4-dev/include/ImageMagick 3 | extra-lib-dirs: 4 | - /nix/store/agp5fqd06czyd21ivxw6qxm4xh94c8sc-imagemagick-7.1.0-4/lib/ 5 | -------------------------------------------------------------------------------- /haskell-torch-examples/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch package 2 | 3 | This is the user-facing part of Haskell-Torch. [Look at the root of the 4 | repository for information on how to use it](https://github.com/abarbu/haskell-torch). 5 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Tensorboard protocol buffers bindings 2 | 3 | This package is part of the Haskell-Torch ecosystem. Check out the documentation 4 | [there](https://github.com/abarbu/haskell-torch). 5 | -------------------------------------------------------------------------------- /interpolateIO/test/doctests.hs: -------------------------------------------------------------------------------- 1 | module Main where 2 | 3 | import Test.DocTest 4 | 5 | main :: IO () 6 | main = doctest ["-isrc", "-optP-include", "-optPdist/build/autogen/cabal_macros.h", "src/Data/String/Interpolate.hs"] 7 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch C bindings 2 | 3 | This package is part of the Haskell-Torch ecosystem. Check out the documentation 4 | [there](https://github.com/abarbu/haskell-torch). You will not use these 5 | bindings directly. 6 | -------------------------------------------------------------------------------- /ChangeLog.md: -------------------------------------------------------------------------------- 1 | # Changelog for Haskell-Torch 2 | 3 | ## [0.85] 2021-11-01 4 | 5 | Nix migration. PyTorch 1.8 support. 6 | Custom GHC with ambiguity plugins. 7 | 8 | ## [0.8] 2020-09-15 9 | 10 | The initial 0.8 release is in prep with PyTorch 1.5 support. 11 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/src/Tensorboard/Proto/Event.hs: -------------------------------------------------------------------------------- 1 | module Tensorboard.Proto.Event( 2 | module Proto.Tensorboard.Src.Event 3 | ,module Proto.Tensorboard.Src.Event_Fields 4 | ) 5 | where 6 | import Proto.Tensorboard.Src.Event 7 | import Proto.Tensorboard.Src.Event_Fields 8 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/src/Tensorboard/Proto/Attributes.hs: -------------------------------------------------------------------------------- 1 | module Tensorboard.Proto.Attributes( 2 | module Proto.Tensorboard.Src.AttrValue 3 | ,module Proto.Tensorboard.Src.AttrValue_Fields 4 | ) 5 | where 6 | import Proto.Tensorboard.Src.AttrValue 7 | import Proto.Tensorboard.Src.AttrValue_Fields 8 | -------------------------------------------------------------------------------- /cabal.project: -------------------------------------------------------------------------------- 1 | packages: 2 | haskell-torch-tools/*.cabal 3 | haskell-torch-cbindings/*.cabal 4 | haskell-torch-imagemagick/*.cabal 5 | haskell-torch-matio/*.cabal 6 | haskell-torch-tensorboard-proto/*.cabal 7 | simplify-nat-algebra-plugin/*.cabal 8 | interpolateIO/*.cabal 9 | default-type-plugin/*.cabal 10 | haskell-torch/*.cabal 11 | -------------------------------------------------------------------------------- /haskell-torch-tools/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch tools 2 | 3 | This package is part of the Haskell-Torch ecosystem. Check out the documentation 4 | [there](https://github.com/abarbu/haskell-torch). 5 | 6 | This contains tools required to generate bindings for Haskell-Torch. There's 7 | nothing user-facing or reusable here. It's only used by the build system. 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | dist-* 3 | cabal-dev 4 | *.o 5 | *.hi 6 | *.hie 7 | *.chi 8 | *.chs.h 9 | *.dyn_o 10 | *.dyn_hi 11 | .hpc 12 | .hsenv 13 | .cabal-sandbox/ 14 | cabal.sandbox.config 15 | *.prof 16 | *.aux 17 | *.hp 18 | *.eventlog 19 | .stack-work/ 20 | cabal.project.local 21 | cabal.project.local~ 22 | .HTF/ 23 | .ghc.environment.* 24 | result*/ 25 | *.sqlite* 26 | notes* 27 | datasets/ 28 | -------------------------------------------------------------------------------- /ihaskell-dynamic.diff: -------------------------------------------------------------------------------- 1 | --- ihaskell/ihaskell.cabal 2 | +++ ihaskell/ihaskell.cabal 3 | @@ -123,7 +123,7 @@ executable ihaskell 4 | hs-source-dirs: main 5 | other-modules: 6 | Paths_ihaskell 7 | - ghc-options: -threaded -rtsopts -Wall 8 | + ghc-options: -threaded -rtsopts -Wall -dynamic 9 | 10 | if os(darwin) 11 | ghc-options: -optP-Wno-nonportable-include-path 12 | -------------------------------------------------------------------------------- /haskell-torch-imagemagick/README.md: -------------------------------------------------------------------------------- 1 | # Haskell Torch Imagemagick bindings 2 | 3 | This package is part of the Haskell-Torch ecosystem. Check out the documentation 4 | [there](https://github.com/abarbu/haskell-torch). 5 | 6 | It's only meant to be used as part of Haskell-Torch. The interface isn't user 7 | friendly, only relevant functions are included, and it's not safe. It's an 8 | internal, not a user-facing library. 9 | -------------------------------------------------------------------------------- /haskell-torch-datasets/src/Torch/Datasets.hs: -------------------------------------------------------------------------------- 1 | module Torch.Datasets( 2 | module Torch.Datasets.Common 3 | ,module Torch.Datasets.Augmentation 4 | ,module Torch.Datasets.Vision.MNIST 5 | ,module Torch.Datasets.Vision.CIFAR)where 6 | 7 | import Torch.Datasets.Augmentation 8 | import Torch.Datasets.Common 9 | import Torch.Datasets.Vision.CIFAR 10 | import Torch.Datasets.Vision.MNIST 11 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | TORCH_ROOT=$(shell python -c 'import torch; import inspect; import os; print(os.path.dirname(inspect.getfile(torch)))') 2 | 3 | all: haskell-torch-cbindings/src/Torch/C/Tensor.hs 4 | 5 | haskell-torch-cbindings/src/Torch/C/Tensor.hs: VariableType.processed 6 | stack build haskell-torch-tools && \ 7 | cd haskell-torch-cbindings && \ 8 | stack exec haskell-torch-tools-generate-ctensor src/Torch/C/Tensor.hs ../VariableType.processed 9 | -------------------------------------------------------------------------------- /haskell-torch/test/Spec.hs: -------------------------------------------------------------------------------- 1 | main :: IO () 2 | main = putStrLn "Test suite not yet implemented" 3 | 4 | -- TODO Move these to a real test suite, for now they're just snippets to try things out. 5 | 6 | -- runWithCuda (pure ()) 7 | -- runWithCpu (pure ()) 8 | 9 | -- Torch.C.State.runWithCuda (Torch.C.CUDA.mkTensorOnDevice (3,4) 0) 10 | 11 | -- z :: IO (Ptr HTensor) 12 | -- z = do 13 | -- c <- runWithCuda (mkTensorOnDevice (3::CLong,4::CLong) (Just 0)) 14 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/src/Tensorboard/Proto/Tensor.hs: -------------------------------------------------------------------------------- 1 | module Tensorboard.Proto.Tensor( 2 | module Proto.Tensorboard.Src.Tensor 3 | ,module Proto.Tensorboard.Src.Tensor_Fields 4 | ,module Proto.Tensorboard.Src.TensorShape 5 | ,module Proto.Tensorboard.Src.TensorShape_Fields 6 | ) 7 | where 8 | import Proto.Tensorboard.Src.Tensor 9 | import Proto.Tensorboard.Src.Tensor_Fields 10 | import Proto.Tensorboard.Src.TensorShape 11 | import Proto.Tensorboard.Src.TensorShape_Fields 12 | -------------------------------------------------------------------------------- /patches/jupyter-nbconvert-fix-theme-6.1.patch: -------------------------------------------------------------------------------- 1 | --- a/nbconvert/preprocessors/csshtmlheader.py 1969-12-31 19:00:01.000000000 -0500 2 | +++ b/nbconvert/preprocessors/csshtmlheader.py 2021-11-03 11:57:05.270846344 -0400 3 | @@ -32,7 +32,7 @@ 4 | 5 | style = Union([Unicode('default'), Type(klass=Style)], 6 | help='Name of the pygments style to use', 7 | - default_value=JupyterStyle 8 | + default_value='vs' 9 | ).tag(config=True) 10 | 11 | def __init__(self, *pargs, **kwargs): 12 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Haskell-Torch documentation master file, created by 2 | sphinx-quickstart on Wed Dec 18 00:21:59 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Haskell-Torch's documentation! 7 | ========================================= 8 | 9 | Coming soon! 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | :caption: Contents: 14 | 15 | Basics.lhs 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /hie.yaml: -------------------------------------------------------------------------------- 1 | cradle: 2 | stack: 3 | - path: "./haskell-torch-cbindings" 4 | component: "haskell-torch-cbindings:lib" 5 | - path: "./haskell-torch-tools" 6 | component: "haskell-torch-tools:exe:haskell-torch-tools-generate-ctensor" 7 | - path: "./haskell-torch" 8 | component: "haskell-torch:lib" 9 | - path: "./default-type-plugin" 10 | component: "default-type-plugin:lib" 11 | - path: "./simplify-nat-algebra-plugin" 12 | component: "simplify-nat-algebra-plugin:lib" 13 | - path: "./haskell-notebook-filter" 14 | component: "haskell-notebook-filter:exe:haskell-filter-jupyter" 15 | -------------------------------------------------------------------------------- /haskell-torch/src/Torch.hs: -------------------------------------------------------------------------------- 1 | module Torch(module X) where 2 | import Torch.Images as X 3 | import Torch.Indexing as X 4 | import Torch.Initialization as X 5 | import Torch.Inplace as X 6 | import Torch.Misc as X 7 | import Torch.Operators as X 8 | import Torch.Optimizer as X 9 | import Torch.StoredModel as X 10 | import Torch.Tensor as X 11 | import Torch.Tensorboard as X 12 | import Torch.Types as X 13 | import Torch.Visualization as X 14 | import Torch.Datasets.Augmentation as X 15 | import Torch.Datasets.Common as X 16 | -------------------------------------------------------------------------------- /patches/jupytext-add-haskell.patch: -------------------------------------------------------------------------------- 1 | diff --git a/jupytext/languages.py b/jupytext/languages.py 2 | index b1a2b8d..dc2f5a4 100644 3 | --- a/jupytext/languages.py 4 | +++ b/jupytext/languages.py 5 | @@ -48,6 +48,7 @@ _SCRIPT_EXTENSIONS = { 6 | ".js": {"language": "javascript", "comment": "//"}, 7 | ".ts": {"language": "typescript", "comment": "//"}, 8 | ".scala": {"language": "scala", "comment": "//"}, 9 | + ".hs": {"language": "haskell", "comment": "--"}, 10 | ".rs": {"language": "rust", "comment": "//"}, 11 | ".robot": {"language": "robotframework", "comment": "#"}, 12 | ".resource": {"language": "robotframework", "comment": "#"}, 13 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/src/Tensorboard/Proto/Summary.hs: -------------------------------------------------------------------------------- 1 | module Tensorboard.Proto.Summary( 2 | module Proto.Tensorboard.Src.Summary 3 | ,module Proto.Tensorboard.Src.Summary_Fields 4 | ,module Proto.Tensorboard.Src.Types 5 | ,module Proto.Tensorboard.Src.Types_Fields 6 | ,module Proto.Tensorboard.Src.ResourceHandle 7 | ,module Proto.Tensorboard.Src.ResourceHandle_Fields 8 | ) 9 | where 10 | import Proto.Tensorboard.Src.Summary 11 | import Proto.Tensorboard.Src.Summary_Fields 12 | import Proto.Tensorboard.Src.Types 13 | import Proto.Tensorboard.Src.Types_Fields 14 | import Proto.Tensorboard.Src.ResourceHandle 15 | import Proto.Tensorboard.Src.ResourceHandle_Fields 16 | 17 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /interpolateIO/test/Data/String/InterpolateIO/IsStringSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE QuasiQuotes #-} 2 | module Data.String.InterpolateIO.IsStringSpec (main, spec) where 3 | 4 | import Test.Hspec 5 | 6 | import qualified Data.Text as T 7 | import Data.String.InterpolateIO.IsString 8 | import System.IO.Unsafe 9 | 10 | main :: IO () 11 | main = hspec spec 12 | 13 | spec :: Spec 14 | spec = do 15 | describe "[c|...|]" $ do 16 | it "can be used to construct String literals" $ do 17 | (unsafePerformIO [c|foo #{23 :: Int} bar|]) `shouldBe` "foo 23 bar" 18 | it "can be used to construct Text literals" $ do 19 | (unsafePerformIO [c|foo #{23 :: Int} bar|]) `shouldBe` T.pack "foo 23 bar" 20 | -------------------------------------------------------------------------------- /interpolateIO/src/Data/String/InterpolateIO/Parse.hs: -------------------------------------------------------------------------------- 1 | module Data.String.InterpolateIO.Parse where 2 | 3 | import Data.String.InterpolateIO.Internal.Util 4 | 5 | data Node = Literal String | Expression String 6 | 7 | parseNodes :: String -> [Node] 8 | parseNodes = go "" 9 | where 10 | go :: String -> String -> [Node] 11 | go acc input = case input of 12 | "" -> [(lit . reverse) acc] 13 | '\\':x:xs -> go (x:'\\':acc) xs 14 | '#':'{':xs -> case span (/= '}') xs of 15 | (ys, _:zs) -> (lit . reverse) acc : Expression ys : go "" zs 16 | (_, "") -> [lit (reverse acc ++ input)] 17 | x:xs -> go (x:acc) xs 18 | 19 | lit :: String -> Node 20 | lit = Literal . unescape 21 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/src/Torch/C/Language.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE PatternSynonyms, TemplateHaskell #-} 2 | 3 | -- | Some helpful @inline-c@ QuasiQuoters 4 | 5 | module Torch.C.Language(cstorable) where 6 | import Foreign 7 | import qualified Language.C.Inline as C 8 | import Language.Haskell.TH 9 | import Language.Haskell.TH.Quote 10 | 11 | cstorable :: Name -> String -> DecsQ 12 | cstorable ty cname = 13 | [d|instance Storable $(conT ty) where 14 | sizeOf _ = fromIntegral $(quoteExp C.pure str) 15 | alignment _ = alignment (undefined :: Ptr ()) 16 | peek = error "not implemented" 17 | poke = error "not implemented"|] 18 | where str = "size_t { sizeof(" ++ cname ++ ") }" 19 | -------------------------------------------------------------------------------- /interpolateIO/src/Data/String/InterpolateIO/IsString.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | module Data.String.InterpolateIO.IsString (c, fromStringIO) where 3 | 4 | import Data.String.ShowIO(fromStringIO) 5 | import Language.Haskell.TH.Quote (QuasiQuoter(..)) 6 | 7 | import qualified Data.String.InterpolateIO as I 8 | 9 | -- | 10 | -- Like `I.c`, but constructs a value of type 11 | -- 12 | -- > IsString a => a 13 | c :: QuasiQuoter 14 | c = QuasiQuoter { 15 | quoteExp = \s -> [|fromStringIO =<< $(quoteExp I.c $ s)|] 16 | , quotePat = err "pattern" 17 | , quoteType = err "type" 18 | , quoteDec = err "declaration" 19 | } 20 | where 21 | err name = error ("Data.String.Interpolate.IsString.c: This QuasiQuoter can not be used as a " ++ name ++ "!") 22 | -------------------------------------------------------------------------------- /ihaskell-matplotlib/IHaskell/Display/Matplotlib.hs: -------------------------------------------------------------------------------- 1 | module IHaskell.Display.Matplotlib where 2 | import Graphics.Matplotlib 3 | import Graphics.Matplotlib.Internal 4 | import IHaskell.Display 5 | 6 | {-# LANGUAGE ExtendedDefaultRules #-} 7 | 8 | -- Only bleeding edge matplotlib has these helpers as of 2021. Delete these 9 | -- after a few years. 10 | 11 | -- | Get the SVG for a figure 12 | toSvg' :: Matplotlib -> IO (Either String String) 13 | toSvg' m = withMplot m (\s -> python $ pyIncludes "" ++ s ++ pySVG) 14 | 15 | pySVG' :: [[Char]] 16 | pySVG' = 17 | ["import io" 18 | ,"i = io.StringIO()" 19 | ,"plot.savefig(i, format='svg')" 20 | ,"print(i.getvalue())"] 21 | 22 | instance IHaskellDisplay Matplotlib where 23 | display m = do 24 | r <- toSvg' m 25 | case r of 26 | Left v -> error v 27 | Right v -> return $ Display [svg v] 28 | -------------------------------------------------------------------------------- /interpolateIO/test/Data/String/InterpolateIO/ParseSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE StandaloneDeriving #-} 2 | {-# OPTIONS_GHC -fno-warn-orphans #-} 3 | module Data.String.InterpolateIO.ParseSpec (main, spec) where 4 | 5 | import Test.Hspec 6 | 7 | import Data.String.InterpolateIO.Parse 8 | 9 | deriving instance Eq Node 10 | deriving instance Show Node 11 | 12 | main :: IO () 13 | main = hspec spec 14 | 15 | spec :: Spec 16 | spec = do 17 | describe "parseNodes" $ do 18 | it "parses string literals" $ do 19 | parseNodes "foo" `shouldBe` [Literal "foo"] 20 | 21 | it "parses embedded expressions" $ do 22 | parseNodes "foo #{bar} baz" `shouldBe` [Literal "foo ", Expression "bar", Literal " baz"] 23 | 24 | context "when given an unterminated expression" $ do 25 | it "parses it as a string literal" $ do 26 | parseNodes "foo #{bar" `shouldBe` [Literal "foo #{bar"] 27 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/src/Tensorboard/Proto/Graph.hs: -------------------------------------------------------------------------------- 1 | module Tensorboard.Proto.Graph( 2 | module Proto.Tensorboard.Src.Graph 3 | ,module Proto.Tensorboard.Src.Graph_Fields 4 | ,module Proto.Tensorboard.Src.NodeDef 5 | ,module Proto.Tensorboard.Src.NodeDef_Fields 6 | ,module Proto.Tensorboard.Src.Versions 7 | ,module Proto.Tensorboard.Src.Versions_Fields 8 | -- ,module Proto.Tensorboard.Src.AttrValue 9 | -- ,module Proto.Tensorboard.Src.AttrValue_Fields 10 | ) 11 | where 12 | import Proto.Tensorboard.Src.Graph 13 | import Proto.Tensorboard.Src.Graph_Fields 14 | import Proto.Tensorboard.Src.NodeDef 15 | import Proto.Tensorboard.Src.NodeDef_Fields 16 | import Proto.Tensorboard.Src.Versions 17 | import Proto.Tensorboard.Src.Versions_Fields 18 | -- import Proto.Tensorboard.Src.AttrValue 19 | -- import Proto.Tensorboard.Src.AttrValue_Fields 20 | -------------------------------------------------------------------------------- /ihaskell-matplotlib/package.yaml: -------------------------------------------------------------------------------- 1 | name: ihaskell-matplotlib 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Using matplotlib with IHaskell 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | # To avoid duplicated efforts in documentation and dealing with the 17 | # complications of embedding Haddock markup inside cabal files, it is 18 | # common to point users to the README.md file. 19 | description: Please see the README on Github at 20 | 21 | dependencies: 22 | - base >= 4.7 && < 5 23 | 24 | library: 25 | source-dirs: . 26 | dependencies: 27 | - matplotlib 28 | - ihaskell >= 0.6.2 29 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "ResourceHandle"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 10 | // not valid across executions, but can be serialized back and forth from within 11 | // a single run. 12 | message ResourceHandleProto { 13 | // Unique name for the device containing the resource. 14 | string device = 1; 15 | 16 | // Container in which this resource is placed. 17 | string container = 2; 18 | 19 | // Unique name of this resource. 20 | string name = 3; 21 | 22 | // Hash code for the type of the resource. Is only valid in the same device 23 | // and in the same execution. 24 | uint64 hash_code = 4; 25 | 26 | // For debug-only, the name of the type pointed to by this handle, if 27 | // available. 28 | string maybe_type_name = 5; 29 | }; 30 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-18.13 2 | 3 | packages: 4 | - haskell-torch-tools 5 | - haskell-torch-cbindings 6 | - haskell-torch-imagemagick 7 | - haskell-torch-matio 8 | - haskell-torch-tensorboard-proto 9 | - simplify-nat-algebra-plugin 10 | - interpolateIO 11 | - default-type-plugin 12 | - haskell-torch 13 | - haskell-torch-models 14 | - haskell-torch-datasets 15 | - dataframe 16 | # for ihaskell integration 17 | - haskell-notebook-filter 18 | - ihaskell-matplotlib 19 | - gym-haskell 20 | - haskell-torch-examples 21 | 22 | extra-deps: 23 | - git: https://github.com/abarbu/matplotlib-haskell 24 | commit: 5c186cb5b9e80212c92c72a68e9ebbc260d413a8 25 | - git: https://github.com/docopt/docopt.hs.git 26 | commit: bdc4c679bf0185ab6c1895172f011193d9e9922c 27 | - proto-lens-setup-0.4.0.5 28 | - proto-lens-0.7.1.0 29 | - proto-lens-runtime-0.7.0.1 30 | - proto-lens-protoc-0.7.1.0 31 | - git: https://github.com/abarbu/haskell-cpython 32 | commit: 3c3c89acbc5a5fa6d60fc23a148f39eb330ecfac 33 | 34 | allow-newer: true 35 | 36 | nix: 37 | enable: true 38 | shell-file: shell.nix 39 | -------------------------------------------------------------------------------- /haskell-torch-matio/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-matio 2 | version: 0.5 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: Foreign 11 | synopsis: Manipulate Matlab .mat files with matio 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | description: Please see the README on Github at 17 | 18 | dependencies: 19 | - base >= 4.7 && < 5 20 | 21 | library: 22 | source-dirs: src 23 | extra-libraries: 24 | - matio 25 | dependencies: 26 | - inline-c 27 | - inline-c-cpp 28 | - containers 29 | - text 30 | - bytestring 31 | - aeson 32 | - half 33 | - vector 34 | - extra 35 | - directory 36 | - filepath 37 | - ieee754 38 | - template-haskell 39 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/src/Torch/C/CUDA.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP, FlexibleContexts, QuasiQuotes, ScopedTypeVariables, TemplateHaskell #-} 2 | 3 | module Torch.C.CUDA where 4 | import Data.Monoid ((<>)) 5 | import Foreign.C.Types 6 | import qualified Language.C.Inline as C 7 | import qualified Language.C.Inline.Cpp as C 8 | import Torch.C.Types 9 | 10 | C.context (C.cppCtx <> tensorCtx) 11 | 12 | #if WITH_CUDA 13 | C.include "" 14 | C.include "" 15 | 16 | C.using "namespace at" 17 | C.using "namespace torch::autograd" 18 | 19 | hasCUDA :: IO CBool 20 | hasCUDA = [C.exp|bool{at::cuda::is_available()}|] 21 | 22 | currentDevice :: IO CInt 23 | currentDevice = [C.exp|int{at::cuda::current_device()}|] 24 | 25 | deviceCount :: IO CInt 26 | deviceCount = [C.exp|int{at::cuda::device_count()}|] 27 | #else 28 | hasCUDA :: IO CBool 29 | hasCUDA = [C.exp|bool{0}|] 30 | 31 | currentDevice :: IO CInt 32 | currentDevice = [C.exp|int{0}|] 33 | 34 | deviceCount :: IO CInt 35 | deviceCount = [C.exp|int{0}|] 36 | #endif 37 | -------------------------------------------------------------------------------- /interpolateIO/package.yaml: -------------------------------------------------------------------------------- 1 | name: interpolateIO 2 | version: 0.2.0 3 | category: Data, Text 4 | stability: experimental 5 | synopsis: String interpolation in IO based on interpolate 6 | description: String interpolation in IO based on interpolate 7 | license: MIT 8 | copyright: (c) 2018 Andrei Barbu, 2013-2015 Simon Hengel 9 | author: Andrei Barbu 10 | maintainer: Andrei Barbu 11 | 12 | github: abarbu/interpolateIO 13 | 14 | ghc-options: -Wall 15 | 16 | dependencies: 17 | - base == 4.* 18 | - template-haskell 19 | - haskell-src-meta >= 0.8 20 | - generics-eot 21 | 22 | source-dirs: src 23 | 24 | library: 25 | exposed-modules: 26 | - Data.String.InterpolateIO 27 | - Data.String.InterpolateIO.IsString 28 | - Data.String.InterpolateIO.Util 29 | - Data.String.InterpolateIO.Internal.Util 30 | - Data.String.ShowIO 31 | 32 | tests: 33 | spec: 34 | source-dirs: test 35 | main: Spec.hs 36 | dependencies: 37 | - text 38 | - bytestring 39 | - hspec >= 1.5 40 | - QuickCheck 41 | - quickcheck-instances 42 | - base-compat 43 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "VersionsProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Version information for a piece of serialized data 10 | // 11 | // There are different types of versions for each type of data 12 | // (GraphDef, etc.), but they all have the same common shape 13 | // described here. 14 | // 15 | // Each consumer has "consumer" and "min_producer" versions (specified 16 | // elsewhere). A consumer is allowed to consume this data if 17 | // 18 | // producer >= min_producer 19 | // consumer >= min_consumer 20 | // consumer not in bad_consumers 21 | // 22 | message VersionDef { 23 | // The version of the code that produced this data. 24 | int32 producer = 1; 25 | 26 | // Any consumer below this version is not allowed to consume this data. 27 | int32 min_consumer = 2; 28 | 29 | // Specific consumer versions which are disallowed (e.g. due to bugs). 30 | repeated int32 bad_consumers = 3; 31 | }; 32 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/src/Torch/C/Generator.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP, QuasiQuotes, TemplateHaskell #-} 2 | 3 | module Torch.C.Generator where 4 | import Data.Monoid ((<>)) 5 | import Data.Word 6 | import Foreign.Ptr 7 | import qualified Language.C.Inline as C 8 | import qualified Language.C.Inline.Cpp as C 9 | import Torch.C.Types 10 | 11 | C.context (tensorCtx <> C.funCtx) 12 | 13 | C.include "" 14 | 15 | C.using "namespace at" 16 | 17 | seed :: Ptr CGenerator -> IO Word64 18 | seed g = [C.exp|uint64_t { $(Generator *g)->seed() }|] 19 | 20 | initialSeed :: IO Word64 21 | initialSeed = [C.exp|uint64_t { at::detail::getDefaultCPUGenerator().current_seed() }|] 22 | 23 | setSeed :: Ptr CGenerator -> Word64 -> IO () 24 | setSeed g s = [C.exp|void { $(Generator *g)->set_current_seed($(uint64_t s)) }|] 25 | 26 | cpuGenerator :: IO (Ptr CGenerator) 27 | cpuGenerator = [C.exp|const Generator *{ &at::globalContext().defaultGenerator(kCPU) }|] 28 | 29 | #if WITH_CUDA 30 | cudaGenerator :: IO (Ptr CGenerator) 31 | cudaGenerator = [C.exp|const Generator *{ &at::globalContext().defaultGenerator(kCUDA) }|] 32 | #endif 33 | -------------------------------------------------------------------------------- /ihaskell-matplotlib/ihaskell-matplotlib.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: ihaskell-matplotlib 8 | version: 0.8.0.0 9 | synopsis: Using matplotlib with IHaskell 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | IHaskell.Display.Matplotlib 29 | other-modules: 30 | Paths_ihaskell_matplotlib 31 | hs-source-dirs: 32 | ./ 33 | build-depends: 34 | base >=4.7 && <5 35 | , ihaskell >=0.6.2 36 | , matplotlib 37 | default-language: Haskell2010 38 | -------------------------------------------------------------------------------- /interpolateIO/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Andrei Barbu , 2013-2015 Simon Hengel 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /interpolateIO/src/Data/String/InterpolateIO/Compat.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module Data.String.InterpolateIO.Compat ( 3 | readMaybe 4 | , module Language.Haskell.TH 5 | ) where 6 | 7 | import Language.Haskell.TH 8 | import Text.Read 9 | 10 | #if !MIN_VERSION_base(4,6,0) 11 | import qualified Text.ParserCombinators.ReadP as P 12 | #endif 13 | 14 | #if !MIN_VERSION_base(4,6,0) 15 | -- | Parse a string using the 'Read' instance. 16 | -- Succeeds if there is exactly one valid result. 17 | -- A 'Left' value indicates a parse error. 18 | readEither :: Read a => String -> Either String a 19 | readEither s = 20 | case [ x | (x,"") <- readPrec_to_S read' minPrec s ] of 21 | [x] -> Right x 22 | [] -> Left "Prelude.read: no parse" 23 | _ -> Left "Prelude.read: ambiguous parse" 24 | where 25 | read' = 26 | do x <- readPrec 27 | lift P.skipSpaces 28 | return x 29 | 30 | -- | Parse a string using the 'Read' instance. 31 | -- Succeeds if there is exactly one valid result. 32 | readMaybe :: Read a => String -> Maybe a 33 | readMaybe s = case readEither s of 34 | Left _ -> Nothing 35 | Right a -> Just a 36 | #endif 37 | -------------------------------------------------------------------------------- /interpolateIO/README.markdown: -------------------------------------------------------------------------------- 1 | # String interpolation in IO! 2 | 3 | This package is part of the Haskell-Torch ecosystem. Check out the documentation 4 | [there](https://github.com/abarbu/haskell-torch). 5 | 6 | It is a fork of [interpolate](http://hackage.haskell.org/package/interpolate) by 7 | Simon Hengel. For when you have values that can only be read by IO operations 8 | like IORefs or matrices backed by C data. 9 | 10 | ## Examples 11 | 12 | >>> :set -XQuasiQuotes 13 | >>> import Data.String.InterpolateIO 14 | 15 | Interpolates strings 16 | 17 | >>> let name = "Marvin" 18 | >>> putStrLn =<< [c|name: #{name}|] 19 | name: Marvin 20 | 21 | or integers 22 | 23 | >>> let age = 23 24 | >>> putStrLn =<< [c|age: #{age}|] 25 | age: 23 26 | 27 | or arbitrary Haskell expressions 28 | 29 | >>> let profession = "\955-scientist" 30 | >>> putStrLn =<< [c|profession: #{unwords [name, "the", profession]}|] 31 | profession: Marvin the λ-scientist 32 | 33 | or values in IO 34 | 35 | >>> import System.Environment 36 | >>> let profession = "\955-scientist" 37 | >>> putStrLn =<< [c|home directory: #{getEnv "HOME"}|] 38 | profession: 39 | -------------------------------------------------------------------------------- /simplify-nat-algebra-plugin/tests/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, FlexibleContexts #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | {-# LANGUAGE StandaloneDeriving #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE TypeInType #-} 7 | {-# LANGUAGE TypeOperators, PartialTypeSignatures #-} 8 | {-# LANGUAGE RankNTypes, ScopedTypeVariables, TypeOperators, TypeApplications, KindSignatures #-} 9 | {-# LANGUAGE DataKinds, GADTs, TypeOperators #-} 10 | {-# OPTIONS -fplugin GHC.TypeLits.KnownNat.Solver -fplugin Plugin.SimplifyNat #-} 11 | 12 | import Data.Proxy 13 | import GHC.TypeLits 14 | import GHC.TypeNats 15 | 16 | w :: forall (x :: Nat). KnownNat x => Proxy x -> Proxy (x + 1) 17 | w = undefined 18 | 19 | e :: forall (x :: Nat). KnownNat x => Proxy x -> Proxy (x - 1) 20 | e = undefined 21 | 22 | f :: forall (x :: Nat). ((1 <=? x) ~ 'True, KnownNat x) => Proxy x -> Proxy x 23 | f = undefined 24 | 25 | -- When this type is inferred you should get 26 | -- q :: (KnownNat x, (1 <=? (x + 1)) ~ 'True, (1 <=? ((x + 1) - 1)) ~ 'True) => Proxy x -> Proxy ((x + 1) - 1) 27 | -- without the plugin, and 28 | -- q :: (KnownNat x, (1 <=? x) ~ 'True) => Proxy x -> Proxy ((x + 1) - 1) 29 | -- with the plugin 30 | q :: _ => _ 31 | q i = f $ e $ w i 32 | 33 | main = pure 0 34 | -------------------------------------------------------------------------------- /default-type-plugin/default-type-plugin.cabal: -------------------------------------------------------------------------------- 1 | name: default-type-plugin 2 | version: 0.1 3 | synopsis: Unlock the power of type defaulting 4 | description: Tired of ambiguity errors? Want a more powerful defaulting mechanism? 5 | Look no further. 6 | homepage: http://github.com/abarbu/haskell-torch/default-type-plugin 7 | bug-reports: http://github.com/abarbu/haskell-torch/default-type-plugin 8 | license: BSD2 9 | license-file: LICENSE 10 | author: Andrei Barbu 11 | maintainer: andrei@0xab.com 12 | copyright: Copyright © 2020 Andrei Barbu 13 | category: Type System 14 | build-type: Simple 15 | extra-source-files: README.md 16 | CHANGELOG.md 17 | cabal-version: >=1.10 18 | 19 | source-repository head 20 | type: git 21 | location: https://github.com/abarbu/default-type-plugin.git 22 | 23 | flag deverror 24 | description: 25 | Enables `-Werror` for development mode and TravisCI 26 | default: False 27 | manual: True 28 | 29 | library 30 | exposed-modules: Plugin.DefaultType 31 | build-depends: base >= 4.8, 32 | ghc >= 8.10.2, 33 | containers 34 | hs-source-dirs: src 35 | default-language: Haskell2010 36 | -------------------------------------------------------------------------------- /haskell-torch-tools/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-tools 2 | version: 0.1.0.0 3 | github: "abarbu/haskell-torch/haskell-torch-tools" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Support tools to generate code for haskell-torch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | description: Please see the README on Github at 17 | 18 | dependencies: 19 | - base >= 4.7 && < 5 20 | 21 | executables: 22 | haskell-torch-tools-generate-ctensor: 23 | main: GenerateCTensor 24 | source-dirs: 25 | - app 26 | - src 27 | ghc-options: 28 | - -threaded 29 | - -rtsopts 30 | - -with-rtsopts=-N 31 | dependencies: 32 | - inline-c 33 | - inline-c-cpp 34 | - template-haskell 35 | - containers 36 | - text 37 | - yaml 38 | - bytestring 39 | - aeson 40 | - docopt 41 | - extra 42 | - directory 43 | - filepath 44 | - unordered-containers 45 | - vector 46 | - lens 47 | - lens-aeson 48 | - stache 49 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-cbindings 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Bindings to the C PyTorch library for Haskell-Torch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | description: Please see the README on Github at 17 | 18 | dependencies: 19 | - base >= 4.7 && < 5 20 | 21 | library: 22 | source-dirs: src 23 | dependencies: 24 | - inline-c 25 | - inline-c-cpp 26 | - template-haskell 27 | - containers 28 | - text 29 | - bytestring 30 | - half 31 | - vector 32 | - extra 33 | - ieee754 34 | - safe-exceptions 35 | extra-libraries: 36 | - stdc++ 37 | - hdf5 38 | - c10 39 | - torch 40 | - torch_cpu 41 | 42 | when: 43 | - condition: flag(cuda) 44 | cpp-options: -DWITH_CUDA 45 | extra-libraries: torch_cuda 46 | 47 | flags: 48 | cuda: 49 | description: If your pytorch is CUDA-enabled, you can set this 50 | manual: true 51 | default: false 52 | -------------------------------------------------------------------------------- /simplify-nat-algebra-plugin/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Andrei Barbu, MIT 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the 14 | distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /haskell-torch-matio/haskell-torch-matio.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-matio 8 | version: 0.5 9 | synopsis: Manipulate Matlab .mat files with matio 10 | description: Please see the README on Github at 11 | category: Foreign 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Foreign.Matio 29 | Foreign.Matio.Types 30 | other-modules: 31 | Paths_haskell_torch_matio 32 | hs-source-dirs: 33 | src 34 | extra-libraries: 35 | matio 36 | build-depends: 37 | aeson 38 | , base >=4.7 && <5 39 | , bytestring 40 | , containers 41 | , directory 42 | , extra 43 | , filepath 44 | , half 45 | , ieee754 46 | , inline-c 47 | , inline-c-cpp 48 | , template-haskell 49 | , text 50 | , vector 51 | default-language: Haskell2010 52 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/RL/Simple.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, ConstraintKinds, ExtendedDefaultRules, FlexibleContexts, FlexibleInstances, GADTs, OverloadedStrings, MultiParamTypeClasses #-} 2 | {-# LANGUAGE PolyKinds, QuasiQuotes, RankNTypes, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies #-} 3 | {-# LANGUAGE TypeFamilyDependencies, TypeInType, TypeOperators, UndecidableInstances #-} 4 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver -fplugin Plugin.DefaultType #-} 5 | 6 | -- | This example shows you Haskell & PyTorch code right next to one 7 | -- another. You get an idea of how the two are related and how to do so some of 8 | -- he most basic operations. 9 | 10 | module Torch.Tutorial.RL.Simple where 11 | import Control.Monad 12 | import Data.Default 13 | import Data.Kind 14 | import Data.Maybe 15 | import Data.Singletons 16 | import Data.String.InterpolateIO 17 | import qualified Data.Vector as V' 18 | import Data.Vector.Storable (Vector) 19 | import qualified Data.Vector.Storable as V 20 | import Foreign.C.Types 21 | import Pipes 22 | import qualified Pipes.Prelude as P 23 | import Torch 24 | import qualified Data.Vector.Storable as VS 25 | import qualified Simulator.Gym as G 26 | -------------------------------------------------------------------------------- /default-type-plugin/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015-2016, University of Twente, 2 | 2017-2018, QBayLogic 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are 7 | met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the 15 | distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /haskell-torch-imagemagick/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-imagemagick 2 | version: 0.1.0.5 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: Foreign 11 | synopsis: Basic image loading/saving with ImageMagick, minimal bindings for Haskell-Torch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | description: Please see the README on Github at 17 | 18 | dependencies: 19 | - base >= 4.7 && < 5 20 | 21 | library: 22 | source-dirs: src 23 | extra-libraries: 24 | - MagickWand-7.Q16HDRI 25 | - MagickCore-7.Q16HDRI 26 | dependencies: 27 | - inline-c 28 | - inline-c-cpp 29 | - containers 30 | - transformers 31 | - lifted-base 32 | - mtl 33 | - text 34 | - bytestring 35 | - aeson 36 | - half 37 | - vector 38 | - extra 39 | - directory 40 | - filepath 41 | - ieee754 42 | - template-haskell 43 | cpp-options: 44 | - -DMAGICKCORE_QUANTUM_DEPTH=16 45 | - -DMAGICKCORE_HDRI_ENABLE=0 46 | include-dirs: 47 | - /nix/store/crrwcbzgbbcjivf36kmminsmv9a0jv7v-imagemagick-7.1.0-4-dev/include/ImageMagick 48 | -------------------------------------------------------------------------------- /interpolateIO/test/Data/String/InterpolateIOSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE QuasiQuotes #-} 2 | module Data.String.InterpolateIOSpec (main, spec) where 3 | 4 | import Test.Hspec 5 | import Test.QuickCheck 6 | import System.IO.Unsafe 7 | 8 | import Data.String.InterpolateIO 9 | 10 | main :: IO () 11 | main = hspec spec 12 | 13 | spec :: Spec 14 | spec = do 15 | describe "[c|...|]" $ do 16 | it "interpolates an expression of type Int" $ do 17 | property $ \x y -> unsafePerformIO [c|foo #{x + y :: Int} bar|] `shouldBe` "foo " ++ show (x + y) ++ " bar" 18 | 19 | it "interpolates an expression of type String" $ do 20 | property $ \xs ys -> unsafePerformIO [c|foo #{xs ++ ys} bar|] `shouldBe` "foo " ++ xs ++ ys ++ " bar" 21 | 22 | it "accepts character escapes" $ do 23 | unsafePerformIO [c|foo \955 bar|] `shouldBe` "foo \955 bar" 24 | 25 | it "accepts character escapes in interpolated expressions" $ do 26 | unsafePerformIO [c|foo #{"\955" :: String} bar|] `shouldBe` "foo \955 bar" 27 | 28 | it "dose not strip backslashes (issue #1)" $ do 29 | unsafePerformIO [c|foo\\bar|] `shouldBe` "foo\\bar" 30 | 31 | it "allows to prevent interpolation by escaping the hash with a backslash" $ do 32 | unsafePerformIO [c|foo \#{23 :: Int} bar|] `shouldBe` "foo #{23 :: Int} bar" 33 | 34 | it "does not prevent interpolation on literal backslash" $ do 35 | unsafePerformIO [c|foo \\#{23 :: Int} bar|] `shouldBe` "foo \\23 bar" 36 | 37 | 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017-2019 Andrei Barbu (MIT) 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation and/or 11 | other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /patches/ihaskell-fixup-set-master-nov-01-2021.patch: -------------------------------------------------------------------------------- 1 | diff --git a/src/IHaskell/Eval/Util.hs b/src/IHaskell/Eval/Util.hs 2 | index ce3200d..d22a57e 100644 3 | --- a/src/IHaskell/Eval/Util.hs 4 | +++ b/src/IHaskell/Eval/Util.hs 5 | @@ -232,7 +232,19 @@ setExtension ext = do 6 | -- (newDynFlags). It returns a list of error messages. 7 | setFlags :: GhcMonad m => [String] -> m [String] 8 | setFlags ext = do 9 | - -- Try to parse flags. 10 | + -- Interactive flags first 11 | + -- Warnings and unrecognized flags will be handled when parsing again with session flags below 12 | + iflags <- getInteractiveDynFlags 13 | +#if MIN_VERSION_ghc(9,2,0) 14 | + logger <- getLogger 15 | + (iflags', _, _) <- parseDynamicFlags logger iflags (map noLoc ext) 16 | +#else 17 | + (iflags', _, _) <- parseDynamicFlags iflags (map noLoc ext) 18 | +#endif 19 | + let irestoredPkgs = iflags' { packageFlags = packageFlags iflags } 20 | + GHC.setInteractiveDynFlags irestoredPkgs 21 | + 22 | + -- Session flags next 23 | flags <- getSessionDynFlags 24 | #if MIN_VERSION_ghc(9,2,0) 25 | logger <- getLogger 26 | @@ -240,11 +252,9 @@ setFlags ext = do 27 | #else 28 | (flags', unrecognized, warnings) <- parseDynamicFlags flags (map noLoc ext) 29 | #endif 30 | - 31 | -- First, try to check if this flag matches any extension name. 32 | let restoredPkgs = flags' { packageFlags = packageFlags flags } 33 | _ <- GHC.setProgramDynFlags restoredPkgs 34 | - GHC.setInteractiveDynFlags restoredPkgs 35 | 36 | -- Create the parse errors. 37 | let noParseErrs = map (("Could not parse: " ++) . unLoc) unrecognized 38 | -------------------------------------------------------------------------------- /patches/ihaskell-fixup-set-0.10.2.1.patch: -------------------------------------------------------------------------------- 1 | diff --git a/src/IHaskell/Eval/Util.hs b/src/IHaskell/Eval/Util.hs 2 | index 945033a..c0de7d3 100644 3 | --- a/src/IHaskell/Eval/Util.hs 4 | +++ b/src/IHaskell/Eval/Util.hs 5 | @@ -57,6 +57,7 @@ import InstEnv (ClsInst(..)) 6 | import Unify (tcMatchTys) 7 | import qualified Pretty 8 | import qualified Outputable as O 9 | +import DynamicLoading (initializePlugins) 10 | #endif 11 | #if MIN_VERSION_ghc(8,6,0) 12 | #else 13 | @@ -220,10 +221,10 @@ setFlags ext = do 14 | flags <- getSessionDynFlags 15 | (flags', unrecognized, warnings) <- parseDynamicFlags flags (map noLoc ext) 16 | 17 | - -- First, try to check if this flag matches any extension name. 18 | - let restoredPkgs = flags' { packageFlags = packageFlags flags } 19 | - _ <- GHC.setProgramDynFlags restoredPkgs 20 | - GHC.setInteractiveDynFlags restoredPkgs 21 | + hsc_env <- GHC.getSession 22 | + flags'' <- liftIO (initializePlugins hsc_env (flags' { packageFlags = packageFlags flags })) 23 | + _ <- GHC.setProgramDynFlags flags'' 24 | + GHC.setInteractiveDynFlags flags'' 25 | 26 | -- Create the parse errors. 27 | let noParseErrs = map (("Could not parse: " ++) . unLoc) unrecognized 28 | @@ -232,7 +233,7 @@ setFlags ext = do 29 | #else 30 | allWarns = map unLoc warnings ++ 31 | #endif 32 | - ["-package not supported yet" | packageFlags flags /= packageFlags flags'] 33 | + ["-package not supported yet" | packageFlags flags /= packageFlags flags''] 34 | warnErrs = map ("Warning: " ++) allWarns 35 | return $ noParseErrs ++ warnErrs 36 | 37 | -------------------------------------------------------------------------------- /dataframe/src/Barbies/TH/Config.hs: -------------------------------------------------------------------------------- 1 | module Barbies.TH.Config 2 | ( DeclareBareBConfig(..) 3 | , classic 4 | , passthrough 5 | ) where 6 | import Language.Haskell.TH 7 | 8 | -- | Keep it in a separate module until NoFieldSelectors gets widespread 9 | data DeclareBareBConfig = DeclareBareBConfig 10 | { friends :: [Name] -- ^ Members with these types won't be wrapped with 'Wear' 11 | , bareName :: String -> Maybe String 12 | -- ^ generate a type synonym for the 'Barbies.Bare.Bare' type? 13 | , coveredName :: String -> Maybe String 14 | -- ^ generate a type synonym for the 'Barbies.Bare.Covered' type? 15 | , barbieName :: String -> String 16 | -- ^ modify the name of the datatype 17 | , switchName :: Q Name 18 | -- ^ the name of the type parameter to toggle between Bare and covered 19 | , wrapperName :: Q Name 20 | -- ^ the name of the type parameter of the wrapper for each field 21 | } 22 | 23 | -- | Does not define any type synonyms 24 | classic :: DeclareBareBConfig 25 | classic = DeclareBareBConfig 26 | { friends = [] 27 | , bareName = const Nothing 28 | , coveredName = const Nothing 29 | , barbieName = id 30 | , switchName = newName "sw" 31 | , wrapperName = newName "h" 32 | } 33 | 34 | -- | Defines a synonym for the bare type with the same name. 35 | -- The strippable definition is suffixed by B, and the covered type is suffixed by H. 36 | passthrough :: DeclareBareBConfig 37 | passthrough = DeclareBareBConfig 38 | { friends = [] 39 | , bareName = Just 40 | , coveredName = Just . (++"H") 41 | , barbieName = (++"B") 42 | , switchName = newName "sw" 43 | , wrapperName = newName "h" 44 | } 45 | -------------------------------------------------------------------------------- /haskell-torch-tools/haskell-torch-tools.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-tools 8 | version: 0.1.0.0 9 | synopsis: Support tools to generate code for haskell-torch 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | subdir: haskell-torch-tools 26 | 27 | executable haskell-torch-tools-generate-ctensor 28 | main-is: GenerateCTensor.hs 29 | other-modules: 30 | Types 31 | Paths_haskell_torch_tools 32 | hs-source-dirs: 33 | app 34 | src 35 | ghc-options: -threaded -rtsopts -with-rtsopts=-N -main-is GenerateCTensor 36 | build-depends: 37 | aeson 38 | , base >=4.7 && <5 39 | , bytestring 40 | , containers 41 | , directory 42 | , docopt 43 | , extra 44 | , filepath 45 | , inline-c 46 | , inline-c-cpp 47 | , lens 48 | , lens-aeson 49 | , stache 50 | , template-haskell 51 | , text 52 | , unordered-containers 53 | , vector 54 | , yaml 55 | default-language: Haskell2010 56 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | package tensorboard; 10 | 11 | // Dimensions of a tensor. 12 | message TensorShapeProto { 13 | // One dimension of the tensor. 14 | message Dim { 15 | // Size of the tensor in that dimension. 16 | // This value must be >= -1, but values of -1 are reserved for "unknown" 17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 18 | // that work with TensorShapeProto may fail at runtime when deserializing 19 | // a TensorShapeProto containing a dim value of -1. 20 | int64 size = 1; 21 | 22 | // Optional name of the tensor dimension. 23 | string name = 2; 24 | }; 25 | 26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 27 | // for a 30 x 40 2D tensor. If an entry has size -1, this 28 | // corresponds to a dimension of unknown size. The names are 29 | // optional. 30 | // 31 | // The order of entries in "dim" matters: It indicates the layout of the 32 | // values in the tensor in-memory representation. 33 | // 34 | // The first entry in "dim" is the outermost dimension used to layout the 35 | // values, the last entry is the innermost dimension. This matches the 36 | // in-memory layout of RowMajor Eigen tensors. 37 | // 38 | // If "dim.size()" > 0, "unknown_rank" must be false. 39 | repeated Dim dim = 2; 40 | 41 | // If true, the number of dimensions in the shape is unknown. 42 | // 43 | // If true, "dim.size()" must be 0. 44 | bool unknown_rank = 3; 45 | }; 46 | -------------------------------------------------------------------------------- /haskell-torch-imagemagick/haskell-torch-imagemagick.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-imagemagick 8 | version: 0.1.0.5 9 | synopsis: Basic image loading/saving with ImageMagick, minimal bindings for Haskell-Torch 10 | description: Please see the README on Github at 11 | category: Foreign 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Foreign.ImageMagick 29 | Foreign.ImageMagick.Types 30 | other-modules: 31 | Paths_haskell_torch_imagemagick 32 | hs-source-dirs: 33 | src 34 | cpp-options: -DMAGICKCORE_QUANTUM_DEPTH=16 -DMAGICKCORE_HDRI_ENABLE=0 35 | include-dirs: 36 | /nix/store/crrwcbzgbbcjivf36kmminsmv9a0jv7v-imagemagick-7.1.0-4-dev/include/ImageMagick 37 | extra-libraries: 38 | MagickWand-7.Q16HDRI 39 | MagickCore-7.Q16HDRI 40 | build-depends: 41 | aeson 42 | , base >=4.7 && <5 43 | , bytestring 44 | , containers 45 | , directory 46 | , extra 47 | , filepath 48 | , half 49 | , ieee754 50 | , inline-c 51 | , inline-c-cpp 52 | , lifted-base 53 | , mtl 54 | , template-haskell 55 | , text 56 | , transformers 57 | , vector 58 | default-language: Haskell2010 59 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/haskell-torch-cbindings.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-cbindings 8 | version: 0.8.0.0 9 | synopsis: Bindings to the C PyTorch library for Haskell-Torch 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | flag cuda 27 | description: If your pytorch is CUDA-enabled, you can set this 28 | manual: True 29 | default: False 30 | 31 | library 32 | exposed-modules: 33 | Torch.C.CUDA 34 | Torch.C.Generator 35 | Torch.C.Language 36 | Torch.C.Scalar 37 | Torch.C.Tensor 38 | Torch.C.Types 39 | Torch.C.Variable 40 | other-modules: 41 | Paths_haskell_torch_cbindings 42 | hs-source-dirs: 43 | src 44 | extra-libraries: 45 | stdc++ 46 | hdf5 47 | c10 48 | torch 49 | torch_cpu 50 | build-depends: 51 | base >=4.7 && <5 52 | , bytestring 53 | , containers 54 | , extra 55 | , half 56 | , ieee754 57 | , inline-c 58 | , inline-c-cpp 59 | , safe-exceptions 60 | , template-haskell 61 | , text 62 | , vector 63 | if flag(cuda) 64 | cpp-options: -DWITH_CUDA 65 | extra-libraries: 66 | torch_cuda 67 | default-language: Haskell2010 68 | -------------------------------------------------------------------------------- /.hlint.yaml: -------------------------------------------------------------------------------- 1 | # HLint configuration file 2 | # https://github.com/ndmitchell/hlint 3 | ########################## 4 | 5 | # This file contains a template configuration file, which is typically 6 | # placed as .hlint.yaml in the root of your project 7 | 8 | 9 | # Specify additional command line arguments 10 | # 11 | # - arguments: [--color, --cpp-simple, -XQuasiQuotes] 12 | 13 | 14 | # Control which extensions/flags/modules/functions can be used 15 | # 16 | # - extensions: 17 | # - default: false # all extension are banned by default 18 | # - name: [PatternGuards, ViewPatterns] # only these listed extensions can be used 19 | # - {name: CPP, within: CrossPlatform} # CPP can only be used in a given module 20 | # 21 | # - flags: 22 | # - {name: -w, within: []} # -w is allowed nowhere 23 | # 24 | # - modules: 25 | # - {name: [Data.Set, Data.HashSet], as: Set} # if you import Data.Set qualified, it must be as 'Set' 26 | # - {name: Control.Arrow, within: []} # Certain modules are banned entirely 27 | # 28 | # - functions: 29 | # - {name: unsafePerformIO, within: []} # unsafePerformIO can only appear in no modules 30 | 31 | 32 | # Add custom hints for this project 33 | # 34 | # Will suggest replacing "wibbleMany [myvar]" with "wibbleOne myvar" 35 | # - error: {lhs: "wibbleMany [x]", rhs: wibbleOne x} 36 | 37 | 38 | # Turn on hints that are off by default 39 | # 40 | # Ban "module X(module X) where", to require a real export list 41 | # - warn: {name: Use explicit module export list} 42 | # 43 | # Replace a $ b $ c with a . b $ c 44 | # - group: {name: dollar, enabled: true} 45 | # 46 | # Generalise map to fmap, ++ to <> 47 | # - group: {name: generalise, enabled: true} 48 | 49 | 50 | # Ignore some builtin hints 51 | - ignore: {name: Eta reduce} 52 | - ignore: {name: Functor law} 53 | - ignore: {name: Reduce duplication} 54 | # - ignore: {name: Use let} 55 | # - ignore: {name: Use const, within: SpecialModule} # Only within certain modules 56 | 57 | 58 | # Define some custom infix operators 59 | # - fixity: infixr 3 ~^#^~ 60 | -------------------------------------------------------------------------------- /simplify-nat-algebra-plugin/simplify-nat-algebra-plugin.cabal: -------------------------------------------------------------------------------- 1 | name: simplify-nat-algebra-plugin 2 | version: 0.1 3 | synopsis: Computer algebra to simplify long type-level Nat expressions 4 | description: Opertions on datatypes with type-level dimensions often lead to very long expressions that can easily be simplified. 5 | homepage: http://github.com/abarbu/haskell-torch/simplify-nat-algebra-plugin 6 | bug-reports: http://github.com/abarbu/haskell-torch/simplify-nat-algebra-plugin 7 | license: BSD2 8 | license-file: LICENSE 9 | author: Andrei Barbu 10 | maintainer: andrei@0xab.com 11 | copyright: Copyright © 2020 Andrei Barbu 12 | category: Type System 13 | build-type: Simple 14 | extra-source-files: README.md 15 | cabal-version: >=1.10 16 | 17 | source-repository head 18 | type: git 19 | location: https://github.com/abarbu/haskell-torch/simplify-nat-algebra-plugin.git 20 | 21 | flag deverror 22 | description: 23 | Enables `-Werror` for development mode and TravisCI 24 | default: False 25 | manual: True 26 | 27 | library 28 | exposed-modules: Plugin.SimplifyNat 29 | build-depends: base >= 4.8, 30 | ghc >= 8.10.2, 31 | containers 32 | hs-source-dirs: src 33 | default-language: Haskell2010 34 | 35 | test-suite test-simplify-nat-algebra-plugin 36 | type: exitcode-stdio-1.0 37 | main-is: Main.hs 38 | build-depends: base >= 4.8 && <5, 39 | simplify-nat-algebra-plugin, 40 | ghc-typelits-knownnat, 41 | tasty >= 0.10, 42 | tasty-hunit >= 0.9, 43 | template-haskell >= 2.11.0.0 44 | hs-source-dirs: tests 45 | default-language: Haskell2010 46 | if flag(deverror) 47 | ghc-options: -O0 -dcore-lint 48 | -------------------------------------------------------------------------------- /dataframe/src/Barbies/FieldName.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, ConstraintKinds, ExtendedDefaultRules, FlexibleContexts, FlexibleInstances, GADTs, OverloadedStrings, MultiParamTypeClasses #-} 2 | {-# LANGUAGE PolyKinds, QuasiQuotes, RankNTypes, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies, DeriveAnyClass #-} 3 | {-# LANGUAGE TypeFamilyDependencies, TypeInType, TypeOperators, UndecidableInstances, StandaloneDeriving, DeriveGeneric, QuantifiedConstraints #-} 4 | {-# LANGUAGE CPP, DuplicateRecordFields, RecordWildCards, FunctionalDependencies, DeriveDataTypeable, DefaultSignatures #-} 5 | module Barbies.FieldName where 6 | import GHC.TypeLits(Symbol) 7 | import GHC.Generics (Generic) 8 | import qualified GHC.Generics as GHC 9 | import Data.Proxy 10 | import Control.Applicative 11 | import Barbies 12 | import Data.String(IsString(..)) 13 | import GHC.TypeLits(KnownSymbol(..),symbolVal) 14 | 15 | class FieldNameB b where 16 | bfieldName' :: b p 17 | 18 | instance FieldNameB GHC.U1 where 19 | bfieldName' = GHC.U1 20 | 21 | instance (FieldNameB t) => FieldNameB (GHC.M1 GHC.C m t) where 22 | bfieldName' = GHC.M1 bfieldName' 23 | 24 | instance (FieldNameB t) => FieldNameB (GHC.M1 GHC.D m t) where 25 | bfieldName' = GHC.M1 bfieldName' 26 | 27 | instance (FieldNameB f, FieldNameB g) => FieldNameB (f GHC.:*: g) where 28 | bfieldName' = bfieldName' GHC.:*: bfieldName' 29 | 30 | instance (m ~ 'GHC.MetaSel ('Just name) su ss ds, IsString a, KnownSymbol name) => FieldNameB (GHC.M1 GHC.S m (GHC.Rec1 (Const a))) where 31 | bfieldName' = GHC.M1 $ GHC.Rec1 $ Const $ fromString $ symbolVal (Proxy :: Proxy name) 32 | 33 | instance (m ~ 'GHC.MetaSel ('Just name) su ss ds, IsString a, KnownSymbol name) => FieldNameB (GHC.M1 GHC.S m (GHC.Rec0 (Const a x))) where 34 | bfieldName' = GHC.M1 $ GHC.K1 $ Const $ fromString $ symbolVal (Proxy :: Proxy name) 35 | 36 | bfieldName :: (Generic (b (Const a)), FieldNameB (GHC.Rep (b (Const a)))) => IsString a => b (Const a) 37 | bfieldName = GHC.to bfieldName' 38 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/README.md: -------------------------------------------------------------------------------- 1 | # An intro to Haskell-Torch 2 | 3 | Tutorials in Intro directory are reimplementations of the [yunjey 4 | tutorials](https://github.com/yunjey/pytorch-tutorial). All credit goes to the 5 | original author. 6 | 7 | They will teach you how to train your own networks with Haskell-Torch on 8 | standard datasets. They also serves as an end-to-end test and benchmark 9 | suite. We tried to translate the original code rather faithfully. 10 | 11 | 1. [The basics](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T01_Basics.hs) 12 | 2. [Linear regression](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T02_LinearRegression.hs) 13 | 3. [Logistic regression](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T03_LogisticRegression.hs) 14 | 4. [Feedforward networks](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T04_FeedforwardNN.hs) 15 | 5. [CNNs](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T05_CNN.hs) 16 | 6. [ResNet](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T06_ResNet.hs) 17 | 7. [RNNs](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T07_RNN.hs) 18 | 8. [BiRNNs](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T08_BiRNN.hs) 19 | 9. [Generate language](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T09_LanguageModel.hs) 20 | 10. [GANs](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T10_GAN.hs) 21 | 11. [VAE](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Intro/T11_VAE.hs) 22 | 23 | We also have a basic introduction to how you should integrate with 24 | [Tensorboard](https://github.com/abarbu/haskell-torch/blob/master/haskell-torch/src/Torch/Tutorial/Tensorboard.hs) 25 | -------------------------------------------------------------------------------- /interpolateIO/test/Data/String/InterpolateIO/UtilSpec.hs: -------------------------------------------------------------------------------- 1 | module Data.String.InterpolateIO.UtilSpec (main, spec) where 2 | 3 | import Prelude () 4 | import Prelude.Compat 5 | 6 | import Test.Hspec 7 | import Test.QuickCheck 8 | 9 | import Data.String.InterpolateIO.Util 10 | 11 | main :: IO () 12 | main = hspec spec 13 | 14 | emptyLine :: Gen String 15 | emptyLine = (++ "\n") <$> listOf (elements " \t") 16 | 17 | spec :: Spec 18 | spec = do 19 | describe "unindent" $ do 20 | it "removes indentation" $ do 21 | let xs = " foo\n bar\n baz \n" 22 | unindent xs `shouldBe` " foo\nbar\n baz \n" 23 | 24 | it "removes the first line of the string if it is empty" $ do 25 | forAll emptyLine $ \xs -> do 26 | let ys = " foo\nbar\n baz\n" 27 | unindent (xs ++ ys) `shouldBe` ys 28 | 29 | it "does not affect additional empty lines at the beginning" $ do 30 | unindent " \n \nfoo" `shouldBe` " \nfoo" 31 | 32 | it "empties the last line if it only consists of spaces" $ do 33 | let xs = "foo\n " 34 | unindent xs `shouldBe` "foo\n" 35 | 36 | it "does not affect other whitespace lines at the end" $ do 37 | unindent "foo\n \n " `shouldBe` "foo\n \n" 38 | 39 | it "disregards empty lines when calculating indentation" $ do 40 | let xs = " foo\n\n \n bar\n" 41 | unindent xs `shouldBe` "foo\n\n\nbar\n" 42 | 43 | it "correctly handles strings that do not end with a newline" $ do 44 | let xs = "foo" 45 | unindent xs `shouldBe` xs 46 | 47 | it "does not affect lines consisting of whitespace (apart from unindenting)" $ do 48 | unindent " foo\n \n bar" `shouldBe` "foo\n \nbar" 49 | 50 | it "is total" $ do 51 | property $ \xs -> length (unindent xs) `shouldSatisfy` (>= 0) 52 | 53 | context "when all lines are empty" $ do 54 | it "does not unindent at all" $ do 55 | forAll emptyLine $ \x -> (forAll $ listOf emptyLine) $ \xs -> do 56 | let ys = concat xs 57 | unindent (x ++ ys) `shouldBe` ys 58 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TypesProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // LINT.IfChange 10 | enum DataType { 11 | // Not a legal value for DataType. Used to indicate a DataType field 12 | // has not been set. 13 | DT_INVALID = 0; 14 | 15 | // Data types that all computation devices are expected to be 16 | // capable to support. 17 | DT_FLOAT = 1; 18 | DT_DOUBLE = 2; 19 | DT_INT32 = 3; 20 | DT_UINT8 = 4; 21 | DT_INT16 = 5; 22 | DT_INT8 = 6; 23 | DT_STRING = 7; 24 | DT_COMPLEX64 = 8; // Single-precision complex 25 | DT_INT64 = 9; 26 | DT_BOOL = 10; 27 | DT_QINT8 = 11; // Quantized int8 28 | DT_QUINT8 = 12; // Quantized uint8 29 | DT_QINT32 = 13; // Quantized int32 30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 31 | DT_QINT16 = 15; // Quantized int16 32 | DT_QUINT16 = 16; // Quantized uint16 33 | DT_UINT16 = 17; 34 | DT_COMPLEX128 = 18; // Double-precision complex 35 | DT_HALF = 19; 36 | DT_RESOURCE = 20; 37 | 38 | // TODO(josh11b): DT_GENERIC_PROTO = ??; 39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? 40 | 41 | // Do not use! These are only for parameters. Every enum above 42 | // should have a corresponding value below (verified by types_test). 43 | DT_FLOAT_REF = 101; 44 | DT_DOUBLE_REF = 102; 45 | DT_INT32_REF = 103; 46 | DT_UINT8_REF = 104; 47 | DT_INT16_REF = 105; 48 | DT_INT8_REF = 106; 49 | DT_STRING_REF = 107; 50 | DT_COMPLEX64_REF = 108; 51 | DT_INT64_REF = 109; 52 | DT_BOOL_REF = 110; 53 | DT_QINT8_REF = 111; 54 | DT_QUINT8_REF = 112; 55 | DT_QINT32_REF = 113; 56 | DT_BFLOAT16_REF = 114; 57 | DT_QINT16_REF = 115; 58 | DT_QUINT16_REF = 116; 59 | DT_UINT16_REF = 117; 60 | DT_COMPLEX128_REF = 118; 61 | DT_HALF_REF = 119; 62 | DT_RESOURCE_REF = 120; 63 | } 64 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) 65 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-tensorboard-proto 2 | version: 0.1.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Talk to tensorboard from Haskell; the protobuf bindings 12 | 13 | # To avoid duplicated efforts in documentation and dealing with the 14 | # complications of embedding Haddock markup inside cabal files, it is 15 | # common to point users to the README.md file. 16 | description: Please see the README on Github at 17 | 18 | dependencies: 19 | - base >= 4.7 && < 5 20 | 21 | extra-source-files: proto/**/*.proto 22 | 23 | custom-setup: 24 | dependencies: 25 | - base 26 | - Cabal 27 | - proto-lens-setup 28 | 29 | library: 30 | source-dirs: src 31 | dependencies: 32 | - data-default 33 | - microlens 34 | - proto-lens 35 | - proto-lens-runtime 36 | - proto-lens-protoc 37 | - text 38 | other-modules: 39 | - Proto.Tensorboard.Src.Summary 40 | - Proto.Tensorboard.Src.Summary_Fields 41 | - Proto.Tensorboard.Src.Tensor 42 | - Proto.Tensorboard.Src.Tensor_Fields 43 | - Proto.Tensorboard.Src.Types 44 | - Proto.Tensorboard.Src.Types_Fields 45 | - Proto.Tensorboard.Src.ResourceHandle 46 | - Proto.Tensorboard.Src.ResourceHandle_Fields 47 | - Proto.Tensorboard.Src.TensorShape 48 | - Proto.Tensorboard.Src.TensorShape_Fields 49 | - Proto.Tensorboard.Src.Event 50 | - Proto.Tensorboard.Src.Event_Fields 51 | - Proto.Tensorboard.Src.Graph 52 | - Proto.Tensorboard.Src.Graph_Fields 53 | - Proto.Tensorboard.Src.NodeDef 54 | - Proto.Tensorboard.Src.NodeDef_Fields 55 | - Proto.Tensorboard.Src.Versions 56 | - Proto.Tensorboard.Src.Versions_Fields 57 | - Proto.Tensorboard.Src.AttrValue 58 | - Proto.Tensorboard.Src.AttrValue_Fields 59 | -------------------------------------------------------------------------------- /dataframe/package.yaml: -------------------------------------------------------------------------------- 1 | name: dataframe 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Haskell data frames, the equivalent of Pandas 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | # To avoid duplicated efforts in documentation and dealing with the 17 | # complications of embedding Haddock markup inside cabal files, it is 18 | # common to point users to the README.md file. 19 | description: Please see the README on Github at 20 | 21 | dependencies: 22 | - base >= 4.7 && < 5 23 | 24 | library: 25 | source-dirs: src 26 | dependencies: 27 | - aeson 28 | - array 29 | - binary 30 | - bytestring 31 | - containers 32 | - data-default 33 | - directory 34 | - extra 35 | - filepath 36 | - generics-eot 37 | - hashable 38 | - hashtables 39 | - haskell-src-exts 40 | - haskell-src-meta 41 | - hostname 42 | - ieee754 43 | - interpolateIO 44 | - matplotlib 45 | - megaparsec 46 | - microlens 47 | - monad-control 48 | - monad-loops 49 | - mtl 50 | - parser-combinators 51 | - pipes 52 | - pipes-aeson 53 | - pipes-bytestring 54 | - pipes-concurrency 55 | - pipes-csv 56 | - pipes-extras 57 | - pipes-group 58 | - pipes-parse 59 | - pipes-safe 60 | - proto-lens 61 | - random 62 | - safe-exceptions 63 | - shelly 64 | - singletons 65 | - statistics 66 | - string-qq 67 | - syb 68 | - template-haskell 69 | - text 70 | - time 71 | - unix 72 | - vector 73 | - yaml 74 | - zlib 75 | - monad-logger 76 | - temporary 77 | - generics-sop 78 | - deepseq 79 | - ghc-prim 80 | - barbies 81 | - vector-algorithms 82 | - prettyprinter 83 | - split 84 | - pptable 85 | - lens 86 | - boxes 87 | - cassava 88 | 89 | -------------------------------------------------------------------------------- /haskell-torch-models/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-models 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Deep learning in Haskell on top of Torch and PyTorch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | # To avoid duplicated efforts in documentation and dealing with the 17 | # complications of embedding Haddock markup inside cabal files, it is 18 | # common to point users to the README.md file. 19 | description: Please see the README on Github at 20 | 21 | dependencies: 22 | - base >= 4.7 && < 5 23 | 24 | library: 25 | source-dirs: src 26 | dependencies: 27 | - haskell-torch 28 | - haskell-torch-datasets 29 | - haskell-torch-imagemagick 30 | - haskell-torch-matio 31 | - haskell-torch-tensorboard-proto 32 | - aeson 33 | - array 34 | - binary 35 | - bytestring 36 | - containers 37 | - data-default 38 | - directory 39 | - extra 40 | - filepath 41 | - generics-eot 42 | - ghc-typelits-knownnat 43 | - ghc-typelits-natnormalise 44 | - half 45 | - hashable 46 | - hashtables 47 | - haskell-src-exts 48 | - haskell-src-meta 49 | - hostname 50 | - ieee754 51 | - interpolateIO 52 | - matplotlib 53 | - megaparsec 54 | - microlens 55 | - monad-control 56 | - monad-loops 57 | - mtl 58 | - parser-combinators 59 | - pipes 60 | - pipes-aeson 61 | - pipes-bytestring 62 | - pipes-concurrency 63 | - pipes-csv 64 | - pipes-extras 65 | - pipes-group 66 | - pipes-parse 67 | - pipes-safe 68 | - proto-lens 69 | - random 70 | - safe-exceptions 71 | - shelly 72 | - singletons 73 | - statistics 74 | - string-qq 75 | - syb 76 | - template-haskell 77 | - text 78 | - time 79 | - unix 80 | - vector 81 | - yaml 82 | - zlib 83 | - monad-logger 84 | - docopt 85 | - default-type-plugin 86 | - simplify-nat-algebra-plugin 87 | - temporary 88 | -------------------------------------------------------------------------------- /haskell-torch-examples/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-examples 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Examples of how to use Haskell-Torch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | # To avoid duplicated efforts in documentation and dealing with the 17 | # complications of embedding Haddock markup inside cabal files, it is 18 | # common to point users to the README.md file. 19 | description: Please see the README on Github at 20 | 21 | dependencies: 22 | - base >= 4.7 && < 5 23 | 24 | library: 25 | source-dirs: src 26 | dependencies: 27 | - haskell-torch 28 | - haskell-torch-models 29 | - haskell-torch-datasets 30 | - haskell-torch-imagemagick 31 | - haskell-torch-matio 32 | - haskell-torch-tensorboard-proto 33 | - aeson 34 | - array 35 | - binary 36 | - bytestring 37 | - containers 38 | - data-default 39 | - directory 40 | - extra 41 | - filepath 42 | - generics-eot 43 | - ghc-typelits-knownnat 44 | - ghc-typelits-natnormalise 45 | - half 46 | - hashable 47 | - hashtables 48 | - haskell-src-exts 49 | - haskell-src-meta 50 | - hostname 51 | - ieee754 52 | - interpolateIO 53 | - matplotlib 54 | - megaparsec 55 | - microlens 56 | - monad-control 57 | - monad-loops 58 | - mtl 59 | - parser-combinators 60 | - pipes 61 | - pipes-aeson 62 | - pipes-bytestring 63 | - pipes-concurrency 64 | - pipes-csv 65 | - pipes-extras 66 | - pipes-group 67 | - pipes-parse 68 | - pipes-safe 69 | - proto-lens 70 | - random 71 | - safe-exceptions 72 | - shelly 73 | - singletons 74 | - statistics 75 | - string-qq 76 | - syb 77 | - template-haskell 78 | - text 79 | - time 80 | - unix 81 | - vector 82 | - yaml 83 | - zlib 84 | - monad-logger 85 | - docopt 86 | - default-type-plugin 87 | - simplify-nat-algebra-plugin 88 | - gym 89 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "GraphProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboard/src/node_def.proto"; 10 | //import "tensorflow/core/framework/function.proto"; 11 | import "tensorboard/src/versions.proto"; 12 | 13 | // Represents the graph of operations 14 | message GraphDef { 15 | repeated NodeDef node = 1; 16 | 17 | // Compatibility versions of the graph. See core/public/version.h for version 18 | // history. The GraphDef version is distinct from the TensorFlow version, and 19 | // each release of TensorFlow will support a range of GraphDef versions. 20 | VersionDef versions = 4; 21 | 22 | // Deprecated single version field; use versions above instead. Since all 23 | // GraphDef changes before "versions" was introduced were forward 24 | // compatible, this field is entirely ignored. 25 | int32 version = 3 [deprecated = true]; 26 | 27 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 28 | // 29 | // "library" provides user-defined functions. 30 | // 31 | // Naming: 32 | // * library.function.name are in a flat namespace. 33 | // NOTE: We may need to change it to be hierarchical to support 34 | // different orgs. E.g., 35 | // { "/google/nn", { ... }}, 36 | // { "/google/vision", { ... }} 37 | // { "/org_foo/module_bar", { ... }} 38 | // map named_lib; 39 | // * If node[i].op is the name of one function in "library", 40 | // node[i] is deemed as a function call. Otherwise, node[i].op 41 | // must be a primitive operation supported by the runtime. 42 | // 43 | // 44 | // Function call semantics: 45 | // 46 | // * The callee may start execution as soon as some of its inputs 47 | // are ready. The caller may want to use Tuple() mechanism to 48 | // ensure all inputs are ready in the same time. 49 | // 50 | // * The consumer of return values may start executing as soon as 51 | // the return values the consumer depends on are ready. The 52 | // consumer may want to use Tuple() mechanism to ensure the 53 | // consumer does not start until all return values of the callee 54 | // function are ready. 55 | //FunctionDefLibrary library = 2; 56 | }; 57 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | from recommonmark.parser import CommonMarkParser 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'Haskell-Torch' 22 | copyright = '2019, Andrei Barbu' 23 | author = 'Andrei Barbu' 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | ] 33 | 34 | # Add any paths that contain templates here, relative to this directory. 35 | templates_path = ['_templates'] 36 | 37 | # List of patterns, relative to source directory, that match files and 38 | # directories to ignore when looking for source files. 39 | # This pattern also affects html_static_path and html_extra_path. 40 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 41 | 42 | 43 | # -- Options for HTML output ------------------------------------------------- 44 | 45 | # The theme to use for HTML and HTML Help pages. See the documentation for 46 | # a list of builtin themes. 47 | # 48 | html_theme = 'sphinx_rtd_theme' 49 | 50 | # Add any paths that contain custom static files (such as style sheets) here, 51 | # relative to this directory. They are copied after the builtin static files, 52 | # so a file named "default.css" will overwrite the builtin "default.css". 53 | html_static_path = ['_static'] 54 | 55 | # The suffix(es) of source filenames. 56 | # You can specify multiple suffix as a list of string: 57 | # 58 | source_suffix = ['.rst', '.md', '.lhs'] 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | source_parsers = { 64 | '.lhs': CommonMarkParser, 65 | } 66 | 67 | -------------------------------------------------------------------------------- /haskell-torch-datasets/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch-datasets 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Deep learning in Haskell on top of Torch and PyTorch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | # To avoid duplicated efforts in documentation and dealing with the 17 | # complications of embedding Haddock markup inside cabal files, it is 18 | # common to point users to the README.md file. 19 | description: Please see the README on Github at 20 | 21 | dependencies: 22 | - base >= 4.7 && < 5 23 | 24 | library: 25 | source-dirs: src 26 | dependencies: 27 | - haskell-torch 28 | - haskell-torch-imagemagick 29 | - haskell-torch-matio 30 | - haskell-torch-tensorboard-proto 31 | - aeson 32 | - array 33 | - binary 34 | - bytestring 35 | - containers 36 | - data-default 37 | - directory 38 | - extra 39 | - filepath 40 | - generics-eot 41 | - ghc-typelits-knownnat 42 | - ghc-typelits-natnormalise 43 | - half 44 | - hashable 45 | - hashtables 46 | - haskell-src-exts 47 | - haskell-src-meta 48 | - hostname 49 | - ieee754 50 | - interpolateIO 51 | - matplotlib 52 | - megaparsec 53 | - microlens 54 | - monad-control 55 | - monad-loops 56 | - mtl 57 | - parser-combinators 58 | - pipes 59 | - pipes-aeson 60 | - pipes-bytestring 61 | - pipes-concurrency 62 | - pipes-csv 63 | - pipes-extras 64 | - pipes-group 65 | - pipes-parse 66 | - pipes-safe 67 | - proto-lens 68 | - random 69 | - safe-exceptions 70 | - shelly 71 | - singletons 72 | - statistics 73 | - string-qq 74 | - syb 75 | - template-haskell 76 | - text 77 | - time 78 | - unix 79 | - vector 80 | - yaml 81 | - zlib 82 | - monad-logger 83 | - docopt 84 | - default-type-plugin 85 | - simplify-nat-algebra-plugin 86 | - temporary 87 | - generics-sop 88 | - deepseq 89 | - ghc-prim 90 | - barbies 91 | - vector-algorithms 92 | - prettyprinter 93 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "EventProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.util"; 8 | 9 | import "tensorboard/src/summary.proto"; 10 | 11 | // Protocol buffer representing an event that happened during 12 | // the execution of a Brain model. 13 | message Event { 14 | // Timestamp of the event. 15 | double wall_time = 1; 16 | 17 | // Global step of the event. 18 | int64 step = 2; 19 | 20 | oneof what { 21 | // An event file was started, with the specified version. 22 | // This is use to identify the contents of the record IO files 23 | // easily. Current version is "brain.Event:2". All versions 24 | // start with "brain.Event:". 25 | string file_version = 3; 26 | // An encoded version of a GraphDef. 27 | bytes graph_def = 4; 28 | // A summary was generated. 29 | Summary summary = 5; 30 | // The user output a log message. Not all messages are logged, only ones 31 | // generated via the Python tensorboard_logging module. 32 | LogMessage log_message = 6; 33 | // The state of the session which can be used for restarting after crashes. 34 | SessionLog session_log = 7; 35 | // The metadata returned by running a session.run() call. 36 | TaggedRunMetadata tagged_run_metadata = 8; 37 | // An encoded version of a MetaGraphDef. 38 | bytes meta_graph_def = 9; 39 | } 40 | } 41 | 42 | // Protocol buffer used for logging messages to the events file. 43 | message LogMessage { 44 | enum Level { 45 | UNKNOWN = 0; 46 | DEBUG = 10; 47 | INFO = 20; 48 | WARN = 30; 49 | ERROR = 40; 50 | FATAL = 50; 51 | } 52 | Level level = 1; 53 | string message = 2; 54 | } 55 | 56 | // Protocol buffer used for logging session state. 57 | message SessionLog { 58 | enum SessionStatus { 59 | STATUS_UNSPECIFIED = 0; 60 | START = 1; 61 | STOP = 2; 62 | CHECKPOINT = 3; 63 | } 64 | 65 | SessionStatus status = 1; 66 | // This checkpoint_path contains both the path and filename. 67 | string checkpoint_path = 2; 68 | string msg = 3; 69 | } 70 | 71 | // For logging the metadata output for a single session.run() call. 72 | message TaggedRunMetadata { 73 | // Tag name associated with this metadata. 74 | string tag = 1; 75 | // Byte-encoded version of the `RunMetadata` proto in order to allow lazy 76 | // deserialization. 77 | bytes run_metadata = 2; 78 | } 79 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T02_LinearRegression.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, ConstraintKinds, DataKinds, ExtendedDefaultRules, FlexibleContexts, FlexibleInstances, GADTs #-} 2 | {-# LANGUAGE OverloadedLists, OverloadedStrings, PolyKinds, QuasiQuotes, RankNTypes, ScopedTypeVariables, TemplateHaskell #-} 3 | {-# LANGUAGE TypeApplications, TypeFamilies, TypeFamilyDependencies, TypeInType, TypeOperators, UndecidableInstances #-} 4 | module Torch.Tutorial.Intro.T02_LinearRegression where 5 | import Data.Default 6 | import Data.String.InterpolateIO 7 | import Graphics.Matplotlib (o1, o2, (%), (@@)) 8 | import qualified Graphics.Matplotlib as M 9 | import Pipes 10 | import qualified Pipes.Prelude as P 11 | import Torch 12 | import Torch.Datasets 13 | 14 | ex = do 15 | -- hyperparameters 16 | let epochs = 60 17 | let learningRate = 0.001 18 | -- 19 | let train = 20 | ((yield $ 21 | let x = [3.3, 4.4, 5.5, 6.71, 6.93, 4.168 22 | ,9.779, 6.182, 7.59, 2.167, 7.042 23 | ,10.791, 5.313, 7.997, 3.1] 24 | y = [1.7, 2.76, 2.09, 3.19, 1.694, 1.573 25 | ,3.366, 2.596, 2.53, 1.221, 2.827 26 | ,3.465, 1.65, 2.904, 1.3] 27 | in DataSample @Train () (fromVector x) (fromVector y)) 28 | :: Producer (DataSample 'Train () 29 | (Tensor TDouble KCpu '[15,1]) 30 | (Tensor TDouble KCpu '[15,1])) 31 | IO ()) 32 | w <- gradP 33 | let model = linear (inFeatures_ @1) (outFeatures_ @1) w 34 | let criterion y = mseLoss y def 35 | params <- toParameters w 36 | mapM_ (\epoch -> do 37 | zeroGradients_ params 38 | loss <- lossForEachData (\d -> do 39 | o <- dataObject d 40 | l <- dataLabel d 41 | pred <- model o 42 | criterion l pred) train 43 | backward1 loss False False 44 | _ <- step_ (sgd (def { sgdLearningRate = learningRate }) params) 45 | putStrLn =<< [c|Epoch #{epoch}/#{epochs} loss #{loss}|]) 46 | [0..epochs-1] 47 | (Just e) <- P.head train 48 | xs <- dataObject e >>= toVector 49 | ys <- dataLabel e >>= toVector 50 | ysPred <- dataObject e >>= model >>= toVector 51 | M.onscreen $ M.plot xs ysPred @@ [o1 "go-", o2 "linewidth" 2] 52 | % M.plot xs ys @@ [o1 "ro"] 53 | pure () 54 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export WITH_JUPYTER=NO 4 | export WITH_CUDA=IF_PRESENT 5 | export CONDA_ENV=haskell-torch 6 | export QUICK_GHC=NO 7 | export FAST= 8 | 9 | while test $# -gt 0; do 10 | case "$1" in 11 | -h|--help) 12 | echo "$(basename $0) - set up HaskellTorch" 13 | echo " by default CUDA will be used if present and jupyter will be installed" 14 | echo " " 15 | echo "$(basename $0) [options]" 16 | echo " " 17 | echo "options:" 18 | echo "-h, --help show brief help" 19 | echo "--with-jupyter install jupyter" 20 | echo "--without-jupyter don't install jupyter" 21 | echo "--with-cuda install CUDA" 22 | echo "--without-cuda don't install CUDA" 23 | echo "--in-conda-base install in base conda, not haskell-torch" 24 | echo "--quick-ghc fast ghc builds, but the compiler will be rather slow" 25 | echo "--fast like stack --fast disables optimizations for app code" 26 | exit 0 27 | ;; 28 | --with-jupyter) 29 | export WITH_JUPYTER=YES 30 | shift 31 | ;; 32 | --without-jupyter) 33 | export WITH_JUPYTER=NO 34 | shift 35 | ;; 36 | --with-cuda) 37 | export WITH_CUDA=YES 38 | shift 39 | ;; 40 | --without-cuda) 41 | export WITH_CUDA=NO 42 | shift 43 | ;; 44 | --in-conda-base) 45 | export CONDA_ENV=base 46 | shift 47 | ;; 48 | --quick-ghc) 49 | export QUICK_GHC=YES 50 | shift 51 | ;; 52 | --fast) 53 | export FAST=--fast 54 | shift 55 | ;; 56 | *) 57 | break 58 | ;; 59 | esac 60 | done 61 | 62 | bash scripts/setup-initial.sh "$@" 63 | bash scripts/setup-conda.sh "$@" 64 | bash scripts/setup-haskell.sh "$@" 65 | bash scripts/setup-jupyter.sh "$@" 66 | 67 | echo "======================================================================" 68 | echo " Haskell-Torch is set up!" 69 | echo "" 70 | echo "Configured with:" 71 | echo " WITH_JUPYTER=$WITH_JUPYTER" 72 | echo " WITH_CUDA=$WITH_CUDA" 73 | echo " QUICK_GHC=$QUICK_GHC" 74 | echo " FAST=$FAST" 75 | echo "======================================================================" 76 | echo 77 | echo " Check above to see if you have CUDA support" 78 | echo " If you want to regenerate the bindings against a new PyTorch, run make" 79 | echo " Next up activate the conda environment. You can later build the code with:" 80 | echo " conda activate haskell-torch && stack build ${FAST}" 81 | -------------------------------------------------------------------------------- /haskell-torch-models/haskell-torch-models.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-models 8 | version: 0.8.0.0 9 | synopsis: Deep learning in Haskell on top of Torch and PyTorch 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Torch.Models.Vision.AlexNet 29 | Torch.Models.Vision.ResNet 30 | Torch.Models.Vision.VGG19 31 | other-modules: 32 | Paths_haskell_torch_models 33 | hs-source-dirs: 34 | src 35 | build-depends: 36 | aeson 37 | , array 38 | , base >=4.7 && <5 39 | , binary 40 | , bytestring 41 | , containers 42 | , data-default 43 | , default-type-plugin 44 | , directory 45 | , docopt 46 | , extra 47 | , filepath 48 | , generics-eot 49 | , ghc-typelits-knownnat 50 | , ghc-typelits-natnormalise 51 | , half 52 | , hashable 53 | , hashtables 54 | , haskell-src-exts 55 | , haskell-src-meta 56 | , haskell-torch 57 | , haskell-torch-datasets 58 | , haskell-torch-imagemagick 59 | , haskell-torch-matio 60 | , haskell-torch-tensorboard-proto 61 | , hostname 62 | , ieee754 63 | , interpolateIO 64 | , matplotlib 65 | , megaparsec 66 | , microlens 67 | , monad-control 68 | , monad-logger 69 | , monad-loops 70 | , mtl 71 | , parser-combinators 72 | , pipes 73 | , pipes-aeson 74 | , pipes-bytestring 75 | , pipes-concurrency 76 | , pipes-csv 77 | , pipes-extras 78 | , pipes-group 79 | , pipes-parse 80 | , pipes-safe 81 | , proto-lens 82 | , random 83 | , safe-exceptions 84 | , shelly 85 | , simplify-nat-algebra-plugin 86 | , singletons 87 | , statistics 88 | , string-qq 89 | , syb 90 | , template-haskell 91 | , temporary 92 | , text 93 | , time 94 | , unix 95 | , vector 96 | , yaml 97 | , zlib 98 | default-language: Haskell2010 99 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Tensorboard.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, CPP, ConstraintKinds, DataKinds, EmptyCase, FlexibleContexts, FlexibleInstances #-} 2 | {-# LANGUAGE FunctionalDependencies, GADTs, KindSignatures, MultiParamTypeClasses, OverloadedLabels, OverloadedStrings #-} 3 | {-# LANGUAGE PartialTypeSignatures, PolyKinds, QuasiQuotes, RankNTypes, ScopedTypeVariables, TemplateHaskell, TypeApplications #-} 4 | {-# LANGUAGE TypeFamilies, TypeFamilyDependencies, TypeInType, TypeOperators, UndecidableInstances #-} 5 | {-# options_ghc -fplugin GHC.TypeLits.KnownNat.Solver #-} 6 | 7 | module Torch.Tutorial.Tensorboard where 8 | import Control.Monad 9 | import Control.Monad.Logger 10 | import Data.Default 11 | import Data.ProtoLens.Labels () 12 | import Torch as T 13 | import Torch.Datasets.Vision.MNIST 14 | 15 | ex = do 16 | sw <- summaryWriter "/tmp/qlog" "tester" 17 | (tr, te) <- mnist "datasets/image/" 18 | (Right test') <- fetchDataset te 19 | let test = transformObjectStream (T.view @'[1,28,28] 20 | >=> greyTensorToImage 21 | >=> randomHorizontalFlip_ 0.5 22 | >=> randomVerticalFlip_ 0.5 23 | >=> randomContrastJitter_ 100 24 | >=> randomBrightnessJitter_ 100 25 | >=> greyImageToTensor @28 @28 @TFloat 26 | ) test' 27 | writeEvent sw (EventMessage LevelInfo "Woof") 28 | t0 <- randn @TFloat @KCpu @'[100] 29 | x <- typed @TFloat <$> stored @KCpu <$> sized (size_ @'[10,3]) <$> randn 30 | y <- sized (size_ @'[10,2]) <$> randn 31 | w1 <- noGradP 32 | w2 <- noGradP 33 | let model = linear (inFeatures_ @3) (outFeatures_ @10) w1 34 | >=> relu 35 | >=> linear (inFeatures_ @10) (outFeatures_ @2) w2 36 | let criterion = mseLoss y def 37 | (loss, trace) <- withTracing [AnyTensor x, AnyTensor y] $ do 38 | pred <- model x 39 | criterion pred 40 | printTrace trace 41 | addGraph sw "grph" trace 42 | forEachDataUntil 43 | (\step _ -> pure (step == 10)) 44 | (\_ _ stream -> pure stream) 45 | (\step epoch value -> do 46 | addScalar sw "bloop/blimp" 3.0 47 | addHistogram sw "bloop/h1" t0 48 | t <- dataObject value ..*@ pure 255 49 | addImageGrey sw "bloop/mnist" t 50 | t0 .*=@ 1.2 51 | nextStep sw) 52 | test 53 | pure () 54 | -------------------------------------------------------------------------------- /dataframe/dataframe.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: dataframe 8 | version: 0.8.0.0 9 | synopsis: Haskell data frames, the equivalent of Pandas 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Barbies.FieldName 29 | Barbies.TH 30 | Barbies.TH.Config 31 | Generics.SOP.Record 32 | Generics.SOP.Record.Combination 33 | Generics.SOP.Record.SubTyping 34 | Text.Tabulate 35 | other-modules: 36 | Paths_dataframe 37 | hs-source-dirs: 38 | src 39 | build-depends: 40 | aeson 41 | , array 42 | , barbies 43 | , base >=4.7 && <5 44 | , binary 45 | , boxes 46 | , bytestring 47 | , cassava 48 | , containers 49 | , data-default 50 | , deepseq 51 | , directory 52 | , extra 53 | , filepath 54 | , generics-eot 55 | , generics-sop 56 | , ghc-prim 57 | , hashable 58 | , hashtables 59 | , haskell-src-exts 60 | , haskell-src-meta 61 | , hostname 62 | , ieee754 63 | , interpolateIO 64 | , lens 65 | , matplotlib 66 | , megaparsec 67 | , microlens 68 | , monad-control 69 | , monad-logger 70 | , monad-loops 71 | , mtl 72 | , parser-combinators 73 | , pipes 74 | , pipes-aeson 75 | , pipes-bytestring 76 | , pipes-concurrency 77 | , pipes-csv 78 | , pipes-extras 79 | , pipes-group 80 | , pipes-parse 81 | , pipes-safe 82 | , pptable 83 | , prettyprinter 84 | , proto-lens 85 | , random 86 | , safe-exceptions 87 | , shelly 88 | , singletons 89 | , split 90 | , statistics 91 | , string-qq 92 | , syb 93 | , template-haskell 94 | , temporary 95 | , text 96 | , time 97 | , unix 98 | , vector 99 | , vector-algorithms 100 | , yaml 101 | , zlib 102 | default-language: Haskell2010 103 | -------------------------------------------------------------------------------- /interpolateIO/interpolateIO.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.31.1. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | -- 7 | -- hash: 9446b16930763d30ef397e2d48f1e06d77b41005867a420a38d4bca2798f56bc 8 | 9 | name: interpolateIO 10 | version: 0.2.0 11 | synopsis: String interpolation in IO based on interpolate 12 | description: String interpolation in IO based on interpolate 13 | category: Data, Text 14 | stability: experimental 15 | homepage: https://github.com/abarbu/interpolateIO#readme 16 | bug-reports: https://github.com/abarbu/interpolateIO/issues 17 | author: Andrei Barbu 18 | maintainer: Andrei Barbu 19 | copyright: (c) 2018 Andrei Barbu, 2013-2015 Simon Hengel 20 | license: MIT 21 | license-file: LICENSE 22 | build-type: Simple 23 | 24 | source-repository head 25 | type: git 26 | location: https://github.com/abarbu/interpolateIO 27 | 28 | library 29 | hs-source-dirs: 30 | src 31 | ghc-options: -Wall 32 | build-depends: 33 | base ==4.* 34 | , generics-eot 35 | , haskell-src-meta >=0.8 36 | , template-haskell 37 | exposed-modules: 38 | Data.String.InterpolateIO 39 | Data.String.InterpolateIO.IsString 40 | Data.String.InterpolateIO.Util 41 | Data.String.InterpolateIO.Internal.Util 42 | Data.String.ShowIO 43 | other-modules: 44 | Data.String.InterpolateIO.Compat 45 | Data.String.InterpolateIO.Parse 46 | Paths_interpolateIO 47 | default-language: Haskell2010 48 | 49 | test-suite spec 50 | type: exitcode-stdio-1.0 51 | main-is: Spec.hs 52 | hs-source-dirs: 53 | src 54 | test 55 | ghc-options: -Wall 56 | build-depends: 57 | QuickCheck 58 | , base ==4.* 59 | , base-compat 60 | , bytestring 61 | , generics-eot 62 | , haskell-src-meta >=0.8 63 | , hspec >=1.5 64 | , quickcheck-instances 65 | , template-haskell 66 | , text 67 | other-modules: 68 | Data.String.InterpolateIO 69 | Data.String.InterpolateIO.Compat 70 | Data.String.InterpolateIO.Internal.Util 71 | Data.String.InterpolateIO.IsString 72 | Data.String.InterpolateIO.Parse 73 | Data.String.InterpolateIO.Util 74 | Data.String.ShowIO 75 | Data.String.InterpolateIO.Internal.UtilSpec 76 | Data.String.InterpolateIO.IsStringSpec 77 | Data.String.InterpolateIO.ParseSpec 78 | Data.String.InterpolateIO.UtilSpec 79 | Data.String.InterpolateIOSpec 80 | Paths_interpolateIO 81 | default-language: Haskell2010 82 | -------------------------------------------------------------------------------- /interpolateIO/src/Data/String/InterpolateIO/Util.hs: -------------------------------------------------------------------------------- 1 | module Data.String.InterpolateIO.Util (unindent) where 2 | 3 | import Control.Arrow ((>>>)) 4 | import Data.Char 5 | 6 | -- | Remove indentation as much as possible while preserving relative 7 | -- indentation levels. 8 | -- 9 | -- `unindent` is useful in combination with `Data.String.Interpolate.c` to remove leading spaces that 10 | -- resulted from code indentation. That way you can freely indent your string 11 | -- literals without the indentation ending up in the resulting strings. 12 | -- 13 | -- Here is an example: 14 | -- 15 | -- >>> :set -XQuasiQuotes 16 | -- >>> import Data.String.Interpolate 17 | -- >>> import Data.String.Interpolate.Util 18 | -- >>> :{ 19 | -- putStr $ unindent [i| 20 | -- def foo 21 | -- 23 22 | -- end 23 | -- |] 24 | -- :} 25 | -- def foo 26 | -- 23 27 | -- end 28 | -- 29 | -- To allow this, two additional things are being done, apart from removing 30 | -- indentation: 31 | -- 32 | -- - One empty line at the beginning will be removed and 33 | -- - if the last newline character (@"\\n"@) is followed by spaces, the spaces are removed. 34 | unindent :: String -> String 35 | unindent = 36 | lines_ 37 | >>> removeLeadingEmptyLine 38 | >>> trimLastLine 39 | >>> removeIndentation 40 | >>> concat 41 | where 42 | isEmptyLine :: String -> Bool 43 | isEmptyLine = all isSpace 44 | 45 | lines_ :: String -> [String] 46 | lines_ [] = [] 47 | lines_ s = case span (/= '\n') s of 48 | (first, '\n' : rest) -> (first ++ "\n") : lines_ rest 49 | (first, rest) -> first : lines_ rest 50 | 51 | removeLeadingEmptyLine :: [String] -> [String] 52 | removeLeadingEmptyLine xs = case xs of 53 | y:ys | isEmptyLine y -> ys 54 | _ -> xs 55 | 56 | trimLastLine :: [String] -> [String] 57 | trimLastLine (a : b : r) = a : trimLastLine (b : r) 58 | trimLastLine [a] = if all (== ' ') a 59 | then [] 60 | else [a] 61 | trimLastLine [] = [] 62 | 63 | removeIndentation :: [String] -> [String] 64 | removeIndentation ys = map (dropSpaces indentation) ys 65 | where 66 | dropSpaces 0 s = s 67 | dropSpaces n (' ' : r) = dropSpaces (n - 1) r 68 | dropSpaces _ s = s 69 | indentation = minimalIndentation ys 70 | minimalIndentation = 71 | safeMinimum 0 72 | . map (length . takeWhile (== ' ')) 73 | . removeEmptyLines 74 | removeEmptyLines = filter (not . isEmptyLine) 75 | 76 | safeMinimum :: Ord a => a -> [a] -> a 77 | safeMinimum x xs = case xs of 78 | [] -> x 79 | _ -> minimum xs 80 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/attr_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "AttrValueProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboard/src/tensor.proto"; 10 | import "tensorboard/src/tensor_shape.proto"; 11 | import "tensorboard/src/types.proto"; 12 | 13 | // Protocol buffer representing the value for an attr used to configure an Op. 14 | // Comment indicates the corresponding attr type. Only the field matching the 15 | // attr type may be filled. 16 | message AttrValue { 17 | // LINT.IfChange 18 | message ListValue { 19 | repeated bytes s = 2; // "list(string)" 20 | repeated int64 i = 3 [packed = true]; // "list(int)" 21 | repeated float f = 4 [packed = true]; // "list(float)" 22 | repeated bool b = 5 [packed = true]; // "list(bool)" 23 | repeated DataType type = 6 [packed = true]; // "list(type)" 24 | repeated TensorShapeProto shape = 7; // "list(shape)" 25 | repeated TensorProto tensor = 8; // "list(tensor)" 26 | repeated NameAttrList func = 9; // "list(attr)" 27 | } 28 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) 29 | 30 | oneof value { 31 | bytes s = 2; // "string" 32 | int64 i = 3; // "int" 33 | float f = 4; // "float" 34 | bool b = 5; // "bool" 35 | DataType type = 6; // "type" 36 | TensorShapeProto shape = 7; // "shape" 37 | TensorProto tensor = 8; // "tensor" 38 | ListValue list = 1; // any "list(...)" 39 | 40 | // "func" represents a function. func.name is a function's name or 41 | // a primitive op's name. func.attr.first is the name of an attr 42 | // defined for that function. func.attr.second is the value for 43 | // that attr in the instantiation. 44 | NameAttrList func = 10; 45 | 46 | // This is a placeholder only used in nodes defined inside a 47 | // function. It indicates the attr value will be supplied when 48 | // the function is instantiated. For example, let us suppose a 49 | // node "N" in function "FN". "N" has an attr "A" with value 50 | // placeholder = "foo". When FN is instantiated with attr "foo" 51 | // set to "bar", the instantiated node N's attr A will have been 52 | // given the value "bar". 53 | string placeholder = 9; 54 | } 55 | } 56 | 57 | // A list of attr names and their values. The whole list is attached 58 | // with a string name. E.g., MatMul[T=float]. 59 | message NameAttrList { 60 | string name = 1; 61 | map attr = 2; 62 | } 63 | -------------------------------------------------------------------------------- /haskell-torch-datasets/haskell-torch-datasets.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-datasets 8 | version: 0.8.0.0 9 | synopsis: Deep learning in Haskell on top of Torch and PyTorch 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Torch.Datasets 29 | Torch.Datasets.Vision.CIFAR 30 | Torch.Datasets.Vision.MNIST 31 | other-modules: 32 | Paths_haskell_torch_datasets 33 | hs-source-dirs: 34 | src 35 | build-depends: 36 | aeson 37 | , array 38 | , barbies 39 | , base >=4.7 && <5 40 | , binary 41 | , bytestring 42 | , containers 43 | , data-default 44 | , deepseq 45 | , default-type-plugin 46 | , directory 47 | , docopt 48 | , extra 49 | , filepath 50 | , generics-eot 51 | , generics-sop 52 | , ghc-prim 53 | , ghc-typelits-knownnat 54 | , ghc-typelits-natnormalise 55 | , half 56 | , hashable 57 | , hashtables 58 | , haskell-src-exts 59 | , haskell-src-meta 60 | , haskell-torch 61 | , haskell-torch-imagemagick 62 | , haskell-torch-matio 63 | , haskell-torch-tensorboard-proto 64 | , hostname 65 | , ieee754 66 | , interpolateIO 67 | , matplotlib 68 | , megaparsec 69 | , microlens 70 | , monad-control 71 | , monad-logger 72 | , monad-loops 73 | , mtl 74 | , parser-combinators 75 | , pipes 76 | , pipes-aeson 77 | , pipes-bytestring 78 | , pipes-concurrency 79 | , pipes-csv 80 | , pipes-extras 81 | , pipes-group 82 | , pipes-parse 83 | , pipes-safe 84 | , prettyprinter 85 | , proto-lens 86 | , random 87 | , safe-exceptions 88 | , shelly 89 | , simplify-nat-algebra-plugin 90 | , singletons 91 | , statistics 92 | , string-qq 93 | , syb 94 | , template-haskell 95 | , temporary 96 | , text 97 | , time 98 | , unix 99 | , vector 100 | , vector-algorithms 101 | , yaml 102 | , zlib 103 | default-language: Haskell2010 104 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/node_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "NodeProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboard/src/attr_value.proto"; 10 | 11 | message NodeDef { 12 | // The name given to this operator. Used for naming inputs, 13 | // logging, visualization, etc. Unique within a single GraphDef. 14 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". 15 | string name = 1; 16 | 17 | // The operation name. There may be custom parameters in attrs. 18 | // Op names starting with an underscore are reserved for internal use. 19 | string op = 2; 20 | 21 | // Each input is "node:src_output" with "node" being a string name and 22 | // "src_output" indicating which output tensor to use from "node". If 23 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs 24 | // may optionally be followed by control inputs that have the format 25 | // "^node". 26 | repeated string input = 3; 27 | 28 | // A (possibly partial) specification for the device on which this 29 | // node should be placed. 30 | // The expected syntax for this string is as follows: 31 | // 32 | // DEVICE_SPEC ::= PARTIAL_SPEC 33 | // 34 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) * 35 | // CONSTRAINT ::= ("job:" JOB_NAME) 36 | // | ("replica:" [1-9][0-9]*) 37 | // | ("task:" [1-9][0-9]*) 38 | // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) 39 | // 40 | // Valid values for this string include: 41 | // * "/job:worker/replica:0/task:1/gpu:3" (full specification) 42 | // * "/job:worker/gpu:3" (partial specification) 43 | // * "" (no specification) 44 | // 45 | // If the constraints do not resolve to a single device (or if this 46 | // field is empty or not present), the runtime will attempt to 47 | // choose a device automatically. 48 | string device = 4; 49 | 50 | // Operation-specific graph-construction-time configuration. 51 | // Note that this should include all attrs defined in the 52 | // corresponding OpDef, including those with a value matching 53 | // the default -- this allows the default to change and makes 54 | // NodeDefs easier to interpret on their own. However, if 55 | // an attr with a default is not specified in this list, the 56 | // default will be used. 57 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and 58 | // one of the names from the corresponding OpDef's attr field). 59 | // The values must have a type matching the corresponding OpDef 60 | // attr's type field. 61 | // TODO(josh11b): Add some examples here showing best practices. 62 | map attr = 5; 63 | }; 64 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/haskell-torch-tensorboard-proto.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.24 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-tensorboard-proto 8 | version: 0.1.0.0 9 | synopsis: Talk to tensorboard from Haskell; the protobuf bindings 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Custom 19 | extra-source-files: 20 | proto/tensorboard/src/attr_value.proto 21 | proto/tensorboard/src/event.proto 22 | proto/tensorboard/src/graph.proto 23 | proto/tensorboard/src/node_def.proto 24 | proto/tensorboard/src/onnx.proto 25 | proto/tensorboard/src/resource_handle.proto 26 | proto/tensorboard/src/summary.proto 27 | proto/tensorboard/src/tensor.proto 28 | proto/tensorboard/src/tensor_shape.proto 29 | proto/tensorboard/src/types.proto 30 | proto/tensorboard/src/versions.proto 31 | 32 | source-repository head 33 | type: git 34 | location: https://github.com/abarbu/haskell-torch 35 | 36 | custom-setup 37 | setup-depends: 38 | Cabal 39 | , base 40 | , proto-lens-setup 41 | 42 | library 43 | exposed-modules: 44 | Tensorboard.Proto.Attributes 45 | Tensorboard.Proto.Event 46 | Tensorboard.Proto.Graph 47 | Tensorboard.Proto.Summary 48 | Tensorboard.Proto.Tensor 49 | other-modules: 50 | Proto.Tensorboard.Src.Summary 51 | Proto.Tensorboard.Src.Summary_Fields 52 | Proto.Tensorboard.Src.Tensor 53 | Proto.Tensorboard.Src.Tensor_Fields 54 | Proto.Tensorboard.Src.Types 55 | Proto.Tensorboard.Src.Types_Fields 56 | Proto.Tensorboard.Src.ResourceHandle 57 | Proto.Tensorboard.Src.ResourceHandle_Fields 58 | Proto.Tensorboard.Src.TensorShape 59 | Proto.Tensorboard.Src.TensorShape_Fields 60 | Proto.Tensorboard.Src.Event 61 | Proto.Tensorboard.Src.Event_Fields 62 | Proto.Tensorboard.Src.Graph 63 | Proto.Tensorboard.Src.Graph_Fields 64 | Proto.Tensorboard.Src.NodeDef 65 | Proto.Tensorboard.Src.NodeDef_Fields 66 | Proto.Tensorboard.Src.Versions 67 | Proto.Tensorboard.Src.Versions_Fields 68 | Proto.Tensorboard.Src.AttrValue 69 | Proto.Tensorboard.Src.AttrValue_Fields 70 | hs-source-dirs: 71 | src 72 | build-depends: 73 | base >=4.7 && <5 74 | , data-default 75 | , microlens 76 | , proto-lens 77 | , proto-lens-protoc 78 | , proto-lens-runtime 79 | , text 80 | default-language: Haskell2010 81 | -------------------------------------------------------------------------------- /interpolateIO/test/Data/String/InterpolateIO/Internal/UtilSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | module Data.String.InterpolateIO.Internal.UtilSpec where 3 | 4 | import Test.Hspec 5 | import Test.QuickCheck 6 | import Test.QuickCheck.Instances () 7 | 8 | import qualified Data.Text as T 9 | import qualified Data.Text.Lazy as LT 10 | import qualified Data.ByteString.Char8 as B 11 | import qualified Data.ByteString.Lazy.Char8 as LB 12 | 13 | import Data.String.InterpolateIO.Internal.Util 14 | 15 | main :: IO () 16 | main = hspec spec 17 | 18 | spec :: Spec 19 | spec = do 20 | describe "toString" $ do 21 | it "behaves like `show`" $ do 22 | property $ \n -> toString (n :: Int) `shouldBe` show n 23 | 24 | context "when used with String" $ do 25 | it "behaves like `id`" $ do 26 | property $ \s -> toString s `shouldBe` s 27 | 28 | context "when used with Text" $ do 29 | it "behaves like `unpack`" $ do 30 | property $ \s -> toString s `shouldBe` T.unpack s 31 | 32 | context "when used with lazy Text" $ do 33 | it "behaves like `unpack`" $ do 34 | property $ \s -> toString s `shouldBe` LT.unpack s 35 | 36 | context "when used with ByteString" $ do 37 | it "behaves like `unpack`" $ do 38 | property $ \s -> toString s `shouldBe` B.unpack s 39 | 40 | context "when used with lazy ByteString" $ do 41 | it "behaves like `unpack`" $ do 42 | property $ \s -> do 43 | #if __GLASGOW_HASKELL__ < 706 44 | pendingWith "Does not work with GHC < 7.6" 45 | #endif 46 | toString s `shouldBe` LB.unpack s 47 | 48 | describe "unescape" $ do 49 | it "unescapes single-character escape codes" $ do 50 | unescape "\\n" `shouldBe` "\n" 51 | 52 | it "unescapes ASCII control code abbreviations" $ do 53 | unescape "\\BEL" `shouldBe` "\BEL" 54 | 55 | it "unescapes decimal character literals" $ do 56 | unescape "\\955" `shouldBe` "\955" 57 | 58 | it "unescapes hexadecimal character literals" $ do 59 | unescape "\\xbeef" `shouldBe` "\xbeef" 60 | 61 | it "unescapes octal character literals" $ do 62 | unescape "\\o1234" `shouldBe` "\o1234" 63 | 64 | context "with control escape sequences" $ do 65 | it "unescapes null character" $ do 66 | unescape "\\^@" `shouldBe` "\^@" 67 | 68 | it "unescapes control codes" $ do 69 | unescape "\\^A" `shouldBe` "\^A" 70 | 71 | it "unescapes escape" $ do 72 | unescape "\\^[" `shouldBe` "\^[" 73 | 74 | it "unescapes file separator" $ do 75 | unescape "\\^\\ x" `shouldBe` "\^\ x" 76 | 77 | it "unescapes group separator" $ do 78 | unescape "\\^]" `shouldBe` "\^]" 79 | 80 | it "unescapes record separator" $ do 81 | unescape "\\^^" `shouldBe` "\^^" 82 | 83 | it "unescapes unit separator" $ do 84 | unescape "\\^_" `shouldBe` "\^_" 85 | -------------------------------------------------------------------------------- /interpolateIO/src/Data/String/InterpolateIO.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | module Data.String.InterpolateIO ( 3 | -- * String interpolation done right for ShowIO 4 | -- | 5 | -- The examples in this module use `QuasiQuotes`. Make sure to enable the 6 | -- corresponding language extension. 7 | -- 8 | -- >>> :set -XQuasiQuotes 9 | -- >>> import Data.String.Interpolate 10 | -- >>> import System.Environment(getEnv, setEnv) 11 | c, toStringShowStringIO 12 | ) where 13 | 14 | import Language.Haskell.TH.Quote (QuasiQuoter(..)) 15 | import Language.Haskell.Meta.Parse (parseExp) 16 | 17 | import Data.String.InterpolateIO.Internal.Util 18 | import Data.String.InterpolateIO.Parse 19 | import Data.String.InterpolateIO.Compat (Q, Exp, appE) 20 | import Data.String.ShowIO 21 | 22 | -- | 23 | -- A `QuasiQuoter` for string interpolation. Expression enclosed within 24 | -- @#{...}@ are interpolated, the result has to be in the `Show` class. 25 | -- 26 | -- It interpolates strings in IO 27 | -- 28 | -- >>> setEnv "TESTVAR" "XYZ" 29 | -- >>> putStrLn =<< [c|lang: #{getEnv "TESTVAR"}|] 30 | -- lang: XYZ 31 | -- 32 | -- or integers that are pure 33 | -- 34 | -- >>> let age = 23 35 | -- >>> putStrLn =<< [c|age: #{age}|] 36 | -- age: 23 37 | -- 38 | -- or arbitrary Haskell pure or IO expressions 39 | -- 40 | -- >>> let profession = "\955-scientist" 41 | -- >>> putStrLn =<< [c|profession: #{unwords [name, "the", profession]}|] 42 | -- profession: Marvin the λ-scientist 43 | c :: QuasiQuoter 44 | c = QuasiQuoter { 45 | quoteExp = toExp . parseNodes . decodeNewlines 46 | , quotePat = err "pattern" 47 | , quoteType = err "type" 48 | , quoteDec = err "declaration" 49 | } 50 | where 51 | err name = error ("Data.String.Interpolate.i: This QuasiQuoter can not be used as a " ++ name ++ "!") 52 | 53 | toExp:: [Node] -> Q Exp 54 | toExp nodes = case nodes of 55 | [] -> [|pure ""|] 56 | (x:xs) -> f x `appE` toExp xs 57 | where 58 | -- f (Literal s) = [|(\z -> showStringIO s =<< z)|] 59 | f (Literal s) = [|(showStringIO s =<<)|] 60 | -- f (Expression e) = [|((=<<) . showStringIO . toStringIO) $(reifyExpression e)|] 61 | -- f (Expression e) = [|(\z_ -> toStringIO $(reifyExpression e) >>= \y_ -> (showStringIO y_ =<< z_))|] 62 | f (Expression e) = [|toStringShowStringIO $(reifyExpression e)|] 63 | 64 | reifyExpression :: String -> Q Exp 65 | reifyExpression s = case parseExp s of 66 | Left _ -> do 67 | fail "Parse error in expression!" :: Q Exp 68 | Right e -> return e 69 | 70 | toStringShowStringIO :: ShowIO a => a -> IO String -> IO String 71 | toStringShowStringIO e n = toStringIO e >>= \y -> (showStringIO y =<< n) 72 | 73 | decodeNewlines :: String -> String 74 | decodeNewlines = go 75 | where 76 | go xs = case xs of 77 | '\r' : '\n' : ys -> '\n' : go ys 78 | y : ys -> y : go ys 79 | [] -> [] 80 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboard/src/resource_handle.proto"; 10 | import "tensorboard/src/tensor_shape.proto"; 11 | import "tensorboard/src/types.proto"; 12 | 13 | // Protocol buffer representing a tensor. 14 | message TensorProto { 15 | DataType dtype = 1; 16 | 17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 18 | TensorShapeProto tensor_shape = 2; 19 | 20 | // Only one of the representations below is set, one of "tensor_contents" and 21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 22 | // contain repeated fields it would require another extra set of messages. 23 | 24 | // Version number. 25 | // 26 | // In version 0, if the "repeated xxx" representations contain only one 27 | // element, that element is repeated to fill the shape. This makes it easy 28 | // to represent a constant Tensor with a single value. 29 | int32 version_number = 3; 30 | 31 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or 32 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation 33 | // can be used for all tensor types. The purpose of this representation is to 34 | // reduce serialization overhead during RPC call by avoiding serialization of 35 | // many repeated small items. 36 | bytes tensor_content = 4; 37 | 38 | // Type specific representations that make it easy to create tensor protos in 39 | // all languages. Only the representation corresponding to "dtype" can 40 | // be set. The values hold the flattened representation of the tensor in 41 | // row major order. 42 | 43 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some 44 | // pointless zero padding for each value here. 45 | repeated int32 half_val = 13 [packed = true]; 46 | 47 | // DT_FLOAT. 48 | repeated float float_val = 5 [packed = true]; 49 | 50 | // DT_DOUBLE. 51 | repeated double double_val = 6 [packed = true]; 52 | 53 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 54 | repeated int32 int_val = 7 [packed = true]; 55 | 56 | // DT_STRING 57 | repeated bytes string_val = 8; 58 | 59 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 60 | // and imaginary parts of i-th single precision complex. 61 | repeated float scomplex_val = 9 [packed = true]; 62 | 63 | // DT_INT64 64 | repeated int64 int64_val = 10 [packed = true]; 65 | 66 | // DT_BOOL 67 | repeated bool bool_val = 11 [packed = true]; 68 | 69 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 70 | // and imaginary parts of i-th double precision complex. 71 | repeated double dcomplex_val = 12 [packed = true]; 72 | 73 | // DT_RESOURCE 74 | repeated ResourceHandleProto resource_handle_val = 14; 75 | }; 76 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: 7 | - completed: 8 | name: matplotlib 9 | version: 0.7.7 10 | git: https://github.com/abarbu/matplotlib-haskell 11 | pantry-tree: 12 | size: 1059 13 | sha256: 56fec510202db68bf4ad28e0add00dd14437cfe5a321600f056be34790a8f921 14 | commit: 5c186cb5b9e80212c92c72a68e9ebbc260d413a8 15 | original: 16 | git: https://github.com/abarbu/matplotlib-haskell 17 | commit: 5c186cb5b9e80212c92c72a68e9ebbc260d413a8 18 | - completed: 19 | name: docopt 20 | version: 0.7.0.7 21 | git: https://github.com/docopt/docopt.hs.git 22 | pantry-tree: 23 | size: 1808 24 | sha256: 4f12f39fd4d913aafb82b21843a27230c0a9bdd2231ebac090dbe70b9fcacdae 25 | commit: bdc4c679bf0185ab6c1895172f011193d9e9922c 26 | original: 27 | git: https://github.com/docopt/docopt.hs.git 28 | commit: bdc4c679bf0185ab6c1895172f011193d9e9922c 29 | - completed: 30 | hackage: proto-lens-setup-0.4.0.5@sha256:ae4514963a6c20ad059bba427cd14b94c6007f614d797ebecae3c37f8bf0fa96,3108 31 | pantry-tree: 32 | size: 235 33 | sha256: 14982fbc9ee0c6f9f9a59c2639b647613eb9c2cfa1d5b1b323077a15ae285ccf 34 | original: 35 | hackage: proto-lens-setup-0.4.0.5 36 | - completed: 37 | hackage: proto-lens-0.7.1.0@sha256:b151890929e71db5b8c2ad86cd758bcdf1dfcf25f34eb6c9ce19e3d7cd4eae39,2959 38 | pantry-tree: 39 | size: 1857 40 | sha256: 2f1199d04d0588805e06faa0bf9a75898584d76243d4f945acbcc0e93913732e 41 | original: 42 | hackage: proto-lens-0.7.1.0 43 | - completed: 44 | hackage: proto-lens-runtime-0.7.0.1@sha256:703f327422b2e204f8ea13c5e178edd8ab42ccd01c7a5a89fcb1b37ab474c68a,3038 45 | pantry-tree: 46 | size: 168 47 | sha256: 145cb9a15b73d45b07cb3f9f0716256b2ed9e27ac296268ce100a4f0e477e110 48 | original: 49 | hackage: proto-lens-runtime-0.7.0.1 50 | - completed: 51 | hackage: proto-lens-protoc-0.7.1.0@sha256:b0b92498af74fc4bb770d51f84405d591e73c085a0c1d9952dd3e14ce07b538f,2220 52 | pantry-tree: 53 | size: 1219 54 | sha256: 1aaa82cb2823b33c9ea96608cc7ec245e54a0d14de8f18736be03220d8dbe683 55 | original: 56 | hackage: proto-lens-protoc-0.7.1.0 57 | - completed: 58 | name: cpython 59 | version: 3.5.1 60 | git: https://github.com/abarbu/haskell-cpython 61 | pantry-tree: 62 | size: 3087 63 | sha256: 814ac53161b9ab0214cf99a1773482ae418457e4086a9bcf821c44d3eff49ec7 64 | commit: 3c3c89acbc5a5fa6d60fc23a148f39eb330ecfac 65 | original: 66 | git: https://github.com/abarbu/haskell-cpython 67 | commit: 3c3c89acbc5a5fa6d60fc23a148f39eb330ecfac 68 | snapshots: 69 | - completed: 70 | size: 586268 71 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/18/13.yaml 72 | sha256: d9e658a22cfe8d87a64fdf219885f942fef5fe2bcb156a9800174911c5da2443 73 | original: lts-18.13 74 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/src/Torch/C/Scalar.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE QuasiQuotes, TemplateHaskell #-} 2 | 3 | module Torch.C.Scalar where 4 | import Data.Int 5 | import Data.Monoid ((<>)) 6 | import Foreign.C.Types 7 | import Foreign.Ptr 8 | import qualified Language.C.Inline as C 9 | import qualified Language.C.Inline.Cpp as C 10 | import Torch.C.Types 11 | 12 | C.context (C.cppCtx <> tensorCtx) 13 | 14 | C.include "" 15 | 16 | C.using "namespace at" 17 | C.using "namespace torch::autograd" 18 | 19 | C.verbatim "extern \"C\" void delete_scalar(Scalar* o) { delete o; }" 20 | 21 | foreign import ccall "&delete_scalar" deleteScalar :: FunPtr (Ptr CScalar -> IO ()) 22 | 23 | mkScalarCPUBool :: CBool -> IO (Ptr CScalar) 24 | mkScalarCPUBool x = [C.exp|Scalar *{ new Scalar($(bool x)) }|] 25 | 26 | mkScalarCPUByte :: CUChar -> IO (Ptr CScalar) 27 | mkScalarCPUByte x = [C.exp|Scalar *{ new Scalar($(unsigned char x)) }|] 28 | 29 | mkScalarCPUChar :: CChar -> IO (Ptr CScalar) 30 | mkScalarCPUChar x = [C.exp|Scalar *{ new Scalar($(char x)) }|] 31 | 32 | mkScalarCPUShort :: CShort -> IO (Ptr CScalar) 33 | mkScalarCPUShort x = [C.exp|Scalar *{ new Scalar($(short x)) }|] 34 | 35 | mkScalarCPUInt :: CInt -> IO (Ptr CScalar) 36 | mkScalarCPUInt x = [C.exp|Scalar *{ new Scalar($(int x)) }|] 37 | 38 | mkScalarCPULong :: CLong -> IO (Ptr CScalar) 39 | mkScalarCPULong x = [C.exp|Scalar *{ new Scalar($(long x)) }|] 40 | 41 | mkScalarCPUHalf :: Int16 -> IO (Ptr CScalar) 42 | mkScalarCPUHalf x = [C.exp|Scalar *{ new Scalar($(int16_t x)) }|] 43 | 44 | mkScalarCPUFloat :: CFloat -> IO (Ptr CScalar) 45 | mkScalarCPUFloat x = [C.exp|Scalar *{ new Scalar($(float x)) }|] 46 | 47 | mkScalarCPUDouble :: CDouble -> IO (Ptr CScalar) 48 | mkScalarCPUDouble x = [C.exp|Scalar *{ new Scalar($(double x)) }|] 49 | 50 | mkScalarCUDABool :: CBool -> IO (Ptr CScalar) 51 | mkScalarCUDABool x = [C.exp|Scalar *{ new Scalar($(bool x)) }|] 52 | 53 | mkScalarCUDAByte :: CUChar -> IO (Ptr CScalar) 54 | mkScalarCUDAByte x = [C.exp|Scalar *{ new Scalar($(unsigned char x)) }|] 55 | 56 | mkScalarCUDAChar :: CChar -> IO (Ptr CScalar) 57 | mkScalarCUDAChar x = [C.exp|Scalar *{ new Scalar($(char x)) }|] 58 | 59 | mkScalarCUDAShort :: CShort -> IO (Ptr CScalar) 60 | mkScalarCUDAShort x = [C.exp|Scalar *{ new Scalar($(short x)) }|] 61 | 62 | mkScalarCUDAInt :: CInt -> IO (Ptr CScalar) 63 | mkScalarCUDAInt x = [C.exp|Scalar *{ new Scalar($(int x)) }|] 64 | 65 | mkScalarCUDALong :: CLong -> IO (Ptr CScalar) 66 | mkScalarCUDALong x = [C.exp|Scalar *{ new Scalar($(long x)) }|] 67 | 68 | mkScalarCUDAHalf :: Int16 -> IO (Ptr CScalar) 69 | mkScalarCUDAHalf x = [C.exp|Scalar *{ new Scalar($(int16_t x)) }|] 70 | 71 | mkScalarCUDAFloat :: CFloat -> IO (Ptr CScalar) 72 | mkScalarCUDAFloat x = [C.exp|Scalar *{ new Scalar($(float x)) }|] 73 | 74 | mkScalarCUDADouble :: CDouble -> IO (Ptr CScalar) 75 | mkScalarCUDADouble x = [C.exp|Scalar *{ new Scalar($(double x)) }|] 76 | -------------------------------------------------------------------------------- /haskell-torch/package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-torch 2 | version: 0.8.0.0 3 | github: "abarbu/haskell-torch" 4 | license: BSD3 5 | author: "Andrei Barbu" 6 | maintainer: "andrei@0xab.com" 7 | copyright: "2018 Andrei Barbu" 8 | homepage: https://github.com/abarbu/haskell-torch 9 | bug-reports: https://github.com/abarbu/haskell-torch/issues 10 | category: AI 11 | synopsis: Deep learning in Haskell on top of Torch and PyTorch 12 | 13 | extra-source-files: 14 | - README.md 15 | 16 | # To avoid duplicated efforts in documentation and dealing with the 17 | # complications of embedding Haddock markup inside cabal files, it is 18 | # common to point users to the README.md file. 19 | description: Please see the README on Github at 20 | 21 | dependencies: 22 | - base >= 4.7 && < 5 23 | 24 | library: 25 | source-dirs: src 26 | dependencies: 27 | - haskell-torch-cbindings >= 0.8.0.0 28 | - haskell-torch-imagemagick 29 | - haskell-torch-matio 30 | - haskell-torch-tensorboard-proto 31 | - aeson 32 | - array 33 | - binary 34 | - bytestring 35 | - containers 36 | - data-default 37 | - directory 38 | - extra 39 | - filepath 40 | - generics-eot 41 | - ghc-typelits-knownnat 42 | - ghc-typelits-natnormalise 43 | - half 44 | - hashable 45 | - hashtables 46 | - haskell-src-exts 47 | - haskell-src-meta 48 | - hostname 49 | - ieee754 50 | - interpolateIO 51 | - matplotlib 52 | - megaparsec 53 | - microlens 54 | - monad-control 55 | - monad-loops 56 | - mtl 57 | - parser-combinators 58 | - pipes 59 | - pipes-aeson 60 | - pipes-bytestring 61 | - pipes-concurrency 62 | - pipes-csv 63 | - pipes-extras 64 | - pipes-group 65 | - pipes-parse 66 | - pipes-safe 67 | - proto-lens 68 | - random 69 | - safe-exceptions 70 | - shelly 71 | - singletons 72 | - statistics 73 | - string-qq 74 | - syb 75 | - template-haskell 76 | - text 77 | - time 78 | - unix 79 | - vector 80 | - yaml 81 | - zlib 82 | - monad-logger 83 | - docopt 84 | - default-type-plugin 85 | - simplify-nat-algebra-plugin 86 | - temporary 87 | - barbies 88 | - distributive 89 | - generics-sop 90 | - vector-algorithms 91 | 92 | tests: 93 | doctest: 94 | main: Doctest.hs 95 | source-dirs: test 96 | # main-is: Doctest.hs 97 | other-modules: [] 98 | dependencies: 99 | - doctest 100 | - doctest-discover 101 | # TODO 102 | # haskell-torch-test: 103 | # main: Spec.hs 104 | # source-dirs: test 105 | # dependencies: 106 | # - haskell-torch 107 | # - tasty 108 | # - tasty-quickcheck 109 | # - tasty-discover 110 | # - tasty-stats 111 | # - tasty-dejafu 112 | # - tasty-golden 113 | -------------------------------------------------------------------------------- /dataframe/src/Generics/SOP/Record/Combination.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE 2 | AllowAmbiguousTypes 3 | , FlexibleContexts 4 | , FlexibleInstances 5 | , MultiParamTypeClasses 6 | , ScopedTypeVariables 7 | , TypeApplications 8 | , TypeInType 9 | , TypeFamilies 10 | , TypeOperators 11 | , UndecidableInstances 12 | #-} 13 | {-# OPTIONS_GHC 14 | -fno-warn-unticked-promoted-constructors 15 | #-} 16 | module Generics.SOP.Record.Combination 17 | ( combination 18 | , IsCombinationOf 19 | , IsElemOf2 20 | , get2 21 | , getField2 22 | ) 23 | where 24 | 25 | import Data.Type.Equality 26 | import Generics.SOP.NP 27 | import GHC.Types 28 | 29 | import Generics.SOP.Record 30 | 31 | 32 | class IsElemOf2 (s :: Symbol) (a :: Type) (r1 :: RecordCode) (r2 :: RecordCode) where 33 | get2 :: Record r1 -> Record r2 -> a 34 | 35 | class IsElemOfIf2 (b :: Bool) 36 | (targetSymbol :: FieldLabel) (targetType :: Type) 37 | (currentSymbol :: FieldLabel) (currentType :: Type) 38 | (r1 :: RecordCode) (r2 :: RecordCode) where 39 | get2_1' :: Record ( '(currentSymbol, currentType) : r1 ) -> Record r2 -> targetType 40 | get2_2' :: Record r1 -> Record ( '(currentSymbol, currentType) : r2 ) -> targetType 41 | 42 | -- Traverse 43 | instance IsElemOfIf2 (targetSymbol == currentSymbol) targetSymbol targetType currentSymbol currentType r1 r2 44 | => IsElemOf2 targetSymbol targetType ( '(currentSymbol, currentType) : r1 ) r2 where 45 | get2 r1 r2 = get2_1' @(targetSymbol == currentSymbol) @targetSymbol @targetType r1 r2 46 | 47 | instance IsElemOfIf2 (targetSymbol == currentSymbol) targetSymbol targetType currentSymbol currentType '[] r2 48 | => IsElemOf2 targetSymbol targetType '[] ( '(currentSymbol, currentType) : r2 ) where 49 | get2 r1 r2 = get2_2' @(targetSymbol == currentSymbol) @targetSymbol @targetType r1 r2 50 | 51 | instance (targetType ~ currentType) => IsElemOfIf2 True s targetType s currentType r1 r2 where 52 | get2_1' (P a :* _) _ = a 53 | get2_2' _ (P a :* _) = a 54 | 55 | instance IsElemOf2 targetSymbol targetType r1 r2 => IsElemOfIf2 False targetSymbol targetType currentSymbol currentType r1 r2 where 56 | get2_1' (_ :* r1) r2 = get2 @targetSymbol @targetType r1 r2 57 | get2_2' r1 (_ :* r2) = get2 @targetSymbol @targetType r1 r2 58 | 59 | getField2 :: forall s a b o ra rb. (IsRecord a ra, IsRecord b rb, IsElemOf2 s o ra rb) => a -> b -> o 60 | getField2 r1 r2 = get2 @s (toRecord r1) (toRecord r2) 61 | 62 | class IsCombinationOf (r1 :: RecordCode) (r2 :: RecordCode) (r :: RecordCode) where 63 | combinationRecords :: Record r1 -> Record r2 -> Record r 64 | 65 | instance IsCombinationOf r1 r2 '[] where 66 | combinationRecords _ _ = Nil 67 | 68 | instance (IsCombinationOf r1 r2 r, IsElemOf2 s2 a2 r1 r2) => IsCombinationOf r1 r2 ( '(s2, a2) : r ) where 69 | combinationRecords r1 r2 = P (get2 @s2 r1 r2) :* combinationRecords r1 r2 70 | 71 | combination :: (IsRecord a ra, IsRecord b rb, IsRecord c rc, IsCombinationOf ra rb rc) => a -> b -> c 72 | combination r1 r2 = fromRecord $ combinationRecords (toRecord r1) (toRecord r2) 73 | 74 | -------------------------------------------------------------------------------- /dataframe/src/Generics/SOP/Record/SubTyping.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE 2 | AllowAmbiguousTypes 3 | , FlexibleContexts 4 | , FlexibleInstances 5 | , MultiParamTypeClasses 6 | , ScopedTypeVariables 7 | , TypeApplications 8 | , TypeInType 9 | , TypeFamilies 10 | , TypeOperators 11 | , UndecidableInstances 12 | #-} 13 | {-# OPTIONS_GHC 14 | -fno-warn-unticked-promoted-constructors 15 | #-} 16 | module Generics.SOP.Record.SubTyping 17 | ( cast 18 | , IsSubTypeOf 19 | , IsElemOf 20 | , get 21 | , getField 22 | ) 23 | where 24 | 25 | import Data.Type.Equality 26 | import Generics.SOP.NP 27 | import GHC.Types 28 | 29 | import Generics.SOP.Record 30 | 31 | -- | Cast one record type to another if there is a subtype relationship 32 | -- between them. Currently, only width subtyping is considered, which means 33 | -- that we can forget and reorder fields. 34 | -- 35 | cast :: (IsRecord a ra, IsRecord b rb, IsSubTypeOf ra rb) => a -> b 36 | cast = fromRecord . castRecord . toRecord 37 | 38 | -- | Extract a record field based on the symbolic name of a field. 39 | -- Requires an explicit type application for the field name. 40 | -- 41 | getField :: forall s a b ra . (IsRecord a ra, IsElemOf s b ra) => a -> b 42 | getField = get @s . toRecord 43 | 44 | -- | Class that checks whether one record code is convertible into another. 45 | -- 46 | -- Conversion works if the first record contains at least the labels of the 47 | -- second record, and if the types of the corresponding fields match exactly. 48 | -- 49 | class IsSubTypeOf (r1 :: RecordCode) (r2 :: RecordCode) where 50 | -- | Perform a safe cast between two records. 51 | castRecord :: Record r1 -> Record r2 52 | 53 | instance IsSubTypeOf r1 '[] where 54 | castRecord _ = Nil 55 | 56 | instance (IsSubTypeOf r1 r2, IsElemOf s2 a2 r1) => IsSubTypeOf r1 ( '(s2, a2) : r2 ) where 57 | castRecord r = P (get @s2 r) :* castRecord r 58 | 59 | -- | Class that checks whether a field of a particular type is contained 60 | -- in a record. 61 | -- 62 | class IsElemOf (s :: Symbol) (a :: Type) (r :: RecordCode) where 63 | -- | Perform an extraction of a given field. Field name has to be passed 64 | -- via type application. 65 | -- 66 | get :: Record r -> a 67 | 68 | -- | Helper class. Isn't strictly needed, but allows us to avoid 69 | -- overlapping instances for the 'IsElemOf' class. 70 | -- 71 | class IsElemOf' (b :: Bool) 72 | (s1 :: FieldLabel) (a1 :: Type) 73 | (s2 :: FieldLabel) (a2 :: Type) 74 | (r :: RecordCode) 75 | where 76 | get' :: Record ( '(s2, a2) : r ) -> a1 77 | 78 | instance 79 | IsElemOf' (SameFieldLabel s1 s2) s1 a1 s2 a2 r => 80 | IsElemOf s1 a1 ( '(s2, a2) : r ) 81 | where 82 | get = get' @(SameFieldLabel s1 s2) @s1 83 | 84 | instance (a1 ~ a2) => IsElemOf' True s a1 s a2 r where 85 | get' (P a :* _) = a 86 | 87 | instance IsElemOf s1 a1 r => IsElemOf' False s1 a1 s2 a2 r where 88 | get' (_ :* r) = get @s1 r 89 | 90 | -- | Decide the equality of two field labels. 91 | -- 92 | -- Just a special case of polymorphic type equality. 93 | -- 94 | type family 95 | SameFieldLabel (s1 :: FieldLabel) (s2 :: FieldLabel) :: Bool where 96 | SameFieldLabel s1 s2 = s1 == s2 97 | -------------------------------------------------------------------------------- /haskell-torch-examples/haskell-torch-examples.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch-examples 8 | version: 0.8.0.0 9 | synopsis: Examples of how to use Haskell-Torch 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Torch.Tutorial.Intro.T01_Basics 29 | Torch.Tutorial.Intro.T02_LinearRegression 30 | Torch.Tutorial.Intro.T03_LogisticRegression 31 | Torch.Tutorial.Intro.T04_FeedforwardNN 32 | Torch.Tutorial.Intro.T05_CNN 33 | Torch.Tutorial.Intro.T06_ResNet 34 | Torch.Tutorial.Intro.T07_RNN 35 | Torch.Tutorial.Intro.T08_BiRNN 36 | Torch.Tutorial.Intro.T09_LanguageModel 37 | Torch.Tutorial.Intro.T10_GAN 38 | Torch.Tutorial.Intro.T11_VAE 39 | Torch.Tutorial.RL.Simple 40 | Torch.Tutorial.Tensorboard 41 | other-modules: 42 | Paths_haskell_torch_examples 43 | hs-source-dirs: 44 | src 45 | build-depends: 46 | aeson 47 | , array 48 | , base >=4.7 && <5 49 | , binary 50 | , bytestring 51 | , containers 52 | , data-default 53 | , default-type-plugin 54 | , directory 55 | , docopt 56 | , extra 57 | , filepath 58 | , generics-eot 59 | , ghc-typelits-knownnat 60 | , ghc-typelits-natnormalise 61 | , gym 62 | , half 63 | , hashable 64 | , hashtables 65 | , haskell-src-exts 66 | , haskell-src-meta 67 | , haskell-torch 68 | , haskell-torch-datasets 69 | , haskell-torch-imagemagick 70 | , haskell-torch-matio 71 | , haskell-torch-models 72 | , haskell-torch-tensorboard-proto 73 | , hostname 74 | , ieee754 75 | , interpolateIO 76 | , matplotlib 77 | , megaparsec 78 | , microlens 79 | , monad-control 80 | , monad-logger 81 | , monad-loops 82 | , mtl 83 | , parser-combinators 84 | , pipes 85 | , pipes-aeson 86 | , pipes-bytestring 87 | , pipes-concurrency 88 | , pipes-csv 89 | , pipes-extras 90 | , pipes-group 91 | , pipes-parse 92 | , pipes-safe 93 | , proto-lens 94 | , random 95 | , safe-exceptions 96 | , shelly 97 | , simplify-nat-algebra-plugin 98 | , singletons 99 | , statistics 100 | , string-qq 101 | , syb 102 | , template-haskell 103 | , text 104 | , time 105 | , unix 106 | , vector 107 | , yaml 108 | , zlib 109 | default-language: Haskell2010 110 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T03_LogisticRegression.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, DataKinds, ExtendedDefaultRules, FlexibleContexts, FlexibleInstances, OverloadedStrings, PolyKinds #-} 2 | {-# LANGUAGE QuasiQuotes, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies #-} 3 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 4 | module Torch.Tutorial.Intro.T03_LogisticRegression where 5 | import Control.Monad 6 | import Data.Default 7 | import Data.IORef 8 | import Data.String.InterpolateIO 9 | import Torch 10 | import Torch.Datasets 11 | 12 | ex = do 13 | let epochs = 5 14 | let learningRate = 0.001 15 | (tr, te) <- mnist "datasets/image/" 16 | (Right tes) <- fetchDataset te 17 | let testStream = batchTensors (batchSize_ @100) tes 18 | (Right trs) <- fetchDataset tr 19 | let trainStream = batchTensors (batchSize_ @100) (shuffle 1000 trs) 20 | w <- gradP 21 | let model = linear (inFeatures_ @784) (outFeatures_ @10) w 22 | let criterion y ypred = crossEntropyLoss y def def def ypred 23 | params <- toParameters w 24 | -- 25 | optimizer <- newIORef (sgd (def { sgdLearningRate = learningRate }) params) 26 | withGrad 27 | $ mapM_ (\epoch -> do 28 | forEachDataN 29 | (\d n -> do 30 | zeroGradients_ params 31 | loss <- do 32 | o <- view =<< dataObject d 33 | l <- view =<< dataLabel d 34 | pred <- model (sized (size_ @'[100,784]) o) 35 | criterion (sized (size_ @'[100]) l) pred 36 | backward1 loss False False 37 | step_ optimizer 38 | when (n `rem` 100 == 0) $ putStrLn =<< [c|Epoch #{epoch+1}/#{epochs} loss #{loss}|] 39 | pure ()) 40 | trainStream) 41 | [0..epochs-1] 42 | x <- withoutGrad 43 | $ foldData 44 | (\(nr, correct :: Int) d -> do 45 | o <- copy @TFloat @KCpu =<< dataObject d 46 | l <- sized (size_ @'[100]) <$> (view =<< dataLabel d) 47 | pred <- model (sized (size_ @'[100,784]) o) 48 | (_, is) <- Torch.maxDim @1 pred 49 | s <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (is .== l)) 50 | pure $ (nr + 100, correct + s)) 51 | (0, 0) 52 | trainStream 53 | putStrLn =<< [c|Training set accuracy #{100 * (fromIntegral (snd x) / fromIntegral (fst x))}% #{fromIntegral (fst x)} images|] 54 | y <- withoutGrad 55 | $ foldData 56 | (\(nr, correct) d -> do 57 | o <- copy @TFloat @KCpu =<< dataObject d 58 | l <- sized (size_ @'[100]) <$> (view =<< dataLabel d) 59 | pred <- model (sized (size_ @'[100,784]) o) 60 | (_, is) <- Torch.maxDim @1 pred 61 | s <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (is .== l)) 62 | pure (nr + 100, correct + s)) 63 | (0, 0) 64 | testStream 65 | putStrLn =<< [c|Test set accuracy #{100 * (fromIntegral (snd y) / fromIntegral (fst y))}% on #{fromIntegral (fst y)} images|] 66 | pure () 67 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: "ubuntu:18.04" 6 | steps: 7 | - run: apt update 8 | - run: apt install -y git 9 | - checkout 10 | - run: 11 | name: System-wide dependencies 12 | command: |- 13 | DEBIAN_FRONTEND=noninteractive apt-get install -y wget 14 | - restore_cache: 15 | name: Restore Cached Dependencies 16 | keys: 17 | - haskell-torch-anaconda-{{ .Branch }}-{{ checksum "environment.yml" }}-{{ checksum "generate-config.py" }}-linux 18 | - restore_cache: 19 | name: Restore Cached Dependencies 20 | keys: 21 | - haskell-torch-stack-{{ .Branch }}-{{ checksum "environment.yml" }}-{{ checksum "generate-config.py" }}-linux 22 | - run: 23 | name: Get Conda 24 | command: | 25 | ([ ! -d ~/anaconda ] && wget https://repo.anaconda.com/archive/Anaconda3-2019.07-Linux-x86_64.sh -O ~/anaconda.sh) || echo "Ok" 26 | # Not needed with anaconda right now, but miniconda needs it on some platforms 27 | # - run: 28 | # name: Workaround for conda bug in some containers 29 | # command: mkdir -p ~/.conda 30 | - run: 31 | name: Install Conda 32 | command: | 33 | ([ ! -d ~/anaconda ] && bash ~/anaconda.sh -b -p $HOME/anaconda) || echo "Ok" 34 | - run: 35 | name: Initialize Conda 36 | command: eval "$(~/anaconda/bin/conda shell.bash hook)" && conda init bash 37 | - run: 38 | name: Install stack 39 | command: wget -qO- https://get.haskellstack.org/ | sh 40 | - run: 41 | name: Set stack up 42 | command: stack setup 43 | - run: 44 | name: Set up package 45 | command: eval "$(~/anaconda/bin/conda shell.bash hook)" && bash setup.sh 46 | - save_cache: 47 | name: Save cache 48 | key: haskell-torch-anaconda-{{ .Branch }}-{{ checksum "environment.yml" }}-{{ checksum "generate-config.py" }}-linux 49 | paths: 50 | - "~/anaconda" 51 | - "~/anaconda.sh" 52 | - "~/.conda" 53 | - "~/.stack" 54 | - "~/.bashrc" 55 | - run: 56 | name: Build 57 | command: | 58 | eval "$(~/anaconda/bin/conda shell.bash hook)" 59 | conda activate haskell-torch 60 | export 61 | stack build --fast -j1 haskell-torch 62 | - save_cache: 63 | name: Save cache 64 | key: haskell-torch-stack-{{ .Branch }}-{{ checksum "environment.yml" }}-{{ checksum "generate-config.py" }}-linux 65 | paths: 66 | - "~/.stack" 67 | # doctests are too slow for CircleCI :( 68 | # TODO Turn the tutorials into end to end tests 69 | # - run: 70 | # name: Run tests 71 | # command: eval "$(~/anaconda/bin/conda shell.bash hook)" && conda activate haskell-torch && stack test --fast -j1 haskell-torch 72 | - store_artifacts: # upload build artifact for display in CircleCi 73 | # TODO Do something useful 74 | path: ~/.local/bin/circleci-demo-haskell-exe 75 | destination: haskell-torch-haskell 76 | -------------------------------------------------------------------------------- /haskell-torch/haskell-torch.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.34.4. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-torch 8 | version: 0.8.0.0 9 | synopsis: Deep learning in Haskell on top of Torch and PyTorch 10 | description: Please see the README on Github at 11 | category: AI 12 | homepage: https://github.com/abarbu/haskell-torch 13 | bug-reports: https://github.com/abarbu/haskell-torch/issues 14 | author: Andrei Barbu 15 | maintainer: andrei@0xab.com 16 | copyright: 2018 Andrei Barbu 17 | license: BSD3 18 | build-type: Simple 19 | extra-source-files: 20 | README.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/abarbu/haskell-torch 25 | 26 | library 27 | exposed-modules: 28 | Torch 29 | Torch.Datasets.Augmentation 30 | Torch.Datasets.Common 31 | Torch.Images 32 | Torch.Indexing 33 | Torch.Initialization 34 | Torch.Inplace 35 | Torch.Internal.CRC32C 36 | Torch.Misc 37 | Torch.Operators 38 | Torch.Optimizer 39 | Torch.StoredModel 40 | Torch.Tensor 41 | Torch.Tensorboard 42 | Torch.Types 43 | Torch.Visualization 44 | other-modules: 45 | Paths_haskell_torch 46 | hs-source-dirs: 47 | src 48 | build-depends: 49 | aeson 50 | , array 51 | , barbies 52 | , base >=4.7 && <5 53 | , binary 54 | , bytestring 55 | , containers 56 | , data-default 57 | , default-type-plugin 58 | , directory 59 | , distributive 60 | , docopt 61 | , extra 62 | , filepath 63 | , generics-eot 64 | , generics-sop 65 | , ghc-typelits-knownnat 66 | , ghc-typelits-natnormalise 67 | , half 68 | , hashable 69 | , hashtables 70 | , haskell-src-exts 71 | , haskell-src-meta 72 | , haskell-torch-cbindings >=0.8.0.0 73 | , haskell-torch-imagemagick 74 | , haskell-torch-matio 75 | , haskell-torch-tensorboard-proto 76 | , hostname 77 | , ieee754 78 | , interpolateIO 79 | , matplotlib 80 | , megaparsec 81 | , microlens 82 | , monad-control 83 | , monad-logger 84 | , monad-loops 85 | , mtl 86 | , parser-combinators 87 | , pipes 88 | , pipes-aeson 89 | , pipes-bytestring 90 | , pipes-concurrency 91 | , pipes-csv 92 | , pipes-extras 93 | , pipes-group 94 | , pipes-parse 95 | , pipes-safe 96 | , proto-lens 97 | , random 98 | , safe-exceptions 99 | , shelly 100 | , simplify-nat-algebra-plugin 101 | , singletons 102 | , statistics 103 | , string-qq 104 | , syb 105 | , template-haskell 106 | , temporary 107 | , text 108 | , time 109 | , unix 110 | , vector 111 | , vector-algorithms 112 | , yaml 113 | , zlib 114 | default-language: Haskell2010 115 | 116 | test-suite doctest 117 | type: exitcode-stdio-1.0 118 | main-is: Doctest.hs 119 | hs-source-dirs: 120 | test 121 | build-depends: 122 | base >=4.7 && <5 123 | , doctest 124 | , doctest-discover 125 | default-language: Haskell2010 126 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T04_FeedforwardNN.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, DataKinds, ExtendedDefaultRules, FlexibleContexts, FlexibleInstances, OverloadedStrings, PolyKinds #-} 2 | {-# LANGUAGE QuasiQuotes, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies #-} 3 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 4 | 5 | module Torch.Tutorial.Intro.T04_FeedforwardNN where 6 | import Control.Monad 7 | import Data.Default 8 | import Data.IORef 9 | import Data.String.InterpolateIO 10 | import Torch 11 | import Torch.Datasets 12 | 13 | ex = do 14 | let epochs = 5 15 | let learningRate = 0.001 16 | -- 17 | (tr, te) <- mnist "datasets/image/" 18 | (Right tes) <- fetchDataset te 19 | let testStream = batchTensors (batchSize_ @100) tes 20 | (Right trs) <- fetchDataset tr 21 | let trainStream = batchTensors (batchSize_ @100) trs 22 | -- 23 | w1 <- gradP 24 | w2 <- gradP 25 | -- 26 | let model :: (Tensor 'TFloat 'KCpu '[100, 784] -> IO (Tensor 'TFloat 'KCpu '[100, 10])) 27 | = linear (inFeatures_ @784) (outFeatures_ @500) w1 28 | >=> relu 29 | >=> linear (inFeatures_ @500) (outFeatures_ @10) w2 30 | -- 31 | let criterion y ypred = crossEntropyLoss y def def def ypred 32 | -- 33 | w1' <- toParameters w1 34 | w2' <- toParameters w2 35 | let params = w1' ++ w2' 36 | optimizer <- newIORef (adam (def { adamLearningRate = learningRate }) params) 37 | -- 38 | withGrad 39 | $ mapM_ (\epoch -> do 40 | print (epoch :: Int) 41 | forEachDataN 42 | (\d n -> do 43 | zeroGradients_ params 44 | loss <- do 45 | o <- view =<< dataObject d 46 | l <- view =<< dataLabel d 47 | pred <- model (sized (size_ @'[100,784]) o) 48 | criterion (sized (size_ @'[100]) l) pred 49 | backward1 loss False False 50 | step_ optimizer 51 | when (n `rem` 100 == 0) $ putStrLn =<< [c|Epoch #{epoch+1}/#{epochs} loss #{loss}|] 52 | pure ()) 53 | trainStream) 54 | [0..epochs-1] 55 | x <- withoutGrad $ do 56 | foldData 57 | (\(nr, correct :: Int) d -> do 58 | o <- copy @TFloat @KCpu =<< dataObject d 59 | l <- sized (size_ @'[100]) <$> (view =<< dataLabel d) 60 | pred <- model (sized (size_ @'[100,784]) o) 61 | (_, is) <- Torch.maxDim @1 pred 62 | s <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (is .== l)) 63 | pure (nr + 100, correct + s)) 64 | (0, 0) 65 | trainStream 66 | putStrLn =<< [c|Training set accuracy #{100 * (fromIntegral (snd x) / fromIntegral (fst x))}% #{fromIntegral (fst x)} images|] 67 | y <- withoutGrad 68 | $ foldData 69 | (\(nr, correct :: Int) d -> do 70 | o <- copy @TFloat @KCpu =<< dataObject d 71 | l <- sized (size_ @'[100]) <$> (view =<< dataLabel d) 72 | pred <- model (sized (size_ @'[100,784]) o) 73 | (_, is) <- Torch.maxDim @1 pred 74 | s <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (is .== l)) 75 | pure (nr + 100, correct + s)) 76 | (0, 0) 77 | testStream 78 | putStrLn =<< [c|Test set accuracy #{100 * (fromIntegral (snd y) / fromIntegral (fst y))}% on #{fromIntegral (fst y)} images|] 79 | pure () 80 | -------------------------------------------------------------------------------- /nix/sources.json: -------------------------------------------------------------------------------- 1 | { 2 | "gitignore.nix": { 3 | "branch": "master", 4 | "description": "Nix functions for filtering local git sources", 5 | "homepage": "", 6 | "owner": "hercules-ci", 7 | "repo": "gitignore", 8 | "rev": "80463148cd97eebacf80ba68cf0043598f0d7438", 9 | "sha256": "1l34rmh4lf4w8a1r8vsvkmg32l1chl0p593fl12r28xx83vn150v", 10 | "type": "tarball", 11 | "url": "https://github.com/hercules-ci/gitignore/archive/80463148cd97eebacf80ba68cf0043598f0d7438.tar.gz", 12 | "url_template": "https://github.com///archive/.tar.gz" 13 | }, 14 | "haskell.nix": { 15 | "branch": "master", 16 | "description": "Alternative Haskell Infrastructure for Nixpkgs", 17 | "homepage": "https://input-output-hk.github.io/haskell.nix", 18 | "owner": "input-output-hk", 19 | "repo": "haskell.nix", 20 | "rev": "89a69afd820506f6032cd805bc18e127c2af47a5", 21 | "sha256": "0qr5wlypvxwqy8kqd7524xdbqcd9s47rhnpvsa2wf60jrs4axbb9", 22 | "type": "tarball", 23 | "url": "https://github.com/input-output-hk/haskell.nix/archive/89a69afd820506f6032cd805bc18e127c2af47a5.tar.gz", 24 | "url_template": "https://github.com///archive/.tar.gz" 25 | }, 26 | "jupyterWith.nix": { 27 | "branch": "master", 28 | "description": "declarative and reproducible Jupyter environments - powered by Nix", 29 | "homepage": "", 30 | "owner": "tweag", 31 | "repo": "jupyterWith", 32 | "rev": "73bdac9ca036c0303fc3a487129e23f9c4ad0bcf", 33 | "sha256": "1ypkr5xg03hgpc15iv961x4plllbvdx4bqa2fjppvhn3335hfxmn", 34 | "type": "tarball", 35 | "url": "https://github.com/tweag/jupyterWith/archive/73bdac9ca036c0303fc3a487129e23f9c4ad0bcf.tar.gz", 36 | "url_template": "https://github.com///archive/.tar.gz" 37 | }, 38 | "niv": { 39 | "branch": "master", 40 | "description": "Easy dependency management for Nix projects", 41 | "homepage": "https://github.com/nmattia/niv", 42 | "owner": "nmattia", 43 | "repo": "niv", 44 | "rev": "65a61b147f307d24bfd0a5cd56ce7d7b7cc61d2e", 45 | "sha256": "17mirpsx5wyw262fpsd6n6m47jcgw8k2bwcp1iwdnrlzy4dhcgqh", 46 | "type": "tarball", 47 | "url": "https://github.com/nmattia/niv/archive/65a61b147f307d24bfd0a5cd56ce7d7b7cc61d2e.tar.gz", 48 | "url_template": "https://github.com///archive/.tar.gz" 49 | }, 50 | "nixpkgs": { 51 | "branch": "release-20.03", 52 | "description": "Nix Packages collection", 53 | "homepage": "", 54 | "owner": "NixOS", 55 | "repo": "nixpkgs", 56 | "rev": "eb73405ecceb1dc505b7cbbd234f8f94165e2696", 57 | "sha256": "06k21wbyhhvq2f1xczszh3c2934p0m02by3l2ixvd6nkwrqklax7", 58 | "type": "tarball", 59 | "url": "https://github.com/NixOS/nixpkgs/archive/eb73405ecceb1dc505b7cbbd234f8f94165e2696.tar.gz", 60 | "url_template": "https://github.com///archive/.tar.gz" 61 | }, 62 | "stackage.nix": { 63 | "branch": "master", 64 | "description": "Automatically generated Nix expressions of Stackage snapshots", 65 | "homepage": "", 66 | "owner": "input-output-hk", 67 | "repo": "stackage.nix", 68 | "rev": "1301f5d364ed6c704103a558e49b08b63096b810", 69 | "sha256": "0l7wslsm8ipci2bsc7j87wa7f9qf5an4qpp4s15i60ij5lfyphvk", 70 | "type": "tarball", 71 | "url": "https://github.com/input-output-hk/stackage.nix/archive/1301f5d364ed6c704103a558e49b08b63096b810.tar.gz", 72 | "url_template": "https://github.com///archive/.tar.gz" 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T07_RNN.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, CPP, ConstraintKinds, DataKinds, DeriveAnyClass, DeriveGeneric, FlexibleContexts, FlexibleInstances #-} 2 | {-# LANGUAGE FunctionalDependencies, GADTs, OverloadedLabels, OverloadedStrings, PartialTypeSignatures, PolyKinds, QuasiQuotes #-} 3 | {-# LANGUAGE RankNTypes, RecordWildCards, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies, TypeFamilyDependencies #-} 4 | {-# LANGUAGE TypeInType, TypeOperators, UndecidableInstances #-} 5 | {-# OPTIONS_GHC -fconstraint-solver-iterations=10 -fdefer-typed-holes #-} 6 | {-# OPTIONS_GHC -fplugin-opt GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} 7 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-} 8 | {-# OPTIONS_GHC -fplugin Plugin.SimplifyNat #-} 9 | 10 | module Torch.Tutorial.Intro.T07_RNN where 11 | import Control.Monad 12 | import Data.Default 13 | import Data.IORef 14 | import Data.String.InterpolateIO 15 | import GHC.Generics 16 | import Torch 17 | import Torch.Datasets 18 | 19 | data Model = Model { w1 :: LSTMParams TFloat KBest 28 128 2 False True, 20 | w2 :: LinearParam TFloat KBest 128 10 } 21 | deriving(Generic,ParameterNames,Stored,Initialize,ToTensors,ToParameters) 22 | 23 | forward :: _ -> DataPurpose -> Tensor TFloat KBest '[100, 28, 28] -> IO (Tensor TFloat KBest '[100, 10]) 24 | forward (Model{..},state1) isTraining = 25 | lstmBatchFirst (inF_ @28) (hiddenF_ @128) (nrLayers_ @2) (isBidirectional_ @False) 0 isTraining w1 state1 26 | >=> (\(a,s) -> select @1 a (-1)) 27 | >=> linear (inF_ @128) (outFeatures_ @10) w2 28 | 29 | ex = do 30 | let epochs = 2 :: Int 31 | -- 32 | (tr, te) <- mnist "datasets/image/" 33 | (Right tes) <- fetchDataset te 34 | let testStream = batchTensors (batchSize_ @100) tes 35 | (Right trs) <- fetchDataset tr 36 | let trainStream = batchTensors (batchSize_ @100) 37 | $ shuffle 1000 trs 38 | -- 39 | net <- gradP 40 | params <- toParameters net 41 | optimizer <- newIORef (adam (def { adamLearningRate = 0.01 }) params) 42 | -- 43 | let criterion y ypred = crossEntropyLoss y def def def ypred 44 | -- 45 | withGrad 46 | $ mapM_ (\epoch -> do 47 | forEachDataN 48 | (\d n -> do 49 | initialState <- gradP 50 | zeroGradients_ params 51 | loss <- do 52 | o <- view =<< toDevice =<< dataObject d 53 | l <- view =<< toDevice =<< dataLabel d 54 | pred <- (forward (net,initialState) (dataPurpose d) . sized (size_ @'[100, 28, 28])) =<< view (sized (size_ @'[100,784]) o) 55 | criterion (sized (size_ @'[100]) l) pred 56 | backward1 loss False False 57 | step_ optimizer 58 | when (n `rem` 100 == 0) $ putStrLn =<< [c|Epoch #{epoch+1}/#{epochs} loss #{loss}|] 59 | putStrLn =<< [c|Step1 #{epoch+1}/#{epochs} #{n+1} loss #{loss}|] 60 | pure ()) 61 | trainStream) 62 | [0..epochs-1] 63 | -- 64 | y <- withoutGrad 65 | $ foldData 66 | (\(nr, correct) d -> do 67 | initialState <- gradP 68 | images <- view =<< toDevice =<< dataObject d 69 | labels <- view =<< toDevice =<< dataLabel d 70 | pred <- forward (net,initialState) (dataPurpose d) images 71 | (_, predictionIndices) <- Torch.maxDim @1 pred 72 | nrCorrect <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (predictionIndices .== sized (size_ @'[100]) labels)) 73 | pure (nr + 100, correct + nrCorrect)) 74 | (0, 0) 75 | testStream 76 | putStrLn =<< [c|Test set accuracy #{((100 * (fromIntegral (snd y) / fromIntegral (fst y))) :: Double)}% on #{((fromIntegral (fst y))::Double)} images|] 77 | -- 78 | writeModelToFile net "lstm.ht" 79 | pure () 80 | -------------------------------------------------------------------------------- /default-type-plugin/src/Plugin/DefaultType.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses, KindSignatures, FlexibleInstances, DataKinds, PatternSynonyms, StandaloneDeriving, GeneralizedNewtypeDeriving, PolyKinds #-} 2 | module Plugin.DefaultType(DefaultType,plugin) where 3 | import GhcPlugins 4 | import TcRnTypes 5 | import Constraint 6 | import TcPluginM 7 | import qualified Inst 8 | import InstEnv 9 | import TcSimplify (approximateWC) 10 | import qualified Finder 11 | import Panic (panicDoc) 12 | import Data.List 13 | import TcType 14 | import qualified Data.Map as M 15 | import TyCoRep (Type(..)) 16 | import TyCon (TyCon(..)) 17 | import Control.Monad (liftM2) 18 | import GHC.TypeLits 19 | 20 | class DefaultType x (y :: x) 21 | 22 | instance Eq Type where 23 | (==) = eqType 24 | instance Ord Type where 25 | compare = nonDetCmpType 26 | instance Semigroup (TcPluginM [a]) where 27 | (<>) = liftM2 (++) 28 | instance Monoid (TcPluginM [a]) where 29 | mempty = pure mempty 30 | 31 | plugin :: Plugin 32 | plugin = defaultPlugin { 33 | defaultingPlugin = install, 34 | pluginRecompile = purePlugin 35 | } 36 | 37 | install args = Just $ DefaultingPlugin { dePluginInit = initialize 38 | , dePluginRun = run 39 | , dePluginStop = stop 40 | } 41 | 42 | pattern FoundModule :: Module -> FindResult 43 | pattern FoundModule a <- Found _ a 44 | fr_mod :: a -> a 45 | fr_mod = id 46 | 47 | lookupModule :: ModuleName -- ^ Name of the module 48 | -> TcPluginM Module 49 | lookupModule mod_nm = do 50 | hsc_env <- TcPluginM.getTopEnv 51 | found_module <- TcPluginM.tcPluginIO $ Finder.findPluginModule hsc_env mod_nm 52 | case found_module of 53 | FoundModule h -> return (fr_mod h) 54 | _ -> do 55 | found_module' <- TcPluginM.findImportedModule mod_nm $ Just $ fsLit "this" 56 | case found_module' of 57 | FoundModule h -> return (fr_mod h) 58 | _ -> panicDoc "Unable to resolve module looked up by plugin: " 59 | (ppr mod_nm) 60 | 61 | data PluginState = PluginState { defaultClassName :: Name } 62 | 63 | -- | Find a 'Name' in a 'Module' given an 'OccName' 64 | lookupName :: Module -> OccName -> TcPluginM Name 65 | lookupName md occ = lookupOrig md occ 66 | 67 | solveDefaultType :: PluginState -> [Ct] -> TcPluginM DefaultingPluginResult 68 | solveDefaultType _ [] = return [] 69 | solveDefaultType state wanteds = do 70 | envs <- getInstEnvs 71 | insts <- classInstances envs <$> tcLookupClass (defaultClassName state) 72 | let defaults = foldl' (\m inst -> 73 | case is_tys inst of 74 | [matchty, replacety] -> 75 | M.insertWith (++) matchty [replacety] m) M.empty insts 76 | let groups = 77 | foldl' (\m wanted -> 78 | foldl' (\m var -> M.insertWith (++) var [wanted] m) 79 | m 80 | (filter (isVariableDefaultable defaults) $ tyCoVarsOfCtList wanted)) 81 | M.empty wanteds 82 | M.foldMapWithKey (\var cts -> 83 | case M.lookup (tyVarKind var) defaults of 84 | Nothing -> error "Bug, we already checked that this variable has a default" 85 | Just deftys -> do 86 | pure [(deftys, (var, cts))]) 87 | groups 88 | where isVariableDefaultable defaults v = isAmbiguousTyVar v && M.member (tyVarKind v) defaults 89 | 90 | lookupDefaultTypes :: TcPluginM PluginState 91 | lookupDefaultTypes = do 92 | md <- lookupModule (mkModuleName "Plugin.DefaultType") 93 | name <- lookupName md (mkTcOcc "DefaultType") 94 | pure $ PluginState { defaultClassName = name } 95 | 96 | initialize = do 97 | lookupDefaultTypes 98 | 99 | run s ws = do 100 | solveDefaultType s (ctsElts $ approximateWC False ws) 101 | 102 | stop _ = do 103 | return () 104 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T08_BiRNN.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, CPP, ConstraintKinds, DataKinds, DeriveAnyClass, DeriveGeneric, FlexibleContexts, FlexibleInstances #-} 2 | {-# LANGUAGE FunctionalDependencies, GADTs, OverloadedLabels, OverloadedStrings, PartialTypeSignatures, PolyKinds, QuasiQuotes #-} 3 | {-# LANGUAGE RankNTypes, RecordWildCards, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies, TypeFamilyDependencies #-} 4 | {-# LANGUAGE TypeInType, TypeOperators, UndecidableInstances #-} 5 | {-# OPTIONS_GHC -fconstraint-solver-iterations=10 -fdefer-typed-holes #-} 6 | {-# OPTIONS_GHC -fplugin-opt GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} 7 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-} 8 | 9 | module Torch.Tutorial.Intro.T08_BiRNN where 10 | import Control.Monad 11 | import Data.Default 12 | import Data.IORef 13 | import Data.String.InterpolateIO 14 | import GHC.Generics 15 | import qualified GHC.TypeLits as TL 16 | import Torch 17 | import Torch.Datasets 18 | 19 | data Model = Model { w1 :: LSTMParams TFloat KBest 28 128 2 'True 'True, 20 | w2 :: LinearParam TFloat KBest (2 TL.* 128) 10 } 21 | deriving(Generic,ParameterNames,Stored,Initialize,ToTensors,ToParameters) 22 | 23 | forward :: _ -> DataPurpose -> Tensor TFloat KBest '[100, 28, 28] -> IO (Tensor TFloat KBest '[100, 10]) 24 | forward (Model{..},state1) isTraining = do 25 | lstmBatchFirst (inF_ @28) (hiddenF_ @128) (nrLayers_ @2) (isBidirectional_ @True) 0 isTraining w1 state1 26 | >=> (\(a,s) -> select @1 a (-1)) 27 | >=> linear (inF_ @(2 TL.* 128)) (outFeatures_ @10) w2 28 | 29 | ex = do 30 | let epochs = 2 :: Int 31 | let learningRate = 0.001 32 | -- 33 | (tr, te) <- mnist "datasets/image/" 34 | (Right tes) <- fetchDataset te 35 | let testStream = batchTensors (batchSize_ @100) tes 36 | (Right trs) <- fetchDataset tr 37 | let trainStream = batchTensors (batchSize_ @100) 38 | $ shuffle 1000 trs 39 | -- 40 | net <- gradP 41 | params <- toParameters net 42 | optimizer <- newIORef (adam (def { adamLearningRate = learningRate }) params) 43 | -- 44 | let criterion y ypred = crossEntropyLoss y def def def ypred 45 | -- 46 | withGrad 47 | $ mapM_ (\epoch -> 48 | forEachDataN 49 | (\d n -> do 50 | initialState <- gradP 51 | zeroGradients_ params 52 | loss <- do 53 | o <- view =<< toDevice =<< dataObject d 54 | l <- view =<< toDevice =<< dataLabel d 55 | pred <- (forward (net,initialState) (dataPurpose d) . sized (size_ @'[100, 28, 28])) =<< view (sized (size_ @'[100,784]) o) 56 | criterion (sized (size_ @'[100]) l) pred 57 | backward1 loss False False 58 | step_ optimizer 59 | when (n `rem` 100 == 0) $ putStrLn =<< [c|Epoch #{epoch+1}/#{epochs} loss #{loss}|] 60 | putStrLn =<< [c|Step1 #{epoch+1}/#{epochs} #{n+1} loss #{loss}|] 61 | pure ()) 62 | trainStream) 63 | [0..epochs-1] 64 | -- 65 | y <- withoutGrad 66 | $ foldData 67 | (\(nr, correct) d -> do 68 | initialState <- gradP 69 | images <- view =<< toDevice =<< dataObject d 70 | labels <- view =<< toDevice =<< dataLabel d 71 | pred <- forward (net,initialState) (dataPurpose d) images 72 | (_, predictionIndices) <- Torch.maxDim @1 pred 73 | nrCorrect <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (predictionIndices .== sized (size_ @'[100]) labels)) 74 | pure (nr + 100, correct + nrCorrect)) 75 | (0, 0) 76 | testStream 77 | putStrLn =<< [c|Test set accuracy #{((100 * (fromIntegral (snd y) / fromIntegral (fst y))) :: Double)}% on #{((fromIntegral (fst y))::Double)} images|] 78 | -- 79 | writeModelToFile net "lstm.ht" 80 | pure () 81 | -------------------------------------------------------------------------------- /haskell-torch-datasets/src/Torch/Datasets/Vision/CIFAR.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, ConstraintKinds, DataKinds, FlexibleContexts, FlexibleInstances, GADTs, KindSignatures #-} 2 | {-# LANGUAGE OverloadedLabels, OverloadedStrings, PartialTypeSignatures, PolyKinds, RankNTypes, ScopedTypeVariables, TypeApplications #-} 3 | {-# LANGUAGE TypeFamilyDependencies, TypeInType, TypeOperators, UndecidableInstances #-} 4 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} 5 | {-# OPTIONS_GHC -fdefer-type-errors #-} 6 | 7 | -- CIFAR10 Dataset 8 | 9 | module Torch.Datasets.Vision.CIFAR where 10 | import Control.Monad.Except 11 | import Data.Map.Strict (Map) 12 | import qualified Data.Map.Strict as M 13 | import Data.Singletons 14 | import Data.Text (Text) 15 | import qualified Data.Text as T 16 | import qualified Data.Text.IO as T 17 | import qualified Data.Text.Read as T 18 | import qualified Foreign.ImageMagick as I 19 | import Pipes 20 | import Torch.Datasets.Common 21 | import Torch.Images 22 | import Torch.Misc 23 | import Torch.Tensor 24 | import Torch.Types 25 | 26 | type CIFAR dataPurpose = Dataset (Map Int Text) dataPurpose Int (Image 3 32 32) (Tensor TLong KCpu '[]) 27 | 28 | cifar10 :: Path -> IO (CIFAR Train ,CIFAR Test) 29 | cifar10 path = 30 | pure (remoteDatasetCIFAR @'Train path filename directory url md5 31 | ,remoteDatasetCIFAR @'Test path filename directory url md5) 32 | where url = "https://pjreddie.com/media/files/cifar.tgz" 33 | filename = "cifar.tgz" 34 | md5 = "a00ceaeb02303e3ff0d0011b38b465fa" 35 | directory = "cifar" 36 | 37 | remoteDatasetCIFAR :: forall (dataPurpose :: DataPurpose). (SingI dataPurpose) 38 | => Text -> Text -> Text -> Text -> Text 39 | -> CIFAR dataPurpose 40 | remoteDatasetCIFAR root filename directory url md5 = 41 | Dataset 42 | { checkDataset = canFail checkAll 43 | , fetchDataset = canFail $ fetchAll False >> access @dataPurpose 44 | , forceFetchDataset = canFail $ fetchAll True >> access @dataPurpose 45 | , accessDataset = canFail (access @dataPurpose) 46 | , metadataDataset = canFail (fst <$> metadata) 47 | } 48 | where path = root filename 49 | dirpath x = root directory x 50 | checkAll = checkMD5 path md5 51 | fetchAll force = do 52 | liftIO $ createDirectoryIfMissing' root 53 | checkMD5 path md5 `retryAfter` downloadUrl url path 54 | e <- liftIO $ doesDirectoryExist (root directory) 55 | when (force || not e) $ extractTar path root 56 | metadata = do 57 | c <- liftIO $ T.readFile (T.unpack $ dirpath "labels.txt") 58 | pure (M.fromList $ zipWith (\l n -> (n,l)) (T.lines c) [0..] 59 | ,M.fromList $ zipWith (\l n -> (l,n)) (T.lines c) [0..]) 60 | pipeCifar :: forall dataPurpose. (SingI dataPurpose) 61 | => M.Map Text Int 62 | -> Pipe Text (DataSample dataPurpose Int (Image 3 32 32) (Tensor TLong KCpu '[])) IO () 63 | pipeCifar labtonr = forever $ do 64 | fname <- await 65 | let [nrtext, label] = T.splitOn "_" $ dropExtension $ takeBaseName $ fname 66 | let (Right (nr, "")) = T.decimal nrtext 67 | case M.lookup label labtonr of 68 | Nothing -> error $ "CIFAR: Don't know what the label of this image should be! " ++ show fname 69 | Just labelnr -> 70 | yield $ DataSample @dataPurpose nr (readImageFromFile fname) 71 | (toScalar $ fromIntegral labelnr) 72 | access :: forall dataPurpose. (SingI dataPurpose) 73 | => CanFail (DataStream dataPurpose Int (Image 3 32 32) (Tensor TLong KCpu '[])) 74 | access = do 75 | liftIO $ I.initialize 76 | (_,labtonr) <- metadata 77 | case demote @dataPurpose of 78 | Test -> 79 | pure $ pipeListDirectory (dirpath "test") 80 | >-> pipeCifar @dataPurpose labtonr 81 | Train -> 82 | pure $ pipeListDirectory (dirpath "train") 83 | >-> pipeCifar @dataPurpose labtonr 84 | -------------------------------------------------------------------------------- /haskell-torch-models/src/Torch/Models/Vision/AlexNet.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, ConstraintKinds, DataKinds, DeriveAnyClass, DeriveGeneric, ExtendedDefaultRules, FlexibleContexts #-} 2 | {-# LANGUAGE FlexibleInstances, GADTs, KindSignatures, MultiParamTypeClasses, OverloadedStrings, PolyKinds, RankNTypes #-} 3 | {-# LANGUAGE ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies, TypeFamilyDependencies, TypeInType, TypeOperators #-} 4 | {-# LANGUAGE UndecidableInstances #-} 5 | {-# OPTIONS_GHC -fconstraint-solver-iterations=50 #-} 6 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} 7 | 8 | -- | AlexNet using the pretrained weights from torchvision. Don't forget to use 9 | -- @standardRGBNormalization@ when running this on any images. 10 | 11 | module Torch.Models.Vision.AlexNet where 12 | import Control.Monad 13 | import Data.Singletons 14 | import Data.Singletons.Prelude.Ord 15 | import Generics.Eot 16 | import GHC.TypeLits as TL 17 | import Torch 18 | import Torch.Datasets 19 | 20 | data AlexNet ki = AlexNet 21 | (ConvParam TFloat ki 64 '[3, 11, 11] 22 | ,ConvParam TFloat ki 192 '[64, 5, 5] 23 | ,ConvParam TFloat ki 384 '[192, 3, 3] 24 | ,ConvParam TFloat ki 256 '[384, 3, 3] 25 | ,ConvParam TFloat ki 256 '[256, 3, 3]) 26 | (LinearParam TFloat ki 9216 4096 27 | ,LinearParam TFloat ki 4096 4096, 28 | LinearParam TFloat ki 4096 1000) 29 | deriving (Generic,ParameterNames,Stored) 30 | 31 | pathToAlexNet = getCachedFileOrDownload "https://www.mediafire.com/file/4p6acunb5ykle5a/alexnet.pt/file" "ba9248ae47a1887ed6623cc568a9755e" "alexnet.pt" "models" 32 | 33 | loadAlexNet :: IO (AlexNet KCpu) 34 | loadAlexNet = do 35 | fname <- pathToAlexNet 36 | m <- readStoredModel fname 37 | a <- loadWithNames m ((("features.0.weight", "features.0.bias"), 38 | ("features.3.weight", "features.3.bias"), 39 | ("features.6.weight", "features.6.bias"), 40 | ("features.8.weight", "features.8.bias"), 41 | ("features.10.weight", "features.10.bias")), 42 | (("classifier.1.weight", "classifier.1.bias"), 43 | ("classifier.4.weight", "classifier.4.bias"), 44 | ("classifier.6.weight", "classifier.6.bias"))) 45 | pure a 46 | 47 | alexNetFeatures (w1, w2, w3, w4, w5) = 48 | conv2d (inChannels_ @3) (outChannels_ @64) (kernel_ @'(11,11)) (stride_ @'(4,4)) (padding_ @'(2,2)) (dilation_ @'(1,1)) (groups_ @1) w1 49 | >=> relu_ 50 | >=> maxPool2d (kernel_ @'(3,3)) (stride_ @'(2,2)) (padding_ @'(0,0)) (dilation_ @'(1,1)) (ceilMode_ @False) 51 | >=> pure . fst 52 | >=> conv2d (inChannels_ @64) (outChannels_ @192) (kernel_ @'(5,5)) (stride_ @'(1,1)) (padding_ @'(2,2)) (dilation_ @'(1,1)) (groups_ @1) w2 53 | >=> relu_ 54 | >=> maxPool2d (kernel_ @'(3,3)) (stride_ @'(2,2)) (padding_ @'(0,0)) (dilation_ @'(1,1)) (ceilMode_ @False) 55 | >=> pure . fst 56 | >=> conv2d (inChannels_ @192) (outChannels_ @384) (kernel_ @'(3,3)) (stride_ @'(1,1)) (padding_ @'(1,1)) (dilation_ @'(1,1)) (groups_ @1) w3 57 | >=> relu_ 58 | >=> conv2d (inChannels_ @384) (outChannels_ @256) (kernel_ @'(3,3)) (stride_ @'(1,1)) (padding_ @'(1,1)) (dilation_ @'(1,1)) (groups_ @1) w4 59 | >=> relu_ 60 | >=> conv2d (inChannels_ @256) (outChannels_ @256) (kernel_ @'(3,3)) (stride_ @'(1,1)) (padding_ @'(1,1)) (dilation_ @'(1,1)) (groups_ @1) w5 61 | >=> relu_ 62 | >=> maxPool2d (kernel_ @'(3,3)) (stride_ @'(2,2)) (padding_ @'(0,0)) (dilation_ @'(1,1)) (ceilMode_ @False) 63 | >=> pure . fst 64 | 65 | alexNetClassifier (w1, w2, w3) dataPurpose = 66 | dropout 0.5 dataPurpose 67 | >=> linear (inFeatures_ @(256 TL.* 6 TL.* 6)) (outFeatures_ @4096) w1 68 | >=> relu_ 69 | >=> dropout 0.5 dataPurpose 70 | >=> linear (inFeatures_ @4096) (outFeatures_ @4096) w2 71 | >=> relu_ 72 | >=> linear (inFeatures_ @4096) (outFeatures_ @1000) w3 73 | 74 | alexNetForward (AlexNet w1s w2s) dataPurpose = 75 | alexNetFeatures w1s 76 | >=> adaptiveAvgPool2d (outFeatures_ @'[6,6]) 77 | >=> flatten 78 | >=> alexNetClassifier w2s dataPurpose 79 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T05_CNN.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, DataKinds, DeriveAnyClass, DeriveGeneric, ExtendedDefaultRules, FlexibleContexts, FlexibleInstances #-} 2 | {-# LANGUAGE OverloadedStrings, PolyKinds, QuasiQuotes, RecordWildCards, ScopedTypeVariables, TemplateHaskell, TypeApplications #-} 3 | {-# LANGUAGE TypeFamilies, TypeOperators #-} 4 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver -fplugin Plugin.SimplifyNat -fconstraint-solver-iterations=10000 #-} 5 | 6 | module Torch.Tutorial.Intro.T05_CNN where 7 | import Control.Monad 8 | import Data.Default 9 | import Data.IORef 10 | import Data.String.InterpolateIO 11 | import GHC.Generics 12 | import GHC.TypeNats 13 | import Torch 14 | import Torch.Datasets 15 | 16 | data SmallCNN = SmallCNN { 17 | w1 :: ConvParam TFloat KCpu 16 '[1, 5, 5] 18 | , w2 :: AffineParam TFloat KCpu '[16] 19 | , w3 :: ConvParam TFloat KCpu 32 '[16, 5, 5] 20 | , w4 :: AffineParam TFloat KCpu '[32] 21 | , w5 :: LinearParam TFloat KCpu 1568 10 22 | , bn1 :: BatchNormState 'TFloat 'KCpu '[16] 23 | , bn2 :: BatchNormState 'TFloat 'KCpu '[32] 24 | } 25 | deriving(Generic,ParameterNames,Stored,Initialize,ToTensors,ToParameters) 26 | 27 | forward :: SmallCNN -> DataPurpose -> Tensor 'TFloat 'KCpu '[100, 1, 28, 28] -> IO (Tensor 'TFloat 'KCpu '[100, 10]) 28 | forward SmallCNN{..} isTraining = 29 | conv2d InChannels (outChannels_ @16) (kernel_ @'(5,5)) (stride_ @'(1,1)) (padding_ @'(2,2)) (dilation_ @'(1,1)) (groups_ @1) w1 30 | >=> batchNorm2d_ bn1 (Just w2) def def isTraining 31 | >=> relu 32 | >=> maxPool2d (kernel_ @'(2,2)) (stride_ @'(2,2)) (padding_ @'(0,0)) (dilation_ @'(1,1)) (ceilMode_ @False) 33 | >=> pure . fst 34 | >=> conv2d InChannels (outChannels_ @32) (kernel_ @'(5,5)) (stride_ @'(1,1)) (padding_ @'(2,2)) (dilation_ @'(1,1)) (groups_ @1) w3 35 | >=> batchNorm2d_ bn2 (Just w4) def def isTraining 36 | >=> relu 37 | >=> maxPool2d (kernel_ @'(2,2)) (stride_ @'(2,2)) (padding_ @'(0,0)) (dilation_ @'(1,1)) (ceilMode_ @False) 38 | >=> pure . fst 39 | >=> view @'[100, 1568] 40 | >=> linear (inFeatures_ @1568) (outFeatures_ @10) w5 41 | 42 | ex = do 43 | let epochs = 2 44 | let learningRate = 0.001 45 | -- 46 | (tr, te) <- mnist "datasets/image/" 47 | (Right tes) <- fetchDataset te 48 | let testStream = batchTensors (batchSize_ @100) tes 49 | (Right trs) <- fetchDataset tr 50 | let trainStream = batchTensors (batchSize_ @100) 51 | $ shuffle 1000 trs 52 | -- 53 | net <- gradP @SmallCNN 54 | params <- toParameters net 55 | optimizer <- newIORef (adam (def { adamLearningRate = learningRate }) params) 56 | -- 57 | let criterion y ypred = crossEntropyLoss y def def def ypred 58 | -- 59 | withGrad 60 | $ mapM_ (\epoch -> 61 | forEachDataN 62 | (\d n -> do 63 | loss <- do 64 | o <- view =<< toCpu =<< dataObject d 65 | l <- view =<< toCpu =<< dataLabel d 66 | pred <- (forward net (dataPurpose d) . sized (size_ @'[100, 1, 28, 28])) =<< view (sized (size_ @'[100,784]) o) 67 | criterion (sized (size_ @'[100]) l) pred 68 | zeroGradients_ params 69 | backward1 loss False False 70 | step_ optimizer 71 | when (n `rem` 100 == 0) $ putStrLn =<< [c|Epoch #{epoch+1}/#{epochs} loss #{loss}|] 72 | putStrLn =<< [c|Step #{epoch+1}/#{epochs} #{n+1} loss #{loss}|] 73 | pure ()) 74 | trainStream) 75 | [0..epochs-1] 76 | -- 77 | y <- withoutGrad 78 | $ foldData 79 | (\(nr, correct) d -> do 80 | o <- copy @TFloat @KCpu =<< dataObject d 81 | l <- sized (size_ @'[100]) <$> (view =<< toCpu =<< dataLabel d) 82 | pred <- (forward net (dataPurpose d) . sized (size_ @'[100, 1, 28, 28])) =<< view (sized (size_ @'[100,784]) o) 83 | (_, is) <- Torch.maxDim @1 pred 84 | s <- fromIntegral <$> (fromScalar =<< Torch.sum =<< toType @TInt =<< (is .== l)) 85 | pure (nr + 100, correct + s)) 86 | (0, 0) 87 | testStream 88 | putStrLn =<< [c|Test set accuracy #{100 * (fromIntegral (snd y) / fromIntegral (fst y))}% on #{fromIntegral (fst y)} images|] 89 | pure () 90 | -------------------------------------------------------------------------------- /haskell-torch-tensorboard-proto/proto/tensorboard/src/summary.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "SummaryProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboard/src/tensor.proto"; 10 | 11 | // Metadata associated with a series of Summary data 12 | message SummaryDescription { 13 | // Hint on how plugins should process the data in this series. 14 | // Supported values include "scalar", "histogram", "image", "audio" 15 | string type_hint = 1; 16 | } 17 | 18 | // Serialization format for histogram module in 19 | // core/lib/histogram/histogram.h 20 | message HistogramProto { 21 | double min = 1; 22 | double max = 2; 23 | double num = 3; 24 | double sum = 4; 25 | double sum_squares = 5; 26 | 27 | // Parallel arrays encoding the bucket boundaries and the bucket values. 28 | // bucket(i) is the count for the bucket i. The range for 29 | // a bucket is: 30 | // i == 0: -DBL_MAX .. bucket_limit(0) 31 | // i != 0: bucket_limit(i-1) .. bucket_limit(i) 32 | repeated double bucket_limit = 6 [packed = true]; 33 | repeated double bucket = 7 [packed = true]; 34 | }; 35 | 36 | // A SummaryMetadata encapsulates information on which plugins are able to make 37 | // use of a certain summary value. 38 | message SummaryMetadata { 39 | message PluginData { 40 | // The name of the plugin this data pertains to. 41 | string plugin_name = 1; 42 | 43 | // The content to store for the plugin. The best practice is for this JSON 44 | // string to be the canonical JSON serialization of a protocol buffer 45 | // defined by the plugin. Converting that protobuf to and from JSON is the 46 | // responsibility of the plugin code, and is not enforced by 47 | // TensorFlow/TensorBoard. 48 | string content = 2; 49 | } 50 | 51 | // A list of plugin data. A single summary value instance may be used by more 52 | // than 1 plugin. 53 | repeated PluginData plugin_data = 1; 54 | }; 55 | 56 | // A Summary is a set of named values to be displayed by the 57 | // visualizer. 58 | // 59 | // Summaries are produced regularly during training, as controlled by 60 | // the "summary_interval_secs" attribute of the training operation. 61 | // Summaries are also produced at the end of an evaluation. 62 | message Summary { 63 | message Image { 64 | // Dimensions of the image. 65 | int32 height = 1; 66 | int32 width = 2; 67 | // Valid colorspace values are 68 | // 1 - grayscale 69 | // 2 - grayscale + alpha 70 | // 3 - RGB 71 | // 4 - RGBA 72 | // 5 - DIGITAL_YUV 73 | // 6 - BGRA 74 | int32 colorspace = 3; 75 | // Image data in encoded format. All image formats supported by 76 | // image_codec::CoderUtil can be stored here. 77 | bytes encoded_image_string = 4; 78 | } 79 | 80 | message Audio { 81 | // Sample rate of the audio in Hz. 82 | float sample_rate = 1; 83 | // Number of channels of audio. 84 | int64 num_channels = 2; 85 | // Length of the audio in frames (samples per channel). 86 | int64 length_frames = 3; 87 | // Encoded audio data and its associated RFC 2045 content type (e.g. 88 | // "audio/wav"). 89 | bytes encoded_audio_string = 4; 90 | string content_type = 5; 91 | } 92 | 93 | message Value { 94 | // Name of the node that output this summary; in general, the name of a 95 | // TensorSummary node. If the node in question has multiple outputs, then 96 | // a ":\d+" suffix will be appended, like "some_op:13". 97 | // Might not be set for legacy summaries (i.e. those not using the tensor 98 | // value field) 99 | string node_name = 7; 100 | 101 | // Tag name for the data. Will only be used by legacy summaries 102 | // (ie. those not using the tensor value field) 103 | // For legacy summaries, will be used as the title of the graph 104 | // in the visualizer. 105 | // 106 | // Tag is usually "op_name:value_name", where "op_name" itself can have 107 | // structure to indicate grouping. 108 | string tag = 1; 109 | SummaryMetadata metadata = 9; 110 | // Value associated with the tag. 111 | oneof value { 112 | float simple_value = 2; 113 | bytes obsolete_old_style_histogram = 3; 114 | Image image = 4; 115 | HistogramProto histo = 5; 116 | Audio audio = 6; 117 | TensorProto tensor = 8; 118 | } 119 | } 120 | 121 | // Set of values for the summary. 122 | repeated Value value = 1; 123 | } 124 | -------------------------------------------------------------------------------- /Plugins.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a76a50e7-fa95-4cdc-86e4-45fa3d8cfb58", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | ":set -XFlexibleContexts -XFlexibleInstances\n", 11 | ":set -XPolyKinds -XRankNTypes -XScopedTypeVariables -XTypeApplications -XTypeFamilies\n", 12 | ":set -XTypeFamilyDependencies -XTypeInType -XTypeOperators -XUndecidableInstances\n", 13 | "import Data.Proxy\n", 14 | "import GHC.TypeLits" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "01fc228d-bf31-4756-858d-bf532e152866", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "let f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer; f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "04bae7f2-3443-444e-9432-083409dc2d21", 30 | "metadata": {}, 31 | "source": [ 32 | "This isn't supposed to work" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "bb3eed7e-7167-4133-8ff8-e7e982c1bd94", 38 | "metadata": {}, 39 | "source": [ 40 | "let f :: forall n . KnownNat n => Proxy n -> Integer; f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 19, 46 | "id": "81fe3842-bf7a-461f-b11e-12e1b2ab5964", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | ":set -fplugin=GHC.TypeLits.KnownNat.Solver" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "a41ca15e-5de6-4320-97e6-3a3ce9b06518", 56 | "metadata": {}, 57 | "source": [ 58 | "This should work now" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 20, 64 | "id": "e56d9088-061c-40bd-a7d7-2707a401ac4c", 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "ename": "", 69 | "evalue": "", 70 | "header": "MessageHeader {mhIdentifiers = [\"17991769-231d-47e7-aadd-8368a29084e1\"], mhParentHeader = Just (MessageHeader {mhIdentifiers = [\"17991769-231d-47e7-aadd-8368a29084e1\"], mhParentHeader = Nothing, mhMetadata = Metadata (fromList [(\"recordTiming\",Bool False),(\"deletedCells\",Array []),(\"cellId\",String \"e56d9088-061c-40bd-a7d7-2707a401ac4c\")]), mhMessageId = UUID {uuidToString = \"fc830778-70d8-4c91-bf2a-2f9e1c724ee2\"}, mhSessionId = UUID {uuidToString = \"17991769-231d-47e7-aadd-8368a29084e1\"}, mhUsername = \"\", mhMsgType = ExecuteRequestMessage, mhBuffers = []}), mhMetadata = Metadata (fromList []), mhMessageId = UUID {uuidToString = \"3bd3b7bb-47da-4dff-8472-f25e5292ddb5\"}, mhSessionId = UUID {uuidToString = \"17991769-231d-47e7-aadd-8368a29084e1\"}, mhUsername = \"\", mhMsgType = ExecuteErrorMessage, mhBuffers = []}", 71 | "output_type": "error", 72 | "traceback": [ 73 | ":1:89: error:\n • Could not deduce (KnownNat (n + 2)) arising from a use of ‘natVal’\n from the context: KnownNat n\n bound by the type signature for:\n f :: forall (n :: Nat). KnownNat n => Proxy n -> Integer\n at :1:5-52\n • In the second argument of ‘(+)’, namely ‘natVal (Proxy :: Proxy (n + 2))’\n In the expression: natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n + 2))\n In an equation for ‘f’: f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n + 2))" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "let f :: forall n . KnownNat n => Proxy n -> Integer; f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "53bbd76b-8f3b-4d42-9f89-dffbf9ffceb5", 84 | "metadata": { 85 | "tags": [] 86 | }, 87 | "source": [ 88 | "import GHC.TypeLits.KnownNat.Solver" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "66cc2535-7f6f-434c-a5f7-64c922e7a079", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | ":set -fplugin GHC.TypeLits.KnownNat.Solver" 99 | ] 100 | } 101 | ], 102 | "metadata": { 103 | "jupytext": { 104 | "cell_metadata_filter": "-all" 105 | }, 106 | "kernelspec": { 107 | "display_name": "Haskell - haskell", 108 | "language": "haskell", 109 | "name": "ihaskell_haskell" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": "Haskell", 113 | "file_extension": ".hs", 114 | "mimetype": "text/x-haskell", 115 | "name": "haskell", 116 | "pygments_lexer": "Haskell", 117 | "version": "8.10.7" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 5 122 | } 123 | -------------------------------------------------------------------------------- /haskell-torch-cbindings/src/Torch/C/Types.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings, QuasiQuotes, ScopedTypeVariables, TemplateHaskell #-} 2 | 3 | -- | All of the types we need to talk to PyTorch. These will be wrapped and not 4 | -- shown to users. 5 | module Torch.C.Types where 6 | import qualified Data.Map as Map 7 | import Data.Monoid (mempty, (<>)) 8 | import Foreign.C.Types 9 | import qualified Language.C.Inline as C 10 | import qualified Language.C.Inline.Context as C 11 | import qualified Language.C.Inline.Cpp as C 12 | import qualified Language.C.Types as C 13 | import qualified Language.Haskell.TH as TH 14 | 15 | C.context C.cppCtx 16 | 17 | data CStorage 18 | data CGenerator 19 | data CVariableType 20 | data CType 21 | data CDevice 22 | data CVariable 23 | 24 | -- Tensors 25 | data CTensorOptions 26 | data CTensor 27 | data CScalar 28 | 29 | -- Jit 30 | data CTracingState 31 | data CGraph 32 | data CEdge 33 | data CNode 34 | data CJitNode 35 | data CJitValue 36 | data CJitIValue 37 | data CJitBlock 38 | data CJitAttributeKind 39 | data CJitScriptModule 40 | 41 | tensorCtx :: C.Context 42 | tensorCtx = C.cppCtx <> C.funCtx <> C.vecCtx <> C.fptrCtx <> ctx 43 | where ctx = mempty 44 | { C.ctxTypesTable = tensorTypesTable } 45 | 46 | tensorTypesTable :: Map.Map C.TypeSpecifier TH.TypeQ 47 | tensorTypesTable = Map.fromList 48 | [ (C.TypeName "bool", [t| C.CBool |]) 49 | -- tensors 50 | , (C.TypeName "TensorOptions", [t| CTensorOptions |]) 51 | , (C.TypeName "Tensor", [t| CTensor |]) 52 | , (C.TypeName "Scalar", [t| CScalar |]) 53 | , (C.TypeName "Storage", [t| CStorage |]) 54 | , (C.TypeName "Generator", [t| CGenerator |]) 55 | , (C.TypeName "JitType", [t| CType |]) 56 | , (C.TypeName "Device", [t| CDevice |]) 57 | -- variables 58 | , (C.TypeName "VariableType", [t| CVariableType|]) 59 | , (C.TypeName "Variable", [t| CVariable |]) 60 | , (C.TypeName "Edge", [t| CEdge |]) 61 | , (C.TypeName "TracingState", [t| CTracingState |]) 62 | , (C.TypeName "Graph", [t| CGraph |]) 63 | , (C.TypeName "Node", [t| CNode |]) 64 | , (C.TypeName "JitNode", [t| CJitNode |]) 65 | , (C.TypeName "JitValue", [t| CJitValue |]) 66 | , (C.TypeName "JitIValue", [t| CJitIValue |]) 67 | , (C.TypeName "JitBlock", [t| CJitBlock |]) 68 | , (C.TypeName "JitAttributeKind", [t| CJitAttributeKind |]) 69 | , (C.TypeName "JitScriptModule", [t| CJitScriptModule |]) 70 | ] 71 | 72 | data Backend = BackendCPU 73 | | BackendCUDA 74 | deriving (Show, Eq) 75 | 76 | data Layout = LayoutStrided 77 | | LayoutSparse 78 | | LayoutMlkdnn 79 | deriving (Show, Eq) 80 | 81 | data ScalarType = ScalarTypeBool 82 | | ScalarTypeByte 83 | | ScalarTypeChar 84 | | ScalarTypeShort 85 | | ScalarTypeInt 86 | | ScalarTypeLong 87 | | ScalarTypeHalf 88 | | ScalarTypeFloat 89 | | ScalarTypeDouble 90 | | ScalarTypeUndefined 91 | deriving (Show, Eq, Ord) 92 | 93 | data TypeKind = TypeKindAny 94 | | TypeKindTensor 95 | | TypeKindTuple 96 | | TypeKindList 97 | | TypeKindDict 98 | | TypeKindNumber 99 | | TypeKindFloat 100 | | TypeKindFuture 101 | | TypeKindInt 102 | | TypeKindNone 103 | | TypeKindString 104 | | TypeKindGenerator 105 | | TypeKindBool 106 | | TypeKindOptional 107 | | TypeKindVar 108 | | TypeKindDeviceObj 109 | | TypeKindFunction 110 | | TypeKindClass 111 | | TypeKindCapsule 112 | | TypeKindInterface 113 | deriving (Show, Eq) 114 | 115 | data Reduction = ReductionNone 116 | | ReductionMean 117 | | ReductionSum 118 | deriving (Show, Eq) 119 | 120 | data MemoryFormat = MemoryFormatContiguous 121 | | MemoryFormatPreserve 122 | | MemoryFormatChannelsLast 123 | | MemoryFormatChannelsLast3d 124 | deriving (Show, Eq) 125 | 126 | data AttributeKind = AttributeKindFloat 127 | | AttributeKindFloats 128 | | AttributeKindInt 129 | | AttributeKindInts 130 | | AttributeKindString 131 | | AttributeKindStrings 132 | | AttributeKindTensor 133 | | AttributeKindTensors 134 | | AttributeKindGraph 135 | | AttributeKindGraphs 136 | deriving (Show, Eq) 137 | 138 | data ModuleEntityType = ModuleEntityType 139 | | ParameterEntityType 140 | | MethodEntityType 141 | deriving (Show, Eq) 142 | -------------------------------------------------------------------------------- /interpolateIO/src/Data/String/ShowIO.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables, MultiParamTypeClasses #-} 2 | {-# LANGUAGE FlexibleContexts, DefaultSignatures #-} 3 | {-# LANGUAGE FlexibleInstances, UndecidableInstances #-} 4 | 5 | module Data.String.ShowIO ( 6 | -- * Just like Show but in IO. Takes the tedium out of printing impure values. 7 | ShowIOS, ShowIO, showIOs, showIO, showsPrecIO, showListIO, showStringIO, IsStringIO, fromStringIO 8 | ) where 9 | 10 | import Data.String 11 | import Generics.Eot as GE 12 | 13 | -- | The @showsIO@ functions return a function that prepends the 14 | -- output 'String' to an existing 'String'. This allows constant-time 15 | -- concatenation of results using lifted functione composition @>=>@. 16 | type ShowIOS = String -> IO String 17 | 18 | -- | equivalent to 'showsPrec' with a precedence of 0. 19 | showIOs :: (ShowIO a) => a -> ShowIOS 20 | showIOs = showsPrecIO 0 21 | 22 | -- | utility function converting a 'String' to a show function that 23 | -- simply prepends the string unchanged. 24 | showStringIO :: String -> ShowIOS 25 | showStringIO s s' = pure $ s ++ s' 26 | 27 | -- | Conversion of values to readable 'IO String's. 28 | class ShowIO a where 29 | -- | Convert a value to a readable 'IO String'. 30 | -- 31 | -- 'showsPrecIO' should satisfy the law 32 | -- 33 | -- > do 34 | -- > sr <- showsPrecIO d x r 35 | -- > srs <- showsPrecIO d x (r ++ s) 36 | -- > pure $ sr ++ s == srs 37 | 38 | showsPrecIO :: Int -- ^ the operator precedence of the enclosing 39 | -- context (a number from @0@ to @11@). 40 | -- Function application has precedence @10@. 41 | -> a -- ^ the value to be converted to a 'IO String' 42 | -> ShowIOS 43 | 44 | -- | A specialised variant of 'showsPrecIO', using precedence context 45 | -- zero, and returning an ordinary 'IO String'. 46 | showIO :: a -> IO String 47 | default showIO :: (HasEot a, EotShowIO GE.Datatype (Eot a)) => a -> IO String 48 | showIO a = eotShowIO (datatype (Proxy :: Proxy a)) (toEot a) 49 | 50 | -- | The method 'showListIO' is provided to allow the programmer to 51 | -- give a specialised way of showing lists of values. 52 | -- For example, this is used by the predefined 'Show' instance of 53 | -- the 'Char' type, where values of type 'String' should be shown 54 | -- in double quotes, rather than between square brackets. 55 | showListIO :: [a] -> ShowIOS 56 | 57 | showsPrecIO _ x s = (++ s) <$> showIO x 58 | showListIO ls s = showListIO__ showIOs ls s 59 | 60 | instance {-# OVERLAPS #-} Show a => ShowIO a where 61 | showIO x = pure $ show x 62 | 63 | instance {-# OVERLAPS #-} ShowIO a => ShowIO (IO a) where 64 | showIO x = x >>= showIO 65 | 66 | showListIO__ :: (a -> ShowIOS) -> [a] -> ShowIOS 67 | showListIO__ _ [] s = pure $ "[]" ++ s 68 | showListIO__ showx (x:xs) s = ('[' :) <$> (showx x =<< showl xs) 69 | where 70 | showl [] = pure (']' : s) 71 | showl (y:ys) = (',' :) <$> (showx y =<< showl ys) 72 | 73 | class IsStringIO a where 74 | fromStringIO :: String -> IO a 75 | 76 | instance IsString a => IsStringIO a where 77 | fromStringIO x = pure $ fromString x 78 | 79 | class EotShowIO meta eot where 80 | eotShowIO :: meta -> eot -> IO String 81 | 82 | instance (EotShowIO [GE.Constructor] a) => EotShowIO GE.Datatype a where 83 | eotShowIO meta a = eotShowIO (constructors meta) a 84 | 85 | instance (EotShowIO [String] this, EotShowIO Int this, EotShowIO [GE.Constructor] next) 86 | => EotShowIO [GE.Constructor] (Either this next) where 87 | eotShowIO (m:_) (Left this) = case m of 88 | Constructor con (Selectors fieldNames) -> 89 | (\x -> con <> " { " <> x <> " } ") <$> eotShowIO fieldNames this 90 | Constructor con (NoSelectors nr) -> 91 | (\x -> con <> " { " <> x <> " } ") <$> eotShowIO nr this 92 | Constructor con NoFields -> 93 | pure con 94 | eotShowIO (_:ms) (Right next) = eotShowIO ms next 95 | eotShowIO [] _ = error "Impossible" 96 | 97 | instance {-# OVERLAPS #-} ShowIO x => EotShowIO Int (x, ()) where 98 | eotShowIO _ (x, ()) = showIO x 99 | 100 | instance {-# OVERLAPS #-} (ShowIO x, EotShowIO Int xs) => EotShowIO Int (x, xs) where 101 | eotShowIO n (x, xs) = do 102 | ps <- showIO x 103 | ps' <- eotShowIO (n + 1) xs 104 | pure $ ps <> " " <> ps' 105 | 106 | instance {-# OVERLAPS #-} ShowIO x => EotShowIO [String] (x, ()) where 107 | eotShowIO [name] (x, ()) = (\v -> name <> " = " <> v) <$> showIO x 108 | eotShowIO _ _ = error "Impossible" 109 | 110 | instance {-# OVERLAPS #-} (ShowIO x, EotShowIO [String] xs) => EotShowIO [String] (x, xs) where 111 | eotShowIO (name:names) (x, xs) = do 112 | ps <- showIO x 113 | ps' <- eotShowIO names xs 114 | pure $ name <> " = " <> ps <> ", " <> ps' 115 | eotShowIO [] _ = error "Impossible" 116 | 117 | instance EotShowIO meta () where 118 | eotShowIO _ _ = pure "" 119 | 120 | instance EotShowIO meta Void where 121 | eotShowIO _ _ = pure "" 122 | -------------------------------------------------------------------------------- /haskell-torch-examples/src/Torch/Tutorial/Intro/T11_VAE.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes, CPP, ConstraintKinds, DataKinds, DeriveAnyClass, DeriveGeneric, FlexibleContexts, FlexibleInstances #-} 2 | {-# LANGUAGE FunctionalDependencies, GADTs, OverloadedLabels, OverloadedStrings, PartialTypeSignatures, PolyKinds, QuasiQuotes #-} 3 | {-# LANGUAGE RankNTypes, RecordWildCards, ScopedTypeVariables, TemplateHaskell, TypeApplications, TypeFamilies, TypeFamilyDependencies #-} 4 | {-# LANGUAGE TypeInType, TypeOperators, UndecidableInstances #-} 5 | {-# OPTIONS_GHC -fconstraint-solver-iterations=10 -fdefer-typed-holes #-} 6 | {-# OPTIONS_GHC -fplugin-opt GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} 7 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-} 8 | 9 | module Torch.Tutorial.Intro.T11_VAE where 10 | import Control.Monad 11 | import Control.Monad.Extra 12 | import Data.Default 13 | import Data.IORef 14 | import Data.String.InterpolateIO 15 | import System.Directory as D 16 | import Torch 17 | import Torch.Tensor as T 18 | import Torch.Datasets 19 | 20 | type ImageSize = 784 21 | type HDim = 400 22 | type ZDim = 20 23 | type BatchSz = 128 24 | 25 | encode :: _ => _ 26 | -> Tensor TFloat ki '[batch, ImageSize] 27 | -> IO (Tensor TFloat ki '[batch, ZDim], Tensor TFloat ki '[batch, ZDim]) 28 | encode (w1,w2,w3,_,_) x = do 29 | h <- linear (inF_ @ImageSize) (outF_ @HDim) w1 x 30 | a <- linear (inF_ @HDim) (outF_ @ZDim) w2 h 31 | b <- linear (inF_ @HDim) (outF_ @ZDim) w3 h 32 | pure (a,b) 33 | 34 | decode :: _ 35 | => _ 36 | -> Tensor TFloat ki '[batch, ZDim] 37 | -> IO (Tensor TFloat ki '[batch, ImageSize]) 38 | decode (_,_,_,w4,w5) x = do 39 | h <- relu =<< linear (inF_ @ZDim) (outF_ @HDim) w4 x 40 | sigmoid =<< linear (inF_ @HDim) (outF_ @ImageSize) w5 h 41 | 42 | reparameterize :: _ 43 | => Tensor TFloat ki '[batch, ZDim] 44 | -> Tensor TFloat ki '[batch, ZDim] 45 | -> IO (Tensor TFloat ki '[batch, ZDim]) 46 | reparameterize mu logVar = do 47 | std <- T.exp =<< logVar ./@ 2 48 | eps <- randn 49 | pure mu ..+ (like std eps .* std) 50 | 51 | forward ws x = do 52 | (mu, logVar) <- encode ws x 53 | z <- reparameterize mu logVar 54 | x' <- decode ws z 55 | pure (x', mu, logVar) 56 | 57 | ex = do 58 | let epochs = 200 :: Int 59 | -- 60 | whenM (D.doesDirectoryExist "generated-images") $ 61 | D.removeDirectoryRecursive "generated-images" 62 | D.createDirectory "generated-images" 63 | net <- gradP 64 | params <- toParameters net 65 | optimizer <- newIORef (adam (def { adamLearningRate = 1e-3 }) params) 66 | -- 67 | (tr, _) <- mnist "datasets/image/" 68 | (Right trs) <- fetchDataset tr 69 | let trainStream = batchTensors (batchSize_ @BatchSz) 70 | $ shuffle 5000 trs 71 | -- 72 | withGrad 73 | $ mapM_ 74 | (\epoch -> 75 | forEachDataN 76 | (\d n -> do 77 | images <- reshape =<< dataObject d 78 | (images', mu, logVar) <- forward net images 79 | -- Compute reconstruction loss and kl divergence 80 | -- For KL divergence, see Appendix B in VAE paper 81 | -- or http://yunjey47.tistory.com/43 82 | reconstructionLoss <- binaryCrossEntropyLoss images def (SizeAverage False) images' 83 | -- TODO This could be a lot cleaner 84 | klDivergence <- 85 | ((-0.5) @*. ) =<< T.sum =<< (1 @+. logVar ..- (T.pow mu =<< toScalar 2) ..- T.exp logVar) 86 | -- Backprop and optimize 87 | loss <- reconstructionLoss .+ klDivergence 88 | zeroGradients_ params 89 | backward1 loss False False 90 | step_ optimizer 91 | when (n `rem` 100 == 0) $ 92 | putStrLn =<< [c|Epoch #{epoch+1}/#{epochs} Reconstruction loss #{reconstructionLoss} KL divergence #{klDivergence}|] 93 | putStrLn =<< [c|Step #{n+1} #{epoch+1}/#{epochs} Reconstruction loss #{reconstructionLoss} KL divergence #{klDivergence}|] 94 | when (n `rem` 100 == 0) $ withoutGrad $ do 95 | -- Sample an image from tha latent space 96 | z <- sized (size_ @'[BatchSz, ZDim]) <$> randn 97 | sampledImages <- decode net z 98 | writeGreyTensorToFile ("generated-images/sampled-"<>show' epoch<>"@"<>show' n<>".jpg") 99 | =<< makeGreyGrid (size_ @8) (padding_ @2) 0 100 | =<< reshape @'[BatchSz, 1, 28, 28] sampledImages 101 | -- Reconstruct images from the dataset 102 | (images', _, _) <- forward net images 103 | is <- reshape @'[BatchSz, 1, 28, 28] images 104 | is' <- reshape @'[BatchSz, 1, 28, 28] images' 105 | writeGreyTensorToFile ("generated-images/reconstructed-"<>show' epoch<>"@"<>show' n<>".jpg") 106 | =<< makeGreyGrid (size_ @8) (padding_ @2) 0 107 | =<< cat2 @0 is is' 108 | pure () 109 | pure ()) 110 | trainStream) 111 | [0..epochs-1] 112 | pure () 113 | --------------------------------------------------------------------------------