├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── _config.yml ├── bench ├── bench-lstm.hs └── bench.hs ├── cbits ├── gradient_descent.c ├── gradient_descent.h ├── im2col.c ├── im2col.h ├── pad.c └── pad.h ├── examples ├── LICENSE ├── grenade-examples.cabal ├── mafia └── main │ ├── feedforward.hs │ ├── gan-mnist.hs │ ├── mnist.hs │ ├── recurrent.hs │ └── shakespeare.hs ├── framework └── mafia ├── grenade.cabal ├── mafia ├── src ├── Grenade.hs └── Grenade │ ├── Core.hs │ ├── Core │ ├── Layer.hs │ ├── LearningParameters.hs │ ├── Network.hs │ ├── Runner.hs │ └── Shape.hs │ ├── Layers.hs │ ├── Layers │ ├── Concat.hs │ ├── Convolution.hs │ ├── Crop.hs │ ├── Deconvolution.hs │ ├── Dropout.hs │ ├── Elu.hs │ ├── FullyConnected.hs │ ├── Inception.hs │ ├── Internal │ │ ├── Convolution.hs │ │ ├── Pad.hs │ │ ├── Pooling.hs │ │ └── Update.hs │ ├── Logit.hs │ ├── Merge.hs │ ├── Pad.hs │ ├── Pooling.hs │ ├── Relu.hs │ ├── Reshape.hs │ ├── Sinusoid.hs │ ├── Softmax.hs │ ├── Tanh.hs │ └── Trivial.hs │ ├── Recurrent.hs │ ├── Recurrent │ ├── Core.hs │ ├── Core │ │ ├── Layer.hs │ │ ├── Network.hs │ │ └── Runner.hs │ ├── Layers.hs │ └── Layers │ │ ├── BasicRecurrent.hs │ │ ├── ConcatRecurrent.hs │ │ └── LSTM.hs │ └── Utils │ └── OneHot.hs ├── stack.yaml ├── stack.yaml.lock └── test ├── Test ├── Grenade │ ├── Layers │ │ ├── Convolution.hs │ │ ├── FullyConnected.hs │ │ ├── Internal │ │ │ ├── Convolution.hs │ │ │ ├── Pooling.hs │ │ │ └── Reference.hs │ │ ├── Nonlinear.hs │ │ ├── PadCrop.hs │ │ └── Pooling.hs │ ├── Network.hs │ └── Recurrent │ │ └── Layers │ │ ├── LSTM.hs │ │ └── LSTM │ │ └── Reference.hs └── Hedgehog │ ├── Compat.hs │ ├── Hmatrix.hs │ └── TypeLits.hs └── test.hs /.gitignore: -------------------------------------------------------------------------------- 1 | cabal.project.local 2 | .cabal-sandbox/ 3 | cabal.sandbox.config 4 | dist/ 5 | dist-newstyle/ 6 | .stack-work/ 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # NB: don't set `language: haskell` here 2 | 3 | # The following enables several GHC versions to be tested; often it's enough to test only against the last release in a major GHC version. Feel free to omit lines listings versions you don't need/want testing for. 4 | env: 5 | - GHCVER=8.0.2 6 | - GHCVER=8.2.2 7 | - GHCVER=8.4.4 8 | - GHCVER=8.6.5 9 | - GHCVER=8.8.3 10 | - GHCVER=8.10.1 11 | 12 | # Note: the distinction between `before_install` and `install` is not important. 13 | before_install: 14 | - travis_retry sudo add-apt-repository -y ppa:hvr/ghc 15 | - travis_retry sudo apt-get update 16 | - travis_retry sudo apt-get install cabal-install-3.2 ghc-$GHCVER libblas-dev liblapack-dev 17 | - export PATH=/opt/cabal/bin:/opt/ghc/$GHCVER/bin:$PATH 18 | 19 | install: 20 | - echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]" 21 | - travis_retry cabal-3.2 update 22 | 23 | notifications: 24 | email: false 25 | 26 | # Here starts the actual work to be performed for the package under test; any command which exits with a non-zero exit code causes the build to fail. 27 | script: 28 | - cabal-3.2 configure --enable-tests 29 | - cabal-3.2 build all 30 | - cabal-3.2 test --test-show-details=direct -j1 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016-2017, Huw Campbell 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Grenade 2 | ======= 3 | 4 | [![Build Status](https://api.travis-ci.org/HuwCampbell/grenade.svg?branch=master)](https://travis-ci.org/HuwCampbell/grenade) 5 | [![Hackage page (downloads and API reference)][hackage-png]][hackage] 6 | [![Hackage-Deps][hackage-deps-png]][hackage-deps] 7 | 8 | 9 | ``` 10 | First shalt thou take out the Holy Pin, then shalt thou count to three, no more, no less. 11 | Three shall be the number thou shalt count, and the number of the counting shall be three. 12 | Four shalt thou not count, neither count thou two, excepting that thou then proceed to three. 13 | Five is right out. 14 | ``` 15 | 16 | 💣 Machine learning which might blow up in your face 💣 17 | 18 | Grenade is a composable, dependently typed, practical, and fast recurrent neural network library 19 | for concise and precise specifications of complex networks in Haskell. 20 | 21 | As an example, a network which can achieve ~1.5% error on MNIST can be 22 | specified and initialised with random weights in a few lines of code with 23 | ```haskell 24 | type MNIST 25 | = Network 26 | '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu 27 | , Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu 28 | , FullyConnected 256 80, Logit, FullyConnected 80 10, Logit] 29 | '[ 'D2 28 28 30 | , 'D3 24 24 10, 'D3 12 12 10 , 'D3 12 12 10 31 | , 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256 32 | , 'D1 80, 'D1 80, 'D1 10, 'D1 10] 33 | 34 | randomMnist :: MonadRandom m => m MNIST 35 | randomMnist = randomNetwork 36 | ``` 37 | 38 | And that's it. Because the types are so rich, there's no specific term level code 39 | required to construct this network; although it is of course possible and 40 | easy to construct and deconstruct the networks and layers explicitly oneself. 41 | 42 | If recurrent neural networks are more your style, you can try defining something 43 | ["unreasonably effective"](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) 44 | with 45 | ```haskell 46 | type Shakespeare 47 | = RecurrentNetwork 48 | '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit] 49 | '[ 'D1 40, 'D1 80, 'D1 40, 'D1 40, 'D1 40 ] 50 | ``` 51 | 52 | Design 53 | ------ 54 | 55 | Networks in Grenade can be thought of as a heterogeneous lists of layers, where 56 | their type includes not only the layers of the network, but also the shapes of 57 | data that are passed between the layers. 58 | 59 | The definition of a network is surprisingly simple: 60 | ```haskell 61 | data Network :: [*] -> [Shape] -> * where 62 | NNil :: SingI i 63 | => Network '[] '[i] 64 | 65 | (:~>) :: (SingI i, SingI h, Layer x i h) 66 | => !x 67 | -> !(Network xs (h ': hs)) 68 | -> Network (x ': xs) (i ': h ': hs) 69 | ``` 70 | 71 | The `Layer x i o` constraint ensures that the layer `x` can sensibly perform a 72 | transformation between the input and output shapes `i` and `o`. 73 | 74 | The lifted data kind `Shape` defines our 1, 2, and 3 dimension types, used to 75 | declare what shape of data is passed between the layers. 76 | 77 | In the MNIST example above, the input layer can be seen to be a two dimensional 78 | (`D2`), image with 28 by 28 pixels. When the first *Convolution* layer runs, it 79 | outputs a three dimensional (`D3`) 24x24x10 image. The last item in the list is 80 | one dimensional (`D1`) with 10 values, representing the categories of the MNIST 81 | data. 82 | 83 | Usage 84 | ----- 85 | 86 | To perform back propagation, one can call the eponymous function 87 | ```haskell 88 | backPropagate :: forall shapes layers. 89 | Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Gradients layers 90 | ``` 91 | which takes a network, appropriate input and target data, and returns the 92 | back propagated gradients for the network. The shapes of the gradients are 93 | appropriate for each layer, and may be trivial for layers like `Relu` which 94 | have no learnable parameters. 95 | 96 | The gradients however can always be applied, yielding a new (hopefully better) 97 | layer with 98 | ```haskell 99 | applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss 100 | ``` 101 | 102 | Layers in Grenade are represented as Haskell classes, so creating one's own is 103 | easy in downstream code. If the shapes of a network are not specified correctly 104 | and a layer can not sensibly perform the operation between two shapes, then 105 | it will result in a compile time error. 106 | 107 | Composition 108 | ----------- 109 | 110 | Networks and Layers in Grenade are easily composed at the type level. As a `Network` 111 | is an instance of `Layer`, one can use a trained Network as a small component in a 112 | larger network easily. Furthermore, we provide 2 layers which are designed to run 113 | layers in parallel and merge their output (either by concatenating them across one 114 | dimension or summing by pointwise adding their activations). This allows one to 115 | write any Network which can be expressed as a 116 | [series parallel graph](https://en.wikipedia.org/wiki/Series-parallel_graph). 117 | 118 | A residual network layer specification for instance could be written as 119 | ```haskell 120 | type Residual net = Merge Trivial net 121 | ``` 122 | If the type `net` is an instance of `Layer`, then `Residual net` will be too. It will 123 | run the network, while retaining its input by passing it through the `Trivial` layer, 124 | and merge the original image with the output. 125 | 126 | See the [MNIST](https://github.com/HuwCampbell/grenade/blob/master/examples/main/mnist.hs) 127 | example, which has been overengineered to contain both residual style learning as well 128 | as inception style convolutions. 129 | 130 | Generative Adversarial Networks 131 | ------------------------------- 132 | 133 | As Grenade is purely functional, one can compose its training functions in flexible 134 | ways. [GAN-MNIST](https://github.com/HuwCampbell/grenade/blob/master/examples/main/gan-mnist.hs) 135 | example displays an interesting, type safe way of writing a generative adversarial 136 | training function in 10 lines of code. 137 | 138 | Layer Zoo 139 | --------- 140 | 141 | Grenade layers are normal haskell data types which are an instance of `Layer`, so 142 | it's easy to build one's own downstream code. We do however provide a decent set 143 | of layers, including convolution, deconvolution, pooling, pad, crop, logit, relu, 144 | elu, tanh, and fully connected. 145 | 146 | Build Instructions 147 | ------------------ 148 | Grenade is most easily built with the [mafia](https://github.com/ambiata/mafia) 149 | script that is located in the repository. You will also need the `lapack` and 150 | `blas` libraries and development tools. Once you have all that, Grenade can be 151 | build using: 152 | 153 | ``` 154 | ./mafia build 155 | ``` 156 | 157 | and the tests run using: 158 | 159 | ``` 160 | ./mafia test 161 | ``` 162 | 163 | Grenade builds with ghc 7.10, 8.0, 8.2 and 8.4. 164 | 165 | Thanks 166 | ------ 167 | Writing a library like this has been on my mind for a while now, but a big shout 168 | out must go to [Justin Le](https://github.com/mstksg), whose 169 | [dependently typed fully connected network](https://blog.jle.im/entry/practical-dependent-types-in-haskell-1.html) 170 | inspired me to get cracking, gave many ideas for the type level tools I 171 | needed, and was a great starting point for writing this library. 172 | 173 | Performance 174 | ----------- 175 | Grenade is backed by hmatrix, BLAS, and LAPACK, with critical functions optimised 176 | in C. Using the im2col trick popularised by Caffe, it should be sufficient for 177 | many problems. 178 | 179 | Being purely functional, it should also be easy to run batches in parallel, which 180 | would be appropriate for larger networks, my current examples however are single 181 | threaded. 182 | 183 | Training 15 generations over Kaggle's 41000 sample MNIST training set on a single 184 | core took around 12 minutes, achieving 1.5% error rate on a 1000 sample holdout set. 185 | 186 | Contributing 187 | ------------ 188 | Contributions are welcome. 189 | 190 | [hackage]: http://hackage.haskell.org/package/grenade 191 | [hackage-png]: http://img.shields.io/hackage/v/grenade.svg 192 | [hackage-deps]: http://packdeps.haskellers.com/reverse/grenade 193 | [hackage-deps-png]: https://img.shields.io/hackage-deps/v/grenade.svg 194 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /bench/bench-lstm.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE BangPatterns #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | import Criterion.Main 5 | 6 | import Grenade 7 | import Grenade.Recurrent 8 | 9 | main :: IO () 10 | main = do 11 | input40 :: S ('D1 40) <- randomOfShape 12 | lstm :: RecNet <- randomRecurrent 13 | 14 | defaultMain [ 15 | bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)] 16 | , bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40) 17 | , bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40) 18 | ] 19 | ] 20 | 21 | nfT2 :: (a, b) -> (a, b) 22 | nfT2 (!a, !b) = (a, b) 23 | 24 | 25 | type R = Recurrent 26 | type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ] 27 | '[ 'D1 40, 'D1 512, 'D1 40 ] 28 | 29 | lp :: LearningParameters 30 | lp = LearningParameters 0.1 0 0 31 | -------------------------------------------------------------------------------- /bench/bench.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | import Criterion.Main 4 | 5 | import Grenade 6 | 7 | import Grenade.Layers.Internal.Convolution 8 | import Grenade.Layers.Internal.Pooling 9 | 10 | import Numeric.LinearAlgebra 11 | 12 | main :: IO () 13 | main = do 14 | x :: S ('D2 60 60 ) <- randomOfShape 15 | y :: S ('D3 60 60 1) <- randomOfShape 16 | 17 | defaultMain [ 18 | bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..]) 19 | , bench "im2col 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..]) 20 | , bench "im2col 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..]) 21 | ] 22 | , bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..]) 23 | , bench "col2im 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..]) 24 | , bench "col2im 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..]) 25 | ] 26 | , bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 1 3 4 2 2 1 1) ((3><4) [1..]) 27 | , bench "poolforwards 28x28" $ whnf (poolForward 1 28 28 5 5 1 1) ((28><28) [1..]) 28 | , bench "poolforwards 100x100" $ whnf (poolForward 1 100 100 10 10 1 1) ((100><100) [1..]) 29 | ] 30 | , bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 1 3 4 2 2 1 1 ((3><4) [1..])) ((2><3) [1..]) 31 | , bench "poolbackwards 28x28" $ whnf (poolBackward 1 28 28 5 5 1 1 ((28><28) [1..])) ((24><24) [1..]) 32 | , bench "poolbackwards 100x100" $ whnf (poolBackward 1 100 100 10 10 1 1 ((100><100) [1..])) ((91><91) [1..]) 33 | ] 34 | , bgroup "padcrop" [ bench "pad 2D 60x60" $ whnf (testRun2D Pad) x 35 | , bench "pad 3D 60x60" $ whnf (testRun3D Pad) y 36 | , bench "crop 2D 60x60" $ whnf (testRun2D' Crop) x 37 | , bench "crop 3D 60x60" $ whnf (testRun3D' Crop) y 38 | ] 39 | ] 40 | 41 | 42 | testRun2D :: Pad 1 1 1 1 -> S ('D2 60 60) -> S ('D2 62 62) 43 | testRun2D = snd ... runForwards 44 | 45 | testRun3D :: Pad 1 1 1 1 -> S ('D3 60 60 1) -> S ('D3 62 62 1) 46 | testRun3D = snd ... runForwards 47 | 48 | testRun2D' :: Crop 1 1 1 1 -> S ('D2 60 60) -> S ('D2 58 58) 49 | testRun2D' = snd ... runForwards 50 | 51 | testRun3D' :: Crop 1 1 1 1 -> S ('D3 60 60 1) -> S ('D3 58 58 1) 52 | testRun3D' = snd ... runForwards 53 | 54 | (...) :: (a -> b) -> (c -> d -> a) -> c -> d -> b 55 | (...) = (.) . (.) 56 | -------------------------------------------------------------------------------- /cbits/gradient_descent.c: -------------------------------------------------------------------------------- 1 | #include "gradient_descent.h" 2 | 3 | void descend_cpu(int len, double rate, double momentum, double regulariser, 4 | const double* weights, 5 | const double* gradient, 6 | const double* last, 7 | double* outputWeights, double* outputMomentum) { 8 | 9 | for (int i = 0; i < len; i++) { 10 | outputMomentum[i] = momentum * last[i] - rate * gradient[i]; 11 | outputWeights[i] = weights[i] + outputMomentum[i] - (rate * regulariser) * weights[i]; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /cbits/gradient_descent.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void descend_cpu(int len, double rate, double momentum, double regulariser, 5 | const double* weights, 6 | const double* gradient, 7 | const double* last, 8 | double* outputWeights, double* outputMomentum); 9 | 10 | -------------------------------------------------------------------------------- /cbits/im2col.c: -------------------------------------------------------------------------------- 1 | #include "im2col.h" 2 | 3 | void im2col_cpu(const double* data_im, const int channels, 4 | const int height, const int width, const int kernel_h, const int kernel_w, 5 | const int stride_h, const int stride_w, 6 | double* data_col) { 7 | 8 | const int channel_size = height * width; 9 | 10 | for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) { 11 | for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) { 12 | for (int channel = 0; channel < channels; channel++) { 13 | for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { 14 | for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { 15 | int input_row = fitting_height + kernel_row; 16 | int input_col = fitting_width + kernel_col; 17 | *(data_col++) = data_im[input_row * width + input_col + channel_size * channel]; 18 | } 19 | } 20 | } 21 | } 22 | } 23 | } 24 | 25 | void col2im_cpu(const double* data_col, const int channels, 26 | const int height, const int width, const int kernel_h, const int kernel_w, 27 | const int stride_h, const int stride_w, 28 | double* data_im) { 29 | 30 | memset(data_im, 0, height * width * channels * sizeof(double)); 31 | 32 | const int channel_size = height * width; 33 | 34 | for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) { 35 | for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) { 36 | for (int channel = 0; channel < channels; channel++) { 37 | for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { 38 | for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { 39 | int input_row = fitting_height + kernel_row; 40 | int input_col = fitting_width + kernel_col; 41 | data_im[input_row * width + input_col + channel_size * channel] += *(data_col++); 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | inline double max ( double a, double b ) { return a > b ? a : b; } 50 | 51 | void pool_forwards_cpu(const double* data_im, const int channels, 52 | const int height, const int width, const int kernel_h, const int kernel_w, 53 | const int stride_h, const int stride_w, 54 | double* data_pooled) { 55 | 56 | const int channel_size = height * width; 57 | 58 | for (int channel = 0; channel < channels; channel++) { 59 | for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) { 60 | for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) { 61 | // Start with the value in 0,0 62 | int max_index = fitting_height * width + fitting_width + channel_size * channel; 63 | double max_value = data_im[max_index]; 64 | // Initial row, skipping the corner we've done 65 | for (int kernel_col = 1; kernel_col < kernel_w; kernel_col++) { 66 | int input_row = fitting_height; 67 | int input_col = fitting_width + kernel_col; 68 | int data_index = input_row * width + input_col + channel_size * channel; 69 | double data_value = data_im[data_index]; 70 | max_value = max ( max_value, data_value ); 71 | } 72 | // The remaining rows 73 | for (int kernel_row = 1; kernel_row < kernel_h; kernel_row++) { 74 | for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { 75 | int input_row = fitting_height + kernel_row; 76 | int input_col = fitting_width + kernel_col; 77 | int data_index = input_row * width + input_col + channel_size * channel; 78 | double data_value = data_im[data_index]; 79 | max_value = max ( max_value, data_value ); 80 | } 81 | } 82 | *(data_pooled++) = max_value; 83 | } 84 | } 85 | } 86 | } 87 | 88 | void pool_backwards_cpu(const double* data_im, const double* data_pooled, 89 | const int channels, const int height, const int width, const int kernel_h, 90 | const int kernel_w, const int stride_h, const int stride_w, 91 | double* data_backgrad ) { 92 | 93 | memset(data_backgrad, 0, height * width * channels * sizeof(double)); 94 | 95 | const int channel_size = height * width; 96 | 97 | for (int channel = 0; channel < channels; channel++) { 98 | for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) { 99 | for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) { 100 | int max_index = fitting_height * width + fitting_width + channel_size * channel; 101 | double max_value = data_im[max_index]; 102 | for (int kernel_col = 1; kernel_col < kernel_w; kernel_col++) { 103 | int input_row = fitting_height; 104 | int input_col = fitting_width + kernel_col; 105 | int data_index = input_row * width + input_col + channel_size * channel; 106 | double data_value = data_im[data_index]; 107 | if ( data_value > max_value ) { 108 | max_index = data_index; 109 | max_value = data_value; 110 | } 111 | } 112 | for (int kernel_row = 1; kernel_row < kernel_h; kernel_row++) { 113 | for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { 114 | int input_row = fitting_height + kernel_row; 115 | int input_col = fitting_width + kernel_col; 116 | int data_index = input_row * width + input_col + channel_size * channel; 117 | double data_value = data_im[data_index]; 118 | if ( data_value > max_value ) { 119 | max_index = data_index; 120 | max_value = data_value; 121 | } 122 | } 123 | } 124 | data_backgrad[max_index] += *(data_pooled++); 125 | } 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /cbits/im2col.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void im2col_cpu(const double* data_im, const int channels, 6 | const int height, const int width, const int kernel_h, const int kernel_w, 7 | const int stride_h, const int stride_w, 8 | double* data_col); 9 | 10 | void col2im_cpu(const double* data_col, const int channels, 11 | const int height, const int width, const int kernel_h, const int kernel_w, 12 | const int stride_h, const int stride_w, 13 | double* data_im); 14 | 15 | void pool_forwards_cpu(const double* data_im, const int channels, 16 | const int height, const int width, const int kernel_h, const int kernel_w, 17 | const int stride_h, const int stride_w, 18 | double* data_pooled); 19 | 20 | void pool_backwards_cpu(const double* data_im, const double* data_pooled, 21 | const int channels, const int height, const int width, const int kernel_h, 22 | const int kernel_w, const int stride_h, const int stride_w, 23 | double* data_backgrad ); 24 | -------------------------------------------------------------------------------- /cbits/pad.c: -------------------------------------------------------------------------------- 1 | #include "pad.h" 2 | 3 | void pad_cpu(double* data, const int channels, 4 | const int height, const int width, const int pad_left, const int pad_top, 5 | const int pad_right, const int pad_bottom, 6 | double* data_padded) { 7 | 8 | const int pad_width = width + pad_left + pad_right; 9 | const int pad_height = height + pad_top + pad_bottom; 10 | 11 | memset(data_padded, 0, pad_height * pad_width * channels * sizeof(double)); 12 | 13 | for (int channel = 0; channel < channels; channel++) { 14 | double* px = data_padded + (pad_width * pad_top + pad_left) + channel * (pad_width * pad_height); 15 | for (int y = 0; y < height; y++) { 16 | memcpy(px, data, sizeof(double) * width); 17 | px += pad_width; 18 | data += width; 19 | } 20 | } 21 | } 22 | 23 | void crop_cpu(double* data, const int channels, 24 | const int height, const int width, const int crop_left, const int crop_top, 25 | const int crop_right, const int crop_bottom, 26 | double* data_cropped) { 27 | 28 | const int crop_width = width + crop_left + crop_right; 29 | const int crop_height = height + crop_top + crop_bottom; 30 | 31 | for (int channel = 0; channel < channels; channel++) { 32 | double* px = data + (crop_width * crop_top + crop_left) + channel * (crop_width * crop_height); 33 | for (int y = 0; y < height; y++) { 34 | memcpy(data_cropped, px, sizeof(double) * width); 35 | px += crop_width; 36 | data_cropped += width; 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /cbits/pad.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void pad_cpu(double* data_im, const int channels, 6 | const int height, const int width, const int pad_left, const int pad_top, 7 | const int pad_right, const int pad_bottom, 8 | double* data_col); 9 | 10 | void crop_cpu(double* data_im, const int channels, 11 | const int height, const int width, const int crop_left, const int crop_top, 12 | const int crop_right, const int crop_bottom, 13 | double* data_col); 14 | -------------------------------------------------------------------------------- /examples/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016-2017, Huw Campbell 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | -------------------------------------------------------------------------------- /examples/grenade-examples.cabal: -------------------------------------------------------------------------------- 1 | name: grenade-examples 2 | version: 0.0.1 3 | license: BSD2 4 | license-file: LICENSE 5 | author: Huw Campbell 6 | maintainer: Huw Campbell 7 | copyright: (c) 2016-2017 Huw Campbell. 8 | synopsis: grenade-examples 9 | category: System 10 | cabal-version: >= 1.8 11 | build-type: Simple 12 | description: grenade-examples 13 | 14 | source-repository head 15 | type: git 16 | location: https://github.com/HuwCampbell/grenade.git 17 | 18 | library 19 | 20 | executable feedforward 21 | ghc-options: -Wall -threaded -O2 22 | main-is: main/feedforward.hs 23 | build-depends: base 24 | , grenade 25 | , attoparsec 26 | , bytestring 27 | , cereal 28 | , either 29 | , optparse-applicative >= 0.13 && < 0.18 30 | , text >= 1.2 31 | , mtl >= 2.2.1 && < 2.3 32 | , hmatrix 33 | , transformers 34 | , singletons 35 | , semigroups 36 | , MonadRandom 37 | 38 | executable mnist 39 | ghc-options: -Wall -threaded -O2 40 | main-is: main/mnist.hs 41 | build-depends: base 42 | , grenade 43 | , attoparsec 44 | , either 45 | , optparse-applicative >= 0.13 && < 0.18 46 | , text >= 1.2 47 | , mtl >= 2.2.1 && < 2.3 48 | , hmatrix >= 0.18 && < 0.21 49 | , transformers 50 | , semigroups 51 | , singletons 52 | , MonadRandom 53 | , vector 54 | 55 | executable gan-mnist 56 | ghc-options: -Wall -threaded -O2 57 | main-is: main/gan-mnist.hs 58 | build-depends: base 59 | , grenade 60 | , attoparsec 61 | , bytestring 62 | , cereal 63 | , either 64 | , optparse-applicative >= 0.13 && < 0.18 65 | , text >= 1.2 66 | , mtl >= 2.2.1 && < 2.3 67 | , hmatrix >= 0.18 && < 0.21 68 | , transformers 69 | , semigroups 70 | , singletons 71 | , MonadRandom 72 | , vector 73 | 74 | executable recurrent 75 | ghc-options: -Wall -threaded -O2 76 | main-is: main/recurrent.hs 77 | build-depends: base 78 | , grenade 79 | , attoparsec 80 | , either 81 | , optparse-applicative >= 0.13 && < 0.18 82 | , text >= 1.2 83 | , mtl >= 2.2.1 && < 2.3 84 | , hmatrix >= 0.18 && < 0.21 85 | , transformers 86 | , semigroups 87 | , singletons 88 | , MonadRandom 89 | 90 | executable shakespeare 91 | ghc-options: -Wall -threaded -O2 92 | main-is: main/shakespeare.hs 93 | build-depends: base 94 | , grenade 95 | , attoparsec 96 | , bytestring 97 | , cereal 98 | , either 99 | , optparse-applicative >= 0.13 && < 0.18 100 | , text >= 1.2 101 | , mtl >= 2.2.1 && < 2.3 102 | , hmatrix >= 0.18 && < 0.21 103 | , transformers 104 | , semigroups 105 | , singletons 106 | , singletons-base 107 | , vector 108 | , MonadRandom 109 | , containers 110 | -------------------------------------------------------------------------------- /examples/mafia: -------------------------------------------------------------------------------- 1 | #!/bin/sh -eu 2 | 3 | : ${MAFIA_HOME:=$HOME/.mafia} 4 | : ${MAFIA_VERSIONS:=$MAFIA_HOME/versions} 5 | 6 | latest_version () { 7 | git ls-remote https://github.com/haskell-mafia/mafia | grep refs/heads/master | cut -f 1 8 | } 9 | 10 | build_version() { 11 | MAFIA_VERSION="$1" 12 | MAFIA_TEMP=$(mktemp -d 2>/dev/null || mktemp -d -t 'exec_mafia') 13 | MAFIA_FILE=mafia-$MAFIA_VERSION 14 | MAFIA_PATH=$MAFIA_VERSIONS/$MAFIA_FILE 15 | mkdir -p $MAFIA_VERSIONS 16 | echo "Building $MAFIA_FILE in $MAFIA_TEMP" 17 | git clone https://github.com/haskell-mafia/mafia $MAFIA_TEMP 18 | git --git-dir="$MAFIA_TEMP/.git" --work-tree="$MAFIA_TEMP" reset --hard $MAFIA_VERSION || { 19 | echo "mafia version ($MAFIA_VERSION) could not be found." >&2 20 | exit 1 21 | } 22 | (cd "$MAFIA_TEMP" && ./bin/bootstrap) || { 23 | got=$? 24 | echo "mafia version ($MAFIA_VERSION) could not be built." >&2 25 | exit "$got" 26 | } 27 | chmod +x "$MAFIA_TEMP/.cabal-sandbox/bin/mafia" 28 | # Ensure executable is on same file-system so final mv is atomic. 29 | mv -f "$MAFIA_TEMP/.cabal-sandbox/bin/mafia" "$MAFIA_PATH.$$" 30 | mv "$MAFIA_PATH.$$" "$MAFIA_PATH" || { 31 | rm -f "$MAFIA_PATH.$$" 32 | echo "INFO: mafia version ($MAFIA_VERSION) already exists not overiding," >&2 33 | echo "INFO: this is expected if parallel builds of the same version of" >&2 34 | echo "INFO: mafia occur, we are playing by first in, wins." >&2 35 | exit 0 36 | } 37 | } 38 | 39 | enable_version() { 40 | if [ $# -eq 0 ]; then 41 | MAFIA_VERSION="$(latest_version)" 42 | echo "INFO: No explicit mafia version requested installing latest ($MAFIA_VERSION)." >&2 43 | else 44 | MAFIA_VERSION="$1" 45 | fi 46 | [ -x "$MAFIA_HOME/versions/mafia-$MAFIA_VERSION" ] || build_version "$MAFIA_VERSION" 47 | ln -sf "$MAFIA_HOME/versions/mafia-$MAFIA_VERSION" "$MAFIA_HOME/versions/mafia" 48 | } 49 | 50 | exec_mafia () { 51 | [ -x "$MAFIA_HOME/versions/mafia" ] || enable_version 52 | "$MAFIA_HOME/versions/mafia" "$@" 53 | } 54 | 55 | # 56 | # The actual start of the script..... 57 | # 58 | 59 | case "${1:-}" in 60 | upgrade) shift; enable_version "$@" ;; 61 | *) exec_mafia "$@" 62 | esac 63 | -------------------------------------------------------------------------------- /examples/main/feedforward.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TupleSections #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | import Control.Monad 9 | import Control.Monad.Random 10 | import Data.List ( foldl' ) 11 | 12 | import qualified Data.ByteString as B 13 | import Data.Serialize 14 | #if ! MIN_VERSION_base(4,13,0) 15 | import Data.Semigroup ( (<>) ) 16 | #endif 17 | import GHC.TypeLits 18 | 19 | import qualified Numeric.LinearAlgebra.Static as SA 20 | 21 | import Options.Applicative 22 | 23 | import Grenade 24 | 25 | 26 | -- The defininition for our simple feed forward network. 27 | -- The type level lists represents the layers and the shapes passed through the layers. 28 | -- One can see that for this demonstration we are using relu, tanh and logit non-linear 29 | -- units, which can be easily subsituted for each other in and out. 30 | -- 31 | -- With around 100000 examples, this should show two clear circles which have been learned by the network. 32 | type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ] 33 | '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1] 34 | 35 | randomNet :: MonadRandom m => m FFNet 36 | randomNet = randomNetwork 37 | 38 | netTrain :: FFNet -> LearningParameters -> Int -> IO FFNet 39 | netTrain net0 rate n = do 40 | inps <- replicateM n $ do 41 | s <- getRandom 42 | return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1 43 | let outs = flip map inps $ \(S1D v) -> 44 | if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33) 45 | then S1D $ fromRational 1 46 | else S1D $ fromRational 0 47 | 48 | let trained = foldl' trainEach net0 (zip inps outs) 49 | return trained 50 | 51 | where 52 | inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool 53 | v `inCircle` (o, r) = SA.norm_2 (v - o) <= r 54 | trainEach !network (i,o) = train rate network i o 55 | 56 | netLoad :: FilePath -> IO FFNet 57 | netLoad modelPath = do 58 | modelData <- B.readFile modelPath 59 | either fail return $ runGet (get :: Get FFNet) modelData 60 | 61 | netScore :: FFNet -> IO () 62 | netScore network = do 63 | let testIns = [ [ (x,y) | x <- [0..50] ] 64 | | y <- [0..20] ] 65 | outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet network (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns 66 | putStrLn $ unlines outMat 67 | 68 | where 69 | render n' | n' <= 0.2 = ' ' 70 | | n' <= 0.4 = '.' 71 | | n' <= 0.6 = '-' 72 | | n' <= 0.8 = '=' 73 | | otherwise = '#' 74 | 75 | normx :: S ('D1 1) -> Double 76 | normx (S1D r) = SA.mean r 77 | 78 | data FeedForwardOpts = FeedForwardOpts Int LearningParameters (Maybe FilePath) (Maybe FilePath) 79 | 80 | feedForward' :: Parser FeedForwardOpts 81 | feedForward' = 82 | FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 100000) 83 | <*> (LearningParameters 84 | <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 85 | <*> option auto (long "momentum" <> value 0.9) 86 | <*> option auto (long "l2" <> value 0.0005) 87 | ) 88 | <*> optional (strOption (long "load")) 89 | <*> optional (strOption (long "save")) 90 | 91 | main :: IO () 92 | main = do 93 | FeedForwardOpts examples rate load save <- execParser (info (feedForward' <**> helper) idm) 94 | net0 <- case load of 95 | Just loadFile -> netLoad loadFile 96 | Nothing -> randomNet 97 | 98 | net <- netTrain net0 rate examples 99 | netScore net 100 | 101 | case save of 102 | Just saveFile -> B.writeFile saveFile $ runPut (put net) 103 | Nothing -> return () 104 | -------------------------------------------------------------------------------- /examples/main/gan-mnist.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE TupleSections #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE FlexibleContexts #-} 10 | 11 | -- This is a simple generative adversarial network to make pictures 12 | -- of numbers similar to those in MNIST. 13 | -- 14 | -- It demonstrates a different usage of the library. Within about 15 15 | -- minutes it was producing examples like this: 16 | -- 17 | -- --. 18 | -- .=-.--..#=### 19 | -- -##==#########. 20 | -- #############- 21 | -- -###-.=..-.-== 22 | -- ###- 23 | -- .###- 24 | -- .####...==-. 25 | -- -####=--.=##= 26 | -- -##=- -## 27 | -- =## 28 | -- -##= 29 | -- -###- 30 | -- .####. 31 | -- .#####. 32 | -- ...---=#####- 33 | -- .=#########. . 34 | -- .#######=. . 35 | -- . =-. 36 | -- 37 | -- It's a 5! 38 | -- 39 | import Control.Applicative 40 | import Control.Monad 41 | import Control.Monad.Random 42 | import Control.Monad.Trans.Except 43 | 44 | import qualified Data.Attoparsec.Text as A 45 | import qualified Data.ByteString as B 46 | import Data.List ( foldl' ) 47 | #if ! MIN_VERSION_base(4,13,0) 48 | import Data.Semigroup ( (<>) ) 49 | #endif 50 | import Data.Serialize 51 | import qualified Data.Text as T 52 | import qualified Data.Text.IO as T 53 | import qualified Data.Vector.Storable as V 54 | 55 | import qualified Numeric.LinearAlgebra.Static as SA 56 | import Numeric.LinearAlgebra.Data ( toLists ) 57 | 58 | import Options.Applicative 59 | 60 | import Grenade 61 | import Grenade.Utils.OneHot 62 | 63 | type Discriminator = 64 | Network 65 | '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu 66 | , Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Relu 67 | , Reshape, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit] 68 | '[ 'D2 28 28 69 | , 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10 70 | , 'D3 8 8 16, 'D3 4 4 16, 'D3 4 4 16 71 | , 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1] 72 | 73 | type Generator = 74 | Network 75 | '[ FullyConnected 80 256, Relu, Reshape 76 | , Deconvolution 16 10 5 5 2 2, Relu 77 | , Deconvolution 10 1 8 8 2 2, Logit] 78 | '[ 'D1 80 79 | , 'D1 256, 'D1 256, 'D3 4 4 16 80 | , 'D3 11 11 10, 'D3 11 11 10 81 | , 'D2 28 28, 'D2 28 28 ] 82 | 83 | randomDiscriminator :: MonadRandom m => m Discriminator 84 | randomDiscriminator = randomNetwork 85 | 86 | randomGenerator :: MonadRandom m => m Generator 87 | randomGenerator = randomNetwork 88 | 89 | trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 80) -> ( Discriminator, Generator ) 90 | trainExample rate discriminator generator realExample noiseSource 91 | = let (generatorTape, fakeExample) = runNetwork generator noiseSource 92 | 93 | (discriminatorTapeReal, guessReal) = runNetwork discriminator realExample 94 | (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample 95 | 96 | (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 ) 97 | (discriminator'fake, _) = runGradient discriminator discriminatorTapeFake guessFake 98 | (_, push) = runGradient discriminator discriminatorTapeFake ( guessFake - 1) 99 | 100 | (generator', _) = runGradient generator generatorTape push 101 | 102 | newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ] 103 | newGenerator = applyUpdate rate generator generator' 104 | in ( newDiscriminator, newGenerator ) 105 | 106 | 107 | ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator) 108 | ganTest (discriminator0, generator0) iterations trainFile rate = do 109 | trainData <- fmap fst <$> readMNIST trainFile 110 | 111 | lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations] 112 | 113 | where 114 | 115 | showShape' :: S ('D2 a b) -> IO () 116 | showShape' (S2D mm) = putStrLn $ 117 | let m = SA.extract mm 118 | ms = toLists m 119 | render n' | n' <= 0.2 = ' ' 120 | | n' <= 0.4 = '.' 121 | | n' <= 0.6 = '-' 122 | | n' <= 0.8 = '=' 123 | | otherwise = '#' 124 | 125 | px = (fmap . fmap) render ms 126 | in unlines px 127 | 128 | runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator) 129 | runIteration trainData ( !discriminator, !generator ) _ = do 130 | trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do 131 | fakeExample <- randomOfShape 132 | return $ trainExample rate discriminatorX generatorX realExample fakeExample 133 | ) ( discriminator, generator ) trainData 134 | 135 | 136 | showShape' . snd . runNetwork (snd trained') =<< randomOfShape 137 | 138 | return trained' 139 | 140 | data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath) 141 | 142 | mnist' :: Parser GanOpts 143 | mnist' = GanOpts <$> argument str (metavar "TRAIN") 144 | <*> option auto (long "iterations" <> short 'i' <> value 15) 145 | <*> (LearningParameters 146 | <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 147 | <*> option auto (long "momentum" <> value 0.9) 148 | <*> option auto (long "l2" <> value 0.0005) 149 | ) 150 | <*> optional (strOption (long "load")) 151 | <*> optional (strOption (long "save")) 152 | 153 | 154 | main :: IO () 155 | main = do 156 | GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm) 157 | putStrLn "Training stupidly simply GAN" 158 | nets0 <- case load of 159 | Just loadFile -> netLoad loadFile 160 | Nothing -> (,) <$> randomDiscriminator <*> randomGenerator 161 | 162 | res <- runExceptT $ ganTest nets0 iter mnist rate 163 | case res of 164 | Right nets1 -> case save of 165 | Just saveFile -> B.writeFile saveFile $ runPut (put nets1) 166 | Nothing -> return () 167 | 168 | Left err -> putStrLn err 169 | 170 | readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))] 171 | readMNIST mnist = ExceptT $ do 172 | mnistdata <- T.readFile mnist 173 | return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata) 174 | 175 | parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10)) 176 | parseMNIST = do 177 | Just lab <- oneHot <$> A.decimal 178 | pixels <- many (A.char ',' >> A.double) 179 | image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels) 180 | return (image, lab) 181 | 182 | netLoad :: FilePath -> IO (Discriminator, Generator) 183 | netLoad modelPath = do 184 | modelData <- B.readFile modelPath 185 | either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData 186 | -------------------------------------------------------------------------------- /examples/main/mnist.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TupleSections #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE FlexibleContexts #-} 9 | 10 | import Control.Applicative 11 | import Control.Monad 12 | import Control.Monad.Random 13 | import Control.Monad.Trans.Except 14 | 15 | import qualified Data.Attoparsec.Text as A 16 | import Data.List ( foldl' ) 17 | #if ! MIN_VERSION_base(4,13,0) 18 | import Data.Semigroup ( (<>) ) 19 | #endif 20 | 21 | import qualified Data.Text as T 22 | import qualified Data.Text.IO as T 23 | import qualified Data.Vector.Storable as V 24 | 25 | import Numeric.LinearAlgebra ( maxIndex ) 26 | import qualified Numeric.LinearAlgebra.Static as SA 27 | 28 | import Options.Applicative 29 | 30 | import Grenade 31 | import Grenade.Utils.OneHot 32 | 33 | -- It's logistic regression! 34 | -- 35 | -- This network is used to show how we can embed a Network as a layer in the larger MNIST 36 | -- type. 37 | type FL i o = 38 | Network 39 | '[ FullyConnected i o, Logit ] 40 | '[ 'D1 i, 'D1 o, 'D1 o ] 41 | 42 | -- The definition of our convolutional neural network. 43 | -- In the type signature, we have a type level list of shapes which are passed between the layers. 44 | -- One can see that the images we are inputing are two dimensional with 28 * 28 pixels. 45 | 46 | -- It's important to keep the type signatures, as there's many layers which can "squeeze" into the gaps 47 | -- between the shapes, so inference can't do it all for us. 48 | 49 | -- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations, 50 | -- this network should get down to about a 1.3% error rate. 51 | -- 52 | -- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead. 53 | -- This one is just here to demonstrate Inception layers in use. 54 | -- 55 | type MNIST = 56 | Network 57 | '[ Reshape, 58 | Concat ('D3 28 28 1) Trivial ('D3 28 28 14) (InceptionMini 28 28 1 5 9), 59 | Pooling 2 2 2 2, Relu, 60 | Concat ('D3 14 14 3) (Convolution 15 3 1 1 1 1) ('D3 14 14 15) (InceptionMini 14 14 15 5 10), Crop 1 1 1 1, Pooling 3 3 3 3, Relu, 61 | Reshape, FL 288 80, FL 80 10 ] 62 | '[ 'D2 28 28, 'D3 28 28 1, 63 | 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15, 'D3 14 14 18, 64 | 'D3 12 12 18, 'D3 4 4 18, 'D3 4 4 18, 65 | 'D1 288, 'D1 80, 'D1 10 ] 66 | 67 | randomMnist :: MonadRandom m => m MNIST 68 | randomMnist = randomNetwork 69 | 70 | convTest :: Int -> FilePath -> FilePath -> LearningParameters -> ExceptT String IO () 71 | convTest iterations trainFile validateFile rate = do 72 | net0 <- lift randomMnist 73 | trainData <- readMNIST trainFile 74 | validateData <- readMNIST validateFile 75 | lift $ foldM_ (runIteration trainData validateData) net0 [1..iterations] 76 | 77 | where 78 | trainEach rate' !network (i, o) = train rate' network i o 79 | 80 | runIteration trainRows validateRows net i = do 81 | let trained' = foldl' (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows 82 | let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows 83 | let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res 84 | print trained' 85 | putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res') 86 | return trained' 87 | 88 | data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters 89 | 90 | mnist' :: Parser MnistOpts 91 | mnist' = MnistOpts <$> argument str (metavar "TRAIN") 92 | <*> argument str (metavar "VALIDATE") 93 | <*> option auto (long "iterations" <> short 'i' <> value 15) 94 | <*> (LearningParameters 95 | <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 96 | <*> option auto (long "momentum" <> value 0.9) 97 | <*> option auto (long "l2" <> value 0.0005) 98 | ) 99 | 100 | main :: IO () 101 | main = do 102 | MnistOpts mnist vali iter rate <- execParser (info (mnist' <**> helper) idm) 103 | putStrLn "Training convolutional neural network..." 104 | 105 | res <- runExceptT $ convTest iter mnist vali rate 106 | case res of 107 | Right () -> pure () 108 | Left err -> putStrLn err 109 | 110 | readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))] 111 | readMNIST mnist = ExceptT $ do 112 | mnistdata <- T.readFile mnist 113 | return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata) 114 | 115 | parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10)) 116 | parseMNIST = do 117 | Just lab <- oneHot <$> A.decimal 118 | pixels <- many (A.char ',' >> A.double) 119 | image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels) 120 | return (image, lab) 121 | -------------------------------------------------------------------------------- /examples/main/recurrent.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TupleSections #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | 9 | import Control.Monad ( foldM ) 10 | import Control.Monad.Random ( MonadRandom, getRandomR ) 11 | 12 | #if MIN_VERSION_base(4,13,0) 13 | import Data.List ( unfoldr ) 14 | #else 15 | import Data.List ( cycle, unfoldr ) 16 | import Data.Semigroup ( (<>) ) 17 | #endif 18 | 19 | import qualified Numeric.LinearAlgebra.Static as SA 20 | 21 | import Options.Applicative 22 | 23 | import Grenade 24 | import Grenade.Recurrent 25 | 26 | -- The defininition for our simple recurrent network. 27 | -- This file just trains a network to generate a repeating sequence 28 | -- of 0 0 1. 29 | -- 30 | -- The F and R types are Tagging types to ensure that the runner and 31 | -- creation function know how to treat the layers. 32 | type R = Recurrent 33 | 34 | type RecNet = RecurrentNetwork '[ R (LSTM 1 1)] 35 | '[ 'D1 1, 'D1 1] 36 | 37 | type RecInput = RecurrentInputs '[ R (LSTM 1 1)] 38 | 39 | randomNet :: MonadRandom m => m RecNet 40 | randomNet = randomRecurrent 41 | 42 | netTest :: MonadRandom m => RecNet -> RecInput -> LearningParameters -> Int -> m (RecNet, RecInput) 43 | netTest net0 i0 rate iterations = 44 | foldM trainIteration (net0,i0) [1..iterations] 45 | where 46 | trainingCycle = cycle [c 0, c 0, c 1] 47 | 48 | trainIteration (net, io) _ = do 49 | dropping <- getRandomR (0, 2) 50 | count <- getRandomR (5, 30) 51 | let t = drop dropping trainingCycle 52 | let example = ((,Nothing) <$> take count t) ++ [(t !! count, Just $ t !! (count + 1))] 53 | return $ trainEach net io example 54 | 55 | trainEach !nt !io !ex = trainRecurrent rate nt io ex 56 | 57 | data FeedForwardOpts = FeedForwardOpts Int LearningParameters 58 | 59 | feedForward' :: Parser FeedForwardOpts 60 | feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 40000) 61 | <*> (LearningParameters 62 | <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 63 | <*> option auto (long "momentum" <> value 0.9) 64 | <*> option auto (long "l2" <> value 0.0005) 65 | ) 66 | 67 | generateRecurrent :: RecNet -> RecInput -> S ('D1 1) -> [Int] 68 | generateRecurrent n s i = 69 | unfoldr go (s, i) 70 | where 71 | go (x, y) = 72 | do let (_, ns, o) = runRecurrent n x y 73 | o' = heat o 74 | Just (o', (ns, fromIntegral o')) 75 | 76 | heat :: S ('D1 1) -> Int 77 | heat x = case x of 78 | (S1D v) -> round (SA.mean v) 79 | 80 | main :: IO () 81 | main = do 82 | FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm) 83 | putStrLn "Training network..." 84 | 85 | net0 <- randomNet 86 | (trained, bestInput) <- netTest net0 0 rate examples 87 | 88 | let results = generateRecurrent trained bestInput (c 1) 89 | 90 | print . take 50 . drop 100 $ results 91 | 92 | c :: Double -> S ('D1 1) 93 | c = S1D . SA.konst 94 | -------------------------------------------------------------------------------- /examples/main/shakespeare.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TupleSections #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE LambdaCase #-} 9 | 10 | import Control.Monad.Random 11 | import Control.Monad.Trans.Except 12 | 13 | import Data.Char ( isUpper, toUpper, toLower ) 14 | import Data.List ( foldl' ) 15 | import Data.Maybe ( fromMaybe ) 16 | 17 | 18 | import qualified Data.Vector as V 19 | import Data.Vector ( Vector ) 20 | 21 | import qualified Data.Map as M 22 | 23 | import qualified Data.ByteString as B 24 | import Data.Serialize 25 | 26 | import GHC.TypeLits 27 | 28 | import Numeric.LinearAlgebra.Static ( konst ) 29 | 30 | import Options.Applicative 31 | 32 | import Grenade 33 | import Grenade.Recurrent 34 | import Grenade.Utils.OneHot 35 | 36 | import System.IO.Unsafe ( unsafeInterleaveIO ) 37 | import Data.Proxy 38 | import Prelude.Singletons 39 | 40 | -- The defininition for our natural language recurrent network. 41 | -- This network is able to learn and generate simple words in 42 | -- about an hour. 43 | -- 44 | -- Grab the input from 45 | -- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 46 | -- 47 | -- This is a first class recurrent net. 48 | -- 49 | -- The F and R types are tagging types to ensure that the runner and 50 | -- creation function know how to treat the layers. 51 | -- 52 | -- As an example, here's a short sequence generated. 53 | -- 54 | -- > KING RICHARD III: 55 | -- > And as the heaven her his words, we the son, I show sand stape but the lament to shall were the sons with a strend 56 | 57 | type F = FeedForward 58 | type R = Recurrent 59 | 60 | -- The definition of our network 61 | type Shakespeare = RecurrentNetwork '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit] 62 | '[ 'D1 40, 'D1 80, 'D1 40, 'D1 40, 'D1 40 ] 63 | 64 | -- The definition of the "sideways" input, which the network is fed recurrently. 65 | type Shakespearian = RecurrentInputs '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit] 66 | 67 | randomNet :: MonadRandom m => m Shakespeare 68 | randomNet = randomRecurrent 69 | 70 | -- | Load the data files and prepare a map of characters to a compressed int representation. 71 | loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Vector Char) 72 | loadShakespeare path = do 73 | contents <- lift $ readFile path 74 | let annotated = annotateCapitals contents 75 | (m,cs) <- ExceptT . return $ hotMap (Proxy :: Proxy 40) annotated 76 | hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated 77 | return (V.fromList hot, m, cs) 78 | 79 | trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian) 80 | trainSlice !lrate !net !recIns input offset size = 81 | let e = fmap (x . oneHot) . V.toList $ V.slice offset size input 82 | in case reverse e of 83 | (o : l : xs) -> 84 | let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs) 85 | in trainRecurrent lrate net recIns examples 86 | _ -> error "Not enough input" 87 | where 88 | x = fromMaybe (error "Hot variable didn't fit.") 89 | 90 | runShakespeare :: ShakespeareOpts -> ExceptT String IO () 91 | runShakespeare opts = do 92 | (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts 93 | (net0, i0) <- lift $ 94 | case loadPath opts of 95 | Just loadFile -> netLoad loadFile 96 | Nothing -> (,0) <$> randomNet 97 | 98 | (trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do 99 | xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1) 100 | let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs 101 | results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0) 102 | putStrLn ("TRAINING STEP WITH SIZE: " ++ show size) 103 | putStrLn (unAnnotateCapitals results) 104 | return (trained, bestInput) 105 | ) (net0, i0) $ replicate 10 (sequenceSize opts) 106 | 107 | case savePath opts of 108 | Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput) 109 | Nothing -> return () 110 | 111 | generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a) 112 | => RecurrentNetwork layers shapes 113 | -> RecurrentInputs layers 114 | -> Double 115 | -> M.Map a Int 116 | -> Vector a 117 | -> S ('D1 n) 118 | -> IO [a] 119 | generateParagraph n s temp hotmap hotdict = 120 | go s 121 | where 122 | go x y = 123 | do let (_, ns, o) = runRecurrent n x y 124 | un <- sample temp hotdict o 125 | Just re <- return $ makeHot hotmap un 126 | rest <- unsafeInterleaveIO $ go ns re 127 | return (un : rest) 128 | 129 | data ShakespeareOpts = ShakespeareOpts { 130 | trainingFile :: FilePath 131 | , iterations :: Int 132 | , rate :: LearningParameters 133 | , sequenceSize :: Int 134 | , temperature :: Double 135 | , loadPath :: Maybe FilePath 136 | , savePath :: Maybe FilePath 137 | } 138 | 139 | shakespeare' :: Parser ShakespeareOpts 140 | shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN") 141 | <*> option auto (long "examples" <> short 'e' <> value 1000000) 142 | <*> (LearningParameters 143 | <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 144 | <*> option auto (long "momentum" <> value 0.95) 145 | <*> option auto (long "l2" <> value 0.000001) 146 | ) 147 | <*> option auto (long "sequence-length" <> short 's' <> value 50) 148 | <*> option auto (long "temperature" <> short 't' <> value 0.4) 149 | <*> optional (strOption (long "load")) 150 | <*> optional (strOption (long "save")) 151 | 152 | main :: IO () 153 | main = do 154 | shopts <- execParser (info (shakespeare' <**> helper) idm) 155 | res <- runExceptT $ runShakespeare shopts 156 | case res of 157 | Right () -> pure () 158 | Left err -> putStrLn err 159 | 160 | 161 | netLoad :: FilePath -> IO (Shakespeare, Shakespearian) 162 | netLoad modelPath = do 163 | modelData <- B.readFile modelPath 164 | either fail return $ runGet get modelData 165 | 166 | -- Replace capitals with an annotation and the lower case letter 167 | -- http://fastml.com/one-weird-trick-for-training-char-rnns/ 168 | annotateCapitals :: String -> String 169 | annotateCapitals (x : rest) 170 | | isUpper x 171 | = '^' : toLower x : annotateCapitals rest 172 | | otherwise 173 | = x : annotateCapitals rest 174 | annotateCapitals [] 175 | = [] 176 | 177 | unAnnotateCapitals :: String -> String 178 | unAnnotateCapitals ('^' : x : rest) 179 | = toUpper x : unAnnotateCapitals rest 180 | unAnnotateCapitals (x : rest) 181 | = x : unAnnotateCapitals rest 182 | unAnnotateCapitals [] 183 | = [] 184 | 185 | -- | Tag the 'Nothing' value of a 'Maybe' 186 | note :: a -> Maybe b -> Either a b 187 | note a = maybe (Left a) Right 188 | -------------------------------------------------------------------------------- /framework/mafia: -------------------------------------------------------------------------------- 1 | #!/bin/sh -eu 2 | 3 | : ${MAFIA_HOME:=$HOME/.mafia} 4 | : ${MAFIA_VERSIONS:=$MAFIA_HOME/versions} 5 | 6 | latest_version () { 7 | git ls-remote https://github.com/haskell-mafia/mafia | grep refs/heads/master | cut -f 1 8 | } 9 | 10 | build_version() { 11 | MAFIA_VERSION="$1" 12 | MAFIA_TEMP=$(mktemp -d 2>/dev/null || mktemp -d -t 'exec_mafia') 13 | MAFIA_FILE=mafia-$MAFIA_VERSION 14 | MAFIA_PATH=$MAFIA_VERSIONS/$MAFIA_FILE 15 | mkdir -p $MAFIA_VERSIONS 16 | echo "Building $MAFIA_FILE in $MAFIA_TEMP" 17 | git clone https://github.com/haskell-mafia/mafia $MAFIA_TEMP 18 | git --git-dir="$MAFIA_TEMP/.git" --work-tree="$MAFIA_TEMP" reset --hard $MAFIA_VERSION || { 19 | echo "mafia version ($MAFIA_VERSION) could not be found." >&2 20 | exit 1 21 | } 22 | (cd "$MAFIA_TEMP" && ./bin/bootstrap) || { 23 | got=$? 24 | echo "mafia version ($MAFIA_VERSION) could not be built." >&2 25 | exit "$got" 26 | } 27 | chmod +x "$MAFIA_TEMP/.cabal-sandbox/bin/mafia" 28 | # Ensure executable is on same file-system so final mv is atomic. 29 | mv -f "$MAFIA_TEMP/.cabal-sandbox/bin/mafia" "$MAFIA_PATH.$$" 30 | mv "$MAFIA_PATH.$$" "$MAFIA_PATH" || { 31 | rm -f "$MAFIA_PATH.$$" 32 | echo "INFO: mafia version ($MAFIA_VERSION) already exists not overiding," >&2 33 | echo "INFO: this is expected if parallel builds of the same version of" >&2 34 | echo "INFO: mafia occur, we are playing by first in, wins." >&2 35 | exit 0 36 | } 37 | } 38 | 39 | enable_version() { 40 | if [ $# -eq 0 ]; then 41 | MAFIA_VERSION="$(latest_version)" 42 | echo "INFO: No explicit mafia version requested installing latest ($MAFIA_VERSION)." >&2 43 | else 44 | MAFIA_VERSION="$1" 45 | fi 46 | [ -x "$MAFIA_HOME/versions/mafia-$MAFIA_VERSION" ] || build_version "$MAFIA_VERSION" 47 | ln -sf "$MAFIA_HOME/versions/mafia-$MAFIA_VERSION" "$MAFIA_HOME/versions/mafia" 48 | } 49 | 50 | exec_mafia () { 51 | [ -x "$MAFIA_HOME/versions/mafia" ] || enable_version 52 | "$MAFIA_HOME/versions/mafia" "$@" 53 | } 54 | 55 | # 56 | # The actual start of the script..... 57 | # 58 | 59 | case "${1:-}" in 60 | upgrade) shift; enable_version "$@" ;; 61 | *) exec_mafia "$@" 62 | esac 63 | -------------------------------------------------------------------------------- /grenade.cabal: -------------------------------------------------------------------------------- 1 | name: grenade 2 | version: 0.1.0 3 | license: BSD2 4 | license-file: LICENSE 5 | author: Huw Campbell 6 | maintainer: Huw Campbell 7 | copyright: (c) 2016-2017 Huw Campbell. 8 | synopsis: Practical Deep Learning in Haskell 9 | category: AI, Machine Learning 10 | cabal-version: >= 1.10 11 | build-type: Simple 12 | description: 13 | Grenade is a composable, dependently typed, practical, and fast 14 | recurrent neural network library for precise specifications and 15 | complex deep neural networks in Haskell. 16 | . 17 | Grenade provides an API for composing layers of a neural network 18 | into a sequence parallel graph in a type safe manner; running 19 | networks with reverse automatic differentiation to calculate their 20 | gradients; and applying gradient descent for learning. 21 | . 22 | Documentation and examples are available on github 23 | . 24 | 25 | extra-source-files: 26 | README.md 27 | cbits/im2col.h 28 | cbits/im2col.c 29 | cbits/gradient_descent.h 30 | cbits/gradient_descent.c 31 | cbits/pad.h 32 | cbits/pad.c 33 | 34 | source-repository head 35 | type: git 36 | location: https://github.com/HuwCampbell/grenade.git 37 | 38 | library 39 | build-depends: 40 | base >= 4.8 && < 5 41 | , bytestring >= 0.10 && < 0.13 42 | , containers >= 0.5 && < 0.8 43 | , cereal >= 0.5 && < 0.6 44 | , deepseq >= 1.4 && < 1.6 45 | , hmatrix >= 0.18 && < 0.21 46 | , MonadRandom >= 0.4 && < 0.7 47 | , primitive >= 0.6 && < 0.10 48 | -- Versions of singletons are *tightly* coupled with the 49 | -- GHC version so its fine to drop version bounds. 50 | , singletons 51 | , singletons-base 52 | , vector >= 0.11 && < 0.14 53 | 54 | ghc-options: 55 | -Wall 56 | hs-source-dirs: 57 | src 58 | 59 | default-language: Haskell2010 60 | 61 | if impl(ghc < 8.0) 62 | ghc-options: -fno-warn-incomplete-patterns 63 | cpp-options: -DType=* 64 | 65 | if impl(ghc >= 8.6) 66 | default-extensions: NoStarIsType 67 | 68 | exposed-modules: 69 | Grenade 70 | Grenade.Core 71 | Grenade.Core.Layer 72 | Grenade.Core.LearningParameters 73 | Grenade.Core.Network 74 | Grenade.Core.Runner 75 | Grenade.Core.Shape 76 | 77 | Grenade.Layers 78 | Grenade.Layers.Concat 79 | Grenade.Layers.Convolution 80 | Grenade.Layers.Crop 81 | Grenade.Layers.Deconvolution 82 | Grenade.Layers.Dropout 83 | Grenade.Layers.Elu 84 | Grenade.Layers.FullyConnected 85 | Grenade.Layers.Inception 86 | Grenade.Layers.Logit 87 | Grenade.Layers.Merge 88 | Grenade.Layers.Pad 89 | Grenade.Layers.Pooling 90 | Grenade.Layers.Relu 91 | Grenade.Layers.Reshape 92 | Grenade.Layers.Sinusoid 93 | Grenade.Layers.Softmax 94 | Grenade.Layers.Tanh 95 | Grenade.Layers.Trivial 96 | 97 | Grenade.Layers.Internal.Convolution 98 | Grenade.Layers.Internal.Pad 99 | Grenade.Layers.Internal.Pooling 100 | Grenade.Layers.Internal.Update 101 | 102 | Grenade.Recurrent 103 | 104 | Grenade.Recurrent.Core 105 | Grenade.Recurrent.Core.Layer 106 | Grenade.Recurrent.Core.Network 107 | Grenade.Recurrent.Core.Runner 108 | 109 | Grenade.Recurrent.Layers 110 | Grenade.Recurrent.Layers.BasicRecurrent 111 | Grenade.Recurrent.Layers.ConcatRecurrent 112 | Grenade.Recurrent.Layers.LSTM 113 | 114 | Grenade.Utils.OneHot 115 | 116 | includes: cbits/im2col.h 117 | cbits/gradient_descent.h 118 | cbits/pad.h 119 | c-sources: cbits/im2col.c 120 | cbits/gradient_descent.c 121 | cbits/pad.c 122 | 123 | cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1 124 | 125 | test-suite test 126 | type: exitcode-stdio-1.0 127 | 128 | main-is: test.hs 129 | 130 | ghc-options: -Wall -threaded -O2 131 | 132 | hs-source-dirs: 133 | test 134 | 135 | default-language: Haskell2010 136 | 137 | other-modules: Test.Hedgehog.Compat 138 | Test.Hedgehog.Hmatrix 139 | Test.Hedgehog.TypeLits 140 | 141 | Test.Grenade.Network 142 | Test.Grenade.Layers.Convolution 143 | Test.Grenade.Layers.FullyConnected 144 | Test.Grenade.Layers.Nonlinear 145 | Test.Grenade.Layers.PadCrop 146 | Test.Grenade.Layers.Pooling 147 | Test.Grenade.Layers.Internal.Convolution 148 | Test.Grenade.Layers.Internal.Pooling 149 | Test.Grenade.Layers.Internal.Reference 150 | 151 | Test.Grenade.Recurrent.Layers.LSTM 152 | Test.Grenade.Recurrent.Layers.LSTM.Reference 153 | 154 | if impl(ghc < 8.0) 155 | ghc-options: -fno-warn-incomplete-patterns 156 | cpp-options: -DType=* 157 | 158 | if impl(ghc >= 8.6) 159 | default-extensions: NoStarIsType 160 | 161 | build-depends: 162 | base 163 | , grenade 164 | , hedgehog >= 1.0 && < 1.5 165 | , hmatrix 166 | , mtl 167 | , singletons 168 | , text >= 1.2 169 | , typelits-witnesses < 0.5 170 | , transformers 171 | , constraints 172 | , MonadRandom 173 | , random 174 | , ad 175 | , reflection 176 | , vector 177 | 178 | 179 | benchmark bench 180 | type: exitcode-stdio-1.0 181 | 182 | main-is: bench.hs 183 | 184 | ghc-options: -Wall -threaded -O2 185 | 186 | hs-source-dirs: 187 | bench 188 | 189 | default-language: Haskell2010 190 | 191 | build-depends: 192 | base 193 | , bytestring 194 | , criterion >= 1.1 && < 1.7 195 | , grenade 196 | , hmatrix 197 | 198 | benchmark bench-lstm 199 | type: exitcode-stdio-1.0 200 | 201 | main-is: bench-lstm.hs 202 | 203 | ghc-options: -Wall -threaded -O2 204 | 205 | hs-source-dirs: 206 | bench 207 | 208 | default-language: Haskell2010 209 | 210 | build-depends: 211 | base 212 | , bytestring 213 | , criterion 214 | , grenade 215 | , hmatrix 216 | -------------------------------------------------------------------------------- /mafia: -------------------------------------------------------------------------------- 1 | framework/mafia -------------------------------------------------------------------------------- /src/Grenade.hs: -------------------------------------------------------------------------------- 1 | module Grenade ( 2 | -- | This is an empty module which simply re-exports public definitions 3 | -- for machine learning with Grenade. 4 | 5 | -- * Exported modules 6 | -- 7 | -- | The core types and runners for Grenade. 8 | module Grenade.Core 9 | 10 | -- | The neural network layer zoo 11 | , module Grenade.Layers 12 | 13 | 14 | -- * Overview of the library 15 | -- $library 16 | 17 | -- * Example usage 18 | -- $example 19 | 20 | ) where 21 | 22 | import Grenade.Core 23 | import Grenade.Layers 24 | 25 | {- $library 26 | Grenade is a purely functional deep learning library. 27 | 28 | It provides an expressive type level API for the construction 29 | of complex neural network architectures. Backing this API is and 30 | implementation written using BLAS and LAPACK, mostly provided by 31 | the hmatrix library. 32 | 33 | -} 34 | 35 | {- $example 36 | A few examples are provided at https://github.com/HuwCampbell/grenade 37 | under the examples folder. 38 | 39 | The starting place is to write your neural network type and a 40 | function to create a random layer of that type. The following 41 | is a simple example which runs a logistic regression. 42 | 43 | > type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ] 44 | > 45 | > randomMyNet :: MonadRandom MyNet 46 | > randomMyNet = randomNetwork 47 | 48 | The function `randomMyNet` witnesses the `CreatableNetwork` 49 | constraint of the neural network, and in doing so, ensures the network 50 | can be built, and hence, that the architecture is sound. 51 | -} 52 | 53 | 54 | -------------------------------------------------------------------------------- /src/Grenade/Core.hs: -------------------------------------------------------------------------------- 1 | module Grenade.Core ( 2 | module Grenade.Core.Layer 3 | , module Grenade.Core.LearningParameters 4 | , module Grenade.Core.Network 5 | , module Grenade.Core.Runner 6 | , module Grenade.Core.Shape 7 | ) where 8 | 9 | import Grenade.Core.Layer 10 | import Grenade.Core.LearningParameters 11 | import Grenade.Core.Network 12 | import Grenade.Core.Runner 13 | import Grenade.Core.Shape 14 | -------------------------------------------------------------------------------- /src/Grenade/Core/Layer.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE TypeOperators #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | {-| 11 | Module : Grenade.Core.Layer 12 | Description : Defines the Layer Classes 13 | Copyright : (c) Huw Campbell, 2016-2017 14 | License : BSD2 15 | Stability : experimental 16 | 17 | This module defines what a Layer is in a Grenade 18 | neural network. 19 | 20 | There are two classes of interest: `UpdateLayer` and `Layer`. 21 | 22 | `UpdateLayer` is required for all types which are used as a layer 23 | in a network. Having no shape information, this class is agnostic 24 | to the input and output data of the layer. 25 | 26 | An instance of `Layer` on the other hand is required for usage in 27 | a neural network, but also specifies the shapes of data that the 28 | network can transform. Multiple instance of `Layer` are permitted 29 | for a single type, to transform different shapes. The `Reshape` layer 30 | for example can act as a flattening layer, and its inverse, projecting 31 | a 1D shape up to 2 or 3 dimensions. 32 | 33 | Instances of `Layer` should be as strict as possible, and not emit 34 | runtime errors. 35 | -} 36 | module Grenade.Core.Layer ( 37 | Layer (..) 38 | , UpdateLayer (..) 39 | ) where 40 | 41 | import Control.Monad.Random ( MonadRandom ) 42 | 43 | import Data.List ( foldl' ) 44 | 45 | #if MIN_VERSION_base(4,9,0) 46 | import Data.Kind (Type) 47 | #endif 48 | 49 | import Grenade.Core.Shape 50 | import Grenade.Core.LearningParameters 51 | 52 | -- | Class for updating a layer. All layers implement this, as it 53 | -- describes how to create and update the layer. 54 | -- 55 | class UpdateLayer x where 56 | -- | The type for the gradient for this layer. 57 | -- Unit if there isn't a gradient to pass back. 58 | type Gradient x :: Type 59 | 60 | -- | Update a layer with its gradient and learning parameters 61 | runUpdate :: LearningParameters -> x -> Gradient x -> x 62 | 63 | -- | Create a random layer, many layers will use pure 64 | createRandom :: MonadRandom m => m x 65 | 66 | -- | Update a layer with many Gradients 67 | runUpdates :: LearningParameters -> x -> [Gradient x] -> x 68 | runUpdates rate = foldl' (runUpdate rate) 69 | 70 | {-# MINIMAL runUpdate, createRandom #-} 71 | 72 | -- | Class for a layer. All layers implement this, however, they don't 73 | -- need to implement it for all shapes, only ones which are 74 | -- appropriate. 75 | -- 76 | class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where 77 | -- | The Wengert tape for this layer. Includes all that is required 78 | -- to generate the back propagated gradients efficiently. As a 79 | -- default, `S i` is fine. 80 | type Tape x i o :: Type 81 | 82 | -- | Used in training and scoring. Take the input from the previous 83 | -- layer, and give the output from this layer. 84 | runForwards :: x -> S i -> (Tape x i o, S o) 85 | 86 | -- | Back propagate a step. Takes the current layer, the input that 87 | -- the layer gave from the input and the back propagated derivatives 88 | -- from the layer above. 89 | -- 90 | -- Returns the gradient layer and the derivatives to push back 91 | -- further. 92 | runBackwards :: x -> Tape x i o -> S o -> (Gradient x, S i) 93 | -------------------------------------------------------------------------------- /src/Grenade/Core/LearningParameters.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : Grenade.Core.LearningParameters 3 | Description : Stochastic gradient descent learning parameters 4 | Copyright : (c) Huw Campbell, 2016-2017 5 | License : BSD2 6 | Stability : experimental 7 | -} 8 | module Grenade.Core.LearningParameters ( 9 | -- | This module contains learning algorithm specific 10 | -- code. Currently, this module should be considered 11 | -- unstable, due to issue #26. 12 | 13 | LearningParameters (..) 14 | ) where 15 | 16 | -- | Learning parameters for stochastic gradient descent. 17 | data LearningParameters = LearningParameters { 18 | learningRate :: Double 19 | , learningMomentum :: Double 20 | , learningRegulariser :: Double 21 | } deriving (Eq, Show) 22 | -------------------------------------------------------------------------------- /src/Grenade/Core/Network.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE BangPatterns #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE MultiParamTypeClasses #-} 9 | {-# LANGUAGE FlexibleContexts #-} 10 | {-# LANGUAGE FlexibleInstances #-} 11 | {-| 12 | Module : Grenade.Core.Network 13 | Description : Core definition of a Neural Network 14 | Copyright : (c) Huw Campbell, 2016-2017 15 | License : BSD2 16 | Stability : experimental 17 | 18 | This module defines the core data types and functions 19 | for non-recurrent neural networks. 20 | -} 21 | 22 | module Grenade.Core.Network ( 23 | Network (..) 24 | , Gradients (..) 25 | , Tapes (..) 26 | 27 | , runNetwork 28 | , runGradient 29 | , applyUpdate 30 | 31 | , randomNetwork 32 | ) where 33 | 34 | import Control.Monad.Random ( MonadRandom ) 35 | 36 | import Data.Singletons 37 | import Data.Serialize 38 | 39 | #if MIN_VERSION_base(4,9,0) 40 | import Data.Kind (Type) 41 | #endif 42 | 43 | import Grenade.Core.Layer 44 | import Grenade.Core.LearningParameters 45 | import Grenade.Core.Shape 46 | import Prelude.Singletons 47 | 48 | -- | Type of a network. 49 | -- 50 | -- The @[*]@ type specifies the types of the layers. 51 | -- 52 | -- The @[Shape]@ type specifies the shapes of data passed between the layers. 53 | -- 54 | -- Can be considered to be a heterogeneous list of layers which are able to 55 | -- transform the data shapes of the network. 56 | data Network :: [Type] -> [Shape] -> Type where 57 | NNil :: SingI i 58 | => Network '[] '[i] 59 | 60 | (:~>) :: (SingI i, SingI h, Layer x i h) 61 | => !x 62 | -> !(Network xs (h ': hs)) 63 | -> Network (x ': xs) (i ': h ': hs) 64 | infixr 5 :~> 65 | 66 | instance Show (Network '[] '[i]) where 67 | show NNil = "NNil" 68 | instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) where 69 | show (x :~> xs) = show x ++ "\n~>\n" ++ show xs 70 | 71 | -- | Gradient of a network. 72 | -- 73 | -- Parameterised on the layers of the network. 74 | data Gradients :: [Type] -> Type where 75 | GNil :: Gradients '[] 76 | 77 | (:/>) :: UpdateLayer x 78 | => Gradient x 79 | -> Gradients xs 80 | -> Gradients (x ': xs) 81 | 82 | -- | Wegnert Tape of a network. 83 | -- 84 | -- Parameterised on the layers and shapes of the network. 85 | data Tapes :: [Type] -> [Shape] -> Type where 86 | TNil :: SingI i 87 | => Tapes '[] '[i] 88 | 89 | (:\>) :: (SingI i, SingI h, Layer x i h) 90 | => !(Tape x i h) 91 | -> !(Tapes xs (h ': hs)) 92 | -> Tapes (x ': xs) (i ': h ': hs) 93 | 94 | 95 | -- | Running a network forwards with some input data. 96 | -- 97 | -- This gives the output, and the Wengert tape required for back 98 | -- propagation. 99 | runNetwork :: forall layers shapes. 100 | Network layers shapes 101 | -> S (Head shapes) 102 | -> (Tapes layers shapes, S (Last shapes)) 103 | runNetwork = 104 | go 105 | where 106 | go :: forall js ss. (Last js ~ Last shapes) 107 | => Network ss js 108 | -> S (Head js) 109 | -> (Tapes ss js, S (Last js)) 110 | go (layer :~> n) !x = 111 | let (tape, forward) = runForwards layer x 112 | (tapes, answer) = go n forward 113 | in (tape :\> tapes, answer) 114 | 115 | go NNil !x 116 | = (TNil, x) 117 | 118 | 119 | -- | Running a loss gradient back through the network. 120 | -- 121 | -- This requires a Wengert tape, generated with the appropriate input 122 | -- for the loss. 123 | -- 124 | -- Gives the gradients for the layer, and the gradient across the 125 | -- input (which may not be required). 126 | runGradient :: forall layers shapes. 127 | Network layers shapes 128 | -> Tapes layers shapes 129 | -> S (Last shapes) 130 | -> (Gradients layers, S (Head shapes)) 131 | runGradient net tapes o = 132 | go net tapes 133 | where 134 | go :: forall js ss. (Last js ~ Last shapes) 135 | => Network ss js 136 | -> Tapes ss js 137 | -> (Gradients ss, S (Head js)) 138 | go (layer :~> n) (tape :\> nt) = 139 | let (gradients, feed) = go n nt 140 | (layer', backGrad) = runBackwards layer tape feed 141 | in (layer' :/> gradients, backGrad) 142 | 143 | go NNil TNil 144 | = (GNil, o) 145 | 146 | 147 | -- | Apply one step of stochastic gradient descent across the network. 148 | applyUpdate :: LearningParameters 149 | -> Network layers shapes 150 | -> Gradients layers 151 | -> Network layers shapes 152 | applyUpdate rate (layer :~> rest) (gradient :/> grest) 153 | = runUpdate rate layer gradient :~> applyUpdate rate rest grest 154 | 155 | applyUpdate _ NNil GNil 156 | = NNil 157 | 158 | -- | A network can easily be created by hand with (:~>), but an easy way to 159 | -- initialise a random network is with the randomNetwork. 160 | class CreatableNetwork (xs :: [Type]) (ss :: [Shape]) where 161 | -- | Create a network with randomly initialised weights. 162 | -- 163 | -- Calls to this function will not compile if the type of the neural 164 | -- network is not sound. 165 | randomNetwork :: MonadRandom m => m (Network xs ss) 166 | 167 | instance SingI i => CreatableNetwork '[] '[i] where 168 | randomNetwork = return NNil 169 | 170 | instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': rs)) => CreatableNetwork (x ': xs) (i ': o ': rs) where 171 | randomNetwork = (:~>) <$> createRandom <*> randomNetwork 172 | 173 | -- | Add very simple serialisation to the network 174 | instance SingI i => Serialize (Network '[] '[i]) where 175 | put NNil = pure () 176 | get = return NNil 177 | 178 | instance (SingI i, SingI o, Layer x i o, Serialize x, Serialize (Network xs (o ': rs))) => Serialize (Network (x ': xs) (i ': o ': rs)) where 179 | put (x :~> r) = put x >> put r 180 | get = (:~>) <$> get <*> get 181 | 182 | 183 | -- | Ultimate composition. 184 | -- 185 | -- This allows a complete network to be treated as a layer in a larger network. 186 | instance CreatableNetwork sublayers subshapes => UpdateLayer (Network sublayers subshapes) where 187 | type Gradient (Network sublayers subshapes) = Gradients sublayers 188 | runUpdate = applyUpdate 189 | createRandom = randomNetwork 190 | 191 | -- | Ultimate composition. 192 | -- 193 | -- This allows a complete network to be treated as a layer in a larger network. 194 | instance (CreatableNetwork sublayers subshapes, i ~ (Head subshapes), o ~ (Last subshapes)) => Layer (Network sublayers subshapes) i o where 195 | type Tape (Network sublayers subshapes) i o = Tapes sublayers subshapes 196 | runForwards = runNetwork 197 | runBackwards = runGradient 198 | -------------------------------------------------------------------------------- /src/Grenade/Core/Runner.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-| 4 | Module : Grenade.Core.Runner 5 | Description : Functions to perform training and backpropagation 6 | Copyright : (c) Huw Campbell, 2016-2017 7 | License : BSD2 8 | Stability : experimental 9 | -} 10 | module Grenade.Core.Runner ( 11 | train 12 | , backPropagate 13 | , runNet 14 | ) where 15 | 16 | 17 | import Grenade.Core.LearningParameters 18 | import Grenade.Core.Network 19 | import Grenade.Core.Shape 20 | import Data.Singletons 21 | import Prelude.Singletons 22 | 23 | -- | Perform reverse automatic differentiation on the network 24 | -- for the current input and expected output. 25 | -- 26 | -- /Note:/ The loss function pushed backwards is appropriate 27 | -- for both regression and classification as a squared loss 28 | -- or log-loss respectively. 29 | -- 30 | -- For other loss functions, use runNetwork and runGradient 31 | -- with the back propagated gradient of your loss. 32 | -- 33 | backPropagate :: SingI (Last shapes) 34 | => Network layers shapes 35 | -> S (Head shapes) 36 | -> S (Last shapes) 37 | -> Gradients layers 38 | backPropagate network input target = 39 | let (tapes, output) = runNetwork network input 40 | (grads, _) = runGradient network tapes (output - target) 41 | in grads 42 | 43 | 44 | -- | Update a network with new weights after training with an instance. 45 | train :: SingI (Last shapes) 46 | => LearningParameters 47 | -> Network layers shapes 48 | -> S (Head shapes) 49 | -> S (Last shapes) 50 | -> Network layers shapes 51 | train rate network input output = 52 | let grads = backPropagate network input output 53 | in applyUpdate rate network grads 54 | 55 | 56 | -- | Run the network with input and return the given output. 57 | runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes) 58 | runNet net = snd . runNetwork net 59 | -------------------------------------------------------------------------------- /src/Grenade/Core/Shape.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE KindSignatures #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE StandaloneDeriving #-} 8 | {-# LANGUAGE FlexibleContexts #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE RankNTypes #-} 11 | {-# LANGUAGE UndecidableInstances #-} 12 | {-| 13 | Module : Grenade.Core.Shape 14 | Description : Dependently typed shapes of data which are passed between layers of a network 15 | Copyright : (c) Huw Campbell, 2016-2017 16 | License : BSD2 17 | Stability : experimental 18 | 19 | 20 | -} 21 | module Grenade.Core.Shape ( 22 | S (..) 23 | , Shape (..) 24 | , Sing (..) 25 | , SShape (..) 26 | , randomOfShape 27 | , fromStorable 28 | ) where 29 | 30 | import Control.DeepSeq (NFData (..)) 31 | import Control.Monad.Random ( MonadRandom, getRandom ) 32 | import Data.Kind (Type) 33 | import Data.Proxy 34 | import Data.Serialize 35 | import Data.Singletons 36 | import Data.Vector.Storable ( Vector ) 37 | import qualified Data.Vector.Storable as V 38 | import GHC.TypeLits 39 | import qualified Numeric.LinearAlgebra.Static as H 40 | import Numeric.LinearAlgebra.Static 41 | import qualified Numeric.LinearAlgebra as NLA 42 | 43 | -- | The current shapes we accept. 44 | -- at the moment this is just one, two, and three dimensional 45 | -- Vectors/Matricies. 46 | -- 47 | -- These are only used with DataKinds, as Kind `Shape`, with Types 'D1, 'D2, 'D3. 48 | data Shape 49 | = D1 Nat 50 | -- ^ One dimensional vector 51 | | D2 Nat Nat 52 | -- ^ Two dimensional matrix. Row, Column. 53 | | D3 Nat Nat Nat 54 | -- ^ Three dimensional matrix. Row, Column, Channels. 55 | 56 | -- | Concrete data structures for a Shape. 57 | -- 58 | -- All shapes are held in contiguous memory. 59 | -- 3D is held in a matrix (usually row oriented) which has height depth * rows. 60 | data S (n :: Shape) where 61 | S1D :: ( KnownNat len ) 62 | => R len 63 | -> S ('D1 len) 64 | 65 | S2D :: ( KnownNat rows, KnownNat columns ) 66 | => L rows columns 67 | -> S ('D2 rows columns) 68 | 69 | S3D :: ( KnownNat rows 70 | , KnownNat columns 71 | , KnownNat depth 72 | , KnownNat (rows * depth)) 73 | => L (rows * depth) columns 74 | -> S ('D3 rows columns depth) 75 | 76 | deriving instance Show (S n) 77 | 78 | -- Singleton instances. 79 | -- 80 | -- These could probably be derived with template haskell, but this seems 81 | -- clear and makes adding the KnownNat constraints simple. 82 | -- We can also keep our code TH free, which is great. 83 | #if MIN_VERSION_singletons(2,6,0) 84 | -- In singletons 2.6 Sing switched from a data family to a type family. 85 | type instance Sing = SShape 86 | 87 | data SShape :: Shape -> Type where 88 | D1Sing :: KnownNat a => SShape ('D1 a) 89 | D2Sing :: (KnownNat a, KnownNat b) => SShape ('D2 a b) 90 | D3Sing :: (KnownNat (a * c), KnownNat a, KnownNat b, KnownNat c) => SShape ('D3 a b c) 91 | #else 92 | data instance Sing (n :: Shape) where 93 | D1Sing :: Sing a -> Sing ('D1 a) 94 | D2Sing :: Sing a -> Sing b -> Sing ('D2 a b) 95 | D3Sing :: KnownNat (a * c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c) 96 | #endif 97 | 98 | instance KnownNat a => SingI ('D1 a) where 99 | sing = D1Sing 100 | instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where 101 | sing = D2Sing 102 | instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where 103 | sing = D3Sing 104 | 105 | instance SingI x => Num (S x) where 106 | (+) = n2 (+) 107 | (-) = n2 (-) 108 | (*) = n2 (*) 109 | abs = n1 abs 110 | signum = n1 signum 111 | fromInteger x = nk (fromInteger x) 112 | 113 | instance SingI x => Fractional (S x) where 114 | (/) = n2 (/) 115 | recip = n1 recip 116 | fromRational x = nk (fromRational x) 117 | 118 | instance SingI x => Floating (S x) where 119 | pi = nk pi 120 | exp = n1 exp 121 | log = n1 log 122 | sqrt = n1 sqrt 123 | (**) = n2 (**) 124 | logBase = n2 logBase 125 | sin = n1 sin 126 | cos = n1 cos 127 | tan = n1 tan 128 | asin = n1 asin 129 | acos = n1 acos 130 | atan = n1 atan 131 | sinh = n1 sinh 132 | cosh = n1 cosh 133 | tanh = n1 tanh 134 | asinh = n1 asinh 135 | acosh = n1 acosh 136 | atanh = n1 atanh 137 | 138 | -- 139 | -- I haven't made shapes strict, as sometimes they're not needed 140 | -- (the last input gradient back for instance) 141 | -- 142 | instance NFData (S x) where 143 | rnf (S1D x) = rnf x 144 | rnf (S2D x) = rnf x 145 | rnf (S3D x) = rnf x 146 | 147 | -- | Generate random data of the desired shape 148 | randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x) 149 | randomOfShape = do 150 | seed :: Int <- getRandom 151 | return $ case (sing :: Sing x) of 152 | D1Sing -> 153 | S1D (randomVector seed Uniform * 2 - 1) 154 | 155 | D2Sing -> 156 | S2D (uniformSample seed (-1) 1) 157 | 158 | D3Sing -> 159 | S3D (uniformSample seed (-1) 1) 160 | 161 | -- | Generate a shape from a Storable Vector. 162 | -- 163 | -- Returns Nothing if the vector is of the wrong size. 164 | fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x) 165 | fromStorable xs = case sing :: Sing x of 166 | D1Sing -> 167 | S1D <$> H.create xs 168 | 169 | D2Sing -> 170 | S2D <$> mkL xs 171 | 172 | D3Sing -> 173 | S3D <$> mkL xs 174 | where 175 | mkL :: forall rows columns. (KnownNat rows, KnownNat columns) 176 | => Vector Double -> Maybe (L rows columns) 177 | mkL v = 178 | let rows = fromIntegral $ natVal (Proxy :: Proxy rows) 179 | columns = fromIntegral $ natVal (Proxy :: Proxy columns) 180 | in if rows * columns == V.length v 181 | then H.create $ NLA.reshape columns v 182 | else Nothing 183 | 184 | 185 | instance SingI x => Serialize (S x) where 186 | put i = (case i of 187 | (S1D x) -> putListOf put . NLA.toList . H.extract $ x 188 | (S2D x) -> putListOf put . NLA.toList . NLA.flatten . H.extract $ x 189 | (S3D x) -> putListOf put . NLA.toList . NLA.flatten . H.extract $ x 190 | ) :: PutM () 191 | 192 | get = do 193 | Just i <- fromStorable . V.fromList <$> getListOf get 194 | return i 195 | 196 | -- Helper function for creating the number instances 197 | n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x 198 | n1 f (S1D x) = S1D (f x) 199 | n1 f (S2D x) = S2D (f x) 200 | n1 f (S3D x) = S3D (f x) 201 | 202 | -- Helper function for creating the number instances 203 | n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x 204 | n2 f (S1D x) (S1D y) = S1D (f x y) 205 | n2 f (S2D x) (S2D y) = S2D (f x y) 206 | n2 f (S3D x) (S3D y) = S3D (f x y) 207 | 208 | -- Helper function for creating the number instances 209 | nk :: forall x. (SingI x) => Double -> S x 210 | nk x = case (sing :: Sing x) of 211 | D1Sing -> 212 | S1D (konst x) 213 | 214 | D2Sing -> 215 | S2D (konst x) 216 | 217 | D3Sing -> 218 | S3D (konst x) 219 | -------------------------------------------------------------------------------- /src/Grenade/Layers.hs: -------------------------------------------------------------------------------- 1 | module Grenade.Layers ( 2 | module Grenade.Layers.Concat 3 | , module Grenade.Layers.Convolution 4 | , module Grenade.Layers.Crop 5 | , module Grenade.Layers.Deconvolution 6 | , module Grenade.Layers.Elu 7 | , module Grenade.Layers.FullyConnected 8 | , module Grenade.Layers.Inception 9 | , module Grenade.Layers.Logit 10 | , module Grenade.Layers.Merge 11 | , module Grenade.Layers.Pad 12 | , module Grenade.Layers.Pooling 13 | , module Grenade.Layers.Reshape 14 | , module Grenade.Layers.Relu 15 | , module Grenade.Layers.Softmax 16 | , module Grenade.Layers.Tanh 17 | , module Grenade.Layers.Trivial 18 | ) where 19 | 20 | import Grenade.Layers.Concat 21 | import Grenade.Layers.Convolution 22 | import Grenade.Layers.Crop 23 | import Grenade.Layers.Deconvolution 24 | import Grenade.Layers.Elu 25 | import Grenade.Layers.Pad 26 | import Grenade.Layers.FullyConnected 27 | import Grenade.Layers.Inception 28 | import Grenade.Layers.Logit 29 | import Grenade.Layers.Merge 30 | import Grenade.Layers.Pooling 31 | import Grenade.Layers.Reshape 32 | import Grenade.Layers.Relu 33 | import Grenade.Layers.Softmax 34 | import Grenade.Layers.Tanh 35 | import Grenade.Layers.Trivial 36 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Concat.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE TypeOperators #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE StandaloneDeriving #-} 12 | {-# LANGUAGE UndecidableInstances #-} 13 | {-| 14 | Module : Grenade.Layers.Concat 15 | Description : Concatenation layer 16 | Copyright : (c) Huw Campbell, 2016-2017 17 | License : BSD2 18 | Stability : experimental 19 | 20 | This module provides the concatenation layer, which runs two chilld layers in parallel and combines their outputs. 21 | -} 22 | module Grenade.Layers.Concat ( 23 | Concat (..) 24 | ) where 25 | 26 | import Data.Serialize 27 | 28 | import Data.Singletons 29 | import GHC.TypeLits 30 | 31 | #if MIN_VERSION_base(4,9,0) 32 | import Data.Kind (Type) 33 | #endif 34 | 35 | import Grenade.Core 36 | 37 | import Numeric.LinearAlgebra.Static ( row, (===), splitRows, unrow, (#), split, R ) 38 | 39 | -- | A Concatentating Layer. 40 | -- 41 | -- This layer shares it's input state between two sublayers, and concatenates their output. 42 | -- 43 | -- With Networks able to be Layers, this allows for very expressive composition of complex Networks. 44 | -- 45 | -- The Concat layer has a few instances, which allow one to flexibly "bash" together the outputs. 46 | -- 47 | -- Two 1D vectors, can go to a 2D shape with 2 rows if their lengths are identical. 48 | -- Any 2 1D vectors can also become a longer 1D Vector. 49 | -- 50 | -- 3D images become 3D images with more channels. The sizes must be the same, one can use Pad 51 | -- and Crop layers to ensure this is the case. 52 | data Concat :: Shape -> Type -> Shape -> Type -> Type where 53 | Concat :: x -> y -> Concat m x n y 54 | 55 | instance (Show x, Show y) => Show (Concat m x n y) where 56 | show (Concat x y) = "Concat\n" ++ show x ++ "\n" ++ show y 57 | 58 | -- | Run two layers in parallel, combining their outputs. 59 | instance (UpdateLayer x, UpdateLayer y) => UpdateLayer (Concat m x n y) where 60 | type Gradient (Concat m x n y) = (Gradient x, Gradient y) 61 | runUpdate lr (Concat x y) (x', y') = Concat (runUpdate lr x x') (runUpdate lr y y') 62 | createRandom = Concat <$> createRandom <*> createRandom 63 | 64 | instance ( SingI i 65 | , Layer x i ('D1 o) 66 | , Layer y i ('D1 o) 67 | ) => Layer (Concat ('D1 o) x ('D1 o) y) i ('D2 2 o) where 68 | type Tape (Concat ('D1 o) x ('D1 o) y) i ('D2 2 o) = (Tape x i ('D1 o), Tape y i ('D1 o)) 69 | 70 | runForwards (Concat x y) input = 71 | let (xT, xOut :: S ('D1 o)) = runForwards x input 72 | (yT, yOut :: S ('D1 o)) = runForwards y input 73 | in case (xOut, yOut) of 74 | (S1D xOut', S1D yOut') -> 75 | ((xT, yT), S2D (row xOut' === row yOut')) 76 | 77 | runBackwards (Concat x y) (xTape, yTape) (S2D o) = 78 | let (ox, oy) = splitRows o 79 | (x', xB :: S i) = runBackwards x xTape (S1D $ unrow ox) 80 | (y', yB :: S i) = runBackwards y yTape (S1D $ unrow oy) 81 | in ((x', y'), xB + yB) 82 | 83 | instance ( SingI i 84 | , Layer x i ('D1 m) 85 | , Layer y i ('D1 n) 86 | , KnownNat o 87 | , KnownNat m 88 | , KnownNat n 89 | , o ~ (m + n) 90 | , n ~ (o - m) 91 | , (m <=? o) ~ 'True 92 | ) => Layer (Concat ('D1 m) x ('D1 n) y) i ('D1 o) where 93 | type Tape (Concat ('D1 m) x ('D1 n) y) i ('D1 o) = (Tape x i ('D1 m), Tape y i ('D1 n)) 94 | 95 | runForwards (Concat x y) input = 96 | let (xT, xOut :: S ('D1 m)) = runForwards x input 97 | (yT, yOut :: S ('D1 n)) = runForwards y input 98 | in case (xOut, yOut) of 99 | (S1D xOut', S1D yOut') -> 100 | ((xT, yT), S1D (xOut' # yOut')) 101 | 102 | runBackwards (Concat x y) (xTape, yTape) (S1D o) = 103 | let (ox :: R m , oy :: R n) = split o 104 | (x', xB :: S i) = runBackwards x xTape (S1D ox) 105 | (y', yB :: S i) = runBackwards y yTape (S1D oy) 106 | in ((x', y'), xB + yB) 107 | 108 | -- | Concat 3D shapes, increasing the number of channels. 109 | instance ( SingI i 110 | , Layer x i ('D3 rows cols m) 111 | , Layer y i ('D3 rows cols n) 112 | , KnownNat (rows * n) 113 | , KnownNat (rows * m) 114 | , KnownNat (rows * o) 115 | , KnownNat o 116 | , KnownNat m 117 | , KnownNat n 118 | , ((rows * m) + (rows * n)) ~ (rows * o) 119 | , ((rows * o) - (rows * m)) ~ (rows * n) 120 | , ((rows * m) <=? (rows * o)) ~ 'True 121 | ) => Layer (Concat ('D3 rows cols m) x ('D3 rows cols n) y) i ('D3 rows cols o) where 122 | type Tape (Concat ('D3 rows cols m) x ('D3 rows cols n) y) i ('D3 rows cols o) = (Tape x i ('D3 rows cols m), Tape y i ('D3 rows cols n)) 123 | 124 | runForwards (Concat x y) input = 125 | let (xT, xOut :: S ('D3 rows cols m)) = runForwards x input 126 | (yT, yOut :: S ('D3 rows cols n)) = runForwards y input 127 | in case (xOut, yOut) of 128 | (S3D xOut', S3D yOut') -> 129 | ((xT, yT), S3D (xOut' === yOut')) 130 | 131 | runBackwards (Concat x y) (xTape, yTape) (S3D o) = 132 | let (ox, oy) = splitRows o 133 | (x', xB :: S i) = runBackwards x xTape (S3D ox :: S ('D3 rows cols m)) 134 | (y', yB :: S i) = runBackwards y yTape (S3D oy :: S ('D3 rows cols n)) 135 | in ((x', y'), xB + yB) 136 | 137 | instance (Serialize a, Serialize b) => Serialize (Concat sa a sb b) where 138 | put (Concat a b) = put a *> put b 139 | get = Concat <$> get <*> get 140 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Crop.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TypeFamilies #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE MultiParamTypeClasses #-} 9 | {-# LANGUAGE UndecidableInstances #-} 10 | {-| 11 | Module : Grenade.Layers.Crop 12 | Description : Cropping layer 13 | Copyright : (c) Huw Campbell, 2016-2017 14 | License : BSD2 15 | Stability : experimental 16 | -} 17 | module Grenade.Layers.Crop ( 18 | Crop (..) 19 | ) where 20 | 21 | import Data.Maybe 22 | import Data.Proxy 23 | import Data.Serialize 24 | 25 | import GHC.TypeLits 26 | import Data.Kind (Type) 27 | 28 | import Grenade.Core 29 | import Grenade.Layers.Internal.Pad 30 | 31 | import Numeric.LinearAlgebra (konst, subMatrix, diagBlock) 32 | import Numeric.LinearAlgebra.Static (extract, create) 33 | 34 | -- | A cropping layer for a neural network. 35 | data Crop :: Nat 36 | -> Nat 37 | -> Nat 38 | -> Nat -> Type where 39 | Crop :: Crop cropLeft cropTop cropRight cropBottom 40 | 41 | instance Show (Crop cropLeft cropTop cropRight cropBottom) where 42 | show Crop = "Crop" 43 | 44 | instance UpdateLayer (Crop l t r b) where 45 | type Gradient (Crop l t r b) = () 46 | runUpdate _ x _ = x 47 | createRandom = return Crop 48 | 49 | instance Serialize (Crop l t r b) where 50 | put _ = return () 51 | get = return Crop 52 | 53 | -- | A two dimentional image can be cropped. 54 | instance ( KnownNat cropLeft 55 | , KnownNat cropTop 56 | , KnownNat cropRight 57 | , KnownNat cropBottom 58 | , KnownNat inputRows 59 | , KnownNat inputColumns 60 | , KnownNat outputRows 61 | , KnownNat outputColumns 62 | , (outputRows + cropTop + cropBottom) ~ inputRows 63 | , (outputColumns + cropLeft + cropRight) ~ inputColumns 64 | ) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where 65 | type Tape (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) = () 66 | runForwards Crop (S2D input) = 67 | let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) 68 | cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) 69 | nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows) 70 | ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns) 71 | m = extract input 72 | r = subMatrix (cropt, cropl) (nrows, ncols) m 73 | in ((), S2D . fromJust . create $ r) 74 | runBackwards _ _ (S2D dEdy) = 75 | let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) 76 | cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) 77 | cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight) 78 | cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) 79 | eo = extract dEdy 80 | vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)] 81 | in ((), S2D . fromJust . create $ vs) 82 | 83 | 84 | -- | A two dimentional image can be cropped. 85 | instance ( KnownNat cropLeft 86 | , KnownNat cropTop 87 | , KnownNat cropRight 88 | , KnownNat cropBottom 89 | , KnownNat inputRows 90 | , KnownNat inputColumns 91 | , KnownNat outputRows 92 | , KnownNat outputColumns 93 | , KnownNat channels 94 | , KnownNat (inputRows * channels) 95 | , KnownNat (outputRows * channels) 96 | , (outputRows + cropTop + cropBottom) ~ inputRows 97 | , (outputColumns + cropLeft + cropRight) ~ inputColumns 98 | ) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where 99 | type Tape (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = () 100 | runForwards Crop (S3D input) = 101 | let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) 102 | padt = fromIntegral $ natVal (Proxy :: Proxy cropTop) 103 | padr = fromIntegral $ natVal (Proxy :: Proxy cropRight) 104 | padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) 105 | inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) 106 | inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 107 | outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) 108 | outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) 109 | ch = fromIntegral $ natVal (Proxy :: Proxy channels) 110 | m = extract input 111 | cropped = crop ch padl padt padr padb outr outc inr inc m 112 | in ((), S3D . fromJust . create $ cropped) 113 | 114 | runBackwards Crop () (S3D gradient) = 115 | let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) 116 | padt = fromIntegral $ natVal (Proxy :: Proxy cropTop) 117 | padr = fromIntegral $ natVal (Proxy :: Proxy cropRight) 118 | padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) 119 | inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) 120 | inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 121 | outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) 122 | outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) 123 | ch = fromIntegral $ natVal (Proxy :: Proxy channels) 124 | m = extract gradient 125 | padded = pad ch padl padt padr padb outr outc inr inc m 126 | in ((), S3D . fromJust . create $ padded) 127 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Dropout.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | module Grenade.Layers.Dropout ( 6 | Dropout (..) 7 | , randomDropout 8 | ) where 9 | 10 | import Control.Monad.Random hiding (fromList) 11 | 12 | import GHC.TypeLits 13 | import Grenade.Core 14 | 15 | -- Dropout layer help to reduce overfitting. 16 | -- Idea here is that the vector is a shape of 1s and 0s, which we multiply the input by. 17 | -- After backpropogation, we return a new matrix/vector, with different bits dropped out. 18 | -- Double is the proportion to drop in each training iteration (like 1% or 5% would be 19 | -- reasonable). 20 | data Dropout = Dropout { 21 | dropoutRate :: Double 22 | , dropoutSeed :: Int 23 | } deriving Show 24 | 25 | instance UpdateLayer Dropout where 26 | type Gradient Dropout = () 27 | runUpdate _ x _ = x 28 | createRandom = randomDropout 0.95 29 | 30 | randomDropout :: MonadRandom m 31 | => Double -> m Dropout 32 | randomDropout rate = Dropout rate <$> getRandom 33 | 34 | instance (KnownNat i) => Layer Dropout ('D1 i) ('D1 i) where 35 | type Tape Dropout ('D1 i) ('D1 i) = () 36 | runForwards (Dropout _ _) (S1D x) = ((), S1D x) 37 | runBackwards (Dropout _ _) _ (S1D x) = ((), S1D x) 38 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Elu.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-| 6 | Module : Grenade.Layers.Logit 7 | Description : Exponential linear unit layer 8 | Copyright : (c) Huw Campbell, 2016-2017 9 | License : BSD2 10 | Stability : experimental 11 | -} 12 | module Grenade.Layers.Elu ( 13 | Elu (..) 14 | ) where 15 | 16 | import Data.Serialize 17 | 18 | import GHC.TypeLits 19 | import Grenade.Core 20 | 21 | import qualified Numeric.LinearAlgebra.Static as LAS 22 | 23 | -- | An exponential linear unit. 24 | -- A layer which can act between any shape of the same dimension, acting as a 25 | -- diode on every neuron individually. 26 | data Elu = Elu 27 | deriving Show 28 | 29 | instance UpdateLayer Elu where 30 | type Gradient Elu = () 31 | runUpdate _ _ _ = Elu 32 | createRandom = return Elu 33 | 34 | instance Serialize Elu where 35 | put _ = return () 36 | get = return Elu 37 | 38 | instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where 39 | type Tape Elu ('D1 i) ('D1 i) = LAS.R i 40 | 41 | runForwards _ (S1D y) = (y, S1D (elu y)) 42 | where 43 | elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a) 44 | runBackwards _ y (S1D dEdy) = ((), S1D (elu' y * dEdy)) 45 | where 46 | elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1) 47 | 48 | instance (KnownNat i, KnownNat j) => Layer Elu ('D2 i j) ('D2 i j) where 49 | type Tape Elu ('D2 i j) ('D2 i j) = S ('D2 i j) 50 | 51 | runForwards _ (S2D y) = (S2D y, S2D (elu y)) 52 | where 53 | elu = LAS.dmmap (\a -> if a <= 0 then exp a - 1 else a) 54 | runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (elu' y * dEdy)) 55 | where 56 | elu' = LAS.dmmap (\a -> if a <= 0 then exp a else 1) 57 | 58 | instance (KnownNat i, KnownNat j, KnownNat k) => Layer Elu ('D3 i j k) ('D3 i j k) where 59 | 60 | type Tape Elu ('D3 i j k) ('D3 i j k) = S ('D3 i j k) 61 | 62 | runForwards _ (S3D y) = (S3D y, S3D (elu y)) 63 | where 64 | elu = LAS.dmmap (\a -> if a <= 0 then exp a - 1 else a) 65 | runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (elu' y * dEdy)) 66 | where 67 | elu' = LAS.dmmap (\a -> if a <= 0 then exp a else 1) 68 | -------------------------------------------------------------------------------- /src/Grenade/Layers/FullyConnected.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | module Grenade.Layers.FullyConnected ( 7 | FullyConnected (..) 8 | , FullyConnected' (..) 9 | , randomFullyConnected 10 | ) where 11 | 12 | import Control.Monad.Random hiding (fromList) 13 | 14 | import Data.Proxy 15 | import Data.Serialize 16 | 17 | import qualified Numeric.LinearAlgebra as LA 18 | import Numeric.LinearAlgebra.Static 19 | 20 | import Grenade.Core 21 | 22 | import Grenade.Layers.Internal.Update 23 | import GHC.TypeLits 24 | 25 | -- | A basic fully connected (or inner product) neural network layer. 26 | data FullyConnected i o = FullyConnected 27 | !(FullyConnected' i o) -- Neuron weights 28 | !(FullyConnected' i o) -- Neuron momentum 29 | 30 | data FullyConnected' i o = FullyConnected' 31 | !(R o) -- Bias 32 | !(L o i) -- Activations 33 | 34 | instance Show (FullyConnected i o) where 35 | show FullyConnected {} = "FullyConnected" 36 | 37 | instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where 38 | type Gradient (FullyConnected i o) = (FullyConnected' i o) 39 | 40 | runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) = 41 | let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum 42 | (newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum 43 | in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum) 44 | 45 | createRandom = randomFullyConnected 46 | 47 | instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where 48 | type Tape (FullyConnected i o) ('D1 i) ('D1 o) = R i 49 | -- Do a matrix vector multiplication and return the result. 50 | runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (v, S1D (wB + wN #> v)) 51 | 52 | -- Run a backpropogation step for a full connected layer. 53 | runBackwards (FullyConnected (FullyConnected' _ wN) _) x (S1D dEdy) = 54 | let wB' = dEdy 55 | mm' = dEdy `outer` x 56 | -- calculate derivatives for next step 57 | dWs = tr wN #> dEdy 58 | in (FullyConnected' wB' mm', S1D dWs) 59 | 60 | instance (KnownNat i, KnownNat o) => Serialize (FullyConnected i o) where 61 | put (FullyConnected (FullyConnected' b w) _) = do 62 | putListOf put . LA.toList . extract $ b 63 | putListOf put . LA.toList . LA.flatten . extract $ w 64 | 65 | get = do 66 | let f = fromIntegral $ natVal (Proxy :: Proxy i) 67 | b <- maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get 68 | k <- maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get 69 | let bm = konst 0 70 | let mm = konst 0 71 | return $ FullyConnected (FullyConnected' b k) (FullyConnected' bm mm) 72 | 73 | randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o) 74 | => m (FullyConnected i o) 75 | randomFullyConnected = do 76 | s1 <- getRandom 77 | s2 <- getRandom 78 | let wB = randomVector s1 Uniform * 2 - 1 79 | wN = uniformSample s2 (-1) 1 80 | bm = konst 0 81 | mm = konst 0 82 | return $ FullyConnected (FullyConnected' wB wN) (FullyConnected' bm mm) 83 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Inception.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE TypeOperators #-} 4 | {-# LANGUAGE TypeFamilies #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | {-# LANGUAGE FlexibleContexts #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | {-# LANGUAGE FlexibleInstances #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-| 11 | Module : Grenade.Core.Network 12 | Description : Inception style parallel convolutional network composition. 13 | Copyright : (c) Huw Campbell, 2016-2017 14 | License : BSD2 15 | Stability : experimental 16 | 17 | Export an Inception style type, which can be used to build up 18 | complex multiconvolution size networks. 19 | -} 20 | module Grenade.Layers.Inception ( 21 | Inception 22 | , InceptionMini 23 | , Resnet 24 | ) where 25 | 26 | import GHC.TypeLits 27 | 28 | import Grenade.Core 29 | import Grenade.Layers.Convolution 30 | import Grenade.Layers.Pad 31 | import Grenade.Layers.Concat 32 | import Grenade.Layers.Merge 33 | import Grenade.Layers.Trivial 34 | 35 | -- | Type of an inception layer. 36 | -- 37 | -- It looks like a bit of a handful, but is actually pretty easy to use. 38 | -- 39 | -- The first three type parameters are the size of the (3D) data the 40 | -- inception layer will take. It will emit 3D data with the number of 41 | -- channels being the sum of @chx@, @chy@, @chz@, which are the number 42 | -- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers 43 | -- respectively. 44 | -- 45 | -- The network get padded effectively before each convolution filters 46 | -- such that the output dimension is the same x and y as the input. 47 | type Inception rows cols channels chx chy chz 48 | = Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionMini rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ] 49 | '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ] 50 | 51 | type InceptionMini rows cols channels chx chy 52 | = Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ] 53 | '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ] 54 | 55 | type Inception3x3 rows cols channels chx 56 | = Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ] 57 | '[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ] 58 | 59 | type Inception5x5 rows cols channels chx 60 | = Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ] 61 | '[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ] 62 | 63 | type Inception7x7 rows cols channels chx 64 | = Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ] 65 | '[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ] 66 | 67 | type Resnet branch = Merge Trivial branch 68 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Internal/Convolution.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ForeignFunctionInterface #-} 2 | module Grenade.Layers.Internal.Convolution ( 3 | im2col 4 | , col2im 5 | , col2vid 6 | , vid2col 7 | ) where 8 | 9 | import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 10 | 11 | import Foreign ( mallocForeignPtrArray, withForeignPtr ) 12 | import Foreign.Ptr ( Ptr ) 13 | 14 | import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols ) 15 | import qualified Numeric.LinearAlgebra.Devel as U 16 | 17 | import System.IO.Unsafe ( unsafePerformIO ) 18 | 19 | col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 20 | col2vid kernelRows kernelColumns strideRows strideColumns height width dataCol = 21 | let channels = cols dataCol `div` (kernelRows * kernelColumns) 22 | in col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol 23 | 24 | col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 25 | col2im kernelRows kernelColumns strideRows strideColumns height width dataCol = 26 | let channels = 1 27 | in col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol 28 | 29 | col2im_c :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 30 | col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol = 31 | let vec = flatten dataCol 32 | in unsafePerformIO $ do 33 | outPtr <- mallocForeignPtrArray (height * width * channels) 34 | let (inPtr, _) = U.unsafeToForeignPtr0 vec 35 | 36 | withForeignPtr inPtr $ \inPtr' -> 37 | withForeignPtr outPtr $ \outPtr' -> 38 | col2im_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 39 | 40 | let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels) 41 | return $ U.matrixFromVector U.RowMajor (height * channels) width matVec 42 | 43 | foreign import ccall unsafe 44 | col2im_cpu 45 | :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 46 | 47 | vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 48 | vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid = 49 | let channels = rows dataVid `div` height 50 | in im2col_c channels height width kernelRows kernelColumns strideRows strideColumns dataVid 51 | 52 | 53 | im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 54 | im2col kernelRows kernelColumns strideRows strideColumns dataIm = 55 | let channels = 1 56 | height = rows dataIm 57 | width = cols dataIm 58 | in im2col_c channels height width kernelRows kernelColumns strideRows strideColumns dataIm 59 | 60 | im2col_c :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 61 | im2col_c channels height width kernelRows kernelColumns strideRows strideColumns dataIm = 62 | let vec = flatten dataIm 63 | rowOut = (height - kernelRows) `div` strideRows + 1 64 | colOut = (width - kernelColumns) `div` strideColumns + 1 65 | kernelSize = kernelRows * kernelColumns 66 | numberOfPatches = rowOut * colOut 67 | in unsafePerformIO $ do 68 | outPtr <- mallocForeignPtrArray (numberOfPatches * kernelSize * channels) 69 | let (inPtr, _) = U.unsafeToForeignPtr0 vec 70 | 71 | withForeignPtr inPtr $ \inPtr' -> 72 | withForeignPtr outPtr $ \outPtr' -> 73 | im2col_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 74 | 75 | let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * kernelSize * channels) 76 | return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec 77 | 78 | foreign import ccall unsafe 79 | im2col_cpu 80 | :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 81 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Internal/Pad.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ForeignFunctionInterface #-} 2 | module Grenade.Layers.Internal.Pad ( 3 | pad 4 | , crop 5 | ) where 6 | 7 | import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 8 | 9 | import Foreign ( mallocForeignPtrArray, withForeignPtr ) 10 | import Foreign.Ptr ( Ptr ) 11 | 12 | import Numeric.LinearAlgebra ( flatten, Matrix ) 13 | import qualified Numeric.LinearAlgebra.Devel as U 14 | 15 | import System.IO.Unsafe ( unsafePerformIO ) 16 | 17 | pad :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 18 | pad channels padLeft padTop padRight padBottom rows cols rows' cols' m 19 | = let outMatSize = rows' * cols' * channels 20 | vec = flatten m 21 | in unsafePerformIO $ do 22 | outPtr <- mallocForeignPtrArray outMatSize 23 | let (inPtr, _) = U.unsafeToForeignPtr0 vec 24 | 25 | withForeignPtr inPtr $ \inPtr' -> 26 | withForeignPtr outPtr $ \outPtr' -> 27 | pad_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr' 28 | 29 | let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize 30 | return (U.matrixFromVector U.RowMajor (rows' * channels) cols' matVec) 31 | {-# INLINE pad #-} 32 | 33 | foreign import ccall unsafe 34 | pad_cpu 35 | :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 36 | 37 | crop :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 38 | crop channels padLeft padTop padRight padBottom rows cols _ _ m 39 | = let outMatSize = rows * cols * channels 40 | vec = flatten m 41 | in unsafePerformIO $ do 42 | outPtr <- mallocForeignPtrArray outMatSize 43 | let (inPtr, _) = U.unsafeToForeignPtr0 vec 44 | 45 | withForeignPtr inPtr $ \inPtr' -> 46 | withForeignPtr outPtr $ \outPtr' -> 47 | crop_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr' 48 | 49 | let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize 50 | return (U.matrixFromVector U.RowMajor (rows * channels) cols matVec) 51 | 52 | foreign import ccall unsafe 53 | crop_cpu 54 | :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 55 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Internal/Pooling.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ForeignFunctionInterface #-} 2 | module Grenade.Layers.Internal.Pooling ( 3 | poolForward 4 | , poolBackward 5 | ) where 6 | 7 | import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 8 | 9 | import Foreign ( mallocForeignPtrArray, withForeignPtr ) 10 | import Foreign.Ptr ( Ptr ) 11 | 12 | import Numeric.LinearAlgebra ( Matrix , flatten ) 13 | import qualified Numeric.LinearAlgebra.Devel as U 14 | 15 | import System.IO.Unsafe ( unsafePerformIO ) 16 | 17 | poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 18 | poolForward channels height width kernelRows kernelColumns strideRows strideColumns dataIm = 19 | let vec = flatten dataIm 20 | rowOut = (height - kernelRows) `div` strideRows + 1 21 | colOut = (width - kernelColumns) `div` strideColumns + 1 22 | numberOfPatches = rowOut * colOut 23 | in unsafePerformIO $ do 24 | outPtr <- mallocForeignPtrArray (numberOfPatches * channels) 25 | let (inPtr, _) = U.unsafeToForeignPtr0 vec 26 | 27 | withForeignPtr inPtr $ \inPtr' -> 28 | withForeignPtr outPtr $ \outPtr' -> 29 | pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 30 | 31 | let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels) 32 | return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec 33 | 34 | foreign import ccall unsafe 35 | pool_forwards_cpu 36 | :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 37 | 38 | poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double 39 | poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad = 40 | let vecIm = flatten dataIm 41 | vecGrad = flatten dataGrad 42 | in unsafePerformIO $ do 43 | outPtr <- mallocForeignPtrArray (height * width * channels) 44 | let (imPtr, _) = U.unsafeToForeignPtr0 vecIm 45 | let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad 46 | 47 | withForeignPtr imPtr $ \imPtr' -> 48 | withForeignPtr gradPtr $ \gradPtr' -> 49 | withForeignPtr outPtr $ \outPtr' -> 50 | pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 51 | 52 | let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels) 53 | return $ U.matrixFromVector U.RowMajor (height * channels) width matVec 54 | 55 | foreign import ccall unsafe 56 | pool_backwards_cpu 57 | :: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 58 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Internal/Update.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ForeignFunctionInterface #-} 2 | module Grenade.Layers.Internal.Update ( 3 | descendMatrix 4 | , descendVector 5 | ) where 6 | 7 | import Data.Maybe ( fromJust ) 8 | import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 9 | 10 | import Foreign ( mallocForeignPtrArray, withForeignPtr ) 11 | import Foreign.Ptr ( Ptr ) 12 | import GHC.TypeLits 13 | 14 | import Numeric.LinearAlgebra ( Vector, flatten ) 15 | import Numeric.LinearAlgebra.Static 16 | import qualified Numeric.LinearAlgebra.Devel as U 17 | 18 | import System.IO.Unsafe ( unsafePerformIO ) 19 | 20 | descendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns) 21 | descendMatrix rate momentum regulariser weights gradient lastUpdate = 22 | let (rows, cols) = size weights 23 | len = rows * cols 24 | -- Most gradients come in in ColumnMajor, 25 | -- so we'll transpose here before flattening them 26 | -- into a vector to prevent a copy. 27 | -- 28 | -- This gives ~15% speed improvement for LSTMs. 29 | weights' = flatten . tr . extract $ weights 30 | gradient' = flatten . tr . extract $ gradient 31 | lastUpdate' = flatten . tr . extract $ lastUpdate 32 | (vw, vm) = descendUnsafe len rate momentum regulariser weights' gradient' lastUpdate' 33 | 34 | -- Note that it's ColumnMajor, as we did a transpose before 35 | -- using the internal vectors. 36 | mw = U.matrixFromVector U.ColumnMajor rows cols vw 37 | mm = U.matrixFromVector U.ColumnMajor rows cols vm 38 | in (fromJust . create $ mw, fromJust . create $ mm) 39 | 40 | descendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r) 41 | descendVector rate momentum regulariser weights gradient lastUpdate = 42 | let len = size weights 43 | weights' = extract weights 44 | gradient' = extract gradient 45 | lastUpdate' = extract lastUpdate 46 | (vw, vm) = descendUnsafe len rate momentum regulariser weights' gradient' lastUpdate' 47 | in (fromJust $ create vw, fromJust $ create vm) 48 | 49 | descendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double) 50 | descendUnsafe len rate momentum regulariser weights gradient lastUpdate = 51 | unsafePerformIO $ do 52 | outWPtr <- mallocForeignPtrArray len 53 | outMPtr <- mallocForeignPtrArray len 54 | let (wPtr, _) = U.unsafeToForeignPtr0 weights 55 | let (gPtr, _) = U.unsafeToForeignPtr0 gradient 56 | let (lPtr, _) = U.unsafeToForeignPtr0 lastUpdate 57 | 58 | withForeignPtr wPtr $ \wPtr' -> 59 | withForeignPtr gPtr $ \gPtr' -> 60 | withForeignPtr lPtr $ \lPtr' -> 61 | withForeignPtr outWPtr $ \outWPtr' -> 62 | withForeignPtr outMPtr $ \outMPtr' -> 63 | descend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr' 64 | 65 | return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len) 66 | 67 | foreign import ccall unsafe 68 | descend_cpu 69 | :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO () 70 | 71 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Logit.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | {-| 7 | Module : Grenade.Layers.Logit 8 | Description : Sigmoid nonlinear layer 9 | Copyright : (c) Huw Campbell, 2016-2017 10 | License : BSD2 11 | Stability : experimental 12 | -} 13 | module Grenade.Layers.Logit ( 14 | Logit (..) 15 | ) where 16 | 17 | 18 | import Data.Serialize 19 | import Data.Singletons 20 | 21 | import Grenade.Core 22 | 23 | -- | A Logit layer. 24 | -- 25 | -- A layer which can act between any shape of the same dimension, perfoming an sigmoid function. 26 | -- This layer should be used as the output layer of a network for logistic regression (classification) 27 | -- problems. 28 | data Logit = Logit 29 | deriving Show 30 | 31 | instance UpdateLayer Logit where 32 | type Gradient Logit = () 33 | runUpdate _ _ _ = Logit 34 | createRandom = return Logit 35 | 36 | instance (a ~ b, SingI a) => Layer Logit a b where 37 | -- Wengert tape optimisation: 38 | -- 39 | -- Derivative of the sigmoid function is 40 | -- d σ(x) / dx = σ(x) • (1 - σ(x)) 41 | -- but we have already calculated σ(x) in 42 | -- the forward pass, so just store that 43 | -- and use it in the backwards pass. 44 | type Tape Logit a b = S a 45 | runForwards _ a = 46 | let l = sigmoid a 47 | in (l, l) 48 | runBackwards _ l g = 49 | let sigmoid' = l * (1 - l) 50 | in ((), sigmoid' * g) 51 | 52 | instance Serialize Logit where 53 | put _ = return () 54 | get = return Logit 55 | 56 | sigmoid :: Floating a => a -> a 57 | sigmoid x = 1 / (1 + exp (-x)) 58 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Merge.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE TypeOperators #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE FlexibleInstances #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE StandaloneDeriving #-} 12 | {-| 13 | Module : Grenade.Core.Network 14 | Description : Merging layer for parallel network composition 15 | Copyright : (c) Huw Campbell, 2016-2017 16 | License : BSD2 17 | Stability : experimental 18 | -} 19 | module Grenade.Layers.Merge ( 20 | Merge (..) 21 | ) where 22 | 23 | import Data.Serialize 24 | 25 | import Data.Singletons 26 | 27 | #if MIN_VERSION_base(4,9,0) 28 | import Data.Kind (Type) 29 | #endif 30 | 31 | import Grenade.Core 32 | 33 | -- | A Merging layer. 34 | -- 35 | -- Similar to Concat layer, except sums the activations instead of creating a larger 36 | -- shape. 37 | data Merge :: Type -> Type -> Type where 38 | Merge :: x -> y -> Merge x y 39 | 40 | instance (Show x, Show y) => Show (Merge x y) where 41 | show (Merge x y) = "Merge\n" ++ show x ++ "\n" ++ show y 42 | 43 | -- | Run two layers in parallel, combining their outputs. 44 | -- This just kind of "smooshes" the weights together. 45 | instance (UpdateLayer x, UpdateLayer y) => UpdateLayer (Merge x y) where 46 | type Gradient (Merge x y) = (Gradient x, Gradient y) 47 | runUpdate lr (Merge x y) (x', y') = Merge (runUpdate lr x x') (runUpdate lr y y') 48 | createRandom = Merge <$> createRandom <*> createRandom 49 | 50 | -- | Combine the outputs and the inputs, summing the output shape 51 | instance (SingI i, SingI o, Layer x i o, Layer y i o) => Layer (Merge x y) i o where 52 | type Tape (Merge x y) i o = (Tape x i o, Tape y i o) 53 | 54 | runForwards (Merge x y) input = 55 | let (xT, xOut) = runForwards x input 56 | (yT, yOut) = runForwards y input 57 | in ((xT, yT), xOut + yOut) 58 | 59 | runBackwards (Merge x y) (xTape, yTape) o = 60 | let (x', xB) = runBackwards x xTape o 61 | (y', yB) = runBackwards y yTape o 62 | in ((x', y'), xB + yB) 63 | 64 | instance (Serialize a, Serialize b) => Serialize (Merge a b) where 65 | put (Merge a b) = put a *> put b 66 | get = Merge <$> get <*> get 67 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Pad.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TypeFamilies #-} 7 | {-# LANGUAGE MultiParamTypeClasses #-} 8 | {-# LANGUAGE FlexibleContexts #-} 9 | {-# LANGUAGE UndecidableInstances #-} 10 | {-| 11 | Module : Grenade.Core.Pad 12 | Description : Padding layer for 2D and 3D images 13 | Copyright : (c) Huw Campbell, 2016-2017 14 | License : BSD2 15 | Stability : experimental 16 | -} 17 | module Grenade.Layers.Pad ( 18 | Pad (..) 19 | ) where 20 | 21 | import Data.Maybe 22 | import Data.Proxy 23 | import Data.Serialize 24 | 25 | import GHC.TypeLits 26 | import Data.Kind (Type) 27 | 28 | import Grenade.Core 29 | import Grenade.Layers.Internal.Pad 30 | 31 | import Numeric.LinearAlgebra (konst, subMatrix, diagBlock) 32 | import Numeric.LinearAlgebra.Static (extract, create) 33 | 34 | -- | A padding layer for a neural network. 35 | -- 36 | -- Pads on the X and Y dimension of an image. 37 | data Pad :: Nat 38 | -> Nat 39 | -> Nat 40 | -> Nat -> Type where 41 | Pad :: Pad padLeft padTop padRight padBottom 42 | 43 | instance Show (Pad padLeft padTop padRight padBottom) where 44 | show Pad = "Pad" 45 | 46 | instance UpdateLayer (Pad l t r b) where 47 | type Gradient (Pad l t r b) = () 48 | runUpdate _ x _ = x 49 | createRandom = return Pad 50 | 51 | instance Serialize (Pad l t r b) where 52 | put _ = return () 53 | get = return Pad 54 | 55 | -- | A two dimentional image can be padded. 56 | instance ( KnownNat padLeft 57 | , KnownNat padTop 58 | , KnownNat padRight 59 | , KnownNat padBottom 60 | , KnownNat inputRows 61 | , KnownNat inputColumns 62 | , KnownNat outputRows 63 | , KnownNat outputColumns 64 | , (inputRows + padTop + padBottom) ~ outputRows 65 | , (inputColumns + padLeft + padRight) ~ outputColumns 66 | ) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where 67 | type Tape (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) = () 68 | runForwards Pad (S2D input) = 69 | let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) 70 | padt = fromIntegral $ natVal (Proxy :: Proxy padTop) 71 | padr = fromIntegral $ natVal (Proxy :: Proxy padRight) 72 | padb = fromIntegral $ natVal (Proxy :: Proxy padBottom) 73 | m = extract input 74 | r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)] 75 | in ((), S2D . fromJust . create $ r) 76 | runBackwards Pad _ (S2D dEdy) = 77 | let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) 78 | padt = fromIntegral $ natVal (Proxy :: Proxy padTop) 79 | nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows) 80 | ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 81 | m = extract dEdy 82 | vs = subMatrix (padt, padl) (nrows, ncols) m 83 | in ((), S2D . fromJust . create $ vs) 84 | 85 | -- | A two dimentional image can be padded. 86 | instance ( KnownNat padLeft 87 | , KnownNat padTop 88 | , KnownNat padRight 89 | , KnownNat padBottom 90 | , KnownNat inputRows 91 | , KnownNat inputColumns 92 | , KnownNat outputRows 93 | , KnownNat outputColumns 94 | , KnownNat channels 95 | , KnownNat (inputRows * channels) 96 | , KnownNat (outputRows * channels) 97 | , (inputRows + padTop + padBottom) ~ outputRows 98 | , (inputColumns + padLeft + padRight) ~ outputColumns 99 | ) => Layer (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where 100 | type Tape (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = () 101 | runForwards Pad (S3D input) = 102 | let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) 103 | padt = fromIntegral $ natVal (Proxy :: Proxy padTop) 104 | padr = fromIntegral $ natVal (Proxy :: Proxy padRight) 105 | padb = fromIntegral $ natVal (Proxy :: Proxy padBottom) 106 | outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) 107 | outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) 108 | inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) 109 | inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 110 | ch = fromIntegral $ natVal (Proxy :: Proxy channels) 111 | m = extract input 112 | padded = pad ch padl padt padr padb inr inc outr outc m 113 | in ((), S3D . fromJust . create $ padded) 114 | 115 | runBackwards Pad () (S3D gradient) = 116 | let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) 117 | padt = fromIntegral $ natVal (Proxy :: Proxy padTop) 118 | padr = fromIntegral $ natVal (Proxy :: Proxy padRight) 119 | padb = fromIntegral $ natVal (Proxy :: Proxy padBottom) 120 | outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) 121 | outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) 122 | inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) 123 | inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 124 | ch = fromIntegral $ natVal (Proxy :: Proxy channels) 125 | m = extract gradient 126 | cropped = crop ch padl padt padr padb inr inc outr outc m 127 | in ((), S3D . fromJust . create $ cropped) 128 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Pooling.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE StandaloneDeriving #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE MultiParamTypeClasses #-} 9 | {-# LANGUAGE FlexibleContexts #-} 10 | {-# LANGUAGE UndecidableInstances #-} 11 | {-| 12 | Module : Grenade.Core.Pooling 13 | Description : Max Pooling layer for 2D and 3D images 14 | Copyright : (c) Huw Campbell, 2016-2017 15 | License : BSD2 16 | Stability : experimental 17 | -} 18 | module Grenade.Layers.Pooling ( 19 | Pooling (..) 20 | ) where 21 | 22 | import Data.Maybe 23 | import Data.Proxy 24 | import Data.Serialize 25 | 26 | import GHC.TypeLits 27 | import Data.Kind (Type) 28 | 29 | import Grenade.Core 30 | import Grenade.Layers.Internal.Pooling 31 | 32 | import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows) 33 | 34 | -- | A pooling layer for a neural network. 35 | -- 36 | -- Does a max pooling, looking over a kernel similarly to the convolution network, but returning 37 | -- maxarg only. This layer is often used to provide minor amounts of translational invariance. 38 | -- 39 | -- The kernel size dictates which input and output sizes will "fit". Fitting the equation: 40 | -- `out = (in - kernel) / stride + 1` for both dimensions. 41 | -- 42 | data Pooling :: Nat -> Nat -> Nat -> Nat -> Type where 43 | Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns 44 | 45 | instance Show (Pooling k k' s s') where 46 | show Pooling = "Pooling" 47 | 48 | instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where 49 | type Gradient (Pooling kernelRows kernelColumns strideRows strideColumns) = () 50 | runUpdate _ Pooling _ = Pooling 51 | createRandom = return Pooling 52 | 53 | instance Serialize (Pooling kernelRows kernelColumns strideRows strideColumns) where 54 | put _ = return () 55 | get = return Pooling 56 | 57 | -- | A two dimentional image can be pooled. 58 | instance ( KnownNat kernelRows 59 | , KnownNat kernelColumns 60 | , KnownNat strideRows 61 | , KnownNat strideColumns 62 | , KnownNat inputRows 63 | , KnownNat inputColumns 64 | , KnownNat outputRows 65 | , KnownNat outputColumns 66 | , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows) 67 | , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns) 68 | ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where 69 | type Tape (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) = S ('D2 inputRows inputColumns) 70 | runForwards Pooling (S2D input) = 71 | let height = fromIntegral $ natVal (Proxy :: Proxy inputRows) 72 | width = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 73 | kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) 74 | ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) 75 | sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) 76 | sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) 77 | ex = extract input 78 | r = poolForward 1 height width kx ky sx sy ex 79 | rs = fromJust . create $ r 80 | in (S2D input, S2D rs) 81 | runBackwards Pooling (S2D input) (S2D dEdy) = 82 | let height = fromIntegral $ natVal (Proxy :: Proxy inputRows) 83 | width = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 84 | kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) 85 | ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) 86 | sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) 87 | sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) 88 | ex = extract input 89 | eo = extract dEdy 90 | vs = poolBackward 1 height width kx ky sx sy ex eo 91 | in ((), S2D . fromJust . create $ vs) 92 | 93 | 94 | -- | A three dimensional image can be pooled on each layer. 95 | instance ( KnownNat kernelRows 96 | , KnownNat kernelColumns 97 | , KnownNat strideRows 98 | , KnownNat strideColumns 99 | , KnownNat inputRows 100 | , KnownNat inputColumns 101 | , KnownNat outputRows 102 | , KnownNat outputColumns 103 | , KnownNat channels 104 | , KnownNat (outputRows * channels) 105 | , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows) 106 | , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns) 107 | ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where 108 | type Tape (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = S ('D3 inputRows inputColumns channels) 109 | runForwards Pooling (S3D input) = 110 | let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) 111 | iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 112 | kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) 113 | ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) 114 | sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) 115 | sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) 116 | ch = fromIntegral $ natVal (Proxy :: Proxy channels) 117 | ex = extract input 118 | r = poolForward ch ix iy kx ky sx sy ex 119 | rs = fromJust . create $ r 120 | in (S3D input, S3D rs) 121 | runBackwards Pooling (S3D input) (S3D dEdy) = 122 | let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) 123 | iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 124 | kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) 125 | ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) 126 | sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) 127 | sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) 128 | ch = fromIntegral $ natVal (Proxy :: Proxy channels) 129 | ex = extract input 130 | eo = extract dEdy 131 | vs = poolBackward ch ix iy kx ky sx sy ex eo 132 | in ((), S3D . fromJust . create $ vs) 133 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Relu.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-| 6 | Module : Grenade.Layers.Relu 7 | Description : Rectifying linear unit layer 8 | Copyright : (c) Huw Campbell, 2016-2017 9 | License : BSD2 10 | Stability : experimental 11 | -} 12 | module Grenade.Layers.Relu ( 13 | Relu (..) 14 | ) where 15 | 16 | import Data.Serialize 17 | 18 | import GHC.TypeLits 19 | import Grenade.Core 20 | 21 | import qualified Numeric.LinearAlgebra.Static as LAS 22 | 23 | -- | A rectifying linear unit. 24 | -- A layer which can act between any shape of the same dimension, acting as a 25 | -- diode on every neuron individually. 26 | data Relu = Relu 27 | deriving Show 28 | 29 | instance UpdateLayer Relu where 30 | type Gradient Relu = () 31 | runUpdate _ _ _ = Relu 32 | createRandom = return Relu 33 | 34 | instance Serialize Relu where 35 | put _ = return () 36 | get = return Relu 37 | 38 | instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where 39 | type Tape Relu ('D1 i) ('D1 i) = S ('D1 i) 40 | 41 | runForwards _ (S1D y) = (S1D y, S1D (relu y)) 42 | where 43 | relu = LAS.dvmap (\a -> if a <= 0 then 0 else a) 44 | runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (relu' y * dEdy)) 45 | where 46 | relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1) 47 | 48 | instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where 49 | type Tape Relu ('D2 i j) ('D2 i j) = S ('D2 i j) 50 | 51 | runForwards _ (S2D y) = (S2D y, S2D (relu y)) 52 | where 53 | relu = LAS.dmmap (\a -> if a <= 0 then 0 else a) 54 | runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (relu' y * dEdy)) 55 | where 56 | relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1) 57 | 58 | instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where 59 | 60 | type Tape Relu ('D3 i j k) ('D3 i j k) = S ('D3 i j k) 61 | 62 | runForwards _ (S3D y) = (S3D y, S3D (relu y)) 63 | where 64 | relu = LAS.dmmap (\a -> if a <= 0 then 0 else a) 65 | runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (relu' y * dEdy)) 66 | where 67 | relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1) 68 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Reshape.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE UndecidableInstances #-} 7 | {-| 8 | Module : Grenade.Layers.Reshape 9 | Description : Multipurpose reshaping layer 10 | Copyright : (c) Huw Campbell, 2016-2017 11 | License : BSD2 12 | Stability : experimental 13 | -} 14 | module Grenade.Layers.Reshape ( 15 | Reshape (..) 16 | ) where 17 | 18 | import Data.Serialize 19 | 20 | import GHC.TypeLits 21 | 22 | import Numeric.LinearAlgebra.Static 23 | import Numeric.LinearAlgebra.Data as LA ( flatten ) 24 | 25 | import Grenade.Core 26 | 27 | -- | Reshape Layer 28 | -- 29 | -- The Reshape layer can flatten any 2D or 3D image to 1D vector with the 30 | -- same number of activations, as well as cast up from 1D to a 2D or 3D 31 | -- shape. 32 | -- 33 | -- Can also be used to turn a 3D image with only one channel into a 2D image 34 | -- or vice versa. 35 | data Reshape = Reshape 36 | deriving Show 37 | 38 | instance UpdateLayer Reshape where 39 | type Gradient Reshape = () 40 | runUpdate _ _ _ = Reshape 41 | createRandom = return Reshape 42 | 43 | instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D2 x y) ('D1 a) where 44 | type Tape Reshape ('D2 x y) ('D1 a) = () 45 | runForwards _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y) 46 | runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 47 | 48 | instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D3 x y z) ('D1 a) where 49 | type Tape Reshape ('D3 x y z) ('D1 a) = () 50 | runForwards _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y) 51 | runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 52 | 53 | instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer Reshape ('D3 x y z) ('D2 x y) where 54 | type Tape Reshape ('D3 x y z) ('D2 x y) = () 55 | runForwards _ (S3D y) = ((), S2D y) 56 | runBackwards _ _ (S2D y) = ((), S3D y) 57 | 58 | instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer Reshape ('D2 x y) ('D3 x y z) where 59 | type Tape Reshape ('D2 x y) ('D3 x y z) = () 60 | runForwards _ (S2D y) = ((), S3D y) 61 | runBackwards _ _ (S3D y) = ((), S2D y) 62 | 63 | instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where 64 | type Tape Reshape ('D1 a) ('D2 x y) = () 65 | runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 66 | runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y) 67 | 68 | instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where 69 | type Tape Reshape ('D1 a) ('D3 x y z) = () 70 | runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 71 | runBackwards _ _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y) 72 | 73 | instance Serialize Reshape where 74 | put _ = return () 75 | get = return Reshape 76 | 77 | 78 | fromJust' :: Maybe x -> x 79 | fromJust' (Just x) = x 80 | fromJust' Nothing = error $ "Reshape error: data shape couldn't be converted." 81 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Sinusoid.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE TypeFamilies #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-| 7 | Module : Grenade.Layers.Sinusoid 8 | Description : Sinusoid nonlinear layer 9 | Copyright : (c) Manuel Schneckenreither, 2018 10 | License : BSD2 11 | Stability : experimental 12 | -} 13 | module Grenade.Layers.Sinusoid ( 14 | Sinusoid (..) 15 | ) where 16 | 17 | import Data.Serialize 18 | import Data.Singletons 19 | 20 | import Grenade.Core 21 | 22 | -- | A Sinusoid layer. 23 | -- A layer which can act between any shape of the same dimension, performing a sin function. 24 | data Sinusoid = Sinusoid 25 | deriving Show 26 | 27 | instance UpdateLayer Sinusoid where 28 | type Gradient Sinusoid = () 29 | runUpdate _ _ _ = Sinusoid 30 | createRandom = return Sinusoid 31 | 32 | instance Serialize Sinusoid where 33 | put _ = return () 34 | get = return Sinusoid 35 | 36 | instance (a ~ b, SingI a) => Layer Sinusoid a b where 37 | type Tape Sinusoid a b = S a 38 | runForwards _ a = (a, sin a) 39 | runBackwards _ a g = ((), cos a * g) 40 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Softmax.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | {-| 7 | Module : Grenade.Core.Softmax 8 | Description : Softmax loss layer 9 | Copyright : (c) Huw Campbell, 2016-2017 10 | License : BSD2 11 | Stability : experimental 12 | -} 13 | module Grenade.Layers.Softmax ( 14 | Softmax (..) 15 | , softmax 16 | , softmax' 17 | ) where 18 | 19 | import Data.Serialize 20 | 21 | import GHC.TypeLits 22 | import Grenade.Core 23 | 24 | import Numeric.LinearAlgebra.Static as LAS 25 | 26 | -- | A Softmax layer 27 | -- 28 | -- This layer is like a logit layer, but normalises 29 | -- a set of matricies to be probabilities. 30 | -- 31 | -- One can use this layer as the last layer in a network 32 | -- if they need normalised probabilities. 33 | data Softmax = Softmax 34 | deriving Show 35 | 36 | instance UpdateLayer Softmax where 37 | type Gradient Softmax = () 38 | runUpdate _ _ _ = Softmax 39 | createRandom = return Softmax 40 | 41 | instance ( KnownNat i ) => Layer Softmax ('D1 i) ('D1 i) where 42 | type Tape Softmax ('D1 i) ('D1 i) = S ('D1 i) 43 | 44 | runForwards _ (S1D y) = (S1D y, S1D (softmax y)) 45 | runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (softmax' y dEdy)) 46 | 47 | instance Serialize Softmax where 48 | put _ = return () 49 | get = return Softmax 50 | 51 | softmax :: KnownNat i => LAS.R i -> LAS.R i 52 | softmax xs = 53 | let xs' = LAS.dvmap exp xs 54 | s = LAS.dot xs' 1 55 | in LAS.dvmap (/ s) xs' 56 | 57 | softmax' :: KnownNat i => LAS.R i -> LAS.R i -> LAS.R i 58 | softmax' x grad = 59 | let yTy = outer sm sm 60 | d = diag sm 61 | g = d - yTy 62 | in g #> grad 63 | where 64 | sm = softmax x 65 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Tanh.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE TypeFamilies #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-| 7 | Module : Grenade.Layers.Tanh 8 | Description : Hyperbolic tangent nonlinear layer 9 | Copyright : (c) Huw Campbell, 2016-2017 10 | License : BSD2 11 | Stability : experimental 12 | -} 13 | module Grenade.Layers.Tanh ( 14 | Tanh (..) 15 | ) where 16 | 17 | import Data.Serialize 18 | import Data.Singletons 19 | 20 | import Grenade.Core 21 | 22 | -- | A Tanh layer. 23 | -- A layer which can act between any shape of the same dimension, performing a tanh function. 24 | data Tanh = Tanh 25 | deriving Show 26 | 27 | instance UpdateLayer Tanh where 28 | type Gradient Tanh = () 29 | runUpdate _ _ _ = Tanh 30 | createRandom = return Tanh 31 | 32 | instance Serialize Tanh where 33 | put _ = return () 34 | get = return Tanh 35 | 36 | instance (a ~ b, SingI a) => Layer Tanh a b where 37 | type Tape Tanh a b = S a 38 | runForwards _ a = (a, tanh a) 39 | runBackwards _ a g = ((), tanh' a * g) 40 | 41 | tanh' :: (Floating a) => a -> a 42 | tanh' t = 1 - s ^ (2 :: Int) where s = tanh t 43 | -------------------------------------------------------------------------------- /src/Grenade/Layers/Trivial.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE FlexibleInstances #-} 6 | {-| 7 | Module : Grenade.Core.Trivial 8 | Description : Trivial layer which perfoms no operations on the data 9 | Copyright : (c) Huw Campbell, 2016-2017 10 | License : BSD2 11 | Stability : experimental 12 | -} 13 | module Grenade.Layers.Trivial ( 14 | Trivial (..) 15 | ) where 16 | 17 | import Data.Serialize 18 | 19 | import Grenade.Core 20 | 21 | -- | A Trivial layer. 22 | -- 23 | -- This can be used to pass an unchanged value up one side of a 24 | -- graph, for a Residual network for example. 25 | data Trivial = Trivial 26 | deriving Show 27 | 28 | instance Serialize Trivial where 29 | put _ = return () 30 | get = return Trivial 31 | 32 | instance UpdateLayer Trivial where 33 | type Gradient Trivial = () 34 | runUpdate _ _ _ = Trivial 35 | createRandom = return Trivial 36 | 37 | instance (a ~ b) => Layer Trivial a b where 38 | type Tape Trivial a b = () 39 | runForwards _ a = ((), a) 40 | runBackwards _ _ y = ((), y) 41 | -------------------------------------------------------------------------------- /src/Grenade/Recurrent.hs: -------------------------------------------------------------------------------- 1 | module Grenade.Recurrent ( 2 | -- | This is an empty module which simply re-exports public definitions 3 | -- for recurrent networks in Grenade. 4 | 5 | -- * Exported modules 6 | -- 7 | -- | The core types and runners for Recurrent Networks. 8 | module Grenade.Recurrent.Core 9 | 10 | -- | The recurrent neural network layer zoo 11 | , module Grenade.Recurrent.Layers 12 | 13 | -- * Overview of recurrent Networks 14 | -- $recurrent 15 | 16 | ) where 17 | 18 | import Grenade.Recurrent.Core 19 | import Grenade.Recurrent.Layers 20 | 21 | {- $recurrent 22 | There are two ways in which deep learning libraries choose to represent 23 | recurrent Neural Networks, as an unrolled graph, or at a first class 24 | level. Grenade chooses the latter representation, and provides a network 25 | type which is specifically suited for recurrent neural networks. 26 | 27 | Currently grenade supports two layers, a basic recurrent layer, and an 28 | LSTM layer. 29 | -} 30 | -------------------------------------------------------------------------------- /src/Grenade/Recurrent/Core.hs: -------------------------------------------------------------------------------- 1 | module Grenade.Recurrent.Core ( 2 | module Grenade.Recurrent.Core.Layer 3 | , module Grenade.Recurrent.Core.Network 4 | , module Grenade.Recurrent.Core.Runner 5 | ) where 6 | 7 | import Grenade.Recurrent.Core.Layer 8 | import Grenade.Recurrent.Core.Network 9 | import Grenade.Recurrent.Core.Runner 10 | -------------------------------------------------------------------------------- /src/Grenade/Recurrent/Core/Layer.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE FlexibleInstances #-} 7 | module Grenade.Recurrent.Core.Layer ( 8 | RecurrentLayer (..) 9 | , RecurrentUpdateLayer (..) 10 | ) where 11 | 12 | #if MIN_VERSION_base(4,9,0) 13 | import Data.Kind (Type) 14 | #endif 15 | 16 | import Grenade.Core 17 | 18 | -- | Class for a recurrent layer. 19 | -- It's quite similar to a normal layer but for the input and output 20 | -- of an extra recurrent data shape. 21 | class UpdateLayer x => RecurrentUpdateLayer x where 22 | -- | Shape of data that is passed between each subsequent run of the layer 23 | type RecurrentShape x :: Type 24 | 25 | class (RecurrentUpdateLayer x, Num (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where 26 | -- | Wengert Tape 27 | type RecTape x i o :: Type 28 | -- | Used in training and scoring. Take the input from the previous 29 | -- layer, and give the output from this layer. 30 | runRecurrentForwards :: x -> RecurrentShape x -> S i -> (RecTape x i o, RecurrentShape x, S o) 31 | -- | Back propagate a step. Takes the current layer, the input that the 32 | -- layer gave from the input and the back propagated derivatives from 33 | -- the layer above. 34 | -- Returns the gradient layer and the derivatives to push back further. 35 | runRecurrentBackwards :: x -> RecTape x i o -> RecurrentShape x -> S o -> (Gradient x, RecurrentShape x, S i) 36 | -------------------------------------------------------------------------------- /src/Grenade/Recurrent/Core/Runner.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE TypeFamilies #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | 10 | module Grenade.Recurrent.Core.Runner ( 11 | runRecurrentExamples 12 | , runRecurrentBackprop 13 | , backPropagateRecurrent 14 | , trainRecurrent 15 | 16 | , RecurrentGradients 17 | ) where 18 | 19 | import Data.List ( foldl' ) 20 | import Grenade.Core 21 | 22 | import Grenade.Recurrent.Core.Network 23 | import Prelude.Singletons 24 | 25 | type RecurrentGradients layers = [RecurrentGradient layers] 26 | 27 | runRecurrentExamples :: forall shapes layers. 28 | RecurrentNetwork layers shapes 29 | -> RecurrentInputs layers 30 | -> [S (Head shapes)] 31 | -> ([(RecurrentTape layers shapes, S (Last shapes))], RecurrentInputs layers) 32 | runRecurrentExamples net = 33 | go 34 | where 35 | go !side [] = ([], side) 36 | go !side (!x : xs) = 37 | let (!tape, !side', !o) = runRecurrent net side x 38 | (!res, !finalSide) = go side' xs 39 | in (( tape, o ) : res, finalSide) 40 | 41 | runRecurrentBackprop :: forall layers shapes. 42 | RecurrentNetwork layers shapes 43 | -> RecurrentInputs layers 44 | -> [(RecurrentTape layers shapes, S (Last shapes))] 45 | -> ([(RecurrentGradient layers, S (Head shapes))], RecurrentInputs layers) 46 | runRecurrentBackprop net = 47 | go 48 | where 49 | go !side [] = ([], side) 50 | go !side ((!tape,!x):xs) = 51 | let (res, !side') = go side xs 52 | (!grad, !finalSide, !o) = runRecurrent' net tape side' x 53 | in (( grad, o ) : res, finalSide) 54 | 55 | 56 | -- | Drive and network and collect its back propogated gradients. 57 | backPropagateRecurrent :: forall shapes layers. (SingI (Last shapes), Fractional (RecurrentInputs layers)) 58 | => RecurrentNetwork layers shapes 59 | -> RecurrentInputs layers 60 | -> [(S (Head shapes), Maybe (S (Last shapes)))] 61 | -> (RecurrentGradients layers, RecurrentInputs layers) 62 | backPropagateRecurrent network recinputs examples = 63 | let (outForwards, _) = runRecurrentExamples network recinputs inputs 64 | 65 | backPropagations = zipWith makeError outForwards targets 66 | 67 | (outBackwards, input') = runRecurrentBackprop network 0 backPropagations 68 | 69 | gradients = fmap fst outBackwards 70 | in (gradients, input') 71 | 72 | where 73 | 74 | inputs = fst <$> examples 75 | targets = snd <$> examples 76 | 77 | makeError :: (x, S (Last shapes)) -> Maybe (S (Last shapes)) -> (x, S (Last shapes)) 78 | makeError (x, _) Nothing = (x, 0) 79 | makeError (x, y) (Just t) = (x, y - t) 80 | 81 | 82 | trainRecurrent :: forall shapes layers. (SingI (Last shapes), Fractional (RecurrentInputs layers)) 83 | => LearningParameters 84 | -> RecurrentNetwork layers shapes 85 | -> RecurrentInputs layers 86 | -> [(S (Head shapes), Maybe (S (Last shapes)))] 87 | -> (RecurrentNetwork layers shapes, RecurrentInputs layers) 88 | trainRecurrent rate network recinputs examples = 89 | let (gradients, recinputs') = backPropagateRecurrent network recinputs examples 90 | 91 | newInputs = updateRecInputs rate recinputs recinputs' 92 | 93 | newNetwork = foldl' (applyRecurrentUpdate rate) network gradients 94 | 95 | in (newNetwork, newInputs) 96 | 97 | updateRecInputs :: Fractional (RecurrentInputs sublayers) 98 | => LearningParameters 99 | -> RecurrentInputs sublayers 100 | -> RecurrentInputs sublayers 101 | -> RecurrentInputs sublayers 102 | 103 | updateRecInputs lp (() :~~+> xs) (() :~~+> ys) 104 | = () :~~+> updateRecInputs lp xs ys 105 | 106 | updateRecInputs lp (x :~@+> xs) (y :~@+> ys) 107 | = (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys 108 | 109 | updateRecInputs _ RINil RINil 110 | = RINil 111 | -------------------------------------------------------------------------------- /src/Grenade/Recurrent/Layers.hs: -------------------------------------------------------------------------------- 1 | module Grenade.Recurrent.Layers ( 2 | module Grenade.Recurrent.Layers.BasicRecurrent 3 | , module Grenade.Recurrent.Layers.ConcatRecurrent 4 | , module Grenade.Recurrent.Layers.LSTM 5 | ) where 6 | 7 | import Grenade.Recurrent.Layers.BasicRecurrent 8 | import Grenade.Recurrent.Layers.ConcatRecurrent 9 | import Grenade.Recurrent.Layers.LSTM 10 | -------------------------------------------------------------------------------- /src/Grenade/Recurrent/Layers/BasicRecurrent.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE TypeOperators #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE UndecidableInstances #-} 9 | 10 | module Grenade.Recurrent.Layers.BasicRecurrent ( 11 | BasicRecurrent (..) 12 | , randomBasicRecurrent 13 | ) where 14 | 15 | 16 | 17 | import Control.Monad.Random ( MonadRandom, getRandom ) 18 | 19 | 20 | import Data.Kind (Type) 21 | 22 | import Numeric.LinearAlgebra.Static 23 | 24 | import GHC.TypeLits 25 | 26 | import Grenade.Core 27 | import Grenade.Recurrent.Core 28 | 29 | data BasicRecurrent :: Nat -- Input layer size 30 | -> Nat -- Output layer size 31 | -> Type where 32 | BasicRecurrent :: ( KnownNat input 33 | , KnownNat output 34 | , KnownNat matrixCols 35 | , matrixCols ~ (input + output)) 36 | => !(R output) -- Bias neuron weights 37 | -> !(R output) -- Bias neuron momentum 38 | -> !(L output matrixCols) -- Activation 39 | -> !(L output matrixCols) -- Momentum 40 | -> BasicRecurrent input output 41 | 42 | data BasicRecurrent' :: Nat -- Input layer size 43 | -> Nat -- Output layer size 44 | -> Type where 45 | BasicRecurrent' :: ( KnownNat input 46 | , KnownNat output 47 | , KnownNat matrixCols 48 | , matrixCols ~ (input + output)) 49 | => !(R output) -- Bias neuron gradients 50 | -> !(L output matrixCols) 51 | -> BasicRecurrent' input output 52 | 53 | instance Show (BasicRecurrent i o) where 54 | show BasicRecurrent {} = "BasicRecurrent" 55 | 56 | instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where 57 | type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o) 58 | 59 | runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = 60 | let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient 61 | newBias = oldBias + newBiasMomentum 62 | newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient 63 | regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations 64 | newActivations = oldActivations + newMomentum - regulariser 65 | in BasicRecurrent newBias newBiasMomentum newActivations newMomentum 66 | 67 | createRandom = randomBasicRecurrent 68 | 69 | instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentUpdateLayer (BasicRecurrent i o) where 70 | type RecurrentShape (BasicRecurrent i o) = S ('D1 o) 71 | 72 | instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentLayer (BasicRecurrent i o) ('D1 i) ('D1 o) where 73 | 74 | type RecTape (BasicRecurrent i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i)) 75 | -- Do a matrix vector multiplication and return the result. 76 | runRecurrentForwards (BasicRecurrent wB _ wN _) (S1D lastOutput) (S1D thisInput) = 77 | let thisOutput = S1D $ wB + wN #> (thisInput # lastOutput) 78 | in ((S1D lastOutput, S1D thisInput), thisOutput, thisOutput) 79 | 80 | -- Run a backpropogation step for a full connected layer. 81 | runRecurrentBackwards (BasicRecurrent _ _ wN _) (S1D lastOutput, S1D thisInput) (S1D dRec) (S1D dEdy) = 82 | let biasGradient = (dRec + dEdy) 83 | layerGrad = (dRec + dEdy) `outer` (thisInput # lastOutput) 84 | -- calcluate derivatives for next step 85 | (backGrad, recGrad) = split $ tr wN #> (dRec + dEdy) 86 | in (BasicRecurrent' biasGradient layerGrad, S1D recGrad, S1D backGrad) 87 | 88 | randomBasicRecurrent :: (MonadRandom m, KnownNat i, KnownNat o, KnownNat x, x ~ (i + o)) 89 | => m (BasicRecurrent i o) 90 | randomBasicRecurrent = do 91 | seed1 <- getRandom 92 | seed2 <- getRandom 93 | let wB = randomVector seed1 Uniform * 2 - 1 94 | wN = uniformSample seed2 (-1) 1 95 | bm = konst 0 96 | mm = konst 0 97 | return $ BasicRecurrent wB bm wN mm 98 | -------------------------------------------------------------------------------- /src/Grenade/Utils/OneHot.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# LANGUAGE TypeOperators #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | 9 | module Grenade.Utils.OneHot ( 10 | oneHot 11 | , hotMap 12 | , makeHot 13 | , unHot 14 | , sample 15 | ) where 16 | 17 | import qualified Control.Monad.Random as MR 18 | 19 | import Data.List ( group, sort ) 20 | 21 | import Data.Map ( Map ) 22 | import qualified Data.Map as M 23 | 24 | import Data.Proxy 25 | 26 | import Data.Vector ( Vector ) 27 | import qualified Data.Vector as V 28 | import qualified Data.Vector.Storable as VS 29 | 30 | import Numeric.LinearAlgebra ( maxIndex ) 31 | import Numeric.LinearAlgebra.Devel 32 | import Numeric.LinearAlgebra.Static 33 | import GHC.TypeLits 34 | 35 | import Grenade.Core.Shape 36 | 37 | -- | From an int which is hot, create a 1D Shape 38 | -- with one index hot (1) with the rest 0. 39 | -- Rerurns Nothing if the hot number is larger 40 | -- than the length of the vector. 41 | oneHot :: forall n. (KnownNat n) 42 | => Int -> Maybe (S ('D1 n)) 43 | oneHot hot = 44 | let len = fromIntegral $ natVal (Proxy :: Proxy n) 45 | in if hot < len 46 | then 47 | fmap S1D . create $ runSTVector $ do 48 | vec <- newVector 0 len 49 | writeVector vec hot 1 50 | return vec 51 | else Nothing 52 | 53 | -- | Create a one hot map from any enumerable. 54 | -- Returns a map, and the ordered list for the reverse transformation 55 | hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Either String (Map a Int, Vector a) 56 | hotMap n as = 57 | let len = fromIntegral $ natVal n 58 | uniq = [ c | (c:_) <- group $ sort as] 59 | hotl = length uniq 60 | in if hotl == len 61 | then 62 | Right (M.fromList $ zip uniq [0..], V.fromList uniq) 63 | else 64 | Left ("Couldn't create hotMap of size " ++ show len ++ " from vector with " ++ show hotl ++ " unique characters") 65 | 66 | -- | From a map and value, create a 1D Shape 67 | -- with one index hot (1) with the rest 0. 68 | -- Rerurns Nothing if the hot number is larger 69 | -- than the length of the vector or the map 70 | -- doesn't contain the value. 71 | makeHot :: forall a n. (Ord a, KnownNat n) 72 | => Map a Int -> a -> Maybe (S ('D1 n)) 73 | makeHot m x = do 74 | hot <- M.lookup x m 75 | let len = fromIntegral $ natVal (Proxy :: Proxy n) 76 | if hot < len 77 | then 78 | fmap S1D . create $ runSTVector $ do 79 | vec <- newVector 0 len 80 | writeVector vec hot 1 81 | return vec 82 | else Nothing 83 | 84 | unHot :: forall a n. KnownNat n 85 | => Vector a -> S ('D1 n) -> Maybe a 86 | unHot v (S1D xs) 87 | = (V.!?) v 88 | $ maxIndex (extract xs) 89 | 90 | sample :: forall a n m. (KnownNat n, MR.MonadRandom m) 91 | => Double -> Vector a -> S ('D1 n) -> m a 92 | sample temperature v (S1D xs) = do 93 | ix <- MR.fromList . zip [0..] . fmap (toRational . exp . (/ temperature) . log) . VS.toList . extract $ xs 94 | return $ v V.! ix 95 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # A warning or info to be displayed to the user on config load. 8 | # user-message: | 9 | 10 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 11 | # A snapshot resolver dictates the compiler version and the set of packages 12 | # to be used for project dependencies. For example: 13 | # 14 | # resolver: lts-3.5 15 | # resolver: nightly-2015-09-21 16 | # resolver: ghc-7.10.2 17 | # 18 | # The location of a snapshot can be provided as a file or url. Stack assumes 19 | # a snapshot provided as a file might change, whereas a url resource does not. 20 | # 21 | # resolver: ./custom-snapshot.yaml 22 | # resolver: https://example.com/snapshots/2018-01-01.yaml 23 | resolver: lts-20.18 24 | 25 | # User packages to be built. 26 | # Various formats can be used as shown in the example below. 27 | # 28 | # packages: 29 | # - some-directory 30 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | packages: 35 | - . 36 | - examples 37 | 38 | # Dependency packages to be pulled from upstream that are not in the resolver. 39 | # These entries can reference officially published versions as well as 40 | # forks / in-progress versions pinned to a git hash. For example: 41 | # 42 | # extra-deps: 43 | # - acme-missiles-0.3 44 | # - git: https://github.com/commercialhaskell/stack.git 45 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 46 | # 47 | extra-deps: 48 | - typelits-witnesses-0.3.0.3 49 | 50 | # Override default flag values for local packages and extra-deps 51 | # flags: {} 52 | 53 | # Extra package databases containing global packages 54 | # extra-package-dbs: [] 55 | 56 | # Control whether we use the GHC we find on the path 57 | # system-ghc: true 58 | # 59 | # Require a specific version of stack, using version ranges 60 | # require-stack-version: -any # Default 61 | # require-stack-version: ">=2.1" 62 | # 63 | # Override the architecture used by stack, especially useful on Windows 64 | # arch: i386 65 | # arch: x86_64 66 | # 67 | # Extra directories used by stack for building 68 | # extra-include-dirs: [/path/to/dir] 69 | # extra-lib-dirs: [/path/to/dir] 70 | # 71 | # Allow a newer minor version of GHC than the snapshot specifies 72 | # compiler-check: newer-minor 73 | 74 | nix: 75 | enable: true 76 | packages: 77 | - blas 78 | - lapack 79 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: 7 | - completed: 8 | hackage: typelits-witnesses-0.3.0.3@sha256:2d9df4ac6ff3077bfd2bf659e4b495e157723ac5b45c519762853f55df5c16db,2738 9 | pantry-tree: 10 | sha256: 6a42a462f98e94933b6e9721acd912c6c6b6a4743635efd15cc5871908c816a0 11 | size: 469 12 | original: 13 | hackage: typelits-witnesses-0.3.0.3 14 | snapshots: 15 | - completed: 16 | sha256: 9fa4bece7acfac1fc7930c5d6e24606004b09e80aa0e52e9f68b148201008db9 17 | size: 649606 18 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/18.yaml 19 | original: lts-20.18 20 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/Convolution.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE KindSignatures #-} 7 | {-# LANGUAGE ConstraintKinds #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 10 | module Test.Grenade.Layers.Convolution where 11 | 12 | import Unsafe.Coerce 13 | import Data.Constraint 14 | import Data.Proxy 15 | import Data.Singletons () 16 | import GHC.TypeLits 17 | import GHC.TypeLits.Witnesses 18 | 19 | #if MIN_VERSION_base(4,9,0) 20 | import Data.Kind (Type) 21 | #endif 22 | 23 | import Grenade.Core 24 | import Grenade.Layers.Convolution 25 | 26 | import Hedgehog 27 | import qualified Hedgehog.Gen as Gen 28 | 29 | import Test.Hedgehog.Hmatrix 30 | import Test.Hedgehog.TypeLits 31 | import Test.Hedgehog.Compat 32 | 33 | data OpaqueConvolution :: Type where 34 | OpaqueConvolution :: Convolution channels filters kernelRows kernelColumns strideRows strideColumns -> OpaqueConvolution 35 | 36 | instance Show OpaqueConvolution where 37 | show (OpaqueConvolution n) = show n 38 | 39 | genConvolution :: ( KnownNat channels 40 | , KnownNat filters 41 | , KnownNat kernelRows 42 | , KnownNat kernelColumns 43 | , KnownNat strideRows 44 | , KnownNat strideColumns 45 | , KnownNat kernelFlattened 46 | , kernelFlattened ~ (kernelRows * kernelColumns * channels) 47 | ) => Gen (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) 48 | genConvolution = Convolution <$> uniformSample <*> uniformSample 49 | 50 | genOpaqueOpaqueConvolution :: Gen OpaqueConvolution 51 | genOpaqueOpaqueConvolution = do 52 | channels <- genNat 53 | filters <- genNat 54 | kernel_h <- genNat 55 | kernel_w <- genNat 56 | stride_h <- genNat 57 | stride_w <- genNat 58 | case (channels, filters, kernel_h, kernel_w, stride_h, stride_w) of 59 | ( SomeNat (pch :: Proxy ch), SomeNat (_ :: Proxy fl), 60 | SomeNat (pkr :: Proxy kr), SomeNat (pkc :: Proxy kc), 61 | SomeNat (_ :: Proxy sr), SomeNat (_ :: Proxy sc)) -> 62 | let p1 = natDict pkr 63 | p2 = natDict pkc 64 | p3 = natDict pch 65 | in case p1 %* p2 %* p3 of 66 | Dict -> OpaqueConvolution <$> (genConvolution :: Gen (Convolution ch fl kr kc sr sc)) 67 | 68 | prop_conv_net_witness = property $ 69 | blindForAll genOpaqueOpaqueConvolution >>= \onet -> 70 | case onet of 71 | (OpaqueConvolution ((Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) -> success 72 | 73 | 74 | prop_conv_net = property $ 75 | blindForAll genOpaqueOpaqueConvolution >>= \onet -> 76 | case onet of 77 | (OpaqueConvolution (convLayer@(Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) -> 78 | let ok stride kernel = [extent | extent <- [(kernel + 1) .. 30 ], (extent - kernel) `mod` stride == 0] 79 | kr = fromIntegral $ natVal (Proxy :: Proxy kernelRows) 80 | kc = fromIntegral $ natVal (Proxy :: Proxy kernelCols) 81 | sr = fromIntegral $ natVal (Proxy :: Proxy strideRows) 82 | sc = fromIntegral $ natVal (Proxy :: Proxy strideCols) 83 | 84 | in forAll (Gen.element (ok sr kr)) >>= \er -> 85 | forAll (Gen.element (ok sc kc)) >>= \ec -> 86 | let rr = ((er - kr) `div` sr) + 1 87 | rc = ((ec - kc) `div` sc) + 1 88 | Just er' = someNatVal er 89 | Just ec' = someNatVal ec 90 | Just rr' = someNatVal rr 91 | Just rc' = someNatVal rc 92 | in case (er', ec', rr', rc') of 93 | ( SomeNat (pinr :: Proxy inRows), SomeNat (_ :: Proxy inCols), SomeNat (pour :: Proxy outRows), SomeNat (_ :: Proxy outCols)) -> 94 | case ( natDict pinr %* natDict (Proxy :: Proxy channels) 95 | , natDict pour %* natDict (Proxy :: Proxy filters) 96 | -- Fake it till you make it. 97 | , (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows))) 98 | , (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of 99 | (Dict, Dict, Dict, Dict) -> 100 | blindForAll (S3D <$> uniformSample) >>= \(input :: S ('D3 inRows inCols channels)) -> 101 | let (tape, output :: S ('D3 outRows outCols filters)) = runForwards convLayer input 102 | backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels)) 103 | = runBackwards convLayer tape output 104 | in backed `seq` success 105 | 106 | 107 | tests :: IO Bool 108 | tests = checkParallel $$(discover) 109 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/FullyConnected.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE BangPatterns #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE GADTs #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9 | module Test.Grenade.Layers.FullyConnected where 10 | 11 | import Data.Proxy 12 | import Data.Singletons () 13 | 14 | import GHC.TypeLits 15 | 16 | #if MIN_VERSION_base(4,9,0) 17 | import Data.Kind (Type) 18 | #endif 19 | 20 | import Grenade.Core 21 | import Grenade.Layers.FullyConnected 22 | 23 | import Hedgehog 24 | 25 | import Test.Hedgehog.Compat 26 | import Test.Hedgehog.Hmatrix 27 | 28 | data OpaqueFullyConnected :: Type where 29 | OpaqueFullyConnected :: (KnownNat i, KnownNat o) => FullyConnected i o -> OpaqueFullyConnected 30 | 31 | instance Show OpaqueFullyConnected where 32 | show (OpaqueFullyConnected n) = show n 33 | 34 | genOpaqueFullyConnected :: Gen OpaqueFullyConnected 35 | genOpaqueFullyConnected = do 36 | input :: Integer <- choose 2 100 37 | output :: Integer <- choose 1 100 38 | let Just input' = someNatVal input 39 | let Just output' = someNatVal output 40 | case (input', output') of 41 | (SomeNat (Proxy :: Proxy i'), SomeNat (Proxy :: Proxy o')) -> do 42 | wB <- randomVector 43 | bM <- randomVector 44 | wN <- uniformSample 45 | kM <- uniformSample 46 | return . OpaqueFullyConnected $ (FullyConnected (FullyConnected' wB wN) (FullyConnected' bM kM) :: FullyConnected i' o') 47 | 48 | prop_fully_connected_forwards :: Property 49 | prop_fully_connected_forwards = property $ do 50 | OpaqueFullyConnected (fclayer :: FullyConnected i o) <- blindForAll genOpaqueFullyConnected 51 | input :: S ('D1 i) <- blindForAll (S1D <$> randomVector) 52 | let (tape, output :: S ('D1 o)) = runForwards fclayer input 53 | backed :: (Gradient (FullyConnected i o), S ('D1 i)) 54 | = runBackwards fclayer tape output 55 | backed `seq` success 56 | 57 | tests :: IO Bool 58 | tests = checkParallel $$(discover) 59 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/Internal/Convolution.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE ConstraintKinds #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | 8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9 | module Test.Grenade.Layers.Internal.Convolution where 10 | 11 | import Grenade.Layers.Internal.Convolution 12 | 13 | import Numeric.LinearAlgebra hiding (uniformSample, konst, (===)) 14 | 15 | import Hedgehog 16 | import qualified Hedgehog.Gen as Gen 17 | import qualified Hedgehog.Range as Range 18 | 19 | import qualified Test.Grenade.Layers.Internal.Reference as Reference 20 | import Test.Hedgehog.Compat 21 | 22 | prop_im2col_col2im_symmetrical_with_kernel_stride = 23 | let factors n = [x | x <- [1..n], n `mod` x == 0] 24 | in property $ do 25 | height <- forAll $ choose 2 100 26 | width <- forAll $ choose 2 100 27 | kernel_h <- forAll $ (height `div`) <$> Gen.element (factors height) 28 | kernel_w <- forAll $ (width `div`) <$> Gen.element (factors width) 29 | input <- forAll $ (height >< width) <$> Gen.list (Range.singleton $ height * width) (Gen.realFloat $ Range.linearFracFrom 0 (-100) 100) 30 | 31 | let stride_h = kernel_h 32 | let stride_w = kernel_w 33 | let out = col2im kernel_h kernel_w stride_h stride_w height width . im2col kernel_h kernel_w stride_h stride_w $ input 34 | input === out 35 | 36 | prop_im2col_col2im_behaves_as_reference = 37 | let ok extent kernel = [stride | stride <- [1..extent], (extent - kernel) `mod` stride == 0] 38 | in property $ do 39 | height <- forAll (choose 2 100) 40 | width <- forAll (choose 2 100) 41 | kernel_h <- forAll (choose 2 (height - 1)) 42 | kernel_w <- forAll (choose 2 (width - 1)) 43 | stride_h <- forAll (Gen.element (ok height kernel_h)) 44 | stride_w <- forAll (Gen.element (ok width kernel_w)) 45 | input <- forAll ((height >< width) <$> Gen.list (Range.singleton $ height * width) (Gen.realFloat $ Range.linearFracFrom 0 (-100) 100)) 46 | 47 | let outFast = im2col kernel_h kernel_w stride_h stride_w input 48 | let retFast = col2im kernel_h kernel_w stride_h stride_w height width outFast 49 | 50 | let outReference = Reference.im2col kernel_h kernel_w stride_h stride_w input 51 | let retReference = Reference.col2im kernel_h kernel_w stride_h stride_w height width outReference 52 | 53 | outFast === outReference 54 | retFast === retReference 55 | 56 | tests :: IO Bool 57 | tests = checkParallel $$(discover) 58 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/Internal/Pooling.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE ConstraintKinds #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | 8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9 | module Test.Grenade.Layers.Internal.Pooling where 10 | 11 | import Grenade.Layers.Internal.Pooling 12 | 13 | import Numeric.LinearAlgebra hiding (uniformSample, konst, (===)) 14 | 15 | import Hedgehog 16 | import qualified Hedgehog.Gen as Gen 17 | import qualified Hedgehog.Range as Range 18 | 19 | import qualified Test.Grenade.Layers.Internal.Reference as Reference 20 | import Test.Hedgehog.Compat 21 | 22 | prop_poolForwards_poolBackwards_behaves_as_reference = 23 | let ok extent kernel = [stride | stride <- [1..extent], (extent - kernel) `mod` stride == 0] 24 | output extent kernel stride = (extent - kernel) `div` stride + 1 25 | in property $ do 26 | height <- forAll $ choose 2 100 27 | width <- forAll $ choose 2 100 28 | kernel_h <- forAll $ choose 1 (height - 1) 29 | kernel_w <- forAll $ choose 1 (width - 1) 30 | stride_h <- forAll $ Gen.element (ok height kernel_h) 31 | stride_w <- forAll $ Gen.element (ok width kernel_w) 32 | input <- forAll $ (height >< width) <$> Gen.list (Range.singleton $ height * width) (Gen.realFloat $ Range.linearFracFrom 0 (-100) 100) 33 | 34 | let outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input 35 | let retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input outFast 36 | 37 | let outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input 38 | let retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input outReference 39 | 40 | outFast === outReference 41 | retFast === retReference 42 | 43 | 44 | tests :: IO Bool 45 | tests = checkParallel $$(discover) 46 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/Internal/Reference.hs: -------------------------------------------------------------------------------- 1 | module Test.Grenade.Layers.Internal.Reference where 2 | 3 | import Numeric.LinearAlgebra 4 | 5 | im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 6 | im2col nrows ncols srows scols m = 7 | let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols 8 | in im2colFit starts nrows ncols m 9 | 10 | vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double 11 | vid2col nrows ncols srows scols inputrows inputcols ms = 12 | let starts = fittingStarts inputrows nrows srows inputcols ncols scols 13 | subs = fmap (im2colFit starts nrows ncols) ms 14 | in foldl1 (|||) subs 15 | 16 | im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double 17 | im2colFit starts nrows ncols m = 18 | let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts 19 | in fromRows imRows 20 | 21 | col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double] 22 | col2vid nrows ncols srows scols drows dcols m = 23 | let starts = fittingStart (cols m) (nrows * ncols) (nrows * ncols) 24 | r = rows m 25 | mats = fmap (\s -> subMatrix (0,s) (r, nrows * ncols) m) starts 26 | colSts = fittingStarts drows nrows srows dcols ncols scols 27 | in fmap (col2imfit colSts nrows ncols drows dcols) mats 28 | 29 | col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 30 | col2im krows kcols srows scols drows dcols m = 31 | let starts = fittingStarts drows krows srows dcols kcols scols 32 | in col2imfit starts krows kcols drows dcols m 33 | 34 | col2imfit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 35 | col2imfit starts krows kcols drows dcols m = 36 | let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]] 37 | convs = fmap (zip indicies . toList) . toRows $ m 38 | pairs = zip convs starts 39 | accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs 40 | in accum (konst 0 (drows, dcols)) (+) accums 41 | 42 | poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 43 | poolForward nrows ncols srows scols outputRows outputCols m = 44 | let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols 45 | in poolForwardFit starts nrows ncols outputRows outputCols m 46 | 47 | poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double) 48 | poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms = 49 | let starts = fittingStarts inRows nrows srows inCols ncols scols 50 | in poolForwardFit starts nrows ncols outputRows outputCols <$> ms 51 | 52 | poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 53 | poolForwardFit starts nrows ncols _ outputCols m = 54 | let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts 55 | in matrix outputCols els 56 | 57 | poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double 58 | poolBackward krows kcols srows scols inputMatrix gradientMatrix = 59 | let inRows = rows inputMatrix 60 | inCols = cols inputMatrix 61 | starts = fittingStarts inRows krows srows inCols kcols scols 62 | in poolBackwardFit starts krows kcols inputMatrix gradientMatrix 63 | 64 | poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double) 65 | poolBackwardList krows kcols srows scols inRows inCols inputMatrices = 66 | let starts = fittingStarts inRows krows srows inCols kcols scols 67 | in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices 68 | 69 | poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double 70 | poolBackwardFit starts krows kcols inputMatrix gradientMatrix = 71 | let inRows = rows inputMatrix 72 | inCols = cols inputMatrix 73 | inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts 74 | grads = toList $ flatten gradientMatrix 75 | grads' = zip3 starts grads inds 76 | accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads' 77 | in accum (konst 0 (inRows, inCols)) (+) accums 78 | 79 | -- | These functions are not even remotely safe, but it's only called from the statically typed 80 | -- commands, so we should be good ?!?!? 81 | -- Returns the starting sub matrix locations which fit inside the larger matrix for the 82 | -- convolution. Takes into account the stride and kernel size. 83 | fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)] 84 | fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh = 85 | let rs = fittingStart nrows kernelrows steprows 86 | cs = fittingStart ncols kernelcols stepcolsh 87 | ls = sequence [rs, cs] 88 | in fmap (\[a,b] -> (a,b)) ls 89 | 90 | -- | Returns the starting sub vector which fit inside the larger vector for the 91 | -- convolution. Takes into account the stride and kernel size. 92 | fittingStart :: Int -> Int -> Int -> [Int] 93 | fittingStart width kernel steps = 94 | let go left | left + kernel < width 95 | = left : go (left + steps) 96 | | left + kernel == width 97 | = [left] 98 | | otherwise 99 | = [] 100 | in go 0 101 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/Nonlinear.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE KindSignatures #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE LambdaCase #-} 8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9 | module Test.Grenade.Layers.Nonlinear where 10 | 11 | import Data.Singletons 12 | 13 | import Grenade 14 | 15 | import Hedgehog 16 | 17 | import Test.Hedgehog.Compat 18 | import Test.Hedgehog.Hmatrix 19 | import Test.Hedgehog.TypeLits 20 | 21 | import Numeric.LinearAlgebra.Static ( norm_Inf ) 22 | 23 | prop_sigmoid_grad :: Property 24 | prop_sigmoid_grad = property $ 25 | forAllWith rss genShape >>= \case 26 | (SomeSing (r :: Sing s)) -> 27 | withSingI r $ 28 | blindForAll genOfShape >>= \(ds :: S s) -> 29 | let (tape, f :: S s) = runForwards Logit ds 30 | ((), ret :: S s) = runBackwards Logit tape (1 :: S s) 31 | (_, numer :: S s) = runForwards Logit (ds + 0.0001) 32 | numericalGradient = (numer - f) * 10000 33 | in assert ((case numericalGradient - ret of 34 | (S1D x) -> norm_Inf x < 0.0001 35 | (S2D x) -> norm_Inf x < 0.0001 36 | (S3D x) -> norm_Inf x < 0.0001) :: Bool) 37 | 38 | prop_tanh_grad :: Property 39 | prop_tanh_grad = property $ 40 | forAllWith rss genShape >>= \case 41 | (SomeSing (r :: Sing s)) -> 42 | withSingI r $ 43 | blindForAll genOfShape >>= \(ds :: S s) -> 44 | let (tape, f :: S s) = runForwards Tanh ds 45 | ((), ret :: S s) = runBackwards Tanh tape (1 :: S s) 46 | (_, numer :: S s) = runForwards Tanh (ds + 0.0001) 47 | numericalGradient = (numer - f) * 10000 48 | in assert ((case numericalGradient - ret of 49 | (S1D x) -> norm_Inf x < 0.001 50 | (S2D x) -> norm_Inf x < 0.001 51 | (S3D x) -> norm_Inf x < 0.001) :: Bool) 52 | 53 | tests :: IO Bool 54 | tests = checkParallel $$(discover) 55 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/PadCrop.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE CPP #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE GADTs #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9 | 10 | module Test.Grenade.Layers.PadCrop where 11 | 12 | import Grenade 13 | 14 | import Hedgehog 15 | 16 | import Numeric.LinearAlgebra.Static ( norm_Inf ) 17 | 18 | import Test.Hedgehog.Hmatrix 19 | 20 | prop_pad_crop :: Property 21 | prop_pad_crop = 22 | let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D3 7 9 5, 'D3 16 15 5, 'D3 7 9 5 ] 23 | net = Pad :~> Crop :~> NNil 24 | in property $ 25 | forAll genOfShape >>= \(d :: S ('D3 7 9 5)) -> 26 | let (tapes, res) = runForwards net d 27 | (_ , grad) = runBackwards net tapes d 28 | in do assert $ d ~~~ res 29 | assert $ grad ~~~ d 30 | 31 | prop_pad_crop_2d :: Property 32 | prop_pad_crop_2d = 33 | let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ] 34 | net = Pad :~> Crop :~> NNil 35 | in property $ 36 | forAll genOfShape >>= \(d :: S ('D2 7 9)) -> 37 | let (tapes, res) = runForwards net d 38 | (_ , grad) = runBackwards net tapes d 39 | in do assert $ d ~~~ res 40 | assert $ grad ~~~ d 41 | 42 | (~~~) :: S x -> S x -> Bool 43 | (S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001 44 | (S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001 45 | (S3D x) ~~~ (S3D y) = norm_Inf (x - y) < 0.00001 46 | 47 | 48 | tests :: IO Bool 49 | tests = checkParallel $$(discover) 50 | -------------------------------------------------------------------------------- /test/Test/Grenade/Layers/Pooling.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE KindSignatures #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 8 | module Test.Grenade.Layers.Pooling where 9 | 10 | import Data.Proxy 11 | import Data.Singletons () 12 | 13 | #if MIN_VERSION_base(4,9,0) 14 | import Data.Kind (Type) 15 | #endif 16 | 17 | import GHC.TypeLits 18 | import Grenade.Layers.Pooling 19 | 20 | import Hedgehog 21 | 22 | import Test.Hedgehog.Compat 23 | 24 | data OpaquePooling :: Type where 25 | OpaquePooling :: (KnownNat kh, KnownNat kw, KnownNat sh, KnownNat sw) => Pooling kh kw sh sw -> OpaquePooling 26 | 27 | instance Show OpaquePooling where 28 | show (OpaquePooling n) = show n 29 | 30 | genOpaquePooling :: Gen OpaquePooling 31 | genOpaquePooling = do 32 | ~(Just kernelHeight) <- someNatVal <$> choose 2 15 33 | ~(Just kernelWidth ) <- someNatVal <$> choose 2 15 34 | ~(Just strideHeight) <- someNatVal <$> choose 2 15 35 | ~(Just strideWidth ) <- someNatVal <$> choose 2 15 36 | 37 | case (kernelHeight, kernelWidth, strideHeight, strideWidth) of 38 | (SomeNat (_ :: Proxy kh), SomeNat (_ :: Proxy kw), SomeNat (_ :: Proxy sh), SomeNat (_ :: Proxy sw)) -> 39 | return $ OpaquePooling (Pooling :: Pooling kh kw sh sw) 40 | 41 | prop_pool_layer_witness = 42 | property $ do 43 | onet <- forAll genOpaquePooling 44 | case onet of 45 | (OpaquePooling (Pooling :: Pooling kernelRows kernelCols strideRows strideCols)) -> 46 | assert True 47 | 48 | tests :: IO Bool 49 | tests = checkParallel $$(discover) 50 | -------------------------------------------------------------------------------- /test/Test/Grenade/Recurrent/Layers/LSTM.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TemplateHaskell #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE ConstraintKinds #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | 10 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 11 | module Test.Grenade.Recurrent.Layers.LSTM where 12 | 13 | import Hedgehog 14 | import Hedgehog.Internal.Source 15 | import Hedgehog.Internal.Show 16 | import Hedgehog.Internal.Property ( failWith, Diff (..) ) 17 | 18 | import Data.Foldable ( toList ) 19 | import Data.Singletons.TypeLits 20 | 21 | import Grenade 22 | import Grenade.Recurrent 23 | 24 | import qualified Numeric.LinearAlgebra as H 25 | import qualified Numeric.LinearAlgebra.Static as S 26 | 27 | 28 | import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference 29 | import Test.Hedgehog.Hmatrix 30 | 31 | genLSTM :: forall i o. (KnownNat i, KnownNat o) => Gen (LSTM i o) 32 | genLSTM = do 33 | let w = uniformSample 34 | u = uniformSample 35 | v = randomVector 36 | 37 | w0 = S.konst 0 38 | u0 = S.konst 0 39 | v0 = S.konst 0 40 | 41 | LSTM <$> (LSTMWeights <$> w <*> u <*> v <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v) 42 | <*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0) 43 | 44 | prop_lstm_reference_forwards = 45 | property $ do 46 | input :: S.R 3 <- forAll randomVector 47 | cell :: S.R 2 <- forAll randomVector 48 | net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM 49 | 50 | let actual = runRecurrentForwards net (S1D cell) (S1D input) 51 | case actual of 52 | (_, (S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) -> 53 | let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut 54 | output' = Reference.Vector . H.toList . S.extract $ output 55 | refNet = Reference.lstmToReference lstmWeights 56 | refCell = Reference.Vector . H.toList . S.extract $ cell 57 | refInput = Reference.Vector . H.toList . S.extract $ input 58 | (refCO, refO) = Reference.runLSTM refNet refCell refInput 59 | in do toList refCO ~~~ toList cellOut' 60 | toList refO ~~~ toList output' 61 | 62 | 63 | prop_lstm_reference_backwards = 64 | property $ do 65 | input :: S.R 3 <- forAll randomVector 66 | cell :: S.R 2 <- forAll randomVector 67 | net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM 68 | let (tape, _ :: S ('D1 2), _ :: S ('D1 2)) 69 | = runRecurrentForwards net (S1D cell) (S1D input) 70 | actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) 71 | case actualBacks of 72 | (actualGradients, _, _ :: S ('D1 3)) -> 73 | let refNet = Reference.lstmToReference lstmWeights 74 | refCell = Reference.Vector . H.toList . S.extract $ cell 75 | refInput = Reference.Vector . H.toList . S.extract $ input 76 | refGradients = Reference.runLSTMback refCell refInput refNet 77 | in toList refGradients ~~~ toList (Reference.lstmToReference actualGradients) 78 | 79 | prop_lstm_reference_backwards_input = 80 | property $ do 81 | input :: S.R 3 <- forAll randomVector 82 | cell :: S.R 2 <- forAll randomVector 83 | net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM 84 | let (tape, _ :: S ('D1 2), _ :: S ('D1 2)) 85 | = runRecurrentForwards net (S1D cell) (S1D input) 86 | actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) 87 | case actualBacks of 88 | (_, _, S1D actualGradients :: S ('D1 3)) -> 89 | let refNet = Reference.lstmToReference lstmWeights 90 | refCell = Reference.Vector . H.toList . S.extract $ cell 91 | refInput = Reference.Vector . H.toList . S.extract $ input 92 | refGradients = Reference.runLSTMbackOnInput refCell refNet refInput 93 | in toList refGradients ~~~ H.toList (S.extract actualGradients) 94 | 95 | prop_lstm_reference_backwards_cell = 96 | property $ do 97 | input :: S.R 3 <- forAll randomVector 98 | cell :: S.R 2 <- forAll randomVector 99 | net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM 100 | let (tape, _ :: S ('D1 2), _ :: S ('D1 2)) 101 | = runRecurrentForwards net (S1D cell) (S1D input) 102 | actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) 103 | case actualBacks of 104 | (_, S1D actualGradients, _ :: S ('D1 3)) -> 105 | let refNet = Reference.lstmToReference lstmWeights 106 | refCell = Reference.Vector . H.toList . S.extract $ cell 107 | refInput = Reference.Vector . H.toList . S.extract $ input 108 | refGradients = Reference.runLSTMbackOnCell refInput refNet refCell 109 | in toList refGradients ~~~ H.toList (S.extract actualGradients) 110 | 111 | (~~~) :: (Monad m, Eq a, Ord a, Num a, Fractional a, Show a, HasCallStack) => [a] -> [a] -> PropertyT m () 112 | (~~~) x y = 113 | if all (< 1e-8) (zipWith (-) x y) then 114 | success 115 | else 116 | case valueDiff <$> mkValue x <*> mkValue y of 117 | Nothing -> 118 | withFrozenCallStack $ 119 | failWith Nothing $ unlines [ 120 | "━━━ Not Simliar ━━━" 121 | , showPretty x 122 | , showPretty y 123 | ] 124 | Just differ -> 125 | withFrozenCallStack $ 126 | failWith (Just $ Diff "Failed (" "- lhs" "~/~" "+ rhs" ")" differ) "" 127 | infix 4 ~~~ 128 | 129 | tests :: IO Bool 130 | tests = checkParallel $$(discover) 131 | -------------------------------------------------------------------------------- /test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE ConstraintKinds #-} 5 | {-# LANGUAGE TypeOperators #-} 6 | {-# LANGUAGE DeriveFunctor #-} 7 | {-# LANGUAGE DeriveFoldable #-} 8 | {-# LANGUAGE DeriveTraversable #-} 9 | {-# LANGUAGE FlexibleContexts #-} 10 | {-# LANGUAGE RankNTypes #-} 11 | 12 | {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 13 | module Test.Grenade.Recurrent.Layers.LSTM.Reference where 14 | 15 | import Data.Reflection 16 | import Numeric.AD.Mode.Reverse 17 | import Numeric.AD.Internal.Reverse (Tape) 18 | 19 | import GHC.TypeLits (KnownNat) 20 | 21 | import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..)) 22 | import qualified Grenade.Recurrent.Layers.LSTM as LSTM 23 | 24 | import qualified Numeric.LinearAlgebra.Static as S 25 | import qualified Numeric.LinearAlgebra as H 26 | 27 | -- 28 | -- This module contains a set of list only versions of 29 | -- an LSTM layer which can be used with the AD library. 30 | -- 31 | -- Using this, we can check to make sure that our fast 32 | -- back propagation implementation is correct. 33 | -- 34 | 35 | -- | List only matrix deriving functor 36 | data Matrix a = Matrix { 37 | matrixWeights :: [[a]] 38 | } deriving (Functor, Foldable, Traversable, Eq, Show) 39 | 40 | -- | List only vector deriving functor 41 | data Vector a = Vector { 42 | vectorWeights :: [a] 43 | } deriving (Functor, Foldable, Traversable, Eq, Show) 44 | 45 | -- | List only LSTM weights 46 | data RefLSTM a = RefLSTM 47 | { refLstmWf :: Matrix a -- Weight Forget (W_f) 48 | , refLstmUf :: Matrix a -- Cell State Forget (U_f) 49 | , refLstmBf :: Vector a -- Bias Forget (b_f) 50 | , refLstmWi :: Matrix a -- Weight Input (W_i) 51 | , refLstmUi :: Matrix a -- Cell State Input (U_i) 52 | , refLstmBi :: Vector a -- Bias Input (b_i) 53 | , refLstmWo :: Matrix a -- Weight Output (W_o) 54 | , refLstmUo :: Matrix a -- Cell State Output (U_o) 55 | , refLstmBo :: Vector a -- Bias Output (b_o) 56 | , refLstmWc :: Matrix a -- Weight Cell (W_c) 57 | , refLstmBc :: Vector a -- Bias Cell (b_c) 58 | } deriving (Functor, Foldable, Traversable, Eq, Show) 59 | 60 | lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double 61 | lstmToReference lw = 62 | RefLSTM 63 | { refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f) 64 | , refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f) 65 | , refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f) 66 | , refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i) 67 | , refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i) 68 | , refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i) 69 | , refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o) 70 | , refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o) 71 | , refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o) 72 | , refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c) 73 | , refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c) 74 | } 75 | 76 | runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a) 77 | runLSTM rl cell input = 78 | let -- Forget state vector 79 | f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell 80 | -- Input state vector 81 | i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell 82 | -- Output state vector 83 | o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell 84 | -- Cell input state vector 85 | c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input 86 | -- Cell state 87 | c_t = f_t #* cell #+ i_t #* c_x 88 | -- Output (it's sometimes recommended to use tanh c_t) 89 | h_t = o_t #* c_t 90 | in (c_t, h_t) 91 | 92 | runLSTMback :: forall a. Floating a => Vector a -> Vector a -> RefLSTM a -> RefLSTM a 93 | runLSTMback cell input = 94 | grad f 95 | where 96 | f :: forall s. Reifies s Tape => RefLSTM (Reverse s a) -> Reverse s a 97 | f net = 98 | let cell' = fmap auto cell 99 | input' = fmap auto input 100 | (cells, forwarded) = runLSTM net cell' input' 101 | in sum forwarded + sum cells 102 | 103 | runLSTMbackOnInput :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a 104 | runLSTMbackOnInput cell net = 105 | grad f 106 | where 107 | f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a 108 | f input = 109 | let cell' = fmap auto cell 110 | net' = fmap auto net 111 | (cells, forwarded) = runLSTM net' cell' input 112 | in sum forwarded + sum cells 113 | 114 | runLSTMbackOnCell :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a 115 | runLSTMbackOnCell input net = 116 | grad f 117 | where 118 | f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a 119 | f cell = 120 | let input' = fmap auto input 121 | net' = fmap auto net 122 | (cells, forwarded) = runLSTM net' cell input' 123 | in sum forwarded + sum cells 124 | 125 | -- | Helper to multiply a matrix by a vector 126 | matMult :: Num a => Matrix a -> Vector a -> Vector a 127 | matMult (Matrix m) (Vector v) = Vector result 128 | where 129 | lrs = map length m 130 | l = length v 131 | result = if all (== l) lrs 132 | then map (\r -> sum $ zipWith (*) r v) m 133 | else error $ "Matrix has rows of length " ++ show lrs ++ 134 | " but vector is of length " ++ show l 135 | 136 | (#>) :: Num a => Matrix a -> Vector a -> Vector a 137 | (#>) = matMult 138 | infixr 8 #> 139 | 140 | (#+) :: Num a => Vector a -> Vector a -> Vector a 141 | (#+) (Vector as) (Vector bs) = Vector $ zipWith (+) as bs 142 | infixl 6 #+ 143 | 144 | (#-) :: Num a => Vector a -> Vector a -> Vector a 145 | (#-) (Vector as) (Vector bs) = Vector $ zipWith (-) as bs 146 | infixl 6 #- 147 | 148 | (#*) :: Num a => Vector a -> Vector a -> Vector a 149 | (#*) (Vector as) (Vector bs) = Vector $ zipWith (*) as bs 150 | infixl 7 #* 151 | 152 | sigmoid :: (Functor f, Floating a) => f a -> f a 153 | sigmoid xs = (\x -> 1 / (1 + exp (-x))) <$> xs 154 | -------------------------------------------------------------------------------- /test/Test/Hedgehog/Compat.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | module Test.Hedgehog.Compat ( 3 | (...) 4 | , choose 5 | , blindForAll 6 | )where 7 | 8 | import Hedgehog (Gen) 9 | import qualified Hedgehog.Gen as Gen 10 | import qualified Hedgehog.Range as Range 11 | import Hedgehog.Internal.Property 12 | 13 | (...) :: (c -> d) -> (a -> b -> c) -> a -> b -> d 14 | (...) = (.) . (.) 15 | {-# INLINE (...) #-} 16 | 17 | choose :: Integral a => a -> a -> Gen a 18 | choose = Gen.integral ... Range.constant 19 | 20 | blindForAll :: Monad m => Gen a -> PropertyT m a 21 | blindForAll = forAllWith (const "blind") 22 | -------------------------------------------------------------------------------- /test/Test/Hedgehog/Hmatrix.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE KindSignatures #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | 7 | module Test.Hedgehog.Hmatrix where 8 | 9 | import Grenade 10 | import Data.Singletons 11 | import Data.Singletons.TypeLits 12 | 13 | import Hedgehog (Gen) 14 | import qualified Hedgehog.Gen as Gen 15 | import qualified Hedgehog.Range as Range 16 | 17 | import qualified Numeric.LinearAlgebra.Static as HStatic 18 | 19 | randomVector :: forall n. ( KnownNat n ) => Gen (HStatic.R n) 20 | randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> Gen.int Range.linearBounded 21 | 22 | uniformSample :: forall m n. ( KnownNat m, KnownNat n ) => Gen (HStatic.L m n) 23 | uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> Gen.int Range.linearBounded 24 | 25 | -- | Generate random data of the desired shape 26 | genOfShape :: forall x. ( SingI x ) => Gen (S x) 27 | genOfShape = 28 | case (sing :: Sing x) of 29 | D1Sing l -> 30 | withKnownNat l $ 31 | S1D <$> randomVector 32 | D2Sing r c -> 33 | withKnownNat r $ withKnownNat c $ 34 | S2D <$> uniformSample 35 | D3Sing r c d -> 36 | withKnownNat r $ withKnownNat c $ withKnownNat d $ 37 | S3D <$> uniformSample 38 | 39 | nice :: S shape -> String 40 | nice (S1D x) = show . HStatic.extract $ x 41 | nice (S2D x) = show . HStatic.extract $ x 42 | nice (S3D x) = show . HStatic.extract $ x 43 | -------------------------------------------------------------------------------- /test/Test/Hedgehog/TypeLits.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE RankNTypes #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE PolyKinds #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE GADTs #-} 8 | module Test.Hedgehog.TypeLits where 9 | 10 | import Data.Constraint (Dict (..)) 11 | import Data.Singletons (Proxy (..), Sing (..), SomeSing (..), sing) 12 | #if MIN_VERSION_singletons(2,6,0) 13 | import Data.Singletons.TypeLits (SNat (..)) 14 | #endif 15 | 16 | import Hedgehog (Gen) 17 | import qualified Hedgehog.Gen as Gen 18 | 19 | import Grenade 20 | 21 | import GHC.TypeLits (SomeNat (..), natVal, someNatVal) 22 | import GHC.TypeLits.Witnesses ((%*), natDict) 23 | import Test.Hedgehog.Compat (choose) 24 | 25 | genNat :: Gen SomeNat 26 | genNat = do 27 | ~(Just n) <- someNatVal <$> choose 1 10 28 | return n 29 | 30 | genShape :: Gen (SomeSing Shape) 31 | genShape 32 | = Gen.choice [ 33 | genD1 34 | , genD2 35 | , genD3 36 | ] 37 | 38 | genD1 :: Gen (SomeSing Shape) 39 | genD1 = do 40 | n <- genNat 41 | return $ case n of 42 | SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x)) 43 | 44 | genD2 :: Gen (SomeSing Shape) 45 | genD2 = do 46 | n <- genNat 47 | m <- genNat 48 | return $ case (n, m) of 49 | (SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y)) 50 | 51 | genD3 :: Gen (SomeSing Shape) 52 | genD3 = do 53 | n <- genNat 54 | m <- genNat 55 | o <- genNat 56 | return $ case (n, m, o) of 57 | (SomeNat (px :: Proxy x), SomeNat (_ :: Proxy y), SomeNat (pz :: Proxy z)) -> 58 | case natDict px %* natDict pz of 59 | Dict -> SomeSing (sing :: Sing ('D3 x y z)) 60 | 61 | rss :: SomeSing Shape -> String 62 | rss (SomeSing (r :: Sing s)) = case r of 63 | (D1Sing a@SNat) -> "D1 " ++ show (natVal a) 64 | (D2Sing a@SNat b@SNat) -> "D2 " ++ show (natVal a) ++ " " ++ show (natVal b) 65 | (D3Sing a@SNat b@SNat c@SNat) -> "D3 " ++ show (natVal a) ++ " " ++ show (natVal b) ++ " " ++ show (natVal c) 66 | -------------------------------------------------------------------------------- /test/test.hs: -------------------------------------------------------------------------------- 1 | import Control.Monad 2 | 3 | import qualified Test.Grenade.Network 4 | 5 | import qualified Test.Grenade.Layers.Pooling 6 | import qualified Test.Grenade.Layers.Convolution 7 | import qualified Test.Grenade.Layers.FullyConnected 8 | import qualified Test.Grenade.Layers.Nonlinear 9 | import qualified Test.Grenade.Layers.PadCrop 10 | 11 | import qualified Test.Grenade.Layers.Internal.Convolution 12 | import qualified Test.Grenade.Layers.Internal.Pooling 13 | 14 | import qualified Test.Grenade.Recurrent.Layers.LSTM 15 | 16 | import System.Exit 17 | import System.IO 18 | 19 | main :: IO () 20 | main = 21 | disorderMain [ 22 | Test.Grenade.Network.tests 23 | 24 | , Test.Grenade.Layers.Pooling.tests 25 | , Test.Grenade.Layers.Convolution.tests 26 | , Test.Grenade.Layers.FullyConnected.tests 27 | , Test.Grenade.Layers.Nonlinear.tests 28 | , Test.Grenade.Layers.PadCrop.tests 29 | 30 | , Test.Grenade.Layers.Internal.Convolution.tests 31 | , Test.Grenade.Layers.Internal.Pooling.tests 32 | 33 | , Test.Grenade.Recurrent.Layers.LSTM.tests 34 | 35 | ] 36 | 37 | disorderMain :: [IO Bool] -> IO () 38 | disorderMain tests = do 39 | lineBuffer 40 | rs <- sequence tests 41 | unless (and rs) exitFailure 42 | 43 | 44 | lineBuffer :: IO () 45 | lineBuffer = do 46 | hSetBuffering stdout LineBuffering 47 | hSetBuffering stderr LineBuffering 48 | --------------------------------------------------------------------------------